├── .codespellignore ├── .env.example ├── .github └── workflows │ ├── integration-tests.yml │ └── unit-tests.yml ├── .gitignore ├── LICENSE ├── Makefile ├── README.md ├── langgraph.json ├── ntbk └── testing.ipynb ├── pyproject.toml ├── src └── enrichment_agent │ ├── __init__.py │ ├── configuration.py │ ├── graph.py │ ├── prompts.py │ ├── state.py │ ├── tools.py │ └── utils.py ├── static ├── config.png ├── overview.png ├── studio.png └── studio_example.png └── tests ├── casettes └── 48fccba2-fb6a-440c-96da-4439d876cf5e.yaml ├── integration_tests ├── __init__.py └── test_graph.py └── unit_tests ├── __init__.py └── test_configuration.py /.codespellignore: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/langchain-ai/data-enrichment/028fd9a96a1179397df5748e06870c2269308f20/.codespellignore -------------------------------------------------------------------------------- /.env.example: -------------------------------------------------------------------------------- 1 | TAVILY_API_KEY=... 2 | 3 | # To separate your traces from other application 4 | LANGSMITH_PROJECT=data-enrichment 5 | 6 | # The following depend on your selected configuration 7 | 8 | ## LLM choice: 9 | ANTHROPIC_API_KEY=.... 10 | FIREWORKS_API_KEY=... 11 | OPENAI_API_KEY=... 12 | -------------------------------------------------------------------------------- /.github/workflows/integration-tests.yml: -------------------------------------------------------------------------------- 1 | # This workflow will run integration tests for the current project once per day 2 | 3 | name: Integration Tests 4 | 5 | on: 6 | schedule: 7 | - cron: "37 14 * * *" # Run at 7:37 AM Pacific Time (14:37 UTC) every day 8 | workflow_dispatch: # Allows triggering the workflow manually in GitHub UI 9 | 10 | # If another scheduled run starts while this workflow is still running, 11 | # cancel the earlier run in favor of the next run. 12 | concurrency: 13 | group: ${{ github.workflow }}-${{ github.ref }} 14 | cancel-in-progress: true 15 | 16 | jobs: 17 | integration-tests: 18 | name: Integration Tests 19 | strategy: 20 | matrix: 21 | os: [ubuntu-latest] 22 | python-version: ["3.11", "3.12"] 23 | runs-on: ${{ matrix.os }} 24 | steps: 25 | - uses: actions/checkout@v4 26 | - name: Set up Python ${{ matrix.python-version }} 27 | uses: actions/setup-python@v4 28 | with: 29 | python-version: ${{ matrix.python-version }} 30 | - name: Install dependencies 31 | run: | 32 | curl -LsSf https://astral.sh/uv/install.sh | sh 33 | uv venv 34 | uv pip install -r pyproject.toml 35 | uv pip install -U pytest-asyncio vcrpy 36 | - name: Run integration tests 37 | env: 38 | ANTHROPIC_API_KEY: ${{ secrets.ANTHROPIC_API_KEY }} 39 | TAVILY_API_KEY: ${{ secrets.TAVILY_API_KEY }} 40 | LANGSMITH_API_KEY: ${{ secrets.LANGSMITH_API_KEY }} 41 | LANGSMITH_TRACING: true 42 | run: | 43 | uv run pytest tests/integration_tests 44 | -------------------------------------------------------------------------------- /.github/workflows/unit-tests.yml: -------------------------------------------------------------------------------- 1 | # This workflow will run unit tests for the current project 2 | 3 | name: CI 4 | 5 | on: 6 | push: 7 | branches: ["main"] 8 | pull_request: 9 | workflow_dispatch: # Allows triggering the workflow manually in GitHub UI 10 | 11 | # If another push to the same PR or branch happens while this workflow is still running, 12 | # cancel the earlier run in favor of the next run. 13 | concurrency: 14 | group: ${{ github.workflow }}-${{ github.ref }} 15 | cancel-in-progress: true 16 | 17 | jobs: 18 | unit-tests: 19 | name: Unit Tests 20 | strategy: 21 | matrix: 22 | os: [ubuntu-latest] 23 | python-version: ["3.11", "3.12"] 24 | runs-on: ${{ matrix.os }} 25 | steps: 26 | - uses: actions/checkout@v4 27 | - name: Set up Python ${{ matrix.python-version }} 28 | uses: actions/setup-python@v4 29 | with: 30 | python-version: ${{ matrix.python-version }} 31 | - name: Install dependencies 32 | run: | 33 | curl -LsSf https://astral.sh/uv/install.sh | sh 34 | uv venv 35 | uv pip install -r pyproject.toml 36 | - name: Lint with ruff 37 | run: | 38 | uv pip install ruff 39 | uv run ruff check . 40 | - name: Lint with mypy 41 | run: | 42 | uv pip install mypy 43 | uv run mypy --strict src/ 44 | - name: Check README spelling 45 | uses: codespell-project/actions-codespell@v2 46 | with: 47 | ignore_words_file: .codespellignore 48 | path: README.md 49 | - name: Check code spelling 50 | uses: codespell-project/actions-codespell@v2 51 | with: 52 | ignore_words_file: .codespellignore 53 | path: src/ 54 | - name: Run tests with pytest 55 | run: | 56 | uv pip install pytest 57 | uv run pytest tests/unit_tests 58 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control 110 | .pdm.toml 111 | .pdm-python 112 | .pdm-build/ 113 | 114 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 115 | __pypackages__/ 116 | 117 | # Celery stuff 118 | celerybeat-schedule 119 | celerybeat.pid 120 | 121 | # SageMath parsed files 122 | *.sage.py 123 | 124 | # Environments 125 | .env 126 | .venv 127 | env/ 128 | venv/ 129 | ENV/ 130 | env.bak/ 131 | venv.bak/ 132 | 133 | # Spyder project settings 134 | .spyderproject 135 | .spyproject 136 | 137 | # Rope project settings 138 | .ropeproject 139 | 140 | # mkdocs documentation 141 | /site 142 | 143 | # mypy 144 | .mypy_cache/ 145 | .dmypy.json 146 | dmypy.json 147 | 148 | # Pyre type checker 149 | .pyre/ 150 | 151 | # pytype static type analyzer 152 | .pytype/ 153 | 154 | # Cython debug symbols 155 | cython_debug/ 156 | 157 | # PyCharm 158 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 159 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 160 | # and can be added to the global gitignore or merged into this file. For a more nuclear 161 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 162 | #.idea/ 163 | .DS_Store 164 | uv.lock 165 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 LangChain 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | .PHONY: all format lint test tests test_watch integration_tests docker_tests help extended_tests 2 | 3 | # Default target executed when no arguments are given to make. 4 | all: help 5 | 6 | # Define a variable for the test file path. 7 | TEST_FILE ?= tests/unit_tests/ 8 | 9 | test: 10 | python -m pytest $(TEST_FILE) 11 | 12 | test_watch: 13 | python -m ptw --snapshot-update --now . -- -vv tests/unit_tests 14 | 15 | test_profile: 16 | python -m pytest -vv tests/unit_tests/ --profile-svg 17 | 18 | extended_tests: 19 | python -m pytest --only-extended $(TEST_FILE) 20 | 21 | 22 | ###################### 23 | # LINTING AND FORMATTING 24 | ###################### 25 | 26 | # Define a variable for Python and notebook files. 27 | PYTHON_FILES=src/ 28 | MYPY_CACHE=.mypy_cache 29 | lint format: PYTHON_FILES=. 30 | lint_diff format_diff: PYTHON_FILES=$(shell git diff --name-only --diff-filter=d main | grep -E '\.py$$|\.ipynb$$') 31 | lint_package: PYTHON_FILES=src 32 | lint_tests: PYTHON_FILES=tests 33 | lint_tests: MYPY_CACHE=.mypy_cache_test 34 | 35 | lint lint_diff lint_package lint_tests: 36 | python -m ruff check . 37 | [ "$(PYTHON_FILES)" = "" ] || python -m ruff format $(PYTHON_FILES) --diff 38 | [ "$(PYTHON_FILES)" = "" ] || python -m ruff check --select I $(PYTHON_FILES) 39 | [ "$(PYTHON_FILES)" = "" ] || python -m mypy --strict $(PYTHON_FILES) 40 | [ "$(PYTHON_FILES)" = "" ] || mkdir -p $(MYPY_CACHE) && python -m mypy --strict $(PYTHON_FILES) --cache-dir $(MYPY_CACHE) 41 | 42 | format format_diff: 43 | ruff format $(PYTHON_FILES) 44 | ruff check --select I --fix $(PYTHON_FILES) 45 | 46 | spell_check: 47 | codespell --toml pyproject.toml 48 | 49 | spell_fix: 50 | codespell --toml pyproject.toml -w 51 | 52 | ###################### 53 | # HELP 54 | ###################### 55 | 56 | help: 57 | @echo '----' 58 | @echo 'format - run code formatters' 59 | @echo 'lint - run linters' 60 | @echo 'test - run unit tests' 61 | @echo 'tests - run unit tests' 62 | @echo 'test TEST_FILE= - run all tests in file' 63 | @echo 'test_watch - run unit tests in watch mode' 64 | 65 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # LangGraph Data Enrichment Template 2 | 3 | [![CI](https://github.com/langchain-ai/data-enrichment/actions/workflows/unit-tests.yml/badge.svg)](https://github.com/langchain-ai/data-enrichment/actions/workflows/unit-tests.yml) 4 | [![Integration Tests](https://github.com/langchain-ai/data-enrichment/actions/workflows/integration-tests.yml/badge.svg)](https://github.com/langchain-ai/data-enrichment/actions/workflows/integration-tests.yml) 5 | [![Open in - LangGraph Studio](https://img.shields.io/badge/Open_in-LangGraph_Studio-00324d.svg?logo=data:image/svg%2bxml;base64,PHN2ZyB4bWxucz0iaHR0cDovL3d3dy53My5vcmcvMjAwMC9zdmciIHdpZHRoPSI4NS4zMzMiIGhlaWdodD0iODUuMzMzIiB2ZXJzaW9uPSIxLjAiIHZpZXdCb3g9IjAgMCA2NCA2NCI+PHBhdGggZD0iTTEzIDcuOGMtNi4zIDMuMS03LjEgNi4zLTYuOCAyNS43LjQgMjQuNi4zIDI0LjUgMjUuOSAyNC41QzU3LjUgNTggNTggNTcuNSA1OCAzMi4zIDU4IDcuMyA1Ni43IDYgMzIgNmMtMTIuOCAwLTE2LjEuMy0xOSAxLjhtMzcuNiAxNi42YzIuOCAyLjggMy40IDQuMiAzLjQgNy42cy0uNiA0LjgtMy40IDcuNkw0Ny4yIDQzSDE2LjhsLTMuNC0zLjRjLTQuOC00LjgtNC44LTEwLjQgMC0xNS4ybDMuNC0zLjRoMzAuNHoiLz48cGF0aCBkPSJNMTguOSAyNS42Yy0xLjEgMS4zLTEgMS43LjQgMi41LjkuNiAxLjcgMS44IDEuNyAyLjcgMCAxIC43IDIuOCAxLjYgNC4xIDEuNCAxLjkgMS40IDIuNS4zIDMuMi0xIC42LS42LjkgMS40LjkgMS41IDAgMi43LS41IDIuNy0xIDAtLjYgMS4xLS44IDIuNi0uNGwyLjYuNy0xLjgtMi45Yy01LjktOS4zLTkuNC0xMi4zLTExLjUtOS44TTM5IDI2YzAgMS4xLS45IDIuNS0yIDMuMi0yLjQgMS41LTIuNiAzLjQtLjUgNC4yLjguMyAyIDEuNyAyLjUgMy4xLjYgMS41IDEuNCAyLjMgMiAyIDEuNS0uOSAxLjItMy41LS40LTMuNS0yLjEgMC0yLjgtMi44LS44LTMuMyAxLjYtLjQgMS42LS41IDAtLjYtMS4xLS4xLTEuNS0uNi0xLjItMS42LjctMS43IDMuMy0yLjEgMy41LS41LjEuNS4yIDEuNi4zIDIuMiAwIC43LjkgMS40IDEuOSAxLjYgMi4xLjQgMi4zLTIuMy4yLTMuMi0uOC0uMy0yLTEuNy0yLjUtMy4xLTEuMS0zLTMtMy4zLTMtLjUiLz48L3N2Zz4=)](https://langgraph-studio.vercel.app/templates/open?githubUrl=https://github.com/langchain-ai/data-enrichment) 6 | 7 | Producing structured results (e.g., to populate a database or spreadsheet) from open-ended research (e.g., web research) is a common use case that LLM-powered agents are well-suited to handle. Here, we provide a general template for this kind of "data enrichment agent" agent using [LangGraph](https://github.com/langchain-ai/langgraph) in [LangGraph Studio](https://github.com/langchain-ai/langgraph-studio). It contains an example graph exported from `src/enrichment_agent/graph.py` that implements a research assistant capable of automatically gathering information on various topics from the web and structuring the results into a user-defined JSON format. 8 | 9 | ![Overview of agent](./static/overview.png) 10 | 11 | ## What it does 12 | 13 | The enrichment agent defined in `src/enrichment_agent/graph.py` performs the following steps: 14 | 15 | 1. Takes a research **topic** and requested **extraction_schema** as input. 16 | 2. Searches the web for relevant information 17 | 3. Reads and extracts key details from websites 18 | 4. Organizes the findings into the requested structured format 19 | 5. Validates the gathered information for completeness and accuracy 20 | 21 | ![Graph view in LangGraph studio UI](./static/studio.png) 22 | 23 | ## Getting Started 24 | 25 | Assuming you have already [installed LangGraph Studio](https://github.com/langchain-ai/langgraph-studio?tab=readme-ov-file#download), to set up: 26 | 27 | 1. Create a `.env` file. 28 | 29 | ```bash 30 | cp .env.example .env 31 | ``` 32 | 33 | 2. Define required API keys in your `.env` file. 34 | 35 | The primary [search tool](./src/enrichment_agent/tools.py) [^1] used is [Tavily](https://tavily.com/). Create an API key [here](https://app.tavily.com/sign-in). 36 | 37 | 40 | 41 | ### Setup Model 42 | 43 | The defaults values for `model` are shown below: 44 | 45 | ```yaml 46 | model: anthropic/claude-3-5-sonnet-20240620 47 | ``` 48 | 49 | Follow the instructions below to get set up, or pick one of the additional options. 50 | 51 | #### Anthropic 52 | 53 | To use Anthropic's chat models: 54 | 55 | 1. Sign up for an [Anthropic API key](https://console.anthropic.com/) if you haven't already. 56 | 2. Once you have your API key, add it to your `.env` file: 57 | 58 | ``` 59 | ANTHROPIC_API_KEY=your-api-key 60 | ``` 61 | #### OpenAI 62 | 63 | To use OpenAI's chat models: 64 | 65 | 1. Sign up for an [OpenAI API key](https://platform.openai.com/signup). 66 | 2. Once you have your API key, add it to your `.env` file: 67 | ``` 68 | OPENAI_API_KEY=your-api-key 69 | ``` 70 | 71 | 72 | 73 | 74 | 75 | 78 | 79 | 3. Consider a research topic and desired extraction schema. 80 | 81 | As an example, here is a research topic we can consider. 82 | 83 | ``` 84 | "Top 5 chip providers for LLM Training" 85 | ``` 86 | 87 | And here is a desired extraction schema (pasted in as "`extraction_schema`"): 88 | 89 | ```json 90 | { 91 | "type": "object", 92 | "properties": { 93 | "companies": { 94 | "type": "array", 95 | "items": { 96 | "type": "object", 97 | "properties": { 98 | "name": { 99 | "type": "string", 100 | "description": "Company name" 101 | }, 102 | "technologies": { 103 | "type": "string", 104 | "description": "Brief summary of key technologies used by the company" 105 | }, 106 | "market_share": { 107 | "type": "string", 108 | "description": "Overview of market share for this company" 109 | }, 110 | "future_outlook": { 111 | "type": "string", 112 | "description": "Brief summary of future prospects and developments in the field for this company" 113 | }, 114 | "key_powers": { 115 | "type": "string", 116 | "description": "Which of the 7 Powers (Scale Economies, Network Economies, Counter Positioning, Switching Costs, Branding, Cornered Resource, Process Power) best describe this company's competitive advantage" 117 | } 118 | }, 119 | "required": ["name", "technologies", "market_share", "future_outlook"] 120 | }, 121 | "description": "List of companies" 122 | } 123 | }, 124 | "required": ["companies"] 125 | } 126 | ``` 127 | 128 | 4. Open the folder LangGraph Studio, and input `topic` and `extraction_schema`. 129 | 130 | ![Results In Studio](./static/studio_example.png) 131 | 132 | ## How to customize 133 | 134 | 1. **Customize research targets**: Provide a custom JSON `extraction_schema` when calling the graph to gather different types of information. 135 | 2. **Select a different model**: We default to anthropic (sonnet-35). You can select a compatible chat model using `provider/model-name` via configuration. Example: `openai/gpt-4o-mini`. 136 | 3. **Customize the prompt**: We provide a default prompt in [prompts.py](./src/enrichment_agent/prompts.py). You can easily update this via configuration. 137 | 138 | For quick prototyping, these configurations can be set in the studio UI. 139 | 140 | ![Config In Studio](./static/config.png) 141 | 142 | You can also quickly extend this template by: 143 | 144 | - Adding new tools and API connections in [tools.py](./src/enrichment_agent/tools.py). These are just any python functions. 145 | - Adding additional steps in [graph.py](./src/enrichment_agent/graph.py). 146 | 147 | ## Development 148 | 149 | While iterating on your graph, you can edit past state and rerun your app from past states to debug specific nodes. Local changes will be automatically applied via hot reload. Try adding an interrupt before the agent calls tools, updating the default system message in `src/enrichment_agent/utils.py` to take on a persona, or adding additional nodes and edges! 150 | 151 | Follow up requests will be appended to the same thread. You can create an entirely new thread, clearing previous history, using the `+` button in the top right. 152 | 153 | You can find the latest (under construction) docs on [LangGraph](https://github.com/langchain-ai/langgraph) here, including examples and other references. Using those guides can help you pick the right patterns to adapt here for your use case. 154 | 155 | LangGraph Studio also integrates with [LangSmith](https://smith.langchain.com/) for more in-depth tracing and collaboration with teammates. 156 | 157 | [^1]: https://python.langchain.com/docs/concepts/#tools 158 | 159 | ## LangGraph API 160 | 161 | We can also interact with the graph using the LangGraph API. 162 | 163 | See `ntbk/testing.ipynb` for an example of how to do this. 164 | 165 | LangGraph Cloud (see [here](https://langchain-ai.github.io/langgraph/cloud/#overview)) make it possible to deploy the agent. 166 | 167 | -------------------------------------------------------------------------------- /langgraph.json: -------------------------------------------------------------------------------- 1 | { 2 | "dependencies": ["."], 3 | "graphs": { 4 | "agent": "./src/enrichment_agent/graph.py:graph" 5 | }, 6 | "env": ".env" 7 | } 8 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = "enrichment-agent" 3 | version = "0.0.1" 4 | description = "An agent that populates and enriches custom schemas" 5 | authors = [ 6 | { name = "William Fu-Hinthorn", email = "13333726+hinthornw@users.noreply.github.com" }, 7 | ] 8 | readme = "README.md" 9 | license = { text = "MIT" } 10 | requires-python = ">=3.9" 11 | dependencies = [ 12 | "langgraph>=0.2.19", 13 | "langchain-openai>=0.1.22", 14 | "langchain-anthropic>=0.1.23", 15 | "langchain>=0.2.14", 16 | "langchain-fireworks>=0.1.7", 17 | "python-dotenv>=1.0.1", 18 | "langchain-community>=0.2.13", 19 | ] 20 | 21 | [project.optional-dependencies] 22 | dev = ["mypy>=1.11.1", "ruff>=0.6.1", "pytest-asyncio"] 23 | 24 | [build-system] 25 | requires = ["setuptools>=73.0.0", "wheel"] 26 | build-backend = "setuptools.build_meta" 27 | 28 | [tool.setuptools] 29 | packages = ["enrichment_agent"] 30 | [tool.setuptools.package-dir] 31 | "enrichment_agent" = "src/enrichment_agent" 32 | "langgraph.templates.enrichment_agent" = "src/enrichment_agent" 33 | 34 | 35 | [tool.setuptools.package-data] 36 | "*" = ["py.typed"] 37 | 38 | [tool.ruff] 39 | lint.select = [ 40 | "E", # pycodestyle 41 | "F", # pyflakes 42 | "I", # isort 43 | "D", # pydocstyle 44 | "D401", # First line should be in imperative mood 45 | "T201", 46 | "UP", 47 | ] 48 | include = ["*.py", "*.pyi", "*.ipynb"] 49 | lint.ignore = ["UP006", "UP007", "UP035", "D417", "E501"] 50 | [tool.ruff.lint.per-file-ignores] 51 | "tests/*" = ["D", "UP"] 52 | "ntbk/*" = ["D", "UP", "T201"] 53 | [tool.ruff.lint.pydocstyle] 54 | convention = "google" 55 | -------------------------------------------------------------------------------- /src/enrichment_agent/__init__.py: -------------------------------------------------------------------------------- 1 | """Enrichment for a pre-defined schema.""" 2 | 3 | from enrichment_agent.graph import graph 4 | 5 | __all__ = ["graph"] 6 | -------------------------------------------------------------------------------- /src/enrichment_agent/configuration.py: -------------------------------------------------------------------------------- 1 | """Define the configurable parameters for the agent.""" 2 | 3 | from __future__ import annotations 4 | 5 | from dataclasses import dataclass, field, fields 6 | from typing import Annotated, Optional 7 | 8 | from langchain_core.runnables import RunnableConfig, ensure_config 9 | 10 | from enrichment_agent import prompts 11 | 12 | 13 | @dataclass(kw_only=True) 14 | class Configuration: 15 | """The configuration for the agent.""" 16 | 17 | model: Annotated[str, {"__template_metadata__": {"kind": "llm"}}] = field( 18 | default="anthropic/claude-3-5-sonnet-20240620", 19 | metadata={ 20 | "description": "The name of the language model to use for the agent. " 21 | "Should be in the form: provider/model-name." 22 | }, 23 | ) 24 | 25 | prompt: str = field( 26 | default=prompts.MAIN_PROMPT, 27 | metadata={ 28 | "description": "The main prompt template to use for the agent's interactions. " 29 | "Expects two f-string arguments: {info} and {topic}." 30 | }, 31 | ) 32 | 33 | max_search_results: int = field( 34 | default=10, 35 | metadata={ 36 | "description": "The maximum number of search results to return for each search query." 37 | }, 38 | ) 39 | 40 | max_info_tool_calls: int = field( 41 | default=3, 42 | metadata={ 43 | "description": "The maximum number of times the Info tool can be called during a single interaction." 44 | }, 45 | ) 46 | 47 | max_loops: int = field( 48 | default=6, 49 | metadata={ 50 | "description": "The maximum number of interaction loops allowed before the agent terminates." 51 | }, 52 | ) 53 | 54 | @classmethod 55 | def from_runnable_config( 56 | cls, config: Optional[RunnableConfig] = None 57 | ) -> Configuration: 58 | """Load configuration w/ defaults for the given invocation.""" 59 | config = ensure_config(config) 60 | configurable = config.get("configurable") or {} 61 | _fields = {f.name for f in fields(cls) if f.init} 62 | return cls(**{k: v for k, v in configurable.items() if k in _fields}) 63 | -------------------------------------------------------------------------------- /src/enrichment_agent/graph.py: -------------------------------------------------------------------------------- 1 | """Define a data enrichment agent. 2 | 3 | Works with a chat model with tool calling support. 4 | """ 5 | 6 | import json 7 | from typing import Any, Dict, List, Literal, Optional, cast 8 | 9 | from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, ToolMessage 10 | from langchain_core.runnables import RunnableConfig 11 | from langgraph.graph import StateGraph 12 | from langgraph.prebuilt import ToolNode 13 | from pydantic import BaseModel, Field 14 | 15 | from enrichment_agent import prompts 16 | from enrichment_agent.configuration import Configuration 17 | from enrichment_agent.state import InputState, OutputState, State 18 | from enrichment_agent.tools import scrape_website, search 19 | from enrichment_agent.utils import init_model 20 | 21 | 22 | async def call_agent_model( 23 | state: State, *, config: Optional[RunnableConfig] = None 24 | ) -> Dict[str, Any]: 25 | """Call the primary Language Model (LLM) to decide on the next research action. 26 | 27 | This asynchronous function performs the following steps: 28 | 1. Initializes configuration and sets up the 'Info' tool, which is the user-defined extraction schema. 29 | 2. Prepares the prompt and message history for the LLM. 30 | 3. Initializes and configures the LLM with available tools. 31 | 4. Invokes the LLM and processes its response. 32 | 5. Handles the LLM's decision to either continue research or submit final info. 33 | """ 34 | # Load configuration from the provided RunnableConfig 35 | configuration = Configuration.from_runnable_config(config) 36 | 37 | # Define the 'Info' tool, which is the user-defined extraction schema 38 | info_tool = { 39 | "name": "Info", 40 | "description": "Call this when you have gathered all the relevant info", 41 | "parameters": state.extraction_schema, 42 | } 43 | 44 | # Format the prompt defined in prompts.py with the extraction schema and topic 45 | p = configuration.prompt.format( 46 | info=json.dumps(state.extraction_schema, indent=2), topic=state.topic 47 | ) 48 | 49 | # Create the messages list with the formatted prompt and the previous messages 50 | messages = [HumanMessage(content=p)] + state.messages 51 | 52 | # Initialize the raw model with the provided configuration and bind the tools 53 | raw_model = init_model(config) 54 | model = raw_model.bind_tools([scrape_website, search, info_tool], tool_choice="any") 55 | response = cast(AIMessage, await model.ainvoke(messages)) 56 | 57 | # Initialize info to None 58 | info = None 59 | 60 | # Check if the response has tool calls 61 | if response.tool_calls: 62 | for tool_call in response.tool_calls: 63 | if tool_call["name"] == "Info": 64 | info = tool_call["args"] 65 | break 66 | if info is not None: 67 | # The agent is submitting their answer; 68 | # ensure it isn't erroneously attempting to simultaneously perform research 69 | response.tool_calls = [ 70 | next(tc for tc in response.tool_calls if tc["name"] == "Info") 71 | ] 72 | response_messages: List[BaseMessage] = [response] 73 | if not response.tool_calls: # If LLM didn't respect the tool_choice 74 | response_messages.append( 75 | HumanMessage(content="Please respond by calling one of the provided tools.") 76 | ) 77 | return { 78 | "messages": response_messages, 79 | "info": info, 80 | # Add 1 to the step count 81 | "loop_step": 1, 82 | } 83 | 84 | 85 | class InfoIsSatisfactory(BaseModel): 86 | """Validate whether the current extracted info is satisfactory and complete.""" 87 | 88 | reason: List[str] = Field( 89 | description="First, provide reasoning for why this is either good or bad as a final result. Must include at least 3 reasons." 90 | ) 91 | is_satisfactory: bool = Field( 92 | description="After providing your reasoning, provide a value indicating whether the result is satisfactory. If not, you will continue researching." 93 | ) 94 | improvement_instructions: Optional[str] = Field( 95 | description="If the result is not satisfactory, provide clear and specific instructions on what needs to be improved or added to make the information satisfactory." 96 | " This should include details on missing information, areas that need more depth, or specific aspects to focus on in further research.", 97 | default=None, 98 | ) 99 | 100 | 101 | async def reflect( 102 | state: State, *, config: Optional[RunnableConfig] = None 103 | ) -> Dict[str, Any]: 104 | """Validate the quality of the data enrichment agent's output. 105 | 106 | This asynchronous function performs the following steps: 107 | 1. Prepares the initial prompt using the main prompt template. 108 | 2. Constructs a message history for the model. 109 | 3. Prepares a checker prompt to evaluate the presumed info. 110 | 4. Initializes and configures a language model with structured output. 111 | 5. Invokes the model to assess the quality of the gathered information. 112 | 6. Processes the model's response and determines if the info is satisfactory. 113 | """ 114 | p = prompts.MAIN_PROMPT.format( 115 | info=json.dumps(state.extraction_schema, indent=2), topic=state.topic 116 | ) 117 | last_message = state.messages[-1] 118 | if not isinstance(last_message, AIMessage): 119 | raise ValueError( 120 | f"{reflect.__name__} expects the last message in the state to be an AI message with tool calls." 121 | f" Got: {type(last_message)}" 122 | ) 123 | messages = [HumanMessage(content=p)] + state.messages[:-1] 124 | presumed_info = state.info 125 | checker_prompt = """I am thinking of calling the info tool with the info below. \ 126 | Is this good? Give your reasoning as well. \ 127 | You can encourage the Assistant to look at specific URLs if that seems relevant, or do more searches. 128 | If you don't think it is good, you should be very specific about what could be improved. 129 | 130 | {presumed_info}""" 131 | p1 = checker_prompt.format(presumed_info=json.dumps(presumed_info or {}, indent=2)) 132 | messages.append(HumanMessage(content=p1)) 133 | raw_model = init_model(config) 134 | bound_model = raw_model.with_structured_output(InfoIsSatisfactory) 135 | response = cast(InfoIsSatisfactory, await bound_model.ainvoke(messages)) 136 | if response.is_satisfactory and presumed_info: 137 | return { 138 | "info": presumed_info, 139 | "messages": [ 140 | ToolMessage( 141 | tool_call_id=last_message.tool_calls[0]["id"], 142 | content="\n".join(response.reason), 143 | name="Info", 144 | additional_kwargs={"artifact": response.model_dump()}, 145 | status="success", 146 | ) 147 | ], 148 | } 149 | else: 150 | return { 151 | "messages": [ 152 | ToolMessage( 153 | tool_call_id=last_message.tool_calls[0]["id"], 154 | content=f"Unsatisfactory response:\n{response.improvement_instructions}", 155 | name="Info", 156 | additional_kwargs={"artifact": response.model_dump()}, 157 | status="error", 158 | ) 159 | ] 160 | } 161 | 162 | 163 | def route_after_agent( 164 | state: State, 165 | ) -> Literal["reflect", "tools", "call_agent_model", "__end__"]: 166 | """Schedule the next node after the agent's action. 167 | 168 | This function determines the next step in the research process based on the 169 | last message in the state. It handles three main scenarios: 170 | 171 | 1. Error recovery: If the last message is unexpectedly not an AIMessage. 172 | 2. Info submission: If the agent has called the "Info" tool to submit findings. 173 | 3. Continued research: If the agent has called any other tool. 174 | """ 175 | last_message = state.messages[-1] 176 | 177 | # "If for some reason the last message is not an AIMessage (due to a bug or unexpected behavior elsewhere in the code), 178 | # it ensures the system doesn't crash but instead tries to recover by calling the agent model again. 179 | if not isinstance(last_message, AIMessage): 180 | return "call_agent_model" 181 | # If the "Into" tool was called, then the model provided its extraction output. Reflect on the result 182 | if last_message.tool_calls and last_message.tool_calls[0]["name"] == "Info": 183 | return "reflect" 184 | # The last message is a tool call that is not "Info" (extraction output) 185 | else: 186 | return "tools" 187 | 188 | 189 | def route_after_checker( 190 | state: State, config: RunnableConfig 191 | ) -> Literal["__end__", "call_agent_model"]: 192 | """Schedule the next node after the checker's evaluation. 193 | 194 | This function determines whether to continue the research process or end it 195 | based on the checker's evaluation and the current state of the research. 196 | """ 197 | configurable = Configuration.from_runnable_config(config) 198 | last_message = state.messages[-1] 199 | 200 | if state.loop_step < configurable.max_loops: 201 | if not state.info: 202 | return "call_agent_model" 203 | if not isinstance(last_message, ToolMessage): 204 | raise ValueError( 205 | f"{route_after_checker.__name__} expected a tool messages. Received: {type(last_message)}." 206 | ) 207 | if last_message.status == "error": 208 | # Research deemed unsatisfactory 209 | return "call_agent_model" 210 | # It's great! 211 | return "__end__" 212 | else: 213 | return "__end__" 214 | 215 | 216 | # Create the graph 217 | workflow = StateGraph( 218 | State, input=InputState, output=OutputState, config_schema=Configuration 219 | ) 220 | workflow.add_node(call_agent_model) 221 | workflow.add_node(reflect) 222 | workflow.add_node("tools", ToolNode([search, scrape_website])) 223 | workflow.add_edge("__start__", "call_agent_model") 224 | workflow.add_conditional_edges("call_agent_model", route_after_agent) 225 | workflow.add_edge("tools", "call_agent_model") 226 | workflow.add_conditional_edges("reflect", route_after_checker) 227 | 228 | graph = workflow.compile() 229 | graph.name = "ResearchTopic" 230 | -------------------------------------------------------------------------------- /src/enrichment_agent/prompts.py: -------------------------------------------------------------------------------- 1 | """Default prompts used in this project.""" 2 | 3 | MAIN_PROMPT = """You are doing web research on behalf of a user. You are trying to figure out this information: 4 | 5 | 6 | {info} 7 | 8 | 9 | You have access to the following tools: 10 | 11 | - `Search`: call a search tool and get back some results 12 | - `ScrapeWebsite`: scrape a website and get relevant notes about the given request. This will update the notes above. 13 | - `Info`: call this when you are done and have gathered all the relevant info 14 | 15 | Here is the information you have about the topic you are researching: 16 | 17 | Topic: {topic}""" 18 | -------------------------------------------------------------------------------- /src/enrichment_agent/state.py: -------------------------------------------------------------------------------- 1 | """State definitions. 2 | 3 | State is the interface between the graph and end user as well as the 4 | data model used internally by the graph. 5 | """ 6 | 7 | import operator 8 | from dataclasses import dataclass, field 9 | from typing import Annotated, Any, List, Optional 10 | 11 | from langchain_core.messages import BaseMessage 12 | from langgraph.graph import add_messages 13 | 14 | 15 | @dataclass(kw_only=True) 16 | class InputState: 17 | """Input state defines the interface between the graph and the user (external API).""" 18 | 19 | topic: str 20 | "The topic for which the agent is tasked to gather information." 21 | 22 | extraction_schema: dict[str, Any] 23 | "The json schema defines the information the agent is tasked with filling out." 24 | 25 | info: Optional[dict[str, Any]] = field(default=None) 26 | "The info state tracks the current extracted data for the given topic, conforming to the provided schema. This is primarily populated by the agent." 27 | 28 | 29 | @dataclass(kw_only=True) 30 | class State(InputState): 31 | """A graph's State defines three main things. 32 | 33 | 1. The structure of the data to be passed between nodes (which "channels" to read from/write to and their types) 34 | 2. Default values for each field 35 | 3. Reducers for the state's fields. Reducers are functions that determine how to apply updates to the state. 36 | See [Reducers](https://langchain-ai.github.io/langgraph/concepts/low_level/#reducers) for more information. 37 | """ 38 | 39 | messages: Annotated[List[BaseMessage], add_messages] = field(default_factory=list) 40 | """ 41 | Messages track the primary execution state of the agent. 42 | 43 | Typically accumulates a pattern of: 44 | 45 | 1. HumanMessage - user input 46 | 2. AIMessage with .tool_calls - agent picking tool(s) to use to collect 47 | information 48 | 3. ToolMessage(s) - the responses (or errors) from the executed tools 49 | 50 | (... repeat steps 2 and 3 as needed ...) 51 | 4. AIMessage without .tool_calls - agent responding in unstructured 52 | format to the user. 53 | 54 | 5. HumanMessage - user responds with the next conversational turn. 55 | 56 | (... repeat steps 2-5 as needed ... ) 57 | 58 | Merges two lists of messages, updating existing messages by ID. 59 | 60 | By default, this ensures the state is "append-only", unless the 61 | new message has the same ID as an existing message. 62 | 63 | Returns: 64 | A new list of messages with the messages from `right` merged into `left`. 65 | If a message in `right` has the same ID as a message in `left`, the 66 | message from `right` will replace the message from `left`. 67 | """ 68 | 69 | loop_step: Annotated[int, operator.add] = field(default=0) 70 | 71 | # Feel free to add additional attributes to your state as needed. 72 | # Common examples include retrieved documents, extracted entities, API connections, etc. 73 | 74 | 75 | @dataclass(kw_only=True) 76 | class OutputState: 77 | """The response object for the end user. 78 | 79 | This class defines the structure of the output that will be provided 80 | to the user after the graph's execution is complete. 81 | """ 82 | 83 | info: dict[str, Any] 84 | """ 85 | A dictionary containing the extracted and processed information 86 | based on the user's query and the graph's execution. 87 | This is the primary output of the enrichment process. 88 | """ 89 | -------------------------------------------------------------------------------- /src/enrichment_agent/tools.py: -------------------------------------------------------------------------------- 1 | """Tools for data enrichment. 2 | 3 | This module contains functions that are directly exposed to the LLM as tools. 4 | These tools can be used for tasks such as web searching and scraping. 5 | Users can edit and extend these tools as needed. 6 | """ 7 | 8 | import json 9 | from typing import Any, Optional, cast 10 | 11 | import aiohttp 12 | from langchain_community.tools.tavily_search import TavilySearchResults 13 | from langchain_core.runnables import RunnableConfig 14 | from langchain_core.tools import InjectedToolArg 15 | from langgraph.prebuilt import InjectedState 16 | from typing_extensions import Annotated 17 | 18 | from enrichment_agent.configuration import Configuration 19 | from enrichment_agent.state import State 20 | from enrichment_agent.utils import init_model 21 | 22 | 23 | async def search( 24 | query: str, *, config: Annotated[RunnableConfig, InjectedToolArg] 25 | ) -> Optional[list[dict[str, Any]]]: 26 | """Query a search engine. 27 | 28 | This function queries the web to fetch comprehensive, accurate, and trusted results. It's particularly useful 29 | for answering questions about current events. Provide as much context in the query as needed to ensure high recall. 30 | """ 31 | configuration = Configuration.from_runnable_config(config) 32 | wrapped = TavilySearchResults(max_results=configuration.max_search_results) 33 | result = await wrapped.ainvoke({"query": query}) 34 | return cast(list[dict[str, Any]], result) 35 | 36 | 37 | _INFO_PROMPT = """You are doing web research on behalf of a user. You are trying to find out this information: 38 | 39 | 40 | {info} 41 | 42 | 43 | You just scraped the following website: {url} 44 | 45 | Based on the website content below, jot down some notes about the website. 46 | 47 | 48 | {content} 49 | """ 50 | 51 | 52 | async def scrape_website( 53 | url: str, 54 | *, 55 | state: Annotated[State, InjectedState], 56 | config: Annotated[RunnableConfig, InjectedToolArg], 57 | ) -> str: 58 | """Scrape and summarize content from a given URL. 59 | 60 | Returns: 61 | str: A summary of the scraped content, tailored to the extraction schema. 62 | """ 63 | async with aiohttp.ClientSession() as session: 64 | async with session.get(url) as response: 65 | content = await response.text() 66 | 67 | p = _INFO_PROMPT.format( 68 | info=json.dumps(state.extraction_schema, indent=2), 69 | url=url, 70 | content=content[:40_000], 71 | ) 72 | raw_model = init_model(config) 73 | result = await raw_model.ainvoke(p) 74 | return str(result.content) 75 | -------------------------------------------------------------------------------- /src/enrichment_agent/utils.py: -------------------------------------------------------------------------------- 1 | """Utility functions used in our graph.""" 2 | 3 | from typing import Optional 4 | 5 | from langchain.chat_models import init_chat_model 6 | from langchain_core.language_models import BaseChatModel 7 | from langchain_core.messages import AnyMessage 8 | from langchain_core.runnables import RunnableConfig 9 | 10 | from enrichment_agent.configuration import Configuration 11 | 12 | 13 | def get_message_text(msg: AnyMessage) -> str: 14 | """Get the text content of a message.""" 15 | content = msg.content 16 | if isinstance(content, str): 17 | return content 18 | elif isinstance(content, dict): 19 | return content.get("text", "") 20 | else: 21 | txts = [c if isinstance(c, str) else (c.get("text") or "") for c in content] 22 | return "".join(txts).strip() 23 | 24 | 25 | def init_model(config: Optional[RunnableConfig] = None) -> BaseChatModel: 26 | """Initialize the configured chat model.""" 27 | configuration = Configuration.from_runnable_config(config) 28 | fully_specified_name = configuration.model 29 | if "/" in fully_specified_name: 30 | provider, model = fully_specified_name.split("/", maxsplit=1) 31 | else: 32 | provider = None 33 | model = fully_specified_name 34 | return init_chat_model(model, model_provider=provider) 35 | -------------------------------------------------------------------------------- /static/config.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/langchain-ai/data-enrichment/028fd9a96a1179397df5748e06870c2269308f20/static/config.png -------------------------------------------------------------------------------- /static/overview.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/langchain-ai/data-enrichment/028fd9a96a1179397df5748e06870c2269308f20/static/overview.png -------------------------------------------------------------------------------- /static/studio.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/langchain-ai/data-enrichment/028fd9a96a1179397df5748e06870c2269308f20/static/studio.png -------------------------------------------------------------------------------- /static/studio_example.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/langchain-ai/data-enrichment/028fd9a96a1179397df5748e06870c2269308f20/static/studio_example.png -------------------------------------------------------------------------------- /tests/integration_tests/__init__.py: -------------------------------------------------------------------------------- 1 | """Define any integration tests you want in this directory.""" 2 | -------------------------------------------------------------------------------- /tests/integration_tests/test_graph.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict 2 | 3 | import pytest 4 | from langsmith import unit 5 | 6 | from enrichment_agent import graph 7 | 8 | 9 | @pytest.fixture(scope="function") 10 | def extraction_schema() -> Dict[str, Any]: 11 | return { 12 | "type": "object", 13 | "properties": { 14 | "founder": { 15 | "type": "string", 16 | "description": "The name of the company founder.", 17 | }, 18 | "websiteUrl": { 19 | "type": "string", 20 | "description": "Website URL of the company, e.g.: https://openai.com/, or https://microsoft.com", 21 | }, 22 | "products_sold": { 23 | "type": "array", 24 | "items": {"type": "string"}, 25 | "description": "A list of products sold by the company.", 26 | }, 27 | }, 28 | "required": ["founder", "websiteUrl", "products_sold"], 29 | } 30 | 31 | 32 | @pytest.mark.asyncio 33 | @unit 34 | async def test_researcher_simple_runthrough(extraction_schema: Dict[str, Any]) -> None: 35 | res = await graph.ainvoke( 36 | { 37 | "topic": "LangChain", 38 | "extraction_schema": extraction_schema, 39 | } 40 | ) 41 | 42 | assert res["info"] is not None 43 | assert "harrison" in res["info"]["founder"].lower() 44 | 45 | 46 | @pytest.fixture(scope="function") 47 | def array_extraction_schema() -> Dict[str, Any]: 48 | return { 49 | "type": "object", 50 | "properties": { 51 | "providers": { 52 | "type": "array", 53 | "items": { 54 | "type": "object", 55 | "properties": { 56 | "name": {"type": "string", "description": "Company name"}, 57 | "technology_summary": { 58 | "type": "string", 59 | "description": "Brief summary of their chip technology for LLM training", 60 | }, 61 | "current_market_share": { 62 | "type": "string", 63 | "description": "Estimated current market share percentage or position", 64 | }, 65 | "future_outlook": { 66 | "type": "string", 67 | "description": "Brief paragraph on future prospects and developments", 68 | }, 69 | }, 70 | "required": [ 71 | "name", 72 | "technology_summary", 73 | "current_market_share", 74 | "future_outlook", 75 | ], 76 | }, 77 | "description": "List of top chip providers for LLM Training", 78 | }, 79 | "overall_market_trends": { 80 | "type": "string", 81 | "description": "Brief paragraph on general trends in the LLM chip market", 82 | }, 83 | }, 84 | "required": ["providers", "overall_market_trends"], 85 | } 86 | 87 | 88 | @pytest.mark.asyncio 89 | @unit 90 | async def test_researcher_list_type(array_extraction_schema: Dict[str, Any]) -> None: 91 | res = await graph.ainvoke( 92 | { 93 | "topic": "Top 5 chip providers for LLM training", 94 | "extraction_schema": array_extraction_schema, 95 | } 96 | ) 97 | # Check that nvidia is amongst them lol 98 | info = res["info"] 99 | assert "providers" in info 100 | assert isinstance(res["info"]["providers"], list) 101 | assert len(info["providers"]) == 5 # Ensure we got exactly 5 providers 102 | 103 | # Check for NVIDIA's presence 104 | nvidia_present = any( 105 | provider["name"].lower().strip() == "nvidia" for provider in info["providers"] 106 | ) 107 | assert ( 108 | nvidia_present 109 | ), "NVIDIA should be among the top 5 chip providers for LLM training" 110 | 111 | # Validate structure of each provider 112 | for provider in info["providers"]: 113 | assert "name" in provider 114 | assert "technology_summary" in provider 115 | assert "current_market_share" in provider 116 | assert "future_outlook" in provider 117 | 118 | # Check for overall market trends 119 | assert "overall_market_trends" in info 120 | assert isinstance(info["overall_market_trends"], str) 121 | assert len(info["overall_market_trends"]) > 0 122 | -------------------------------------------------------------------------------- /tests/unit_tests/__init__.py: -------------------------------------------------------------------------------- 1 | """Define any unit tests you may want in this directory.""" 2 | -------------------------------------------------------------------------------- /tests/unit_tests/test_configuration.py: -------------------------------------------------------------------------------- 1 | from enrichment_agent.configuration import Configuration 2 | 3 | 4 | def test_configuration_from_none() -> None: 5 | Configuration.from_runnable_config() 6 | --------------------------------------------------------------------------------