├── .gitignore ├── CODE_OF_CONDUCT.md ├── CONTRIBUTING.md ├── LICENSE ├── README.md ├── blog_sample_agents ├── 0-Notebook-environment │ ├── bedrock_agent_helper.py │ ├── requirements.txt │ ├── setup_environment.ipynb │ └── utils.ipynb ├── 1-Sample-rag-agent │ └── sample_rag_agent.ipynb ├── 2-Sample-text2sql-agent │ ├── data_prep.py │ ├── lambda_function.py │ ├── openapi_schema.json │ └── sample_text2sql_agent.ipynb ├── 3-Cleanup │ ├── cleanup_rag_agent.ipynb │ └── cleanup_text2sql_agent.ipynb ├── README.MD ├── execute_eval.sh └── img │ ├── rag_langfuse_dashboard.png │ ├── rag_trace.png │ ├── rag_trace_dashboard.png │ ├── ts_langfuse_dashboard.png │ ├── ts_trace.png │ └── ts_trace_dashboard.png ├── config.env.tpl ├── data_files └── sample_data_file.json ├── driver.py ├── evaluators ├── README.MD ├── cot_evaluator.py ├── custom_evaluator.py ├── rag_evaluator.py └── text2sql_evaluator.py ├── helpers ├── README.MD ├── agent_info_extractor.py └── cot_helper.py ├── img └── evaluation_workflow.png └── requirements.txt /.gitignore: -------------------------------------------------------------------------------- 1 | env/ 2 | __pycache__/ 3 | venv/ 4 | *.cfg 5 | *.env 6 | hidden/ -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | ## Code of Conduct 2 | This project has adopted the [Amazon Open Source Code of Conduct](https://aws.github.io/code-of-conduct). 3 | For more information see the [Code of Conduct FAQ](https://aws.github.io/code-of-conduct-faq) or contact 4 | opensource-codeofconduct@amazon.com with any additional questions or comments. 5 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing Guidelines 2 | 3 | Thank you for your interest in contributing to our project. Whether it's a bug report, new feature, correction, or additional 4 | documentation, we greatly value feedback and contributions from our community. 5 | 6 | Please read through this document before submitting any issues or pull requests to ensure we have all the necessary 7 | information to effectively respond to your bug report or contribution. 8 | 9 | 10 | ## Reporting Bugs/Feature Requests 11 | 12 | We welcome you to use the GitHub issue tracker to report bugs or suggest features. 13 | 14 | When filing an issue, please check existing open, or recently closed, issues to make sure somebody else hasn't already 15 | reported the issue. Please try to include as much information as you can. Details like these are incredibly useful: 16 | 17 | * A reproducible test case or series of steps 18 | * The version of our code being used 19 | * Any modifications you've made relevant to the bug 20 | * Anything unusual about your environment or deployment 21 | 22 | 23 | ## Contributing via Pull Requests 24 | Contributions via pull requests are much appreciated. Before sending us a pull request, please ensure that: 25 | 26 | 1. You are working against the latest source on the *main* branch. 27 | 2. You check existing open, and recently merged, pull requests to make sure someone else hasn't addressed the problem already. 28 | 3. You open an issue to discuss any significant work - we would hate for your time to be wasted. 29 | 30 | To send us a pull request, please: 31 | 32 | 1. Fork the repository. 33 | 2. Modify the source; please focus on the specific change you are contributing. If you also reformat all the code, it will be hard for us to focus on your change. 34 | 3. Ensure local tests pass. 35 | 4. Commit to your fork using clear commit messages. 36 | 5. Send us a pull request, answering any default questions in the pull request interface. 37 | 6. Pay attention to any automated CI failures reported in the pull request, and stay involved in the conversation. 38 | 39 | GitHub provides additional document on [forking a repository](https://help.github.com/articles/fork-a-repo/) and 40 | [creating a pull request](https://help.github.com/articles/creating-a-pull-request/). 41 | 42 | 43 | ## Finding contributions to work on 44 | Looking at the existing issues is a great way to find something to contribute on. As our projects, by default, use the default GitHub issue labels (enhancement/bug/duplicate/help wanted/invalid/question/wontfix), looking at any 'help wanted' issues is a great place to start. 45 | 46 | 47 | ## Code of Conduct 48 | This project has adopted the [Amazon Open Source Code of Conduct](https://aws.github.io/code-of-conduct). 49 | For more information see the [Code of Conduct FAQ](https://aws.github.io/code-of-conduct-faq) or contact 50 | opensource-codeofconduct@amazon.com with any additional questions or comments. 51 | 52 | 53 | ## Security issue notifications 54 | If you discover a potential security issue in this project we ask that you notify AWS/Amazon Security via our [vulnerability reporting page](http://aws.amazon.com/security/vulnerability-reporting/). Please do **not** create a public github issue. 55 | 56 | 57 | ## Licensing 58 | 59 | See the [LICENSE](LICENSE) file for our project's licensing. We will ask you to confirm the licensing of your contribution. 60 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT No Attribution 2 | 3 | Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy of 6 | this software and associated documentation files (the "Software"), to deal in 7 | the Software without restriction, including without limitation the rights to 8 | use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of 9 | the Software, and to permit persons to whom the Software is furnished to do so. 10 | 11 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 12 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS 13 | FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR 14 | COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER 15 | IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN 16 | CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 17 | 18 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Open Source Bedrock Agent Evaluation 2 | 3 | Open Source Bedrock Agent Evaluation is an evalauation framework for Amazon Bedrock agent tool-use and chain-of-thought reasoning with observability dashboards in LangFuse. 4 | 5 | ## Existing AWS assets 6 | https://github.com/awslabs/agent-evaluation implements an LLM agent (evaluator) that will orchestrate conversations with your own agent (target) and evaluate the responses during the conversation. 7 | 8 | Our repository provides the following additional features: 9 | 10 | ## Features 11 | 12 | - Test your own Bedrock Agent with custom questions 13 | - Provides the option for LLM-as-a-judge without ground truth reference 14 | - Includes both Agent Goal metrics for chain of thought , and Task specific metrics with RAG, Text2SQL and custom tools 15 | - Observability with integration with Langfuse that includes latency and cost information 16 | - Dashboard comparison for comparison of agents with multiple Bedrock LLMs 17 | 18 | ## Evaluation Workflow 19 | 20 | ![Evaluation Workflow](img/evaluation_workflow.png) 21 | 22 | ## Evaluation Results in Langfuse 23 | 24 | ### Dashboard 25 | ![Example Dashboard](blog_sample_agents/img/rag_langfuse_dashboard.png) 26 | 27 | ### Panel of Traces 28 | ![Example Traces](blog_sample_agents/img/rag_trace_dashboard.png) 29 | 30 | ### Individual Trace 31 | ![Example Trace](blog_sample_agents/img/rag_trace.png) 32 | 33 | 34 | ### Deployment Options 35 | 1. Clone this repo to a SageMaker notebook instance 36 | 2. Clone this repo locally and set up AWS CLI credentials to your AWS account 37 | 38 | ### Pre-Requisites 39 | 40 | 1. Set up a LangFuse account using the cloud https://www.langfuse.com or the self-host option for AWS https://github.com/aws-samples/deploy-langfuse-on-ecs-with-fargate/tree/main/langfuse-v3 41 | 42 | 2. Create an organization in Langfuse 43 | 44 | 3. Create a project within your Langfuse organization 45 | 46 | 4. Save your Langfuse project keys (Secret Key, Public Key, and Host) to use in config 47 | 48 | 5. If you are using the self-hosted option and want to see model costs then you must create a model definition in Langfuse for the LLM used by your agent, instructions can be found here https://langfuse.com/docs/model-usage-and-cost#custom-model-definitions 49 | 50 | ### SageMaker Notebook Deployment Steps 51 | 52 | 1. Create a SageMaker notebook instance in your AWS account 53 | 54 | 2. Open a terminal and navigate to the SageMaker/ folder within the instance 55 | ```bash 56 | cd SageMaker/ 57 | ``` 58 | 59 | 3. Clone this repository 60 | ```bash 61 | git clone https://github.com/aws-samples/open-source-bedrock-agent-evaluation.git 62 | ``` 63 | 64 | 4. Navigate to the repository and install the necessary requirements 65 | ```bash 66 | cd amazon-bedrock-agent-evaluation-framework/ 67 | pip3 install -r requirements.txt 68 | ``` 69 | 70 | ### Local Deployment Steps 71 | 72 | 1. Clone this repository 73 | ```bash 74 | git clone https://github.com/aws-samples/open-source-bedrock-agent-evaluation.git 75 | ``` 76 | 77 | 2. Navigate to the repository and install the necessary requirements 78 | ```bash 79 | cd open-source-bedrock-agent-evaluation/ 80 | pip3 install -r requirements.txt 81 | ``` 82 | 83 | 3. Set up AWS CLI to access AWS account resources locally https://docs.aws.amazon.com/cli/latest/userguide/getting-started-quickstart.html 84 | 85 | 86 | ### Agent Evaluation Options 87 | 1. Bring you own agent to evaluate 88 | 2. Create sample agents from this repository and run evaluations 89 | 90 | ### Option 1: Bring your own agent to evaluate 91 | 1. Bring your existing agent you want to evaluate (Currently RAG and Text2SQL evaluations built-in) 92 | 2. Create a dataset file for evaluations, manually or using the generator (Refer to the data_files/sample_data_file.json for the necessary format) 93 | 94 | 3. Copy the template configuration file and fill in the necessary information 95 | ```bash 96 | cp config_tpl.env.tpl config.env 97 | ``` 98 | 99 | 4. Run driver.py to execute evaluation job against dataset 100 | ```bash 101 | python3 driver.py 102 | ``` 103 | 104 | 5. Check your Langfuse project console to see the evaluation results! 105 | 106 | ### Option 2: Create Sample Agents to run Evaluations 107 | Follow the instructions in the [Blog Sample Agents README](blog_sample_agents/README.MD). This is a guided way to run the evaluation framework on pre-created Bedrock Agents. 108 | 109 | ## Security 110 | 111 | See [CONTRIBUTING](CONTRIBUTING.md#security-issue-notifications) for more information. 112 | 113 | ## License 114 | 115 | This library is licensed under the MIT-0 License. See the LICENSE file. 116 | -------------------------------------------------------------------------------- /blog_sample_agents/0-Notebook-environment/requirements.txt: -------------------------------------------------------------------------------- 1 | boto3 2 | awscli 3 | botocore 4 | opensearch-py 5 | retrying 6 | termcolor 7 | rich 8 | datasets 9 | pandas -------------------------------------------------------------------------------- /blog_sample_agents/0-Notebook-environment/setup_environment.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "045a7b02-6a49-4b6d-b11f-43ed4ad3f190", 6 | "metadata": { 7 | "tags": [] 8 | }, 9 | "source": [ 10 | "# Prepare Notebook Environment for Sample Agents\n", 11 | "In this section we prepare this notebook environment with the necessary dependencies to create the sample agents." 12 | ] 13 | }, 14 | { 15 | "cell_type": "markdown", 16 | "id": "be755083-73cf-48b0-8d2f-06769d176503", 17 | "metadata": {}, 18 | "source": [ 19 | "#### Run the pip3 commands below to install all needed packages" 20 | ] 21 | }, 22 | { 23 | "cell_type": "code", 24 | "execution_count": null, 25 | "id": "22cff0fe-0fc5-455c-a8b7-88f7711afba7", 26 | "metadata": {}, 27 | "outputs": [], 28 | "source": [ 29 | "!pip3 install -r requirements.txt\n", 30 | "!pip3 install --upgrade boto3\n", 31 | "!pip3 show boto3" 32 | ] 33 | }, 34 | { 35 | "cell_type": "markdown", 36 | "id": "057a6dc9-bfea-44ce-8a7d-36fa5aca1d97", 37 | "metadata": {}, 38 | "source": [ 39 | "#### Ensure the latest version of boto3 is shown below" 40 | ] 41 | }, 42 | { 43 | "cell_type": "code", 44 | "execution_count": null, 45 | "id": "5aaf9d4b-6c68-452b-be00-65aee9ec8af4", 46 | "metadata": { 47 | "tags": [] 48 | }, 49 | "outputs": [], 50 | "source": [ 51 | "!pip freeze | grep boto3" 52 | ] 53 | }, 54 | { 55 | "cell_type": "markdown", 56 | "id": "3d89cd8b-73ff-4cef-b040-31cc4e4b0233", 57 | "metadata": {}, 58 | "source": [ 59 | "#### Import all needed Python libraries" 60 | ] 61 | }, 62 | { 63 | "cell_type": "code", 64 | "execution_count": null, 65 | "id": "f56fe856-e1d4-42f1-aaf0-424f04ff4ba9", 66 | "metadata": {}, 67 | "outputs": [], 68 | "source": [ 69 | "import os\n", 70 | "IMPORTS_PATH = os.path.abspath(os.path.join(os.getcwd(), \"utils.ipynb\"))\n", 71 | "%store IMPORTS_PATH" 72 | ] 73 | }, 74 | { 75 | "cell_type": "code", 76 | "execution_count": null, 77 | "id": "4ac91f55-082e-4508-815a-2e36013d1841", 78 | "metadata": {}, 79 | "outputs": [], 80 | "source": [ 81 | "%run $IMPORTS_PATH" 82 | ] 83 | }, 84 | { 85 | "cell_type": "markdown", 86 | "id": "17e2d364-aab0-45fc-bf0f-2ad4bd8d9918", 87 | "metadata": {}, 88 | "source": [ 89 | "#### Extract account information needed for agent creation and define needed agent models" 90 | ] 91 | }, 92 | { 93 | "cell_type": "code", 94 | "execution_count": null, 95 | "id": "657e0b1a-d4ff-4cfc-bb50-b64f888ca773", 96 | "metadata": {}, 97 | "outputs": [], 98 | "source": [ 99 | "# boto3 session\n", 100 | "sts_client = boto3.client('sts')\n", 101 | "session = boto3.session.Session()\n", 102 | "\n", 103 | "# Account info\n", 104 | "account_id = sts_client.get_caller_identity()[\"Account\"]\n", 105 | "region = session.region_name\n", 106 | "\n", 107 | "# Foundation model used for agents, default to cross-region inference profile\n", 108 | "agent_foundation_model = [\"us.anthropic.claude-3-5-sonnet-20241022-v2:0\"]" 109 | ] 110 | }, 111 | { 112 | "cell_type": "markdown", 113 | "id": "62cf3207-9bad-4be7-8dc9-378b8c487819", 114 | "metadata": {}, 115 | "source": [ 116 | "### (BEFORE YOU PROCEED) Ensure that you have access to all Bedrock models in your AWS account\n", 117 | "If you have to enable model access, give a couple minutes before proceeding with agent creation" 118 | ] 119 | }, 120 | { 121 | "cell_type": "markdown", 122 | "id": "9c6a97f3-d0a7-4715-b24d-7cb1f3fda07a", 123 | "metadata": {}, 124 | "source": [ 125 | "#### Store all needed variables in environment for future use in development" 126 | ] 127 | }, 128 | { 129 | "cell_type": "code", 130 | "execution_count": null, 131 | "id": "70ef6cb8-0b7b-42ee-8c91-673bbea5b126", 132 | "metadata": { 133 | "tags": [] 134 | }, 135 | "outputs": [], 136 | "source": [ 137 | "# Store account info\n", 138 | "%store account_id\n", 139 | "%store region\n", 140 | "\n", 141 | "# Store model lists\n", 142 | "%store agent_foundation_model" 143 | ] 144 | }, 145 | { 146 | "cell_type": "code", 147 | "execution_count": null, 148 | "id": "cd013e32-049b-44d5-925e-b46f0778f46a", 149 | "metadata": {}, 150 | "outputs": [], 151 | "source": [ 152 | "# Output SageMaker Notebook execution role for user to add policies to\n", 153 | "\n", 154 | "# Get the SageMaker session\n", 155 | "sagemaker_session = sagemaker.Session()\n", 156 | "\n", 157 | "# Get the execution role\n", 158 | "role = sagemaker_session.get_caller_identity_arn()\n", 159 | "\n", 160 | "print(f\"SageMaker Execution Role: {role}\")" 161 | ] 162 | }, 163 | { 164 | "cell_type": "markdown", 165 | "id": "eb4da212-5d32-4e6d-ae03-866f8f419145", 166 | "metadata": {}, 167 | "source": [ 168 | "### Navigate to the SageMaker Notebook execution role displayed above in IAM and attach the following policies: \n", 169 | "\n", 170 | "#### [\"BedrockFullAccess\", \"IAMFullAccess\", \"AWSLambda_FullAccess\", \"AmazonS3FullAccess\", \"AmazonAthenaFullAccess\"]\n", 171 | "\n", 172 | "Give a few minutes for these permissions to update" 173 | ] 174 | }, 175 | { 176 | "cell_type": "markdown", 177 | "id": "90d19548-c949-47f4-af33-89365e8401b5", 178 | "metadata": {}, 179 | "source": [ 180 | "## Proceed to agent creation notebooks now!" 181 | ] 182 | } 183 | ], 184 | "metadata": { 185 | "kernelspec": { 186 | "display_name": "conda_python3", 187 | "language": "python", 188 | "name": "conda_python3" 189 | }, 190 | "language_info": { 191 | "codemirror_mode": { 192 | "name": "ipython", 193 | "version": 3 194 | }, 195 | "file_extension": ".py", 196 | "mimetype": "text/x-python", 197 | "name": "python", 198 | "nbconvert_exporter": "python", 199 | "pygments_lexer": "ipython3", 200 | "version": "3.10.16" 201 | } 202 | }, 203 | "nbformat": 4, 204 | "nbformat_minor": 5 205 | } 206 | -------------------------------------------------------------------------------- /blog_sample_agents/0-Notebook-environment/utils.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "ef601009-d039-44ee-ae31-bcfba6cf172b", 6 | "metadata": {}, 7 | "source": [ 8 | "#### Define all needed import statements" 9 | ] 10 | }, 11 | { 12 | "cell_type": "code", 13 | "execution_count": null, 14 | "id": "b1614159-cfab-4199-9c45-8f652b1523fb", 15 | "metadata": { 16 | "tags": [] 17 | }, 18 | "outputs": [], 19 | "source": [ 20 | "# Standard Python libraries\n", 21 | "import boto3\n", 22 | "import warnings\n", 23 | "import sagemaker\n", 24 | "import os\n", 25 | "import json\n", 26 | "import time\n", 27 | "import uuid\n", 28 | "import inspect\n", 29 | "from datasets import load_dataset\n", 30 | "\n", 31 | "# Import needed functions to create agent (Bedrock sample code: https://github.com/awslabs/amazon-bedrock-agent-samples/blob/main/examples/amazon-bedrock-multi-agent-collaboration/energy_efficiency_management_agent/1-energy-forecast/1_forecasting_agent.ipynb)\n", 32 | "from bedrock_agent_helper import AgentsForAmazonBedrock\n", 33 | "\n", 34 | "warnings.filterwarnings('ignore')\n", 35 | "print(\"Successfully imported necessary libraries into notebook\")" 36 | ] 37 | } 38 | ], 39 | "metadata": { 40 | "kernelspec": { 41 | "display_name": "conda_python3", 42 | "language": "python", 43 | "name": "conda_python3" 44 | }, 45 | "language_info": { 46 | "codemirror_mode": { 47 | "name": "ipython", 48 | "version": 3 49 | }, 50 | "file_extension": ".py", 51 | "mimetype": "text/x-python", 52 | "name": "python", 53 | "nbconvert_exporter": "python", 54 | "pygments_lexer": "ipython3", 55 | "version": "3.10.16" 56 | } 57 | }, 58 | "nbformat": 4, 59 | "nbformat_minor": 5 60 | } 61 | -------------------------------------------------------------------------------- /blog_sample_agents/1-Sample-rag-agent/sample_rag_agent.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Sample RAG Agent Walkthrough\n", 8 | "\n", 9 | "This notebook will walk users through setting up a sample RAG Agent with the [Hugging Face 'rag-mini-wikipedia' dataset](https://huggingface.co/datasets/rag-datasets/rag-mini-wikipedia) and evaluating it with [Bedrock Agent Evaluation Framework](https://github.com/aws-samples/amazon-bedrock-agent-evaluation-framework)" 10 | ] 11 | }, 12 | { 13 | "cell_type": "markdown", 14 | "metadata": {}, 15 | "source": [ 16 | "## Prerequisites\n", 17 | "\n", 18 | "This notebook assumes that you have gone through the notebook environment setup in the 0-Notebook-environment/ folder and have set up a Langfuse project" 19 | ] 20 | }, 21 | { 22 | "cell_type": "markdown", 23 | "metadata": {}, 24 | "source": [ 25 | "#### Ensure the latest version of boto3 is shown below\n", 26 | "\n", 27 | "##### If not then run through setup_environment.ipynb in the 0-Notebook-environment/ folder" 28 | ] 29 | }, 30 | { 31 | "cell_type": "code", 32 | "execution_count": null, 33 | "metadata": {}, 34 | "outputs": [], 35 | "source": [ 36 | "!pip freeze | grep boto3" 37 | ] 38 | }, 39 | { 40 | "cell_type": "markdown", 41 | "metadata": {}, 42 | "source": [ 43 | "#### Load in environment variables to notebook" 44 | ] 45 | }, 46 | { 47 | "cell_type": "code", 48 | "execution_count": null, 49 | "metadata": {}, 50 | "outputs": [], 51 | "source": [ 52 | "# Retrieve import path\n", 53 | "%store -r IMPORTS_PATH\n", 54 | "\n", 55 | "# Retrieve account info\n", 56 | "%store -r account_id\n", 57 | "%store -r region\n", 58 | "\n", 59 | "# Retrieve model lists\n", 60 | "%store -r agent_foundation_model" 61 | ] 62 | }, 63 | { 64 | "cell_type": "markdown", 65 | "metadata": {}, 66 | "source": [ 67 | "#### Retrieve imports environment variable and bring libraries into notebook" 68 | ] 69 | }, 70 | { 71 | "cell_type": "code", 72 | "execution_count": null, 73 | "metadata": {}, 74 | "outputs": [], 75 | "source": [ 76 | "%run $IMPORTS_PATH" 77 | ] 78 | }, 79 | { 80 | "cell_type": "markdown", 81 | "metadata": {}, 82 | "source": [ 83 | "### Download Hugging Face 'rag-mini-wikipedia' dataset" 84 | ] 85 | }, 86 | { 87 | "cell_type": "code", 88 | "execution_count": null, 89 | "metadata": {}, 90 | "outputs": [], 91 | "source": [ 92 | "# Retrieves dataset corpus using Datasets Python library\n", 93 | "\n", 94 | "ds_corpus = load_dataset(\"rag-datasets/rag-mini-wikipedia\", \"text-corpus\")" 95 | ] 96 | }, 97 | { 98 | "cell_type": "markdown", 99 | "metadata": {}, 100 | "source": [ 101 | "### Write text corpus to file and upload to Amazon S3 to use as data source for knowledge base" 102 | ] 103 | }, 104 | { 105 | "cell_type": "code", 106 | "execution_count": null, 107 | "metadata": {}, 108 | "outputs": [], 109 | "source": [ 110 | "# Write whole corpus to a .txt file\n", 111 | "\n", 112 | "with open('mini_wiki.txt', 'w') as f:\n", 113 | " f.write(str(ds_corpus['passages']['passage']))\n", 114 | "\n", 115 | "print(\"You can now view the whole Wikipedia corpus in mini_wiki.txt\")" 116 | ] 117 | }, 118 | { 119 | "cell_type": "code", 120 | "execution_count": null, 121 | "metadata": {}, 122 | "outputs": [], 123 | "source": [ 124 | "# Create Amazon S3 bucket and upload .txt. file to Amazon S3 bucket\n", 125 | "\n", 126 | "s3_client = boto3.client('s3')\n", 127 | "\n", 128 | "wiki_bucket_name = f\"rag-mini-wiki-{account_id}-{uuid.uuid4().hex[:6]}\"\n", 129 | "\n", 130 | "if region == 'us-east-1':\n", 131 | " s3_client.create_bucket(\n", 132 | " Bucket=wiki_bucket_name\n", 133 | " )\n", 134 | "else:\n", 135 | " s3_client.create_bucket(\n", 136 | " Bucket=wiki_bucket_name,\n", 137 | " CreateBucketConfiguration={\n", 138 | " 'LocationConstraint': region\n", 139 | " }\n", 140 | " )\n", 141 | "\n", 142 | "%store wiki_bucket_name\n", 143 | "\n", 144 | "print(\"Created bucket with name '{}' in region '{}'\".format(wiki_bucket_name, region))" 145 | ] 146 | }, 147 | { 148 | "cell_type": "code", 149 | "execution_count": null, 150 | "metadata": {}, 151 | "outputs": [], 152 | "source": [ 153 | "# Place .txt corpus in S3 bucket\n", 154 | "\n", 155 | "s3_client.upload_file('mini_wiki.txt', wiki_bucket_name, 'mini_wiki.txt')\n", 156 | "\n", 157 | "print(\"Uploaded corpus to '{}'\".format(wiki_bucket_name))" 158 | ] 159 | }, 160 | { 161 | "cell_type": "markdown", 162 | "metadata": {}, 163 | "source": [ 164 | "### Create Bedrock Knowledge Base\n", 165 | "#### Follow the steps below to create a Bedrock Knowledge Base in the AWS Console manually\n", 166 | "Note: Ensure this knowledge base is located in the same region as this notebook!" 167 | ] 168 | }, 169 | { 170 | "cell_type": "markdown", 171 | "metadata": {}, 172 | "source": [ 173 | "Step 1: Navigate to the 'Amazon Bedrock' service in the AWS Console and navigate to the 'Knowledge Bases' section\n", 174 | "\n", 175 | "Step 2: Click 'Create' and select 'Knowledge Base with vector store'\n", 176 | "\n", 177 | "Step 3: Name the Knowledge Base 'mini-wiki-kb' and select the Amazon S3 data source radio button\n", 178 | "\n", 179 | "Step 4: Name the data source 'mini-wiki-data' and select the S3 bucket file 'mini_wiki.txt' that was uploaded, \n", 180 | " e.x. s3://rag-mini-wikipedia-data-XXXXXXXXXXXX/mini_wiki.txt\n", 181 | "\n", 182 | "Step 5: Use the default parsing and default chunking options\n", 183 | "\n", 184 | "Step 6: Select the 'Titan Text Embeddings V2' embedding model and create an Amazon OpenSearch Serverless vector store with the quick create option\n", 185 | "\n", 186 | "Step 7: Now create the knowledge base (this process may take several minutes)\n", 187 | "\n", 188 | "Step 8: Manually sync the data source with the knowledge base by clicking on the data source and selecting 'Sync' and wait for the process to finish before proceeding to the next step" 189 | ] 190 | }, 191 | { 192 | "cell_type": "code", 193 | "execution_count": null, 194 | "metadata": {}, 195 | "outputs": [], 196 | "source": [ 197 | "# Fetch knowledge base ID\n", 198 | "\n", 199 | "bedrock_agent_client = boto3.client(\"bedrock-agent\", region)\n", 200 | "\n", 201 | "# Call the list_knowledge_bases method\n", 202 | "response = bedrock_agent_client.list_knowledge_bases()\n", 203 | "wiki_kb_id = None\n", 204 | "\n", 205 | "# Iterate through knowledge bases and find needed one\n", 206 | "if 'knowledgeBaseSummaries' in response:\n", 207 | " for kb in response['knowledgeBaseSummaries']:\n", 208 | " if 'mini-wiki-kb' in kb['name']:\n", 209 | " wiki_kb_id = kb['knowledgeBaseId']\n", 210 | "\n", 211 | "%store wiki_kb_id\n", 212 | "\n", 213 | "wiki_kb_id" 214 | ] 215 | }, 216 | { 217 | "cell_type": "markdown", 218 | "metadata": {}, 219 | "source": [ 220 | "## Create RAG Agent" 221 | ] 222 | }, 223 | { 224 | "cell_type": "code", 225 | "execution_count": null, 226 | "metadata": {}, 227 | "outputs": [], 228 | "source": [ 229 | "agent_name = 'sample-rag-agent'\n", 230 | "agent_description = \"RAG agent to run against the Hugging Face 'rag-mini-wikipedia' dataset\"\n", 231 | "agent_instruction = \"\"\"Use the associated knowledge base to answer questions.\"\"\"" 232 | ] 233 | }, 234 | { 235 | "cell_type": "code", 236 | "execution_count": null, 237 | "metadata": {}, 238 | "outputs": [], 239 | "source": [ 240 | "agents = AgentsForAmazonBedrock()\n", 241 | "\n", 242 | "rag_agent = agents.create_agent(\n", 243 | " agent_name,\n", 244 | " agent_description,\n", 245 | " agent_instruction,\n", 246 | " agent_foundation_model,\n", 247 | " code_interpretation=False,\n", 248 | " verbose=False\n", 249 | ")\n", 250 | "\n", 251 | "rag_agent" 252 | ] 253 | }, 254 | { 255 | "cell_type": "code", 256 | "execution_count": null, 257 | "metadata": {}, 258 | "outputs": [], 259 | "source": [ 260 | "rag_agent_id = rag_agent[0]\n", 261 | "rag_agent_arn = f\"arn:aws:bedrock:{region}:{account_id}:agent/{rag_agent_id}\"\n", 262 | "\n", 263 | "rag_agent_id, rag_agent_arn" 264 | ] 265 | }, 266 | { 267 | "cell_type": "code", 268 | "execution_count": null, 269 | "metadata": {}, 270 | "outputs": [], 271 | "source": [ 272 | "agents.associate_kb_with_agent(\n", 273 | " rag_agent_id,\n", 274 | " \"Hugging Face 'rag-mini-wikipedia' dataset\", \n", 275 | " wiki_kb_id\n", 276 | ")" 277 | ] 278 | }, 279 | { 280 | "cell_type": "markdown", 281 | "metadata": {}, 282 | "source": [ 283 | "### Test RAG Agent" 284 | ] 285 | }, 286 | { 287 | "cell_type": "markdown", 288 | "metadata": {}, 289 | "source": [ 290 | "#### Invoke Sample RAG Agent Test Alias to see that it answers question properly" 291 | ] 292 | }, 293 | { 294 | "cell_type": "code", 295 | "execution_count": null, 296 | "metadata": {}, 297 | "outputs": [], 298 | "source": [ 299 | "# Ask example question to agent\n", 300 | "\n", 301 | "bedrock_agent_runtime_client = boto3.client(\"bedrock-agent-runtime\", region)\n", 302 | "\n", 303 | "session_id:str = str(uuid.uuid1())\n", 304 | "\n", 305 | "test_query = \"Who suggested Lincoln grow a beard?\"\n", 306 | "response = bedrock_agent_runtime_client.invoke_agent(\n", 307 | " inputText=test_query,\n", 308 | " agentId=rag_agent_id,\n", 309 | " agentAliasId=\"TSTALIASID\", \n", 310 | " sessionId=session_id,\n", 311 | " enableTrace=True, \n", 312 | " endSession=False,\n", 313 | " sessionState={}\n", 314 | ")\n", 315 | "\n", 316 | "print(\"Request sent to Agent\")\n", 317 | "print(\"====================\")\n", 318 | "print(\"Agent processing query now\")\n", 319 | "print(\"====================\")\n", 320 | "\n", 321 | "# Initialize an empty string to store the answer\n", 322 | "answer = \"\"\n", 323 | "\n", 324 | "# Iterate through the event stream\n", 325 | "for event in response['completion']:\n", 326 | " # Check if the event is a 'chunk' event\n", 327 | " if 'chunk' in event:\n", 328 | " chunk_obj = event['chunk']\n", 329 | " if 'bytes' in chunk_obj:\n", 330 | " # Decode the bytes and append to the answer\n", 331 | " chunk_data = chunk_obj['bytes'].decode('utf-8')\n", 332 | " answer += chunk_data\n", 333 | "\n", 334 | "# Now 'answer' contains the full response from the agent\n", 335 | "print(\"Agent Answer: {}\".format(answer))\n", 336 | "print(\"====================\")" 337 | ] 338 | }, 339 | { 340 | "cell_type": "markdown", 341 | "metadata": {}, 342 | "source": [ 343 | "#### Prepare agent and create alias for use with evaluation framework" 344 | ] 345 | }, 346 | { 347 | "cell_type": "code", 348 | "execution_count": null, 349 | "metadata": {}, 350 | "outputs": [], 351 | "source": [ 352 | "rag_agent_alias_id, rag_agent_alias_arn = agents.create_agent_alias(\n", 353 | " rag_agent[0], 'v1'\n", 354 | ")\n", 355 | "\n", 356 | "%store rag_agent_alias_arn\n", 357 | "rag_agent_alias_id, rag_agent_alias_arn" 358 | ] 359 | }, 360 | { 361 | "cell_type": "markdown", 362 | "metadata": {}, 363 | "source": [ 364 | "### Create input .json file for all RAG using ground truth provided by Hugging Face dataset" 365 | ] 366 | }, 367 | { 368 | "cell_type": "markdown", 369 | "metadata": {}, 370 | "source": [ 371 | "#### Below is the option to specify the number of questions to generate. \n", 372 | "\n", 373 | "#### Default is 10, set to -1 to run through all questions, or specify to any other desired number" 374 | ] 375 | }, 376 | { 377 | "cell_type": "code", 378 | "execution_count": null, 379 | "metadata": {}, 380 | "outputs": [], 381 | "source": [ 382 | "number_questions = 10" 383 | ] 384 | }, 385 | { 386 | "cell_type": "code", 387 | "execution_count": null, 388 | "metadata": {}, 389 | "outputs": [], 390 | "source": [ 391 | "# Create input json data file for evaluation framework and place so it can be run by evaluation framework by user\n", 392 | "\n", 393 | "ds_qa = load_dataset(\"rag-datasets/rag-mini-wikipedia\", \"question-answer\")\n", 394 | "\n", 395 | "input_data_dict = {}\n", 396 | "\n", 397 | "# Iterate through all elements in dataset\n", 398 | "for index, data in enumerate(ds_qa['test']):\n", 399 | "\n", 400 | " # Extract desired number of questions\n", 401 | " if number_questions != -1:\n", 402 | " if index == number_questions:\n", 403 | " break\n", 404 | "\n", 405 | " qa_pair = {\n", 406 | " \"question_id\": data['id'],\n", 407 | " \"question_type\": \"RAG\",\n", 408 | " \"question\": data['question'],\n", 409 | " \"ground_truth\": data['answer']\n", 410 | " } \n", 411 | " input_data_dict[\"Trajectory{}\".format(index)] = [qa_pair]\n", 412 | "\n", 413 | "# Save to JSON file\n", 414 | "with open('rag_data_file_auto.json', 'w', encoding='utf-8') as f:\n", 415 | " json.dump(input_data_dict, f, indent=4, ensure_ascii=False)" 416 | ] 417 | }, 418 | { 419 | "cell_type": "markdown", 420 | "metadata": {}, 421 | "source": [ 422 | "## Create config.env that evaluation tool needs\n", 423 | "Note: Input Langfuse host and keys into the variables below" 424 | ] 425 | }, 426 | { 427 | "cell_type": "code", 428 | "execution_count": null, 429 | "metadata": {}, 430 | "outputs": [], 431 | "source": [ 432 | "user_input = f\"\"\"\n", 433 | "\n", 434 | "AGENT_ID=\"{rag_agent_id}\"\n", 435 | "AGENT_ALIAS_ID=\"{rag_agent_alias_id}\"\n", 436 | "\n", 437 | "DATA_FILE_PATH=\"blog_sample_agents/1-Sample-rag-agent/rag_data_file_auto.json\"\n", 438 | "\n", 439 | "LANGFUSE_PUBLIC_KEY=\"FILL_IN\"\n", 440 | "LANGFUSE_SECRET_KEY=\"FILL_IN\"\n", 441 | "LANGFUSE_HOST=\"FILL_IN\"\n", 442 | "\n", 443 | "\"\"\"" 444 | ] 445 | }, 446 | { 447 | "cell_type": "code", 448 | "execution_count": null, 449 | "metadata": {}, 450 | "outputs": [], 451 | "source": [ 452 | "import os\n", 453 | "from string import Template\n", 454 | "\n", 455 | "# Set the correct paths relative to current location\n", 456 | "base_dir = os.path.dirname(os.path.dirname(os.getcwd())) # Go up two levels\n", 457 | "template_file_path = os.path.join(base_dir, 'config.env.tpl')\n", 458 | "config_file_path = os.path.join(base_dir, 'config.env')\n", 459 | "\n", 460 | "# Read the template file from the Bedrock Agent Evaluation Framework\n", 461 | "with open(template_file_path, 'r') as template_file:\n", 462 | " template_content = template_file.read()\n", 463 | "\n", 464 | "\n", 465 | "# Convert template content and user input into dictionaries\n", 466 | "def parse_env_content(content):\n", 467 | " env_dict = {}\n", 468 | " for line in content.split('\\n'):\n", 469 | " line = line.strip()\n", 470 | " if line and not line.startswith('#'):\n", 471 | " if '=' in line:\n", 472 | " key, value = line.split('=', 1)\n", 473 | " env_dict[key.strip()] = value.strip()\n", 474 | " return env_dict\n", 475 | "\n", 476 | "template_dict = parse_env_content(template_content)\n", 477 | "user_dict = parse_env_content(user_input)\n", 478 | "\n", 479 | "# Merge dictionaries, with user input taking precedence\n", 480 | "final_dict = {**template_dict, **user_dict}\n", 481 | "\n", 482 | "# Create the config.env content\n", 483 | "config_content = \"\"\n", 484 | "for key, value in final_dict.items():\n", 485 | " config_content += f\"{key}={value}\\n\"\n", 486 | "\n", 487 | "# Write to config.env file in the correct folder\n", 488 | "with open(config_file_path, 'w') as config_file:\n", 489 | " config_file.write(config_content)\n", 490 | "\n", 491 | "print(f\"config.env file has been created successfully in amazon-bedrock-agent-evaluation-framework!\")" 492 | ] 493 | }, 494 | { 495 | "cell_type": "markdown", 496 | "metadata": {}, 497 | "source": [ 498 | "## Run [Bedrock Agent Evaluation Framework](https://github.com/aws-samples/amazon-bedrock-agent-evaluation-framework) on the newly created sample RAG agent!\n", 499 | "Note: For some questions, the RAG agent may run into issues evaluating, in that case an error trace will show in Langfuse" 500 | ] 501 | }, 502 | { 503 | "cell_type": "markdown", 504 | "metadata": {}, 505 | "source": [ 506 | "![Langfuse Dashboard](../img/rag_langfuse_dashboard.png)" 507 | ] 508 | }, 509 | { 510 | "cell_type": "markdown", 511 | "metadata": {}, 512 | "source": [ 513 | "![Trace Dashboard](../img/rag_trace_dashboard.png)" 514 | ] 515 | }, 516 | { 517 | "cell_type": "markdown", 518 | "metadata": {}, 519 | "source": [ 520 | "![Trace](../img/rag_trace.png)" 521 | ] 522 | }, 523 | { 524 | "cell_type": "code", 525 | "execution_count": null, 526 | "metadata": {}, 527 | "outputs": [], 528 | "source": [ 529 | "# Execute bash script to run evaluation\n", 530 | "!cd .. && chmod +x execute_eval.sh && ./execute_eval.sh" 531 | ] 532 | }, 533 | { 534 | "cell_type": "markdown", 535 | "metadata": {}, 536 | "source": [ 537 | "### Navigate to your Langfuse host address, open the relevant Langfuse project, and view the traces populated there during evaluation run" 538 | ] 539 | } 540 | ], 541 | "metadata": { 542 | "kernelspec": { 543 | "display_name": "conda_python3", 544 | "language": "python", 545 | "name": "conda_python3" 546 | }, 547 | "language_info": { 548 | "codemirror_mode": { 549 | "name": "ipython", 550 | "version": 3 551 | }, 552 | "file_extension": ".py", 553 | "mimetype": "text/x-python", 554 | "name": "python", 555 | "nbconvert_exporter": "python", 556 | "pygments_lexer": "ipython3", 557 | "version": "3.10.16" 558 | } 559 | }, 560 | "nbformat": 4, 561 | "nbformat_minor": 4 562 | } 563 | -------------------------------------------------------------------------------- /blog_sample_agents/2-Sample-text2sql-agent/data_prep.py: -------------------------------------------------------------------------------- 1 | import boto3 2 | import json 3 | import os 4 | import pandas as pd 5 | import sqlite3 6 | import uuid 7 | import zipfile 8 | from botocore.exceptions import ClientError 9 | from pathlib import Path 10 | import time 11 | import joblib 12 | 13 | 14 | # S3 Bucket Creation and Setup 15 | def create_s3_bucket(bucket_name, region): 16 | s3_client = boto3.client('s3') 17 | s3 = boto3.resource('s3') 18 | 19 | try: 20 | 21 | if region == 'us-east-1': 22 | # For us-east-1, don't specify LocationConstraint 23 | s3_client.create_bucket(Bucket=bucket_name) 24 | print(f"Created bucket: {bucket_name}") 25 | else: 26 | s3_client.create_bucket( 27 | Bucket=bucket_name, 28 | CreateBucketConfiguration={'LocationConstraint': region} 29 | ) 30 | print(f"Created bucket: {bucket_name}") 31 | 32 | return bucket_name 33 | 34 | except ClientError as e: 35 | print(f"Error: {e}") 36 | return None 37 | 38 | # Data Processing Functions 39 | def create_and_unzip(first_zip, new_directory, second_zip): 40 | notebook_dir = Path.cwd() 41 | new_dir = Path(new_directory) 42 | new_dir.mkdir(parents=True, exist_ok=True) 43 | print(f"Created directory: {new_dir}") 44 | 45 | with zipfile.ZipFile(first_zip, 'r') as zip_ref: 46 | zip_ref.extractall(new_directory) 47 | print(f"Unzipped {first_zip} to {new_directory}") 48 | 49 | subdirs = [d for d in new_dir.iterdir() if d.is_dir()] 50 | 51 | if subdirs: 52 | auto_found_dir = subdirs[0] 53 | print(f"Found directory: {auto_found_dir}") 54 | second_zip_path = auto_found_dir / second_zip 55 | print(second_zip_path) 56 | if second_zip_path.exists(): 57 | with zipfile.ZipFile(second_zip_path, 'r') as zip_ref: 58 | zip_ref.extractall(notebook_dir) 59 | print(f"Unzipped {second_zip} in {notebook_dir}") 60 | else: 61 | print(f"Could not find {second_zip} in {auto_found_dir}") 62 | else: 63 | print(f"No directories found in {new_dir}") 64 | 65 | def process_database_and_upload(database_folder, bucket_name): 66 | folder_path = Path(database_folder) 67 | database_name = folder_path.name 68 | 69 | sqlite_files = list(folder_path.glob('*.db')) + list(folder_path.glob('*.sqlite')) 70 | if not sqlite_files: 71 | print(f"No SQLite file found in {database_folder}") 72 | return 73 | 74 | sqlite_file = sqlite_files[0] 75 | print(f"\nProcessing database: {database_name}") 76 | print(f"SQLite file: {sqlite_file}") 77 | 78 | try: 79 | conn = sqlite3.connect(sqlite_file) 80 | cursor = conn.cursor() 81 | cursor.execute("SELECT name FROM sqlite_master WHERE type='table';") 82 | tables = cursor.fetchall() 83 | s3_client = boto3.client('s3') 84 | 85 | for table in tables: 86 | table_name = table[0] 87 | print(f"\nProcessing table: {table_name}") 88 | 89 | try: 90 | df = pd.read_sql_query(f'SELECT * FROM "{table_name}"', conn) 91 | parquet_filename = f"{table_name}.parquet" 92 | local_parquet_path = folder_path / parquet_filename 93 | df.to_parquet(local_parquet_path, index=False) 94 | 95 | s3_key = f"{database_name}/{parquet_filename}" 96 | s3_client.upload_file( 97 | str(local_parquet_path), 98 | bucket_name, 99 | s3_key 100 | ) 101 | print(f"Uploaded to s3://{bucket_name}/{s3_key}") 102 | os.remove(local_parquet_path) 103 | 104 | except Exception as e: 105 | print(f"Error processing table {table_name}: {str(e)}") 106 | continue 107 | 108 | conn.close() 109 | 110 | except Exception as e: 111 | print(f"Error processing database {database_name}: {str(e)}") 112 | return 113 | 114 | def set_athena_result_location(result_bucket): 115 | try: 116 | athena_client = boto3.client('athena') 117 | s3_output_location = f's3://{result_bucket}/athena-results/' 118 | 119 | response = athena_client.update_work_group( 120 | WorkGroup='primary', 121 | ConfigurationUpdates={ 122 | 'ResultConfigurationUpdates': { 123 | 'OutputLocation': s3_output_location 124 | }, 125 | 'EnforceWorkGroupConfiguration': True 126 | } 127 | ) 128 | 129 | print(f"Successfully set Athena query result location to: {s3_output_location}") 130 | return True 131 | 132 | except Exception as e: 133 | print(f"Error setting result location: {e}") 134 | return False 135 | 136 | def list_s3_folders_and_files(bucket_name): 137 | """Get all database folders and their parquet files""" 138 | s3_client = boto3.client('s3') 139 | 140 | try: 141 | # Get all objects in bucket 142 | paginator = s3_client.get_paginator('list_objects_v2') 143 | database_tables = {} 144 | 145 | for page in paginator.paginate(Bucket=bucket_name): 146 | if 'Contents' not in page: 147 | continue 148 | 149 | for obj in page['Contents']: 150 | # Split the key into parts 151 | parts = obj['Key'].split('/') 152 | 153 | # Check if it's a parquet file 154 | if len(parts) >= 2 and parts[-1].endswith('.parquet'): 155 | database_name = parts[0] 156 | table_name = parts[-1].replace('.parquet', '') 157 | 158 | if database_name not in database_tables: 159 | database_tables[database_name] = [] 160 | database_tables[database_name].append(table_name) 161 | 162 | return database_tables 163 | 164 | except ClientError as e: 165 | print(f"Error listing S3 contents: {e}") 166 | return None 167 | 168 | 169 | def generate_and_create_table(results_bucket_name, parquet_bucket_name, database_name, table_name): 170 | """Generate and create a single table""" 171 | try: 172 | # Generate DDL 173 | s3_path = f's3://{parquet_bucket_name}/{database_name}/{table_name}.parquet' 174 | df = pd.read_parquet(s3_path) 175 | 176 | # Map pandas types to Athena types 177 | type_mapping = { 178 | 'object': 'string', 179 | 'int64': 'int', 180 | 'float64': 'double', 181 | 'bool': 'boolean', 182 | 'datetime64[ns]': 'timestamp' 183 | } 184 | 185 | # Generate column definitions 186 | columns = [] 187 | for col, dtype in df.dtypes.items(): 188 | athena_type = type_mapping.get(str(dtype), 'string') 189 | columns.append(f"`{col}` {athena_type}") 190 | 191 | # Create DDL statement 192 | column_definitions = ',\n '.join(columns) 193 | s3_location = f's3://{parquet_bucket_name}/{database_name}/' 194 | 195 | ddl = f"""CREATE EXTERNAL TABLE IF NOT EXISTS {database_name}.{table_name} ( 196 | {column_definitions} 197 | ) 198 | STORED AS PARQUET 199 | LOCATION '{s3_location}';""" 200 | 201 | print(f"\nGenerating table: {database_name}.{table_name}") 202 | 203 | # Execute DDL 204 | athena_client = boto3.client('athena') 205 | 206 | # Create table 207 | response = athena_client.start_query_execution( 208 | QueryString=ddl, 209 | QueryExecutionContext={ 210 | 'Database': database_name 211 | }, 212 | ResultConfiguration={ 213 | 'OutputLocation': f's3://{results_bucket_name}/athena-results/' 214 | } 215 | ) 216 | 217 | # Wait for table creation 218 | query_execution_id = response['QueryExecutionId'] 219 | while True: 220 | response = athena_client.get_query_execution(QueryExecutionId=query_execution_id) 221 | state = response['QueryExecution']['Status']['State'] 222 | if state in ['SUCCEEDED', 'FAILED', 'CANCELLED']: 223 | break 224 | time.sleep(1) 225 | 226 | if state == 'SUCCEEDED': 227 | print(f"Created table: {database_name}.{table_name}") 228 | return True 229 | else: 230 | print(f"Failed to create table {database_name}.{table_name}: {state}") 231 | return False 232 | 233 | except Exception as e: 234 | print(f"Error creating table {database_name}.{table_name}: {e}") 235 | return False 236 | 237 | def create_all_databases_and_tables(results_bucket_name, parquet_bucket_name): 238 | """Create all databases and tables from S3 bucket structure""" 239 | try: 240 | # Get database and table structure from S3 241 | database_tables = list_s3_folders_and_files(parquet_bucket_name) 242 | if not database_tables: 243 | print("No databases/tables found in S3") 244 | return False 245 | 246 | athena_client = boto3.client('athena') 247 | 248 | # Process each database 249 | for database_name, tables in database_tables.items(): 250 | print(f"\nProcessing database: {database_name}") 251 | 252 | # Create database 253 | create_database = f"CREATE DATABASE IF NOT EXISTS {database_name}" 254 | response = athena_client.start_query_execution( 255 | QueryString=create_database, 256 | ResultConfiguration={ 257 | 'OutputLocation': f's3://{results_bucket_name}/athena-results/' 258 | } 259 | ) 260 | 261 | # Wait for database creation 262 | query_execution_id = response['QueryExecutionId'] 263 | while True: 264 | response = athena_client.get_query_execution(QueryExecutionId=query_execution_id) 265 | state = response['QueryExecution']['Status']['State'] 266 | if state in ['SUCCEEDED', 'FAILED', 'CANCELLED']: 267 | break 268 | time.sleep(1) 269 | 270 | if state == 'SUCCEEDED': 271 | print(f"Created database: {database_name}") 272 | 273 | # Create each table in the database 274 | for table_name in tables: 275 | generate_and_create_table( 276 | results_bucket_name, 277 | parquet_bucket_name, 278 | database_name, 279 | table_name 280 | ) 281 | else: 282 | print(f"Failed to create database {database_name}: {state}") 283 | continue 284 | 285 | return True 286 | 287 | except Exception as e: 288 | print(f"Error in create_all_databases_and_tables: {e}") 289 | return False 290 | 291 | def filter_on_db(input_file, output_file, db_name): 292 | # Read the input JSON file 293 | with open(input_file, 'r') as f: 294 | data = json.load(f) 295 | 296 | # Filter entries where db_id is "california_schools" 297 | filtered_data = [entry for entry in data if entry.get('db_id') == db_name] 298 | 299 | # Write the filtered data to output file 300 | with open(output_file, 'w') as f: 301 | json.dump(filtered_data, f, indent=2) 302 | 303 | def main(): 304 | # Configuration 305 | REGION = os.environ.get('REGION') 306 | BASE_BUCKET_NAME = os.environ.get('BASE_BUCKET_NAME') 307 | ATHENA_RESULTS_BUCKET_NAME = os.environ.get('ATHENA_RESULTS_BUCKET_NAME') 308 | BASE_DIR = os.environ.get('BASE_DIR') 309 | DATABASE_NAME = os.environ.get('DATABASE_NAME') 310 | 311 | 312 | # Step 1: Unzip files 313 | create_and_unzip( 314 | first_zip='dev.zip', 315 | new_directory='unzipped_dev', 316 | second_zip='dev_databases.zip' 317 | ) 318 | 319 | # Step 2: Create buckets 320 | main_bucket = create_s3_bucket(BASE_BUCKET_NAME, REGION) 321 | athena_results_bucket = create_s3_bucket(ATHENA_RESULTS_BUCKET_NAME, REGION) 322 | 323 | if not all([main_bucket, athena_results_bucket]): 324 | print("Failed to create required buckets") 325 | return 326 | 327 | # Step 3: Process and upload database 328 | base_path = Path(BASE_DIR) 329 | target_folder = base_path/DATABASE_NAME # Create path to specific database folder 330 | 331 | if target_folder.exists() and target_folder.is_dir(): 332 | process_database_and_upload(target_folder, main_bucket) 333 | else: 334 | print(f"Database folder '{DATABASE_NAME}' not found in {BASE_DIR}") 335 | 336 | # Step 4: Setup Athena configurations 337 | set_athena_result_location(athena_results_bucket) 338 | 339 | # Step 5: Create Athena databases and tables 340 | success = create_all_databases_and_tables(athena_results_bucket, main_bucket) 341 | if success: 342 | print("\nCompleted creating all databases and tables in Athena!") 343 | 344 | # Step 6: Generate birdsql.json 345 | filter_on_db('unzipped_dev/dev_20240627/dev.json','birdsql_data.json',DATABASE_NAME) 346 | print("Created birdsql_data.json for agent") 347 | 348 | 349 | if __name__ == "__main__": 350 | main() 351 | -------------------------------------------------------------------------------- /blog_sample_agents/2-Sample-text2sql-agent/lambda_function.py: -------------------------------------------------------------------------------- 1 | import boto3 2 | import time 3 | import os 4 | import uuid 5 | import json 6 | import sys 7 | from collections import defaultdict 8 | 9 | athena_client = boto3.client('athena') 10 | 11 | def get_schema(database_name="california_schools"): 12 | """ 13 | Get schema information for all tables in Athena databases 14 | """ 15 | 16 | sql = f""" 17 | SELECT 18 | table_name, 19 | column_name, 20 | data_type 21 | FROM information_schema.columns 22 | WHERE table_schema = '{database_name}' 23 | ORDER BY table_name, ordinal_position; 24 | """ 25 | 26 | try: 27 | # Start query execution 28 | response = athena_client.start_query_execution( 29 | QueryString=sql, 30 | QueryExecutionContext={ 31 | 'Database': database_name 32 | } 33 | ) 34 | 35 | query_execution_id = response['QueryExecutionId'] 36 | 37 | def wait_for_query_completion(query_execution_id): 38 | while True: 39 | response = athena_client.get_query_execution( 40 | QueryExecutionId=query_execution_id 41 | ) 42 | state = response['QueryExecution']['Status']['State'] 43 | 44 | if state in ['SUCCEEDED', 'FAILED', 'CANCELLED']: 45 | print(f"Query {state}") 46 | return state 47 | 48 | print("Waiting for query to complete...") 49 | time.sleep(2) 50 | 51 | # Wait for query completion 52 | state = wait_for_query_completion(query_execution_id) 53 | 54 | if state == 'SUCCEEDED': 55 | # Get query results 56 | results = athena_client.get_query_results( 57 | QueryExecutionId=query_execution_id 58 | ) 59 | print("Got query results for schema") 60 | # Assuming you have a database connection and cursor setup 61 | # cursor.execute(sql) 62 | # results = cursor.fetchall() 63 | 64 | database_structure = [] 65 | table_dict = {} 66 | 67 | # Skip the header row 68 | rows = results['ResultSet']['Rows'][1:] 69 | 70 | for row in rows: 71 | # Extract values from the Data structure 72 | table_name = row['Data'][0]['VarCharValue'] 73 | column_name = row['Data'][1]['VarCharValue'] 74 | data_type = row['Data'][2]['VarCharValue'] 75 | 76 | # Initialize table if not exists 77 | if table_name not in table_dict: 78 | table_dict[table_name] = [] 79 | 80 | # Append column information 81 | table_dict[table_name].append((column_name, data_type)) 82 | 83 | # Convert to the desired format 84 | for table_name, columns in table_dict.items(): 85 | database_structure.append({ 86 | "table_name": table_name, 87 | "columns": columns 88 | }) 89 | 90 | return database_structure 91 | 92 | else: 93 | raise Exception(f"Query failed with state: {state}") 94 | except Exception as e: 95 | print(f"Error getting schema: {e}") 96 | raise 97 | 98 | def query_athena(query, database_name='california_schools'): 99 | """ 100 | Execute a query on Athena 101 | """ 102 | try: 103 | # Start query execution 104 | response = athena_client.start_query_execution( 105 | QueryString=query, 106 | QueryExecutionContext={ 107 | 'Database': database_name 108 | } 109 | ) 110 | 111 | query_execution_id = response['QueryExecutionId'] 112 | 113 | def wait_for_query_completion(query_execution_id): 114 | while True: 115 | response = athena_client.get_query_execution( 116 | QueryExecutionId=query_execution_id 117 | ) 118 | state = response['QueryExecution']['Status']['State'] 119 | 120 | if state == 'FAILED': 121 | error_message = response['QueryExecution']['Status'].get('StateChangeReason', 'Unknown error') 122 | raise Exception(f"Query failed: {error_message}") 123 | 124 | if state == 'CANCELLED': 125 | raise Exception("Query was cancelled") 126 | 127 | if state == 'SUCCEEDED': 128 | return state 129 | 130 | print("Waiting for query to complete...") 131 | time.sleep(2) 132 | 133 | # Wait for query completion 134 | state = wait_for_query_completion(query_execution_id) 135 | print("query complete") 136 | # Get query results 137 | print(state) 138 | 139 | if state == 'SUCCEEDED': 140 | results = athena_client.get_query_results( 141 | QueryExecutionId=query_execution_id 142 | ) 143 | print("got query results") 144 | print(results) 145 | # Process results 146 | processed_results = [] 147 | headers = [] 148 | 149 | # Get headers from first row 150 | if results['ResultSet']['Rows']: 151 | headers = [field.get('VarCharValue', '') for field in results['ResultSet']['Rows'][0]['Data']] 152 | 153 | # Process data rows 154 | for row in results['ResultSet']['Rows'][1:]: 155 | values = [field.get('VarCharValue', '') for field in row['Data']] 156 | row_dict = dict(zip(headers, values)) 157 | processed_results.append(row_dict) 158 | 159 | print(processed_results) 160 | return processed_results 161 | 162 | else: 163 | raise Exception(f"Query failed with state: {state}") 164 | 165 | except Exception as e: 166 | print(f"Error executing query: {e}") 167 | raise 168 | 169 | def upload_result_s3(result, bucket, key): 170 | s3_client = boto3.client('s3') 171 | s3_client.put_object( 172 | Bucket=bucket, 173 | Key=key, 174 | Body=json.dumps(result) 175 | ) 176 | return { 177 | "storage": "s3", 178 | "bucket": bucket, 179 | "key": key 180 | } 181 | 182 | def lambda_handler(event, context): 183 | result = None 184 | error_message = None 185 | 186 | try: 187 | if event['apiPath'] == "/getschema": 188 | result = get_schema() 189 | 190 | elif event['apiPath'] == "/queryathena": 191 | params =event['parameters'] 192 | for param in params: 193 | if param.get("name") == "query": 194 | query = param.get("value") 195 | print(query) 196 | 197 | result = query_athena(query) 198 | print("end of query ") 199 | 200 | else: 201 | raise ValueError(f"Unknown apiPath: {event['apiPath']}") 202 | 203 | if result: 204 | print("Query Result:", result) 205 | 206 | except Exception as e: 207 | error_message = str(e) 208 | print(f"Error occurred: {error_message}") 209 | 210 | BUCKET_NAME = os.environ['BUCKET_NAME'] 211 | KEY = str(uuid.uuid4()) + '.json' 212 | size = sys.getsizeof(str(result)) if result else 0 213 | print(f"Response size: {size} bytes") 214 | 215 | if size > 20000: 216 | print('Size greater than 20KB, writing to a file in S3') 217 | result = upload_result_s3(result, BUCKET_NAME, KEY) 218 | response_body = { 219 | 'application/json': { 220 | 'body': f"Result uploaded to S3. Bucket: {BUCKET_NAME}, Key: {KEY}" 221 | } 222 | } 223 | else: 224 | response_body = { 225 | 'application/json': { 226 | 'body': str(result) if result else error_message 227 | } 228 | } 229 | 230 | action_response = { 231 | 'actionGroup': event['actionGroup'], 232 | 'apiPath': event['apiPath'], 233 | 'httpMethod': event['httpMethod'], 234 | 'httpStatusCode': 200 if result else 500, 235 | 'responseBody': response_body 236 | } 237 | 238 | api_response = { 239 | 'messageVersion': '1.0', 240 | 'response': action_response, 241 | } 242 | 243 | return api_response -------------------------------------------------------------------------------- /blog_sample_agents/2-Sample-text2sql-agent/openapi_schema.json: -------------------------------------------------------------------------------- 1 | { 2 | "openapi": "3.0.1", 3 | "info": { 4 | "title": "Database schema look up and query APIs", 5 | "version": "1.0.0", 6 | "description": "APIs for looking up database table schemas and making queries to database tables." 7 | }, 8 | "paths": { 9 | "/getschema": { 10 | "get": { 11 | "summary": "Get a list of all columns in the athena database", 12 | "description": "Get the list of all columns in the athena database table. Return all the column information in database table.", 13 | "operationId": "getschema", 14 | "responses": { 15 | "200": { 16 | "description": "Gets the list of table names and their schemas in the database", 17 | "content": { 18 | "application/json": { 19 | "schema": { 20 | "type": "array", 21 | "items": { 22 | "type": "object", 23 | "properties": { 24 | "Table": { 25 | "type": "string", 26 | "description": "The name of the table in the database." 27 | }, 28 | "Schema": { 29 | "type": "string", 30 | "description": "The schema of the table in the database. Contains all columns needed for making queries." 31 | } 32 | } 33 | } 34 | } 35 | } 36 | } 37 | } 38 | } 39 | } 40 | }, 41 | "/queryathena": { 42 | "get": { 43 | "summary": "API to send query to the athena database table", 44 | "description": "Send a query to the database table to retrieve information pertaining to the users question. The API takes in only one SQL query at a time, sends the SQL statement and returns the query results from the table. This API should be called for each SQL query to a database table.", 45 | "operationId": "queryathena", 46 | "parameters": [ 47 | { 48 | "name": "query", 49 | "in": "query", 50 | "required": true, 51 | "schema": { 52 | "type": "string" 53 | }, 54 | "description": "SQL statement to query database table." 55 | } 56 | ], 57 | "responses": { 58 | "200": { 59 | "description": "Query sent successfully", 60 | "content": { 61 | "application/json": { 62 | "schema": { 63 | "type": "object", 64 | "properties": { 65 | "responseBody": { 66 | "type": "string", 67 | "description": "The query response from the database." 68 | } 69 | } 70 | } 71 | } 72 | } 73 | }, 74 | "400": { 75 | "description": "Bad request. One or more required fields are missing or invalid." 76 | } 77 | } 78 | } 79 | } 80 | } 81 | } -------------------------------------------------------------------------------- /blog_sample_agents/2-Sample-text2sql-agent/sample_text2sql_agent.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Sample Text2SQL Agent Walkthrough\n", 8 | "\n", 9 | "This notebook will walk users through setting up a Text2SQL Agent with [BirdSQL - Mini Dev Dataset](https://github.com/bird-bench/mini_dev) and evaluating it with [Bedrock Agent Evaluation Framework](https://github.com/aws-samples/amazon-bedrock-agent-evaluation-framework/tree/main)" 10 | ] 11 | }, 12 | { 13 | "cell_type": "markdown", 14 | "metadata": {}, 15 | "source": [ 16 | "## Prerequisites\n", 17 | "\n", 18 | "This notebook assumes that you have gone through the notebook environment setup in the 0-Notebook-environment/ folder and have set up a Langfuse project" 19 | ] 20 | }, 21 | { 22 | "cell_type": "markdown", 23 | "metadata": {}, 24 | "source": [ 25 | "#### Ensure the latest version of boto3 is shown below\n", 26 | "\n", 27 | "##### If not then run through setup_environment.ipynb in the 0-Notebook-environment/ folder" 28 | ] 29 | }, 30 | { 31 | "cell_type": "code", 32 | "execution_count": null, 33 | "metadata": {}, 34 | "outputs": [], 35 | "source": [ 36 | "!pip freeze | grep \"boto3\"" 37 | ] 38 | }, 39 | { 40 | "cell_type": "markdown", 41 | "metadata": {}, 42 | "source": [ 43 | "#### Load in environment variables to notebook" 44 | ] 45 | }, 46 | { 47 | "cell_type": "code", 48 | "execution_count": null, 49 | "metadata": {}, 50 | "outputs": [], 51 | "source": [ 52 | "# Retrieve import path\n", 53 | "%store -r IMPORTS_PATH\n", 54 | "\n", 55 | "# Retrieve account info\n", 56 | "%store -r account_id\n", 57 | "%store -r region\n", 58 | "\n", 59 | "# Retrieve model lists\n", 60 | "%store -r agent_foundation_model" 61 | ] 62 | }, 63 | { 64 | "cell_type": "markdown", 65 | "metadata": {}, 66 | "source": [ 67 | "#### Retrieve imports environment variable and bring libraries into notebook" 68 | ] 69 | }, 70 | { 71 | "cell_type": "code", 72 | "execution_count": null, 73 | "metadata": {}, 74 | "outputs": [], 75 | "source": [ 76 | "%run $IMPORTS_PATH" 77 | ] 78 | }, 79 | { 80 | "cell_type": "markdown", 81 | "metadata": {}, 82 | "source": [ 83 | "# Download BirdSQL - Mini Dev Dataset\n", 84 | "Note: This can take up to several minutes" 85 | ] 86 | }, 87 | { 88 | "cell_type": "code", 89 | "execution_count": null, 90 | "metadata": {}, 91 | "outputs": [], 92 | "source": [ 93 | "# Download .zip file to local directory\n", 94 | "\n", 95 | "!wget https://bird-bench.oss-cn-beijing.aliyuncs.com/dev.zip" 96 | ] 97 | }, 98 | { 99 | "cell_type": "markdown", 100 | "metadata": {}, 101 | "source": [ 102 | "## Transform and store data for Text2SQL agent" 103 | ] 104 | }, 105 | { 106 | "cell_type": "markdown", 107 | "metadata": {}, 108 | "source": [ 109 | "In order to run the Text2SQL agent, we will need to setup the Athena databases to make SQL queries against. The following script will:\n", 110 | "1. Unzip the downloaded folder\n", 111 | "2. Create S3 buckets\n", 112 | "3. Convert .sqlite files into individual .parquet files for each table\n", 113 | "4. Upload to the database s3 bucket\n", 114 | "5. Set up appropriate Athena permissions\n", 115 | "6. Create databases in Athena" 116 | ] 117 | }, 118 | { 119 | "cell_type": "code", 120 | "execution_count": null, 121 | "metadata": {}, 122 | "outputs": [], 123 | "source": [ 124 | "# Store the variables in os\n", 125 | "\n", 126 | "random_suffix_1 = uuid.uuid4().hex[:6]\n", 127 | "base_bucket_name = f\"{'text2sql-agent'}-{account_id}-{random_suffix_1}\"\n", 128 | "random_suffix_2 = uuid.uuid4().hex[:6]\n", 129 | "athena_results_bucket_name = f\"{'text2sql-athena-results'}-{account_id}-{random_suffix_2}\"\n", 130 | "athena_database_name = 'california_schools'\n", 131 | "\n", 132 | "os.environ['REGION'] = region\n", 133 | "os.environ['BASE_BUCKET_NAME'] = base_bucket_name\n", 134 | "os.environ['ATHENA_RESULTS_BUCKET_NAME'] = athena_results_bucket_name\n", 135 | "os.environ['BASE_DIR'] = 'dev_databases'\n", 136 | "os.environ['DATABASE_NAME'] = athena_database_name\n", 137 | "\n", 138 | "%store base_bucket_name\n", 139 | "%store athena_results_bucket_name\n", 140 | "%store athena_database_name" 141 | ] 142 | }, 143 | { 144 | "cell_type": "code", 145 | "execution_count": null, 146 | "metadata": {}, 147 | "outputs": [], 148 | "source": [ 149 | "%run data_prep.py" 150 | ] 151 | }, 152 | { 153 | "cell_type": "markdown", 154 | "metadata": {}, 155 | "source": [ 156 | "# Create Text2SQL Agent" 157 | ] 158 | }, 159 | { 160 | "cell_type": "code", 161 | "execution_count": null, 162 | "metadata": {}, 163 | "outputs": [], 164 | "source": [ 165 | "agent_name = 'sample-text2sql-agent'\n", 166 | "agent_description = \"Text2SQL agent to run against Bird-SQL Mini-Dev benchmark dataset\"\n", 167 | "agent_instruction = \"\"\"\n", 168 | "You are an AI Agent specialized in generating SQL queries for Amazon Athena against Amazon S3 .parquet files. \n", 169 | "Your primary task is to interpret user queries, generate appropriate SQL queries, and provide the executed sql \n", 170 | "query as well as relevant answers based on the data. Follow these instructions carefully: 1. Before generating any \n", 171 | "SQL query, use the /getschema tool to familiarize yourself with the data structure. 2. When generating an SQL query: \n", 172 | "a. Write the query as a single line, removing all newline characters. b. Column names must be exactly as they appear \n", 173 | "in the schema, including spaces. Do not replace spaces with underscores. c. Always enclose column names that contain \n", 174 | "spaces in double quotes (\"). d. Be extra careful with column names containing special characters or spaces. \n", 175 | "3. Column name handling: a. Never modify column names. Use them exactly as they appear in the schema. \n", 176 | "b. If a column name contains spaces or special characters, always enclose it in double quotes (\"). \n", 177 | "c. Do not use underscores in place of spaces in column names. 4. Query output format: \n", 178 | "a. Always include the exact query that was run in your response. Start your response with \n", 179 | "\"Executed SQL Query:\" followed by the exact query that was run. b. Format the SQL query in a code block \n", 180 | "using three backticks (```). c. After the query, provide your explanation and analysis. \n", 181 | "5. When providing your response: a. Start with the executed SQL query as specified in step \n", 182 | "4. b. Double-check that all column names in your generated query match the schema exactly. \n", 183 | "c. Ask for clarifications from the user if required. 6. Error handling: a. \n", 184 | "If a query fails due to column name issues: - Review the schema and correct any mismatched column names. - \n", 185 | "Ensure all column names with spaces are enclosed in double quotes. - Regenerate the query with corrected column names. - \n", 186 | "Display both the failed query and the corrected query. b. Implement retry logic with up to 3 attempts for failed queries. \n", 187 | "Here are a few examples of generating SQL queries based on a question: \n", 188 | "Question: What is the highest eligible free rate for K-12 students in the schools in Alameda County? \n", 189 | "Executed SQL Query: \"SELECT `Free Meal Count (K-12)` / `Enrollment (K-12)` FROM frpm WHERE `County Name` = 'Alameda' \n", 190 | "ORDER BY (CAST(`Free Meal Count (K-12)` AS REAL) / `Enrollment (K-12)`) DESC LIMIT 1\" Question: Please list the zip \n", 191 | "code of all the charter schools in Fresno County Office of Education. Executed SQL Query: \"SELECT T2.Zip FROM frpm \n", 192 | "AS T1 INNER JOIN schools AS T2 ON T1.CDSCode = T2.CDSCode WHERE T1.`District Name` = 'Fresno County Office of Education' \n", 193 | "AND T1.`Charter School (Y/N)` = 1\" Question: Consider the average difference between K-12 enrollment and 15-17 enrollment \n", 194 | "of schools that are locally funded, list the names and DOC type of schools which has a difference above this average. \n", 195 | "Executed SQL Query: \"SELECT T2.School, T2.DOC FROM frpm AS T1 INNER JOIN schools AS T2 ON T1.CDSCode = T2.CDSCode \n", 196 | "WHERE T2.FundingType = 'Locally funded' AND (T1.`Enrollment (K-12)` - T1.`Enrollment (Ages 5-17)`) > \n", 197 | "(SELECT AVG(T3.`Enrollment (K-12)` - T3.`Enrollment (Ages 5-17)`) FROM frpm AS T3 INNER JOIN schools AS T4 ON T3.CDSCode = \n", 198 | "T4.CDSCode WHERE T4.FundingType = 'Locally funded')\"\n", 199 | "\"\"\"" 200 | ] 201 | }, 202 | { 203 | "cell_type": "code", 204 | "execution_count": null, 205 | "metadata": {}, 206 | "outputs": [], 207 | "source": [ 208 | "agents = AgentsForAmazonBedrock()\n", 209 | "\n", 210 | "text2sql_agent = agents.create_agent(\n", 211 | " agent_name,\n", 212 | " agent_description,\n", 213 | " agent_instruction,\n", 214 | " agent_foundation_model,\n", 215 | " code_interpretation=False,\n", 216 | " verbose=False\n", 217 | ")\n", 218 | "\n", 219 | "text2sql_agent" 220 | ] 221 | }, 222 | { 223 | "cell_type": "code", 224 | "execution_count": null, 225 | "metadata": {}, 226 | "outputs": [], 227 | "source": [ 228 | "text2sql_agent_id = text2sql_agent[0]\n", 229 | "text2sql_agent_arn = f\"arn:aws:bedrock:{region}:{account_id}:agent/{text2sql_agent_id}\"\n", 230 | "\n", 231 | "text2sql_agent_id, text2sql_agent_arn" 232 | ] 233 | }, 234 | { 235 | "cell_type": "code", 236 | "execution_count": null, 237 | "metadata": {}, 238 | "outputs": [], 239 | "source": [ 240 | "api_schema_string = '''{\n", 241 | " \"openapi\": \"3.0.1\",\n", 242 | " \"info\": {\n", 243 | " \"title\": \"Database schema look up and query APIs\",\n", 244 | " \"version\": \"1.0.0\",\n", 245 | " \"description\": \"APIs for looking up database table schemas and making queries to database tables.\"\n", 246 | " },\n", 247 | " \"paths\": {\n", 248 | " \"/getschema\": {\n", 249 | " \"get\": {\n", 250 | " \"summary\": \"Get a list of all columns in the athena database\",\n", 251 | " \"description\": \"Get the list of all columns in the athena database table. Return all the column information in database table.\",\n", 252 | " \"operationId\": \"getschema\",\n", 253 | " \"responses\": {\n", 254 | " \"200\": {\n", 255 | " \"description\": \"Gets the list of table names and their schemas in the database\",\n", 256 | " \"content\": {\n", 257 | " \"application/json\": {\n", 258 | " \"schema\": {\n", 259 | " \"type\": \"array\",\n", 260 | " \"items\": {\n", 261 | " \"type\": \"object\",\n", 262 | " \"properties\": {\n", 263 | " \"Table\": {\n", 264 | " \"type\": \"string\",\n", 265 | " \"description\": \"The name of the table in the database.\"\n", 266 | " },\n", 267 | " \"Schema\": {\n", 268 | " \"type\": \"string\",\n", 269 | " \"description\": \"The schema of the table in the database. Contains all columns needed for making queries.\"\n", 270 | " }\n", 271 | " }\n", 272 | " }\n", 273 | " }\n", 274 | " }\n", 275 | " }\n", 276 | " }\n", 277 | " }\n", 278 | " }\n", 279 | " },\n", 280 | " \"/queryathena\": {\n", 281 | " \"get\": {\n", 282 | " \"summary\": \"API to send query to the athena database table\",\n", 283 | " \"description\": \"Send a query to the database table to retrieve information pertaining to the users question. The API takes in only one SQL query at a time, sends the SQL statement and returns the query results from the table. This API should be called for each SQL query to a database table.\",\n", 284 | " \"operationId\": \"queryathena\",\n", 285 | " \"parameters\": [\n", 286 | " {\n", 287 | " \"name\": \"query\",\n", 288 | " \"in\": \"query\",\n", 289 | " \"required\": true,\n", 290 | " \"schema\": {\n", 291 | " \"type\": \"string\"\n", 292 | " },\n", 293 | " \"description\": \"SQL statement to query database table.\"\n", 294 | " }\n", 295 | " ],\n", 296 | " \"responses\": {\n", 297 | " \"200\": {\n", 298 | " \"description\": \"Query sent successfully\",\n", 299 | " \"content\": {\n", 300 | " \"application/json\": {\n", 301 | " \"schema\": {\n", 302 | " \"type\": \"object\",\n", 303 | " \"properties\": {\n", 304 | " \"responseBody\": {\n", 305 | " \"type\": \"string\",\n", 306 | " \"description\": \"The query response from the database.\"\n", 307 | " }\n", 308 | " }\n", 309 | " }\n", 310 | " }\n", 311 | " }\n", 312 | " },\n", 313 | " \"400\": {\n", 314 | " \"description\": \"Bad request. One or more required fields are missing or invalid.\"\n", 315 | " }\n", 316 | " }\n", 317 | " }\n", 318 | " }\n", 319 | " }\n", 320 | " } '''" 321 | ] 322 | }, 323 | { 324 | "cell_type": "code", 325 | "execution_count": null, 326 | "metadata": {}, 327 | "outputs": [], 328 | "source": [ 329 | "api_schema = {\"payload\": api_schema_string}" 330 | ] 331 | }, 332 | { 333 | "cell_type": "markdown", 334 | "metadata": {}, 335 | "source": [ 336 | "### Attach Lambda function and create ActionGroup" 337 | ] 338 | }, 339 | { 340 | "cell_type": "code", 341 | "execution_count": null, 342 | "metadata": {}, 343 | "outputs": [], 344 | "source": [ 345 | "text2sql_lambda_function_name = \"text2sql\"\n", 346 | "text2sql_lambda_function_arn = f\"arn:aws:lambda:{region}:{account_id}:function:{text2sql_lambda_function_name}\"\n", 347 | "%store text2sql_lambda_function_name\n", 348 | "%store text2sql_lambda_function_arn" 349 | ] 350 | }, 351 | { 352 | "cell_type": "code", 353 | "execution_count": null, 354 | "metadata": {}, 355 | "outputs": [], 356 | "source": [ 357 | "agents.add_action_group_with_lambda(\n", 358 | " agent_name=agent_name,\n", 359 | " lambda_function_name=text2sql_lambda_function_name,\n", 360 | " source_code_file=\"lambda_function.py\",\n", 361 | " agent_action_group_name=\"queryAthena\",\n", 362 | " agent_action_group_description=\"Action for getting the database schema and querying with Athena\",\n", 363 | " api_schema=api_schema,\n", 364 | " verbose=True\n", 365 | ")" 366 | ] 367 | }, 368 | { 369 | "cell_type": "code", 370 | "execution_count": null, 371 | "metadata": {}, 372 | "outputs": [], 373 | "source": [ 374 | "# Add environment variable to lambda function\n", 375 | "\n", 376 | "# Create unique bucket name\n", 377 | "bucket_suffix = uuid.uuid4().hex[:6]\n", 378 | "birdsql_bucket_name = f\"birdsql-results-{bucket_suffix}\"\n", 379 | "\n", 380 | "s3_client = boto3.client('s3')\n", 381 | "s3 = boto3.resource('s3')\n", 382 | " \n", 383 | "if region == 'us-east-1':\n", 384 | " # For us-east-1, don't specify LocationConstraint\n", 385 | " s3_client.create_bucket(Bucket=birdsql_bucket_name)\n", 386 | " print(f\"Created query results bucket: {birdsql_bucket_name}\")\n", 387 | "else:\n", 388 | " s3_client.create_bucket(\n", 389 | " Bucket=birdsql_bucket_name,\n", 390 | " CreateBucketConfiguration={'LocationConstraint': region}\n", 391 | " )\n", 392 | " print(f\"Created query results bucket: {birdsql_bucket_name}\")\n", 393 | "\n", 394 | "# Update Lambda environment variables\n", 395 | "lambda_client = boto3.client('lambda')\n", 396 | "\n", 397 | "try:\n", 398 | " # Get current configuration\n", 399 | " response = lambda_client.get_function_configuration(FunctionName=text2sql_lambda_function_name)\n", 400 | " current_env = response.get('Environment', {}).get('Variables', {})\n", 401 | " \n", 402 | " # Add new environment variable\n", 403 | " current_env['BUCKET_NAME'] = birdsql_bucket_name\n", 404 | " \n", 405 | " # Update Lambda configuration\n", 406 | " lambda_client.update_function_configuration(\n", 407 | " FunctionName=text2sql_lambda_function_name,\n", 408 | " Environment={\n", 409 | " 'Variables': current_env\n", 410 | " }\n", 411 | " )\n", 412 | " print(f\"Added BUCKET_NAME environment variable to '{text2sql_lambda_function_name}' Lambda function\")\n", 413 | "except Exception as e:\n", 414 | " print(f\"Error updating Lambda: {str(e)}\")\n", 415 | "\n", 416 | "%store birdsql_bucket_name" 417 | ] 418 | }, 419 | { 420 | "cell_type": "markdown", 421 | "metadata": {}, 422 | "source": [ 423 | "### Add resource based policy to Lambda function to allow agent to invoke" 424 | ] 425 | }, 426 | { 427 | "cell_type": "code", 428 | "execution_count": null, 429 | "metadata": {}, 430 | "outputs": [], 431 | "source": [ 432 | "lambda_client = boto3.client('lambda', region)\n", 433 | "\n", 434 | "# Define the resource policy statement\n", 435 | "policy_statement = {\n", 436 | " \"Sid\": \"AllowBedrockAgentAccess\",\n", 437 | " \"Effect\": \"Allow\",\n", 438 | " \"Principal\": {\n", 439 | " \"Service\": \"bedrock.amazonaws.com\"\n", 440 | " },\n", 441 | " \"Action\": \"lambda:InvokeFunction\",\n", 442 | " \"Resource\": text2sql_lambda_function_arn,\n", 443 | " \"Condition\": {\n", 444 | " \"ArnEquals\": {\n", 445 | " \"aws:SourceArn\": text2sql_agent_arn\n", 446 | " }\n", 447 | " }\n", 448 | "}\n", 449 | "\n", 450 | "try:\n", 451 | " # Get the current policy\n", 452 | " response = lambda_client.get_policy(FunctionName=text2sql_lambda_function_arn)\n", 453 | " current_policy = json.loads(response['Policy'])\n", 454 | " \n", 455 | " # Add the new statement to the existing policy\n", 456 | " current_policy['Statement'].append(policy_statement)\n", 457 | " \n", 458 | "except lambda_client.exceptions.ResourceNotFoundException:\n", 459 | " # If there's no existing policy, create a new one\n", 460 | " current_policy = {\n", 461 | " \"Version\": \"2012-10-17\",\n", 462 | " \"Statement\": [policy_statement]\n", 463 | " }\n", 464 | "\n", 465 | "# Convert the policy to JSON string\n", 466 | "updated_policy = json.dumps(current_policy)\n", 467 | "\n", 468 | "# Add or update the resource policy\n", 469 | "response = lambda_client.add_permission(\n", 470 | " FunctionName=text2sql_lambda_function_arn,\n", 471 | " StatementId=\"AllowText2SQLAgentAccess\",\n", 472 | " Action=\"lambda:InvokeFunction\",\n", 473 | " Principal=\"bedrock.amazonaws.com\",\n", 474 | " SourceArn=text2sql_agent_arn\n", 475 | ")\n", 476 | "\n", 477 | "print(\"Resource policy added successfully.\")\n", 478 | "print(\"Response:\", response)" 479 | ] 480 | }, 481 | { 482 | "cell_type": "markdown", 483 | "metadata": {}, 484 | "source": [ 485 | "### Add permissions to Lambda function execution role" 486 | ] 487 | }, 488 | { 489 | "cell_type": "code", 490 | "execution_count": null, 491 | "metadata": {}, 492 | "outputs": [], 493 | "source": [ 494 | "# Create clients\n", 495 | "iam_client = boto3.client('iam')\n", 496 | "lambda_client = boto3.client('lambda', region)\n", 497 | "\n", 498 | "# Get the function configuration\n", 499 | "response = lambda_client.get_function_configuration(FunctionName=text2sql_lambda_function_name)\n", 500 | "role_arn = response['Role']\n", 501 | "role_name = role_arn.split('/')[-1]\n", 502 | "\n", 503 | "# Policy ARNs to attach\n", 504 | "policy_arns = [\n", 505 | " 'arn:aws:iam::aws:policy/AmazonAthenaFullAccess',\n", 506 | " 'arn:aws:iam::aws:policy/AmazonS3FullAccess'\n", 507 | "]\n", 508 | "\n", 509 | "# Attach each policy\n", 510 | "for policy_arn in policy_arns:\n", 511 | " try:\n", 512 | " iam_client.attach_role_policy(\n", 513 | " RoleName=role_name,\n", 514 | " PolicyArn=policy_arn\n", 515 | " )\n", 516 | " print(f\"Successfully attached {policy_arn} to role {role_name}\")\n", 517 | " except Exception as e:\n", 518 | " print(f\"Error attaching {policy_arn}: {str(e)}\")\n", 519 | "\n", 520 | "# Verify attached policies\n", 521 | "try:\n", 522 | " response = iam_client.list_attached_role_policies(RoleName=role_name)\n", 523 | " print(\"\\nAttached policies:\")\n", 524 | " for policy in response['AttachedPolicies']:\n", 525 | " print(f\"- {policy['PolicyName']}\")\n", 526 | "except Exception as e:\n", 527 | " print(f\"Error listing policies: {str(e)}\")" 528 | ] 529 | }, 530 | { 531 | "cell_type": "markdown", 532 | "metadata": {}, 533 | "source": [ 534 | "### Invoke Text2SQL Agent Test Alias to see that it answers question properly" 535 | ] 536 | }, 537 | { 538 | "cell_type": "code", 539 | "execution_count": null, 540 | "metadata": {}, 541 | "outputs": [], 542 | "source": [ 543 | "%%time\n", 544 | "\n", 545 | "bedrock_agent_runtime_client = boto3.client(\"bedrock-agent-runtime\", region)\n", 546 | "\n", 547 | "session_id:str = str(uuid.uuid1())\n", 548 | "\n", 549 | "query = \"What is the highest eligible free rate for K-12 students in the schools in Alameda County?\"\n", 550 | "response = bedrock_agent_runtime_client.invoke_agent(\n", 551 | " inputText=query,\n", 552 | " agentId=text2sql_agent_id,\n", 553 | " agentAliasId=\"TSTALIASID\", \n", 554 | " sessionId=session_id,\n", 555 | " enableTrace=True, \n", 556 | " endSession=False,\n", 557 | " sessionState={}\n", 558 | ")\n", 559 | "\n", 560 | "print(\"Request sent to Agent\")\n", 561 | "print(\"====================\")\n", 562 | "print(\"Agent processing query now\")\n", 563 | "print(\"====================\")\n", 564 | "\n", 565 | "# Initialize an empty string to store the answer\n", 566 | "answer = \"\"\n", 567 | "\n", 568 | "# Iterate through the event stream\n", 569 | "for event in response['completion']:\n", 570 | " # Check if the event is a 'chunk' event\n", 571 | " if 'chunk' in event:\n", 572 | " chunk_obj = event['chunk']\n", 573 | " if 'bytes' in chunk_obj:\n", 574 | " # Decode the bytes and append to the answer\n", 575 | " chunk_data = chunk_obj['bytes'].decode('utf-8')\n", 576 | " answer += chunk_data\n", 577 | "\n", 578 | "# Now 'answer' contains the full response from the agent\n", 579 | "print(\"Agent Answer: {}\".format(answer))\n", 580 | "print(\"====================\")" 581 | ] 582 | }, 583 | { 584 | "cell_type": "markdown", 585 | "metadata": {}, 586 | "source": [ 587 | "### Now that agent has been tested, prepare it by creating an alias for use with evaluation framework" 588 | ] 589 | }, 590 | { 591 | "cell_type": "code", 592 | "execution_count": null, 593 | "metadata": {}, 594 | "outputs": [], 595 | "source": [ 596 | "text2sql_agent_alias_id, text2sql_agent_alias_arn = agents.create_agent_alias(\n", 597 | " text2sql_agent[0], 'v1'\n", 598 | ")\n", 599 | "\n", 600 | "text2sql_agent_alias_id, text2sql_agent_alias_arn" 601 | ] 602 | }, 603 | { 604 | "cell_type": "markdown", 605 | "metadata": {}, 606 | "source": [ 607 | "### Create input file for evaluation framework to evaluate agent's Text2SQL capabilities" 608 | ] 609 | }, 610 | { 611 | "cell_type": "markdown", 612 | "metadata": {}, 613 | "source": [ 614 | "#### Below is the option to specify the number of questions to generate. \n", 615 | "\n", 616 | "#### Default is 10, set to -1 to run through all questions, or specify to any other desired number" 617 | ] 618 | }, 619 | { 620 | "cell_type": "code", 621 | "execution_count": null, 622 | "metadata": {}, 623 | "outputs": [], 624 | "source": [ 625 | "num_questions = 10" 626 | ] 627 | }, 628 | { 629 | "cell_type": "code", 630 | "execution_count": null, 631 | "metadata": {}, 632 | "outputs": [], 633 | "source": [ 634 | "ATHENA_RESULTS_BUCKET_NAME = os.environ.get('ATHENA_RESULTS_BUCKET_NAME')\n", 635 | "DATABASE_NAME = os.environ.get('DATABASE_NAME')\n", 636 | "\n", 637 | "ATHENA_RESULTS_BUCKET_NAME, DATABASE_NAME" 638 | ] 639 | }, 640 | { 641 | "cell_type": "markdown", 642 | "metadata": {}, 643 | "source": [ 644 | "### Generate ground truth information for each question\n", 645 | "Please wait for all information to be generated before proceeding to next step\n", 646 | "\n", 647 | "Note: This script uses an LLM to generate a ground truth SQL query so not all questions will be able to have a ground truth, in that case the question is skipped" 648 | ] 649 | }, 650 | { 651 | "cell_type": "code", 652 | "execution_count": null, 653 | "metadata": {}, 654 | "outputs": [], 655 | "source": [ 656 | "import boto3\n", 657 | "import time\n", 658 | "import pandas as pd\n", 659 | "import json\n", 660 | "from typing import Dict, List, Any\n", 661 | "\n", 662 | "def run_query(query: str, athena_client):\n", 663 | " \"\"\"\n", 664 | " Run Athena query and return results as pandas DataFrame\n", 665 | " \"\"\"\n", 666 | " \n", 667 | " try:\n", 668 | " response = athena_client.start_query_execution(\n", 669 | " QueryString=query,\n", 670 | " QueryExecutionContext={\n", 671 | " 'Database': DATABASE_NAME\n", 672 | " },\n", 673 | " ResultConfiguration={\n", 674 | " 'OutputLocation': f's3://{ATHENA_RESULTS_BUCKET_NAME}/athena-results/'\n", 675 | " }\n", 676 | " )\n", 677 | " \n", 678 | " query_execution_id = response['QueryExecutionId']\n", 679 | " \n", 680 | " while True:\n", 681 | " response = athena_client.get_query_execution(QueryExecutionId=query_execution_id)\n", 682 | " state = response['QueryExecution']['Status']['State']\n", 683 | " \n", 684 | " if state in ['SUCCEEDED', 'FAILED', 'CANCELLED']:\n", 685 | " break\n", 686 | " \n", 687 | " time.sleep(1)\n", 688 | " \n", 689 | " if state == 'SUCCEEDED':\n", 690 | " response = athena_client.get_query_results(QueryExecutionId=query_execution_id)\n", 691 | " # Process data\n", 692 | " results = []\n", 693 | " for row in response['ResultSet']['Rows'][1:]: # Skip header row\n", 694 | " for field in row['Data']:\n", 695 | " value = field.get('VarCharValue', '')\n", 696 | " \n", 697 | " results.append(str(value))\n", 698 | " \n", 699 | " # Join results with commas\n", 700 | " return ', '.join(results)\n", 701 | " else:\n", 702 | " print(f\"Generated query failed with state: {state}\")\n", 703 | " return None\n", 704 | " \n", 705 | " except Exception as e:\n", 706 | " print(f\"Error running query: {e}\")\n", 707 | " return None\n", 708 | "\n", 709 | "def generate_ground_truth_answer(question_id: int, question: str, sql_query: str, \n", 710 | " sql_context: str, query_results: str) -> Dict:\n", 711 | " \"\"\"\n", 712 | " Generate ground truth answer using AWS Bedrock based on question and query results\n", 713 | " \"\"\"\n", 714 | " bedrock_runtime = boto3.client(\n", 715 | " service_name='bedrock-runtime'\n", 716 | " )\n", 717 | " \n", 718 | " # Construct prompt for Bedrock\n", 719 | " prompt = f\"\"\"You are generating ground truth answers that will be used to evaluate the factual correctness of Text2SQL agent responses.\n", 720 | "\n", 721 | "Question: {question}\n", 722 | "Query Results: {query_results}\n", 723 | "\n", 724 | "Generate a natural language answer that:\n", 725 | "1. States all numerical values and facts from the query results explicitly\n", 726 | "2. Uses consistent formatting for numbers (maintain exact precision from results)\n", 727 | "3. Includes all relevant values if multiple results are returned\n", 728 | "4. States the answer in a clear, declarative way that directly addresses the question\n", 729 | "5. Avoids additional interpretations or information not present in the query results\n", 730 | "\n", 731 | "Remember:\n", 732 | "- Focus only on the facts present in the query results\n", 733 | "- Use the exact numbers shown in the results\n", 734 | "- Structure the answer to make fact-checking straightforward\n", 735 | "- Be explicit about any percentages, counts, or measurements\n", 736 | "- Make sure every number in the query results is mentioned in your answer\n", 737 | "\n", 738 | "Your answer should be easy to compare with other responses for factual accuracy.\"\"\"\n", 739 | "\n", 740 | " # Create request body for Claude model\n", 741 | " body = json.dumps({\n", 742 | " \"anthropic_version\": \"bedrock-2023-05-31\",\n", 743 | " \"max_tokens\": 512,\n", 744 | " \"temperature\": 0.5,\n", 745 | " \"messages\": [\n", 746 | " {\n", 747 | " \"role\": \"user\",\n", 748 | " \"content\": [{\"type\": \"text\", \"text\": prompt}],\n", 749 | " }\n", 750 | " ],\n", 751 | " })\n", 752 | "\n", 753 | " try:\n", 754 | " # Call Bedrock\n", 755 | "\n", 756 | " response = None\n", 757 | "\n", 758 | " # Use cross-region inference profile\n", 759 | " response = bedrock_runtime.invoke_model(\n", 760 | " modelId='us.anthropic.claude-3-5-sonnet-20241022-v2:0', # or your preferred model\n", 761 | " body=body\n", 762 | " )\n", 763 | " \n", 764 | " # Parse response\n", 765 | " response_body = json.loads(response['body'].read())\n", 766 | " answer = response_body['content'][0]['text']\n", 767 | " \n", 768 | " # Format the response in the required structure\n", 769 | " formatted_response = {\n", 770 | " \"question_id\": question_id,\n", 771 | " \"question\": question,\n", 772 | " \"question_type\": \"TEXT2SQL\",\n", 773 | " \"ground_truth\": {\n", 774 | " \"ground_truth_sql_query\": sql_query,\n", 775 | " \"ground_truth_sql_context\": sql_context,\n", 776 | " \"ground_truth_query_result\": query_results,\n", 777 | " \"ground_truth_answer\": answer\n", 778 | " }\n", 779 | " }\n", 780 | " \n", 781 | " return formatted_response\n", 782 | " \n", 783 | " except Exception as e:\n", 784 | " print(f\"Error generating answer: {e}\")\n", 785 | " return None\n", 786 | "\n", 787 | " \n", 788 | "def get_schema(athena_client):\n", 789 | " \"\"\"\n", 790 | " Get schema information for all tables in Athena databases\n", 791 | " \"\"\"\n", 792 | "\n", 793 | " sql = f\"\"\"\n", 794 | " SELECT\n", 795 | " table_name,\n", 796 | " column_name,\n", 797 | " data_type\n", 798 | " FROM information_schema.columns\n", 799 | " WHERE table_schema = '{DATABASE_NAME}'\n", 800 | " ORDER BY table_name, ordinal_position;\n", 801 | " \"\"\"\n", 802 | " \n", 803 | " try:\n", 804 | " # Start query execution\n", 805 | " response = athena_client.start_query_execution(\n", 806 | " QueryString=sql,\n", 807 | " QueryExecutionContext={\n", 808 | " 'Database': DATABASE_NAME\n", 809 | " }\n", 810 | " )\n", 811 | " \n", 812 | " query_execution_id = response['QueryExecutionId']\n", 813 | " \n", 814 | " def wait_for_query_completion(query_execution_id):\n", 815 | " while True:\n", 816 | " response = athena_client.get_query_execution(\n", 817 | " QueryExecutionId=query_execution_id\n", 818 | " )\n", 819 | " state = response['QueryExecution']['Status']['State']\n", 820 | " \n", 821 | " if state in ['SUCCEEDED', 'FAILED', 'CANCELLED']:\n", 822 | " return state\n", 823 | " \n", 824 | " time.sleep(2)\n", 825 | " \n", 826 | " # Wait for query completion\n", 827 | " state = wait_for_query_completion(query_execution_id)\n", 828 | "\n", 829 | " if state == 'SUCCEEDED':\n", 830 | " # Get query results\n", 831 | " results = athena_client.get_query_results(\n", 832 | " QueryExecutionId=query_execution_id\n", 833 | " )\n", 834 | " # Assuming you have a database connection and cursor setup\n", 835 | " # cursor.execute(sql)\n", 836 | " # results = cursor.fetchall()\n", 837 | " \n", 838 | " database_structure = []\n", 839 | " table_dict = {}\n", 840 | "\n", 841 | " # Skip the header row\n", 842 | " rows = results['ResultSet']['Rows'][1:]\n", 843 | "\n", 844 | " for row in rows:\n", 845 | " # Extract values from the Data structure\n", 846 | " table_name = row['Data'][0]['VarCharValue']\n", 847 | " column_name = row['Data'][1]['VarCharValue']\n", 848 | " data_type = row['Data'][2]['VarCharValue']\n", 849 | " \n", 850 | " # Initialize table if not exists\n", 851 | " if table_name not in table_dict:\n", 852 | " table_dict[table_name] = []\n", 853 | " \n", 854 | " # Append column information\n", 855 | " table_dict[table_name].append((column_name, data_type))\n", 856 | "\n", 857 | " # Convert to the desired format\n", 858 | " for table_name, columns in table_dict.items():\n", 859 | " database_structure.append({\n", 860 | " \"table_name\": table_name,\n", 861 | " \"columns\": columns\n", 862 | " })\n", 863 | "\n", 864 | " return database_structure\n", 865 | "\n", 866 | " else:\n", 867 | " raise Exception(f\"Query to get schema failed with state: {state}\")\n", 868 | " except Exception as e:\n", 869 | " print(f\"Error getting schema: {e}\")\n", 870 | " raise\n", 871 | "\n", 872 | "def generate_dataset(input_file: str, output_file: str, athena_client, num_questions):\n", 873 | " \"\"\"\n", 874 | " Generate dataset with ground truth answers in trajectory format\n", 875 | " \"\"\"\n", 876 | " try:\n", 877 | " # Read input file\n", 878 | " with open(input_file, 'r') as f:\n", 879 | " questions_data = json.load(f)\n", 880 | " \n", 881 | " # Initialize trajectories dictionary\n", 882 | " trajectories = {}\n", 883 | "\n", 884 | " # Keep track of number of trajectories created\n", 885 | " num_trajectories = 0\n", 886 | " \n", 887 | " # Process each question\n", 888 | " for idx, item in enumerate(questions_data):\n", 889 | "\n", 890 | " if num_questions != -1:\n", 891 | " if num_trajectories == num_questions:\n", 892 | " break\n", 893 | " \n", 894 | " question_id = item.get('question_id', 0)\n", 895 | " question = item['question']\n", 896 | " sql_query = item['SQL']\n", 897 | " \n", 898 | " print(f\"\\nProcessing question {question_id}: {question}\")\n", 899 | " \n", 900 | " # Get table schema\n", 901 | " sql_context = get_schema(athena_client)\n", 902 | " # Run query\n", 903 | " query_results = run_query(sql_query.replace('`','\"'), athena_client)\n", 904 | " if query_results is not None:\n", 905 | " # Generate answer with formatted response\n", 906 | " response = generate_ground_truth_answer(\n", 907 | " question_id=question_id,\n", 908 | " question=question,\n", 909 | " sql_query=sql_query,\n", 910 | " sql_context=str(sql_context),\n", 911 | " query_results=query_results\n", 912 | " )\n", 913 | " \n", 914 | " if response:\n", 915 | " # Increment number of trajectories\n", 916 | " num_trajectories += 1\n", 917 | " \n", 918 | " # Create trajectory key\n", 919 | " trajectory_key = f\"Trajectory{num_trajectories}\"\n", 920 | " \n", 921 | " # Format the response for this trajectory\n", 922 | " trajectory_response = [response]\n", 923 | " \n", 924 | " # Add to trajectories dictionary\n", 925 | " trajectories[trajectory_key] = trajectory_response\n", 926 | " print(f\"Generated ground truth for question {question_id}\")\n", 927 | " else:\n", 928 | " # Don't increment number of trajectories\n", 929 | " print(\"Ground truth unable to be generated for this question, skipping\")\n", 930 | " continue\n", 931 | " \n", 932 | " # Write results to output file\n", 933 | " with open(output_file, 'w') as f:\n", 934 | " json.dump(trajectories, f, indent=2)\n", 935 | " \n", 936 | " print(f\"\\nProcessed {len(trajectories)} questions. Results saved to {output_file}\")\n", 937 | " \n", 938 | " except Exception as e:\n", 939 | " print(f\"Error generating dataset: {e}\")\n", 940 | "\n", 941 | "INPUT_FILE = \"birdsql_data.json\"\n", 942 | "OUTPUT_FILE = \"text2sql_data_file_auto.json\"\n", 943 | "\n", 944 | "athena_client = boto3.client('athena')\n", 945 | "\n", 946 | "generate_dataset(INPUT_FILE, OUTPUT_FILE,athena_client, num_questions)" 947 | ] 948 | }, 949 | { 950 | "cell_type": "markdown", 951 | "metadata": {}, 952 | "source": [ 953 | "## Create config.env that evaluation tool needs\n", 954 | "Note: Input Langfuse host and keys into the variables below" 955 | ] 956 | }, 957 | { 958 | "cell_type": "code", 959 | "execution_count": null, 960 | "metadata": {}, 961 | "outputs": [], 962 | "source": [ 963 | "user_input = f\"\"\"\n", 964 | "\n", 965 | "AGENT_ID=\"{text2sql_agent_id}\"\n", 966 | "AGENT_ALIAS_ID=\"{text2sql_agent_alias_id}\"\n", 967 | "\n", 968 | "DATA_FILE_PATH=\"blog_sample_agents/2-Sample-text2sql-agent/text2sql_data_file_auto.json\"\n", 969 | "\n", 970 | "LANGFUSE_PUBLIC_KEY=\"FILL_IN\"\n", 971 | "LANGFUSE_SECRET_KEY=\"FILL_IN\"\n", 972 | "LANGFUSE_HOST=\"FILL_IN\"\n", 973 | "\n", 974 | "\"\"\"" 975 | ] 976 | }, 977 | { 978 | "cell_type": "code", 979 | "execution_count": null, 980 | "metadata": {}, 981 | "outputs": [], 982 | "source": [ 983 | "import os\n", 984 | "from string import Template\n", 985 | "\n", 986 | "# Set the correct paths relative to current location\n", 987 | "base_dir = os.path.dirname(os.path.dirname(os.getcwd())) # Go up two levels\n", 988 | "template_file_path = os.path.join(base_dir, 'config.env.tpl')\n", 989 | "config_file_path = os.path.join(base_dir, 'config.env')\n", 990 | "\n", 991 | "# Read the template file from the Bedrock Agent Evaluation Framework\n", 992 | "with open(template_file_path, 'r') as template_file:\n", 993 | " template_content = template_file.read()\n", 994 | "\n", 995 | "\n", 996 | "# Convert template content and user input into dictionaries\n", 997 | "def parse_env_content(content):\n", 998 | " env_dict = {}\n", 999 | " for line in content.split('\\n'):\n", 1000 | " line = line.strip()\n", 1001 | " if line and not line.startswith('#'):\n", 1002 | " if '=' in line:\n", 1003 | " key, value = line.split('=', 1)\n", 1004 | " env_dict[key.strip()] = value.strip()\n", 1005 | " return env_dict\n", 1006 | "\n", 1007 | "template_dict = parse_env_content(template_content)\n", 1008 | "user_dict = parse_env_content(user_input)\n", 1009 | "\n", 1010 | "# Merge dictionaries, with user input taking precedence\n", 1011 | "final_dict = {**template_dict, **user_dict}\n", 1012 | "\n", 1013 | "# Create the config.env content\n", 1014 | "config_content = \"\"\n", 1015 | "for key, value in final_dict.items():\n", 1016 | " config_content += f\"{key}={value}\\n\"\n", 1017 | "\n", 1018 | "# Write to config.env file in the correct folder\n", 1019 | "with open(config_file_path, 'w') as config_file:\n", 1020 | " config_file.write(config_content)\n", 1021 | "\n", 1022 | "print(f\"config.env file has been created successfully in amazon-bedrock-agent-evaluation-framework!\")" 1023 | ] 1024 | }, 1025 | { 1026 | "cell_type": "markdown", 1027 | "metadata": {}, 1028 | "source": [ 1029 | "## Run [Bedrock Agent Evaluation Framework](https://github.com/aws-samples/amazon-bedrock-agent-evaluation-framework) to get results on the Text2SQL Agent!\n", 1030 | "Note: For some questions, the Text2SQL agent may not be able to generate an executable query, in that case an error trace will show in Langfuse" 1031 | ] 1032 | }, 1033 | { 1034 | "cell_type": "markdown", 1035 | "metadata": {}, 1036 | "source": [ 1037 | "![Langfuse Dashboard](../img/ts_langfuse_dashboard.png)" 1038 | ] 1039 | }, 1040 | { 1041 | "cell_type": "markdown", 1042 | "metadata": {}, 1043 | "source": [ 1044 | "![Trace Dashboard](../img/ts_trace_dashboard.png)" 1045 | ] 1046 | }, 1047 | { 1048 | "cell_type": "markdown", 1049 | "metadata": {}, 1050 | "source": [ 1051 | "![Trace](../img/ts_trace.png)" 1052 | ] 1053 | }, 1054 | { 1055 | "cell_type": "code", 1056 | "execution_count": null, 1057 | "metadata": {}, 1058 | "outputs": [], 1059 | "source": [ 1060 | "# Execute bash script to run evaluation\n", 1061 | "!cd .. && chmod +x execute_eval.sh && ./execute_eval.sh" 1062 | ] 1063 | }, 1064 | { 1065 | "cell_type": "markdown", 1066 | "metadata": {}, 1067 | "source": [ 1068 | "### Navigate to your Langfuse host address, open the relevant Langfuse project, and view the traces populated there during evaluation run" 1069 | ] 1070 | } 1071 | ], 1072 | "metadata": { 1073 | "kernelspec": { 1074 | "display_name": "conda_python3", 1075 | "language": "python", 1076 | "name": "conda_python3" 1077 | }, 1078 | "language_info": { 1079 | "codemirror_mode": { 1080 | "name": "ipython", 1081 | "version": 3 1082 | }, 1083 | "file_extension": ".py", 1084 | "mimetype": "text/x-python", 1085 | "name": "python", 1086 | "nbconvert_exporter": "python", 1087 | "pygments_lexer": "ipython3", 1088 | "version": "3.10.16" 1089 | } 1090 | }, 1091 | "nbformat": 4, 1092 | "nbformat_minor": 4 1093 | } 1094 | -------------------------------------------------------------------------------- /blog_sample_agents/3-Cleanup/cleanup_rag_agent.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "5353afe6-a969-4878-97a3-213a6206fb6c", 6 | "metadata": {}, 7 | "source": [ 8 | "# Sample RAG Agent Cleanup\n", 9 | "In this section we clean up the resources created for the sample RAG agent" 10 | ] 11 | }, 12 | { 13 | "cell_type": "markdown", 14 | "id": "b6dfeab2-3f2e-44cd-bd2a-4785312b47c7", 15 | "metadata": {}, 16 | "source": [ 17 | "#### Ensure the latest version of boto3 is shown below" 18 | ] 19 | }, 20 | { 21 | "cell_type": "code", 22 | "execution_count": null, 23 | "id": "5ce4f3a3-2b73-4bbc-a2ba-77fcdb87262a", 24 | "metadata": {}, 25 | "outputs": [], 26 | "source": [ 27 | "!pip freeze | grep boto3" 28 | ] 29 | }, 30 | { 31 | "cell_type": "markdown", 32 | "id": "608e1f20-13a0-4a1f-9f8d-a82cc7b2aa4c", 33 | "metadata": {}, 34 | "source": [ 35 | "#### Load in environment variables to notebook" 36 | ] 37 | }, 38 | { 39 | "cell_type": "code", 40 | "execution_count": null, 41 | "id": "3ae0ce5f-9bc5-44da-8348-5d6ba3ac0960", 42 | "metadata": {}, 43 | "outputs": [], 44 | "source": [ 45 | "# Retrieve import path\n", 46 | "%store -r IMPORTS_PATH\n", 47 | "\n", 48 | "# Retrieve account info\n", 49 | "%store -r region\n", 50 | "\n", 51 | "# Retrive relevant resources\n", 52 | "%store -r wiki_bucket_name\n", 53 | "%store -r wiki_kb_id" 54 | ] 55 | }, 56 | { 57 | "cell_type": "markdown", 58 | "id": "97a50653-c04e-416f-a316-321e91357ccd", 59 | "metadata": {}, 60 | "source": [ 61 | "#### Retrieve imports environment variable and bring libraries into notebook" 62 | ] 63 | }, 64 | { 65 | "cell_type": "code", 66 | "execution_count": null, 67 | "id": "a48fec72-d0b9-460a-b5d1-a462f5849c86", 68 | "metadata": {}, 69 | "outputs": [], 70 | "source": [ 71 | "%run $IMPORTS_PATH" 72 | ] 73 | }, 74 | { 75 | "cell_type": "markdown", 76 | "id": "cb3bcbc6-c01a-41e3-877d-32e44a49921b", 77 | "metadata": {}, 78 | "source": [ 79 | "#### Define Clients" 80 | ] 81 | }, 82 | { 83 | "cell_type": "code", 84 | "execution_count": null, 85 | "id": "5b334a10-51bb-4328-a05d-f0b490d05fb2", 86 | "metadata": {}, 87 | "outputs": [], 88 | "source": [ 89 | "agents = AgentsForAmazonBedrock()\n", 90 | "s3_client = boto3.client('s3', region)\n", 91 | "bedrock_agent_client = boto3.client('bedrock-agent', region)" 92 | ] 93 | }, 94 | { 95 | "cell_type": "markdown", 96 | "id": "76cd89f7-0fce-415c-b187-2fdea80eaabc", 97 | "metadata": {}, 98 | "source": [ 99 | "#### Destroy Sample RAG Agent" 100 | ] 101 | }, 102 | { 103 | "cell_type": "code", 104 | "execution_count": null, 105 | "id": "50b4e188-6b3e-498a-aad9-cd205b1756e8", 106 | "metadata": {}, 107 | "outputs": [], 108 | "source": [ 109 | "agents.delete_agent(\"sample-rag-agent\",delete_role_flag=True)" 110 | ] 111 | }, 112 | { 113 | "cell_type": "markdown", 114 | "id": "75cd4ee6-5b38-4196-8ec3-4566d6e8dcf1", 115 | "metadata": {}, 116 | "source": [ 117 | "#### Destroy S3 bucket" 118 | ] 119 | }, 120 | { 121 | "cell_type": "code", 122 | "execution_count": null, 123 | "id": "d14ce6a2-5bf4-49ef-b365-cdf526f90b94", 124 | "metadata": {}, 125 | "outputs": [], 126 | "source": [ 127 | " # First, delete all objects and versions in the bucket\n", 128 | "s3_resource = boto3.resource('s3')\n", 129 | "bucket = s3_resource.Bucket(wiki_bucket_name)\n", 130 | "\n", 131 | "# Delete all objects and their versions\n", 132 | "bucket.objects.all().delete()\n", 133 | "bucket.object_versions.all().delete()\n", 134 | "\n", 135 | "# Now delete the empty bucket\n", 136 | "s3_client.delete_bucket(Bucket=wiki_bucket_name)\n", 137 | "print(f\"Bucket {wiki_bucket_name} has been successfully deleted\")" 138 | ] 139 | }, 140 | { 141 | "cell_type": "markdown", 142 | "id": "d30277c6-3306-41c6-9dc5-b245ba485b9e", 143 | "metadata": {}, 144 | "source": [ 145 | "#### Destroy Knowledge Base" 146 | ] 147 | }, 148 | { 149 | "cell_type": "code", 150 | "execution_count": null, 151 | "id": "d8138843-a45c-4142-931b-db8113f11139", 152 | "metadata": {}, 153 | "outputs": [], 154 | "source": [ 155 | "bedrock_agent_client.delete_knowledge_base(knowledgeBaseId=wiki_kb_id)\n", 156 | "print(f\"Knowledge base {wiki_kb_id} has been successfully deleted\")" 157 | ] 158 | } 159 | ], 160 | "metadata": { 161 | "kernelspec": { 162 | "display_name": "conda_python3", 163 | "language": "python", 164 | "name": "conda_python3" 165 | }, 166 | "language_info": { 167 | "codemirror_mode": { 168 | "name": "ipython", 169 | "version": 3 170 | }, 171 | "file_extension": ".py", 172 | "mimetype": "text/x-python", 173 | "name": "python", 174 | "nbconvert_exporter": "python", 175 | "pygments_lexer": "ipython3", 176 | "version": "3.10.16" 177 | } 178 | }, 179 | "nbformat": 4, 180 | "nbformat_minor": 5 181 | } 182 | -------------------------------------------------------------------------------- /blog_sample_agents/3-Cleanup/cleanup_text2sql_agent.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "5353afe6-a969-4878-97a3-213a6206fb6c", 6 | "metadata": {}, 7 | "source": [ 8 | "# Sample Text2SQL Agent Cleanup\n", 9 | "In this section we clean up the resources created for the sample Text2SQL agent" 10 | ] 11 | }, 12 | { 13 | "cell_type": "markdown", 14 | "id": "b6dfeab2-3f2e-44cd-bd2a-4785312b47c7", 15 | "metadata": {}, 16 | "source": [ 17 | "#### Ensure the latest version of boto3 is shown below" 18 | ] 19 | }, 20 | { 21 | "cell_type": "code", 22 | "execution_count": null, 23 | "id": "5ce4f3a3-2b73-4bbc-a2ba-77fcdb87262a", 24 | "metadata": {}, 25 | "outputs": [], 26 | "source": [ 27 | "!pip freeze | grep boto3" 28 | ] 29 | }, 30 | { 31 | "cell_type": "markdown", 32 | "id": "608e1f20-13a0-4a1f-9f8d-a82cc7b2aa4c", 33 | "metadata": {}, 34 | "source": [ 35 | "#### Load in environment variables to notebook" 36 | ] 37 | }, 38 | { 39 | "cell_type": "code", 40 | "execution_count": null, 41 | "id": "3ae0ce5f-9bc5-44da-8348-5d6ba3ac0960", 42 | "metadata": {}, 43 | "outputs": [], 44 | "source": [ 45 | "# Retrieve import path\n", 46 | "%store -r IMPORTS_PATH\n", 47 | "\n", 48 | "# Retrieve account info\n", 49 | "%store -r region\n", 50 | "\n", 51 | "# Retrive relevant resources\n", 52 | "%store -r base_bucket_name\n", 53 | "%store -r athena_results_bucket_name\n", 54 | "%store -r athena_database_name\n", 55 | "%store -r text2sql_lambda_function_name" 56 | ] 57 | }, 58 | { 59 | "cell_type": "markdown", 60 | "id": "97a50653-c04e-416f-a316-321e91357ccd", 61 | "metadata": {}, 62 | "source": [ 63 | "#### Retrieve imports environment variable and bring libraries into notebook" 64 | ] 65 | }, 66 | { 67 | "cell_type": "code", 68 | "execution_count": null, 69 | "id": "a48fec72-d0b9-460a-b5d1-a462f5849c86", 70 | "metadata": {}, 71 | "outputs": [], 72 | "source": [ 73 | "%run $IMPORTS_PATH" 74 | ] 75 | }, 76 | { 77 | "cell_type": "markdown", 78 | "id": "cb3bcbc6-c01a-41e3-877d-32e44a49921b", 79 | "metadata": {}, 80 | "source": [ 81 | "#### Define Clients" 82 | ] 83 | }, 84 | { 85 | "cell_type": "code", 86 | "execution_count": null, 87 | "id": "5b334a10-51bb-4328-a05d-f0b490d05fb2", 88 | "metadata": {}, 89 | "outputs": [], 90 | "source": [ 91 | "agents = AgentsForAmazonBedrock()\n", 92 | "s3_client = boto3.client('s3', region)\n", 93 | "lambda_client = boto3.client('lambda', region)" 94 | ] 95 | }, 96 | { 97 | "cell_type": "markdown", 98 | "id": "76cd89f7-0fce-415c-b187-2fdea80eaabc", 99 | "metadata": {}, 100 | "source": [ 101 | "#### Destroy Sample Text2SQL Agent" 102 | ] 103 | }, 104 | { 105 | "cell_type": "code", 106 | "execution_count": null, 107 | "id": "50b4e188-6b3e-498a-aad9-cd205b1756e8", 108 | "metadata": {}, 109 | "outputs": [], 110 | "source": [ 111 | "agents.delete_agent(\"sample-text2sql-agent\",delete_role_flag=True)" 112 | ] 113 | }, 114 | { 115 | "cell_type": "markdown", 116 | "id": "97cb4b21-b689-4618-b5b1-2b4bccb53888", 117 | "metadata": {}, 118 | "source": [ 119 | "#### Destroy Athena database" 120 | ] 121 | }, 122 | { 123 | "cell_type": "code", 124 | "execution_count": null, 125 | "id": "fd871e69-4fd0-438f-b2d6-1272754b6c82", 126 | "metadata": {}, 127 | "outputs": [], 128 | "source": [ 129 | "def execute_query(query, database):\n", 130 | " athena_client = boto3.client('athena')\n", 131 | " return athena_client.start_query_execution(\n", 132 | " QueryString=query,\n", 133 | " QueryExecutionContext={'Database': database},\n", 134 | " ResultConfiguration={'OutputLocation': f's3://{athena_results_bucket_name}/'}\n", 135 | " )\n", 136 | "\n", 137 | "def delete_all_tables_and_database():\n", 138 | " athena_client = boto3.client('athena')\n", 139 | "\n", 140 | " # Get all tables\n", 141 | " tables = [t['Name'] for t in athena_client.list_table_metadata(\n", 142 | " CatalogName='AwsDataCatalog', # Add this parameter\n", 143 | " DatabaseName=athena_database_name\n", 144 | " )['TableMetadataList']]\n", 145 | "\n", 146 | " # Drop all tables\n", 147 | " for table in tables:\n", 148 | " execute_query(f\"DROP TABLE IF EXISTS `{table}`\", athena_database_name)\n", 149 | " print(f\"Dropped table: {table}\")\n", 150 | "\n", 151 | " # Drop the database\n", 152 | " execute_query(f\"DROP DATABASE IF EXISTS `{athena_database_name}`\", 'default')\n", 153 | " print(f\"Dropped database: {athena_database_name}\")\n", 154 | "\n", 155 | "delete_all_tables_and_database()" 156 | ] 157 | }, 158 | { 159 | "cell_type": "markdown", 160 | "id": "75cd4ee6-5b38-4196-8ec3-4566d6e8dcf1", 161 | "metadata": {}, 162 | "source": [ 163 | "#### Destroy S3 buckets" 164 | ] 165 | }, 166 | { 167 | "cell_type": "code", 168 | "execution_count": null, 169 | "id": "d14ce6a2-5bf4-49ef-b365-cdf526f90b94", 170 | "metadata": {}, 171 | "outputs": [], 172 | "source": [ 173 | "s3_resource = boto3.resource('s3')\n", 174 | "\n", 175 | "# Delete contents of base bucket\n", 176 | "bucket1 = s3_resource.Bucket(base_bucket_name)\n", 177 | "bucket1.objects.all().delete()\n", 178 | "bucket1.object_versions.all().delete()\n", 179 | "\n", 180 | "# Delete contents of Athena results bucket\n", 181 | "bucket2 = s3_resource.Bucket(athena_results_bucket_name)\n", 182 | "bucket2.objects.all().delete()\n", 183 | "bucket2.object_versions.all().delete()\n", 184 | "\n", 185 | "# Delete the empty buckets\n", 186 | "s3_client = boto3.client('s3')\n", 187 | "s3_client.delete_bucket(Bucket=base_bucket_name)\n", 188 | "s3_client.delete_bucket(Bucket=athena_results_bucket_name)\n", 189 | "\n", 190 | "print(f\"Bucket {base_bucket_name} has been successfully deleted\")\n", 191 | "print(f\"Bucket {athena_results_bucket_name} has been successfully deleted\")" 192 | ] 193 | }, 194 | { 195 | "cell_type": "markdown", 196 | "id": "e5ed5a54-add4-480c-8c35-820e323490ad", 197 | "metadata": {}, 198 | "source": [ 199 | "#### Destroy Lambda function" 200 | ] 201 | }, 202 | { 203 | "cell_type": "code", 204 | "execution_count": null, 205 | "id": "0ea468e1-dda9-4c8c-89ee-958e009f0980", 206 | "metadata": {}, 207 | "outputs": [], 208 | "source": [ 209 | "lambda_client.delete_function(FunctionName=text2sql_lambda_function_name)\n", 210 | "print(f\"Successfully deleted Lambda function: {text2sql_lambda_function_name}\")" 211 | ] 212 | } 213 | ], 214 | "metadata": { 215 | "kernelspec": { 216 | "display_name": "conda_python3", 217 | "language": "python", 218 | "name": "conda_python3" 219 | }, 220 | "language_info": { 221 | "codemirror_mode": { 222 | "name": "ipython", 223 | "version": 3 224 | }, 225 | "file_extension": ".py", 226 | "mimetype": "text/x-python", 227 | "name": "python", 228 | "nbconvert_exporter": "python", 229 | "pygments_lexer": "ipython3", 230 | "version": "3.10.16" 231 | } 232 | }, 233 | "nbformat": 4, 234 | "nbformat_minor": 5 235 | } 236 | -------------------------------------------------------------------------------- /blog_sample_agents/README.MD: -------------------------------------------------------------------------------- 1 | # Deploying Sample Agents for Evaluation 2 | 3 | You can choose one or both of the RAG and Text2SQL sample agent to try out evaluations 4 | 5 | ## Deployment Steps 6 | 7 | 1. Set up a Langfuse project using either the cloud version www.langfuse.com or the AWS self-hosted option https://github.com/aws-samples/deploy-langfuse-on-ecs-with-fargate/tree/main/langfuse-v3 8 | 9 | 2. If you are using the self-hosted option and want to see model costs then you must create a model definition in Langfuse for "us.anthropic.claude-3-5-sonnet-20241022-v2:0", instructions can be found here https://langfuse.com/docs/model-usage-and-cost#custom-model-definitions 10 | 11 | 3. Create a SageMaker notebook instance in your AWS account 12 | 13 | 4. Open a terminal and navigate to the SageMaker/ folder within the instance 14 | ```bash 15 | cd SageMaker/ 16 | ``` 17 | 18 | 5. Clone this repository 19 | ```bash 20 | git clone https://github.com/aws-samples/amazon-bedrock-agent-evaluation-framework 21 | ``` 22 | 23 | 6. Navigate to the repository and install the necessary requirements 24 | ```bash 25 | cd amazon-bedrock-agent-evaluation-framework/ 26 | pip3 install -r requirements.txt 27 | ``` 28 | 29 | 7. Go to the blog_sample_agents/ folder and navigate to 0-Notebook-environment/setup-environment.ipynb to set up your Jupyter environment 30 | 31 | 8. Choose the conda_python3 kernel for the SageMaker notebook 32 | 33 | 9. Follow the respective agent notebooks to deploy each agent and evaluate it with a benchmark dataset! 34 | 35 | 36 | ## RAG / Text2SQL Agent Setup 37 | 38 | 1. Run through the RAG/Text2SQL notebook to create the agents in your AWS account 39 | (WARNING: DUE TO NATURE OF SQL QUERIES OPTIMIZED FOR DIFFERENT DATABASE ENGINES, SOME MORE COMPLEX TEXT2SQL SAMPLE QUESTIONS MAY EITHER NOT WORK OR HAVE A LOW EVALUATION SCORE) 40 | 2. Check the langfuse console for traces and evaluation metrics (Refer to the 'Navigating the Langfuse Console' section in the root readme) -------------------------------------------------------------------------------- /blog_sample_agents/execute_eval.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | source ~/.bashrc 3 | 4 | echo "Starting script execution..." 5 | echo "--------------------------------------" 6 | 7 | echo "Loading bash configuration..." 8 | source ~/.bashrc 9 | echo "Bash configuration loaded" 10 | echo "--------------------------------------" 11 | 12 | echo "Changing to evaluation framework root directory..." 13 | cd .. 14 | echo "Directory changed successfully" 15 | echo "--------------------------------------" 16 | 17 | echo "Installing requirements..." 18 | pip3 install -r requirements.txt 19 | echo "Requirements installation complete" 20 | echo "--------------------------------------" 21 | 22 | echo "Running evaluation ..." 23 | python3 driver.py 24 | echo "--------------------------------------" 25 | echo "Script execution completed" 26 | -------------------------------------------------------------------------------- /blog_sample_agents/img/rag_langfuse_dashboard.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aws-samples/open-source-bedrock-agent-evaluation/a4254b324bc3c684e393d544197d4c92b0258465/blog_sample_agents/img/rag_langfuse_dashboard.png -------------------------------------------------------------------------------- /blog_sample_agents/img/rag_trace.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aws-samples/open-source-bedrock-agent-evaluation/a4254b324bc3c684e393d544197d4c92b0258465/blog_sample_agents/img/rag_trace.png -------------------------------------------------------------------------------- /blog_sample_agents/img/rag_trace_dashboard.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aws-samples/open-source-bedrock-agent-evaluation/a4254b324bc3c684e393d544197d4c92b0258465/blog_sample_agents/img/rag_trace_dashboard.png -------------------------------------------------------------------------------- /blog_sample_agents/img/ts_langfuse_dashboard.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aws-samples/open-source-bedrock-agent-evaluation/a4254b324bc3c684e393d544197d4c92b0258465/blog_sample_agents/img/ts_langfuse_dashboard.png -------------------------------------------------------------------------------- /blog_sample_agents/img/ts_trace.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aws-samples/open-source-bedrock-agent-evaluation/a4254b324bc3c684e393d544197d4c92b0258465/blog_sample_agents/img/ts_trace.png -------------------------------------------------------------------------------- /blog_sample_agents/img/ts_trace_dashboard.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aws-samples/open-source-bedrock-agent-evaluation/a4254b324bc3c684e393d544197d4c92b0258465/blog_sample_agents/img/ts_trace_dashboard.png -------------------------------------------------------------------------------- /config.env.tpl: -------------------------------------------------------------------------------- 1 | # Bedrock Models for Evaluation 2 | MODEL_ID_EVAL="us.anthropic.claude-3-5-haiku-20241022-v1:0" 3 | MODEL_ID_EVAL_COT="us.anthropic.claude-3-5-sonnet-20241022-v2:0" 4 | EMBEDDING_MODEL_ID="amazon.titan-embed-text-v2:0" 5 | 6 | # Model parameters 7 | MAX_TOKENS = 2048 8 | TEMPERATURE = 0 9 | TOP_P = 1 10 | 11 | # Bedrock Agent details 12 | AGENT_ID="" 13 | AGENT_ALIAS_ID="" 14 | 15 | # Trajectories to evaluate, place data file in data_files/ folder 16 | DATA_FILE_PATH="data_files/DATA_FILE_NAME" 17 | 18 | # Langfuse Project Setup 19 | LANGFUSE_PUBLIC_KEY="" 20 | LANGFUSE_SECRET_KEY="" 21 | LANGFUSE_HOST="" 22 | -------------------------------------------------------------------------------- /data_files/sample_data_file.json: -------------------------------------------------------------------------------- 1 | { 2 | "Trajectory0": [ 3 | { 4 | "question_id": 0, 5 | "question_type": "RAG", 6 | "question": "Was Abraham Lincoln the sixteenth President of the United States?", 7 | "ground_truth": "yes" 8 | } 9 | ], 10 | "Trajectory1": [ 11 | { 12 | "question_id": 1, 13 | "question": "What is the highest eligible free rate for K-12 students in the schools in Alameda County?", 14 | "question_type": "TEXT2SQL", 15 | "ground_truth": { 16 | "ground_truth_sql_query": "SELECT `Free Meal Count (K-12)` / `Enrollment (K-12)` FROM frpm WHERE `County Name` = 'Alameda' ORDER BY (CAST(`Free Meal Count (K-12)` AS REAL) / `Enrollment (K-12)`) DESC LIMIT 1", 17 | "ground_truth_sql_context": "[{'table_name': 'frpm', 'columns': [('cdscode', 'varchar'), ('academic year', 'varchar'), ('county code', 'varchar'), ('district code', 'integer'), ('school code', 'varchar'), ('county name', 'varchar'), ('district name', 'varchar'), ('school name', 'varchar'), ('district type', 'varchar'), ('school type', 'varchar'), ('educational option type', 'varchar'), ('nslp provision status', 'varchar'), ('charter school (y/n)', 'double'), ('charter school number', 'varchar'), ('charter funding type', 'varchar'), ('irc', 'double'), ('low grade', 'varchar'), ('high grade', 'varchar'), ('enrollment (k-12)', 'double'), ('free meal count (k-12)', 'double'), ('percent (%) eligible free (k-12)', 'double'), ('frpm count (k-12)', 'double'), ('percent (%) eligible frpm (k-12)', 'double'), ('enrollment (ages 5-17)', 'double'), ('free meal count (ages 5-17)', 'double'), ('percent (%) eligible free (ages 5-17)', 'double'), ('frpm count (ages 5-17)', 'double'), ('percent (%) eligible frpm (ages 5-17)', 'double'), ('2013-14 calpads fall 1 certification status', 'integer')]}, {'table_name': 'satscores', 'columns': [('cds', 'varchar'), ('rtype', 'varchar'), ('sname', 'varchar'), ('dname', 'varchar'), ('cname', 'varchar'), ('enroll12', 'integer'), ('numtsttakr', 'integer'), ('avgscrread', 'double'), ('avgscrmath', 'double'), ('avgscrwrite', 'double'), ('numge1500', 'double')]}, {'table_name': 'schools', 'columns': [('cdscode', 'varchar'), ('ncesdist', 'varchar'), ('ncesschool', 'varchar'), ('statustype', 'varchar'), ('county', 'varchar'), ('district', 'varchar'), ('school', 'varchar'), ('street', 'varchar'), ('streetabr', 'varchar'), ('city', 'varchar'), ('zip', 'varchar'), ('state', 'varchar'), ('mailstreet', 'varchar'), ('mailstrabr', 'varchar'), ('mailcity', 'varchar'), ('mailzip', 'varchar'), ('mailstate', 'varchar'), ('phone', 'varchar'), ('ext', 'varchar'), ('website', 'varchar'), ('opendate', 'varchar'), ('closeddate', 'varchar'), ('charter', 'double'), ('charternum', 'varchar'), ('fundingtype', 'varchar'), ('doc', 'varchar'), ('doctype', 'varchar'), ('soc', 'varchar'), ('soctype', 'varchar'), ('edopscode', 'varchar'), ('edopsname', 'varchar'), ('eilcode', 'varchar'), ('eilname', 'varchar'), ('gsoffered', 'varchar'), ('gsserved', 'varchar'), ('virtual', 'varchar'), ('magnet', 'double'), ('latitude', 'double'), ('longitude', 'double'), ('admfname1', 'varchar'), ('admlname1', 'varchar'), ('admemail1', 'varchar'), ('admfname2', 'varchar'), ('admlname2', 'varchar'), ('admemail2', 'varchar'), ('admfname3', 'varchar'), ('admlname3', 'varchar'), ('admemail3', 'varchar'), ('lastupdate', 'varchar')]}]", 18 | "ground_truth_query_result": "1.0", 19 | "ground_truth_answer": "The highest eligible free rate for K-12 students in schools in Alameda County is 1.0." 20 | } 21 | } 22 | ], 23 | "Trajectory2": [ 24 | { 25 | "question_id": 2, 26 | "question_type": "CUSTOM", 27 | "question": "Generate a bar chart of the top 5 gene biomarkers based on their p value and include their names in the x axis.", 28 | "ground_truth": "" 29 | } 30 | ] 31 | } 32 | -------------------------------------------------------------------------------- /driver.py: -------------------------------------------------------------------------------- 1 | import os 2 | import uuid 3 | import boto3 4 | import sys 5 | import json 6 | from typing import Dict, Any, List 7 | from evaluators.rag_evaluator import RAGEvaluator 8 | from evaluators.text2sql_evaluator import Text2SQLEvaluator 9 | from evaluators.custom_evaluator import CustomEvaluator 10 | from botocore.client import Config 11 | from helpers.agent_info_extractor import AgentInfoExtractor 12 | import time 13 | 14 | from dotenv import load_dotenv 15 | 16 | # Load environment variables from config.env 17 | load_dotenv('config.env') 18 | 19 | # Get environment variables 20 | 21 | #AGENT SETUP 22 | AGENT_ID = os.getenv('AGENT_ID') 23 | AGENT_ALIAS_ID = os.getenv('AGENT_ALIAS_ID') 24 | 25 | #LANGFUSE SETUP 26 | LANGFUSE_PUBLIC_KEY = os.getenv('LANGFUSE_PUBLIC_KEY') 27 | LANGFUSE_SECRET_KEY = os.getenv('LANGFUSE_SECRET_KEY') 28 | LANGFUSE_HOST = os.getenv('LANGFUSE_HOST') 29 | 30 | #MODEL HYPERPARAMETERS 31 | MAX_TOKENS = int(os.getenv('MAX_TOKENS')) 32 | TEMPERATURE = float(os.getenv('TEMPERATURE')) 33 | TOP_P = float(os.getenv('TOP_P')) 34 | 35 | #EVALUATION MODELS 36 | MODEL_ID_EVAL = os.getenv('MODEL_ID_EVAL') 37 | EMBEDDING_MODEL_ID = os.getenv('EMBEDDING_MODEL_ID') 38 | MODEL_ID_EVAL_COT = os.getenv('MODEL_ID_EVAL_COT') 39 | 40 | #DATA 41 | DATA_FILE_PATH = os.getenv('DATA_FILE_PATH') 42 | 43 | def setup_environment() -> None: 44 | """Setup environment variables for Langfuse""" 45 | langfuse_vars = { 46 | "LANGFUSE_PUBLIC_KEY": LANGFUSE_PUBLIC_KEY, 47 | "LANGFUSE_SECRET_KEY": LANGFUSE_SECRET_KEY, 48 | "LANGFUSE_HOST": LANGFUSE_HOST 49 | } 50 | for key, value in langfuse_vars.items(): 51 | os.environ[key] = value 52 | 53 | def get_config() -> Dict[str, Any]: 54 | """Get configuration settings""" 55 | 56 | # Create shared clients 57 | bedrock_config = Config( 58 | connect_timeout=120, 59 | read_timeout=120, 60 | retries={'max_attempts': 0} 61 | ) 62 | 63 | shared_clients = { 64 | 'bedrock_agent_client': boto3.client('bedrock-agent'), 65 | 'bedrock_agent_runtime': boto3.client( 66 | 'bedrock-agent-runtime', 67 | config=bedrock_config 68 | ), 69 | 'bedrock_runtime': boto3.client('bedrock-runtime') 70 | } 71 | 72 | return { 73 | 'AGENT_ID': AGENT_ID, 74 | 'AGENT_ALIAS_ID': AGENT_ALIAS_ID, 75 | 'MODEL_ID_EVAL': MODEL_ID_EVAL, 76 | 'EMBEDDING_MODEL_ID': EMBEDDING_MODEL_ID, 77 | 'TEMPERATURE': TEMPERATURE, 78 | 'MAX_TOKENS': MAX_TOKENS, 79 | 'MODEL_ID_EVAL_COT': MODEL_ID_EVAL_COT, 80 | 'TOP_P': TOP_P, 81 | 'ENABLE_TRACE': True, 82 | 'clients': shared_clients 83 | } 84 | 85 | 86 | def create_evaluator(eval_type: str, config: Dict[str, Any], 87 | agent_info: Dict[str, Any], data: Dict[str, Any], trace_id: str, 88 | session_id: str, trajectory_id: str) -> Any: 89 | """Create appropriate evaluator based on evaluation type""" 90 | evaluator_map = { 91 | 'RAG': RAGEvaluator, 92 | 'TEXT2SQL': Text2SQLEvaluator, 93 | 'CUSTOM': CustomEvaluator 94 | # Add other evaluator types here 95 | } 96 | 97 | evaluator_class = evaluator_map.get(eval_type) 98 | if not evaluator_class: 99 | raise ValueError(f"Unknown evaluation type: {eval_type}") 100 | 101 | return evaluator_class( 102 | config=config, 103 | agent_info=agent_info, 104 | eval_type=eval_type, 105 | question=data['question'], 106 | ground_truth=data['ground_truth'], 107 | trace_id=trace_id, 108 | session_id=session_id, 109 | trajectory_id = trajectory_id, 110 | question_id=data['question_id'] 111 | ) 112 | 113 | def run_evaluation(data_file: str) -> None: 114 | """Main evaluation function""" 115 | # Setup 116 | setup_environment() 117 | config = get_config() 118 | 119 | # Initialize clients and extractors 120 | extractor = AgentInfoExtractor(config['clients']['bedrock_agent_client']) 121 | agent_info = extractor.extract_agent_info(AGENT_ID, AGENT_ALIAS_ID) 122 | 123 | # Load and process data 124 | with open(data_file, 'r') as f: 125 | data_dict = json.load(f) 126 | 127 | #For each data file, go into each trajectory 128 | for trajectoryID, questions in data_dict.items(): 129 | #Iterate through all the questions in each trajectory 130 | 131 | # Create unqiue session ID for trajectory 132 | session_id = str(uuid.uuid4()) 133 | print(f"Session ID for {trajectoryID}: {session_id}") 134 | 135 | #go through each question in each trajectory 136 | for question in questions: 137 | #get the evaluation type for the question 138 | eval_type = question.get('question_type') 139 | question_id = question['question_id'] 140 | 141 | print(f"Running {trajectoryID} - {eval_type} - Q{question_id} evaluation") 142 | 143 | trace_id = str(uuid.uuid1()) 144 | 145 | try: 146 | evaluator = create_evaluator( 147 | eval_type=eval_type, 148 | config=config, 149 | agent_info=agent_info, 150 | data=question, 151 | trace_id=trace_id, 152 | session_id=session_id, 153 | trajectory_id= trajectoryID 154 | ) 155 | 156 | 157 | results = evaluator.run_evaluation() 158 | if results is None: 159 | print(f"Skipping {trajectoryID} question {question_id} due to evaluation failure") 160 | time.sleep(90) 161 | continue 162 | 163 | print(f"Successfully evaluated {trajectoryID} question {question_id}") 164 | # print(results) 165 | time.sleep(90) 166 | 167 | except Exception as e: 168 | print(f"Failed to evalute for {trajectoryID} question {question_id}: {str(e)}") 169 | #if not a bedrock error, continue to next question 170 | time.sleep(90) 171 | continue 172 | 173 | except KeyboardInterrupt: 174 | sys.exit(0) 175 | 176 | # Driver 177 | if __name__ == "__main__": 178 | #Name of the data file 179 | run_evaluation(DATA_FILE_PATH) -------------------------------------------------------------------------------- /evaluators/README.MD: -------------------------------------------------------------------------------- 1 | Information about evaluators here -------------------------------------------------------------------------------- /evaluators/cot_evaluator.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | from typing import List, Dict, Any, Optional, Tuple 3 | from datetime import datetime 4 | from langfuse import Langfuse 5 | import helpers.cot_helper as cot_helper 6 | import time 7 | import json 8 | import re 9 | 10 | class ToolEvaluator(ABC): 11 | def __init__(self, 12 | config: Dict[str, Any], 13 | agent_info: Dict[str, Any], 14 | eval_type: str, 15 | question: str, 16 | ground_truth: Any, 17 | trace_id: str, 18 | session_id: str, 19 | question_id: int, 20 | trajectory_id: str): 21 | """ 22 | Base class for tool evaluation 23 | 24 | Args: 25 | config (Dict[str, Any]): Configuration dictionary containing credentials and settings 26 | agent_info (Dict[str, Any]): Information about the agent being evaluated 27 | eval_type (str): Type of evaluation being performed 28 | question (str): Question to evaluate 29 | ground_truth (Any): Ground truth answer 30 | trace_id (str): Unique identifier for the evaluation trace 31 | session_id (str): Identifier for the evaluation session 32 | question_id (int): Identifier for the specific question 33 | """ 34 | self.config = config 35 | self.agent_info = agent_info 36 | self.eval_type = eval_type 37 | self.question = question 38 | self.ground_truth = ground_truth 39 | self.trace_id = trace_id 40 | self.session_id = session_id 41 | self.question_id = question_id 42 | self.trajectory_id = trajectory_id 43 | self.clients = config.get('clients', {}) 44 | self.langfuse = Langfuse() 45 | 46 | self._initialize_clients() 47 | 48 | @abstractmethod 49 | def _initialize_clients(self) -> None: 50 | """Initialize tool-specific clients""" 51 | pass 52 | 53 | @abstractmethod 54 | def invoke_agent(self, tries: int = 1) -> Tuple[Dict[str, Any], datetime]: 55 | """ 56 | Invoke the specific tool and process its response 57 | 58 | Args: 59 | tries (int): Number of retry attempts 60 | 61 | Returns: 62 | Tuple containing processed response and start time 63 | """ 64 | pass 65 | 66 | @abstractmethod 67 | def evaluate_response(self, metadata: Dict[str, Any]) -> Dict[str, Any]: 68 | """ 69 | Evaluate tool response using specified metrics 70 | 71 | Args: 72 | metadata (Dict[str, Any]): Metadata for evaluation 73 | 74 | Returns: 75 | Dict containing evaluation results 76 | """ 77 | pass 78 | 79 | def _add_agent_collaborators(self, agents_used, trimmed_orc_trace): 80 | 81 | for item in trimmed_orc_trace: 82 | # Check for invocationInput 83 | if 'invocationInput' in item: 84 | if 'agentCollaboratorInvocationInput' in item['invocationInput']: 85 | name = item['invocationInput']['agentCollaboratorInvocationInput']['agentCollaboratorName'] 86 | agents_used.add(name) 87 | 88 | # Check for observation with agentCollaboratorInvocationOutput 89 | if 'observation' in item: 90 | if 'agentCollaboratorInvocationOutput' in item['observation']: 91 | name = item['observation']['agentCollaboratorInvocationOutput']['agentCollaboratorName'] 92 | agents_used.add(name) 93 | 94 | return agents_used 95 | 96 | # print("Collaborators Used for Question: {}".format(collaborator_names)) 97 | 98 | def _create_trace(self) -> Any: 99 | """Create and initialize a Langfuse trace""" 100 | traj_num = re.findall(r'\d+$', self.trajectory_id)[0] 101 | 102 | return self.langfuse.trace( 103 | id=self.trace_id, 104 | session_id=self.session_id, 105 | input=self.question, 106 | name=f"T{traj_num}-Q{self.question_id}-{self.eval_type}", 107 | user_id=self.config['AGENT_ID'], 108 | tags=[self.eval_type, self.agent_info['agentModel'], self.agent_info['agentType']] 109 | ) 110 | 111 | 112 | def _handle_error(self, trace: Any, error: Exception, stage: str) -> None: 113 | """Handle and log errors during evaluation without raising""" 114 | traj_num = re.findall(r'\d+$', self.trajectory_id)[0] 115 | 116 | error_message = f"{stage} error: {str(error)}" 117 | trace.update( 118 | name=f"[ERROR] T{traj_num}-Q{self.question_id}-{self.eval_type}", 119 | metadata={"errorMessage": error_message}, 120 | output={"Agent Error": error_message}, 121 | tags=["ERROR"] 122 | ) 123 | print(f"Error in {stage}: {error_message}") 124 | 125 | 126 | def process_trace_step(self,trace_step): 127 | 128 | if 'orchestrationTrace' in trace_step: 129 | # print("This is an orchestration trace") 130 | 131 | trace = trace_step['orchestrationTrace'] 132 | 133 | orchestration_trace = { 134 | 'name': 'Orchestration', 135 | 'input': json.loads(trace['modelInvocationInput'].get('text', '{}')), 136 | 'output': json.loads(trace['modelInvocationOutput']['rawResponse'].get('content', '{}')), 137 | 'metadata': trace['modelInvocationOutput'].get('metadata', {}), 138 | 'trace_id': trace['modelInvocationInput'].get('traceId') 139 | } 140 | 141 | if 'observation' in trace and 'finalResponse' in trace['observation']: 142 | orchestration_trace['final_response'] = { 143 | 'text': trace['observation']['finalResponse'].get('text') 144 | } 145 | 146 | # print("Data: {}".format(orchestration_trace)) 147 | return orchestration_trace 148 | 149 | def combine_traces(self,full_trace): 150 | 151 | trace_ids = [] 152 | trace_steps = [] 153 | cur_dict = {} 154 | 155 | def find_trace_id(data): 156 | if isinstance(data, dict): 157 | # If traceId is directly in this dictionary, return it 158 | if 'traceId' in data: 159 | return data['traceId'] 160 | # Otherwise search through all values in the dictionary 161 | for value in data.values(): 162 | result = find_trace_id(value) 163 | if result: 164 | return result 165 | # If the value is a list, search through its elements 166 | elif isinstance(data, list): 167 | for item in data: 168 | result = find_trace_id(item) 169 | if result: 170 | return result 171 | return None 172 | 173 | #iterate through all the traces 174 | for cur_trace in full_trace: 175 | 176 | cur_trace_id = find_trace_id(cur_trace) 177 | #LOGIC FOR INITIALIZING NEW DICTIONARY 178 | #only for the first instsance of a single trace ID 179 | if cur_trace_id not in trace_ids: 180 | # print("Unique trace ID: {}".format(cur_trace_id)) 181 | #initialize new dict with the agent information 182 | if cur_dict: 183 | trace_steps.append(cur_dict) 184 | cur_dict = {} 185 | 186 | cur_dict = {key: value for key, value in cur_trace.items() if key != 'trace'} 187 | 188 | # print("Unique dict: {}".format(cur_dict)) 189 | trace_ids.append(cur_trace_id) 190 | 191 | #LOGIC FOR ADDING TO EXISTING DICTIOANRY 192 | #append to cur_dict what's in trace.anytracetype (orchestrationTrace) and put the whole thing in there 193 | 194 | if 'orchestrationTrace' in cur_trace['trace']: 195 | first_key = next(iter(cur_trace['trace']['orchestrationTrace'])) 196 | cur_dict[first_key] = cur_trace['trace']['orchestrationTrace'][first_key] 197 | 198 | if cur_dict: 199 | trace_steps.append(cur_dict) 200 | 201 | return trace_steps 202 | 203 | 204 | def run_evaluation(self) -> Dict[str, Any]: 205 | """Run the complete evaluation pipeline""" 206 | trace = self._create_trace() 207 | 208 | # Invoke try block 209 | try: 210 | 211 | # Invoke tool and get processed response 212 | full_trace, processed_response, agent_start_time = self.invoke_agent() 213 | 214 | #if there is no response, then raise an error 215 | if not processed_response or not processed_response.get('agent_answer'): 216 | self._handle_error(trace, Exception("Failed to get or process agent response"), "Agent Processing") 217 | return None 218 | 219 | trace.update( 220 | metadata={ 221 | "Ground Truth": self.ground_truth, 222 | str(self.eval_type + " Evaluation Model"): self.config['MODEL_ID_EVAL'], 223 | "Chain of Thought Evaluation Model": self.config['MODEL_ID_EVAL_COT'] 224 | }, 225 | output=processed_response['agent_answer'] 226 | ) 227 | 228 | # Evaluation try block 229 | try: 230 | 231 | # Eliminate unneeded information for COT evaluation 232 | orc_trace_full = [item['trace']['orchestrationTrace'] for item in full_trace if 'orchestrationTrace' in item['trace']] 233 | 234 | #Combine all the traces with the same trace ID 235 | trace_step_spans = self.combine_traces(full_trace) 236 | 237 | trimmed_orc_trace = [item['rationale']['text'] for item in orc_trace_full if 'rationale' in item] 238 | 239 | trace_steps = "" 240 | for i, item in enumerate(trimmed_orc_trace, 1): 241 | trace_steps += f"Step {i}: {item}\n" 242 | 243 | agents_used = {self.agent_info['agentName']} 244 | 245 | # Add collaborator agents if multi-agent in use 246 | if self.agent_info['agentType'] == "MULTI-AGENT": 247 | agents_used = self._add_agent_collaborators(agents_used, orc_trace_full) 248 | 249 | # Chain of thought processes whole agent trace + agent info 250 | cot_eval_results, cot_system_prompt = cot_helper.evaluate_cot(trace_steps, processed_response['agent_answer'],self.agent_info, self.clients['bedrock_runtime'], self.config['MODEL_ID_EVAL_COT']) 251 | 252 | # Create an evaluation generation 253 | agent_generation = trace.generation( 254 | name= "Agent Generation Information", 255 | input=[ 256 | {"role": "system", "content": self.agent_info['agentInstruction']}, 257 | {"role": "user", "content": self.question} 258 | ], 259 | model=self.agent_info['agentModel'], 260 | model_parameters={"temperature": self.config['TEMPERATURE']}, 261 | start_time=agent_start_time, 262 | metadata=processed_response.get('agent_generation_metadata') 263 | ) 264 | 265 | agent_generation.end( 266 | output=processed_response.get('agent_answer'), 267 | usage_details={ 268 | "input": processed_response.get('input_tokens'), 269 | "output": processed_response.get('output_tokens') 270 | } 271 | ) 272 | 273 | #CHAIN OF THOUGHT EVALUATION SECTION START 274 | 275 | # Create generation based on CoT output 276 | cot_generation = trace.generation( 277 | name="CoT Evaluation LLM-As-Judge Generation", 278 | input=[ 279 | {"role": "system", "content": cot_system_prompt}, 280 | {"role": "user", "content": self.question} 281 | ], 282 | output=cot_eval_results, 283 | metadata={"agents_used": agents_used, 'model_used': self.config['MODEL_ID_EVAL_COT']} 284 | ) 285 | 286 | 287 | for index, step in enumerate(trace_step_spans): 288 | 289 | # Create trace step spans 290 | subtrace_span = cot_generation.span( 291 | name="Agent Trace Step {}".format(index+1), 292 | input = step.get('modelInvocationInput'), 293 | output={'Model Raw Response': step.get('modelInvocationOutput', {}).get('rawResponse'), 294 | "Model Rationale": step.get('rationale')}, 295 | metadata = {"Model Output metadata": step.get('modelInvocationOutput', {}).get('metadata'), 296 | "Observation": step.get('observation')} 297 | ) 298 | 299 | subtrace_span.end() 300 | 301 | # Prevents trace spans from getting sent out of order 302 | time.sleep(1) 303 | 304 | cot_generation.end() 305 | 306 | #Send the scores of chain of thought evaluation 307 | for metric_name, value in cot_eval_results.items(): 308 | cot_generation.score( 309 | name=str("COT_" + metric_name), 310 | value=value['score'], 311 | comment = value['explanation'], 312 | ) 313 | 314 | #CHAIN OF THOUGHT EVALUATION END 315 | 316 | 317 | #AGENT EVALAUATION RESULTS START 318 | 319 | # Prepare metadata and evaluate 320 | evaluation_metadata = { 321 | 'question': self.question, 322 | 'ground_truth': self.ground_truth, 323 | 'agent_response': processed_response.get('agent_answer'), 324 | 'evaluation_metadata': processed_response.get('agent_generation_metadata'), 325 | **self.config 326 | } 327 | 328 | evaluation_results = self.evaluate_response(evaluation_metadata) 329 | 330 | # TODO: Make the logic better, stopgap solution to work with custom 331 | if self.eval_type != "CUSTOM": 332 | for metric_name, metric_info in evaluation_results['metrics_scores'].items(): 333 | trace.score(name=str(self.eval_type + "_" + metric_name), value=metric_info.get('score'), comment=metric_info.get('explanation')) 334 | 335 | # Update trace with final results 336 | return { 337 | 'question_id': self.question_id, 338 | 'question': self.question, 339 | 'ground_truth': self.ground_truth, 340 | 'agent_response': processed_response, 341 | 'evaluation_results': evaluation_results, 342 | 'trace_id': self.trace_id 343 | } 344 | 345 | except Exception as e: 346 | self._handle_error(trace, e, "Evaluation") 347 | return None 348 | 349 | except Exception as e: 350 | self._handle_error(trace, e, "Agent Invocation") 351 | return None 352 | 353 | except KeyboardInterrupt as e: 354 | self._handle_error(trace,e, "Manually Stopped Evaluation Job") 355 | raise KeyboardInterrupt -------------------------------------------------------------------------------- /evaluators/custom_evaluator.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, Any, List, Tuple 2 | import time 3 | from datetime import datetime 4 | from ragas import evaluate 5 | from evaluators.cot_evaluator import ToolEvaluator 6 | 7 | class CustomEvaluator(ToolEvaluator): 8 | def __init__(self, **kwargs): 9 | """ 10 | Initialize Custom Evaluator with all necessary components 11 | 12 | Args: 13 | **kwargs: Arguments passed to parent class 14 | """ 15 | super().__init__(**kwargs) 16 | 17 | def _initialize_clients(self) -> None: 18 | """Initialize evaluation-specific models using shared clients""" 19 | # Use shared clients 20 | self.bedrock_agent_client = self.clients['bedrock_agent_client'] 21 | self.bedrock_agent_runtime_client = self.clients['bedrock_agent_runtime'] 22 | self.bedrock_client = self.clients['bedrock_runtime'] 23 | 24 | def evaluate_response(self, metadata: Dict[str, Any]) -> Dict[str, Any]: 25 | """ 26 | Evaluate just COT with no scores returned 27 | 28 | Args: 29 | metadata (Dict[str, Any]): Evaluation metadata 30 | 31 | Returns: 32 | Dict containing evaluation results 33 | """ 34 | 35 | return {"metric_scores": {}} 36 | 37 | def invoke_agent(self, tries: int = 1) -> Tuple[Dict[str, Any], datetime]: 38 | """ 39 | Invoke the Custom tool and process its response with retry logic 40 | 41 | Args: 42 | trace_id (str): Unique identifier for the trace 43 | tries (int): Number of retry attempts 44 | 45 | Returns: 46 | Tuple of (processed_response, start_time) 47 | """ 48 | agent_start_time = datetime.now() 49 | max_retries = 3 50 | 51 | try: 52 | # Invoke agent 53 | raw_response = self.bedrock_agent_runtime_client.invoke_agent( 54 | inputText=self.question, 55 | agentId=self.config['AGENT_ID'], 56 | agentAliasId=self.config['AGENT_ALIAS_ID'], 57 | sessionId=self.session_id, 58 | enableTrace=self.config['ENABLE_TRACE'] 59 | ) 60 | 61 | # Process response 62 | agent_answer = None 63 | input_tokens = 0 64 | output_tokens = 0 65 | orc_trace_full = [] 66 | full_trace = [] 67 | 68 | for event in raw_response['completion']: 69 | if 'chunk' in event: 70 | agent_answer = event['chunk']['bytes'].decode('utf-8') 71 | 72 | elif "trace" in event: 73 | 74 | full_trace.append(event['trace']) 75 | 76 | trace_obj = event['trace']['trace'] 77 | 78 | 79 | if "orchestrationTrace" in trace_obj: 80 | orc_trace = trace_obj['orchestrationTrace'] 81 | # Add trace to full_trace object 82 | orc_trace_full.append(orc_trace) 83 | 84 | # Extract token usage 85 | if 'modelInvocationOutput' in orc_trace: 86 | usage = orc_trace['modelInvocationOutput']['metadata']['usage'] 87 | input_tokens += usage.get('inputTokens',0) 88 | output_tokens += usage.get('outputTokens',0) 89 | 90 | processed_response = { 91 | 'agent_generation_metadata': {'ResponseMetadata': raw_response.get('ResponseMetadata', {})}, 92 | 'agent_answer': agent_answer, 93 | 'input_tokens': input_tokens, 94 | 'output_tokens': output_tokens 95 | } 96 | 97 | return full_trace, processed_response, agent_start_time 98 | 99 | except Exception as e: 100 | if (hasattr(e, 'response') and 101 | 'Error' in e.response and 102 | e.response['Error'].get('Code') == 'throttlingException' and 103 | tries <= max_retries): 104 | 105 | wait_time = 30 * tries 106 | print(f"Throttling occurred. Attempt {tries} of {max_retries}. " 107 | f"Waiting {wait_time} seconds before retry...") 108 | time.sleep(wait_time) 109 | return self.invoke_agent(tries + 1) 110 | else: 111 | raise e -------------------------------------------------------------------------------- /evaluators/rag_evaluator.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, Any, List, Tuple 2 | import boto3 3 | import time 4 | import math 5 | from botocore.client import Config 6 | from datetime import datetime 7 | from langchain_aws.chat_models.bedrock import ChatBedrock 8 | from langchain_aws.embeddings.bedrock import BedrockEmbeddings 9 | from datasets import Dataset 10 | from ragas import evaluate 11 | from evaluators.cot_evaluator import ToolEvaluator 12 | from ragas.metrics import ( 13 | faithfulness, 14 | answer_relevancy, 15 | context_recall, 16 | answer_similarity 17 | ) 18 | 19 | class RAGEvaluator(ToolEvaluator): 20 | def __init__(self, **kwargs): 21 | """ 22 | Initialize RAG Evaluator with all necessary components 23 | 24 | Args: 25 | **kwargs: Arguments passed to parent class 26 | """ 27 | super().__init__(**kwargs) 28 | 29 | def _initialize_clients(self) -> None: 30 | """Initialize evaluation-specific models using shared clients""" 31 | # Use shared clients 32 | self.bedrock_agent_client = self.clients['bedrock_agent_client'] 33 | self.bedrock_agent_runtime_client = self.clients['bedrock_agent_runtime'] 34 | self.bedrock_client = self.clients['bedrock_runtime'] 35 | 36 | # Initialize evaluation models 37 | self.llm_for_evaluation = ChatBedrock( 38 | model_id=self.config['MODEL_ID_EVAL'], 39 | max_tokens=100000, 40 | client=self.bedrock_client # Use shared client 41 | ) 42 | 43 | self.bedrock_embeddings = BedrockEmbeddings( 44 | model_id=self.config['EMBEDDING_MODEL_ID'], 45 | client=self.bedrock_client # Use shared client 46 | ) 47 | 48 | def prepare_evaluation_dataset(self, metadata: Dict[str, Any]) -> Dataset: 49 | """ 50 | Prepare dataset for RAG evaluation 51 | 52 | Args: 53 | metadata (Dict[str, Any]): Evaluation metadata 54 | 55 | Returns: 56 | Dataset object ready for evaluation 57 | """ 58 | return Dataset.from_dict({ 59 | "question": [metadata['question']], 60 | "answer": [metadata['agent_response']], 61 | "contexts": [metadata['evaluation_metadata']['rag_contexts']], 62 | "ground_truth": [metadata['ground_truth']] 63 | }) 64 | 65 | def evaluate_response(self, metadata: Dict[str, Any]) -> Dict[str, Any]: 66 | """ 67 | Evaluate the RAG response using specified metrics 68 | 69 | Args: 70 | metadata (Dict[str, Any]): Evaluation metadata 71 | 72 | Returns: 73 | Dict containing evaluation results 74 | """ 75 | try: 76 | dataset = self.prepare_evaluation_dataset(metadata) 77 | evaluation_results = evaluate( 78 | dataset=dataset, 79 | metrics=[ 80 | faithfulness, 81 | answer_relevancy, 82 | context_recall, 83 | answer_similarity 84 | ], 85 | llm=self.llm_for_evaluation, 86 | embeddings=self.bedrock_embeddings 87 | ) 88 | 89 | except Exception as e: 90 | raise Exception("Error: {}".format(e)) 91 | 92 | # Check for NaN values in scores and throw error if found 93 | for metric, score in evaluation_results.scores[0].items(): 94 | if math.isnan(score): 95 | raise Exception("Empty score detected, RAGAS had issue evaluating") 96 | 97 | return { 98 | 'metrics_scores': { 99 | metric: {'score': score} for metric, score in evaluation_results.scores[0].items() 100 | } 101 | } 102 | 103 | 104 | def invoke_agent(self, tries: int = 1) -> Tuple[Dict[str, Any], datetime]: 105 | """ 106 | Invoke the RAG tool and process its response with retry logic 107 | 108 | Args: 109 | tries (int): Number of retry attempts 110 | 111 | Returns: 112 | Tuple of (processed_response, start_time) 113 | """ 114 | agent_start_time = datetime.now() 115 | max_retries = 3 116 | 117 | try: 118 | # Invoke agent 119 | raw_response = self.bedrock_agent_runtime_client.invoke_agent( 120 | inputText=self.question, 121 | agentId=self.config['AGENT_ID'], 122 | agentAliasId=self.config['AGENT_ALIAS_ID'], 123 | # Test that this works 124 | sessionId=self.session_id, 125 | # sessionId=self.trace_id, 126 | enableTrace=self.config['ENABLE_TRACE'] 127 | ) 128 | 129 | 130 | # Process response 131 | rag_contexts = [] 132 | agent_answer = None 133 | input_tokens = 0 134 | output_tokens = 0 135 | full_trace = [] 136 | 137 | for event in raw_response['completion']: 138 | if 'chunk' in event: 139 | agent_answer = event['chunk']['bytes'].decode('utf-8') 140 | 141 | elif "trace" in event: 142 | full_trace.append(event['trace']) 143 | trace_obj = event['trace']['trace'] 144 | # print(trace_obj) 145 | if "orchestrationTrace" in trace_obj: 146 | orc_trace = trace_obj['orchestrationTrace'] 147 | 148 | # Extract context from knowledge base lookup 149 | if 'observation' in orc_trace: 150 | obs_trace = orc_trace['observation'] 151 | if 'knowledgeBaseLookupOutput' in obs_trace: 152 | output_trace = obs_trace['knowledgeBaseLookupOutput'] 153 | if 'retrievedReferences' in output_trace: 154 | for ref in output_trace['retrievedReferences']: 155 | rag_contexts.append(ref['content']['text']) 156 | 157 | # Extract token usage 158 | if 'modelInvocationOutput' in orc_trace: 159 | usage = orc_trace['modelInvocationOutput']['metadata']['usage'] 160 | input_tokens += usage.get('inputTokens',0) 161 | output_tokens += usage.get('outputTokens',0) 162 | 163 | 164 | 165 | processed_response = { 166 | 'agent_generation_metadata': {'ResponseMetadata': raw_response.get('ResponseMetadata', {}), "rag_contexts": rag_contexts}, 167 | 'agent_answer': agent_answer, 168 | 'input_tokens': input_tokens, 169 | 'output_tokens': output_tokens 170 | } 171 | 172 | return full_trace, processed_response, agent_start_time 173 | 174 | except Exception as e: 175 | if (hasattr(e, 'response') and 176 | 'Error' in e.response and 177 | e.response['Error'].get('Code') == 'throttlingException' and 178 | tries <= max_retries): 179 | 180 | wait_time = 30 * tries 181 | print(f"Throttling occurred. Attempt {tries} of {max_retries}. " 182 | f"Waiting {wait_time} seconds before retry...") 183 | time.sleep(wait_time) 184 | return self.invoke_agent(tries + 1) 185 | else: 186 | raise e -------------------------------------------------------------------------------- /evaluators/text2sql_evaluator.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, Any, List, Tuple 2 | from datetime import datetime 3 | from langchain_aws.chat_models import ChatBedrock 4 | from ragas.llms import LangchainLLMWrapper 5 | import json 6 | import time 7 | from evaluators.cot_evaluator import ToolEvaluator 8 | 9 | class Text2SQLEvaluator(ToolEvaluator): 10 | def __init__(self, **kwargs): 11 | """Initialize Text2SQL Evaluator with all necessary components""" 12 | super().__init__(**kwargs) 13 | 14 | def _initialize_clients(self) -> None: 15 | """Initialize evaluation-specific models using shared clients""" 16 | self.bedrock_agent_runtime_client = self.clients['bedrock_agent_runtime'] 17 | self.bedrock_client = self.clients['bedrock_runtime'] 18 | 19 | # Initialize evaluation model 20 | self.bedrock_model = ChatBedrock( 21 | model_id=self.config['MODEL_ID_EVAL'], 22 | client=self.bedrock_client 23 | ) 24 | self.evaluator_llm = LangchainLLMWrapper(self.bedrock_model) 25 | 26 | def evaluate_response(self, metadata: Dict[str, Any]) -> Dict[str, Any]: 27 | """Evaluate Text2SQL response using LLM as judge with two key metrics""" 28 | try: 29 | 30 | evaluation_prompt = f"""You are an expert evaluator for Text2SQL systems. Evaluate the following response based on two key metrics. 31 | 32 | Question: {metadata['question']} 33 | Database Schema: {metadata['ground_truth']['ground_truth_sql_context']} 34 | 35 | Ground Truth SQL: {metadata['ground_truth']['ground_truth_sql_query']} 36 | Generated SQL: {metadata['evaluation_metadata']['agent_query']} 37 | 38 | Ground Truth Answer: {metadata['ground_truth']['ground_truth_answer']} 39 | Generated Answer: {metadata['agent_response']} 40 | 41 | Query Result: {metadata['ground_truth']['ground_truth_query_result']} 42 | 43 | Evaluate and provide scores (0-1) and explanations for these metrics: 44 | 45 | SQL Semantic Equivalence: Evaluate if the generated SQL would produce the same results as the ground truth SQL. 46 | Answer Correctness: Check if the generated answer correctly represents the query results and matches ground truth. 47 | 48 | Provide your evaluation in this exact JSON format: 49 | {{ 50 | "metrics_scores": {{ 51 | "sql_semantic_equivalence": {{ 52 | "score": numeric_value, 53 | "explanation": "Brief explanation of why this score was given" 54 | }}, 55 | "answer_correctness": {{ 56 | "score": numeric_value, 57 | "explanation": "Brief explanation of why this score was given" 58 | }} 59 | }} 60 | }} 61 | """ 62 | 63 | # Call LLM for evaluation, use Claude 3 Sonnet 64 | response = self.bedrock_client.invoke_model( 65 | modelId='anthropic.claude-3-sonnet-20240229-v1:0', 66 | body=json.dumps({ 67 | "anthropic_version": "bedrock-2023-05-31", 68 | "max_tokens": 1024, 69 | "temperature": 0, 70 | "messages": [ 71 | { 72 | "role": "user", 73 | "content": [{"type": "text", "text": evaluation_prompt}], 74 | } 75 | ], 76 | }) 77 | ) 78 | 79 | # Parse and return the evaluation 80 | evaluation = json.loads(json.loads(response['body'].read())['content'][0]['text']) 81 | return evaluation 82 | 83 | except Exception as e: 84 | raise Exception(f"error: {str(e)}") 85 | 86 | def invoke_agent(self, tries: int = 1) -> Tuple[Dict[str, Any], datetime]: 87 | """ 88 | Invoke the Text2SQL agent and process its response 89 | 90 | Args: 91 | question (str): Question to process 92 | trace_id (str): Unique identifier for the trace 93 | tries (int): Number of retry attempts 94 | 95 | Returns: 96 | Tuple of (processed_response, start_time) 97 | """ 98 | agent_start_time = datetime.now() 99 | max_retries = 3 100 | 101 | try: 102 | # Invoke agent 103 | raw_response = self.bedrock_agent_runtime_client.invoke_agent( 104 | inputText=self.question, 105 | agentId=self.config['AGENT_ID'], 106 | agentAliasId=self.config['AGENT_ALIAS_ID'], 107 | # Confirm that this works 108 | sessionId=self.session_id, 109 | enableTrace=self.config['ENABLE_TRACE'] 110 | ) 111 | 112 | # Process response 113 | agent_query = "" 114 | agent_answer = None 115 | end_event_received = False 116 | input_tokens = 0 117 | output_tokens = 0 118 | full_trace = [] 119 | 120 | 121 | for event in raw_response['completion']: 122 | if 'chunk' in event: 123 | data = event['chunk']['bytes'] 124 | agent_answer = data.decode('utf-8') 125 | end_event_received = True 126 | 127 | elif "trace" in event: 128 | full_trace.append(event['trace']) 129 | trace_obj = event['trace']['trace'] 130 | if "orchestrationTrace" in trace_obj: 131 | orc_trace = trace_obj['orchestrationTrace'] 132 | 133 | # Extract SQL query 134 | if 'invocationInput' in orc_trace: 135 | invoc_trace = orc_trace['invocationInput'] 136 | if 'actionGroupInvocationInput' in invoc_trace: 137 | action_trace = invoc_trace['actionGroupInvocationInput'] 138 | if 'apiPath' in action_trace: 139 | if action_trace['apiPath'] == "/queryredshift": 140 | agent_query = action_trace['parameters'][0]['value'] 141 | 142 | # Extract token usage if available 143 | if 'modelInvocationOutput' in orc_trace: 144 | usage = orc_trace['modelInvocationOutput']['metadata']['usage'] 145 | input_tokens += usage.get('inputTokens',0) 146 | output_tokens += usage.get('outputTokens',0) 147 | 148 | if not end_event_received: 149 | raise Exception("End event not received") 150 | 151 | processed_response = { 152 | 'agent_generation_metadata': { 153 | "agent_query": agent_query, 154 | 'ResponseMetadata': raw_response.get('ResponseMetadata', {}) 155 | }, 156 | 'agent_answer': agent_answer, 157 | 'input_tokens': input_tokens, 158 | 'output_tokens': output_tokens 159 | } 160 | 161 | return full_trace, processed_response, agent_start_time 162 | 163 | except Exception as e: 164 | if (hasattr(e, 'response') and 165 | 'Error' in e.response and 166 | e.response['Error'].get('Code') == 'throttlingException' and 167 | tries <= max_retries): 168 | 169 | wait_time = 30 * tries 170 | print(f"Throttling occurred. Attempt {tries} of {max_retries}. " 171 | f"Waiting {wait_time} seconds before retry...") 172 | time.sleep(wait_time) 173 | return self.invoke_agent(tries + 1) 174 | else: 175 | raise e -------------------------------------------------------------------------------- /helpers/README.MD: -------------------------------------------------------------------------------- 1 | Explain helper files here -------------------------------------------------------------------------------- /helpers/agent_info_extractor.py: -------------------------------------------------------------------------------- 1 | class AgentInfoExtractor: 2 | def __init__(self, bedrock_agent_client): 3 | self.client = bedrock_agent_client 4 | 5 | def get_agent_alias_version(self, agent_id, alias_id): 6 | """Get agent version from alias information""" 7 | alias_info = self.client.get_agent_alias(agentAliasId=alias_id, agentId=agent_id) 8 | return alias_info['agentAlias']['routingConfiguration'][0]['agentVersion'] 9 | 10 | def get_agent_name(self, agent_id): 11 | """Get agent name from agent information""" 12 | agent_info = self.client.get_agent(agentId=agent_id) 13 | return agent_info['agent']['agentName'] 14 | 15 | def get_agent_version_details(self, agent_id, agent_version): 16 | """Get detailed information about an agent version""" 17 | version_info = self.client.get_agent_version(agentId=agent_id, agentVersion=agent_version) 18 | 19 | 20 | #LOGIC FOR CHANGING THE MODEL NAME IN CASE OF USING CROSS-REGION REFERENCE 21 | model_id = version_info['agentVersion']['foundationModel'].split('/')[-1] 22 | if model_id.startswith("us."): 23 | # print("Changing model name from cross-region reference") 24 | model_id = model_id[3:] 25 | 26 | 27 | return { 28 | 'model_id': model_id, 29 | 'instruction': version_info['agentVersion']['instruction'], 30 | 'description': version_info['agentVersion']['description'] 31 | } 32 | 33 | def get_action_groups(self, agent_id, agent_version): 34 | """Get action groups for an agent""" 35 | return self.client.list_agent_action_groups( 36 | agentId=agent_id, 37 | agentVersion=agent_version 38 | )['actionGroupSummaries'] 39 | 40 | def create_agent_info(self, agent_id, alias_id, agent_type): 41 | """Create a standardized agent info dictionary""" 42 | agent_name = self.get_agent_name(agent_id) 43 | agent_version = self.get_agent_alias_version(agent_id, alias_id) 44 | version_details = self.get_agent_version_details(agent_id, agent_version) 45 | action_groups = self.get_action_groups(agent_id, agent_version) 46 | 47 | 48 | # Agent info in a dictionary 49 | agent_info = { 50 | "agentId": agent_id, 51 | "agentAlias": alias_id, 52 | "agentName": agent_name, 53 | "agentVersion": agent_version, 54 | "agentType": agent_type, 55 | "agentModel": version_details['model_id'], 56 | "agentDescription": version_details['description'], 57 | "agentInstruction": version_details['instruction'], 58 | "actionGroups": action_groups 59 | } 60 | 61 | # print("Agent info: {}".format(agent_info)) 62 | 63 | return agent_info 64 | 65 | def get_collaborator_info(self, agent_id, agent_version): 66 | """Get information about collaborator agents""" 67 | 68 | # Get whole list of collaborators 69 | collaborators = self.client.list_agent_collaborators( 70 | agentId=agent_id, 71 | agentVersion=agent_version 72 | )['agentCollaboratorSummaries'] 73 | 74 | #new dictionary of collaborator info consisting of: name, description, instruction 75 | collaborator_info = {} 76 | for collab in collaborators: 77 | 78 | # Create dictionary with specific collaborator's info 79 | collab_info = { 80 | "collaborationInstruction": collab['collaborationInstruction'], 81 | } 82 | 83 | # Use collab_name as key 84 | collaborator_info[collab['collaboratorName']] = collab_info 85 | 86 | return collaborator_info 87 | 88 | def extract_agent_info(self, agent_id, agent_alias_id): 89 | """Main method to extract all agent information""" 90 | agents_info = {} 91 | 92 | # Check if agent is collaborative 93 | is_collaborative = self.client.get_agent( 94 | agentId=agent_id 95 | )['agent']['agentCollaboration'] != "DISABLED" 96 | 97 | # Add main agent info 98 | agent_type = "MULTI-AGENT" if is_collaborative else "SINGLE-AGENT" 99 | 100 | # Create agent_info dictionary 101 | agents_info = self.create_agent_info(agent_id, agent_alias_id, agent_type) 102 | 103 | # Add collaborator info if collaborative 104 | if is_collaborative: 105 | agent_version = self.get_agent_alias_version(agent_id, agent_alias_id) 106 | 107 | # Create new key called 'collaborators' 108 | agents_info['collaborators'] = self.get_collaborator_info(agent_id, agent_version) 109 | 110 | else: 111 | # Create new key called 'collaborators' 112 | agents_info['collaborators'] = None 113 | 114 | return agents_info -------------------------------------------------------------------------------- /helpers/cot_helper.py: -------------------------------------------------------------------------------- 1 | from langchain_aws import ChatBedrock 2 | from langchain.prompts import PromptTemplate 3 | import json 4 | 5 | # Goal: Evaluate agent CoT using LLM-as-judge and output results 6 | def evaluate_cot(agent_cot:str, agent_response:str, agent_info:list, client, MODEL_ID_EVAL_COT): 7 | 8 | # Clean inputs to template 9 | agent_instructions = agent_info['agentInstruction'] 10 | collaborator_instructions = agent_info['collaborators'] 11 | # clean_agent_cot = agent_cot 12 | 13 | # Initialize Bedrock client 14 | llm = ChatBedrock(model = MODEL_ID_EVAL_COT, client=client) 15 | 16 | system_prompt_template = PromptTemplate( 17 | 18 | input_variables=["agent_instructions", "collaborator_instructions", "agent_cot"], 19 | template=""" 20 | You are an expert evaluator analyzing AI Agent execution. Evaluate the agent's chain of thought performance on three key metrics: 21 | 22 | Agent Instructions: 23 | {agent_instructions} 24 | 25 | Collaboration Context: 26 | {collaborator_instructions} 27 | 28 | Agent Chain-of-Thought: 29 | {agent_cot} 30 | 31 | Final Agent Response: 32 | {agent_response} 33 | 34 | Evaluate on these three critical aspects: 35 | 36 | Helpfulness: How well does the execution satisfy explicit and implicit expectations? 37 | - Is it sensible, coherent, and clear? 38 | - Does it solve the task effectively? 39 | - Does it follow instructions? 40 | - Is it appropriately specific/general? 41 | - Does it anticipate user needs? 42 | 43 | Faithfulness: Does the execution stick to available information and context? 44 | - Does it avoid making unfounded claims? 45 | - Does it stay within the scope of given information? 46 | - Are conclusions properly supported? 47 | - Does it avoid contradicting the context? 48 | 49 | Instruction Following: Does it respect all explicit directions? 50 | - Follows specific requirements 51 | - Adheres to given constraints 52 | - Respects defined boundaries 53 | - Completes all requested steps 54 | 55 | Output your evaluation in the following Python dictionary format: 56 | 57 | {{ 58 | "helpfulness": {{ 59 | "score": , 60 | "explanation": "" 61 | }}, 62 | 63 | "faithfulness": {{ 64 | "score": , 65 | "explanation": "" 66 | }}, 67 | 68 | "instruction_following": {{ 69 | "score": , 70 | "explanation": "" 71 | }}, 72 | 73 | "overall": {{ 74 | "score": , 75 | "explanation": "" 76 | }} 77 | }} 78 | 79 | Provide clear, concise explanations focusing on specific examples from the execution. Ensure your output is a valid Python dictionary that can be parsed directly. Do not include any text before or after the dictionary. 80 | """ 81 | ) 82 | 83 | # Format the system prompt with function parameters 84 | system_prompt = system_prompt_template.format(agent_cot=agent_cot, agent_instructions=agent_instructions, collaborator_instructions=collaborator_instructions, agent_response=agent_response) 85 | 86 | # Create the messages list 87 | messages = [ 88 | {"role": "system", "content": system_prompt}, 89 | {"role": "user", "content": "Please generate the chain-of-thought evaluation as specified."} 90 | ] 91 | 92 | # Invoke the model to get CoT evaluation 93 | response = llm.invoke(messages) 94 | 95 | # Convert model response to dictionary 96 | eval_results = json.loads(response.content) 97 | 98 | def clean_prompt_indentation(prompt_string): 99 | # Split into lines and strip leading/trailing whitespace 100 | lines = prompt_string.split('\n') 101 | cleaned_lines = [line.strip() for line in lines] 102 | # Rejoin with newlines 103 | return '\n'.join(cleaned_lines) 104 | 105 | system_prompt = clean_prompt_indentation(system_prompt) 106 | 107 | return eval_results, system_prompt -------------------------------------------------------------------------------- /img/evaluation_workflow.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aws-samples/open-source-bedrock-agent-evaluation/a4254b324bc3c684e393d544197d4c92b0258465/img/evaluation_workflow.png -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | aiohappyeyeballs==2.4.3 2 | aiohttp==3.11.2 3 | aiosignal==1.3.1 4 | annotated-types==0.7.0 5 | anyio==4.6.2.post1 6 | appdirs==1.4.4 7 | attrs==24.2.0 8 | backoff==2.2.1 9 | boto3==1.36.4 10 | botocore==1.36.4 11 | certifi==2024.8.30 12 | charset-normalizer==3.4.0 13 | click==8.1.7 14 | dataclasses-json==0.6.7 15 | datasets==3.1.0 16 | dill==0.3.8 17 | distro==1.9.0 18 | docopt==0.6.2 19 | filelock==3.16.1 20 | frozenlist==1.5.0 21 | fsspec==2024.9.0 22 | h11==0.16.0 23 | httpcore==1.0.7 24 | httpx==0.27.2 25 | httpx-sse==0.4.0 26 | huggingface-hub==0.26.2 27 | idna==3.10 28 | jiter==0.7.1 29 | jmespath==1.0.1 30 | joblib==1.4.2 31 | jsonpatch==1.33 32 | jsonpointer==3.0.0 33 | langchain==0.3.7 34 | langchain-aws==0.2.7 35 | langchain-community==0.3.7 36 | langchain-core==0.3.18 37 | langchain-openai==0.2.8 38 | langchain-text-splitters==0.3.2 39 | langfuse==2.54.1 40 | langsmith==0.1.143 41 | marshmallow==3.23.1 42 | multidict==6.1.0 43 | multiprocess==0.70.16 44 | mypy-extensions==1.0.0 45 | nest-asyncio==1.6.0 46 | nltk==3.9.1 47 | numpy==1.26.4 48 | openai==1.54.4 49 | orjson==3.10.11 50 | packaging==24.2 51 | pandas==2.2.3 52 | pipreqs==0.4.13 53 | propcache==0.2.0 54 | pyarrow==18.0.0 55 | pydantic==2.9.2 56 | pydantic-settings==2.6.1 57 | pydantic_core==2.23.4 58 | pysbd==0.3.4 59 | python-dateutil==2.9.0.post0 60 | python-dotenv==1.0.1 61 | pytz==2024.2 62 | PyYAML==6.0.2 63 | ragas==0.2.13 64 | regex==2024.11.6 65 | requests==2.32.3 66 | requests-toolbelt==1.0.0 67 | s3transfer==0.11.1 68 | six==1.16.0 69 | sniffio==1.3.1 70 | SQLAlchemy==2.0.35 71 | tenacity==9.0.0 72 | tiktoken==0.8.0 73 | tqdm==4.67.0 74 | typing-inspect==0.9.0 75 | typing_extensions==4.12.2 76 | tzdata==2024.2 77 | urllib3==2.2.3 78 | wrapt==1.16.0 79 | xxhash==3.5.0 80 | yarg==0.1.10 81 | yarl==1.17.1 82 | --------------------------------------------------------------------------------