├── .github ├── dependabot.yml └── workflows │ ├── ci.yaml │ └── release.yaml ├── .gitignore ├── .pre-commit-config.yaml ├── .python-version ├── AGENTS.md ├── LICENSE ├── README.md ├── Tiltfile ├── airflow_ai_sdk ├── __init__.py ├── airflow.py ├── decorators │ ├── agent.py │ ├── branch.py │ ├── embed.py │ └── llm.py ├── models │ ├── base.py │ └── tool.py └── operators │ ├── agent.py │ ├── embed.py │ ├── llm.py │ └── llm_branch.py ├── docs ├── README.md ├── examples │ └── index.md ├── features.md ├── index.md ├── interface │ ├── README.md │ ├── airflow.md │ ├── decorators │ │ ├── agent.md │ │ ├── branch.md │ │ ├── embed.md │ │ └── llm.md │ ├── models │ │ ├── base.md │ │ └── tool.md │ └── operators │ │ ├── agent.md │ │ ├── embed.md │ │ ├── llm.md │ │ └── llm_branch.md └── usage.md ├── examples ├── .astro │ ├── config.yaml │ ├── dag_integrity_exceptions.txt │ └── test_dag_integrity_default.py ├── .dockerignore ├── .gitignore ├── Dockerfile ├── README.md ├── dags │ ├── deep_research.py │ ├── email_generation.py │ ├── github_changelog.py │ ├── product_feedback_summarization.py │ ├── sentiment_classification.py │ ├── support_ticket_routing.py │ └── text_embedding.py ├── docker-compose.yaml ├── packages.txt └── requirements.txt ├── pyproject.toml ├── scripts └── generate_interface_docs.py ├── tests ├── __init__.py ├── operators │ ├── test_agent.py │ ├── test_embed.py │ ├── test_llm.py │ └── test_llm_branch.py └── test_airflow_imports.py └── uv.lock /.github/dependabot.yml: -------------------------------------------------------------------------------- 1 | # To get started with Dependabot version updates, you'll need to specify which 2 | # package ecosystems to update and where the package manifests are located. 3 | # Please see the documentation for all configuration options: 4 | # https://docs.github.com/code-security/dependabot/dependabot-version-updates/configuration-options-for-the-dependabot.yml-file 5 | 6 | version: 2 7 | updates: 8 | - package-ecosystem: "uv" # See documentation for possible values 9 | directory: "/" # Location of package manifests 10 | schedule: 11 | interval: "daily" 12 | -------------------------------------------------------------------------------- /.github/workflows/ci.yaml: -------------------------------------------------------------------------------- 1 | name: CI 2 | 3 | on: 4 | push: 5 | branches: 6 | - main 7 | pull_request: 8 | branches: 9 | - main 10 | 11 | jobs: 12 | ruff: 13 | runs-on: ubuntu-latest 14 | steps: 15 | - uses: actions/checkout@v4 16 | 17 | - name: install uv 18 | uses: astral-sh/setup-uv@v5 19 | with: 20 | version: "0.6.10" 21 | enable-cache: true 22 | 23 | - name: install python 24 | run: uv python install 25 | 26 | - name: install dependencies 27 | run: uv sync --all-extras --dev 28 | 29 | - name: ruff check 30 | run: uv run ruff check --output-format=github . 31 | 32 | - name: ruff format 33 | run: uv run ruff format --check . 34 | 35 | - name: minimize uv cache 36 | run: uv cache prune --ci 37 | 38 | test: 39 | runs-on: ubuntu-latest 40 | strategy: 41 | matrix: 42 | python-version: 43 | - "3.10" 44 | - "3.11" 45 | - "3.12" 46 | airflow-version: 47 | - "2.10.0" 48 | - "3.0.0" 49 | 50 | steps: 51 | - uses: actions/checkout@v4 52 | 53 | - name: install uv 54 | uses: astral-sh/setup-uv@v5 55 | with: 56 | version: "0.6.10" 57 | enable-cache: true 58 | 59 | - name: install python 60 | run: uv python install ${{ matrix.python-version }} 61 | 62 | - name: install dependencies 63 | run: uv sync --all-extras --dev --no-install-package apache-airflow 64 | 65 | - name: install airflow 66 | run: uv add apache-airflow==${{ matrix.airflow-version }} 67 | 68 | - name: install requirements necessary for examples 69 | run: uv add sentence-transformers 70 | 71 | - name: print python package versions 72 | run: uv pip freeze 73 | 74 | - name: print airflow version 75 | run: uv run --python ${{ matrix.python-version }} airflow version 76 | 77 | - name: pytest 78 | run: uv run --python ${{ matrix.python-version }} pytest -v 79 | 80 | - name: minimize uv cache 81 | run: uv cache prune --ci 82 | 83 | docs: 84 | runs-on: ubuntu-latest 85 | steps: 86 | - uses: actions/checkout@v4 87 | 88 | - name: install uv 89 | uses: astral-sh/setup-uv@v5 90 | with: 91 | version: "0.6.10" 92 | enable-cache: true 93 | 94 | - name: install python 95 | run: uv python install 96 | 97 | - name: install pre-commit 98 | run: uv add pre-commit 99 | 100 | - name: install dependencies 101 | run: uv sync --all-extras --dev 102 | 103 | - name: generate interface docs 104 | run: uv run python scripts/generate_interface_docs.py 105 | 106 | # if there are changes to the docs, fail and tell the user to run the script 107 | - name: check for changes to the docs 108 | run: | 109 | if git diff --name-only | grep -q "docs"; then 110 | echo "Changes to docs detected. Please run 'uv run python scripts/generate_interface_docs.py' to update the generated docs." 111 | exit 1 112 | else 113 | echo "No changes to docs detected" 114 | fi 115 | 116 | - name: minimize uv cache 117 | run: uv cache prune --ci 118 | -------------------------------------------------------------------------------- /.github/workflows/release.yaml: -------------------------------------------------------------------------------- 1 | name: Release 2 | 3 | on: 4 | release: 5 | types: 6 | - published 7 | 8 | jobs: 9 | release: 10 | runs-on: ubuntu-latest 11 | environment: 12 | name: release 13 | permissions: 14 | id-token: write 15 | steps: 16 | - uses: actions/checkout@v4 17 | 18 | - name: install uv 19 | uses: astral-sh/setup-uv@v5 20 | with: 21 | version: "0.6.10" 22 | enable-cache: true 23 | 24 | - name: install python 25 | run: uv python install 26 | 27 | - name: build package 28 | run: uv build 29 | 30 | - name: upload package 31 | run: uv publish --trusted-publishing always 32 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # UV 98 | # Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | #uv.lock 102 | 103 | # poetry 104 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 105 | # This is especially recommended for binary packages to ensure reproducibility, and is more 106 | # commonly ignored for libraries. 107 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 108 | #poetry.lock 109 | 110 | # pdm 111 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 112 | #pdm.lock 113 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 114 | # in version control. 115 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control 116 | .pdm.toml 117 | .pdm-python 118 | .pdm-build/ 119 | 120 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 121 | __pypackages__/ 122 | 123 | # Celery stuff 124 | celerybeat-schedule 125 | celerybeat.pid 126 | 127 | # SageMath parsed files 128 | *.sage.py 129 | 130 | # Environments 131 | .env 132 | .venv 133 | env/ 134 | venv/ 135 | ENV/ 136 | env.bak/ 137 | venv.bak/ 138 | 139 | # Spyder project settings 140 | .spyderproject 141 | .spyproject 142 | 143 | # Rope project settings 144 | .ropeproject 145 | 146 | # mkdocs documentation 147 | /site 148 | 149 | # mypy 150 | .mypy_cache/ 151 | .dmypy.json 152 | dmypy.json 153 | 154 | # Pyre type checker 155 | .pyre/ 156 | 157 | # pytype static type analyzer 158 | .pytype/ 159 | 160 | # Cython debug symbols 161 | cython_debug/ 162 | 163 | # PyCharm 164 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 165 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 166 | # and can be added to the global gitignore or merged into this file. For a more nuclear 167 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 168 | #.idea/ 169 | 170 | # Ruff stuff: 171 | .ruff_cache/ 172 | 173 | # PyPI configuration file 174 | .pypirc 175 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | # See https://pre-commit.com for more information 2 | # See https://pre-commit.com/hooks.html for more hooks 3 | repos: 4 | - repo: local 5 | hooks: 6 | - id: check-env-file 7 | name: Check if .env file is empty 8 | entry: bash -c 'FILE=dev/.env; if [ -s "$FILE" ]; then echo "$FILE is not empty. Please remove its content."; exit 1; fi' 9 | language: system 10 | types: [file] 11 | pass_filenames: false 12 | - repo: https://github.com/pre-commit/pre-commit-hooks 13 | rev: v5.0.0 14 | hooks: 15 | - id: check-added-large-files 16 | - id: check-merge-conflict 17 | - id: check-toml 18 | - id: check-yaml 19 | args: 20 | - --unsafe 21 | - id: debug-statements 22 | - id: end-of-file-fixer 23 | - id: mixed-line-ending 24 | - id: pretty-format-json 25 | args: ["--autofix"] 26 | - id: trailing-whitespace 27 | - id: detect-private-key 28 | - id: detect-aws-credentials 29 | args: ["--allow-missing-credentials"] 30 | - repo: https://github.com/codespell-project/codespell 31 | rev: v2.4.1 32 | hooks: 33 | - id: codespell 34 | name: Run codespell to check for common misspellings in files 35 | language: python 36 | types: [text] 37 | args: 38 | - -L connexion,aci 39 | - repo: https://github.com/pre-commit/pygrep-hooks 40 | rev: v1.10.0 41 | hooks: 42 | - id: python-check-mock-methods 43 | - repo: https://github.com/Lucas-C/pre-commit-hooks 44 | rev: v1.5.5 45 | hooks: 46 | - id: remove-crlf 47 | - id: remove-tabs 48 | exclude: ^docs/make.bat$|^docs/Makefile$|^dev/dags/dbt/jaffle_shop/seeds/raw_orders.csv$ 49 | - repo: https://github.com/asottile/pyupgrade 50 | rev: v3.19.1 51 | hooks: 52 | - id: pyupgrade 53 | args: 54 | - --py39-plus 55 | - --keep-runtime-typing 56 | - repo: https://github.com/astral-sh/ruff-pre-commit 57 | rev: v0.11.2 58 | hooks: 59 | - id: ruff 60 | args: 61 | - --fix 62 | - id: ruff-format 63 | 64 | ci: 65 | autofix_commit_msg: 🎨 [pre-commit.ci] Auto format from pre-commit.com hooks 66 | autoupdate_commit_msg: ⬆ [pre-commit.ci] pre-commit autoupdate 67 | skip: 68 | - mypy # build of https://github.com/pre-commit/mirrors-mypy:types-PyYAML,types-attrs,attrs,types-requests, 69 | #types-python-dateutil,apache-airflow@v1.5.0 for python@python3 exceeds tier max size 250MiB: 262.6MiB 70 | -------------------------------------------------------------------------------- /.python-version: -------------------------------------------------------------------------------- 1 | 3.12 2 | -------------------------------------------------------------------------------- /AGENTS.md: -------------------------------------------------------------------------------- 1 | Refer to .github/workflows/ci.yaml for examples on how to run tests and linting. 2 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # airflow-ai-sdk 2 | 3 | A Python SDK for working with LLMs from [Apache Airflow](https://github.com/apache/airflow). It allows users to call LLMs and orchestrate agent calls directly within their Airflow pipelines using decorator-based tasks. 4 | 5 | We find it's often helpful to rely on mature orchestration tooling like Airflow for instrumenting LLM workflows and agents in production, as these LLM workflows follow the same form factor as more traditional workflows like ETL pipelines, operational processes, and ML workflows. 6 | 7 | ## Quick Start 8 | 9 | ```bash 10 | pip install airflow-ai-sdk[openai] 11 | ``` 12 | 13 | Installing with no optional dependencies will give you the slim version of the package. The available optional dependencies are listed in [pyproject.toml](https://github.com/astronomer/airflow-ai-sdk/blob/main/pyproject.toml#L17). 14 | 15 | ## Features 16 | 17 | - **LLM tasks with `@task.llm`:** Define tasks that call language models to process text 18 | - **Agent tasks with `@task.agent`:** Orchestrate multi-step AI reasoning with custom tools 19 | - **Automatic output parsing:** Use type hints to automatically parse and validate LLM outputs 20 | - **Branching with `@task.llm_branch`:** Change DAG control flow based on LLM output 21 | - **Model support:** All models in the Pydantic AI library (OpenAI, Anthropic, Gemini, etc.) 22 | - **Embedding tasks with `@task.embed`:** Create vector embeddings from text 23 | 24 | ## Example 25 | 26 | ```python 27 | from typing import Literal 28 | import pendulum 29 | from airflow.decorators import dag, task 30 | from airflow.models.dagrun import DagRun 31 | 32 | 33 | @task.llm( 34 | model="gpt-4o-mini", 35 | result_type=Literal["positive", "negative", "neutral"], 36 | system_prompt="Classify the sentiment of the given text.", 37 | ) 38 | def process_with_llm(dag_run: DagRun) -> str: 39 | input_text = dag_run.conf.get("input_text") 40 | 41 | # can do pre-processing here (e.g. PII redaction) 42 | return input_text 43 | 44 | 45 | @dag( 46 | schedule=None, 47 | start_date=pendulum.datetime(2025, 1, 1), 48 | catchup=False, 49 | params={"input_text": "I'm very happy with the product."}, 50 | ) 51 | def sentiment_classification(): 52 | process_with_llm() 53 | 54 | 55 | sentiment_classification() 56 | ``` 57 | 58 | ## Examples Repository 59 | 60 | To get started with a complete example environment, check out the [examples repository](https://github.com/astronomer/ai-sdk-examples), which offers a full local Airflow instance with the AI SDK installed and 5 example pipelines: 61 | 62 | ```bash 63 | git clone https://github.com/astronomer/ai-sdk-examples.git 64 | cd ai-sdk-examples 65 | astro dev start 66 | ``` 67 | 68 | If you don't have the Astro CLI installed, run `brew install astro` or see other options [here](https://www.astronomer.io/docs/astro/cli/install-cli). 69 | 70 | ## Documentation 71 | 72 | For detailed documentation, see the [docs directory](docs/): 73 | 74 | - [Getting Started](docs/index.md) 75 | - [Features](docs/features.md) 76 | - [Usage Guide](docs/usage.md) 77 | - [Examples](docs/examples/index.md) 78 | 79 | ## License 80 | 81 | [LICENSE](LICENSE) 82 | -------------------------------------------------------------------------------- /Tiltfile: -------------------------------------------------------------------------------- 1 | docker_compose('examples/docker-compose.yaml') 2 | 3 | sync_pyproj_toml = sync('./pyproject.toml', '/usr/local/airflow/airflow-ai-sdk/pyproject.toml') 4 | sync_readme = sync('./README.md', '/usr/local/airflow/airflow_ai_sdk/README.md') 5 | sync_src = sync('./airflow_ai_sdk', '/usr/local/airflow/airflow_ai_sdk/airflow_ai_sdk') 6 | 7 | docker_build( 8 | 'airflow-ai-sdk', 9 | context='.', 10 | dockerfile='examples/Dockerfile', 11 | ignore=['.venv', '**/logs/**'], 12 | live_update=[ 13 | sync_pyproj_toml, 14 | sync_src, 15 | sync_readme, 16 | run( 17 | 'cd /usr/local/airflow/airflow_ai_sdk && uv pip install -e .', 18 | trigger=['pyproject.toml'] 19 | ), 20 | ] 21 | ) 22 | -------------------------------------------------------------------------------- /airflow_ai_sdk/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | This package provides an SDK for building LLM workflows and agents using Apache Airflow. 3 | """ 4 | 5 | from typing import Any 6 | 7 | __version__ = "0.1.3" 8 | 9 | from airflow_ai_sdk.decorators.agent import agent 10 | from airflow_ai_sdk.decorators.branch import llm_branch 11 | from airflow_ai_sdk.decorators.embed import embed 12 | from airflow_ai_sdk.decorators.llm import llm 13 | from airflow_ai_sdk.models.base import BaseModel 14 | 15 | __all__ = ["agent", "llm", "llm_branch", "BaseModel"] 16 | 17 | 18 | def get_provider_info() -> dict[str, Any]: 19 | """Get provider information for Airflow. 20 | 21 | Returns: 22 | A dictionary containing package information and task decorators. 23 | """ 24 | return { 25 | "package-name": "airflow-ai-sdk", 26 | "name": "Airflow AI SDK", 27 | "description": "SDK for building LLM workflows and agents using Apache Airflow", 28 | "versions": [__version__], 29 | "task-decorators": [ 30 | { 31 | "name": "agent", 32 | "class-name": "airflow_ai_sdk.decorators.agent.agent", 33 | }, 34 | { 35 | "name": "llm", 36 | "class-name": "airflow_ai_sdk.decorators.llm.llm", 37 | }, 38 | { 39 | "name": "llm_branch", 40 | "class-name": "airflow_ai_sdk.decorators.branch.llm_branch", 41 | }, 42 | { 43 | "name": "embed", 44 | "class-name": "airflow_ai_sdk.decorators.embed.embed", 45 | }, 46 | ], 47 | } 48 | -------------------------------------------------------------------------------- /airflow_ai_sdk/airflow.py: -------------------------------------------------------------------------------- 1 | """ 2 | This module provides compatibility layer for Airflow 2.x and 3.x by importing the necessary 3 | decorators, operators, and context utilities from the appropriate Airflow version. 4 | """ 5 | 6 | try: 7 | # 3.x 8 | from airflow.providers.standard.decorators.python import _PythonDecoratedOperator 9 | from airflow.providers.standard.operators.branch import BranchMixIn 10 | from airflow.sdk.bases.decorator import TaskDecorator, task_decorator_factory 11 | from airflow.sdk.definitions.context import Context 12 | except ImportError: 13 | # 2.x 14 | from airflow.decorators.base import ( 15 | TaskDecorator, 16 | task_decorator_factory, 17 | ) 18 | from airflow.decorators.python import _PythonDecoratedOperator 19 | from airflow.operators.python import BranchMixIn 20 | from airflow.utils.context import Context 21 | 22 | __all__ = [ 23 | "Context", 24 | "task_decorator_factory", 25 | "TaskDecorator", 26 | "_PythonDecoratedOperator", 27 | "BranchMixIn", 28 | ] 29 | -------------------------------------------------------------------------------- /airflow_ai_sdk/decorators/agent.py: -------------------------------------------------------------------------------- 1 | """ 2 | This module contains the decorators for the agent. 3 | """ 4 | 5 | from typing import TYPE_CHECKING, Any 6 | 7 | from pydantic_ai.agent import Agent 8 | 9 | from airflow_ai_sdk.airflow import task_decorator_factory 10 | from airflow_ai_sdk.operators.agent import AgentDecoratedOperator 11 | 12 | if TYPE_CHECKING: 13 | from airflow_ai_sdk.airflow import TaskDecorator 14 | 15 | 16 | def agent(agent: Agent, **kwargs: dict[str, Any]) -> "TaskDecorator": 17 | """ 18 | Decorator to execute an `pydantic_ai.Agent` inside an Airflow task. 19 | 20 | Example: 21 | 22 | ```python 23 | from pydantic_ai import Agent 24 | 25 | my_agent = Agent(model="o3-mini", system_prompt="Say hello") 26 | 27 | @task.agent(my_agent) 28 | def greet(name: str) -> str: 29 | return name 30 | ``` 31 | """ 32 | kwargs["agent"] = agent 33 | return task_decorator_factory( 34 | decorated_operator_class=AgentDecoratedOperator, 35 | **kwargs, 36 | ) 37 | -------------------------------------------------------------------------------- /airflow_ai_sdk/decorators/branch.py: -------------------------------------------------------------------------------- 1 | """ 2 | This module contains the decorators for the llm_branch decorator. 3 | """ 4 | 5 | from typing import TYPE_CHECKING, Any 6 | 7 | from pydantic_ai import models 8 | 9 | from airflow_ai_sdk.airflow import task_decorator_factory 10 | from airflow_ai_sdk.operators.llm_branch import LLMBranchDecoratedOperator 11 | 12 | if TYPE_CHECKING: 13 | from airflow_ai_sdk.airflow import TaskDecorator 14 | 15 | 16 | def llm_branch( 17 | model: models.Model | models.KnownModelName, 18 | system_prompt: str, 19 | allow_multiple_branches: bool = False, 20 | **kwargs: dict[str, Any], 21 | ) -> "TaskDecorator": 22 | """ 23 | Decorator to branch a DAG based on the result of an LLM call. 24 | 25 | Example: 26 | 27 | ```python 28 | @task 29 | def handle_positive_sentiment(text: str) -> str: 30 | return "Handle positive sentiment" 31 | 32 | @task 33 | def handle_negative_sentiment(text: str) -> str: 34 | return "Handle negative sentiment" 35 | 36 | @task.llm_branch(model="o3-mini", system_prompt="Classify this text by sentiment") 37 | def decide(text: str) -> str: 38 | return text 39 | 40 | # then, in the DAG: 41 | decide >> [handle_positive_sentiment, handle_negative_sentiment] 42 | ``` 43 | """ 44 | kwargs["model"] = model 45 | kwargs["system_prompt"] = system_prompt 46 | kwargs["allow_multiple_branches"] = allow_multiple_branches 47 | return task_decorator_factory( 48 | decorated_operator_class=LLMBranchDecoratedOperator, 49 | **kwargs, 50 | ) 51 | -------------------------------------------------------------------------------- /airflow_ai_sdk/decorators/embed.py: -------------------------------------------------------------------------------- 1 | """ 2 | This module contains the decorators for embedding. 3 | """ 4 | 5 | from typing import TYPE_CHECKING, Any 6 | 7 | from airflow_ai_sdk.airflow import task_decorator_factory 8 | from airflow_ai_sdk.operators.embed import EmbedDecoratedOperator 9 | 10 | if TYPE_CHECKING: 11 | from airflow_ai_sdk.airflow import TaskDecorator 12 | 13 | 14 | def embed( 15 | model_name: str = "all-MiniLM-L12-v2", 16 | **kwargs: dict[str, Any], 17 | ) -> "TaskDecorator": 18 | """ 19 | Decorator to embed text using a SentenceTransformer model. 20 | 21 | Args: 22 | model_name: The name of the model to use for the embedding. Passed to 23 | the `SentenceTransformer` constructor. 24 | **kwargs: Keyword arguments to pass to the `EmbedDecoratedOperator` 25 | constructor. 26 | 27 | Example: 28 | 29 | ```python 30 | @task.embed() 31 | def vectorize() -> str: 32 | return "Example text" 33 | ``` 34 | """ 35 | kwargs["model_name"] = model_name 36 | return task_decorator_factory( 37 | decorated_operator_class=EmbedDecoratedOperator, 38 | **kwargs, 39 | ) 40 | -------------------------------------------------------------------------------- /airflow_ai_sdk/decorators/llm.py: -------------------------------------------------------------------------------- 1 | """ 2 | This module contains the decorators for the llm decorator. 3 | """ 4 | 5 | from typing import TYPE_CHECKING, Any 6 | 7 | from pydantic_ai import models 8 | 9 | from airflow_ai_sdk.airflow import task_decorator_factory 10 | from airflow_ai_sdk.models.base import BaseModel 11 | from airflow_ai_sdk.operators.llm import LLMDecoratedOperator 12 | 13 | if TYPE_CHECKING: 14 | from airflow_ai_sdk.airflow import TaskDecorator 15 | 16 | 17 | def llm( 18 | model: models.Model | models.KnownModelName, 19 | system_prompt: str, 20 | result_type: type[BaseModel] | None = None, 21 | **kwargs: dict[str, Any], 22 | ) -> "TaskDecorator": 23 | """ 24 | Decorator to make a single call to an LLM. 25 | 26 | Example: 27 | 28 | ```python 29 | @task.llm(model="o3-mini", system_prompt="Translate to French") 30 | def translate(text: str) -> str: 31 | return text 32 | ``` 33 | """ 34 | kwargs["model"] = model 35 | kwargs["result_type"] = result_type 36 | kwargs["system_prompt"] = system_prompt 37 | return task_decorator_factory( 38 | decorated_operator_class=LLMDecoratedOperator, 39 | **kwargs, 40 | ) 41 | -------------------------------------------------------------------------------- /airflow_ai_sdk/models/base.py: -------------------------------------------------------------------------------- 1 | """ 2 | This module provides a base class for all models in the SDK. The base class ensures 3 | proper serialization of task inputs and outputs as required by Airflow. 4 | """ 5 | 6 | from pydantic import BaseModel as PydanticBaseModel 7 | 8 | 9 | class BaseModel(PydanticBaseModel): 10 | """ 11 | Base class for all models in the SDK. 12 | 13 | This class extends Pydantic's BaseModel to provide a common foundation for all 14 | models used in the SDK. It ensures proper serialization of task inputs and outputs 15 | as required by Airflow. 16 | 17 | Example: 18 | 19 | ```python 20 | from airflow_ai_sdk.models.base import BaseModel 21 | 22 | class MyModel(BaseModel): 23 | name: str 24 | value: int 25 | ``` 26 | """ 27 | -------------------------------------------------------------------------------- /airflow_ai_sdk/models/tool.py: -------------------------------------------------------------------------------- 1 | """ 2 | This module provides a wrapper around pydantic_ai.Tool for better observability in Airflow. 3 | """ 4 | 5 | from pydantic_ai import Tool as PydanticTool 6 | from pydantic_ai.tools import AgentDepsT, _messages 7 | 8 | 9 | class WrappedTool(PydanticTool[AgentDepsT]): 10 | """ 11 | Wrapper around `pydantic_ai.Tool` for better observability in Airflow. 12 | 13 | This class extends the `pydantic_ai.Tool` class to provide enhanced logging 14 | capabilities in Airflow. It wraps tool calls and results in log groups for 15 | better visibility in the Airflow UI. 16 | 17 | Example: 18 | 19 | ```python 20 | from airflow_ai_sdk.models.tool import WrappedTool 21 | from pydantic_ai import Tool 22 | 23 | tool = Tool(my_function, name="my_tool") 24 | wrapped_tool = WrappedTool.from_pydantic_tool(tool) 25 | ``` 26 | """ 27 | 28 | async def run( 29 | self, 30 | message: _messages.ToolCallPart, 31 | *args: object, 32 | **kwargs: object, 33 | ) -> _messages.ToolReturnPart | _messages.RetryPromptPart: 34 | """ 35 | Execute the tool with enhanced logging. 36 | 37 | Args: 38 | message: The tool call message containing the tool name and arguments. 39 | *args: Additional positional arguments for the tool. 40 | **kwargs: Additional keyword arguments for the tool. 41 | 42 | Returns: 43 | The tool's return value wrapped in a ToolReturnPart or RetryPromptPart. 44 | """ 45 | from pprint import pprint 46 | 47 | print(f"::group::Calling tool {message.tool_name} with args {message.args}") 48 | 49 | result = await super().run(message, *args, **kwargs) 50 | print("Result") 51 | pprint(result.content) 52 | 53 | print(f"::endgroup::") 54 | 55 | return result 56 | 57 | @classmethod 58 | def from_pydantic_tool( 59 | cls, tool: PydanticTool[AgentDepsT] 60 | ) -> "WrappedTool[AgentDepsT]": 61 | """ 62 | Create a WrappedTool instance from a pydantic_ai.Tool. 63 | 64 | Args: 65 | tool: The pydantic_ai.Tool instance to wrap. 66 | 67 | Returns: 68 | A new WrappedTool instance with the same configuration as the input tool. 69 | """ 70 | return cls( 71 | tool.function, 72 | name=tool.name, 73 | description=tool.description, 74 | ) 75 | -------------------------------------------------------------------------------- /airflow_ai_sdk/operators/agent.py: -------------------------------------------------------------------------------- 1 | """ 2 | This module provides the AgentDecoratedOperator class for executing pydantic_ai.Agent 3 | instances within Airflow tasks. 4 | """ 5 | 6 | from typing import Any 7 | 8 | from pydantic_ai import Agent 9 | 10 | from airflow_ai_sdk.airflow import Context, _PythonDecoratedOperator 11 | from airflow_ai_sdk.models.base import BaseModel 12 | from airflow_ai_sdk.models.tool import WrappedTool 13 | 14 | 15 | class AgentDecoratedOperator(_PythonDecoratedOperator): 16 | """ 17 | Operator that executes a `pydantic_ai.Agent`. 18 | 19 | This operator wraps a `pydantic_ai.Agent` instance and executes it within an Airflow task. 20 | It provides enhanced logging capabilities through `WrappedTool`. 21 | 22 | Example: 23 | 24 | ```python 25 | from pydantic_ai import Agent 26 | from airflow_ai_sdk.operators.agent import AgentDecoratedOperator 27 | 28 | def prompt() -> str: 29 | return "Hello" 30 | 31 | operator = AgentDecoratedOperator( 32 | task_id="example", 33 | python_callable=prompt, 34 | agent=Agent(model="o3-mini", system_prompt="Say hello"), 35 | ) 36 | ``` 37 | """ 38 | 39 | custom_operator_name = "@task.agent" 40 | 41 | def __init__( 42 | self, 43 | agent: Agent, 44 | op_args: list[Any], 45 | op_kwargs: dict[str, Any], 46 | *args: dict[str, Any], 47 | **kwargs: dict[str, Any], 48 | ): 49 | """ 50 | Initialize the AgentDecoratedOperator. 51 | 52 | Args: 53 | agent: The `pydantic_ai.Agent` instance to execute. 54 | op_args: Positional arguments to pass to the `python_callable`. 55 | op_kwargs: Keyword arguments to pass to the `python_callable`. 56 | *args: Additional positional arguments for the operator. 57 | **kwargs: Additional keyword arguments for the operator. 58 | """ 59 | super().__init__(*args, op_args=op_args, op_kwargs=op_kwargs, **kwargs) 60 | 61 | self.op_args = op_args 62 | self.op_kwargs = op_kwargs 63 | self.agent = agent 64 | 65 | # wrapping the tool will print the tool call and the result in an airflow log group for better observability 66 | self.agent._function_tools = { 67 | name: WrappedTool.from_pydantic_tool(tool) 68 | for name, tool in self.agent._function_tools.items() 69 | } 70 | 71 | def execute(self, context: Context) -> str | dict[str, Any] | list[str]: 72 | """ 73 | Execute the agent with the given context. 74 | 75 | Args: 76 | context: The Airflow context for this task execution. 77 | 78 | Returns: 79 | The result of the agent's execution, which can be a string, dictionary, 80 | or list of strings. 81 | """ 82 | print("Executing LLM call") 83 | 84 | prompt = super().execute(context) 85 | print(f"Prompt: {prompt}") 86 | 87 | try: 88 | result = self.agent.run_sync(prompt) 89 | print(f"Result: {result}") 90 | except Exception as e: 91 | print(f"Error: {e}") 92 | raise e 93 | 94 | # turn the result into a dict 95 | if isinstance(result.data, BaseModel): 96 | return result.data.model_dump() 97 | 98 | return result.data 99 | -------------------------------------------------------------------------------- /airflow_ai_sdk/operators/embed.py: -------------------------------------------------------------------------------- 1 | """ 2 | This module provides the EmbedDecoratedOperator class for generating text embeddings 3 | using SentenceTransformer models within Airflow tasks. 4 | """ 5 | 6 | from typing import Any 7 | 8 | from airflow_ai_sdk.airflow import Context, _PythonDecoratedOperator 9 | 10 | 11 | class EmbedDecoratedOperator(_PythonDecoratedOperator): 12 | """ 13 | Operator that builds embeddings for text using SentenceTransformer models. 14 | 15 | This operator generates embeddings for text input using a specified SentenceTransformer 16 | model. It provides a convenient way to create embeddings within Airflow tasks. 17 | 18 | Example: 19 | 20 | ```python 21 | from airflow_ai_sdk.operators.embed import EmbedDecoratedOperator 22 | 23 | def produce_text() -> str: 24 | return "document" 25 | 26 | operator = EmbedDecoratedOperator( 27 | task_id="embed", 28 | python_callable=produce_text, 29 | model_name="all-MiniLM-L12-v2", 30 | ) 31 | ``` 32 | """ 33 | 34 | custom_operator_name = "@task.embed" 35 | 36 | def __init__( 37 | self, 38 | op_args: list[Any], 39 | op_kwargs: dict[str, Any], 40 | model_name: str, 41 | encode_kwargs: dict[str, Any] = None, 42 | *args: dict[str, Any], 43 | **kwargs: dict[str, Any], 44 | ): 45 | """ 46 | Initialize the EmbedDecoratedOperator. 47 | 48 | Args: 49 | op_args: Positional arguments to pass to the python_callable. 50 | op_kwargs: Keyword arguments to pass to the python_callable. 51 | model_name: The name of the model to use for the embedding. Passed to the `SentenceTransformer` constructor. 52 | encode_kwargs: Keyword arguments to pass to the `encode` method of the SentenceTransformer model. 53 | *args: Additional positional arguments for the operator. 54 | **kwargs: Additional keyword arguments for the operator. 55 | """ 56 | if encode_kwargs is None: 57 | encode_kwargs = {} 58 | 59 | super().__init__(*args, op_args=op_args, op_kwargs=op_kwargs, **kwargs) 60 | 61 | self.model_name = model_name 62 | self.encode_kwargs = encode_kwargs 63 | 64 | try: 65 | import sentence_transformers # noqa: F401 66 | except ImportError as e: 67 | raise ImportError( 68 | "sentence-transformers is not installed but is required for the embedding operator. Please install it before using the embedding operator." 69 | ) from e 70 | 71 | def execute(self, context: Context) -> list[float]: 72 | """ 73 | Execute the embedding operation with the given context. 74 | 75 | Args: 76 | context: The Airflow context for this task execution. 77 | 78 | Returns: 79 | A list of floats representing the embedding vector for the input text. 80 | """ 81 | from sentence_transformers import SentenceTransformer 82 | 83 | text = super().execute(context) 84 | if not isinstance(text, str): 85 | raise TypeError("The input text must be a string.") 86 | 87 | model = SentenceTransformer(self.model_name) 88 | return model.encode(text, **self.encode_kwargs).tolist() 89 | -------------------------------------------------------------------------------- /airflow_ai_sdk/operators/llm.py: -------------------------------------------------------------------------------- 1 | """ 2 | This module provides the LLMDecoratedOperator class for making single LLM calls 3 | within Airflow tasks. 4 | """ 5 | 6 | from typing import Any 7 | 8 | from pydantic import BaseModel 9 | from pydantic_ai import Agent, models 10 | 11 | from airflow_ai_sdk.airflow import Context 12 | from airflow_ai_sdk.operators.agent import AgentDecoratedOperator 13 | 14 | 15 | class LLMDecoratedOperator(AgentDecoratedOperator): 16 | """ 17 | Simpler interface for performing a single LLM call. 18 | 19 | This operator provides a simplified interface for making single LLM calls within 20 | Airflow tasks, without the full agent functionality. 21 | 22 | Example: 23 | 24 | ```python 25 | from airflow_ai_sdk.operators.llm import LLMDecoratedOperator 26 | 27 | def make_prompt() -> str: 28 | return "Hello" 29 | 30 | operator = LLMDecoratedOperator( 31 | task_id="llm", 32 | python_callable=make_prompt, 33 | model="o3-mini", 34 | system_prompt="Reply politely", 35 | ) 36 | ``` 37 | """ 38 | 39 | custom_operator_name = "@task.llm" 40 | 41 | def __init__( 42 | self, 43 | model: models.Model | models.KnownModelName, 44 | system_prompt: str, 45 | result_type: type[BaseModel] = str, 46 | **kwargs: dict[str, Any], 47 | ): 48 | """ 49 | Initialize the LLMDecoratedOperator. 50 | 51 | Args: 52 | model: The LLM model to use for the call. 53 | system_prompt: The system prompt to use for the call. 54 | result_type: Optional Pydantic model type to validate and parse the result. 55 | **kwargs: Additional keyword arguments for the operator. 56 | """ 57 | agent = Agent( 58 | model=model, 59 | system_prompt=system_prompt, 60 | result_type=result_type, 61 | ) 62 | super().__init__(agent=agent, **kwargs) 63 | -------------------------------------------------------------------------------- /airflow_ai_sdk/operators/llm_branch.py: -------------------------------------------------------------------------------- 1 | """ 2 | This module provides the LLMBranchDecoratedOperator class for branching DAGs based on 3 | LLM decisions within Airflow tasks. 4 | """ 5 | 6 | from enum import Enum 7 | from typing import Any 8 | 9 | from pydantic_ai import Agent, models 10 | 11 | from airflow_ai_sdk.airflow import BranchMixIn, Context 12 | from airflow_ai_sdk.operators.agent import AgentDecoratedOperator 13 | 14 | 15 | class LLMBranchDecoratedOperator(AgentDecoratedOperator, BranchMixIn): 16 | """ 17 | Branch a DAG based on the result of an LLM call. 18 | 19 | This operator uses an LLM to decide which downstream task to execute next. 20 | It combines the capabilities of an LLM with Airflow's branching functionality. 21 | 22 | Example: 23 | 24 | ```python 25 | from airflow_ai_sdk.operators.llm_branch import LLMBranchDecoratedOperator 26 | 27 | def make_prompt() -> str: 28 | return "Choose" 29 | 30 | operator = LLMBranchDecoratedOperator( 31 | task_id="branch", 32 | python_callable=make_prompt, 33 | model="o3-mini", 34 | system_prompt="Return 'a' or 'b'", 35 | ) 36 | ``` 37 | """ 38 | 39 | custom_operator_name = "@task.llm_branch" 40 | 41 | def __init__( 42 | self, 43 | model: models.Model | models.KnownModelName, 44 | system_prompt: str, 45 | allow_multiple_branches: bool = False, 46 | **kwargs: dict[str, Any], 47 | ): 48 | """ 49 | Initialize the LLMBranchDecoratedOperator. 50 | 51 | Args: 52 | model: The LLM model to use for the decision. 53 | system_prompt: The system prompt to use for the decision. 54 | allow_multiple_branches: Whether to allow multiple downstream tasks to be executed. 55 | **kwargs: Additional keyword arguments for the operator. 56 | """ 57 | self.model = model 58 | self.system_prompt = system_prompt 59 | self.allow_multiple_branches = allow_multiple_branches 60 | 61 | agent = Agent( 62 | model=model, 63 | system_prompt=system_prompt, 64 | ) 65 | 66 | super().__init__(agent=agent, **kwargs) 67 | 68 | def execute(self, context: Context) -> str | list[str]: 69 | """ 70 | Execute the branching decision with the given context. 71 | 72 | Args: 73 | context: The Airflow context for this task execution. 74 | 75 | Returns: 76 | The task_id(s) of the downstream task(s) to execute next. 77 | """ 78 | # create an enum of the downstream tasks and add it to the agent 79 | downstream_tasks_enum = Enum( 80 | "DownstreamTasks", 81 | {task_id: task_id for task_id in self.downstream_task_ids}, 82 | ) 83 | 84 | self.agent = Agent( 85 | model=self.model, 86 | system_prompt=self.system_prompt, 87 | result_type=downstream_tasks_enum, 88 | ) 89 | 90 | result = super().execute(context) 91 | 92 | # turn the result into a string 93 | if isinstance(result, Enum): 94 | result = result.value 95 | 96 | # if the response is not a string, cast it to a string 97 | if not isinstance(result, str): 98 | result = str(result) 99 | 100 | if isinstance(result, list) and not self.allow_multiple_branches: 101 | raise ValueError( 102 | "Multiple branches were returned but allow_multiple_branches is False" 103 | ) 104 | 105 | return self.do_branch(context, result) 106 | -------------------------------------------------------------------------------- /docs/README.md: -------------------------------------------------------------------------------- 1 | # Airflow AI SDK Documentation 2 | 3 | Welcome to the Airflow AI SDK documentation. This SDK allows you to work with LLMs from Apache Airflow, based on [Pydantic AI](https://ai.pydantic.dev). 4 | 5 | ## Table of Contents 6 | 7 | ### Getting Started 8 | - [Introduction](index.md) - Overview and quick start guide 9 | - [Features](features.md) - Overview of available features 10 | - [Usage](usage.md) - Basic usage guide 11 | 12 | ### Examples 13 | - [Example DAGs](examples/index.md) - Detailed examples of using the SDK 14 | 15 | ### Reference 16 | - [API Reference](api-reference/) - API documentation 17 | 18 | ### Development 19 | - [GitHub Repository](https://github.com/astronomer/airflow-ai-sdk) - Source code and issues 20 | - [Examples Repository](https://github.com/astronomer/ai-sdk-examples) - Full examples with Airflow 21 | -------------------------------------------------------------------------------- /docs/examples/index.md: -------------------------------------------------------------------------------- 1 | # Examples 2 | 3 | The airflow-ai-sdk comes with several example DAGs that demonstrate how to use the various decorators and features. All examples can be found in the [`examples/dags`](https://github.com/astronomer/airflow-ai-sdk/tree/main/examples/dags) directory of the repository. 4 | 5 | ## Available Examples 6 | 7 | ### 1. LLM Calls - GitHub Changelog 8 | 9 | [View example code](https://github.com/astronomer/airflow-ai-sdk/blob/main/examples/dags/github_changelog.py) 10 | 11 | Demonstrates using `@task.llm` to summarize GitHub commits with a language model. 12 | 13 | **Features demonstrated:** 14 | - Using `@task.llm` decorator with a specific model (gpt-4o-mini) 15 | - Specifying a detailed system prompt for commit summarization 16 | - Transforming task input to LLM input (joining commits into a string) 17 | - Weekly scheduling for regular changelog generation 18 | 19 | ### 2. Structured Output - Product Feedback 20 | 21 | [View example code](https://github.com/astronomer/airflow-ai-sdk/blob/main/examples/dags/product_feedback_summarization.py) 22 | 23 | Shows how to use `@task.llm` with Pydantic models to parse structured data from product feedback. 24 | 25 | **Features demonstrated:** 26 | - Creating a custom Pydantic model for structured LLM output 27 | - PII masking in preprocessing 28 | - Using `result_type` parameter with a Pydantic model 29 | - Task mapping with `expand()` to process multiple feedback items 30 | - Conditional execution with `AirflowSkipException` 31 | 32 | ### 3. Agent Tasks - Deep Research 33 | 34 | [View example code](https://github.com/astronomer/airflow-ai-sdk/blob/main/examples/dags/deep_research.py) 35 | 36 | Illustrates how to use `@task.agent` with custom tools to perform deep research on topics. 37 | 38 | **Features demonstrated:** 39 | - Creating a custom agent with tools 40 | - Using the `@task.agent` decorator 41 | - Creating custom tools (web content fetching) 42 | - Integrating with external APIs (DuckDuckGo search) 43 | - Using runtime parameters for dynamic DAG execution 44 | 45 | ### 4. Branching Tasks - Support Ticket Routing 46 | 47 | [View example code](https://github.com/astronomer/airflow-ai-sdk/blob/main/examples/dags/support_ticket_routing.py) 48 | 49 | Demonstrates how to use `@task.llm_branch` to route support tickets based on priority. 50 | 51 | **Features demonstrated:** 52 | - Using `@task.llm_branch` for conditional workflow routing 53 | - Setting up branch tasks based on LLM decisions 54 | - Configuring `allow_multiple_branches` parameter 55 | - Processing DAG run configuration parameters 56 | - Detailed prompt engineering for classification tasks 57 | 58 | ### 5. Embedding Tasks - Text Embedding 59 | 60 | [View example code](https://github.com/astronomer/airflow-ai-sdk/blob/main/examples/dags/text_embedding.py) 61 | 62 | Shows how to use `@task.embed` to create vector embeddings from text. 63 | 64 | **Features demonstrated:** 65 | - Using `@task.embed` decorator 66 | - Specifying embedding model parameters 67 | - Generating vector embeddings from text 68 | - Configuring encoding parameters (normalization) 69 | - Task mapping to process multiple texts in parallel 70 | 71 | ## Running the Examples 72 | 73 | The examples can be run in a local Airflow environment with the following steps: 74 | 75 | 1. Clone the [examples repository](https://github.com/astronomer/ai-sdk-examples): 76 | ```bash 77 | git clone https://github.com/astronomer/ai-sdk-examples.git 78 | ``` 79 | 80 | 2. Navigate to the examples directory: 81 | ```bash 82 | cd ai-sdk-examples 83 | ``` 84 | 85 | 3. Start the Airflow environment: 86 | ```bash 87 | astro dev start 88 | ``` 89 | 90 | For more information on the individual examples, refer to the code comments in each example file. 91 | -------------------------------------------------------------------------------- /docs/features.md: -------------------------------------------------------------------------------- 1 | # Features 2 | 3 | ## Core Features 4 | 5 | - **LLM tasks with `@task.llm`:** Define tasks that call language models (e.g. GPT-3.5-turbo) to process text. 6 | - **Agent tasks with `@task.agent`:** Orchestrate multi-step AI reasoning by leveraging custom tools. 7 | - **Automatic output parsing:** Use function type hints (including Pydantic models) to automatically parse and validate LLM outputs. 8 | - **Branching with `@task.llm_branch`:** Change the control flow of a DAG based on the output of an LLM. 9 | - **Model support:** Support for [all models in the Pydantic AI library](https://ai.pydantic.dev/models/) (OpenAI, Anthropic, Gemini, Ollama, Groq, Mistral, Cohere, Bedrock) 10 | - **Embedding tasks with `@task.embed`:** Create vector embeddings from text using sentence-transformers models. 11 | 12 | ## Why Use Airflow for AI Workflows? 13 | 14 | Airflow provides several advantages for orchestrating AI workflows: 15 | 16 | - **Flexible scheduling:** run tasks on a fixed schedule, on-demand, or based on external events 17 | - **Dynamic task mapping:** easily process multiple inputs in parallel with full error handling and observability 18 | - **Branching and conditional logic:** change the control flow of a DAG based on the output of certain tasks 19 | - **Error handling:** built-in support for retries, exponential backoff, and timeouts 20 | - **Resource management:** limit the concurrency of tasks with Airflow Pools 21 | - **Monitoring:** detailed logs and monitoring capabilities 22 | - **Scalability:** designed for production workflows 23 | 24 | ## Task Decorators 25 | 26 | ### @task.llm 27 | 28 | The `@task.llm` decorator enables calling language models from your Airflow tasks. It supports: 29 | 30 | - Configurable models from various providers 31 | - System and user prompts 32 | - Structured output parsing with Pydantic models 33 | - Type validation 34 | 35 | ### @task.agent 36 | 37 | The `@task.agent` decorator adds agent capabilities to your Airflow tasks, enabling: 38 | 39 | - Multi-step reasoning 40 | - Tool usage for external operations 41 | - Memory and context management 42 | - Complex problem-solving workflows 43 | 44 | ### @task.llm_branch 45 | 46 | The `@task.llm_branch` decorator adds LLM-based decision making to your DAG control flow: 47 | 48 | - Routes execution based on LLM output 49 | - Ensures output matches a downstream task ID 50 | - Supports both single and multiple branch selection 51 | 52 | ### @task.embed 53 | 54 | The `@task.embed` decorator creates vector embeddings from text: 55 | 56 | - Uses sentence-transformers models 57 | - Creates embeddings usable for semantic search, clustering, etc. 58 | - Configurable model selection 59 | - Optional normalization and other encoding parameters 60 | -------------------------------------------------------------------------------- /docs/index.md: -------------------------------------------------------------------------------- 1 | # airflow-ai-sdk 2 | 3 | This SDK allows you to work with LLMs from Apache Airflow, based on [Pydantic AI](https://ai.pydantic.dev). It enables calling LLMs and orchestrating agent calls directly within Airflow pipelines using decorator-based tasks. 4 | 5 | ## Quick Start 6 | 7 | To install the package with optional dependencies: 8 | 9 | ```bash 10 | pip install airflow-ai-sdk[openai,duckduckgo] 11 | ``` 12 | 13 | Note that installing the package with no optional dependencies will install the slim version, which does not include any LLM models or tools. The available optional packages are listed in the [pyproject.toml](https://github.com/astronomer/airflow-ai-sdk/blob/main/pyproject.toml#L17). 14 | 15 | ## Examples Repository 16 | 17 | Check out the [examples repository](https://github.com/astronomer/ai-sdk-examples), which offers a full local Airflow instance with the AI SDK installed and 5 example pipelines: 18 | 19 | ```bash 20 | git clone https://github.com/astronomer/ai-sdk-examples.git 21 | cd ai-sdk-examples 22 | astro dev start 23 | ``` 24 | 25 | If you don't have the Astro CLI installed, run `brew install astro` (or see other options [here](https://www.astronomer.io/docs/astro/cli/install-cli)). 26 | 27 | ## Design Principles 28 | 29 | We follow the taskflow pattern of Airflow with four decorators: 30 | 31 | - `@task.llm`: Define a task that calls an LLM. Under the hood, this creates a Pydantic AI `Agent` with no tools. 32 | - `@task.agent`: Define a task that calls an agent. You can pass in a Pydantic AI `Agent` directly. 33 | - `@task.llm_branch`: Define a task that branches the control flow of a DAG based on the output of an LLM. Enforces that the LLM output is one of the downstream task_ids. 34 | - `@task.embed`: Define a task that embeds text using a sentence-transformers model. 35 | 36 | The function supplied to each decorator is a translation function that converts the Airflow task's input into the LLM's input. If you don't want to do any translation, you can just return the input unchanged. 37 | 38 | ## Documentation 39 | 40 | - [Features](features.md) - Overview of available features 41 | - [Usage](usage.md) - Basic usage guide 42 | - [Examples](examples/) - Detailed examples 43 | - [API Reference](api-reference/) - API documentation 44 | 45 | ## Motivation 46 | 47 | AI workflows are becoming increasingly common as organizations look for pragmatic ways to get value out of LLMs. Airflow is a powerful tool for managing the dependencies between tasks and for scheduling and monitoring them, and has been trusted by data teams for 10+ years. 48 | 49 | This SDK is designed to make it easy to integrate LLM workflows into your Airflow pipelines, from simple LLM calls to complex agentic workflows. 50 | -------------------------------------------------------------------------------- /docs/interface/README.md: -------------------------------------------------------------------------------- 1 | # Public Interface Documentation 2 | 3 | This directory contains auto-generated documentation for the public interface of the `airflow_ai_sdk` package. The files are organized into folders that match the SDK's package layout. To regenerate these files run: 4 | 5 | ```bash 6 | uv run python scripts/generate_interface_docs.py 7 | ``` 8 | -------------------------------------------------------------------------------- /docs/interface/airflow.md: -------------------------------------------------------------------------------- 1 | # airflow_ai_sdk.airflow 2 | 3 | This module provides compatibility layer for Airflow 2.x and 3.x by importing the necessary 4 | decorators, operators, and context utilities from the appropriate Airflow version. 5 | -------------------------------------------------------------------------------- /docs/interface/decorators/agent.md: -------------------------------------------------------------------------------- 1 | # airflow_ai_sdk.decorators.agent 2 | 3 | This module contains the decorators for the agent. 4 | 5 | ## agent 6 | 7 | Decorator to execute an `pydantic_ai.Agent` inside an Airflow task. 8 | 9 | Example: 10 | 11 | ```python 12 | from pydantic_ai import Agent 13 | 14 | my_agent = Agent(model="o3-mini", system_prompt="Say hello") 15 | 16 | @task.agent(my_agent) 17 | def greet(name: str) -> str: 18 | return name 19 | ``` 20 | -------------------------------------------------------------------------------- /docs/interface/decorators/branch.md: -------------------------------------------------------------------------------- 1 | # airflow_ai_sdk.decorators.branch 2 | 3 | This module contains the decorators for the llm_branch decorator. 4 | 5 | ## llm_branch 6 | 7 | Decorator to branch a DAG based on the result of an LLM call. 8 | 9 | Example: 10 | 11 | ```python 12 | @task 13 | def handle_positive_sentiment(text: str) -> str: 14 | return "Handle positive sentiment" 15 | 16 | @task 17 | def handle_negative_sentiment(text: str) -> str: 18 | return "Handle negative sentiment" 19 | 20 | @task.llm_branch(model="o3-mini", system_prompt="Classify this text by sentiment") 21 | def decide(text: str) -> str: 22 | return text 23 | 24 | # then, in the DAG: 25 | decide >> [handle_positive_sentiment, handle_negative_sentiment] 26 | ``` 27 | -------------------------------------------------------------------------------- /docs/interface/decorators/embed.md: -------------------------------------------------------------------------------- 1 | # airflow_ai_sdk.decorators.embed 2 | 3 | This module contains the decorators for embedding. 4 | 5 | ## embed 6 | 7 | Decorator to embed text using a SentenceTransformer model. 8 | 9 | Args: 10 | model_name: The name of the model to use for the embedding. Passed to 11 | the `SentenceTransformer` constructor. 12 | **kwargs: Keyword arguments to pass to the `EmbedDecoratedOperator` 13 | constructor. 14 | 15 | Example: 16 | 17 | ```python 18 | @task.embed() 19 | def vectorize() -> str: 20 | return "Example text" 21 | ``` 22 | -------------------------------------------------------------------------------- /docs/interface/decorators/llm.md: -------------------------------------------------------------------------------- 1 | # airflow_ai_sdk.decorators.llm 2 | 3 | This module contains the decorators for the llm decorator. 4 | 5 | ## llm 6 | 7 | Decorator to make a single call to an LLM. 8 | 9 | Example: 10 | 11 | ```python 12 | @task.llm(model="o3-mini", system_prompt="Translate to French") 13 | def translate(text: str) -> str: 14 | return text 15 | ``` 16 | -------------------------------------------------------------------------------- /docs/interface/models/base.md: -------------------------------------------------------------------------------- 1 | # airflow_ai_sdk.models.base 2 | 3 | This module provides a base class for all models in the SDK. The base class ensures 4 | proper serialization of task inputs and outputs as required by Airflow. 5 | 6 | ## BaseModel 7 | 8 | Base class for all models in the SDK. 9 | 10 | This class extends Pydantic's BaseModel to provide a common foundation for all 11 | models used in the SDK. It ensures proper serialization of task inputs and outputs 12 | as required by Airflow. 13 | 14 | Example: 15 | 16 | ```python 17 | from airflow_ai_sdk.models.base import BaseModel 18 | 19 | class MyModel(BaseModel): 20 | name: str 21 | value: int 22 | ``` 23 | -------------------------------------------------------------------------------- /docs/interface/models/tool.md: -------------------------------------------------------------------------------- 1 | # airflow_ai_sdk.models.tool 2 | 3 | This module provides a wrapper around pydantic_ai.Tool for better observability in Airflow. 4 | 5 | ## WrappedTool 6 | 7 | Wrapper around `pydantic_ai.Tool` for better observability in Airflow. 8 | 9 | This class extends the `pydantic_ai.Tool` class to provide enhanced logging 10 | capabilities in Airflow. It wraps tool calls and results in log groups for 11 | better visibility in the Airflow UI. 12 | 13 | Example: 14 | 15 | ```python 16 | from airflow_ai_sdk.models.tool import WrappedTool 17 | from pydantic_ai import Tool 18 | 19 | tool = Tool(my_function, name="my_tool") 20 | wrapped_tool = WrappedTool.from_pydantic_tool(tool) 21 | ``` 22 | -------------------------------------------------------------------------------- /docs/interface/operators/agent.md: -------------------------------------------------------------------------------- 1 | # airflow_ai_sdk.operators.agent 2 | 3 | This module provides the AgentDecoratedOperator class for executing pydantic_ai.Agent 4 | instances within Airflow tasks. 5 | 6 | ## AgentDecoratedOperator 7 | 8 | Operator that executes a `pydantic_ai.Agent`. 9 | 10 | This operator wraps a `pydantic_ai.Agent` instance and executes it within an Airflow task. 11 | It provides enhanced logging capabilities through `WrappedTool`. 12 | 13 | Example: 14 | 15 | ```python 16 | from pydantic_ai import Agent 17 | from airflow_ai_sdk.operators.agent import AgentDecoratedOperator 18 | 19 | def prompt() -> str: 20 | return "Hello" 21 | 22 | operator = AgentDecoratedOperator( 23 | task_id="example", 24 | python_callable=prompt, 25 | agent=Agent(model="o3-mini", system_prompt="Say hello"), 26 | ) 27 | ``` 28 | -------------------------------------------------------------------------------- /docs/interface/operators/embed.md: -------------------------------------------------------------------------------- 1 | # airflow_ai_sdk.operators.embed 2 | 3 | This module provides the EmbedDecoratedOperator class for generating text embeddings 4 | using SentenceTransformer models within Airflow tasks. 5 | 6 | ## EmbedDecoratedOperator 7 | 8 | Operator that builds embeddings for text using SentenceTransformer models. 9 | 10 | This operator generates embeddings for text input using a specified SentenceTransformer 11 | model. It provides a convenient way to create embeddings within Airflow tasks. 12 | 13 | Example: 14 | 15 | ```python 16 | from airflow_ai_sdk.operators.embed import EmbedDecoratedOperator 17 | 18 | def produce_text() -> str: 19 | return "document" 20 | 21 | operator = EmbedDecoratedOperator( 22 | task_id="embed", 23 | python_callable=produce_text, 24 | model_name="all-MiniLM-L12-v2", 25 | ) 26 | ``` 27 | -------------------------------------------------------------------------------- /docs/interface/operators/llm.md: -------------------------------------------------------------------------------- 1 | # airflow_ai_sdk.operators.llm 2 | 3 | This module provides the LLMDecoratedOperator class for making single LLM calls 4 | within Airflow tasks. 5 | 6 | ## LLMDecoratedOperator 7 | 8 | Simpler interface for performing a single LLM call. 9 | 10 | This operator provides a simplified interface for making single LLM calls within 11 | Airflow tasks, without the full agent functionality. 12 | 13 | Example: 14 | 15 | ```python 16 | from airflow_ai_sdk.operators.llm import LLMDecoratedOperator 17 | 18 | def make_prompt() -> str: 19 | return "Hello" 20 | 21 | operator = LLMDecoratedOperator( 22 | task_id="llm", 23 | python_callable=make_prompt, 24 | model="o3-mini", 25 | system_prompt="Reply politely", 26 | ) 27 | ``` 28 | -------------------------------------------------------------------------------- /docs/interface/operators/llm_branch.md: -------------------------------------------------------------------------------- 1 | # airflow_ai_sdk.operators.llm_branch 2 | 3 | This module provides the LLMBranchDecoratedOperator class for branching DAGs based on 4 | LLM decisions within Airflow tasks. 5 | 6 | ## LLMBranchDecoratedOperator 7 | 8 | Branch a DAG based on the result of an LLM call. 9 | 10 | This operator uses an LLM to decide which downstream task to execute next. 11 | It combines the capabilities of an LLM with Airflow's branching functionality. 12 | 13 | Example: 14 | 15 | ```python 16 | from airflow_ai_sdk.operators.llm_branch import LLMBranchDecoratedOperator 17 | 18 | def make_prompt() -> str: 19 | return "Choose" 20 | 21 | operator = LLMBranchDecoratedOperator( 22 | task_id="branch", 23 | python_callable=make_prompt, 24 | model="o3-mini", 25 | system_prompt="Return 'a' or 'b'", 26 | ) 27 | ``` 28 | -------------------------------------------------------------------------------- /docs/usage.md: -------------------------------------------------------------------------------- 1 | # Basic Usage 2 | 3 | ## Installation 4 | 5 | Install the SDK with any optional dependencies you need: 6 | 7 | ```bash 8 | pip install airflow-ai-sdk[openai,duckduckgo] 9 | ``` 10 | 11 | Available optional dependencies are defined in the [pyproject.toml](https://github.com/astronomer/airflow-ai-sdk/blob/main/pyproject.toml#L17) file. You can also install optional dependencies from [Pydantic AI](https://ai.pydantic.dev/install/) directly. 12 | 13 | ## Task Decorators 14 | 15 | ### LLM Tasks with @task.llm 16 | 17 | ```python 18 | from airflow.decorators import dag, task 19 | import pendulum 20 | import airflow_ai_sdk as ai_sdk 21 | 22 | @task.llm( 23 | model="gpt-4o-mini", # model name 24 | result_type=str, # return type 25 | system_prompt="You are a helpful assistant." # system prompt for the LLM 26 | ) 27 | def process_with_llm(input_text: str) -> str: 28 | # This function transforms Airflow task input into LLM input 29 | return input_text 30 | 31 | @dag( 32 | schedule=None, 33 | start_date=pendulum.datetime(2025, 1, 1), 34 | catchup=False, 35 | ) 36 | def simple_llm_dag(): 37 | result = process_with_llm("Summarize the benefits of using Airflow with LLMs.") 38 | 39 | simple_llm_dag() 40 | ``` 41 | 42 | ### Structured Output with Pydantic Models 43 | 44 | ```python 45 | from typing import Literal 46 | import airflow_ai_sdk as ai_sdk 47 | 48 | class TextAnalysis(ai_sdk.BaseModel): 49 | summary: str 50 | sentiment: Literal["positive", "negative", "neutral"] 51 | key_points: list[str] 52 | 53 | @task.llm( 54 | model="gpt-4o-mini", 55 | result_type=TextAnalysis, 56 | system_prompt="Analyze the provided text." 57 | ) 58 | def analyze_text(text: str) -> TextAnalysis: 59 | return text 60 | ``` 61 | 62 | ### Agent Tasks with @task.agent 63 | 64 | ```python 65 | from pydantic_ai import Agent 66 | from pydantic_ai.common_tools.duckduckgo import duckduckgo_search_tool 67 | 68 | research_agent = Agent( 69 | "o3-mini", # model name 70 | system_prompt="You are a research agent that finds information from the web.", 71 | tools=[duckduckgo_search_tool()] # tools the agent can use 72 | ) 73 | 74 | @task.agent(agent=research_agent) 75 | def research_topic(topic: str) -> str: 76 | return topic 77 | ``` 78 | 79 | ### Branching Tasks with @task.llm_branch 80 | 81 | ```python 82 | @task.llm_branch( 83 | model="gpt-4o-mini", 84 | system_prompt="Classify the text based on its priority.", 85 | allow_multiple_branches=False # only select one branch 86 | ) 87 | def classify_priority(text: str) -> str: 88 | return text 89 | 90 | @task 91 | def handle_high_priority(text: str): 92 | print(f"Handling high priority: {text}") 93 | 94 | @task 95 | def handle_medium_priority(text: str): 96 | print(f"Handling medium priority: {text}") 97 | 98 | @task 99 | def handle_low_priority(text: str): 100 | print(f"Handling low priority: {text}") 101 | 102 | @dag(...) 103 | def priority_routing_dag(): 104 | result = classify_priority("This is an urgent request") 105 | 106 | high_task = handle_high_priority(result) 107 | medium_task = handle_medium_priority(result) 108 | low_task = handle_low_priority(result) 109 | 110 | classify_priority >> [high_task, medium_task, low_task] 111 | ``` 112 | 113 | ### Embedding Tasks with @task.embed 114 | 115 | ```python 116 | @task.embed( 117 | model_name="all-MiniLM-L12-v2", 118 | encode_kwargs={"normalize_embeddings": True} 119 | ) 120 | def create_embeddings(text: str) -> list[float]: 121 | return text 122 | 123 | @dag(...) 124 | def embedding_dag(): 125 | texts = ["First text", "Second text", "Third text"] 126 | embeddings = create_embeddings.expand(text=texts) 127 | # Now use embeddings for semantic search, clustering, etc. 128 | ``` 129 | 130 | ## Error Handling 131 | 132 | You can use Airflow's built-in error handling features with these tasks: 133 | 134 | ```python 135 | @task.llm( 136 | model="gpt-4o-mini", 137 | result_type=str, 138 | system_prompt="Answer the question.", 139 | retries=3, # retry 3 times if the task fails 140 | retry_delay=pendulum.duration(seconds=30), # wait 30 seconds between retries 141 | ) 142 | def answer_question(question: str) -> str: 143 | return question 144 | ``` 145 | -------------------------------------------------------------------------------- /examples/.astro/config.yaml: -------------------------------------------------------------------------------- 1 | project: 2 | name: examples 3 | -------------------------------------------------------------------------------- /examples/.astro/dag_integrity_exceptions.txt: -------------------------------------------------------------------------------- 1 | # Add dag files to exempt from parse test below. ex: dags/ 2 | -------------------------------------------------------------------------------- /examples/.astro/test_dag_integrity_default.py: -------------------------------------------------------------------------------- 1 | """Test the validity of all DAGs. **USED BY DEV PARSE COMMAND DO NOT EDIT**""" 2 | 3 | import logging 4 | import os 5 | from contextlib import contextmanager 6 | 7 | import pytest 8 | from airflow.hooks.base import BaseHook 9 | from airflow.models import Connection, DagBag, Variable 10 | from airflow.utils.db import initdb 11 | 12 | # init airflow database 13 | initdb() 14 | 15 | # The following code patches errors caused by missing OS Variables, Airflow Connections, and Airflow Variables 16 | 17 | 18 | # =========== MONKEYPATCH BaseHook.get_connection() =========== 19 | def basehook_get_connection_monkeypatch(key: str, *args, **kwargs): 20 | print( 21 | f"Attempted to fetch connection during parse returning an empty Connection object for {key}" 22 | ) 23 | return Connection(key) 24 | 25 | 26 | BaseHook.get_connection = basehook_get_connection_monkeypatch 27 | # # =========== /MONKEYPATCH BASEHOOK.GET_CONNECTION() =========== 28 | 29 | 30 | # =========== MONKEYPATCH OS.GETENV() =========== 31 | def os_getenv_monkeypatch(key: str, *args, **kwargs): 32 | default = None 33 | if args: 34 | default = args[0] # os.getenv should get at most 1 arg after the key 35 | if kwargs: 36 | default = kwargs.get("default") # and sometimes kwarg if people are using the sig 37 | 38 | env_value = os.environ.get(key, None) 39 | 40 | if env_value: 41 | return env_value # if the env_value is set, return it 42 | if ( 43 | key == "JENKINS_HOME" and default is None 44 | ): # fix https://github.com/astronomer/astro-cli/issues/601 45 | return None 46 | if default: 47 | return default # otherwise return whatever default has been passed 48 | return f"MOCKED_{key.upper()}_VALUE" # if absolutely nothing has been passed - return the mocked value 49 | 50 | 51 | os.getenv = os_getenv_monkeypatch 52 | # # =========== /MONKEYPATCH OS.GETENV() =========== 53 | 54 | # =========== MONKEYPATCH VARIABLE.GET() =========== 55 | 56 | 57 | class magic_dict(dict): 58 | def __init__(self, *args, **kwargs): 59 | self.update(*args, **kwargs) 60 | 61 | def __getitem__(self, key): 62 | return {}.get(key, "MOCKED_KEY_VALUE") 63 | 64 | 65 | _no_default = object() # allow falsey defaults 66 | 67 | 68 | def variable_get_monkeypatch(key: str, default_var=_no_default, deserialize_json=False): 69 | print( 70 | f"Attempted to get Variable value during parse, returning a mocked value for {key}" 71 | ) 72 | 73 | if default_var is not _no_default: 74 | return default_var 75 | if deserialize_json: 76 | return magic_dict() 77 | return "NON_DEFAULT_MOCKED_VARIABLE_VALUE" 78 | 79 | 80 | Variable.get = variable_get_monkeypatch 81 | # # =========== /MONKEYPATCH VARIABLE.GET() =========== 82 | 83 | 84 | @contextmanager 85 | def suppress_logging(namespace): 86 | """ 87 | Suppress logging within a specific namespace to keep tests "clean" during build 88 | """ 89 | logger = logging.getLogger(namespace) 90 | old_value = logger.disabled 91 | logger.disabled = True 92 | try: 93 | yield 94 | finally: 95 | logger.disabled = old_value 96 | 97 | 98 | def get_import_errors(): 99 | """ 100 | Generate a tuple for import errors in the dag bag, and include DAGs without errors. 101 | """ 102 | with suppress_logging("airflow"): 103 | dag_bag = DagBag(include_examples=False) 104 | 105 | def strip_path_prefix(path): 106 | return os.path.relpath(path, os.environ.get("AIRFLOW_HOME")) 107 | 108 | # Initialize an empty list to store the tuples 109 | result = [] 110 | 111 | # Iterate over the items in import_errors 112 | for k, v in dag_bag.import_errors.items(): 113 | result.append((strip_path_prefix(k), v.strip())) 114 | 115 | # Check if there are DAGs without errors 116 | for file_path in dag_bag.dags: 117 | # Check if the file_path is not in import_errors, meaning no errors 118 | if file_path not in dag_bag.import_errors: 119 | result.append((strip_path_prefix(file_path), "No import errors")) 120 | 121 | return result 122 | 123 | 124 | @pytest.mark.parametrize( 125 | "rel_path, rv", get_import_errors(), ids=[x[0] for x in get_import_errors()] 126 | ) 127 | def test_file_imports(rel_path, rv): 128 | """Test for import errors on a file""" 129 | if os.path.exists(".astro/dag_integrity_exceptions.txt"): 130 | with open(".astro/dag_integrity_exceptions.txt") as f: 131 | exceptions = f.readlines() 132 | print(f"Exceptions: {exceptions}") 133 | if (rv != "No import errors") and rel_path not in exceptions: 134 | # If rv is not "No import errors," consider it a failed test 135 | raise Exception(f"{rel_path} failed to import with message \n {rv}") 136 | # If rv is "No import errors," consider it a passed test 137 | print(f"{rel_path} passed the import test") 138 | -------------------------------------------------------------------------------- /examples/.dockerignore: -------------------------------------------------------------------------------- 1 | astro 2 | .git 3 | .env 4 | airflow_settings.yaml 5 | logs/ 6 | .venv 7 | airflow.db 8 | airflow.cfg 9 | __pycache__/ 10 | **/*.pyc 11 | **/logs/** 12 | -------------------------------------------------------------------------------- /examples/.gitignore: -------------------------------------------------------------------------------- 1 | .git 2 | .env 3 | .DS_Store 4 | airflow_settings.yaml 5 | __pycache__/ 6 | astro 7 | .venv 8 | airflow-webserver.pid 9 | webserver_config.py 10 | airflow.cfg 11 | airflow.db 12 | logs/ 13 | **/logs/** 14 | -------------------------------------------------------------------------------- /examples/Dockerfile: -------------------------------------------------------------------------------- 1 | FROM astrocrpublic.azurecr.io/runtime:3.0-1-base 2 | 3 | USER root 4 | 5 | RUN pip install -U uv 6 | 7 | COPY ./examples/requirements.txt ${AIRFLOW_HOME}/requirements.txt 8 | COPY ./pyproject.toml ${AIRFLOW_HOME}/airflow_ai_sdk/pyproject.toml 9 | COPY ./README.md ${AIRFLOW_HOME}/airflow_ai_sdk/README.md 10 | COPY ./uv.lock ${AIRFLOW_HOME}/airflow_ai_sdk/uv.lock 11 | COPY ./airflow_ai_sdk ${AIRFLOW_HOME}/airflow_ai_sdk/airflow_ai_sdk 12 | 13 | # install the package in editable mode 14 | RUN uv pip install --system -e "${AIRFLOW_HOME}/airflow_ai_sdk[openai,duckduckgo]" 15 | 16 | # install the requirements 17 | RUN uv pip install --system -r "${AIRFLOW_HOME}/requirements.txt" 18 | 19 | # make sure astro user owns the package 20 | RUN chown -R astro:astro ${AIRFLOW_HOME}/airflow_ai_sdk 21 | 22 | USER astro 23 | -------------------------------------------------------------------------------- /examples/README.md: -------------------------------------------------------------------------------- 1 | # examples 2 | 3 | This directory contains examples of how to use the Airflow AI SDK. To run Airflow locally, run (from the root of the repo): 4 | 5 | ```bash 6 | export AIRFLOW_HOME=$(pwd)/examples AIRFLOW__CORE__LOAD_EXAMPLES=false && uv run airflow standalone 7 | ``` 8 | 9 | Each example can also be run with: 10 | 11 | ```bash 12 | uv run examples/dags/example_dag.py 13 | ``` 14 | -------------------------------------------------------------------------------- /examples/dags/deep_research.py: -------------------------------------------------------------------------------- 1 | """ 2 | This shows how to use the SDK to build a deep research agent. 3 | """ 4 | 5 | import pendulum 6 | import requests 7 | try: 8 | from airflow.sdk import dag, task 9 | except ImportError: 10 | from airflow.decorators import dag, task 11 | from airflow.models.dagrun import DagRun 12 | from airflow.models.param import Param 13 | from bs4 import BeautifulSoup 14 | from pydantic_ai import Agent 15 | from pydantic_ai.common_tools.duckduckgo import duckduckgo_search_tool 16 | 17 | 18 | async def get_page_content(url: str) -> str: 19 | """ 20 | Get the content of a page. 21 | """ 22 | response = requests.get(url) 23 | soup = BeautifulSoup(response.text, "html.parser") 24 | 25 | distillation_agent = Agent( 26 | "gpt-4o-mini", 27 | system_prompt=""" 28 | You are responsible for distilling information from a text. The summary will be used by a research agent to generate a research report. 29 | 30 | Keep the summary concise and to the point, focusing on only key information. 31 | """, 32 | ) 33 | 34 | return await distillation_agent.run(soup.get_text()) 35 | 36 | deep_research_agent = Agent( 37 | "o3-mini", 38 | system_prompt=""" 39 | You are a deep research agent who is very skilled at distilling information from the web. You are given a query and your job is to generate a research report. 40 | 41 | You can search the web by using the `duckduckgo_search_tool`. You can also use the `get_page_content` tool to get the contents of a page. 42 | 43 | Keep going until you have enough information to generate a research report. Assume you know nothing about the query or contents, so you need to search the web for relevant information. 44 | 45 | Find at least 8-10 sources to include in the research report. If you run out of sources, keep searching the web for more information with variations of the question. 46 | 47 | Do not use quotes in your search queries. 48 | 49 | Do not generate new information, only distill information from the web. If you want to cite a source, make sure you fetch the full contents because the summary may not be enough. 50 | """, 51 | tools=[duckduckgo_search_tool(), get_page_content], 52 | ) 53 | 54 | @task.agent(agent=deep_research_agent) 55 | def deep_research_task(dag_run: DagRun) -> str: 56 | """ 57 | This task performs a deep research on the given query. 58 | """ 59 | query = dag_run.conf.get("query") 60 | 61 | if not query: 62 | raise ValueError("Query is required") 63 | 64 | print(f"Performing deep research on {query}") 65 | 66 | return query 67 | 68 | 69 | @task 70 | def upload_results(results: str): 71 | print("Uploading results") 72 | print("-" * 100) 73 | print(results) 74 | print("-" * 100) 75 | 76 | @dag( 77 | schedule=None, 78 | start_date=pendulum.datetime(2025, 3, 1, tz="UTC"), 79 | catchup=False, 80 | params={ 81 | "query": Param( 82 | type="string", 83 | default="How has the field of data engineering evolved in the last 5 years?", 84 | ), 85 | }, 86 | ) 87 | def deep_research(): 88 | results = deep_research_task() 89 | upload_results(results) 90 | 91 | dag = deep_research() 92 | -------------------------------------------------------------------------------- /examples/dags/email_generation.py: -------------------------------------------------------------------------------- 1 | """ 2 | This example consumes a list of prospects and generates personalized email messages for each prospect. 3 | """ 4 | 5 | import pendulum 6 | try: 7 | from airflow.sdk import dag, task 8 | except ImportError: 9 | from airflow.decorators import dag, task 10 | from airflow.exceptions import AirflowSkipException 11 | 12 | import airflow_ai_sdk as ai_sdk 13 | 14 | 15 | @task 16 | def get_prospects() -> list[dict]: 17 | """ 18 | Get the list of prospects from the database. 19 | """ 20 | return [ 21 | { 22 | "name": "John Doe", 23 | "company": "Acme Inc.", 24 | "industry": "Software", 25 | "job_title": "CTO", 26 | "lead_source": "LinkedIn", 27 | }, 28 | { 29 | "name": "Jane Smith", 30 | "company": "Smith Corp.", 31 | "industry": "Financial Services", 32 | "job_title": "Data Engineer", 33 | "lead_source": "Product Trial" 34 | }, 35 | { 36 | "name": "Bob Johnson", 37 | "company": "Tech Solutions", 38 | "industry": "Adtech", 39 | "job_title": "VP of Engineering", 40 | "lead_source": "Contact Us Form", 41 | }, 42 | { 43 | "name": "Alice Brown", 44 | "company": "DataTech", 45 | "industry": "Consulting", 46 | "job_title": "Data Analyst", 47 | "lead_source": "Meetup", 48 | }, 49 | { 50 | "name": "Charlie Green", 51 | "company": "GreenTech", 52 | "industry": "Climate Tech", 53 | "job_title": "Data Engineering Manager", 54 | "lead_source": "LinkedIn", 55 | }, 56 | ] 57 | 58 | 59 | class Email(ai_sdk.BaseModel): 60 | subject: str 61 | body: str 62 | 63 | @task.llm( 64 | model="o3-mini", 65 | result_type=Email, 66 | system_prompt=""" 67 | You are a sales agent who is responsible for generating personalized email messages for prospects for Astro, 68 | the best managed Airflow service on the market. Given the audience is technical, you should focus on the 69 | features and technology as opposed to more generic marketing/sales language. 70 | 71 | You will be given a list of prospects and your job is to generate a personalized email message for each prospect. 72 | 73 | Here are some things to focus on: 74 | - Use the prospect's name in the email subject line 75 | - Keep the email subject line concise and to the point 76 | - Use the prospect's company and job title to personalize the email 77 | - Think very hard about what the prospect would want to read in an email 78 | - Include a call to action in the email 79 | - Ask a question in the email 80 | 81 | Here are some things to avoid: 82 | - Don't use generic language 83 | - Don't use filler words 84 | - Don't use vague language 85 | - Don't use clichés 86 | 87 | Here is some helpful information about Astro: 88 | - **Not just managed Airflow** – Astro is a **unified DataOps platform** that lets you seamlessly **build, run, and observe** data pipelines in one place, going beyond basic Apache Airflow-as-a-service. This unified approach eliminates fragmented tools and silos across the data lifecycle. 89 | - **Orchestration is mission-critical** – Modern businesses run on data pipelines. Over 90% of data engineers recommend Apache Airflow, and more than half of large enterprises use it for their **most critical workloads**. Astro delivers Airflow’s power as a **trusted, enterprise-grade service**, ensuring these vital pipelines are dependable and scalable. 90 | - **Data pipelines drive revenue** – Orchestrated data workflows aren’t just for internal reports anymore. **85%+ of teams plan to use Airflow for customer-facing, revenue-generating solutions** (like AI-driven products and automated customer experiences) in the next year. Astro’s platform helps organizations **deliver these innovations faster**, turning data pipelines into a competitive advantage. 91 | - **Eliminates engineering pain** – Astro **handles the heavy lifting** of pipeline infrastructure so your team doesn’t have to. It abstracts away maintenance headaches like cluster scaling, scheduling failovers, and Airflow upgrades, freeing your data engineers to focus on building value rather than managing servers. Teams using Astro report migrating complex workflows “a lot faster than expected” and no longer worry about keeping Airflow healthy. 92 | - **Built-in observability** – With Astro, you get **pipeline observability and alerting out-of-the-box**. The platform provides SLA dashboards, data lineage tracking, and real-time alerts on failures, all integrated with your orchestration. This means you can quickly detect issues, ensure data quality, and trust that your pipelines deliver fresh data on time – without bolting on third-party monitoring tools. 93 | - **Intelligent automation** – Astro goes beyond manual scheduling with AI-driven capabilities. It can auto-tune and even self-heal pipelines (e.g. retrying tasks, adjusting resources), and it offers smart assistants for DAG authoring (like natural language pipeline generation). The result is a boost in **pipeline reliability and efficiency** – Astro users see significant gains in uptime and team productivity across the data lifecycle. 94 | - **24×7 expert support** – Adopting Astro means you’re backed by **Apache Airflow experts** whenever you need help. Astronomer provides **24/7 enterprise support** from top Airflow committers, giving your team direct access to experts for troubleshooting and best-practice guidance. This white-glove support and professional services de-risk your data projects and ensure success in production. 95 | - **Boosts developer productivity** – Astro comes with tooling that supercharges data engineering workflows. For example, the **Astro CLI** lets you run and test DAGs locally in a production-like environment, and Astro’s cloud IDE and CI/CD integrations make it easy to write, version, and deploy pipelines with less boilerplate. These features let your team iterate faster and with confidence. 96 | - **Northern Trust (Financial Services)** – Replaced a legacy scheduler (Control-M) with Astro to modernize its data workflows, laying a solid foundation for future growth and innovation. By migrating to Astro, Northern Trust eliminated the limitations of their old batch processes and can now deliver data products faster in a highly regulated environment. 97 | - **Black Crow AI (Marketing Tech)** – Turned to Astro to overcome massive Airflow scaling challenges as their data operations grew. With Astro’s managed orchestration, Black Crow AI now reliably delivers **AI-driven data products** to customers, even as data volumes and workloads spike with company growth. 98 | - **McKenzie Intelligence (Geospatial Analytics)** – Used Astro to eliminate manual data-processing tasks and enforce consistency, effectively tripling their efficiency in analyzing disaster impacts. This automation enabled McKenzie to run critical catastrophe assessment pipelines 24/7 worldwide, vastly improving response time and coverage. 99 | - **Bestow (Life Insurance)** – Overcame early pipeline bottlenecks by adopting Astro, which accelerated developer productivity and operational efficiency. By offloading orchestration to Astro, Bestow’s engineering team removed maintenance burdens and delivered new insurance insights faster, helping transform the life insurance landscape with data-driven services. 100 | - **SciPlay (Gaming)** – Scaled up game data analytics with Astro’s managed Airflow, allowing this social gaming leader to handle surging data without missing a beat. Offloading pipeline orchestration to Astro helped SciPlay drive rapid innovation in player analytics and personalized features, directly supporting player engagement and revenue growth. 101 | - **Black Wealth Data Center (Non-profit)** – Chose Astro as a scalable, sustainable Airflow solution to run their data pipelines for social impact. Astro’s fully managed service allowed BWDC to expand their analytics initiatives without worrying about infrastructure limits or platform reliability, so they can focus on their mission of closing the racial wealth gap. 102 | - **Anastasia (Retail Analytics)** – Migrated from AWS Step Functions to Astro to power its AI‑powered insights platform for small retailers. With Astro orchestrating complex workflows behind the scenes, Anastasia optimizes clients’ inventory and sales predictions reliably, addressing the pressing operational challenges that SMBs face in real time. 103 | - **Laurel (Timekeeping AI)** – Freed up its data team by moving to Astro’s managed Airflow, giving engineers more time to build revenue-generating ML pipelines instead of fighting fires. This partnership has accelerated Laurel’s machine learning development for automated timekeeping, as the data team can iterate on models without being bogged down by pipeline maintenance. 104 | - **Texas Rangers (Sports)** – Orchestrated the MLB team’s analytics on Astro and cut data delivery time by 24 hours with zero additional infrastructure cost. Faster data availability means coaches and analysts get next-day insights instead of a two-day lag, improving game preparation and in-game decision-making with up-to-the-minute analytics. 105 | - **Autodesk (Software)** – Retired a legacy Oozie scheduler and migrated hundreds of critical workflows to Astro with help from Astronomer’s experts. By partnering with Astro, Autodesk gained a modern, Airflow-powered orchestration backbone for its cloud transformation – one that scales with demand and removes the pain of managing their own scheduling infrastructure. 106 | - **CRED (Fintech)** – Switched from a brittle Apache NiFi setup to Astro’s fully managed Airflow to keep pace with hyper-growth in users and data. With Astro, CRED achieved faster and more reliable data pipelines on a scalable Airflow foundation, ensuring that as their business grew, their data platform stayed ahead of demand instead of becoming a bottleneck. 107 | - **VTEX (E-Commerce)** – Adopted Astro to enforce consistency and reliability across complex data environments in its global commerce platform. Astro’s managed infrastructure and one-click Airflow upgrades meant VTEX could cut through pipeline complexity and always stay on the latest features. The time saved on debugging and upkeep has allowed VTEX’s data team to move much faster and extend orchestration to new teams, unlocking use cases in recruitment analytics, sales dashboards, and more. 108 | 109 | Here are some examples of successful emails: 110 | 111 | ### Email 1: Modernize Your Data Pipelines with Astro 112 | 113 | **Subject:** Modernize Your Data Pipelines with Astro’s Unified DataOps Platform 114 | 115 | Hi [Name], 116 | 117 | I’m reaching out to introduce Astro—a unified DataOps platform that goes beyond managed Airflow to help you build, run, and observe data pipelines effortlessly. 118 | 119 | **Key benefits include:** 120 | - A single platform to streamline complex data workflows 121 | - Built-in observability with SLA dashboards, data lineage, and real-time alerts 122 | - 24×7 expert support from top Airflow committers to help your team every step of the way 123 | 124 | Companies like Northern Trust have modernized their data workflows with Astro, leaving legacy systems behind. I’d love to show you how Astro can eliminate engineering headaches and accelerate your data innovation. 125 | 126 | Are you open to a brief call next week? 127 | 128 | Best regards, 129 | [Your Name] 130 | [Your Title] 131 | [Your Contact Information] 132 | 133 | --- 134 | 135 | ### Email 2: Accelerate Your Data Innovation with Intelligent Automation 136 | 137 | **Subject:** Accelerate Your Data Innovation with Intelligent Automation 138 | 139 | Hi [Name], 140 | 141 | In today’s competitive landscape, efficient data pipelines are key to unlocking revenue-generating insights. Astro’s unified DataOps platform offers intelligent automation that can auto-tune and even self-heal your pipelines. 142 | 143 | **Why Astro?** 144 | - Seamlessly manage critical data workflows with minimal manual intervention 145 | - Empower your team with developer tools like the Astro CLI for faster iteration 146 | - Proven results: Companies like Black Crow AI leverage Astro to reliably deliver AI-driven data products even as workloads spike 147 | 148 | I’d love to discuss how Astro can help your organization deliver innovations faster and free up your engineering team to focus on strategic initiatives. 149 | 150 | Looking forward to connecting, 151 | [Your Name] 152 | [Your Title] 153 | [Your Contact Information] 154 | 155 | --- 156 | 157 | ### Email 3: Overcome Pipeline Challenges and Scale with Confidence 158 | 159 | **Subject:** Overcome Pipeline Challenges & Scale with Astro’s DataOps Platform 160 | 161 | Hi [Name], 162 | 163 | Managing and scaling data pipelines shouldn’t hold back your growth. Astro is built to solve common data engineering pain points by handling the heavy lifting of orchestration, scaling, and maintenance. 164 | 165 | **How Astro makes a difference:** 166 | - Eliminates the headaches of infrastructure management so your team can focus on high-value projects 167 | - Ensures robust, scalable data pipelines with enterprise-grade reliability 168 | - Success story: Autodesk transitioned hundreds of workflows to Astro, modernizing their orchestration backbone with expert support 169 | 170 | Let’s explore how Astro can empower your team to scale data operations seamlessly. Would you be available for a quick call this week? 171 | 172 | Thanks, 173 | [Your Name] 174 | [Your Title] 175 | [Your Contact Information] 176 | 177 | --- 178 | 179 | ### Email 4: Empower Your Team with Superior Developer Productivity 180 | 181 | **Subject:** Empower Your Team with Astro’s Productivity-Boosting Tools 182 | 183 | Hi [Name], 184 | 185 | I wanted to share how Astro’s unified DataOps platform can dramatically boost your team’s productivity. By offloading pipeline orchestration and maintenance to Astro, your developers can focus on building innovative, revenue-generating solutions. 186 | 187 | **Astro’s key productivity benefits:** 188 | - Local testing with the Astro CLI and seamless CI/CD integrations for efficient deployments 189 | - Intuitive tooling that lets your team iterate faster without worrying about infrastructure challenges 190 | - Proven results: Bestow has accelerated its product delivery by reducing pipeline bottlenecks through Astro 191 | 192 | Could we schedule a time to discuss how Astro can help your team work smarter, not harder? 193 | 194 | Best, 195 | [Your Name] 196 | [Your Title] 197 | [Your Contact Information] 198 | 199 | --- 200 | 201 | ### Email 5: Enhance Data Reliability with Built-in Observability & 24×7 Support 202 | 203 | **Subject:** Enhance Data Reliability with Astro’s Observability & Expert Support 204 | 205 | Hi [Name], 206 | 207 | Ensuring data quality and reliability is critical in today’s data-driven environment. Astro’s unified DataOps platform not only orchestrates your data pipelines but also offers built-in observability and 24×7 expert support. 208 | 209 | **What sets Astro apart:** 210 | - Real-time monitoring with SLA dashboards and data lineage tracking 211 | - Proactive alerts to quickly identify and resolve pipeline issues before they impact your business 212 | - Trusted by leading organizations like VTEX and SciPlay for its consistent reliability and round-the-clock support 213 | 214 | I’d love to share more on how Astro can help you maintain flawless data operations. Are you available for a brief call this week? 215 | 216 | Regards, 217 | [Your Name] 218 | [Your Title] 219 | [Your Contact Information] 220 | """ 221 | ) 222 | def generate_email(prospect: dict | None = None) -> Email: 223 | """ 224 | Generate a personalized email message for the prospect. 225 | """ 226 | if prospect is None: 227 | raise AirflowSkipException("No prospect provided") 228 | 229 | return f""" 230 | Name: {prospect["name"]} 231 | Company: {prospect["company"]} 232 | Industry: {prospect["industry"]} 233 | Job Title: {prospect["job_title"]} 234 | """ 235 | 236 | 237 | @task 238 | def send_email(email: dict[str, str] | None = None): 239 | """ 240 | Send the email to the prospect. Just print the email for now. 241 | """ 242 | if email is None: 243 | raise AirflowSkipException("No email provided") 244 | 245 | from pprint import pprint 246 | 247 | pprint(email) 248 | 249 | 250 | @dag( 251 | schedule=None, 252 | start_date=pendulum.datetime(2025, 3, 1, tz="UTC"), 253 | catchup=False, 254 | ) 255 | def email_generation(): 256 | prospects = get_prospects() 257 | emails = generate_email.expand(prospect=prospects) 258 | send_email.expand(email=emails) 259 | 260 | my_dag = email_generation() 261 | -------------------------------------------------------------------------------- /examples/dags/github_changelog.py: -------------------------------------------------------------------------------- 1 | """ 2 | This shows how to use the SDK to build a simple GitHub change summarization workflow. 3 | """ 4 | 5 | import os 6 | 7 | import pendulum 8 | try: 9 | from airflow.sdk import dag, task 10 | except ImportError: 11 | from airflow.decorators import dag, task 12 | from github import Github 13 | 14 | 15 | @task 16 | def get_recent_commits(data_interval_start: pendulum.DateTime, data_interval_end: pendulum.DateTime) -> list[str]: 17 | """ 18 | This task returns a mocked list of recent commits. In a real workflow, this 19 | task would get the recent commits from a database or API. 20 | """ 21 | print(f"Getting commits for {data_interval_start} to {data_interval_end}") 22 | gh = Github(os.getenv("GITHUB_TOKEN")) 23 | repo = gh.get_repo("apache/airflow") 24 | commits = repo.get_commits(since=data_interval_start, until=data_interval_end) 25 | return [f"{commit.commit.sha}: {commit.commit.message}" for commit in commits] 26 | 27 | @task.llm( 28 | model="gpt-4o-mini", 29 | system_prompt=""" 30 | Your job is to summarize the commits to the Airflow project given a week's worth 31 | of commits. Pay particular attention to large changes and new features as opposed 32 | to bug fixes and minor changes. 33 | 34 | You don't need to include every commit, just the most important ones. Add a one line 35 | overall summary of the changes at the top, followed by bullet points of the most 36 | important changes. 37 | 38 | Example output: 39 | 40 | This week, we made architectural changes to the core scheduler to make it more 41 | maintainable and easier to understand. 42 | 43 | - Made the scheduler 20% faster (commit 1234567) 44 | - Added a new task type: `example_task` (commit 1234568) 45 | - Added a new operator: `example_operator` (commit 1234569) 46 | - Added a new sensor: `example_sensor` (commit 1234570) 47 | """ 48 | ) 49 | def summarize_commits(commits: list[str] | None = None) -> str: 50 | """ 51 | This task summarizes the commits. 52 | """ 53 | # don't need to do any translation 54 | return "\n".join(commits) 55 | 56 | @task 57 | def send_summaries(summaries: str): 58 | """ 59 | This task prints the summaries. In a real workflow, this task would send the summaries to a chat channel. 60 | """ 61 | print(summaries) 62 | 63 | @dag( 64 | schedule="@weekly", 65 | start_date=pendulum.datetime(2025, 3, 1, tz="UTC"), 66 | catchup=False, 67 | ) 68 | def github_changelog(): 69 | commits = get_recent_commits() 70 | summaries = summarize_commits(commits=commits) 71 | send_summaries(summaries) 72 | 73 | dag = github_changelog() 74 | -------------------------------------------------------------------------------- /examples/dags/product_feedback_summarization.py: -------------------------------------------------------------------------------- 1 | """ 2 | This shows how to use the SDK to build a simple product feedback summarization workflow. 3 | """ 4 | 5 | from typing import Any, Literal 6 | 7 | import pendulum 8 | try: 9 | from airflow.sdk import dag, task 10 | except ImportError: 11 | from airflow.decorators import dag, task 12 | from airflow.exceptions import AirflowSkipException 13 | 14 | import airflow_ai_sdk as ai_sdk 15 | 16 | 17 | @task 18 | def get_product_feedback() -> list[str]: 19 | """ 20 | This task returns a mocked list of product feedback. In a real workflow, this 21 | task would get the product feedback from a database or API. 22 | """ 23 | return [ 24 | "I absolutely love Apache Airflow’s intuitive user interface and its robust DAG visualization capabilities. The scheduling and task dependency features are a joy to work with and greatly enhance my workflow efficiency. I would love to see an auto-scaling feature for tasks in future releases to further optimize performance.", 25 | "The overall experience with Apache Airflow has been disappointing due to its steep learning curve and inconsistent documentation. Many features seem underdeveloped, and the UI often feels clunky and unresponsive. It would be great if future updates included a revamped interface and clearer setup guides.", 26 | "Apache Airflow shines with its flexible Python-based task definitions and comprehensive logging system, making it a standout tool in workflow management. The integration capabilities are top-notch and have streamlined my data pipelines remarkably. I do hope that upcoming versions will include enhanced real-time monitoring features to further improve user experience.", 27 | "My experience with Apache Airflow has been largely negative, primarily because of the frequent performance lags and the overwhelming complexity of the configuration process. The lack of clear error messages only adds to the frustration. I wish the development team would simplify the setup process and incorporate more user-friendly error reporting mechanisms.", 28 | "I am very impressed with Apache Airflow’s modular design and its extensive library of operators, which together make it a powerful tool for orchestrating complex workflows. The overall stability during routine operations is commendable, and the community support is excellent. However, a feature for customizable dashboards would be a welcome addition in future updates.", 29 | "Using Apache Airflow has been a challenging experience due to its unintuitive interface and the slow performance during high-load periods. The limited documentation on advanced features makes troubleshooting a real hassle. I strongly recommend that future versions include comprehensive tutorials and performance optimizations to address these issues.", 30 | "Apache Airflow offers a remarkable level of flexibility with its DAG management and integration capabilities, which I find very useful. The platform consistently performs well in orchestrating multi-step data processes, and I appreciate its strong community backing. A potential enhancement could be the introduction of an in-built scheduling calendar for a more streamlined planning process.", 31 | "I have encountered numerous issues with Apache Airflow, including a clunky UI and recurring glitches that disrupt workflow execution. The error handling is inadequate, making it difficult to pinpoint problems during failures. It would be beneficial if the next update focused on enhancing UI stability and implementing a more robust error recovery system." 32 | "This is a review that is about something random" 33 | ] 34 | 35 | class ProductFeedbackSummary(ai_sdk.BaseModel): 36 | summary: str 37 | sentiment: Literal["positive", "negative", "neutral"] 38 | feature_requests: list[str] 39 | 40 | 41 | @task.llm(model="gpt-4o-mini", result_type=ProductFeedbackSummary, system_prompt="Extract the summary, sentiment, and feature requests from the product feedback.",) 42 | def summarize_product_feedback(feedback: str | None = None) -> ProductFeedbackSummary: 43 | """ 44 | This task summarizes the product feedback. You can add logic here to transform the input 45 | before summarizing it. 46 | """ 47 | # if the feedback doesn't mention Airflow, skip it 48 | if "Airflow" not in feedback: 49 | raise AirflowSkipException("Feedback does not mention Airflow") 50 | 51 | return feedback 52 | 53 | 54 | @task 55 | def upload_summaries(summaries: list[dict[str, Any]]): 56 | """ 57 | This task prints the summaries. In a real workflow, this task would upload the summaries to a database or API. 58 | """ 59 | from pprint import pprint 60 | for summary in summaries: 61 | pprint(summary) 62 | 63 | @dag( 64 | schedule=None, 65 | start_date=pendulum.datetime(2021, 1, 1, tz="UTC"), 66 | catchup=False, 67 | ) 68 | def product_feedback_summarization(): 69 | feedback = get_product_feedback() 70 | summaries = summarize_product_feedback.expand(feedback=feedback) 71 | upload_summaries(summaries) 72 | 73 | dag = product_feedback_summarization() 74 | -------------------------------------------------------------------------------- /examples/dags/sentiment_classification.py: -------------------------------------------------------------------------------- 1 | from typing import Literal 2 | from airflow.decorators import dag, task 3 | import pendulum 4 | from airflow.models.dagrun import DagRun 5 | 6 | 7 | @task.llm( 8 | model="gpt-4o-mini", 9 | result_type=Literal["positive", "negative", "neutral"], 10 | system_prompt="Classify the sentiment of the given text.", 11 | ) 12 | def process_with_llm(dag_run: DagRun) -> str: 13 | input_text = dag_run.conf.get("input_text") 14 | 15 | # can do pre-processing here (e.g. PII redaction) 16 | return input_text 17 | 18 | 19 | @dag( 20 | schedule=None, 21 | start_date=pendulum.datetime(2025, 1, 1), 22 | catchup=False, 23 | params={"input_text": "I'm very happy with the product."}, 24 | ) 25 | def sentiment_classification(): 26 | process_with_llm() 27 | 28 | 29 | sentiment_classification() 30 | -------------------------------------------------------------------------------- /examples/dags/support_ticket_routing.py: -------------------------------------------------------------------------------- 1 | """ 2 | Example DAG that routes support tickets to the correct department using the llm_branch decorator. 3 | """ 4 | 5 | import pendulum 6 | try: 7 | from airflow.sdk import dag, task 8 | except ImportError: 9 | from airflow.decorators import dag, task 10 | from airflow.models.dagrun import DagRun 11 | 12 | 13 | def mask_pii(ticket_content: str) -> str: 14 | """ 15 | This function masks PII in the ticket content. You could do this one of a few ways... 16 | - Use regexes / string replacements to mask PII 17 | - Use a pretrained model via an API (e.g. HuggingFace) 18 | - Use a custom model run locally on the machine 19 | """ 20 | return ticket_content 21 | 22 | @task.llm_branch( 23 | model="gpt-4o-mini", 24 | system_prompt=""" 25 | You are a support agent that routes support tickets based on the priority of the ticket. 26 | 27 | Here are the priority definitions: 28 | - P0: Critical issues that impact the user's ability to use the product, specifically for a production deployment. 29 | - P1: Issues that impact the user's ability to use the product, but not as severely (or not for their production deployment). 30 | - P2: Issues that are low priority and can wait until the next business day 31 | - P3: Issues that are not important or time sensitive 32 | 33 | Here are some examples of tickets and their priorities: 34 | - "Our production deployment just went down because it ran out of memory. Please help.": P0 35 | - "Our staging / dev / QA deployment just went down because it ran out of memory. Please help.": P1 36 | - "I'm having trouble logging in to my account.": P1 37 | - "The UI is not loading.": P1 38 | - "I need help setting up my account.": P2 39 | - "I have a question about the product.": P3 40 | """, 41 | allow_multiple_branches=True, 42 | ) 43 | def route_ticket(dag_run: DagRun) -> str: 44 | """ 45 | This task routes the support ticket to the correct department based on the priority of the ticket. It also does 46 | PII masking on the ticket content before sending it to the LLM. 47 | """ 48 | ticket_content = dag_run.conf.get("ticket") 49 | 50 | # mask PII in the ticket content 51 | ticket_content = mask_pii(ticket_content) 52 | 53 | return ticket_content 54 | 55 | @task 56 | def handle_p0_ticket(ticket: str): 57 | print(f"Handling P0 ticket: {ticket}") 58 | 59 | @task 60 | def handle_p1_ticket(ticket: str): 61 | print(f"Handling P1 ticket: {ticket}") 62 | 63 | @task 64 | def handle_p2_ticket(ticket: str): 65 | print(f"Handling P2 ticket: {ticket}") 66 | 67 | @task 68 | def handle_p3_ticket(ticket: str): 69 | print(f"Handling P3 ticket: {ticket}") 70 | 71 | @dag( 72 | start_date=pendulum.datetime(2025, 1, 1, tz="UTC"), 73 | schedule=None, 74 | catchup=False, 75 | params={"ticket": "Hi, our production deployment just went down because it ran out of memory. Please help."} 76 | ) 77 | def support_ticket_routing(): 78 | ticket = route_ticket() 79 | 80 | handle_p0_ticket(ticket) 81 | handle_p1_ticket(ticket) 82 | handle_p2_ticket(ticket) 83 | handle_p3_ticket(ticket) 84 | 85 | dag = support_ticket_routing() 86 | -------------------------------------------------------------------------------- /examples/dags/text_embedding.py: -------------------------------------------------------------------------------- 1 | """ 2 | Example DAG that demonstrates how to use the @task.embed decorator to create vector embeddings. 3 | """ 4 | 5 | import pendulum 6 | 7 | from airflow.decorators import dag, task 8 | 9 | @task 10 | def get_texts() -> list[str]: 11 | """ 12 | This task returns a list of texts to embed. In a real workflow, this 13 | task would get the texts from a database or API. 14 | """ 15 | return [ 16 | "The quick brown fox jumps over the lazy dog", 17 | "A fast orange fox leaps over a sleepy canine", 18 | "The weather is beautiful today", 19 | ] 20 | 21 | @task.embed( 22 | model_name="all-MiniLM-L12-v2", # default model 23 | encode_kwargs={"normalize_embeddings": True} # optional kwargs for the encode method 24 | ) 25 | def create_embeddings(text: str) -> list[float]: 26 | """ 27 | This task creates embeddings for the given text. The decorator handles 28 | the model initialization and encoding. 29 | """ 30 | return text 31 | 32 | @task 33 | def store_embeddings(embeddings: list[list[float]]): 34 | """ 35 | This task stores the embeddings. In a real workflow, this task would 36 | store the embeddings in a vector database. 37 | """ 38 | print(f"Storing {len(embeddings)} embeddings") 39 | 40 | @dag( 41 | schedule=None, 42 | start_date=pendulum.datetime(2025, 1, 1, tz="UTC"), 43 | catchup=False, 44 | ) 45 | def text_embedding(): 46 | texts = get_texts() 47 | embeddings = create_embeddings.expand(text=texts) 48 | store_embeddings(embeddings) 49 | 50 | text_embedding() 51 | -------------------------------------------------------------------------------- /examples/docker-compose.yaml: -------------------------------------------------------------------------------- 1 | --- 2 | version: '3' 3 | 4 | x-airflow-common: &airflow-common 5 | image: airflow-ai-sdk 6 | build: 7 | context: .. 8 | dockerfile: examples/Dockerfile 9 | env_file: .env 10 | networks: 11 | - airflow 12 | environment: &common-env-vars 13 | AIRFLOW__API__BASE_URL: "http://localhost:8080" 14 | AIRFLOW__API__PORT: 8080 15 | AIRFLOW__API_AUTH__JWT_SECRET: "airflow-ai-sdk" 16 | AIRFLOW__CORE__AUTH_MANAGER: airflow.api_fastapi.auth.managers.simple.simple_auth_manager.SimpleAuthManager 17 | AIRFLOW__CORE__SIMPLE_AUTH_MANAGER_ALL_ADMINS: "True" 18 | AIRFLOW__CORE__EXECUTION_API_SERVER_URL: "http://api-server:8080/execution/" 19 | AIRFLOW__CORE__EXECUTOR: LocalExecutor 20 | AIRFLOW__CORE__FERNET_KEY: '' 21 | AIRFLOW__CORE__LOAD_EXAMPLES: "False" 22 | AIRFLOW__CORE__SQL_ALCHEMY_CONN: postgresql://airflow:pg_password@postgres:5432/airflow 23 | AIRFLOW__DATABASE__SQL_ALCHEMY_CONN: postgresql://airflow:pg_password@postgres:5432/airflow 24 | AIRFLOW__SCHEDULER__STANDALONE_DAG_PROCESSOR: True 25 | AIRFLOW__WEBSERVER__SECRET_KEY: "airflow-ai-sdk" 26 | AIRFLOW__WEBSERVER__RBAC: "True" 27 | AIRFLOW__WEBSERVER__EXPOSE_CONFIG: "True" 28 | ASTRONOMER_ENVIRONMENT: local 29 | OPENLINEAGE_DISABLED: "True" 30 | AIRFLOW__SCHEDULER__ENABLE_HEALTH_CHECK: 'true' 31 | volumes: 32 | - ./dags:/usr/local/airflow/dags 33 | - ./plugins:/usr/local/airflow/plugins 34 | - ./include:/usr/local/airflow/include 35 | - ./tests:/usr/local/airflow/tests 36 | - airflow_logs:/usr/local/airflow/logs 37 | 38 | networks: 39 | airflow: 40 | driver: bridge 41 | 42 | volumes: 43 | postgres_data: 44 | driver: local 45 | airflow_logs: 46 | driver: local 47 | 48 | services: 49 | postgres: 50 | image: postgres:13 51 | restart: unless-stopped 52 | networks: 53 | - airflow 54 | ports: 55 | - "5432:5432" 56 | volumes: 57 | - postgres_data:/var/lib/postgresql/data 58 | environment: 59 | POSTGRES_USER: airflow 60 | POSTGRES_PASSWORD: pg_password 61 | POSTGRES_DB: airflow 62 | healthcheck: 63 | test: ["CMD", "pg_isready", "-U", "airflow"] 64 | interval: 10s 65 | retries: 5 66 | start_period: 5s 67 | 68 | db-migration: 69 | <<: *airflow-common 70 | depends_on: 71 | - postgres 72 | command: 73 | - airflow 74 | - db 75 | - migrate 76 | 77 | scheduler: 78 | <<: *airflow-common 79 | depends_on: 80 | - db-migration 81 | command: 82 | - airflow 83 | - scheduler 84 | restart: unless-stopped 85 | healthcheck: 86 | test: ["CMD", "curl", "--fail", "http://localhost:8974/health"] 87 | interval: 30s 88 | timeout: 10s 89 | retries: 5 90 | start_period: 30s 91 | 92 | dag-processor: 93 | <<: *airflow-common 94 | depends_on: 95 | - db-migration 96 | command: 97 | - airflow 98 | - dag-processor 99 | restart: unless-stopped 100 | healthcheck: 101 | test: ["CMD-SHELL", 'airflow jobs check --job-type DagProcessorJob --hostname "$${HOSTNAME}"'] 102 | interval: 30s 103 | timeout: 10s 104 | retries: 5 105 | start_period: 30s 106 | 107 | api-server: 108 | <<: *airflow-common 109 | depends_on: 110 | - db-migration 111 | command: 112 | - airflow 113 | - api-server 114 | restart: unless-stopped 115 | ports: 116 | - "8080:8080" 117 | healthcheck: 118 | test: ["CMD", "curl", "--fail", "http://localhost:8080/api/v2/version"] 119 | interval: 30s 120 | timeout: 10s 121 | retries: 5 122 | start_period: 30s 123 | 124 | triggerer: 125 | <<: *airflow-common 126 | depends_on: 127 | - db-migration 128 | command: 129 | - airflow 130 | - triggerer 131 | restart: unless-stopped 132 | healthcheck: 133 | test: ["CMD-SHELL", 'airflow jobs check --job-type TriggererJob --hostname "$${HOSTNAME}"'] 134 | interval: 30s 135 | timeout: 10s 136 | retries: 5 137 | start_period: 30s 138 | -------------------------------------------------------------------------------- /examples/packages.txt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/astronomer/airflow-ai-sdk/572ce0cf17b1eeb1002c380fc003a12e797c2d6a/examples/packages.txt -------------------------------------------------------------------------------- /examples/requirements.txt: -------------------------------------------------------------------------------- 1 | # Astro Runtime includes the following pre-installed providers packages: https://www.astronomer.io/docs/astro/runtime-image-architecture#provider-packages 2 | PyGithub 3 | BeautifulSoup4 4 | sentence-transformers 5 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["hatchling"] 3 | build-backend = "hatchling.build" 4 | 5 | [project] 6 | name = "airflow-ai-sdk" 7 | dynamic = ["version"] 8 | description = "SDK for building LLM workflows and agents using Apache Airflow" 9 | readme = "README.md" 10 | requires-python = ">=3.10" 11 | classifiers = [ 12 | "Development Status :: 4 - Beta", 13 | "Framework :: Apache Airflow", 14 | "Framework :: Apache Airflow :: Provider", 15 | "Intended Audience :: Developers", 16 | "License :: OSI Approved :: Apache Software License", 17 | "Operating System :: OS Independent", 18 | "Programming Language :: Python", 19 | "Programming Language :: Python :: 3", 20 | "Programming Language :: Python :: 3.10", 21 | "Programming Language :: Python :: 3.11", 22 | "Programming Language :: Python :: 3.12", 23 | ] 24 | dependencies = [ 25 | "apache-airflow>=2.7.0", 26 | "typing-extensions>=4.0.0", 27 | "pydantic-ai-slim<0.1.0", 28 | ] 29 | 30 | [project.optional-dependencies] 31 | # models 32 | openai = ["pydantic-ai-slim[openai]<0.1.0"] 33 | cohere = ["pydantic-ai-slim[cohere]<0.1.0"] 34 | vertexai = ["pydantic-ai-slim[vertexai]<0.1.0"] 35 | anthropic = ["pydantic-ai-slim[anthropic]<0.1.0"] 36 | groq = ["pydantic-ai-slim[groq]<0.1.0"] 37 | mistral = ["pydantic-ai-slim[mistral]<0.1.0"] 38 | bedrock = ["pydantic-ai-slim[bedrock]<0.1.0"] 39 | 40 | # tools 41 | duckduckgo = ["pydantic-ai-slim[duckduckgo]<0.1.0"] 42 | 43 | # mcp 44 | mcp = ["pydantic-ai-slim[mcp]<0.1.0"] 45 | 46 | [dependency-groups] 47 | dev = ["ruff>=0.11.2"] 48 | 49 | [tool.hatch.version] 50 | path = "airflow_ai_sdk/__init__.py" 51 | 52 | [tool.hatch.build.targets.wheel] 53 | packages = ["airflow_ai_sdk"] 54 | 55 | [project.entry-points.apache_airflow_provider] 56 | provider_info = "airflow_ai_sdk:get_provider_info" 57 | 58 | 59 | [tool.ruff] 60 | line-length = 88 61 | indent-width = 4 62 | exclude = ["tests", "examples"] 63 | 64 | target-version = "py312" 65 | 66 | [tool.ruff.lint] 67 | select = [ 68 | "ANN", # Require type annotations on functions (flake8-annotations) – helps MyPy catch missing types 69 | "TC", # Enforce type-checking imports (flake8-type-checking) – e.g. heavy imports in `if TYPE_CHECKING` for faster startup 70 | "B", # Bugbear rules – catch common mistakes and potential issues, some with performance impact (flake8-bugbear) 71 | "ARG", # Flag unused function arguments (flake8-unused-arguments) – helps remove dead code and catch bugs 72 | "C4", # Optimize comprehensions (flake8-comprehensions) – avoid unnecessary list()/set() around generators 73 | "SIM", # Simplify expressions (flake8-simplify) – enforce idiomatic, efficient Python constructs 74 | "PERF", # Performance tweaks (perflint) – flag code that can run faster 75 | "FAST", # FastAPI-specific rules – catch common FastAPI issues (e.g. missing `Annotated` on deps, unused path params) 76 | "S", # Security checks (flake8-bandit) – warn about insecure patterns (use of eval, weak cryptography, etc.) 77 | "RET", # Return statement consistency (flake8-return) – e.g. no mix of return vs return None in the same function 78 | "RSE", # Raise statement hygiene (flake8-raise) – avoid raising generic Exceptions or improper usage 79 | "N", # PEP8 naming conventions (pep8-naming) – enforce standard naming for classes, variables, constants 80 | "Q", # Quote consistency (flake8-quotes) – enforce a consistent string quote style across the codebase 81 | "I", # Import order (isort) – ensure imports are sorted into groups (stdlib, third-party, local) for readability 82 | "UP", # Python upgrades (pyupgrade) – suggest modern syntax (f-strings, walrus operator, etc.) for clean, up-to-date code 83 | "W293", # Avoid blank lines with whitespace 84 | ] 85 | 86 | [tool.ruff.lint.flake8-annotations] 87 | mypy-init-return = true # Don't require `-> None` on __init__ (MyPy treats __init__ as implicitly returning None) 88 | ignore-fully-untyped = false # Even fully untyped functions are reported (ensure *every* function has type hints) 89 | 90 | [tool.ruff.lint.flake8-type-checking] 91 | strict = true # Treat missing type-checking guards strictly (e.g., insist on using `if TYPE_CHECKING` for imports of heavy modules) 92 | 93 | [tool.ruff.lint.flake8-quotes] 94 | inline-quotes = "double" # Use double quotes for regular strings 95 | multiline-quotes = "double" # Use triple-double quotes for docstrings and multi-line strings 96 | 97 | [tool.uv] 98 | dev-dependencies = [ 99 | "ruff>=0.11.2", 100 | "pytest>=7.0.0", 101 | "pytest-mock>=3.10.0", 102 | ] 103 | 104 | [tool.pytest.ini_options] 105 | testpaths = ["tests"] 106 | python_files = "test_*.py" 107 | python_classes = "Test*" 108 | python_functions = "test_*" 109 | -------------------------------------------------------------------------------- /scripts/generate_interface_docs.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | import inspect 3 | from pathlib import Path 4 | 5 | PACKAGE = Path("airflow_ai_sdk") 6 | DOCS_DIR = Path("docs/interface") 7 | DOCS_DIR.mkdir(parents=True, exist_ok=True) 8 | 9 | 10 | import subprocess 11 | from types import ModuleType 12 | from typing import TextIO 13 | 14 | 15 | def document_module(module: ModuleType, file: TextIO) -> None: 16 | file.write(f"# {module.__name__}\n\n") 17 | if module.__doc__: 18 | file.write(inspect.getdoc(module)) 19 | file.write("\n\n") 20 | for name, obj in inspect.getmembers(module): 21 | if name.startswith("_"): 22 | continue 23 | if inspect.isfunction(obj) or inspect.isclass(obj): 24 | obj_module = getattr(obj, "__module__", module.__name__) 25 | if obj_module != module.__name__: 26 | continue 27 | file.write(f"## {name}\n\n") 28 | doc = inspect.getdoc(obj) or "No documentation." 29 | file.write(doc) 30 | file.write("\n\n") 31 | 32 | 33 | def main() -> None: 34 | for path in PACKAGE.rglob("*.py"): 35 | if path.name == "__init__.py": 36 | continue 37 | module_name = path.with_suffix("").as_posix().replace("/", ".") 38 | module = importlib.import_module(module_name) 39 | 40 | out_path = DOCS_DIR / path.relative_to(PACKAGE).with_suffix(".md") 41 | out_path.parent.mkdir(parents=True, exist_ok=True) 42 | 43 | with out_path.open("w") as f: 44 | document_module(module, f) 45 | 46 | # run pre-commit hooks 47 | subprocess.run(["pre-commit", "run", "--all-files"]) # noqa: S603 S607 48 | 49 | 50 | if __name__ == "__main__": 51 | main() 52 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Test package for airflow-ai-sdk. 3 | """ 4 | -------------------------------------------------------------------------------- /tests/operators/test_agent.py: -------------------------------------------------------------------------------- 1 | """ 2 | Tests for the AgentDecoratedOperator class. 3 | """ 4 | 5 | from unittest.mock import MagicMock, patch 6 | 7 | import pytest 8 | from airflow.utils.context import Context 9 | from pydantic_ai import Tool 10 | from pydantic_ai.agent import AgentRunResult 11 | 12 | from airflow_ai_sdk.models.base import BaseModel 13 | from airflow_ai_sdk.operators.agent import AgentDecoratedOperator, WrappedTool 14 | 15 | 16 | def tool1(input: str) -> str: 17 | """Tool 1.""" 18 | return f"tool1_result: {input}" 19 | 20 | def tool2(input: str) -> str: 21 | """Tool 2.""" 22 | return f"tool2_result: {input}" 23 | 24 | 25 | @pytest.fixture 26 | def base_config(): 27 | """Base configuration for tests.""" 28 | return { 29 | "op_args": [], 30 | "op_kwargs": {}, 31 | } 32 | 33 | 34 | @pytest.fixture 35 | def mock_context(): 36 | """Create a mock context.""" 37 | return MagicMock(spec=Context) 38 | 39 | 40 | @pytest.fixture 41 | def mock_agent_with_tools(): 42 | """Create a mock agent with tools.""" 43 | mock_agent = MagicMock() 44 | mock_agent._function_tools = {"tool1": Tool(tool1), "tool2": Tool(tool2)} 45 | return mock_agent 46 | 47 | 48 | @pytest.fixture 49 | def mock_agent_no_tools(): 50 | """Create a mock agent without tools.""" 51 | mock_agent = MagicMock() 52 | mock_agent._function_tools = {} 53 | return mock_agent 54 | 55 | 56 | @pytest.fixture 57 | def patched_agent_class(): 58 | """Patch the Agent class.""" 59 | with patch("airflow_ai_sdk.operators.agent.Agent") as mock_agent_class: 60 | yield mock_agent_class 61 | 62 | 63 | @pytest.fixture 64 | def patched_super_execute(): 65 | """Patch _PythonDecoratedOperator.execute.""" 66 | with patch("airflow_ai_sdk.operators.agent._PythonDecoratedOperator.execute") as mock_super_execute: 67 | mock_super_execute.return_value = "test_prompt" 68 | yield mock_super_execute 69 | 70 | 71 | def test_init(base_config, mock_agent_with_tools): 72 | """Test the initialization of AgentDecoratedOperator.""" 73 | operator = AgentDecoratedOperator( 74 | agent=mock_agent_with_tools, 75 | task_id="test_task", 76 | python_callable=lambda: "test", 77 | op_args=base_config["op_args"], 78 | op_kwargs=base_config["op_kwargs"], 79 | ) 80 | 81 | # Make sure the tools were wrapped by checking the agent's function tools' class 82 | print(operator.agent._function_tools) 83 | assert isinstance(operator.agent._function_tools["tool1"], WrappedTool) 84 | assert isinstance(operator.agent._function_tools["tool2"], WrappedTool) 85 | 86 | def test_execute_with_string_result(base_config, mock_context, mock_agent_no_tools): 87 | """Test execute method with a string result.""" 88 | # Mock the result of run_sync 89 | mock_result = MagicMock(spec=AgentRunResult) 90 | mock_result.data = "test_result" 91 | mock_agent_no_tools.run_sync.return_value = mock_result 92 | 93 | # Create the operator 94 | operator = AgentDecoratedOperator( 95 | agent=mock_agent_no_tools, 96 | task_id="test_task", 97 | python_callable=lambda: "test", 98 | op_args=base_config["op_args"], 99 | op_kwargs=base_config["op_kwargs"], 100 | ) 101 | 102 | # Call execute 103 | result = operator.execute(mock_context) 104 | 105 | # Verify the result 106 | assert result == "test_result" 107 | 108 | 109 | def test_execute_with_base_model_result(base_config, mock_context, mock_agent_no_tools): 110 | """Test execute method with a BaseModel result.""" 111 | # Create a test model 112 | class TestModel(BaseModel): 113 | field1: str 114 | field2: int 115 | 116 | test_model = TestModel(field1="test", field2=42) 117 | 118 | # Mock the result of run_sync 119 | mock_result = MagicMock() 120 | mock_result.data = test_model 121 | mock_agent_no_tools.run_sync.return_value = mock_result 122 | 123 | # Create the operator 124 | operator = AgentDecoratedOperator( 125 | agent=mock_agent_no_tools, 126 | task_id="test_task", 127 | python_callable=lambda: "test", 128 | op_args=base_config["op_args"], 129 | op_kwargs=base_config["op_kwargs"], 130 | ) 131 | 132 | # Call execute 133 | result = operator.execute(mock_context) 134 | 135 | # Verify the result 136 | assert result == {"field1": "test", "field2": 42} 137 | 138 | 139 | def test_execute_with_error(base_config, mock_context, mock_agent_no_tools): 140 | """Test execute method when an error occurs.""" 141 | # Configure the mock agent's run_sync to raise an exception 142 | error_message = "Test error" 143 | mock_agent_no_tools.run_sync.side_effect = ValueError(error_message) 144 | 145 | # Create the operator 146 | operator = AgentDecoratedOperator( 147 | agent=mock_agent_no_tools, 148 | task_id="test_task", 149 | python_callable=lambda: "test", 150 | op_args=base_config["op_args"], 151 | op_kwargs=base_config["op_kwargs"], 152 | ) 153 | 154 | # Call execute 155 | with pytest.raises(ValueError, match=error_message): 156 | operator.execute(mock_context) 157 | 158 | # Verify that run_sync was called 159 | mock_agent_no_tools.run_sync.assert_called_once_with("test") 160 | -------------------------------------------------------------------------------- /tests/operators/test_embed.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import sys 3 | import importlib 4 | from unittest.mock import patch, MagicMock, call 5 | from airflow_ai_sdk.airflow import _PythonDecoratedOperator, task_decorator_factory 6 | from airflow_ai_sdk.operators.embed import EmbedDecoratedOperator 7 | from airflow_ai_sdk.decorators.embed import embed 8 | 9 | class StubArray: 10 | def __init__(self, data): 11 | self._data = data 12 | def tolist(self): 13 | return self._data 14 | 15 | class StubModel: 16 | def __init__(self, name): 17 | self.name = name 18 | 19 | def encode(self, text, **kwargs): 20 | # Store kwargs for assertion 21 | self.last_encode_kwargs = kwargs 22 | return StubArray([0.1, 0.2, 0.3]) 23 | 24 | @patch.object(_PythonDecoratedOperator, "execute", autospec=True) 25 | @patch("sentence_transformers.SentenceTransformer", autospec=True) 26 | def test_execute_returns_vector(mock_sentence_transformer, mock_super_execute): 27 | mock_super_execute.return_value = "hello world" 28 | mock_model = StubModel("test-model") 29 | mock_sentence_transformer.return_value = mock_model 30 | 31 | op = EmbedDecoratedOperator( 32 | task_id="embed_test", 33 | python_callable=lambda: "ignored", 34 | op_args=None, 35 | op_kwargs=None, 36 | model_name="test-model", 37 | ) 38 | 39 | vec = op.execute(context=None) 40 | 41 | mock_super_execute.assert_called_once_with(op, None) 42 | mock_sentence_transformer.assert_called_once_with("test-model") 43 | assert vec == [0.1, 0.2, 0.3] 44 | 45 | @patch.object(_PythonDecoratedOperator, "execute", autospec=True) 46 | @patch("sentence_transformers.SentenceTransformer", autospec=True) 47 | def test_execute_with_encode_kwargs(mock_sentence_transformer, mock_super_execute): 48 | mock_super_execute.return_value = "hello world" 49 | mock_model = StubModel("test-model") 50 | mock_sentence_transformer.return_value = mock_model 51 | 52 | encode_kwargs = { 53 | "normalize_embeddings": True, 54 | "batch_size": 32, 55 | "show_progress_bar": False 56 | } 57 | 58 | op = EmbedDecoratedOperator( 59 | task_id="embed_test", 60 | python_callable=lambda: "ignored", 61 | op_args=None, 62 | op_kwargs=None, 63 | model_name="test-model", 64 | encode_kwargs=encode_kwargs 65 | ) 66 | 67 | vec = op.execute(context=None) 68 | 69 | mock_super_execute.assert_called_once_with(op, None) 70 | mock_sentence_transformer.assert_called_once_with("test-model") 71 | assert vec == [0.1, 0.2, 0.3] 72 | # Verify encode_kwargs were passed correctly 73 | assert mock_model.last_encode_kwargs == encode_kwargs 74 | 75 | @patch.object(_PythonDecoratedOperator, "execute", autospec=True) 76 | def test_execute_raises_error_on_non_str(mock_super_execute): 77 | mock_super_execute.return_value = 12345 78 | 79 | op = EmbedDecoratedOperator( 80 | task_id="embed_test", 81 | python_callable=lambda: "ignored", 82 | op_args=None, 83 | op_kwargs=None, 84 | model_name="test-model", 85 | ) 86 | 87 | with pytest.raises(TypeError) as excinfo: 88 | op.execute(context=None) 89 | 90 | msg = str(excinfo.value) 91 | assert "text" in msg.lower() 92 | assert "str" in msg.lower() 93 | 94 | @patch("airflow_ai_sdk.decorators.embed.task_decorator_factory") 95 | def test_embed_decorator(mock_task_decorator_factory): 96 | # Setup mock 97 | mock_decorator = MagicMock() 98 | mock_task_decorator_factory.return_value = mock_decorator 99 | 100 | # Test with custom model name 101 | custom_model = "custom-model" 102 | result = embed(model_name=custom_model) 103 | 104 | # Verify decorator factory was called correctly 105 | mock_task_decorator_factory.assert_called_once_with( 106 | decorated_operator_class=EmbedDecoratedOperator, 107 | model_name=custom_model 108 | ) 109 | # Verify the result is the mock decorator 110 | assert result == mock_decorator 111 | 112 | @patch("airflow_ai_sdk.decorators.embed.task_decorator_factory") 113 | def test_embed_decorator_with_default_model(mock_task_decorator_factory): 114 | # Reset mock 115 | mock_task_decorator_factory.reset_mock() 116 | 117 | # Setup mock 118 | mock_decorator = MagicMock() 119 | mock_task_decorator_factory.return_value = mock_decorator 120 | 121 | # Call with default model name 122 | result = embed() 123 | 124 | # Verify decorator factory was called with default model name 125 | mock_task_decorator_factory.assert_called_once_with( 126 | decorated_operator_class=EmbedDecoratedOperator, 127 | model_name="all-MiniLM-L12-v2" # Default value defined in embed.py 128 | ) 129 | # Verify the result is the mock decorator 130 | assert result == mock_decorator 131 | 132 | @patch("airflow_ai_sdk.decorators.embed.task_decorator_factory") 133 | def test_embed_decorator_passes_additional_kwargs(mock_task_decorator_factory): 134 | # Reset mock 135 | mock_task_decorator_factory.reset_mock() 136 | 137 | # Setup mock 138 | mock_decorator = MagicMock() 139 | mock_task_decorator_factory.return_value = mock_decorator 140 | 141 | # Call with additional kwargs 142 | additional_kwargs = { 143 | "task_id": "custom_task_id", 144 | "pool": "custom_pool", 145 | "priority_weight": 10, 146 | "encode_kwargs": {"normalize_embeddings": True} 147 | } 148 | 149 | result = embed(model_name="custom-model", **additional_kwargs) 150 | 151 | # Verify decorator factory was called with all kwargs 152 | expected_kwargs = { 153 | "decorated_operator_class": EmbedDecoratedOperator, 154 | "model_name": "custom-model", 155 | **additional_kwargs 156 | } 157 | mock_task_decorator_factory.assert_called_once_with(**expected_kwargs) 158 | 159 | # Verify the result is the mock decorator 160 | assert result == mock_decorator 161 | 162 | def test_import_error_when_sentence_transformers_not_installed(): 163 | # Mock that sentence_transformers is not installed 164 | with patch.dict(sys.modules, {'sentence_transformers': None}): 165 | # Make sure the import attempt fails 166 | with pytest.raises(ImportError) as excinfo: 167 | # Force module reload to trigger the import check 168 | import importlib 169 | if 'airflow_ai_sdk.operators.embed' in sys.modules: 170 | importlib.reload(sys.modules['airflow_ai_sdk.operators.embed']) 171 | else: 172 | import airflow_ai_sdk.operators.embed 173 | 174 | # Try to create an operator - this should raise ImportError 175 | EmbedDecoratedOperator( 176 | task_id="embed_test", 177 | python_callable=lambda: "ignored", 178 | op_args=None, 179 | op_kwargs=None, 180 | model_name="test-model", 181 | ) 182 | 183 | # Check for expected error message 184 | error_msg = str(excinfo.value) 185 | assert "sentence-transformers is not installed" in error_msg 186 | 187 | class ExceptionRaisingModel: 188 | def __init__(self, name): 189 | self.name = name 190 | 191 | def encode(self, text, **kwargs): 192 | raise ValueError("Test exception from model encoding") 193 | 194 | @patch.object(_PythonDecoratedOperator, "execute", autospec=True) 195 | @patch("sentence_transformers.SentenceTransformer", autospec=True) 196 | def test_model_exception_handling(mock_sentence_transformer, mock_super_execute): 197 | # Setup mocks 198 | mock_super_execute.return_value = "hello world" 199 | mock_sentence_transformer.return_value = ExceptionRaisingModel("test-model") 200 | 201 | op = EmbedDecoratedOperator( 202 | task_id="embed_test", 203 | python_callable=lambda: "ignored", 204 | op_args=None, 205 | op_kwargs=None, 206 | model_name="test-model", 207 | ) 208 | 209 | # The operator should re-raise the exception 210 | with pytest.raises(ValueError) as excinfo: 211 | op.execute(context=None) 212 | 213 | # Check that the exception came from the model 214 | assert "Test exception from model encoding" in str(excinfo.value) 215 | 216 | # Verify the proper methods were called 217 | mock_super_execute.assert_called_once_with(op, None) 218 | mock_sentence_transformer.assert_called_once_with("test-model") 219 | 220 | @patch.object(_PythonDecoratedOperator, "execute", autospec=True) 221 | @patch("sentence_transformers.SentenceTransformer", autospec=True) 222 | def test_non_string_input_types(mock_sentence_transformer, mock_super_execute): 223 | mock_model = StubModel("test-model") 224 | mock_sentence_transformer.return_value = mock_model 225 | 226 | op = EmbedDecoratedOperator( 227 | task_id="embed_test", 228 | python_callable=lambda: "ignored", 229 | op_args=None, 230 | op_kwargs=None, 231 | model_name="test-model", 232 | ) 233 | 234 | # Test input types that aren't strings and should cause TypeError 235 | non_string_inputs = [ 236 | 123, # int 237 | 1.23, # float 238 | True, # bool 239 | ["item1", "item2"], # list 240 | {"key": "value"}, # dict 241 | (1, 2, 3), # tuple 242 | None # None 243 | ] 244 | 245 | for input_val in non_string_inputs: 246 | mock_super_execute.reset_mock() 247 | mock_super_execute.return_value = input_val 248 | 249 | with pytest.raises(TypeError) as excinfo: 250 | op.execute(context=None) 251 | 252 | error_msg = str(excinfo.value) 253 | assert "text" in error_msg.lower() 254 | assert "str" in error_msg.lower() 255 | mock_super_execute.assert_called_once_with(op, None) 256 | -------------------------------------------------------------------------------- /tests/operators/test_llm.py: -------------------------------------------------------------------------------- 1 | """ 2 | Tests for the LLMDecoratedOperator class. 3 | """ 4 | 5 | from unittest.mock import MagicMock, patch 6 | 7 | import pytest 8 | from pydantic import BaseModel 9 | from pydantic_ai.models import Model 10 | 11 | from airflow_ai_sdk.operators.llm import LLMDecoratedOperator 12 | 13 | 14 | @pytest.fixture 15 | def base_config(): 16 | """Base configuration for tests.""" 17 | return { 18 | "model": "gpt-4", 19 | "system_prompt": "You are a helpful assistant.", 20 | "op_args": [], 21 | "op_kwargs": {}, 22 | } 23 | 24 | 25 | @pytest.fixture 26 | def mock_agent(): 27 | """Create a mock agent.""" 28 | return MagicMock() 29 | 30 | 31 | @pytest.fixture 32 | def patched_agent_class(mock_agent): 33 | """Patch the Agent class.""" 34 | with patch("airflow_ai_sdk.operators.llm.Agent") as mock_agent_class: 35 | mock_agent_class.return_value = mock_agent 36 | yield mock_agent_class 37 | 38 | 39 | @pytest.fixture 40 | def patched_super_init(): 41 | """Patch the AgentDecoratedOperator.__init__ method.""" 42 | with patch("airflow_ai_sdk.operators.llm.AgentDecoratedOperator.__init__", return_value=None) as mock_super_init: 43 | yield mock_super_init 44 | 45 | 46 | def test_init_with_default_result_type(base_config, patched_agent_class, patched_super_init, mock_agent): 47 | """Test initialization with default result type (str).""" 48 | # Create the operator 49 | operator = LLMDecoratedOperator( 50 | model=base_config["model"], 51 | system_prompt=base_config["system_prompt"], 52 | task_id="test_task", 53 | op_args=base_config["op_args"], 54 | op_kwargs=base_config["op_kwargs"], 55 | python_callable=lambda: "test", 56 | ) 57 | 58 | # Verify that Agent was created with the correct arguments 59 | patched_agent_class.assert_called_once_with( 60 | model=base_config["model"], 61 | system_prompt=base_config["system_prompt"], 62 | result_type=str, 63 | ) 64 | 65 | # Verify that AgentDecoratedOperator.__init__ was called with the mock agent 66 | patched_super_init.assert_called_once() 67 | args, kwargs = patched_super_init.call_args 68 | assert kwargs["agent"] == mock_agent 69 | assert "task_id" in kwargs 70 | assert "op_args" in kwargs 71 | assert "op_kwargs" in kwargs 72 | assert "python_callable" in kwargs 73 | 74 | 75 | def test_init_with_custom_result_type(base_config, patched_agent_class, patched_super_init, mock_agent): 76 | """Test initialization with custom result type.""" 77 | # Create a test model 78 | class TestModel(BaseModel): 79 | field1: str 80 | field2: int 81 | 82 | # Create the operator 83 | operator = LLMDecoratedOperator( 84 | model=base_config["model"], 85 | system_prompt=base_config["system_prompt"], 86 | result_type=TestModel, 87 | task_id="test_task", 88 | op_args=base_config["op_args"], 89 | op_kwargs=base_config["op_kwargs"], 90 | python_callable=lambda: "test", 91 | ) 92 | 93 | # Verify that Agent was created with the correct arguments 94 | patched_agent_class.assert_called_once_with( 95 | model=base_config["model"], 96 | system_prompt=base_config["system_prompt"], 97 | result_type=TestModel, 98 | ) 99 | 100 | # Verify that AgentDecoratedOperator.__init__ was called with the mock agent 101 | patched_super_init.assert_called_once() 102 | args, kwargs = patched_super_init.call_args 103 | assert kwargs["agent"] == mock_agent 104 | assert "task_id" in kwargs 105 | assert "op_args" in kwargs 106 | assert "op_kwargs" in kwargs 107 | assert "python_callable" in kwargs 108 | 109 | 110 | def test_init_with_model_object(base_config, patched_agent_class, patched_super_init, mock_agent): 111 | """Test initialization with a Model object instead of a string.""" 112 | # Create a mock model object 113 | mock_model = MagicMock(spec=Model) 114 | 115 | # Create the operator 116 | operator = LLMDecoratedOperator( 117 | model=mock_model, 118 | system_prompt=base_config["system_prompt"], 119 | task_id="test_task", 120 | op_args=base_config["op_args"], 121 | op_kwargs=base_config["op_kwargs"], 122 | python_callable=lambda: "test", 123 | ) 124 | 125 | # Verify that Agent was created with the correct arguments 126 | patched_agent_class.assert_called_once_with( 127 | model=mock_model, 128 | system_prompt=base_config["system_prompt"], 129 | result_type=str, 130 | ) 131 | 132 | # Verify that AgentDecoratedOperator.__init__ was called with the mock agent 133 | patched_super_init.assert_called_once() 134 | args, kwargs = patched_super_init.call_args 135 | assert kwargs["agent"] == mock_agent 136 | assert "task_id" in kwargs 137 | assert "op_args" in kwargs 138 | assert "op_kwargs" in kwargs 139 | assert "python_callable" in kwargs 140 | -------------------------------------------------------------------------------- /tests/operators/test_llm_branch.py: -------------------------------------------------------------------------------- 1 | """ 2 | Tests for the LLMBranchDecoratedOperator class. 3 | """ 4 | 5 | from enum import Enum 6 | from unittest.mock import MagicMock, patch 7 | 8 | import pytest 9 | from airflow.utils.context import Context 10 | 11 | from airflow_ai_sdk.operators.llm_branch import LLMBranchDecoratedOperator 12 | 13 | 14 | @pytest.fixture 15 | def base_config(): 16 | """Base configuration for tests.""" 17 | return { 18 | "model": "gpt-4", 19 | "system_prompt": "You are a helpful assistant.", 20 | "op_args": [], 21 | "op_kwargs": {}, 22 | } 23 | 24 | 25 | @pytest.fixture 26 | def mock_context(): 27 | """Create a mock context.""" 28 | return MagicMock(spec=Context) 29 | 30 | 31 | @pytest.fixture 32 | def mock_agent(): 33 | """Create a mock agent.""" 34 | return MagicMock() 35 | 36 | 37 | @pytest.fixture 38 | def patched_agent_class(mock_agent): 39 | """Patch the Agent class.""" 40 | with patch("airflow_ai_sdk.operators.llm_branch.Agent") as mock_agent_class: 41 | mock_agent_class.return_value = mock_agent 42 | yield mock_agent_class 43 | 44 | 45 | @pytest.fixture 46 | def patched_super_init(): 47 | """Patch the AgentDecoratedOperator.__init__ method.""" 48 | with patch("airflow_ai_sdk.operators.llm_branch.AgentDecoratedOperator.__init__", return_value=None) as mock_super_init: 49 | yield mock_super_init 50 | 51 | 52 | def test_init(base_config, patched_agent_class, patched_super_init, mock_agent): 53 | """Test initialization of LLMBranchDecoratedOperator.""" 54 | # Create the operator 55 | operator = LLMBranchDecoratedOperator( 56 | model=base_config["model"], 57 | system_prompt=base_config["system_prompt"], 58 | task_id="test_task", 59 | op_args=base_config["op_args"], 60 | op_kwargs=base_config["op_kwargs"], 61 | python_callable=lambda: "test", 62 | ) 63 | 64 | # Verify that Agent was created with the correct arguments 65 | patched_agent_class.assert_called_once_with( 66 | model=base_config["model"], 67 | system_prompt=base_config["system_prompt"], 68 | ) 69 | 70 | # Verify that the properties were set correctly 71 | assert operator.model == base_config["model"] 72 | assert operator.system_prompt == base_config["system_prompt"] 73 | assert operator.allow_multiple_branches is False 74 | 75 | # Verify that AgentDecoratedOperator.__init__ was called with the mock agent 76 | patched_super_init.assert_called_once() 77 | args, kwargs = patched_super_init.call_args 78 | assert kwargs["agent"] == mock_agent 79 | assert "task_id" in kwargs 80 | assert "op_args" in kwargs 81 | assert "op_kwargs" in kwargs 82 | assert "python_callable" in kwargs 83 | 84 | 85 | def test_init_with_multiple_branches(base_config, patched_agent_class, patched_super_init, mock_agent): 86 | """Test initialization with allow_multiple_branches=True.""" 87 | # Create the operator 88 | operator = LLMBranchDecoratedOperator( 89 | model=base_config["model"], 90 | system_prompt=base_config["system_prompt"], 91 | allow_multiple_branches=True, 92 | task_id="test_task", 93 | op_args=base_config["op_args"], 94 | op_kwargs=base_config["op_kwargs"], 95 | python_callable=lambda: "test", 96 | ) 97 | 98 | # Verify the allow_multiple_branches property was set 99 | assert operator.allow_multiple_branches is True 100 | 101 | 102 | def test_execute_with_enum_result(base_config, mock_context, mock_agent): 103 | """Test execute method with an Enum result.""" 104 | with patch("airflow_ai_sdk.operators.llm_branch.Agent") as mock_agent_class: 105 | with patch("airflow_ai_sdk.operators.llm_branch.AgentDecoratedOperator.execute") as mock_super_execute: 106 | # Set up mock agent 107 | mock_agent_class.return_value = mock_agent 108 | 109 | # Mock the result of super().execute 110 | task_id = "task2" 111 | mock_super_execute.return_value = task_id 112 | 113 | # Mock the do_branch method to return a list of tasks 114 | mock_do_branch_result = ["task2"] 115 | 116 | # Create the operator 117 | with patch("airflow_ai_sdk.operators.llm_branch.AgentDecoratedOperator"): 118 | with patch.object(LLMBranchDecoratedOperator, "do_branch", return_value=mock_do_branch_result): 119 | operator = LLMBranchDecoratedOperator( 120 | model=base_config["model"], 121 | system_prompt=base_config["system_prompt"], 122 | task_id="test_task", 123 | op_args=base_config["op_args"], 124 | op_kwargs=base_config["op_kwargs"], 125 | python_callable=lambda: "test", 126 | ) 127 | 128 | # Set downstream task IDs 129 | operator.downstream_task_ids = ["task1", "task2", "task3"] 130 | 131 | # Call execute 132 | result = operator.execute(mock_context) 133 | 134 | # Verify a new Agent was created with the correct enum result_type 135 | assert mock_agent_class.call_count == 2 # Once in __init__ and once in execute 136 | 137 | # Verify that super().execute was called 138 | mock_super_execute.assert_called_once_with(mock_context) 139 | 140 | # Verify the result 141 | assert result == mock_do_branch_result 142 | 143 | 144 | def test_execute_with_non_string_result(base_config, mock_context, mock_agent): 145 | """Test execute method with a non-string result that needs to be cast to a string.""" 146 | with patch("airflow_ai_sdk.operators.llm_branch.Agent") as mock_agent_class: 147 | with patch("airflow_ai_sdk.operators.llm_branch.AgentDecoratedOperator.execute") as mock_super_execute: 148 | # Set up mock agent 149 | mock_agent_class.return_value = mock_agent 150 | 151 | # Mock the result of super().execute to return a non-string value 152 | mock_super_execute.return_value = 123 153 | 154 | # Mock the do_branch method to return a list of tasks 155 | mock_do_branch_result = ["task1"] 156 | 157 | # Create the operator 158 | with patch("airflow_ai_sdk.operators.llm_branch.AgentDecoratedOperator"): 159 | with patch.object(LLMBranchDecoratedOperator, "do_branch", return_value=mock_do_branch_result) as mock_do_branch: 160 | operator = LLMBranchDecoratedOperator( 161 | model=base_config["model"], 162 | system_prompt=base_config["system_prompt"], 163 | task_id="test_task", 164 | op_args=base_config["op_args"], 165 | op_kwargs=base_config["op_kwargs"], 166 | python_callable=lambda: "test", 167 | ) 168 | 169 | # Set downstream task IDs 170 | operator.downstream_task_ids = ["task1", "task2", "task3"] 171 | 172 | # Call execute 173 | result = operator.execute(mock_context) 174 | 175 | # Verify that do_branch was called with the string representation 176 | mock_do_branch.assert_called_once_with(mock_context, "123") 177 | 178 | # Verify the result 179 | assert result == mock_do_branch_result 180 | -------------------------------------------------------------------------------- /tests/test_airflow_imports.py: -------------------------------------------------------------------------------- 1 | """ 2 | Tests that all imports from airflow.py resolve correctly. 3 | """ 4 | 5 | import pytest 6 | 7 | 8 | def test_imports_resolve(): 9 | """Test that all imports from airflow.py can be imported successfully.""" 10 | try: 11 | from airflow_ai_sdk.airflow import ( 12 | Context, 13 | task_decorator_factory, 14 | TaskDecorator, 15 | _PythonDecoratedOperator, 16 | BranchMixIn, 17 | ) 18 | # If we get here, imports resolved successfully 19 | assert True 20 | except ImportError as e: 21 | pytest.fail(f"Failed to import from airflow.py: {e}") 22 | --------------------------------------------------------------------------------