├── .github ├── .github │ ├── pull_request_template.md │ └── workflows │ │ ├── pre-commit.yml │ │ └── pytest.yml ├── pull_request_template.md └── workflows │ ├── pre-commit.yml │ └── pytest.yml ├── .gitignore ├── .pre-commit-config.yaml ├── .testing_env ├── LICENSE ├── README.md ├── docs ├── development.md ├── wyvern_architecture.png └── wyvern_logo.jpg ├── examples ├── __init__.py ├── example_business_logic.py ├── feature_store_main.py └── real_time_features_main.py ├── log_config.yml ├── poetry.lock ├── pyproject.toml ├── sdk_ref.md ├── setup.cfg ├── tests ├── __init__.py ├── components │ ├── __init__.py │ └── business_logic │ │ ├── __init__.py │ │ └── test_pinning_business_logic.py ├── conftest.py ├── feature_store │ ├── __init__.py │ └── test_real_time_features.py ├── load_test │ └── load_test.js └── scenarios │ ├── __init__.py │ ├── single_entity_pipelines │ ├── __init__.py │ └── test_single_entity_pipeline.py │ ├── test_indexation.py │ └── test_product_ranking.py └── wyvern ├── __init__.py ├── api.py ├── aws ├── __init__.py └── kinesis.py ├── cli └── commands.py ├── clients ├── __init__.py └── snowflake.py ├── components ├── __init__.py ├── api_route_component.py ├── business_logic │ ├── __init__.py │ ├── boosting_business_logic.py │ ├── business_logic.py │ └── pinning_business_logic.py ├── candidates │ ├── __init__.py │ └── candidate_logger.py ├── component.py ├── events │ ├── __init__.py │ └── events.py ├── features │ ├── __init__.py │ ├── feature_logger.py │ ├── feature_retrieval_pipeline.py │ ├── feature_store.py │ └── realtime_features_component.py ├── helpers │ ├── __init__.py │ ├── linear_algebra.py │ ├── polars.py │ └── sorting.py ├── impressions │ ├── __init__.py │ └── impression_logger.py ├── index │ ├── __init__.py │ └── _index.py ├── models │ ├── __init__.py │ ├── model_chain_component.py │ ├── model_component.py │ └── modelbit_component.py ├── pagination │ ├── __init__.py │ ├── pagination_component.py │ └── pagination_fields.py ├── pipeline_component.py ├── ranking_pipeline.py └── single_entity_pipeline.py ├── config.py ├── core ├── __init__.py ├── compression.py └── http.py ├── entities ├── __init__.py ├── candidate_entities.py ├── feature_entities.py ├── identifier.py ├── identifier_entities.py ├── index_entities.py ├── model_entities.py └── request.py ├── event_logging ├── __init__.py └── event_logger.py ├── exceptions.py ├── experimentation ├── __init__.py ├── client.py ├── experimentation_logging.py └── providers │ ├── __init__.py │ ├── base.py │ └── eppo_provider.py ├── feature_store ├── __init__.py ├── constants.py ├── feature_server.py ├── historical_feature_util.py └── schemas.py ├── helper ├── __init__.py └── sort.py ├── index.py ├── redis.py ├── request_context.py ├── service.py ├── tracking.py ├── utils.py ├── web_frameworks ├── __init__.py └── fastapi.py ├── wyvern_logging.py ├── wyvern_request.py ├── wyvern_tracing.py └── wyvern_typing.py /.github/.github/pull_request_template.md: -------------------------------------------------------------------------------- 1 | - [ ] Does this PR have impact on local development experience? If yes, make sure you have a plan and add the documentations to address issues that come with the change 2 | - [ ] bump version 3 | - [ ] make a release 4 | - [ ] publish to pypi service 5 | -------------------------------------------------------------------------------- /.github/.github/workflows/pre-commit.yml: -------------------------------------------------------------------------------- 1 | name: pre-commit 2 | 3 | on: 4 | pull_request: 5 | push: 6 | branches: [main] 7 | 8 | jobs: 9 | pre-commit: 10 | runs-on: ubuntu-latest 11 | steps: 12 | - uses: actions/checkout@v3 13 | - uses: actions/setup-python@v4 14 | with: 15 | python-version: "3.10" 16 | - uses: pre-commit/action@v3.0.0 17 | -------------------------------------------------------------------------------- /.github/.github/workflows/pytest.yml: -------------------------------------------------------------------------------- 1 | name: pytest 2 | 3 | # Run this job on pushes to `main`, and for pull requests. If you don't specify 4 | # `branches: [main], then this actions runs _twice_ on pull requests, which is 5 | # annoying. 6 | on: 7 | pull_request: 8 | push: 9 | branches: [main] 10 | 11 | jobs: 12 | test: 13 | runs-on: ubuntu-latest 14 | steps: 15 | - uses: actions/checkout@v3 16 | 17 | # If you wanted to use multiple Python versions, you'd have specify a matrix in the job and 18 | # reference the matrixe python version here. 19 | - uses: actions/setup-python@v4 20 | with: 21 | python-version: "3.10" 22 | 23 | # Cache the installation of Poetry itself, e.g. the next step. This prevents the workflow 24 | # from installing Poetry every time, which can be slow. Note the use of the Poetry version 25 | # number in the cache key, and the "-0" suffix: this allows you to invalidate the cache 26 | # manually if/when you want to upgrade Poetry, or if something goes wrong. This could be 27 | # mildly cleaner by using an environment variable, but I don't really care. 28 | - name: cache poetry install 29 | uses: actions/cache@v3 30 | with: 31 | path: ~/.local 32 | key: poetry-1.2.2-0 33 | 34 | # Install Poetry. You could do this manually, or there are several actions that do this. 35 | # `snok/install-poetry` seems to be minimal yet complete, and really just calls out to 36 | # Poetry's default install script, which feels correct. I pin the Poetry version here 37 | # because Poetry does occasionally change APIs between versions and I don't want my 38 | # actions to break if it does. 39 | # 40 | # The key configuration value here is `virtualenvs-in-project: true`: this creates the 41 | # venv as a `.venv` in your testing directory, which allows the next step to easily 42 | # cache it. 43 | - uses: snok/install-poetry@v1 44 | with: 45 | version: 1.2.2 46 | virtualenvs-create: true 47 | virtualenvs-in-project: true 48 | 49 | # Cache your dependencies (i.e. all the stuff in your `pyproject.toml`). Note the cache 50 | # key: if you're using multiple Python versions, or multiple OSes, you'd need to include 51 | # them in the cache key. I'm not, so it can be simple and just depend on the poetry.lock. 52 | - name: cache deps 53 | id: cache-deps 54 | uses: actions/cache@v3 55 | with: 56 | path: .venv 57 | key: pydeps-${{ hashFiles('**/poetry.lock') }} 58 | 59 | # Install dependencies. `--no-root` means "install all dependencies but not the project 60 | # itself", which is what you want to avoid caching _your_ code. The `if` statement 61 | # ensures this only runs on a cache miss. 62 | - run: poetry install --no-interaction --no-root 63 | if: steps.cache-deps.outputs.cache-hit != 'true' 64 | 65 | # Now install _your_ project. This isn't necessary for many types of projects -- particularly 66 | # things like Django apps don't need this. But it's a good idea since it fully-exercises the 67 | # pyproject.toml and makes that if you add things like console-scripts at some point that 68 | # they'll be installed and working. 69 | - run: poetry install --no-interaction 70 | 71 | # And finally run tests. I'm using pytest and all my pytest config is in my `pyproject.toml` 72 | # so this line is super-simple. But it could be as complex as you need. 73 | - run: poetry run pytest 74 | -------------------------------------------------------------------------------- /.github/pull_request_template.md: -------------------------------------------------------------------------------- 1 | - [ ] Does this PR have impact on local development experience? If yes, make sure you have a plan and add the documentations to address issues that come with the change 2 | - [ ] bump version 3 | - [ ] make a release 4 | - [ ] publish to pypi service 5 | -------------------------------------------------------------------------------- /.github/workflows/pre-commit.yml: -------------------------------------------------------------------------------- 1 | name: pre-commit 2 | 3 | on: 4 | pull_request: 5 | push: 6 | branches: [main] 7 | 8 | jobs: 9 | pre-commit: 10 | runs-on: ubuntu-latest 11 | steps: 12 | - uses: actions/checkout@v3 13 | - uses: actions/setup-python@v4 14 | with: 15 | python-version: "3.11" 16 | - uses: pre-commit/action@v3.0.0 17 | -------------------------------------------------------------------------------- /.github/workflows/pytest.yml: -------------------------------------------------------------------------------- 1 | name: pytest 2 | 3 | # Run this job on pushes to `main`, and for pull requests. If you don't specify 4 | # `branches: [main], then this actions runs _twice_ on pull requests, which is 5 | # annoying. 6 | on: 7 | pull_request: 8 | push: 9 | branches: [main] 10 | 11 | jobs: 12 | test: 13 | runs-on: ubuntu-latest 14 | steps: 15 | - uses: actions/checkout@v3 16 | 17 | # If you wanted to use multiple Python versions, you'd have specify a matrix in the job and 18 | # reference the matrixe python version here. 19 | - uses: actions/setup-python@v4 20 | with: 21 | python-version: "3.11" 22 | 23 | # Cache the installation of Poetry itself, e.g. the next step. This prevents the workflow 24 | # from installing Poetry every time, which can be slow. Note the use of the Poetry version 25 | # number in the cache key, and the "-0" suffix: this allows you to invalidate the cache 26 | # manually if/when you want to upgrade Poetry, or if something goes wrong. This could be 27 | # mildly cleaner by using an environment variable, but I don't really care. 28 | - name: cache poetry install 29 | uses: actions/cache@v3 30 | with: 31 | path: ~/.local 32 | key: poetry-1.2.2-0 33 | 34 | # Install Poetry. You could do this manually, or there are several actions that do this. 35 | # `snok/install-poetry` seems to be minimal yet complete, and really just calls out to 36 | # Poetry's default install script, which feels correct. I pin the Poetry version here 37 | # because Poetry does occasionally change APIs between versions and I don't want my 38 | # actions to break if it does. 39 | # 40 | # The key configuration value here is `virtualenvs-in-project: true`: this creates the 41 | # venv as a `.venv` in your testing directory, which allows the next step to easily 42 | # cache it. 43 | - uses: snok/install-poetry@v1 44 | with: 45 | version: 1.2.2 46 | virtualenvs-create: true 47 | virtualenvs-in-project: true 48 | 49 | # Cache your dependencies (i.e. all the stuff in your `pyproject.toml`). Note the cache 50 | # key: if you're using multiple Python versions, or multiple OSes, you'd need to include 51 | # them in the cache key. I'm not, so it can be simple and just depend on the poetry.lock. 52 | - name: cache deps 53 | id: cache-deps 54 | uses: actions/cache@v3 55 | with: 56 | path: .venv 57 | key: pydeps-${{ hashFiles('**/poetry.lock') }} 58 | 59 | # Install dependencies. `--no-root` means "install all dependencies but not the project 60 | # itself", which is what you want to avoid caching _your_ code. The `if` statement 61 | # ensures this only runs on a cache miss. 62 | - run: poetry install --no-interaction --no-root 63 | if: steps.cache-deps.outputs.cache-hit != 'true' 64 | 65 | # Now install _your_ project. This isn't necessary for many types of projects -- particularly 66 | # things like Django apps don't need this. But it's a good idea since it fully-exercises the 67 | # pyproject.toml and makes that if you add things like console-scripts at some point that 68 | # they'll be installed and working. 69 | - run: poetry install --no-interaction 70 | 71 | # And finally run tests. I'm using pytest and all my pytest config is in my `pyproject.toml` 72 | # so this line is super-simple. But it could be as complex as you need. 73 | - run: poetry run pytest 74 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | 62 | # PyCharm 63 | .idea/ 64 | db.sqlite3 65 | db.sqlite3-journal 66 | 67 | # Flask stuff: 68 | instance/ 69 | .webassets-cache 70 | 71 | # Scrapy stuff: 72 | .scrapy 73 | 74 | # Sphinx documentation 75 | docs/_build/ 76 | 77 | # PyBuilder 78 | target/ 79 | 80 | # Jupyter Notebook 81 | .ipynb_checkpoints 82 | 83 | # IPython 84 | profile_default/ 85 | ipython_config.py 86 | 87 | # pyenv 88 | .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 98 | __pypackages__/ 99 | 100 | # Celery stuff 101 | celerybeat-schedule 102 | celerybeat.pid 103 | 104 | # SageMath parsed files 105 | *.sage.py 106 | 107 | # Environments 108 | .env 109 | .venv 110 | env/ 111 | venv/ 112 | ENV/ 113 | env.bak/ 114 | venv.bak/ 115 | 116 | # Spyder project settings 117 | .spyderproject 118 | .spyproject 119 | 120 | # Rope project settings 121 | .ropeproject 122 | 123 | # mkdocs documentation 124 | /site 125 | 126 | # mypy 127 | .mypy_cache/ 128 | .dmypy.json 129 | dmypy.json 130 | 131 | # Pyre type checker 132 | .pyre/ 133 | 134 | # ctags 135 | tags 136 | tags.lock 137 | tags.temp 138 | 139 | # mac 140 | .DS_Store 141 | 142 | # vim 143 | *.swp 144 | 145 | # pyright 146 | pyrightconfig.json 147 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | # Standard hooks 3 | - repo: https://github.com/pre-commit/pre-commit-hooks 4 | rev: v4.1.0 5 | hooks: 6 | - id: check-added-large-files 7 | - id: check-case-conflict 8 | - id: check-merge-conflict 9 | - id: check-symlinks 10 | - id: check-yaml 11 | - id: debug-statements 12 | - id: end-of-file-fixer 13 | - id: mixed-line-ending 14 | - id: trailing-whitespace 15 | - id: fix-encoding-pragma 16 | 17 | # Black, the code formatter, natively supports pre-commit 18 | - repo: https://github.com/psf/black 19 | rev: 22.3.0 20 | hooks: 21 | - id: black 22 | 23 | - repo: https://github.com/pycqa/flake8 24 | rev: 4.0.1 25 | hooks: 26 | - id: flake8 27 | additional_dependencies: [flake8-bugbear, pep8-naming] 28 | 29 | - repo: https://github.com/pre-commit/mirrors-autopep8 30 | rev: v2.0.0 31 | hooks: 32 | - id: autopep8 33 | 34 | - repo: https://github.com/pre-commit/mirrors-mypy 35 | rev: v1.0.0 36 | hooks: 37 | - id: mypy 38 | args: [--show-error-codes, --python-version=3.10] 39 | additional_dependencies: 40 | - fastapi 41 | - uvicorn 42 | - typer 43 | - types-pyyaml 44 | - types-boto3 45 | - pyhumps 46 | - setuptools 47 | - pandas-stubs 48 | - types-redis 49 | - python-dotenv 50 | - feast 51 | - pytest 52 | - types-protobuf 53 | - snowflake-connector-python 54 | - fastapi-utils 55 | - pyinstrument 56 | - ddtrace 57 | - msgspec 58 | - lz4 59 | - types-requests 60 | - more-itertools 61 | - tqdm 62 | - types-tqdm 63 | - nest-asyncio 64 | - aiohttp 65 | - polars 66 | exclude: "^tests/" 67 | 68 | # Check for spelling 69 | - repo: https://github.com/codespell-project/codespell 70 | rev: v2.1.0 71 | hooks: 72 | - id: codespell 73 | exclude: ".supp$" 74 | args: ["-L", "nd,ot,thist,paramater"] 75 | 76 | - repo: https://github.com/asottile/add-trailing-comma 77 | rev: v2.2.1 78 | hooks: 79 | - id: add-trailing-comma 80 | 81 | - repo: https://github.com/pre-commit/pygrep-hooks 82 | rev: v1.9.0 83 | hooks: 84 | - id: python-check-blanket-noqa 85 | - id: python-check-mock-methods 86 | - id: python-no-log-warn 87 | - id: python-use-type-annotations 88 | 89 | - repo: https://github.com/prettier/pre-commit 90 | rev: 57f39166b5a5a504d6808b87ab98d41ebf095b46 91 | hooks: 92 | - id: prettier 93 | 94 | # Disallow some common capitalization mistakes 95 | - repo: local 96 | hooks: 97 | - id: disallow-caps 98 | name: Disallow improper capitalization 99 | language: pygrep 100 | entry: PyBind|Numpy|Cmake|CCache|PyTest|PyTest-Cov 101 | exclude: | 102 | (?x)( 103 | ^\.pre-commit-config.yaml$ | 104 | poetry.lock$ 105 | ) 106 | - repo: https://github.com/pycqa/isort 107 | rev: 5.12.0 108 | hooks: 109 | - id: isort 110 | name: isort (python) 111 | -------------------------------------------------------------------------------- /.testing_env: -------------------------------------------------------------------------------- 1 | DD_TRACE_ENABLED=false 2 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Wyvern AI uses Elastic License 2.0 2 | 3 | Elastic License 2.0 (ELv2) 4 | 5 | **Acceptance** 6 | By using the software, you agree to all of the terms and conditions below. 7 | 8 | **Copyright License** 9 | The licensor grants you a non-exclusive, royalty-free, worldwide, non-sublicensable, non-transferable license to use, copy, distribute, make available, and prepare derivative works of the software, in each case subject to the limitations and conditions below 10 | 11 | **Limitations** 12 | You may not provide the software to third parties as a hosted or managed service, where the service provides users with access to any substantial set of the features or functionality of the software. 13 | 14 | You may not move, change, disable, or circumvent the license key functionality in the software, and you may not remove or obscure any functionality in the software that is protected by the license key. 15 | 16 | You may not alter, remove, or obscure any licensing, copyright, or other notices of the licensor in the software. Any use of the licensor’s trademarks is subject to applicable law. 17 | 18 | **Patents** 19 | The licensor grants you a license, under any patent claims the licensor can license, or becomes able to license, to make, have made, use, sell, offer for sale, import and have imported the software, in each case subject to the limitations and conditions in this license. This license does not cover any patent claims that you cause to be infringed by modifications or additions to the software. If you or your company make any written claim that the software infringes or contributes to infringement of any patent, your patent license for the software granted under these terms ends immediately. If your company makes such a claim, your patent license ends immediately for work on behalf of your company. 20 | 21 | **Notices** 22 | You must ensure that anyone who gets a copy of any part of the software from you also gets a copy of these terms. 23 | 24 | If you modify the software, you must include in any modified copies of the software prominent notices stating that you have modified the software. 25 | 26 | **No Other Rights** 27 | These terms do not imply any licenses other than those expressly granted in these terms. 28 | 29 | **Termination** 30 | If you use the software in violation of these terms, such use is not licensed, and your licenses will automatically terminate. If the licensor provides you with a notice of your violation, and you cease all violation of this license no later than 30 days after you receive that notice, your licenses will be reinstated retroactively. However, if you violate these terms after such reinstatement, any additional violation of these terms will cause your licenses to terminate automatically and permanently. 31 | 32 | **No Liability** 33 | As far as the law allows, the software comes as is, without any warranty or condition, and the licensor will not be liable to you for any damages arising out of these terms or the use or nature of the software, under any kind of legal claim. 34 | 35 | **Definitions** 36 | The *licensor* is the entity offering these terms, and the *software* is the software the licensor makes available under these terms, including any portion of it. 37 | 38 | *you* refers to the individual or entity agreeing to these terms. 39 | 40 | *your company* is any legal entity, sole proprietorship, or other kind of organization that you work for, plus all organizations that have control over, are under the control of, or are under common control with that organization. *control* means ownership of substantially all the assets of an entity, or the power to direct its management and policies by vote, contract, or otherwise. Control can be direct or indirect. 41 | 42 | *your licenses* are all the licenses granted to you for the software under these terms. 43 | 44 | *use* means anything you do with the software requiring one of your licenses. 45 | 46 | *trademark* means trademarks, service marks, and similar rights. 47 | -------------------------------------------------------------------------------- /docs/development.md: -------------------------------------------------------------------------------- 1 | # Setup 2 | 3 | This project is managed by Poetry 4 | 5 | ## Set up PyEnv 6 | 7 | Install PyEnv using Brew: 8 | 9 | ```bash 10 | brew install pyenv 11 | ``` 12 | 13 | Add this to your ~/.zshrc or ~/.bashrc depending on what you use. Documentation copied from [here](https://github.com/pyenv/pyenv#set-up-your-shell-environment-for-pyenv) 14 | 15 | ```bash 16 | export PYENV_ROOT="$HOME/.pyenv" 17 | command -v pyenv >/dev/null || export PATH="$PYENV_ROOT/bin:$PATH" 18 | eval "$(pyenv init -)" 19 | ``` 20 | 21 | ## Set up Poetry env 22 | 23 | ### 1. Install Poetry 24 | 25 | ```bash 26 | brew install poetry 27 | ``` 28 | 29 | ### 2. Set up Poetry to create virtual envs in the local directory 30 | 31 | ```bash 32 | poetry config virtualenvs.in-project true 33 | ``` 34 | 35 | ### 3. Python Version 36 | 37 | Poetry apparently has trouble initializing the Python version itself, so you'll have to force it to use the correct version 38 | 39 | At the time of this writing, the correct version is 3.10, so just run: 40 | 41 | ``` 42 | poetry env use 3.10 43 | ``` 44 | 45 | And it'll switch the python version to the correct one. You only need to do this once 46 | 47 | ### 4. Set up the virtual environment 48 | 49 | Have poetry set up all of the configs 50 | 51 | ```bash 52 | poetry install 53 | ``` 54 | 55 | ### (Optional) 5. Set up auto-poetry shell spawning 56 | 57 | Add this to your ~/.zshrc: 58 | 59 | This automatically spawns a new poetry shell whenever you `cd` into a directory with a poetry env 60 | 61 | ```bash 62 | ### Autoomatically activate virtual environment 63 | function auto_poetry_shell { 64 | if [ -f "pyproject.toml" ] ; then 65 | source ./.venv/bin/activate 66 | fi 67 | } 68 | 69 | function cd { 70 | builtin cd "$@" 71 | auto_poetry_shell 72 | } 73 | 74 | auto_poetry_shell 75 | ``` 76 | 77 | ## Set up Pre-commit 78 | 79 | ```bash 80 | brew install pre-commit 81 | pre-commit install 82 | ``` 83 | 84 | ## Generate SDK reference 85 | 86 | The markdown is generated from the docstrings in the code. To generate the markdown, run: 87 | 88 | ```bash 89 | pydoc-markdown -I $(pwd) > sdk_ref.md 90 | ``` 91 | -------------------------------------------------------------------------------- /docs/wyvern_architecture.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Skyvern-AI/wyvern/a6f5f72860e6f528e81fb816fa3f4275d08bed1d/docs/wyvern_architecture.png -------------------------------------------------------------------------------- /docs/wyvern_logo.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Skyvern-AI/wyvern/a6f5f72860e6f528e81fb816fa3f4275d08bed1d/docs/wyvern_logo.jpg -------------------------------------------------------------------------------- /examples/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Skyvern-AI/wyvern/a6f5f72860e6f528e81fb816fa3f4275d08bed1d/examples/__init__.py -------------------------------------------------------------------------------- /examples/example_business_logic.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import asyncio 3 | import json 4 | import logging 5 | from typing import List 6 | 7 | from pydantic import BaseModel 8 | 9 | from wyvern.components.business_logic.boosting_business_logic import ( 10 | BoostingBusinessLogicComponent, 11 | ) 12 | from wyvern.components.business_logic.business_logic import ( 13 | BusinessLogicEvent, 14 | BusinessLogicPipeline, 15 | BusinessLogicRequest, 16 | ) 17 | from wyvern.components.component import Component 18 | from wyvern.entities.candidate_entities import CandidateSetEntity, ScoredCandidate 19 | from wyvern.entities.identifier_entities import ProductEntity, QueryEntity 20 | from wyvern.entities.request import BaseWyvernRequest 21 | 22 | logger = logging.getLogger(__name__) 23 | 24 | 25 | class SimpleProductEntity(ProductEntity): 26 | product_name: str = "" 27 | product_description: str = "" 28 | 29 | 30 | class ExampleProductSearchRankingRequest( 31 | BaseWyvernRequest, 32 | QueryEntity, 33 | CandidateSetEntity[SimpleProductEntity], 34 | ): 35 | pass 36 | 37 | 38 | class CandleBoostingBusinessLogicComponent( 39 | BoostingBusinessLogicComponent[ 40 | SimpleProductEntity, 41 | ExampleProductSearchRankingRequest, 42 | ], 43 | ): 44 | async def execute( 45 | self, 46 | input: BusinessLogicRequest[ 47 | SimpleProductEntity, 48 | ExampleProductSearchRankingRequest, 49 | ], 50 | **kwargs, 51 | ) -> List[ScoredCandidate[SimpleProductEntity]]: 52 | # TODO (suchintan): Feature request: Add a way to load a CSV from S3 53 | # Define this in a CSV (query, entity_keys, boost) 54 | # Get access to S3 55 | # Load CSV from S3 56 | # Reference the file and boost it -- reload the file every 15 minutes or something like that 57 | logger.info(f"Boosting candles for query={input.request.query}") 58 | if input.request.query == "candle": 59 | return self.boost(input.scored_candidates, entity_keys={"3"}, boost=100) 60 | else: 61 | return input.scored_candidates 62 | 63 | 64 | class AlwaysBoostWaxSealProduct( 65 | BoostingBusinessLogicComponent[ 66 | SimpleProductEntity, 67 | ExampleProductSearchRankingRequest, 68 | ], 69 | ): 70 | async def execute( 71 | self, 72 | input: BusinessLogicRequest[ 73 | SimpleProductEntity, 74 | ExampleProductSearchRankingRequest, 75 | ], 76 | **kwargs, 77 | ) -> List[ScoredCandidate[SimpleProductEntity]]: 78 | return self.boost(input.scored_candidates, entity_keys={"7"}, boost=100) 79 | 80 | 81 | class SearchBusinessLogicPipeline( 82 | BusinessLogicPipeline[SimpleProductEntity, ExampleProductSearchRankingRequest], 83 | ): 84 | def __init__(self): 85 | super().__init__( 86 | CandleBoostingBusinessLogicComponent(), 87 | AlwaysBoostWaxSealProduct(), 88 | name="search_business_logic_pipeline", 89 | ) 90 | 91 | 92 | search_business_logic_pipeline = SearchBusinessLogicPipeline() 93 | 94 | 95 | class ExampleProductSearchRankingCandidateResponse(BaseModel): 96 | product_name: str 97 | old_rank: int 98 | old_score: float 99 | new_rank: int 100 | new_score: float 101 | 102 | 103 | class ExampleProductSearchRankingResponse(BaseModel): 104 | ranked_products: List[ExampleProductSearchRankingCandidateResponse] 105 | events: List[BusinessLogicEvent] 106 | 107 | 108 | class ProductQueryRankingBusinessLogicComponent( 109 | Component[ExampleProductSearchRankingRequest, ExampleProductSearchRankingResponse], 110 | ): 111 | def __init__(self): 112 | super().__init__( 113 | search_business_logic_pipeline, 114 | name="product_query_ranking_business_logic_component", 115 | ) 116 | 117 | async def execute( 118 | self, input: ExampleProductSearchRankingRequest, **kwargs 119 | ) -> ExampleProductSearchRankingResponse: 120 | logger.info(f"Input request: {input}") 121 | # Set up a really silly score 122 | scored_candidates: List[ScoredCandidate] = [ 123 | ScoredCandidate(entity=candidate, score=(len(input.candidates) - i)) 124 | for i, candidate in enumerate(input.candidates) 125 | ] 126 | 127 | business_logic_request = BusinessLogicRequest[ 128 | SimpleProductEntity, 129 | ExampleProductSearchRankingRequest, 130 | ]( 131 | request=input, 132 | scored_candidates=scored_candidates, 133 | ) 134 | 135 | ranked_products = await search_business_logic_pipeline.execute( 136 | business_logic_request, 137 | ) 138 | 139 | pretty_ranked_products = [ 140 | ExampleProductSearchRankingCandidateResponse( 141 | product_name=entity_score.entity.product_name, 142 | old_rank=input.candidates.index(entity_score.entity), 143 | old_score=ranked_products.request.scored_candidates[ 144 | input.candidates.index(entity_score.entity) 145 | ].score, 146 | new_rank=i, 147 | new_score=entity_score.score, 148 | ) 149 | for i, entity_score in enumerate(ranked_products.adjusted_candidates) 150 | ] 151 | 152 | return ExampleProductSearchRankingResponse( 153 | ranked_products=pretty_ranked_products, 154 | ) 155 | 156 | 157 | def create_example_business_logic_component() -> ProductQueryRankingBusinessLogicComponent: 158 | business_logic_component = ProductQueryRankingBusinessLogicComponent() 159 | return business_logic_component 160 | 161 | 162 | async def sample_product_query_ranking_request() -> None: 163 | """ 164 | How to run this: `python wyvern/examples/example_business_logic.py` 165 | 166 | Json representation of the request: 167 | ``` 168 | { 169 | "request_id": "rrr", 170 | "query": "candle", 171 | "candidates": [ 172 | {"product_id": "1", "product_name": "scented candle"}, 173 | {"product_id": "2", "product_name": "hot candle"}, 174 | {"product_id": "3", "product_name": "pumpkin candle"}, 175 | {"product_id": "4", "product_name": "unrelated item"}, 176 | {"product_id": "5", "product_name": "candle holder accessory"}, 177 | {"product_id": "6", "product_name": "earwax holder"}, 178 | {"product_id": "7", "product_name": "wax seal"} 179 | ], 180 | } 181 | ``` 182 | """ 183 | logger.info("Start query product business logic case...") 184 | req = ExampleProductSearchRankingRequest( 185 | request_id="rrr", 186 | query="candle", 187 | candidates=[ 188 | SimpleProductEntity(product_id="1", product_name="scented candle"), 189 | SimpleProductEntity(product_id="2", product_name="hot candle"), 190 | SimpleProductEntity(product_id="3", product_name="pumpkin candle"), 191 | SimpleProductEntity(product_id="4", product_name="unrelated item"), 192 | SimpleProductEntity(product_id="5", product_name="candle holder accessory"), 193 | SimpleProductEntity(product_id="6", product_name="earwax holder"), 194 | SimpleProductEntity(product_id="7", product_name="wax seal"), 195 | ], 196 | ) 197 | 198 | component = create_example_business_logic_component() 199 | await component.initialize() 200 | 201 | response = await component.execute(req) 202 | 203 | json_formatted_str = json.dumps( 204 | response.dict(), 205 | indent=2, 206 | ) 207 | logger.info(f"Response: {json_formatted_str}") 208 | 209 | 210 | if __name__ == "__main__": 211 | asyncio.run(sample_product_query_ranking_request()) 212 | -------------------------------------------------------------------------------- /examples/feature_store_main.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import logging 3 | from typing import Dict 4 | 5 | import typer 6 | from pydantic import BaseModel 7 | 8 | from wyvern import Identifier 9 | from wyvern.components.api_route_component import APIRouteComponent 10 | from wyvern.components.features.feature_store import ( 11 | FeatureStoreRetrievalRequest, 12 | feature_store_retrieval_component, 13 | ) 14 | from wyvern.entities.feature_entities import FeatureData 15 | from wyvern.service import WyvernService 16 | 17 | wyvern_cli_app = typer.Typer() 18 | 19 | logger = logging.getLogger(__name__) 20 | 21 | 22 | class FeatureStoreResponse(BaseModel): 23 | feature_data: Dict[str, FeatureData] 24 | 25 | 26 | class FeatureStoreTestingComponent( 27 | APIRouteComponent[FeatureStoreRetrievalRequest, FeatureStoreResponse], 28 | ): 29 | PATH = "/feature-store-testing" 30 | REQUEST_SCHEMA_CLASS = FeatureStoreRetrievalRequest 31 | RESPONSE_SCHEMA_CLASS = FeatureStoreResponse 32 | 33 | async def execute( 34 | self, input: FeatureStoreRetrievalRequest, **kwargs 35 | ) -> FeatureStoreResponse: 36 | logger.info(f"Executing input {input}") 37 | feature_df = await feature_store_retrieval_component.execute(input) 38 | feature_dicts = feature_df.df.to_dicts() 39 | feature_data: Dict[str, FeatureData] = { 40 | str(feature_dict["IDENTIFIER"]): FeatureData( 41 | identifier=Identifier( 42 | identifier_type=feature_dict["IDENTIFIER"].split("::")[0], 43 | identifier=feature_dict["IDENTIFIER"].split("::")[1], 44 | ), 45 | features={ 46 | feature_name: feature_value 47 | for feature_name, feature_value in feature_dict.items() 48 | if feature_name != "IDENTIFIER" 49 | }, 50 | ) 51 | for feature_dict in feature_dicts 52 | } 53 | return FeatureStoreResponse( 54 | feature_data=feature_data, 55 | ) 56 | 57 | 58 | @wyvern_cli_app.command() 59 | def run( 60 | host: str = "127.0.0.1", 61 | port: int = 8000, 62 | ) -> None: 63 | """ 64 | Run your wyvern service 65 | """ 66 | WyvernService.run( 67 | route_components=[FeatureStoreTestingComponent], 68 | host=host, 69 | port=port, 70 | ) 71 | 72 | 73 | if __name__ == "__main__": 74 | # TODO (suchintan): Add support for hot swapping code here 75 | wyvern_cli_app() 76 | -------------------------------------------------------------------------------- /log_config.yml: -------------------------------------------------------------------------------- 1 | version: 1 2 | disable_existing_loggers: False 3 | formatters: 4 | timestamped: 5 | format: "%(levelname)s: [%(asctime)s] [%(name)s] %(message)s" 6 | handlers: 7 | console: 8 | class: logging.StreamHandler 9 | formatter: timestamped 10 | level: INFO 11 | stream: ext://sys.stdout 12 | root: 13 | level: INFO 14 | handlers: [console] 15 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.poetry] 2 | name = "wyvern-ai" 3 | version = "0.0.30" 4 | description = "" 5 | authors = ["Wyvern AI "] 6 | readme = "README.md" 7 | packages = [ 8 | { include = "wyvern" }, 9 | ] 10 | 11 | [tool.poetry.dependencies] 12 | python = ">=3.8,<3.12" 13 | pydantic = "^1.10.4" 14 | fastapi = "^0.95.2" 15 | uvicorn = "^0.22.0" 16 | typer = {extras = ["all"], version = "^0.9.0"} 17 | pyyaml = "^6.0" 18 | pyhumps = "^3.8.0" 19 | python-dotenv = "^1.0.0" 20 | pandas = "^1.5,<2.0" 21 | feast = {version = "^0.34.1", extras = ["redis", "snowflake"]} 22 | snowflake-connector-python = "^3.1" 23 | boto3 = "^1.26.146" 24 | ddtrace = "^1.14.0" 25 | msgspec = "^0.16.0" 26 | lz4 = "^4.3.2" 27 | more-itertools = "^9.1.0" 28 | tqdm = "^4.65.0" 29 | nest-asyncio = "^1.5.7" 30 | eppo-server-sdk = "^1.2.3" 31 | scipy = "^1.10.1" 32 | aiohttp = {extras = ["speedups"], version = "^3.8.5"} 33 | requests = "^2.31.0" 34 | platformdirs = "^3.8" 35 | posthog = "^3.0.2" 36 | polars = "^0.19.6" 37 | 38 | 39 | [tool.poetry.group.dev.dependencies] 40 | ipython = "^8.9.0" 41 | pytest = "^7.2.1" 42 | isort = "^5.12.0" 43 | types-pyyaml = "^6.0.12.6" 44 | black = "^22.6.0" 45 | pip-tools = "^6.12.2" 46 | twine = "^4.0.2" 47 | pytest-asyncio = "^0.21.0" 48 | pytest-mock = "^3.10.0" 49 | types-boto3 = "^1.0.2" 50 | pyinstrument = "^4.4.0" 51 | pytest-dotenv = "^0.5.2" 52 | ipykernel = "^6.25.0" 53 | aioresponses = "^0.7.4" 54 | 55 | [build-system] 56 | requires = ["poetry-core"] 57 | build-backend = "poetry.core.masonry.api" 58 | 59 | [[tool.mypy.overrides]] 60 | module=[ 61 | "scipy.spatial.distance.*", 62 | "setuptools.*", 63 | "ddtrace.*", 64 | "nest_asyncio.*", 65 | "lz4.*", 66 | "posthog.*", 67 | ] 68 | ignore_missing_imports = true 69 | 70 | [tool.isort] 71 | profile = "black" 72 | 73 | [tool.pytest.ini_options] 74 | addopts = "-v" 75 | filterwarnings = ["ignore::DeprecationWarning"] 76 | log_cli = true 77 | log_cli_level = "INFO" 78 | log_cli_format = "%(asctime)s [%(levelname)8s] %(message)s (%(filename)s:%(lineno)s)" 79 | log_cli_date_format = "%Y-%m-%d %H:%M:%S" 80 | env_files = [ 81 | ".testing_env", 82 | ] 83 | 84 | [tool.poetry.scripts] 85 | wyvern = "wyvern.cli.commands:app" 86 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [flake8] 2 | max-line-length = 120 3 | # black will introduce W503 which is not PEP8 compatible right now. However, https://peps.python.org/pep-0008/#should-a-line-break-before-or-after-a-binary-operator 4 | # Ignore E203 due to https://github.com/psf/black/issues/315 5 | ignore = N805,N802,B008,W503,E203 6 | extend-immutable-calls = fastapi.Depends, fastapi.params.Depends 7 | 8 | [metadata] 9 | description-file = README.md 10 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Skyvern-AI/wyvern/a6f5f72860e6f528e81fb816fa3f4275d08bed1d/tests/__init__.py -------------------------------------------------------------------------------- /tests/components/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Skyvern-AI/wyvern/a6f5f72860e6f528e81fb816fa3f4275d08bed1d/tests/components/__init__.py -------------------------------------------------------------------------------- /tests/components/business_logic/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Skyvern-AI/wyvern/a6f5f72860e6f528e81fb816fa3f4275d08bed1d/tests/components/business_logic/__init__.py -------------------------------------------------------------------------------- /tests/components/business_logic/test_pinning_business_logic.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from collections import defaultdict 3 | from typing import Dict, List 4 | 5 | import pytest 6 | 7 | from wyvern import request_context 8 | from wyvern.components.business_logic.business_logic import ( 9 | BusinessLogicPipeline, 10 | BusinessLogicRequest, 11 | BusinessLogicResponse, 12 | ) 13 | from wyvern.components.business_logic.pinning_business_logic import ( 14 | PinningBusinessLogicComponent, 15 | ) 16 | from wyvern.entities.candidate_entities import ScoredCandidate 17 | from wyvern.entities.feature_entities import FeatureDataFrame 18 | from wyvern.entities.identifier_entities import ProductEntity 19 | from wyvern.entities.request import BaseWyvernRequest 20 | from wyvern.wyvern_request import WyvernRequest 21 | 22 | 23 | async def set_up_pinning_components( 24 | scored_candidates: List[ScoredCandidate[ProductEntity]], 25 | entity_pins: Dict[str, int], 26 | allow_down_ranking=True, 27 | ) -> BusinessLogicResponse[ProductEntity, BaseWyvernRequest]: 28 | class TestPins(PinningBusinessLogicComponent[ProductEntity, BaseWyvernRequest]): 29 | async def execute( 30 | self, 31 | input: BusinessLogicRequest[ProductEntity, BaseWyvernRequest], 32 | **kwargs, 33 | ) -> List[ScoredCandidate[ProductEntity]]: 34 | return self.pin( 35 | input.scored_candidates, 36 | entity_pins=entity_pins, 37 | allow_down_ranking=allow_down_ranking, 38 | ) 39 | 40 | class TestBusinessLogicPipeline( 41 | BusinessLogicPipeline[ProductEntity, BaseWyvernRequest], 42 | ): 43 | def __init__(self): 44 | """ 45 | Add new business logic components here. All business logic steps are executed in the order defined here. 46 | """ 47 | super().__init__( 48 | TestPins(), 49 | name="test_business_logic_pipeline", 50 | ) 51 | 52 | pipeline = TestBusinessLogicPipeline() 53 | await pipeline.initialize() 54 | 55 | request = BusinessLogicRequest[ProductEntity, BaseWyvernRequest]( 56 | request=BaseWyvernRequest(request_id="123"), 57 | scored_candidates=scored_candidates, 58 | ) 59 | 60 | request_context.set( 61 | WyvernRequest( 62 | method="POST", 63 | url="TestTest", 64 | url_path="Test", 65 | json=request, 66 | headers={}, 67 | entity_store={}, 68 | events=[], 69 | feature_df=FeatureDataFrame(), 70 | feature_orig_identifiers=defaultdict(dict), 71 | model_output_map={}, 72 | ), 73 | ) 74 | return await pipeline.execute(request) 75 | 76 | 77 | def generate_scored_candidates(id_score_pairs: Dict[str, float]): 78 | return [ 79 | ScoredCandidate(entity=ProductEntity(product_id=id), score=score) 80 | for id, score in id_score_pairs.items() 81 | ] 82 | 83 | 84 | @pytest.mark.asyncio 85 | async def test_pins(): 86 | scored_candidates = generate_scored_candidates( 87 | { 88 | "product_1": 6, 89 | "product_2": 5, 90 | "product_3": 4, 91 | "product_4": 3, 92 | "product_5": 2, 93 | "product_6": 1, 94 | }, 95 | ) 96 | 97 | pins = { 98 | "product_6": 11, 99 | "product_5": 10, 100 | "product_3": 0, 101 | "product_2": 0, 102 | "product_4": 2, 103 | } 104 | 105 | result = await set_up_pinning_components(scored_candidates, pins) 106 | 107 | adjusted_candidates = [ 108 | candidate.entity.product_id for candidate in result.adjusted_candidates 109 | ] 110 | # Bug -- product_4 is coming in at index 3, not index 2 like requested.. due to the other boosts 111 | expected_order = [ 112 | "product_2", 113 | "product_3", 114 | "product_1", 115 | "product_4", 116 | "product_5", 117 | "product_6", 118 | ] 119 | assert adjusted_candidates == expected_order 120 | 121 | 122 | @pytest.mark.asyncio 123 | async def test_pins__no_down_ranking(): 124 | scored_candidates = generate_scored_candidates( 125 | { 126 | "product_1": 6, 127 | "product_2": 5, 128 | "product_3": 4, 129 | "product_4": 3, 130 | "product_5": 2, 131 | "product_6": 1, 132 | }, 133 | ) 134 | 135 | pins = { 136 | "product_6": 11, 137 | "product_5": 12, 138 | "product_3": 0, 139 | "product_2": 22, 140 | "product_4": 2, 141 | } 142 | 143 | result = await set_up_pinning_components( 144 | scored_candidates, 145 | pins, 146 | allow_down_ranking=False, 147 | ) 148 | 149 | adjusted_candidates = [ 150 | candidate.entity.product_id for candidate in result.adjusted_candidates 151 | ] 152 | # Bug -- product_4 is coming in at index 3, not index 2 like requested.. due to the other boosts 153 | expected_order = [ 154 | "product_3", 155 | "product_1", 156 | "product_2", 157 | "product_4", 158 | "product_5", 159 | "product_6", 160 | ] 161 | assert adjusted_candidates == expected_order 162 | 163 | 164 | """ 165 | TODO (suchintan): 166 | Test cases: 167 | 1. Pin any product 168 | 2. Pin multiple products in different order 169 | 3. Allow down ranking = false and true 170 | 4. Pin a product that is not in the list 171 | 5. Pin a product that is in the list but not in the top 10 172 | 6. Pin multiple products to the same position 173 | 7. Pin to the top of the list 174 | 8. Pin to the bottom of the list 175 | """ 176 | -------------------------------------------------------------------------------- /tests/conftest.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import pytest 3 | from ddtrace import tracer 4 | 5 | 6 | @pytest.fixture(scope="session", autouse=True) 7 | def disable_ddtrace(): 8 | tracer.enabled = False 9 | -------------------------------------------------------------------------------- /tests/feature_store/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Skyvern-AI/wyvern/a6f5f72860e6f528e81fb816fa3f4275d08bed1d/tests/feature_store/__init__.py -------------------------------------------------------------------------------- /tests/load_test/load_test.js: -------------------------------------------------------------------------------- 1 | // Install k6 `brew install k6` 2 | import http from "k6/http"; 3 | import { sleep, check } from "k6"; 4 | 5 | export const options = { 6 | vus: 1, 7 | duration: "5s", 8 | }; 9 | 10 | export default function () { 11 | const api_key = ""; 12 | 13 | const params = { 14 | headers: { 15 | "Content-Type": "application/json", 16 | "x-api-key": api_key, 17 | }, 18 | }; 19 | 20 | const candidates = Array.from({ length: 1000 }, (_, i) => ({ 21 | product_id: i + Math.floor(Math.random() * 10000), 22 | opensearch_score: 1000 - i, 23 | })); 24 | 25 | const payload = JSON.stringify({ 26 | request_id: "test_request_id", 27 | query: { query: "candle" }, 28 | candidates: candidates, 29 | user: { user_id: "1234", user_name: "user_name" }, 30 | user_page_size: 10, 31 | user_page: 20, 32 | candidate_page_size: 1000, 33 | candidate_page: 0, 34 | }); 35 | 36 | const r = http.post( 37 | "https://api.wyvern.ai/api/v1/product-search-ranking", 38 | payload, 39 | params 40 | ); 41 | check(r, { "status was 200": (r) => r.status == 200 }); 42 | console.log(r.body); 43 | sleep(1); 44 | } 45 | /* 46 | { 47 | "request_id": "test_request_id", 48 | "query": {"query": "candle"}, 49 | "candidates": [ 50 | {"product_id": "0", "opensearch_score": 1000}, 51 | ], 52 | "user": {"user_id": "1234", "user_name": "user_name"}, 53 | "user_page_size": 10, 54 | "user_page": 20, 55 | "candidate_page_size": 1000, 56 | "candidate_page": 0 57 | } 58 | */ 59 | -------------------------------------------------------------------------------- /tests/scenarios/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Skyvern-AI/wyvern/a6f5f72860e6f528e81fb816fa3f4275d08bed1d/tests/scenarios/__init__.py -------------------------------------------------------------------------------- /tests/scenarios/single_entity_pipelines/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Skyvern-AI/wyvern/a6f5f72860e6f528e81fb816fa3f4275d08bed1d/tests/scenarios/single_entity_pipelines/__init__.py -------------------------------------------------------------------------------- /tests/scenarios/single_entity_pipelines/test_single_entity_pipeline.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from typing import List 3 | 4 | import pytest 5 | from fastapi.testclient import TestClient 6 | 7 | from wyvern.components.business_logic.business_logic import ( 8 | SingleEntityBusinessLogicComponent, 9 | SingleEntityBusinessLogicPipeline, 10 | SingleEntityBusinessLogicRequest, 11 | ) 12 | from wyvern.components.models.model_chain_component import SingleEntityModelChain 13 | from wyvern.components.models.model_component import SingleEntityModelComponent 14 | from wyvern.components.single_entity_pipeline import ( 15 | SingleEntityPipeline, 16 | SingleEntityPipelineResponse, 17 | ) 18 | from wyvern.entities.identifier import Identifier 19 | from wyvern.entities.identifier_entities import WyvernEntity 20 | from wyvern.entities.model_entities import ModelOutput 21 | from wyvern.entities.request import BaseWyvernRequest 22 | from wyvern.service import WyvernService 23 | 24 | 25 | class Seller(WyvernEntity): 26 | seller_id: str 27 | 28 | def generate_identifier(self) -> Identifier: 29 | return Identifier( 30 | identifier=self.seller_id, 31 | identifier_type="seller", 32 | ) 33 | 34 | 35 | class Buyer(WyvernEntity): 36 | buyer_id: str 37 | 38 | def generate_identifier(self) -> Identifier: 39 | return Identifier( 40 | identifier=self.buyer_id, 41 | identifier_type="buyer", 42 | ) 43 | 44 | 45 | class Order(WyvernEntity): 46 | order_id: str 47 | 48 | def generate_identifier(self) -> Identifier: 49 | return Identifier( 50 | identifier=self.order_id, 51 | identifier_type="order", 52 | ) 53 | 54 | 55 | class FraudRequest(BaseWyvernRequest): 56 | seller: Seller 57 | buyer: Buyer 58 | order: Order 59 | 60 | 61 | class FraudResponse(SingleEntityPipelineResponse[float]): 62 | reasons: List[str] 63 | 64 | 65 | class FraudRuleModel(SingleEntityModelComponent[FraudRequest, ModelOutput[float]]): 66 | async def inference(self, input: FraudRequest, **kwargs) -> ModelOutput[float]: 67 | return ModelOutput( 68 | data={ 69 | input.order.identifier: 1, 70 | }, 71 | ) 72 | 73 | 74 | class FraudAssessmentModel( 75 | SingleEntityModelComponent[FraudRequest, ModelOutput[float]], 76 | ): 77 | async def inference(self, input: FraudRequest, **kwargs) -> ModelOutput[float]: 78 | return ModelOutput( 79 | data={ 80 | input.order.identifier: 1, 81 | }, 82 | ) 83 | 84 | 85 | fraud_model = SingleEntityModelChain[FraudRequest, ModelOutput[float]]( 86 | FraudRuleModel(), 87 | FraudAssessmentModel(), 88 | name="fraud_model", 89 | ) 90 | 91 | 92 | class FraudBusinessLogicComponent( 93 | SingleEntityBusinessLogicComponent[FraudRequest, float], 94 | ): 95 | async def execute( 96 | self, 97 | input: SingleEntityBusinessLogicRequest[FraudRequest, float], 98 | **kwargs, 99 | ) -> float: 100 | if input.request.seller.identifier.identifier == "test_seller_new": 101 | return 0.0 102 | return input.model_output 103 | 104 | 105 | fraud_biz_pipeline = SingleEntityBusinessLogicPipeline( 106 | FraudBusinessLogicComponent(), 107 | name="fraud_biz_pipeline", 108 | ) 109 | 110 | 111 | class FraudPipeline(SingleEntityPipeline[FraudRequest, float]): 112 | PATH = "/fraud" 113 | REQUEST_SCHEMA_CLASS = FraudRequest 114 | RESPONSE_SCHEMA_CLASS = FraudResponse 115 | 116 | def generate_response( 117 | self, 118 | input: FraudRequest, 119 | pipeline_output: float, 120 | ) -> FraudResponse: 121 | if pipeline_output == 0.0: 122 | return FraudResponse( 123 | data=pipeline_output, 124 | reasons=["Fraudulent order detected!"], 125 | ) 126 | return FraudResponse( 127 | data=pipeline_output, 128 | reasons=[], 129 | ) 130 | 131 | 132 | fraud_pipeline = FraudPipeline(model=fraud_model, business_logic=fraud_biz_pipeline) 133 | 134 | 135 | @pytest.fixture 136 | def mock_redis(mocker): 137 | with mocker.patch( 138 | "wyvern.redis.wyvern_redis.mget", 139 | return_value=[], 140 | ): 141 | yield 142 | 143 | 144 | @pytest.fixture 145 | def test_client(mock_redis): 146 | wyvern_app = WyvernService.generate_app( 147 | route_components=[fraud_pipeline], 148 | ) 149 | yield TestClient(wyvern_app) 150 | 151 | 152 | def test_end_to_end(test_client): 153 | response = test_client.post( 154 | "/api/v1/fraud", 155 | json={ 156 | "request_id": "test_request_id", 157 | "seller": {"seller_id": "test_seller_id"}, 158 | "buyer": {"buyer_id": "test_buyer_id"}, 159 | "order": {"order_id": "test_order_id"}, 160 | }, 161 | ) 162 | assert response.status_code == 200 163 | assert response.json() == {"data": 1.0, "reasons": []} 164 | 165 | 166 | def test_end_to_end__new_seller(test_client): 167 | response = test_client.post( 168 | "/api/v1/fraud", 169 | json={ 170 | "request_id": "test_request_id", 171 | "seller": {"seller_id": "test_seller_new"}, 172 | "buyer": {"buyer_id": "test_buyer_id"}, 173 | "order": {"order_id": "test_order_id"}, 174 | }, 175 | ) 176 | assert response.status_code == 200 177 | assert response.json() == {"data": 0.0, "reasons": ["Fraudulent order detected!"]} 178 | -------------------------------------------------------------------------------- /tests/scenarios/test_indexation.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import pytest 3 | from fastapi.testclient import TestClient 4 | 5 | from wyvern.service import WyvernService 6 | 7 | PRODUCT_ENTITY_1 = { 8 | "product_id": "1", 9 | "product_name": "test_product1", 10 | "product_description": "test_product1_description", 11 | } 12 | PRODUCT_ENTITY_2 = { 13 | "product_id": "2", 14 | "product_name": "test_product2", 15 | "product_description": "test_product2_description", 16 | } 17 | 18 | PRODUCT_ENTITY_1_WITH_ID = { 19 | "id": "1", 20 | "product_name": "test_product1", 21 | "product_description": "test_product1_description", 22 | } 23 | PRODUCT_ENTITY_2_WITH_ID = { 24 | "id": "2", 25 | "product_name": "test_product2", 26 | "product_description": "test_product2_description", 27 | } 28 | 29 | 30 | @pytest.fixture 31 | def mock_redis(mocker): 32 | with mocker.patch( 33 | "wyvern.redis.wyvern_redis.bulk_index", 34 | return_value=["1", "2"], 35 | ), mocker.patch( 36 | "wyvern.redis.wyvern_redis.get_entity", 37 | return_value=PRODUCT_ENTITY_1, 38 | ), mocker.patch( 39 | "wyvern.redis.wyvern_redis.get_entities", 40 | return_value=[ 41 | PRODUCT_ENTITY_1, 42 | PRODUCT_ENTITY_2, 43 | ], 44 | ), mocker.patch( 45 | "wyvern.redis.wyvern_redis.delete_entity", 46 | ), mocker.patch( 47 | "wyvern.redis.wyvern_redis.delete_entities", 48 | ): 49 | yield 50 | 51 | 52 | @pytest.fixture 53 | def test_client(mock_redis): 54 | wyvern_service = WyvernService.generate() 55 | yield TestClient(wyvern_service.service.app) 56 | 57 | 58 | @pytest.mark.asyncio 59 | async def test_product_upload(test_client): 60 | response = test_client.post( 61 | "/api/v1/entities/upload", 62 | json={ 63 | "entities": [ 64 | PRODUCT_ENTITY_1, 65 | PRODUCT_ENTITY_2, 66 | ], 67 | "entity_type": "product", 68 | }, 69 | ) 70 | assert response.status_code == 200 71 | assert response.json() == { 72 | "entity_type": "product", 73 | "entity_ids": ["1", "2"], 74 | } 75 | 76 | 77 | @pytest.mark.asyncio 78 | async def test_product_upload__with_different_entity_key(test_client): 79 | response = test_client.post( 80 | "/api/v1/entities/upload", 81 | json={ 82 | "entities": [ 83 | PRODUCT_ENTITY_1_WITH_ID, 84 | PRODUCT_ENTITY_2_WITH_ID, 85 | ], 86 | "entity_type": "product", 87 | "entity_key": "id", 88 | }, 89 | ) 90 | assert response.status_code == 200 91 | assert response.json() == { 92 | "entity_type": "product", 93 | "entity_ids": ["1", "2"], 94 | } 95 | 96 | 97 | @pytest.mark.asyncio 98 | async def test_get_products(test_client): 99 | response = test_client.post( 100 | "/api/v1/entities/get", 101 | json={ 102 | "entity_ids": ["1", "2"], 103 | "entity_type": "product", 104 | }, 105 | ) 106 | assert response.status_code == 200 107 | assert response.json() == { 108 | "entity_type": "product", 109 | "entities": { 110 | "1": PRODUCT_ENTITY_1, 111 | "2": PRODUCT_ENTITY_2, 112 | }, 113 | } 114 | 115 | 116 | @pytest.mark.asyncio 117 | async def test_delete_products(test_client): 118 | response = test_client.post( 119 | "/api/v1/entities/delete", 120 | json={ 121 | "entity_ids": ["1", "2"], 122 | "entity_type": "product", 123 | }, 124 | ) 125 | assert response.status_code == 200 126 | assert response.json() == { 127 | "entity_type": "product", 128 | "entity_ids": ["1", "2"], 129 | } 130 | -------------------------------------------------------------------------------- /wyvern/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from wyvern.components.features.realtime_features_component import ( 3 | RealtimeFeatureComponent, 4 | ) 5 | from wyvern.components.models.model_chain_component import SingleEntityModelChain 6 | from wyvern.components.models.model_component import ( 7 | ModelComponent, 8 | MultiEntityModelComponent, 9 | SingleEntityModelComponent, 10 | ) 11 | from wyvern.components.pipeline_component import PipelineComponent 12 | from wyvern.components.ranking_pipeline import ( 13 | RankingPipeline, 14 | RankingRequest, 15 | RankingResponse, 16 | ) 17 | from wyvern.components.single_entity_pipeline import ( 18 | SingleEntityPipeline, 19 | SingleEntityPipelineResponse, 20 | ) 21 | from wyvern.entities.candidate_entities import CandidateSetEntity 22 | from wyvern.entities.identifier import CompositeIdentifier, Identifier, IdentifierType 23 | from wyvern.entities.identifier_entities import ( 24 | ProductEntity, 25 | QueryEntity, 26 | UserEntity, 27 | WyvernDataModel, 28 | WyvernEntity, 29 | ) 30 | from wyvern.entities.model_entities import ChainedModelInput, ModelInput, ModelOutput 31 | from wyvern.feature_store.feature_server import generate_wyvern_store_app 32 | from wyvern.service import WyvernService 33 | from wyvern.wyvern_logging import setup_logging 34 | from wyvern.wyvern_tracing import setup_tracing 35 | from wyvern.wyvern_typing import WyvernFeature 36 | 37 | setup_logging() 38 | setup_tracing() 39 | 40 | 41 | __all__ = [ 42 | "generate_wyvern_store_app", 43 | "CandidateSetEntity", 44 | "ChainedModelInput", 45 | "CompositeIdentifier", 46 | "Identifier", 47 | "IdentifierType", 48 | "ModelComponent", 49 | "ModelInput", 50 | "ModelOutput", 51 | "MultiEntityModelComponent", 52 | "PipelineComponent", 53 | "ProductEntity", 54 | "QueryEntity", 55 | "RankingPipeline", 56 | "RankingResponse", 57 | "RankingRequest", 58 | "RealtimeFeatureComponent", 59 | "SingleEntityModelChain", 60 | "SingleEntityModelComponent", 61 | "SingleEntityPipeline", 62 | "SingleEntityPipelineResponse", 63 | "UserEntity", 64 | "WyvernDataModel", 65 | "WyvernEntity", 66 | "WyvernFeature", 67 | "WyvernService", 68 | ] 69 | -------------------------------------------------------------------------------- /wyvern/aws/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Skyvern-AI/wyvern/a6f5f72860e6f528e81fb816fa3f4275d08bed1d/wyvern/aws/__init__.py -------------------------------------------------------------------------------- /wyvern/aws/kinesis.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import logging 3 | import traceback 4 | from enum import Enum 5 | from typing import Callable, List 6 | 7 | import boto3 8 | from ddtrace import tracer 9 | from pydantic import BaseModel 10 | 11 | from wyvern.config import settings 12 | 13 | logger = logging.getLogger(__name__) 14 | 15 | CHUNK_SIZE = 100 16 | 17 | 18 | class KinesisFirehoseStream(str, Enum): 19 | """ 20 | Enum for Kinesis Firehose stream names 21 | 22 | Usage: 23 | ``` 24 | >>> KinesisFirehoseStream.EVENT_STREAM.get_stream_name() 25 | ``` 26 | """ 27 | 28 | EVENT_STREAM = "event-stream" 29 | 30 | def get_stream_name( 31 | self, 32 | customer_specific: bool = True, 33 | env_specific: bool = True, 34 | ) -> str: 35 | """ 36 | Returns the stream name for the given stream 37 | 38 | Args: 39 | customer_specific: Whether the stream name should be customer specific 40 | env_specific: Whether the stream name should be environment specific 41 | 42 | Returns: 43 | The stream name 44 | """ 45 | stream_name = self.value 46 | if customer_specific: 47 | stream_name = f"{settings.PROJECT_NAME}-{stream_name}" 48 | 49 | if env_specific: 50 | env_name = settings.ENVIRONMENT 51 | stream_name = f"{stream_name}-{env_name}" 52 | 53 | return stream_name 54 | 55 | 56 | class WyvernKinesisFirehose: 57 | """ 58 | Wrapper around boto3 Kinesis Firehose client 59 | """ 60 | 61 | def __init__(self): 62 | self.firehose_client = boto3.client( 63 | "firehose", 64 | aws_access_key_id=settings.AWS_ACCESS_KEY_ID, 65 | aws_secret_access_key=settings.AWS_SECRET_ACCESS_KEY, 66 | region_name=settings.AWS_REGION_NAME, 67 | ) 68 | 69 | def put_record_batch_callable( 70 | self, 71 | stream_name: KinesisFirehoseStream, 72 | record_generator: List[Callable[[], List[BaseModel]]], 73 | ): 74 | """ 75 | Puts records to the given stream. This is a callable that can be used with FastAPI's BackgroundTasks. This 76 | way events can be logged asynchronously after the response is sent to the client. 77 | 78 | Args: 79 | stream_name (KinesisFirehoseStream): The stream to put records to 80 | record_generator (List[Callable[[], List[BaseModel]]]): A list of functions that return a list of records 81 | 82 | Returns: 83 | None 84 | """ 85 | with tracer.trace("flush_records_to_kinesis_firehose"): 86 | records = [ 87 | record 88 | for record_generator in record_generator 89 | for record in record_generator() 90 | ] 91 | self.put_record_batch(stream_name, records) 92 | 93 | def put_record_batch( 94 | self, 95 | stream_name: KinesisFirehoseStream, 96 | records: List[BaseModel], 97 | ): 98 | """ 99 | Puts records to the given stream 100 | 101 | Args: 102 | stream_name (KinesisFirehoseStream): The stream to put records to 103 | records (List[BaseModel]): A list of records 104 | 105 | Returns: 106 | None 107 | """ 108 | if not records: 109 | return 110 | dict_records = [{"Data": record.json()} for record in records] 111 | 112 | record_chunks = [ 113 | dict_records[i : (i + CHUNK_SIZE)] 114 | for i in range(0, len(dict_records), CHUNK_SIZE) 115 | ] 116 | for chunk in record_chunks: 117 | if settings.EVENT_LOGGING_ENABLED and settings.ENVIRONMENT != "development": 118 | try: 119 | self.firehose_client.put_record_batch( 120 | DeliveryStreamName=stream_name.get_stream_name(), 121 | Records=chunk, 122 | ) 123 | except Exception: 124 | logger.exception( 125 | "Failed to put records to kinesis firehose", 126 | traceback.format_exc(), 127 | ) 128 | else: 129 | logger.debug( 130 | "Logging disabled. Not sending records to Kinesis Firehose. Records: {chunk}", 131 | ) 132 | 133 | 134 | wyvern_kinesis_firehose = WyvernKinesisFirehose() 135 | -------------------------------------------------------------------------------- /wyvern/cli/commands.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import getpass 3 | import importlib 4 | import os 5 | import platform 6 | import shutil 7 | import subprocess 8 | import sys 9 | import zipfile 10 | from pathlib import Path 11 | from typing import Optional 12 | 13 | import requests 14 | import typer 15 | import uvicorn 16 | from humps.main import decamelize 17 | from platformdirs import user_data_dir 18 | from typing_extensions import Annotated 19 | 20 | from wyvern import tracking 21 | 22 | REDIS_VERSION = "redis-6.2.9" 23 | app = typer.Typer() 24 | WYVERN_TEMPLATE_URL = "https://codeload.github.com/Wyvern-AI/wyvern-starter/zip/main" 25 | REDIS_URL = f"https://download.redis.io/releases/{REDIS_VERSION}.tar.gz" 26 | REDIS_DIR = os.path.join(user_data_dir("redis_cli"), "redis") 27 | REDIS_BIN = os.path.join(REDIS_DIR, REDIS_VERSION, "src", "redis-server") 28 | 29 | 30 | def _replace_info(project: str, author: Optional[str] = None): 31 | toml_file_path = os.path.join(project, "pyproject.toml") 32 | if os.path.isfile(toml_file_path): 33 | author = author or getpass.getuser() or "" 34 | with open(toml_file_path, "r") as file: 35 | file_content = file.read() 36 | file_content = file_content.replace("wyvern-starter", project) 37 | file_content = file_content.replace('authors = [""]', f'authors = ["{author}"]') 38 | with open(toml_file_path, "w") as file: 39 | file.write(file_content) 40 | 41 | 42 | def is_redis_installed(): 43 | return os.path.exists(REDIS_BIN) 44 | 45 | 46 | def is_redis_running(): 47 | if platform.system().lower() == "windows": 48 | try: 49 | subprocess.run( 50 | ["tasklist"], 51 | stdout=subprocess.PIPE, 52 | stderr=subprocess.PIPE, 53 | check=True, 54 | ) 55 | except subprocess.CalledProcessError: 56 | return False 57 | return ( 58 | "redis-server.exe" 59 | in subprocess.run(["tasklist"], stdout=subprocess.PIPE).stdout.decode() 60 | ) 61 | 62 | else: # For Unix-like systems (Linux, macOS) 63 | try: 64 | subprocess.run( 65 | ["ps", "-A"], 66 | stdout=subprocess.PIPE, 67 | stderr=subprocess.PIPE, 68 | check=True, 69 | ) 70 | except subprocess.CalledProcessError: 71 | return False 72 | return ( 73 | "redis-server" 74 | in subprocess.run(["ps", "-A"], stdout=subprocess.PIPE).stdout.decode() 75 | ) 76 | 77 | 78 | def try_install_redis(): 79 | if is_redis_installed(): 80 | typer.echo("Redis is already installed.") 81 | return 82 | 83 | typer.echo(f"Installing Redis in {REDIS_DIR}...") 84 | os.makedirs(REDIS_DIR, exist_ok=True) 85 | subprocess.run( 86 | ["curl", "-L", REDIS_URL, "-o", os.path.join(REDIS_DIR, "redis.tar.gz")], 87 | ) 88 | subprocess.run( 89 | ["tar", "xzvf", os.path.join(REDIS_DIR, "redis.tar.gz")], 90 | cwd=REDIS_DIR, 91 | ) 92 | subprocess.run(["make"], cwd=os.path.join(REDIS_DIR, REDIS_VERSION)) 93 | shutil.move( 94 | os.path.join(REDIS_DIR, REDIS_VERSION, "src", "redis-server"), 95 | REDIS_BIN, 96 | ) 97 | 98 | 99 | @app.command() 100 | def init( 101 | project: str = typer.Argument(..., help="Name of the project"), 102 | ) -> None: 103 | """ 104 | Initializes Wyvern application template code 105 | 106 | Args: 107 | project (str): Name of the project 108 | """ 109 | 110 | # decamelize project name first 111 | project = decamelize(project) 112 | 113 | tracking.capture(event="oss_init_start") 114 | typer.echo("Initializing Wyvern application template code...") 115 | 116 | # validate project name 117 | if "/" in project: 118 | typer.echo("Error: Invalid project name. Project name cannot contain '/'") 119 | return 120 | 121 | if Path(project).exists(): 122 | typer.echo(f"Error: Destination path '{project}' already exists.") 123 | return 124 | 125 | response = requests.get(WYVERN_TEMPLATE_URL) 126 | 127 | if response.status_code != 200: 128 | typer.echo(f"Error: Unable to download code from {WYVERN_TEMPLATE_URL}") 129 | return 130 | 131 | with open("temp.zip", "wb") as temp_zip: 132 | temp_zip.write(response.content) 133 | 134 | with zipfile.ZipFile("temp.zip", "r") as zip_ref: 135 | zip_ref.extractall(project) 136 | 137 | os.remove("temp.zip") 138 | # Flatten the extracted content into the destination directory 139 | extracted_dir = os.path.join(project, os.listdir(project)[0]) 140 | for item in os.listdir(extracted_dir): 141 | item_path = os.path.join(extracted_dir, item) 142 | if os.path.isfile(item_path) or os.path.isdir(item_path): 143 | shutil.move(item_path, os.path.join(project, item)) 144 | shutil.rmtree(extracted_dir) 145 | 146 | # add a .env file to the new repository with ENVIRONMENT=development 147 | with open(os.path.join(project, ".env"), "w") as env_file: 148 | env_file.write("ENVIRONMENT=development\n") 149 | 150 | tracking.capture(event="oss_init_succeed") 151 | typer.echo( 152 | f"Successfully initialized Wyvern application template code in {project}", 153 | ) 154 | 155 | 156 | @app.command() 157 | def run( 158 | path: str = "pipelines.main:app", 159 | host: Annotated[ 160 | str, 161 | typer.Option(help="Host to run the application on"), 162 | ] = "0.0.0.0", 163 | port: Annotated[ 164 | int, 165 | typer.Option(help="Port to run the application on. Default port is 5001"), 166 | ] = 5001, 167 | ) -> None: 168 | """ 169 | Starts Wyvern application server 170 | 171 | Example usage: 172 | wyvern run --path pipelines.main:app --host 0.0.0.0 --port 5001 173 | 174 | Args: 175 | path (str): path to the wyvern app. Default path is pipelines.main:app 176 | host (str): Host to run the application on. Default host is 0.0.0.0 177 | port (int): Port to run the application on. Default port is 5001 178 | """ 179 | tracking.capture(event="oss_run_start") 180 | typer.echo("Running your ML application") 181 | # import the app from path 182 | try: 183 | sys.path.append(".") 184 | module_path, app_name = path.split(":") 185 | module = importlib.import_module(module_path) 186 | except ImportError: 187 | tracking.capture(event="oss_run_failed_import") 188 | typer.echo(f"Failed to import {path}") 189 | raise 190 | fastapi_app = getattr(module, app_name) 191 | config = uvicorn.Config( 192 | fastapi_app, 193 | host=host, 194 | port=port, 195 | ) 196 | uvicorn_server = uvicorn.Server(config=config) 197 | tracking.capture(event="oss_run_succeed") 198 | uvicorn_server.run() 199 | 200 | 201 | @app.command() 202 | def redis() -> None: 203 | """Starts Redis server. This command will also install redis locally if it's not installed.""" 204 | tracking.capture(event="oss_redis_start") 205 | try_install_redis() 206 | 207 | if is_redis_running(): 208 | typer.echo("Redis is already running.") 209 | return 210 | 211 | typer.echo("Starting Redis...") 212 | tracking.capture(event="oss_redis_succeed") 213 | subprocess.run([REDIS_BIN]) 214 | -------------------------------------------------------------------------------- /wyvern/clients/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Skyvern-AI/wyvern/a6f5f72860e6f528e81fb816fa3f4275d08bed1d/wyvern/clients/__init__.py -------------------------------------------------------------------------------- /wyvern/clients/snowflake.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import snowflake.connector 3 | 4 | from wyvern.config import settings 5 | 6 | 7 | def generate_snowflake_ctx() -> snowflake.connector.SnowflakeConnection: 8 | """ 9 | Generate a Snowflake context from the settings 10 | """ 11 | return snowflake.connector.connect( 12 | user=settings.SNOWFLAKE_USER, 13 | password=settings.SNOWFLAKE_PASSWORD, 14 | role=settings.SNOWFLAKE_ROLE, 15 | account=settings.SNOWFLAKE_ACCOUNT, 16 | warehouse=settings.SNOWFLAKE_WAREHOUSE, 17 | database=settings.SNOWFLAKE_DATABASE, 18 | schema=settings.SNOWFLAKE_OFFLINE_STORE_SCHEMA, 19 | ) 20 | -------------------------------------------------------------------------------- /wyvern/components/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Skyvern-AI/wyvern/a6f5f72860e6f528e81fb816fa3f4275d08bed1d/wyvern/components/__init__.py -------------------------------------------------------------------------------- /wyvern/components/api_route_component.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from collections import deque 3 | from typing import Deque, List, Optional, Tuple, Type 4 | 5 | from ddtrace import tracer 6 | 7 | from wyvern import request_context 8 | from wyvern.components.component import Component 9 | from wyvern.entities.identifier import Identifier 10 | from wyvern.entities.identifier_entities import WyvernDataModel, WyvernEntity 11 | from wyvern.redis import wyvern_redis 12 | from wyvern.wyvern_typing import REQUEST_SCHEMA, RESPONSE_SCHEMA 13 | 14 | 15 | class APIRouteComponent(Component[REQUEST_SCHEMA, RESPONSE_SCHEMA]): 16 | """ 17 | APIRouteComponent is the base class for all the API routes in Wyvern. It is a Component that 18 | takes in a request schema and a response schema, and it is responsible for hydrating the request 19 | data with Wyvern Index data, and then pass the hydrated data to the next component in the pipeline. 20 | 21 | The APIRouteComponent is also responsible for the API routing, which means it is responsible for 22 | the API versioning and the API path. 23 | 24 | Attributes: 25 | API_VERSION: the version of the API. This is used in the API routing. The default value is "v1". 26 | PATH: the path of the API. This is used in the API routing. 27 | REQUEST_SCHEMA_CLASS: the class of the request schema. This is used to validate the request data. 28 | RESPONSE_SCHEMA_CLASS: the class of the response schema. This is used to validate the response data. 29 | API_NAME: the name of the API. This is used in the API routing. If not provided, the name of the 30 | APIRouteComponent will be used. 31 | """ 32 | 33 | # this is the api version 34 | API_VERSION: str = "v1" 35 | 36 | # this is the path for the API route 37 | PATH: str 38 | # this is the class of request schema represented by pydantic BaseModel 39 | REQUEST_SCHEMA_CLASS: Type[REQUEST_SCHEMA] 40 | # this is the class of response schema represented by pydantic BaseModel 41 | RESPONSE_SCHEMA_CLASS: Type[RESPONSE_SCHEMA] 42 | 43 | API_NAME: str = "" 44 | 45 | def __init__(self, *upstreams: Component, name: Optional[str] = None) -> None: 46 | super().__init__(*upstreams, name=name) 47 | self.api_name = self.API_NAME or self.name 48 | 49 | async def warm_up(self, input: REQUEST_SCHEMA) -> None: 50 | """ 51 | This is the warm-up function that is called before the API route is called. 52 | """ 53 | # TODO shu: hydrate 54 | await self.hydrate(input) 55 | return 56 | 57 | @tracer.wrap(name="APIRouteComponent.hydrate") 58 | async def hydrate(self, input: REQUEST_SCHEMA) -> None: 59 | """ 60 | Wyvern APIRouteComponent recursively hydrate the request input data with Wyvern Index data 61 | 62 | TODO: this function could be moved to a global place 63 | """ 64 | if not isinstance(input, WyvernDataModel): 65 | return 66 | # use BFS to go through the input pydantic model 67 | # hydrate the data for each WyvernEntity that is encountered layer by layer if there are nested WyvernEntity 68 | identifiers: List[Identifier] = input.get_all_identifiers(cached=False) 69 | queue: Deque[WyvernDataModel] = deque([input]) 70 | while identifiers and queue: 71 | identifiers, queue = await self._bfs_hydrate(identifiers, queue) 72 | 73 | async def _bfs_hydrate( 74 | self, 75 | identifiers: List[Identifier], 76 | queue: Deque[WyvernDataModel], 77 | ) -> Tuple[List[Identifier], Deque[WyvernDataModel]]: 78 | """ 79 | This is a helper function for hydrate. It does a BFS on the input WyvernDataModel and hydrate the data. 80 | 81 | Args: 82 | identifiers: a list of identifiers that need to be hydrated 83 | queue: a queue of WyvernDataModel that need to be hydrated 84 | 85 | Returns: 86 | The next level identifiers and the next level queue 87 | """ 88 | current_request = request_context.ensure_current_request() 89 | 90 | # load all the entities from Wyvern Index to self.entity_store 91 | index_keys = [identifier.index_key() for identifier in identifiers] 92 | 93 | # we're doing an in place update for the entity_store here to save redundant iterations 94 | # and improve the performance of the code 95 | await wyvern_redis.mget_update_in_place(index_keys, current_request) 96 | 97 | next_level_queue: Deque[WyvernDataModel] = deque([]) 98 | next_level_identifiers: List[Identifier] = [] 99 | while queue: 100 | current_obj = queue.popleft() 101 | # go through all the fields of the current object, and add WyvernDataModel to the queue 102 | for field in current_obj.__fields__: 103 | value = getattr(current_obj, field) 104 | if isinstance(value, WyvernDataModel): 105 | queue.append(value) 106 | if isinstance(value, List): 107 | # if the field is a list, we need to check each item in the list 108 | # to make sure WyvernDataModel items are enqueued 109 | for item in value: 110 | if isinstance(item, WyvernDataModel): 111 | queue.append(item) 112 | 113 | if isinstance(current_obj, WyvernEntity): 114 | # if the current node is a WyvernEntity, 115 | # we need to hydrate the data if the entity exists in the index 116 | # get the entity from wyvern index 117 | index_key = current_obj.identifier.index_key() 118 | entity = current_request.entity_store.get(index_key) 119 | 120 | # load the data into the entity 121 | if entity: 122 | current_obj.load_fields(entity) 123 | 124 | # generate the next level queue 125 | for ( 126 | id_field_name, 127 | entity_field_name, 128 | ) in current_obj.nested_hydration().items(): 129 | id_field_value = getattr(current_obj, id_field_name) 130 | if not id_field_value: 131 | continue 132 | entity_class: Type[WyvernEntity] = current_obj.__fields__[ 133 | entity_field_name 134 | ].type_ 135 | entity_obj = entity_class(**{id_field_name: id_field_value}) 136 | setattr( 137 | current_obj, 138 | entity_field_name, 139 | entity_obj, 140 | ) 141 | next_level_identifiers.append(entity_obj.identifier) 142 | next_level_queue.append(getattr(current_obj, entity_field_name)) 143 | return next_level_identifiers, next_level_queue 144 | -------------------------------------------------------------------------------- /wyvern/components/business_logic/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Skyvern-AI/wyvern/a6f5f72860e6f528e81fb816fa3f4275d08bed1d/wyvern/components/business_logic/__init__.py -------------------------------------------------------------------------------- /wyvern/components/business_logic/pinning_business_logic.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import logging 3 | from typing import Callable, Dict, Generic, List, Optional 4 | 5 | from wyvern.components.business_logic.business_logic import BusinessLogicComponent 6 | from wyvern.components.component import Component 7 | from wyvern.entities.candidate_entities import ( 8 | GENERALIZED_WYVERN_ENTITY, 9 | ScoredCandidate, 10 | ) 11 | from wyvern.wyvern_typing import REQUEST_ENTITY 12 | 13 | logger = logging.getLogger(__name__) 14 | 15 | 16 | class PinningBusinessLogicComponent( 17 | BusinessLogicComponent[GENERALIZED_WYVERN_ENTITY, REQUEST_ENTITY], 18 | Generic[GENERALIZED_WYVERN_ENTITY, REQUEST_ENTITY], 19 | ): 20 | """ 21 | A component that performs boosting on an entity with a set of candidates 22 | 23 | The request itself could contain more than just entities, for example it may contain a query and so on 24 | """ 25 | 26 | def __init__(self, *upstreams: Component): 27 | super().__init__(*upstreams, name=self.__class__.__name__) 28 | 29 | def pin( 30 | self, 31 | scored_candidates: List[ScoredCandidate[GENERALIZED_WYVERN_ENTITY]], 32 | entity_pins: Dict[str, int], 33 | entity_key_mapping: Callable[ 34 | [GENERALIZED_WYVERN_ENTITY], 35 | str, 36 | ] = lambda candidate: candidate.identifier.identifier, 37 | allow_down_ranking: bool = False, 38 | ) -> List[ScoredCandidate[GENERALIZED_WYVERN_ENTITY]]: 39 | """ 40 | Pins the supplied entity to the specific position 41 | 42 | Args: 43 | scored_candidates: The list of scored candidates 44 | entity_pins: The map of entity keys (unique identifiers) to pin, and their pinning position 45 | entity_key_mapping: A lambda function that takes in a candidate entity and 46 | returns the field we should apply the pin to 47 | allow_down_ranking: Whether to allow down-ranking of candidates that are not pinned 48 | 49 | Returns: 50 | The list of scored candidates with the pinned entities 51 | """ 52 | applied_pins_score: Dict[int, float] = {} 53 | re_scored_candidates: List[ScoredCandidate[GENERALIZED_WYVERN_ENTITY]] = [] 54 | for index, candidate in enumerate(scored_candidates): 55 | current_position = len(scored_candidates) - index 56 | pin_candidate_new_position: Optional[int] = None 57 | # Determine which desired pins are applicable for this candidate set 58 | entity_key = entity_key_mapping(candidate.entity) 59 | if entity_key in entity_pins: 60 | desired_position = entity_pins[entity_key] 61 | 62 | if allow_down_ranking or desired_position < current_position: 63 | pin_candidate_new_position = min( 64 | desired_position, 65 | len(scored_candidates) - 1, 66 | ) 67 | 68 | if pin_candidate_new_position is None: 69 | re_scored_candidates.append(candidate) 70 | continue 71 | 72 | pinned_score = self._get_pinned_score( 73 | applied_pins_score, 74 | candidate, 75 | pin_candidate_new_position, 76 | scored_candidates, 77 | ) 78 | 79 | re_scored_candidates.append( 80 | ScoredCandidate(entity=candidate.entity, score=pinned_score), 81 | ) 82 | 83 | self._update_applied_pins_score( 84 | applied_pins_score, 85 | pin_candidate_new_position, 86 | pinned_score, 87 | ) 88 | 89 | return re_scored_candidates 90 | 91 | def _update_applied_pins_score( 92 | self, 93 | applied_pins_score: Dict[int, float], 94 | current_position: int, 95 | new_score: float, 96 | ): 97 | """ 98 | Updates the applied pins score dictionary with the new score for the given position 99 | 100 | Args: 101 | applied_pins_score: The dictionary of applied pins score 102 | current_position: The current position to update 103 | new_score: The new score to apply 104 | """ 105 | if current_position in applied_pins_score: 106 | # This means this position already had a pin applied to it.. so we need to update the position 107 | existing_pin_score = applied_pins_score[current_position] 108 | if existing_pin_score > new_score: 109 | # new_score is smaller, so let's update our memory of the pin to occupy the previous position 110 | self._update_applied_pins_score( 111 | applied_pins_score, 112 | current_position - 1, 113 | existing_pin_score, 114 | ) 115 | else: 116 | # new_score is higher, so let's update our memory of the pin to occupy the next position 117 | self._update_applied_pins_score( 118 | applied_pins_score, 119 | current_position + 1, 120 | existing_pin_score, 121 | ) 122 | 123 | applied_pins_score[current_position] = new_score 124 | 125 | def _get_pinned_score( 126 | self, 127 | applied_pins_score: Dict[int, float], 128 | candidate: ScoredCandidate[GENERALIZED_WYVERN_ENTITY], 129 | pin_candidate_new_position: int, 130 | scored_candidates: List[ScoredCandidate[GENERALIZED_WYVERN_ENTITY]], 131 | ) -> float: 132 | """ 133 | Gets the score for the pinned candidate 134 | 135 | Args: 136 | applied_pins_score: The dictionary of applied pins score 137 | candidate: The candidate to pin 138 | pin_candidate_new_position: The new position to pin the candidate to 139 | scored_candidates: The list of scored candidates 140 | 141 | Returns: 142 | The score for the pinned candidate 143 | """ 144 | if pin_candidate_new_position >= len(scored_candidates) - 1: 145 | # Pinned position is outside or at the bottom of the candidate set, 146 | # subtract current score from the lowest score 147 | # If there are multiple pins at the bottom, it will respect their relative score 148 | # The reciprocal is used to ensure that higher scored products end up having a higher final score 149 | return scored_candidates[-1].score - (1.0 / candidate.score) 150 | elif pin_candidate_new_position == 0: 151 | # Pinned position is at the top of the candidate set -- add the highest score to the current candidate score 152 | # If there are multiple pins at position 1, it will currently respect their relative score 153 | # This makes sense to me, but we can change it if we want 154 | return scored_candidates[0].score + candidate.score 155 | else: 156 | # Average the scores of the candidates on either side of the pinned position 157 | left_position = pin_candidate_new_position 158 | right_position = pin_candidate_new_position + 1 159 | left_side_score = ( 160 | scored_candidates[left_position].score 161 | if left_position not in applied_pins_score 162 | else applied_pins_score[left_position] 163 | ) 164 | right_side_score = ( 165 | scored_candidates[right_position].score 166 | if right_position not in applied_pins_score 167 | else applied_pins_score[right_position] 168 | ) 169 | logger.debug( 170 | f"applied_pins_score={applied_pins_score} candidate={candidate.entity.get_all_identifiers()} " 171 | f"pin_candidate_new_position={pin_candidate_new_position} " 172 | f"left_side_score={left_side_score} right_side_score={right_side_score}", 173 | ) 174 | return (left_side_score + right_side_score) / 2 175 | -------------------------------------------------------------------------------- /wyvern/components/candidates/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Skyvern-AI/wyvern/a6f5f72860e6f528e81fb816fa3f4275d08bed1d/wyvern/components/candidates/__init__.py -------------------------------------------------------------------------------- /wyvern/components/candidates/candidate_logger.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from datetime import datetime 3 | from typing import Generic, List 4 | 5 | from ddtrace import tracer 6 | from pydantic.generics import GenericModel 7 | 8 | from wyvern import request_context 9 | from wyvern.components.component import Component 10 | from wyvern.components.events.events import EntityEventData, EventType, LoggedEvent 11 | from wyvern.entities.candidate_entities import ( 12 | GENERALIZED_WYVERN_ENTITY, 13 | ScoredCandidate, 14 | ) 15 | from wyvern.event_logging import event_logger 16 | from wyvern.wyvern_typing import REQUEST_ENTITY 17 | 18 | 19 | class CandidateEventData(EntityEventData): 20 | """ 21 | Event data for a candidate event 22 | 23 | Attributes: 24 | candidate_score: The score of the candidate 25 | candidate_order: The order of the candidate in the list of candidates 26 | """ 27 | 28 | candidate_score: float 29 | candidate_order: int 30 | 31 | 32 | class CandidateEvent(LoggedEvent[CandidateEventData]): 33 | event_type: EventType = EventType.CANDIDATE 34 | 35 | 36 | class CandidateEventLoggingRequest( 37 | GenericModel, 38 | Generic[GENERALIZED_WYVERN_ENTITY, REQUEST_ENTITY], 39 | ): 40 | request: REQUEST_ENTITY 41 | scored_candidates: List[ScoredCandidate[GENERALIZED_WYVERN_ENTITY]] 42 | 43 | 44 | class CandidateEventLoggingComponent( 45 | Component[ 46 | CandidateEventLoggingRequest[GENERALIZED_WYVERN_ENTITY, REQUEST_ENTITY], 47 | None, 48 | ], 49 | Generic[GENERALIZED_WYVERN_ENTITY, REQUEST_ENTITY], 50 | ): 51 | @tracer.wrap(name="CandidateEventLoggingComponent.execute") 52 | async def execute( 53 | self, 54 | input: CandidateEventLoggingRequest[GENERALIZED_WYVERN_ENTITY, REQUEST_ENTITY], 55 | **kwargs 56 | ) -> None: 57 | current_span = tracer.current_span() 58 | if current_span: 59 | current_span.set_tag("candidate_size", len(input.scored_candidates)) 60 | wyvern_request = request_context.ensure_current_request() 61 | url_path = wyvern_request.url_path 62 | run_id = wyvern_request.run_id 63 | 64 | def candidate_events_generator() -> List[CandidateEvent]: 65 | timestamp = datetime.utcnow() 66 | candidate_events = [ 67 | CandidateEvent( 68 | request_id=input.request.request_id, 69 | run_id=run_id, 70 | api_source=url_path, 71 | event_timestamp=timestamp, 72 | event_data=CandidateEventData( 73 | entity_identifier=candidate.entity.identifier.identifier, 74 | entity_identifier_type=candidate.entity.identifier.identifier_type, 75 | candidate_score=candidate.score, 76 | candidate_order=i, 77 | ), 78 | ) 79 | for i, candidate in enumerate(input.scored_candidates) 80 | ] 81 | return candidate_events 82 | 83 | event_logger.log_events(candidate_events_generator) # type: ignore 84 | -------------------------------------------------------------------------------- /wyvern/components/events/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Skyvern-AI/wyvern/a6f5f72860e6f528e81fb816fa3f4275d08bed1d/wyvern/components/events/__init__.py -------------------------------------------------------------------------------- /wyvern/components/events/events.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from datetime import datetime 3 | from enum import Enum 4 | from typing import Generic, Optional, TypeVar 5 | 6 | from pydantic import BaseModel, Field 7 | from pydantic.generics import GenericModel 8 | 9 | from wyvern import request_context 10 | 11 | EVENT_DATA = TypeVar("EVENT_DATA", bound=BaseModel) 12 | 13 | 14 | def generate_run_id() -> str: 15 | curr_wyvern_request = request_context.current() 16 | if curr_wyvern_request is None: 17 | return str(0) 18 | return str(curr_wyvern_request.run_id) 19 | 20 | 21 | class EventType(str, Enum): 22 | """Enum for the different types of events that can be logged.""" 23 | 24 | BUSINESS_LOGIC = "BUSINESS_LOGIC" 25 | CANDIDATE = "CANDIDATE" 26 | FEATURE = "FEATURE" 27 | MODEL = "MODEL" 28 | IMPRESSION = "IMPRESSION" 29 | EXPERIMENTATION = "EXPERIMENTATION" 30 | CUSTOM = "CUSTOM" 31 | 32 | 33 | class LoggedEvent(GenericModel, Generic[EVENT_DATA]): 34 | """Base class for all logged events. 35 | 36 | Attributes: 37 | request_id: The request ID of the request that triggered the event. 38 | api_source: The API source of the request that triggered the event. 39 | event_timestamp: The timestamp of the event. 40 | event_type: The type of the event. 41 | event_data: The data associated with the event. This is a generic type that can be any subclass of BaseModel. 42 | """ 43 | 44 | request_id: Optional[str] 45 | api_source: Optional[str] 46 | event_timestamp: Optional[datetime] 47 | event_type: EventType 48 | event_data: EVENT_DATA 49 | run_id: str = Field(default_factory=generate_run_id) 50 | 51 | 52 | class EntityEventData(BaseModel): 53 | """Base class for all entity event data. 54 | 55 | Attributes: 56 | entity_identifier: The identifier of the entity that the event is associated with. 57 | entity_identifier_type: The type of the entity identifier. 58 | """ 59 | 60 | entity_identifier: str 61 | entity_identifier_type: str 62 | 63 | 64 | class CustomEntityEventData(EntityEventData): 65 | event_name: str 66 | 67 | 68 | ENTITY_EVENT_DATA_TYPE = TypeVar("ENTITY_EVENT_DATA_TYPE", bound=EntityEventData) 69 | 70 | 71 | class CustomEvent(LoggedEvent[ENTITY_EVENT_DATA_TYPE]): 72 | """Class for custom events. Custom event data must be a subclass of EntityEventData. 73 | 74 | Attributes: 75 | event_type: The type of the event. This is always EventType.CUSTOM. 76 | """ 77 | 78 | event_type: EventType = EventType.CUSTOM 79 | -------------------------------------------------------------------------------- /wyvern/components/features/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Skyvern-AI/wyvern/a6f5f72860e6f528e81fb816fa3f4275d08bed1d/wyvern/components/features/__init__.py -------------------------------------------------------------------------------- /wyvern/components/features/feature_logger.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from datetime import datetime 3 | from typing import Generic 4 | 5 | from pydantic.generics import GenericModel 6 | from pydantic.main import BaseModel 7 | 8 | from wyvern import request_context 9 | from wyvern.components.component import Component 10 | from wyvern.components.events.events import EventType, LoggedEvent 11 | from wyvern.entities.feature_entities import IDENTIFIER 12 | from wyvern.event_logging import event_logger 13 | from wyvern.wyvern_typing import REQUEST_ENTITY, WyvernFeature 14 | 15 | 16 | class FeatureLogEventData(BaseModel): 17 | """Data for a feature event. 18 | 19 | Attributes: 20 | feature_identifier: The identifier of the feature. 21 | feature_identifier_type: The type of the feature identifier. 22 | feature_name: The name of the feature. 23 | feature_value: The value of the feature. 24 | """ 25 | 26 | feature_identifier: str 27 | feature_identifier_type: str 28 | feature_name: str 29 | feature_value: WyvernFeature 30 | 31 | 32 | class FeatureEvent(LoggedEvent[FeatureLogEventData]): 33 | """A feature event. 34 | 35 | Attributes: 36 | event_type: The type of the event. Defaults to EventType.FEATURE. 37 | """ 38 | 39 | event_type: EventType = EventType.FEATURE 40 | 41 | 42 | class FeatureEventLoggingRequest( 43 | GenericModel, 44 | Generic[REQUEST_ENTITY], 45 | ): 46 | """A request to log feature events. 47 | 48 | Attributes: 49 | request: The request to log feature events for. 50 | feature_df: The feature data frame to log. 51 | """ 52 | 53 | request: REQUEST_ENTITY 54 | 55 | class Config: 56 | arbitrary_types_allowed = True 57 | 58 | 59 | class FeatureEventLoggingComponent( 60 | Component[FeatureEventLoggingRequest[REQUEST_ENTITY], None], 61 | Generic[REQUEST_ENTITY], 62 | ): 63 | """A component that logs feature events.""" 64 | 65 | async def execute( 66 | self, input: FeatureEventLoggingRequest[REQUEST_ENTITY], **kwargs 67 | ) -> None: 68 | """Logs feature events.""" 69 | wyvern_request = request_context.ensure_current_request() 70 | url_path = wyvern_request.url_path 71 | run_id = wyvern_request.run_id 72 | 73 | def feature_event_generator(): 74 | """Generates feature events. This is a generator function that's called by the event logger. It's never called directly. 75 | 76 | Returns: 77 | A list of feature events. 78 | """ 79 | timestamp = datetime.utcnow() 80 | 81 | # Extract column names excluding "IDENTIFIER" 82 | feature_columns = wyvern_request.feature_df.df.columns[1:] 83 | 84 | return [ 85 | FeatureEvent( 86 | request_id=input.request.request_id, 87 | run_id=run_id, 88 | api_source=url_path, 89 | event_timestamp=timestamp, 90 | event_data=FeatureLogEventData( 91 | feature_identifier_type=wyvern_request.get_original_identifier( 92 | row[IDENTIFIER], 93 | col, 94 | ).identifier_type, 95 | feature_identifier=wyvern_request.get_original_identifier( 96 | row[IDENTIFIER], 97 | col, 98 | ).identifier, 99 | feature_name=col, 100 | feature_value=row[col], 101 | ), 102 | ) 103 | for row in wyvern_request.feature_df.df.iter_rows(named=True) 104 | for col in feature_columns 105 | if row[col] 106 | ] 107 | 108 | event_logger.log_events(feature_event_generator) # type: ignore 109 | -------------------------------------------------------------------------------- /wyvern/components/helpers/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Skyvern-AI/wyvern/a6f5f72860e6f528e81fb816fa3f4275d08bed1d/wyvern/components/helpers/__init__.py -------------------------------------------------------------------------------- /wyvern/components/helpers/linear_algebra.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import asyncio 3 | from typing import List, Tuple 4 | 5 | from scipy.spatial.distance import cosine 6 | 7 | from wyvern.components.component import Component 8 | 9 | 10 | class CosineSimilarityComponent( 11 | Component[List[Tuple[List[float], List[float]]], List[float]], 12 | ): 13 | """ 14 | A component that computes cosine similarity in parallel for all pairs of embeddings. 15 | """ 16 | 17 | def __init__(self, name: str): 18 | super().__init__(name=name) 19 | 20 | async def execute( 21 | self, 22 | input: List[Tuple[List[float], List[float]]], 23 | **kwargs, 24 | ) -> List[float]: 25 | """ 26 | Computes cosine similarity in parallel for all pairs of embeddings. 27 | 28 | Args: 29 | input: List of tuples of embeddings to compute cosine similarity for. 30 | 31 | Returns: 32 | List of cosine similarities. 33 | """ 34 | tasks = await asyncio.gather( 35 | *[ 36 | self.cosine_similarity(embedding1, embedding2) 37 | for (embedding1, embedding2) in input 38 | ], 39 | return_exceptions=False, 40 | ) 41 | # TODO (suchintan): Handle exceptions in cosine similarity function 42 | return list(tasks) 43 | 44 | async def cosine_similarity( 45 | self, 46 | embedding_1: List[float], 47 | embedding_2: List[float], 48 | ) -> float: 49 | """ 50 | Computes cosine similarity between two embeddings. 51 | """ 52 | return 1 - cosine(embedding_1, embedding_2) 53 | -------------------------------------------------------------------------------- /wyvern/components/helpers/polars.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import polars as pl 3 | 4 | 5 | def cast_float32_to_float64(df) -> pl.DataFrame: 6 | float32_cols = [ 7 | col for col, dtype in zip(df.columns, df.dtypes) if dtype == pl.Float32 8 | ] 9 | return df.with_columns([df[col].cast(pl.Float64) for col in float32_cols]) 10 | -------------------------------------------------------------------------------- /wyvern/components/helpers/sorting.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from typing import List 3 | 4 | from wyvern.components.component import Component 5 | from wyvern.entities.candidate_entities import ( 6 | GENERALIZED_WYVERN_ENTITY, 7 | ScoredCandidate, 8 | ) 9 | 10 | 11 | class SortingComponent( 12 | Component[ 13 | List[ScoredCandidate[GENERALIZED_WYVERN_ENTITY]], 14 | List[ScoredCandidate[GENERALIZED_WYVERN_ENTITY]], 15 | ], 16 | ): 17 | """ 18 | Sorts a list of candidates based on a score. 19 | """ 20 | 21 | async def execute( 22 | self, 23 | input: List[ScoredCandidate[GENERALIZED_WYVERN_ENTITY]], 24 | descending=True, 25 | **kwargs 26 | ) -> List[ScoredCandidate[GENERALIZED_WYVERN_ENTITY]]: 27 | """ 28 | Sorts a list of candidates based on a score. 29 | 30 | Args: 31 | input: A list of candidates to be sorted. Each candidate must have a score. 32 | descending: Whether to sort in descending order. Defaults to True. 33 | 34 | Returns: 35 | A sorted list of candidates. 36 | """ 37 | return sorted(input, key=lambda candidate: candidate.score, reverse=descending) 38 | -------------------------------------------------------------------------------- /wyvern/components/impressions/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Skyvern-AI/wyvern/a6f5f72860e6f528e81fb816fa3f4275d08bed1d/wyvern/components/impressions/__init__.py -------------------------------------------------------------------------------- /wyvern/components/impressions/impression_logger.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from datetime import datetime 3 | from typing import Generic, List 4 | 5 | from ddtrace import tracer 6 | from pydantic.generics import GenericModel 7 | 8 | from wyvern import request_context 9 | from wyvern.components.component import Component 10 | from wyvern.components.events.events import EntityEventData, EventType, LoggedEvent 11 | from wyvern.entities.candidate_entities import ( 12 | GENERALIZED_WYVERN_ENTITY, 13 | ScoredCandidate, 14 | ) 15 | from wyvern.event_logging import event_logger 16 | from wyvern.wyvern_typing import REQUEST_ENTITY 17 | 18 | 19 | class ImpressionEventData(EntityEventData): 20 | """ 21 | Impression event data. This is the data that is logged for each impression. 22 | 23 | Args: 24 | impression_score: The score of the impression. 25 | impression_order: The order of the impression. 26 | """ 27 | 28 | impression_score: float 29 | impression_order: int 30 | 31 | 32 | class ImpressionEvent(LoggedEvent[ImpressionEventData]): 33 | """ 34 | Impression event. This is the event that is logged for each impression. 35 | 36 | Args: 37 | event_type: The type of the event. This is always EventType.IMPRESSION. 38 | """ 39 | 40 | event_type: EventType = EventType.IMPRESSION 41 | 42 | 43 | class ImpressionEventLoggingRequest( 44 | GenericModel, 45 | Generic[GENERALIZED_WYVERN_ENTITY, REQUEST_ENTITY], 46 | ): 47 | """ 48 | Impression event logging request. 49 | 50 | Args: 51 | request: The request that was made. 52 | scored_impressions: The scored impressions. This is a list of scored candidates. 53 | Each scored candidate has an entity and a score. 54 | """ 55 | 56 | request: REQUEST_ENTITY 57 | scored_impressions: List[ScoredCandidate[GENERALIZED_WYVERN_ENTITY]] 58 | 59 | 60 | class ImpressionEventLoggingComponent( 61 | Component[ 62 | ImpressionEventLoggingRequest[GENERALIZED_WYVERN_ENTITY, REQUEST_ENTITY], 63 | None, 64 | ], 65 | Generic[GENERALIZED_WYVERN_ENTITY, REQUEST_ENTITY], 66 | ): 67 | """ 68 | Impression event logging component. This component logs impression events. 69 | """ 70 | 71 | @tracer.wrap(name="ImpressionEventLoggingComponent.execute") 72 | async def execute( 73 | self, 74 | input: ImpressionEventLoggingRequest[GENERALIZED_WYVERN_ENTITY, REQUEST_ENTITY], 75 | **kwargs 76 | ) -> None: 77 | """ 78 | Logs impression events. 79 | 80 | Args: 81 | input: The input to the component. This contains the request and the scored impressions. 82 | 83 | Returns: 84 | None 85 | """ 86 | current_span = tracer.current_span() 87 | if current_span: 88 | current_span.set_tag("impression_size", len(input.scored_impressions)) 89 | wyvern_request = request_context.ensure_current_request() 90 | url_path = wyvern_request.url_path 91 | run_id = wyvern_request.run_id 92 | 93 | def impression_events_generator() -> List[ImpressionEvent]: 94 | timestamp = datetime.utcnow() 95 | impression_events = [ 96 | ImpressionEvent( 97 | request_id=input.request.request_id, 98 | run_id=run_id, 99 | api_source=url_path, 100 | event_timestamp=timestamp, 101 | event_data=ImpressionEventData( 102 | entity_identifier=impression.entity.identifier.identifier, 103 | entity_identifier_type=impression.entity.identifier.identifier_type, 104 | impression_score=impression.score, 105 | impression_order=i, 106 | ), 107 | ) 108 | for i, impression in enumerate(input.scored_impressions) 109 | ] 110 | return impression_events 111 | 112 | event_logger.log_events(impression_events_generator) # type: ignore 113 | -------------------------------------------------------------------------------- /wyvern/components/index/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from ._index import IndexDeleteComponent, IndexGetComponent, IndexUploadComponent 3 | 4 | __all__ = [ 5 | "IndexDeleteComponent", 6 | "IndexGetComponent", 7 | "IndexUploadComponent", 8 | ] 9 | -------------------------------------------------------------------------------- /wyvern/components/index/_index.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from __future__ import annotations 3 | 4 | import logging 5 | from typing import Any, Dict, List, Type 6 | 7 | from wyvern.components.api_route_component import APIRouteComponent 8 | from wyvern.entities.index_entities import ( 9 | DeleteEntitiesRequest, 10 | DeleteEntitiesResponse, 11 | EntitiesRequest, 12 | GetEntitiesResponse, 13 | IndexRequest, 14 | IndexResponse, 15 | ) 16 | from wyvern.exceptions import WyvernEntityValidationError, WyvernError 17 | from wyvern.index import WyvernEntityIndex, WyvernIndex 18 | from wyvern.redis import wyvern_redis 19 | 20 | logger = logging.getLogger(__name__) 21 | 22 | 23 | class IndexUploadComponent( 24 | APIRouteComponent[IndexRequest, IndexResponse], 25 | ): 26 | PATH: str = "/entities/upload" 27 | REQUEST_SCHEMA_CLASS: Type[IndexRequest] = IndexRequest 28 | RESPONSE_SCHEMA_CLASS: Type[IndexResponse] = IndexResponse 29 | 30 | async def execute( 31 | self, 32 | input: IndexRequest, 33 | **kwargs, 34 | ) -> IndexResponse: 35 | """ 36 | bulk index entities with redis pipeline 37 | """ 38 | 39 | entity_internal_key = f"{input.entity_type.value}_id" 40 | entity_key: str = input.entity_key or entity_internal_key 41 | 42 | entities: List[Dict[str, Any]] = [] 43 | for entity in input.entities: 44 | # validation: entity must have entity_key 45 | if entity_key not in entity: 46 | raise WyvernEntityValidationError( 47 | entity_key=entity_key, 48 | entity=entity, 49 | ) 50 | 51 | if entity_internal_key not in entity: 52 | entity[entity_internal_key] = entity[entity_key] 53 | elif (entity_internal_key in entity) and ( 54 | entity[entity_internal_key] != entity[entity_key] 55 | ): 56 | logger.warning( 57 | f"entity already has an internal key={entity_internal_key} " 58 | f"with value={entity[entity_internal_key]}, " 59 | f"skipping setting the value to {entity[entity_key]}", 60 | ) 61 | 62 | entities.append(entity) 63 | 64 | entity_ids = await wyvern_redis.bulk_index( 65 | entities, 66 | entity_key, 67 | input.entity_type.value, 68 | ) 69 | 70 | return IndexResponse( 71 | entity_type=input.entity_type.value, 72 | entity_ids=entity_ids, 73 | ) 74 | 75 | 76 | class IndexDeleteComponent( 77 | APIRouteComponent[DeleteEntitiesRequest, DeleteEntitiesResponse], 78 | ): 79 | PATH: str = "/entities/delete" 80 | REQUEST_SCHEMA_CLASS: Type[DeleteEntitiesRequest] = DeleteEntitiesRequest 81 | RESPONSE_SCHEMA_CLASS: Type[DeleteEntitiesResponse] = DeleteEntitiesResponse 82 | 83 | async def execute( 84 | self, 85 | input: DeleteEntitiesRequest, 86 | **kwargs, 87 | ) -> DeleteEntitiesResponse: 88 | await WyvernIndex.bulk_delete(input.entity_type.value, input.entity_ids) 89 | return DeleteEntitiesResponse( 90 | entity_ids=input.entity_ids, 91 | entity_type=input.entity_type.value, 92 | ) 93 | 94 | 95 | class IndexGetComponent( 96 | APIRouteComponent[EntitiesRequest, GetEntitiesResponse], 97 | ): 98 | PATH: str = "/entities/get" 99 | REQUEST_SCHEMA_CLASS: Type[EntitiesRequest] = EntitiesRequest 100 | RESPONSE_SCHEMA_CLASS: Type[GetEntitiesResponse] = GetEntitiesResponse 101 | 102 | async def execute( 103 | self, 104 | input: EntitiesRequest, 105 | **kwargs, 106 | ) -> GetEntitiesResponse: 107 | entities = await WyvernEntityIndex.bulk_get( 108 | entity_type=input.entity_type.value, 109 | entity_ids=input.entity_ids, 110 | ) 111 | if len(entities) != len(input.entity_ids): 112 | raise WyvernError("Unexpected Error") 113 | entity_map = {input.entity_ids[i]: entities[i] for i in range(len(entities))} 114 | return GetEntitiesResponse( 115 | entity_type=input.entity_type.value, 116 | entities=entity_map, 117 | ) 118 | -------------------------------------------------------------------------------- /wyvern/components/models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Skyvern-AI/wyvern/a6f5f72860e6f528e81fb816fa3f4275d08bed1d/wyvern/components/models/__init__.py -------------------------------------------------------------------------------- /wyvern/components/models/model_chain_component.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from functools import cached_property 3 | from typing import Optional, Set 4 | 5 | from wyvern.components.models.model_component import ( 6 | BaseModelComponent, 7 | MultiEntityModelComponent, 8 | SingleEntityModelComponent, 9 | ) 10 | from wyvern.entities.model_entities import MODEL_INPUT, MODEL_OUTPUT, ChainedModelInput 11 | from wyvern.exceptions import MissingModelChainOutputError 12 | from wyvern.wyvern_typing import REQUEST_ENTITY 13 | 14 | 15 | class MultiEntityModelChain(MultiEntityModelComponent[MODEL_INPUT, MODEL_OUTPUT]): 16 | def __init__(self, *upstreams: BaseModelComponent, name: Optional[str] = None): 17 | super().__init__(*upstreams, name=name) 18 | self.chain = upstreams 19 | 20 | @cached_property 21 | def manifest_feature_names(self) -> Set[str]: 22 | feature_names: Set[str] = set() 23 | for model in self.chain: 24 | feature_names = feature_names.union(model.manifest_feature_names) 25 | return feature_names 26 | 27 | async def inference(self, input: MODEL_INPUT, **kwargs) -> MODEL_OUTPUT: 28 | output = None 29 | prev_model: Optional[BaseModelComponent] = None 30 | for model in self.chain: 31 | curr_input: ChainedModelInput 32 | if prev_model is not None and output is not None: 33 | curr_input = ChainedModelInput( 34 | request=input.request, 35 | entities=input.entities, 36 | upstream_model_name=prev_model.name, 37 | upstream_model_output=output.data, 38 | ) 39 | else: 40 | curr_input = ChainedModelInput( 41 | request=input.request, 42 | entities=input.entities, 43 | upstream_model_name=None, 44 | upstream_model_output={}, 45 | ) 46 | output = await model.execute(curr_input, **kwargs) 47 | prev_model = model 48 | 49 | if output is None: 50 | raise MissingModelChainOutputError() 51 | 52 | # TODO: do type checking to make sure the output is of the correct type 53 | return output 54 | 55 | 56 | class SingleEntityModelChain(SingleEntityModelComponent[REQUEST_ENTITY, MODEL_OUTPUT]): 57 | def __init__( 58 | self, *upstreams: SingleEntityModelComponent, name: Optional[str] = None 59 | ): 60 | super().__init__(*upstreams, name=name) 61 | self.chain = upstreams 62 | 63 | @cached_property 64 | def manifest_feature_names(self) -> Set[str]: 65 | feature_names: Set[str] = set() 66 | for model in self.chain: 67 | feature_names = feature_names.union(model.manifest_feature_names) 68 | return feature_names 69 | 70 | async def inference(self, input: REQUEST_ENTITY, **kwargs) -> MODEL_OUTPUT: 71 | output = None 72 | for model in self.chain: 73 | output = await model.execute(input, **kwargs) 74 | 75 | if output is None: 76 | raise MissingModelChainOutputError() 77 | 78 | # TODO: do type checking to make sure the output is of the correct type 79 | return output 80 | -------------------------------------------------------------------------------- /wyvern/components/models/modelbit_component.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import asyncio 3 | import logging 4 | from functools import cached_property 5 | from typing import Any, Dict, List, Optional, Set, Tuple, TypeAlias, Union, final 6 | 7 | from wyvern.components.models.model_component import ( 8 | BaseModelComponent, 9 | MultiEntityModelComponent, 10 | SingleEntityModelComponent, 11 | ) 12 | from wyvern.config import settings 13 | from wyvern.core.http import aiohttp_client 14 | from wyvern.entities.identifier import Identifier 15 | from wyvern.entities.identifier_entities import WyvernEntity 16 | from wyvern.entities.model_entities import MODEL_INPUT, MODEL_OUTPUT 17 | from wyvern.entities.request import BaseWyvernRequest 18 | from wyvern.exceptions import ( 19 | WyvernModelbitTokenMissingError, 20 | WyvernModelbitValidationError, 21 | ) 22 | from wyvern.wyvern_typing import INPUT_TYPE, REQUEST_ENTITY 23 | 24 | JSON: TypeAlias = Union[Dict[str, "JSON"], List["JSON"], str, int, float, bool, None] 25 | logger = logging.getLogger(__name__) 26 | 27 | 28 | class ModelbitMixin(BaseModelComponent[INPUT_TYPE, MODEL_OUTPUT]): 29 | AUTH_TOKEN: str = "" 30 | URL: str = "" 31 | 32 | def __init__( 33 | self, 34 | *upstreams, 35 | name: Optional[str] = None, 36 | auth_token: Optional[str] = None, 37 | url: Optional[str] = None, 38 | cache_output: bool = False, 39 | ) -> None: 40 | """ 41 | Args: 42 | *upstreams: A list of upstream components. 43 | name: A string that represents the name of the model. 44 | auth_token: A string that represents the auth token for Modelbit. 45 | url: A string that represents the url for Modelbit. 46 | 47 | Raises: 48 | WyvernModelbitTokenMissingError: If the auth token is not provided. 49 | """ 50 | super().__init__(*upstreams, name=name, cache_output=cache_output) 51 | self._auth_token = auth_token or self.AUTH_TOKEN 52 | self._modelbit_url = url or self.URL 53 | self.headers = { 54 | "Authorization": self._auth_token, 55 | "Content-Type": "application/json", 56 | } 57 | 58 | if not self._auth_token: 59 | raise WyvernModelbitTokenMissingError() 60 | 61 | @cached_property 62 | def modelbit_features(self) -> List[str]: 63 | """ 64 | This is a cached property that returns a list of modelbit features. This method should be implemented by the 65 | subclass. 66 | """ 67 | return [] 68 | 69 | @cached_property 70 | def manifest_feature_names(self) -> Set[str]: 71 | """ 72 | This is a cached property that returns a set of manifest feature names. This method wraps around the 73 | modelbit_features property. 74 | """ 75 | return set(self.modelbit_features) 76 | 77 | async def inference(self, input: INPUT_TYPE, **kwargs) -> MODEL_OUTPUT: 78 | """ 79 | This method sends a request to Modelbit and returns the output. 80 | """ 81 | # TODO shu: currently we don't support modelbit inference just for request if the input contains entities 82 | 83 | target_identifiers, all_requests = await self.build_requests(input) 84 | 85 | if len(target_identifiers) != len(all_requests): 86 | raise WyvernModelbitValidationError( 87 | f"Number of identifiers ({len(target_identifiers)}) " 88 | f"does not match number of modelbit requests ({len(all_requests)})", 89 | ) 90 | 91 | # split requests into smaller batches and parallelize them 92 | futures = [ 93 | aiohttp_client().post( 94 | self._modelbit_url, 95 | headers=self.headers, 96 | json={"data": all_requests[i : i + settings.MODELBIT_BATCH_SIZE]}, 97 | ) 98 | for i in range(0, len(all_requests), settings.MODELBIT_BATCH_SIZE) 99 | ] 100 | responses = await asyncio.gather(*futures) 101 | # resp_list: List[List[float]] = resp.json().get("data", []) 102 | output_data: Dict[Identifier, Optional[Union[float, str, List[float]]]] = {} 103 | 104 | for batch_idx, resp in enumerate(responses): 105 | if resp.status != 200: 106 | text = await resp.text() 107 | logger.warning(f"Modelbit inference failed: {text}") 108 | continue 109 | resp_list: List[List[Union[float, str, List[float], None]]] = ( 110 | await resp.json() 111 | ).get( 112 | "data", 113 | [], 114 | ) 115 | for idx, individual_output in enumerate(resp_list): 116 | # individual_output[0] is the index of modelbit output which is useless so we'll not use it 117 | # individual_output[1] is the actual output 118 | output_data[ 119 | target_identifiers[batch_idx * settings.MODELBIT_BATCH_SIZE + idx] 120 | ] = individual_output[1] 121 | 122 | return self.model_output_type( 123 | data=output_data, 124 | model_name=self.name, 125 | ) 126 | 127 | async def build_requests( 128 | self, 129 | input: INPUT_TYPE, 130 | ) -> Tuple[List[Identifier], List[Any]]: 131 | """ 132 | This method builds requests for Modelbit. This method should be implemented by the subclass. 133 | """ 134 | raise NotImplementedError 135 | 136 | 137 | class ModelbitComponent( 138 | ModelbitMixin[MODEL_INPUT, MODEL_OUTPUT], 139 | MultiEntityModelComponent[MODEL_INPUT, MODEL_OUTPUT], 140 | ): 141 | """ 142 | ModelbitComponent is a base class for all modelbit model components. It provides a common interface to implement 143 | all modelbit models. 144 | 145 | ModelbitComponent is a subclass of ModelComponent. 146 | 147 | Attributes: 148 | AUTH_TOKEN: A class variable that stores the auth token for Modelbit. 149 | URL: A class variable that stores the url for Modelbit. 150 | """ 151 | 152 | async def build_requests( 153 | self, 154 | input: MODEL_INPUT, 155 | ) -> Tuple[List[Identifier], List[Any]]: 156 | """ 157 | Please refer to modlebit batch inference API: 158 | https://doc.modelbit.com/deployments/rest-api/ 159 | """ 160 | target_entities: List[ 161 | Union[WyvernEntity, BaseWyvernRequest] 162 | ] = input.entities or [input.request] 163 | target_identifiers = [entity.identifier for entity in target_entities] 164 | identifier_features_tuples = self.get_features( 165 | target_identifiers, 166 | self.modelbit_features, 167 | ) 168 | 169 | all_requests = [ 170 | [idx + 1, features] 171 | for idx, (identifier, features) in enumerate(identifier_features_tuples) 172 | ] 173 | return target_identifiers, all_requests 174 | 175 | 176 | class SingleEntityModelbitComponent( 177 | ModelbitMixin[REQUEST_ENTITY, MODEL_OUTPUT], 178 | SingleEntityModelComponent[REQUEST_ENTITY, MODEL_OUTPUT], 179 | ): 180 | @final 181 | async def build_requests( 182 | self, 183 | input: REQUEST_ENTITY, 184 | ) -> Tuple[List[Identifier], List[Any]]: 185 | target_identifier, request = await self.build_request(input) 186 | all_requests = [[1, request]] 187 | return [target_identifier], all_requests 188 | 189 | async def build_request(self, input: REQUEST_ENTITY) -> Tuple[Identifier, Any]: 190 | raise NotImplementedError 191 | -------------------------------------------------------------------------------- /wyvern/components/pagination/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Skyvern-AI/wyvern/a6f5f72860e6f528e81fb816fa3f4275d08bed1d/wyvern/components/pagination/__init__.py -------------------------------------------------------------------------------- /wyvern/components/pagination/pagination_component.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import logging 3 | from typing import Generic, List 4 | 5 | from pydantic.generics import GenericModel 6 | 7 | from wyvern.components.component import Component 8 | from wyvern.components.pagination.pagination_fields import PaginationFields 9 | from wyvern.exceptions import PaginationError 10 | from wyvern.wyvern_typing import T 11 | 12 | logger = logging.getLogger(__name__) 13 | 14 | 15 | class PaginationRequest(GenericModel, Generic[T]): 16 | """ 17 | This is the input to the PaginationComponent. 18 | 19 | Attributes: 20 | pagination_fields: The pagination fields that are used to compute the pagination. 21 | entities: The entities that need to be paginated. 22 | """ 23 | 24 | pagination_fields: PaginationFields 25 | entities: List[T] 26 | 27 | 28 | class PaginationComponent(Component[PaginationRequest[T], List[T]]): 29 | """ 30 | This component is used to paginate the entities. It takes in the pagination fields and the entities and returns 31 | the paginated entities. 32 | """ 33 | 34 | def __init__(self): 35 | super().__init__(name="PaginationComponent") 36 | 37 | async def execute(self, input: PaginationRequest[T], **kwargs) -> List[T]: 38 | """ 39 | This method paginates the entities based on the pagination fields. 40 | 41 | Validations: 42 | 1. The ranking page should be greater than or equal to 0. 43 | 2. The candidate page should be greater than or equal to 0. 44 | 3. The candidate page size should be less than or equal to 1000. 45 | 4. The number of entities should be less than or equal to 1000. 46 | 5. The user page size should be less than or equal to 100. 47 | 6. The user page size should be less than or equal to the candidate page size. 48 | 7. The end index should be less than the number of entities. 49 | 8. The end index should be greater than the start index. 50 | 51 | Returns: 52 | The paginated entities. 53 | """ 54 | if len(input.entities) == 0: 55 | logger.info("Found no entities to paginate, skipping pagination") 56 | return [] 57 | 58 | user_page = input.pagination_fields.user_page 59 | candidate_page = input.pagination_fields.candidate_page 60 | candidate_page_size = input.pagination_fields.candidate_page_size 61 | user_page_size = input.pagination_fields.user_page_size 62 | 63 | ranking_page = user_page - ( 64 | candidate_page * candidate_page_size / user_page_size 65 | ) 66 | 67 | start_index = int(ranking_page * user_page_size) 68 | end_index = min(int((ranking_page + 1) * user_page_size), len(input.entities)) 69 | 70 | # TODO (suchintan): Add test case, this can happen if candidate page > user page 71 | if ranking_page < 0: 72 | message = ( 73 | f"Ranking page {ranking_page} is less than 0. Is the user_page correct?. " 74 | f"pagination_fields={input.pagination_fields}" 75 | ) 76 | logger.error(message) 77 | raise PaginationError(message) 78 | 79 | # TODO (suchintan): I wonder if we can have this kind of validation live in the FastApi layer 80 | # TODO (suchintan): Add test case 81 | if candidate_page < 0 or user_page < 0: 82 | message = ( 83 | f"User page {user_page} or candidate page {candidate_page} is less than 0, " 84 | f"pagination_fields={input.pagination_fields}" 85 | ) 86 | logger.error(message) 87 | raise PaginationError(message) 88 | 89 | # TODO (suchintan): Add test case 90 | if candidate_page_size > 1000 or candidate_page_size < 0: 91 | message = ( 92 | f"Candidate page size {candidate_page_size} is greater than 1000 or less than 0, " 93 | f"pagination_fields={input.pagination_fields}" 94 | ) 95 | logger.error(message) 96 | raise PaginationError(message) 97 | 98 | # TODO (suchintan): Add test case 99 | if len(input.entities) > 1000: 100 | message = ( 101 | f"Number of entities {len(input.entities)} is greater than 1000, " 102 | f"pagination_fields={input.pagination_fields}" 103 | ) 104 | logger.error(message) 105 | raise PaginationError(message) 106 | 107 | # TODO (suchintan): Add test case 108 | if user_page_size > 100 or user_page_size < 0: 109 | message = ( 110 | f"User page size {user_page_size} is greater than 100 or less than 0, " 111 | f"pagination_fields={input.pagination_fields}" 112 | ) 113 | logger.error(message) 114 | raise PaginationError(message) 115 | 116 | if user_page_size > candidate_page_size: 117 | message = ( 118 | f"User page size {user_page_size} is greater than candidate page size {candidate_page_size}, " 119 | f"pagination_fields={input.pagination_fields}" 120 | ) 121 | logger.error(message) 122 | raise PaginationError(message) 123 | 124 | # TODO (suchintan): Add test case 125 | if end_index > len(input.entities): 126 | message = ( 127 | f"Computed End index {end_index} is greater than the number of entities {len(input.entities)}, " 128 | f"pagination_fields={input.pagination_fields}" 129 | ) 130 | logger.error(message) 131 | raise PaginationError(message) 132 | 133 | # This should NEVER happen, but add a case here 134 | if end_index <= start_index: 135 | message = ( 136 | f"Computed end_index={end_index} is less than or equal to the start_index={start_index} " 137 | f"number_of_entities={len(input.entities)}, " 138 | f"pagination_fields={input.pagination_fields}" 139 | ) 140 | logger.error(message) 141 | raise PaginationError(message) 142 | 143 | # TODO (suchintan): Add test case 144 | return input.entities[start_index:end_index] 145 | -------------------------------------------------------------------------------- /wyvern/components/pagination/pagination_fields.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from pydantic import BaseModel 3 | 4 | 5 | class PaginationFields(BaseModel): 6 | """ 7 | Pagination fields for requests. This is a mixin class that can be used in any request that requires pagination. 8 | 9 | Attributes: 10 | user_page_size: Zero-indexed user facing page number 11 | user_page: Number of items per user facing page 12 | candidate_page_size: This is the size of the candidate page. 13 | candidate_page: This is the zero-indexed page number for the candidate set 14 | """ 15 | 16 | user_page_size: int 17 | user_page: int 18 | candidate_page_size: int 19 | candidate_page: int 20 | -------------------------------------------------------------------------------- /wyvern/components/pipeline_component.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from functools import cached_property 3 | from typing import Optional, Set, Type 4 | 5 | from ddtrace import tracer 6 | 7 | from wyvern.components.api_route_component import APIRouteComponent 8 | from wyvern.components.component import Component 9 | from wyvern.components.features.feature_retrieval_pipeline import ( 10 | FeatureRetrievalPipeline, 11 | FeatureRetrievalPipelineRequest, 12 | ) 13 | from wyvern.components.features.realtime_features_component import ( 14 | RealtimeFeatureComponent, 15 | ) 16 | from wyvern.exceptions import ComponentAlreadyDefinedInPipelineComponentError 17 | from wyvern.wyvern_typing import REQUEST_ENTITY, RESPONSE_SCHEMA 18 | 19 | 20 | class PipelineComponent(APIRouteComponent[REQUEST_ENTITY, RESPONSE_SCHEMA]): 21 | """ 22 | PipelineComponent is the base class for all the pipeline components in Wyvern. It is a Component that 23 | takes in a request entity and a response schema, and it is responsible for hydrating the request 24 | data with Wyvern Index data, and then pass the hydrated data to the next component in the pipeline. 25 | """ 26 | 27 | def __init__( 28 | self, 29 | *upstreams: Component, 30 | name: Optional[str] = None, 31 | handle_feature_store_exceptions: bool = False, 32 | ) -> None: 33 | for upstream in upstreams: 34 | if isinstance(upstream, FeatureRetrievalPipeline): 35 | raise ComponentAlreadyDefinedInPipelineComponentError( 36 | component_type="FeatureRetrievalPipeline", 37 | ) 38 | 39 | self.feature_retrieval_pipeline = FeatureRetrievalPipeline[REQUEST_ENTITY]( 40 | name=f"{self.__class__.__name__}-feature_retrieval", 41 | handle_exceptions=handle_feature_store_exceptions, 42 | ) 43 | self.feature_names: Set[str] = set() 44 | super().__init__(*upstreams, self.feature_retrieval_pipeline, name=name) 45 | 46 | @cached_property 47 | def realtime_features_overrides(self) -> Set[Type[RealtimeFeatureComponent]]: 48 | """ 49 | This function defines the set of RealtimeFeatureComponents that generates features 50 | with non-deterministic feature names. 51 | For example, feature names like matched_query_brand. 52 | That feature is defined like matched_query_{input.query.matched_query}, so it can refer to 10 or 20 features 53 | """ 54 | return set() 55 | 56 | async def initialize(self) -> None: 57 | # get all the feature names from all the upstream components 58 | for component in self.initialized_components: 59 | for feature_name in component.manifest_feature_names: 60 | self.feature_names.add(feature_name) 61 | 62 | @tracer.wrap(name="PipelineComponent.retrieve_features") 63 | async def retrieve_features(self, request: REQUEST_ENTITY) -> None: 64 | """ 65 | TODO shu: it doesn't support feature overrides. Write code to support that 66 | """ 67 | feature_request = FeatureRetrievalPipelineRequest[REQUEST_ENTITY]( 68 | request=request, 69 | requested_feature_names=self.feature_names, 70 | feature_overrides=self.realtime_features_overrides, 71 | ) 72 | await self.feature_retrieval_pipeline.execute( 73 | feature_request, 74 | ) 75 | 76 | async def warm_up(self, input: REQUEST_ENTITY) -> None: 77 | await super().warm_up(input) 78 | 79 | # TODO shu: split feature_retrieval_pipeline into 80 | # 1. feature retrieval from feature store 2. realtime feature computation 81 | # then the warm_up and feature retrieval from feature store can be done in parallel 82 | # suchintan: we also need to retrieve brand features from the feature store 83 | # and brand info would only be available via hydration so hydration has to be done first 84 | await self.retrieve_features(input) 85 | -------------------------------------------------------------------------------- /wyvern/components/ranking_pipeline.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from typing import Any, Generic, List, Optional 3 | 4 | from pydantic import BaseModel 5 | 6 | from wyvern.components.business_logic.business_logic import ( 7 | BusinessLogicPipeline, 8 | BusinessLogicRequest, 9 | ) 10 | from wyvern.components.candidates.candidate_logger import CandidateEventLoggingComponent 11 | from wyvern.components.events.events import LoggedEvent 12 | from wyvern.components.impressions.impression_logger import ( 13 | ImpressionEventLoggingComponent, 14 | ImpressionEventLoggingRequest, 15 | ) 16 | from wyvern.components.models.model_component import ModelComponent 17 | from wyvern.components.pagination.pagination_component import ( 18 | PaginationComponent, 19 | PaginationRequest, 20 | ) 21 | from wyvern.components.pagination.pagination_fields import PaginationFields 22 | from wyvern.components.pipeline_component import PipelineComponent 23 | from wyvern.entities.candidate_entities import ScoredCandidate 24 | from wyvern.entities.identifier_entities import QueryEntity 25 | from wyvern.entities.model_entities import ModelInput 26 | from wyvern.entities.request import BaseWyvernRequest 27 | from wyvern.event_logging import event_logger 28 | from wyvern.wyvern_typing import WYVERN_ENTITY 29 | 30 | 31 | class RankingRequest( 32 | BaseWyvernRequest, 33 | PaginationFields, 34 | Generic[WYVERN_ENTITY], 35 | ): 36 | """ 37 | This is the request for the ranking pipeline. 38 | 39 | Attributes: 40 | query: the query entity 41 | candidates: the list of candidate entities 42 | """ 43 | 44 | query: QueryEntity 45 | candidates: List[WYVERN_ENTITY] 46 | 47 | 48 | class ResponseCandidate(BaseModel): 49 | """ 50 | This is the response candidate. 51 | 52 | Attributes: 53 | candidate_id: the identifier of the candidate 54 | ranked_score: the ranked score of the candidate 55 | """ 56 | 57 | candidate_id: str 58 | ranked_score: float 59 | 60 | 61 | class RankingResponse(BaseModel): 62 | """ 63 | This is the response for the ranking pipeline. 64 | 65 | Attributes: 66 | ranked_candidates: the list of ranked candidates 67 | events: the list of logged events 68 | """ 69 | 70 | ranked_candidates: List[ResponseCandidate] 71 | events: Optional[List[LoggedEvent[Any]]] 72 | 73 | 74 | class RankingPipeline( 75 | PipelineComponent[RankingRequest, RankingResponse], 76 | Generic[WYVERN_ENTITY], 77 | ): 78 | """ 79 | This is the ranking pipeline. 80 | 81 | Attributes: 82 | PATH: the path of the API. This is used in the API routing. The default value is "/ranking". 83 | """ 84 | 85 | PATH: str = "/ranking" 86 | 87 | def __init__(self, name: Optional[str] = None): 88 | self.pagination_component = PaginationComponent[ 89 | ScoredCandidate[WYVERN_ENTITY] 90 | ]() 91 | self.ranking_model = self.get_model() 92 | self.candidate_logging_component = CandidateEventLoggingComponent[ 93 | WYVERN_ENTITY, 94 | RankingRequest[WYVERN_ENTITY], 95 | ]() 96 | self.impression_logging_component = ImpressionEventLoggingComponent[ 97 | WYVERN_ENTITY, 98 | RankingRequest[WYVERN_ENTITY], 99 | ]() 100 | 101 | upstream_components = [ 102 | self.pagination_component, 103 | self.ranking_model, 104 | self.candidate_logging_component, 105 | self.impression_logging_component, 106 | ] 107 | self.business_logic_pipeline: BusinessLogicPipeline 108 | business_logic = self.get_business_logic() 109 | if business_logic: 110 | self.business_logic_pipeline = business_logic 111 | else: 112 | self.business_logic_pipeline = BusinessLogicPipeline[ 113 | WYVERN_ENTITY, 114 | RankingRequest[WYVERN_ENTITY], 115 | ]() 116 | upstream_components.append(self.business_logic_pipeline) 117 | 118 | super().__init__( 119 | *upstream_components, 120 | name=name, 121 | ) 122 | 123 | def get_model(self) -> ModelComponent: 124 | """ 125 | This is the ranking model. 126 | 127 | The model input should be a subclass of ModelInput. 128 | Its output should be scored candidates 129 | """ 130 | raise NotImplementedError 131 | 132 | def get_business_logic(self) -> Optional[BusinessLogicPipeline]: 133 | """ 134 | This is the business logic pipeline. It is optional. If not provided, the ranking pipeline will not 135 | apply any business logic. 136 | 137 | The business logic pipeline should be a subclass of BusinessLogicPipeline. Some examples of business logic 138 | for ranking pipeline are: 139 | 1. Deduplication 140 | 2. Filtering 141 | 3. (De)boosting 142 | """ 143 | return None 144 | 145 | async def execute( 146 | self, 147 | input: RankingRequest[WYVERN_ENTITY], 148 | **kwargs, 149 | ) -> RankingResponse: 150 | ranked_candidates = await self.rank_candidates(input) 151 | 152 | pagination_request = PaginationRequest[ScoredCandidate[WYVERN_ENTITY]]( 153 | pagination_fields=input, 154 | entities=ranked_candidates, 155 | ) 156 | paginated_candidates = await self.pagination_component.execute( 157 | pagination_request, 158 | ) 159 | 160 | # TODO (suchintan): This should be automatic -- add this to the pipeline abstraction 161 | impression_logging_request = ImpressionEventLoggingRequest[ 162 | WYVERN_ENTITY, 163 | RankingRequest[WYVERN_ENTITY], 164 | ]( 165 | scored_impressions=paginated_candidates, 166 | request=input, 167 | ) 168 | await self.impression_logging_component.execute(impression_logging_request) 169 | 170 | response_ranked_candidates = [ 171 | ResponseCandidate( 172 | candidate_id=candidate.entity.identifier.identifier, 173 | ranked_score=candidate.score, 174 | ) 175 | for candidate in paginated_candidates 176 | ] 177 | 178 | response = RankingResponse( 179 | ranked_candidates=response_ranked_candidates, 180 | events=event_logger.get_logged_events() if input.include_events else None, 181 | ) 182 | 183 | return response 184 | 185 | async def rank_candidates( 186 | self, 187 | request: RankingRequest[WYVERN_ENTITY], 188 | ) -> List[ScoredCandidate[WYVERN_ENTITY]]: 189 | """ 190 | This function ranks the candidates. 191 | 192 | 1. It first calls the ranking model to get the model scores for the candidates. 193 | 2. It then calls the business logic pipeline to adjust the model scores. 194 | 3. It returns the adjusted candidates. 195 | 196 | Args: 197 | request: the ranking request 198 | 199 | Returns: 200 | A list of ScoredCandidate 201 | """ 202 | model_input = ModelInput[WYVERN_ENTITY, RankingRequest[WYVERN_ENTITY]]( 203 | request=request, 204 | entities=request.candidates, 205 | ) 206 | model_outputs = await self.ranking_model.execute(model_input) 207 | 208 | scored_candidates: List[ScoredCandidate] = [ 209 | ScoredCandidate( 210 | entity=candidate, 211 | score=( 212 | model_outputs.data.get(candidate.identifier) or 0 213 | ), # TODO (shu): what to do if model score is None? 214 | ) 215 | for i, candidate in enumerate(request.candidates) 216 | ] 217 | 218 | business_logic_request = BusinessLogicRequest[ 219 | WYVERN_ENTITY, 220 | RankingRequest[WYVERN_ENTITY], 221 | ]( 222 | request=request, 223 | scored_candidates=scored_candidates, 224 | ) 225 | 226 | # business_logic makes sure the candidates are sorted 227 | business_logic_response = await self.business_logic_pipeline.execute( 228 | business_logic_request, 229 | ) 230 | return business_logic_response.adjusted_candidates 231 | -------------------------------------------------------------------------------- /wyvern/components/single_entity_pipeline.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from typing import Any, Generic, List, Optional 3 | 4 | from pydantic.generics import GenericModel 5 | 6 | from wyvern.components.business_logic.business_logic import ( 7 | SingleEntityBusinessLogicPipeline, 8 | SingleEntityBusinessLogicRequest, 9 | ) 10 | from wyvern.components.component import Component 11 | from wyvern.components.events.events import LoggedEvent 12 | from wyvern.components.models.model_component import SingleEntityModelComponent 13 | from wyvern.components.pipeline_component import PipelineComponent 14 | from wyvern.entities.identifier import Identifier 15 | from wyvern.entities.model_entities import MODEL_OUTPUT_DATA_TYPE 16 | from wyvern.event_logging import event_logger 17 | from wyvern.exceptions import MissingModelOutputError 18 | from wyvern.wyvern_typing import REQUEST_ENTITY 19 | 20 | 21 | class SingleEntityPipelineResponse(GenericModel, Generic[MODEL_OUTPUT_DATA_TYPE]): 22 | data: Optional[MODEL_OUTPUT_DATA_TYPE] = None 23 | events: Optional[List[LoggedEvent[Any]]] = None 24 | 25 | 26 | class SingleEntityPipeline( 27 | PipelineComponent[ 28 | REQUEST_ENTITY, 29 | SingleEntityPipelineResponse[MODEL_OUTPUT_DATA_TYPE], 30 | ], 31 | Generic[REQUEST_ENTITY, MODEL_OUTPUT_DATA_TYPE], 32 | ): 33 | def __init__( 34 | self, 35 | *upstreams: Component, 36 | model: SingleEntityModelComponent, 37 | business_logic: Optional[ 38 | SingleEntityBusinessLogicPipeline[REQUEST_ENTITY, MODEL_OUTPUT_DATA_TYPE] 39 | ] = None, 40 | name: Optional[str] = None, 41 | handle_feature_store_exceptions: bool = False, 42 | ) -> None: 43 | upstream_components = list(upstreams) 44 | 45 | self.model = model 46 | upstream_components.append(self.model) 47 | 48 | if not business_logic: 49 | business_logic = SingleEntityBusinessLogicPipeline[ 50 | REQUEST_ENTITY, 51 | MODEL_OUTPUT_DATA_TYPE, 52 | ]() 53 | self.business_logic = business_logic 54 | upstream_components.append(self.business_logic) 55 | 56 | super().__init__( 57 | *upstream_components, 58 | name=name, 59 | handle_feature_store_exceptions=handle_feature_store_exceptions, 60 | ) 61 | 62 | async def execute( 63 | self, 64 | input: REQUEST_ENTITY, 65 | **kwargs, 66 | ) -> SingleEntityPipelineResponse[MODEL_OUTPUT_DATA_TYPE]: 67 | output = await self.model.execute(input, **kwargs) 68 | identifiers: List[Identifier] = list(output.data.keys()) 69 | if not identifiers: 70 | raise MissingModelOutputError() 71 | identifier = identifiers[0] 72 | model_output_data: MODEL_OUTPUT_DATA_TYPE = output.data.get(identifier) 73 | 74 | business_logic_input = SingleEntityBusinessLogicRequest[ 75 | REQUEST_ENTITY, 76 | MODEL_OUTPUT_DATA_TYPE, 77 | ]( 78 | identifier=identifier, 79 | request=input, 80 | model_output=model_output_data, 81 | ) 82 | business_logic_output = await self.business_logic.execute( 83 | input=business_logic_input, 84 | **kwargs, 85 | ) 86 | return self.generate_response( 87 | input, 88 | business_logic_output.adjusted_output, 89 | ) 90 | 91 | def generate_response( 92 | self, 93 | input: REQUEST_ENTITY, 94 | pipeline_output: Optional[MODEL_OUTPUT_DATA_TYPE], 95 | ) -> SingleEntityPipelineResponse[MODEL_OUTPUT_DATA_TYPE]: 96 | return SingleEntityPipelineResponse[MODEL_OUTPUT_DATA_TYPE]( 97 | data=pipeline_output, 98 | events=event_logger.get_logged_events() if input.include_events else None, 99 | ) 100 | -------------------------------------------------------------------------------- /wyvern/config.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from pydantic import BaseSettings 3 | 4 | from wyvern.experimentation.providers.base import ExperimentationProvider 5 | 6 | 7 | class Settings(BaseSettings): 8 | """Settings for the Wyvern service 9 | 10 | Extends from BaseSettings class, allowing values to be overridden by environment variables. This is useful 11 | in production for secrets you do not wish to save in code 12 | 13 | Attributes: 14 | ENVIRONMENT: The environment the service is running in. Default to `development`. 15 | PROJECT_NAME: The name of the project. Default to `default`. 16 | REDIS_HOST: The host of the redis instance. Default to `localhost`. 17 | REDIS_PORT: The port of the redis instance. Default to `6379`. 18 | 19 | WYVERN_API_KEY: The API key for the Wyvern API. Default to `""`, empty string. 20 | WYVERN_BASE_URL: The base url of the Wyvern API. Default to `https://api.wyvern.ai` 21 | WYVERN_ONLINE_FEATURES_PATH: 22 | The path to the online features endpoint. Default to `/feature/get-online-features`. 23 | WYVERN_HISTORICAL_FEATURES_PATH: 24 | The path to the historical features endpoint. Default to `/feature/get-historical-features`. 25 | WYVERN_FEATURE_STORE_URL: The url of the Wyvern feature store. Default to `https://api.wyvern.ai`. 26 | 27 | SNOWFLAKE_ACCOUNT: The account name of the Snowflake instance. Default to `""`, empty string. 28 | SNOWFLAKE_USER: The username of the Snowflake instance. Default to `""`, empty string. 29 | SNOWFLAKE_PASSWORD: The password of the Snowflake instance. Default to `""`, empty string. 30 | SNOWFLAKE_ROLE: The role of the Snowflake instance. Default to `""`, empty string. 31 | SNOWFLAKE_WAREHOUSE: The warehouse of the Snowflake instance. Default to `""`, empty string. 32 | SNOWFLAKE_DATABASE: The database of the Snowflake instance. Default to `""`, empty string. 33 | SNOWFLAKE_OFFLINE_STORE_SCHEMA: The schema of the Snowflake instance. Default to `PUBLIC`. 34 | 35 | AWS_ACCESS_KEY_ID: The access key id for the AWS instance. Default to `""`, empty string. 36 | AWS_SECRET_ACCESS_KEY: The secret access key for the AWS instance. Default to `""`, empty string. 37 | AWS_REGION_NAME: The region name for the AWS instance. Default to `us-east-1`. 38 | 39 | FEATURE_STORE_TIMEOUT: The timeout for the feature store. Default to `60` seconds. 40 | SERVER_TIMEOUT: The timeout for the server. Default to `60` seconds. 41 | 42 | REDIS_BATCH_SIZE: The batch size for the redis instance. Default to `100`. 43 | WYVERN_INDEX_VERSION: The version of the Wyvern index. Default to `1`. 44 | MODELBIT_BATCH_SIZE: The batch size for the modelbit. Default to `30`. 45 | 46 | EXPERIMENTATION_ENABLED: Whether experimentation is enabled. Default to `False`. 47 | EXPERIMENTATION_PROVIDER: The experimentation provider. Default to `ExperimentationProvider.EPPO.value`. 48 | EPPO_API_KEY: The API key for EPPO (an experimentation provider). Default to `""`, empty string. 49 | 50 | FEATURE_STORE_ENABLED: Whether the feature store is enabled. Default to `True`. 51 | EVENT_LOGGING_ENABLED: Whether event logging is enabled. Default to `True`. 52 | """ 53 | 54 | ENVIRONMENT: str = "development" 55 | 56 | PROJECT_NAME: str = "default" 57 | 58 | REDIS_HOST: str = "localhost" 59 | REDIS_PORT: int = 6379 60 | 61 | # URLs 62 | WYVERN_BASE_URL = "https://api.wyvern.ai" 63 | WYVERN_ONLINE_FEATURES_PATH: str = "/feature/get-online-features" 64 | WYVERN_HISTORICAL_FEATURES_PATH: str = "/feature/get-historical-features" 65 | WYVERN_FEATURE_STORE_URL: str = "https://api.wyvern.ai" 66 | 67 | WYVERN_API_KEY: str = "" 68 | 69 | # Snowflake configurations 70 | SNOWFLAKE_ACCOUNT: str = "" 71 | SNOWFLAKE_USER: str = "" 72 | SNOWFLAKE_PASSWORD: str = "" 73 | SNOWFLAKE_ROLE: str = "" 74 | SNOWFLAKE_WAREHOUSE: str = "" 75 | SNOWFLAKE_DATABASE: str = "" 76 | SNOWFLAKE_OFFLINE_STORE_SCHEMA: str = "PUBLIC" 77 | SNOWFLAKE_REALTIME_FEATURE_LOG_TABLE: str = "FEATURE_LOGS" 78 | 79 | # NOTE: aws configs are used for feature logging with AWS firehose 80 | AWS_ACCESS_KEY_ID: str = "" 81 | AWS_SECRET_ACCESS_KEY: str = "" 82 | AWS_REGION_NAME: str = "us-east-1" 83 | 84 | FEATURE_STORE_TIMEOUT: int = 60 85 | SERVER_TIMEOUT: int = 60 86 | 87 | # pipeline service configurations 88 | REDIS_BATCH_SIZE: int = 100 89 | 90 | WYVERN_INDEX_VERSION: int = 1 91 | 92 | MODELBIT_BATCH_SIZE: int = 30 93 | MODEL_BATCH_SIZE: int = 30 94 | 95 | # experimentation configurations 96 | EXPERIMENTATION_ENABLED: bool = False 97 | EXPERIMENTATION_PROVIDER: str = ExperimentationProvider.EPPO.value 98 | EPPO_API_KEY: str = "" 99 | 100 | # wyvern component flag 101 | FEATURE_STORE_ENABLED: bool = True 102 | EVENT_LOGGING_ENABLED: bool = True 103 | 104 | class Config: 105 | env_file = (".env", ".env.prod") 106 | env_file_encoding = "utf-8" 107 | 108 | 109 | settings = Settings() 110 | -------------------------------------------------------------------------------- /wyvern/core/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Skyvern-AI/wyvern/a6f5f72860e6f528e81fb816fa3f4275d08bed1d/wyvern/core/__init__.py -------------------------------------------------------------------------------- /wyvern/core/compression.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from typing import Any, Dict, Union 3 | 4 | import lz4.frame 5 | import msgspec 6 | 7 | msgspec_json_encoder = msgspec.json.Encoder() 8 | msgspec_json_decoder = msgspec.json.Decoder() 9 | 10 | 11 | def wyvern_encode(data: Dict[str, Any]) -> bytes: 12 | """ 13 | encode a dict to compressed bytes using lz4.frame 14 | """ 15 | return lz4.frame.compress(msgspec_json_encoder.encode(data)) 16 | 17 | 18 | def wyvern_decode(data: Union[bytes, str]) -> Dict[str, Any]: 19 | """ 20 | decode compressed bytes to a dict with lz4.frame 21 | """ 22 | return msgspec_json_decoder.decode(lz4.frame.decompress(data)) 23 | -------------------------------------------------------------------------------- /wyvern/core/http.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import logging 3 | 4 | import aiohttp 5 | 6 | from wyvern.exceptions import WyvernError 7 | 8 | logger = logging.getLogger(__name__) 9 | DEFAULT_REQUEST_TIMEOUT = 60 10 | timeout = aiohttp.ClientTimeout(total=DEFAULT_REQUEST_TIMEOUT) 11 | 12 | 13 | class AiohttpClientWrapper: 14 | """AiohttpClientWrapper is a singleton wrapper around aiohttp.ClientSession.""" 15 | 16 | async_client = None 17 | 18 | def start(self): 19 | """Instantiate the client. Call from the FastAPI startup hook.""" 20 | self.async_client = aiohttp.ClientSession(timeout=timeout) 21 | 22 | async def stop(self): 23 | """Gracefully shutdown. Call from FastAPI shutdown hook.""" 24 | if not self.async_client: 25 | return 26 | if self.async_client and not self.async_client.closed: 27 | await self.async_client.close() 28 | self.async_client = None 29 | 30 | def __call__(self): 31 | """Calling the instantiated AiohttpClientWrapper returns the wrapped singleton.""" 32 | # Ensure we don't use it if not started / running 33 | if self.async_client is None: 34 | raise WyvernError("AiohttpClientWrapper not started") 35 | 36 | return self.async_client 37 | 38 | 39 | aiohttp_client = AiohttpClientWrapper() 40 | """ 41 | The aiohttp client singleton. Use this to make requests. 42 | 43 | Example: 44 | ```python 45 | from wyvern.core.http import aiohttp_client 46 | aiohttp_client().get("https://www.wyvern.ai") 47 | ``` 48 | """ 49 | -------------------------------------------------------------------------------- /wyvern/entities/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Skyvern-AI/wyvern/a6f5f72860e6f528e81fb816fa3f4275d08bed1d/wyvern/entities/__init__.py -------------------------------------------------------------------------------- /wyvern/entities/candidate_entities.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from __future__ import annotations 3 | 4 | from typing import Generic, List, TypeVar 5 | 6 | from pydantic.generics import GenericModel 7 | 8 | from wyvern.entities.identifier_entities import WyvernDataModel 9 | from wyvern.wyvern_typing import GENERALIZED_WYVERN_ENTITY 10 | 11 | 12 | # TODO (suchintan): This should be renamed to ScoredEntity probably 13 | class ScoredCandidate( 14 | GenericModel, 15 | Generic[GENERALIZED_WYVERN_ENTITY], 16 | ): 17 | """ 18 | A candidate entity with a score. 19 | 20 | Attributes: 21 | entity: The candidate entity. 22 | score: The score of the candidate entity. Defaults to 0.0. 23 | """ 24 | 25 | entity: GENERALIZED_WYVERN_ENTITY 26 | score: float = 0.0 27 | 28 | 29 | class CandidateSetEntity( 30 | WyvernDataModel, 31 | GenericModel, 32 | Generic[GENERALIZED_WYVERN_ENTITY], 33 | ): 34 | """ 35 | A set of candidate entities. This is a generic model that can be used to represent a set of candidate entities. 36 | Attributes: 37 | candidates: The list of candidate entities. 38 | """ 39 | 40 | candidates: List[GENERALIZED_WYVERN_ENTITY] 41 | 42 | 43 | CANDIDATE_SET_ENTITY = TypeVar("CANDIDATE_SET_ENTITY", bound=CandidateSetEntity) 44 | -------------------------------------------------------------------------------- /wyvern/entities/feature_entities.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from __future__ import annotations 3 | 4 | import logging 5 | from typing import Dict, List 6 | 7 | import polars as pl 8 | from pydantic.main import BaseModel 9 | 10 | from wyvern.entities.identifier import Identifier, get_identifier_key 11 | from wyvern.wyvern_typing import WyvernFeature 12 | 13 | logger = logging.getLogger(__name__) 14 | 15 | IDENTIFIER = "IDENTIFIER" 16 | 17 | 18 | class FeatureData(BaseModel, frozen=True): 19 | """ 20 | A class to represent the features of an entity. 21 | 22 | Attributes: 23 | identifier: The identifier of the entity. 24 | features: A dictionary of feature names to feature values. 25 | """ 26 | 27 | identifier: Identifier 28 | features: Dict[str, WyvernFeature] = {} 29 | 30 | def __str__(self) -> str: 31 | return f"identifier={self.identifier} features={self.features}" 32 | 33 | def __repr__(self): 34 | return self.__str__() 35 | 36 | 37 | class FeatureDataFrame(BaseModel): 38 | """ 39 | A class to store features in a polars dataframe. 40 | """ 41 | 42 | df: pl.DataFrame = pl.DataFrame().with_columns( 43 | pl.Series(name=IDENTIFIER, dtype=pl.Utf8), 44 | ) 45 | 46 | class Config: 47 | arbitrary_types_allowed = True 48 | frozen = True 49 | 50 | def get_features( 51 | self, 52 | identifiers: List[Identifier], 53 | feature_names: List[str], 54 | ) -> pl.DataFrame: 55 | # Filter the dataframe by identifier. If the identifier is a composite identifier, use the primary identifier 56 | identifier_keys = [get_identifier_key(identifier) for identifier in identifiers] 57 | return self.get_features_by_identifier_keys( 58 | identifier_keys=identifier_keys, 59 | feature_names=feature_names, 60 | ) 61 | 62 | def get_features_by_identifier_keys( 63 | self, 64 | identifier_keys: List[str], 65 | feature_names: List[str], 66 | ) -> pl.DataFrame: 67 | # Filter the dataframe by identifier 68 | df = self.df.filter(pl.col(IDENTIFIER).is_in(identifier_keys)) 69 | 70 | # Process feature names, adding identifier to the selection 71 | feature_names = [IDENTIFIER] + feature_names 72 | existing_cols = df.columns 73 | for col_name in feature_names: 74 | if col_name not in existing_cols: 75 | # Add a new column filled with None values if it doesn't exist 76 | df = df.with_columns(pl.lit(None).alias(col_name)) 77 | df = df.select(feature_names) 78 | 79 | return df 80 | 81 | def get_all_features_for_identifier(self, identifier: Identifier) -> pl.DataFrame: 82 | identifier_key = get_identifier_key(identifier) 83 | return self.df.filter(pl.col(IDENTIFIER) == identifier_key) 84 | 85 | @staticmethod 86 | def build_empty_df( 87 | identifiers: List[Identifier], 88 | feature_names: List[str], 89 | ) -> FeatureDataFrame: 90 | """ 91 | Builds an empty polars df with the given identifiers and feature names. 92 | """ 93 | identifier_keys = [get_identifier_key(identifier) for identifier in identifiers] 94 | df_columns = [ 95 | pl.Series(name=IDENTIFIER, values=identifier_keys, dtype=pl.Object), 96 | ] 97 | df_columns.extend( 98 | [pl.lit(None).alias(feature_name) for feature_name in feature_names], # type: ignore 99 | ) 100 | return FeatureDataFrame(df=pl.DataFrame().with_columns(df_columns)) 101 | -------------------------------------------------------------------------------- /wyvern/entities/identifier.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from __future__ import annotations 3 | 4 | import logging 5 | from enum import Enum 6 | from typing import Union 7 | 8 | from pydantic.main import BaseModel 9 | 10 | from wyvern.config import settings 11 | from wyvern.utils import generate_index_key 12 | 13 | COMPOSITE_SEPARATOR = ":" 14 | logger = logging.getLogger(__name__) 15 | 16 | 17 | class SimpleIdentifierType(str, Enum): 18 | """ 19 | Simple identifier types are those that are not composite. 20 | """ 21 | 22 | PRODUCT = "product" 23 | QUERY = "query" 24 | BRAND = "brand" 25 | CATEGORY = "category" 26 | USER = "user" 27 | REQUEST = "request" 28 | 29 | 30 | def composite( 31 | primary_identifier_type: SimpleIdentifierType, 32 | secondary_identifier_type: SimpleIdentifierType, 33 | ) -> str: 34 | """ 35 | Composite identifier types are those that are composite. For example, a product with id p_1234 and type "product" 36 | a user with id u_1234 and type "user" would have a composite identifier of "p_1234:u_1234", and a composite 37 | identifier_type of "product:user". This is useful for indexing and searching for composite entities. 38 | """ 39 | return f"{primary_identifier_type.value}{COMPOSITE_SEPARATOR}{secondary_identifier_type.value}" 40 | 41 | 42 | class CompositeIdentifierType(str, Enum): 43 | """ 44 | Composite identifier types are those that are composite. For example, a composite identifier type of 45 | "product:user" would be a composite identifier type for a product and a user. This is useful for indexing and 46 | searching for composite entities. 47 | """ 48 | 49 | PRODUCT_QUERY = composite( 50 | SimpleIdentifierType.PRODUCT, 51 | SimpleIdentifierType.QUERY, 52 | ) 53 | BRAND_QUERY = composite(SimpleIdentifierType.BRAND, SimpleIdentifierType.QUERY) 54 | CATEGORY_QUERY = composite( 55 | SimpleIdentifierType.CATEGORY, 56 | SimpleIdentifierType.QUERY, 57 | ) 58 | USER_PRODUCT = composite(SimpleIdentifierType.PRODUCT, SimpleIdentifierType.USER) 59 | USER_BRAND = composite(SimpleIdentifierType.BRAND, SimpleIdentifierType.USER) 60 | USER_CATEGORY = composite(SimpleIdentifierType.CATEGORY, SimpleIdentifierType.USER) 61 | QUERY_USER = composite(SimpleIdentifierType.QUERY, SimpleIdentifierType.USER) 62 | 63 | 64 | IdentifierType = Union[SimpleIdentifierType, CompositeIdentifierType] 65 | 66 | 67 | class Identifier(BaseModel): 68 | """ 69 | Identifiers exist to represent a unique entity through their unique id and their type 70 | For example: a product with id p_1234 and type "product" or a user with id u_1234 and type "user" 71 | 72 | Composite identifiers are also possible, for example: 73 | a product with id p_1234 and type "product" 74 | a user with id u_1234 and type "user" 75 | 76 | The composite identifier would be "p_1234:u_1234", 77 | and the composite identifier_type would be "product:user" 78 | """ 79 | 80 | identifier: str 81 | identifier_type: str 82 | 83 | class Config: 84 | frozen = True 85 | 86 | def __str__(self) -> str: 87 | return f"{self.identifier_type}::{self.identifier}" 88 | 89 | def __repr__(self): 90 | return self.__str__() 91 | 92 | def __hash__(self): 93 | return hash(self.__str__()) 94 | 95 | @staticmethod 96 | def as_identifier_type( 97 | identifier_type_string: str, 98 | ) -> IdentifierType: 99 | try: 100 | return SimpleIdentifierType(identifier_type_string) 101 | except ValueError: 102 | pass 103 | return CompositeIdentifierType(identifier_type_string) 104 | 105 | def index_key(self) -> str: 106 | return generate_index_key( 107 | settings.PROJECT_NAME, 108 | self.identifier_type, 109 | self.identifier, 110 | ) 111 | 112 | 113 | class CompositeIdentifier(Identifier): 114 | """ 115 | Composite identifiers exist to represent a unique entity through their unique id and their type. At most, they 116 | can have two identifiers and two identifier types. For example: 117 | a product with id p_1234 and type "product" 118 | a user with id u_1234 and type "user" 119 | 120 | The composite identifier would be "p_1234:u_1234", and the composite identifier_type would be "product:user". 121 | """ 122 | 123 | primary_identifier: Identifier 124 | secondary_identifier: Identifier 125 | 126 | def __init__( 127 | self, primary_identifier: Identifier, secondary_identifier: Identifier, **kwargs 128 | ): 129 | identifier = f"{primary_identifier.identifier}{COMPOSITE_SEPARATOR}{secondary_identifier.identifier}" 130 | identifier_type = self.as_identifier_type( 131 | primary_identifier.identifier_type 132 | + COMPOSITE_SEPARATOR 133 | + secondary_identifier.identifier_type, 134 | ) 135 | super().__init__( 136 | identifier=identifier, 137 | identifier_type=identifier_type.value, 138 | primary_identifier=primary_identifier, 139 | secondary_identifier=secondary_identifier, 140 | **kwargs, 141 | ) 142 | 143 | 144 | def get_identifier_key( 145 | identifier: Identifier, 146 | ) -> str: 147 | """ 148 | Returns the identifier key for a given identifier. If the identifier is a composite identifier, the primary 149 | identifier is used. This is useful while doing feature retrievals for composite entities. 150 | """ 151 | if isinstance(identifier, CompositeIdentifier): 152 | return str(identifier.primary_identifier) 153 | return str(identifier) 154 | -------------------------------------------------------------------------------- /wyvern/entities/index_entities.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from typing import Any, Dict, List, Optional 3 | 4 | from pydantic import BaseModel, Field 5 | 6 | from wyvern.entities.identifier import SimpleIdentifierType 7 | 8 | MIN_INDEX_ITEMS = 0 9 | MAX_INDEX_ITEMS = 1000 10 | 11 | 12 | class IndexResponse(BaseModel): 13 | entity_type: str 14 | entity_ids: List[str] 15 | 16 | 17 | class IndexRequest(BaseModel): 18 | entities: List[Dict[Any, Any]] = Field( 19 | min_items=MIN_INDEX_ITEMS, 20 | max_items=MAX_INDEX_ITEMS, 21 | ) 22 | entity_type: SimpleIdentifierType 23 | entity_key: Optional[str] 24 | 25 | 26 | class EntitiesRequest(BaseModel): 27 | entity_ids: List[str] = Field( 28 | min_items=MIN_INDEX_ITEMS, 29 | max_items=MAX_INDEX_ITEMS, 30 | ) 31 | entity_type: SimpleIdentifierType 32 | 33 | 34 | class DeleteEntitiesRequest(EntitiesRequest): 35 | pass 36 | 37 | 38 | class GetEntitiesResponse(BaseModel): 39 | entity_type: str 40 | entities: Dict[str, Optional[Dict[Any, Any]]] = Field(default_factory=dict) 41 | 42 | 43 | class DeleteEntitiesResponse(BaseModel): 44 | entity_type: str 45 | entity_ids: List[str] 46 | -------------------------------------------------------------------------------- /wyvern/entities/model_entities.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from typing import Any, Dict, Generic, List, Optional, TypeVar, Union 3 | 4 | from pydantic.generics import GenericModel 5 | 6 | from wyvern.entities.identifier import Identifier 7 | from wyvern.exceptions import WyvernModelInputError 8 | from wyvern.wyvern_typing import GENERALIZED_WYVERN_ENTITY, REQUEST_ENTITY 9 | 10 | MODEL_OUTPUT_DATA_TYPE = TypeVar( 11 | "MODEL_OUTPUT_DATA_TYPE", 12 | bound=Union[ 13 | float, 14 | str, 15 | List[float], 16 | Dict[str, Any], 17 | ], 18 | ) 19 | """ 20 | MODEL_OUTPUT_DATA_TYPE is the type of the output of the model. It can be a float, a string, or a list of floats 21 | (e.g. a list of probabilities, embeddings, etc.) 22 | """ 23 | 24 | 25 | class ModelOutput(GenericModel, Generic[MODEL_OUTPUT_DATA_TYPE]): 26 | """ 27 | This class defines the output of a model. 28 | 29 | Args: 30 | data: A dictionary mapping entity identifiers to model outputs. The model outputs can also be None. 31 | model_name: The name of the model. This is optional. 32 | """ 33 | 34 | data: Dict[Identifier, Optional[MODEL_OUTPUT_DATA_TYPE]] 35 | model_name: Optional[str] = None 36 | 37 | def get_entity_output( 38 | self, 39 | identifier: Identifier, 40 | ) -> Optional[MODEL_OUTPUT_DATA_TYPE]: 41 | """ 42 | Get the model output for a given entity identifier. 43 | 44 | Args: 45 | identifier: The identifier of the entity. 46 | 47 | Returns: 48 | The model output for the given entity identifier. This can also be None if the model output is None. 49 | """ 50 | return self.data.get(identifier) 51 | 52 | 53 | class ModelInput(GenericModel, Generic[GENERALIZED_WYVERN_ENTITY, REQUEST_ENTITY]): 54 | """ 55 | This class defines the input to a model. 56 | 57 | Args: 58 | request: The request that will be used to generate the model input. 59 | entities: A list of entities that will be used to generate the model input. 60 | """ 61 | 62 | request: REQUEST_ENTITY 63 | entities: List[GENERALIZED_WYVERN_ENTITY] = [] 64 | 65 | @property 66 | def first_entity(self) -> GENERALIZED_WYVERN_ENTITY: 67 | """ 68 | Get the first entity in the list of entities. This is useful when you know that there is only one entity. 69 | 70 | Returns: 71 | The first entity in the list of entities. 72 | """ 73 | if not self.entities: 74 | raise WyvernModelInputError(model_input=self) 75 | return self.entities[0] 76 | 77 | @property 78 | def first_identifier(self) -> Identifier: 79 | """ 80 | Get the identifier of the first entity in the list of entities. This is useful when you know that there is only 81 | one entity. 82 | 83 | Returns: 84 | The identifier of the first entity in the list of entities. 85 | """ 86 | return self.first_entity.identifier 87 | 88 | 89 | MODEL_INPUT = TypeVar("MODEL_INPUT", bound=ModelInput) 90 | MODEL_OUTPUT = TypeVar("MODEL_OUTPUT", bound=ModelOutput) 91 | 92 | 93 | class ChainedModelInput(ModelInput, Generic[GENERALIZED_WYVERN_ENTITY, REQUEST_ENTITY]): 94 | upstream_model_output: Dict[ 95 | Identifier, 96 | Optional[ 97 | Union[ 98 | float, 99 | str, 100 | List[float], 101 | Dict[str, Optional[Union[float, str, list[float]]]], 102 | ] 103 | ], 104 | ] 105 | upstream_model_name: Optional[str] = None 106 | -------------------------------------------------------------------------------- /wyvern/entities/request.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from typing import Optional 3 | 4 | from pydantic import PrivateAttr 5 | 6 | from wyvern.entities.identifier import Identifier 7 | from wyvern.entities.identifier_entities import WyvernDataModel 8 | 9 | 10 | class BaseWyvernRequest(WyvernDataModel): 11 | """ 12 | Base class for all Wyvern requests. This class is used to generate an identifier for the request. 13 | 14 | Attributes: 15 | request_id: The request id. 16 | include_events: Whether to include events in the response. 17 | """ 18 | 19 | request_id: str 20 | include_events: Optional[bool] = False 21 | 22 | _identifier: Identifier = PrivateAttr() 23 | 24 | def __init__(self, **kwargs): 25 | super().__init__(**kwargs) 26 | self._identifier = self.generate_identifier() 27 | 28 | @property 29 | def identifier(self) -> Identifier: 30 | return self._identifier 31 | 32 | def generate_identifier(self) -> Identifier: 33 | """ 34 | Generates an identifier for the request. 35 | 36 | Returns: 37 | Identifier: The identifier for the request. The identifier type is "request". 38 | """ 39 | return Identifier( 40 | identifier=self.request_id, 41 | identifier_type="request", 42 | ) 43 | -------------------------------------------------------------------------------- /wyvern/event_logging/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Skyvern-AI/wyvern/a6f5f72860e6f528e81fb816fa3f4275d08bed1d/wyvern/event_logging/__init__.py -------------------------------------------------------------------------------- /wyvern/event_logging/event_logger.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from datetime import datetime 3 | from typing import Any, Callable, List 4 | 5 | from wyvern import request_context 6 | from wyvern.components.events.events import ( 7 | ENTITY_EVENT_DATA_TYPE, 8 | CustomEvent, 9 | LoggedEvent, 10 | ) 11 | 12 | 13 | def log_events(event_generator: Callable[[], List[LoggedEvent]]): 14 | """ 15 | Logs events to the current request context. 16 | 17 | Args: 18 | event_generator: A function that returns a list of events to be logged. 19 | """ 20 | request_context.ensure_current_request().events.append(event_generator) 21 | 22 | 23 | def get_logged_events() -> List[LoggedEvent[Any]]: 24 | """ 25 | Returns: 26 | A list of all the events logged in the current request context. 27 | """ 28 | return [ 29 | event 30 | for event_generator in request_context.ensure_current_request().events 31 | for event in event_generator() 32 | ] 33 | 34 | 35 | def get_logged_events_generator() -> List[Callable[[], List[LoggedEvent[Any]]]]: 36 | """ 37 | Returns: 38 | A list of all the event generators logged in the current request context. 39 | """ 40 | return request_context.ensure_current_request().events 41 | 42 | 43 | def log_custom_events(events: List[ENTITY_EVENT_DATA_TYPE]) -> None: 44 | """ 45 | Logs custom events to the current request context. 46 | 47 | Args: 48 | events: A list of custom events to be logged. 49 | """ 50 | request = request_context.ensure_current_request() 51 | api_source = request.url_path 52 | request_id = request.request_id 53 | run_id = request.run_id 54 | 55 | def event_generator() -> List[LoggedEvent[Any]]: 56 | timestamp = datetime.utcnow() 57 | return [ 58 | CustomEvent( 59 | request_id=request_id, 60 | run_id=run_id, 61 | api_source=api_source, 62 | event_timestamp=timestamp, 63 | event_data=event, 64 | ) 65 | for event in events 66 | ] 67 | 68 | request_context.ensure_current_request().events.append(event_generator) 69 | -------------------------------------------------------------------------------- /wyvern/exceptions.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from typing import Optional 3 | 4 | from wyvern import request_context 5 | 6 | 7 | class WyvernError(Exception): 8 | """Base class for all Wyvern errors. 9 | 10 | Attributes: 11 | message: The error message. 12 | error_code: The error code. 13 | """ 14 | 15 | message = "Wyvern error" 16 | 17 | def __init__( 18 | self, 19 | message: Optional[str] = None, 20 | error_code: int = 0, 21 | **kwargs, 22 | ) -> None: 23 | self.error_code = error_code 24 | self.kwargs = kwargs 25 | if message: 26 | self.message = message 27 | try: 28 | self._error_string = self.message.format(**kwargs) 29 | except Exception: 30 | # at least get the core message out if something happened 31 | self._error_string = self.message 32 | wyvern_request = request_context.current() 33 | request_id = None 34 | if wyvern_request and wyvern_request.request_id: 35 | request_id = wyvern_request.request_id 36 | elif "request_id" in kwargs: 37 | request_id = kwargs["request_id"] 38 | 39 | self.request_id = request_id 40 | if self.request_id: 41 | self._error_string = f"[request_id={self.request_id}] {self._error_string}" 42 | 43 | def __str__(self) -> str: 44 | return f"{self.__class__.__name__}: {self._error_string}" 45 | 46 | 47 | class WyvernEntityValidationError(WyvernError): 48 | """ 49 | Raised when entity data is invalid 50 | """ 51 | 52 | message = "{entity_key} is missing in entity data: {entity}" 53 | 54 | 55 | class PaginationError(WyvernError): 56 | """ 57 | Raised when there is an error in pagination 58 | """ 59 | 60 | pass 61 | 62 | 63 | class WyvernRouteRegistrationError(WyvernError): 64 | """ 65 | Raised when there is an error in registering a route 66 | """ 67 | 68 | message = ( 69 | "WyvernRouteRegistrationError: Invalid component: {component}. To register a route, " 70 | "the component must be a subclass of APIComponentRoute" 71 | ) 72 | 73 | 74 | class ComponentAlreadyDefinedInPipelineComponentError(WyvernError): 75 | """ 76 | Raised when a component is already defined in a pipeline component 77 | """ 78 | 79 | message = "'{component_type}' is already defined by the PipelineComponent. It cannot be passed as an upstream!" 80 | 81 | 82 | class WyvernFeatureStoreError(WyvernError): 83 | """ 84 | Raised when there is an error in feature store 85 | """ 86 | 87 | message = "Received error from feature store: {error}" 88 | 89 | 90 | class WyvernFeatureNameError(WyvernError): 91 | """ 92 | Raised when there is an error in feature name 93 | """ 94 | 95 | message = ( 96 | "Invalid online feature names: {invalid_feature_names}. " 97 | "feature references must have format 'feature_view:feature', e.g. customer_fv:daily_transactions. " 98 | "Are these realtime features? Make sure you define realtime feature component and register them." 99 | ) 100 | 101 | 102 | class WyvernFeatureValueError(WyvernError): 103 | """ 104 | Raised when there is an error in feature value 105 | """ 106 | 107 | message = "More than one feature value found for identifier={identifier} feature_name={feature_name}." 108 | 109 | 110 | class WyvernModelInputError(WyvernError): 111 | """ 112 | Raised when there is an error in model input 113 | """ 114 | 115 | message = ( 116 | "Invalid ModelInput: {model_input}" 117 | "ModelInput.entities must contain at least one entity." 118 | ) 119 | 120 | 121 | class WyvernModelbitTokenMissingError(WyvernError): 122 | """ 123 | Raised when modelbit token is missing 124 | """ 125 | 126 | message = "Modelbit authentication token is required." 127 | 128 | 129 | class WyvernModelbitValidationError(WyvernError): 130 | """ 131 | Raised when modelbit validation fails 132 | """ 133 | 134 | message = "Generated modelbit requests length does not match the number of target entities." 135 | 136 | 137 | class WyvernAPIKeyMissingError(WyvernError): 138 | """ 139 | Raised when api key is missing 140 | """ 141 | 142 | message = ( 143 | "Wyvern api key is missing. " 144 | "Pass api_key to WyvernAPI or define WYVERN_API_KEY in your environment." 145 | ) 146 | 147 | 148 | class ExperimentationProviderNotSupportedError(WyvernError): 149 | """ 150 | Raised when experimentation provider is not supported 151 | """ 152 | 153 | message = "Received error from feature store: {provider_name}" 154 | 155 | 156 | class ExperimentationClientInitializationError(WyvernError): 157 | """ 158 | Raised when experimentation client initialization fails 159 | """ 160 | 161 | message = "Failed to initialize experimentation client for provider: {provider_name}, {error}" 162 | 163 | 164 | class EntityColumnMissingError(WyvernError): 165 | message = "Entity column {entity} is missing in the entity data" 166 | 167 | 168 | class MissingModelChainOutputError(WyvernError): 169 | message = "Model chain output is missing" 170 | 171 | 172 | class MissingModelOutputError(WyvernError): 173 | message = "Identifier is missing in the model output" 174 | 175 | 176 | class WyvernLoggingOriginalIdentifierMissingError(WyvernError): 177 | """ 178 | Raised when original identifier is missing during feature logging 179 | """ 180 | 181 | message = "Original identifier is missing for primary identifier={identifier} feature_name={feature_name}." 182 | -------------------------------------------------------------------------------- /wyvern/experimentation/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Skyvern-AI/wyvern/a6f5f72860e6f528e81fb816fa3f4275d08bed1d/wyvern/experimentation/__init__.py -------------------------------------------------------------------------------- /wyvern/experimentation/client.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import logging 3 | import traceback 4 | from typing import Optional 5 | 6 | from wyvern.config import settings 7 | from wyvern.exceptions import ExperimentationProviderNotSupportedError 8 | from wyvern.experimentation.providers.base import ExperimentationProvider 9 | from wyvern.experimentation.providers.eppo_provider import EppoExperimentationClient 10 | 11 | logger = logging.getLogger(__name__) 12 | 13 | 14 | class ExperimentationClient: 15 | """ 16 | A client for interacting with experimentation providers. 17 | """ 18 | 19 | def __init__(self, provider_name: str, api_key: Optional[str] = None): 20 | """ 21 | Initializes the ExperimentationClient with a specified provider. 22 | 23 | Args: 24 | - provider_name (str): The name of the experimentation provider (e.g., "eppo"). 25 | """ 26 | if not settings.EXPERIMENTATION_ENABLED: 27 | logger.info("Experimentation is disabled") 28 | self.enabled = False 29 | return 30 | 31 | self.enabled = True 32 | if provider_name == ExperimentationProvider.EPPO.value: 33 | logger.info("Using EPPO experimentation provider") 34 | self.provider = EppoExperimentationClient(api_key=api_key) 35 | else: 36 | raise ExperimentationProviderNotSupportedError(provider_name=provider_name) 37 | 38 | def get_experiment_result( 39 | self, experiment_id: str, entity_id: str, **kwargs 40 | ) -> Optional[str]: 41 | """ 42 | Get the result (variant) for a given experiment and entity using the chosen provider. 43 | 44 | Args: 45 | - experiment_id (str): The unique ID of the experiment. 46 | - entity_id (str): The unique ID of the entity. 47 | - kwargs (dict): Any additional arguments to pass to the provider for targeting. 48 | 49 | Returns: 50 | - str: The result (variant) assigned to the entity for the specified experiment. 51 | """ 52 | if not self.enabled: 53 | logger.error( 54 | "get_experiment_result called when experimentation is disabled", 55 | ) 56 | return None 57 | 58 | result = None 59 | has_error = False 60 | 61 | try: 62 | result = self.provider.get_result(experiment_id, entity_id, **kwargs) 63 | except Exception: 64 | logger.exception( 65 | f"Error getting experiment result. Experiment ID: {experiment_id}, Entity ID: {entity_id} | " 66 | f"{traceback.format_exc()}", 67 | ) 68 | has_error = True 69 | 70 | self.provider.log_result(experiment_id, entity_id, result, has_error, **kwargs) 71 | return result 72 | 73 | 74 | experimentation_client = ExperimentationClient( 75 | provider_name=settings.EXPERIMENTATION_PROVIDER, 76 | ) 77 | -------------------------------------------------------------------------------- /wyvern/experimentation/experimentation_logging.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from datetime import datetime 3 | from typing import Dict, Optional 4 | 5 | from pydantic import BaseModel 6 | 7 | from wyvern.components.events.events import EventType, LoggedEvent 8 | 9 | 10 | class ExperimentationEventData(BaseModel): 11 | """ 12 | Data class for ExperimentationEvent. 13 | 14 | Attributes: 15 | experiment_id: The experiment id. 16 | entity_id: The entity id. 17 | result: The result of the experiment. Can be None. 18 | timestamp: The timestamp of the event. 19 | metadata: The metadata of the event such as targeting parameters etc. 20 | has_error: Whether the request has errored or not. 21 | """ 22 | 23 | experiment_id: str 24 | entity_id: str 25 | result: Optional[str] 26 | timestamp: datetime 27 | metadata: Dict 28 | has_error: bool 29 | 30 | 31 | class ExperimentationEvent(LoggedEvent[ExperimentationEventData]): 32 | """ 33 | Event class for ExperimentationEvent. 34 | 35 | Attributes: 36 | event_type: The event type. This is always EventType.EXPERIMENTATION. 37 | """ 38 | 39 | event_type: EventType = EventType.EXPERIMENTATION 40 | -------------------------------------------------------------------------------- /wyvern/experimentation/providers/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Skyvern-AI/wyvern/a6f5f72860e6f528e81fb816fa3f4275d08bed1d/wyvern/experimentation/providers/__init__.py -------------------------------------------------------------------------------- /wyvern/experimentation/providers/base.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from abc import ABC, abstractmethod 3 | from enum import Enum 4 | from typing import Optional 5 | 6 | 7 | class ExperimentationProvider(str, Enum): 8 | """ 9 | An enum for the experimentation providers. 10 | """ 11 | 12 | EPPO = "eppo" 13 | 14 | 15 | class BaseExperimentationProvider(ABC): 16 | """ 17 | A base class for experimentation providers. 18 | All providers should inherit from this and implement the necessary methods. 19 | """ 20 | 21 | @abstractmethod 22 | def get_result(self, experiment_id: str, entity_id: str, **kwargs) -> Optional[str]: 23 | """ 24 | Get the result (variant) for a given experiment and entity. 25 | 26 | Args: 27 | - experiment_id (str): The unique ID of the experiment. 28 | - entity_id (str): The unique ID of the entity. 29 | - kwargs (dict): Any additional arguments to pass to the provider for targeting. 30 | 31 | Returns: 32 | - str | None: The result (variant) assigned to the entity for the specified experiment or None. 33 | """ 34 | raise NotImplementedError 35 | 36 | @abstractmethod 37 | def log_result( 38 | self, 39 | experiment_id: str, 40 | entity_id: str, 41 | variant: Optional[str] = None, 42 | has_error: bool = False, 43 | **kwargs 44 | ) -> None: 45 | """ 46 | Log the result (variant) for a given experiment and entity. 47 | 48 | Args: 49 | - experiment_id (str): The unique ID of the experiment. 50 | - entity_id (str): The unique ID of the entity. 51 | - variant (str): The result (variant) assigned to the entity for the specified experiment. 52 | - kwargs (dict): Any additional arguments to pass to the provider for targeting. 53 | 54 | Returns: 55 | - None 56 | """ 57 | raise NotImplementedError 58 | -------------------------------------------------------------------------------- /wyvern/experimentation/providers/eppo_provider.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from datetime import datetime 3 | from typing import List, Optional 4 | 5 | import eppo_client # type: ignore 6 | from eppo_client.assignment_logger import AssignmentLogger # type: ignore 7 | from eppo_client.config import Config # type: ignore 8 | 9 | from wyvern import request_context 10 | from wyvern.components.events.events import LoggedEvent 11 | from wyvern.config import settings 12 | from wyvern.event_logging import event_logger 13 | from wyvern.exceptions import ExperimentationClientInitializationError 14 | from wyvern.experimentation.experimentation_logging import ( 15 | ExperimentationEvent, 16 | ExperimentationEventData, 17 | ) 18 | from wyvern.experimentation.providers.base import ( 19 | BaseExperimentationProvider, 20 | ExperimentationProvider, 21 | ) 22 | 23 | 24 | class EppoExperimentationClient(BaseExperimentationProvider): 25 | """ 26 | An experimentation client specifically for the Eppo platform. 27 | 28 | Extends the BaseExperimentationProvider to provide functionality using the Eppo client. 29 | 30 | Methods: 31 | - __init__() -> None 32 | - get_result(experiment_id: str, entity_id: str, **kwargs) -> str 33 | - log_result(experiment_id: str, entity_id: str, variant: str) -> None 34 | """ 35 | 36 | def __init__(self, api_key: Optional[str] = None): 37 | api_key = api_key or settings.EPPO_API_KEY 38 | # AssignmentLogger is a dummy logger that does not log anything. 39 | # We handle logging ourselves in the log_result method. 40 | try: 41 | client_config = Config( 42 | api_key=api_key, 43 | assignment_logger=AssignmentLogger(), 44 | ) 45 | eppo_client.init(client_config) 46 | except Exception as e: 47 | raise ExperimentationClientInitializationError( 48 | provider_name=ExperimentationProvider.EPPO.value, 49 | error=e, 50 | ) 51 | 52 | def get_result(self, experiment_id: str, entity_id: str, **kwargs) -> Optional[str]: 53 | """ 54 | Fetches the variant for a given experiment and entity from the Eppo client. 55 | 56 | Args: 57 | - experiment_id (str): The unique ID of the experiment. 58 | - entity_id (str): The unique ID of the entity (e.g., user or other subject). 59 | - **kwargs: Additional arguments to be passed to the Eppo client's get_assignment method. 60 | 61 | Returns: 62 | - str | None: The result (variant) assigned to the entity for the specified experiment or None. 63 | """ 64 | client = eppo_client.get_instance() 65 | variation = client.get_assignment_variation(entity_id, experiment_id, kwargs) 66 | return variation.value if variation else None 67 | 68 | def log_result( 69 | self, 70 | experiment_id: str, 71 | entity_id: str, 72 | variant: Optional[str] = None, 73 | has_error: bool = False, 74 | **kwargs 75 | ) -> None: 76 | """ 77 | Logs the result for a given experiment and entity. 78 | 79 | Args: 80 | - experiment_id (str): The unique ID of the experiment. 81 | - entity_id (str): The unique ID of the entity. 82 | - variant (str): The assigned variant for the given experiment and entity. 83 | 84 | Note: This method is overridden to do nothing because the assignment logger we set in Eppo already 85 | handles result logging upon assignment. 86 | """ 87 | 88 | request = request_context.ensure_current_request() 89 | 90 | def event_generator() -> List[LoggedEvent[ExperimentationEventData]]: 91 | timestamp = datetime.utcnow() 92 | request_id = request.request_id 93 | api_source = request.url_path 94 | run_id = request.run_id 95 | 96 | return [ 97 | ExperimentationEvent( 98 | request_id=request_id, 99 | run_id=run_id, 100 | api_source=api_source, 101 | event_timestamp=timestamp, 102 | event_data=ExperimentationEventData( 103 | experiment_id=experiment_id, 104 | entity_id=entity_id, 105 | result=variant, 106 | timestamp=timestamp, 107 | metadata=kwargs, 108 | has_error=has_error, 109 | ), 110 | ), 111 | ] 112 | 113 | event_logger.log_events(event_generator) 114 | -------------------------------------------------------------------------------- /wyvern/feature_store/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Skyvern-AI/wyvern/a6f5f72860e6f528e81fb816fa3f4275d08bed1d/wyvern/feature_store/__init__.py -------------------------------------------------------------------------------- /wyvern/feature_store/constants.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | SQL_COLUMN_SEPARATOR = "__" 3 | FULL_FEATURE_NAME_SEPARATOR = ":" 4 | -------------------------------------------------------------------------------- /wyvern/feature_store/schemas.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from datetime import datetime 3 | from typing import Any, Dict, List, Optional 4 | 5 | from pydantic import BaseModel, Field 6 | 7 | 8 | class GetOnlineFeaturesRequest(BaseModel): 9 | """ 10 | Request object for getting online features. 11 | 12 | Attributes: 13 | entities: A dictionary of entity name to entity value. 14 | features: A list of feature names. 15 | full_feature_names: A boolean indicating whether to return full feature names. If True, the feature names will 16 | be returned in the format `__`. If False, only the feature names will be 17 | returned. 18 | """ 19 | 20 | entities: Dict[str, Any] = {} 21 | features: List[str] = [] 22 | full_feature_names: bool = False 23 | 24 | 25 | class GetHistoricalFeaturesRequest(BaseModel): 26 | """ 27 | Request object for getting historical features. 28 | 29 | Attributes: 30 | entities: A dictionary of entity name to entity value. 31 | timestamps: A list of timestamps. Used to retrieve historical features at specific timestamps. If not provided, 32 | the latest feature values will be returned. 33 | features: A list of feature names. 34 | """ 35 | 36 | entities: Dict[str, List[Any]] 37 | timestamps: List[datetime] = [] 38 | features: List[str] = [] 39 | 40 | 41 | class GetFeastHistoricalFeaturesRequest(BaseModel): 42 | """ 43 | Request object for getting historical features from Feast. 44 | 45 | Attributes: 46 | full_feature_names: A boolean indicating whether to return full feature names. If True, the feature names will 47 | be returned in the format `__`. If False, only the feature names will be 48 | returned. 49 | entities: A dictionary of entity name to entity value. 50 | features: A list of feature names. 51 | """ 52 | 53 | full_feature_names: bool = False 54 | entities: Dict[str, List[Any]] = {} 55 | features: List[str] = [] 56 | 57 | 58 | class GetHistoricalFeaturesResponse(BaseModel): 59 | """ 60 | Response object for getting historical features. 61 | 62 | Attributes: 63 | results: A list of dictionaries containing feature values. 64 | """ 65 | 66 | results: List[Dict[str, Any]] = [] 67 | 68 | 69 | class MaterializeRequest(BaseModel): 70 | """ 71 | Request object for materializing feature views. 72 | 73 | Attributes: 74 | end_date: The end date of the materialization window. Defaults to the current time. 75 | feature_views: A list of feature view names to materialize. If not provided, all feature views will be 76 | materialized. 77 | start_date: The start date of the materialization window. Defaults to None, which will use the start date of 78 | the feature view. 79 | """ 80 | 81 | end_date: datetime = Field(default_factory=datetime.utcnow) 82 | feature_views: Optional[List[str]] = None 83 | start_date: Optional[datetime] = None 84 | 85 | 86 | class RequestEntityIdentifierObjects(BaseModel): 87 | """ 88 | Request object for getting entity identifier objects. 89 | 90 | Attributes: 91 | request_ids: A list of request IDs. 92 | entity_identifiers: A list of entity identifiers. 93 | feature_names: A list of feature names. 94 | """ 95 | 96 | request_ids: List[str] = [] 97 | entity_identifiers: List[str] = [] 98 | feature_names: List[str] = [] 99 | -------------------------------------------------------------------------------- /wyvern/helper/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Skyvern-AI/wyvern/a6f5f72860e6f528e81fb816fa3f4275d08bed1d/wyvern/helper/__init__.py -------------------------------------------------------------------------------- /wyvern/helper/sort.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from enum import Enum 3 | 4 | from pydantic import BaseModel 5 | 6 | 7 | class SortEnum(str, Enum): 8 | """ 9 | Enum for sort order. 10 | """ 11 | 12 | asc = "asc" 13 | desc = "desc" 14 | 15 | 16 | class Sort(BaseModel): 17 | """ 18 | Sort class for sorting the results. 19 | 20 | Attributes: 21 | sort_key: The key to sort on. 22 | sort_field: The field to sort on. 23 | sort_order: The order to sort on. Defaults to desc. 24 | """ 25 | 26 | sort_key: str 27 | sort_field: str 28 | sort_order: SortEnum = SortEnum.desc 29 | -------------------------------------------------------------------------------- /wyvern/index.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from typing import Any, Dict, List, Optional, Sequence 3 | 4 | from wyvern.redis import wyvern_redis 5 | 6 | 7 | class WyvernIndex: 8 | @classmethod 9 | async def get(cls, entity_type: str, entity_id: str) -> Optional[Dict[str, Any]]: 10 | return await wyvern_redis.get_entity( 11 | entity_type=entity_type, 12 | entity_id=entity_id, 13 | ) 14 | 15 | @classmethod 16 | async def bulk_get( 17 | cls, 18 | entity_type: str, 19 | entity_ids: Sequence[str], 20 | ) -> List[Optional[Dict[str, Any]]]: 21 | if not entity_ids: 22 | return [] 23 | return await wyvern_redis.get_entities( 24 | entity_type=entity_type, 25 | entity_ids=entity_ids, 26 | ) 27 | 28 | @classmethod 29 | async def delete(cls, entity_type: str, entity_id: str) -> None: 30 | await wyvern_redis.delete_entity( 31 | entity_type=entity_type, 32 | entity_id=entity_id, 33 | ) 34 | 35 | @classmethod 36 | async def bulk_delete( 37 | cls, 38 | entity_type: str, 39 | entity_ids: Sequence[str], 40 | ) -> None: 41 | if not entity_ids: 42 | return 43 | await wyvern_redis.delete_entities( 44 | entity_type=entity_type, 45 | entity_ids=entity_ids, 46 | ) 47 | 48 | 49 | class WyvernEntityIndex: 50 | @classmethod 51 | async def get( 52 | cls, 53 | entity_type: str, 54 | entity_id: str, 55 | ) -> Optional[Dict[str, Any]]: 56 | return await WyvernIndex.get( 57 | entity_type=entity_type, 58 | entity_id=entity_id, 59 | ) 60 | 61 | @classmethod 62 | async def bulk_get( 63 | cls, 64 | entity_type: str, 65 | entity_ids: Sequence[str], 66 | ) -> List[Optional[Dict[str, Any]]]: 67 | return await WyvernIndex.bulk_get( 68 | entity_type=entity_type, 69 | entity_ids=entity_ids, 70 | ) 71 | 72 | @classmethod 73 | async def delete(cls, entity_type: str, entity_id: str) -> None: 74 | await WyvernIndex.delete( 75 | entity_type=entity_type, 76 | entity_id=entity_id, 77 | ) 78 | 79 | @classmethod 80 | async def bulk_delete( 81 | cls, 82 | entity_type: str, 83 | entity_ids: Sequence[str], 84 | ) -> None: 85 | await WyvernIndex.bulk_delete( 86 | entity_type=entity_type, 87 | entity_ids=entity_ids, 88 | ) 89 | -------------------------------------------------------------------------------- /wyvern/redis.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import logging 3 | from typing import Any, Dict, List, Optional, Sequence 4 | 5 | from redis.asyncio import Redis 6 | 7 | from wyvern.config import settings 8 | from wyvern.core.compression import wyvern_decode, wyvern_encode 9 | from wyvern.utils import generate_index_key 10 | from wyvern.wyvern_request import WyvernRequest 11 | 12 | logger = logging.getLogger(__name__) 13 | 14 | REDIS_BATCH_SIZE = settings.REDIS_BATCH_SIZE 15 | 16 | 17 | class WyvernRedis: 18 | """ 19 | WyvernRedis is a wrapper for redis client to help index your entities in redis with Wyvern's convention 20 | """ 21 | 22 | def __init__( 23 | self, 24 | scope: str = "", 25 | redis_host: Optional[str] = None, 26 | redis_port: Optional[int] = None, 27 | ) -> None: 28 | """ 29 | scope is used to prefix the redis key. You can use the environment variable PROJECT_NAME to set the scope. 30 | """ 31 | host = redis_host or settings.REDIS_HOST 32 | if not host: 33 | raise ValueError("redis host is not set or found in environment variable") 34 | port = redis_port or settings.REDIS_PORT 35 | if not port: 36 | raise ValueError("redis port is not set or found in environment variable") 37 | self.redis_connection: Redis = Redis( 38 | host=host, 39 | port=port, 40 | ) 41 | self.key_prefix = scope or settings.PROJECT_NAME 42 | 43 | # TODO (shu): This entire file shouldn't be called redis.py -- this is specific to indexing 44 | # We should actually have a redis.py file that does any of the required logic.. and mock that at most 45 | async def bulk_index( 46 | self, 47 | entities: List[Dict[str, Any]], 48 | entity_key: str, 49 | entity_type: str, 50 | ) -> List[str]: 51 | if not entities: 52 | return [] 53 | mapping = { 54 | generate_index_key( 55 | self.key_prefix, 56 | entity_type, 57 | entity[entity_key], 58 | ): wyvern_encode(entity) 59 | for entity in entities 60 | } 61 | await self.redis_connection.mset(mapping=mapping) # type: ignore 62 | return [entity[entity_key] for entity in entities] 63 | 64 | async def get(self, index_key: str) -> Optional[str]: 65 | return await self.redis_connection.get(index_key) 66 | 67 | async def mget(self, index_keys: List[str]) -> List[Optional[str]]: 68 | if not index_keys: 69 | return [] 70 | return await self.redis_connection.mget(index_keys) 71 | 72 | async def mget_json( 73 | self, 74 | index_keys: List[str], 75 | ) -> List[Optional[Dict[str, Any]]]: 76 | results = await self.mget(index_keys) 77 | return [wyvern_decode(val) if val is not None else None for val in results] 78 | 79 | async def mget_update_in_place( 80 | self, 81 | index_keys: List[str], 82 | wyvern_request: WyvernRequest, 83 | ) -> None: 84 | # single mget way 85 | results = await self.mget(index_keys) 86 | wyvern_request.entity_store = { 87 | key: wyvern_decode(val) if val is not None else None 88 | for key, val in zip(index_keys, results) 89 | } 90 | 91 | async def get_entity( 92 | self, 93 | entity_type: str, 94 | entity_id: str, 95 | ) -> Optional[Dict[str, Any]]: 96 | """ 97 | get entity from redis 98 | """ 99 | index_key = generate_index_key( 100 | self.key_prefix, 101 | entity_type, 102 | entity_id, 103 | ) 104 | 105 | encoded_entity = await self.get(index_key) 106 | if not encoded_entity: 107 | return None 108 | return wyvern_decode(encoded_entity) 109 | 110 | async def get_entities( 111 | self, 112 | entity_type: str, 113 | entity_ids: Sequence[str], 114 | ) -> List[Optional[Dict[str, Any]]]: 115 | """ 116 | get entity from redis 117 | """ 118 | index_keys = [ 119 | generate_index_key(self.key_prefix, entity_type, entity_id) 120 | for entity_id in entity_ids 121 | ] 122 | if not index_keys: 123 | return [] 124 | return await self.mget_json(index_keys) 125 | 126 | async def delete_entity( 127 | self, 128 | entity_type: str, 129 | entity_id: str, 130 | ) -> None: 131 | """ 132 | delete entity from redis 133 | """ 134 | index_key = generate_index_key(self.key_prefix, entity_type, entity_id) 135 | await self.redis_connection.delete(index_key) 136 | 137 | async def delete_entities( 138 | self, 139 | entity_type: str, 140 | entity_ids: Sequence[str], 141 | ) -> None: 142 | """ 143 | delete entities from redis 144 | """ 145 | index_keys = [ 146 | generate_index_key(self.key_prefix, entity_type, entity_id) 147 | for entity_id in entity_ids 148 | ] 149 | await self.redis_connection.delete(*index_keys) 150 | 151 | 152 | wyvern_redis = WyvernRedis() 153 | -------------------------------------------------------------------------------- /wyvern/request_context.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from contextvars import ContextVar 3 | from typing import Optional 4 | 5 | from wyvern.wyvern_request import WyvernRequest 6 | 7 | _request_context: ContextVar[Optional[WyvernRequest]] = ContextVar( 8 | "Global request context", 9 | default=None, 10 | ) 11 | 12 | 13 | def current() -> Optional[WyvernRequest]: 14 | """ 15 | Get the current request context 16 | 17 | Returns: 18 | The current request context, or None if there is none 19 | """ 20 | return _request_context.get() 21 | 22 | 23 | def ensure_current_request() -> WyvernRequest: 24 | """ 25 | Get the current request context, or raise an error if there is none 26 | 27 | Returns: 28 | The current request context if there is one 29 | 30 | Raises: 31 | RuntimeError: If there is no current request context 32 | """ 33 | request = current() 34 | if request is None: 35 | raise RuntimeError("No wyvern request context") 36 | return request 37 | 38 | 39 | def set(request: WyvernRequest) -> None: 40 | """ 41 | Set the current request context 42 | 43 | Args: 44 | request: The request context to set 45 | 46 | Returns: 47 | None 48 | """ 49 | _request_context.set(request) 50 | 51 | 52 | def reset() -> None: 53 | """ 54 | Reset the current request context 55 | 56 | Returns: 57 | None 58 | """ 59 | _request_context.set(None) 60 | -------------------------------------------------------------------------------- /wyvern/service.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from __future__ import annotations 3 | 4 | import asyncio 5 | from typing import List, Optional, Type, Union 6 | 7 | from dotenv import load_dotenv 8 | from fastapi import FastAPI 9 | 10 | from wyvern.components.api_route_component import APIRouteComponent 11 | from wyvern.components.features.realtime_features_component import ( 12 | RealtimeFeatureComponent, 13 | ) 14 | from wyvern.components.index import ( 15 | IndexDeleteComponent, 16 | IndexGetComponent, 17 | IndexUploadComponent, 18 | ) 19 | from wyvern.web_frameworks.fastapi import WyvernFastapi 20 | 21 | 22 | class WyvernService: 23 | """ 24 | The class to define, generate and run a Wyvern service 25 | 26 | Attributes: 27 | host: The host to run the service on. Defaults to localhost. 28 | port: The port to run the service on. Defaults to 5000. 29 | """ 30 | 31 | def __init__( 32 | self, 33 | *, 34 | host: str = "127.0.0.1", 35 | port: int = 5000, 36 | ) -> None: 37 | self.host = host 38 | self.port = port 39 | self.service = WyvernFastapi(host=self.host, port=self.port) 40 | 41 | async def register_routes( 42 | self, 43 | route_components: List[Union[Type[APIRouteComponent], APIRouteComponent]], 44 | ) -> None: 45 | """ 46 | Register the routes for the Wyvern service 47 | 48 | Args: 49 | route_components: The list of route components to register 50 | 51 | Returns: 52 | None 53 | """ 54 | for route_component in route_components: 55 | await self.service.register_route(route_component=route_component) 56 | 57 | def _run( 58 | self, 59 | ) -> None: 60 | """ 61 | Run the Wyvern service 62 | 63 | Returns: 64 | None 65 | """ 66 | load_dotenv() 67 | self.service.run() 68 | 69 | @staticmethod 70 | def generate( 71 | *, 72 | route_components: Optional[ 73 | List[Union[Type[APIRouteComponent], APIRouteComponent]] 74 | ] = None, 75 | realtime_feature_components: Optional[ 76 | List[Type[RealtimeFeatureComponent]] 77 | ] = None, 78 | host: str = "127.0.0.1", 79 | port: int = 5000, 80 | ) -> WyvernService: 81 | """ 82 | Generate a Wyvern service 83 | 84 | Args: 85 | route_components: The list of route components to register. Defaults to None. 86 | realtime_feature_components: The list of realtime feature components to register. Defaults to None. 87 | host: The host to run the service on. Defaults to localhost. 88 | port: The port to run the service on. Defaults to 5000. 89 | 90 | Returns: 91 | WyvernService: The generated Wyvern service 92 | """ 93 | route_components = route_components or [] 94 | service = WyvernService(host=host, port=port) 95 | asyncio.run( 96 | service.register_routes( 97 | [ 98 | IndexDeleteComponent, 99 | IndexGetComponent, 100 | IndexUploadComponent, 101 | *route_components, 102 | ], 103 | ), 104 | ) 105 | return service 106 | 107 | @staticmethod 108 | def run( 109 | *, 110 | route_components: List[Union[Type[APIRouteComponent], APIRouteComponent]], 111 | realtime_feature_components: Optional[ 112 | List[Type[RealtimeFeatureComponent]] 113 | ] = None, 114 | host: str = "127.0.0.1", 115 | port: int = 5000, 116 | ): 117 | """ 118 | Generate and run a Wyvern service 119 | 120 | Args: 121 | route_components: The list of route components to register 122 | realtime_feature_components: The list of realtime feature components to register. Defaults to None. 123 | host: The host to run the service on. Defaults to localhost. 124 | port: The port to run the service on. Defaults to 5000. 125 | 126 | Returns: 127 | None 128 | """ 129 | service = WyvernService.generate( 130 | route_components=route_components, 131 | realtime_feature_components=realtime_feature_components, 132 | host=host, 133 | port=port, 134 | ) 135 | service._run() 136 | 137 | @staticmethod 138 | def generate_app( 139 | *, 140 | route_components: Optional[ 141 | List[Union[Type[APIRouteComponent], APIRouteComponent]] 142 | ] = None, 143 | realtime_feature_components: Optional[ 144 | List[Type[RealtimeFeatureComponent]] 145 | ] = None, 146 | host: str = "127.0.0.1", 147 | port: int = 5000, 148 | ) -> FastAPI: 149 | """ 150 | Generate a Wyvern service and return the FastAPI app 151 | 152 | Args: 153 | route_components: The list of route components to register. Defaults to None. 154 | realtime_feature_components: The list of realtime feature components to register. Defaults to None. 155 | host (str, optional): The host to run the service on. Defaults to localhost. 156 | port (int, optional): The port to run the service on. Defaults to 5000. 157 | 158 | Returns: 159 | FastAPI: The generated FastAPI app 160 | """ 161 | service = WyvernService.generate( 162 | route_components=route_components, 163 | realtime_feature_components=realtime_feature_components, 164 | host=host, 165 | port=port, 166 | ) 167 | return service.service.app 168 | -------------------------------------------------------------------------------- /wyvern/tracking.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import importlib.metadata 3 | import platform 4 | from typing import Any, Dict, Optional 5 | 6 | from posthog import Posthog 7 | 8 | from wyvern.config import settings 9 | 10 | posthog = Posthog( 11 | "phc_bVT2ugnZhMHRWqMvSRHPdeTjaPxQqT3QSsI3r5FlQR5", 12 | host="https://app.posthog.com", 13 | disable_geoip=False, 14 | ) 15 | 16 | 17 | def get_oss_version() -> str: 18 | try: 19 | return importlib.metadata.version("wyvern-ai") 20 | except Exception: 21 | return "unknown" 22 | 23 | 24 | def analytics_metadata() -> Dict[str, Any]: 25 | return { 26 | "os": platform.system().lower(), 27 | "oss_version": get_oss_version(), 28 | "machine": platform.machine(), 29 | "platform": platform.platform(), 30 | "python_version": platform.python_version(), 31 | "environment": settings.ENVIRONMENT, 32 | } 33 | 34 | 35 | def capture( 36 | event: str, 37 | distinct_id: str = "oss", 38 | data: Optional[Dict[str, Any]] = None, 39 | ) -> None: 40 | try: 41 | data = data or {} 42 | data.update(analytics_metadata()) 43 | posthog.capture( 44 | distinct_id=distinct_id, 45 | event=event, 46 | properties=data, 47 | ) 48 | except Exception as e: 49 | posthog.capture( 50 | distinct_id=distinct_id, 51 | event="failure", 52 | properties={ 53 | "capture_error": str(e), 54 | }, 55 | ) 56 | -------------------------------------------------------------------------------- /wyvern/utils.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from wyvern.config import settings 3 | 4 | 5 | def generate_index_key( 6 | scope: str, 7 | entity_type: str, 8 | entity_id: str, 9 | ) -> str: 10 | return f"{scope}:{settings.WYVERN_INDEX_VERSION}:{entity_type}:{entity_id}" 11 | -------------------------------------------------------------------------------- /wyvern/web_frameworks/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Skyvern-AI/wyvern/a6f5f72860e6f528e81fb816fa3f4275d08bed1d/wyvern/web_frameworks/__init__.py -------------------------------------------------------------------------------- /wyvern/wyvern_logging.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import logging 3 | import logging.config 4 | import os 5 | 6 | import yaml 7 | 8 | logger = logging.getLogger(__name__) 9 | 10 | 11 | def setup_logging(): 12 | """ 13 | Setup logging configuration by loading from log_config.yml file. Logs an error if the 14 | file cannot be found or loaded and uses default logging configuration. 15 | """ 16 | # this log_config.yml file path is changed compared to the original library code 17 | path = os.path.abspath("log_config.yml") 18 | 19 | if os.path.exists(path): 20 | with open(path, "rt") as f: 21 | try: 22 | config = yaml.safe_load(f.read()) 23 | 24 | # logfile_path = config["handlers"]["file"]["filename"] 25 | # os.makedirs(logfile_path, exist_ok=True) 26 | logging.config.dictConfig(config) 27 | except Exception as e: 28 | logger.error("Error in Logging Configuration. Using default configs") 29 | raise e 30 | # logging.basicConfig(level=logging.INFO) 31 | else: 32 | logging.basicConfig(level=logging.INFO) 33 | logger.debug("Failed to load configuration file. Using default configs") 34 | -------------------------------------------------------------------------------- /wyvern/wyvern_request.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from __future__ import annotations 3 | 4 | from collections import defaultdict 5 | from dataclasses import dataclass 6 | from typing import Any, Callable, Coroutine, Dict, List, Optional, Tuple, Union 7 | from urllib.parse import urlparse 8 | 9 | import fastapi 10 | from pydantic import BaseModel 11 | 12 | from wyvern import request_context 13 | from wyvern.components.events.events import LoggedEvent 14 | from wyvern.entities.feature_entities import FeatureDataFrame 15 | from wyvern.entities.identifier import Identifier 16 | from wyvern.exceptions import WyvernLoggingOriginalIdentifierMissingError 17 | 18 | 19 | class ShadowRequest: 20 | callable: Callable[..., Coroutine[Any, Any, Any]] 21 | args: Tuple[Any, ...] 22 | kwargs: Dict[str, Any] 23 | 24 | def __init__( 25 | self, shadow_callable: Callable[..., Coroutine[Any, Any, Any]], *args, **kwargs 26 | ): 27 | self.callable = shadow_callable 28 | self.args = args 29 | self.kwargs = kwargs 30 | 31 | 32 | @dataclass 33 | class WyvernRequest: 34 | """ 35 | WyvernRequest is a dataclass that represents a request to the Wyvern service. It is used to pass 36 | information between the various components of the Wyvern service. 37 | 38 | Attributes: 39 | method: The HTTP method of the request 40 | url: The full URL of the request 41 | url_path: The path of the URL of the request 42 | json: The JSON body of the request, represented by pydantic model 43 | headers: The headers of the request 44 | entity_store: A dictionary that can be used to store entities that are created during the request 45 | events: A list of functions that return a list of LoggedEvents. These functions are called at the end of 46 | the request to log events to the event store 47 | feature_df: The feature data frame that is created during the request 48 | request_id: The request ID of the request 49 | """ 50 | 51 | method: str 52 | url: str 53 | url_path: str 54 | json: BaseModel 55 | headers: Dict[Any, Any] 56 | 57 | entity_store: Dict[str, Optional[Dict[str, Any]]] 58 | # TODO (suchintan): Validate that there is no thread leakage here 59 | # The list of list here is a minor performance optimization to prevent copying of lists for events 60 | events: List[Callable[[], List[LoggedEvent[Any]]]] 61 | 62 | feature_df: FeatureDataFrame 63 | # feature_orig_identifiers is a hack to get around the fact that the feature dataframe does not store 64 | # the original identifiers of the entities. This is needed for logging the features with the correct 65 | # identifiers. The below map is a map of the feature name to the primary identifier key of the entity to the 66 | # original identifier of the entity 67 | feature_orig_identifiers: Dict[str, Dict[str, Identifier]] 68 | 69 | # the key is the name of the model and the value is a map of the identifier to the model score 70 | model_output_map: Dict[ 71 | str, 72 | Dict[ 73 | Identifier, 74 | Union[ 75 | float, 76 | str, 77 | List[float], 78 | Dict[str, Optional[Union[float, str, list[float]]]], 79 | None, 80 | ], 81 | ], 82 | ] 83 | 84 | request_id: Optional[str] = None 85 | run_id: str = "0" 86 | 87 | shadow_requests: Optional[List[ShadowRequest]] = None 88 | 89 | # TODO: params 90 | 91 | @classmethod 92 | def parse_fastapi_request( 93 | cls, 94 | json: BaseModel, 95 | req: fastapi.Request, 96 | run_id: str = "0", 97 | request_id: Optional[str] = None, 98 | ) -> WyvernRequest: 99 | """ 100 | Parses a FastAPI request into a WyvernRequest 101 | 102 | Args: 103 | json: The JSON body of the request, represented by pydantic model 104 | req: The FastAPI request 105 | request_id: The request ID of the request 106 | 107 | Returns: 108 | A WyvernRequest 109 | """ 110 | return cls( 111 | method=req.method, 112 | url=str(req.url), 113 | url_path=urlparse(str(req.url)).path, 114 | json=json, 115 | headers=dict(req.headers), 116 | entity_store={}, 117 | events=[], 118 | feature_df=FeatureDataFrame(), 119 | feature_orig_identifiers=defaultdict(dict), 120 | model_output_map={}, 121 | request_id=request_id, 122 | run_id=run_id, 123 | ) 124 | 125 | def cache_model_output( 126 | self, 127 | model_name: str, 128 | data: Dict[ 129 | Identifier, 130 | Union[ 131 | float, 132 | str, 133 | List[float], 134 | Dict[str, Optional[Union[float, str, list[float]]]], 135 | None, 136 | ], 137 | ], 138 | ) -> None: 139 | if model_name not in self.model_output_map: 140 | self.model_output_map[model_name] = {} 141 | self.model_output_map[model_name].update(data) 142 | 143 | def get_model_output( 144 | self, 145 | model_name: str, 146 | identifier: Identifier, 147 | ) -> Optional[ 148 | Union[ 149 | float, 150 | str, 151 | List[float], 152 | Dict[str, Optional[Union[float, str, list[float]]]], 153 | None, 154 | ] 155 | ]: 156 | if model_name not in self.model_output_map: 157 | return None 158 | return self.model_output_map[model_name].get(identifier) 159 | 160 | def get_original_identifier( 161 | self, 162 | primary_identifier_key: str, 163 | feature_name: str, 164 | ) -> Identifier: 165 | """Gets the original identifier for a feature name and primary identifier key. 166 | 167 | Args: 168 | primary_identifier_key: The primary identifier key. 169 | feature_name: The name of the feature. 170 | 171 | 172 | Returns: 173 | The original identifier. 174 | """ 175 | try: 176 | return self.feature_orig_identifiers[feature_name][primary_identifier_key] 177 | except KeyError: 178 | raise WyvernLoggingOriginalIdentifierMissingError( 179 | identifier=primary_identifier_key, 180 | feature_name=feature_name, 181 | ) 182 | 183 | def add_shadow_request_call( 184 | self, 185 | shadow_request: ShadowRequest, 186 | ): 187 | if self.shadow_requests is None: 188 | self.shadow_requests = [] 189 | 190 | self.shadow_requests.append(shadow_request) 191 | 192 | async def execute_shadow_requests(self): 193 | if self.shadow_requests is None: 194 | return 195 | try: 196 | request_context.set(self) 197 | for shadow_request in self.shadow_requests: 198 | await shadow_request.callable( 199 | *shadow_request.args, **shadow_request.kwargs 200 | ) 201 | finally: 202 | request_context.reset() 203 | -------------------------------------------------------------------------------- /wyvern/wyvern_tracing.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from ddtrace import tracer 3 | from ddtrace.filters import FilterRequestsOnUrl 4 | 5 | from wyvern.config import settings 6 | 7 | 8 | def setup_tracing(): 9 | """ 10 | Setup tracing for Wyvern service. Tracing is disabled in development mode and for healthcheck requests. 11 | """ 12 | tracer.configure( 13 | settings={ 14 | "FILTERS": [ 15 | FilterRequestsOnUrl(r"http://.*/healthcheck$"), 16 | ], 17 | }, 18 | ) 19 | 20 | if settings.ENVIRONMENT == "development": 21 | tracer.enabled = False 22 | -------------------------------------------------------------------------------- /wyvern/wyvern_typing.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from typing import List, TypeVar, Union 3 | 4 | from pydantic import BaseModel 5 | 6 | from wyvern.entities.identifier_entities import WyvernEntity 7 | from wyvern.entities.request import BaseWyvernRequest 8 | 9 | T = TypeVar("T") 10 | REQUEST_ENTITY = TypeVar("REQUEST_ENTITY", bound=BaseWyvernRequest) 11 | WYVERN_ENTITY = TypeVar("WYVERN_ENTITY", bound=WyvernEntity) 12 | GENERALIZED_WYVERN_ENTITY = TypeVar( 13 | "GENERALIZED_WYVERN_ENTITY", 14 | bound=Union[WyvernEntity, BaseWyvernRequest], 15 | ) 16 | INPUT_TYPE = TypeVar("INPUT_TYPE") 17 | OUTPUT_TYPE = TypeVar("OUTPUT_TYPE") 18 | UPSTREAM_INPUT_TYPE = TypeVar("UPSTREAM_INPUT_TYPE") 19 | UPSTREAM_OUTPUT_TYPE = TypeVar("UPSTREAM_OUTPUT_TYPE") 20 | REQUEST_SCHEMA = TypeVar("REQUEST_SCHEMA", bound=BaseModel) 21 | RESPONSE_SCHEMA = TypeVar("RESPONSE_SCHEMA", bound=BaseModel) 22 | 23 | WyvernFeature = Union[float, str, List[float], None] 24 | """A WyvernFeature defines the type of a feature in Wyvern. It can be a float, a string, a list of floats, or None.""" 25 | --------------------------------------------------------------------------------