├── .gitignore ├── .readthedocs.yaml ├── CHANGELOG.md ├── LICENSE.md ├── MANIFEST.in ├── README.md ├── docs ├── .gitignore ├── Makefile ├── README.md ├── make.bat ├── requirements.txt └── source │ ├── _static │ ├── custom.css │ ├── dark-logo.svg │ ├── light-logo.svg │ └── switcher.json │ ├── _templates │ ├── custom_sidebar.html │ └── navbar-mid.html │ ├── conf.py │ ├── index.md │ ├── llm-api │ ├── datasets.rst │ ├── evaluators.rst │ ├── exporter.rst │ ├── logger.rst │ └── types.rst │ └── ml-api │ ├── exporter.rst │ └── logger.rst ├── pyproject.toml ├── src └── arize │ ├── __init__.py │ ├── api.py │ ├── bounded_executor.py │ ├── experimental │ ├── __init__.py │ ├── datasets │ │ ├── __init__.py │ │ ├── core │ │ │ ├── __init__.py │ │ │ ├── client.py │ │ │ └── session.py │ │ ├── experiments │ │ │ ├── __init__.py │ │ │ ├── evaluators │ │ │ │ ├── __init__.py │ │ │ │ ├── base.py │ │ │ │ ├── exceptions.py │ │ │ │ ├── executors.py │ │ │ │ ├── rate_limiters.py │ │ │ │ └── utils.py │ │ │ ├── functions.py │ │ │ ├── tracing.py │ │ │ └── types.py │ │ ├── utils │ │ │ ├── __init__.py │ │ │ ├── constants.py │ │ │ └── experiment_utils.py │ │ └── validation │ │ │ ├── __init__.py │ │ │ ├── errors.py │ │ │ └── validator.py │ ├── integrations │ │ ├── whylabs │ │ │ ├── __init__.py │ │ │ ├── client.py │ │ │ ├── generator.py │ │ │ └── test.py │ │ ├── whylabs_vanguard_governance │ │ │ ├── __init__.py │ │ │ ├── client.py │ │ │ └── test.py │ │ └── whylabs_vanguard_ingestion │ │ │ ├── __init__.py │ │ │ ├── client.py │ │ │ ├── generator.py │ │ │ └── test.py │ ├── online_tasks │ │ ├── __init__.py │ │ └── dataframe_preprocessor.py │ └── prompt_hub │ │ ├── __init__.py │ │ ├── client.py │ │ ├── constants.py │ │ └── prompts │ │ ├── __init__.py │ │ ├── base.py │ │ ├── open_ai.py │ │ └── vertex_ai.py │ ├── exporter │ ├── README.md │ ├── __init__.py │ ├── core │ │ ├── __init__.py │ │ ├── client.py │ │ ├── query.py │ │ └── session.py │ ├── publicexporter_pb2.py │ └── utils │ │ ├── __init__.py │ │ ├── constants.py │ │ ├── errors.py │ │ ├── schema_parser.py │ │ ├── tracing.py │ │ └── validation.py │ ├── pandas │ ├── __init__.py │ ├── embeddings │ │ ├── __init__.py │ │ ├── auto_generator.py │ │ ├── base_generators.py │ │ ├── constants.py │ │ ├── cv_generators.py │ │ ├── errors.py │ │ ├── nlp_generators.py │ │ ├── tabular_generators.py │ │ └── usecases.py │ ├── etl │ │ ├── __init__.py │ │ ├── casting.py │ │ └── errors.py │ ├── generative │ │ ├── __init__.py │ │ ├── llm_evaluation │ │ │ ├── __init__.py │ │ │ ├── constants.py │ │ │ └── hf_metrics.py │ │ └── nlp_metrics │ │ │ ├── __init__.py │ │ │ └── hf_metrics.py │ ├── logger.py │ ├── proto │ │ ├── __init__.py │ │ └── requests_pb2.py │ ├── surrogate_explainer │ │ ├── __init__.py │ │ └── mimic.py │ ├── tracing │ │ ├── __init__.py │ │ ├── columns.py │ │ ├── constants.py │ │ ├── types.py │ │ ├── utils.py │ │ └── validation │ │ │ ├── __init__.py │ │ │ ├── annotations │ │ │ ├── __init__.py │ │ │ ├── annotations_validation.py │ │ │ ├── dataframe_form_validation.py │ │ │ └── value_validation.py │ │ │ ├── common │ │ │ ├── __init__.py │ │ │ ├── argument_validation.py │ │ │ ├── dataframe_form_validation.py │ │ │ ├── errors.py │ │ │ └── value_validation.py │ │ │ ├── evals │ │ │ ├── __init__.py │ │ │ ├── dataframe_form_validation.py │ │ │ ├── evals_validation.py │ │ │ └── value_validation.py │ │ │ ├── metadata │ │ │ ├── __init__.py │ │ │ ├── argument_validation.py │ │ │ ├── dataframe_form_validation.py │ │ │ └── value_validation.py │ │ │ └── spans │ │ │ ├── __init__.py │ │ │ ├── dataframe_form_validation.py │ │ │ ├── spans_validation.py │ │ │ └── value_validation.py │ └── validation │ │ ├── __init__.py │ │ ├── errors.py │ │ └── validator.py │ ├── public_pb2.py │ ├── single_log │ ├── __init__.py │ ├── casting.py │ └── errors.py │ ├── utils │ ├── __init__.py │ ├── constants.py │ ├── errors.py │ ├── logging.py │ ├── model_mapping.json │ ├── proto.py │ ├── types.py │ └── utils.py │ └── version.py └── tests ├── __init__.py ├── experimental ├── __init__.py └── datasets │ ├── __init__.py │ ├── experiments │ ├── __init__.py │ └── test_experiments.py │ └── validation │ ├── __init__.py │ └── test_validator.py ├── exporter ├── __init__.py ├── test_exporter.py ├── utils │ ├── __init__.py │ └── test_schema_parser.py └── validations │ ├── test_validator_invalid_types.py │ └── test_validator_invalid_values.py ├── fixtures ├── __init__.py └── mpg.csv ├── pandas ├── etl │ └── test_casting_config.py ├── generative │ └── nlp_metrics │ │ └── test_nlp_metrics.py ├── logger │ ├── test_log_annotations.py │ └── test_log_metadata.py ├── surrogate_explainer │ └── test_surrogate_explainer.py ├── test_pandas_logger.py ├── tracing │ ├── test_columns.py │ ├── test_logger_tracing.py │ └── validation │ │ ├── test_invalid_annotation_arguments.py │ │ ├── test_invalid_annotation_values.py │ │ ├── test_invalid_arguments.py │ │ ├── test_invalid_metadata_arguments.py │ │ ├── test_invalid_metadata_values.py │ │ └── test_invalid_values.py └── validation │ ├── test_pandas_validator_error_messages.py │ ├── test_pandas_validator_invalid_params.py │ ├── test_pandas_validator_invalid_records.py │ ├── test_pandas_validator_invalid_reserved_columns.py │ ├── test_pandas_validator_invalid_shap_suffix.py │ ├── test_pandas_validator_invalid_types.py │ ├── test_pandas_validator_invalid_values.py │ ├── test_pandas_validator_missing_columns.py │ ├── test_pandas_validator_modelmapping.py │ ├── test_pandas_validator_ranking_param_checks.py │ ├── test_pandas_validator_ranking_record_checks.py │ ├── test_pandas_validator_ranking_type_checks.py │ ├── test_pandas_validator_ranking_value_checks.py │ └── test_pandas_validator_required_checks.py ├── single_log └── test_casting.py ├── test_api.py ├── test_utils.py └── types ├── test_embedding_types.py ├── test_multi_class_types.py └── test_type_helpers.py /.gitignore: -------------------------------------------------------------------------------- 1 | *.py[cod] 2 | 3 | # C extensions 4 | *.so 5 | 6 | # IDEs 7 | .idea/ 8 | .vscode/ 9 | 10 | # Packages 11 | *.egg 12 | *.egg-info 13 | .eggs/ 14 | # build 15 | parts 16 | bin 17 | var 18 | sdist 19 | dist 20 | develop-eggs 21 | .installed.cfg 22 | lib 23 | lib64 24 | 25 | # Installer logs 26 | pip-log.txt 27 | 28 | # Unit test / coverage reports 29 | .coverage 30 | .tox 31 | nosetests.xml 32 | 33 | # Complexity 34 | output/*.html 35 | output/*/index.html 36 | 37 | # Sphinx 38 | docs/_build 39 | 40 | # Cookiecutter 41 | output/ 42 | 43 | .workon 44 | -------------------------------------------------------------------------------- /.readthedocs.yaml: -------------------------------------------------------------------------------- 1 | # Read the Docs configuration file for Sphinx projects 2 | # See https://docs.readthedocs.io/en/stable/config-file/v2.html for details 3 | 4 | # Required 5 | version: 2 6 | 7 | # Set the OS, Python version and other tools you might need 8 | build: 9 | os: ubuntu-22.04 10 | tools: 11 | python: "3.12" 12 | # You can also specify other tool versions: 13 | # nodejs: "20" 14 | # rust: "1.70" 15 | # golang: "1.20" 16 | 17 | # Build documentation in the "docs/" directory with Sphinx 18 | sphinx: 19 | configuration: docs/source/conf.py 20 | # You can configure Sphinx to use a different builder, for instance use the dirhtml builder for simpler URLs 21 | # builder: "dirhtml" 22 | # Fail on all warnings to avoid broken references 23 | # fail_on_warning: true 24 | 25 | # Optionally build your docs in additional formats such as PDF and ePub 26 | # formats: 27 | # - pdf 28 | # - epub 29 | 30 | # Optional but recommended, declare the Python requirements required 31 | # to build your documentation 32 | # See https://docs.readthedocs.io/en/stable/guides/reproducible-builds.html 33 | python: 34 | install: 35 | - requirements: docs/requirements.txt 36 | - method: pip 37 | path: . 38 | -------------------------------------------------------------------------------- /LICENSE.md: -------------------------------------------------------------------------------- 1 | Copyright (c) 2020, Arize AI 2 | All rights reserved. 3 | 4 | Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: 5 | 6 | 1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. 7 | 8 | 2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. 9 | 10 | 3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission. 11 | 12 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 13 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | include README.md requirements.txt LICENSE.md arize/utils/*.json 2 | -------------------------------------------------------------------------------- /docs/.gitignore: -------------------------------------------------------------------------------- 1 | build 2 | -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line, and also 5 | # from the environment for the first two. 6 | SPHINXOPTS ?= 7 | SPHINXBUILD ?= sphinx-build 8 | SPHINXPROJ = arize 9 | SOURCEDIR = source 10 | BUILDDIR = build 11 | 12 | # Put it first so that "make" without argument is like "make help". 13 | help: 14 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 15 | 16 | .PHONY: help Makefile 17 | 18 | # Catch-all target: route all unknown targets to Sphinx using the new 19 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 20 | %: Makefile 21 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 22 | -------------------------------------------------------------------------------- /docs/README.md: -------------------------------------------------------------------------------- 1 | # Maintenance README for Arize Sphinx API Documentation 2 | 3 | This API reference provides comprehensive details for Arize's API. The documentation covers only public, user-facing API endpoints offered in Arize. 4 | 5 | Maintaining the API reference consists of two parts: 6 | 7 | 1. Building the documentation with Sphinx 8 | 2. Hosting and CI with readthedocs 9 | 10 | ## TL;DR 11 | ``` 12 | uv venv --python=python3.11 13 | uv pip install -r requirements.txt 14 | make clean html 15 | # then open build/html/index.html in your browser 16 | # currently, the build/html directory is copied over as the static site for arize-docs.onrender.com 17 | ``` 18 | 19 | ## Files 20 | - conf.py: All sphinx-related configuration is done here and is necessary to run Sphinx. 21 | - index.md: Main entrypoint for the API reference. This file must be in the `source` directory. For documentation to show up on the API reference, there must be a path (does not have to be direct) defined in index.md to the target documentation file. 22 | - requirements.txt: This file is necessary for management of dependencies on the readthedocs platform and its build process. 23 | - make files: Not required but useful in generating static HTML pages locally. 24 | 25 | ## Useful references 26 | https://pydata-sphinx-theme.readthedocs.io/ 27 | https://sphinx-design.readthedocs.io/en/latest/ 28 | https://www.sphinx-doc.org/en/master/usage/restructuredtext/basics.html 29 | https://www.sphinx-doc.org/en/master/usage/extensions/autodoc.html 30 | https://docs.readthedocs.io/en/stable/automation-rules.html -------------------------------------------------------------------------------- /docs/make.bat: -------------------------------------------------------------------------------- 1 | @ECHO OFF 2 | 3 | pushd %~dp0 4 | 5 | REM Command file for Sphinx documentation 6 | 7 | if "%SPHINXBUILD%" == "" ( 8 | set SPHINXBUILD=sphinx-build 9 | ) 10 | set SOURCEDIR=source 11 | set BUILDDIR=build 12 | set SPHINXPROJ=arize 13 | 14 | %SPHINXBUILD% >NUL 2>NUL 15 | if errorlevel 9009 ( 16 | echo. 17 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx 18 | echo.installed, then set the SPHINXBUILD environment variable to point 19 | echo.to the full path of the 'sphinx-build' executable. Alternatively you 20 | echo.may add the Sphinx directory to PATH. 21 | echo. 22 | echo.If you don't have Sphinx installed, grab it from 23 | echo.https://www.sphinx-doc.org/ 24 | exit /b 1 25 | ) 26 | 27 | if "%1" == "" goto help 28 | 29 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 30 | goto end 31 | 32 | :help 33 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 34 | 35 | :end 36 | popd 37 | -------------------------------------------------------------------------------- /docs/requirements.txt: -------------------------------------------------------------------------------- 1 | myst_parser 2 | sphinx==7.3.7 3 | pydata-sphinx-theme 4 | linkify-it-py 5 | sphinx_design 6 | 7 | -e ".[Datasets]" 8 | -------------------------------------------------------------------------------- /docs/source/_static/custom.css: -------------------------------------------------------------------------------- 1 | @import "pydata_sphinx_theme.scss"; 2 | 3 | /* -- index ------------------------------------------------------------------ */ 4 | 5 | .bd-article h1 { 6 | display: flex; 7 | justify-content: center; 8 | } 9 | 10 | /* .bd-article #api-definition h2 { 11 | display: flex; 12 | justify-content: center; 13 | } */ 14 | 15 | #main-content > .bd-article > img { 16 | display: flex; 17 | justify-content: center; 18 | } 19 | 20 | #arize-api-reference > #external-links { 21 | display: flex; 22 | flex-direction: row; 23 | justify-content: space-around; 24 | } 25 | 26 | #arize-api-reference > #api-definition > .toctree-wrapper > ul { 27 | display: flex; 28 | flex-direction: row; 29 | justify-content: space-between; 30 | flex-wrap: wrap; 31 | } 32 | 33 | #arize-api-reference > #api-definition > .toctree-wrapper > ul > li { 34 | padding: 1em 1em 1em 1em; 35 | } 36 | 37 | /* -- navbar display --------------------------------------------------------- */ 38 | 39 | .navbar-nav .nav-link { 40 | text-transform: capitalize; 41 | } 42 | 43 | /* -- sidebar display -------------------------------------------------------- */ 44 | 45 | /* -- signature display ------------------------------------------------------ */ 46 | 47 | .sig.sig-object.py { 48 | font-style: normal !important; 49 | font-feature-settings: "kern"; 50 | font-family: "Roboto Mono", "Courier New", Courier, monospace; 51 | background: var(--pst-color-surface); 52 | } 53 | 54 | .sig.sig-object.py .sig-param, 55 | .sig-paren { 56 | font-weight: normal; 57 | } 58 | 59 | .sig.sig-object.py em { 60 | font-style: normal; 61 | } 62 | -------------------------------------------------------------------------------- /docs/source/_static/dark-logo.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | -------------------------------------------------------------------------------- /docs/source/_static/light-logo.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | -------------------------------------------------------------------------------- /docs/source/_static/switcher.json: -------------------------------------------------------------------------------- 1 | [ 2 | { 3 | "version": "latest", 4 | "url": "https://arize-client-python.readthedocs.io/en/latest/", 5 | "preferred": true 6 | }, 7 | { 8 | "version": "v7.43.0", 9 | "url": "https://arize-client-python.readthedocs.io/en/v7.43.0/" 10 | }, 11 | { 12 | "version": "v7.42.0", 13 | "url": "https://arize-client-python.readthedocs.io/en/v7.42.0/" 14 | }, 15 | { 16 | "version": "v7.41.0", 17 | "url": "https://arize-client-python.readthedocs.io/en/v7.41.0/" 18 | }, 19 | { 20 | "version": "v7.40.0", 21 | "url": "https://arize-client-python.readthedocs.io/en/v7.40.0/" 22 | }, 23 | { 24 | "version": "v7.39.0", 25 | "url": "https://arize-client-python.readthedocs.io/en/v7.39.0/" 26 | }, 27 | { 28 | "version": "v7.38.0", 29 | "url": "https://arize-client-python.readthedocs.io/en/v7.38.0/" 30 | }, 31 | { 32 | "version": "v7.37.0", 33 | "url": "https://arize-client-python.readthedocs.io/en/v7.37.0/" 34 | }, 35 | { 36 | "version": "v7.36.0", 37 | "url": "https://arize-client-python.readthedocs.io/en/v7.36.0/" 38 | }, 39 | { 40 | "version": "v7.35.0", 41 | "url": "https://arize-client-python.readthedocs.io/en/v7.35.0/" 42 | }, 43 | { 44 | "version": "v7.34.0", 45 | "url": "https://arize-client-python.readthedocs.io/en/v7.34.0/" 46 | }, 47 | { 48 | "version": "v7.33.0", 49 | "url": "https://arize-client-python.readthedocs.io/en/v7.33.0/" 50 | }, 51 | { 52 | "version": "v7.32.0", 53 | "url": "https://arize-client-python.readthedocs.io/en/v7.32.0/" 54 | }, 55 | { 56 | "version": "v7.31.0", 57 | "url": "https://arize-client-python.readthedocs.io/en/v7.31.0/" 58 | }, 59 | { 60 | "version": "v7.30.0", 61 | "url": "https://arize-client-python.readthedocs.io/en/v7.30.0/" 62 | }, 63 | { 64 | "version": "v7.29.0", 65 | "url": "https://arize-client-python.readthedocs.io/en/v7.29.0/" 66 | }, 67 | { 68 | "version": "v7.28.0", 69 | "url": "https://arize-client-python.readthedocs.io/en/v7.28.0/" 70 | }, 71 | { 72 | "version": "v7.27.1", 73 | "url": "https://arize-client-python.readthedocs.io/en/v7.27.1/" 74 | }, 75 | { 76 | "version": "v7.26.0", 77 | "url": "https://arize-client-python.readthedocs.io/en/v7.26.0/" 78 | }, 79 | { 80 | "version": "v7.25.0", 81 | "url": "https://arize-client-python.readthedocs.io/en/v7.25.0/" 82 | } 83 | ] 84 | -------------------------------------------------------------------------------- /docs/source/_templates/custom_sidebar.html: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /docs/source/_templates/navbar-mid.html: -------------------------------------------------------------------------------- 1 | {# Displays links to the top-level TOCtree elements, in the header navbar. #} 2 | -------------------------------------------------------------------------------- /docs/source/index.md: -------------------------------------------------------------------------------- 1 | --- 2 | myst: 3 | html_meta: 4 | "description lang=en": | 5 | Top-level documentation for arize, 6 | with links to the rest of the site.. 7 | html_theme.sidebar_secondary.remove: true 8 | --- 9 | 10 | # Arize Python SDK Reference 11 | 12 |
13 | 27 |
28 | 29 | Below you can find the classes, functions, and parameters for tracing and evaluating your application using the Arize Python SDK. To get a complete guide on how to use Arize, including tutorials, quickstarts, and concept explanations, read our [docs](https://docs.arize.com/arize). This primarily covers LLM SDK methods, and we're documenting the rest of the methods for the full SDK spanning ML and CV use cases soon! 30 | 31 | ```{seealso} 32 | Check out our [Slack community](https://arize-ai.slack.com/join/shared_invite/zt-1px8dcmlf-fmThhDFD_V_48oU7ALan4Q#/shared-invite/email) and [GitHub repository](https://github.com/arize-ai/client_python)! 33 | ``` 34 | 35 | ::::{grid} 2 36 | :::{grid-item} 37 | ```{toctree} 38 | :maxdepth: 1 39 | :caption: LLM API 40 | llm-api/datasets 41 | llm-api/evaluators 42 | llm-api/exporter 43 | llm-api/logger 44 | llm-api/types 45 | ``` 46 | ::: 47 | :::{grid-item} 48 | ```{toctree} 49 | :maxdepth: 1 50 | :caption: ML API 51 | ml-api 52 | ml-api/exporter 53 | ml-api/logger 54 | ``` 55 | ::: 56 | :::: 57 | 58 | 59 | 60 | 61 | -------------------------------------------------------------------------------- /docs/source/llm-api/datasets.rst: -------------------------------------------------------------------------------- 1 | datasets & experiments 2 | ====================== 3 | Use datasets to curate spans from your LLM applications for testing. Run experiments to test different models, prompts, parameters for your LLM apps. `Read our quickstart guide for more information `_. 4 | 5 | To use in your code, import the following: 6 | 7 | ``from arize.experimental.datasets import ArizeDatasetsClient`` 8 | 9 | .. autoclass:: arize.experimental.datasets.ArizeDatasetsClient 10 | :members: 11 | :exclude-members: port, host, otlp_endpoint, __init__, api_key, developer_key, scheme, session 12 | -------------------------------------------------------------------------------- /docs/source/llm-api/evaluators.rst: -------------------------------------------------------------------------------- 1 | evaluators 2 | ^^^^^^^^^^ 3 | These are used to create evaluators as a class. `See our docs for more information `_. 4 | 5 | To import evaluators, use the following: 6 | ``from arize.experimental.datasets.experiments.evaluators.base import ...`` 7 | 8 | .. autoclass:: arize.experimental.datasets.experiments.evaluators.base.Evaluator 9 | :members: evaluate, async_evaluate 10 | :exclude-members: kind, name, _name -------------------------------------------------------------------------------- /docs/source/llm-api/exporter.rst: -------------------------------------------------------------------------------- 1 | exporter 2 | ======== 3 | Use this to export data from Arize. `Read this guide for more information `_. 4 | 5 | To use in your code, import the following: 6 | 7 | ``from arize.exporter import ArizeExportClient`` 8 | 9 | .. automodule:: arize.exporter 10 | :members: 11 | :exclude-members: api_key, arize_config_path, arize_profile, get_progress_bar, host, port, scheme, session, __init__, get_arize_schema -------------------------------------------------------------------------------- /docs/source/llm-api/logger.rst: -------------------------------------------------------------------------------- 1 | logger 2 | ====== 3 | Use this to log evaluations or spans to Arize in bulk. Read our quickstart guide for `logging evaluations to Arize `_. 4 | 5 | To use in your code, import the following: 6 | 7 | ``from arize.pandas.logger import Client`` 8 | 9 | .. automodule:: arize.pandas.logger 10 | :members: 11 | :exclude-members: __init__, log, InvalidSessionError, FlightSession -------------------------------------------------------------------------------- /docs/source/llm-api/types.rst: -------------------------------------------------------------------------------- 1 | types 2 | ^^^^^ 3 | These are the classes used across the experiment functions. 4 | 5 | To import types, use the following: 6 | ``from arize.experimental.datasets.experiments.types import ...`` 7 | 8 | .. autoclass:: arize.experimental.datasets.experiments.types.Example 9 | :exclude-members: id, updated_at, input, output, metadata, dataset_row, from_dict 10 | 11 | .. autoclass:: arize.experimental.datasets.experiments.types.EvaluationResult 12 | :exclude-members: score, label, explanation, metadata, from_dict 13 | 14 | .. autoclass:: arize.experimental.datasets.experiments.types.ExperimentRun 15 | :exclude-members: start_time, end_time, experiment_id, dataset_example_id, repetition_number, output, error, id, trace_id, from_dict 16 | 17 | .. autoclass:: arize.experimental.datasets.experiments.types.ExperimentEvaluationRun 18 | :exclude-members: start_time, end_time, experiment_id, dataset_example_id, repetition_number, output, error, id, trace_id, annotator_kind, experiment_run_id, from_dict, name, result 19 | -------------------------------------------------------------------------------- /docs/source/ml-api/exporter.rst: -------------------------------------------------------------------------------- 1 | exporter 2 | ======== 3 | Use this to export data from Arize. `Read this guide for more information `_. 4 | 5 | To use in your code, import the following: 6 | 7 | ``from arize.exporter import ArizeExportClient`` 8 | 9 | .. automodule:: arize.exporter 10 | :members: 11 | :exclude-members: api_key, arize_config_path, arize_profile, get_progress_bar, host, port, scheme, session, __init__, get_arize_schema -------------------------------------------------------------------------------- /docs/source/ml-api/logger.rst: -------------------------------------------------------------------------------- 1 | logger 2 | ====== 3 | 4 | client 5 | ^^^^^^ 6 | Use this to log inferences from your model (ML, CV, NLP, etc.) to Arize. 7 | 8 | To use in your code, import the following: ``from arize.pandas.logger import Client`` 9 | 10 | .. automodule:: arize.pandas.logger 11 | :members: 12 | :exclude-members: __init__, log_spans, log_evaluations, log_evaluations_sync, FlightSession, InvalidSessionError 13 | 14 | 15 | client (single record) 16 | ^^^^^^^^^^^^^^^^^^^^^^^ 17 | To use in your code, import the following: ``from arize.api import Client`` 18 | 19 | .. automodule:: arize.api 20 | :members: 21 | :exclude-members: __init__ -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = "arize" 3 | description = "A helper library to interact with Arize AI APIs" 4 | readme = "README.md" 5 | requires-python = ">=3.6" 6 | license = { text = "BSD" } 7 | keywords = [ 8 | "Arize", 9 | "Observability", 10 | "Monitoring", 11 | "Explainability", 12 | "Tracing", 13 | "LLM", 14 | "Evaluations", 15 | ] 16 | authors = [ 17 | { name = "Arize AI", email = "support@arize.com" }, 18 | ] 19 | maintainers = [ 20 | { name = "Arize AI", email = "support@arize.com" }, 21 | ] 22 | classifiers = [ 23 | "Development Status :: 5 - Production/Stable", 24 | "Programming Language :: Python", 25 | "Programming Language :: Python :: 3.6", 26 | "Programming Language :: Python :: 3.7", 27 | "Programming Language :: Python :: 3.8", 28 | "Programming Language :: Python :: 3.9", 29 | "Programming Language :: Python :: 3.10", 30 | "Programming Language :: Python :: 3.11", 31 | "Programming Language :: Python :: 3.12", 32 | "Programming Language :: Python :: 3.13", 33 | "Intended Audience :: Developers", 34 | "Topic :: Scientific/Engineering :: Artificial Intelligence", 35 | "Topic :: Software Development :: Libraries :: Python Modules", 36 | "Topic :: System :: Logging", 37 | "Topic :: System :: Monitoring", 38 | ] 39 | dependencies = [ 40 | "requests_futures==1.0.0", 41 | "googleapis_common_protos>=1.51.0,<2", 42 | "protobuf>=4.21.0,<6", 43 | "pandas>=0.25.3,<3", 44 | "pyarrow>=0.15.0", 45 | "tqdm>=4.60.0,<5", 46 | "pydantic>=2.0.0,<3", 47 | ] 48 | dynamic = ["version"] 49 | 50 | [project.urls] 51 | Homepage = "https://arize.com" 52 | Documentation = "https://docs.arize.com/arize" 53 | Issues = "https://github.com/Arize-ai/client_python/issues" 54 | Source = "https://github.com/Arize-ai/client_python" 55 | Changelog = "https://github.com/Arize-ai/client_python/blob/main/CHANGELOG.md" 56 | 57 | [project.optional-dependencies] 58 | dev = [ 59 | "pytest==8.3.3", 60 | "ruff==0.6.9", 61 | ] 62 | MimicExplainer = [ 63 | "interpret-community[mimic]>=0.22.0,<1", 64 | ] 65 | AutoEmbeddings = [ 66 | "transformers>=4.25, <5", 67 | "tokenizers>=0.13, <1", 68 | "datasets>=2.8, <3, !=2.14.*", 69 | "torch>=1.13, <3", 70 | "Pillow>=8.4.0, <11", 71 | ] 72 | NLP_Metrics = [ 73 | "nltk>=3.0.0, <4", 74 | "sacrebleu>=2.3.1, <3", 75 | "rouge-score>=0.1.2, <1", 76 | "evaluate>=0.3, <1", 77 | "datasets!=2.14.*", 78 | ] 79 | LLM_Evaluation = [ 80 | # To be removed in version 8 in favor of NLP_Metrics 81 | "nltk>=3.0.0, <4", 82 | "sacrebleu>=2.3.1, <3", 83 | "rouge-score>=0.1.2, <1", 84 | "evaluate>=0.3, <1", 85 | "datasets!=2.14.*", 86 | ] 87 | Tracing = [ 88 | "opentelemetry-semantic-conventions>=0.43b0, <1", 89 | "openinference-semantic-conventions>=0.1.12, <1", 90 | "deprecated", #opentelemetry-semantic-conventions requires it 91 | ] 92 | Datasets = [ 93 | "typing-extensions>=4, <5", 94 | "wrapt>=1.12.1, <2", 95 | "opentelemetry-semantic-conventions>=0.43b0, <1", 96 | "openinference-semantic-conventions>=0.1.6, <1", 97 | "opentelemetry-sdk>=1.25.0, <2", 98 | "opentelemetry-exporter-otlp>=1.25.0, <2", 99 | "deprecated", #opentelemetry-semantic-conventions requires it 100 | ] 101 | PromptHub = [ 102 | "gql>=3.0.0", 103 | "requests_toolbelt>=1.0.0", 104 | ] 105 | PromptHub_VertexAI = [ 106 | "google-cloud-aiplatform>=1.0.0", 107 | ] 108 | 109 | [build-system] 110 | requires = ["hatchling"] 111 | build-backend = "hatchling.build" 112 | 113 | [tool.hatch.version] 114 | path = "src/arize/version.py" 115 | 116 | [tool.hatch.build] 117 | only-packages = true 118 | 119 | [tool.hatch.build.targets.wheel] 120 | packages = ["src/arize"] 121 | 122 | [tool.hatch.build.targets.sdist] 123 | exclude = [ 124 | "src/arize/examples", 125 | "tests", 126 | "docs", 127 | ] 128 | 129 | 130 | [tool.black] 131 | include = '\.pyi?$' 132 | exclude = '(_pb2\.py$|docs/source/.*\.py)' 133 | 134 | [tool.ruff] 135 | target-version = "py37" 136 | line-length = 80 137 | exclude = [ 138 | "dist/", 139 | "__pycache__", 140 | "*_pb2.py*", 141 | "*_pb2_grpc.py*", 142 | "*.pyi", 143 | "docs/", 144 | ] 145 | [tool.ruff.format] 146 | docstring-code-format = true 147 | line-ending = "native" 148 | 149 | [tool.ruff.lint] 150 | select = [ 151 | # pycodestyle Error 152 | "E", 153 | # pycodestyle Warning 154 | "W", 155 | # Pyflakes 156 | "F", 157 | # pyupgrade 158 | "UP", 159 | # flake8-bugbear 160 | "B", 161 | # flake8-simplify 162 | "SIM", 163 | # isort 164 | "I", 165 | # TODO: Enable pydocstyle when ready for API docs 166 | # # pydocstyle 167 | # "D", 168 | ] 169 | ignore= [ 170 | "D203", # Do not use a blank line to separate the docstring from the class definition, 171 | "D212", # The summary line should be located on the second physical line of the docstring 172 | ] 173 | 174 | [tool.ruff.lint.isort] 175 | force-wrap-aliases = true 176 | 177 | [tool.ruff.lint.pycodestyle] 178 | max-doc-length = 110 179 | max-line-length = 110 180 | 181 | [tool.ruff.lint.pydocstyle] 182 | convention = "google" 183 | 184 | [tool.ruff.lint.pyupgrade] 185 | keep-runtime-typing = true 186 | -------------------------------------------------------------------------------- /src/arize/__init__.py: -------------------------------------------------------------------------------- 1 | from .version import __version__ 2 | 3 | __all__ = [ 4 | "__version__", 5 | ] 6 | -------------------------------------------------------------------------------- /src/arize/bounded_executor.py: -------------------------------------------------------------------------------- 1 | from concurrent.futures import ThreadPoolExecutor 2 | from threading import BoundedSemaphore 3 | 4 | 5 | class BoundedExecutor: 6 | """ 7 | BoundedExecutor behaves as a ThreadPoolExecutor which will block on 8 | calls to submit() once the limit given as "bound" work items are queued for 9 | execution. 10 | :param bound: Integer - the maximum number of items in the work queue 11 | :param max_workers: Integer - the size of the thread pool 12 | """ 13 | 14 | def __init__(self, bound, max_workers): 15 | self.executor = ThreadPoolExecutor(max_workers=max_workers) 16 | self.semaphore = BoundedSemaphore(bound + max_workers) 17 | 18 | """See concurrent.futures.Executor#submit""" 19 | 20 | def submit(self, fn, *args, **kwargs): 21 | self.semaphore.acquire() 22 | try: 23 | future = self.executor.submit(fn, *args, **kwargs) 24 | except Exception: 25 | self.semaphore.release() 26 | raise 27 | else: 28 | future.add_done_callback(lambda x: self.semaphore.release()) 29 | return future 30 | 31 | """See concurrent.futures.Executor#shutdown""" 32 | 33 | def shutdown(self, wait=True): 34 | self.executor.shutdown(wait) 35 | -------------------------------------------------------------------------------- /src/arize/experimental/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Arize-ai/client_python/8f72212f42490baa7a90831c591b83ab6173ba83/src/arize/experimental/__init__.py -------------------------------------------------------------------------------- /src/arize/experimental/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from .core.client import ArizeDatasetsClient 2 | 3 | __all__ = ["ArizeDatasetsClient"] 4 | -------------------------------------------------------------------------------- /src/arize/experimental/datasets/core/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Arize-ai/client_python/8f72212f42490baa7a90831c591b83ab6173ba83/src/arize/experimental/datasets/core/__init__.py -------------------------------------------------------------------------------- /src/arize/experimental/datasets/core/session.py: -------------------------------------------------------------------------------- 1 | import uuid 2 | from dataclasses import dataclass, field 3 | 4 | from pyarrow import flight 5 | 6 | from arize.utils.logging import logger 7 | from arize.utils.utils import get_python_version 8 | from arize.version import __version__ 9 | 10 | from ..utils.constants import DEFAULT_PACKAGE_NAME 11 | from ..validation.errors import InvalidSessionError 12 | 13 | 14 | @dataclass 15 | class Session: 16 | api_key: str 17 | host: str 18 | port: int 19 | scheme: str 20 | session_name: str = field(init=False) 21 | call_options: flight.FlightCallOptions = field(init=False) 22 | 23 | def __post_init__(self): 24 | self.session_name = f"python-sdk-{DEFAULT_PACKAGE_NAME}-{uuid.uuid4()}" 25 | logger.debug(f"Creating named session as '{self.session_name}'.") 26 | if self.api_key is None: 27 | logger.error(InvalidSessionError.error_message()) 28 | raise InvalidSessionError 29 | 30 | logger.debug( 31 | f"Created session with Arize API Key '{self.api_key}' at '{self.host}':'{self.port}'" 32 | ) 33 | self._set_headers() 34 | 35 | def connect(self) -> flight.FlightClient: 36 | """ 37 | Connects to Arize Flight server public endpoint with the 38 | provided api key. 39 | """ 40 | try: 41 | disable_cert = self.host.lower() == "localhost" 42 | client = flight.FlightClient( 43 | location=f"{self.scheme}://{self.host}:{self.port}", 44 | disable_server_verification=disable_cert, 45 | ) 46 | self.call_options = flight.FlightCallOptions(headers=self._headers) 47 | return client 48 | except Exception: 49 | logger.error( 50 | "There was an error trying to connect to the Arize Flight Endpoint" 51 | ) 52 | raise 53 | 54 | def _set_headers(self) -> None: 55 | self._headers = [ 56 | (b"origin", b"arize-python-datasets-client"), 57 | (b"auth-token-bin", f"{self.api_key}".encode()), 58 | (b"sdk-language", b"python"), 59 | (b"language-version", get_python_version().encode("utf-8")), 60 | (b"sdk-version", __version__.encode("utf-8")), 61 | (b"arize-interface", b"flight"), 62 | ] 63 | -------------------------------------------------------------------------------- /src/arize/experimental/datasets/experiments/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Arize-ai/client_python/8f72212f42490baa7a90831c591b83ab6173ba83/src/arize/experimental/datasets/experiments/__init__.py -------------------------------------------------------------------------------- /src/arize/experimental/datasets/experiments/evaluators/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Arize-ai/client_python/8f72212f42490baa7a90831c591b83ab6173ba83/src/arize/experimental/datasets/experiments/evaluators/__init__.py -------------------------------------------------------------------------------- /src/arize/experimental/datasets/experiments/evaluators/exceptions.py: -------------------------------------------------------------------------------- 1 | class ArizeException(Exception): 2 | pass 3 | 4 | 5 | class ArizeContextLimitExceeded(ArizeException): 6 | pass 7 | 8 | 9 | class ArizeTemplateMappingError(ArizeException): 10 | pass 11 | -------------------------------------------------------------------------------- /src/arize/experimental/datasets/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Arize-ai/client_python/8f72212f42490baa7a90831c591b83ab6173ba83/src/arize/experimental/datasets/utils/__init__.py -------------------------------------------------------------------------------- /src/arize/experimental/datasets/utils/constants.py: -------------------------------------------------------------------------------- 1 | from enum import Enum 2 | 3 | from openinference.semconv import trace 4 | 5 | from arize.pandas.proto import requests_pb2 as pb2 6 | 7 | INFERENCES = pb2.INFERENCES 8 | GENERATIVE = pb2.GENERATIVE 9 | 10 | """Internal Use""" 11 | 12 | # Default API endpoint when not provided through env variable nor profile 13 | DEFAULT_ARIZE_FLIGHT_HOST = "flight.arize.com" 14 | DEFAULT_ARIZE_FLIGHT_PORT = 443 15 | DEFAULT_ARIZE_OTLP_ENDPOINT = "https://otlp.arize.com/v1" 16 | 17 | # Name of the current package. 18 | DEFAULT_PACKAGE_NAME = "arize_python_datasets_client" 19 | 20 | # Default headers to trace and help identify requests. For debugging. 21 | DEFAULT_ARIZE_SESSION_ID = "x-arize-session-id" # Generally the session name. 22 | DEFAULT_ARIZE_TRACE_ID = "x-arize-trace-id" 23 | DEFAULT_PACKAGE_VERSION = "x-package-version" 24 | 25 | # Default initial wait time for retries in seconds. 26 | DEFAULT_RETRY_INITIAL_WAIT_TIME = 0.25 27 | 28 | # Default maximum wait time for retries in seconds. 29 | DEFAULT_RETRY_MAX_WAIT_TIME = 10.0 30 | 31 | # Default to use grpc + tls scheme. 32 | DEFAULT_TRANSPORT_SCHEME = "grpc+tls" 33 | 34 | 35 | class FlightActionKey(Enum): 36 | GET_DATASET_VERSION = "get_dataset_version" 37 | LIST_DATASETS = "list_datasets" 38 | DELETE_DATASET = "delete_dataset" 39 | CREATE_EXPERIMENT_DB_ENTRY = "create_experiment_db_entry" 40 | DELETE_EXPERIMENT = "delete_experiment" 41 | 42 | 43 | OPEN_INFERENCE_JSON_STR_TYPES = frozenset( 44 | [ 45 | trace.DocumentAttributes.DOCUMENT_METADATA, 46 | trace.SpanAttributes.LLM_FUNCTION_CALL, 47 | trace.SpanAttributes.LLM_INVOCATION_PARAMETERS, 48 | trace.SpanAttributes.LLM_PROMPT_TEMPLATE_VARIABLES, 49 | trace.MessageAttributes.MESSAGE_FUNCTION_CALL_ARGUMENTS_JSON, 50 | trace.SpanAttributes.METADATA, 51 | trace.SpanAttributes.TOOL_PARAMETERS, 52 | trace.ToolCallAttributes.TOOL_CALL_FUNCTION_ARGUMENTS_JSON, 53 | ] 54 | ) 55 | -------------------------------------------------------------------------------- /src/arize/experimental/datasets/utils/experiment_utils.py: -------------------------------------------------------------------------------- 1 | import dataclasses 2 | import datetime 3 | import functools 4 | from enum import Enum 5 | from pathlib import Path 6 | from typing import Any, Callable, Mapping, Sequence, Union 7 | 8 | try: 9 | from typing import get_args, get_origin # Python 3.8+ 10 | except ImportError: 11 | from typing_extensions import get_args, get_origin # For Python <3.8 12 | 13 | 14 | import numpy as np 15 | 16 | 17 | def get_func_name(fn: Callable[..., Any]) -> str: 18 | """ 19 | Makes a best-effort attempt to get the name of the function. 20 | """ 21 | if isinstance(fn, functools.partial): 22 | return fn.func.__qualname__ 23 | if hasattr(fn, "__qualname__") and not fn.__qualname__.endswith(""): 24 | return fn.__qualname__.split("..")[-1] 25 | return str(fn) 26 | 27 | 28 | def jsonify(obj: Any) -> Any: 29 | """ 30 | Coerce object to be json serializable. 31 | """ 32 | if isinstance(obj, Enum): 33 | return jsonify(obj.value) 34 | if isinstance(obj, (str, int, float, bool)) or obj is None: 35 | return obj 36 | if isinstance(obj, (list, set, frozenset, Sequence)): 37 | return [jsonify(v) for v in obj] 38 | if isinstance(obj, (dict, Mapping)): 39 | return {jsonify(k): jsonify(v) for k, v in obj.items()} 40 | if dataclasses.is_dataclass(obj): 41 | result = {} 42 | for field in dataclasses.fields(obj): 43 | k = field.name 44 | v = getattr(obj, k) 45 | if not ( 46 | v is None 47 | and get_origin(field) is Union 48 | and type(None) in get_args(field) 49 | ): 50 | result[k] = jsonify(v) 51 | return result 52 | if isinstance(obj, (datetime.date, datetime.datetime, datetime.time)): 53 | return obj.isoformat() 54 | if isinstance(obj, datetime.timedelta): 55 | return obj.total_seconds() 56 | if isinstance(obj, Path): 57 | return str(obj) 58 | if isinstance(obj, BaseException): 59 | return str(obj) 60 | if isinstance(obj, np.ndarray): 61 | return [jsonify(v) for v in obj] 62 | if hasattr(obj, "__float__"): 63 | return float(obj) 64 | if hasattr(obj, "model_dump") and callable(obj.model_dump): 65 | # pydantic v2 66 | try: 67 | d = obj 68 | assert isinstance(d, dict) 69 | except BaseException: 70 | pass 71 | else: 72 | return jsonify(d) 73 | if hasattr(obj, "dict") and callable(obj.dict): 74 | # pydantic v1 75 | try: 76 | d = obj.dict() 77 | assert isinstance(d, dict) 78 | except BaseException: 79 | pass 80 | else: 81 | return jsonify(d) 82 | cls = obj.__class__ 83 | return f"<{cls.__module__}.{cls.__name__} object>" 84 | -------------------------------------------------------------------------------- /src/arize/experimental/datasets/validation/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Arize-ai/client_python/8f72212f42490baa7a90831c591b83ab6173ba83/src/arize/experimental/datasets/validation/__init__.py -------------------------------------------------------------------------------- /src/arize/experimental/datasets/validation/errors.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | 3 | 4 | class DatasetError(Exception, ABC): 5 | def __str__(self) -> str: 6 | return self.error_message() 7 | 8 | @abstractmethod 9 | def __repr__(self) -> str: 10 | pass 11 | 12 | @abstractmethod 13 | def error_message(self) -> str: 14 | pass 15 | 16 | 17 | class InvalidSessionError(DatasetError): 18 | @staticmethod 19 | def error_message() -> str: 20 | return ( 21 | "Credentials not provided or invalid. Please pass in the correct api_key when " 22 | "initiating a new ArizeExportClient. Alternatively, you can set up credentials " 23 | "in a profile or as an environment variable" 24 | ) 25 | 26 | 27 | class InvalidConfigFileError(DatasetError): 28 | @staticmethod 29 | def error_message() -> str: 30 | return "Invalid/Misconfigured Configuration File" 31 | 32 | 33 | class IDColumnUniqueConstraintError(DatasetError): 34 | @staticmethod 35 | def error_message() -> str: 36 | return "'id' column must contain unique values" 37 | 38 | 39 | class RequiredColumnsError(DatasetError): 40 | def __init__(self, missing_columns: set) -> None: 41 | self.missing_columns = missing_columns 42 | 43 | def error_message(self) -> str: 44 | return f"Missing required columns: {self.missing_columns}" 45 | 46 | def __repr__(self) -> str: 47 | return f"RequiredColumnsError({self.missing_columns})" 48 | -------------------------------------------------------------------------------- /src/arize/experimental/datasets/validation/validator.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | import pandas as pd 4 | 5 | from . import errors as err 6 | 7 | 8 | class Validator: 9 | @staticmethod 10 | def validate( 11 | df: pd.DataFrame, 12 | ) -> List[err.DatasetError]: 13 | ## check all require columns are present 14 | required_columns_errors = Validator._check_required_columns(df) 15 | if required_columns_errors: 16 | return required_columns_errors 17 | 18 | ## check id column is unique 19 | id_column_unique_constraint_error = ( 20 | Validator._check_id_column_is_unique(df) 21 | ) 22 | if id_column_unique_constraint_error: 23 | return id_column_unique_constraint_error 24 | 25 | return [] 26 | 27 | @staticmethod 28 | def _check_required_columns(df: pd.DataFrame) -> List[err.DatasetError]: 29 | required_columns = ["id", "created_at", "updated_at"] 30 | missing_columns = set(required_columns) - set(df.columns) 31 | if missing_columns: 32 | return [err.RequiredColumnsError(missing_columns)] 33 | return [] 34 | 35 | @staticmethod 36 | def _check_id_column_is_unique(df: pd.DataFrame) -> List[err.DatasetError]: 37 | if not df["id"].is_unique: 38 | return [err.IDColumnUniqueConstraintError] 39 | return [] 40 | -------------------------------------------------------------------------------- /src/arize/experimental/integrations/whylabs/__init__.py: -------------------------------------------------------------------------------- 1 | from arize.experimental.integrations.whylabs.client import IntegrationClient 2 | from arize.experimental.integrations.whylabs.generator import ( 3 | WhylabsProfileAdapter, 4 | ) 5 | 6 | __all__ = ["WhylabsProfileAdapter", "IntegrationClient"] 7 | -------------------------------------------------------------------------------- /src/arize/experimental/integrations/whylabs_vanguard_governance/__init__.py: -------------------------------------------------------------------------------- 1 | from arize.experimental.integrations.whylabs_vanguard_governance.client import ( 2 | IntegrationClient, 3 | ) 4 | 5 | __all__ = ["IntegrationClient"] 6 | -------------------------------------------------------------------------------- /src/arize/experimental/integrations/whylabs_vanguard_governance/client.py: -------------------------------------------------------------------------------- 1 | # type: ignore[pb2] 2 | from datetime import datetime, timedelta 3 | from typing import Dict, Optional, Union 4 | 5 | import pandas as pd 6 | 7 | from arize.pandas.logger import Client 8 | from arize.utils.types import ( 9 | Environments, 10 | ModelTypes, 11 | Schema, 12 | ) 13 | 14 | 15 | class IntegrationClient: 16 | def __init__( 17 | self, 18 | api_key: Optional[str] = None, 19 | space_id: Optional[str] = None, 20 | space_key: Optional[str] = None, 21 | uri: Optional[str] = "https://api.arize.com/v1", 22 | additional_headers: Optional[Dict[str, str]] = None, 23 | request_verify: Union[bool, str] = True, 24 | developer_key: Optional[str] = None, 25 | host: Optional[str] = None, 26 | port: Optional[int] = None, 27 | ) -> None: 28 | """ 29 | Wrapper for the Arize Client specific to WhyLabs profiles. 30 | """ 31 | self._client = Client( 32 | api_key=api_key, 33 | space_id=space_id, 34 | space_key=space_key, 35 | uri=uri, 36 | additional_headers=additional_headers, 37 | request_verify=request_verify, 38 | developer_key=developer_key, 39 | host=host, 40 | port=port, 41 | ) 42 | 43 | def create_model( 44 | self, 45 | model_id: str, 46 | model_type: ModelTypes, 47 | environment: Environments = Environments.PRODUCTION, 48 | ) -> None: 49 | df = pd.DataFrame( 50 | { 51 | "timestamp": [datetime.now() - timedelta(days=365)], 52 | "ARIZE_PLACEHOLDER_STRING": "ARIZE_DUMMY", 53 | "ARIZE_PLACEHOLDER_FLOAT": 1, 54 | } 55 | ) 56 | 57 | if model_type == ModelTypes.RANKING: 58 | schema = Schema( 59 | timestamp_column_name="timestamp", 60 | prediction_group_id_column_name="ARIZE_PLACEHOLDER_FLOAT", 61 | rank_column_name="ARIZE_PLACEHOLDER_FLOAT", 62 | ) 63 | elif model_type == ModelTypes.MULTI_CLASS: 64 | df["ARIZE_PLACEHOLDER_MAP"] = [ 65 | [ 66 | { 67 | "class_name": "ARIZE_DUMMY", 68 | "score": 1.0, 69 | } 70 | ] 71 | ] 72 | schema = Schema( 73 | timestamp_column_name="timestamp", 74 | prediction_score_column_name="ARIZE_PLACEHOLDER_MAP", 75 | ) 76 | elif ( 77 | model_type == ModelTypes.NUMERIC 78 | or model_type == ModelTypes.REGRESSION 79 | ): 80 | schema = Schema( 81 | timestamp_column_name="timestamp", 82 | prediction_label_column_name="ARIZE_PLACEHOLDER_FLOAT", 83 | prediction_score_column_name="ARIZE_PLACEHOLDER_FLOAT", 84 | actual_label_column_name="ARIZE_PLACEHOLDER_FLOAT", 85 | actual_score_column_name="ARIZE_PLACEHOLDER_FLOAT", 86 | ) 87 | else: 88 | schema = Schema( 89 | timestamp_column_name="timestamp", 90 | prediction_label_column_name="ARIZE_PLACEHOLDER_STRING", 91 | prediction_score_column_name="ARIZE_PLACEHOLDER_FLOAT", 92 | actual_label_column_name="ARIZE_PLACEHOLDER_STRING", 93 | actual_score_column_name="ARIZE_PLACEHOLDER_FLOAT", 94 | ) 95 | 96 | return self._client.log( 97 | dataframe=df, 98 | schema=schema, 99 | environment=environment, 100 | model_id=model_id, 101 | model_type=model_type, 102 | ) 103 | -------------------------------------------------------------------------------- /src/arize/experimental/integrations/whylabs_vanguard_governance/test.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from arize.experimental.integrations.whylabs_vanguard_governance import ( 4 | IntegrationClient, 5 | ) 6 | from arize.utils.types import Environments, ModelTypes 7 | 8 | os.environ["ARIZE_SPACE_ID"] = "REPLACE_ME" 9 | os.environ["ARIZE_API_KEY"] = "REPLACE_ME" 10 | os.environ["ARIZE_DEVELOPER_KEY"] = "REPLACE_ME" 11 | 12 | 13 | def test_create_model(): 14 | client = IntegrationClient( 15 | api_key=os.environ["ARIZE_API_KEY"], 16 | space_id=os.environ["ARIZE_SPACE_ID"], 17 | developer_key=os.environ["ARIZE_DEVELOPER_KEY"], 18 | ) 19 | client.create_model( 20 | model_id="test-create-binary-classification", 21 | environment=Environments.PRODUCTION, 22 | model_type=ModelTypes.BINARY_CLASSIFICATION, 23 | ) 24 | 25 | 26 | if __name__ == "__main__": 27 | test_create_model() 28 | -------------------------------------------------------------------------------- /src/arize/experimental/integrations/whylabs_vanguard_ingestion/__init__.py: -------------------------------------------------------------------------------- 1 | from arize.experimental.integrations.whylabs_vanguard_ingestion.client import ( 2 | IntegrationClient, 3 | ) 4 | from arize.experimental.integrations.whylabs_vanguard_ingestion.generator import ( 5 | WhylabsVanguardProfileAdapter, 6 | ) 7 | 8 | __all__ = ["WhylabsVanguardProfileAdapter", "IntegrationClient"] 9 | -------------------------------------------------------------------------------- /src/arize/experimental/online_tasks/__init__.py: -------------------------------------------------------------------------------- 1 | from .dataframe_preprocessor import extract_nested_data_to_column 2 | 3 | __all__ = ["extract_nested_data_to_column"] 4 | -------------------------------------------------------------------------------- /src/arize/experimental/prompt_hub/__init__.py: -------------------------------------------------------------------------------- 1 | from .client import ArizePromptClient 2 | from .prompts import LLMProvider, Prompt, prompt_to_llm_input 3 | 4 | __all__ = ["ArizePromptClient", "Prompt", "LLMProvider", "prompt_to_llm_input"] 5 | -------------------------------------------------------------------------------- /src/arize/experimental/prompt_hub/constants.py: -------------------------------------------------------------------------------- 1 | """ 2 | Constants for the Arize Prompt Hub. 3 | """ 4 | 5 | # Mapping from internal model names to external model names 6 | ARIZE_INTERNAL_MODEL_MAPPING = { 7 | # ---------------------- 8 | # OpenAI Models 9 | # ---------------------- 10 | "GPT_4o_MINI": "gpt-4o-mini", 11 | "GPT_4o_MINI_2024_07_18": "gpt-4o-mini-2024-07-18", 12 | "GPT_4o": "gpt-4o", 13 | "GPT_4o_2024_05_13": "gpt-4o-2024-05-13", 14 | "GPT_4o_2024_08_06": "gpt-4o-2024-08-06", 15 | "CHATGPT_4o_LATEST": "chatgpt-4o-latest", 16 | "O1_PREVIEW": "o1-preview", 17 | "O1_PREVIEW_2024_09_12": "o1-preview-2024-09-12", 18 | "O1_MINI": "o1-mini", 19 | "O1_MINI_2024_09_12": "o1-mini-2024-09-12", 20 | "GPT_4_TURBO": "gpt-4-turbo", 21 | "GPT_4_TURBO_2024_04_09": "gpt-4-turbo-2024-04-09", 22 | "GPT_4_TURBO_PREVIEW": "gpt-4-turbo-preview", 23 | "GPT_4_0125_PREVIEW": "gpt-4-0125-preview", 24 | "GPT_4_1106_PREVIEW": "gpt-4-1106-preview", 25 | "GPT_4": "gpt-4", 26 | "GPT_4_32k": "gpt-4-32k", 27 | "GPT_4_0613": "gpt-4-0613", 28 | "GPT_4_0314": "gpt-4-0314", 29 | "GPT_4_VISION_PREVIEW": "gpt-4-vision-preview", 30 | "GPT_3_5_TURBO": "gpt-3.5-turbo", 31 | "GPT_3_5_TURBO_1106": "gpt-3.5-turbo-1106", 32 | "GPT_3_5_TURBO_INSTRUCT": "gpt-3.5-turbo-instruct", 33 | "GPT_3_5_TURBO_0125": "gpt-3.5-turbo-0125", 34 | "TEXT_DAVINCI_003": "text-davinci-003", 35 | "BABBAGE_002": "babbage-002", 36 | "DAVINCI_002": "davinci-002", 37 | "O1_2024_12_17": "o1-2024-12-17", 38 | "O1": "o1", 39 | "O3_MINI": "o3-mini", 40 | "O3_MINI_2025_01_31": "o3-mini-2025-01-31", 41 | # ---------------------- 42 | # Vertex AI Models 43 | # ---------------------- 44 | "GEMINI_1_0_PRO": "gemini-1.0-pro", 45 | "GEMINI_1_0_PRO_VISION": "gemini-1.0-pro-vision", 46 | "GEMINI_1_5_FLASH": "gemini-1.5-flash", 47 | "GEMINI_1_5_FLASH_002": "gemini-1.5-flash-002", 48 | "GEMINI_1_5_FLASH_8B": "gemini-1.5-flash-8b", 49 | "GEMINI_1_5_FLASH_LATEST": "gemini-1.5-flash-latest", 50 | "GEMINI_1_5_FLASH_8B_LATEST": "gemini-1.5-flash-8b-latest", 51 | "GEMINI_1_5_PRO": "gemini-1.5-pro", 52 | "GEMINI_1_5_PRO_001": "gemini-1.5-pro-001", 53 | "GEMINI_1_5_PRO_002": "gemini-1.5-pro-002", 54 | "GEMINI_1_5_PRO_LATEST": "gemini-1.5-pro-latest", 55 | "GEMINI_2_0_FLASH": "gemini-2.0-flash", 56 | "GEMINI_2_0_FLASH_001": "gemini-2.0-flash-001", 57 | "GEMINI_2_0_FLASH_EXP": "gemini-2.0-flash-exp", 58 | "GEMINI_2_0_FLASH_LITE_PREVIEW_02_05": "gemini-2.0-flash-lite-preview-02-05", 59 | "GEMINI_PRO": "gemini-pro", 60 | } 61 | 62 | # Reverse mapping from external model names to internal model names 63 | ARIZE_EXTERNAL_MODEL_MAPPING = { 64 | v: k for k, v in ARIZE_INTERNAL_MODEL_MAPPING.items() 65 | } 66 | -------------------------------------------------------------------------------- /src/arize/experimental/prompt_hub/prompts/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Prompt implementations for different LLM providers. 3 | """ 4 | 5 | from arize.experimental.prompt_hub.prompts.base import ( 6 | FormattedPrompt, 7 | LLMProvider, 8 | Prompt, 9 | PromptInputVariableFormat, 10 | prompt_to_llm_input, 11 | ) 12 | 13 | __all__ = [ 14 | "FormattedPrompt", 15 | "Prompt", 16 | "PromptInputVariableFormat", 17 | "LLMProvider", 18 | "prompt_to_llm_input", 19 | ] 20 | -------------------------------------------------------------------------------- /src/arize/experimental/prompt_hub/prompts/open_ai.py: -------------------------------------------------------------------------------- 1 | """ 2 | OpenAI-specific prompt implementations. 3 | """ 4 | 5 | from dataclasses import dataclass 6 | from typing import Any, Dict, List, Mapping, Sequence 7 | 8 | from arize.experimental.prompt_hub.prompts.base import FormattedPrompt, Prompt 9 | 10 | 11 | @dataclass(frozen=True) 12 | class OpenAIPrompt(FormattedPrompt): 13 | """ 14 | OpenAI-specific formatted prompt. 15 | 16 | Contains fully formatted messages and additional parameters (e.g. model, temperature, etc.) 17 | required by the OpenAI API. 18 | """ 19 | 20 | messages: Sequence[Dict[str, str]] 21 | kwargs: Dict[str, Any] 22 | 23 | 24 | def to_openai_prompt( 25 | prompt: Prompt, variables: Mapping[str, Any] 26 | ) -> OpenAIPrompt: 27 | """ 28 | Convert a Prompt to an OpenAI-specific formatted prompt. 29 | 30 | Args: 31 | prompt: The prompt to format. 32 | variables: A mapping of variable names to values. 33 | 34 | Returns: 35 | An OpenAI-specific formatted prompt. 36 | """ 37 | return OpenAIPrompt( 38 | messages=format_openai_prompt(prompt, variables), 39 | kwargs=openai_kwargs(prompt), 40 | ) 41 | 42 | 43 | def format_openai_prompt( 44 | prompt: Prompt, variables: Mapping[str, Any] 45 | ) -> List[Dict[str, str]]: 46 | """ 47 | Format a Prompt's messages for the OpenAI API. 48 | 49 | Args: 50 | prompt: The prompt to format. 51 | variables: A mapping of variable names to values. 52 | 53 | Returns: 54 | A list of formatted message dictionaries. 55 | """ 56 | formatted_messages = [] 57 | for message in prompt.messages: 58 | formatted_message = message.copy() 59 | formatted_message["content"] = message["content"].format(**variables) 60 | formatted_messages.append(formatted_message) 61 | return formatted_messages 62 | 63 | 64 | def openai_kwargs(prompt: Prompt) -> Dict[str, Any]: 65 | """ 66 | Generate kwargs for the OpenAI API based on the prompt. 67 | 68 | Args: 69 | prompt: The prompt to generate kwargs for. 70 | 71 | Returns: 72 | A dictionary of kwargs for the OpenAI API. 73 | """ 74 | return {"model": prompt.model_name} 75 | -------------------------------------------------------------------------------- /src/arize/experimental/prompt_hub/prompts/vertex_ai.py: -------------------------------------------------------------------------------- 1 | """ 2 | Vertex AI-specific prompt implementations. 3 | """ 4 | 5 | from dataclasses import dataclass 6 | from typing import Any, List, Mapping, Sequence 7 | 8 | from vertexai.generative_models import Content, Part 9 | 10 | from arize.experimental.prompt_hub.prompts.base import FormattedPrompt, Prompt 11 | 12 | 13 | @dataclass(frozen=True) 14 | class VertexAIPrompt(FormattedPrompt): 15 | """ 16 | Vertex AI-specific formatted prompt. 17 | 18 | Contains fully formatted messages and additional parameters (e.g. model, temperature, etc.) 19 | required by the Vertex AI API. 20 | """ 21 | 22 | messages: Sequence[Content] 23 | model_name: str 24 | 25 | 26 | def to_vertexai_prompt( 27 | prompt: Prompt, variables: Mapping[str, Any] 28 | ) -> VertexAIPrompt: 29 | """ 30 | Convert a Prompt to a Vertex AI-specific formatted prompt. 31 | 32 | Args: 33 | prompt: The prompt to format. 34 | variables: A mapping of variable names to values. 35 | 36 | Returns: 37 | A Vertex AI-specific formatted prompt. 38 | 39 | Raises: 40 | ImportError: If Vertex AI dependencies are not available. 41 | """ 42 | return VertexAIPrompt( 43 | messages=format_vertexai_prompt(prompt, variables), 44 | model_name=prompt.model_name, 45 | ) 46 | 47 | 48 | def format_vertexai_prompt( 49 | prompt: Prompt, variables: Mapping[str, Any] 50 | ) -> List[Content]: 51 | """ 52 | Format a Prompt's messages for the Vertex AI API. 53 | 54 | Args: 55 | prompt: The prompt to format. 56 | variables: A mapping of variable names to values. 57 | 58 | Returns: 59 | A list of formatted Content objects. 60 | 61 | Raises: 62 | ImportError: If Vertex AI dependencies are not available. 63 | """ 64 | formatted_messages = [] 65 | for message in prompt.messages: 66 | formatted_message = Content( 67 | role=message["role"], 68 | parts=[Part.from_text(message["content"].format(**variables))], 69 | ) 70 | formatted_messages.append(formatted_message) 71 | return formatted_messages 72 | -------------------------------------------------------------------------------- /src/arize/exporter/README.md: -------------------------------------------------------------------------------- 1 | ## Arize Python Exporter Client - User Guide 2 | 3 | ### Step 1: Pip install `arize` and set up `api_key` and `space_id`
4 | ``` 5 | ! pip install -q arize 6 | ``` 7 | ``` 8 | api_key = '' 9 | space_id = '' 10 | ``` 11 | - You can get your `space_id` by visiting [app.arize.com](https://app.arize.com). The url will be in this format: `https://app.arize.com/organizations/:org_id/spaces/:space_id`
12 | **NOTE: this is not the same as the space key used to send data using the SDK**
13 | 14 | - To get `api_key`, you must have Developer Access to your space. Visit [arize docs](https://docs.arize.com/arize/integrations/graphql-api/getting-started-with-programmatic-access) for more details
15 | **NOTE: this is not the same as the api key in Space Settings**
16 | 17 | ### Step 2: Initiate an `ArizeExportClient`
18 | 19 | ``` 20 | from arize.exporter import ArizeExportClient 21 | 22 | client = ArizeExportClient(api_key=api_key) 23 | ``` 24 | 25 | ### Step 3: Export production data with predictions only to a pandas dataframe 26 | 27 | ``` 28 | from arize.utils.types import Environments 29 | from datetime import datetime 30 | 31 | start_time = datetime(2023, 4, 10, 0, 0, 0, 0) 32 | end_time = datetime(2023, 4, 15, 0, 0, 0, 0) 33 | 34 | df = client.export_model_to_df( 35 | space_id=space_id, 36 | model_id='arize-demo-fraud-use-case', 37 | environment=Environments.PRODUCTION, 38 | start_time=start_time, 39 | end_time=end_time, 40 | model_version='', #optional field 41 | batch_id='', #optional field 42 | ) 43 | ``` 44 | 45 | ### Export production data with predictions and actuals to a pandas dataframe 46 | 47 | ``` 48 | from arize.utils.types import Environments 49 | from datetime import datetime 50 | 51 | start_time = datetime(2023, 4, 10, 0, 0, 0, 0) 52 | end_time = datetime(2023, 4, 15, 0, 0, 0, 0) 53 | 54 | df = client.export_model_to_df( 55 | space_id=, 56 | model_id='arize-demo-fraud-use-case', 57 | environment=Environments.PRODUCTION, 58 | start_time=start_time, 59 | end_time=end_time, 60 | include_actuals=True, #optional field 61 | model_version='', #optional field 62 | batch_id='', #optional field 63 | ) 64 | ``` 65 | 66 | 67 | ### Export training data to a pandas dataframe 68 | 69 | ``` 70 | from arize.utils.types import Environments 71 | from datetime import datetime 72 | 73 | start_time = datetime(2023, 4, 10, 0, 0, 0, 0) 74 | end_time = datetime(2023, 4, 15, 0, 0, 0, 0) 75 | 76 | df = client.export_model_to_df( 77 | space_id=, 78 | model_id='arize-demo-fraud-use-case', 79 | environment=Environments.TRAINING, 80 | start_time=start_time, 81 | end_time=end_time, 82 | model_version='', #optional field 83 | batch_id='', #optional field 84 | ) 85 | ``` 86 | ### Export validation data to a pandas dataframe 87 | 88 | ``` 89 | from arize.utils.types import Environments 90 | from datetime import datetime 91 | 92 | start_time = datetime(2023, 4, 10, 0, 0, 0, 0) 93 | end_time = datetime(2023, 4, 15, 0, 0, 0, 0) 94 | 95 | df = client.export_model_to_df( 96 | space_id=, 97 | model_id='arize-demo-fraud-use-case', 98 | environment=Environments.VALIDATION, 99 | start_time=start_time, 100 | end_time=end_time, 101 | model_version='', #optional field 102 | batch_id='', #optional field 103 | ) 104 | ``` 105 | 106 | ### Export production data to a parquet file 107 | 108 | ``` 109 | from arize.utils.types import Environments 110 | from datetime import datetime 111 | 112 | start_time = datetime(2023, 4, 10, 0, 0, 0, 0) 113 | end_time = datetime(2023, 4, 15, 0, 0, 0, 0) 114 | 115 | client.export_model_to_parquet( 116 | path = "example.parquet", 117 | space_id=, 118 | model_id='arize-demo-fraud-use-case', 119 | environment=Environments.PRODUCTION, 120 | start_time=start_time, 121 | end_time=end_time, 122 | model_version='', #optional field 123 | batch_id='', #optional field 124 | ) 125 | 126 | df = pd.read_parquet("example.parquet") 127 | ``` 128 | -------------------------------------------------------------------------------- /src/arize/exporter/__init__.py: -------------------------------------------------------------------------------- 1 | from .core.client import ArizeExportClient 2 | from .utils.schema_parser import get_arize_schema 3 | 4 | __all__ = ["ArizeExportClient", "get_arize_schema"] 5 | -------------------------------------------------------------------------------- /src/arize/exporter/core/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Arize-ai/client_python/8f72212f42490baa7a90831c591b83ab6173ba83/src/arize/exporter/core/__init__.py -------------------------------------------------------------------------------- /src/arize/exporter/core/query.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import Tuple 3 | 4 | from google.protobuf import json_format 5 | from pyarrow import flight 6 | 7 | from arize.utils.logging import logger 8 | 9 | from .. import publicexporter_pb2 as exp_pb2 10 | 11 | 12 | @dataclass(frozen=True) 13 | class Query: 14 | query_descriptor: exp_pb2.RecordQueryDescriptor 15 | 16 | def execute( 17 | self, 18 | client: flight.FlightClient, 19 | call_options: flight.FlightCallOptions, 20 | ) -> Tuple[flight.FlightStreamReader, int]: 21 | try: 22 | flight_info = client.get_flight_info( 23 | flight.FlightDescriptor.for_command( 24 | json_format.MessageToJson(self.query_descriptor) # type: ignore 25 | ), 26 | call_options, 27 | ) 28 | logger.info("Fetching data...") 29 | 30 | if flight_info.total_records == 0: 31 | logger.info("Query returns no data") 32 | return None, 0 33 | logger.debug("Ticket: %s", flight_info.endpoints[0].ticket) 34 | 35 | # Retrieve the result set as flight stream reader 36 | reader = client.do_get( 37 | flight_info.endpoints[0].ticket, call_options 38 | ) 39 | logger.info("Starting exporting...") 40 | return reader, flight_info.total_records 41 | 42 | except Exception: 43 | logger.error( 44 | "There was an error trying to get the data from the endpoint" 45 | ) 46 | raise 47 | -------------------------------------------------------------------------------- /src/arize/exporter/core/session.py: -------------------------------------------------------------------------------- 1 | import configparser 2 | import os 3 | import uuid 4 | from dataclasses import dataclass, field 5 | from typing import Optional 6 | 7 | from pyarrow import flight 8 | 9 | from arize.utils.logging import logger 10 | from arize.utils.utils import get_python_version 11 | from arize.version import __version__ 12 | 13 | from ..utils.constants import ( 14 | ARIZE_API_KEY, 15 | DEFAULT_ARIZE_API_KEY_CONFIG_KEY, 16 | DEFAULT_PACKAGE_NAME, 17 | PROFILE_FILE_NAME, 18 | ) 19 | from ..utils.errors import InvalidConfigFileError, InvalidSessionError 20 | 21 | 22 | @dataclass 23 | class Session: 24 | api_key: Optional[str] 25 | arize_profile: str 26 | arize_config_path: str 27 | host: str 28 | port: int 29 | scheme: str 30 | session_name: str = field(init=False) 31 | call_options: flight.FlightCallOptions = field(init=False) 32 | 33 | def __post_init__(self): 34 | self.session_name = f"python-sdk-{DEFAULT_PACKAGE_NAME}-{uuid.uuid4()}" 35 | logger.info(f"Creating named session as '{self.session_name}'.") 36 | # If api_key is not passed, try reading from environment variable. 37 | # If api_key is also not set as environment variable, read from config file 38 | self.api_key = self.api_key or ARIZE_API_KEY or self._read_config() 39 | if self.api_key is None: 40 | logger.error(InvalidSessionError.error_message()) 41 | raise InvalidSessionError 42 | 43 | if self.host.startswith(("http://", "https://")): 44 | scheme, host = self.host.split("://") 45 | logger.warning( 46 | f"The host '{self.host}' starts with '{scheme}' and it will be stripped, " 47 | "remove it or consider using the `scheme` parameter instead." 48 | ) 49 | self.host = host 50 | 51 | logger.debug( 52 | f"Created session with Arize API Token '{self.api_key}' at '{self.host}':'{self.port}'" 53 | ) 54 | self._set_headers() 55 | 56 | def _read_config(self) -> Optional[str]: 57 | config_parser = Session._get_config_parser() 58 | file_path = os.path.join(self.arize_config_path, PROFILE_FILE_NAME) 59 | logger.debug( 60 | f"No provided connection details. Looking up session values from '{self.arize_profile}' in " 61 | f"'{file_path}'." 62 | ) 63 | try: 64 | config_parser.read(file_path) 65 | return config_parser.get( 66 | self.arize_profile, DEFAULT_ARIZE_API_KEY_CONFIG_KEY 67 | ) 68 | except configparser.NoSectionError as err: 69 | # Missing api key error is raised in the __post_init__ method 70 | logger.warning( 71 | f"Can't extract API key from config file. {err.message}" 72 | ) 73 | return None 74 | except Exception as err: 75 | logger.error(InvalidConfigFileError.error_message()) 76 | raise InvalidConfigFileError from err 77 | 78 | @staticmethod 79 | def _get_config_parser() -> configparser.ConfigParser: 80 | return configparser.ConfigParser() 81 | 82 | def connect(self) -> flight.FlightClient: 83 | """ 84 | Connects to Arize Flight server public endpoint with the 85 | provided api key. 86 | """ 87 | try: 88 | disable_cert = self.host.lower() == "localhost" 89 | client = flight.FlightClient( 90 | location=f"{self.scheme}://{self.host}:{self.port}", 91 | disable_server_verification=disable_cert, 92 | ) 93 | self.call_options = flight.FlightCallOptions(headers=self._headers) 94 | return client 95 | except Exception: 96 | logger.error( 97 | "There was an error trying to connect to the Arize Flight Endpoint" 98 | ) 99 | raise 100 | 101 | def _set_headers(self) -> None: 102 | self._headers = [ 103 | (b"origin", b"arize-python-exporter"), 104 | (b"auth-token-bin", f"{self.api_key}".encode()), 105 | (b"sdk-language", b"python"), 106 | (b"language-version", get_python_version().encode("utf-8")), 107 | (b"sdk-version", __version__.encode("utf-8")), 108 | (b"arize-interface", b"flight"), 109 | ] 110 | -------------------------------------------------------------------------------- /src/arize/exporter/publicexporter_pb2.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Generated by the protocol buffer compiler. DO NOT EDIT! 3 | # source: publicexporter.proto 4 | """Generated protocol buffer code.""" 5 | from google.protobuf import descriptor as _descriptor 6 | from google.protobuf import descriptor_pool as _descriptor_pool 7 | from google.protobuf import message as _message 8 | from google.protobuf import reflection as _reflection 9 | from google.protobuf import symbol_database as _symbol_database 10 | # @@protoc_insertion_point(imports) 11 | 12 | _sym_db = _symbol_database.Default() 13 | 14 | 15 | from google.protobuf import timestamp_pb2 as google_dot_protobuf_dot_timestamp__pb2 16 | 17 | 18 | DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x14publicexporter.proto\x12\x0epublicexporter\x1a\x1fgoogle/protobuf/timestamp.proto\"\xfa\x03\n\x15RecordQueryDescriptor\x12\x10\n\x08space_id\x18\x01 \x01(\t\x12\x10\n\x08model_id\x18\x02 \x01(\t\x12\x46\n\x0b\x65nvironment\x18\x03 \x01(\x0e\x32\x31.publicexporter.RecordQueryDescriptor.Environment\x12\x15\n\rmodel_version\x18\x04 \x01(\t\x12\x10\n\x08\x62\x61tch_id\x18\x05 \x01(\t\x12.\n\nstart_time\x18\x06 \x01(\x0b\x32\x1a.google.protobuf.Timestamp\x12,\n\x08\x65nd_time\x18\x07 \x01(\x0b\x32\x1a.google.protobuf.Timestamp\x12\x17\n\x0finclude_actuals\x18\x08 \x01(\x08\x12\x19\n\x11\x66ilter_expression\x18\t \x01(\t\x12H\n\x18similarity_search_params\x18\n \x01(\x0b\x32&.publicexporter.SimilaritySearchParams\x12\x19\n\x11projected_columns\x18\x0b \x03(\t\"U\n\x0b\x45nvironment\x12\x0b\n\x07UNKNOWN\x10\x00\x12\x0c\n\x08TRAINING\x10\x01\x12\x0e\n\nVALIDATION\x10\x02\x12\x0e\n\nPRODUCTION\x10\x03\x12\x0b\n\x07TRACING\x10\x04\"\x9b\x02\n\x16SimilaritySearchParams\x12\x44\n\nreferences\x18\x01 \x03(\x0b\x32\x30.publicexporter.SimilaritySearchParams.Reference\x12\x1a\n\x12search_column_name\x18\x02 \x01(\t\x12\x11\n\tthreshold\x18\x03 \x01(\x01\x1a\x8b\x01\n\tReference\x12\x15\n\rprediction_id\x18\x01 \x01(\t\x12\x1d\n\x15reference_column_name\x18\x02 \x01(\t\x12\x38\n\x14prediction_timestamp\x18\x03 \x01(\x0b\x32\x1a.google.protobuf.Timestamp\x12\x0e\n\x06vector\x18\x04 \x03(\x01\x42GZEgithub.com/Arize-ai/arize/go/pkg/flightserver/protocol/publicexporterb\x06proto3') 19 | 20 | 21 | 22 | _RECORDQUERYDESCRIPTOR = DESCRIPTOR.message_types_by_name['RecordQueryDescriptor'] 23 | _SIMILARITYSEARCHPARAMS = DESCRIPTOR.message_types_by_name['SimilaritySearchParams'] 24 | _SIMILARITYSEARCHPARAMS_REFERENCE = _SIMILARITYSEARCHPARAMS.nested_types_by_name['Reference'] 25 | _RECORDQUERYDESCRIPTOR_ENVIRONMENT = _RECORDQUERYDESCRIPTOR.enum_types_by_name['Environment'] 26 | RecordQueryDescriptor = _reflection.GeneratedProtocolMessageType('RecordQueryDescriptor', (_message.Message,), { 27 | 'DESCRIPTOR' : _RECORDQUERYDESCRIPTOR, 28 | '__module__' : 'publicexporter_pb2' 29 | # @@protoc_insertion_point(class_scope:publicexporter.RecordQueryDescriptor) 30 | }) 31 | _sym_db.RegisterMessage(RecordQueryDescriptor) 32 | 33 | SimilaritySearchParams = _reflection.GeneratedProtocolMessageType('SimilaritySearchParams', (_message.Message,), { 34 | 35 | 'Reference' : _reflection.GeneratedProtocolMessageType('Reference', (_message.Message,), { 36 | 'DESCRIPTOR' : _SIMILARITYSEARCHPARAMS_REFERENCE, 37 | '__module__' : 'publicexporter_pb2' 38 | # @@protoc_insertion_point(class_scope:publicexporter.SimilaritySearchParams.Reference) 39 | }) 40 | , 41 | 'DESCRIPTOR' : _SIMILARITYSEARCHPARAMS, 42 | '__module__' : 'publicexporter_pb2' 43 | # @@protoc_insertion_point(class_scope:publicexporter.SimilaritySearchParams) 44 | }) 45 | _sym_db.RegisterMessage(SimilaritySearchParams) 46 | _sym_db.RegisterMessage(SimilaritySearchParams.Reference) 47 | 48 | if _descriptor._USE_C_DESCRIPTORS == False: 49 | 50 | DESCRIPTOR._options = None 51 | DESCRIPTOR._serialized_options = b'ZEgithub.com/Arize-ai/arize/go/pkg/flightserver/protocol/publicexporter' 52 | _RECORDQUERYDESCRIPTOR._serialized_start=74 53 | _RECORDQUERYDESCRIPTOR._serialized_end=580 54 | _RECORDQUERYDESCRIPTOR_ENVIRONMENT._serialized_start=495 55 | _RECORDQUERYDESCRIPTOR_ENVIRONMENT._serialized_end=580 56 | _SIMILARITYSEARCHPARAMS._serialized_start=583 57 | _SIMILARITYSEARCHPARAMS._serialized_end=866 58 | _SIMILARITYSEARCHPARAMS_REFERENCE._serialized_start=727 59 | _SIMILARITYSEARCHPARAMS_REFERENCE._serialized_end=866 60 | # @@protoc_insertion_point(module_scope) 61 | -------------------------------------------------------------------------------- /src/arize/exporter/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Arize-ai/client_python/8f72212f42490baa7a90831c591b83ab6173ba83/src/arize/exporter/utils/__init__.py -------------------------------------------------------------------------------- /src/arize/exporter/utils/constants.py: -------------------------------------------------------------------------------- 1 | import os 2 | from pathlib import Path 3 | 4 | """Environmental configuration""" 5 | 6 | # Override ARIZE Default Profile when reading from the config-file in a session. 7 | ARIZE_PROFILE = os.getenv("ARIZE_PROFILE") 8 | 9 | # Override ARIZE API Token when creating a session. 10 | ARIZE_API_KEY = os.getenv("ARIZE_API_KEY") 11 | 12 | 13 | """Internal Use""" 14 | 15 | # Default API endpoint when not provided through env variable nor profile 16 | DEFAULT_ARIZE_FLIGHT_HOST = "flight.arize.com" 17 | DEFAULT_ARIZE_FLIGHT_PORT = 443 18 | 19 | # Name of the current package. 20 | DEFAULT_PACKAGE_NAME = "arize_python_export_client" 21 | 22 | # Default config keys for the Arize config file. Created via the CLI. 23 | DEFAULT_ARIZE_API_KEY_CONFIG_KEY = "api_key" 24 | 25 | # Default headers to trace and help identify requests. For debugging. 26 | DEFAULT_ARIZE_SESSION_ID = "x-arize-session-id" # Generally the session name. 27 | DEFAULT_ARIZE_TRACE_ID = "x-arize-trace-id" 28 | DEFAULT_PACKAGE_VERSION = "x-package-version" 29 | 30 | # File name for profile configuration. 31 | PROFILE_FILE_NAME = "profiles.ini" 32 | 33 | # Default profile to be used. 34 | DEFAULT_PROFILE_NAME = "default" 35 | 36 | # Default path where any configuration files are written. 37 | DEFAULT_CONFIG_PATH = os.path.join(str(Path.home()), ".arize") 38 | 39 | # Default initial wait time for retries in seconds. 40 | DEFAULT_RETRY_INITIAL_WAIT_TIME = 0.25 41 | 42 | # Default maximum wait time for retries in seconds. 43 | DEFAULT_RETRY_MAX_WAIT_TIME = 10.0 44 | 45 | # Default to use grpc + tls scheme. 46 | DEFAULT_TRANSPORT_SCHEME = "grpc+tls" 47 | -------------------------------------------------------------------------------- /src/arize/exporter/utils/errors.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | 3 | 4 | class ExportingError(Exception, ABC): 5 | def __str__(self) -> str: 6 | return self.error_message() 7 | 8 | @abstractmethod 9 | def error_message(self) -> str: 10 | pass 11 | 12 | 13 | class InvalidSessionError(ExportingError): 14 | @staticmethod 15 | def error_message() -> str: 16 | return ( 17 | "Credentials not provided or invalid. Please pass in the correct api_key when " 18 | "initiating a new ArizeExportClient. Alternatively, you can set up credentials " 19 | "in a profile or as an environment variable" 20 | ) 21 | 22 | 23 | class InvalidConfigFileError(Exception): 24 | @staticmethod 25 | def error_message() -> str: 26 | return "WENDY TO WRITE APPROPRIATE ERROR MESSAGE" 27 | -------------------------------------------------------------------------------- /src/arize/exporter/utils/tracing.py: -------------------------------------------------------------------------------- 1 | import importlib.util 2 | import json 3 | from typing import List 4 | 5 | import numpy as np 6 | import pandas as pd 7 | 8 | from arize.utils.logging import logger 9 | 10 | try: 11 | oic_spec = importlib.util.find_spec("openinference.semconv") 12 | except Exception: 13 | oic_spec = None 14 | 15 | if oic_spec is not None: 16 | from arize.pandas.tracing.columns import ( 17 | SPAN_ATTRIBUTES_EMBEDDING_EMBEDDINGS_COL, 18 | SPAN_ATTRIBUTES_LLM_INPUT_MESSAGES_COL, 19 | SPAN_ATTRIBUTES_LLM_INVOCATION_PARAMETERS_COL, 20 | SPAN_ATTRIBUTES_LLM_OUTPUT_MESSAGES_COL, 21 | SPAN_ATTRIBUTES_LLM_PROMPT_TEMPLATE_VARIABLES_COL, 22 | SPAN_ATTRIBUTES_LLM_TOOLS_COL, 23 | SPAN_ATTRIBUTES_METADATA, 24 | SPAN_ATTRIBUTES_RETRIEVAL_DOCUMENTS_COL, 25 | SPAN_ATTRIBUTES_TOOL_PARAMETERS_COL, 26 | SPAN_END_TIME_COL, 27 | SPAN_START_TIME_COL, 28 | ) 29 | 30 | 31 | # Data transformer for Otel tracing data into types and values that are more ergonomic for users needing 32 | # to interact with the data in Python; This class is intended to be used by Arize and not by users 33 | # Any errors encountered are unexpected since Arize also controls the data types returned from the platform 34 | # but the resulting error messages provide clarity on what the effect 35 | # of the error is on the data; It should not prevent a user from continuing to use the data 36 | class OtelTracingDataTransformer: 37 | def transform(self, df: pd.DataFrame) -> pd.DataFrame: 38 | errors: List[str] = [] 39 | 40 | # Convert list of json serializable strings columns to list of dictionaries for more 41 | # conveinent data processing in Python 42 | list_of_json_string_column_names: List[str] = [ 43 | col.name 44 | for col in [ 45 | SPAN_ATTRIBUTES_LLM_INPUT_MESSAGES_COL, 46 | SPAN_ATTRIBUTES_LLM_OUTPUT_MESSAGES_COL, 47 | SPAN_ATTRIBUTES_EMBEDDING_EMBEDDINGS_COL, 48 | SPAN_ATTRIBUTES_RETRIEVAL_DOCUMENTS_COL, 49 | SPAN_ATTRIBUTES_LLM_TOOLS_COL, 50 | ] 51 | if col.name in df.columns 52 | ] 53 | for col_name in list_of_json_string_column_names: 54 | try: 55 | df[col_name] = df[col_name].apply( 56 | self._transform_value_to_list_of_dict 57 | ) 58 | except Exception as e: 59 | errors.append( 60 | f"Unable to transform json string data to a Python dict in column '{col_name}'; " 61 | f"May encounter issues when importing data back into Arize; Error: {e}" 62 | ) 63 | 64 | json_string_column_names: List[str] = [ 65 | col.name 66 | for col in [ 67 | SPAN_ATTRIBUTES_LLM_PROMPT_TEMPLATE_VARIABLES_COL, 68 | SPAN_ATTRIBUTES_METADATA, 69 | ] 70 | if col.name in df.columns 71 | ] 72 | for col_name in json_string_column_names: 73 | try: 74 | df[col_name] = df[col_name].apply(self._transform_json_to_dict) 75 | except Exception as e: 76 | errors.append( 77 | f"Unable to transform json string data to a Python dict in column '{col_name}'; " 78 | f"May encounter issues when importing data back into Arize; Error: {e}" 79 | ) 80 | 81 | # Clean json string columns since empty strings are equivalent here to None but are not valid json 82 | dirty_string_column_names: List[str] = [ 83 | col.name 84 | for col in [ 85 | SPAN_ATTRIBUTES_LLM_INVOCATION_PARAMETERS_COL, 86 | SPAN_ATTRIBUTES_TOOL_PARAMETERS_COL, 87 | ] 88 | if col.name in df.columns 89 | ] 90 | for col_name in dirty_string_column_names: 91 | df[col_name] = df[col_name].apply(self._clean_json_string) 92 | 93 | # Convert timestamp columns to datetime objects 94 | timestamp_column_names: List[str] = [ 95 | col.name 96 | for col in [ 97 | SPAN_START_TIME_COL, 98 | SPAN_END_TIME_COL, 99 | ] 100 | if col.name in df.columns 101 | ] 102 | for col_name in timestamp_column_names: 103 | df[col_name] = df[col_name].apply( 104 | self._convert_timestamp_to_datetime 105 | ) 106 | 107 | for err in errors: 108 | logger.warning(err) 109 | 110 | return df 111 | 112 | def _transform_value_to_list_of_dict(self, value): 113 | if value is None: 114 | return None 115 | 116 | if isinstance(value, (list, np.ndarray)): 117 | return [ 118 | self._deserialize_json_string_to_dict(i) 119 | for i in value 120 | if self._is_non_empty_string(i) 121 | ] 122 | elif self._is_non_empty_string(value): 123 | return [self._deserialize_json_string_to_dict(value)] 124 | 125 | def _transform_json_to_dict(self, value): 126 | if value is None: 127 | return None 128 | 129 | if self._is_non_empty_string(value): 130 | return self._deserialize_json_string_to_dict(value) 131 | 132 | if isinstance(value, str) and value == "": 133 | # transform empty string to None 134 | return None 135 | 136 | def _is_non_empty_string(self, value): 137 | return isinstance(value, str) and value != "" 138 | 139 | def _deserialize_json_string_to_dict(self, value: str): 140 | try: 141 | return json.loads(value) 142 | except json.JSONDecodeError as e: 143 | raise ValueError(f"Invalid JSON string: {value}") from e 144 | 145 | def _clean_json_string(self, value): 146 | return value if self._is_non_empty_string(value) else None 147 | 148 | def _convert_timestamp_to_datetime(self, value): 149 | return ( 150 | pd.Timestamp(value, unit="ns") 151 | if value and isinstance(value, (int, float, np.int64)) 152 | else value 153 | ) 154 | -------------------------------------------------------------------------------- /src/arize/exporter/utils/validation.py: -------------------------------------------------------------------------------- 1 | from datetime import datetime 2 | 3 | 4 | class Validator: 5 | @staticmethod 6 | def validate_input_type(input, input_name: str, input_type: type) -> None: 7 | if not isinstance(input, input_type) and input is not None: 8 | raise TypeError( 9 | f"{input_name} {input} is type {type(input)}, but must be a {input_type.__name__}" 10 | ) 11 | 12 | @staticmethod 13 | def validate_input_value(input, input_name: str, choices: tuple) -> None: 14 | if input not in choices: 15 | raise ValueError( 16 | f"{input_name} is {input}, but must be one of {', '.join(choices)}" 17 | ) 18 | 19 | @staticmethod 20 | def validate_start_end_time(start_time, end_time: datetime) -> None: 21 | if start_time >= end_time: 22 | raise ValueError("start_time must be before end_time") 23 | -------------------------------------------------------------------------------- /src/arize/pandas/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Arize-ai/client_python/8f72212f42490baa7a90831c591b83ab6173ba83/src/arize/pandas/__init__.py -------------------------------------------------------------------------------- /src/arize/pandas/embeddings/__init__.py: -------------------------------------------------------------------------------- 1 | from .auto_generator import EmbeddingGenerator 2 | from .usecases import UseCases 3 | 4 | __all__ = ["EmbeddingGenerator", "UseCases"] 5 | -------------------------------------------------------------------------------- /src/arize/pandas/embeddings/auto_generator.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | 3 | from . import constants 4 | from .base_generators import BaseEmbeddingGenerator 5 | from .constants import ( 6 | CV_PRETRAINED_MODELS, 7 | DEFAULT_CV_IMAGE_CLASSIFICATION_MODEL, 8 | DEFAULT_CV_OBJECT_DETECTION_MODEL, 9 | DEFAULT_NLP_SEQUENCE_CLASSIFICATION_MODEL, 10 | DEFAULT_NLP_SUMMARIZATION_MODEL, 11 | DEFAULT_TABULAR_MODEL, 12 | NLP_PRETRAINED_MODELS, 13 | ) 14 | from .cv_generators import ( 15 | EmbeddingGeneratorForCVImageClassification, 16 | EmbeddingGeneratorForCVObjectDetection, 17 | ) 18 | from .nlp_generators import ( 19 | EmbeddingGeneratorForNLPSequenceClassification, 20 | EmbeddingGeneratorForNLPSummarization, 21 | ) 22 | from .tabular_generators import EmbeddingGeneratorForTabularFeatures 23 | from .usecases import UseCases 24 | 25 | 26 | class EmbeddingGenerator: 27 | def __init__(self, **kwargs: str): 28 | raise OSError( 29 | f"{self.__class__.__name__} is designed to be instantiated using the " 30 | f"`{self.__class__.__name__}.from_use_case(use_case, **kwargs)` method." 31 | ) 32 | 33 | @staticmethod 34 | def from_use_case(use_case: str, **kwargs: str) -> BaseEmbeddingGenerator: 35 | if use_case == UseCases.NLP.SEQUENCE_CLASSIFICATION: 36 | return EmbeddingGeneratorForNLPSequenceClassification(**kwargs) 37 | elif use_case == UseCases.NLP.SUMMARIZATION: 38 | return EmbeddingGeneratorForNLPSummarization(**kwargs) 39 | elif use_case == UseCases.CV.IMAGE_CLASSIFICATION: 40 | return EmbeddingGeneratorForCVImageClassification(**kwargs) 41 | elif use_case == UseCases.CV.OBJECT_DETECTION: 42 | return EmbeddingGeneratorForCVObjectDetection(**kwargs) 43 | elif use_case == UseCases.STRUCTURED.TABULAR_EMBEDDINGS: 44 | return EmbeddingGeneratorForTabularFeatures(**kwargs) 45 | else: 46 | raise ValueError(f"Invalid use case {use_case}") 47 | 48 | @classmethod 49 | def list_default_models(cls) -> pd.DataFrame: 50 | df = pd.DataFrame( 51 | { 52 | "Area": ["NLP", "NLP", "CV", "CV", "STRUCTURED"], 53 | "Usecase": [ 54 | UseCases.NLP.SEQUENCE_CLASSIFICATION.name, 55 | UseCases.NLP.SUMMARIZATION.name, 56 | UseCases.CV.IMAGE_CLASSIFICATION.name, 57 | UseCases.CV.OBJECT_DETECTION.name, 58 | UseCases.STRUCTURED.TABULAR_EMBEDDINGS.name, 59 | ], 60 | "Model Name": [ 61 | DEFAULT_NLP_SEQUENCE_CLASSIFICATION_MODEL, 62 | DEFAULT_NLP_SUMMARIZATION_MODEL, 63 | DEFAULT_CV_IMAGE_CLASSIFICATION_MODEL, 64 | DEFAULT_CV_OBJECT_DETECTION_MODEL, 65 | DEFAULT_TABULAR_MODEL, 66 | ], 67 | } 68 | ) 69 | df.sort_values( 70 | by=[col for col in df.columns], ascending=True, inplace=True 71 | ) 72 | return df.reset_index(drop=True) 73 | 74 | @classmethod 75 | def list_pretrained_models(cls) -> pd.DataFrame: 76 | data = { 77 | "Task": ["NLP" for _ in NLP_PRETRAINED_MODELS] 78 | + ["CV" for _ in CV_PRETRAINED_MODELS], 79 | "Architecture": [ 80 | cls.__parse_model_arch(model) 81 | for model in NLP_PRETRAINED_MODELS + CV_PRETRAINED_MODELS 82 | ], 83 | "Model Name": NLP_PRETRAINED_MODELS + CV_PRETRAINED_MODELS, 84 | } 85 | df = pd.DataFrame(data) 86 | df.sort_values( 87 | by=[col for col in df.columns], ascending=True, inplace=True 88 | ) 89 | return df.reset_index(drop=True) 90 | 91 | @staticmethod 92 | def __parse_model_arch(model_name: str) -> str: 93 | if constants.GPT.lower() in model_name.lower(): 94 | return constants.GPT 95 | elif constants.BERT.lower() in model_name.lower(): 96 | return constants.BERT 97 | elif constants.VIT.lower() in model_name.lower(): 98 | return constants.VIT 99 | else: 100 | raise ValueError("Invalid model_name, unknown architecture.") 101 | -------------------------------------------------------------------------------- /src/arize/pandas/embeddings/constants.py: -------------------------------------------------------------------------------- 1 | DEFAULT_NLP_SEQUENCE_CLASSIFICATION_MODEL = "distilbert-base-uncased" 2 | DEFAULT_NLP_SUMMARIZATION_MODEL = "distilbert-base-uncased" 3 | DEFAULT_TABULAR_MODEL = "distilbert-base-uncased" 4 | DEFAULT_CV_IMAGE_CLASSIFICATION_MODEL = "google/vit-base-patch32-224-in21k" 5 | DEFAULT_CV_OBJECT_DETECTION_MODEL = "facebook/detr-resnet-101" 6 | NLP_PRETRAINED_MODELS = [ 7 | "bert-base-cased", 8 | "bert-base-uncased", 9 | "bert-large-cased", 10 | "bert-large-uncased", 11 | "distilbert-base-cased", 12 | "distilbert-base-uncased", 13 | "xlm-roberta-base", 14 | "xlm-roberta-large", 15 | ] 16 | 17 | CV_PRETRAINED_MODELS = [ 18 | "google/vit-base-patch16-224-in21k", 19 | "google/vit-base-patch16-384", 20 | "google/vit-base-patch32-224-in21k", 21 | "google/vit-base-patch32-384", 22 | "google/vit-large-patch16-224-in21k", 23 | "google/vit-large-patch16-384", 24 | "google/vit-large-patch32-224-in21k", 25 | "google/vit-large-patch32-384", 26 | ] 27 | IMPORT_ERROR_MESSAGE = ( 28 | "To enable embedding generation, the arize module must be installed with " 29 | "extra dependencies. Run: pip install 'arize[AutoEmbeddings]'." 30 | ) 31 | 32 | GPT = "GPT" 33 | BERT = "BERT" 34 | VIT = "ViT" 35 | -------------------------------------------------------------------------------- /src/arize/pandas/embeddings/cv_generators.py: -------------------------------------------------------------------------------- 1 | from .base_generators import CVEmbeddingGenerator 2 | from .constants import ( 3 | DEFAULT_CV_IMAGE_CLASSIFICATION_MODEL, 4 | DEFAULT_CV_OBJECT_DETECTION_MODEL, 5 | ) 6 | from .usecases import UseCases 7 | 8 | 9 | class EmbeddingGeneratorForCVImageClassification(CVEmbeddingGenerator): 10 | def __init__( 11 | self, model_name: str = DEFAULT_CV_IMAGE_CLASSIFICATION_MODEL, **kwargs 12 | ): 13 | super().__init__( 14 | use_case=UseCases.CV.IMAGE_CLASSIFICATION, 15 | model_name=model_name, 16 | **kwargs, 17 | ) 18 | 19 | 20 | class EmbeddingGeneratorForCVObjectDetection(CVEmbeddingGenerator): 21 | def __init__( 22 | self, model_name: str = DEFAULT_CV_OBJECT_DETECTION_MODEL, **kwargs 23 | ): 24 | super().__init__( 25 | use_case=UseCases.CV.OBJECT_DETECTION, 26 | model_name=model_name, 27 | **kwargs, 28 | ) 29 | -------------------------------------------------------------------------------- /src/arize/pandas/embeddings/errors.py: -------------------------------------------------------------------------------- 1 | class InvalidIndexError(Exception): 2 | def __repr__(self) -> str: 3 | return "Invalid_Index_Error" 4 | 5 | def __str__(self) -> str: 6 | return self.error_message() 7 | 8 | def __init__(self, field_name: str) -> None: 9 | self.field_name = field_name 10 | 11 | def error_message(self) -> str: 12 | if self.field_name == "DataFrame": 13 | return ( 14 | f"The index of the {self.field_name} is invalid; " 15 | f"reset the index by using df.reset_index(drop=True, inplace=True)" 16 | ) 17 | else: 18 | return ( 19 | f"The index of the Series given by the column '{self.field_name}' is invalid; " 20 | f"reset the index by using df.reset_index(drop=True, inplace=True)" 21 | ) 22 | 23 | 24 | class HuggingFaceRepositoryNotFound(Exception): 25 | def __repr__(self) -> str: 26 | return "HuggingFace_Repository_Not_Found_Error" 27 | 28 | def __str__(self) -> str: 29 | return self.error_message() 30 | 31 | def __init__(self, model_name: str) -> None: 32 | self.model_name = model_name 33 | 34 | def error_message(self) -> str: 35 | return ( 36 | f"The given model name '{self.model_name}' is not a valid model identifier listed on " 37 | "'https://huggingface.co/models'. " 38 | "If this is a private repository, log in with `huggingface-cli login` or importing " 39 | "`login` from `huggingface_hub` if you are using a notebook. " 40 | "Learn more in https://huggingface.co/docs/huggingface_hub/quick-start#login" 41 | ) 42 | -------------------------------------------------------------------------------- /src/arize/pandas/embeddings/nlp_generators.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | from typing import Optional, cast 3 | 4 | import pandas as pd 5 | 6 | from arize.utils.logging import logger 7 | 8 | from .base_generators import NLPEmbeddingGenerator 9 | from .constants import ( 10 | DEFAULT_NLP_SEQUENCE_CLASSIFICATION_MODEL, 11 | DEFAULT_NLP_SUMMARIZATION_MODEL, 12 | IMPORT_ERROR_MESSAGE, 13 | ) 14 | from .usecases import UseCases 15 | 16 | try: 17 | from datasets import Dataset 18 | except ImportError: 19 | raise ImportError(IMPORT_ERROR_MESSAGE) from None 20 | 21 | 22 | class EmbeddingGeneratorForNLPSequenceClassification(NLPEmbeddingGenerator): 23 | def __init__( 24 | self, 25 | model_name: str = DEFAULT_NLP_SEQUENCE_CLASSIFICATION_MODEL, 26 | **kwargs, 27 | ): 28 | super().__init__( 29 | use_case=UseCases.NLP.SEQUENCE_CLASSIFICATION, 30 | model_name=model_name, 31 | **kwargs, 32 | ) 33 | 34 | def generate_embeddings( 35 | self, 36 | text_col: pd.Series, 37 | class_label_col: Optional[pd.Series] = None, 38 | ) -> pd.Series: 39 | """ 40 | Obtain embedding vectors from your text data using pre-trained large language models. 41 | 42 | :param text_col: a pandas Series containing the different pieces of text. 43 | :param class_label_col: if this column is passed, the sentence "The classification label 44 | is " will be appended to the text in the `text_col`. 45 | :return: a pandas Series containing the embedding vectors. 46 | """ 47 | if not isinstance(text_col, pd.Series): 48 | raise TypeError("text_col must be a pandas Series") 49 | 50 | self.check_invalid_index(field=text_col) 51 | 52 | if class_label_col is not None: 53 | if not isinstance(class_label_col, pd.Series): 54 | raise TypeError("class_label_col must be a pandas Series") 55 | df = pd.concat( 56 | {"text": text_col, "class_label": class_label_col}, axis=1 57 | ) 58 | prepared_text_col = df.apply( 59 | lambda row: f" The classification label is {row['class_label']}. {row['text']}", 60 | axis=1, 61 | ) 62 | ds = Dataset.from_dict({"text": prepared_text_col}) 63 | else: 64 | ds = Dataset.from_dict({"text": text_col}) 65 | 66 | ds.set_transform(partial(self.tokenize, text_feat_name="text")) 67 | logger.info("Generating embedding vectors") 68 | ds = ds.map( 69 | lambda batch: self._get_embedding_vector(batch, "cls_token"), 70 | batched=True, 71 | batch_size=self.batch_size, 72 | ) 73 | return cast(pd.DataFrame, ds.to_pandas())["embedding_vector"] 74 | 75 | 76 | class EmbeddingGeneratorForNLPSummarization(NLPEmbeddingGenerator): 77 | def __init__( 78 | self, model_name: str = DEFAULT_NLP_SUMMARIZATION_MODEL, **kwargs 79 | ): 80 | super().__init__( 81 | use_case=UseCases.NLP.SUMMARIZATION, 82 | model_name=model_name, 83 | **kwargs, 84 | ) 85 | 86 | def generate_embeddings( 87 | self, 88 | text_col: pd.Series, 89 | ) -> pd.Series: 90 | """ 91 | Obtain embedding vectors from your text data using pre-trained large language models. 92 | 93 | :param text_col: a pandas Series containing the different pieces of text. 94 | :return: a pandas Series containing the embedding vectors. 95 | """ 96 | if not isinstance(text_col, pd.Series): 97 | raise TypeError("text_col must be a pandas Series") 98 | self.check_invalid_index(field=text_col) 99 | 100 | ds = Dataset.from_dict({"text": text_col}) 101 | 102 | ds.set_transform(partial(self.tokenize, text_feat_name="text")) 103 | logger.info("Generating embedding vectors") 104 | ds = ds.map( 105 | lambda batch: self._get_embedding_vector(batch, "cls_token"), 106 | batched=True, 107 | batch_size=self.batch_size, 108 | ) 109 | return cast(pd.DataFrame, ds.to_pandas())["embedding_vector"] 110 | -------------------------------------------------------------------------------- /src/arize/pandas/embeddings/usecases.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from enum import Enum, auto, unique 3 | 4 | 5 | @unique 6 | class NLPUseCases(Enum): 7 | SEQUENCE_CLASSIFICATION = auto() 8 | SUMMARIZATION = auto() 9 | 10 | 11 | @unique 12 | class CVUseCases(Enum): 13 | IMAGE_CLASSIFICATION = auto() 14 | OBJECT_DETECTION = auto() 15 | 16 | 17 | @unique 18 | class TabularUsecases(Enum): 19 | TABULAR_EMBEDDINGS = auto() 20 | 21 | 22 | @dataclass 23 | class UseCases: 24 | NLP = NLPUseCases 25 | CV = CVUseCases 26 | STRUCTURED = TabularUsecases 27 | -------------------------------------------------------------------------------- /src/arize/pandas/etl/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Arize-ai/client_python/8f72212f42490baa7a90831c591b83ab6173ba83/src/arize/pandas/etl/__init__.py -------------------------------------------------------------------------------- /src/arize/pandas/etl/errors.py: -------------------------------------------------------------------------------- 1 | from arize.utils.logging import log_a_list 2 | from arize.utils.types import TypedColumns 3 | 4 | 5 | class ColumnCastingError(Exception): 6 | def __str__(self) -> str: 7 | return self.error_message() 8 | 9 | def __init__( 10 | self, 11 | error_msg: str, 12 | attempted_columns: str, 13 | attempted_type: TypedColumns, 14 | ) -> None: 15 | self.error_msg = error_msg 16 | self.attempted_casting_columns = attempted_columns 17 | self.attempted_casting_type = attempted_type 18 | 19 | def error_message(self) -> str: 20 | return ( 21 | f"Failed to cast to type {self.attempted_casting_type} " 22 | f"for columns: {log_a_list(self.attempted_casting_columns, 'and')}. " 23 | f"Error: {self.error_msg}" 24 | ) 25 | 26 | 27 | class InvalidTypedColumnsError(Exception): 28 | def __str__(self) -> str: 29 | return self.error_message() 30 | 31 | def __init__(self, field_name: str, reason: str) -> None: 32 | self.field_name = field_name 33 | self.reason = reason 34 | 35 | def error_message(self) -> str: 36 | return f"The {self.field_name} TypedColumns object {self.reason}." 37 | 38 | 39 | class InvalidSchemaFieldTypeError(Exception): 40 | def __str__(self) -> str: 41 | return self.error_message() 42 | 43 | def __init__(self, msg: str) -> None: 44 | self.msg = msg 45 | 46 | def error_message(self) -> str: 47 | return self.msg 48 | -------------------------------------------------------------------------------- /src/arize/pandas/generative/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Arize-ai/client_python/8f72212f42490baa7a90831c591b83ab6173ba83/src/arize/pandas/generative/__init__.py -------------------------------------------------------------------------------- /src/arize/pandas/generative/llm_evaluation/__init__.py: -------------------------------------------------------------------------------- 1 | from .hf_metrics import bleu, google_bleu, meteor, rouge, sacre_bleu 2 | 3 | __all__ = [ 4 | "bleu", 5 | "sacre_bleu", 6 | "meteor", 7 | "rouge", 8 | "google_bleu", 9 | ] 10 | -------------------------------------------------------------------------------- /src/arize/pandas/generative/llm_evaluation/constants.py: -------------------------------------------------------------------------------- 1 | IMPORT_ERROR_MESSAGE = ( 2 | "To enable evaluation of language models, the arize module must be installed with " 3 | "extra dependencies. Run: pip install 'arize[NLP_Metrics]'." 4 | ) 5 | -------------------------------------------------------------------------------- /src/arize/pandas/generative/nlp_metrics/__init__.py: -------------------------------------------------------------------------------- 1 | from .hf_metrics import bleu, google_bleu, meteor, rouge, sacre_bleu 2 | 3 | __all__ = [ 4 | "bleu", 5 | "google_bleu", 6 | "meteor", 7 | "rouge", 8 | "sacre_bleu", 9 | ] 10 | -------------------------------------------------------------------------------- /src/arize/pandas/proto/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Arize-ai/client_python/8f72212f42490baa7a90831c591b83ab6173ba83/src/arize/pandas/proto/__init__.py -------------------------------------------------------------------------------- /src/arize/pandas/surrogate_explainer/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Arize-ai/client_python/8f72212f42490baa7a90831c591b83ab6173ba83/src/arize/pandas/surrogate_explainer/__init__.py -------------------------------------------------------------------------------- /src/arize/pandas/surrogate_explainer/mimic.py: -------------------------------------------------------------------------------- 1 | import random 2 | import string 3 | from dataclasses import replace 4 | from typing import Callable, Tuple 5 | 6 | import numpy as np 7 | import pandas as pd 8 | from interpret_community.mimic.mimic_explainer import ( 9 | LGBMExplainableModel, 10 | MimicExplainer, 11 | ) 12 | from sklearn.preprocessing import LabelEncoder 13 | 14 | from arize.pandas.logger import Schema 15 | from arize.utils.types import ( 16 | CATEGORICAL_MODEL_TYPES, 17 | NUMERIC_MODEL_TYPES, 18 | ModelTypes, 19 | ) 20 | 21 | 22 | class Mimic: 23 | _testing = False 24 | 25 | def __init__(self, X: pd.DataFrame, model_func: Callable): 26 | self.explainer = MimicExplainer( 27 | model_func, 28 | X, 29 | LGBMExplainableModel, 30 | augment_data=False, 31 | is_function=True, 32 | ) 33 | 34 | def explain(self, X: pd.DataFrame) -> pd.DataFrame: 35 | return pd.DataFrame( 36 | self.explainer.explain_local(X).local_importance_values, 37 | columns=X.columns, 38 | index=X.index, 39 | ) 40 | 41 | @staticmethod 42 | def augment( 43 | df: pd.DataFrame, schema: Schema, model_type: ModelTypes 44 | ) -> Tuple[pd.DataFrame, Schema]: 45 | features = schema.feature_column_names 46 | X = df[features] 47 | 48 | if X.shape[1] == 0: 49 | return df, schema 50 | 51 | if model_type in CATEGORICAL_MODEL_TYPES: 52 | if not schema.prediction_score_column_name: 53 | raise ValueError( 54 | "To calculate surrogate explainability, " 55 | f"prediction_score_column_name must be specified in schema for {model_type}." 56 | ) 57 | 58 | y_col_name = schema.prediction_score_column_name 59 | y = df[y_col_name].to_numpy() 60 | 61 | _min, _max = np.min(y), np.max(y) 62 | if not 0 <= _min <= 1 or not 0 <= _max <= 1: 63 | raise ValueError( 64 | f"To calculate surrogate explainability for {model_type}, " 65 | f"prediction scores must be between 0 and 1, but current " 66 | f"prediction scores range from {_min} to {_max}." 67 | ) 68 | 69 | # model func requires 1 positional argument 70 | def model_func(_): # type: ignore 71 | return np.column_stack((1 - y, y)) 72 | 73 | elif model_type in NUMERIC_MODEL_TYPES: 74 | y_col_name = schema.prediction_label_column_name 75 | if schema.prediction_score_column_name is not None: 76 | y_col_name = schema.prediction_score_column_name 77 | y = df[y_col_name].to_numpy() 78 | 79 | _finite_count = np.isfinite(y).sum() 80 | if len(y) - _finite_count: 81 | raise ValueError( 82 | f"To calculate surrogate explainability for {model_type}, " 83 | f"predictions must not contain NaN or infinite values, but " 84 | f"{len(y) - _finite_count} NaN or infinite value(s) are found in {y_col_name}." 85 | ) 86 | 87 | # model func requires 1 positional argument 88 | def model_func(_): # type: ignore 89 | return y 90 | 91 | else: 92 | raise ValueError( 93 | "Surrogate explainability is not supported for the specified " 94 | f"model type {model_type}." 95 | ) 96 | 97 | # Column name mapping between features and feature importance values. 98 | # This is used to augment the schema. 99 | col_map = { 100 | ft: f"{''.join(random.choices(string.ascii_letters, k=8))}" 101 | for ft in features 102 | } 103 | aug_schema = replace(schema, shap_values_column_names=col_map) 104 | 105 | # Limit the total number of "cells" to 20M, unless it results in too few or 106 | # too many rows. This is done to keep the runtime low. Records not sampled 107 | # have feature importance values set to 0. 108 | samp_size = min( 109 | len(X), min(100_000, max(1_000, 20_000_000 // X.shape[1])) 110 | ) 111 | 112 | if samp_size < len(X): 113 | _mask = np.zeros(len(X), dtype=int) 114 | _mask[:samp_size] = 1 115 | np.random.shuffle(_mask) 116 | _mask = _mask.astype(bool) 117 | X = X[_mask] 118 | y = y[_mask] 119 | 120 | # Replace all pd.NA values with np.nan values 121 | for col in X.columns: 122 | if X[col].isna().any(): 123 | X[col] = X[col].astype(object).where(~X[col].isna(), np.nan) 124 | 125 | # Apply integer encoding to non-numeric columns. 126 | # Currently training and explaining detasets are the same, but 127 | # this can be changed in the future. The student model can be 128 | # fitted on a much larger dataset since it takes a lot less time. 129 | X = pd.concat( 130 | [ 131 | X.select_dtypes(exclude=[object, "string"]), 132 | pd.DataFrame( 133 | { 134 | name: LabelEncoder().fit_transform(data) 135 | for name, data in X.select_dtypes( 136 | include=[object, "string"] 137 | ).items() 138 | }, 139 | index=X.index, 140 | ), 141 | ], 142 | axis=1, 143 | ) 144 | 145 | aug_df = pd.concat( 146 | [ 147 | df, 148 | Mimic(X, model_func).explain(X).rename(col_map, axis=1), 149 | ], 150 | axis=1, 151 | ) 152 | 153 | # Fill null with zero so they're not counted as missing records by server 154 | if not Mimic._testing: 155 | aug_df.fillna({c: 0 for c in col_map.values()}, inplace=True) 156 | 157 | return ( 158 | aug_df, 159 | aug_schema, 160 | ) 161 | -------------------------------------------------------------------------------- /src/arize/pandas/tracing/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Arize-ai/client_python/8f72212f42490baa7a90831c591b83ab6173ba83/src/arize/pandas/tracing/__init__.py -------------------------------------------------------------------------------- /src/arize/pandas/tracing/constants.py: -------------------------------------------------------------------------------- 1 | # The defualt format used to parse datetime objects from strings 2 | DEFAULT_DATETIME_FMT = "%Y-%m-%dT%H:%M:%S.%f+00:00" 3 | # Minumum/Maximum number of characters for span/trace/parent ids in spans 4 | SPAN_ID_MIN_STR_LENGTH = 12 5 | SPAN_ID_MAX_STR_LENGTH = 128 6 | # Minumum/Maximum number of characters for span name 7 | SPAN_NAME_MIN_STR_LENGTH = 0 8 | SPAN_NAME_MAX_STR_LENGTH = 50 9 | # Minumum/Maximum number of characters for span status message 10 | SPAN_STATUS_MSG_MIN_STR_LENGTH = 0 11 | SPAN_STATUS_MSG_MAX_STR_LENGTH = 10_000 12 | # Minumum/Maximum number of characters for span event name 13 | SPAN_EVENT_NAME_MAX_STR_LENGTH = 100 14 | # Minumum/Maximum number of characters for span event attributes 15 | SPAN_EVENT_ATTRS_MAX_STR_LENGTH = 10_000 16 | # Maximum number of characters for span kind 17 | SPAN_KIND_MAX_STR_LENGTH = 100 18 | SPAN_EXCEPTION_TYPE_MAX_STR_LENGTH = 100 19 | SPAN_EXCEPTION_MESSAGE_MAX_STR_LENGTH = 100 20 | SPAN_EXCEPTION_STACK_TRACE_MAX_STR_LENGTH = 10_000 21 | SPAN_IO_VALUE_MAX_STR_LENGTH = 4_000_000 22 | SPAN_IO_MIME_TYPE_MAX_STR_LENGTH = 100 23 | SPAN_EMBEDDING_NAME_MAX_STR_LENGTH = 100 24 | SPAN_EMBEDDING_TEXT_MAX_STR_LENGTH = 4_000_000 25 | SPAN_LLM_MODEL_NAME_MAX_STR_LENGTH = 100 26 | SPAN_LLM_MESSAGE_ROLE_MAX_STR_LENGTH = 100 27 | SPAN_LLM_MESSAGE_CONTENT_MAX_STR_LENGTH = 4_000_000 28 | SPAN_LLM_TOOL_CALL_FUNCTION_NAME_MAX_STR_LENGTH = 500 29 | SPAN_LLM_PROMPT_TEMPLATE_MAX_STR_LENGTH = 4_000_000 30 | SPAN_LLM_PROMPT_TEMPLATE_VARIABLES_MAX_STR_LENGTH = 10_000 31 | SPAN_LLM_PROMPT_TEMPLATE_VERSION_MAX_STR_LENGTH = 100 32 | SPAN_TOOL_NAME_MAX_STR_LENGTH = 100 33 | SPAN_TOOL_DESCRIPTION_MAX_STR_LENGTH = 1_000 34 | SPAN_TOOL_PARAMETERS_MAX_STR_LENGTH = 1_000 35 | SPAN_RERANKER_QUERY_MAX_STR_LENGTH = 10_000 36 | SPAN_RERANKER_MODEL_NAME_MAX_STR_LENGTH = 100 37 | SPAN_DOCUMENT_ID_MAX_STR_LENGTH = 100 38 | SPAN_DOCUMENT_CONTENT_MAX_STR_LENGTH = 4_000_000 39 | JSON_STRING_MAX_STR_LENGTH = 4_000_000 40 | # Eval related constants 41 | EVAL_LABEL_MIN_STR_LENGTH = 1 # we do not accept empty strings 42 | EVAL_LABEL_MAX_STR_LENGTH = 100 43 | EVAL_EXPLANATION_MAX_STR_LENGTH = 10_000 44 | 45 | # Annotation related constants 46 | ANNOTATION_LABEL_MIN_STR_LENGTH = 1 47 | ANNOTATION_LABEL_MAX_STR_LENGTH = 100 # Max length for annotation label string 48 | ANNOTATION_UPDATED_BY_MAX_STR_LENGTH = 100 49 | ANNOTATION_NOTES_MAX_STR_LENGTH = ( 50 | 10_000 # Max length for annotation note string 51 | ) 52 | 53 | # Maximum number of characters for session/user ids in spans 54 | SESSION_ID_MAX_STR_LENGTH = 128 55 | USER_ID_MAX_STR_LENGTH = 128 56 | -------------------------------------------------------------------------------- /src/arize/pandas/tracing/types.py: -------------------------------------------------------------------------------- 1 | from enum import Enum, unique 2 | 3 | 4 | @unique 5 | class StatusCodes(Enum): 6 | UNSET = 0 7 | OK = 1 8 | ERROR = 2 9 | 10 | @classmethod 11 | def list_codes(cls): 12 | return [t.name for t in cls] 13 | -------------------------------------------------------------------------------- /src/arize/pandas/tracing/utils.py: -------------------------------------------------------------------------------- 1 | import json 2 | from datetime import datetime 3 | from typing import Any, Dict, Iterable, List, Optional, Union 4 | 5 | import numpy as np 6 | import pandas as pd 7 | 8 | from arize.utils.logging import logger 9 | 10 | from .columns import SPAN_OPENINFERENCE_COLUMNS, SpanColumnDataType 11 | 12 | 13 | def convert_timestamps(df: pd.DataFrame, fmt: str = "") -> pd.DataFrame: 14 | time_cols = [ 15 | col 16 | for col in SPAN_OPENINFERENCE_COLUMNS 17 | if col.data_type == SpanColumnDataType.TIMESTAMP 18 | ] 19 | for col in time_cols: 20 | df[col.name] = df[col.name].apply(lambda dt: _datetime_to_ns(dt, fmt)) 21 | return df 22 | 23 | 24 | def _datetime_to_ns(dt: Union[str, datetime], fmt: str) -> int: 25 | if isinstance(dt, str): 26 | try: 27 | ts = int(datetime.timestamp(datetime.strptime(dt, fmt)) * 1e9) 28 | except Exception as e: 29 | logger.error( 30 | f"Error parsing string '{dt}' to timestamp in nanoseconds " 31 | f"using the format '{fmt}': {e}" 32 | ) 33 | raise e 34 | return ts 35 | elif isinstance(dt, datetime): 36 | try: 37 | ts = int(datetime.timestamp(dt) * 1e9) 38 | except Exception as e: 39 | logger.error( 40 | f"Error converting datetime object to nanoseconds: {e}" 41 | ) 42 | raise e 43 | return ts 44 | elif isinstance(dt, (pd.Timestamp, pd.DatetimeIndex)): 45 | try: 46 | ts = int(datetime.timestamp(dt.to_pydatetime()) * 1e9) 47 | except Exception as e: 48 | logger.error( 49 | f"Error converting pandas Timestamp to nanoseconds: {e}" 50 | ) 51 | raise e 52 | return ts 53 | elif isinstance(dt, (int, float)): 54 | # Assume value already in nanoseconds, 55 | # validate timestamps in validate_values 56 | return dt 57 | else: 58 | e = TypeError(f"Cannot convert type {type(dt)} to nanoseconds") 59 | logger.error(f"Error converting pandas Timestamp to nanoseconds: {e}") 60 | raise e 61 | 62 | 63 | def jsonify_dictionaries(df: pd.DataFrame) -> pd.DataFrame: 64 | # NOTE: numpy arrays are not json serializable. Hence, we assume the 65 | # embeddings come as lists, not arrays 66 | dict_cols = [ 67 | col 68 | for col in SPAN_OPENINFERENCE_COLUMNS 69 | if col.data_type == SpanColumnDataType.DICT 70 | ] 71 | list_of_dict_cols = [ 72 | col 73 | for col in SPAN_OPENINFERENCE_COLUMNS 74 | if col.data_type == SpanColumnDataType.LIST_DICT 75 | ] 76 | for col in dict_cols: 77 | col_name = col.name 78 | if col_name not in df.columns: 79 | logger.debug(f"passing on {col_name}") 80 | continue 81 | logger.debug(f"jsonifying {col_name}") 82 | df[col_name] = df[col_name].apply(lambda d: _jsonify_dict(d)) 83 | 84 | for col in list_of_dict_cols: 85 | col_name = col.name 86 | if col_name not in df.columns: 87 | logger.debug(f"passing on {col_name}") 88 | continue 89 | logger.debug(f"jsonifying {col_name}") 90 | df[col_name] = df[col_name].apply( 91 | lambda list_of_dicts: _jsonify_list_of_dicts(list_of_dicts) 92 | ) 93 | return df 94 | 95 | 96 | def _jsonify_list_of_dicts( 97 | list_of_dicts: Optional[Iterable[Dict[str, Any]]], 98 | ) -> Optional[List[str]]: 99 | if not isinstance(list_of_dicts, Iterable) and isMissingValue( 100 | list_of_dicts 101 | ): 102 | return None 103 | list_of_json = [] 104 | for d in list_of_dicts: 105 | list_of_json.append(_jsonify_dict(d)) 106 | return list_of_json 107 | 108 | 109 | def _jsonify_dict(d: Optional[Dict[str, Any]]) -> Optional[str]: 110 | if d is None: 111 | return 112 | if isMissingValue(d): 113 | return None 114 | d = d.copy() # avoid side effects 115 | for k, v in d.items(): 116 | if isinstance(v, np.ndarray): 117 | d[k] = v.tolist() 118 | if isinstance(v, dict): 119 | d[k] = _jsonify_dict(v) 120 | return json.dumps(d, ensure_ascii=False) 121 | 122 | 123 | # Defines what is considered a missing value 124 | def isMissingValue(value: Any) -> bool: 125 | assumed_missing_values = ( 126 | np.inf, 127 | -np.inf, 128 | ) 129 | return value in assumed_missing_values or pd.isna(value) 130 | 131 | 132 | def extract_project_name_from_params( 133 | model_id: Optional[str] = None, project_name: Optional[str] = None 134 | ): 135 | project_name_param = model_id 136 | 137 | # Prefer project_name over model_id 138 | if project_name and project_name.strip(): 139 | project_name_param = project_name 140 | 141 | return project_name_param 142 | -------------------------------------------------------------------------------- /src/arize/pandas/tracing/validation/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Arize-ai/client_python/8f72212f42490baa7a90831c591b83ab6173ba83/src/arize/pandas/tracing/validation/__init__.py -------------------------------------------------------------------------------- /src/arize/pandas/tracing/validation/annotations/__init__.py: -------------------------------------------------------------------------------- 1 | # Empty file 2 | -------------------------------------------------------------------------------- /src/arize/pandas/tracing/validation/annotations/annotations_validation.py: -------------------------------------------------------------------------------- 1 | from itertools import chain 2 | from typing import List 3 | 4 | import pandas as pd 5 | 6 | # Keep common validation imports 7 | from arize.pandas.tracing.columns import SPAN_SPAN_ID_COL 8 | 9 | # Import annotation-specific validation modules (to be created) 10 | from arize.pandas.tracing.validation.annotations import ( 11 | dataframe_form_validation as df_validation, 12 | ) 13 | 14 | # Import the new annotation value validation module 15 | from arize.pandas.tracing.validation.annotations import value_validation 16 | from arize.pandas.tracing.validation.common import ( 17 | argument_validation as common_arg_validation, 18 | ) 19 | from arize.pandas.tracing.validation.common import ( 20 | dataframe_form_validation as common_df_validation, 21 | ) 22 | from arize.pandas.tracing.validation.common import ( 23 | value_validation as common_value_validation, 24 | ) 25 | from arize.pandas.validation import errors as err 26 | 27 | 28 | def validate_argument_types( 29 | annotations_dataframe: pd.DataFrame, 30 | project_name: str, 31 | ) -> List[err.ValidationError]: 32 | """Validates argument types for log_annotations.""" 33 | checks = chain( 34 | common_arg_validation._check_field_convertible_to_str(project_name), 35 | common_arg_validation._check_dataframe_type( 36 | annotations_dataframe 37 | ), # Use renamed parameter 38 | ) 39 | return list(checks) 40 | 41 | 42 | def validate_dataframe_form( 43 | annotations_dataframe: pd.DataFrame, 44 | ) -> List[err.ValidationError]: 45 | """Validates the form/structure of the annotation dataframe.""" 46 | # Call annotation-specific function (to be created) 47 | df_validation._log_info_dataframe_extra_column_names(annotations_dataframe) 48 | checks = chain( 49 | # Common checks remain the same 50 | common_df_validation._check_dataframe_index(annotations_dataframe), 51 | common_df_validation._check_dataframe_required_column_set( 52 | annotations_dataframe, required_columns=[SPAN_SPAN_ID_COL.name] 53 | ), 54 | common_df_validation._check_dataframe_for_duplicate_columns( 55 | annotations_dataframe 56 | ), 57 | # Call annotation-specific content type check (to be created) 58 | df_validation._check_dataframe_column_content_type( 59 | annotations_dataframe 60 | ), 61 | ) 62 | return list(checks) 63 | 64 | 65 | def validate_values( 66 | annotations_dataframe: pd.DataFrame, 67 | project_name: str, 68 | ) -> List[err.ValidationError]: 69 | """Validates the values within the annotation dataframe.""" 70 | checks = chain( 71 | # Common checks remain the same 72 | common_value_validation._check_invalid_project_name(project_name), 73 | # Call annotation-specific value checks from the imported module 74 | value_validation._check_annotation_cols(annotations_dataframe), 75 | value_validation._check_annotation_columns_null_values( 76 | annotations_dataframe 77 | ), 78 | value_validation._check_annotation_notes_column(annotations_dataframe), 79 | ) 80 | return list(checks) 81 | -------------------------------------------------------------------------------- /src/arize/pandas/tracing/validation/annotations/value_validation.py: -------------------------------------------------------------------------------- 1 | import re 2 | from itertools import chain 3 | from typing import List 4 | 5 | import pandas as pd 6 | 7 | import arize.pandas.tracing.constants as tracing_constants 8 | 9 | # Import annotation-specific and common column constants 10 | from arize.pandas.tracing.columns import ( 11 | ANNOTATION_COLUMN_PREFIX, 12 | ANNOTATION_LABEL_SUFFIX, 13 | ANNOTATION_NAME_PATTERN, 14 | ANNOTATION_NOTES_COLUMN_NAME, 15 | ANNOTATION_SCORE_SUFFIX, 16 | ANNOTATION_UPDATED_AT_SUFFIX, 17 | ANNOTATION_UPDATED_BY_SUFFIX, 18 | ) 19 | 20 | # Import common validation errors and functions 21 | from arize.pandas.tracing.validation.common import errors as tracing_err 22 | from arize.pandas.tracing.validation.common import ( 23 | value_validation as common_value_validation, 24 | ) 25 | from arize.pandas.validation import errors as err 26 | 27 | 28 | def _check_annotation_cols( 29 | dataframe: pd.DataFrame, 30 | ) -> List[err.ValidationError]: 31 | """Checks value length and validity for columns matching annotation patterns.""" 32 | checks = [] 33 | for col in dataframe.columns: 34 | if col.endswith(ANNOTATION_LABEL_SUFFIX): 35 | checks.append( 36 | common_value_validation._check_string_column_value_length( 37 | df=dataframe, 38 | col_name=col, 39 | min_len=tracing_constants.ANNOTATION_LABEL_MIN_STR_LENGTH, 40 | max_len=tracing_constants.ANNOTATION_LABEL_MAX_STR_LENGTH, 41 | is_required=False, # Individual columns are not required, null check handles completeness 42 | ) 43 | ) 44 | elif col.endswith(ANNOTATION_SCORE_SUFFIX): 45 | checks.append( 46 | common_value_validation._check_float_column_valid_numbers( 47 | df=dataframe, 48 | col_name=col, 49 | ) 50 | ) 51 | elif col.endswith(ANNOTATION_UPDATED_BY_SUFFIX): 52 | checks.append( 53 | common_value_validation._check_string_column_value_length( 54 | df=dataframe, 55 | col_name=col, 56 | min_len=1, 57 | max_len=tracing_constants.ANNOTATION_UPDATED_BY_MAX_STR_LENGTH, 58 | is_required=False, 59 | ) 60 | ) 61 | elif col.endswith(ANNOTATION_UPDATED_AT_SUFFIX): 62 | checks.append( 63 | common_value_validation._check_value_timestamp( 64 | df=dataframe, 65 | col_name=col, 66 | is_required=False, # updated_at is not strictly required per row 67 | ) 68 | ) 69 | # No check for ANNOTATION_NOTES_COLUMN_NAME here, handled by _check_annotation_notes_column 70 | return list(chain(*checks)) 71 | 72 | 73 | def _check_annotation_columns_null_values( 74 | dataframe: pd.DataFrame, 75 | ) -> List[err.ValidationError]: 76 | """Checks that for a given annotation name, at least one of label or score is non-null per row.""" 77 | invalid_annotation_names = [] 78 | annotation_names = set() 79 | # Find all unique annotation names from column headers 80 | for col in dataframe.columns: 81 | match = re.match(ANNOTATION_NAME_PATTERN, col) 82 | if match: 83 | annotation_names.add(match.group(1)) 84 | 85 | for ann_name in annotation_names: 86 | label_col = ( 87 | f"{ANNOTATION_COLUMN_PREFIX}{ann_name}{ANNOTATION_LABEL_SUFFIX}" 88 | ) 89 | score_col = ( 90 | f"{ANNOTATION_COLUMN_PREFIX}{ann_name}{ANNOTATION_SCORE_SUFFIX}" 91 | ) 92 | 93 | label_exists = label_col in dataframe.columns 94 | score_exists = score_col in dataframe.columns 95 | 96 | # Check only if both label and score columns exist for this name 97 | # If only one exists, its presence is sufficient 98 | if label_exists and score_exists: 99 | # Find rows where BOTH label and score are null 100 | condition = ( 101 | dataframe[label_col].isnull() & dataframe[score_col].isnull() 102 | ) 103 | if condition.any(): 104 | invalid_annotation_names.append(ann_name) 105 | # Check if only label exists but it's always null 106 | elif label_exists and not score_exists: 107 | if dataframe[label_col].isnull().all(): 108 | invalid_annotation_names.append(ann_name) 109 | # Check if only score exists but it's always null 110 | elif not label_exists and score_exists: 111 | if dataframe[score_col].isnull().all(): 112 | invalid_annotation_names.append(ann_name) 113 | 114 | # Use set to report each name only once 115 | unique_invalid_names = sorted(list(set(invalid_annotation_names))) 116 | if unique_invalid_names: 117 | return [ 118 | tracing_err.InvalidNullAnnotationLabelAndScore( 119 | annotation_names=unique_invalid_names 120 | ) 121 | ] 122 | return [] 123 | 124 | 125 | def _check_annotation_notes_column( 126 | dataframe: pd.DataFrame, 127 | ) -> List[err.ValidationError]: 128 | """Checks the value length for the optional annotation.notes column (raw string).""" 129 | col_name = ANNOTATION_NOTES_COLUMN_NAME 130 | if col_name in dataframe.columns: 131 | # Validate the length of the raw string 132 | return list( 133 | chain( 134 | *common_value_validation._check_string_column_value_length( 135 | df=dataframe, 136 | col_name=col_name, 137 | min_len=0, # Allow empty notes 138 | max_len=tracing_constants.ANNOTATION_NOTES_MAX_STR_LENGTH, 139 | is_required=False, 140 | ) 141 | ) 142 | ) 143 | return [] 144 | -------------------------------------------------------------------------------- /src/arize/pandas/tracing/validation/common/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Arize-ai/client_python/8f72212f42490baa7a90831c591b83ab6173ba83/src/arize/pandas/tracing/validation/common/__init__.py -------------------------------------------------------------------------------- /src/arize/pandas/tracing/validation/common/argument_validation.py: -------------------------------------------------------------------------------- 1 | from typing import Any, List, Optional 2 | 3 | import pandas as pd 4 | 5 | from arize.pandas.tracing.validation.common import errors as tracing_err 6 | from arize.pandas.validation import errors as err 7 | 8 | 9 | def _check_field_convertible_to_str( 10 | project_name: str, 11 | model_version: Optional[str] = None, 12 | ) -> List[err.InvalidFieldTypeConversion]: 13 | wrong_fields = [] 14 | if project_name is not None and not isinstance(project_name, str): 15 | try: 16 | str(project_name) 17 | except Exception: 18 | wrong_fields.append("project_name") 19 | if model_version is not None and not isinstance(model_version, str): 20 | try: 21 | str(model_version) 22 | except Exception: 23 | wrong_fields.append("model_version") 24 | 25 | if wrong_fields: 26 | return [err.InvalidFieldTypeConversion(wrong_fields, "string")] 27 | return [] 28 | 29 | 30 | def _check_dataframe_type( 31 | dataframe, 32 | ) -> List[tracing_err.InvalidTypeArgument]: 33 | if not isinstance(dataframe, pd.DataFrame): 34 | return [ 35 | tracing_err.InvalidTypeArgument( 36 | wrong_arg=dataframe, 37 | arg_name="dataframe", 38 | arg_type="pandas DataFrame", 39 | ) 40 | ] 41 | return [] 42 | 43 | 44 | def _check_datetime_format_type( 45 | dt_fmt: Any, 46 | ) -> List[tracing_err.InvalidTypeArgument]: 47 | if not isinstance(dt_fmt, str): 48 | return [ 49 | tracing_err.InvalidTypeArgument( 50 | wrong_arg=dt_fmt, 51 | arg_name="dateTime format", 52 | arg_type="string", 53 | ) 54 | ] 55 | return [] 56 | -------------------------------------------------------------------------------- /src/arize/pandas/tracing/validation/common/dataframe_form_validation.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | import pandas as pd 4 | 5 | from arize.pandas.tracing.validation.common import errors as tracing_err 6 | from arize.pandas.validation import errors as err 7 | 8 | 9 | def _check_dataframe_index( 10 | dataframe: pd.DataFrame, 11 | ) -> List[err.InvalidDataFrameIndex]: 12 | if (dataframe.index != dataframe.reset_index(drop=True).index).any(): 13 | return [err.InvalidDataFrameIndex()] 14 | return [] 15 | 16 | 17 | def _check_dataframe_required_column_set( 18 | df: pd.DataFrame, 19 | required_columns: List[str], 20 | ) -> List[tracing_err.InvalidDataFrameMissingColumns]: 21 | existing_columns = set(df.columns) 22 | missing_cols = [] 23 | for col in required_columns: 24 | if col not in existing_columns: 25 | missing_cols.append(col) 26 | 27 | if missing_cols: 28 | return [ 29 | tracing_err.InvalidDataFrameMissingColumns( 30 | missing_cols=missing_cols 31 | ) 32 | ] 33 | return [] 34 | 35 | 36 | def _check_dataframe_for_duplicate_columns( 37 | df: pd.DataFrame, 38 | ) -> List[tracing_err.InvalidDataFrameDuplicateColumns]: 39 | # Get the duplicated column names from the dataframe 40 | duplicate_columns = df.columns[df.columns.duplicated()] 41 | if not duplicate_columns.empty: 42 | return [tracing_err.InvalidDataFrameDuplicateColumns(duplicate_columns)] 43 | return [] 44 | -------------------------------------------------------------------------------- /src/arize/pandas/tracing/validation/evals/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Arize-ai/client_python/8f72212f42490baa7a90831c591b83ab6173ba83/src/arize/pandas/tracing/validation/evals/__init__.py -------------------------------------------------------------------------------- /src/arize/pandas/tracing/validation/evals/dataframe_form_validation.py: -------------------------------------------------------------------------------- 1 | import re 2 | from typing import List 3 | 4 | import pandas as pd 5 | 6 | from arize.pandas.tracing.columns import ( 7 | EVAL_COLUMN_PATTERN, 8 | EVAL_EXPLANATION_PATTERN, 9 | EVAL_LABEL_PATTERN, 10 | EVAL_SCORE_PATTERN, 11 | SPAN_SPAN_ID_COL, 12 | ) 13 | from arize.pandas.tracing.utils import isMissingValue 14 | from arize.pandas.tracing.validation.common import errors as tracing_err 15 | from arize.utils.logging import log_a_list, logger 16 | 17 | 18 | def _log_info_dataframe_extra_column_names( 19 | df: pd.DataFrame, 20 | ) -> None: 21 | if df is None: 22 | return None 23 | irrelevant_columns = [ 24 | col 25 | for col in df.columns 26 | if not ( 27 | pd.Series(col).str.match(EVAL_COLUMN_PATTERN).any() 28 | or col == SPAN_SPAN_ID_COL.name 29 | ) 30 | ] 31 | if irrelevant_columns: 32 | logger.info( 33 | "The following columns do not follow the evaluation column naming convention " 34 | f"and will be ignored: {log_a_list(list_of_str=irrelevant_columns, join_word='and')}. " 35 | "Evaluation columns must be named as follows: " 36 | "- eval..label" 37 | "- eval..score" 38 | "- eval..explanation" 39 | ) 40 | return None 41 | 42 | 43 | def _check_dataframe_column_content_type( 44 | df: pd.DataFrame, 45 | ) -> List[tracing_err.InvalidDataFrameColumnContentTypes]: 46 | wrong_labels_cols = [] 47 | wrong_scores_cols = [] 48 | wrong_explanations_cols = [] 49 | errors = [] 50 | eval_label_re = re.compile(EVAL_LABEL_PATTERN) 51 | eval_score_re = re.compile(EVAL_SCORE_PATTERN) 52 | eval_explanation_re = re.compile(EVAL_EXPLANATION_PATTERN) 53 | for column in df.columns: 54 | if column == SPAN_SPAN_ID_COL.name and not all( 55 | isinstance(value, str) for value in df[column] 56 | ): 57 | errors.append( 58 | tracing_err.InvalidDataFrameColumnContentTypes( 59 | invalid_type_cols=[SPAN_SPAN_ID_COL.name], 60 | expected_type="string", 61 | ), 62 | ) 63 | if eval_label_re.match(column): 64 | if not all( 65 | isinstance(value, str) or isMissingValue(value) 66 | for value in df[column] 67 | ): 68 | wrong_labels_cols.append(column) 69 | elif eval_score_re.match(column): 70 | if not all( 71 | isinstance(value, (int, float)) or isMissingValue(value) 72 | for value in df[column] 73 | ): 74 | wrong_scores_cols.append(column) 75 | elif eval_explanation_re.match(column) and not all( 76 | isinstance(value, str) or isMissingValue(value) 77 | for value in df[column] 78 | ): 79 | wrong_explanations_cols.append(column) 80 | 81 | if wrong_labels_cols: 82 | errors.append( 83 | tracing_err.InvalidDataFrameColumnContentTypes( 84 | invalid_type_cols=wrong_labels_cols, 85 | expected_type="strings", 86 | ), 87 | ) 88 | if wrong_scores_cols: 89 | errors.append( 90 | tracing_err.InvalidDataFrameColumnContentTypes( 91 | invalid_type_cols=wrong_scores_cols, 92 | expected_type="ints or floats", 93 | ), 94 | ) 95 | if wrong_explanations_cols: 96 | errors.append( 97 | tracing_err.InvalidDataFrameColumnContentTypes( 98 | invalid_type_cols=wrong_explanations_cols, 99 | expected_type="strings", 100 | ), 101 | ) 102 | return errors 103 | -------------------------------------------------------------------------------- /src/arize/pandas/tracing/validation/evals/evals_validation.py: -------------------------------------------------------------------------------- 1 | from itertools import chain 2 | from typing import List, Optional 3 | 4 | import pandas as pd 5 | 6 | from arize.pandas.tracing.columns import SPAN_SPAN_ID_COL 7 | from arize.pandas.tracing.validation.common import ( 8 | argument_validation as common_arg_validation, 9 | ) 10 | from arize.pandas.tracing.validation.common import ( 11 | dataframe_form_validation as common_df_validation, 12 | ) 13 | from arize.pandas.tracing.validation.common import ( 14 | value_validation as common_value_validation, 15 | ) 16 | from arize.pandas.tracing.validation.evals import ( 17 | dataframe_form_validation as df_validation, 18 | ) 19 | from arize.pandas.tracing.validation.evals import value_validation 20 | from arize.pandas.validation import errors as err 21 | 22 | 23 | def validate_argument_types( 24 | evals_dataframe: pd.DataFrame, 25 | project_name: str, 26 | model_version: Optional[str] = None, 27 | ) -> List[err.ValidationError]: 28 | checks = chain( 29 | common_arg_validation._check_field_convertible_to_str( 30 | project_name, model_version 31 | ), 32 | common_arg_validation._check_dataframe_type(evals_dataframe), 33 | ) 34 | return list(checks) 35 | 36 | 37 | def validate_dataframe_form( 38 | evals_dataframe: pd.DataFrame, 39 | ) -> List[err.ValidationError]: 40 | df_validation._log_info_dataframe_extra_column_names(evals_dataframe) 41 | checks = chain( 42 | # Common 43 | common_df_validation._check_dataframe_index(evals_dataframe), 44 | common_df_validation._check_dataframe_required_column_set( 45 | evals_dataframe, required_columns=[SPAN_SPAN_ID_COL.name] 46 | ), 47 | common_df_validation._check_dataframe_for_duplicate_columns( 48 | evals_dataframe 49 | ), 50 | # Eval specific 51 | df_validation._check_dataframe_column_content_type(evals_dataframe), 52 | ) 53 | return list(checks) 54 | 55 | 56 | def validate_values( 57 | evals_dataframe: pd.DataFrame, 58 | project_name: str, 59 | model_version: Optional[str] = None, 60 | ) -> List[err.ValidationError]: 61 | checks = chain( 62 | # Common 63 | common_value_validation._check_invalid_project_name(project_name), 64 | common_value_validation._check_invalid_model_version(model_version), 65 | # Eval specific 66 | value_validation._check_eval_cols(evals_dataframe), 67 | value_validation._check_eval_columns_null_values(evals_dataframe), 68 | ) 69 | return list(checks) 70 | -------------------------------------------------------------------------------- /src/arize/pandas/tracing/validation/evals/value_validation.py: -------------------------------------------------------------------------------- 1 | import re 2 | from itertools import chain 3 | from typing import List 4 | 5 | import pandas as pd 6 | 7 | import arize.pandas.tracing.constants as tracing_constants 8 | from arize.pandas.tracing.columns import ( 9 | EVAL_COLUMN_PREFIX, 10 | EVAL_EXPLANATION_SUFFIX, 11 | EVAL_LABEL_SUFFIX, 12 | EVAL_NAME_PATTERN, 13 | EVAL_SCORE_SUFFIX, 14 | ) 15 | from arize.pandas.tracing.validation.common import errors as tracing_err 16 | from arize.pandas.tracing.validation.common import value_validation 17 | from arize.pandas.validation import errors as err 18 | 19 | 20 | def _check_eval_cols( 21 | dataframe: pd.DataFrame, 22 | ) -> List[err.ValidationError]: 23 | checks = [] 24 | for col in dataframe.columns: 25 | if col.endswith(EVAL_LABEL_SUFFIX): 26 | checks.append( 27 | value_validation._check_string_column_value_length( 28 | df=dataframe, 29 | col_name=col, 30 | min_len=tracing_constants.EVAL_LABEL_MIN_STR_LENGTH, 31 | max_len=tracing_constants.EVAL_LABEL_MAX_STR_LENGTH, 32 | is_required=False, 33 | ) 34 | ) 35 | elif col.endswith(EVAL_SCORE_SUFFIX): 36 | checks.append( 37 | value_validation._check_float_column_valid_numbers( 38 | df=dataframe, 39 | col_name=col, 40 | ) 41 | ) 42 | elif col.endswith(EVAL_EXPLANATION_SUFFIX): 43 | checks.append( 44 | value_validation._check_string_column_value_length( 45 | df=dataframe, 46 | col_name=col, 47 | min_len=0, 48 | max_len=tracing_constants.EVAL_EXPLANATION_MAX_STR_LENGTH, 49 | is_required=False, 50 | ) 51 | ) 52 | return list(chain(*checks)) 53 | 54 | 55 | # Evals are valid if they are entirely null (no label, score, or explanation) since this 56 | # represents a span without an eval. Evals are also valid if at least one of label or score 57 | # is not null 58 | def _check_eval_columns_null_values( 59 | dataframe: pd.DataFrame, 60 | ) -> List[err.ValidationError]: 61 | invalid_eval_names = [] 62 | eval_names = set() 63 | for col in dataframe.columns: 64 | match = re.match(EVAL_NAME_PATTERN, col) 65 | if match: 66 | eval_names.add(match.group(1)) 67 | 68 | for eval_name in eval_names: 69 | label_col = f"{EVAL_COLUMN_PREFIX}{eval_name}{EVAL_LABEL_SUFFIX}" 70 | score_col = f"{EVAL_COLUMN_PREFIX}{eval_name}{EVAL_SCORE_SUFFIX}" 71 | explanation_col = ( 72 | f"{EVAL_COLUMN_PREFIX}{eval_name}{EVAL_EXPLANATION_SUFFIX}" 73 | ) 74 | columns_to_check = [] 75 | 76 | if label_col in dataframe.columns: 77 | columns_to_check.append(label_col) 78 | if score_col in dataframe.columns: 79 | columns_to_check.append(score_col) 80 | 81 | # If there are explanations, they cannot be orphan () 82 | if explanation_col in dataframe.columns: 83 | condition = ( 84 | dataframe[columns_to_check].isnull().all(axis=1) 85 | & ~dataframe[explanation_col].isnull() 86 | ) 87 | if condition.any(): 88 | invalid_eval_names.append(eval_name) 89 | 90 | if invalid_eval_names: 91 | return [ 92 | tracing_err.InvalidNullEvalLabelAndScore( 93 | eval_names=invalid_eval_names 94 | ) 95 | ] 96 | return [] 97 | -------------------------------------------------------------------------------- /src/arize/pandas/tracing/validation/metadata/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /src/arize/pandas/tracing/validation/metadata/argument_validation.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | import pandas as pd 4 | 5 | from ....validation.errors import ValidationError 6 | 7 | 8 | class MetadataArgumentError(ValidationError): 9 | def __init__(self, message: str, resolution: str) -> None: 10 | self.message = message 11 | self.resolution = resolution 12 | 13 | def __repr__(self) -> str: 14 | return "Metadata_Argument_Error" 15 | 16 | def error_message(self) -> str: 17 | return f"{self.message} {self.resolution}" 18 | 19 | 20 | def validate_argument_types( 21 | metadata_dataframe, project_name 22 | ) -> List[ValidationError]: 23 | """ 24 | Validates the types of arguments passed to update_spans_metadata. 25 | 26 | Args: 27 | metadata_dataframe: DataFrame with span IDs and patch documents 28 | project_name: Name of the project 29 | 30 | Returns: 31 | A list of validation errors, empty if none found 32 | """ 33 | errors = [] 34 | 35 | # Check metadata_dataframe type 36 | if not isinstance(metadata_dataframe, pd.DataFrame): 37 | errors.append( 38 | MetadataArgumentError( 39 | "metadata_dataframe must be a pandas DataFrame", 40 | "The metadata_dataframe argument must be a pandas DataFrame.", 41 | ) 42 | ) 43 | 44 | # Check project_name 45 | if not isinstance(project_name, str) or not project_name.strip(): 46 | errors.append( 47 | MetadataArgumentError( 48 | "project_name must be a non-empty string", 49 | "The project_name argument must be a non-empty string.", 50 | ) 51 | ) 52 | 53 | return errors 54 | -------------------------------------------------------------------------------- /src/arize/pandas/tracing/validation/metadata/dataframe_form_validation.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | from ....tracing.columns import SPAN_SPAN_ID_COL 4 | from ....validation.errors import ValidationError 5 | 6 | 7 | class MetadataFormError(ValidationError): 8 | def __init__(self, message: str, resolution: str) -> None: 9 | self.message = message 10 | self.resolution = resolution 11 | 12 | def __repr__(self) -> str: 13 | return "Metadata_Form_Error" 14 | 15 | def error_message(self) -> str: 16 | return f"{self.message} {self.resolution}" 17 | 18 | 19 | def validate_dataframe_form( 20 | metadata_dataframe, patch_document_column_name="patch_document" 21 | ) -> List[ValidationError]: 22 | """ 23 | Validates the structure of the metadata update dataframe. 24 | 25 | Args: 26 | metadata_dataframe: DataFrame with span IDs and patch documents or attributes.metadata.* columns 27 | patch_document_column_name: Name of the column containing patch documents 28 | 29 | Returns: 30 | A list of validation errors, empty if none found 31 | """ 32 | errors = [] 33 | 34 | # Check for empty dataframe 35 | if metadata_dataframe.empty: 36 | errors.append( 37 | MetadataFormError( 38 | "metadata_dataframe is empty", 39 | "The metadata_dataframe is empty. No data to send.", 40 | ) 41 | ) 42 | return errors 43 | 44 | # Check for required span_id column 45 | if SPAN_SPAN_ID_COL.name not in metadata_dataframe.columns: 46 | errors.append( 47 | MetadataFormError( 48 | f"Missing required column: {SPAN_SPAN_ID_COL.name}", 49 | f"The metadata_dataframe must contain the span ID column: {SPAN_SPAN_ID_COL.name}.", 50 | ) 51 | ) 52 | return errors 53 | 54 | # Check for metadata columns - either patch_document or attributes.metadata.* columns 55 | has_patch_document = ( 56 | patch_document_column_name in metadata_dataframe.columns 57 | ) 58 | metadata_prefix = "attributes.metadata." 59 | metadata_columns = [ 60 | col 61 | for col in metadata_dataframe.columns 62 | if col.startswith(metadata_prefix) 63 | ] 64 | has_metadata_fields = len(metadata_columns) > 0 65 | 66 | if not has_patch_document and not has_metadata_fields: 67 | errors.append( 68 | MetadataFormError( 69 | "Missing metadata columns", 70 | f"The metadata_dataframe must contain either the patch document column " 71 | f"'{patch_document_column_name}' or at least one column with the prefix " 72 | f"'{metadata_prefix}'.", 73 | ) 74 | ) 75 | return errors 76 | 77 | # Check for null values in required columns 78 | null_columns = [] 79 | 80 | # Span ID cannot be null 81 | if metadata_dataframe[SPAN_SPAN_ID_COL.name].isna().any(): 82 | null_columns.append(SPAN_SPAN_ID_COL.name) 83 | 84 | # If using patch_document, it cannot be null 85 | if ( 86 | has_patch_document 87 | and metadata_dataframe[patch_document_column_name].isna().any() 88 | ): 89 | null_columns.append(patch_document_column_name) 90 | 91 | # If using metadata fields, check each one 92 | if has_metadata_fields: 93 | for col in metadata_columns: 94 | if ( 95 | metadata_dataframe[col].isna().all() 96 | ): # All values in column are null 97 | null_columns.append(col) 98 | 99 | if null_columns: 100 | errors.append( 101 | MetadataFormError( 102 | f"Columns with null values: {', '.join(null_columns)}", 103 | f"The following columns cannot contain null values: {', '.join(null_columns)}.", 104 | ) 105 | ) 106 | 107 | return errors 108 | -------------------------------------------------------------------------------- /src/arize/pandas/tracing/validation/spans/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Arize-ai/client_python/8f72212f42490baa7a90831c591b83ab6173ba83/src/arize/pandas/tracing/validation/spans/__init__.py -------------------------------------------------------------------------------- /src/arize/pandas/tracing/validation/spans/spans_validation.py: -------------------------------------------------------------------------------- 1 | from itertools import chain 2 | from typing import List, Optional 3 | 4 | import pandas as pd 5 | 6 | from arize.pandas.tracing.columns import SPAN_OPENINFERENCE_REQUIRED_COLUMNS 7 | from arize.pandas.tracing.validation.common import ( 8 | argument_validation as common_arg_validation, 9 | ) 10 | from arize.pandas.tracing.validation.common import ( 11 | dataframe_form_validation as common_df_validation, 12 | ) 13 | from arize.pandas.tracing.validation.common import ( 14 | value_validation as common_value_validation, 15 | ) 16 | from arize.pandas.tracing.validation.spans import ( 17 | dataframe_form_validation as df_validation, 18 | ) 19 | from arize.pandas.tracing.validation.spans import value_validation 20 | from arize.pandas.validation import errors as err 21 | 22 | 23 | def validate_argument_types( 24 | spans_dataframe: pd.DataFrame, 25 | project_name: str, 26 | dt_fmt: str, 27 | model_version: Optional[str] = None, 28 | ) -> List[err.ValidationError]: 29 | checks = chain( 30 | common_arg_validation._check_field_convertible_to_str( 31 | project_name, model_version 32 | ), 33 | common_arg_validation._check_dataframe_type(spans_dataframe), 34 | common_arg_validation._check_datetime_format_type(dt_fmt), 35 | ) 36 | return list(checks) 37 | 38 | 39 | def validate_dataframe_form( 40 | spans_dataframe: pd.DataFrame, 41 | ) -> List[err.ValidationError]: 42 | df_validation._log_info_dataframe_extra_column_names(spans_dataframe) 43 | checks = chain( 44 | # Common 45 | common_df_validation._check_dataframe_index(spans_dataframe), 46 | common_df_validation._check_dataframe_required_column_set( 47 | spans_dataframe, 48 | required_columns=[ 49 | col.name for col in SPAN_OPENINFERENCE_REQUIRED_COLUMNS 50 | ], 51 | ), 52 | common_df_validation._check_dataframe_for_duplicate_columns( 53 | spans_dataframe 54 | ), 55 | # Spans specific 56 | df_validation._check_dataframe_column_content_type(spans_dataframe), 57 | ) 58 | return list(checks) 59 | 60 | 61 | def validate_values( 62 | spans_dataframe: pd.DataFrame, 63 | project_name: str, 64 | model_version: Optional[str] = None, 65 | ) -> List[err.ValidationError]: 66 | checks = chain( 67 | # Common 68 | common_value_validation._check_invalid_project_name(project_name), 69 | common_value_validation._check_invalid_model_version(model_version), 70 | # Spans specific 71 | value_validation._check_span_root_field_values(spans_dataframe), 72 | value_validation._check_span_attributes_values(spans_dataframe), 73 | ) 74 | return list(checks) 75 | -------------------------------------------------------------------------------- /src/arize/pandas/validation/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Arize-ai/client_python/8f72212f42490baa7a90831c591b83ab6173ba83/src/arize/pandas/validation/__init__.py -------------------------------------------------------------------------------- /src/arize/single_log/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Arize-ai/client_python/8f72212f42490baa7a90831c591b83ab6173ba83/src/arize/single_log/__init__.py -------------------------------------------------------------------------------- /src/arize/single_log/casting.py: -------------------------------------------------------------------------------- 1 | import math 2 | from typing import List, Union 3 | 4 | from arize.utils.types import ArizeTypes, TypedValue 5 | 6 | from .errors import CastingError 7 | 8 | 9 | def cast_dictionary(d: dict) -> Union[dict, None]: 10 | if not d: 11 | return None 12 | cast_dict = {} 13 | for k, v in d.items(): 14 | if isinstance(v, TypedValue): 15 | v = cast_value(v) 16 | cast_dict[k] = v 17 | return cast_dict 18 | 19 | 20 | def cast_value( 21 | typed_value: TypedValue, 22 | ) -> Union[str, int, float, List[str], None]: 23 | """ 24 | Casts a TypedValue to its provided type, preserving all null values as None or float('nan'). 25 | 26 | Arguments: 27 | --------- 28 | typed_value: TypedValue 29 | The TypedValue to cast. 30 | 31 | Returns: 32 | ------- 33 | Union[str, int, float, List[str], None] 34 | The cast value. 35 | 36 | Raises: 37 | ------ 38 | CastingError 39 | If the value cannot be cast to the provided type. 40 | 41 | """ 42 | if typed_value.value is None: 43 | return None 44 | 45 | if typed_value.type == ArizeTypes.FLOAT: 46 | return _cast_to_float(typed_value) 47 | elif typed_value.type == ArizeTypes.INT: 48 | return _cast_to_int(typed_value) 49 | elif typed_value.type == ArizeTypes.STR: 50 | return _cast_to_str(typed_value) 51 | else: 52 | raise CastingError("Unknown casting type", typed_value) 53 | 54 | 55 | def _cast_to_float(typed_value: TypedValue) -> Union[float, None]: 56 | try: 57 | return float(typed_value.value) 58 | except Exception as e: 59 | raise CastingError(str(e), typed_value) from e 60 | 61 | 62 | def _cast_to_int(typed_value: TypedValue) -> Union[int, None]: 63 | # a NaN float can't be cast to an int. Proactively return None instead. 64 | if isinstance(typed_value.value, float) and math.isnan(typed_value.value): 65 | return None 66 | # If the value is a float, to avoid losing data precision, 67 | # we can only cast to an int if it is equivalent to an integer (e.g. 7.0). 68 | if ( 69 | isinstance(typed_value.value, float) 70 | and not typed_value.value.is_integer() 71 | ): 72 | raise CastingError( 73 | "Cannot convert float with non-zero fractional part to int", 74 | typed_value, 75 | ) 76 | try: 77 | return int(typed_value.value) 78 | except Exception as e: 79 | raise CastingError(str(e), typed_value) from e 80 | 81 | 82 | def _cast_to_str(typed_value: TypedValue) -> Union[str, None]: 83 | # a NaN float can't be cast to a string. Proactively return None instead. 84 | if isinstance(typed_value.value, float) and math.isnan(typed_value.value): 85 | return None 86 | try: 87 | return str(typed_value.value) 88 | except Exception as e: 89 | raise CastingError(str(e), typed_value) from e 90 | -------------------------------------------------------------------------------- /src/arize/single_log/errors.py: -------------------------------------------------------------------------------- 1 | from arize.utils.types import TypedValue 2 | 3 | 4 | class CastingError(Exception): 5 | def __str__(self) -> str: 6 | return self.error_message() 7 | 8 | def __init__(self, error_msg: str, typed_value: TypedValue) -> None: 9 | self.error_msg = error_msg 10 | self.typed_value = typed_value 11 | 12 | def error_message(self) -> str: 13 | return ( 14 | f"Failed to cast value {self.typed_value.value} of type {type(self.typed_value.value)} " 15 | f"to type {self.typed_value.type}. " 16 | f"Error: {self.error_msg}." 17 | ) 18 | -------------------------------------------------------------------------------- /src/arize/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Arize-ai/client_python/8f72212f42490baa7a90831c591b83ab6173ba83/src/arize/utils/__init__.py -------------------------------------------------------------------------------- /src/arize/utils/constants.py: -------------------------------------------------------------------------------- 1 | import json 2 | from pathlib import Path 3 | 4 | MAX_BYTES_PER_BULK_RECORD = 100000 5 | MAX_DAYS_WITHIN_RANGE = 365 6 | MIN_PREDICTION_ID_LEN = 1 7 | MAX_PREDICTION_ID_LEN = 512 8 | MIN_DOCUMENT_ID_LEN = 1 9 | MAX_DOCUMENT_ID_LEN = 128 10 | # The maximum number of character for tag values 11 | MAX_TAG_LENGTH = 20_000 12 | MAX_TAG_LENGTH_TRUNCATION = 1_000 13 | # The maximum number of character for embedding raw data 14 | MAX_RAW_DATA_CHARACTERS = 2_000_000 15 | MAX_RAW_DATA_CHARACTERS_TRUNCATION = 5_000 16 | # The maximum number of acceptable years in the past from current time for prediction_timestamps 17 | MAX_PAST_YEARS_FROM_CURRENT_TIME = 5 18 | # The maximum number of acceptable years in the future from current time for prediction_timestamps 19 | MAX_FUTURE_YEARS_FROM_CURRENT_TIME = 1 20 | # The maximum number of character for llm model name 21 | MAX_LLM_MODEL_NAME_LENGTH = 20_000 22 | MAX_LLM_MODEL_NAME_LENGTH_TRUNCATION = 50 23 | # The maximum number of character for prompt template 24 | MAX_PROMPT_TEMPLATE_LENGTH = 50_000 25 | MAX_PROMPT_TEMPLATE_LENGTH_TRUNCATION = 5_000 26 | # The maximum number of character for prompt template version 27 | MAX_PROMPT_TEMPLATE_VERSION_LENGTH = 20_000 28 | MAX_PROMPT_TEMPLATE_VERSION_LENGTH_TRUNCATION = 50 29 | # The maximum number of embeddings 30 | MAX_NUMBER_OF_EMBEDDINGS = 30 31 | MAX_EMBEDDING_DIMENSIONALITY = 20_000 32 | # The maximum number of classes for multi class 33 | MAX_NUMBER_OF_MULTI_CLASS_CLASSES = 300 34 | MAX_MULTI_CLASS_NAME_LENGTH = 100 35 | # The maximum number of references in embedding similarity search params 36 | MAX_NUMBER_OF_SIMILARITY_REFERENCES = 10 37 | 38 | # Arize generated columns 39 | GENERATED_PREDICTION_LABEL_COL = "arize_generated_prediction_label" 40 | GENERATED_LLM_PARAMS_JSON_COL = "arize_generated_llm_params_json" 41 | 42 | # reserved columns for LLM run metadata 43 | LLM_RUN_METADATA_TOTAL_TOKEN_COUNT_TAG_NAME = "total_token_count" 44 | LLM_RUN_METADATA_PROMPT_TOKEN_COUNT_TAG_NAME = "prompt_token_count" 45 | LLM_RUN_METADATA_RESPONSE_TOKEN_COUNT_TAG_NAME = "response_token_count" 46 | LLM_RUN_METADATA_RESPONSE_LATENCY_MS_TAG_NAME = "response_latency_ms" 47 | 48 | # all reserved tags 49 | RESERVED_TAG_COLS = [ 50 | LLM_RUN_METADATA_TOTAL_TOKEN_COUNT_TAG_NAME, 51 | LLM_RUN_METADATA_PROMPT_TOKEN_COUNT_TAG_NAME, 52 | LLM_RUN_METADATA_RESPONSE_TOKEN_COUNT_TAG_NAME, 53 | LLM_RUN_METADATA_RESPONSE_LATENCY_MS_TAG_NAME, 54 | ] 55 | 56 | 57 | # Authentication via environment variables 58 | SPACE_KEY_ENVVAR_NAME = "ARIZE_SPACE_KEY" 59 | API_KEY_ENVVAR_NAME = "ARIZE_API_KEY" 60 | DEVELOPER_KEY_ENVVAR_NAME = "ARIZE_DEVELOPER_KEY" 61 | SPACE_ID_ENVVAR_NAME = "ARIZE_SPACE_ID" 62 | 63 | # Default public Flight endpoint when not provided through env variable nor profile 64 | DEFAULT_ARIZE_FLIGHT_HOST = "flight.arize.com" 65 | DEFAULT_ARIZE_FLIGHT_PORT = 443 66 | 67 | path = Path(__file__).with_name("model_mapping.json") 68 | with path.open("r") as f: 69 | MODEL_MAPPING_CONFIG = json.load(f) 70 | -------------------------------------------------------------------------------- /src/arize/utils/logging.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import sys 3 | from typing import Any, List 4 | 5 | 6 | class CustomLogFormatter(logging.Formatter): 7 | grey = "\x1b[38;21m" 8 | blue = "\x1b[38;5;39m" 9 | yellow = "\x1b[33m" 10 | red = "\x1b[38;5;196m" 11 | bold_red = "\x1b[31;1m" 12 | reset = "\x1b[0m" 13 | 14 | def __init__(self, fmt): 15 | super().__init__() 16 | self.fmt = fmt 17 | self.FORMATS = { 18 | logging.DEBUG: self.blue + self.fmt + self.reset, 19 | logging.INFO: self.grey + self.fmt + self.reset, 20 | logging.WARNING: self.yellow + self.fmt + self.reset, 21 | logging.ERROR: self.red + self.fmt + self.reset, 22 | logging.CRITICAL: self.bold_red + self.fmt + self.reset, 23 | } 24 | 25 | def format(self, record): 26 | log_fmt = self.FORMATS.get(record.levelno) 27 | formatter = logging.Formatter(log_fmt) 28 | return formatter.format(record) 29 | 30 | 31 | logger = logging.getLogger(__name__) 32 | logger.propagate = False 33 | if logger.hasHandlers(): 34 | logger.handlers.clear() 35 | logger.setLevel(logging.INFO) 36 | fmt = " %(name)s | %(levelname)s | %(message)s" 37 | if hasattr(sys, "ps1"): # for python interactive mode 38 | handler = logging.StreamHandler(sys.stdout) 39 | handler.setLevel(logging.INFO) 40 | handler.setFormatter(CustomLogFormatter(fmt)) 41 | logger.addHandler(handler) 42 | 43 | 44 | def get_truncation_warning_message(instance, limit) -> str: 45 | return ( 46 | f"Attention: {instance} exceeding the {limit} character limit will be " 47 | "automatically truncated upon ingestion into the Arize platform. Should you require " 48 | "a higher limit, please reach out to our support team at support@arize.com" 49 | ) 50 | 51 | 52 | def log_a_list(list_of_str: List[Any], join_word: str) -> str: 53 | if list_of_str is None or len(list_of_str) == 0: 54 | return "" 55 | if len(list_of_str) == 1: 56 | return list_of_str[0] 57 | return ( 58 | f"{', '.join(map(str, list_of_str[:-1]))} {join_word} {list_of_str[-1]}" 59 | ) 60 | -------------------------------------------------------------------------------- /src/arize/version.py: -------------------------------------------------------------------------------- 1 | __version__ = "7.43.1" 2 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Arize-ai/client_python/8f72212f42490baa7a90831c591b83ab6173ba83/tests/__init__.py -------------------------------------------------------------------------------- /tests/experimental/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Arize-ai/client_python/8f72212f42490baa7a90831c591b83ab6173ba83/tests/experimental/__init__.py -------------------------------------------------------------------------------- /tests/experimental/datasets/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Arize-ai/client_python/8f72212f42490baa7a90831c591b83ab6173ba83/tests/experimental/datasets/__init__.py -------------------------------------------------------------------------------- /tests/experimental/datasets/experiments/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Arize-ai/client_python/8f72212f42490baa7a90831c591b83ab6173ba83/tests/experimental/datasets/experiments/__init__.py -------------------------------------------------------------------------------- /tests/experimental/datasets/validation/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Arize-ai/client_python/8f72212f42490baa7a90831c591b83ab6173ba83/tests/experimental/datasets/validation/__init__.py -------------------------------------------------------------------------------- /tests/experimental/datasets/validation/test_validator.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | import pytest 4 | 5 | if sys.version_info < (3, 8): 6 | pytest.skip("Requires Python 3.8 or higher", allow_module_level=True) 7 | 8 | import pandas as pd 9 | 10 | from arize.experimental.datasets.core.client import ( 11 | ArizeDatasetsClient, 12 | _convert_default_columns_to_json_str, 13 | ) 14 | from arize.experimental.datasets.validation.errors import ( 15 | IDColumnUniqueConstraintError, 16 | RequiredColumnsError, 17 | ) 18 | from arize.experimental.datasets.validation.validator import Validator 19 | 20 | 21 | def test_happy_path(): 22 | df = pd.DataFrame( 23 | { 24 | "user_data": [1, 2, 3], 25 | } 26 | ) 27 | 28 | df_new = ArizeDatasetsClient._set_default_columns_for_dataset(df) 29 | differences = set(df_new.columns) ^ { 30 | "id", 31 | "created_at", 32 | "updated_at", 33 | "user_data", 34 | } 35 | assert not differences 36 | 37 | validation_errors = Validator.validate(df) 38 | assert len(validation_errors) == 0 39 | 40 | 41 | def test_missing_columns(): 42 | df = pd.DataFrame( 43 | { 44 | "user_data": [1, 2, 3], 45 | } 46 | ) 47 | 48 | validation_errors = Validator.validate(df) 49 | assert len(validation_errors) == 1 50 | assert type(validation_errors[0]) is RequiredColumnsError 51 | 52 | 53 | def test_non_unique_id_column(): 54 | df = pd.DataFrame( 55 | { 56 | "id": [1, 1, 2], 57 | "user_data": [1, 2, 3], 58 | } 59 | ) 60 | df_new = ArizeDatasetsClient._set_default_columns_for_dataset(df) 61 | 62 | validation_errors = Validator.validate(df_new) 63 | assert len(validation_errors) == 1 64 | assert validation_errors[0] is IDColumnUniqueConstraintError 65 | 66 | 67 | @pytest.mark.skipif( 68 | sys.version_info < (3, 8), reason="Requires Python 3.8 or higher" 69 | ) 70 | def test_dict_to_json_conversion() -> None: 71 | df = pd.DataFrame( 72 | { 73 | "id": [1, 2, 3], 74 | "eval.MyEvaluator.metadata": [ 75 | {"key": "value"}, 76 | {"key": "value"}, 77 | {"key": "value"}, 78 | ], 79 | "not_converted_dict_col": [ 80 | {"key": "value"}, 81 | {"key": "value"}, 82 | {"key": "value"}, 83 | ], 84 | } 85 | ) 86 | # before conversion, the column with the evaluator name is a dict 87 | assert type(df["eval.MyEvaluator.metadata"][0]) is dict 88 | assert type(df["not_converted_dict_col"][0]) is dict 89 | 90 | # Check that only the column with the evaluator name is converted to JSON 91 | converted_df = _convert_default_columns_to_json_str(df) 92 | assert type(converted_df["eval.MyEvaluator.metadata"][0]) is str 93 | assert type(converted_df["not_converted_dict_col"][0]) is dict 94 | -------------------------------------------------------------------------------- /tests/exporter/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Arize-ai/client_python/8f72212f42490baa7a90831c591b83ab6173ba83/tests/exporter/__init__.py -------------------------------------------------------------------------------- /tests/exporter/test_exporter.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Arize-ai/client_python/8f72212f42490baa7a90831c591b83ab6173ba83/tests/exporter/test_exporter.py -------------------------------------------------------------------------------- /tests/exporter/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Arize-ai/client_python/8f72212f42490baa7a90831c591b83ab6173ba83/tests/exporter/utils/__init__.py -------------------------------------------------------------------------------- /tests/exporter/validations/test_validator_invalid_types.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | from datetime import datetime 3 | 4 | from arize.exporter.utils.validation import Validator 5 | 6 | 7 | class MyTestCase(unittest.TestCase): 8 | def test_zero_error(self): 9 | try: 10 | for key, val in valid_inputs.items(): 11 | Validator.validate_input_type(val[0], key, val[1]) 12 | except TypeError: 13 | self.fail("validate_input_type raised TypeError unexpectedly") 14 | 15 | def test_invalid_space_id(self): 16 | space_id = invalid_inputs["space_id"][0] 17 | valid_type = invalid_inputs["space_id"][1] 18 | with self.assertRaisesRegex( 19 | TypeError, 20 | f"space_id {space_id} is type {type(space_id)}, but must be a str", 21 | ): 22 | Validator.validate_input_type(space_id, "space_id", valid_type) 23 | 24 | def test_invalid_model_name(self): 25 | model_name = invalid_inputs["model_name"][0] 26 | valid_type = invalid_inputs["model_name"][1] 27 | with self.assertRaisesRegex( 28 | TypeError, 29 | f"model_name {model_name} is type {type(model_name)}, but must be a str", 30 | ): 31 | Validator.validate_input_type(model_name, "model_name", valid_type) 32 | 33 | def test_invalid_data_type(self): 34 | data_type = invalid_inputs["data_type"][0] 35 | valid_type = invalid_inputs["data_type"][1] 36 | with self.assertRaisesRegex( 37 | TypeError, 38 | f"data_type {data_type} is type {type(data_type)}, but must be a str", 39 | ): 40 | Validator.validate_input_type(data_type, "data_type", valid_type) 41 | 42 | def test_invalid_start_or_end_time(self): 43 | start_time = invalid_inputs["start_time"][0] 44 | valid_type = invalid_inputs["start_time"][1] 45 | with self.assertRaisesRegex( 46 | TypeError, 47 | f"start_time {start_time} is type {type(start_time)}, but must be a datetime", 48 | ): 49 | Validator.validate_input_type(start_time, "start_time", valid_type) 50 | 51 | def test_invalid_path(self): 52 | path = invalid_inputs["path"][0] 53 | valid_type = invalid_inputs["path"][1] 54 | with self.assertRaisesRegex( 55 | TypeError, 56 | f"path {path} is type {type(path)}, but must be a str", 57 | ): 58 | Validator.validate_input_type(path, "path", valid_type) 59 | 60 | 61 | valid_inputs = { 62 | "space_id": ("abc123", str), 63 | "model_name": ("test_model", str), 64 | "data_type": ("predictions", str), 65 | "start_time": (datetime(2023, 4, 1, 0, 0, 0, 0), datetime), 66 | "end_time": (datetime(2023, 4, 15, 0, 0, 0, 0), datetime), 67 | "path": ("example.parquet", str), 68 | } 69 | 70 | invalid_inputs = { 71 | "space_id": (123, str), 72 | "model_name": (123, str), 73 | "data_type": (0.2, str), 74 | "start_time": ("2022-10-10", datetime), 75 | "end_time": ("2022-10-15", datetime), 76 | "path": (0.2, str), 77 | } 78 | 79 | if __name__ == "__main__": 80 | unittest.main() 81 | -------------------------------------------------------------------------------- /tests/exporter/validations/test_validator_invalid_values.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | from datetime import datetime 3 | 4 | from arize.exporter.utils.validation import Validator 5 | 6 | 7 | class MyTestCase(unittest.TestCase): 8 | def test_valid_data_type(self): 9 | try: 10 | Validator.validate_input_value( 11 | valid_data_type.upper(), "data_type", data_types 12 | ) 13 | except TypeError: 14 | self.fail("validate_input_type raised TypeError unexpectedly") 15 | 16 | def test_invalid_data_type(self): 17 | with self.assertRaisesRegex( 18 | ValueError, 19 | f"data_type is {invalid_data_type.upper()}, but must be one of PREDICTIONS, " 20 | f"CONCLUSIONS, EXPLANATIONS, PREPRODUCTION", 21 | ): 22 | Validator.validate_input_value( 23 | invalid_data_type.upper(), "data_type", data_types 24 | ) 25 | 26 | def test_invalid_start_end_time(self): 27 | with self.assertRaisesRegex( 28 | ValueError, 29 | "start_time must be before end_time", 30 | ): 31 | Validator.validate_start_end_time(start_time, end_time) 32 | 33 | 34 | valid_data_type = "preproduction" 35 | invalid_data_type = "hello" 36 | data_types = ("PREDICTIONS", "CONCLUSIONS", "EXPLANATIONS", "PREPRODUCTION") 37 | start_time = datetime(2023, 6, 15, 10, 30) 38 | end_time = datetime(2023, 6, 10, 10, 30) 39 | 40 | if __name__ == "__main__": 41 | unittest.main() 42 | -------------------------------------------------------------------------------- /tests/fixtures/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Arize-ai/client_python/8f72212f42490baa7a90831c591b83ab6173ba83/tests/fixtures/__init__.py -------------------------------------------------------------------------------- /tests/pandas/generative/nlp_metrics/test_nlp_metrics.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import pytest 3 | 4 | from arize.pandas.generative.nlp_metrics import ( 5 | bleu, 6 | google_bleu, 7 | meteor, 8 | rouge, 9 | sacre_bleu, 10 | ) 11 | 12 | 13 | def get_text_df() -> pd.DataFrame: 14 | return pd.DataFrame( 15 | { 16 | "response": [ 17 | "The cat is on the mat.", 18 | "The NASA Opportunity rover is battling a massive dust storm on Mars.", 19 | ], 20 | "references": [ 21 | ["The cat is on the blue mat."], 22 | [ 23 | "The Opportunity rover is combating a big sandstorm on Mars.", 24 | "A NASA rover is fighting a massive storm on Mars.", 25 | ], 26 | ], 27 | } 28 | ) 29 | 30 | 31 | def get_results_df() -> pd.DataFrame: 32 | return pd.DataFrame( 33 | { 34 | "bleu": [0.6129752413741056, 0.32774568052975916], 35 | "sacrebleu": [61.29752413741059, 32.774568052975916], 36 | "google_bleu": [0.6538461538461539, 0.3695652173913043], 37 | "rouge1": [0.923076923076923, 0.7272727272727272], 38 | "rouge2": [0.7272727272727272, 0.39999999999999997], 39 | "rougeL": [0.923076923076923, 0.7272727272727272], 40 | "rougeLsum": [0.923076923076923, 0.7272727272727272], 41 | "meteor": [0.8757427021441488, 0.7682980599647267], 42 | } 43 | ) 44 | 45 | 46 | def test_bleu_score() -> None: 47 | texts = get_text_df() 48 | results = get_results_df() 49 | 50 | try: 51 | bleu_scores = bleu( 52 | response_col=texts["response"], references_col=texts["references"] 53 | ) 54 | except Exception as e: 55 | raise AssertionError("Unexpected Error") from e 56 | 57 | assert (bleu_scores == results["bleu"]).all(), "BLEU scores should match" # type:ignore 58 | 59 | 60 | def test_sacrebleu_score() -> None: 61 | texts = get_text_df() 62 | results = get_results_df() 63 | 64 | try: 65 | sacrebleu_scores = sacre_bleu( 66 | response_col=texts["response"], references_col=texts["references"] 67 | ) 68 | except Exception as e: 69 | raise AssertionError("Unexpected Error") from e 70 | 71 | assert ( 72 | sacrebleu_scores == results["sacrebleu"] 73 | ).all(), "SacreBLEU scores should match" # type:ignore 74 | 75 | 76 | def test_google_bleu_score() -> None: 77 | texts = get_text_df() 78 | results = get_results_df() 79 | 80 | try: 81 | gbleu_scores = google_bleu( 82 | response_col=texts["response"], references_col=texts["references"] 83 | ) 84 | except Exception as e: 85 | raise AssertionError("Unexpected Error") from e 86 | 87 | assert ( 88 | gbleu_scores == results["google_bleu"] 89 | ).all(), "Google BLEU scores should match" # type:ignore 90 | 91 | 92 | def test_rouge_score() -> None: 93 | texts = get_text_df() 94 | results = get_results_df() 95 | 96 | try: 97 | rouge_scores = rouge( 98 | response_col=texts["response"], references_col=texts["references"] 99 | ) 100 | except Exception as e: 101 | raise AssertionError("Unexpected Error") from e 102 | 103 | # Check that only default rouge scores are returned, and they match 104 | assert isinstance(rouge_scores, dict) 105 | assert len(rouge_scores.keys()) == 1 106 | assert list(rouge_scores.keys())[0] == "rougeL" # type:ignore 107 | assert ( 108 | rouge_scores["rougeL"] == results["rougeL"] 109 | ).all(), "ROUGE scores should match" # type: ignore 110 | 111 | rouge_types = [ 112 | "rouge1", 113 | "rouge2", 114 | "rougeL", 115 | "rougeLsum", 116 | ] 117 | try: 118 | rouge_scores = rouge( 119 | response_col=texts["response"], 120 | references_col=texts["references"], 121 | rouge_types=rouge_types, 122 | ) 123 | except Exception as e: 124 | raise AssertionError("Unexpected Error") from e 125 | 126 | # Check that all rouge scores are returned, and they match 127 | assert isinstance(rouge_scores, dict) 128 | assert list(rouge_scores.keys()) == rouge_types 129 | for rtype in rouge_types: 130 | assert ( 131 | rouge_scores[rtype] == results[rtype] 132 | ).all(), f"ROUGE scores ({rtype}) should match" # type: ignore 133 | 134 | 135 | def test_meteor_score() -> None: 136 | texts = get_text_df() 137 | results = get_results_df() 138 | 139 | try: 140 | meteor_scores = meteor( 141 | response_col=texts["response"], references_col=texts["references"] 142 | ) 143 | except Exception as e: 144 | raise AssertionError("Unexpected Error") from e 145 | 146 | assert ( 147 | meteor_scores == results["meteor"] 148 | ).all(), "METEOR scores should match" # type:ignore 149 | 150 | 151 | if __name__ == "__main__": 152 | raise SystemExit(pytest.main([__file__])) 153 | -------------------------------------------------------------------------------- /tests/pandas/tracing/test_columns.py: -------------------------------------------------------------------------------- 1 | import re 2 | import sys 3 | 4 | import pytest 5 | 6 | if sys.version_info >= (3, 8): 7 | from arize.pandas.tracing.columns import ( 8 | ANNOTATION_COLUMN_PATTERN, 9 | ANNOTATION_NAME_PATTERN, 10 | EVAL_COLUMN_PATTERN, 11 | EVAL_EXPLANATION_PATTERN, 12 | EVAL_LABEL_PATTERN, 13 | EVAL_NAME_PATTERN, 14 | EVAL_SCORE_PATTERN, 15 | ) 16 | 17 | 18 | @pytest.mark.skipif(sys.version_info < (3, 8), reason="Requires python>=3.8") 19 | def test_eval_column_pattern(): 20 | assert re.match(EVAL_COLUMN_PATTERN, "eval.name.label") 21 | assert re.match(EVAL_COLUMN_PATTERN, "eval.name with spaces.label") 22 | assert re.match(EVAL_COLUMN_PATTERN, "eval.name.score") 23 | assert re.match(EVAL_COLUMN_PATTERN, "eval.name.explanation") 24 | assert not re.match(EVAL_COLUMN_PATTERN, "eval.name") 25 | assert not re.match(EVAL_COLUMN_PATTERN, "name.label") 26 | 27 | 28 | @pytest.mark.skipif(sys.version_info < (3, 8), reason="Requires python>=3.8") 29 | def test_eval_label_pattern(): 30 | assert re.match(EVAL_LABEL_PATTERN, "eval.name.label") 31 | assert re.match(EVAL_LABEL_PATTERN, "eval.name with spaces.label") 32 | assert not re.match(EVAL_LABEL_PATTERN, "eval.name.score") 33 | 34 | 35 | @pytest.mark.skipif(sys.version_info < (3, 8), reason="Requires python>=3.8") 36 | def test_eval_score_pattern(): 37 | assert re.match(EVAL_SCORE_PATTERN, "eval.name.score") 38 | assert re.match(EVAL_SCORE_PATTERN, "eval.name with spaces.score") 39 | assert not re.match(EVAL_SCORE_PATTERN, "eval.name.label") 40 | 41 | 42 | @pytest.mark.skipif(sys.version_info < (3, 8), reason="Requires python>=3.8") 43 | def test_eval_explanation_pattern(): 44 | assert re.match(EVAL_EXPLANATION_PATTERN, "eval.name.explanation") 45 | assert re.match( 46 | EVAL_EXPLANATION_PATTERN, "eval.name with spaces.explanation" 47 | ) 48 | assert not re.match(EVAL_EXPLANATION_PATTERN, "eval_name.label") 49 | 50 | 51 | @pytest.mark.skipif(sys.version_info < (3, 8), reason="Requires python>=3.8") 52 | def test_eval_name_capture(): 53 | matches = [ 54 | ("eval.name_part.label", "name_part"), 55 | ("eval.name with spaces.label", "name with spaces"), 56 | ] 57 | for test_str, expected_name in matches: 58 | match = re.match(EVAL_NAME_PATTERN, test_str) 59 | assert match is not None, f"Failed to match '{test_str}'" 60 | assert ( 61 | match.group(1) == expected_name 62 | ), f"Incorrect name captured for '{test_str}'" 63 | 64 | non_matches = [ 65 | "evalname.", # Missing dot 66 | "name_part.", # Missing prefix 67 | "eval.name", # Missing suffix 68 | ] 69 | for test_str in non_matches: 70 | assert ( 71 | re.match(EVAL_NAME_PATTERN, test_str) is None 72 | ), f"Incorrectly matched '{test_str}'" 73 | 74 | 75 | @pytest.mark.skipif(sys.version_info < (3, 8), reason="Requires python>=3.8") 76 | def test_annotation_column_pattern(): 77 | """Test the regex pattern for matching valid annotation column names.""" 78 | valid_patterns = [ 79 | "annotation.quality.label", 80 | "annotation.toxicity_score.score", 81 | "annotation.needs review.updated_by", 82 | "annotation.timestamp_col.updated_at", 83 | "annotation.numeric_name123.label", 84 | ] 85 | invalid_patterns = [ 86 | "annotation.quality", # Missing suffix 87 | "annotations.quality.label", # Wrong prefix 88 | "annotation.quality.Score", # Incorrect suffix case 89 | "annotation..label", # Empty name part 90 | "quality.label", # Missing prefix 91 | "annotation.notes", # Specific reserved name (tested separately if needed) 92 | "eval.quality.label", # Different prefix 93 | ] 94 | 95 | for pattern in valid_patterns: 96 | assert re.match( 97 | ANNOTATION_COLUMN_PATTERN, pattern 98 | ), f"Should match: {pattern}" 99 | for pattern in invalid_patterns: 100 | assert not re.match( 101 | ANNOTATION_COLUMN_PATTERN, pattern 102 | ), f"Should NOT match: {pattern}" 103 | 104 | 105 | @pytest.mark.skipif(sys.version_info < (3, 8), reason="Requires python>=3.8") 106 | def test_annotation_name_capture(): 107 | """Test capturing the annotation name from column strings.""" 108 | matches = [ 109 | ("annotation.quality.label", "quality"), 110 | ("annotation.name with spaces.score", "name with spaces"), 111 | ("annotation.name_123.updated_by", "name_123"), 112 | ] 113 | for test_str, expected_name in matches: 114 | match = re.match(ANNOTATION_NAME_PATTERN, test_str) 115 | assert match is not None, f"Failed to match '{test_str}'" 116 | assert ( 117 | match.group(1) == expected_name 118 | ), f"Incorrect name captured for '{test_str}'" 119 | 120 | non_matches = [ 121 | "annotationname.label", # Missing dot after prefix 122 | "annotation..score", # Empty name part 123 | "annotation.name", # No suffix 124 | "quality.label", # Missing prefix 125 | ] 126 | for test_str in non_matches: 127 | assert ( 128 | re.match(ANNOTATION_NAME_PATTERN, test_str) is None 129 | ), f"Incorrectly matched '{test_str}'" 130 | 131 | 132 | if __name__ == "__main__": 133 | raise SystemExit(pytest.main([__file__])) 134 | -------------------------------------------------------------------------------- /tests/pandas/tracing/validation/test_invalid_annotation_arguments.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import time 3 | import uuid 4 | 5 | import pandas as pd 6 | import pytest 7 | 8 | if sys.version_info >= (3, 8): 9 | from arize.pandas.tracing.columns import ( 10 | ANNOTATION_LABEL_SUFFIX, 11 | ANNOTATION_NOTES_COLUMN_NAME, 12 | ANNOTATION_SCORE_SUFFIX, 13 | ANNOTATION_UPDATED_AT_SUFFIX, 14 | ANNOTATION_UPDATED_BY_SUFFIX, 15 | SPAN_SPAN_ID_COL, 16 | ) 17 | from arize.pandas.tracing.validation.annotations import ( 18 | annotations_validation, 19 | ) 20 | 21 | 22 | def get_valid_annotation_df(num_rows=2): 23 | """Helper to create a DataFrame with the correct type for the notes column.""" 24 | span_ids = [str(uuid.uuid4()) for _ in range(num_rows)] 25 | current_time_ms = int(time.time() * 1000) 26 | 27 | # Initialize notes with empty lists for all rows to ensure consistent list type 28 | notes_list = [[] for _ in range(num_rows)] 29 | if num_rows > 0: 30 | notes_list[0] = [ 31 | '{"text": "Note 1"}' 32 | ] # Assign the actual note to the first row 33 | 34 | df = pd.DataFrame( 35 | { 36 | SPAN_SPAN_ID_COL.name: span_ids, 37 | f"annotation.quality{ANNOTATION_LABEL_SUFFIX}": ["good", "bad"][ 38 | :num_rows 39 | ], 40 | f"annotation.quality{ANNOTATION_SCORE_SUFFIX}": [0.9, 0.1][ 41 | :num_rows 42 | ], 43 | f"annotation.quality{ANNOTATION_UPDATED_BY_SUFFIX}": [ 44 | "user1", 45 | "user2", 46 | ][:num_rows], 47 | f"annotation.quality{ANNOTATION_UPDATED_AT_SUFFIX}": [ 48 | current_time_ms - 1000, 49 | current_time_ms, 50 | ][:num_rows], 51 | # Use the pre-initialized list with consistent type 52 | ANNOTATION_NOTES_COLUMN_NAME: pd.Series(notes_list, dtype=object), 53 | } 54 | ) 55 | return df 56 | 57 | 58 | @pytest.mark.skipif(sys.version_info < (3, 8), reason="Requires python>=3.8") 59 | def test_valid_annotation_column_types(): 60 | """Tests that a DataFrame with correct types passes validation.""" 61 | annotations_dataframe = get_valid_annotation_df() 62 | errors = annotations_validation.validate_dataframe_form( 63 | annotations_dataframe=annotations_dataframe 64 | ) 65 | assert len(errors) == 0, "Expected no validation errors for valid types" 66 | 67 | 68 | @pytest.mark.skipif(sys.version_info < (3, 8), reason="Requires python>=3.8") 69 | def test_invalid_annotation_label_type(): 70 | """Tests error for non-string label column.""" 71 | annotations_dataframe = get_valid_annotation_df() 72 | annotations_dataframe[f"annotation.quality{ANNOTATION_LABEL_SUFFIX}"] = [ 73 | 1, 74 | 2, 75 | ] 76 | errors = annotations_validation.validate_dataframe_form( 77 | annotations_dataframe=annotations_dataframe 78 | ) 79 | assert len(errors) > 0 80 | 81 | 82 | @pytest.mark.skipif(sys.version_info < (3, 8), reason="Requires python>=3.8") 83 | def test_invalid_annotation_score_type(): 84 | """Tests error for non-numeric score column.""" 85 | annotations_dataframe = get_valid_annotation_df() 86 | annotations_dataframe[f"annotation.quality{ANNOTATION_SCORE_SUFFIX}"] = [ 87 | "high", 88 | "low", 89 | ] 90 | errors = annotations_validation.validate_dataframe_form( 91 | annotations_dataframe=annotations_dataframe 92 | ) 93 | assert len(errors) > 0 94 | 95 | 96 | @pytest.mark.skipif(sys.version_info < (3, 8), reason="Requires python>=3.8") 97 | def test_invalid_annotation_updated_by_type(): 98 | """Tests error for non-string updated_by column.""" 99 | annotations_dataframe = get_valid_annotation_df() 100 | annotations_dataframe[ 101 | f"annotation.quality{ANNOTATION_UPDATED_BY_SUFFIX}" 102 | ] = [100, 200] 103 | errors = annotations_validation.validate_dataframe_form( 104 | annotations_dataframe=annotations_dataframe 105 | ) 106 | assert len(errors) > 0 107 | 108 | 109 | @pytest.mark.skipif(sys.version_info < (3, 8), reason="Requires python>=3.8") 110 | def test_invalid_annotation_updated_at_type(): 111 | """Tests error for non-numeric updated_at column.""" 112 | annotations_dataframe = get_valid_annotation_df() 113 | annotations_dataframe[ 114 | f"annotation.quality{ANNOTATION_UPDATED_AT_SUFFIX}" 115 | ] = ["yesterday", "today"] 116 | errors = annotations_validation.validate_dataframe_form( 117 | annotations_dataframe=annotations_dataframe 118 | ) 119 | assert len(errors) > 0 120 | 121 | 122 | @pytest.mark.skipif(sys.version_info < (3, 8), reason="Requires python>=3.8") 123 | def test_invalid_annotation_notes_type(): 124 | """Tests error for non-list notes column.""" 125 | annotations_dataframe = get_valid_annotation_df() 126 | annotations_dataframe[ANNOTATION_NOTES_COLUMN_NAME] = "just a string" 127 | errors = annotations_validation.validate_dataframe_form( 128 | annotations_dataframe=annotations_dataframe 129 | ) 130 | assert len(errors) > 0 131 | 132 | 133 | if __name__ == "__main__": 134 | raise SystemExit(pytest.main([__file__])) 135 | -------------------------------------------------------------------------------- /tests/pandas/tracing/validation/test_invalid_arguments.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | import pandas as pd 4 | import pytest 5 | 6 | if sys.version_info >= (3, 8): 7 | from arize.pandas.tracing.validation.evals import evals_validation 8 | 9 | valid_spans_dataframe = pd.DataFrame( 10 | { 11 | "context.span_id": ["span_id_11111111", "span_id_22222222"], 12 | "context.trace_id": ["trace_id_11111111", "trace_id_22222222"], 13 | "name": ["name_1", "name_2"], 14 | "start_time": [ 15 | "2024-01-18T18:28:27.429383+00:00", 16 | "2024-01-18T18:28:27.429383+00:00", 17 | ], 18 | "end_time": [ 19 | "2024-01-18T18:28:27.429383+00:00", 20 | "2024-01-18T18:28:27.429383+00:00", 21 | ], 22 | } 23 | ) 24 | 25 | 26 | @pytest.mark.skipif(sys.version_info < (3, 8), reason="Requires python>=3.8") 27 | def test_valid_eval_column_types(): 28 | evals_dataframe = pd.DataFrame( 29 | { 30 | "context.span_id": ["span_id_1", "span_id_2"], 31 | "eval.eval_1.label": ["relevant", "irrelevant"], 32 | "eval.eval_1.score": [1.0, None], 33 | "eval.eval_1.explanation": [ 34 | "explanation for relevant", 35 | "explanation for irrelevant", 36 | ], 37 | } 38 | ) 39 | errors = evals_validation.validate_dataframe_form( 40 | evals_dataframe=evals_dataframe 41 | ) 42 | assert len(errors) == 0, "Expected no validation errors for all columns" 43 | 44 | 45 | @pytest.mark.skipif(sys.version_info < (3, 8), reason="Requires python>=3.8") 46 | def test_invalid_label_columns_type(): 47 | evals_dataframe = pd.DataFrame( 48 | { 49 | "context.span_id": ["span_id_1", "span_id_2"], 50 | "eval.eval_1.label": [ 51 | 1, 52 | 2, 53 | ], 54 | } 55 | ) 56 | errors = evals_validation.validate_dataframe_form( 57 | evals_dataframe=evals_dataframe 58 | ) 59 | assert ( 60 | len(errors) > 0 61 | ), "Expected validation errors for label columns with incorrect type" 62 | 63 | 64 | @pytest.mark.skipif(sys.version_info < (3, 8), reason="Requires python>=3.8") 65 | def test_invalid_score_columns_type(): 66 | evals_dataframe = pd.DataFrame( 67 | { 68 | "context.span_id": ["span_id_1", "span_id_2"], 69 | "eval.eval_1.score": [ 70 | "1.0", 71 | "None", 72 | ], 73 | } 74 | ) 75 | errors = evals_validation.validate_dataframe_form( 76 | evals_dataframe=evals_dataframe 77 | ) 78 | assert ( 79 | len(errors) > 0 80 | ), "Expected validation errors for score columns with incorrect type" 81 | 82 | 83 | @pytest.mark.skipif(sys.version_info < (3, 8), reason="Requires python>=3.8") 84 | def test_invalid_explanation_columns_type(): 85 | evals_dataframe = pd.DataFrame( 86 | { 87 | "context.span_id": ["span_id_1", "span_id_2"], 88 | "eval.eval_1.explanation": [ 89 | 1, 90 | 2, 91 | ], 92 | } 93 | ) 94 | errors = evals_validation.validate_dataframe_form( 95 | evals_dataframe=evals_dataframe 96 | ) 97 | assert ( 98 | len(errors) > 0 99 | ), "Expected validation errors for explanation columns with incorrect type" 100 | 101 | 102 | if __name__ == "__main__": 103 | raise SystemExit(pytest.main([__file__])) 104 | -------------------------------------------------------------------------------- /tests/pandas/validation/test_pandas_validator_error_messages.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | import arize.pandas.validation.errors as err 4 | from arize.utils.types import Environments, ModelTypes 5 | 6 | # ---------------- 7 | # Parameter checks 8 | # ---------------- 9 | 10 | 11 | def test_missing_columns(): 12 | err_msg = str(err.MissingColumns(["genotype", "phenotype"])) 13 | assert "genotype" in err_msg 14 | assert "phenotype" in err_msg 15 | 16 | 17 | def test_invalid_model_type(): 18 | err_msg = str(err.InvalidModelType()) 19 | assert all(mt.name in err_msg for mt in ModelTypes) 20 | 21 | 22 | def test_invalid_environment(): 23 | err_msg = str(err.InvalidEnvironment()) 24 | assert all(env.name in err_msg for env in Environments) 25 | 26 | 27 | # ----------- 28 | # Type checks 29 | # ----------- 30 | 31 | 32 | def test_Invalid_type(): 33 | err_msg = str(err.InvalidType("123", ["456", "789"], "112")) 34 | assert "123" in err_msg 35 | assert "456" in err_msg 36 | assert "789" in err_msg 37 | assert "112" in err_msg 38 | 39 | 40 | def test_invalid_type_features(): 41 | err_msg = str( 42 | err.InvalidTypeFeatures( 43 | ["genotype", "phenotype"], ["Triceratops", "Archaeopteryx"] 44 | ) 45 | ) 46 | assert "genotype" in err_msg 47 | assert "phenotype" in err_msg 48 | assert "Triceratops" in err_msg 49 | assert "Archaeopteryx" in err_msg 50 | 51 | 52 | def test_invalid_type_tags(): 53 | err_msg = str( 54 | err.InvalidTypeTags( 55 | ["genotype", "phenotype"], ["Triceratops", "Archaeopteryx"] 56 | ) 57 | ) 58 | assert "genotype" in err_msg 59 | assert "phenotype" in err_msg 60 | assert "Triceratops" in err_msg 61 | assert "Archaeopteryx" in err_msg 62 | 63 | 64 | def test_invalid_type_shap_values(): 65 | err_msg = str( 66 | err.InvalidTypeShapValues( 67 | ["genotype", "phenotype"], ["Triceratops", "Archaeopteryx"] 68 | ) 69 | ) 70 | assert "genotype" in err_msg 71 | assert "phenotype" in err_msg 72 | assert "Triceratops" in err_msg 73 | assert "Archaeopteryx" in err_msg 74 | 75 | 76 | # ------------ 77 | # Value checks 78 | # ------------ 79 | 80 | 81 | def test_invalid_timestamp(): 82 | err_msg = str(err.InvalidValueTimestamp("Spinosaurus")) 83 | assert "Spinosaurus" in err_msg 84 | 85 | 86 | def test_invalid_missing_value(): 87 | err_msg = str(err.InvalidValueMissingValue("Stegosaurus", "missing")) 88 | assert "Stegosaurus" in err_msg 89 | assert "missing" in err_msg 90 | 91 | 92 | def test_invalid_infinite_value(): 93 | err_msg = str(err.InvalidValueMissingValue("Stegosaurus", "infinite")) 94 | assert "Stegosaurus" in err_msg 95 | assert "infinite" in err_msg 96 | 97 | 98 | if __name__ == "__main__": 99 | raise SystemExit(pytest.main([__file__])) 100 | -------------------------------------------------------------------------------- /tests/pandas/validation/test_pandas_validator_invalid_shap_suffix.py: -------------------------------------------------------------------------------- 1 | from collections import ChainMap 2 | 3 | import pandas as pd 4 | import pytest 5 | 6 | from arize.pandas.logger import Schema 7 | from arize.pandas.validation.errors import InvalidShapSuffix 8 | from arize.pandas.validation.validator import Validator 9 | from arize.utils.types import Environments, ModelTypes 10 | 11 | 12 | def test_invalid_feature_columns(): 13 | errors = Validator.validate_params( 14 | **ChainMap( 15 | { 16 | "dataframe": kwargs["dataframe"].assign(feat_shap=[0]), 17 | "schema": Schema( 18 | prediction_id_column_name="prediction_id", 19 | prediction_label_column_name="prediction_label", 20 | feature_column_names=["feat_shap"], 21 | ), 22 | }, 23 | kwargs, 24 | ), 25 | ) 26 | assert len(errors) == 1 27 | assert type(errors[0]) is InvalidShapSuffix 28 | 29 | 30 | def test_invalid_tag_columns(): 31 | errors = Validator.validate_params( 32 | **ChainMap( 33 | { 34 | "dataframe": kwargs["dataframe"].assign(tag_shap=[0]), 35 | "schema": Schema( 36 | prediction_id_column_name="prediction_id", 37 | prediction_label_column_name="prediction_label", 38 | tag_column_names=["tag_shap"], 39 | ), 40 | }, 41 | kwargs, 42 | ), 43 | ) 44 | assert len(errors) == 1 45 | assert type(errors[0]) is InvalidShapSuffix 46 | 47 | 48 | def test_invalid_shap_columns(): 49 | errors = Validator.validate_params( 50 | **ChainMap( 51 | { 52 | "dataframe": kwargs["dataframe"].assign(shap=[0]), 53 | "schema": Schema( 54 | prediction_id_column_name="prediction_id", 55 | prediction_label_column_name="prediction_label", 56 | shap_values_column_names={"feat_shap": "shap"}, 57 | ), 58 | }, 59 | kwargs, 60 | ), 61 | ) 62 | assert len(errors) == 1 63 | assert type(errors[0]) is InvalidShapSuffix 64 | 65 | 66 | def test_invalid_multiple(): 67 | errors = Validator.validate_params( 68 | **ChainMap( 69 | { 70 | "dataframe": kwargs["dataframe"].assign( 71 | feat_shap=[0], tag_shap=[0], shap=[0] 72 | ), 73 | "schema": Schema( 74 | prediction_id_column_name="prediction_id", 75 | prediction_label_column_name="prediction_label", 76 | feature_column_names=["feat_shap"], 77 | tag_column_names=["tag_shap"], 78 | shap_values_column_names={"feat_shap": "shap"}, 79 | ), 80 | }, 81 | kwargs, 82 | ), 83 | ) 84 | assert len(errors) == 1 85 | assert type(errors[0]) is InvalidShapSuffix 86 | 87 | 88 | kwargs = { 89 | "model_id": "fraud", 90 | "model_version": "v1.0", 91 | "model_type": ModelTypes.SCORE_CATEGORICAL, 92 | "environment": Environments.PRODUCTION, 93 | "dataframe": pd.DataFrame( 94 | { 95 | "prediction_id": pd.Series(["0"]), 96 | "prediction_label": pd.Series(["fraud"]), 97 | "prediction_score": pd.Series([1]), 98 | "actual_label": pd.Series(["not fraud"]), 99 | "actual_score": pd.Series([0]), 100 | } 101 | ), 102 | "schema": Schema( 103 | prediction_id_column_name="prediction_id", 104 | prediction_label_column_name="prediction_label", 105 | ), 106 | } 107 | 108 | if __name__ == "__main__": 109 | raise SystemExit(pytest.main([__file__])) 110 | -------------------------------------------------------------------------------- /tests/pandas/validation/test_pandas_validator_ranking_param_checks.py: -------------------------------------------------------------------------------- 1 | from collections import ChainMap 2 | from datetime import datetime, timedelta 3 | 4 | import pandas as pd 5 | import pytest 6 | 7 | import arize.pandas.validation.errors as err 8 | from arize.pandas.logger import Schema 9 | from arize.pandas.validation.validator import Validator 10 | from arize.utils.types import Environments, ModelTypes 11 | 12 | kwargs = { 13 | "model_id": "rank", 14 | "model_type": ModelTypes.RANKING, 15 | "environment": Environments.PRODUCTION, 16 | "dataframe": pd.DataFrame( 17 | { 18 | "prediction_timestamp": pd.Series( 19 | [ 20 | datetime.now(), 21 | datetime.now() + timedelta(days=1), 22 | datetime.now() - timedelta(days=364), 23 | datetime.now() + timedelta(days=364), 24 | ] 25 | ), 26 | "prediction_id": pd.Series(["x_1", "x_2", "y_1", "y_2"]), 27 | "prediction_group_id": pd.Series(["X", "X", "Y", "Y"]), 28 | "item_type": pd.Series(["toy", "game", "game", "pens"]), 29 | "ranking_rank": pd.Series([1, 2, 1, 2]), 30 | "ranking_category": pd.Series( 31 | [ 32 | ["click", "purchase"], 33 | ["click", "favor"], 34 | ["favor"], 35 | ["click"], 36 | ] 37 | ), 38 | "ranking_relevance": pd.Series([1, 0, 2, 0]), 39 | } 40 | ), 41 | "schema": Schema( 42 | prediction_id_column_name="prediction_id", 43 | prediction_group_id_column_name="prediction_group_id", 44 | timestamp_column_name="prediction_timestamp", 45 | feature_column_names=["item_type"], 46 | rank_column_name="ranking_rank", 47 | actual_score_column_name="ranking_relevance", 48 | actual_label_column_name="ranking_category", 49 | ), 50 | } 51 | 52 | 53 | def test_ranking_param_check_happy_path(): 54 | errors = Validator.validate_params(**kwargs) 55 | assert len(errors) == 0 56 | 57 | 58 | def test_ranking_param_check_allow_delayed_actuals(): 59 | # The following Schema has no prediction information 60 | # we should allow delayed actuals 61 | errors = Validator.validate_params( 62 | **ChainMap( 63 | { 64 | "schema": Schema( 65 | prediction_id_column_name="prediction_id", 66 | timestamp_column_name="prediction_timestamp", 67 | feature_column_names=["item_type"], 68 | actual_score_column_name="ranking_relevance", 69 | actual_label_column_name="ranking_category", 70 | ) 71 | }, 72 | kwargs, 73 | ) 74 | ) 75 | assert len(errors) == 0 76 | 77 | 78 | def test_ranking_param_check_missing_prediction_group_id(): 79 | errors = Validator.validate_params( 80 | **ChainMap( 81 | { 82 | "schema": Schema( 83 | prediction_id_column_name="prediction_id", 84 | prediction_group_id_column_name=None, 85 | timestamp_column_name="prediction_timestamp", 86 | feature_column_names=["item_type"], 87 | rank_column_name="ranking_rank", 88 | actual_score_column_name="ranking_relevance", 89 | actual_label_column_name="ranking_category", 90 | ) 91 | }, 92 | kwargs, 93 | ) 94 | ) 95 | assert len(errors) == 1 96 | assert type(errors[0]) is err.MissingRequiredColumnsForRankingModel 97 | 98 | 99 | def test_ranking_param_check_missing_rank(): 100 | errors = Validator.validate_params( 101 | **ChainMap( 102 | { 103 | "schema": Schema( 104 | prediction_id_column_name="prediction_id", 105 | prediction_group_id_column_name="prediction_group_id", 106 | timestamp_column_name="prediction_timestamp", 107 | feature_column_names=["item_type"], 108 | rank_column_name=None, 109 | actual_score_column_name="ranking_relevance", 110 | actual_label_column_name="ranking_category", 111 | ) 112 | }, 113 | kwargs, 114 | ) 115 | ) 116 | assert len(errors) == 1 117 | assert type(errors[0]) is err.MissingRequiredColumnsForRankingModel 118 | 119 | 120 | def test_ranking_param_check_missing_category(): 121 | errors = Validator.validate_params( 122 | **ChainMap( 123 | { 124 | "schema": Schema( 125 | prediction_id_column_name="prediction_id", 126 | prediction_group_id_column_name="prediction_group_id", 127 | timestamp_column_name="prediction_timestamp", 128 | feature_column_names=["item_type"], 129 | rank_column_name="ranking_rank", 130 | actual_score_column_name=None, 131 | actual_label_column_name=None, 132 | ) 133 | }, 134 | kwargs, 135 | ) 136 | ) 137 | assert len(errors) == 0 138 | 139 | 140 | if __name__ == "__main__": 141 | raise SystemExit(pytest.main([__file__])) 142 | -------------------------------------------------------------------------------- /tests/single_log/test_casting.py: -------------------------------------------------------------------------------- 1 | import math 2 | import sys 3 | from dataclasses import dataclass 4 | from typing import List, Union 5 | 6 | import pytest 7 | 8 | from arize.single_log.casting import cast_value 9 | from arize.single_log.errors import CastingError 10 | from arize.utils.types import ArizeTypes, TypedValue 11 | 12 | 13 | @dataclass 14 | class SingleLogTestCase: 15 | typed_value: TypedValue 16 | expected_value: Union[int, float, str, None] 17 | expected_error: CastingError = None 18 | 19 | 20 | @pytest.mark.skipif(sys.version_info < (3, 8), reason="Requires python>=3.8") 21 | def test_string_to_float_error(): 22 | tv = TypedValue(value="fruit", type=ArizeTypes.FLOAT) 23 | tc = SingleLogTestCase( 24 | typed_value=tv, 25 | expected_value=None, 26 | expected_error=CastingError( 27 | error_msg="could not convert string to float: 'fruit'", 28 | typed_value=tv, 29 | ), 30 | ) 31 | table_test([tc]) 32 | 33 | 34 | @pytest.mark.skipif(sys.version_info < (3, 8), reason="Requires python>=3.8") 35 | def test_numeric_string_to_int_error(): 36 | tv = TypedValue(value="10.2", type=ArizeTypes.INT) 37 | tc = SingleLogTestCase( 38 | typed_value=tv, 39 | expected_value=None, 40 | expected_error=CastingError( 41 | error_msg="invalid literal for int() with base 10: '10.2'", 42 | typed_value=tv, 43 | ), 44 | ) 45 | table_test([tc]) 46 | 47 | 48 | @pytest.mark.skipif(sys.version_info < (3, 8), reason="Requires python>=3.8") 49 | def test_empty_string_to_numeric_error(): 50 | tv = TypedValue(value="", type=ArizeTypes.FLOAT) 51 | tc = SingleLogTestCase( 52 | typed_value=tv, 53 | expected_value=None, 54 | expected_error=CastingError( 55 | error_msg="could not convert string to float: ''", 56 | typed_value=tv, 57 | ), 58 | ) 59 | table_test([tc]) 60 | 61 | 62 | @pytest.mark.skipif(sys.version_info < (3, 8), reason="Requires python>=3.8") 63 | def test_float_to_int_error(): 64 | tv = TypedValue(value=4.4, type=ArizeTypes.INT) 65 | tc = SingleLogTestCase( 66 | typed_value=tv, 67 | expected_value=None, 68 | expected_error=CastingError( 69 | # this is our custom error; 70 | # native python float->int casting succeeds by taking the floor of the float. 71 | error_msg="Cannot convert float with non-zero fractional part to int", 72 | typed_value=tv, 73 | ), 74 | ) 75 | table_test([tc]) 76 | 77 | 78 | @pytest.mark.skipif(sys.version_info < (3, 8), reason="Requires python>=3.8") 79 | def test_cast_to_float_no_error(): 80 | tc1 = SingleLogTestCase( 81 | typed_value=TypedValue(value=1, type=ArizeTypes.FLOAT), 82 | expected_value=1.0, 83 | ) 84 | tc2 = SingleLogTestCase( 85 | typed_value=TypedValue(value="7.7", type=ArizeTypes.FLOAT), 86 | expected_value=7.7, 87 | ) 88 | tc3 = SingleLogTestCase( 89 | typed_value=TypedValue(value="NaN", type=ArizeTypes.FLOAT), 90 | expected_value=float("NaN"), 91 | ) 92 | table_test([tc1, tc2, tc3]) 93 | 94 | 95 | @pytest.mark.skipif(sys.version_info < (3, 8), reason="Requires python>=3.8") 96 | def test_cast_to_int_no_error(): 97 | tc1 = SingleLogTestCase( 98 | typed_value=TypedValue(value=1.0, type=ArizeTypes.INT), expected_value=1 99 | ) 100 | tc2 = SingleLogTestCase( 101 | typed_value=TypedValue(value="7", type=ArizeTypes.INT), expected_value=7 102 | ) 103 | tc3 = SingleLogTestCase( 104 | typed_value=TypedValue(value=None, type=ArizeTypes.INT), 105 | expected_value=None, 106 | ) 107 | table_test([tc1, tc2, tc3]) 108 | 109 | 110 | @pytest.mark.skipif(sys.version_info < (3, 8), reason="Requires python>=3.8") 111 | def test_cast_to_string_no_error(): 112 | tc1 = SingleLogTestCase( 113 | typed_value=TypedValue(value=1.0, type=ArizeTypes.STR), 114 | expected_value="1.0", 115 | ) 116 | tc2 = SingleLogTestCase( 117 | typed_value=TypedValue(value=float("NaN"), type=ArizeTypes.STR), 118 | expected_value=None, 119 | ) 120 | tc3 = SingleLogTestCase( 121 | typed_value=TypedValue(value=None, type=ArizeTypes.STR), 122 | expected_value=None, 123 | ) 124 | table_test([tc1, tc2, tc3]) 125 | 126 | 127 | def table_test(test_cases: List[SingleLogTestCase]): 128 | for test_case in test_cases: 129 | try: 130 | v = cast_value(test_case.typed_value) 131 | except Exception as e: 132 | if test_case.expected_error is None: 133 | pytest.fail("Unexpected error!") 134 | else: 135 | assert isinstance(e, CastingError) 136 | assert e.typed_value == test_case.expected_error.typed_value 137 | assert e.error_msg == test_case.expected_error.error_msg 138 | else: 139 | if test_case.expected_value is None: 140 | assert v is None 141 | elif not isinstance(test_case.expected_value, str) and math.isnan( 142 | test_case.expected_value 143 | ): 144 | assert math.isnan(v) 145 | else: 146 | assert test_case.expected_value == v 147 | 148 | 149 | if __name__ == "__main__": 150 | raise SystemExit(pytest.main([__file__])) 151 | -------------------------------------------------------------------------------- /tests/test_utils.py: -------------------------------------------------------------------------------- 1 | import json 2 | from collections import namedtuple 3 | 4 | import pytest 5 | 6 | from arize.utils import utils 7 | 8 | 9 | def test_response_url(): 10 | input = namedtuple("Response", ["content"])( 11 | json.dumps( 12 | { 13 | "realTimeIngestionUri": ( 14 | "https://app.dev.arize.com/" 15 | "organizations/test-hmac-org/" 16 | "spaces/test-hmac-space/" 17 | "models/modelName/" 18 | "z-upload-classification-data-with-arize?" 19 | "selectedTab=overview" 20 | ) 21 | } 22 | ).encode() 23 | ) 24 | expected = ( 25 | "https://app.dev.arize.com/" 26 | "organizations/test-hmac-org/" 27 | "spaces/test-hmac-space/models/" 28 | "modelName/" 29 | "z-upload-classification-data-with-arize?" 30 | "selectedTab=overview" 31 | ) 32 | assert utils.reconstruct_url(input) == expected 33 | 34 | 35 | if __name__ == "__main__": 36 | raise SystemExit(pytest.main([__file__])) 37 | -------------------------------------------------------------------------------- /tests/types/test_embedding_types.py: -------------------------------------------------------------------------------- 1 | import random 2 | import string 3 | 4 | import numpy as np 5 | import pandas as pd 6 | import pytest 7 | 8 | from arize.utils.constants import MAX_RAW_DATA_CHARACTERS 9 | from arize.utils.types import Embedding 10 | 11 | 12 | def random_string(N: int) -> str: 13 | return "".join(random.choices(string.ascii_uppercase + string.digits, k=N)) 14 | 15 | 16 | long_raw_data_string = random_string(MAX_RAW_DATA_CHARACTERS) 17 | long_raw_data_token_array = [random_string(7) for _ in range(11000)] 18 | input_embeddings = { 19 | "correct:complete:list_vector": Embedding( 20 | vector=[1.0, 2, 3], 21 | data="this is a test sentence", 22 | link_to_data="https://my-bucket.s3.us-west-2.amazonaws.com/puppy.png", 23 | ), 24 | "correct:complete:ndarray_vector+list_data": Embedding( 25 | vector=np.array([1.0, 2, 3]), 26 | data=["This", "is", "a", "test", "token", "array"], 27 | link_to_data="https://my-bucket.s3.us-west-2.amazonaws.com/puppy.png", 28 | ), 29 | "correct:complete:pdSeries_vector+ndarray_data": Embedding( 30 | vector=pd.Series([1.0, 2, 3]), 31 | data=np.array(["This", "is", "a", "test", "token", "array"]), 32 | link_to_data="https://my-bucket.s3.us-west-2.amazonaws.com/puppy.png", 33 | ), 34 | "correct:complete:ndarray_vector+pdSeries_data": Embedding( 35 | vector=np.array([1.0, 2, 3]), 36 | data=pd.Series(["This", "is", "a", "test", "token", "array"]), 37 | link_to_data="https://my-bucket.s3.us-west-2.amazonaws.com/puppy.png", 38 | ), 39 | "correct:missing:data": Embedding( 40 | vector=np.array([1.0, 2, 3]), 41 | link_to_data="https://my-bucket.s3.us-west-2.amazonaws.com/puppy.png", 42 | ), 43 | "correct:missing:link_to_data": Embedding( 44 | vector=pd.Series([1.0, 2, 3]), 45 | data=["This", "is", "a", "test", "token", "array"], 46 | ), 47 | "correct:empty_vector": Embedding( 48 | vector=np.array([]), 49 | data=["This", "is", "a", "test", "token", "array"], 50 | ), 51 | "wrong_type:vector": Embedding( 52 | vector=pd.DataFrame([1.0, 2, 3]), 53 | data=2, 54 | link_to_data="https://my-bucket.s3.us-west-2.amazonaws.com/puppy.png", 55 | ), 56 | "wrong_type:data_num": Embedding( 57 | vector=pd.Series([1.0, 2, 3]), 58 | data=2, 59 | link_to_data="https://my-bucket.s3.us-west-2.amazonaws.com/puppy.png", 60 | ), 61 | "wrong_type:data_dataframe": Embedding( 62 | vector=pd.Series([1.0, 2, 3]), 63 | data=pd.DataFrame(["This", "is", "a", "test", "token", "array"]), 64 | link_to_data="https://my-bucket.s3.us-west-2.amazonaws.com/puppy.png", 65 | ), 66 | "wrong_type:link_to_data": Embedding( 67 | vector=np.array([1.0, 2, 3]), 68 | data=["This", "is", "a", "test", "token", "array"], 69 | link_to_data=True, 70 | ), 71 | "wrong_value:size_1_vector": Embedding( 72 | vector=np.array([1.0]), 73 | data=["This", "is", "a", "test", "token", "array"], 74 | ), 75 | "wrong_value:raw_data_string_too_long": Embedding( 76 | vector=pd.Series([1.0, 2, 3]), 77 | data=long_raw_data_string, 78 | ), 79 | "wrong_value:raw_data_token_array_too_long": Embedding( 80 | vector=pd.Series([1.0, 2, 3]), 81 | data=long_raw_data_token_array, 82 | ), 83 | } 84 | 85 | 86 | def test_correct_embeddings(): 87 | keys = [key for key in input_embeddings if "correct:" in key] 88 | assert len(keys) > 0, "Test configuration error: keys must not be empty" 89 | 90 | for key in keys: 91 | embedding = input_embeddings[key] 92 | try: 93 | embedding.validate(key) 94 | except Exception as err: 95 | raise AssertionError( 96 | f"Correct embeddings should give no errors. Failing key = {key:s}. " 97 | f"Error = {err}" 98 | ) from None 99 | 100 | 101 | def test_wrong_value_fields(): 102 | keys = [key for key in input_embeddings if "wrong_value:" in key] 103 | assert len(keys) > 0, "Test configuration error: keys must not be empty" 104 | 105 | for key in keys: 106 | embedding = input_embeddings[key] 107 | try: 108 | embedding.validate(key) 109 | except Exception as err: 110 | assert isinstance( 111 | err, ValueError 112 | ), "Wrong field values should raise value errors" 113 | 114 | 115 | def test_wrong_type_fields(): 116 | keys = [key for key in input_embeddings if "wrong_type:" in key] 117 | assert len(keys) > 0, "Test configuration error: keys must not be empty" 118 | 119 | for key in keys: 120 | embedding = input_embeddings[key] 121 | try: 122 | embedding.validate(key) 123 | except Exception as err: 124 | assert isinstance( 125 | err, TypeError 126 | ), "Wrong field types should raise type errors" 127 | 128 | 129 | if __name__ == "__main__": 130 | raise SystemExit(pytest.main([__file__])) 131 | -------------------------------------------------------------------------------- /tests/types/test_multi_class_types.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from arize.utils.constants import ( 4 | MAX_MULTI_CLASS_NAME_LENGTH, 5 | MAX_NUMBER_OF_MULTI_CLASS_CLASSES, 6 | ) 7 | from arize.utils.types import MultiClassActualLabel, MultiClassPredictionLabel 8 | 9 | overMaxClasses = {"class": 0.2} 10 | for i in range(MAX_NUMBER_OF_MULTI_CLASS_CLASSES): 11 | overMaxClasses[f"class_{i}"] = 0.2 12 | overMaxClassLen = "a" 13 | for _ in range(MAX_MULTI_CLASS_NAME_LENGTH): 14 | overMaxClassLen += "a" 15 | 16 | input_labels = { 17 | "correct:prediction_scores": MultiClassPredictionLabel( 18 | prediction_scores={"class1": 0.1, "class2": 0.2}, 19 | ), 20 | "correct:threshold_scores": MultiClassPredictionLabel( 21 | prediction_scores={"class1": 0.1, "class2": 0.2}, 22 | threshold_scores={"class1": 0.1, "class2": 0.2}, 23 | ), 24 | "invalid:wrong_pred_dictionary_type": MultiClassPredictionLabel( 25 | prediction_scores={"class1": "score", "class2": "score2"}, 26 | ), 27 | "invalid:no_prediction_scores": MultiClassPredictionLabel( 28 | prediction_scores={}, 29 | ), 30 | "invalid:too many_prediction_scores": MultiClassPredictionLabel( 31 | prediction_scores=overMaxClasses, 32 | ), 33 | "invalid:pred_empty_class_name": MultiClassPredictionLabel( 34 | prediction_scores={"": 1.1, "class2": 0.2}, 35 | ), 36 | "invalid:pred_class_name_too_long": MultiClassPredictionLabel( 37 | prediction_scores={overMaxClassLen: 1.1, "class2": 0.2}, 38 | ), 39 | "invalid:pred_score_over_1": MultiClassPredictionLabel( 40 | prediction_scores={"class1": 1.1, "class2": 0.2}, 41 | ), 42 | "invalid:wrong_thresh_dictionary_type": MultiClassPredictionLabel( 43 | prediction_scores={"class1": 0.1, "class2": 0.2}, 44 | threshold_scores={"class1": "score", "class2": 0.2}, 45 | ), 46 | "invalid:pred_thresh_not_same_num_scores": MultiClassPredictionLabel( 47 | prediction_scores={"class1": 0.1, "class2": 0.2}, 48 | threshold_scores={"class1": 0.1}, 49 | ), 50 | "invalid:pred_thresh_not_same_classes": MultiClassPredictionLabel( 51 | prediction_scores={"class1": 0.1, "class2": 0.2}, 52 | threshold_scores={"class1": 0.1, "class3": 0.1}, 53 | ), 54 | "invalid:thresh_score_under_0": MultiClassPredictionLabel( 55 | prediction_scores={"class1": 0.1, "class2": 0.2}, 56 | threshold_scores={"class1": -1, "class2": 0.2}, 57 | ), 58 | "correct:actual_scores": MultiClassActualLabel( 59 | actual_scores={"class1": 0, "class2": 1}, 60 | ), 61 | "correct:actual_scores_multi_1": MultiClassActualLabel( 62 | actual_scores={"class1": 1, "class2": 1}, 63 | ), 64 | "invalid:wrong_actual_dictionary_type": MultiClassActualLabel( 65 | actual_scores={"class1": "score", "class2": 0}, 66 | ), 67 | "invalid:no_actual_scores": MultiClassActualLabel( 68 | actual_scores={}, 69 | ), 70 | "invalid:too_many_actual_scores": MultiClassActualLabel( 71 | actual_scores=overMaxClasses, 72 | ), 73 | "invalid:actual_score_empty_class_name": MultiClassActualLabel( 74 | actual_scores={"": 1, "class2": 0}, 75 | ), 76 | "invalid:act_class_name_too_long": MultiClassActualLabel( 77 | actual_scores={overMaxClassLen: 1.1, "class2": 0.2}, 78 | ), 79 | "invalid:actual_score_not_0_or_1": MultiClassActualLabel( 80 | actual_scores={"class1": 0.7, "class2": 0.2}, 81 | ), 82 | } 83 | 84 | 85 | def test_correct_multi_class_label(): 86 | keys = [key for key in input_labels if "correct:" in key] 87 | assert len(keys) > 0, "Test configuration error: keys must not be empty" 88 | 89 | for key in keys: 90 | multi_class_label = input_labels[key] 91 | try: 92 | multi_class_label.validate() 93 | except Exception as err: 94 | raise AssertionError( 95 | f"Correct mutli class prediction label should give no errors. Failing key = {key:s}. " 96 | f"Error = {err}" 97 | ) from None 98 | 99 | 100 | def test_invalid_scores(): 101 | keys = [key for key in input_labels if "invalid:" in key] 102 | assert len(keys) > 0, "Test configuration error: keys must not be empty" 103 | 104 | for key in keys: 105 | multi_class_label = input_labels[key] 106 | with pytest.raises(ValueError) as e: 107 | multi_class_label.validate() 108 | assert isinstance( 109 | e, ValueError 110 | ), "Invalid values should raise value errors" 111 | 112 | 113 | if __name__ == "__main__": 114 | raise SystemExit(pytest.main([__file__])) 115 | -------------------------------------------------------------------------------- /tests/types/test_type_helpers.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from arize.utils.types import is_dict_of, is_list_of 4 | 5 | 6 | def test_is_dict_of(): 7 | # Assert only key 8 | assert ( 9 | is_dict_of( 10 | {"class1": 0.1, "class2": 0.2}, 11 | key_allowed_types=str, 12 | ) 13 | is True 14 | ) 15 | # Assert key and value exact types 16 | assert ( 17 | is_dict_of( 18 | {"class1": 0.1, "class2": 0.2}, 19 | key_allowed_types=str, 20 | value_allowed_types=float, 21 | ) 22 | is True 23 | ) 24 | assert ( 25 | is_dict_of( 26 | {"class1": 0.1, "class2": 0.2}, 27 | key_allowed_types=str, 28 | value_allowed_types=int, 29 | ) 30 | is False 31 | ) 32 | # Assert key and value union types 33 | assert ( 34 | is_dict_of( 35 | {"class1": 0.1, "class2": 0.2}, 36 | key_allowed_types=str, 37 | value_allowed_types=(str, float), 38 | ) 39 | is True 40 | ) 41 | # Assert key and exact list of value types 42 | assert ( 43 | is_dict_of( 44 | {"class1": [1, 2], "class2": [3, 4]}, 45 | key_allowed_types=str, 46 | value_list_allowed_types=int, 47 | ) 48 | is True 49 | ) 50 | # Assert key and exact list of value types 51 | assert ( 52 | is_dict_of( 53 | {"class1": [1, 2], "class2": [3, 4]}, 54 | key_allowed_types=str, 55 | value_list_allowed_types=str, 56 | ) 57 | is False 58 | ) 59 | # Assert key and union list of value types 60 | assert ( 61 | is_dict_of( 62 | {"class1": [1, 2], "class2": [3, 4]}, 63 | key_allowed_types=str, 64 | value_list_allowed_types=(str, int), 65 | ) 66 | is True 67 | ) 68 | assert ( 69 | is_dict_of( 70 | {"class1": [1, 2], "class2": ["a", "b"]}, 71 | key_allowed_types=str, 72 | value_list_allowed_types=(str, int), 73 | ) 74 | is True 75 | ) 76 | # Assert key and value and list of value types 77 | assert ( 78 | is_dict_of( 79 | {"class1": 1, "class2": ["a", "b"], "class3": [0.4, 0.7]}, 80 | key_allowed_types=str, 81 | value_allowed_types=int, 82 | value_list_allowed_types=(str, float), 83 | ) 84 | is True 85 | ) 86 | assert ( 87 | is_dict_of( 88 | {"class1": 1, "class2": ["a", "b"], "class3": [0.4, 0.7]}, 89 | key_allowed_types=str, 90 | value_allowed_types=str, 91 | value_list_allowed_types=(str, float), 92 | ) 93 | is False 94 | ) 95 | 96 | 97 | def test_is_list_of(): 98 | assert is_list_of([1, 2], int) is True 99 | assert is_list_of([1, 2], float) is False 100 | assert is_list_of(["x", 2], int) is False 101 | assert is_list_of(["x", 2], (str, int)) is True 102 | 103 | 104 | if __name__ == "__main__": 105 | raise SystemExit(pytest.main([__file__])) 106 | --------------------------------------------------------------------------------