├── .gitignore
├── .readthedocs.yaml
├── CHANGELOG.md
├── LICENSE.md
├── MANIFEST.in
├── README.md
├── docs
├── .gitignore
├── Makefile
├── README.md
├── make.bat
├── requirements.txt
└── source
│ ├── _static
│ ├── custom.css
│ ├── dark-logo.svg
│ ├── light-logo.svg
│ └── switcher.json
│ ├── _templates
│ ├── custom_sidebar.html
│ └── navbar-mid.html
│ ├── conf.py
│ ├── index.md
│ ├── llm-api
│ ├── datasets.rst
│ ├── evaluators.rst
│ ├── exporter.rst
│ ├── logger.rst
│ └── types.rst
│ └── ml-api
│ ├── exporter.rst
│ └── logger.rst
├── pyproject.toml
├── src
└── arize
│ ├── __init__.py
│ ├── api.py
│ ├── bounded_executor.py
│ ├── experimental
│ ├── __init__.py
│ ├── datasets
│ │ ├── __init__.py
│ │ ├── core
│ │ │ ├── __init__.py
│ │ │ ├── client.py
│ │ │ └── session.py
│ │ ├── experiments
│ │ │ ├── __init__.py
│ │ │ ├── evaluators
│ │ │ │ ├── __init__.py
│ │ │ │ ├── base.py
│ │ │ │ ├── exceptions.py
│ │ │ │ ├── executors.py
│ │ │ │ ├── rate_limiters.py
│ │ │ │ └── utils.py
│ │ │ ├── functions.py
│ │ │ ├── tracing.py
│ │ │ └── types.py
│ │ ├── utils
│ │ │ ├── __init__.py
│ │ │ ├── constants.py
│ │ │ └── experiment_utils.py
│ │ └── validation
│ │ │ ├── __init__.py
│ │ │ ├── errors.py
│ │ │ └── validator.py
│ ├── integrations
│ │ ├── whylabs
│ │ │ ├── __init__.py
│ │ │ ├── client.py
│ │ │ ├── generator.py
│ │ │ └── test.py
│ │ ├── whylabs_vanguard_governance
│ │ │ ├── __init__.py
│ │ │ ├── client.py
│ │ │ └── test.py
│ │ └── whylabs_vanguard_ingestion
│ │ │ ├── __init__.py
│ │ │ ├── client.py
│ │ │ ├── generator.py
│ │ │ └── test.py
│ ├── online_tasks
│ │ ├── __init__.py
│ │ └── dataframe_preprocessor.py
│ └── prompt_hub
│ │ ├── __init__.py
│ │ ├── client.py
│ │ ├── constants.py
│ │ └── prompts
│ │ ├── __init__.py
│ │ ├── base.py
│ │ ├── open_ai.py
│ │ └── vertex_ai.py
│ ├── exporter
│ ├── README.md
│ ├── __init__.py
│ ├── core
│ │ ├── __init__.py
│ │ ├── client.py
│ │ ├── query.py
│ │ └── session.py
│ ├── publicexporter_pb2.py
│ └── utils
│ │ ├── __init__.py
│ │ ├── constants.py
│ │ ├── errors.py
│ │ ├── schema_parser.py
│ │ ├── tracing.py
│ │ └── validation.py
│ ├── pandas
│ ├── __init__.py
│ ├── embeddings
│ │ ├── __init__.py
│ │ ├── auto_generator.py
│ │ ├── base_generators.py
│ │ ├── constants.py
│ │ ├── cv_generators.py
│ │ ├── errors.py
│ │ ├── nlp_generators.py
│ │ ├── tabular_generators.py
│ │ └── usecases.py
│ ├── etl
│ │ ├── __init__.py
│ │ ├── casting.py
│ │ └── errors.py
│ ├── generative
│ │ ├── __init__.py
│ │ ├── llm_evaluation
│ │ │ ├── __init__.py
│ │ │ ├── constants.py
│ │ │ └── hf_metrics.py
│ │ └── nlp_metrics
│ │ │ ├── __init__.py
│ │ │ └── hf_metrics.py
│ ├── logger.py
│ ├── proto
│ │ ├── __init__.py
│ │ └── requests_pb2.py
│ ├── surrogate_explainer
│ │ ├── __init__.py
│ │ └── mimic.py
│ ├── tracing
│ │ ├── __init__.py
│ │ ├── columns.py
│ │ ├── constants.py
│ │ ├── types.py
│ │ ├── utils.py
│ │ └── validation
│ │ │ ├── __init__.py
│ │ │ ├── annotations
│ │ │ ├── __init__.py
│ │ │ ├── annotations_validation.py
│ │ │ ├── dataframe_form_validation.py
│ │ │ └── value_validation.py
│ │ │ ├── common
│ │ │ ├── __init__.py
│ │ │ ├── argument_validation.py
│ │ │ ├── dataframe_form_validation.py
│ │ │ ├── errors.py
│ │ │ └── value_validation.py
│ │ │ ├── evals
│ │ │ ├── __init__.py
│ │ │ ├── dataframe_form_validation.py
│ │ │ ├── evals_validation.py
│ │ │ └── value_validation.py
│ │ │ ├── metadata
│ │ │ ├── __init__.py
│ │ │ ├── argument_validation.py
│ │ │ ├── dataframe_form_validation.py
│ │ │ └── value_validation.py
│ │ │ └── spans
│ │ │ ├── __init__.py
│ │ │ ├── dataframe_form_validation.py
│ │ │ ├── spans_validation.py
│ │ │ └── value_validation.py
│ └── validation
│ │ ├── __init__.py
│ │ ├── errors.py
│ │ └── validator.py
│ ├── public_pb2.py
│ ├── single_log
│ ├── __init__.py
│ ├── casting.py
│ └── errors.py
│ ├── utils
│ ├── __init__.py
│ ├── constants.py
│ ├── errors.py
│ ├── logging.py
│ ├── model_mapping.json
│ ├── proto.py
│ ├── types.py
│ └── utils.py
│ └── version.py
└── tests
├── __init__.py
├── experimental
├── __init__.py
└── datasets
│ ├── __init__.py
│ ├── experiments
│ ├── __init__.py
│ └── test_experiments.py
│ └── validation
│ ├── __init__.py
│ └── test_validator.py
├── exporter
├── __init__.py
├── test_exporter.py
├── utils
│ ├── __init__.py
│ └── test_schema_parser.py
└── validations
│ ├── test_validator_invalid_types.py
│ └── test_validator_invalid_values.py
├── fixtures
├── __init__.py
└── mpg.csv
├── pandas
├── etl
│ └── test_casting_config.py
├── generative
│ └── nlp_metrics
│ │ └── test_nlp_metrics.py
├── logger
│ ├── test_log_annotations.py
│ └── test_log_metadata.py
├── surrogate_explainer
│ └── test_surrogate_explainer.py
├── test_pandas_logger.py
├── tracing
│ ├── test_columns.py
│ ├── test_logger_tracing.py
│ └── validation
│ │ ├── test_invalid_annotation_arguments.py
│ │ ├── test_invalid_annotation_values.py
│ │ ├── test_invalid_arguments.py
│ │ ├── test_invalid_metadata_arguments.py
│ │ ├── test_invalid_metadata_values.py
│ │ └── test_invalid_values.py
└── validation
│ ├── test_pandas_validator_error_messages.py
│ ├── test_pandas_validator_invalid_params.py
│ ├── test_pandas_validator_invalid_records.py
│ ├── test_pandas_validator_invalid_reserved_columns.py
│ ├── test_pandas_validator_invalid_shap_suffix.py
│ ├── test_pandas_validator_invalid_types.py
│ ├── test_pandas_validator_invalid_values.py
│ ├── test_pandas_validator_missing_columns.py
│ ├── test_pandas_validator_modelmapping.py
│ ├── test_pandas_validator_ranking_param_checks.py
│ ├── test_pandas_validator_ranking_record_checks.py
│ ├── test_pandas_validator_ranking_type_checks.py
│ ├── test_pandas_validator_ranking_value_checks.py
│ └── test_pandas_validator_required_checks.py
├── single_log
└── test_casting.py
├── test_api.py
├── test_utils.py
└── types
├── test_embedding_types.py
├── test_multi_class_types.py
└── test_type_helpers.py
/.gitignore:
--------------------------------------------------------------------------------
1 | *.py[cod]
2 |
3 | # C extensions
4 | *.so
5 |
6 | # IDEs
7 | .idea/
8 | .vscode/
9 |
10 | # Packages
11 | *.egg
12 | *.egg-info
13 | .eggs/
14 | # build
15 | parts
16 | bin
17 | var
18 | sdist
19 | dist
20 | develop-eggs
21 | .installed.cfg
22 | lib
23 | lib64
24 |
25 | # Installer logs
26 | pip-log.txt
27 |
28 | # Unit test / coverage reports
29 | .coverage
30 | .tox
31 | nosetests.xml
32 |
33 | # Complexity
34 | output/*.html
35 | output/*/index.html
36 |
37 | # Sphinx
38 | docs/_build
39 |
40 | # Cookiecutter
41 | output/
42 |
43 | .workon
44 |
--------------------------------------------------------------------------------
/.readthedocs.yaml:
--------------------------------------------------------------------------------
1 | # Read the Docs configuration file for Sphinx projects
2 | # See https://docs.readthedocs.io/en/stable/config-file/v2.html for details
3 |
4 | # Required
5 | version: 2
6 |
7 | # Set the OS, Python version and other tools you might need
8 | build:
9 | os: ubuntu-22.04
10 | tools:
11 | python: "3.12"
12 | # You can also specify other tool versions:
13 | # nodejs: "20"
14 | # rust: "1.70"
15 | # golang: "1.20"
16 |
17 | # Build documentation in the "docs/" directory with Sphinx
18 | sphinx:
19 | configuration: docs/source/conf.py
20 | # You can configure Sphinx to use a different builder, for instance use the dirhtml builder for simpler URLs
21 | # builder: "dirhtml"
22 | # Fail on all warnings to avoid broken references
23 | # fail_on_warning: true
24 |
25 | # Optionally build your docs in additional formats such as PDF and ePub
26 | # formats:
27 | # - pdf
28 | # - epub
29 |
30 | # Optional but recommended, declare the Python requirements required
31 | # to build your documentation
32 | # See https://docs.readthedocs.io/en/stable/guides/reproducible-builds.html
33 | python:
34 | install:
35 | - requirements: docs/requirements.txt
36 | - method: pip
37 | path: .
38 |
--------------------------------------------------------------------------------
/LICENSE.md:
--------------------------------------------------------------------------------
1 | Copyright (c) 2020, Arize AI
2 | All rights reserved.
3 |
4 | Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met:
5 |
6 | 1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer.
7 |
8 | 2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution.
9 |
10 | 3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission.
11 |
12 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
13 |
--------------------------------------------------------------------------------
/MANIFEST.in:
--------------------------------------------------------------------------------
1 | include README.md requirements.txt LICENSE.md arize/utils/*.json
2 |
--------------------------------------------------------------------------------
/docs/.gitignore:
--------------------------------------------------------------------------------
1 | build
2 |
--------------------------------------------------------------------------------
/docs/Makefile:
--------------------------------------------------------------------------------
1 | # Minimal makefile for Sphinx documentation
2 | #
3 |
4 | # You can set these variables from the command line, and also
5 | # from the environment for the first two.
6 | SPHINXOPTS ?=
7 | SPHINXBUILD ?= sphinx-build
8 | SPHINXPROJ = arize
9 | SOURCEDIR = source
10 | BUILDDIR = build
11 |
12 | # Put it first so that "make" without argument is like "make help".
13 | help:
14 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
15 |
16 | .PHONY: help Makefile
17 |
18 | # Catch-all target: route all unknown targets to Sphinx using the new
19 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS).
20 | %: Makefile
21 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
22 |
--------------------------------------------------------------------------------
/docs/README.md:
--------------------------------------------------------------------------------
1 | # Maintenance README for Arize Sphinx API Documentation
2 |
3 | This API reference provides comprehensive details for Arize's API. The documentation covers only public, user-facing API endpoints offered in Arize.
4 |
5 | Maintaining the API reference consists of two parts:
6 |
7 | 1. Building the documentation with Sphinx
8 | 2. Hosting and CI with readthedocs
9 |
10 | ## TL;DR
11 | ```
12 | uv venv --python=python3.11
13 | uv pip install -r requirements.txt
14 | make clean html
15 | # then open build/html/index.html in your browser
16 | # currently, the build/html directory is copied over as the static site for arize-docs.onrender.com
17 | ```
18 |
19 | ## Files
20 | - conf.py: All sphinx-related configuration is done here and is necessary to run Sphinx.
21 | - index.md: Main entrypoint for the API reference. This file must be in the `source` directory. For documentation to show up on the API reference, there must be a path (does not have to be direct) defined in index.md to the target documentation file.
22 | - requirements.txt: This file is necessary for management of dependencies on the readthedocs platform and its build process.
23 | - make files: Not required but useful in generating static HTML pages locally.
24 |
25 | ## Useful references
26 | https://pydata-sphinx-theme.readthedocs.io/
27 | https://sphinx-design.readthedocs.io/en/latest/
28 | https://www.sphinx-doc.org/en/master/usage/restructuredtext/basics.html
29 | https://www.sphinx-doc.org/en/master/usage/extensions/autodoc.html
30 | https://docs.readthedocs.io/en/stable/automation-rules.html
--------------------------------------------------------------------------------
/docs/make.bat:
--------------------------------------------------------------------------------
1 | @ECHO OFF
2 |
3 | pushd %~dp0
4 |
5 | REM Command file for Sphinx documentation
6 |
7 | if "%SPHINXBUILD%" == "" (
8 | set SPHINXBUILD=sphinx-build
9 | )
10 | set SOURCEDIR=source
11 | set BUILDDIR=build
12 | set SPHINXPROJ=arize
13 |
14 | %SPHINXBUILD% >NUL 2>NUL
15 | if errorlevel 9009 (
16 | echo.
17 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx
18 | echo.installed, then set the SPHINXBUILD environment variable to point
19 | echo.to the full path of the 'sphinx-build' executable. Alternatively you
20 | echo.may add the Sphinx directory to PATH.
21 | echo.
22 | echo.If you don't have Sphinx installed, grab it from
23 | echo.https://www.sphinx-doc.org/
24 | exit /b 1
25 | )
26 |
27 | if "%1" == "" goto help
28 |
29 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%
30 | goto end
31 |
32 | :help
33 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%
34 |
35 | :end
36 | popd
37 |
--------------------------------------------------------------------------------
/docs/requirements.txt:
--------------------------------------------------------------------------------
1 | myst_parser
2 | sphinx==7.3.7
3 | pydata-sphinx-theme
4 | linkify-it-py
5 | sphinx_design
6 |
7 | -e ".[Datasets]"
8 |
--------------------------------------------------------------------------------
/docs/source/_static/custom.css:
--------------------------------------------------------------------------------
1 | @import "pydata_sphinx_theme.scss";
2 |
3 | /* -- index ------------------------------------------------------------------ */
4 |
5 | .bd-article h1 {
6 | display: flex;
7 | justify-content: center;
8 | }
9 |
10 | /* .bd-article #api-definition h2 {
11 | display: flex;
12 | justify-content: center;
13 | } */
14 |
15 | #main-content > .bd-article > img {
16 | display: flex;
17 | justify-content: center;
18 | }
19 |
20 | #arize-api-reference > #external-links {
21 | display: flex;
22 | flex-direction: row;
23 | justify-content: space-around;
24 | }
25 |
26 | #arize-api-reference > #api-definition > .toctree-wrapper > ul {
27 | display: flex;
28 | flex-direction: row;
29 | justify-content: space-between;
30 | flex-wrap: wrap;
31 | }
32 |
33 | #arize-api-reference > #api-definition > .toctree-wrapper > ul > li {
34 | padding: 1em 1em 1em 1em;
35 | }
36 |
37 | /* -- navbar display --------------------------------------------------------- */
38 |
39 | .navbar-nav .nav-link {
40 | text-transform: capitalize;
41 | }
42 |
43 | /* -- sidebar display -------------------------------------------------------- */
44 |
45 | /* -- signature display ------------------------------------------------------ */
46 |
47 | .sig.sig-object.py {
48 | font-style: normal !important;
49 | font-feature-settings: "kern";
50 | font-family: "Roboto Mono", "Courier New", Courier, monospace;
51 | background: var(--pst-color-surface);
52 | }
53 |
54 | .sig.sig-object.py .sig-param,
55 | .sig-paren {
56 | font-weight: normal;
57 | }
58 |
59 | .sig.sig-object.py em {
60 | font-style: normal;
61 | }
62 |
--------------------------------------------------------------------------------
/docs/source/_static/dark-logo.svg:
--------------------------------------------------------------------------------
1 |
11 |
--------------------------------------------------------------------------------
/docs/source/_static/light-logo.svg:
--------------------------------------------------------------------------------
1 |
11 |
--------------------------------------------------------------------------------
/docs/source/_static/switcher.json:
--------------------------------------------------------------------------------
1 | [
2 | {
3 | "version": "latest",
4 | "url": "https://arize-client-python.readthedocs.io/en/latest/",
5 | "preferred": true
6 | },
7 | {
8 | "version": "v7.43.0",
9 | "url": "https://arize-client-python.readthedocs.io/en/v7.43.0/"
10 | },
11 | {
12 | "version": "v7.42.0",
13 | "url": "https://arize-client-python.readthedocs.io/en/v7.42.0/"
14 | },
15 | {
16 | "version": "v7.41.0",
17 | "url": "https://arize-client-python.readthedocs.io/en/v7.41.0/"
18 | },
19 | {
20 | "version": "v7.40.0",
21 | "url": "https://arize-client-python.readthedocs.io/en/v7.40.0/"
22 | },
23 | {
24 | "version": "v7.39.0",
25 | "url": "https://arize-client-python.readthedocs.io/en/v7.39.0/"
26 | },
27 | {
28 | "version": "v7.38.0",
29 | "url": "https://arize-client-python.readthedocs.io/en/v7.38.0/"
30 | },
31 | {
32 | "version": "v7.37.0",
33 | "url": "https://arize-client-python.readthedocs.io/en/v7.37.0/"
34 | },
35 | {
36 | "version": "v7.36.0",
37 | "url": "https://arize-client-python.readthedocs.io/en/v7.36.0/"
38 | },
39 | {
40 | "version": "v7.35.0",
41 | "url": "https://arize-client-python.readthedocs.io/en/v7.35.0/"
42 | },
43 | {
44 | "version": "v7.34.0",
45 | "url": "https://arize-client-python.readthedocs.io/en/v7.34.0/"
46 | },
47 | {
48 | "version": "v7.33.0",
49 | "url": "https://arize-client-python.readthedocs.io/en/v7.33.0/"
50 | },
51 | {
52 | "version": "v7.32.0",
53 | "url": "https://arize-client-python.readthedocs.io/en/v7.32.0/"
54 | },
55 | {
56 | "version": "v7.31.0",
57 | "url": "https://arize-client-python.readthedocs.io/en/v7.31.0/"
58 | },
59 | {
60 | "version": "v7.30.0",
61 | "url": "https://arize-client-python.readthedocs.io/en/v7.30.0/"
62 | },
63 | {
64 | "version": "v7.29.0",
65 | "url": "https://arize-client-python.readthedocs.io/en/v7.29.0/"
66 | },
67 | {
68 | "version": "v7.28.0",
69 | "url": "https://arize-client-python.readthedocs.io/en/v7.28.0/"
70 | },
71 | {
72 | "version": "v7.27.1",
73 | "url": "https://arize-client-python.readthedocs.io/en/v7.27.1/"
74 | },
75 | {
76 | "version": "v7.26.0",
77 | "url": "https://arize-client-python.readthedocs.io/en/v7.26.0/"
78 | },
79 | {
80 | "version": "v7.25.0",
81 | "url": "https://arize-client-python.readthedocs.io/en/v7.25.0/"
82 | }
83 | ]
84 |
--------------------------------------------------------------------------------
/docs/source/_templates/custom_sidebar.html:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/docs/source/_templates/navbar-mid.html:
--------------------------------------------------------------------------------
1 | {# Displays links to the top-level TOCtree elements, in the header navbar. #}
2 |
--------------------------------------------------------------------------------
/docs/source/index.md:
--------------------------------------------------------------------------------
1 | ---
2 | myst:
3 | html_meta:
4 | "description lang=en": |
5 | Top-level documentation for arize,
6 | with links to the rest of the site..
7 | html_theme.sidebar_secondary.remove: true
8 | ---
9 |
10 | # Arize Python SDK Reference
11 |
12 |
13 |
27 |
28 |
29 | Below you can find the classes, functions, and parameters for tracing and evaluating your application using the Arize Python SDK. To get a complete guide on how to use Arize, including tutorials, quickstarts, and concept explanations, read our [docs](https://docs.arize.com/arize). This primarily covers LLM SDK methods, and we're documenting the rest of the methods for the full SDK spanning ML and CV use cases soon!
30 |
31 | ```{seealso}
32 | Check out our [Slack community](https://arize-ai.slack.com/join/shared_invite/zt-1px8dcmlf-fmThhDFD_V_48oU7ALan4Q#/shared-invite/email) and [GitHub repository](https://github.com/arize-ai/client_python)!
33 | ```
34 |
35 | ::::{grid} 2
36 | :::{grid-item}
37 | ```{toctree}
38 | :maxdepth: 1
39 | :caption: LLM API
40 | llm-api/datasets
41 | llm-api/evaluators
42 | llm-api/exporter
43 | llm-api/logger
44 | llm-api/types
45 | ```
46 | :::
47 | :::{grid-item}
48 | ```{toctree}
49 | :maxdepth: 1
50 | :caption: ML API
51 | ml-api
52 | ml-api/exporter
53 | ml-api/logger
54 | ```
55 | :::
56 | ::::
57 |
58 |
59 |
60 |
61 |
--------------------------------------------------------------------------------
/docs/source/llm-api/datasets.rst:
--------------------------------------------------------------------------------
1 | datasets & experiments
2 | ======================
3 | Use datasets to curate spans from your LLM applications for testing. Run experiments to test different models, prompts, parameters for your LLM apps. `Read our quickstart guide for more information `_.
4 |
5 | To use in your code, import the following:
6 |
7 | ``from arize.experimental.datasets import ArizeDatasetsClient``
8 |
9 | .. autoclass:: arize.experimental.datasets.ArizeDatasetsClient
10 | :members:
11 | :exclude-members: port, host, otlp_endpoint, __init__, api_key, developer_key, scheme, session
12 |
--------------------------------------------------------------------------------
/docs/source/llm-api/evaluators.rst:
--------------------------------------------------------------------------------
1 | evaluators
2 | ^^^^^^^^^^
3 | These are used to create evaluators as a class. `See our docs for more information `_.
4 |
5 | To import evaluators, use the following:
6 | ``from arize.experimental.datasets.experiments.evaluators.base import ...``
7 |
8 | .. autoclass:: arize.experimental.datasets.experiments.evaluators.base.Evaluator
9 | :members: evaluate, async_evaluate
10 | :exclude-members: kind, name, _name
--------------------------------------------------------------------------------
/docs/source/llm-api/exporter.rst:
--------------------------------------------------------------------------------
1 | exporter
2 | ========
3 | Use this to export data from Arize. `Read this guide for more information `_.
4 |
5 | To use in your code, import the following:
6 |
7 | ``from arize.exporter import ArizeExportClient``
8 |
9 | .. automodule:: arize.exporter
10 | :members:
11 | :exclude-members: api_key, arize_config_path, arize_profile, get_progress_bar, host, port, scheme, session, __init__, get_arize_schema
--------------------------------------------------------------------------------
/docs/source/llm-api/logger.rst:
--------------------------------------------------------------------------------
1 | logger
2 | ======
3 | Use this to log evaluations or spans to Arize in bulk. Read our quickstart guide for `logging evaluations to Arize `_.
4 |
5 | To use in your code, import the following:
6 |
7 | ``from arize.pandas.logger import Client``
8 |
9 | .. automodule:: arize.pandas.logger
10 | :members:
11 | :exclude-members: __init__, log, InvalidSessionError, FlightSession
--------------------------------------------------------------------------------
/docs/source/llm-api/types.rst:
--------------------------------------------------------------------------------
1 | types
2 | ^^^^^
3 | These are the classes used across the experiment functions.
4 |
5 | To import types, use the following:
6 | ``from arize.experimental.datasets.experiments.types import ...``
7 |
8 | .. autoclass:: arize.experimental.datasets.experiments.types.Example
9 | :exclude-members: id, updated_at, input, output, metadata, dataset_row, from_dict
10 |
11 | .. autoclass:: arize.experimental.datasets.experiments.types.EvaluationResult
12 | :exclude-members: score, label, explanation, metadata, from_dict
13 |
14 | .. autoclass:: arize.experimental.datasets.experiments.types.ExperimentRun
15 | :exclude-members: start_time, end_time, experiment_id, dataset_example_id, repetition_number, output, error, id, trace_id, from_dict
16 |
17 | .. autoclass:: arize.experimental.datasets.experiments.types.ExperimentEvaluationRun
18 | :exclude-members: start_time, end_time, experiment_id, dataset_example_id, repetition_number, output, error, id, trace_id, annotator_kind, experiment_run_id, from_dict, name, result
19 |
--------------------------------------------------------------------------------
/docs/source/ml-api/exporter.rst:
--------------------------------------------------------------------------------
1 | exporter
2 | ========
3 | Use this to export data from Arize. `Read this guide for more information `_.
4 |
5 | To use in your code, import the following:
6 |
7 | ``from arize.exporter import ArizeExportClient``
8 |
9 | .. automodule:: arize.exporter
10 | :members:
11 | :exclude-members: api_key, arize_config_path, arize_profile, get_progress_bar, host, port, scheme, session, __init__, get_arize_schema
--------------------------------------------------------------------------------
/docs/source/ml-api/logger.rst:
--------------------------------------------------------------------------------
1 | logger
2 | ======
3 |
4 | client
5 | ^^^^^^
6 | Use this to log inferences from your model (ML, CV, NLP, etc.) to Arize.
7 |
8 | To use in your code, import the following: ``from arize.pandas.logger import Client``
9 |
10 | .. automodule:: arize.pandas.logger
11 | :members:
12 | :exclude-members: __init__, log_spans, log_evaluations, log_evaluations_sync, FlightSession, InvalidSessionError
13 |
14 |
15 | client (single record)
16 | ^^^^^^^^^^^^^^^^^^^^^^^
17 | To use in your code, import the following: ``from arize.api import Client``
18 |
19 | .. automodule:: arize.api
20 | :members:
21 | :exclude-members: __init__
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [project]
2 | name = "arize"
3 | description = "A helper library to interact with Arize AI APIs"
4 | readme = "README.md"
5 | requires-python = ">=3.6"
6 | license = { text = "BSD" }
7 | keywords = [
8 | "Arize",
9 | "Observability",
10 | "Monitoring",
11 | "Explainability",
12 | "Tracing",
13 | "LLM",
14 | "Evaluations",
15 | ]
16 | authors = [
17 | { name = "Arize AI", email = "support@arize.com" },
18 | ]
19 | maintainers = [
20 | { name = "Arize AI", email = "support@arize.com" },
21 | ]
22 | classifiers = [
23 | "Development Status :: 5 - Production/Stable",
24 | "Programming Language :: Python",
25 | "Programming Language :: Python :: 3.6",
26 | "Programming Language :: Python :: 3.7",
27 | "Programming Language :: Python :: 3.8",
28 | "Programming Language :: Python :: 3.9",
29 | "Programming Language :: Python :: 3.10",
30 | "Programming Language :: Python :: 3.11",
31 | "Programming Language :: Python :: 3.12",
32 | "Programming Language :: Python :: 3.13",
33 | "Intended Audience :: Developers",
34 | "Topic :: Scientific/Engineering :: Artificial Intelligence",
35 | "Topic :: Software Development :: Libraries :: Python Modules",
36 | "Topic :: System :: Logging",
37 | "Topic :: System :: Monitoring",
38 | ]
39 | dependencies = [
40 | "requests_futures==1.0.0",
41 | "googleapis_common_protos>=1.51.0,<2",
42 | "protobuf>=4.21.0,<6",
43 | "pandas>=0.25.3,<3",
44 | "pyarrow>=0.15.0",
45 | "tqdm>=4.60.0,<5",
46 | "pydantic>=2.0.0,<3",
47 | ]
48 | dynamic = ["version"]
49 |
50 | [project.urls]
51 | Homepage = "https://arize.com"
52 | Documentation = "https://docs.arize.com/arize"
53 | Issues = "https://github.com/Arize-ai/client_python/issues"
54 | Source = "https://github.com/Arize-ai/client_python"
55 | Changelog = "https://github.com/Arize-ai/client_python/blob/main/CHANGELOG.md"
56 |
57 | [project.optional-dependencies]
58 | dev = [
59 | "pytest==8.3.3",
60 | "ruff==0.6.9",
61 | ]
62 | MimicExplainer = [
63 | "interpret-community[mimic]>=0.22.0,<1",
64 | ]
65 | AutoEmbeddings = [
66 | "transformers>=4.25, <5",
67 | "tokenizers>=0.13, <1",
68 | "datasets>=2.8, <3, !=2.14.*",
69 | "torch>=1.13, <3",
70 | "Pillow>=8.4.0, <11",
71 | ]
72 | NLP_Metrics = [
73 | "nltk>=3.0.0, <4",
74 | "sacrebleu>=2.3.1, <3",
75 | "rouge-score>=0.1.2, <1",
76 | "evaluate>=0.3, <1",
77 | "datasets!=2.14.*",
78 | ]
79 | LLM_Evaluation = [
80 | # To be removed in version 8 in favor of NLP_Metrics
81 | "nltk>=3.0.0, <4",
82 | "sacrebleu>=2.3.1, <3",
83 | "rouge-score>=0.1.2, <1",
84 | "evaluate>=0.3, <1",
85 | "datasets!=2.14.*",
86 | ]
87 | Tracing = [
88 | "opentelemetry-semantic-conventions>=0.43b0, <1",
89 | "openinference-semantic-conventions>=0.1.12, <1",
90 | "deprecated", #opentelemetry-semantic-conventions requires it
91 | ]
92 | Datasets = [
93 | "typing-extensions>=4, <5",
94 | "wrapt>=1.12.1, <2",
95 | "opentelemetry-semantic-conventions>=0.43b0, <1",
96 | "openinference-semantic-conventions>=0.1.6, <1",
97 | "opentelemetry-sdk>=1.25.0, <2",
98 | "opentelemetry-exporter-otlp>=1.25.0, <2",
99 | "deprecated", #opentelemetry-semantic-conventions requires it
100 | ]
101 | PromptHub = [
102 | "gql>=3.0.0",
103 | "requests_toolbelt>=1.0.0",
104 | ]
105 | PromptHub_VertexAI = [
106 | "google-cloud-aiplatform>=1.0.0",
107 | ]
108 |
109 | [build-system]
110 | requires = ["hatchling"]
111 | build-backend = "hatchling.build"
112 |
113 | [tool.hatch.version]
114 | path = "src/arize/version.py"
115 |
116 | [tool.hatch.build]
117 | only-packages = true
118 |
119 | [tool.hatch.build.targets.wheel]
120 | packages = ["src/arize"]
121 |
122 | [tool.hatch.build.targets.sdist]
123 | exclude = [
124 | "src/arize/examples",
125 | "tests",
126 | "docs",
127 | ]
128 |
129 |
130 | [tool.black]
131 | include = '\.pyi?$'
132 | exclude = '(_pb2\.py$|docs/source/.*\.py)'
133 |
134 | [tool.ruff]
135 | target-version = "py37"
136 | line-length = 80
137 | exclude = [
138 | "dist/",
139 | "__pycache__",
140 | "*_pb2.py*",
141 | "*_pb2_grpc.py*",
142 | "*.pyi",
143 | "docs/",
144 | ]
145 | [tool.ruff.format]
146 | docstring-code-format = true
147 | line-ending = "native"
148 |
149 | [tool.ruff.lint]
150 | select = [
151 | # pycodestyle Error
152 | "E",
153 | # pycodestyle Warning
154 | "W",
155 | # Pyflakes
156 | "F",
157 | # pyupgrade
158 | "UP",
159 | # flake8-bugbear
160 | "B",
161 | # flake8-simplify
162 | "SIM",
163 | # isort
164 | "I",
165 | # TODO: Enable pydocstyle when ready for API docs
166 | # # pydocstyle
167 | # "D",
168 | ]
169 | ignore= [
170 | "D203", # Do not use a blank line to separate the docstring from the class definition,
171 | "D212", # The summary line should be located on the second physical line of the docstring
172 | ]
173 |
174 | [tool.ruff.lint.isort]
175 | force-wrap-aliases = true
176 |
177 | [tool.ruff.lint.pycodestyle]
178 | max-doc-length = 110
179 | max-line-length = 110
180 |
181 | [tool.ruff.lint.pydocstyle]
182 | convention = "google"
183 |
184 | [tool.ruff.lint.pyupgrade]
185 | keep-runtime-typing = true
186 |
--------------------------------------------------------------------------------
/src/arize/__init__.py:
--------------------------------------------------------------------------------
1 | from .version import __version__
2 |
3 | __all__ = [
4 | "__version__",
5 | ]
6 |
--------------------------------------------------------------------------------
/src/arize/bounded_executor.py:
--------------------------------------------------------------------------------
1 | from concurrent.futures import ThreadPoolExecutor
2 | from threading import BoundedSemaphore
3 |
4 |
5 | class BoundedExecutor:
6 | """
7 | BoundedExecutor behaves as a ThreadPoolExecutor which will block on
8 | calls to submit() once the limit given as "bound" work items are queued for
9 | execution.
10 | :param bound: Integer - the maximum number of items in the work queue
11 | :param max_workers: Integer - the size of the thread pool
12 | """
13 |
14 | def __init__(self, bound, max_workers):
15 | self.executor = ThreadPoolExecutor(max_workers=max_workers)
16 | self.semaphore = BoundedSemaphore(bound + max_workers)
17 |
18 | """See concurrent.futures.Executor#submit"""
19 |
20 | def submit(self, fn, *args, **kwargs):
21 | self.semaphore.acquire()
22 | try:
23 | future = self.executor.submit(fn, *args, **kwargs)
24 | except Exception:
25 | self.semaphore.release()
26 | raise
27 | else:
28 | future.add_done_callback(lambda x: self.semaphore.release())
29 | return future
30 |
31 | """See concurrent.futures.Executor#shutdown"""
32 |
33 | def shutdown(self, wait=True):
34 | self.executor.shutdown(wait)
35 |
--------------------------------------------------------------------------------
/src/arize/experimental/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Arize-ai/client_python/8f72212f42490baa7a90831c591b83ab6173ba83/src/arize/experimental/__init__.py
--------------------------------------------------------------------------------
/src/arize/experimental/datasets/__init__.py:
--------------------------------------------------------------------------------
1 | from .core.client import ArizeDatasetsClient
2 |
3 | __all__ = ["ArizeDatasetsClient"]
4 |
--------------------------------------------------------------------------------
/src/arize/experimental/datasets/core/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Arize-ai/client_python/8f72212f42490baa7a90831c591b83ab6173ba83/src/arize/experimental/datasets/core/__init__.py
--------------------------------------------------------------------------------
/src/arize/experimental/datasets/core/session.py:
--------------------------------------------------------------------------------
1 | import uuid
2 | from dataclasses import dataclass, field
3 |
4 | from pyarrow import flight
5 |
6 | from arize.utils.logging import logger
7 | from arize.utils.utils import get_python_version
8 | from arize.version import __version__
9 |
10 | from ..utils.constants import DEFAULT_PACKAGE_NAME
11 | from ..validation.errors import InvalidSessionError
12 |
13 |
14 | @dataclass
15 | class Session:
16 | api_key: str
17 | host: str
18 | port: int
19 | scheme: str
20 | session_name: str = field(init=False)
21 | call_options: flight.FlightCallOptions = field(init=False)
22 |
23 | def __post_init__(self):
24 | self.session_name = f"python-sdk-{DEFAULT_PACKAGE_NAME}-{uuid.uuid4()}"
25 | logger.debug(f"Creating named session as '{self.session_name}'.")
26 | if self.api_key is None:
27 | logger.error(InvalidSessionError.error_message())
28 | raise InvalidSessionError
29 |
30 | logger.debug(
31 | f"Created session with Arize API Key '{self.api_key}' at '{self.host}':'{self.port}'"
32 | )
33 | self._set_headers()
34 |
35 | def connect(self) -> flight.FlightClient:
36 | """
37 | Connects to Arize Flight server public endpoint with the
38 | provided api key.
39 | """
40 | try:
41 | disable_cert = self.host.lower() == "localhost"
42 | client = flight.FlightClient(
43 | location=f"{self.scheme}://{self.host}:{self.port}",
44 | disable_server_verification=disable_cert,
45 | )
46 | self.call_options = flight.FlightCallOptions(headers=self._headers)
47 | return client
48 | except Exception:
49 | logger.error(
50 | "There was an error trying to connect to the Arize Flight Endpoint"
51 | )
52 | raise
53 |
54 | def _set_headers(self) -> None:
55 | self._headers = [
56 | (b"origin", b"arize-python-datasets-client"),
57 | (b"auth-token-bin", f"{self.api_key}".encode()),
58 | (b"sdk-language", b"python"),
59 | (b"language-version", get_python_version().encode("utf-8")),
60 | (b"sdk-version", __version__.encode("utf-8")),
61 | (b"arize-interface", b"flight"),
62 | ]
63 |
--------------------------------------------------------------------------------
/src/arize/experimental/datasets/experiments/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Arize-ai/client_python/8f72212f42490baa7a90831c591b83ab6173ba83/src/arize/experimental/datasets/experiments/__init__.py
--------------------------------------------------------------------------------
/src/arize/experimental/datasets/experiments/evaluators/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Arize-ai/client_python/8f72212f42490baa7a90831c591b83ab6173ba83/src/arize/experimental/datasets/experiments/evaluators/__init__.py
--------------------------------------------------------------------------------
/src/arize/experimental/datasets/experiments/evaluators/exceptions.py:
--------------------------------------------------------------------------------
1 | class ArizeException(Exception):
2 | pass
3 |
4 |
5 | class ArizeContextLimitExceeded(ArizeException):
6 | pass
7 |
8 |
9 | class ArizeTemplateMappingError(ArizeException):
10 | pass
11 |
--------------------------------------------------------------------------------
/src/arize/experimental/datasets/utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Arize-ai/client_python/8f72212f42490baa7a90831c591b83ab6173ba83/src/arize/experimental/datasets/utils/__init__.py
--------------------------------------------------------------------------------
/src/arize/experimental/datasets/utils/constants.py:
--------------------------------------------------------------------------------
1 | from enum import Enum
2 |
3 | from openinference.semconv import trace
4 |
5 | from arize.pandas.proto import requests_pb2 as pb2
6 |
7 | INFERENCES = pb2.INFERENCES
8 | GENERATIVE = pb2.GENERATIVE
9 |
10 | """Internal Use"""
11 |
12 | # Default API endpoint when not provided through env variable nor profile
13 | DEFAULT_ARIZE_FLIGHT_HOST = "flight.arize.com"
14 | DEFAULT_ARIZE_FLIGHT_PORT = 443
15 | DEFAULT_ARIZE_OTLP_ENDPOINT = "https://otlp.arize.com/v1"
16 |
17 | # Name of the current package.
18 | DEFAULT_PACKAGE_NAME = "arize_python_datasets_client"
19 |
20 | # Default headers to trace and help identify requests. For debugging.
21 | DEFAULT_ARIZE_SESSION_ID = "x-arize-session-id" # Generally the session name.
22 | DEFAULT_ARIZE_TRACE_ID = "x-arize-trace-id"
23 | DEFAULT_PACKAGE_VERSION = "x-package-version"
24 |
25 | # Default initial wait time for retries in seconds.
26 | DEFAULT_RETRY_INITIAL_WAIT_TIME = 0.25
27 |
28 | # Default maximum wait time for retries in seconds.
29 | DEFAULT_RETRY_MAX_WAIT_TIME = 10.0
30 |
31 | # Default to use grpc + tls scheme.
32 | DEFAULT_TRANSPORT_SCHEME = "grpc+tls"
33 |
34 |
35 | class FlightActionKey(Enum):
36 | GET_DATASET_VERSION = "get_dataset_version"
37 | LIST_DATASETS = "list_datasets"
38 | DELETE_DATASET = "delete_dataset"
39 | CREATE_EXPERIMENT_DB_ENTRY = "create_experiment_db_entry"
40 | DELETE_EXPERIMENT = "delete_experiment"
41 |
42 |
43 | OPEN_INFERENCE_JSON_STR_TYPES = frozenset(
44 | [
45 | trace.DocumentAttributes.DOCUMENT_METADATA,
46 | trace.SpanAttributes.LLM_FUNCTION_CALL,
47 | trace.SpanAttributes.LLM_INVOCATION_PARAMETERS,
48 | trace.SpanAttributes.LLM_PROMPT_TEMPLATE_VARIABLES,
49 | trace.MessageAttributes.MESSAGE_FUNCTION_CALL_ARGUMENTS_JSON,
50 | trace.SpanAttributes.METADATA,
51 | trace.SpanAttributes.TOOL_PARAMETERS,
52 | trace.ToolCallAttributes.TOOL_CALL_FUNCTION_ARGUMENTS_JSON,
53 | ]
54 | )
55 |
--------------------------------------------------------------------------------
/src/arize/experimental/datasets/utils/experiment_utils.py:
--------------------------------------------------------------------------------
1 | import dataclasses
2 | import datetime
3 | import functools
4 | from enum import Enum
5 | from pathlib import Path
6 | from typing import Any, Callable, Mapping, Sequence, Union
7 |
8 | try:
9 | from typing import get_args, get_origin # Python 3.8+
10 | except ImportError:
11 | from typing_extensions import get_args, get_origin # For Python <3.8
12 |
13 |
14 | import numpy as np
15 |
16 |
17 | def get_func_name(fn: Callable[..., Any]) -> str:
18 | """
19 | Makes a best-effort attempt to get the name of the function.
20 | """
21 | if isinstance(fn, functools.partial):
22 | return fn.func.__qualname__
23 | if hasattr(fn, "__qualname__") and not fn.__qualname__.endswith(""):
24 | return fn.__qualname__.split("..")[-1]
25 | return str(fn)
26 |
27 |
28 | def jsonify(obj: Any) -> Any:
29 | """
30 | Coerce object to be json serializable.
31 | """
32 | if isinstance(obj, Enum):
33 | return jsonify(obj.value)
34 | if isinstance(obj, (str, int, float, bool)) or obj is None:
35 | return obj
36 | if isinstance(obj, (list, set, frozenset, Sequence)):
37 | return [jsonify(v) for v in obj]
38 | if isinstance(obj, (dict, Mapping)):
39 | return {jsonify(k): jsonify(v) for k, v in obj.items()}
40 | if dataclasses.is_dataclass(obj):
41 | result = {}
42 | for field in dataclasses.fields(obj):
43 | k = field.name
44 | v = getattr(obj, k)
45 | if not (
46 | v is None
47 | and get_origin(field) is Union
48 | and type(None) in get_args(field)
49 | ):
50 | result[k] = jsonify(v)
51 | return result
52 | if isinstance(obj, (datetime.date, datetime.datetime, datetime.time)):
53 | return obj.isoformat()
54 | if isinstance(obj, datetime.timedelta):
55 | return obj.total_seconds()
56 | if isinstance(obj, Path):
57 | return str(obj)
58 | if isinstance(obj, BaseException):
59 | return str(obj)
60 | if isinstance(obj, np.ndarray):
61 | return [jsonify(v) for v in obj]
62 | if hasattr(obj, "__float__"):
63 | return float(obj)
64 | if hasattr(obj, "model_dump") and callable(obj.model_dump):
65 | # pydantic v2
66 | try:
67 | d = obj
68 | assert isinstance(d, dict)
69 | except BaseException:
70 | pass
71 | else:
72 | return jsonify(d)
73 | if hasattr(obj, "dict") and callable(obj.dict):
74 | # pydantic v1
75 | try:
76 | d = obj.dict()
77 | assert isinstance(d, dict)
78 | except BaseException:
79 | pass
80 | else:
81 | return jsonify(d)
82 | cls = obj.__class__
83 | return f"<{cls.__module__}.{cls.__name__} object>"
84 |
--------------------------------------------------------------------------------
/src/arize/experimental/datasets/validation/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Arize-ai/client_python/8f72212f42490baa7a90831c591b83ab6173ba83/src/arize/experimental/datasets/validation/__init__.py
--------------------------------------------------------------------------------
/src/arize/experimental/datasets/validation/errors.py:
--------------------------------------------------------------------------------
1 | from abc import ABC, abstractmethod
2 |
3 |
4 | class DatasetError(Exception, ABC):
5 | def __str__(self) -> str:
6 | return self.error_message()
7 |
8 | @abstractmethod
9 | def __repr__(self) -> str:
10 | pass
11 |
12 | @abstractmethod
13 | def error_message(self) -> str:
14 | pass
15 |
16 |
17 | class InvalidSessionError(DatasetError):
18 | @staticmethod
19 | def error_message() -> str:
20 | return (
21 | "Credentials not provided or invalid. Please pass in the correct api_key when "
22 | "initiating a new ArizeExportClient. Alternatively, you can set up credentials "
23 | "in a profile or as an environment variable"
24 | )
25 |
26 |
27 | class InvalidConfigFileError(DatasetError):
28 | @staticmethod
29 | def error_message() -> str:
30 | return "Invalid/Misconfigured Configuration File"
31 |
32 |
33 | class IDColumnUniqueConstraintError(DatasetError):
34 | @staticmethod
35 | def error_message() -> str:
36 | return "'id' column must contain unique values"
37 |
38 |
39 | class RequiredColumnsError(DatasetError):
40 | def __init__(self, missing_columns: set) -> None:
41 | self.missing_columns = missing_columns
42 |
43 | def error_message(self) -> str:
44 | return f"Missing required columns: {self.missing_columns}"
45 |
46 | def __repr__(self) -> str:
47 | return f"RequiredColumnsError({self.missing_columns})"
48 |
--------------------------------------------------------------------------------
/src/arize/experimental/datasets/validation/validator.py:
--------------------------------------------------------------------------------
1 | from typing import List
2 |
3 | import pandas as pd
4 |
5 | from . import errors as err
6 |
7 |
8 | class Validator:
9 | @staticmethod
10 | def validate(
11 | df: pd.DataFrame,
12 | ) -> List[err.DatasetError]:
13 | ## check all require columns are present
14 | required_columns_errors = Validator._check_required_columns(df)
15 | if required_columns_errors:
16 | return required_columns_errors
17 |
18 | ## check id column is unique
19 | id_column_unique_constraint_error = (
20 | Validator._check_id_column_is_unique(df)
21 | )
22 | if id_column_unique_constraint_error:
23 | return id_column_unique_constraint_error
24 |
25 | return []
26 |
27 | @staticmethod
28 | def _check_required_columns(df: pd.DataFrame) -> List[err.DatasetError]:
29 | required_columns = ["id", "created_at", "updated_at"]
30 | missing_columns = set(required_columns) - set(df.columns)
31 | if missing_columns:
32 | return [err.RequiredColumnsError(missing_columns)]
33 | return []
34 |
35 | @staticmethod
36 | def _check_id_column_is_unique(df: pd.DataFrame) -> List[err.DatasetError]:
37 | if not df["id"].is_unique:
38 | return [err.IDColumnUniqueConstraintError]
39 | return []
40 |
--------------------------------------------------------------------------------
/src/arize/experimental/integrations/whylabs/__init__.py:
--------------------------------------------------------------------------------
1 | from arize.experimental.integrations.whylabs.client import IntegrationClient
2 | from arize.experimental.integrations.whylabs.generator import (
3 | WhylabsProfileAdapter,
4 | )
5 |
6 | __all__ = ["WhylabsProfileAdapter", "IntegrationClient"]
7 |
--------------------------------------------------------------------------------
/src/arize/experimental/integrations/whylabs_vanguard_governance/__init__.py:
--------------------------------------------------------------------------------
1 | from arize.experimental.integrations.whylabs_vanguard_governance.client import (
2 | IntegrationClient,
3 | )
4 |
5 | __all__ = ["IntegrationClient"]
6 |
--------------------------------------------------------------------------------
/src/arize/experimental/integrations/whylabs_vanguard_governance/client.py:
--------------------------------------------------------------------------------
1 | # type: ignore[pb2]
2 | from datetime import datetime, timedelta
3 | from typing import Dict, Optional, Union
4 |
5 | import pandas as pd
6 |
7 | from arize.pandas.logger import Client
8 | from arize.utils.types import (
9 | Environments,
10 | ModelTypes,
11 | Schema,
12 | )
13 |
14 |
15 | class IntegrationClient:
16 | def __init__(
17 | self,
18 | api_key: Optional[str] = None,
19 | space_id: Optional[str] = None,
20 | space_key: Optional[str] = None,
21 | uri: Optional[str] = "https://api.arize.com/v1",
22 | additional_headers: Optional[Dict[str, str]] = None,
23 | request_verify: Union[bool, str] = True,
24 | developer_key: Optional[str] = None,
25 | host: Optional[str] = None,
26 | port: Optional[int] = None,
27 | ) -> None:
28 | """
29 | Wrapper for the Arize Client specific to WhyLabs profiles.
30 | """
31 | self._client = Client(
32 | api_key=api_key,
33 | space_id=space_id,
34 | space_key=space_key,
35 | uri=uri,
36 | additional_headers=additional_headers,
37 | request_verify=request_verify,
38 | developer_key=developer_key,
39 | host=host,
40 | port=port,
41 | )
42 |
43 | def create_model(
44 | self,
45 | model_id: str,
46 | model_type: ModelTypes,
47 | environment: Environments = Environments.PRODUCTION,
48 | ) -> None:
49 | df = pd.DataFrame(
50 | {
51 | "timestamp": [datetime.now() - timedelta(days=365)],
52 | "ARIZE_PLACEHOLDER_STRING": "ARIZE_DUMMY",
53 | "ARIZE_PLACEHOLDER_FLOAT": 1,
54 | }
55 | )
56 |
57 | if model_type == ModelTypes.RANKING:
58 | schema = Schema(
59 | timestamp_column_name="timestamp",
60 | prediction_group_id_column_name="ARIZE_PLACEHOLDER_FLOAT",
61 | rank_column_name="ARIZE_PLACEHOLDER_FLOAT",
62 | )
63 | elif model_type == ModelTypes.MULTI_CLASS:
64 | df["ARIZE_PLACEHOLDER_MAP"] = [
65 | [
66 | {
67 | "class_name": "ARIZE_DUMMY",
68 | "score": 1.0,
69 | }
70 | ]
71 | ]
72 | schema = Schema(
73 | timestamp_column_name="timestamp",
74 | prediction_score_column_name="ARIZE_PLACEHOLDER_MAP",
75 | )
76 | elif (
77 | model_type == ModelTypes.NUMERIC
78 | or model_type == ModelTypes.REGRESSION
79 | ):
80 | schema = Schema(
81 | timestamp_column_name="timestamp",
82 | prediction_label_column_name="ARIZE_PLACEHOLDER_FLOAT",
83 | prediction_score_column_name="ARIZE_PLACEHOLDER_FLOAT",
84 | actual_label_column_name="ARIZE_PLACEHOLDER_FLOAT",
85 | actual_score_column_name="ARIZE_PLACEHOLDER_FLOAT",
86 | )
87 | else:
88 | schema = Schema(
89 | timestamp_column_name="timestamp",
90 | prediction_label_column_name="ARIZE_PLACEHOLDER_STRING",
91 | prediction_score_column_name="ARIZE_PLACEHOLDER_FLOAT",
92 | actual_label_column_name="ARIZE_PLACEHOLDER_STRING",
93 | actual_score_column_name="ARIZE_PLACEHOLDER_FLOAT",
94 | )
95 |
96 | return self._client.log(
97 | dataframe=df,
98 | schema=schema,
99 | environment=environment,
100 | model_id=model_id,
101 | model_type=model_type,
102 | )
103 |
--------------------------------------------------------------------------------
/src/arize/experimental/integrations/whylabs_vanguard_governance/test.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | from arize.experimental.integrations.whylabs_vanguard_governance import (
4 | IntegrationClient,
5 | )
6 | from arize.utils.types import Environments, ModelTypes
7 |
8 | os.environ["ARIZE_SPACE_ID"] = "REPLACE_ME"
9 | os.environ["ARIZE_API_KEY"] = "REPLACE_ME"
10 | os.environ["ARIZE_DEVELOPER_KEY"] = "REPLACE_ME"
11 |
12 |
13 | def test_create_model():
14 | client = IntegrationClient(
15 | api_key=os.environ["ARIZE_API_KEY"],
16 | space_id=os.environ["ARIZE_SPACE_ID"],
17 | developer_key=os.environ["ARIZE_DEVELOPER_KEY"],
18 | )
19 | client.create_model(
20 | model_id="test-create-binary-classification",
21 | environment=Environments.PRODUCTION,
22 | model_type=ModelTypes.BINARY_CLASSIFICATION,
23 | )
24 |
25 |
26 | if __name__ == "__main__":
27 | test_create_model()
28 |
--------------------------------------------------------------------------------
/src/arize/experimental/integrations/whylabs_vanguard_ingestion/__init__.py:
--------------------------------------------------------------------------------
1 | from arize.experimental.integrations.whylabs_vanguard_ingestion.client import (
2 | IntegrationClient,
3 | )
4 | from arize.experimental.integrations.whylabs_vanguard_ingestion.generator import (
5 | WhylabsVanguardProfileAdapter,
6 | )
7 |
8 | __all__ = ["WhylabsVanguardProfileAdapter", "IntegrationClient"]
9 |
--------------------------------------------------------------------------------
/src/arize/experimental/online_tasks/__init__.py:
--------------------------------------------------------------------------------
1 | from .dataframe_preprocessor import extract_nested_data_to_column
2 |
3 | __all__ = ["extract_nested_data_to_column"]
4 |
--------------------------------------------------------------------------------
/src/arize/experimental/prompt_hub/__init__.py:
--------------------------------------------------------------------------------
1 | from .client import ArizePromptClient
2 | from .prompts import LLMProvider, Prompt, prompt_to_llm_input
3 |
4 | __all__ = ["ArizePromptClient", "Prompt", "LLMProvider", "prompt_to_llm_input"]
5 |
--------------------------------------------------------------------------------
/src/arize/experimental/prompt_hub/constants.py:
--------------------------------------------------------------------------------
1 | """
2 | Constants for the Arize Prompt Hub.
3 | """
4 |
5 | # Mapping from internal model names to external model names
6 | ARIZE_INTERNAL_MODEL_MAPPING = {
7 | # ----------------------
8 | # OpenAI Models
9 | # ----------------------
10 | "GPT_4o_MINI": "gpt-4o-mini",
11 | "GPT_4o_MINI_2024_07_18": "gpt-4o-mini-2024-07-18",
12 | "GPT_4o": "gpt-4o",
13 | "GPT_4o_2024_05_13": "gpt-4o-2024-05-13",
14 | "GPT_4o_2024_08_06": "gpt-4o-2024-08-06",
15 | "CHATGPT_4o_LATEST": "chatgpt-4o-latest",
16 | "O1_PREVIEW": "o1-preview",
17 | "O1_PREVIEW_2024_09_12": "o1-preview-2024-09-12",
18 | "O1_MINI": "o1-mini",
19 | "O1_MINI_2024_09_12": "o1-mini-2024-09-12",
20 | "GPT_4_TURBO": "gpt-4-turbo",
21 | "GPT_4_TURBO_2024_04_09": "gpt-4-turbo-2024-04-09",
22 | "GPT_4_TURBO_PREVIEW": "gpt-4-turbo-preview",
23 | "GPT_4_0125_PREVIEW": "gpt-4-0125-preview",
24 | "GPT_4_1106_PREVIEW": "gpt-4-1106-preview",
25 | "GPT_4": "gpt-4",
26 | "GPT_4_32k": "gpt-4-32k",
27 | "GPT_4_0613": "gpt-4-0613",
28 | "GPT_4_0314": "gpt-4-0314",
29 | "GPT_4_VISION_PREVIEW": "gpt-4-vision-preview",
30 | "GPT_3_5_TURBO": "gpt-3.5-turbo",
31 | "GPT_3_5_TURBO_1106": "gpt-3.5-turbo-1106",
32 | "GPT_3_5_TURBO_INSTRUCT": "gpt-3.5-turbo-instruct",
33 | "GPT_3_5_TURBO_0125": "gpt-3.5-turbo-0125",
34 | "TEXT_DAVINCI_003": "text-davinci-003",
35 | "BABBAGE_002": "babbage-002",
36 | "DAVINCI_002": "davinci-002",
37 | "O1_2024_12_17": "o1-2024-12-17",
38 | "O1": "o1",
39 | "O3_MINI": "o3-mini",
40 | "O3_MINI_2025_01_31": "o3-mini-2025-01-31",
41 | # ----------------------
42 | # Vertex AI Models
43 | # ----------------------
44 | "GEMINI_1_0_PRO": "gemini-1.0-pro",
45 | "GEMINI_1_0_PRO_VISION": "gemini-1.0-pro-vision",
46 | "GEMINI_1_5_FLASH": "gemini-1.5-flash",
47 | "GEMINI_1_5_FLASH_002": "gemini-1.5-flash-002",
48 | "GEMINI_1_5_FLASH_8B": "gemini-1.5-flash-8b",
49 | "GEMINI_1_5_FLASH_LATEST": "gemini-1.5-flash-latest",
50 | "GEMINI_1_5_FLASH_8B_LATEST": "gemini-1.5-flash-8b-latest",
51 | "GEMINI_1_5_PRO": "gemini-1.5-pro",
52 | "GEMINI_1_5_PRO_001": "gemini-1.5-pro-001",
53 | "GEMINI_1_5_PRO_002": "gemini-1.5-pro-002",
54 | "GEMINI_1_5_PRO_LATEST": "gemini-1.5-pro-latest",
55 | "GEMINI_2_0_FLASH": "gemini-2.0-flash",
56 | "GEMINI_2_0_FLASH_001": "gemini-2.0-flash-001",
57 | "GEMINI_2_0_FLASH_EXP": "gemini-2.0-flash-exp",
58 | "GEMINI_2_0_FLASH_LITE_PREVIEW_02_05": "gemini-2.0-flash-lite-preview-02-05",
59 | "GEMINI_PRO": "gemini-pro",
60 | }
61 |
62 | # Reverse mapping from external model names to internal model names
63 | ARIZE_EXTERNAL_MODEL_MAPPING = {
64 | v: k for k, v in ARIZE_INTERNAL_MODEL_MAPPING.items()
65 | }
66 |
--------------------------------------------------------------------------------
/src/arize/experimental/prompt_hub/prompts/__init__.py:
--------------------------------------------------------------------------------
1 | """
2 | Prompt implementations for different LLM providers.
3 | """
4 |
5 | from arize.experimental.prompt_hub.prompts.base import (
6 | FormattedPrompt,
7 | LLMProvider,
8 | Prompt,
9 | PromptInputVariableFormat,
10 | prompt_to_llm_input,
11 | )
12 |
13 | __all__ = [
14 | "FormattedPrompt",
15 | "Prompt",
16 | "PromptInputVariableFormat",
17 | "LLMProvider",
18 | "prompt_to_llm_input",
19 | ]
20 |
--------------------------------------------------------------------------------
/src/arize/experimental/prompt_hub/prompts/open_ai.py:
--------------------------------------------------------------------------------
1 | """
2 | OpenAI-specific prompt implementations.
3 | """
4 |
5 | from dataclasses import dataclass
6 | from typing import Any, Dict, List, Mapping, Sequence
7 |
8 | from arize.experimental.prompt_hub.prompts.base import FormattedPrompt, Prompt
9 |
10 |
11 | @dataclass(frozen=True)
12 | class OpenAIPrompt(FormattedPrompt):
13 | """
14 | OpenAI-specific formatted prompt.
15 |
16 | Contains fully formatted messages and additional parameters (e.g. model, temperature, etc.)
17 | required by the OpenAI API.
18 | """
19 |
20 | messages: Sequence[Dict[str, str]]
21 | kwargs: Dict[str, Any]
22 |
23 |
24 | def to_openai_prompt(
25 | prompt: Prompt, variables: Mapping[str, Any]
26 | ) -> OpenAIPrompt:
27 | """
28 | Convert a Prompt to an OpenAI-specific formatted prompt.
29 |
30 | Args:
31 | prompt: The prompt to format.
32 | variables: A mapping of variable names to values.
33 |
34 | Returns:
35 | An OpenAI-specific formatted prompt.
36 | """
37 | return OpenAIPrompt(
38 | messages=format_openai_prompt(prompt, variables),
39 | kwargs=openai_kwargs(prompt),
40 | )
41 |
42 |
43 | def format_openai_prompt(
44 | prompt: Prompt, variables: Mapping[str, Any]
45 | ) -> List[Dict[str, str]]:
46 | """
47 | Format a Prompt's messages for the OpenAI API.
48 |
49 | Args:
50 | prompt: The prompt to format.
51 | variables: A mapping of variable names to values.
52 |
53 | Returns:
54 | A list of formatted message dictionaries.
55 | """
56 | formatted_messages = []
57 | for message in prompt.messages:
58 | formatted_message = message.copy()
59 | formatted_message["content"] = message["content"].format(**variables)
60 | formatted_messages.append(formatted_message)
61 | return formatted_messages
62 |
63 |
64 | def openai_kwargs(prompt: Prompt) -> Dict[str, Any]:
65 | """
66 | Generate kwargs for the OpenAI API based on the prompt.
67 |
68 | Args:
69 | prompt: The prompt to generate kwargs for.
70 |
71 | Returns:
72 | A dictionary of kwargs for the OpenAI API.
73 | """
74 | return {"model": prompt.model_name}
75 |
--------------------------------------------------------------------------------
/src/arize/experimental/prompt_hub/prompts/vertex_ai.py:
--------------------------------------------------------------------------------
1 | """
2 | Vertex AI-specific prompt implementations.
3 | """
4 |
5 | from dataclasses import dataclass
6 | from typing import Any, List, Mapping, Sequence
7 |
8 | from vertexai.generative_models import Content, Part
9 |
10 | from arize.experimental.prompt_hub.prompts.base import FormattedPrompt, Prompt
11 |
12 |
13 | @dataclass(frozen=True)
14 | class VertexAIPrompt(FormattedPrompt):
15 | """
16 | Vertex AI-specific formatted prompt.
17 |
18 | Contains fully formatted messages and additional parameters (e.g. model, temperature, etc.)
19 | required by the Vertex AI API.
20 | """
21 |
22 | messages: Sequence[Content]
23 | model_name: str
24 |
25 |
26 | def to_vertexai_prompt(
27 | prompt: Prompt, variables: Mapping[str, Any]
28 | ) -> VertexAIPrompt:
29 | """
30 | Convert a Prompt to a Vertex AI-specific formatted prompt.
31 |
32 | Args:
33 | prompt: The prompt to format.
34 | variables: A mapping of variable names to values.
35 |
36 | Returns:
37 | A Vertex AI-specific formatted prompt.
38 |
39 | Raises:
40 | ImportError: If Vertex AI dependencies are not available.
41 | """
42 | return VertexAIPrompt(
43 | messages=format_vertexai_prompt(prompt, variables),
44 | model_name=prompt.model_name,
45 | )
46 |
47 |
48 | def format_vertexai_prompt(
49 | prompt: Prompt, variables: Mapping[str, Any]
50 | ) -> List[Content]:
51 | """
52 | Format a Prompt's messages for the Vertex AI API.
53 |
54 | Args:
55 | prompt: The prompt to format.
56 | variables: A mapping of variable names to values.
57 |
58 | Returns:
59 | A list of formatted Content objects.
60 |
61 | Raises:
62 | ImportError: If Vertex AI dependencies are not available.
63 | """
64 | formatted_messages = []
65 | for message in prompt.messages:
66 | formatted_message = Content(
67 | role=message["role"],
68 | parts=[Part.from_text(message["content"].format(**variables))],
69 | )
70 | formatted_messages.append(formatted_message)
71 | return formatted_messages
72 |
--------------------------------------------------------------------------------
/src/arize/exporter/README.md:
--------------------------------------------------------------------------------
1 | ## Arize Python Exporter Client - User Guide
2 |
3 | ### Step 1: Pip install `arize` and set up `api_key` and `space_id`
4 | ```
5 | ! pip install -q arize
6 | ```
7 | ```
8 | api_key = ''
9 | space_id = ''
10 | ```
11 | - You can get your `space_id` by visiting [app.arize.com](https://app.arize.com). The url will be in this format: `https://app.arize.com/organizations/:org_id/spaces/:space_id`
12 | **NOTE: this is not the same as the space key used to send data using the SDK**
13 |
14 | - To get `api_key`, you must have Developer Access to your space. Visit [arize docs](https://docs.arize.com/arize/integrations/graphql-api/getting-started-with-programmatic-access) for more details
15 | **NOTE: this is not the same as the api key in Space Settings**
16 |
17 | ### Step 2: Initiate an `ArizeExportClient`
18 |
19 | ```
20 | from arize.exporter import ArizeExportClient
21 |
22 | client = ArizeExportClient(api_key=api_key)
23 | ```
24 |
25 | ### Step 3: Export production data with predictions only to a pandas dataframe
26 |
27 | ```
28 | from arize.utils.types import Environments
29 | from datetime import datetime
30 |
31 | start_time = datetime(2023, 4, 10, 0, 0, 0, 0)
32 | end_time = datetime(2023, 4, 15, 0, 0, 0, 0)
33 |
34 | df = client.export_model_to_df(
35 | space_id=space_id,
36 | model_id='arize-demo-fraud-use-case',
37 | environment=Environments.PRODUCTION,
38 | start_time=start_time,
39 | end_time=end_time,
40 | model_version='', #optional field
41 | batch_id='', #optional field
42 | )
43 | ```
44 |
45 | ### Export production data with predictions and actuals to a pandas dataframe
46 |
47 | ```
48 | from arize.utils.types import Environments
49 | from datetime import datetime
50 |
51 | start_time = datetime(2023, 4, 10, 0, 0, 0, 0)
52 | end_time = datetime(2023, 4, 15, 0, 0, 0, 0)
53 |
54 | df = client.export_model_to_df(
55 | space_id=,
56 | model_id='arize-demo-fraud-use-case',
57 | environment=Environments.PRODUCTION,
58 | start_time=start_time,
59 | end_time=end_time,
60 | include_actuals=True, #optional field
61 | model_version='', #optional field
62 | batch_id='', #optional field
63 | )
64 | ```
65 |
66 |
67 | ### Export training data to a pandas dataframe
68 |
69 | ```
70 | from arize.utils.types import Environments
71 | from datetime import datetime
72 |
73 | start_time = datetime(2023, 4, 10, 0, 0, 0, 0)
74 | end_time = datetime(2023, 4, 15, 0, 0, 0, 0)
75 |
76 | df = client.export_model_to_df(
77 | space_id=,
78 | model_id='arize-demo-fraud-use-case',
79 | environment=Environments.TRAINING,
80 | start_time=start_time,
81 | end_time=end_time,
82 | model_version='', #optional field
83 | batch_id='', #optional field
84 | )
85 | ```
86 | ### Export validation data to a pandas dataframe
87 |
88 | ```
89 | from arize.utils.types import Environments
90 | from datetime import datetime
91 |
92 | start_time = datetime(2023, 4, 10, 0, 0, 0, 0)
93 | end_time = datetime(2023, 4, 15, 0, 0, 0, 0)
94 |
95 | df = client.export_model_to_df(
96 | space_id=,
97 | model_id='arize-demo-fraud-use-case',
98 | environment=Environments.VALIDATION,
99 | start_time=start_time,
100 | end_time=end_time,
101 | model_version='', #optional field
102 | batch_id='', #optional field
103 | )
104 | ```
105 |
106 | ### Export production data to a parquet file
107 |
108 | ```
109 | from arize.utils.types import Environments
110 | from datetime import datetime
111 |
112 | start_time = datetime(2023, 4, 10, 0, 0, 0, 0)
113 | end_time = datetime(2023, 4, 15, 0, 0, 0, 0)
114 |
115 | client.export_model_to_parquet(
116 | path = "example.parquet",
117 | space_id=,
118 | model_id='arize-demo-fraud-use-case',
119 | environment=Environments.PRODUCTION,
120 | start_time=start_time,
121 | end_time=end_time,
122 | model_version='', #optional field
123 | batch_id='', #optional field
124 | )
125 |
126 | df = pd.read_parquet("example.parquet")
127 | ```
128 |
--------------------------------------------------------------------------------
/src/arize/exporter/__init__.py:
--------------------------------------------------------------------------------
1 | from .core.client import ArizeExportClient
2 | from .utils.schema_parser import get_arize_schema
3 |
4 | __all__ = ["ArizeExportClient", "get_arize_schema"]
5 |
--------------------------------------------------------------------------------
/src/arize/exporter/core/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Arize-ai/client_python/8f72212f42490baa7a90831c591b83ab6173ba83/src/arize/exporter/core/__init__.py
--------------------------------------------------------------------------------
/src/arize/exporter/core/query.py:
--------------------------------------------------------------------------------
1 | from dataclasses import dataclass
2 | from typing import Tuple
3 |
4 | from google.protobuf import json_format
5 | from pyarrow import flight
6 |
7 | from arize.utils.logging import logger
8 |
9 | from .. import publicexporter_pb2 as exp_pb2
10 |
11 |
12 | @dataclass(frozen=True)
13 | class Query:
14 | query_descriptor: exp_pb2.RecordQueryDescriptor
15 |
16 | def execute(
17 | self,
18 | client: flight.FlightClient,
19 | call_options: flight.FlightCallOptions,
20 | ) -> Tuple[flight.FlightStreamReader, int]:
21 | try:
22 | flight_info = client.get_flight_info(
23 | flight.FlightDescriptor.for_command(
24 | json_format.MessageToJson(self.query_descriptor) # type: ignore
25 | ),
26 | call_options,
27 | )
28 | logger.info("Fetching data...")
29 |
30 | if flight_info.total_records == 0:
31 | logger.info("Query returns no data")
32 | return None, 0
33 | logger.debug("Ticket: %s", flight_info.endpoints[0].ticket)
34 |
35 | # Retrieve the result set as flight stream reader
36 | reader = client.do_get(
37 | flight_info.endpoints[0].ticket, call_options
38 | )
39 | logger.info("Starting exporting...")
40 | return reader, flight_info.total_records
41 |
42 | except Exception:
43 | logger.error(
44 | "There was an error trying to get the data from the endpoint"
45 | )
46 | raise
47 |
--------------------------------------------------------------------------------
/src/arize/exporter/core/session.py:
--------------------------------------------------------------------------------
1 | import configparser
2 | import os
3 | import uuid
4 | from dataclasses import dataclass, field
5 | from typing import Optional
6 |
7 | from pyarrow import flight
8 |
9 | from arize.utils.logging import logger
10 | from arize.utils.utils import get_python_version
11 | from arize.version import __version__
12 |
13 | from ..utils.constants import (
14 | ARIZE_API_KEY,
15 | DEFAULT_ARIZE_API_KEY_CONFIG_KEY,
16 | DEFAULT_PACKAGE_NAME,
17 | PROFILE_FILE_NAME,
18 | )
19 | from ..utils.errors import InvalidConfigFileError, InvalidSessionError
20 |
21 |
22 | @dataclass
23 | class Session:
24 | api_key: Optional[str]
25 | arize_profile: str
26 | arize_config_path: str
27 | host: str
28 | port: int
29 | scheme: str
30 | session_name: str = field(init=False)
31 | call_options: flight.FlightCallOptions = field(init=False)
32 |
33 | def __post_init__(self):
34 | self.session_name = f"python-sdk-{DEFAULT_PACKAGE_NAME}-{uuid.uuid4()}"
35 | logger.info(f"Creating named session as '{self.session_name}'.")
36 | # If api_key is not passed, try reading from environment variable.
37 | # If api_key is also not set as environment variable, read from config file
38 | self.api_key = self.api_key or ARIZE_API_KEY or self._read_config()
39 | if self.api_key is None:
40 | logger.error(InvalidSessionError.error_message())
41 | raise InvalidSessionError
42 |
43 | if self.host.startswith(("http://", "https://")):
44 | scheme, host = self.host.split("://")
45 | logger.warning(
46 | f"The host '{self.host}' starts with '{scheme}' and it will be stripped, "
47 | "remove it or consider using the `scheme` parameter instead."
48 | )
49 | self.host = host
50 |
51 | logger.debug(
52 | f"Created session with Arize API Token '{self.api_key}' at '{self.host}':'{self.port}'"
53 | )
54 | self._set_headers()
55 |
56 | def _read_config(self) -> Optional[str]:
57 | config_parser = Session._get_config_parser()
58 | file_path = os.path.join(self.arize_config_path, PROFILE_FILE_NAME)
59 | logger.debug(
60 | f"No provided connection details. Looking up session values from '{self.arize_profile}' in "
61 | f"'{file_path}'."
62 | )
63 | try:
64 | config_parser.read(file_path)
65 | return config_parser.get(
66 | self.arize_profile, DEFAULT_ARIZE_API_KEY_CONFIG_KEY
67 | )
68 | except configparser.NoSectionError as err:
69 | # Missing api key error is raised in the __post_init__ method
70 | logger.warning(
71 | f"Can't extract API key from config file. {err.message}"
72 | )
73 | return None
74 | except Exception as err:
75 | logger.error(InvalidConfigFileError.error_message())
76 | raise InvalidConfigFileError from err
77 |
78 | @staticmethod
79 | def _get_config_parser() -> configparser.ConfigParser:
80 | return configparser.ConfigParser()
81 |
82 | def connect(self) -> flight.FlightClient:
83 | """
84 | Connects to Arize Flight server public endpoint with the
85 | provided api key.
86 | """
87 | try:
88 | disable_cert = self.host.lower() == "localhost"
89 | client = flight.FlightClient(
90 | location=f"{self.scheme}://{self.host}:{self.port}",
91 | disable_server_verification=disable_cert,
92 | )
93 | self.call_options = flight.FlightCallOptions(headers=self._headers)
94 | return client
95 | except Exception:
96 | logger.error(
97 | "There was an error trying to connect to the Arize Flight Endpoint"
98 | )
99 | raise
100 |
101 | def _set_headers(self) -> None:
102 | self._headers = [
103 | (b"origin", b"arize-python-exporter"),
104 | (b"auth-token-bin", f"{self.api_key}".encode()),
105 | (b"sdk-language", b"python"),
106 | (b"language-version", get_python_version().encode("utf-8")),
107 | (b"sdk-version", __version__.encode("utf-8")),
108 | (b"arize-interface", b"flight"),
109 | ]
110 |
--------------------------------------------------------------------------------
/src/arize/exporter/publicexporter_pb2.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # Generated by the protocol buffer compiler. DO NOT EDIT!
3 | # source: publicexporter.proto
4 | """Generated protocol buffer code."""
5 | from google.protobuf import descriptor as _descriptor
6 | from google.protobuf import descriptor_pool as _descriptor_pool
7 | from google.protobuf import message as _message
8 | from google.protobuf import reflection as _reflection
9 | from google.protobuf import symbol_database as _symbol_database
10 | # @@protoc_insertion_point(imports)
11 |
12 | _sym_db = _symbol_database.Default()
13 |
14 |
15 | from google.protobuf import timestamp_pb2 as google_dot_protobuf_dot_timestamp__pb2
16 |
17 |
18 | DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x14publicexporter.proto\x12\x0epublicexporter\x1a\x1fgoogle/protobuf/timestamp.proto\"\xfa\x03\n\x15RecordQueryDescriptor\x12\x10\n\x08space_id\x18\x01 \x01(\t\x12\x10\n\x08model_id\x18\x02 \x01(\t\x12\x46\n\x0b\x65nvironment\x18\x03 \x01(\x0e\x32\x31.publicexporter.RecordQueryDescriptor.Environment\x12\x15\n\rmodel_version\x18\x04 \x01(\t\x12\x10\n\x08\x62\x61tch_id\x18\x05 \x01(\t\x12.\n\nstart_time\x18\x06 \x01(\x0b\x32\x1a.google.protobuf.Timestamp\x12,\n\x08\x65nd_time\x18\x07 \x01(\x0b\x32\x1a.google.protobuf.Timestamp\x12\x17\n\x0finclude_actuals\x18\x08 \x01(\x08\x12\x19\n\x11\x66ilter_expression\x18\t \x01(\t\x12H\n\x18similarity_search_params\x18\n \x01(\x0b\x32&.publicexporter.SimilaritySearchParams\x12\x19\n\x11projected_columns\x18\x0b \x03(\t\"U\n\x0b\x45nvironment\x12\x0b\n\x07UNKNOWN\x10\x00\x12\x0c\n\x08TRAINING\x10\x01\x12\x0e\n\nVALIDATION\x10\x02\x12\x0e\n\nPRODUCTION\x10\x03\x12\x0b\n\x07TRACING\x10\x04\"\x9b\x02\n\x16SimilaritySearchParams\x12\x44\n\nreferences\x18\x01 \x03(\x0b\x32\x30.publicexporter.SimilaritySearchParams.Reference\x12\x1a\n\x12search_column_name\x18\x02 \x01(\t\x12\x11\n\tthreshold\x18\x03 \x01(\x01\x1a\x8b\x01\n\tReference\x12\x15\n\rprediction_id\x18\x01 \x01(\t\x12\x1d\n\x15reference_column_name\x18\x02 \x01(\t\x12\x38\n\x14prediction_timestamp\x18\x03 \x01(\x0b\x32\x1a.google.protobuf.Timestamp\x12\x0e\n\x06vector\x18\x04 \x03(\x01\x42GZEgithub.com/Arize-ai/arize/go/pkg/flightserver/protocol/publicexporterb\x06proto3')
19 |
20 |
21 |
22 | _RECORDQUERYDESCRIPTOR = DESCRIPTOR.message_types_by_name['RecordQueryDescriptor']
23 | _SIMILARITYSEARCHPARAMS = DESCRIPTOR.message_types_by_name['SimilaritySearchParams']
24 | _SIMILARITYSEARCHPARAMS_REFERENCE = _SIMILARITYSEARCHPARAMS.nested_types_by_name['Reference']
25 | _RECORDQUERYDESCRIPTOR_ENVIRONMENT = _RECORDQUERYDESCRIPTOR.enum_types_by_name['Environment']
26 | RecordQueryDescriptor = _reflection.GeneratedProtocolMessageType('RecordQueryDescriptor', (_message.Message,), {
27 | 'DESCRIPTOR' : _RECORDQUERYDESCRIPTOR,
28 | '__module__' : 'publicexporter_pb2'
29 | # @@protoc_insertion_point(class_scope:publicexporter.RecordQueryDescriptor)
30 | })
31 | _sym_db.RegisterMessage(RecordQueryDescriptor)
32 |
33 | SimilaritySearchParams = _reflection.GeneratedProtocolMessageType('SimilaritySearchParams', (_message.Message,), {
34 |
35 | 'Reference' : _reflection.GeneratedProtocolMessageType('Reference', (_message.Message,), {
36 | 'DESCRIPTOR' : _SIMILARITYSEARCHPARAMS_REFERENCE,
37 | '__module__' : 'publicexporter_pb2'
38 | # @@protoc_insertion_point(class_scope:publicexporter.SimilaritySearchParams.Reference)
39 | })
40 | ,
41 | 'DESCRIPTOR' : _SIMILARITYSEARCHPARAMS,
42 | '__module__' : 'publicexporter_pb2'
43 | # @@protoc_insertion_point(class_scope:publicexporter.SimilaritySearchParams)
44 | })
45 | _sym_db.RegisterMessage(SimilaritySearchParams)
46 | _sym_db.RegisterMessage(SimilaritySearchParams.Reference)
47 |
48 | if _descriptor._USE_C_DESCRIPTORS == False:
49 |
50 | DESCRIPTOR._options = None
51 | DESCRIPTOR._serialized_options = b'ZEgithub.com/Arize-ai/arize/go/pkg/flightserver/protocol/publicexporter'
52 | _RECORDQUERYDESCRIPTOR._serialized_start=74
53 | _RECORDQUERYDESCRIPTOR._serialized_end=580
54 | _RECORDQUERYDESCRIPTOR_ENVIRONMENT._serialized_start=495
55 | _RECORDQUERYDESCRIPTOR_ENVIRONMENT._serialized_end=580
56 | _SIMILARITYSEARCHPARAMS._serialized_start=583
57 | _SIMILARITYSEARCHPARAMS._serialized_end=866
58 | _SIMILARITYSEARCHPARAMS_REFERENCE._serialized_start=727
59 | _SIMILARITYSEARCHPARAMS_REFERENCE._serialized_end=866
60 | # @@protoc_insertion_point(module_scope)
61 |
--------------------------------------------------------------------------------
/src/arize/exporter/utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Arize-ai/client_python/8f72212f42490baa7a90831c591b83ab6173ba83/src/arize/exporter/utils/__init__.py
--------------------------------------------------------------------------------
/src/arize/exporter/utils/constants.py:
--------------------------------------------------------------------------------
1 | import os
2 | from pathlib import Path
3 |
4 | """Environmental configuration"""
5 |
6 | # Override ARIZE Default Profile when reading from the config-file in a session.
7 | ARIZE_PROFILE = os.getenv("ARIZE_PROFILE")
8 |
9 | # Override ARIZE API Token when creating a session.
10 | ARIZE_API_KEY = os.getenv("ARIZE_API_KEY")
11 |
12 |
13 | """Internal Use"""
14 |
15 | # Default API endpoint when not provided through env variable nor profile
16 | DEFAULT_ARIZE_FLIGHT_HOST = "flight.arize.com"
17 | DEFAULT_ARIZE_FLIGHT_PORT = 443
18 |
19 | # Name of the current package.
20 | DEFAULT_PACKAGE_NAME = "arize_python_export_client"
21 |
22 | # Default config keys for the Arize config file. Created via the CLI.
23 | DEFAULT_ARIZE_API_KEY_CONFIG_KEY = "api_key"
24 |
25 | # Default headers to trace and help identify requests. For debugging.
26 | DEFAULT_ARIZE_SESSION_ID = "x-arize-session-id" # Generally the session name.
27 | DEFAULT_ARIZE_TRACE_ID = "x-arize-trace-id"
28 | DEFAULT_PACKAGE_VERSION = "x-package-version"
29 |
30 | # File name for profile configuration.
31 | PROFILE_FILE_NAME = "profiles.ini"
32 |
33 | # Default profile to be used.
34 | DEFAULT_PROFILE_NAME = "default"
35 |
36 | # Default path where any configuration files are written.
37 | DEFAULT_CONFIG_PATH = os.path.join(str(Path.home()), ".arize")
38 |
39 | # Default initial wait time for retries in seconds.
40 | DEFAULT_RETRY_INITIAL_WAIT_TIME = 0.25
41 |
42 | # Default maximum wait time for retries in seconds.
43 | DEFAULT_RETRY_MAX_WAIT_TIME = 10.0
44 |
45 | # Default to use grpc + tls scheme.
46 | DEFAULT_TRANSPORT_SCHEME = "grpc+tls"
47 |
--------------------------------------------------------------------------------
/src/arize/exporter/utils/errors.py:
--------------------------------------------------------------------------------
1 | from abc import ABC, abstractmethod
2 |
3 |
4 | class ExportingError(Exception, ABC):
5 | def __str__(self) -> str:
6 | return self.error_message()
7 |
8 | @abstractmethod
9 | def error_message(self) -> str:
10 | pass
11 |
12 |
13 | class InvalidSessionError(ExportingError):
14 | @staticmethod
15 | def error_message() -> str:
16 | return (
17 | "Credentials not provided or invalid. Please pass in the correct api_key when "
18 | "initiating a new ArizeExportClient. Alternatively, you can set up credentials "
19 | "in a profile or as an environment variable"
20 | )
21 |
22 |
23 | class InvalidConfigFileError(Exception):
24 | @staticmethod
25 | def error_message() -> str:
26 | return "WENDY TO WRITE APPROPRIATE ERROR MESSAGE"
27 |
--------------------------------------------------------------------------------
/src/arize/exporter/utils/tracing.py:
--------------------------------------------------------------------------------
1 | import importlib.util
2 | import json
3 | from typing import List
4 |
5 | import numpy as np
6 | import pandas as pd
7 |
8 | from arize.utils.logging import logger
9 |
10 | try:
11 | oic_spec = importlib.util.find_spec("openinference.semconv")
12 | except Exception:
13 | oic_spec = None
14 |
15 | if oic_spec is not None:
16 | from arize.pandas.tracing.columns import (
17 | SPAN_ATTRIBUTES_EMBEDDING_EMBEDDINGS_COL,
18 | SPAN_ATTRIBUTES_LLM_INPUT_MESSAGES_COL,
19 | SPAN_ATTRIBUTES_LLM_INVOCATION_PARAMETERS_COL,
20 | SPAN_ATTRIBUTES_LLM_OUTPUT_MESSAGES_COL,
21 | SPAN_ATTRIBUTES_LLM_PROMPT_TEMPLATE_VARIABLES_COL,
22 | SPAN_ATTRIBUTES_LLM_TOOLS_COL,
23 | SPAN_ATTRIBUTES_METADATA,
24 | SPAN_ATTRIBUTES_RETRIEVAL_DOCUMENTS_COL,
25 | SPAN_ATTRIBUTES_TOOL_PARAMETERS_COL,
26 | SPAN_END_TIME_COL,
27 | SPAN_START_TIME_COL,
28 | )
29 |
30 |
31 | # Data transformer for Otel tracing data into types and values that are more ergonomic for users needing
32 | # to interact with the data in Python; This class is intended to be used by Arize and not by users
33 | # Any errors encountered are unexpected since Arize also controls the data types returned from the platform
34 | # but the resulting error messages provide clarity on what the effect
35 | # of the error is on the data; It should not prevent a user from continuing to use the data
36 | class OtelTracingDataTransformer:
37 | def transform(self, df: pd.DataFrame) -> pd.DataFrame:
38 | errors: List[str] = []
39 |
40 | # Convert list of json serializable strings columns to list of dictionaries for more
41 | # conveinent data processing in Python
42 | list_of_json_string_column_names: List[str] = [
43 | col.name
44 | for col in [
45 | SPAN_ATTRIBUTES_LLM_INPUT_MESSAGES_COL,
46 | SPAN_ATTRIBUTES_LLM_OUTPUT_MESSAGES_COL,
47 | SPAN_ATTRIBUTES_EMBEDDING_EMBEDDINGS_COL,
48 | SPAN_ATTRIBUTES_RETRIEVAL_DOCUMENTS_COL,
49 | SPAN_ATTRIBUTES_LLM_TOOLS_COL,
50 | ]
51 | if col.name in df.columns
52 | ]
53 | for col_name in list_of_json_string_column_names:
54 | try:
55 | df[col_name] = df[col_name].apply(
56 | self._transform_value_to_list_of_dict
57 | )
58 | except Exception as e:
59 | errors.append(
60 | f"Unable to transform json string data to a Python dict in column '{col_name}'; "
61 | f"May encounter issues when importing data back into Arize; Error: {e}"
62 | )
63 |
64 | json_string_column_names: List[str] = [
65 | col.name
66 | for col in [
67 | SPAN_ATTRIBUTES_LLM_PROMPT_TEMPLATE_VARIABLES_COL,
68 | SPAN_ATTRIBUTES_METADATA,
69 | ]
70 | if col.name in df.columns
71 | ]
72 | for col_name in json_string_column_names:
73 | try:
74 | df[col_name] = df[col_name].apply(self._transform_json_to_dict)
75 | except Exception as e:
76 | errors.append(
77 | f"Unable to transform json string data to a Python dict in column '{col_name}'; "
78 | f"May encounter issues when importing data back into Arize; Error: {e}"
79 | )
80 |
81 | # Clean json string columns since empty strings are equivalent here to None but are not valid json
82 | dirty_string_column_names: List[str] = [
83 | col.name
84 | for col in [
85 | SPAN_ATTRIBUTES_LLM_INVOCATION_PARAMETERS_COL,
86 | SPAN_ATTRIBUTES_TOOL_PARAMETERS_COL,
87 | ]
88 | if col.name in df.columns
89 | ]
90 | for col_name in dirty_string_column_names:
91 | df[col_name] = df[col_name].apply(self._clean_json_string)
92 |
93 | # Convert timestamp columns to datetime objects
94 | timestamp_column_names: List[str] = [
95 | col.name
96 | for col in [
97 | SPAN_START_TIME_COL,
98 | SPAN_END_TIME_COL,
99 | ]
100 | if col.name in df.columns
101 | ]
102 | for col_name in timestamp_column_names:
103 | df[col_name] = df[col_name].apply(
104 | self._convert_timestamp_to_datetime
105 | )
106 |
107 | for err in errors:
108 | logger.warning(err)
109 |
110 | return df
111 |
112 | def _transform_value_to_list_of_dict(self, value):
113 | if value is None:
114 | return None
115 |
116 | if isinstance(value, (list, np.ndarray)):
117 | return [
118 | self._deserialize_json_string_to_dict(i)
119 | for i in value
120 | if self._is_non_empty_string(i)
121 | ]
122 | elif self._is_non_empty_string(value):
123 | return [self._deserialize_json_string_to_dict(value)]
124 |
125 | def _transform_json_to_dict(self, value):
126 | if value is None:
127 | return None
128 |
129 | if self._is_non_empty_string(value):
130 | return self._deserialize_json_string_to_dict(value)
131 |
132 | if isinstance(value, str) and value == "":
133 | # transform empty string to None
134 | return None
135 |
136 | def _is_non_empty_string(self, value):
137 | return isinstance(value, str) and value != ""
138 |
139 | def _deserialize_json_string_to_dict(self, value: str):
140 | try:
141 | return json.loads(value)
142 | except json.JSONDecodeError as e:
143 | raise ValueError(f"Invalid JSON string: {value}") from e
144 |
145 | def _clean_json_string(self, value):
146 | return value if self._is_non_empty_string(value) else None
147 |
148 | def _convert_timestamp_to_datetime(self, value):
149 | return (
150 | pd.Timestamp(value, unit="ns")
151 | if value and isinstance(value, (int, float, np.int64))
152 | else value
153 | )
154 |
--------------------------------------------------------------------------------
/src/arize/exporter/utils/validation.py:
--------------------------------------------------------------------------------
1 | from datetime import datetime
2 |
3 |
4 | class Validator:
5 | @staticmethod
6 | def validate_input_type(input, input_name: str, input_type: type) -> None:
7 | if not isinstance(input, input_type) and input is not None:
8 | raise TypeError(
9 | f"{input_name} {input} is type {type(input)}, but must be a {input_type.__name__}"
10 | )
11 |
12 | @staticmethod
13 | def validate_input_value(input, input_name: str, choices: tuple) -> None:
14 | if input not in choices:
15 | raise ValueError(
16 | f"{input_name} is {input}, but must be one of {', '.join(choices)}"
17 | )
18 |
19 | @staticmethod
20 | def validate_start_end_time(start_time, end_time: datetime) -> None:
21 | if start_time >= end_time:
22 | raise ValueError("start_time must be before end_time")
23 |
--------------------------------------------------------------------------------
/src/arize/pandas/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Arize-ai/client_python/8f72212f42490baa7a90831c591b83ab6173ba83/src/arize/pandas/__init__.py
--------------------------------------------------------------------------------
/src/arize/pandas/embeddings/__init__.py:
--------------------------------------------------------------------------------
1 | from .auto_generator import EmbeddingGenerator
2 | from .usecases import UseCases
3 |
4 | __all__ = ["EmbeddingGenerator", "UseCases"]
5 |
--------------------------------------------------------------------------------
/src/arize/pandas/embeddings/auto_generator.py:
--------------------------------------------------------------------------------
1 | import pandas as pd
2 |
3 | from . import constants
4 | from .base_generators import BaseEmbeddingGenerator
5 | from .constants import (
6 | CV_PRETRAINED_MODELS,
7 | DEFAULT_CV_IMAGE_CLASSIFICATION_MODEL,
8 | DEFAULT_CV_OBJECT_DETECTION_MODEL,
9 | DEFAULT_NLP_SEQUENCE_CLASSIFICATION_MODEL,
10 | DEFAULT_NLP_SUMMARIZATION_MODEL,
11 | DEFAULT_TABULAR_MODEL,
12 | NLP_PRETRAINED_MODELS,
13 | )
14 | from .cv_generators import (
15 | EmbeddingGeneratorForCVImageClassification,
16 | EmbeddingGeneratorForCVObjectDetection,
17 | )
18 | from .nlp_generators import (
19 | EmbeddingGeneratorForNLPSequenceClassification,
20 | EmbeddingGeneratorForNLPSummarization,
21 | )
22 | from .tabular_generators import EmbeddingGeneratorForTabularFeatures
23 | from .usecases import UseCases
24 |
25 |
26 | class EmbeddingGenerator:
27 | def __init__(self, **kwargs: str):
28 | raise OSError(
29 | f"{self.__class__.__name__} is designed to be instantiated using the "
30 | f"`{self.__class__.__name__}.from_use_case(use_case, **kwargs)` method."
31 | )
32 |
33 | @staticmethod
34 | def from_use_case(use_case: str, **kwargs: str) -> BaseEmbeddingGenerator:
35 | if use_case == UseCases.NLP.SEQUENCE_CLASSIFICATION:
36 | return EmbeddingGeneratorForNLPSequenceClassification(**kwargs)
37 | elif use_case == UseCases.NLP.SUMMARIZATION:
38 | return EmbeddingGeneratorForNLPSummarization(**kwargs)
39 | elif use_case == UseCases.CV.IMAGE_CLASSIFICATION:
40 | return EmbeddingGeneratorForCVImageClassification(**kwargs)
41 | elif use_case == UseCases.CV.OBJECT_DETECTION:
42 | return EmbeddingGeneratorForCVObjectDetection(**kwargs)
43 | elif use_case == UseCases.STRUCTURED.TABULAR_EMBEDDINGS:
44 | return EmbeddingGeneratorForTabularFeatures(**kwargs)
45 | else:
46 | raise ValueError(f"Invalid use case {use_case}")
47 |
48 | @classmethod
49 | def list_default_models(cls) -> pd.DataFrame:
50 | df = pd.DataFrame(
51 | {
52 | "Area": ["NLP", "NLP", "CV", "CV", "STRUCTURED"],
53 | "Usecase": [
54 | UseCases.NLP.SEQUENCE_CLASSIFICATION.name,
55 | UseCases.NLP.SUMMARIZATION.name,
56 | UseCases.CV.IMAGE_CLASSIFICATION.name,
57 | UseCases.CV.OBJECT_DETECTION.name,
58 | UseCases.STRUCTURED.TABULAR_EMBEDDINGS.name,
59 | ],
60 | "Model Name": [
61 | DEFAULT_NLP_SEQUENCE_CLASSIFICATION_MODEL,
62 | DEFAULT_NLP_SUMMARIZATION_MODEL,
63 | DEFAULT_CV_IMAGE_CLASSIFICATION_MODEL,
64 | DEFAULT_CV_OBJECT_DETECTION_MODEL,
65 | DEFAULT_TABULAR_MODEL,
66 | ],
67 | }
68 | )
69 | df.sort_values(
70 | by=[col for col in df.columns], ascending=True, inplace=True
71 | )
72 | return df.reset_index(drop=True)
73 |
74 | @classmethod
75 | def list_pretrained_models(cls) -> pd.DataFrame:
76 | data = {
77 | "Task": ["NLP" for _ in NLP_PRETRAINED_MODELS]
78 | + ["CV" for _ in CV_PRETRAINED_MODELS],
79 | "Architecture": [
80 | cls.__parse_model_arch(model)
81 | for model in NLP_PRETRAINED_MODELS + CV_PRETRAINED_MODELS
82 | ],
83 | "Model Name": NLP_PRETRAINED_MODELS + CV_PRETRAINED_MODELS,
84 | }
85 | df = pd.DataFrame(data)
86 | df.sort_values(
87 | by=[col for col in df.columns], ascending=True, inplace=True
88 | )
89 | return df.reset_index(drop=True)
90 |
91 | @staticmethod
92 | def __parse_model_arch(model_name: str) -> str:
93 | if constants.GPT.lower() in model_name.lower():
94 | return constants.GPT
95 | elif constants.BERT.lower() in model_name.lower():
96 | return constants.BERT
97 | elif constants.VIT.lower() in model_name.lower():
98 | return constants.VIT
99 | else:
100 | raise ValueError("Invalid model_name, unknown architecture.")
101 |
--------------------------------------------------------------------------------
/src/arize/pandas/embeddings/constants.py:
--------------------------------------------------------------------------------
1 | DEFAULT_NLP_SEQUENCE_CLASSIFICATION_MODEL = "distilbert-base-uncased"
2 | DEFAULT_NLP_SUMMARIZATION_MODEL = "distilbert-base-uncased"
3 | DEFAULT_TABULAR_MODEL = "distilbert-base-uncased"
4 | DEFAULT_CV_IMAGE_CLASSIFICATION_MODEL = "google/vit-base-patch32-224-in21k"
5 | DEFAULT_CV_OBJECT_DETECTION_MODEL = "facebook/detr-resnet-101"
6 | NLP_PRETRAINED_MODELS = [
7 | "bert-base-cased",
8 | "bert-base-uncased",
9 | "bert-large-cased",
10 | "bert-large-uncased",
11 | "distilbert-base-cased",
12 | "distilbert-base-uncased",
13 | "xlm-roberta-base",
14 | "xlm-roberta-large",
15 | ]
16 |
17 | CV_PRETRAINED_MODELS = [
18 | "google/vit-base-patch16-224-in21k",
19 | "google/vit-base-patch16-384",
20 | "google/vit-base-patch32-224-in21k",
21 | "google/vit-base-patch32-384",
22 | "google/vit-large-patch16-224-in21k",
23 | "google/vit-large-patch16-384",
24 | "google/vit-large-patch32-224-in21k",
25 | "google/vit-large-patch32-384",
26 | ]
27 | IMPORT_ERROR_MESSAGE = (
28 | "To enable embedding generation, the arize module must be installed with "
29 | "extra dependencies. Run: pip install 'arize[AutoEmbeddings]'."
30 | )
31 |
32 | GPT = "GPT"
33 | BERT = "BERT"
34 | VIT = "ViT"
35 |
--------------------------------------------------------------------------------
/src/arize/pandas/embeddings/cv_generators.py:
--------------------------------------------------------------------------------
1 | from .base_generators import CVEmbeddingGenerator
2 | from .constants import (
3 | DEFAULT_CV_IMAGE_CLASSIFICATION_MODEL,
4 | DEFAULT_CV_OBJECT_DETECTION_MODEL,
5 | )
6 | from .usecases import UseCases
7 |
8 |
9 | class EmbeddingGeneratorForCVImageClassification(CVEmbeddingGenerator):
10 | def __init__(
11 | self, model_name: str = DEFAULT_CV_IMAGE_CLASSIFICATION_MODEL, **kwargs
12 | ):
13 | super().__init__(
14 | use_case=UseCases.CV.IMAGE_CLASSIFICATION,
15 | model_name=model_name,
16 | **kwargs,
17 | )
18 |
19 |
20 | class EmbeddingGeneratorForCVObjectDetection(CVEmbeddingGenerator):
21 | def __init__(
22 | self, model_name: str = DEFAULT_CV_OBJECT_DETECTION_MODEL, **kwargs
23 | ):
24 | super().__init__(
25 | use_case=UseCases.CV.OBJECT_DETECTION,
26 | model_name=model_name,
27 | **kwargs,
28 | )
29 |
--------------------------------------------------------------------------------
/src/arize/pandas/embeddings/errors.py:
--------------------------------------------------------------------------------
1 | class InvalidIndexError(Exception):
2 | def __repr__(self) -> str:
3 | return "Invalid_Index_Error"
4 |
5 | def __str__(self) -> str:
6 | return self.error_message()
7 |
8 | def __init__(self, field_name: str) -> None:
9 | self.field_name = field_name
10 |
11 | def error_message(self) -> str:
12 | if self.field_name == "DataFrame":
13 | return (
14 | f"The index of the {self.field_name} is invalid; "
15 | f"reset the index by using df.reset_index(drop=True, inplace=True)"
16 | )
17 | else:
18 | return (
19 | f"The index of the Series given by the column '{self.field_name}' is invalid; "
20 | f"reset the index by using df.reset_index(drop=True, inplace=True)"
21 | )
22 |
23 |
24 | class HuggingFaceRepositoryNotFound(Exception):
25 | def __repr__(self) -> str:
26 | return "HuggingFace_Repository_Not_Found_Error"
27 |
28 | def __str__(self) -> str:
29 | return self.error_message()
30 |
31 | def __init__(self, model_name: str) -> None:
32 | self.model_name = model_name
33 |
34 | def error_message(self) -> str:
35 | return (
36 | f"The given model name '{self.model_name}' is not a valid model identifier listed on "
37 | "'https://huggingface.co/models'. "
38 | "If this is a private repository, log in with `huggingface-cli login` or importing "
39 | "`login` from `huggingface_hub` if you are using a notebook. "
40 | "Learn more in https://huggingface.co/docs/huggingface_hub/quick-start#login"
41 | )
42 |
--------------------------------------------------------------------------------
/src/arize/pandas/embeddings/nlp_generators.py:
--------------------------------------------------------------------------------
1 | from functools import partial
2 | from typing import Optional, cast
3 |
4 | import pandas as pd
5 |
6 | from arize.utils.logging import logger
7 |
8 | from .base_generators import NLPEmbeddingGenerator
9 | from .constants import (
10 | DEFAULT_NLP_SEQUENCE_CLASSIFICATION_MODEL,
11 | DEFAULT_NLP_SUMMARIZATION_MODEL,
12 | IMPORT_ERROR_MESSAGE,
13 | )
14 | from .usecases import UseCases
15 |
16 | try:
17 | from datasets import Dataset
18 | except ImportError:
19 | raise ImportError(IMPORT_ERROR_MESSAGE) from None
20 |
21 |
22 | class EmbeddingGeneratorForNLPSequenceClassification(NLPEmbeddingGenerator):
23 | def __init__(
24 | self,
25 | model_name: str = DEFAULT_NLP_SEQUENCE_CLASSIFICATION_MODEL,
26 | **kwargs,
27 | ):
28 | super().__init__(
29 | use_case=UseCases.NLP.SEQUENCE_CLASSIFICATION,
30 | model_name=model_name,
31 | **kwargs,
32 | )
33 |
34 | def generate_embeddings(
35 | self,
36 | text_col: pd.Series,
37 | class_label_col: Optional[pd.Series] = None,
38 | ) -> pd.Series:
39 | """
40 | Obtain embedding vectors from your text data using pre-trained large language models.
41 |
42 | :param text_col: a pandas Series containing the different pieces of text.
43 | :param class_label_col: if this column is passed, the sentence "The classification label
44 | is " will be appended to the text in the `text_col`.
45 | :return: a pandas Series containing the embedding vectors.
46 | """
47 | if not isinstance(text_col, pd.Series):
48 | raise TypeError("text_col must be a pandas Series")
49 |
50 | self.check_invalid_index(field=text_col)
51 |
52 | if class_label_col is not None:
53 | if not isinstance(class_label_col, pd.Series):
54 | raise TypeError("class_label_col must be a pandas Series")
55 | df = pd.concat(
56 | {"text": text_col, "class_label": class_label_col}, axis=1
57 | )
58 | prepared_text_col = df.apply(
59 | lambda row: f" The classification label is {row['class_label']}. {row['text']}",
60 | axis=1,
61 | )
62 | ds = Dataset.from_dict({"text": prepared_text_col})
63 | else:
64 | ds = Dataset.from_dict({"text": text_col})
65 |
66 | ds.set_transform(partial(self.tokenize, text_feat_name="text"))
67 | logger.info("Generating embedding vectors")
68 | ds = ds.map(
69 | lambda batch: self._get_embedding_vector(batch, "cls_token"),
70 | batched=True,
71 | batch_size=self.batch_size,
72 | )
73 | return cast(pd.DataFrame, ds.to_pandas())["embedding_vector"]
74 |
75 |
76 | class EmbeddingGeneratorForNLPSummarization(NLPEmbeddingGenerator):
77 | def __init__(
78 | self, model_name: str = DEFAULT_NLP_SUMMARIZATION_MODEL, **kwargs
79 | ):
80 | super().__init__(
81 | use_case=UseCases.NLP.SUMMARIZATION,
82 | model_name=model_name,
83 | **kwargs,
84 | )
85 |
86 | def generate_embeddings(
87 | self,
88 | text_col: pd.Series,
89 | ) -> pd.Series:
90 | """
91 | Obtain embedding vectors from your text data using pre-trained large language models.
92 |
93 | :param text_col: a pandas Series containing the different pieces of text.
94 | :return: a pandas Series containing the embedding vectors.
95 | """
96 | if not isinstance(text_col, pd.Series):
97 | raise TypeError("text_col must be a pandas Series")
98 | self.check_invalid_index(field=text_col)
99 |
100 | ds = Dataset.from_dict({"text": text_col})
101 |
102 | ds.set_transform(partial(self.tokenize, text_feat_name="text"))
103 | logger.info("Generating embedding vectors")
104 | ds = ds.map(
105 | lambda batch: self._get_embedding_vector(batch, "cls_token"),
106 | batched=True,
107 | batch_size=self.batch_size,
108 | )
109 | return cast(pd.DataFrame, ds.to_pandas())["embedding_vector"]
110 |
--------------------------------------------------------------------------------
/src/arize/pandas/embeddings/usecases.py:
--------------------------------------------------------------------------------
1 | from dataclasses import dataclass
2 | from enum import Enum, auto, unique
3 |
4 |
5 | @unique
6 | class NLPUseCases(Enum):
7 | SEQUENCE_CLASSIFICATION = auto()
8 | SUMMARIZATION = auto()
9 |
10 |
11 | @unique
12 | class CVUseCases(Enum):
13 | IMAGE_CLASSIFICATION = auto()
14 | OBJECT_DETECTION = auto()
15 |
16 |
17 | @unique
18 | class TabularUsecases(Enum):
19 | TABULAR_EMBEDDINGS = auto()
20 |
21 |
22 | @dataclass
23 | class UseCases:
24 | NLP = NLPUseCases
25 | CV = CVUseCases
26 | STRUCTURED = TabularUsecases
27 |
--------------------------------------------------------------------------------
/src/arize/pandas/etl/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Arize-ai/client_python/8f72212f42490baa7a90831c591b83ab6173ba83/src/arize/pandas/etl/__init__.py
--------------------------------------------------------------------------------
/src/arize/pandas/etl/errors.py:
--------------------------------------------------------------------------------
1 | from arize.utils.logging import log_a_list
2 | from arize.utils.types import TypedColumns
3 |
4 |
5 | class ColumnCastingError(Exception):
6 | def __str__(self) -> str:
7 | return self.error_message()
8 |
9 | def __init__(
10 | self,
11 | error_msg: str,
12 | attempted_columns: str,
13 | attempted_type: TypedColumns,
14 | ) -> None:
15 | self.error_msg = error_msg
16 | self.attempted_casting_columns = attempted_columns
17 | self.attempted_casting_type = attempted_type
18 |
19 | def error_message(self) -> str:
20 | return (
21 | f"Failed to cast to type {self.attempted_casting_type} "
22 | f"for columns: {log_a_list(self.attempted_casting_columns, 'and')}. "
23 | f"Error: {self.error_msg}"
24 | )
25 |
26 |
27 | class InvalidTypedColumnsError(Exception):
28 | def __str__(self) -> str:
29 | return self.error_message()
30 |
31 | def __init__(self, field_name: str, reason: str) -> None:
32 | self.field_name = field_name
33 | self.reason = reason
34 |
35 | def error_message(self) -> str:
36 | return f"The {self.field_name} TypedColumns object {self.reason}."
37 |
38 |
39 | class InvalidSchemaFieldTypeError(Exception):
40 | def __str__(self) -> str:
41 | return self.error_message()
42 |
43 | def __init__(self, msg: str) -> None:
44 | self.msg = msg
45 |
46 | def error_message(self) -> str:
47 | return self.msg
48 |
--------------------------------------------------------------------------------
/src/arize/pandas/generative/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Arize-ai/client_python/8f72212f42490baa7a90831c591b83ab6173ba83/src/arize/pandas/generative/__init__.py
--------------------------------------------------------------------------------
/src/arize/pandas/generative/llm_evaluation/__init__.py:
--------------------------------------------------------------------------------
1 | from .hf_metrics import bleu, google_bleu, meteor, rouge, sacre_bleu
2 |
3 | __all__ = [
4 | "bleu",
5 | "sacre_bleu",
6 | "meteor",
7 | "rouge",
8 | "google_bleu",
9 | ]
10 |
--------------------------------------------------------------------------------
/src/arize/pandas/generative/llm_evaluation/constants.py:
--------------------------------------------------------------------------------
1 | IMPORT_ERROR_MESSAGE = (
2 | "To enable evaluation of language models, the arize module must be installed with "
3 | "extra dependencies. Run: pip install 'arize[NLP_Metrics]'."
4 | )
5 |
--------------------------------------------------------------------------------
/src/arize/pandas/generative/nlp_metrics/__init__.py:
--------------------------------------------------------------------------------
1 | from .hf_metrics import bleu, google_bleu, meteor, rouge, sacre_bleu
2 |
3 | __all__ = [
4 | "bleu",
5 | "google_bleu",
6 | "meteor",
7 | "rouge",
8 | "sacre_bleu",
9 | ]
10 |
--------------------------------------------------------------------------------
/src/arize/pandas/proto/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Arize-ai/client_python/8f72212f42490baa7a90831c591b83ab6173ba83/src/arize/pandas/proto/__init__.py
--------------------------------------------------------------------------------
/src/arize/pandas/surrogate_explainer/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Arize-ai/client_python/8f72212f42490baa7a90831c591b83ab6173ba83/src/arize/pandas/surrogate_explainer/__init__.py
--------------------------------------------------------------------------------
/src/arize/pandas/surrogate_explainer/mimic.py:
--------------------------------------------------------------------------------
1 | import random
2 | import string
3 | from dataclasses import replace
4 | from typing import Callable, Tuple
5 |
6 | import numpy as np
7 | import pandas as pd
8 | from interpret_community.mimic.mimic_explainer import (
9 | LGBMExplainableModel,
10 | MimicExplainer,
11 | )
12 | from sklearn.preprocessing import LabelEncoder
13 |
14 | from arize.pandas.logger import Schema
15 | from arize.utils.types import (
16 | CATEGORICAL_MODEL_TYPES,
17 | NUMERIC_MODEL_TYPES,
18 | ModelTypes,
19 | )
20 |
21 |
22 | class Mimic:
23 | _testing = False
24 |
25 | def __init__(self, X: pd.DataFrame, model_func: Callable):
26 | self.explainer = MimicExplainer(
27 | model_func,
28 | X,
29 | LGBMExplainableModel,
30 | augment_data=False,
31 | is_function=True,
32 | )
33 |
34 | def explain(self, X: pd.DataFrame) -> pd.DataFrame:
35 | return pd.DataFrame(
36 | self.explainer.explain_local(X).local_importance_values,
37 | columns=X.columns,
38 | index=X.index,
39 | )
40 |
41 | @staticmethod
42 | def augment(
43 | df: pd.DataFrame, schema: Schema, model_type: ModelTypes
44 | ) -> Tuple[pd.DataFrame, Schema]:
45 | features = schema.feature_column_names
46 | X = df[features]
47 |
48 | if X.shape[1] == 0:
49 | return df, schema
50 |
51 | if model_type in CATEGORICAL_MODEL_TYPES:
52 | if not schema.prediction_score_column_name:
53 | raise ValueError(
54 | "To calculate surrogate explainability, "
55 | f"prediction_score_column_name must be specified in schema for {model_type}."
56 | )
57 |
58 | y_col_name = schema.prediction_score_column_name
59 | y = df[y_col_name].to_numpy()
60 |
61 | _min, _max = np.min(y), np.max(y)
62 | if not 0 <= _min <= 1 or not 0 <= _max <= 1:
63 | raise ValueError(
64 | f"To calculate surrogate explainability for {model_type}, "
65 | f"prediction scores must be between 0 and 1, but current "
66 | f"prediction scores range from {_min} to {_max}."
67 | )
68 |
69 | # model func requires 1 positional argument
70 | def model_func(_): # type: ignore
71 | return np.column_stack((1 - y, y))
72 |
73 | elif model_type in NUMERIC_MODEL_TYPES:
74 | y_col_name = schema.prediction_label_column_name
75 | if schema.prediction_score_column_name is not None:
76 | y_col_name = schema.prediction_score_column_name
77 | y = df[y_col_name].to_numpy()
78 |
79 | _finite_count = np.isfinite(y).sum()
80 | if len(y) - _finite_count:
81 | raise ValueError(
82 | f"To calculate surrogate explainability for {model_type}, "
83 | f"predictions must not contain NaN or infinite values, but "
84 | f"{len(y) - _finite_count} NaN or infinite value(s) are found in {y_col_name}."
85 | )
86 |
87 | # model func requires 1 positional argument
88 | def model_func(_): # type: ignore
89 | return y
90 |
91 | else:
92 | raise ValueError(
93 | "Surrogate explainability is not supported for the specified "
94 | f"model type {model_type}."
95 | )
96 |
97 | # Column name mapping between features and feature importance values.
98 | # This is used to augment the schema.
99 | col_map = {
100 | ft: f"{''.join(random.choices(string.ascii_letters, k=8))}"
101 | for ft in features
102 | }
103 | aug_schema = replace(schema, shap_values_column_names=col_map)
104 |
105 | # Limit the total number of "cells" to 20M, unless it results in too few or
106 | # too many rows. This is done to keep the runtime low. Records not sampled
107 | # have feature importance values set to 0.
108 | samp_size = min(
109 | len(X), min(100_000, max(1_000, 20_000_000 // X.shape[1]))
110 | )
111 |
112 | if samp_size < len(X):
113 | _mask = np.zeros(len(X), dtype=int)
114 | _mask[:samp_size] = 1
115 | np.random.shuffle(_mask)
116 | _mask = _mask.astype(bool)
117 | X = X[_mask]
118 | y = y[_mask]
119 |
120 | # Replace all pd.NA values with np.nan values
121 | for col in X.columns:
122 | if X[col].isna().any():
123 | X[col] = X[col].astype(object).where(~X[col].isna(), np.nan)
124 |
125 | # Apply integer encoding to non-numeric columns.
126 | # Currently training and explaining detasets are the same, but
127 | # this can be changed in the future. The student model can be
128 | # fitted on a much larger dataset since it takes a lot less time.
129 | X = pd.concat(
130 | [
131 | X.select_dtypes(exclude=[object, "string"]),
132 | pd.DataFrame(
133 | {
134 | name: LabelEncoder().fit_transform(data)
135 | for name, data in X.select_dtypes(
136 | include=[object, "string"]
137 | ).items()
138 | },
139 | index=X.index,
140 | ),
141 | ],
142 | axis=1,
143 | )
144 |
145 | aug_df = pd.concat(
146 | [
147 | df,
148 | Mimic(X, model_func).explain(X).rename(col_map, axis=1),
149 | ],
150 | axis=1,
151 | )
152 |
153 | # Fill null with zero so they're not counted as missing records by server
154 | if not Mimic._testing:
155 | aug_df.fillna({c: 0 for c in col_map.values()}, inplace=True)
156 |
157 | return (
158 | aug_df,
159 | aug_schema,
160 | )
161 |
--------------------------------------------------------------------------------
/src/arize/pandas/tracing/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Arize-ai/client_python/8f72212f42490baa7a90831c591b83ab6173ba83/src/arize/pandas/tracing/__init__.py
--------------------------------------------------------------------------------
/src/arize/pandas/tracing/constants.py:
--------------------------------------------------------------------------------
1 | # The defualt format used to parse datetime objects from strings
2 | DEFAULT_DATETIME_FMT = "%Y-%m-%dT%H:%M:%S.%f+00:00"
3 | # Minumum/Maximum number of characters for span/trace/parent ids in spans
4 | SPAN_ID_MIN_STR_LENGTH = 12
5 | SPAN_ID_MAX_STR_LENGTH = 128
6 | # Minumum/Maximum number of characters for span name
7 | SPAN_NAME_MIN_STR_LENGTH = 0
8 | SPAN_NAME_MAX_STR_LENGTH = 50
9 | # Minumum/Maximum number of characters for span status message
10 | SPAN_STATUS_MSG_MIN_STR_LENGTH = 0
11 | SPAN_STATUS_MSG_MAX_STR_LENGTH = 10_000
12 | # Minumum/Maximum number of characters for span event name
13 | SPAN_EVENT_NAME_MAX_STR_LENGTH = 100
14 | # Minumum/Maximum number of characters for span event attributes
15 | SPAN_EVENT_ATTRS_MAX_STR_LENGTH = 10_000
16 | # Maximum number of characters for span kind
17 | SPAN_KIND_MAX_STR_LENGTH = 100
18 | SPAN_EXCEPTION_TYPE_MAX_STR_LENGTH = 100
19 | SPAN_EXCEPTION_MESSAGE_MAX_STR_LENGTH = 100
20 | SPAN_EXCEPTION_STACK_TRACE_MAX_STR_LENGTH = 10_000
21 | SPAN_IO_VALUE_MAX_STR_LENGTH = 4_000_000
22 | SPAN_IO_MIME_TYPE_MAX_STR_LENGTH = 100
23 | SPAN_EMBEDDING_NAME_MAX_STR_LENGTH = 100
24 | SPAN_EMBEDDING_TEXT_MAX_STR_LENGTH = 4_000_000
25 | SPAN_LLM_MODEL_NAME_MAX_STR_LENGTH = 100
26 | SPAN_LLM_MESSAGE_ROLE_MAX_STR_LENGTH = 100
27 | SPAN_LLM_MESSAGE_CONTENT_MAX_STR_LENGTH = 4_000_000
28 | SPAN_LLM_TOOL_CALL_FUNCTION_NAME_MAX_STR_LENGTH = 500
29 | SPAN_LLM_PROMPT_TEMPLATE_MAX_STR_LENGTH = 4_000_000
30 | SPAN_LLM_PROMPT_TEMPLATE_VARIABLES_MAX_STR_LENGTH = 10_000
31 | SPAN_LLM_PROMPT_TEMPLATE_VERSION_MAX_STR_LENGTH = 100
32 | SPAN_TOOL_NAME_MAX_STR_LENGTH = 100
33 | SPAN_TOOL_DESCRIPTION_MAX_STR_LENGTH = 1_000
34 | SPAN_TOOL_PARAMETERS_MAX_STR_LENGTH = 1_000
35 | SPAN_RERANKER_QUERY_MAX_STR_LENGTH = 10_000
36 | SPAN_RERANKER_MODEL_NAME_MAX_STR_LENGTH = 100
37 | SPAN_DOCUMENT_ID_MAX_STR_LENGTH = 100
38 | SPAN_DOCUMENT_CONTENT_MAX_STR_LENGTH = 4_000_000
39 | JSON_STRING_MAX_STR_LENGTH = 4_000_000
40 | # Eval related constants
41 | EVAL_LABEL_MIN_STR_LENGTH = 1 # we do not accept empty strings
42 | EVAL_LABEL_MAX_STR_LENGTH = 100
43 | EVAL_EXPLANATION_MAX_STR_LENGTH = 10_000
44 |
45 | # Annotation related constants
46 | ANNOTATION_LABEL_MIN_STR_LENGTH = 1
47 | ANNOTATION_LABEL_MAX_STR_LENGTH = 100 # Max length for annotation label string
48 | ANNOTATION_UPDATED_BY_MAX_STR_LENGTH = 100
49 | ANNOTATION_NOTES_MAX_STR_LENGTH = (
50 | 10_000 # Max length for annotation note string
51 | )
52 |
53 | # Maximum number of characters for session/user ids in spans
54 | SESSION_ID_MAX_STR_LENGTH = 128
55 | USER_ID_MAX_STR_LENGTH = 128
56 |
--------------------------------------------------------------------------------
/src/arize/pandas/tracing/types.py:
--------------------------------------------------------------------------------
1 | from enum import Enum, unique
2 |
3 |
4 | @unique
5 | class StatusCodes(Enum):
6 | UNSET = 0
7 | OK = 1
8 | ERROR = 2
9 |
10 | @classmethod
11 | def list_codes(cls):
12 | return [t.name for t in cls]
13 |
--------------------------------------------------------------------------------
/src/arize/pandas/tracing/utils.py:
--------------------------------------------------------------------------------
1 | import json
2 | from datetime import datetime
3 | from typing import Any, Dict, Iterable, List, Optional, Union
4 |
5 | import numpy as np
6 | import pandas as pd
7 |
8 | from arize.utils.logging import logger
9 |
10 | from .columns import SPAN_OPENINFERENCE_COLUMNS, SpanColumnDataType
11 |
12 |
13 | def convert_timestamps(df: pd.DataFrame, fmt: str = "") -> pd.DataFrame:
14 | time_cols = [
15 | col
16 | for col in SPAN_OPENINFERENCE_COLUMNS
17 | if col.data_type == SpanColumnDataType.TIMESTAMP
18 | ]
19 | for col in time_cols:
20 | df[col.name] = df[col.name].apply(lambda dt: _datetime_to_ns(dt, fmt))
21 | return df
22 |
23 |
24 | def _datetime_to_ns(dt: Union[str, datetime], fmt: str) -> int:
25 | if isinstance(dt, str):
26 | try:
27 | ts = int(datetime.timestamp(datetime.strptime(dt, fmt)) * 1e9)
28 | except Exception as e:
29 | logger.error(
30 | f"Error parsing string '{dt}' to timestamp in nanoseconds "
31 | f"using the format '{fmt}': {e}"
32 | )
33 | raise e
34 | return ts
35 | elif isinstance(dt, datetime):
36 | try:
37 | ts = int(datetime.timestamp(dt) * 1e9)
38 | except Exception as e:
39 | logger.error(
40 | f"Error converting datetime object to nanoseconds: {e}"
41 | )
42 | raise e
43 | return ts
44 | elif isinstance(dt, (pd.Timestamp, pd.DatetimeIndex)):
45 | try:
46 | ts = int(datetime.timestamp(dt.to_pydatetime()) * 1e9)
47 | except Exception as e:
48 | logger.error(
49 | f"Error converting pandas Timestamp to nanoseconds: {e}"
50 | )
51 | raise e
52 | return ts
53 | elif isinstance(dt, (int, float)):
54 | # Assume value already in nanoseconds,
55 | # validate timestamps in validate_values
56 | return dt
57 | else:
58 | e = TypeError(f"Cannot convert type {type(dt)} to nanoseconds")
59 | logger.error(f"Error converting pandas Timestamp to nanoseconds: {e}")
60 | raise e
61 |
62 |
63 | def jsonify_dictionaries(df: pd.DataFrame) -> pd.DataFrame:
64 | # NOTE: numpy arrays are not json serializable. Hence, we assume the
65 | # embeddings come as lists, not arrays
66 | dict_cols = [
67 | col
68 | for col in SPAN_OPENINFERENCE_COLUMNS
69 | if col.data_type == SpanColumnDataType.DICT
70 | ]
71 | list_of_dict_cols = [
72 | col
73 | for col in SPAN_OPENINFERENCE_COLUMNS
74 | if col.data_type == SpanColumnDataType.LIST_DICT
75 | ]
76 | for col in dict_cols:
77 | col_name = col.name
78 | if col_name not in df.columns:
79 | logger.debug(f"passing on {col_name}")
80 | continue
81 | logger.debug(f"jsonifying {col_name}")
82 | df[col_name] = df[col_name].apply(lambda d: _jsonify_dict(d))
83 |
84 | for col in list_of_dict_cols:
85 | col_name = col.name
86 | if col_name not in df.columns:
87 | logger.debug(f"passing on {col_name}")
88 | continue
89 | logger.debug(f"jsonifying {col_name}")
90 | df[col_name] = df[col_name].apply(
91 | lambda list_of_dicts: _jsonify_list_of_dicts(list_of_dicts)
92 | )
93 | return df
94 |
95 |
96 | def _jsonify_list_of_dicts(
97 | list_of_dicts: Optional[Iterable[Dict[str, Any]]],
98 | ) -> Optional[List[str]]:
99 | if not isinstance(list_of_dicts, Iterable) and isMissingValue(
100 | list_of_dicts
101 | ):
102 | return None
103 | list_of_json = []
104 | for d in list_of_dicts:
105 | list_of_json.append(_jsonify_dict(d))
106 | return list_of_json
107 |
108 |
109 | def _jsonify_dict(d: Optional[Dict[str, Any]]) -> Optional[str]:
110 | if d is None:
111 | return
112 | if isMissingValue(d):
113 | return None
114 | d = d.copy() # avoid side effects
115 | for k, v in d.items():
116 | if isinstance(v, np.ndarray):
117 | d[k] = v.tolist()
118 | if isinstance(v, dict):
119 | d[k] = _jsonify_dict(v)
120 | return json.dumps(d, ensure_ascii=False)
121 |
122 |
123 | # Defines what is considered a missing value
124 | def isMissingValue(value: Any) -> bool:
125 | assumed_missing_values = (
126 | np.inf,
127 | -np.inf,
128 | )
129 | return value in assumed_missing_values or pd.isna(value)
130 |
131 |
132 | def extract_project_name_from_params(
133 | model_id: Optional[str] = None, project_name: Optional[str] = None
134 | ):
135 | project_name_param = model_id
136 |
137 | # Prefer project_name over model_id
138 | if project_name and project_name.strip():
139 | project_name_param = project_name
140 |
141 | return project_name_param
142 |
--------------------------------------------------------------------------------
/src/arize/pandas/tracing/validation/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Arize-ai/client_python/8f72212f42490baa7a90831c591b83ab6173ba83/src/arize/pandas/tracing/validation/__init__.py
--------------------------------------------------------------------------------
/src/arize/pandas/tracing/validation/annotations/__init__.py:
--------------------------------------------------------------------------------
1 | # Empty file
2 |
--------------------------------------------------------------------------------
/src/arize/pandas/tracing/validation/annotations/annotations_validation.py:
--------------------------------------------------------------------------------
1 | from itertools import chain
2 | from typing import List
3 |
4 | import pandas as pd
5 |
6 | # Keep common validation imports
7 | from arize.pandas.tracing.columns import SPAN_SPAN_ID_COL
8 |
9 | # Import annotation-specific validation modules (to be created)
10 | from arize.pandas.tracing.validation.annotations import (
11 | dataframe_form_validation as df_validation,
12 | )
13 |
14 | # Import the new annotation value validation module
15 | from arize.pandas.tracing.validation.annotations import value_validation
16 | from arize.pandas.tracing.validation.common import (
17 | argument_validation as common_arg_validation,
18 | )
19 | from arize.pandas.tracing.validation.common import (
20 | dataframe_form_validation as common_df_validation,
21 | )
22 | from arize.pandas.tracing.validation.common import (
23 | value_validation as common_value_validation,
24 | )
25 | from arize.pandas.validation import errors as err
26 |
27 |
28 | def validate_argument_types(
29 | annotations_dataframe: pd.DataFrame,
30 | project_name: str,
31 | ) -> List[err.ValidationError]:
32 | """Validates argument types for log_annotations."""
33 | checks = chain(
34 | common_arg_validation._check_field_convertible_to_str(project_name),
35 | common_arg_validation._check_dataframe_type(
36 | annotations_dataframe
37 | ), # Use renamed parameter
38 | )
39 | return list(checks)
40 |
41 |
42 | def validate_dataframe_form(
43 | annotations_dataframe: pd.DataFrame,
44 | ) -> List[err.ValidationError]:
45 | """Validates the form/structure of the annotation dataframe."""
46 | # Call annotation-specific function (to be created)
47 | df_validation._log_info_dataframe_extra_column_names(annotations_dataframe)
48 | checks = chain(
49 | # Common checks remain the same
50 | common_df_validation._check_dataframe_index(annotations_dataframe),
51 | common_df_validation._check_dataframe_required_column_set(
52 | annotations_dataframe, required_columns=[SPAN_SPAN_ID_COL.name]
53 | ),
54 | common_df_validation._check_dataframe_for_duplicate_columns(
55 | annotations_dataframe
56 | ),
57 | # Call annotation-specific content type check (to be created)
58 | df_validation._check_dataframe_column_content_type(
59 | annotations_dataframe
60 | ),
61 | )
62 | return list(checks)
63 |
64 |
65 | def validate_values(
66 | annotations_dataframe: pd.DataFrame,
67 | project_name: str,
68 | ) -> List[err.ValidationError]:
69 | """Validates the values within the annotation dataframe."""
70 | checks = chain(
71 | # Common checks remain the same
72 | common_value_validation._check_invalid_project_name(project_name),
73 | # Call annotation-specific value checks from the imported module
74 | value_validation._check_annotation_cols(annotations_dataframe),
75 | value_validation._check_annotation_columns_null_values(
76 | annotations_dataframe
77 | ),
78 | value_validation._check_annotation_notes_column(annotations_dataframe),
79 | )
80 | return list(checks)
81 |
--------------------------------------------------------------------------------
/src/arize/pandas/tracing/validation/annotations/value_validation.py:
--------------------------------------------------------------------------------
1 | import re
2 | from itertools import chain
3 | from typing import List
4 |
5 | import pandas as pd
6 |
7 | import arize.pandas.tracing.constants as tracing_constants
8 |
9 | # Import annotation-specific and common column constants
10 | from arize.pandas.tracing.columns import (
11 | ANNOTATION_COLUMN_PREFIX,
12 | ANNOTATION_LABEL_SUFFIX,
13 | ANNOTATION_NAME_PATTERN,
14 | ANNOTATION_NOTES_COLUMN_NAME,
15 | ANNOTATION_SCORE_SUFFIX,
16 | ANNOTATION_UPDATED_AT_SUFFIX,
17 | ANNOTATION_UPDATED_BY_SUFFIX,
18 | )
19 |
20 | # Import common validation errors and functions
21 | from arize.pandas.tracing.validation.common import errors as tracing_err
22 | from arize.pandas.tracing.validation.common import (
23 | value_validation as common_value_validation,
24 | )
25 | from arize.pandas.validation import errors as err
26 |
27 |
28 | def _check_annotation_cols(
29 | dataframe: pd.DataFrame,
30 | ) -> List[err.ValidationError]:
31 | """Checks value length and validity for columns matching annotation patterns."""
32 | checks = []
33 | for col in dataframe.columns:
34 | if col.endswith(ANNOTATION_LABEL_SUFFIX):
35 | checks.append(
36 | common_value_validation._check_string_column_value_length(
37 | df=dataframe,
38 | col_name=col,
39 | min_len=tracing_constants.ANNOTATION_LABEL_MIN_STR_LENGTH,
40 | max_len=tracing_constants.ANNOTATION_LABEL_MAX_STR_LENGTH,
41 | is_required=False, # Individual columns are not required, null check handles completeness
42 | )
43 | )
44 | elif col.endswith(ANNOTATION_SCORE_SUFFIX):
45 | checks.append(
46 | common_value_validation._check_float_column_valid_numbers(
47 | df=dataframe,
48 | col_name=col,
49 | )
50 | )
51 | elif col.endswith(ANNOTATION_UPDATED_BY_SUFFIX):
52 | checks.append(
53 | common_value_validation._check_string_column_value_length(
54 | df=dataframe,
55 | col_name=col,
56 | min_len=1,
57 | max_len=tracing_constants.ANNOTATION_UPDATED_BY_MAX_STR_LENGTH,
58 | is_required=False,
59 | )
60 | )
61 | elif col.endswith(ANNOTATION_UPDATED_AT_SUFFIX):
62 | checks.append(
63 | common_value_validation._check_value_timestamp(
64 | df=dataframe,
65 | col_name=col,
66 | is_required=False, # updated_at is not strictly required per row
67 | )
68 | )
69 | # No check for ANNOTATION_NOTES_COLUMN_NAME here, handled by _check_annotation_notes_column
70 | return list(chain(*checks))
71 |
72 |
73 | def _check_annotation_columns_null_values(
74 | dataframe: pd.DataFrame,
75 | ) -> List[err.ValidationError]:
76 | """Checks that for a given annotation name, at least one of label or score is non-null per row."""
77 | invalid_annotation_names = []
78 | annotation_names = set()
79 | # Find all unique annotation names from column headers
80 | for col in dataframe.columns:
81 | match = re.match(ANNOTATION_NAME_PATTERN, col)
82 | if match:
83 | annotation_names.add(match.group(1))
84 |
85 | for ann_name in annotation_names:
86 | label_col = (
87 | f"{ANNOTATION_COLUMN_PREFIX}{ann_name}{ANNOTATION_LABEL_SUFFIX}"
88 | )
89 | score_col = (
90 | f"{ANNOTATION_COLUMN_PREFIX}{ann_name}{ANNOTATION_SCORE_SUFFIX}"
91 | )
92 |
93 | label_exists = label_col in dataframe.columns
94 | score_exists = score_col in dataframe.columns
95 |
96 | # Check only if both label and score columns exist for this name
97 | # If only one exists, its presence is sufficient
98 | if label_exists and score_exists:
99 | # Find rows where BOTH label and score are null
100 | condition = (
101 | dataframe[label_col].isnull() & dataframe[score_col].isnull()
102 | )
103 | if condition.any():
104 | invalid_annotation_names.append(ann_name)
105 | # Check if only label exists but it's always null
106 | elif label_exists and not score_exists:
107 | if dataframe[label_col].isnull().all():
108 | invalid_annotation_names.append(ann_name)
109 | # Check if only score exists but it's always null
110 | elif not label_exists and score_exists:
111 | if dataframe[score_col].isnull().all():
112 | invalid_annotation_names.append(ann_name)
113 |
114 | # Use set to report each name only once
115 | unique_invalid_names = sorted(list(set(invalid_annotation_names)))
116 | if unique_invalid_names:
117 | return [
118 | tracing_err.InvalidNullAnnotationLabelAndScore(
119 | annotation_names=unique_invalid_names
120 | )
121 | ]
122 | return []
123 |
124 |
125 | def _check_annotation_notes_column(
126 | dataframe: pd.DataFrame,
127 | ) -> List[err.ValidationError]:
128 | """Checks the value length for the optional annotation.notes column (raw string)."""
129 | col_name = ANNOTATION_NOTES_COLUMN_NAME
130 | if col_name in dataframe.columns:
131 | # Validate the length of the raw string
132 | return list(
133 | chain(
134 | *common_value_validation._check_string_column_value_length(
135 | df=dataframe,
136 | col_name=col_name,
137 | min_len=0, # Allow empty notes
138 | max_len=tracing_constants.ANNOTATION_NOTES_MAX_STR_LENGTH,
139 | is_required=False,
140 | )
141 | )
142 | )
143 | return []
144 |
--------------------------------------------------------------------------------
/src/arize/pandas/tracing/validation/common/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Arize-ai/client_python/8f72212f42490baa7a90831c591b83ab6173ba83/src/arize/pandas/tracing/validation/common/__init__.py
--------------------------------------------------------------------------------
/src/arize/pandas/tracing/validation/common/argument_validation.py:
--------------------------------------------------------------------------------
1 | from typing import Any, List, Optional
2 |
3 | import pandas as pd
4 |
5 | from arize.pandas.tracing.validation.common import errors as tracing_err
6 | from arize.pandas.validation import errors as err
7 |
8 |
9 | def _check_field_convertible_to_str(
10 | project_name: str,
11 | model_version: Optional[str] = None,
12 | ) -> List[err.InvalidFieldTypeConversion]:
13 | wrong_fields = []
14 | if project_name is not None and not isinstance(project_name, str):
15 | try:
16 | str(project_name)
17 | except Exception:
18 | wrong_fields.append("project_name")
19 | if model_version is not None and not isinstance(model_version, str):
20 | try:
21 | str(model_version)
22 | except Exception:
23 | wrong_fields.append("model_version")
24 |
25 | if wrong_fields:
26 | return [err.InvalidFieldTypeConversion(wrong_fields, "string")]
27 | return []
28 |
29 |
30 | def _check_dataframe_type(
31 | dataframe,
32 | ) -> List[tracing_err.InvalidTypeArgument]:
33 | if not isinstance(dataframe, pd.DataFrame):
34 | return [
35 | tracing_err.InvalidTypeArgument(
36 | wrong_arg=dataframe,
37 | arg_name="dataframe",
38 | arg_type="pandas DataFrame",
39 | )
40 | ]
41 | return []
42 |
43 |
44 | def _check_datetime_format_type(
45 | dt_fmt: Any,
46 | ) -> List[tracing_err.InvalidTypeArgument]:
47 | if not isinstance(dt_fmt, str):
48 | return [
49 | tracing_err.InvalidTypeArgument(
50 | wrong_arg=dt_fmt,
51 | arg_name="dateTime format",
52 | arg_type="string",
53 | )
54 | ]
55 | return []
56 |
--------------------------------------------------------------------------------
/src/arize/pandas/tracing/validation/common/dataframe_form_validation.py:
--------------------------------------------------------------------------------
1 | from typing import List
2 |
3 | import pandas as pd
4 |
5 | from arize.pandas.tracing.validation.common import errors as tracing_err
6 | from arize.pandas.validation import errors as err
7 |
8 |
9 | def _check_dataframe_index(
10 | dataframe: pd.DataFrame,
11 | ) -> List[err.InvalidDataFrameIndex]:
12 | if (dataframe.index != dataframe.reset_index(drop=True).index).any():
13 | return [err.InvalidDataFrameIndex()]
14 | return []
15 |
16 |
17 | def _check_dataframe_required_column_set(
18 | df: pd.DataFrame,
19 | required_columns: List[str],
20 | ) -> List[tracing_err.InvalidDataFrameMissingColumns]:
21 | existing_columns = set(df.columns)
22 | missing_cols = []
23 | for col in required_columns:
24 | if col not in existing_columns:
25 | missing_cols.append(col)
26 |
27 | if missing_cols:
28 | return [
29 | tracing_err.InvalidDataFrameMissingColumns(
30 | missing_cols=missing_cols
31 | )
32 | ]
33 | return []
34 |
35 |
36 | def _check_dataframe_for_duplicate_columns(
37 | df: pd.DataFrame,
38 | ) -> List[tracing_err.InvalidDataFrameDuplicateColumns]:
39 | # Get the duplicated column names from the dataframe
40 | duplicate_columns = df.columns[df.columns.duplicated()]
41 | if not duplicate_columns.empty:
42 | return [tracing_err.InvalidDataFrameDuplicateColumns(duplicate_columns)]
43 | return []
44 |
--------------------------------------------------------------------------------
/src/arize/pandas/tracing/validation/evals/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Arize-ai/client_python/8f72212f42490baa7a90831c591b83ab6173ba83/src/arize/pandas/tracing/validation/evals/__init__.py
--------------------------------------------------------------------------------
/src/arize/pandas/tracing/validation/evals/dataframe_form_validation.py:
--------------------------------------------------------------------------------
1 | import re
2 | from typing import List
3 |
4 | import pandas as pd
5 |
6 | from arize.pandas.tracing.columns import (
7 | EVAL_COLUMN_PATTERN,
8 | EVAL_EXPLANATION_PATTERN,
9 | EVAL_LABEL_PATTERN,
10 | EVAL_SCORE_PATTERN,
11 | SPAN_SPAN_ID_COL,
12 | )
13 | from arize.pandas.tracing.utils import isMissingValue
14 | from arize.pandas.tracing.validation.common import errors as tracing_err
15 | from arize.utils.logging import log_a_list, logger
16 |
17 |
18 | def _log_info_dataframe_extra_column_names(
19 | df: pd.DataFrame,
20 | ) -> None:
21 | if df is None:
22 | return None
23 | irrelevant_columns = [
24 | col
25 | for col in df.columns
26 | if not (
27 | pd.Series(col).str.match(EVAL_COLUMN_PATTERN).any()
28 | or col == SPAN_SPAN_ID_COL.name
29 | )
30 | ]
31 | if irrelevant_columns:
32 | logger.info(
33 | "The following columns do not follow the evaluation column naming convention "
34 | f"and will be ignored: {log_a_list(list_of_str=irrelevant_columns, join_word='and')}. "
35 | "Evaluation columns must be named as follows: "
36 | "- eval..label"
37 | "- eval..score"
38 | "- eval..explanation"
39 | )
40 | return None
41 |
42 |
43 | def _check_dataframe_column_content_type(
44 | df: pd.DataFrame,
45 | ) -> List[tracing_err.InvalidDataFrameColumnContentTypes]:
46 | wrong_labels_cols = []
47 | wrong_scores_cols = []
48 | wrong_explanations_cols = []
49 | errors = []
50 | eval_label_re = re.compile(EVAL_LABEL_PATTERN)
51 | eval_score_re = re.compile(EVAL_SCORE_PATTERN)
52 | eval_explanation_re = re.compile(EVAL_EXPLANATION_PATTERN)
53 | for column in df.columns:
54 | if column == SPAN_SPAN_ID_COL.name and not all(
55 | isinstance(value, str) for value in df[column]
56 | ):
57 | errors.append(
58 | tracing_err.InvalidDataFrameColumnContentTypes(
59 | invalid_type_cols=[SPAN_SPAN_ID_COL.name],
60 | expected_type="string",
61 | ),
62 | )
63 | if eval_label_re.match(column):
64 | if not all(
65 | isinstance(value, str) or isMissingValue(value)
66 | for value in df[column]
67 | ):
68 | wrong_labels_cols.append(column)
69 | elif eval_score_re.match(column):
70 | if not all(
71 | isinstance(value, (int, float)) or isMissingValue(value)
72 | for value in df[column]
73 | ):
74 | wrong_scores_cols.append(column)
75 | elif eval_explanation_re.match(column) and not all(
76 | isinstance(value, str) or isMissingValue(value)
77 | for value in df[column]
78 | ):
79 | wrong_explanations_cols.append(column)
80 |
81 | if wrong_labels_cols:
82 | errors.append(
83 | tracing_err.InvalidDataFrameColumnContentTypes(
84 | invalid_type_cols=wrong_labels_cols,
85 | expected_type="strings",
86 | ),
87 | )
88 | if wrong_scores_cols:
89 | errors.append(
90 | tracing_err.InvalidDataFrameColumnContentTypes(
91 | invalid_type_cols=wrong_scores_cols,
92 | expected_type="ints or floats",
93 | ),
94 | )
95 | if wrong_explanations_cols:
96 | errors.append(
97 | tracing_err.InvalidDataFrameColumnContentTypes(
98 | invalid_type_cols=wrong_explanations_cols,
99 | expected_type="strings",
100 | ),
101 | )
102 | return errors
103 |
--------------------------------------------------------------------------------
/src/arize/pandas/tracing/validation/evals/evals_validation.py:
--------------------------------------------------------------------------------
1 | from itertools import chain
2 | from typing import List, Optional
3 |
4 | import pandas as pd
5 |
6 | from arize.pandas.tracing.columns import SPAN_SPAN_ID_COL
7 | from arize.pandas.tracing.validation.common import (
8 | argument_validation as common_arg_validation,
9 | )
10 | from arize.pandas.tracing.validation.common import (
11 | dataframe_form_validation as common_df_validation,
12 | )
13 | from arize.pandas.tracing.validation.common import (
14 | value_validation as common_value_validation,
15 | )
16 | from arize.pandas.tracing.validation.evals import (
17 | dataframe_form_validation as df_validation,
18 | )
19 | from arize.pandas.tracing.validation.evals import value_validation
20 | from arize.pandas.validation import errors as err
21 |
22 |
23 | def validate_argument_types(
24 | evals_dataframe: pd.DataFrame,
25 | project_name: str,
26 | model_version: Optional[str] = None,
27 | ) -> List[err.ValidationError]:
28 | checks = chain(
29 | common_arg_validation._check_field_convertible_to_str(
30 | project_name, model_version
31 | ),
32 | common_arg_validation._check_dataframe_type(evals_dataframe),
33 | )
34 | return list(checks)
35 |
36 |
37 | def validate_dataframe_form(
38 | evals_dataframe: pd.DataFrame,
39 | ) -> List[err.ValidationError]:
40 | df_validation._log_info_dataframe_extra_column_names(evals_dataframe)
41 | checks = chain(
42 | # Common
43 | common_df_validation._check_dataframe_index(evals_dataframe),
44 | common_df_validation._check_dataframe_required_column_set(
45 | evals_dataframe, required_columns=[SPAN_SPAN_ID_COL.name]
46 | ),
47 | common_df_validation._check_dataframe_for_duplicate_columns(
48 | evals_dataframe
49 | ),
50 | # Eval specific
51 | df_validation._check_dataframe_column_content_type(evals_dataframe),
52 | )
53 | return list(checks)
54 |
55 |
56 | def validate_values(
57 | evals_dataframe: pd.DataFrame,
58 | project_name: str,
59 | model_version: Optional[str] = None,
60 | ) -> List[err.ValidationError]:
61 | checks = chain(
62 | # Common
63 | common_value_validation._check_invalid_project_name(project_name),
64 | common_value_validation._check_invalid_model_version(model_version),
65 | # Eval specific
66 | value_validation._check_eval_cols(evals_dataframe),
67 | value_validation._check_eval_columns_null_values(evals_dataframe),
68 | )
69 | return list(checks)
70 |
--------------------------------------------------------------------------------
/src/arize/pandas/tracing/validation/evals/value_validation.py:
--------------------------------------------------------------------------------
1 | import re
2 | from itertools import chain
3 | from typing import List
4 |
5 | import pandas as pd
6 |
7 | import arize.pandas.tracing.constants as tracing_constants
8 | from arize.pandas.tracing.columns import (
9 | EVAL_COLUMN_PREFIX,
10 | EVAL_EXPLANATION_SUFFIX,
11 | EVAL_LABEL_SUFFIX,
12 | EVAL_NAME_PATTERN,
13 | EVAL_SCORE_SUFFIX,
14 | )
15 | from arize.pandas.tracing.validation.common import errors as tracing_err
16 | from arize.pandas.tracing.validation.common import value_validation
17 | from arize.pandas.validation import errors as err
18 |
19 |
20 | def _check_eval_cols(
21 | dataframe: pd.DataFrame,
22 | ) -> List[err.ValidationError]:
23 | checks = []
24 | for col in dataframe.columns:
25 | if col.endswith(EVAL_LABEL_SUFFIX):
26 | checks.append(
27 | value_validation._check_string_column_value_length(
28 | df=dataframe,
29 | col_name=col,
30 | min_len=tracing_constants.EVAL_LABEL_MIN_STR_LENGTH,
31 | max_len=tracing_constants.EVAL_LABEL_MAX_STR_LENGTH,
32 | is_required=False,
33 | )
34 | )
35 | elif col.endswith(EVAL_SCORE_SUFFIX):
36 | checks.append(
37 | value_validation._check_float_column_valid_numbers(
38 | df=dataframe,
39 | col_name=col,
40 | )
41 | )
42 | elif col.endswith(EVAL_EXPLANATION_SUFFIX):
43 | checks.append(
44 | value_validation._check_string_column_value_length(
45 | df=dataframe,
46 | col_name=col,
47 | min_len=0,
48 | max_len=tracing_constants.EVAL_EXPLANATION_MAX_STR_LENGTH,
49 | is_required=False,
50 | )
51 | )
52 | return list(chain(*checks))
53 |
54 |
55 | # Evals are valid if they are entirely null (no label, score, or explanation) since this
56 | # represents a span without an eval. Evals are also valid if at least one of label or score
57 | # is not null
58 | def _check_eval_columns_null_values(
59 | dataframe: pd.DataFrame,
60 | ) -> List[err.ValidationError]:
61 | invalid_eval_names = []
62 | eval_names = set()
63 | for col in dataframe.columns:
64 | match = re.match(EVAL_NAME_PATTERN, col)
65 | if match:
66 | eval_names.add(match.group(1))
67 |
68 | for eval_name in eval_names:
69 | label_col = f"{EVAL_COLUMN_PREFIX}{eval_name}{EVAL_LABEL_SUFFIX}"
70 | score_col = f"{EVAL_COLUMN_PREFIX}{eval_name}{EVAL_SCORE_SUFFIX}"
71 | explanation_col = (
72 | f"{EVAL_COLUMN_PREFIX}{eval_name}{EVAL_EXPLANATION_SUFFIX}"
73 | )
74 | columns_to_check = []
75 |
76 | if label_col in dataframe.columns:
77 | columns_to_check.append(label_col)
78 | if score_col in dataframe.columns:
79 | columns_to_check.append(score_col)
80 |
81 | # If there are explanations, they cannot be orphan ()
82 | if explanation_col in dataframe.columns:
83 | condition = (
84 | dataframe[columns_to_check].isnull().all(axis=1)
85 | & ~dataframe[explanation_col].isnull()
86 | )
87 | if condition.any():
88 | invalid_eval_names.append(eval_name)
89 |
90 | if invalid_eval_names:
91 | return [
92 | tracing_err.InvalidNullEvalLabelAndScore(
93 | eval_names=invalid_eval_names
94 | )
95 | ]
96 | return []
97 |
--------------------------------------------------------------------------------
/src/arize/pandas/tracing/validation/metadata/__init__.py:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/src/arize/pandas/tracing/validation/metadata/argument_validation.py:
--------------------------------------------------------------------------------
1 | from typing import List
2 |
3 | import pandas as pd
4 |
5 | from ....validation.errors import ValidationError
6 |
7 |
8 | class MetadataArgumentError(ValidationError):
9 | def __init__(self, message: str, resolution: str) -> None:
10 | self.message = message
11 | self.resolution = resolution
12 |
13 | def __repr__(self) -> str:
14 | return "Metadata_Argument_Error"
15 |
16 | def error_message(self) -> str:
17 | return f"{self.message} {self.resolution}"
18 |
19 |
20 | def validate_argument_types(
21 | metadata_dataframe, project_name
22 | ) -> List[ValidationError]:
23 | """
24 | Validates the types of arguments passed to update_spans_metadata.
25 |
26 | Args:
27 | metadata_dataframe: DataFrame with span IDs and patch documents
28 | project_name: Name of the project
29 |
30 | Returns:
31 | A list of validation errors, empty if none found
32 | """
33 | errors = []
34 |
35 | # Check metadata_dataframe type
36 | if not isinstance(metadata_dataframe, pd.DataFrame):
37 | errors.append(
38 | MetadataArgumentError(
39 | "metadata_dataframe must be a pandas DataFrame",
40 | "The metadata_dataframe argument must be a pandas DataFrame.",
41 | )
42 | )
43 |
44 | # Check project_name
45 | if not isinstance(project_name, str) or not project_name.strip():
46 | errors.append(
47 | MetadataArgumentError(
48 | "project_name must be a non-empty string",
49 | "The project_name argument must be a non-empty string.",
50 | )
51 | )
52 |
53 | return errors
54 |
--------------------------------------------------------------------------------
/src/arize/pandas/tracing/validation/metadata/dataframe_form_validation.py:
--------------------------------------------------------------------------------
1 | from typing import List
2 |
3 | from ....tracing.columns import SPAN_SPAN_ID_COL
4 | from ....validation.errors import ValidationError
5 |
6 |
7 | class MetadataFormError(ValidationError):
8 | def __init__(self, message: str, resolution: str) -> None:
9 | self.message = message
10 | self.resolution = resolution
11 |
12 | def __repr__(self) -> str:
13 | return "Metadata_Form_Error"
14 |
15 | def error_message(self) -> str:
16 | return f"{self.message} {self.resolution}"
17 |
18 |
19 | def validate_dataframe_form(
20 | metadata_dataframe, patch_document_column_name="patch_document"
21 | ) -> List[ValidationError]:
22 | """
23 | Validates the structure of the metadata update dataframe.
24 |
25 | Args:
26 | metadata_dataframe: DataFrame with span IDs and patch documents or attributes.metadata.* columns
27 | patch_document_column_name: Name of the column containing patch documents
28 |
29 | Returns:
30 | A list of validation errors, empty if none found
31 | """
32 | errors = []
33 |
34 | # Check for empty dataframe
35 | if metadata_dataframe.empty:
36 | errors.append(
37 | MetadataFormError(
38 | "metadata_dataframe is empty",
39 | "The metadata_dataframe is empty. No data to send.",
40 | )
41 | )
42 | return errors
43 |
44 | # Check for required span_id column
45 | if SPAN_SPAN_ID_COL.name not in metadata_dataframe.columns:
46 | errors.append(
47 | MetadataFormError(
48 | f"Missing required column: {SPAN_SPAN_ID_COL.name}",
49 | f"The metadata_dataframe must contain the span ID column: {SPAN_SPAN_ID_COL.name}.",
50 | )
51 | )
52 | return errors
53 |
54 | # Check for metadata columns - either patch_document or attributes.metadata.* columns
55 | has_patch_document = (
56 | patch_document_column_name in metadata_dataframe.columns
57 | )
58 | metadata_prefix = "attributes.metadata."
59 | metadata_columns = [
60 | col
61 | for col in metadata_dataframe.columns
62 | if col.startswith(metadata_prefix)
63 | ]
64 | has_metadata_fields = len(metadata_columns) > 0
65 |
66 | if not has_patch_document and not has_metadata_fields:
67 | errors.append(
68 | MetadataFormError(
69 | "Missing metadata columns",
70 | f"The metadata_dataframe must contain either the patch document column "
71 | f"'{patch_document_column_name}' or at least one column with the prefix "
72 | f"'{metadata_prefix}'.",
73 | )
74 | )
75 | return errors
76 |
77 | # Check for null values in required columns
78 | null_columns = []
79 |
80 | # Span ID cannot be null
81 | if metadata_dataframe[SPAN_SPAN_ID_COL.name].isna().any():
82 | null_columns.append(SPAN_SPAN_ID_COL.name)
83 |
84 | # If using patch_document, it cannot be null
85 | if (
86 | has_patch_document
87 | and metadata_dataframe[patch_document_column_name].isna().any()
88 | ):
89 | null_columns.append(patch_document_column_name)
90 |
91 | # If using metadata fields, check each one
92 | if has_metadata_fields:
93 | for col in metadata_columns:
94 | if (
95 | metadata_dataframe[col].isna().all()
96 | ): # All values in column are null
97 | null_columns.append(col)
98 |
99 | if null_columns:
100 | errors.append(
101 | MetadataFormError(
102 | f"Columns with null values: {', '.join(null_columns)}",
103 | f"The following columns cannot contain null values: {', '.join(null_columns)}.",
104 | )
105 | )
106 |
107 | return errors
108 |
--------------------------------------------------------------------------------
/src/arize/pandas/tracing/validation/spans/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Arize-ai/client_python/8f72212f42490baa7a90831c591b83ab6173ba83/src/arize/pandas/tracing/validation/spans/__init__.py
--------------------------------------------------------------------------------
/src/arize/pandas/tracing/validation/spans/spans_validation.py:
--------------------------------------------------------------------------------
1 | from itertools import chain
2 | from typing import List, Optional
3 |
4 | import pandas as pd
5 |
6 | from arize.pandas.tracing.columns import SPAN_OPENINFERENCE_REQUIRED_COLUMNS
7 | from arize.pandas.tracing.validation.common import (
8 | argument_validation as common_arg_validation,
9 | )
10 | from arize.pandas.tracing.validation.common import (
11 | dataframe_form_validation as common_df_validation,
12 | )
13 | from arize.pandas.tracing.validation.common import (
14 | value_validation as common_value_validation,
15 | )
16 | from arize.pandas.tracing.validation.spans import (
17 | dataframe_form_validation as df_validation,
18 | )
19 | from arize.pandas.tracing.validation.spans import value_validation
20 | from arize.pandas.validation import errors as err
21 |
22 |
23 | def validate_argument_types(
24 | spans_dataframe: pd.DataFrame,
25 | project_name: str,
26 | dt_fmt: str,
27 | model_version: Optional[str] = None,
28 | ) -> List[err.ValidationError]:
29 | checks = chain(
30 | common_arg_validation._check_field_convertible_to_str(
31 | project_name, model_version
32 | ),
33 | common_arg_validation._check_dataframe_type(spans_dataframe),
34 | common_arg_validation._check_datetime_format_type(dt_fmt),
35 | )
36 | return list(checks)
37 |
38 |
39 | def validate_dataframe_form(
40 | spans_dataframe: pd.DataFrame,
41 | ) -> List[err.ValidationError]:
42 | df_validation._log_info_dataframe_extra_column_names(spans_dataframe)
43 | checks = chain(
44 | # Common
45 | common_df_validation._check_dataframe_index(spans_dataframe),
46 | common_df_validation._check_dataframe_required_column_set(
47 | spans_dataframe,
48 | required_columns=[
49 | col.name for col in SPAN_OPENINFERENCE_REQUIRED_COLUMNS
50 | ],
51 | ),
52 | common_df_validation._check_dataframe_for_duplicate_columns(
53 | spans_dataframe
54 | ),
55 | # Spans specific
56 | df_validation._check_dataframe_column_content_type(spans_dataframe),
57 | )
58 | return list(checks)
59 |
60 |
61 | def validate_values(
62 | spans_dataframe: pd.DataFrame,
63 | project_name: str,
64 | model_version: Optional[str] = None,
65 | ) -> List[err.ValidationError]:
66 | checks = chain(
67 | # Common
68 | common_value_validation._check_invalid_project_name(project_name),
69 | common_value_validation._check_invalid_model_version(model_version),
70 | # Spans specific
71 | value_validation._check_span_root_field_values(spans_dataframe),
72 | value_validation._check_span_attributes_values(spans_dataframe),
73 | )
74 | return list(checks)
75 |
--------------------------------------------------------------------------------
/src/arize/pandas/validation/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Arize-ai/client_python/8f72212f42490baa7a90831c591b83ab6173ba83/src/arize/pandas/validation/__init__.py
--------------------------------------------------------------------------------
/src/arize/single_log/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Arize-ai/client_python/8f72212f42490baa7a90831c591b83ab6173ba83/src/arize/single_log/__init__.py
--------------------------------------------------------------------------------
/src/arize/single_log/casting.py:
--------------------------------------------------------------------------------
1 | import math
2 | from typing import List, Union
3 |
4 | from arize.utils.types import ArizeTypes, TypedValue
5 |
6 | from .errors import CastingError
7 |
8 |
9 | def cast_dictionary(d: dict) -> Union[dict, None]:
10 | if not d:
11 | return None
12 | cast_dict = {}
13 | for k, v in d.items():
14 | if isinstance(v, TypedValue):
15 | v = cast_value(v)
16 | cast_dict[k] = v
17 | return cast_dict
18 |
19 |
20 | def cast_value(
21 | typed_value: TypedValue,
22 | ) -> Union[str, int, float, List[str], None]:
23 | """
24 | Casts a TypedValue to its provided type, preserving all null values as None or float('nan').
25 |
26 | Arguments:
27 | ---------
28 | typed_value: TypedValue
29 | The TypedValue to cast.
30 |
31 | Returns:
32 | -------
33 | Union[str, int, float, List[str], None]
34 | The cast value.
35 |
36 | Raises:
37 | ------
38 | CastingError
39 | If the value cannot be cast to the provided type.
40 |
41 | """
42 | if typed_value.value is None:
43 | return None
44 |
45 | if typed_value.type == ArizeTypes.FLOAT:
46 | return _cast_to_float(typed_value)
47 | elif typed_value.type == ArizeTypes.INT:
48 | return _cast_to_int(typed_value)
49 | elif typed_value.type == ArizeTypes.STR:
50 | return _cast_to_str(typed_value)
51 | else:
52 | raise CastingError("Unknown casting type", typed_value)
53 |
54 |
55 | def _cast_to_float(typed_value: TypedValue) -> Union[float, None]:
56 | try:
57 | return float(typed_value.value)
58 | except Exception as e:
59 | raise CastingError(str(e), typed_value) from e
60 |
61 |
62 | def _cast_to_int(typed_value: TypedValue) -> Union[int, None]:
63 | # a NaN float can't be cast to an int. Proactively return None instead.
64 | if isinstance(typed_value.value, float) and math.isnan(typed_value.value):
65 | return None
66 | # If the value is a float, to avoid losing data precision,
67 | # we can only cast to an int if it is equivalent to an integer (e.g. 7.0).
68 | if (
69 | isinstance(typed_value.value, float)
70 | and not typed_value.value.is_integer()
71 | ):
72 | raise CastingError(
73 | "Cannot convert float with non-zero fractional part to int",
74 | typed_value,
75 | )
76 | try:
77 | return int(typed_value.value)
78 | except Exception as e:
79 | raise CastingError(str(e), typed_value) from e
80 |
81 |
82 | def _cast_to_str(typed_value: TypedValue) -> Union[str, None]:
83 | # a NaN float can't be cast to a string. Proactively return None instead.
84 | if isinstance(typed_value.value, float) and math.isnan(typed_value.value):
85 | return None
86 | try:
87 | return str(typed_value.value)
88 | except Exception as e:
89 | raise CastingError(str(e), typed_value) from e
90 |
--------------------------------------------------------------------------------
/src/arize/single_log/errors.py:
--------------------------------------------------------------------------------
1 | from arize.utils.types import TypedValue
2 |
3 |
4 | class CastingError(Exception):
5 | def __str__(self) -> str:
6 | return self.error_message()
7 |
8 | def __init__(self, error_msg: str, typed_value: TypedValue) -> None:
9 | self.error_msg = error_msg
10 | self.typed_value = typed_value
11 |
12 | def error_message(self) -> str:
13 | return (
14 | f"Failed to cast value {self.typed_value.value} of type {type(self.typed_value.value)} "
15 | f"to type {self.typed_value.type}. "
16 | f"Error: {self.error_msg}."
17 | )
18 |
--------------------------------------------------------------------------------
/src/arize/utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Arize-ai/client_python/8f72212f42490baa7a90831c591b83ab6173ba83/src/arize/utils/__init__.py
--------------------------------------------------------------------------------
/src/arize/utils/constants.py:
--------------------------------------------------------------------------------
1 | import json
2 | from pathlib import Path
3 |
4 | MAX_BYTES_PER_BULK_RECORD = 100000
5 | MAX_DAYS_WITHIN_RANGE = 365
6 | MIN_PREDICTION_ID_LEN = 1
7 | MAX_PREDICTION_ID_LEN = 512
8 | MIN_DOCUMENT_ID_LEN = 1
9 | MAX_DOCUMENT_ID_LEN = 128
10 | # The maximum number of character for tag values
11 | MAX_TAG_LENGTH = 20_000
12 | MAX_TAG_LENGTH_TRUNCATION = 1_000
13 | # The maximum number of character for embedding raw data
14 | MAX_RAW_DATA_CHARACTERS = 2_000_000
15 | MAX_RAW_DATA_CHARACTERS_TRUNCATION = 5_000
16 | # The maximum number of acceptable years in the past from current time for prediction_timestamps
17 | MAX_PAST_YEARS_FROM_CURRENT_TIME = 5
18 | # The maximum number of acceptable years in the future from current time for prediction_timestamps
19 | MAX_FUTURE_YEARS_FROM_CURRENT_TIME = 1
20 | # The maximum number of character for llm model name
21 | MAX_LLM_MODEL_NAME_LENGTH = 20_000
22 | MAX_LLM_MODEL_NAME_LENGTH_TRUNCATION = 50
23 | # The maximum number of character for prompt template
24 | MAX_PROMPT_TEMPLATE_LENGTH = 50_000
25 | MAX_PROMPT_TEMPLATE_LENGTH_TRUNCATION = 5_000
26 | # The maximum number of character for prompt template version
27 | MAX_PROMPT_TEMPLATE_VERSION_LENGTH = 20_000
28 | MAX_PROMPT_TEMPLATE_VERSION_LENGTH_TRUNCATION = 50
29 | # The maximum number of embeddings
30 | MAX_NUMBER_OF_EMBEDDINGS = 30
31 | MAX_EMBEDDING_DIMENSIONALITY = 20_000
32 | # The maximum number of classes for multi class
33 | MAX_NUMBER_OF_MULTI_CLASS_CLASSES = 300
34 | MAX_MULTI_CLASS_NAME_LENGTH = 100
35 | # The maximum number of references in embedding similarity search params
36 | MAX_NUMBER_OF_SIMILARITY_REFERENCES = 10
37 |
38 | # Arize generated columns
39 | GENERATED_PREDICTION_LABEL_COL = "arize_generated_prediction_label"
40 | GENERATED_LLM_PARAMS_JSON_COL = "arize_generated_llm_params_json"
41 |
42 | # reserved columns for LLM run metadata
43 | LLM_RUN_METADATA_TOTAL_TOKEN_COUNT_TAG_NAME = "total_token_count"
44 | LLM_RUN_METADATA_PROMPT_TOKEN_COUNT_TAG_NAME = "prompt_token_count"
45 | LLM_RUN_METADATA_RESPONSE_TOKEN_COUNT_TAG_NAME = "response_token_count"
46 | LLM_RUN_METADATA_RESPONSE_LATENCY_MS_TAG_NAME = "response_latency_ms"
47 |
48 | # all reserved tags
49 | RESERVED_TAG_COLS = [
50 | LLM_RUN_METADATA_TOTAL_TOKEN_COUNT_TAG_NAME,
51 | LLM_RUN_METADATA_PROMPT_TOKEN_COUNT_TAG_NAME,
52 | LLM_RUN_METADATA_RESPONSE_TOKEN_COUNT_TAG_NAME,
53 | LLM_RUN_METADATA_RESPONSE_LATENCY_MS_TAG_NAME,
54 | ]
55 |
56 |
57 | # Authentication via environment variables
58 | SPACE_KEY_ENVVAR_NAME = "ARIZE_SPACE_KEY"
59 | API_KEY_ENVVAR_NAME = "ARIZE_API_KEY"
60 | DEVELOPER_KEY_ENVVAR_NAME = "ARIZE_DEVELOPER_KEY"
61 | SPACE_ID_ENVVAR_NAME = "ARIZE_SPACE_ID"
62 |
63 | # Default public Flight endpoint when not provided through env variable nor profile
64 | DEFAULT_ARIZE_FLIGHT_HOST = "flight.arize.com"
65 | DEFAULT_ARIZE_FLIGHT_PORT = 443
66 |
67 | path = Path(__file__).with_name("model_mapping.json")
68 | with path.open("r") as f:
69 | MODEL_MAPPING_CONFIG = json.load(f)
70 |
--------------------------------------------------------------------------------
/src/arize/utils/logging.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import sys
3 | from typing import Any, List
4 |
5 |
6 | class CustomLogFormatter(logging.Formatter):
7 | grey = "\x1b[38;21m"
8 | blue = "\x1b[38;5;39m"
9 | yellow = "\x1b[33m"
10 | red = "\x1b[38;5;196m"
11 | bold_red = "\x1b[31;1m"
12 | reset = "\x1b[0m"
13 |
14 | def __init__(self, fmt):
15 | super().__init__()
16 | self.fmt = fmt
17 | self.FORMATS = {
18 | logging.DEBUG: self.blue + self.fmt + self.reset,
19 | logging.INFO: self.grey + self.fmt + self.reset,
20 | logging.WARNING: self.yellow + self.fmt + self.reset,
21 | logging.ERROR: self.red + self.fmt + self.reset,
22 | logging.CRITICAL: self.bold_red + self.fmt + self.reset,
23 | }
24 |
25 | def format(self, record):
26 | log_fmt = self.FORMATS.get(record.levelno)
27 | formatter = logging.Formatter(log_fmt)
28 | return formatter.format(record)
29 |
30 |
31 | logger = logging.getLogger(__name__)
32 | logger.propagate = False
33 | if logger.hasHandlers():
34 | logger.handlers.clear()
35 | logger.setLevel(logging.INFO)
36 | fmt = " %(name)s | %(levelname)s | %(message)s"
37 | if hasattr(sys, "ps1"): # for python interactive mode
38 | handler = logging.StreamHandler(sys.stdout)
39 | handler.setLevel(logging.INFO)
40 | handler.setFormatter(CustomLogFormatter(fmt))
41 | logger.addHandler(handler)
42 |
43 |
44 | def get_truncation_warning_message(instance, limit) -> str:
45 | return (
46 | f"Attention: {instance} exceeding the {limit} character limit will be "
47 | "automatically truncated upon ingestion into the Arize platform. Should you require "
48 | "a higher limit, please reach out to our support team at support@arize.com"
49 | )
50 |
51 |
52 | def log_a_list(list_of_str: List[Any], join_word: str) -> str:
53 | if list_of_str is None or len(list_of_str) == 0:
54 | return ""
55 | if len(list_of_str) == 1:
56 | return list_of_str[0]
57 | return (
58 | f"{', '.join(map(str, list_of_str[:-1]))} {join_word} {list_of_str[-1]}"
59 | )
60 |
--------------------------------------------------------------------------------
/src/arize/version.py:
--------------------------------------------------------------------------------
1 | __version__ = "7.43.1"
2 |
--------------------------------------------------------------------------------
/tests/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Arize-ai/client_python/8f72212f42490baa7a90831c591b83ab6173ba83/tests/__init__.py
--------------------------------------------------------------------------------
/tests/experimental/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Arize-ai/client_python/8f72212f42490baa7a90831c591b83ab6173ba83/tests/experimental/__init__.py
--------------------------------------------------------------------------------
/tests/experimental/datasets/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Arize-ai/client_python/8f72212f42490baa7a90831c591b83ab6173ba83/tests/experimental/datasets/__init__.py
--------------------------------------------------------------------------------
/tests/experimental/datasets/experiments/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Arize-ai/client_python/8f72212f42490baa7a90831c591b83ab6173ba83/tests/experimental/datasets/experiments/__init__.py
--------------------------------------------------------------------------------
/tests/experimental/datasets/validation/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Arize-ai/client_python/8f72212f42490baa7a90831c591b83ab6173ba83/tests/experimental/datasets/validation/__init__.py
--------------------------------------------------------------------------------
/tests/experimental/datasets/validation/test_validator.py:
--------------------------------------------------------------------------------
1 | import sys
2 |
3 | import pytest
4 |
5 | if sys.version_info < (3, 8):
6 | pytest.skip("Requires Python 3.8 or higher", allow_module_level=True)
7 |
8 | import pandas as pd
9 |
10 | from arize.experimental.datasets.core.client import (
11 | ArizeDatasetsClient,
12 | _convert_default_columns_to_json_str,
13 | )
14 | from arize.experimental.datasets.validation.errors import (
15 | IDColumnUniqueConstraintError,
16 | RequiredColumnsError,
17 | )
18 | from arize.experimental.datasets.validation.validator import Validator
19 |
20 |
21 | def test_happy_path():
22 | df = pd.DataFrame(
23 | {
24 | "user_data": [1, 2, 3],
25 | }
26 | )
27 |
28 | df_new = ArizeDatasetsClient._set_default_columns_for_dataset(df)
29 | differences = set(df_new.columns) ^ {
30 | "id",
31 | "created_at",
32 | "updated_at",
33 | "user_data",
34 | }
35 | assert not differences
36 |
37 | validation_errors = Validator.validate(df)
38 | assert len(validation_errors) == 0
39 |
40 |
41 | def test_missing_columns():
42 | df = pd.DataFrame(
43 | {
44 | "user_data": [1, 2, 3],
45 | }
46 | )
47 |
48 | validation_errors = Validator.validate(df)
49 | assert len(validation_errors) == 1
50 | assert type(validation_errors[0]) is RequiredColumnsError
51 |
52 |
53 | def test_non_unique_id_column():
54 | df = pd.DataFrame(
55 | {
56 | "id": [1, 1, 2],
57 | "user_data": [1, 2, 3],
58 | }
59 | )
60 | df_new = ArizeDatasetsClient._set_default_columns_for_dataset(df)
61 |
62 | validation_errors = Validator.validate(df_new)
63 | assert len(validation_errors) == 1
64 | assert validation_errors[0] is IDColumnUniqueConstraintError
65 |
66 |
67 | @pytest.mark.skipif(
68 | sys.version_info < (3, 8), reason="Requires Python 3.8 or higher"
69 | )
70 | def test_dict_to_json_conversion() -> None:
71 | df = pd.DataFrame(
72 | {
73 | "id": [1, 2, 3],
74 | "eval.MyEvaluator.metadata": [
75 | {"key": "value"},
76 | {"key": "value"},
77 | {"key": "value"},
78 | ],
79 | "not_converted_dict_col": [
80 | {"key": "value"},
81 | {"key": "value"},
82 | {"key": "value"},
83 | ],
84 | }
85 | )
86 | # before conversion, the column with the evaluator name is a dict
87 | assert type(df["eval.MyEvaluator.metadata"][0]) is dict
88 | assert type(df["not_converted_dict_col"][0]) is dict
89 |
90 | # Check that only the column with the evaluator name is converted to JSON
91 | converted_df = _convert_default_columns_to_json_str(df)
92 | assert type(converted_df["eval.MyEvaluator.metadata"][0]) is str
93 | assert type(converted_df["not_converted_dict_col"][0]) is dict
94 |
--------------------------------------------------------------------------------
/tests/exporter/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Arize-ai/client_python/8f72212f42490baa7a90831c591b83ab6173ba83/tests/exporter/__init__.py
--------------------------------------------------------------------------------
/tests/exporter/test_exporter.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Arize-ai/client_python/8f72212f42490baa7a90831c591b83ab6173ba83/tests/exporter/test_exporter.py
--------------------------------------------------------------------------------
/tests/exporter/utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Arize-ai/client_python/8f72212f42490baa7a90831c591b83ab6173ba83/tests/exporter/utils/__init__.py
--------------------------------------------------------------------------------
/tests/exporter/validations/test_validator_invalid_types.py:
--------------------------------------------------------------------------------
1 | import unittest
2 | from datetime import datetime
3 |
4 | from arize.exporter.utils.validation import Validator
5 |
6 |
7 | class MyTestCase(unittest.TestCase):
8 | def test_zero_error(self):
9 | try:
10 | for key, val in valid_inputs.items():
11 | Validator.validate_input_type(val[0], key, val[1])
12 | except TypeError:
13 | self.fail("validate_input_type raised TypeError unexpectedly")
14 |
15 | def test_invalid_space_id(self):
16 | space_id = invalid_inputs["space_id"][0]
17 | valid_type = invalid_inputs["space_id"][1]
18 | with self.assertRaisesRegex(
19 | TypeError,
20 | f"space_id {space_id} is type {type(space_id)}, but must be a str",
21 | ):
22 | Validator.validate_input_type(space_id, "space_id", valid_type)
23 |
24 | def test_invalid_model_name(self):
25 | model_name = invalid_inputs["model_name"][0]
26 | valid_type = invalid_inputs["model_name"][1]
27 | with self.assertRaisesRegex(
28 | TypeError,
29 | f"model_name {model_name} is type {type(model_name)}, but must be a str",
30 | ):
31 | Validator.validate_input_type(model_name, "model_name", valid_type)
32 |
33 | def test_invalid_data_type(self):
34 | data_type = invalid_inputs["data_type"][0]
35 | valid_type = invalid_inputs["data_type"][1]
36 | with self.assertRaisesRegex(
37 | TypeError,
38 | f"data_type {data_type} is type {type(data_type)}, but must be a str",
39 | ):
40 | Validator.validate_input_type(data_type, "data_type", valid_type)
41 |
42 | def test_invalid_start_or_end_time(self):
43 | start_time = invalid_inputs["start_time"][0]
44 | valid_type = invalid_inputs["start_time"][1]
45 | with self.assertRaisesRegex(
46 | TypeError,
47 | f"start_time {start_time} is type {type(start_time)}, but must be a datetime",
48 | ):
49 | Validator.validate_input_type(start_time, "start_time", valid_type)
50 |
51 | def test_invalid_path(self):
52 | path = invalid_inputs["path"][0]
53 | valid_type = invalid_inputs["path"][1]
54 | with self.assertRaisesRegex(
55 | TypeError,
56 | f"path {path} is type {type(path)}, but must be a str",
57 | ):
58 | Validator.validate_input_type(path, "path", valid_type)
59 |
60 |
61 | valid_inputs = {
62 | "space_id": ("abc123", str),
63 | "model_name": ("test_model", str),
64 | "data_type": ("predictions", str),
65 | "start_time": (datetime(2023, 4, 1, 0, 0, 0, 0), datetime),
66 | "end_time": (datetime(2023, 4, 15, 0, 0, 0, 0), datetime),
67 | "path": ("example.parquet", str),
68 | }
69 |
70 | invalid_inputs = {
71 | "space_id": (123, str),
72 | "model_name": (123, str),
73 | "data_type": (0.2, str),
74 | "start_time": ("2022-10-10", datetime),
75 | "end_time": ("2022-10-15", datetime),
76 | "path": (0.2, str),
77 | }
78 |
79 | if __name__ == "__main__":
80 | unittest.main()
81 |
--------------------------------------------------------------------------------
/tests/exporter/validations/test_validator_invalid_values.py:
--------------------------------------------------------------------------------
1 | import unittest
2 | from datetime import datetime
3 |
4 | from arize.exporter.utils.validation import Validator
5 |
6 |
7 | class MyTestCase(unittest.TestCase):
8 | def test_valid_data_type(self):
9 | try:
10 | Validator.validate_input_value(
11 | valid_data_type.upper(), "data_type", data_types
12 | )
13 | except TypeError:
14 | self.fail("validate_input_type raised TypeError unexpectedly")
15 |
16 | def test_invalid_data_type(self):
17 | with self.assertRaisesRegex(
18 | ValueError,
19 | f"data_type is {invalid_data_type.upper()}, but must be one of PREDICTIONS, "
20 | f"CONCLUSIONS, EXPLANATIONS, PREPRODUCTION",
21 | ):
22 | Validator.validate_input_value(
23 | invalid_data_type.upper(), "data_type", data_types
24 | )
25 |
26 | def test_invalid_start_end_time(self):
27 | with self.assertRaisesRegex(
28 | ValueError,
29 | "start_time must be before end_time",
30 | ):
31 | Validator.validate_start_end_time(start_time, end_time)
32 |
33 |
34 | valid_data_type = "preproduction"
35 | invalid_data_type = "hello"
36 | data_types = ("PREDICTIONS", "CONCLUSIONS", "EXPLANATIONS", "PREPRODUCTION")
37 | start_time = datetime(2023, 6, 15, 10, 30)
38 | end_time = datetime(2023, 6, 10, 10, 30)
39 |
40 | if __name__ == "__main__":
41 | unittest.main()
42 |
--------------------------------------------------------------------------------
/tests/fixtures/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Arize-ai/client_python/8f72212f42490baa7a90831c591b83ab6173ba83/tests/fixtures/__init__.py
--------------------------------------------------------------------------------
/tests/pandas/generative/nlp_metrics/test_nlp_metrics.py:
--------------------------------------------------------------------------------
1 | import pandas as pd
2 | import pytest
3 |
4 | from arize.pandas.generative.nlp_metrics import (
5 | bleu,
6 | google_bleu,
7 | meteor,
8 | rouge,
9 | sacre_bleu,
10 | )
11 |
12 |
13 | def get_text_df() -> pd.DataFrame:
14 | return pd.DataFrame(
15 | {
16 | "response": [
17 | "The cat is on the mat.",
18 | "The NASA Opportunity rover is battling a massive dust storm on Mars.",
19 | ],
20 | "references": [
21 | ["The cat is on the blue mat."],
22 | [
23 | "The Opportunity rover is combating a big sandstorm on Mars.",
24 | "A NASA rover is fighting a massive storm on Mars.",
25 | ],
26 | ],
27 | }
28 | )
29 |
30 |
31 | def get_results_df() -> pd.DataFrame:
32 | return pd.DataFrame(
33 | {
34 | "bleu": [0.6129752413741056, 0.32774568052975916],
35 | "sacrebleu": [61.29752413741059, 32.774568052975916],
36 | "google_bleu": [0.6538461538461539, 0.3695652173913043],
37 | "rouge1": [0.923076923076923, 0.7272727272727272],
38 | "rouge2": [0.7272727272727272, 0.39999999999999997],
39 | "rougeL": [0.923076923076923, 0.7272727272727272],
40 | "rougeLsum": [0.923076923076923, 0.7272727272727272],
41 | "meteor": [0.8757427021441488, 0.7682980599647267],
42 | }
43 | )
44 |
45 |
46 | def test_bleu_score() -> None:
47 | texts = get_text_df()
48 | results = get_results_df()
49 |
50 | try:
51 | bleu_scores = bleu(
52 | response_col=texts["response"], references_col=texts["references"]
53 | )
54 | except Exception as e:
55 | raise AssertionError("Unexpected Error") from e
56 |
57 | assert (bleu_scores == results["bleu"]).all(), "BLEU scores should match" # type:ignore
58 |
59 |
60 | def test_sacrebleu_score() -> None:
61 | texts = get_text_df()
62 | results = get_results_df()
63 |
64 | try:
65 | sacrebleu_scores = sacre_bleu(
66 | response_col=texts["response"], references_col=texts["references"]
67 | )
68 | except Exception as e:
69 | raise AssertionError("Unexpected Error") from e
70 |
71 | assert (
72 | sacrebleu_scores == results["sacrebleu"]
73 | ).all(), "SacreBLEU scores should match" # type:ignore
74 |
75 |
76 | def test_google_bleu_score() -> None:
77 | texts = get_text_df()
78 | results = get_results_df()
79 |
80 | try:
81 | gbleu_scores = google_bleu(
82 | response_col=texts["response"], references_col=texts["references"]
83 | )
84 | except Exception as e:
85 | raise AssertionError("Unexpected Error") from e
86 |
87 | assert (
88 | gbleu_scores == results["google_bleu"]
89 | ).all(), "Google BLEU scores should match" # type:ignore
90 |
91 |
92 | def test_rouge_score() -> None:
93 | texts = get_text_df()
94 | results = get_results_df()
95 |
96 | try:
97 | rouge_scores = rouge(
98 | response_col=texts["response"], references_col=texts["references"]
99 | )
100 | except Exception as e:
101 | raise AssertionError("Unexpected Error") from e
102 |
103 | # Check that only default rouge scores are returned, and they match
104 | assert isinstance(rouge_scores, dict)
105 | assert len(rouge_scores.keys()) == 1
106 | assert list(rouge_scores.keys())[0] == "rougeL" # type:ignore
107 | assert (
108 | rouge_scores["rougeL"] == results["rougeL"]
109 | ).all(), "ROUGE scores should match" # type: ignore
110 |
111 | rouge_types = [
112 | "rouge1",
113 | "rouge2",
114 | "rougeL",
115 | "rougeLsum",
116 | ]
117 | try:
118 | rouge_scores = rouge(
119 | response_col=texts["response"],
120 | references_col=texts["references"],
121 | rouge_types=rouge_types,
122 | )
123 | except Exception as e:
124 | raise AssertionError("Unexpected Error") from e
125 |
126 | # Check that all rouge scores are returned, and they match
127 | assert isinstance(rouge_scores, dict)
128 | assert list(rouge_scores.keys()) == rouge_types
129 | for rtype in rouge_types:
130 | assert (
131 | rouge_scores[rtype] == results[rtype]
132 | ).all(), f"ROUGE scores ({rtype}) should match" # type: ignore
133 |
134 |
135 | def test_meteor_score() -> None:
136 | texts = get_text_df()
137 | results = get_results_df()
138 |
139 | try:
140 | meteor_scores = meteor(
141 | response_col=texts["response"], references_col=texts["references"]
142 | )
143 | except Exception as e:
144 | raise AssertionError("Unexpected Error") from e
145 |
146 | assert (
147 | meteor_scores == results["meteor"]
148 | ).all(), "METEOR scores should match" # type:ignore
149 |
150 |
151 | if __name__ == "__main__":
152 | raise SystemExit(pytest.main([__file__]))
153 |
--------------------------------------------------------------------------------
/tests/pandas/tracing/test_columns.py:
--------------------------------------------------------------------------------
1 | import re
2 | import sys
3 |
4 | import pytest
5 |
6 | if sys.version_info >= (3, 8):
7 | from arize.pandas.tracing.columns import (
8 | ANNOTATION_COLUMN_PATTERN,
9 | ANNOTATION_NAME_PATTERN,
10 | EVAL_COLUMN_PATTERN,
11 | EVAL_EXPLANATION_PATTERN,
12 | EVAL_LABEL_PATTERN,
13 | EVAL_NAME_PATTERN,
14 | EVAL_SCORE_PATTERN,
15 | )
16 |
17 |
18 | @pytest.mark.skipif(sys.version_info < (3, 8), reason="Requires python>=3.8")
19 | def test_eval_column_pattern():
20 | assert re.match(EVAL_COLUMN_PATTERN, "eval.name.label")
21 | assert re.match(EVAL_COLUMN_PATTERN, "eval.name with spaces.label")
22 | assert re.match(EVAL_COLUMN_PATTERN, "eval.name.score")
23 | assert re.match(EVAL_COLUMN_PATTERN, "eval.name.explanation")
24 | assert not re.match(EVAL_COLUMN_PATTERN, "eval.name")
25 | assert not re.match(EVAL_COLUMN_PATTERN, "name.label")
26 |
27 |
28 | @pytest.mark.skipif(sys.version_info < (3, 8), reason="Requires python>=3.8")
29 | def test_eval_label_pattern():
30 | assert re.match(EVAL_LABEL_PATTERN, "eval.name.label")
31 | assert re.match(EVAL_LABEL_PATTERN, "eval.name with spaces.label")
32 | assert not re.match(EVAL_LABEL_PATTERN, "eval.name.score")
33 |
34 |
35 | @pytest.mark.skipif(sys.version_info < (3, 8), reason="Requires python>=3.8")
36 | def test_eval_score_pattern():
37 | assert re.match(EVAL_SCORE_PATTERN, "eval.name.score")
38 | assert re.match(EVAL_SCORE_PATTERN, "eval.name with spaces.score")
39 | assert not re.match(EVAL_SCORE_PATTERN, "eval.name.label")
40 |
41 |
42 | @pytest.mark.skipif(sys.version_info < (3, 8), reason="Requires python>=3.8")
43 | def test_eval_explanation_pattern():
44 | assert re.match(EVAL_EXPLANATION_PATTERN, "eval.name.explanation")
45 | assert re.match(
46 | EVAL_EXPLANATION_PATTERN, "eval.name with spaces.explanation"
47 | )
48 | assert not re.match(EVAL_EXPLANATION_PATTERN, "eval_name.label")
49 |
50 |
51 | @pytest.mark.skipif(sys.version_info < (3, 8), reason="Requires python>=3.8")
52 | def test_eval_name_capture():
53 | matches = [
54 | ("eval.name_part.label", "name_part"),
55 | ("eval.name with spaces.label", "name with spaces"),
56 | ]
57 | for test_str, expected_name in matches:
58 | match = re.match(EVAL_NAME_PATTERN, test_str)
59 | assert match is not None, f"Failed to match '{test_str}'"
60 | assert (
61 | match.group(1) == expected_name
62 | ), f"Incorrect name captured for '{test_str}'"
63 |
64 | non_matches = [
65 | "evalname.", # Missing dot
66 | "name_part.", # Missing prefix
67 | "eval.name", # Missing suffix
68 | ]
69 | for test_str in non_matches:
70 | assert (
71 | re.match(EVAL_NAME_PATTERN, test_str) is None
72 | ), f"Incorrectly matched '{test_str}'"
73 |
74 |
75 | @pytest.mark.skipif(sys.version_info < (3, 8), reason="Requires python>=3.8")
76 | def test_annotation_column_pattern():
77 | """Test the regex pattern for matching valid annotation column names."""
78 | valid_patterns = [
79 | "annotation.quality.label",
80 | "annotation.toxicity_score.score",
81 | "annotation.needs review.updated_by",
82 | "annotation.timestamp_col.updated_at",
83 | "annotation.numeric_name123.label",
84 | ]
85 | invalid_patterns = [
86 | "annotation.quality", # Missing suffix
87 | "annotations.quality.label", # Wrong prefix
88 | "annotation.quality.Score", # Incorrect suffix case
89 | "annotation..label", # Empty name part
90 | "quality.label", # Missing prefix
91 | "annotation.notes", # Specific reserved name (tested separately if needed)
92 | "eval.quality.label", # Different prefix
93 | ]
94 |
95 | for pattern in valid_patterns:
96 | assert re.match(
97 | ANNOTATION_COLUMN_PATTERN, pattern
98 | ), f"Should match: {pattern}"
99 | for pattern in invalid_patterns:
100 | assert not re.match(
101 | ANNOTATION_COLUMN_PATTERN, pattern
102 | ), f"Should NOT match: {pattern}"
103 |
104 |
105 | @pytest.mark.skipif(sys.version_info < (3, 8), reason="Requires python>=3.8")
106 | def test_annotation_name_capture():
107 | """Test capturing the annotation name from column strings."""
108 | matches = [
109 | ("annotation.quality.label", "quality"),
110 | ("annotation.name with spaces.score", "name with spaces"),
111 | ("annotation.name_123.updated_by", "name_123"),
112 | ]
113 | for test_str, expected_name in matches:
114 | match = re.match(ANNOTATION_NAME_PATTERN, test_str)
115 | assert match is not None, f"Failed to match '{test_str}'"
116 | assert (
117 | match.group(1) == expected_name
118 | ), f"Incorrect name captured for '{test_str}'"
119 |
120 | non_matches = [
121 | "annotationname.label", # Missing dot after prefix
122 | "annotation..score", # Empty name part
123 | "annotation.name", # No suffix
124 | "quality.label", # Missing prefix
125 | ]
126 | for test_str in non_matches:
127 | assert (
128 | re.match(ANNOTATION_NAME_PATTERN, test_str) is None
129 | ), f"Incorrectly matched '{test_str}'"
130 |
131 |
132 | if __name__ == "__main__":
133 | raise SystemExit(pytest.main([__file__]))
134 |
--------------------------------------------------------------------------------
/tests/pandas/tracing/validation/test_invalid_annotation_arguments.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import time
3 | import uuid
4 |
5 | import pandas as pd
6 | import pytest
7 |
8 | if sys.version_info >= (3, 8):
9 | from arize.pandas.tracing.columns import (
10 | ANNOTATION_LABEL_SUFFIX,
11 | ANNOTATION_NOTES_COLUMN_NAME,
12 | ANNOTATION_SCORE_SUFFIX,
13 | ANNOTATION_UPDATED_AT_SUFFIX,
14 | ANNOTATION_UPDATED_BY_SUFFIX,
15 | SPAN_SPAN_ID_COL,
16 | )
17 | from arize.pandas.tracing.validation.annotations import (
18 | annotations_validation,
19 | )
20 |
21 |
22 | def get_valid_annotation_df(num_rows=2):
23 | """Helper to create a DataFrame with the correct type for the notes column."""
24 | span_ids = [str(uuid.uuid4()) for _ in range(num_rows)]
25 | current_time_ms = int(time.time() * 1000)
26 |
27 | # Initialize notes with empty lists for all rows to ensure consistent list type
28 | notes_list = [[] for _ in range(num_rows)]
29 | if num_rows > 0:
30 | notes_list[0] = [
31 | '{"text": "Note 1"}'
32 | ] # Assign the actual note to the first row
33 |
34 | df = pd.DataFrame(
35 | {
36 | SPAN_SPAN_ID_COL.name: span_ids,
37 | f"annotation.quality{ANNOTATION_LABEL_SUFFIX}": ["good", "bad"][
38 | :num_rows
39 | ],
40 | f"annotation.quality{ANNOTATION_SCORE_SUFFIX}": [0.9, 0.1][
41 | :num_rows
42 | ],
43 | f"annotation.quality{ANNOTATION_UPDATED_BY_SUFFIX}": [
44 | "user1",
45 | "user2",
46 | ][:num_rows],
47 | f"annotation.quality{ANNOTATION_UPDATED_AT_SUFFIX}": [
48 | current_time_ms - 1000,
49 | current_time_ms,
50 | ][:num_rows],
51 | # Use the pre-initialized list with consistent type
52 | ANNOTATION_NOTES_COLUMN_NAME: pd.Series(notes_list, dtype=object),
53 | }
54 | )
55 | return df
56 |
57 |
58 | @pytest.mark.skipif(sys.version_info < (3, 8), reason="Requires python>=3.8")
59 | def test_valid_annotation_column_types():
60 | """Tests that a DataFrame with correct types passes validation."""
61 | annotations_dataframe = get_valid_annotation_df()
62 | errors = annotations_validation.validate_dataframe_form(
63 | annotations_dataframe=annotations_dataframe
64 | )
65 | assert len(errors) == 0, "Expected no validation errors for valid types"
66 |
67 |
68 | @pytest.mark.skipif(sys.version_info < (3, 8), reason="Requires python>=3.8")
69 | def test_invalid_annotation_label_type():
70 | """Tests error for non-string label column."""
71 | annotations_dataframe = get_valid_annotation_df()
72 | annotations_dataframe[f"annotation.quality{ANNOTATION_LABEL_SUFFIX}"] = [
73 | 1,
74 | 2,
75 | ]
76 | errors = annotations_validation.validate_dataframe_form(
77 | annotations_dataframe=annotations_dataframe
78 | )
79 | assert len(errors) > 0
80 |
81 |
82 | @pytest.mark.skipif(sys.version_info < (3, 8), reason="Requires python>=3.8")
83 | def test_invalid_annotation_score_type():
84 | """Tests error for non-numeric score column."""
85 | annotations_dataframe = get_valid_annotation_df()
86 | annotations_dataframe[f"annotation.quality{ANNOTATION_SCORE_SUFFIX}"] = [
87 | "high",
88 | "low",
89 | ]
90 | errors = annotations_validation.validate_dataframe_form(
91 | annotations_dataframe=annotations_dataframe
92 | )
93 | assert len(errors) > 0
94 |
95 |
96 | @pytest.mark.skipif(sys.version_info < (3, 8), reason="Requires python>=3.8")
97 | def test_invalid_annotation_updated_by_type():
98 | """Tests error for non-string updated_by column."""
99 | annotations_dataframe = get_valid_annotation_df()
100 | annotations_dataframe[
101 | f"annotation.quality{ANNOTATION_UPDATED_BY_SUFFIX}"
102 | ] = [100, 200]
103 | errors = annotations_validation.validate_dataframe_form(
104 | annotations_dataframe=annotations_dataframe
105 | )
106 | assert len(errors) > 0
107 |
108 |
109 | @pytest.mark.skipif(sys.version_info < (3, 8), reason="Requires python>=3.8")
110 | def test_invalid_annotation_updated_at_type():
111 | """Tests error for non-numeric updated_at column."""
112 | annotations_dataframe = get_valid_annotation_df()
113 | annotations_dataframe[
114 | f"annotation.quality{ANNOTATION_UPDATED_AT_SUFFIX}"
115 | ] = ["yesterday", "today"]
116 | errors = annotations_validation.validate_dataframe_form(
117 | annotations_dataframe=annotations_dataframe
118 | )
119 | assert len(errors) > 0
120 |
121 |
122 | @pytest.mark.skipif(sys.version_info < (3, 8), reason="Requires python>=3.8")
123 | def test_invalid_annotation_notes_type():
124 | """Tests error for non-list notes column."""
125 | annotations_dataframe = get_valid_annotation_df()
126 | annotations_dataframe[ANNOTATION_NOTES_COLUMN_NAME] = "just a string"
127 | errors = annotations_validation.validate_dataframe_form(
128 | annotations_dataframe=annotations_dataframe
129 | )
130 | assert len(errors) > 0
131 |
132 |
133 | if __name__ == "__main__":
134 | raise SystemExit(pytest.main([__file__]))
135 |
--------------------------------------------------------------------------------
/tests/pandas/tracing/validation/test_invalid_arguments.py:
--------------------------------------------------------------------------------
1 | import sys
2 |
3 | import pandas as pd
4 | import pytest
5 |
6 | if sys.version_info >= (3, 8):
7 | from arize.pandas.tracing.validation.evals import evals_validation
8 |
9 | valid_spans_dataframe = pd.DataFrame(
10 | {
11 | "context.span_id": ["span_id_11111111", "span_id_22222222"],
12 | "context.trace_id": ["trace_id_11111111", "trace_id_22222222"],
13 | "name": ["name_1", "name_2"],
14 | "start_time": [
15 | "2024-01-18T18:28:27.429383+00:00",
16 | "2024-01-18T18:28:27.429383+00:00",
17 | ],
18 | "end_time": [
19 | "2024-01-18T18:28:27.429383+00:00",
20 | "2024-01-18T18:28:27.429383+00:00",
21 | ],
22 | }
23 | )
24 |
25 |
26 | @pytest.mark.skipif(sys.version_info < (3, 8), reason="Requires python>=3.8")
27 | def test_valid_eval_column_types():
28 | evals_dataframe = pd.DataFrame(
29 | {
30 | "context.span_id": ["span_id_1", "span_id_2"],
31 | "eval.eval_1.label": ["relevant", "irrelevant"],
32 | "eval.eval_1.score": [1.0, None],
33 | "eval.eval_1.explanation": [
34 | "explanation for relevant",
35 | "explanation for irrelevant",
36 | ],
37 | }
38 | )
39 | errors = evals_validation.validate_dataframe_form(
40 | evals_dataframe=evals_dataframe
41 | )
42 | assert len(errors) == 0, "Expected no validation errors for all columns"
43 |
44 |
45 | @pytest.mark.skipif(sys.version_info < (3, 8), reason="Requires python>=3.8")
46 | def test_invalid_label_columns_type():
47 | evals_dataframe = pd.DataFrame(
48 | {
49 | "context.span_id": ["span_id_1", "span_id_2"],
50 | "eval.eval_1.label": [
51 | 1,
52 | 2,
53 | ],
54 | }
55 | )
56 | errors = evals_validation.validate_dataframe_form(
57 | evals_dataframe=evals_dataframe
58 | )
59 | assert (
60 | len(errors) > 0
61 | ), "Expected validation errors for label columns with incorrect type"
62 |
63 |
64 | @pytest.mark.skipif(sys.version_info < (3, 8), reason="Requires python>=3.8")
65 | def test_invalid_score_columns_type():
66 | evals_dataframe = pd.DataFrame(
67 | {
68 | "context.span_id": ["span_id_1", "span_id_2"],
69 | "eval.eval_1.score": [
70 | "1.0",
71 | "None",
72 | ],
73 | }
74 | )
75 | errors = evals_validation.validate_dataframe_form(
76 | evals_dataframe=evals_dataframe
77 | )
78 | assert (
79 | len(errors) > 0
80 | ), "Expected validation errors for score columns with incorrect type"
81 |
82 |
83 | @pytest.mark.skipif(sys.version_info < (3, 8), reason="Requires python>=3.8")
84 | def test_invalid_explanation_columns_type():
85 | evals_dataframe = pd.DataFrame(
86 | {
87 | "context.span_id": ["span_id_1", "span_id_2"],
88 | "eval.eval_1.explanation": [
89 | 1,
90 | 2,
91 | ],
92 | }
93 | )
94 | errors = evals_validation.validate_dataframe_form(
95 | evals_dataframe=evals_dataframe
96 | )
97 | assert (
98 | len(errors) > 0
99 | ), "Expected validation errors for explanation columns with incorrect type"
100 |
101 |
102 | if __name__ == "__main__":
103 | raise SystemExit(pytest.main([__file__]))
104 |
--------------------------------------------------------------------------------
/tests/pandas/validation/test_pandas_validator_error_messages.py:
--------------------------------------------------------------------------------
1 | import pytest
2 |
3 | import arize.pandas.validation.errors as err
4 | from arize.utils.types import Environments, ModelTypes
5 |
6 | # ----------------
7 | # Parameter checks
8 | # ----------------
9 |
10 |
11 | def test_missing_columns():
12 | err_msg = str(err.MissingColumns(["genotype", "phenotype"]))
13 | assert "genotype" in err_msg
14 | assert "phenotype" in err_msg
15 |
16 |
17 | def test_invalid_model_type():
18 | err_msg = str(err.InvalidModelType())
19 | assert all(mt.name in err_msg for mt in ModelTypes)
20 |
21 |
22 | def test_invalid_environment():
23 | err_msg = str(err.InvalidEnvironment())
24 | assert all(env.name in err_msg for env in Environments)
25 |
26 |
27 | # -----------
28 | # Type checks
29 | # -----------
30 |
31 |
32 | def test_Invalid_type():
33 | err_msg = str(err.InvalidType("123", ["456", "789"], "112"))
34 | assert "123" in err_msg
35 | assert "456" in err_msg
36 | assert "789" in err_msg
37 | assert "112" in err_msg
38 |
39 |
40 | def test_invalid_type_features():
41 | err_msg = str(
42 | err.InvalidTypeFeatures(
43 | ["genotype", "phenotype"], ["Triceratops", "Archaeopteryx"]
44 | )
45 | )
46 | assert "genotype" in err_msg
47 | assert "phenotype" in err_msg
48 | assert "Triceratops" in err_msg
49 | assert "Archaeopteryx" in err_msg
50 |
51 |
52 | def test_invalid_type_tags():
53 | err_msg = str(
54 | err.InvalidTypeTags(
55 | ["genotype", "phenotype"], ["Triceratops", "Archaeopteryx"]
56 | )
57 | )
58 | assert "genotype" in err_msg
59 | assert "phenotype" in err_msg
60 | assert "Triceratops" in err_msg
61 | assert "Archaeopteryx" in err_msg
62 |
63 |
64 | def test_invalid_type_shap_values():
65 | err_msg = str(
66 | err.InvalidTypeShapValues(
67 | ["genotype", "phenotype"], ["Triceratops", "Archaeopteryx"]
68 | )
69 | )
70 | assert "genotype" in err_msg
71 | assert "phenotype" in err_msg
72 | assert "Triceratops" in err_msg
73 | assert "Archaeopteryx" in err_msg
74 |
75 |
76 | # ------------
77 | # Value checks
78 | # ------------
79 |
80 |
81 | def test_invalid_timestamp():
82 | err_msg = str(err.InvalidValueTimestamp("Spinosaurus"))
83 | assert "Spinosaurus" in err_msg
84 |
85 |
86 | def test_invalid_missing_value():
87 | err_msg = str(err.InvalidValueMissingValue("Stegosaurus", "missing"))
88 | assert "Stegosaurus" in err_msg
89 | assert "missing" in err_msg
90 |
91 |
92 | def test_invalid_infinite_value():
93 | err_msg = str(err.InvalidValueMissingValue("Stegosaurus", "infinite"))
94 | assert "Stegosaurus" in err_msg
95 | assert "infinite" in err_msg
96 |
97 |
98 | if __name__ == "__main__":
99 | raise SystemExit(pytest.main([__file__]))
100 |
--------------------------------------------------------------------------------
/tests/pandas/validation/test_pandas_validator_invalid_shap_suffix.py:
--------------------------------------------------------------------------------
1 | from collections import ChainMap
2 |
3 | import pandas as pd
4 | import pytest
5 |
6 | from arize.pandas.logger import Schema
7 | from arize.pandas.validation.errors import InvalidShapSuffix
8 | from arize.pandas.validation.validator import Validator
9 | from arize.utils.types import Environments, ModelTypes
10 |
11 |
12 | def test_invalid_feature_columns():
13 | errors = Validator.validate_params(
14 | **ChainMap(
15 | {
16 | "dataframe": kwargs["dataframe"].assign(feat_shap=[0]),
17 | "schema": Schema(
18 | prediction_id_column_name="prediction_id",
19 | prediction_label_column_name="prediction_label",
20 | feature_column_names=["feat_shap"],
21 | ),
22 | },
23 | kwargs,
24 | ),
25 | )
26 | assert len(errors) == 1
27 | assert type(errors[0]) is InvalidShapSuffix
28 |
29 |
30 | def test_invalid_tag_columns():
31 | errors = Validator.validate_params(
32 | **ChainMap(
33 | {
34 | "dataframe": kwargs["dataframe"].assign(tag_shap=[0]),
35 | "schema": Schema(
36 | prediction_id_column_name="prediction_id",
37 | prediction_label_column_name="prediction_label",
38 | tag_column_names=["tag_shap"],
39 | ),
40 | },
41 | kwargs,
42 | ),
43 | )
44 | assert len(errors) == 1
45 | assert type(errors[0]) is InvalidShapSuffix
46 |
47 |
48 | def test_invalid_shap_columns():
49 | errors = Validator.validate_params(
50 | **ChainMap(
51 | {
52 | "dataframe": kwargs["dataframe"].assign(shap=[0]),
53 | "schema": Schema(
54 | prediction_id_column_name="prediction_id",
55 | prediction_label_column_name="prediction_label",
56 | shap_values_column_names={"feat_shap": "shap"},
57 | ),
58 | },
59 | kwargs,
60 | ),
61 | )
62 | assert len(errors) == 1
63 | assert type(errors[0]) is InvalidShapSuffix
64 |
65 |
66 | def test_invalid_multiple():
67 | errors = Validator.validate_params(
68 | **ChainMap(
69 | {
70 | "dataframe": kwargs["dataframe"].assign(
71 | feat_shap=[0], tag_shap=[0], shap=[0]
72 | ),
73 | "schema": Schema(
74 | prediction_id_column_name="prediction_id",
75 | prediction_label_column_name="prediction_label",
76 | feature_column_names=["feat_shap"],
77 | tag_column_names=["tag_shap"],
78 | shap_values_column_names={"feat_shap": "shap"},
79 | ),
80 | },
81 | kwargs,
82 | ),
83 | )
84 | assert len(errors) == 1
85 | assert type(errors[0]) is InvalidShapSuffix
86 |
87 |
88 | kwargs = {
89 | "model_id": "fraud",
90 | "model_version": "v1.0",
91 | "model_type": ModelTypes.SCORE_CATEGORICAL,
92 | "environment": Environments.PRODUCTION,
93 | "dataframe": pd.DataFrame(
94 | {
95 | "prediction_id": pd.Series(["0"]),
96 | "prediction_label": pd.Series(["fraud"]),
97 | "prediction_score": pd.Series([1]),
98 | "actual_label": pd.Series(["not fraud"]),
99 | "actual_score": pd.Series([0]),
100 | }
101 | ),
102 | "schema": Schema(
103 | prediction_id_column_name="prediction_id",
104 | prediction_label_column_name="prediction_label",
105 | ),
106 | }
107 |
108 | if __name__ == "__main__":
109 | raise SystemExit(pytest.main([__file__]))
110 |
--------------------------------------------------------------------------------
/tests/pandas/validation/test_pandas_validator_ranking_param_checks.py:
--------------------------------------------------------------------------------
1 | from collections import ChainMap
2 | from datetime import datetime, timedelta
3 |
4 | import pandas as pd
5 | import pytest
6 |
7 | import arize.pandas.validation.errors as err
8 | from arize.pandas.logger import Schema
9 | from arize.pandas.validation.validator import Validator
10 | from arize.utils.types import Environments, ModelTypes
11 |
12 | kwargs = {
13 | "model_id": "rank",
14 | "model_type": ModelTypes.RANKING,
15 | "environment": Environments.PRODUCTION,
16 | "dataframe": pd.DataFrame(
17 | {
18 | "prediction_timestamp": pd.Series(
19 | [
20 | datetime.now(),
21 | datetime.now() + timedelta(days=1),
22 | datetime.now() - timedelta(days=364),
23 | datetime.now() + timedelta(days=364),
24 | ]
25 | ),
26 | "prediction_id": pd.Series(["x_1", "x_2", "y_1", "y_2"]),
27 | "prediction_group_id": pd.Series(["X", "X", "Y", "Y"]),
28 | "item_type": pd.Series(["toy", "game", "game", "pens"]),
29 | "ranking_rank": pd.Series([1, 2, 1, 2]),
30 | "ranking_category": pd.Series(
31 | [
32 | ["click", "purchase"],
33 | ["click", "favor"],
34 | ["favor"],
35 | ["click"],
36 | ]
37 | ),
38 | "ranking_relevance": pd.Series([1, 0, 2, 0]),
39 | }
40 | ),
41 | "schema": Schema(
42 | prediction_id_column_name="prediction_id",
43 | prediction_group_id_column_name="prediction_group_id",
44 | timestamp_column_name="prediction_timestamp",
45 | feature_column_names=["item_type"],
46 | rank_column_name="ranking_rank",
47 | actual_score_column_name="ranking_relevance",
48 | actual_label_column_name="ranking_category",
49 | ),
50 | }
51 |
52 |
53 | def test_ranking_param_check_happy_path():
54 | errors = Validator.validate_params(**kwargs)
55 | assert len(errors) == 0
56 |
57 |
58 | def test_ranking_param_check_allow_delayed_actuals():
59 | # The following Schema has no prediction information
60 | # we should allow delayed actuals
61 | errors = Validator.validate_params(
62 | **ChainMap(
63 | {
64 | "schema": Schema(
65 | prediction_id_column_name="prediction_id",
66 | timestamp_column_name="prediction_timestamp",
67 | feature_column_names=["item_type"],
68 | actual_score_column_name="ranking_relevance",
69 | actual_label_column_name="ranking_category",
70 | )
71 | },
72 | kwargs,
73 | )
74 | )
75 | assert len(errors) == 0
76 |
77 |
78 | def test_ranking_param_check_missing_prediction_group_id():
79 | errors = Validator.validate_params(
80 | **ChainMap(
81 | {
82 | "schema": Schema(
83 | prediction_id_column_name="prediction_id",
84 | prediction_group_id_column_name=None,
85 | timestamp_column_name="prediction_timestamp",
86 | feature_column_names=["item_type"],
87 | rank_column_name="ranking_rank",
88 | actual_score_column_name="ranking_relevance",
89 | actual_label_column_name="ranking_category",
90 | )
91 | },
92 | kwargs,
93 | )
94 | )
95 | assert len(errors) == 1
96 | assert type(errors[0]) is err.MissingRequiredColumnsForRankingModel
97 |
98 |
99 | def test_ranking_param_check_missing_rank():
100 | errors = Validator.validate_params(
101 | **ChainMap(
102 | {
103 | "schema": Schema(
104 | prediction_id_column_name="prediction_id",
105 | prediction_group_id_column_name="prediction_group_id",
106 | timestamp_column_name="prediction_timestamp",
107 | feature_column_names=["item_type"],
108 | rank_column_name=None,
109 | actual_score_column_name="ranking_relevance",
110 | actual_label_column_name="ranking_category",
111 | )
112 | },
113 | kwargs,
114 | )
115 | )
116 | assert len(errors) == 1
117 | assert type(errors[0]) is err.MissingRequiredColumnsForRankingModel
118 |
119 |
120 | def test_ranking_param_check_missing_category():
121 | errors = Validator.validate_params(
122 | **ChainMap(
123 | {
124 | "schema": Schema(
125 | prediction_id_column_name="prediction_id",
126 | prediction_group_id_column_name="prediction_group_id",
127 | timestamp_column_name="prediction_timestamp",
128 | feature_column_names=["item_type"],
129 | rank_column_name="ranking_rank",
130 | actual_score_column_name=None,
131 | actual_label_column_name=None,
132 | )
133 | },
134 | kwargs,
135 | )
136 | )
137 | assert len(errors) == 0
138 |
139 |
140 | if __name__ == "__main__":
141 | raise SystemExit(pytest.main([__file__]))
142 |
--------------------------------------------------------------------------------
/tests/single_log/test_casting.py:
--------------------------------------------------------------------------------
1 | import math
2 | import sys
3 | from dataclasses import dataclass
4 | from typing import List, Union
5 |
6 | import pytest
7 |
8 | from arize.single_log.casting import cast_value
9 | from arize.single_log.errors import CastingError
10 | from arize.utils.types import ArizeTypes, TypedValue
11 |
12 |
13 | @dataclass
14 | class SingleLogTestCase:
15 | typed_value: TypedValue
16 | expected_value: Union[int, float, str, None]
17 | expected_error: CastingError = None
18 |
19 |
20 | @pytest.mark.skipif(sys.version_info < (3, 8), reason="Requires python>=3.8")
21 | def test_string_to_float_error():
22 | tv = TypedValue(value="fruit", type=ArizeTypes.FLOAT)
23 | tc = SingleLogTestCase(
24 | typed_value=tv,
25 | expected_value=None,
26 | expected_error=CastingError(
27 | error_msg="could not convert string to float: 'fruit'",
28 | typed_value=tv,
29 | ),
30 | )
31 | table_test([tc])
32 |
33 |
34 | @pytest.mark.skipif(sys.version_info < (3, 8), reason="Requires python>=3.8")
35 | def test_numeric_string_to_int_error():
36 | tv = TypedValue(value="10.2", type=ArizeTypes.INT)
37 | tc = SingleLogTestCase(
38 | typed_value=tv,
39 | expected_value=None,
40 | expected_error=CastingError(
41 | error_msg="invalid literal for int() with base 10: '10.2'",
42 | typed_value=tv,
43 | ),
44 | )
45 | table_test([tc])
46 |
47 |
48 | @pytest.mark.skipif(sys.version_info < (3, 8), reason="Requires python>=3.8")
49 | def test_empty_string_to_numeric_error():
50 | tv = TypedValue(value="", type=ArizeTypes.FLOAT)
51 | tc = SingleLogTestCase(
52 | typed_value=tv,
53 | expected_value=None,
54 | expected_error=CastingError(
55 | error_msg="could not convert string to float: ''",
56 | typed_value=tv,
57 | ),
58 | )
59 | table_test([tc])
60 |
61 |
62 | @pytest.mark.skipif(sys.version_info < (3, 8), reason="Requires python>=3.8")
63 | def test_float_to_int_error():
64 | tv = TypedValue(value=4.4, type=ArizeTypes.INT)
65 | tc = SingleLogTestCase(
66 | typed_value=tv,
67 | expected_value=None,
68 | expected_error=CastingError(
69 | # this is our custom error;
70 | # native python float->int casting succeeds by taking the floor of the float.
71 | error_msg="Cannot convert float with non-zero fractional part to int",
72 | typed_value=tv,
73 | ),
74 | )
75 | table_test([tc])
76 |
77 |
78 | @pytest.mark.skipif(sys.version_info < (3, 8), reason="Requires python>=3.8")
79 | def test_cast_to_float_no_error():
80 | tc1 = SingleLogTestCase(
81 | typed_value=TypedValue(value=1, type=ArizeTypes.FLOAT),
82 | expected_value=1.0,
83 | )
84 | tc2 = SingleLogTestCase(
85 | typed_value=TypedValue(value="7.7", type=ArizeTypes.FLOAT),
86 | expected_value=7.7,
87 | )
88 | tc3 = SingleLogTestCase(
89 | typed_value=TypedValue(value="NaN", type=ArizeTypes.FLOAT),
90 | expected_value=float("NaN"),
91 | )
92 | table_test([tc1, tc2, tc3])
93 |
94 |
95 | @pytest.mark.skipif(sys.version_info < (3, 8), reason="Requires python>=3.8")
96 | def test_cast_to_int_no_error():
97 | tc1 = SingleLogTestCase(
98 | typed_value=TypedValue(value=1.0, type=ArizeTypes.INT), expected_value=1
99 | )
100 | tc2 = SingleLogTestCase(
101 | typed_value=TypedValue(value="7", type=ArizeTypes.INT), expected_value=7
102 | )
103 | tc3 = SingleLogTestCase(
104 | typed_value=TypedValue(value=None, type=ArizeTypes.INT),
105 | expected_value=None,
106 | )
107 | table_test([tc1, tc2, tc3])
108 |
109 |
110 | @pytest.mark.skipif(sys.version_info < (3, 8), reason="Requires python>=3.8")
111 | def test_cast_to_string_no_error():
112 | tc1 = SingleLogTestCase(
113 | typed_value=TypedValue(value=1.0, type=ArizeTypes.STR),
114 | expected_value="1.0",
115 | )
116 | tc2 = SingleLogTestCase(
117 | typed_value=TypedValue(value=float("NaN"), type=ArizeTypes.STR),
118 | expected_value=None,
119 | )
120 | tc3 = SingleLogTestCase(
121 | typed_value=TypedValue(value=None, type=ArizeTypes.STR),
122 | expected_value=None,
123 | )
124 | table_test([tc1, tc2, tc3])
125 |
126 |
127 | def table_test(test_cases: List[SingleLogTestCase]):
128 | for test_case in test_cases:
129 | try:
130 | v = cast_value(test_case.typed_value)
131 | except Exception as e:
132 | if test_case.expected_error is None:
133 | pytest.fail("Unexpected error!")
134 | else:
135 | assert isinstance(e, CastingError)
136 | assert e.typed_value == test_case.expected_error.typed_value
137 | assert e.error_msg == test_case.expected_error.error_msg
138 | else:
139 | if test_case.expected_value is None:
140 | assert v is None
141 | elif not isinstance(test_case.expected_value, str) and math.isnan(
142 | test_case.expected_value
143 | ):
144 | assert math.isnan(v)
145 | else:
146 | assert test_case.expected_value == v
147 |
148 |
149 | if __name__ == "__main__":
150 | raise SystemExit(pytest.main([__file__]))
151 |
--------------------------------------------------------------------------------
/tests/test_utils.py:
--------------------------------------------------------------------------------
1 | import json
2 | from collections import namedtuple
3 |
4 | import pytest
5 |
6 | from arize.utils import utils
7 |
8 |
9 | def test_response_url():
10 | input = namedtuple("Response", ["content"])(
11 | json.dumps(
12 | {
13 | "realTimeIngestionUri": (
14 | "https://app.dev.arize.com/"
15 | "organizations/test-hmac-org/"
16 | "spaces/test-hmac-space/"
17 | "models/modelName/"
18 | "z-upload-classification-data-with-arize?"
19 | "selectedTab=overview"
20 | )
21 | }
22 | ).encode()
23 | )
24 | expected = (
25 | "https://app.dev.arize.com/"
26 | "organizations/test-hmac-org/"
27 | "spaces/test-hmac-space/models/"
28 | "modelName/"
29 | "z-upload-classification-data-with-arize?"
30 | "selectedTab=overview"
31 | )
32 | assert utils.reconstruct_url(input) == expected
33 |
34 |
35 | if __name__ == "__main__":
36 | raise SystemExit(pytest.main([__file__]))
37 |
--------------------------------------------------------------------------------
/tests/types/test_embedding_types.py:
--------------------------------------------------------------------------------
1 | import random
2 | import string
3 |
4 | import numpy as np
5 | import pandas as pd
6 | import pytest
7 |
8 | from arize.utils.constants import MAX_RAW_DATA_CHARACTERS
9 | from arize.utils.types import Embedding
10 |
11 |
12 | def random_string(N: int) -> str:
13 | return "".join(random.choices(string.ascii_uppercase + string.digits, k=N))
14 |
15 |
16 | long_raw_data_string = random_string(MAX_RAW_DATA_CHARACTERS)
17 | long_raw_data_token_array = [random_string(7) for _ in range(11000)]
18 | input_embeddings = {
19 | "correct:complete:list_vector": Embedding(
20 | vector=[1.0, 2, 3],
21 | data="this is a test sentence",
22 | link_to_data="https://my-bucket.s3.us-west-2.amazonaws.com/puppy.png",
23 | ),
24 | "correct:complete:ndarray_vector+list_data": Embedding(
25 | vector=np.array([1.0, 2, 3]),
26 | data=["This", "is", "a", "test", "token", "array"],
27 | link_to_data="https://my-bucket.s3.us-west-2.amazonaws.com/puppy.png",
28 | ),
29 | "correct:complete:pdSeries_vector+ndarray_data": Embedding(
30 | vector=pd.Series([1.0, 2, 3]),
31 | data=np.array(["This", "is", "a", "test", "token", "array"]),
32 | link_to_data="https://my-bucket.s3.us-west-2.amazonaws.com/puppy.png",
33 | ),
34 | "correct:complete:ndarray_vector+pdSeries_data": Embedding(
35 | vector=np.array([1.0, 2, 3]),
36 | data=pd.Series(["This", "is", "a", "test", "token", "array"]),
37 | link_to_data="https://my-bucket.s3.us-west-2.amazonaws.com/puppy.png",
38 | ),
39 | "correct:missing:data": Embedding(
40 | vector=np.array([1.0, 2, 3]),
41 | link_to_data="https://my-bucket.s3.us-west-2.amazonaws.com/puppy.png",
42 | ),
43 | "correct:missing:link_to_data": Embedding(
44 | vector=pd.Series([1.0, 2, 3]),
45 | data=["This", "is", "a", "test", "token", "array"],
46 | ),
47 | "correct:empty_vector": Embedding(
48 | vector=np.array([]),
49 | data=["This", "is", "a", "test", "token", "array"],
50 | ),
51 | "wrong_type:vector": Embedding(
52 | vector=pd.DataFrame([1.0, 2, 3]),
53 | data=2,
54 | link_to_data="https://my-bucket.s3.us-west-2.amazonaws.com/puppy.png",
55 | ),
56 | "wrong_type:data_num": Embedding(
57 | vector=pd.Series([1.0, 2, 3]),
58 | data=2,
59 | link_to_data="https://my-bucket.s3.us-west-2.amazonaws.com/puppy.png",
60 | ),
61 | "wrong_type:data_dataframe": Embedding(
62 | vector=pd.Series([1.0, 2, 3]),
63 | data=pd.DataFrame(["This", "is", "a", "test", "token", "array"]),
64 | link_to_data="https://my-bucket.s3.us-west-2.amazonaws.com/puppy.png",
65 | ),
66 | "wrong_type:link_to_data": Embedding(
67 | vector=np.array([1.0, 2, 3]),
68 | data=["This", "is", "a", "test", "token", "array"],
69 | link_to_data=True,
70 | ),
71 | "wrong_value:size_1_vector": Embedding(
72 | vector=np.array([1.0]),
73 | data=["This", "is", "a", "test", "token", "array"],
74 | ),
75 | "wrong_value:raw_data_string_too_long": Embedding(
76 | vector=pd.Series([1.0, 2, 3]),
77 | data=long_raw_data_string,
78 | ),
79 | "wrong_value:raw_data_token_array_too_long": Embedding(
80 | vector=pd.Series([1.0, 2, 3]),
81 | data=long_raw_data_token_array,
82 | ),
83 | }
84 |
85 |
86 | def test_correct_embeddings():
87 | keys = [key for key in input_embeddings if "correct:" in key]
88 | assert len(keys) > 0, "Test configuration error: keys must not be empty"
89 |
90 | for key in keys:
91 | embedding = input_embeddings[key]
92 | try:
93 | embedding.validate(key)
94 | except Exception as err:
95 | raise AssertionError(
96 | f"Correct embeddings should give no errors. Failing key = {key:s}. "
97 | f"Error = {err}"
98 | ) from None
99 |
100 |
101 | def test_wrong_value_fields():
102 | keys = [key for key in input_embeddings if "wrong_value:" in key]
103 | assert len(keys) > 0, "Test configuration error: keys must not be empty"
104 |
105 | for key in keys:
106 | embedding = input_embeddings[key]
107 | try:
108 | embedding.validate(key)
109 | except Exception as err:
110 | assert isinstance(
111 | err, ValueError
112 | ), "Wrong field values should raise value errors"
113 |
114 |
115 | def test_wrong_type_fields():
116 | keys = [key for key in input_embeddings if "wrong_type:" in key]
117 | assert len(keys) > 0, "Test configuration error: keys must not be empty"
118 |
119 | for key in keys:
120 | embedding = input_embeddings[key]
121 | try:
122 | embedding.validate(key)
123 | except Exception as err:
124 | assert isinstance(
125 | err, TypeError
126 | ), "Wrong field types should raise type errors"
127 |
128 |
129 | if __name__ == "__main__":
130 | raise SystemExit(pytest.main([__file__]))
131 |
--------------------------------------------------------------------------------
/tests/types/test_multi_class_types.py:
--------------------------------------------------------------------------------
1 | import pytest
2 |
3 | from arize.utils.constants import (
4 | MAX_MULTI_CLASS_NAME_LENGTH,
5 | MAX_NUMBER_OF_MULTI_CLASS_CLASSES,
6 | )
7 | from arize.utils.types import MultiClassActualLabel, MultiClassPredictionLabel
8 |
9 | overMaxClasses = {"class": 0.2}
10 | for i in range(MAX_NUMBER_OF_MULTI_CLASS_CLASSES):
11 | overMaxClasses[f"class_{i}"] = 0.2
12 | overMaxClassLen = "a"
13 | for _ in range(MAX_MULTI_CLASS_NAME_LENGTH):
14 | overMaxClassLen += "a"
15 |
16 | input_labels = {
17 | "correct:prediction_scores": MultiClassPredictionLabel(
18 | prediction_scores={"class1": 0.1, "class2": 0.2},
19 | ),
20 | "correct:threshold_scores": MultiClassPredictionLabel(
21 | prediction_scores={"class1": 0.1, "class2": 0.2},
22 | threshold_scores={"class1": 0.1, "class2": 0.2},
23 | ),
24 | "invalid:wrong_pred_dictionary_type": MultiClassPredictionLabel(
25 | prediction_scores={"class1": "score", "class2": "score2"},
26 | ),
27 | "invalid:no_prediction_scores": MultiClassPredictionLabel(
28 | prediction_scores={},
29 | ),
30 | "invalid:too many_prediction_scores": MultiClassPredictionLabel(
31 | prediction_scores=overMaxClasses,
32 | ),
33 | "invalid:pred_empty_class_name": MultiClassPredictionLabel(
34 | prediction_scores={"": 1.1, "class2": 0.2},
35 | ),
36 | "invalid:pred_class_name_too_long": MultiClassPredictionLabel(
37 | prediction_scores={overMaxClassLen: 1.1, "class2": 0.2},
38 | ),
39 | "invalid:pred_score_over_1": MultiClassPredictionLabel(
40 | prediction_scores={"class1": 1.1, "class2": 0.2},
41 | ),
42 | "invalid:wrong_thresh_dictionary_type": MultiClassPredictionLabel(
43 | prediction_scores={"class1": 0.1, "class2": 0.2},
44 | threshold_scores={"class1": "score", "class2": 0.2},
45 | ),
46 | "invalid:pred_thresh_not_same_num_scores": MultiClassPredictionLabel(
47 | prediction_scores={"class1": 0.1, "class2": 0.2},
48 | threshold_scores={"class1": 0.1},
49 | ),
50 | "invalid:pred_thresh_not_same_classes": MultiClassPredictionLabel(
51 | prediction_scores={"class1": 0.1, "class2": 0.2},
52 | threshold_scores={"class1": 0.1, "class3": 0.1},
53 | ),
54 | "invalid:thresh_score_under_0": MultiClassPredictionLabel(
55 | prediction_scores={"class1": 0.1, "class2": 0.2},
56 | threshold_scores={"class1": -1, "class2": 0.2},
57 | ),
58 | "correct:actual_scores": MultiClassActualLabel(
59 | actual_scores={"class1": 0, "class2": 1},
60 | ),
61 | "correct:actual_scores_multi_1": MultiClassActualLabel(
62 | actual_scores={"class1": 1, "class2": 1},
63 | ),
64 | "invalid:wrong_actual_dictionary_type": MultiClassActualLabel(
65 | actual_scores={"class1": "score", "class2": 0},
66 | ),
67 | "invalid:no_actual_scores": MultiClassActualLabel(
68 | actual_scores={},
69 | ),
70 | "invalid:too_many_actual_scores": MultiClassActualLabel(
71 | actual_scores=overMaxClasses,
72 | ),
73 | "invalid:actual_score_empty_class_name": MultiClassActualLabel(
74 | actual_scores={"": 1, "class2": 0},
75 | ),
76 | "invalid:act_class_name_too_long": MultiClassActualLabel(
77 | actual_scores={overMaxClassLen: 1.1, "class2": 0.2},
78 | ),
79 | "invalid:actual_score_not_0_or_1": MultiClassActualLabel(
80 | actual_scores={"class1": 0.7, "class2": 0.2},
81 | ),
82 | }
83 |
84 |
85 | def test_correct_multi_class_label():
86 | keys = [key for key in input_labels if "correct:" in key]
87 | assert len(keys) > 0, "Test configuration error: keys must not be empty"
88 |
89 | for key in keys:
90 | multi_class_label = input_labels[key]
91 | try:
92 | multi_class_label.validate()
93 | except Exception as err:
94 | raise AssertionError(
95 | f"Correct mutli class prediction label should give no errors. Failing key = {key:s}. "
96 | f"Error = {err}"
97 | ) from None
98 |
99 |
100 | def test_invalid_scores():
101 | keys = [key for key in input_labels if "invalid:" in key]
102 | assert len(keys) > 0, "Test configuration error: keys must not be empty"
103 |
104 | for key in keys:
105 | multi_class_label = input_labels[key]
106 | with pytest.raises(ValueError) as e:
107 | multi_class_label.validate()
108 | assert isinstance(
109 | e, ValueError
110 | ), "Invalid values should raise value errors"
111 |
112 |
113 | if __name__ == "__main__":
114 | raise SystemExit(pytest.main([__file__]))
115 |
--------------------------------------------------------------------------------
/tests/types/test_type_helpers.py:
--------------------------------------------------------------------------------
1 | import pytest
2 |
3 | from arize.utils.types import is_dict_of, is_list_of
4 |
5 |
6 | def test_is_dict_of():
7 | # Assert only key
8 | assert (
9 | is_dict_of(
10 | {"class1": 0.1, "class2": 0.2},
11 | key_allowed_types=str,
12 | )
13 | is True
14 | )
15 | # Assert key and value exact types
16 | assert (
17 | is_dict_of(
18 | {"class1": 0.1, "class2": 0.2},
19 | key_allowed_types=str,
20 | value_allowed_types=float,
21 | )
22 | is True
23 | )
24 | assert (
25 | is_dict_of(
26 | {"class1": 0.1, "class2": 0.2},
27 | key_allowed_types=str,
28 | value_allowed_types=int,
29 | )
30 | is False
31 | )
32 | # Assert key and value union types
33 | assert (
34 | is_dict_of(
35 | {"class1": 0.1, "class2": 0.2},
36 | key_allowed_types=str,
37 | value_allowed_types=(str, float),
38 | )
39 | is True
40 | )
41 | # Assert key and exact list of value types
42 | assert (
43 | is_dict_of(
44 | {"class1": [1, 2], "class2": [3, 4]},
45 | key_allowed_types=str,
46 | value_list_allowed_types=int,
47 | )
48 | is True
49 | )
50 | # Assert key and exact list of value types
51 | assert (
52 | is_dict_of(
53 | {"class1": [1, 2], "class2": [3, 4]},
54 | key_allowed_types=str,
55 | value_list_allowed_types=str,
56 | )
57 | is False
58 | )
59 | # Assert key and union list of value types
60 | assert (
61 | is_dict_of(
62 | {"class1": [1, 2], "class2": [3, 4]},
63 | key_allowed_types=str,
64 | value_list_allowed_types=(str, int),
65 | )
66 | is True
67 | )
68 | assert (
69 | is_dict_of(
70 | {"class1": [1, 2], "class2": ["a", "b"]},
71 | key_allowed_types=str,
72 | value_list_allowed_types=(str, int),
73 | )
74 | is True
75 | )
76 | # Assert key and value and list of value types
77 | assert (
78 | is_dict_of(
79 | {"class1": 1, "class2": ["a", "b"], "class3": [0.4, 0.7]},
80 | key_allowed_types=str,
81 | value_allowed_types=int,
82 | value_list_allowed_types=(str, float),
83 | )
84 | is True
85 | )
86 | assert (
87 | is_dict_of(
88 | {"class1": 1, "class2": ["a", "b"], "class3": [0.4, 0.7]},
89 | key_allowed_types=str,
90 | value_allowed_types=str,
91 | value_list_allowed_types=(str, float),
92 | )
93 | is False
94 | )
95 |
96 |
97 | def test_is_list_of():
98 | assert is_list_of([1, 2], int) is True
99 | assert is_list_of([1, 2], float) is False
100 | assert is_list_of(["x", 2], int) is False
101 | assert is_list_of(["x", 2], (str, int)) is True
102 |
103 |
104 | if __name__ == "__main__":
105 | raise SystemExit(pytest.main([__file__]))
106 |
--------------------------------------------------------------------------------