├── .flake8 ├── .github ├── CODEOWNERS ├── dependabot.yml └── workflows │ ├── ci.yaml │ └── publish.yaml ├── .gitignore ├── .pre-commit-config.yaml ├── LICENSE ├── MANIFEST.in ├── README.md ├── pyproject.toml ├── requirements.in ├── requirements.txt ├── setup.cfg ├── src └── pytest_langchain │ ├── __init__.py │ ├── __main__.py │ ├── plugin.py │ └── tests │ ├── conftest.py │ └── test_chain.py └── tests ├── __init__.py ├── test_cli.py └── test_data ├── config.yaml └── llm_chain.yaml /.flake8: -------------------------------------------------------------------------------- 1 | [flake8] 2 | max-line-length = 120 3 | extend-ignore = E203, W503 4 | -------------------------------------------------------------------------------- /.github/CODEOWNERS: -------------------------------------------------------------------------------- 1 | # https://docs.github.com/en/repositories/managing-your-repositorys-settings-and-features/customizing-your-repository/about-code-owners 2 | 3 | * @ajndkr 4 | -------------------------------------------------------------------------------- /.github/dependabot.yml: -------------------------------------------------------------------------------- 1 | # Please see the documentation for all configuration options: 2 | # https://docs.github.com/github/administering-a-repository/configuration-options-for-dependency-updates 3 | # 4 | # pip dependencies update is handled in workflows/upgrade-dependencies.yml 5 | 6 | version: 2 7 | updates: 8 | 9 | # Maintain dependencies for github actions 10 | - package-ecosystem: github-actions 11 | directory: / 12 | schedule: 13 | interval: daily 14 | labels: 15 | - ci/cd 16 | open-pull-requests-limit: 5 17 | rebase-strategy: disabled 18 | -------------------------------------------------------------------------------- /.github/workflows/ci.yaml: -------------------------------------------------------------------------------- 1 | name: CI 2 | 3 | on: 4 | workflow_dispatch: # Allow running on-demand 5 | push: 6 | branches: 7 | - main 8 | paths: 9 | - .github/workflows/ci.yaml 10 | - src/** 11 | - tests/** 12 | - '!**.md' 13 | 14 | jobs: 15 | build-and-test: 16 | if: github.event.pull_request.draft == false 17 | runs-on: ubuntu-latest 18 | steps: 19 | - name: Code Checkout 20 | uses: actions/checkout@v3 21 | 22 | - name: Set up Python 23 | uses: actions/setup-python@v4 24 | with: 25 | python-version: 3.10.6 26 | cache: pip 27 | cache-dependency-path: requirements*.txt 28 | 29 | - name: Install dependencies 30 | run: | 31 | sudo apt-get update 32 | pip install -r requirements.txt 33 | pip install -e . 34 | 35 | - name: Unit Tests 36 | run: | 37 | pip install pytest pytest-cov 38 | pytest --cov=src --cov-report=term-missing:skip-covered --cov-fail-under=70 -v tests/ 39 | -------------------------------------------------------------------------------- /.github/workflows/publish.yaml: -------------------------------------------------------------------------------- 1 | name: Publish 2 | 3 | on: 4 | workflow_dispatch: # Allow running on-demand 5 | push: 6 | tags: 7 | - "*" 8 | 9 | jobs: 10 | # forked from https://github.com/rochacbruno/python-project-template/blob/main/.github/workflows/release.yml 11 | build-and-publish: 12 | runs-on: ubuntu-20.04 13 | steps: 14 | - name: Code Checkout 15 | uses: actions/checkout@v3 16 | 17 | - name: Set up Python 18 | uses: actions/setup-python@v4 19 | with: 20 | python-version: 3.10.6 21 | 22 | - name: Install dependencies 23 | run: | 24 | python -m pip install --upgrade pip 25 | pip install -r requirements.txt 26 | 27 | - name: Build and publish package 28 | env: 29 | TWINE_USERNAME: __token__ 30 | TWINE_PASSWORD: ${{ secrets.PYPI_API_TOKEN }} 31 | run: | 32 | python -m build . --outdir dist/ 33 | twine upload dist/* 34 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Created by https://www.toptal.com/developers/gitignore/api/macos,python,visualstudiocode 2 | # Edit at https://www.toptal.com/developers/gitignore?templates=macos,python,visualstudiocode 3 | 4 | ### macOS ### 5 | # General 6 | .DS_Store 7 | .AppleDouble 8 | .LSOverride 9 | 10 | # Icon must end with two \r 11 | Icon 12 | 13 | 14 | # Thumbnails 15 | ._* 16 | 17 | # Files that might appear in the root of a volume 18 | .DocumentRevisions-V100 19 | .fseventsd 20 | .Spotlight-V100 21 | .TemporaryItems 22 | .Trashes 23 | .VolumeIcon.icns 24 | .com.apple.timemachine.donotpresent 25 | 26 | # Directories potentially created on remote AFP share 27 | .AppleDB 28 | .AppleDesktop 29 | Network Trash Folder 30 | Temporary Items 31 | .apdisk 32 | 33 | ### macOS Patch ### 34 | # iCloud generated files 35 | *.icloud 36 | 37 | ### Python ### 38 | # Byte-compiled / optimized / DLL files 39 | __pycache__/ 40 | *.py[cod] 41 | *$py.class 42 | 43 | # C extensions 44 | *.so 45 | 46 | # Distribution / packaging 47 | .Python 48 | build/ 49 | develop-eggs/ 50 | dist/ 51 | downloads/ 52 | eggs/ 53 | .eggs/ 54 | lib/ 55 | lib64/ 56 | parts/ 57 | sdist/ 58 | var/ 59 | wheels/ 60 | share/python-wheels/ 61 | *.egg-info/ 62 | .installed.cfg 63 | *.egg 64 | MANIFEST 65 | 66 | # PyInstaller 67 | # Usually these files are written by a python script from a template 68 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 69 | *.manifest 70 | *.spec 71 | 72 | # Installer logs 73 | pip-log.txt 74 | pip-delete-this-directory.txt 75 | 76 | # Unit test / coverage reports 77 | htmlcov/ 78 | .tox/ 79 | .nox/ 80 | .coverage 81 | .coverage.* 82 | .cache 83 | nosetests.xml 84 | coverage.xml 85 | *.cover 86 | *.py,cover 87 | .hypothesis/ 88 | .pytest_cache/ 89 | cover/ 90 | 91 | # Translations 92 | *.mo 93 | *.pot 94 | 95 | # Django stuff: 96 | *.log 97 | local_settings.py 98 | db.sqlite3 99 | db.sqlite3-journal 100 | 101 | # Flask stuff: 102 | instance/ 103 | .webassets-cache 104 | 105 | # Scrapy stuff: 106 | .scrapy 107 | 108 | # Sphinx documentation 109 | docs/_build/ 110 | 111 | # PyBuilder 112 | .pybuilder/ 113 | target/ 114 | 115 | # Jupyter Notebook 116 | .ipynb_checkpoints 117 | 118 | # IPython 119 | profile_default/ 120 | ipython_config.py 121 | 122 | # pyenv 123 | # For a library or package, you might want to ignore these files since the code is 124 | # intended to run in multiple environments; otherwise, check them in: 125 | # .python-version 126 | 127 | # pipenv 128 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 129 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 130 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 131 | # install all needed dependencies. 132 | #Pipfile.lock 133 | 134 | # poetry 135 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 136 | # This is especially recommended for binary packages to ensure reproducibility, and is more 137 | # commonly ignored for libraries. 138 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 139 | #poetry.lock 140 | 141 | # pdm 142 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 143 | #pdm.lock 144 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 145 | # in version control. 146 | # https://pdm.fming.dev/#use-with-ide 147 | .pdm.toml 148 | 149 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 150 | __pypackages__/ 151 | 152 | # Celery stuff 153 | celerybeat-schedule 154 | celerybeat.pid 155 | 156 | # SageMath parsed files 157 | *.sage.py 158 | 159 | # Environments 160 | .env 161 | .venv 162 | env/ 163 | venv/ 164 | ENV/ 165 | env.bak/ 166 | venv.bak/ 167 | 168 | # Spyder project settings 169 | .spyderproject 170 | .spyproject 171 | 172 | # Rope project settings 173 | .ropeproject 174 | 175 | # mkdocs documentation 176 | /site 177 | 178 | # mypy 179 | .mypy_cache/ 180 | .dmypy.json 181 | dmypy.json 182 | 183 | # Pyre type checker 184 | .pyre/ 185 | 186 | # pytype static type analyzer 187 | .pytype/ 188 | 189 | # Cython debug symbols 190 | cython_debug/ 191 | 192 | # PyCharm 193 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 194 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 195 | # and can be added to the global gitignore or merged into this file. For a more nuclear 196 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 197 | #.idea/ 198 | 199 | ### VisualStudioCode ### 200 | .vscode/* 201 | !.vscode/settings.json 202 | !.vscode/tasks.json 203 | !.vscode/launch.json 204 | !.vscode/extensions.json 205 | !.vscode/*.code-snippets 206 | 207 | # Local History for Visual Studio Code 208 | .history/ 209 | 210 | # Built Visual Studio Code Extensions 211 | *.vsix 212 | 213 | ### VisualStudioCode Patch ### 214 | # Ignore all local history of files 215 | .history 216 | .ionide 217 | 218 | # End of https://www.toptal.com/developers/gitignore/api/macos,python,visualstudiocode 219 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | default_language_version: 2 | python: python3.10 3 | 4 | repos: 5 | - repo: https://github.com/pre-commit/pre-commit-hooks 6 | rev: v4.4.0 7 | hooks: 8 | - id: check-added-large-files 9 | - id: check-ast 10 | - id: check-builtin-literals 11 | - id: check-docstring-first 12 | - id: check-json 13 | - id: check-merge-conflict 14 | - id: check-toml 15 | - id: check-xml 16 | - id: check-yaml 17 | args: [--allow-multiple-documents] 18 | - id: debug-statements 19 | - id: detect-private-key 20 | - id: end-of-file-fixer 21 | - id: fix-byte-order-marker 22 | - id: fix-encoding-pragma 23 | args: [--remove] 24 | - id: mixed-line-ending 25 | - id: trailing-whitespace 26 | args: [--markdown-linebreak-ext=md] 27 | exclude: setup.cfg 28 | 29 | - repo: https://github.com/psf/black 30 | rev: 23.1.0 31 | hooks: 32 | - id: black 33 | 34 | - repo: https://github.com/PyCQA/flake8 35 | rev: 6.0.0 36 | hooks: 37 | - id: flake8 38 | 39 | - repo: https://github.com/PyCQA/isort 40 | rev: 5.12.0 41 | hooks: 42 | - id: isort 43 | args: [--profile, black] 44 | 45 | - repo: https://github.com/pre-commit/mirrors-mypy 46 | rev: v1.0.1 47 | hooks: 48 | - id: mypy 49 | additional_dependencies: [types-setuptools, types-PyYAML] 50 | exclude: (/tests?/) 51 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Ajinkya Indulkar 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | # Manifest syntax https://docs.python.org/2/distutils/sourcedist.html 2 | include LICENSE 3 | include README.md 4 | 5 | graft assets 6 | graft src 7 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # 🦜️🔗✅ pytest-langchain 2 | 3 | Pytest-style test runner for langchain projects. 4 | 5 |
6 | 7 | [![license](https://img.shields.io/badge/License-MIT-blue.svg)](https://github.com/ajndkr/pytest-langchain/dot/blob/main/LICENSE) 8 | [![CI](https://github.com/ajndkr/pytest-langchain/actions/workflows/ci.yaml/badge.svg)](https://github.com/ajndkr/pytest-langchain/actions/workflows/ci.yaml) 9 | [![Publish](https://github.com/ajndkr/pytest-langchain/actions/workflows/publish.yaml/badge.svg)](https://github.com/ajndkr/pytest-langchain/actions/workflows/publish.yaml) 10 | 11 |
12 | 13 | ## Installation 14 | 15 | ### Install from PyPI: 16 | 17 | ``` 18 | pip install pytest-langchain 19 | ``` 20 | 21 | ### Install from source: 22 | 23 | ``` 24 | git clone https://github.com/ajndkr/pytest-langchain 25 | cd pytest-langchain 26 | pip install . 27 | ``` 28 | 29 | ## Usage 30 | 31 | - Serialise your LLM chain into a YAML file. 32 | Refer to [docs](https://langchain.readthedocs.io/en/latest/modules/chains/generic/serialization.html) 33 | for more details. 34 | 35 | - Create a new configuration YAML file to run `pytest-langchain` with the following structure: 36 | 37 | ```yaml 38 | chain_file: 39 | test_cases: 40 | - [, ] 41 | - [, ] 42 | - ... 43 | ``` 44 | 45 | - Run `pytest-langchain`: 46 | 47 | ``` 48 | pytest-langchain -c --openai-api-key 49 | ``` 50 | 51 | For more options, run `pytest-langchain --help`. 52 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools"] 3 | build-backend = "setuptools.build_meta" 4 | 5 | [tool.pytest.ini_options] 6 | filterwarnings = ["ignore:.*"] 7 | -------------------------------------------------------------------------------- /requirements.in: -------------------------------------------------------------------------------- 1 | # base 2 | -e file:. 3 | 4 | # ci/cd 5 | pre-commit 6 | black 7 | flake8 8 | isort 9 | mypy 10 | pytest 11 | pytest-cov 12 | bumpversion 13 | build 14 | setuptools 15 | wheel 16 | twine 17 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | # 2 | # This file is autogenerated by pip-compile with Python 3.10 3 | # by the following command: 4 | # 5 | # pip-compile requirements.in 6 | # 7 | -e file:. 8 | # via -r requirements.in 9 | aiodns==3.0.0 10 | # via aleph-alpha-client 11 | aiohttp==3.8.4 12 | # via 13 | # aiohttp-retry 14 | # aleph-alpha-client 15 | # langchain 16 | aiohttp-retry==2.8.3 17 | # via aleph-alpha-client 18 | aiosignal==1.3.1 19 | # via aiohttp 20 | aleph-alpha-client==2.16.0 21 | # via langchain 22 | async-timeout==4.0.2 23 | # via aiohttp 24 | attrs==22.2.0 25 | # via 26 | # aiohttp 27 | # pytest 28 | black==23.1.0 29 | # via -r requirements.in 30 | bleach==6.0.0 31 | # via readme-renderer 32 | build==0.10.0 33 | # via -r requirements.in 34 | bump2version==1.0.1 35 | # via bumpversion 36 | bumpversion==0.6.0 37 | # via -r requirements.in 38 | certifi==2022.12.7 39 | # via requests 40 | cffi==1.15.1 41 | # via pycares 42 | cfgv==3.3.1 43 | # via pre-commit 44 | charset-normalizer==3.0.1 45 | # via 46 | # aiohttp 47 | # requests 48 | click==8.1.3 49 | # via 50 | # black 51 | # pytest-langchain 52 | coverage[toml]==7.2.0 53 | # via pytest-cov 54 | dataclasses-json==0.5.7 55 | # via langchain 56 | distlib==0.3.6 57 | # via virtualenv 58 | docutils==0.19 59 | # via readme-renderer 60 | exceptiongroup==1.1.0 61 | # via pytest 62 | filelock==3.9.0 63 | # via virtualenv 64 | flake8==6.0.0 65 | # via -r requirements.in 66 | frozenlist==1.3.3 67 | # via 68 | # aiohttp 69 | # aiosignal 70 | greenlet==2.0.2 71 | # via sqlalchemy 72 | identify==2.5.18 73 | # via pre-commit 74 | idna==3.4 75 | # via 76 | # requests 77 | # yarl 78 | importlib-metadata==6.0.0 79 | # via 80 | # keyring 81 | # twine 82 | iniconfig==2.0.0 83 | # via pytest 84 | isort==5.12.0 85 | # via -r requirements.in 86 | jaraco-classes==3.2.3 87 | # via keyring 88 | keyring==23.13.1 89 | # via twine 90 | langchain==0.0.94 91 | # via pytest-langchain 92 | markdown-it-py==2.2.0 93 | # via rich 94 | marshmallow==3.19.0 95 | # via 96 | # dataclasses-json 97 | # marshmallow-enum 98 | marshmallow-enum==1.5.1 99 | # via dataclasses-json 100 | mccabe==0.7.0 101 | # via flake8 102 | mdurl==0.1.2 103 | # via markdown-it-py 104 | more-itertools==9.0.0 105 | # via jaraco-classes 106 | multidict==6.0.4 107 | # via 108 | # aiohttp 109 | # yarl 110 | mypy==1.0.1 111 | # via -r requirements.in 112 | mypy-extensions==1.0.0 113 | # via 114 | # black 115 | # mypy 116 | # typing-inspect 117 | nodeenv==1.7.0 118 | # via pre-commit 119 | numpy==1.24.2 120 | # via langchain 121 | packaging==23.0 122 | # via 123 | # black 124 | # build 125 | # marshmallow 126 | # pytest 127 | pathspec==0.11.0 128 | # via black 129 | pkginfo==1.9.6 130 | # via twine 131 | platformdirs==3.0.0 132 | # via 133 | # black 134 | # virtualenv 135 | pluggy==1.0.0 136 | # via pytest 137 | pre-commit==3.1.0 138 | # via -r requirements.in 139 | pycares==4.3.0 140 | # via aiodns 141 | pycodestyle==2.10.0 142 | # via flake8 143 | pycparser==2.21 144 | # via cffi 145 | pydantic==1.10.5 146 | # via langchain 147 | pyflakes==3.0.1 148 | # via flake8 149 | pygments==2.14.0 150 | # via 151 | # readme-renderer 152 | # rich 153 | pyproject-hooks==1.0.0 154 | # via build 155 | pytest==7.2.1 156 | # via 157 | # -r requirements.in 158 | # pytest-cov 159 | # pytest-langchain 160 | pytest-cov==4.0.0 161 | # via -r requirements.in 162 | pyyaml==6.0 163 | # via 164 | # langchain 165 | # pre-commit 166 | # pytest-langchain 167 | readme-renderer==37.3 168 | # via twine 169 | requests==2.28.2 170 | # via 171 | # aleph-alpha-client 172 | # langchain 173 | # requests-toolbelt 174 | # twine 175 | requests-toolbelt==0.10.1 176 | # via twine 177 | rfc3986==2.0.0 178 | # via twine 179 | rich==13.3.1 180 | # via twine 181 | six==1.16.0 182 | # via bleach 183 | sqlalchemy==1.4.46 184 | # via langchain 185 | tenacity==8.2.1 186 | # via langchain 187 | tokenizers==0.13.2 188 | # via aleph-alpha-client 189 | tomli==2.0.1 190 | # via 191 | # black 192 | # build 193 | # coverage 194 | # mypy 195 | # pyproject-hooks 196 | # pytest 197 | twine==4.0.2 198 | # via -r requirements.in 199 | typing-extensions==4.5.0 200 | # via 201 | # mypy 202 | # pydantic 203 | # typing-inspect 204 | typing-inspect==0.8.0 205 | # via dataclasses-json 206 | urllib3==1.26.14 207 | # via 208 | # aleph-alpha-client 209 | # requests 210 | # twine 211 | virtualenv==20.19.0 212 | # via pre-commit 213 | webencodings==0.5.1 214 | # via bleach 215 | wheel==0.38.4 216 | # via -r requirements.in 217 | yarl==1.8.2 218 | # via aiohttp 219 | zipp==3.15.0 220 | # via importlib-metadata 221 | 222 | # The following packages are considered to be unsafe in a requirements file: 223 | # setuptools 224 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [bumpversion] 2 | current_version = 0.1.0 3 | commit = True 4 | tag = False 5 | parse = (?P\d+)\.(?P\d+)\.(?P\d+)? 6 | serialize = 7 | {major}.{minor}.{patch} 8 | 9 | [bumpversion:file:src/pytest_langchain/__init__.py] 10 | search = __version__ = "{current_version}" 11 | replace = __version__ = "{new_version}" 12 | 13 | [metadata] 14 | name = pytest-langchain 15 | version = attr: pytest_langchain.__version__ 16 | author = Ajinkya Indulkar 17 | author_email = 26824103+ajndkr@users.noreply.github.com 18 | description = Pytest-style test runner for langchain agents 19 | long_description = file: README.md 20 | long_description_content_type = text/markdown 21 | url = https://github.com/ajndkr/pytest-langchain 22 | license = MIT License 23 | classifiers = 24 | Programming Language :: Python :: 3 25 | License :: OSI Approved :: MIT License 26 | Operating System :: OS Independent 27 | 28 | [options] 29 | package_dir = 30 | = src 31 | packages = find_namespace: 32 | python_requires = >=3.7,<3.11 33 | install_requires = 34 | click 35 | pytest 36 | PyYAML 37 | langchain 38 | 39 | [options.packages.find] 40 | where = src 41 | 42 | [options.entry_points] 43 | pytest11 = 44 | langchain = pytest_langchain.plugin 45 | console_scripts = 46 | pytest-langchain = pytest_langchain:main 47 | 48 | [mypy] 49 | ignore_missing_imports = True 50 | -------------------------------------------------------------------------------- /src/pytest_langchain/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | pytest-langchain. 3 | 4 | Description: 5 | Pytest-style test runner for langchain agents. 6 | """ 7 | 8 | from .__main__ import main 9 | 10 | __version__ = "0.1.0" 11 | __all__ = ["main"] 12 | -------------------------------------------------------------------------------- /src/pytest_langchain/__main__.py: -------------------------------------------------------------------------------- 1 | import os 2 | from pathlib import Path 3 | 4 | import click 5 | import pytest 6 | 7 | 8 | @click.command() 9 | @click.option( 10 | "-c", 11 | "--config", 12 | "config_file", 13 | type=click.Path(exists=True), 14 | help="Path to config file.", 15 | ) 16 | @click.option("--openai-api-key", "openai_api_key", type=str, help="OpenAI API key.") 17 | def main(config_file: str, openai_api_key: str): 18 | """pytest-langchain CLI""" 19 | config_file = str(Path(config_file).absolute()) 20 | config_dir = str(Path(config_file).parent) 21 | 22 | os.chdir(Path(__file__).parent) 23 | pytest.main( 24 | [ 25 | "--rootdir", 26 | str(Path(__file__).parent), 27 | "--langchain-config-dir", 28 | config_dir, 29 | "--langchain-config-file", 30 | config_file, 31 | "--openai-api-key", 32 | openai_api_key, 33 | "tests/", 34 | "-v", 35 | ], 36 | plugins=["langchain"], 37 | ) 38 | 39 | 40 | if __name__ == "__main__": 41 | main() 42 | -------------------------------------------------------------------------------- /src/pytest_langchain/plugin.py: -------------------------------------------------------------------------------- 1 | """Custom pytest plugin for pytest-langchain.""" 2 | 3 | 4 | def pytest_addoption(parser): 5 | parser.addoption( 6 | "--langchain-config-dir", 7 | action="store", 8 | help="path to config directory", 9 | ) 10 | parser.addoption( 11 | "--langchain-config-file", 12 | action="store", 13 | help="path to config YAML file", 14 | ) 15 | parser.addoption( 16 | "--openai-api-key", 17 | action="store", 18 | help="OpenAI API key", 19 | ) 20 | -------------------------------------------------------------------------------- /src/pytest_langchain/tests/conftest.py: -------------------------------------------------------------------------------- 1 | import os 2 | from pathlib import Path 3 | from typing import Any, Dict 4 | 5 | import pytest 6 | import yaml 7 | from langchain.chains import load_chain 8 | 9 | ALLOWED_EXTENSIONS = [".yaml"] 10 | PYTEST_CLI_ERROR = 4 11 | PYTEST_INTERNAL_ERROR = 3 12 | 13 | 14 | def pytest_generate_tests(metafunc): 15 | if "test_case" in metafunc.fixturenames: 16 | config = load_config( 17 | metafunc.config.getoption("--langchain-config-file"), 18 | metafunc.config.getoption("--langchain-config-dir"), 19 | ) 20 | metafunc.parametrize("test_case", config["test_cases"]) 21 | 22 | 23 | @pytest.fixture(scope="session") 24 | def config_file(request): 25 | return request.config.getoption("--langchain-config-file") 26 | 27 | 28 | @pytest.fixture(scope="session") 29 | def config_dir(request): 30 | return request.config.getoption("--langchain-config-dir") 31 | 32 | 33 | @pytest.fixture(scope="session") 34 | def config(config_file, config_dir): 35 | return load_config(config_file, config_dir) 36 | 37 | 38 | @pytest.fixture(scope="session") 39 | def chain_file(config: Dict[str, Any]): 40 | return config["chain_file"] 41 | 42 | 43 | @pytest.fixture(scope="session") 44 | def llm_chain(request, chain_file, config_dir): 45 | """Loads and tests LLMChain""" 46 | if not Path(chain_file).exists(): 47 | chain_file = str(Path(config_dir) / Path(chain_file)) 48 | 49 | os.environ["OPENAI_API_KEY"] = request.config.getoption("--openai-api-key") 50 | 51 | print("OPENAI_API_KEY", os.environ["OPENAI_API_KEY"]) 52 | 53 | return load_chain(chain_file) 54 | 55 | 56 | def load_config(file: str, dir: str) -> Dict[str, Any]: 57 | """Loads config dictionary from path""" 58 | if file is None: 59 | pytest.exit("Config file not provided", PYTEST_CLI_ERROR) 60 | 61 | if not Path(file).exists(): 62 | file = str(Path(dir) / Path(file)) 63 | 64 | if Path(file).suffix not in ALLOWED_EXTENSIONS: 65 | pytest.exit("Invalid config file extension", PYTEST_CLI_ERROR) 66 | 67 | config = yaml.safe_load(Path(file).read_text()) 68 | if "test_cases" not in config: 69 | pytest.exit( 70 | "Config file does not contain a 'test_cases' key", PYTEST_INTERNAL_ERROR 71 | ) 72 | 73 | return config 74 | -------------------------------------------------------------------------------- /src/pytest_langchain/tests/test_chain.py: -------------------------------------------------------------------------------- 1 | def test_chain(llm_chain, test_case): 2 | """Loads and tests LLMChain""" 3 | print(test_case) 4 | chain_input, chain_output = test_case 5 | assert llm_chain.run(chain_input) == chain_output 6 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ajndkr/pytest-langchain/426dfc6846215d93606e4d4129ba76300f818cc1/tests/__init__.py -------------------------------------------------------------------------------- /tests/test_cli.py: -------------------------------------------------------------------------------- 1 | from unittest import mock 2 | 3 | import pytest 4 | from click.testing import CliRunner 5 | 6 | from pytest_langchain import main 7 | 8 | 9 | @pytest.fixture(scope="function") 10 | def runner(request): 11 | """Fixture for CLI runner.""" 12 | return CliRunner() 13 | 14 | 15 | def test_cli_help(runner): 16 | """Tests the help command of Shoeblender CLI.""" 17 | # invoke help command 18 | result = runner.invoke(main, ["--help"]) 19 | 20 | assert not result.exception 21 | assert "pytest-langchain CLI" in result.output 22 | assert result.exit_code == 0 23 | 24 | 25 | def test_cli_run(runner, tmp_path, monkeypatch): 26 | """Tests pkgviz CLI run.""" 27 | 28 | llm_chain = mock.Mock() 29 | llm_chain.run = mock.Mock(return_value="test") 30 | 31 | monkeypatch.setattr("langchain.chains.load_chain", llm_chain) 32 | 33 | runner.invoke( 34 | main, ["-c", "tests/test_data/config.yaml", "--openai-api-key", "test"] 35 | ) 36 | -------------------------------------------------------------------------------- /tests/test_data/config.yaml: -------------------------------------------------------------------------------- 1 | chain_file: llm_chain.yaml 2 | test_cases: 3 | - - "colorful socks" 4 | - "\n\nSocktastic!" 5 | - - "toys" 6 | - "\n\nTeddy's Toy Factory" 7 | -------------------------------------------------------------------------------- /tests/test_data/llm_chain.yaml: -------------------------------------------------------------------------------- 1 | _type: llm_chain 2 | llm: 3 | _type: openai 4 | best_of: 1 5 | frequency_penalty: 0 6 | logit_bias: {} 7 | max_tokens: 256 8 | model_name: text-davinci-003 9 | n: 1 10 | presence_penalty: 0 11 | request_timeout: null 12 | temperature: 0 13 | top_p: 1 14 | memory: null 15 | output_key: text 16 | prompt: 17 | _type: prompt 18 | input_variables: 19 | - product 20 | output_parser: null 21 | template: What is a good name for a company that makes {product}? 22 | template_format: f-string 23 | validate_template: true 24 | verbose: false 25 | --------------------------------------------------------------------------------