├── .gitignore ├── .pre-commit-config.yaml ├── .style.yapf ├── CODE_OF_CONDUCT.md ├── CONTRIBUTING.md ├── LICENSE ├── README.md ├── app.py ├── cdk.json ├── cdk └── cdk_stack.py ├── create_api_key_secrets.sh ├── lambda └── websearch_lambda.py ├── pyproject.toml ├── requirements.txt ├── test └── invoke-agent.py └── v2 ├── .gitignore ├── README.md ├── app.py ├── cdk.json ├── cdk ├── __init__.py ├── bedrock_agent.py ├── cdk_stack.py ├── lambda_functions.py ├── log_groups.py ├── utils.py └── websearch_lambda_layers.py ├── config.yaml ├── lambda ├── advanced_web_search │ ├── advanced_web_search_lambda.py │ ├── llm_operations.py │ ├── models.py │ ├── search_operations_tavily.py │ ├── strut_output_bedrock.py │ ├── test_advanced_web_search.py │ └── utils.py └── websearch │ ├── config.py │ ├── exceptions.py │ ├── google_search.py │ ├── models.py │ ├── response_formatter.py │ ├── search_providers.py │ ├── test_websearch.py │ ├── utils.py │ ├── validators.py │ └── websearch_lambda.py └── lambda_layer └── advanced_websearch_libraries └── requirements.txt /.gitignore: -------------------------------------------------------------------------------- 1 | cdk.out 2 | **/__pycache__/** 3 | cdk/__pychache__/** 4 | env.sh 5 | .aider* 6 | .env 7 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | # See https://pre-commit.com for more information 2 | # See https://pre-commit.com/hooks.html for more hooks 3 | --- 4 | repos: 5 | - repo: https://github.com/pre-commit/pre-commit-hooks 6 | rev: v4.6.0 7 | hooks: 8 | - id: trailing-whitespace 9 | - id: end-of-file-fixer 10 | - id: debug-statements 11 | 12 | - repo: https://github.com/pycqa/flake8 13 | rev: 7.1.0 14 | hooks: 15 | - id: flake8 16 | additional_dependencies: [pep8-naming] 17 | args: 18 | - "--max-line-length=140" 19 | 20 | - repo: https://github.com/ambv/black 21 | rev: 24.4.2 22 | hooks: 23 | - id: black 24 | language_version: python3.12 25 | args: ["--line-length", "140"] 26 | 27 | # - repo: https://github.com/pre-commit/mirrors-yapf 28 | # rev: v0.32.0 29 | # hooks: 30 | # - id: yapf 31 | # args: ["--exclude", "cdk.out"] 32 | # additional_dependencies: [toml] 33 | 34 | - repo: https://github.com/pycqa/isort 35 | rev: 5.13.2 36 | hooks: 37 | - id: isort 38 | name: isort (python) 39 | 40 | - repo: https://github.com/trailofbits/pip-audit 41 | rev: v2.7.3 42 | hooks: 43 | - id: pip-audit 44 | 45 | - repo: https://github.com/adrienverge/yamllint.git 46 | rev: v1.35.1 47 | hooks: 48 | - id: yamllint 49 | 50 | - repo: https://github.com/PyCQA/bandit 51 | rev: 1.7.9 52 | hooks: 53 | - id: bandit 54 | args: ["-c", "pyproject.toml", "-r", "."] 55 | additional_dependencies: ["bandit[toml]"] 56 | -------------------------------------------------------------------------------- /.style.yapf: -------------------------------------------------------------------------------- 1 | [style] 2 | based_on_style = pep8 3 | column_limit = 130 4 | -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | ## Code of Conduct 2 | This project has adopted the [Amazon Open Source Code of Conduct](https://aws.github.io/code-of-conduct). 3 | For more information see the [Code of Conduct FAQ](https://aws.github.io/code-of-conduct-faq) or contact 4 | opensource-codeofconduct@amazon.com with any additional questions or comments. 5 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing Guidelines 2 | 3 | Thank you for your interest in contributing to our project. Whether it's a bug report, new feature, correction, or additional 4 | documentation, we greatly value feedback and contributions from our community. 5 | 6 | Please read through this document before submitting any issues or pull requests to ensure we have all the necessary 7 | information to effectively respond to your bug report or contribution. 8 | 9 | 10 | ## Reporting Bugs/Feature Requests 11 | 12 | We welcome you to use the GitHub issue tracker to report bugs or suggest features. 13 | 14 | When filing an issue, please check existing open, or recently closed, issues to make sure somebody else hasn't already 15 | reported the issue. Please try to include as much information as you can. Details like these are incredibly useful: 16 | 17 | * A reproducible test case or series of steps 18 | * The version of our code being used 19 | * Any modifications you've made relevant to the bug 20 | * Anything unusual about your environment or deployment 21 | 22 | 23 | ## Contributing via Pull Requests 24 | Contributions via pull requests are much appreciated. Before sending us a pull request, please ensure that: 25 | 26 | 1. You are working against the latest source on the *main* branch. 27 | 2. You check existing open, and recently merged, pull requests to make sure someone else hasn't addressed the problem already. 28 | 3. You open an issue to discuss any significant work - we would hate for your time to be wasted. 29 | 30 | To send us a pull request, please: 31 | 32 | 1. Fork the repository. 33 | 2. Modify the source; please focus on the specific change you are contributing. If you also reformat all the code, it will be hard for us to focus on your change. 34 | 3. Ensure local tests pass. 35 | 4. Commit to your fork using clear commit messages. 36 | 5. Send us a pull request, answering any default questions in the pull request interface. 37 | 6. Pay attention to any automated CI failures reported in the pull request, and stay involved in the conversation. 38 | 39 | GitHub provides additional document on [forking a repository](https://help.github.com/articles/fork-a-repo/) and 40 | [creating a pull request](https://help.github.com/articles/creating-a-pull-request/). 41 | 42 | 43 | ## Finding contributions to work on 44 | Looking at the existing issues is a great way to find something to contribute on. As our projects, by default, use the default GitHub issue labels (enhancement/bug/duplicate/help wanted/invalid/question/wontfix), looking at any 'help wanted' issues is a great place to start. 45 | 46 | 47 | ## Code of Conduct 48 | This project has adopted the [Amazon Open Source Code of Conduct](https://aws.github.io/code-of-conduct). 49 | For more information see the [Code of Conduct FAQ](https://aws.github.io/code-of-conduct-faq) or contact 50 | opensource-codeofconduct@amazon.com with any additional questions or comments. 51 | 52 | 53 | ## Security issue notifications 54 | If you discover a potential security issue in this project we ask that you notify AWS/Amazon Security via our [vulnerability reporting page](http://aws.amazon.com/security/vulnerability-reporting/). Please do **not** create a public github issue. 55 | 56 | 57 | ## Licensing 58 | 59 | See the [LICENSE](LICENSE) file for our project's licensing. We will ask you to confirm the licensing of your contribution. 60 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT No Attribution 2 | 3 | Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy of 6 | this software and associated documentation files (the "Software"), to deal in 7 | the Software without restriction, including without limitation the rights to 8 | use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of 9 | the Software, and to permit persons to whom the Software is furnished to do so. 10 | 11 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 12 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS 13 | FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR 14 | COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER 15 | IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN 16 | CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 17 | 18 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Web Search Agent with Amazon Bedrock 2 | 3 | This repository contains the code for implementing a web search agent using Amazon Bedrock, as described in the AWS blog post "How to use Amazon Bedrock agents with a web search API to integrate dynamic web content in your generative AI application." 4 | 5 | ## Overview 6 | 7 | This project demonstrates how to create an AI agent that can perform web searches using Amazon Bedrock. It integrates two web search APIs (SerpAPI and Tavily AI) and showcases how to build and deploy the agent using both the AWS Management Console and AWS CDK. 8 | 9 | ## Features 10 | 11 | - Integration with Amazon Bedrock for AI agent creation 12 | - Web search functionality using SerpAPI (Google Search) and Tavily AI 13 | - AWS CDK deployment script for infrastructure as code 14 | - Lambda function for handling web search requests 15 | - Example of how to use AWS Secrets Manager for API key storage 16 | 17 | ## Prerequisites 18 | 19 | - An active AWS Account 20 | - [AWS CDK](https://docs.aws.amazon.com/cdk/v2/guide/getting_started.html#getting_started_install) version 2.174.3 or later 21 | - Python 3.11 or later 22 | - API keys for [SerpAPI](https://serpapi.com/) and [Tavily AI](https://tavily.com/) 23 | 24 | ## Setup 25 | 26 | 1. Clone this repository: 27 | ``` 28 | git clone https://github.com/aws-samples/websearch_agent 29 | ``` 30 | 31 | 2. Create and activate a Python virtual environment: 32 | ``` 33 | python -m venv .venv 34 | source .venv/bin/activate 35 | ``` 36 | 37 | 3. Install the required dependencies: 38 | ``` 39 | pip install -r requirements.txt 40 | ``` 41 | 42 | 4. Store your API keys in AWS Secrets Manager: 43 | ``` 44 | aws secretsmanager create-secret --name SERPER_API_KEY --secret-string "your_serper_api_key" 45 | aws secretsmanager create-secret --name TAVILY_API_KEY --secret-string "your_tavily_api_key" 46 | ``` 47 | 48 | 5. Deploy the CDK stack: 49 | ``` 50 | cdk deploy 51 | ``` 52 | 53 | ## Usage 54 | 55 | After deployment, you can test the agent using the provided Python script: 56 | 57 | ``` 58 | python invoke-agent.py --agent_id --agent_alias_id --prompt "What are the latest AWS news?" 59 | ``` 60 | 61 | Replace `` and `` with the values output by the CDK deployment. 62 | 63 | ## Project Structure 64 | 65 | - `/app.py`: Top-level definition of the AWS CDK app 66 | - `/cdk/`: Contains the stack definition for the web search agent 67 | - `/lambda/`: Contains the Lambda function code for handling API calls 68 | - `/test/`: Contains a Python script to test the deployed agent 69 | 70 | ## Cleaning Up 71 | 72 | To remove all resources created by this project: 73 | 74 | 1. Run `cdk destroy` to delete the CDK-managed resources. 75 | 2. Delete the secrets from AWS Secrets Manager: 76 | ``` 77 | aws secretsmanager delete-secret --secret-id SERPER_API_KEY 78 | aws secretsmanager delete-secret --secret-id TAVILY_API_KEY 79 | ``` 80 | 81 | ## Security 82 | 83 | See [CONTRIBUTING](CONTRIBUTING.md#security-issue-notifications) for more information. 84 | 85 | ## License 86 | 87 | This project is licensed under the MIT-0 License. See the [LICENSE](LICENSE) file for details. 88 | -------------------------------------------------------------------------------- /app.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 4 | # SPDX-License-Identifier: MIT-0 5 | 6 | import aws_cdk as cdk 7 | import cdk_nag 8 | 9 | from cdk.cdk_stack import WebSearchAgentStack 10 | 11 | app = cdk.App() 12 | WebSearchAgentStack(app, "WebSearchAgentStack") 13 | 14 | cdk.Aspects.of(app).add(cdk_nag.AwsSolutionsChecks(verbose=True)) 15 | app.synth() 16 | -------------------------------------------------------------------------------- /cdk.json: -------------------------------------------------------------------------------- 1 | { 2 | "app": "python3 app.py", 3 | "watch": { 4 | "include": [ 5 | "**" 6 | ], 7 | "exclude": [ 8 | "README.md", 9 | "cdk*.json", 10 | "requirements*.txt", 11 | "source.bat", 12 | "**/__init__.py", 13 | "**/__pycache__", 14 | "tests" 15 | ] 16 | }, 17 | "context": { 18 | "@aws-cdk/aws-lambda:recognizeLayerVersion": true, 19 | "@aws-cdk/core:checkSecretUsage": true, 20 | "@aws-cdk/core:target-partitions": [ 21 | "aws", 22 | "aws-cn" 23 | ], 24 | "@aws-cdk-containers/ecs-service-extensions:enableDefaultLogDriver": true, 25 | "@aws-cdk/aws-ec2:uniqueImdsv2TemplateName": true, 26 | "@aws-cdk/aws-ecs:arnFormatIncludesClusterName": true, 27 | "@aws-cdk/aws-iam:minimizePolicies": true, 28 | "@aws-cdk/core:validateSnapshotRemovalPolicy": true, 29 | "@aws-cdk/aws-codepipeline:crossAccountKeyAliasStackSafeResourceName": true, 30 | "@aws-cdk/aws-s3:createDefaultLoggingPolicy": true, 31 | "@aws-cdk/aws-sns-subscriptions:restrictSqsDescryption": true, 32 | "@aws-cdk/aws-apigateway:disableCloudWatchRole": true, 33 | "@aws-cdk/core:enablePartitionLiterals": true, 34 | "@aws-cdk/aws-events:eventsTargetQueueSameAccount": true, 35 | "@aws-cdk/aws-iam:standardizedServicePrincipals": true, 36 | "@aws-cdk/aws-ecs:disableExplicitDeploymentControllerForCircuitBreaker": true, 37 | "@aws-cdk/aws-iam:importedRoleStackSafeDefaultPolicyName": true, 38 | "@aws-cdk/aws-s3:serverAccessLogsUseBucketPolicy": true, 39 | "@aws-cdk/aws-route53-patters:useCertificate": true, 40 | "@aws-cdk/customresources:installLatestAwsSdkDefault": false, 41 | "@aws-cdk/aws-rds:databaseProxyUniqueResourceName": true, 42 | "@aws-cdk/aws-codedeploy:removeAlarmsFromDeploymentGroup": true, 43 | "@aws-cdk/aws-apigateway:authorizerChangeDeploymentLogicalId": true, 44 | "@aws-cdk/aws-ec2:launchTemplateDefaultUserData": true, 45 | "@aws-cdk/aws-redshift:columnId": true, 46 | "@aws-cdk/aws-stepfunctions-tasks:enableEmrServicePolicyV2": true, 47 | "@aws-cdk/aws-ec2:restrictDefaultSecurityGroup": true, 48 | "@aws-cdk/aws-apigateway:requestValidatorUniqueId": true, 49 | "@aws-cdk/aws-kms:aliasNameRef": true, 50 | "@aws-cdk/aws-autoscaling:generateLaunchTemplateInsteadOfLaunchConfig": true, 51 | "@aws-cdk/core:includePrefixInUniqueNameGeneration": true, 52 | "@aws-cdk/aws-efs:denyAnonymousAccess": true, 53 | "@aws-cdk/aws-opensearchservice:enableOpensearchMultiAzWithStandby": true, 54 | "@aws-cdk/aws-lambda-nodejs:useLatestRuntimeVersion": true, 55 | "@aws-cdk/aws-efs:mountTargetOrderInsensitiveLogicalId": true, 56 | "@aws-cdk/aws-rds:auroraClusterChangeScopeOfInstanceParameterGroupWithEachParameters": true, 57 | "@aws-cdk/aws-appsync:useArnForSourceApiAssociationIdentifier": true, 58 | "@aws-cdk/aws-rds:preventRenderingDeprecatedCredentials": true, 59 | "@aws-cdk/aws-codepipeline-actions:useNewDefaultBranchForCodeCommitSource": true, 60 | "@aws-cdk/aws-cloudwatch-actions:changeLambdaPermissionLogicalIdForLambdaAction": true, 61 | "@aws-cdk/aws-codepipeline:crossAccountKeysDefaultValueToFalse": true, 62 | "@aws-cdk/aws-codepipeline:defaultPipelineTypeToV2": true, 63 | "@aws-cdk/aws-kms:reduceCrossAccountRegionPolicyScope": true, 64 | "@aws-cdk/aws-eks:nodegroupNameAttribute": true, 65 | "@aws-cdk/aws-ec2:ebsDefaultGp3Volume": true, 66 | "@aws-cdk/aws-ecs:removeDefaultDeploymentAlarm": true, 67 | "@aws-cdk/custom-resources:logApiResponseDataPropertyTrueDefault": false 68 | } 69 | } 70 | -------------------------------------------------------------------------------- /cdk/cdk_stack.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # SPDX-License-Identifier: MIT-0 3 | from typing import Any, Self 4 | 5 | import aws_cdk.aws_iam as iam 6 | import cdk_nag 7 | from aws_cdk import Aws, CfnOutput, Duration, Stack 8 | from aws_cdk import aws_bedrock as bedrock 9 | from aws_cdk import aws_lambda as _lambda 10 | from constructs import Construct 11 | 12 | MODEL_ID = "anthropic.claude-3-sonnet-20240229-v1:0" 13 | FUNCTION_NAME = "websearch_lambda" 14 | AGENT_NAME = "websearch_agent" 15 | ACTION_GROUP_NAME = "action-group-web-search" 16 | 17 | 18 | class WebSearchAgentStack(Stack): # type: ignore 19 | def __init__(self: Self, scope: Construct, construct_id: str, **kwargs: Any) -> None: # type: ignore 20 | super().__init__(scope, construct_id, **kwargs) 21 | 22 | lambda_policy = iam.Policy( 23 | self, 24 | "LambdaPolicy", 25 | statements=[ 26 | iam.PolicyStatement( 27 | sid="CreateLogGroup", 28 | effect=iam.Effect.ALLOW, 29 | actions=["logs:CreateLogGroup"], 30 | resources=[f"arn:aws:logs:{Aws.REGION}:{Aws.ACCOUNT_ID}:*"], 31 | ), 32 | iam.PolicyStatement( 33 | sid="CreateLogStreamAndPutLogEvents", 34 | effect=iam.Effect.ALLOW, 35 | actions=["logs:CreateLogStream", "logs:PutLogEvents"], 36 | resources=[ 37 | f"arn:aws:logs:{Aws.REGION}:{Aws.ACCOUNT_ID}:log-group:/aws/lambda/{FUNCTION_NAME}", 38 | f"arn:aws:logs:{Aws.REGION}:{Aws.ACCOUNT_ID}:log-group:/aws/lambda/{FUNCTION_NAME}:log-stream:*", 39 | ], 40 | # log-group:/aws/lambda/action-group-web-search-2j017:log-stream: 41 | # resources=[f"arn:aws:logs:{Aws.REGION}:{Aws.ACCOUNT_ID}:log-group:/aws/lambda/{FUNCTION_NAME}:*"], 42 | ), 43 | iam.PolicyStatement( 44 | sid="GetSecretsManagerSecret", 45 | effect=iam.Effect.ALLOW, 46 | actions=["secretsmanager:GetSecretValue"], 47 | resources=[ 48 | f"arn:aws:secretsmanager:{Aws.REGION}:{Aws.ACCOUNT_ID}:secret:SERPER_API_KEY-*", 49 | f"arn:aws:secretsmanager:{Aws.REGION}:{Aws.ACCOUNT_ID}:secret:TAVILY_API_KEY-*", 50 | ], 51 | ), 52 | ], 53 | ) 54 | cdk_nag.NagSuppressions.add_resource_suppressions( 55 | lambda_policy, 56 | [cdk_nag.NagPackSuppression(id="AwsSolutions-IAM5", reason="'log-stream:*' - stream names dynamically generated at runtime")], 57 | ) 58 | 59 | lambda_role = iam.Role( 60 | self, 61 | "LambdaRole", 62 | role_name=f"{FUNCTION_NAME}_role", 63 | assumed_by=iam.ServicePrincipal("lambda.amazonaws.com"), 64 | ) 65 | 66 | lambda_role.attach_inline_policy(lambda_policy) 67 | 68 | lambda_function = _lambda.Function( 69 | self, 70 | "WebSearch", 71 | function_name=FUNCTION_NAME, 72 | runtime=_lambda.Runtime.PYTHON_3_12, 73 | architecture=_lambda.Architecture.ARM_64, 74 | code=_lambda.Code.from_asset("lambda"), 75 | handler="websearch_lambda.lambda_handler", 76 | timeout=Duration.seconds(30), 77 | role=lambda_role, 78 | environment={"LOG_LEVEL": "DEBUG", "ACTION_GROUP": f"{ACTION_GROUP_NAME}"}, 79 | ) 80 | 81 | bedrock_account_principal = iam.PrincipalWithConditions( 82 | iam.ServicePrincipal("bedrock.amazonaws.com"), 83 | conditions={ 84 | "StringEquals": {"aws:SourceAccount": f"{Aws.ACCOUNT_ID}"}, 85 | }, 86 | ) 87 | lambda_function.add_permission( 88 | id="LambdaResourcePolicyAgentsInvokeFunction", 89 | principal=bedrock_account_principal, 90 | action="lambda:invokeFunction", 91 | ) 92 | 93 | agent_policy = iam.Policy( 94 | self, 95 | "AgentPolicy", 96 | statements=[ 97 | iam.PolicyStatement( 98 | sid="AmazonBedrockAgentBedrockFoundationModelPolicy", 99 | effect=iam.Effect.ALLOW, 100 | actions=["bedrock:InvokeModel"], 101 | resources=[f"arn:aws:bedrock:{Aws.REGION}::foundation-model/{MODEL_ID}"], 102 | ), 103 | ], 104 | ) 105 | 106 | agent_role_trust = iam.PrincipalWithConditions( 107 | iam.ServicePrincipal("bedrock.amazonaws.com"), 108 | conditions={ 109 | "StringLike": {"aws:SourceAccount": f"{Aws.ACCOUNT_ID}"}, 110 | "ArnLike": {"aws:SourceArn": f"arn:aws:bedrock:{Aws.REGION}:{Aws.ACCOUNT_ID}:agent/*"}, 111 | }, 112 | ) 113 | agent_role = iam.Role( 114 | self, 115 | "AmazonBedrockExecutionRoleForAgents", 116 | role_name=f"AmazonBedrockExecutionRoleForAgents_{AGENT_NAME}", 117 | assumed_by=agent_role_trust, 118 | ) 119 | agent_role.attach_inline_policy(agent_policy) 120 | 121 | action_group = bedrock.CfnAgent.AgentActionGroupProperty( 122 | action_group_name=f"{ACTION_GROUP_NAME}", 123 | description="Action that will trigger the lambda", 124 | action_group_executor=bedrock.CfnAgent.ActionGroupExecutorProperty(lambda_=lambda_function.function_arn), 125 | function_schema=bedrock.CfnAgent.FunctionSchemaProperty( 126 | functions=[ 127 | bedrock.CfnAgent.FunctionProperty( 128 | name="tavily-ai-search", 129 | description=""" 130 | To retrieve information via the internet 131 | or for topics that the LLM does not know about and 132 | intense research is needed. 133 | """, 134 | parameters={ 135 | "search_query": bedrock.CfnAgent.ParameterDetailProperty( 136 | type="string", 137 | description="The search query for the Tavily web search.", 138 | required=True, 139 | ) 140 | }, 141 | ), 142 | bedrock.CfnAgent.FunctionProperty( 143 | name="google-search", 144 | description="For targeted news, like 'what are the latest news in Austria' or similar.", 145 | parameters={ 146 | "search_query": bedrock.CfnAgent.ParameterDetailProperty( 147 | type="string", 148 | description="The search query for the Google web search.", 149 | required=True, 150 | ) 151 | }, 152 | ), 153 | ] 154 | ), 155 | ) 156 | 157 | agent_instruction = """ 158 | You are an agent that can handle various tasks as described below: 159 | 160 | 1/ Helping users do research and finding up to date information. For up to date information always 161 | uses web search. Web search has two flavours: 162 | 163 | 1a/ Google Search - this is great for looking up up to date information and current events 164 | 165 | 2b/ Tavily AI Search - this is used to do deep research on topics your user is interested in. 166 | Not good on being used on news as it does not order search results by date. 167 | 168 | 2/ Retrieving knowledge from the vast knowledge bases that you are connected to. 169 | """ 170 | 171 | agent = bedrock.CfnAgent( 172 | self, 173 | "WebSearchAgent", 174 | agent_name=f"{AGENT_NAME}", 175 | foundation_model=f"{MODEL_ID}", 176 | action_groups=[action_group], 177 | auto_prepare=True, 178 | instruction=agent_instruction, 179 | agent_resource_role_arn=agent_role.role_arn, 180 | ) 181 | 182 | agent_alias = bedrock.CfnAgentAlias( 183 | self, 184 | "WebSearchAgentAlias", 185 | agent_id=agent.attr_agent_id, 186 | agent_alias_name=f"{AGENT_NAME}", 187 | ) 188 | 189 | CfnOutput(self, "agent_id", value=agent.attr_agent_id) 190 | CfnOutput(self, "agent_alias_id", value=agent_alias.attr_agent_alias_id) 191 | CfnOutput(self, "agent_version", value=agent.attr_agent_version) 192 | -------------------------------------------------------------------------------- /create_api_key_secrets.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | echo "Ensure TAVILY_API_KEY and SERPER_API_KEY environment variables are set." 4 | 5 | echo "Creating SERPER_API_KEY secret in secrets manager..." 6 | aws secretsmanager create-secret \ 7 | --name SERPER_API_KEY \ 8 | --description "The API secret key for Serper." \ 9 | --secret-string "$SERPER_API_KEY" 10 | 11 | echo "Creating TAVILY_API_KEY secret in secrets manager..." 12 | aws secretsmanager create-secret \ 13 | --name TAVILY_API_KEY \ 14 | --description "The API secret key for Tavily AI." \ 15 | --secret-string "$TAVILY_API_KEY" 16 | 17 | echo "Done." 18 | -------------------------------------------------------------------------------- /lambda/websearch_lambda.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # SPDX-License-Identifier: MIT-0 3 | import http.client 4 | import json 5 | import logging 6 | import os 7 | import urllib.parse 8 | import urllib.request 9 | 10 | import boto3 11 | 12 | log_level = os.environ.get("LOG_LEVEL", "INFO").strip().upper() 13 | logging.basicConfig(format="[%(asctime)s] p%(process)s {%(filename)s:%(lineno)d} %(levelname)s - %(message)s") 14 | logger = logging.getLogger(__name__) 15 | logger.setLevel(log_level) 16 | 17 | AWS_REGION = os.environ.get("AWS_REGION", "us-east-1") 18 | ACTION_GROUP_NAME = os.environ.get("ACTION_GROUP", "action-group-web-search-d213q") 19 | FUNCTION_NAMES = ["tavily-ai-search", "google-search"] 20 | 21 | 22 | def is_env_var_set(env_var: str) -> bool: 23 | return env_var in os.environ and os.environ[env_var] not in ("", "0", "false", "False") 24 | 25 | 26 | def get_from_secretstore_or_env(key: str) -> str: 27 | if is_env_var_set(key): 28 | logger.warning(f"getting value for {key} from environment var; recommended to use AWS Secrets Manager instead") 29 | return os.environ[key] 30 | 31 | session = boto3.session.Session() 32 | secrets_manager = session.client(service_name="secretsmanager", region_name=AWS_REGION) 33 | try: 34 | secret_value = secrets_manager.get_secret_value(SecretId=key) 35 | except Exception as e: 36 | logger.error(f"could not get secret {key} from secrets manager: {e}") 37 | raise e 38 | 39 | secret: str = secret_value["SecretString"] 40 | 41 | return secret 42 | 43 | 44 | SERPER_API_KEY = get_from_secretstore_or_env("SERPER_API_KEY") 45 | TAVILY_API_KEY = get_from_secretstore_or_env("TAVILY_API_KEY") 46 | 47 | 48 | def extract_search_params(action_group, function, parameters): 49 | if action_group != ACTION_GROUP_NAME: 50 | logger.error(f"unexpected name '{action_group}'; expected valid action group name '{ACTION_GROUP_NAME}'") 51 | return None, None 52 | 53 | if function not in FUNCTION_NAMES: 54 | logger.error(f"unexpected function name '{function}'; valid function names are'{FUNCTION_NAMES}'") 55 | return None, None 56 | 57 | search_query = next( 58 | (param["value"] for param in parameters if param["name"] == "search_query"), 59 | None, 60 | ) 61 | 62 | target_website = next( 63 | (param["value"] for param in parameters if param["name"] == "target_website"), 64 | None, 65 | ) 66 | 67 | logger.debug(f"extract_search_params: {search_query=} {target_website=}") 68 | 69 | return search_query, target_website 70 | 71 | 72 | def google_search(search_query: str, target_website: str = "") -> str: 73 | query = search_query 74 | if target_website: 75 | query += f" site:{target_website}" 76 | 77 | conn = http.client.HTTPSConnection("google.serper.dev") 78 | payload = json.dumps({"q": query}) 79 | headers = {"X-API-KEY": SERPER_API_KEY, "Content-Type": "application/json"} 80 | 81 | search_type = "news" # "news", "search", 82 | conn.request("POST", f"/{search_type}", payload, headers) 83 | res = conn.getresponse() 84 | data = res.read() 85 | 86 | return data.decode("utf-8") 87 | 88 | 89 | def tavily_ai_search(search_query: str, target_website: str = "") -> str: 90 | logger.info(f"executing Tavily AI search with {search_query=}") 91 | 92 | base_url = "https://api.tavily.com/search" 93 | headers = {"Content-Type": "application/json", "Accept": "application/json"} 94 | payload = { 95 | "api_key": TAVILY_API_KEY, 96 | "query": search_query, 97 | "search_depth": "advanced", 98 | "include_images": False, 99 | "include_answer": False, 100 | "include_raw_content": False, 101 | "max_results": 3, 102 | "include_domains": [target_website] if target_website else [], 103 | "exclude_domains": [], 104 | } 105 | 106 | data = json.dumps(payload).encode("utf-8") 107 | request = urllib.request.Request(base_url, data=data, headers=headers) # nosec: B310 fixed url we want to open 108 | 109 | try: 110 | response = urllib.request.urlopen(request) # nosec: B310 fixed url we want to open 111 | response_data: str = response.read().decode("utf-8") 112 | logger.debug(f"response from Tavily AI search {response_data=}") 113 | return response_data 114 | except urllib.error.HTTPError as e: 115 | logger.error(f"failed to retrieve search results from Tavily AI Search, error: {e.code}") 116 | 117 | return "" 118 | 119 | 120 | def lambda_handler(event, _): # type: ignore 121 | logging.debug(f"lambda_handler {event=}") 122 | 123 | action_group = event["actionGroup"] 124 | function = event["function"] 125 | parameters = event.get("parameters", []) 126 | 127 | logger.info(f"lambda_handler: {action_group=} {function=}") 128 | 129 | search_query, target_website = extract_search_params(action_group, function, parameters) 130 | 131 | search_results: str = "" 132 | if function == "tavily-ai-search": 133 | search_results = tavily_ai_search(search_query, target_website) 134 | elif function == "google-search": 135 | search_results = google_search(search_query, target_website) 136 | 137 | logger.debug(f"query results {search_results=}") 138 | 139 | # Prepare the response 140 | function_response_body = {"TEXT": {"body": f"Here are the top search results for the query '{search_query}': {search_results} "}} 141 | 142 | action_response = { 143 | "actionGroup": action_group, 144 | "function": function, 145 | "functionResponse": {"responseBody": function_response_body}, 146 | } 147 | 148 | response = {"response": action_response, "messageVersion": event["messageVersion"]} 149 | 150 | logger.debug(f"lambda_handler: {response=}") 151 | 152 | return response 153 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.bandit] 2 | exclude_dirs = [ 3 | ".venv", 4 | ".git", 5 | "__pycache__", 6 | "cdk.out", 7 | ] 8 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | aws-cdk-lib>=2.150.0 2 | constructs>=10.3.0 3 | cdk-nag>=2.28.158 4 | boto3>=1.34.150 5 | -------------------------------------------------------------------------------- /test/invoke-agent.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # SPDX-License-Identifier: MIT-0 3 | """ 4 | 5 | Convenience python script to test agent. 6 | Example use: 7 | python invoke-agent.py --agent_id RRSS0TTN3H --agent_alias_id 1PSAF8M2LF --prompt "What are the latest AWS news?" 8 | 9 | """ 10 | 11 | import argparse 12 | import os 13 | import random 14 | import string 15 | 16 | try: 17 | import boto3 18 | except ImportError: 19 | print("Error: The boto3 library is required. Please install e.g. with 'pip install boto3'") 20 | os._exit(0) 21 | import botocore 22 | 23 | REGION = os.environ["AWS_REGION"] 24 | 25 | session = boto3.session.Session() 26 | agents_runtime_client = session.client(service_name="bedrock-agent-runtime", region_name=REGION) 27 | 28 | 29 | def invoke_agent(agent_id, agent_alias_id, session_id, prompt): 30 | """ 31 | Sends a prompt for the agent to process and respond to. 32 | 33 | :param agent_id: The unique identifier of the agent to use. 34 | :param agent_alias_id: The alias of the agent to use. 35 | :param session_id: The unique identifier of the session. Use the same value across requests 36 | to continue the same conversation. 37 | :param prompt: The prompt that you want Claude to complete. 38 | :return: Inference response from the model. 39 | """ 40 | 41 | try: 42 | # Note: The execution time depends on the foundation model, complexity of the agent, and 43 | # the length of the prompt. In some cases, it can take up to a minute or more to generate a response. 44 | response = agents_runtime_client.invoke_agent( 45 | agentId=agent_id, 46 | agentAliasId=agent_alias_id, 47 | sessionId=session_id, 48 | inputText=prompt, 49 | ) 50 | 51 | completion = "" 52 | for event in response.get("completion"): 53 | chunk = event["chunk"] 54 | completion = completion + chunk["bytes"].decode() 55 | 56 | except botocore.exceptions.ClientError as e: 57 | print(f"Couldn't invoke agent. {e}") 58 | raise 59 | 60 | return completion 61 | 62 | 63 | if __name__ == "__main__": 64 | parser = argparse.ArgumentParser() 65 | parser.add_argument("--agent_id", required=True, help="The unique identifier of the agent to use.") 66 | parser.add_argument("--agent_alias_id", required=True, help="The alias of the agent to use.") 67 | parser.add_argument("--prompt", required=True, help="The prompt that you want Claude to complete.") 68 | args = parser.parse_args() 69 | 70 | session_id = "".join(random.choices(string.ascii_lowercase, k=10)) # nosec B311 no cryptographic purpose use 71 | 72 | response = invoke_agent(args.agent_id, args.agent_alias_id, session_id, args.prompt) 73 | 74 | print(response) 75 | -------------------------------------------------------------------------------- /v2/.gitignore: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aws-samples/websearch_agent/fbd3bc0b08af39e6fc184419e0ac0d1785f1a106/v2/.gitignore -------------------------------------------------------------------------------- /v2/README.md: -------------------------------------------------------------------------------- 1 | # Web Search Agent V2 with Amazon Bedrock 2 | 3 | This is version 2 of the Web Search Agent, featuring enhanced capabilities and improved architecture. This version introduces advanced web search capabilities, asynchronous operations, and integration with Claude 3 models. 4 | 5 | ## New Features in V2 6 | 7 | - **Advanced Web Search**: Intelligent query rewriting and result analysis 8 | - **Asynchronous Operations**: Parallel search execution for better performance 9 | - **Structured Output**: Integration with Bedrock's Claude 3 models for better response formatting 10 | - **Enhanced Error Handling**: Robust error management and validation 11 | - **Cross-Region Support**: Flexible deployment across AWS regions 12 | - **Type Safety**: Comprehensive type hints and Pydantic models 13 | 14 | ## Architecture Components 15 | 16 | - **Web Search Lambda**: Basic web search functionality using Google Search and Tavily AI 17 | - **Advanced Web Search Lambda**: Enhanced search with query refinement and result analysis 18 | - **Lambda Layers**: Optimized dependency management for ARM64 architecture 19 | - **Bedrock Agent**: Updated agent configuration for improved interaction 20 | 21 | ## Prerequisites 22 | 23 | - An active AWS Account 24 | - AWS CDK version 2.174.3 or later 25 | - Python 3.12 or later 26 | - API keys for [SerpAPI](https://serpapi.com/) and [Tavily AI](https://tavily.com/) 27 | 28 | ## Setup 29 | 30 | 1. Clone the repository and follow the setups steps that are the same as in V1. 31 | 32 | ``` 33 | git clone https://github.com/aws-samples/websearch_agent 34 | cd websearch_agent 35 | ``` 36 | 37 | 2. Create and activate a Python virtual environment: 38 | 39 | ``` 40 | python -m venv .venv 41 | source .venv/bin/activate 42 | ``` 43 | 44 | 3. Install the required dependencies: 45 | 46 | ``` 47 | pip install -r requirements.txt 48 | ``` 49 | 50 | 4. Store your API keys in AWS Secrets Manager: 51 | 52 | ``` 53 | aws secretsmanager create-secret --name SERPER_API_KEY --secret-string "your_serper_api_key" 54 | aws secretsmanager create-secret --name TAVILY_API_KEY --secret-string "your_tavily_api_key" 55 | ``` 56 | 57 | 5. Deploy the CDK stack: 58 | ``` 59 | cd v2 60 | cdk deploy 61 | ``` 62 | 63 | ## Configuration 64 | 65 | The `config.yaml` file contains important settings: 66 | 67 | - AWS region configuration 68 | - Cross-region inference settings 69 | - Model selection (Claude 3 Sonnet/Haiku) 70 | - Lambda function names and configurations 71 | - Action group names 72 | 73 | If you deployed the v1 version, you have 2 options: 74 | 75 | - delete the v1 stack 76 | - rename the components in the v2 stack to avoid naming conflicts 77 | 78 | 79 | ## Cleaning Up 80 | 81 | To remove all resources: 82 | 83 | 1. Run `cdk destroy` to delete the CDK-managed resources 84 | 2. Delete the secrets from AWS Secrets Manager: 85 | ``` 86 | aws secretsmanager delete-secret --secret-id SERPER_API_KEY 87 | aws secretsmanager delete-secret --secret-id TAVILY_API_KEY 88 | ``` 89 | 90 | ## Security 91 | 92 | See [CONTRIBUTING](../CONTRIBUTING.md#security-issue-notifications) for more information. 93 | 94 | ## License 95 | 96 | This project is licensed under the MIT-0 License. See the [LICENSE](../LICENSE) file for details. 97 | -------------------------------------------------------------------------------- /v2/app.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 4 | # SPDX-License-Identifier: MIT-0 5 | 6 | import aws_cdk as cdk 7 | import cdk_nag 8 | 9 | from cdk.cdk_stack import WebSearchAgentStack 10 | 11 | app = cdk.App() 12 | WebSearchAgentStack(app, "WebSearchAgentStack-v2") 13 | 14 | cdk.Aspects.of(app).add(cdk_nag.AwsSolutionsChecks(verbose=True)) 15 | app.synth() 16 | -------------------------------------------------------------------------------- /v2/cdk.json: -------------------------------------------------------------------------------- 1 | { 2 | "app": "python3 app.py", 3 | "watch": { 4 | "include": [ 5 | "**" 6 | ], 7 | "exclude": [ 8 | "README.md", 9 | "cdk*.json", 10 | "requirements*.txt", 11 | "source.bat", 12 | "**/__init__.py", 13 | "**/__pycache__", 14 | "tests" 15 | ] 16 | }, 17 | "context": { 18 | "@aws-cdk/aws-lambda:recognizeLayerVersion": true, 19 | "@aws-cdk/core:checkSecretUsage": true, 20 | "@aws-cdk/core:target-partitions": [ 21 | "aws", 22 | "aws-cn" 23 | ], 24 | "@aws-cdk-containers/ecs-service-extensions:enableDefaultLogDriver": true, 25 | "@aws-cdk/aws-ec2:uniqueImdsv2TemplateName": true, 26 | "@aws-cdk/aws-ecs:arnFormatIncludesClusterName": true, 27 | "@aws-cdk/aws-iam:minimizePolicies": true, 28 | "@aws-cdk/core:validateSnapshotRemovalPolicy": true, 29 | "@aws-cdk/aws-codepipeline:crossAccountKeyAliasStackSafeResourceName": true, 30 | "@aws-cdk/aws-s3:createDefaultLoggingPolicy": true, 31 | "@aws-cdk/aws-sns-subscriptions:restrictSqsDescryption": true, 32 | "@aws-cdk/aws-apigateway:disableCloudWatchRole": true, 33 | "@aws-cdk/core:enablePartitionLiterals": true, 34 | "@aws-cdk/aws-events:eventsTargetQueueSameAccount": true, 35 | "@aws-cdk/aws-iam:standardizedServicePrincipals": true, 36 | "@aws-cdk/aws-ecs:disableExplicitDeploymentControllerForCircuitBreaker": true, 37 | "@aws-cdk/aws-iam:importedRoleStackSafeDefaultPolicyName": true, 38 | "@aws-cdk/aws-s3:serverAccessLogsUseBucketPolicy": true, 39 | "@aws-cdk/aws-route53-patters:useCertificate": true, 40 | "@aws-cdk/customresources:installLatestAwsSdkDefault": false, 41 | "@aws-cdk/aws-rds:databaseProxyUniqueResourceName": true, 42 | "@aws-cdk/aws-codedeploy:removeAlarmsFromDeploymentGroup": true, 43 | "@aws-cdk/aws-apigateway:authorizerChangeDeploymentLogicalId": true, 44 | "@aws-cdk/aws-ec2:launchTemplateDefaultUserData": true, 45 | "@aws-cdk/aws-redshift:columnId": true, 46 | "@aws-cdk/aws-stepfunctions-tasks:enableEmrServicePolicyV2": true, 47 | "@aws-cdk/aws-ec2:restrictDefaultSecurityGroup": true, 48 | "@aws-cdk/aws-apigateway:requestValidatorUniqueId": true, 49 | "@aws-cdk/aws-kms:aliasNameRef": true, 50 | "@aws-cdk/aws-autoscaling:generateLaunchTemplateInsteadOfLaunchConfig": true, 51 | "@aws-cdk/core:includePrefixInUniqueNameGeneration": true, 52 | "@aws-cdk/aws-efs:denyAnonymousAccess": true, 53 | "@aws-cdk/aws-opensearchservice:enableOpensearchMultiAzWithStandby": true, 54 | "@aws-cdk/aws-lambda-nodejs:useLatestRuntimeVersion": true, 55 | "@aws-cdk/aws-efs:mountTargetOrderInsensitiveLogicalId": true, 56 | "@aws-cdk/aws-rds:auroraClusterChangeScopeOfInstanceParameterGroupWithEachParameters": true, 57 | "@aws-cdk/aws-appsync:useArnForSourceApiAssociationIdentifier": true, 58 | "@aws-cdk/aws-rds:preventRenderingDeprecatedCredentials": true, 59 | "@aws-cdk/aws-codepipeline-actions:useNewDefaultBranchForCodeCommitSource": true, 60 | "@aws-cdk/aws-cloudwatch-actions:changeLambdaPermissionLogicalIdForLambdaAction": true, 61 | "@aws-cdk/aws-codepipeline:crossAccountKeysDefaultValueToFalse": true, 62 | "@aws-cdk/aws-codepipeline:defaultPipelineTypeToV2": true, 63 | "@aws-cdk/aws-kms:reduceCrossAccountRegionPolicyScope": true, 64 | "@aws-cdk/aws-eks:nodegroupNameAttribute": true, 65 | "@aws-cdk/aws-ec2:ebsDefaultGp3Volume": true, 66 | "@aws-cdk/aws-ecs:removeDefaultDeploymentAlarm": true, 67 | "@aws-cdk/custom-resources:logApiResponseDataPropertyTrueDefault": false 68 | } 69 | } 70 | -------------------------------------------------------------------------------- /v2/cdk/__init__.py: -------------------------------------------------------------------------------- 1 | # This file is intentionally left empty to mark the directory as a Python package. 2 | -------------------------------------------------------------------------------- /v2/cdk/bedrock_agent.py: -------------------------------------------------------------------------------- 1 | from aws_cdk import aws_bedrock as bedrock 2 | from aws_cdk import aws_iam as iam 3 | from aws_cdk import aws_lambda as _lambda 4 | from constructs import Construct 5 | 6 | 7 | def create_bedrock_agent( 8 | self: Construct, 9 | construct_id: str, 10 | websearch_lambda: _lambda.Function, 11 | advanced_web_search_lambda: _lambda.Function, 12 | agent_role: iam.Role, 13 | config: dict, 14 | ) -> tuple[bedrock.CfnAgent, bedrock.CfnAgentAlias]: 15 | websearch_action_group = bedrock.CfnAgent.AgentActionGroupProperty( 16 | action_group_name=f"{config['WEBSEARCH_ACTION_GROUP_NAME']}", 17 | description="Action that will trigger the websearch lambda", 18 | action_group_executor=bedrock.CfnAgent.ActionGroupExecutorProperty( 19 | lambda_=websearch_lambda.function_arn # pylint: disable=unexpected-keyword-arg 20 | ), 21 | function_schema=bedrock.CfnAgent.FunctionSchemaProperty( 22 | functions=[ 23 | bedrock.CfnAgent.FunctionProperty( 24 | name="tavily-ai-search", 25 | description=""" 26 | To retrieve information via the internet 27 | or for topics that the LLM does not know about and 28 | intense research is needed. 29 | """, 30 | parameters={ 31 | "search_query": bedrock.CfnAgent.ParameterDetailProperty( 32 | type="string", 33 | description="The search query for the Tavily web search.", 34 | required=True, 35 | ), 36 | "target_website": bedrock.CfnAgent.ParameterDetailProperty( 37 | type="string", 38 | description="Limits the search to a specific website.", 39 | required=False, 40 | ), 41 | "search_depth": bedrock.CfnAgent.ParameterDetailProperty( 42 | type="string", 43 | description="The depth of the search. Can be 'basic' or 'advanced'. Basic uses 1 credit, advanced uses 2 credits.", 44 | required=False, 45 | ), 46 | "max_results": bedrock.CfnAgent.ParameterDetailProperty( 47 | type="integer", 48 | description="The maximum number of search results to return. Default is 3.", 49 | required=False, 50 | ), 51 | "topic": bedrock.CfnAgent.ParameterDetailProperty( 52 | type="string", 53 | description="Category of search. Supports 'general' or 'news'. Default is 'general'.", 54 | required=False, 55 | ), 56 | }, 57 | ), 58 | bedrock.CfnAgent.FunctionProperty( 59 | name="google-search", 60 | description="For targeted news, like 'what are the latest news in Austria' or similar.", 61 | parameters={ 62 | "search_query": bedrock.CfnAgent.ParameterDetailProperty( 63 | type="string", 64 | description="The search query for the Google web search.", 65 | required=True, 66 | ), 67 | "target_website": bedrock.CfnAgent.ParameterDetailProperty( 68 | type="string", 69 | description="Limits the search to a specific website.", 70 | required=False, 71 | ), 72 | "search_type": bedrock.CfnAgent.ParameterDetailProperty( 73 | type="string", 74 | description="Type of search to perform. Options: 'search' (web search), 'news' (news search).", 75 | required=False, 76 | ), 77 | "time_period": bedrock.CfnAgent.ParameterDetailProperty( 78 | type="string", 79 | description="Filter results by recency. Options: 'qdr:h' (past hour), 'qdr:d' (past day), 'qdr:w' (past week), 'qdr:m' (past month), 'qdr:y' (past year).", 80 | required=False, 81 | ), 82 | "country_code": bedrock.CfnAgent.ParameterDetailProperty( 83 | type="string", 84 | description="Two-letter country code for localized results.", 85 | required=False, 86 | ), 87 | }, 88 | ), 89 | ] 90 | ), 91 | ) 92 | 93 | advanced_web_search_action_group = bedrock.CfnAgent.AgentActionGroupProperty( 94 | action_group_name=f"{config['ADVANCED_SEARCH_ACTION_GROUP_NAME']}", 95 | description="Action that will trigger the advanced web search lambda", 96 | action_group_executor=bedrock.CfnAgent.ActionGroupExecutorProperty( 97 | lambda_=advanced_web_search_lambda.function_arn # pylint: disable=unexpected-keyword-arg 98 | ), 99 | function_schema=bedrock.CfnAgent.FunctionSchemaProperty( 100 | functions=[ 101 | bedrock.CfnAgent.FunctionProperty( 102 | name="advanced-web-search", 103 | description=""" 104 | To perform a comprehensive search with query rewriting, 105 | subquery searching, and result checking. 106 | """, 107 | parameters={ 108 | "search_query": bedrock.CfnAgent.ParameterDetailProperty( 109 | type="string", 110 | description="The search query for the Advanced Web Search.", 111 | required=True, 112 | ) 113 | }, 114 | ), 115 | ] 116 | ), 117 | ) 118 | 119 | agent_instruction = """ 120 | You are an agent that can handle various tasks as described below: 121 | 122 | 1/ Helping users do research and finding up to date information. For up to date information always 123 | uses web search. Web search has three flavours: 124 | 125 | 1a/ Google Search - this is great for looking up up to date information and current events 126 | 127 | 1b/ Tavily AI Search - this is used to do deep research on topics your user is interested in. 128 | Not good on being used on news as it does not order search results by date. 129 | 130 | 1c/ Advanced Web Search - this performs a comprehensive search with query rewriting, 131 | subquery searching, and result checking. Use this for complex queries that require 132 | a more thorough analysis. 133 | 134 | """ 135 | 136 | agent = bedrock.CfnAgent( 137 | self, 138 | "WebSearchAgent", 139 | agent_name=f"{config['AGENT_NAME']}", 140 | foundation_model=config["SMART_LLM"], 141 | action_groups=[websearch_action_group, advanced_web_search_action_group], 142 | auto_prepare=True, 143 | instruction=agent_instruction, 144 | agent_resource_role_arn=agent_role.role_arn, 145 | ) 146 | 147 | agent_alias = bedrock.CfnAgentAlias( 148 | self, 149 | "WebSearchAgentAlias", 150 | agent_id=agent.attr_agent_id, 151 | agent_alias_name=f"{config['AGENT_NAME']}", 152 | ) 153 | 154 | return agent, agent_alias 155 | -------------------------------------------------------------------------------- /v2/cdk/cdk_stack.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # SPDX-License-Identifier: MIT-0 3 | from typing import Any, Self 4 | 5 | import aws_cdk.aws_iam as iam 6 | import cdk_nag 7 | from aws_cdk import CfnOutput, Stack 8 | from aws_cdk import aws_lambda as _lambda 9 | from constructs import Construct 10 | from .websearch_lambda_layers import WebSearchLambdaLayers 11 | 12 | from .bedrock_agent import create_bedrock_agent 13 | from .lambda_functions import create_lambda_functions, create_lambda_roles 14 | from .log_groups import create_log_groups 15 | from .utils import load_config, get_account_id 16 | 17 | 18 | class WebSearchAgentStack(Stack): 19 | def __init__( 20 | self: Self, scope: Construct, construct_id: str, **kwargs: Any 21 | ) -> None: 22 | super().__init__(scope, construct_id, **kwargs) 23 | 24 | # Load configuration 25 | config = load_config() 26 | 27 | # Create Lambda layers 28 | lambda_layers = WebSearchLambdaLayers( 29 | self, 30 | "WebSearchLambdaLayers", 31 | stack_name=construct_id, 32 | architecture=_lambda.Architecture.ARM_64, 33 | region=config['DEPLOY_REGION'] 34 | ) 35 | 36 | # Create log groups 37 | websearch_log_group, advanced_web_search_log_group = create_log_groups( 38 | self, construct_id, config 39 | ) 40 | 41 | lambda_policy = iam.Policy( 42 | self, 43 | "LambdaPolicy", 44 | statements=[ 45 | iam.PolicyStatement( 46 | sid="CreateLogGroup", 47 | effect=iam.Effect.ALLOW, 48 | actions=["logs:CreateLogGroup"], 49 | resources=[ 50 | f"arn:aws:logs:{config['DEPLOY_REGION']}:{get_account_id()}:*" 51 | ], 52 | ), 53 | iam.PolicyStatement( 54 | sid="CreateLogStreamAndPutLogEvents", 55 | effect=iam.Effect.ALLOW, 56 | actions=["logs:CreateLogStream", "logs:PutLogEvents"], 57 | resources=[ 58 | websearch_log_group.log_group_arn, 59 | f"{websearch_log_group.log_group_arn}:log-stream:*", 60 | advanced_web_search_log_group.log_group_arn, 61 | f"{advanced_web_search_log_group.log_group_arn}:log-stream:*", 62 | ], 63 | ), 64 | iam.PolicyStatement( 65 | sid="GetSecretsManagerSecret", 66 | effect=iam.Effect.ALLOW, 67 | actions=["secretsmanager:GetSecretValue"], 68 | resources=[ 69 | f"arn:aws:secretsmanager:{config['DEPLOY_REGION']}:{get_account_id()}:secret:SERPER_API_KEY-*", 70 | f"arn:aws:secretsmanager:{config['DEPLOY_REGION']}:{get_account_id()}:secret:TAVILY_API_KEY-*", 71 | ], 72 | ), 73 | ], 74 | ) 75 | cdk_nag.NagSuppressions.add_resource_suppressions( 76 | lambda_policy, 77 | [ 78 | cdk_nag.NagPackSuppression( 79 | id="AwsSolutions-IAM5", 80 | reason="'log-stream:*' - stream names dynamically generated at runtime", 81 | ) 82 | ], 83 | ) 84 | 85 | # Create Lambda roles 86 | websearch_lambda_role, advanced_web_search_lambda_role = create_lambda_roles( 87 | self, construct_id, lambda_policy, config 88 | ) 89 | 90 | # Add suppressions for Lambda roles 91 | cdk_nag.NagSuppressions.add_resource_suppressions( 92 | websearch_lambda_role, 93 | [ 94 | cdk_nag.NagPackSuppression( 95 | id="AwsSolutions-IAM5", 96 | reason="Lambda role requires access to CloudWatch logs and Secrets Manager", 97 | applies_to=[ 98 | "Resource::arn:aws:logs:::log-group:/aws/lambda/*:*", 99 | "Resource::arn:aws:secretsmanager:::secret:SERPER_API_KEY-*", 100 | "Resource::arn:aws:secretsmanager:::secret:TAVILY_API_KEY-*", 101 | ], 102 | ) 103 | ], 104 | ) 105 | cdk_nag.NagSuppressions.add_resource_suppressions( 106 | advanced_web_search_lambda_role, 107 | [ 108 | cdk_nag.NagPackSuppression( 109 | id="AwsSolutions-IAM5", 110 | reason="Lambda role requires access to CloudWatch logs and Secrets Manager", 111 | applies_to=[ 112 | "Resource::arn:aws:logs:::log-group:/aws/lambda/*:*", 113 | "Resource::arn:aws:secretsmanager:::secret:SERPER_API_KEY-*", 114 | "Resource::arn:aws:secretsmanager:::secret:TAVILY_API_KEY-*", 115 | ], 116 | ) 117 | ], 118 | ) 119 | 120 | # Create Lambda functions 121 | websearch_lambda, advanced_web_search_lambda = create_lambda_functions( 122 | self, 123 | construct_id, 124 | websearch_lambda_role, 125 | advanced_web_search_lambda_role, 126 | lambda_layers, 127 | config, 128 | ) 129 | 130 | bedrock_account_principal = iam.PrincipalWithConditions( 131 | iam.ServicePrincipal("bedrock.amazonaws.com"), 132 | conditions={ 133 | "StringEquals": {"aws:SourceAccount": f"{get_account_id()}"}, 134 | }, 135 | ) 136 | websearch_lambda.add_permission( 137 | id="WebSearchLambdaResourcePolicyAgentsInvokeFunction", 138 | principal=bedrock_account_principal, 139 | action="lambda:invokeFunction", 140 | ) 141 | advanced_web_search_lambda.add_permission( 142 | id="AdvancedWebSearchLambdaResourcePolicyAgentsInvokeFunction", 143 | principal=bedrock_account_principal, 144 | action="lambda:invokeFunction", 145 | ) 146 | 147 | agent_policy = iam.Policy( 148 | self, 149 | "AgentPolicy", 150 | statements=[ 151 | iam.PolicyStatement( 152 | sid="AmazonBedrockAgentBedrockFoundationModelPolicy", 153 | effect=iam.Effect.ALLOW, 154 | actions=["bedrock:InvokeModel"], 155 | resources=[ 156 | f"arn:aws:bedrock:{config['DEPLOY_REGION']}::foundation-model/{config['SMART_LLM']}", 157 | f"arn:aws:bedrock:{config['DEPLOY_REGION']}::foundation-model/{config['FAST_LLM']}", 158 | ], 159 | ), 160 | ], 161 | ) 162 | 163 | # Add suppression for agent policy 164 | cdk_nag.NagSuppressions.add_resource_suppressions( 165 | agent_policy, 166 | [ 167 | cdk_nag.NagPackSuppression( 168 | id="AwsSolutions-IAM5", 169 | reason="Agent policy requires access to Bedrock foundation models", 170 | applies_to=[ 171 | f"Resource::arn:aws:bedrock:{config['DEPLOY_REGION']}::foundation-model/{config['SMART_LLM']}", 172 | f"Resource::arn:aws:bedrock:{config['DEPLOY_REGION']}::foundation-model/{config['FAST_LLM']}", 173 | ], 174 | ) 175 | ], 176 | ) 177 | 178 | agent_role_trust = iam.PrincipalWithConditions( 179 | iam.ServicePrincipal("bedrock.amazonaws.com"), 180 | conditions={ 181 | "StringLike": {"aws:SourceAccount": f"{get_account_id()}"}, 182 | "ArnLike": { 183 | "aws:SourceArn": f"arn:aws:bedrock:{config['DEPLOY_REGION']}:{get_account_id()}:agent/*" 184 | }, 185 | }, 186 | ) 187 | agent_role = iam.Role( 188 | self, 189 | "AmazonBedrockExecutionRoleForAgents", 190 | role_name=f"AmazonBedrockExecutionRoleForAgents_{config['AGENT_NAME']}", 191 | assumed_by=agent_role_trust, 192 | ) 193 | agent_role.attach_inline_policy(agent_policy) 194 | 195 | # Create Bedrock agent 196 | agent, agent_alias = create_bedrock_agent( 197 | self, 198 | construct_id, 199 | websearch_lambda, 200 | advanced_web_search_lambda, 201 | agent_role, 202 | config, 203 | ) 204 | 205 | CfnOutput(self, "agent_id", value=agent.attr_agent_id) 206 | CfnOutput(self, "agent_alias_id", value=agent_alias.attr_agent_alias_id) 207 | CfnOutput(self, "agent_version", value=agent.attr_agent_version) 208 | -------------------------------------------------------------------------------- /v2/cdk/lambda_functions.py: -------------------------------------------------------------------------------- 1 | import aws_cdk.aws_iam as iam 2 | from aws_cdk import Duration, BundlingOptions 3 | from aws_cdk import aws_lambda as _lambda 4 | from constructs import Construct 5 | from .websearch_lambda_layers import WebSearchLambdaLayers 6 | 7 | 8 | def create_lambda_functions( 9 | self: Construct, 10 | construct_id: str, 11 | websearch_lambda_role: iam.Role, 12 | advanced_web_search_lambda_role: iam.Role, 13 | lambda_layers: WebSearchLambdaLayers, 14 | config: dict, 15 | ) -> tuple[_lambda.Function, _lambda.Function]: 16 | websearch_lambda = _lambda.Function( 17 | self, 18 | "WebSearch", 19 | function_name=config["WEBSEARCH_FUNCTION_NAME"], 20 | runtime=_lambda.Runtime.PYTHON_3_12, 21 | architecture=_lambda.Architecture.ARM_64, 22 | code=_lambda.Code.from_asset("lambda/websearch"), 23 | layers=[ 24 | lambda_layers.project_dependencies, 25 | lambda_layers.aws_lambda_powertools, 26 | ], 27 | handler="websearch_lambda.lambda_handler", 28 | timeout=Duration.seconds(300), 29 | role=websearch_lambda_role, 30 | environment={ 31 | "LOG_LEVEL": "DEBUG", 32 | "ACTION_GROUP": f"{config['WEBSEARCH_ACTION_GROUP_NAME']}", 33 | }, 34 | ) 35 | 36 | advanced_web_search_lambda = _lambda.Function( 37 | self, 38 | "AdvancedWebSearch", 39 | function_name=config["ADVANCED_SEARCH_FUNCTION_NAME"], 40 | runtime=_lambda.Runtime.PYTHON_3_12, 41 | architecture=_lambda.Architecture.ARM_64, 42 | code=_lambda.Code.from_asset("lambda/advanced_web_search"), 43 | layers=[ 44 | lambda_layers.project_dependencies, 45 | lambda_layers.aws_lambda_powertools, 46 | ], 47 | handler="advanced_web_search_lambda.lambda_handler", 48 | timeout=Duration.seconds(300), 49 | role=advanced_web_search_lambda_role, 50 | environment={ 51 | "LOG_LEVEL": "DEBUG", 52 | "ACTION_GROUP": f"{config['ADVANCED_SEARCH_ACTION_GROUP_NAME']}", 53 | "SMART_LLM": config["SMART_LLM"], 54 | "FAST_LLM": config["FAST_LLM"], 55 | "CROSS_REGION_INFERENCE": config["use_cross_region_inference"], 56 | }, 57 | ) 58 | 59 | return websearch_lambda, advanced_web_search_lambda 60 | 61 | 62 | def create_lambda_roles( 63 | self: Construct, 64 | construct_id: str, 65 | lambda_policy: iam.Policy, 66 | config: dict, 67 | ) -> tuple[iam.Role, iam.Role]: 68 | websearch_lambda_role = iam.Role( 69 | self, 70 | "WebSearchLambdaRole", 71 | role_name=f"{config['WEBSEARCH_FUNCTION_NAME']}_role", 72 | assumed_by=iam.ServicePrincipal("lambda.amazonaws.com"), 73 | ) 74 | websearch_lambda_role.attach_inline_policy(lambda_policy) 75 | 76 | advanced_web_search_lambda_role = iam.Role( 77 | self, 78 | "AdvancedWebSearchLambdaRole", 79 | role_name=f"{config['ADVANCED_SEARCH_FUNCTION_NAME']}_role", 80 | assumed_by=iam.ServicePrincipal("lambda.amazonaws.com"), 81 | ) 82 | advanced_web_search_lambda_role.attach_inline_policy(lambda_policy) 83 | 84 | cross_region_prefix = config["DEPLOY_REGION"][:2] + "." 85 | 86 | # Add Bedrock model invocation permissions to both Lambda roles 87 | bedrock_policy = iam.PolicyStatement( 88 | sid="AmazonBedrockInvokeModelPolicy", 89 | effect=iam.Effect.ALLOW, 90 | actions=[ 91 | "bedrock:InvokeModel", 92 | "bedrock:Converse", 93 | ], 94 | resources=[ 95 | f"arn:aws:bedrock:{config['DEPLOY_REGION']}::foundation-model/{config['SMART_LLM']}", 96 | f"arn:aws:bedrock:{config['DEPLOY_REGION']}::foundation-model/{config['FAST_LLM']}", 97 | f"arn:aws:bedrock:{config['DEPLOY_REGION']}::foundation-model/{cross_region_prefix}{config['SMART_LLM']}", 98 | f"arn:aws:bedrock:{config['DEPLOY_REGION']}::foundation-model/{cross_region_prefix}{config['FAST_LLM']}", 99 | ], 100 | ) 101 | websearch_lambda_role.add_to_policy(bedrock_policy) 102 | advanced_web_search_lambda_role.add_to_policy(bedrock_policy) 103 | 104 | return websearch_lambda_role, advanced_web_search_lambda_role 105 | -------------------------------------------------------------------------------- /v2/cdk/log_groups.py: -------------------------------------------------------------------------------- 1 | from aws_cdk import aws_logs as logs 2 | from constructs import Construct 3 | 4 | 5 | def create_log_groups( 6 | self: Construct, 7 | construct_id: str, 8 | config: dict, 9 | ) -> tuple[logs.LogGroup, logs.LogGroup]: 10 | websearch_log_group = logs.LogGroup( 11 | self, 12 | "WebSearchLogGroup", 13 | log_group_name=f"/aws/lambda/{config['WEBSEARCH_FUNCTION_NAME']}", 14 | retention=logs.RetentionDays.ONE_WEEK, 15 | ) 16 | 17 | advanced_web_search_log_group = logs.LogGroup( 18 | self, 19 | "AdvancedWebSearchLogGroup", 20 | log_group_name=f"/aws/lambda/{config['ADVANCED_SEARCH_FUNCTION_NAME']}", 21 | retention=logs.RetentionDays.ONE_WEEK, 22 | ) 23 | 24 | return websearch_log_group, advanced_web_search_log_group 25 | -------------------------------------------------------------------------------- /v2/cdk/utils.py: -------------------------------------------------------------------------------- 1 | import yaml 2 | from aws_cdk import Aws 3 | 4 | 5 | def load_config(file_path: str = "config.yaml") -> dict: 6 | with open(file_path, "r") as config_file: 7 | config = yaml.safe_load(config_file) 8 | 9 | DEPLOY_REGION = config.get("DEPLOY_REGION", "us-west-2") 10 | use_cross_region_inference = config.get("USE_CROSS_REGION_INFERENCE", "True") 11 | if use_cross_region_inference == True: 12 | prefix = DEPLOY_REGION[:2] 13 | SMART_LLM = ( 14 | prefix 15 | + "." 16 | + config.get("SMART_LLM", "anthropic.claude-3-5-sonnet-20240620-v1:0") 17 | ) 18 | FAST_LLM = ( 19 | prefix 20 | + "." 21 | + config.get("FAST_LLM", "anthropic.claude-3-haiku-20240307-v1:0") 22 | ) 23 | else: 24 | SMART_LLM = config.get("SMART_LLM", "anthropic.claude-3-5-sonnet-20240620-v1:0") 25 | FAST_LLM = config.get("FAST_LLM", "anthropic.claude-3-haiku-20240307-v1:0") 26 | 27 | config.update( 28 | { 29 | "DEPLOY_REGION": DEPLOY_REGION, 30 | "SMART_LLM": SMART_LLM, 31 | "FAST_LLM": FAST_LLM, 32 | "use_cross_region_inference": use_cross_region_inference, 33 | } 34 | ) 35 | 36 | return config 37 | 38 | 39 | def get_account_id() -> str: 40 | return Aws.ACCOUNT_ID 41 | -------------------------------------------------------------------------------- /v2/cdk/websearch_lambda_layers.py: -------------------------------------------------------------------------------- 1 | from typing import cast 2 | 3 | from aws_cdk import Aws, BundlingOptions, BundlingOutput, DockerImage, RemovalPolicy 4 | from aws_cdk import aws_lambda as _lambda 5 | from aws_cdk.aws_opsworks import CfnLayer 6 | from aws_cdk.aws_s3_assets import Asset 7 | from constructs import Construct 8 | 9 | # THE VERSION OF POWERTOOLS arn:aws:lambda:{region}:017000801446:layer:AWSLambdaPowertoolsPythonV3-{python_version}-arm64:2 with python version being python312 e.g. 10 | # from https://docs.powertools.aws.dev/lambda/python/latest/#lambda-layer 11 | 12 | AWS_LAMBDA_POWERTOOL_LAYER_VERSION_ARN = "arn:aws:lambda:{region}:017000801446:layer:AWSLambdaPowertoolsPythonV3-{python_version}-arm64:2" 13 | 14 | 15 | class WebSearchLambdaLayers(Construct): 16 | def __init__( 17 | self, 18 | scope: Construct, 19 | construct_id: str, 20 | stack_name: str, 21 | architecture: _lambda.Architecture, 22 | python_runtime: _lambda.Runtime = _lambda.Runtime.PYTHON_3_12, 23 | region: str = Aws.REGION, 24 | **kwargs, 25 | ) -> None: 26 | super().__init__(scope, construct_id, **kwargs) 27 | 28 | self._runtime = python_runtime 29 | self._architecture = architecture 30 | self._region = region 31 | self._python_handle = self._runtime.name.replace(".", "") 32 | # print(self._python_handle) 33 | 34 | # print(self._region) 35 | # print( 36 | # AWS_LAMBDA_POWERTOOL_LAYER_VERSION_ARN.format( 37 | # region=self._region, python_version=self._python_handle 38 | # ), 39 | # ) 40 | # AWS Lambda PowerTools 41 | self.aws_lambda_powertools = _lambda.LayerVersion.from_layer_version_arn( 42 | self, 43 | f"{stack_name}-lambda-powertools-layer", 44 | layer_version_arn=AWS_LAMBDA_POWERTOOL_LAYER_VERSION_ARN.format( 45 | region=self._region, python_version=self._python_handle 46 | ), 47 | ) 48 | 49 | # Project dependencies layer 50 | self.project_dependencies = self._create_layer_from_asset( 51 | layer_name=f"{stack_name}-project-dependencies-layer", 52 | path_to_layer_assets="lambda_layer/advanced_websearch_libraries", 53 | description="Lambda layer containing project dependencies (aiohttp, pydantic, langchain)", 54 | ) 55 | 56 | def _create_layer_from_asset( 57 | self, layer_name: str, path_to_layer_assets: str, description: str 58 | ) -> _lambda.LayerVersion: 59 | ecr = ( 60 | self._runtime.bundling_image.image 61 | + f":latest-{self._architecture.to_string()}" 62 | ) 63 | bundling_option = BundlingOptions( 64 | image=DockerImage(ecr), 65 | command=[ 66 | "bash", 67 | "-c", 68 | "pip --no-cache-dir install -r requirements.txt -t /asset-output/python", 69 | ], 70 | output_type=BundlingOutput.AUTO_DISCOVER, 71 | platform=self._architecture.docker_platform, 72 | network="host", 73 | ) 74 | layer_asset = Asset( 75 | self, 76 | f"{layer_name}-BundledAsset", 77 | path=path_to_layer_assets, 78 | bundling=bundling_option, 79 | ) 80 | layer_version = _lambda.LayerVersion( 81 | self, 82 | layer_name, 83 | code=_lambda.Code.from_bucket( 84 | layer_asset.bucket, layer_asset.s3_object_key 85 | ), 86 | compatible_runtimes=[self._runtime], 87 | compatible_architectures=[self._architecture], 88 | removal_policy=RemovalPolicy.DESTROY, 89 | layer_version_name=layer_name, 90 | description=description, 91 | ) 92 | 93 | # Adding metadata entries to CF template for local testing 94 | cfn_layer = cast(CfnLayer, layer_version.node.default_child) 95 | layer_asset.add_resource_metadata( 96 | resource=cfn_layer, resource_property="Content" 97 | ) 98 | return layer_version 99 | -------------------------------------------------------------------------------- /v2/config.yaml: -------------------------------------------------------------------------------- 1 | # Deployment configuration 2 | 3 | # AWS region where the stack will be deployed 4 | DEPLOY_REGION: "us-west-2" 5 | 6 | # Use cross-region-inference 7 | CROSS_REGION_INFERENCE: True 8 | 9 | # Model ID for the smart (more capable) language model 10 | SMART_LLM: "anthropic.claude-3-5-sonnet-20241022-v2:0" 11 | 12 | # Model ID for the fast (less capable but quicker) language model 13 | FAST_LLM: "anthropic.claude-3-haiku-20240307-v1:0" 14 | 15 | # Lambda function names (will be suffixed with deployment stage) 16 | WEBSEARCH_FUNCTION_NAME: "websearch-lambda" 17 | ADVANCED_SEARCH_FUNCTION_NAME: "advanced-websearch-lambda" 18 | 19 | # Action group names 20 | WEBSEARCH_ACTION_GROUP_NAME: "websearch" 21 | ADVANCED_SEARCH_ACTION_GROUP_NAME: "advanced-websearch" 22 | 23 | # Agent name (must match pattern ^([0-9a-zA-Z][_-]?){1,100}$) 24 | AGENT_NAME: "websearch_agent" 25 | -------------------------------------------------------------------------------- /v2/lambda/advanced_web_search/advanced_web_search_lambda.py: -------------------------------------------------------------------------------- 1 | import json 2 | import logging 3 | import os 4 | import asyncio 5 | from models import ( 6 | InitialQuery, 7 | AggregatedSearchResults, 8 | AdvancedWebSearchResult, 9 | FinalAnswer, 10 | ) 11 | from llm_operations import rewrite_query, analyze_results, formulate_final_answer 12 | from search_operations_tavily import perform_tavily_searches 13 | 14 | log_level = os.environ.get("LOG_LEVEL", "INFO").strip().upper() 15 | logging.basicConfig( 16 | format="[%(asctime)s] p%(process)s {%(filename)s:%(lineno)d} %(levelname)s - %(message)s" 17 | ) 18 | logger = logging.getLogger(__name__) 19 | logger.setLevel(log_level) 20 | 21 | ACTION_GROUP_NAME = os.environ.get("ACTION_GROUP", "action-group-advanced-web-search") 22 | FUNCTION_NAME = "advanced-web-search" 23 | 24 | 25 | async def advanced_web_search(search_query: str) -> AdvancedWebSearchResult: 26 | """ 27 | Performs an advanced web search with iterative query refinement and result analysis. 28 | 29 | This async function takes a search query and iteratively: 30 | 1. Rewrites the query into multiple variations 31 | 2. Performs searches using the Tavily search API 32 | 3. Analyzes the aggregated results 33 | 4. Either returns a final answer or refines the query for another iteration 34 | 35 | Args: 36 | search_query (str): The initial search query string 37 | 38 | Returns: 39 | AdvancedWebSearchResult: Object containing: 40 | - original_query: The initial search query 41 | - final_answer: FinalAnswer object with the answer and references 42 | - search_iterations: Number of search iterations performed 43 | - total_queries: Total number of queries executed 44 | - total_results: Total number of results retrieved 45 | 46 | The function will attempt up to 3 iterations of query refinement. If no satisfactory 47 | answer is found after 3 iterations, it returns a result indicating the failure to 48 | find an answer. 49 | """ 50 | initial_query = InitialQuery(query=search_query) 51 | iterations = 0 52 | max_iterations = 3 53 | total_queries = 0 54 | total_results = 0 55 | 56 | while iterations < max_iterations: 57 | iterations += 1 58 | rewritten_queries = rewrite_query(initial_query) 59 | new_queries = rewritten_queries.rewritten_queries 60 | tavily_results = await perform_tavily_searches(new_queries) 61 | 62 | total_queries += len(new_queries) 63 | total_results += sum(result.count for result in tavily_results) 64 | 65 | aggregated_results = AggregatedSearchResults( 66 | original_query=initial_query.query, 67 | rewritten_queries=new_queries, 68 | search_results=tavily_results, 69 | ) 70 | analysis = analyze_results(aggregated_results) 71 | 72 | if analysis.is_question_answered: 73 | final_answer = formulate_final_answer(aggregated_results) 74 | return AdvancedWebSearchResult( 75 | original_query=initial_query.query, 76 | final_answer=final_answer, 77 | search_iterations=iterations, 78 | total_queries=total_queries, 79 | total_results=total_results, 80 | ) 81 | 82 | initial_query = InitialQuery(query=analysis.explanation) 83 | 84 | return AdvancedWebSearchResult( 85 | original_query=search_query, 86 | final_answer=FinalAnswer( 87 | original_query=search_query, 88 | answer="Unable to find a satisfactory answer after multiple attempts", 89 | references=[], 90 | ), 91 | search_iterations=iterations, 92 | total_queries=total_queries, 93 | total_results=total_results, 94 | ) 95 | 96 | 97 | async def async_lambda_handler(event, context): 98 | """ 99 | Handles Lambda function invocations for advanced web search requests. 100 | 101 | This async function processes incoming Lambda events containing web search queries, 102 | executes the search using the advanced_web_search function, and formats the results 103 | into the expected response structure. 104 | Args: 105 | event (dict): AWS Lambda event object containing: 106 | - actionGroup: The action group identifier 107 | - function: The function name 108 | - parameters: List of parameters including search_query 109 | - messageVersion: Message format version 110 | context: AWS Lambda context object (unused) 111 | 112 | Returns: 113 | dict: Response object containing: 114 | - response: Action response with search results 115 | - messageVersion: Message format version from input event 116 | 117 | Raises: 118 | Exception: Any errors during execution are caught, logged and returned as error responses 119 | """ 120 | logging.debug(f"lambda_handler {event=}") 121 | 122 | try: 123 | action_group = event["actionGroup"] 124 | function = event["function"] 125 | parameters = event.get("parameters", []) 126 | 127 | logger.info(f"lambda_handler: {action_group=} {function=}") 128 | 129 | search_query = next( 130 | (param["value"] for param in parameters if param["name"] == "search_query"), 131 | None, 132 | ) 133 | 134 | if search_query: 135 | search_results = await advanced_web_search(search_query) 136 | result_text = json.dumps(search_results.dict(), indent=2) 137 | else: 138 | result_text = "Error: Invalid search query" 139 | 140 | logger.debug(f"query results {result_text=}") 141 | 142 | function_response_body = { 143 | "TEXT": { 144 | "body": f"Here are the advanced web search results for the query '{search_query}':\n{result_text}" 145 | } 146 | } 147 | 148 | action_response = { 149 | "actionGroup": action_group, 150 | "function": function, 151 | "functionResponse": {"responseBody": function_response_body}, 152 | } 153 | 154 | response = { 155 | "response": action_response, 156 | "messageVersion": event["messageVersion"], 157 | } 158 | 159 | logger.debug(f"lambda_handler: {response=}") 160 | 161 | return response 162 | 163 | except Exception as e: 164 | logger.error(f"An error occurred: {str(e)}") 165 | return { 166 | "response": { 167 | "actionGroup": ACTION_GROUP_NAME, 168 | "function": FUNCTION_NAME, 169 | "functionResponse": { 170 | "responseBody": { 171 | "TEXT": { 172 | "body": f"An error occurred while processing the request: {str(e)}" 173 | } 174 | } 175 | }, 176 | }, 177 | "messageVersion": event.get("messageVersion", "1.0"), 178 | } 179 | 180 | 181 | def lambda_handler(event, context): 182 | """ 183 | Main Lambda handler function that executes the async_lambda_handler. 184 | 185 | This function serves as a synchronous wrapper around the asynchronous 186 | async_lambda_handler, using asyncio.run to execute the async function 187 | in a synchronous context as required by AWS Lambda. 188 | 189 | Args: 190 | event (dict): AWS Lambda event object containing the request details 191 | context: AWS Lambda context object 192 | 193 | Returns: 194 | dict: The response from async_lambda_handler containing the search results 195 | or error information 196 | """ 197 | return asyncio.run(async_lambda_handler(event, context)) 198 | -------------------------------------------------------------------------------- /v2/lambda/advanced_web_search/llm_operations.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | from botocore.exceptions import ClientError 4 | from models import ( 5 | InitialQuery, 6 | RewrittenQueries, 7 | AggregatedSearchResults, 8 | LLMAnalysisResult, 9 | FinalAnswer, 10 | ) 11 | from strut_output_bedrock import create_bedrock_structured_output 12 | 13 | logger = logging.getLogger(__name__) 14 | 15 | SMART_LLM = os.environ.get("SMART_LLM", "anthropic.claude-3-5-sonnet-20240620-v1:0") 16 | FAST_LLM = os.environ.get("FAST_LLM", "anthropic.claude-3-haiku-20240307-v1:0") 17 | AWS_REGION = os.environ.get("AWS_REGION", "us-west-2") 18 | 19 | 20 | def rewrite_query(initial_query: InitialQuery) -> RewrittenQueries: 21 | try: 22 | system_prompt = """ 23 | As an AI language model specializing in query optimization, your task is to rewrite the given search query to improve search results. 24 | Follow these guidelines: 25 | 1. Maintain the original intent and core subject of the query. 26 | 2. Expand on any abbreviations or acronyms. 27 | 3. Add relevant context or specificity that might yield better results. 28 | 4. Consider alternative phrasings that might capture different aspects of the topic. 29 | 30 | Your output must be a JSON object with two fields: 31 | 1. 'original_query': The exact query provided to you. 32 | 2. 'rewritten_queries': A list of 3 rewritten versions of the original query. 33 | 34 | Example output format: 35 | { 36 | "original_query": "What are quantum computers?", 37 | "rewritten_queries": [ 38 | "Explain the basic principles and functionality of quantum computing", 39 | "How do quantum computers differ from classical computers in terms of processing and capabilities?", 40 | "Recent advancements and potential applications of quantum computing technology" 41 | ] 42 | } 43 | """ 44 | rewritten_queries = create_bedrock_structured_output( 45 | pydantic_model=RewrittenQueries, 46 | model_id=SMART_LLM, 47 | temperature=0.7, 48 | system_prompt=system_prompt, 49 | region_name=AWS_REGION, 50 | ) 51 | 52 | query_rewrite_prompt = f"Rewrite the following search query to improve search results: {initial_query.query}" 53 | logger.debug(f"Query rewrite prompt: {query_rewrite_prompt}") 54 | 55 | result = rewritten_queries.invoke(initial_query.query) 56 | logger.debug(f"Rewritten queries: {result}") 57 | return result 58 | 59 | except ClientError as e: 60 | logger.error(f"Error rewriting query: {e}") 61 | return RewrittenQueries( 62 | original_query=initial_query.query, rewritten_queries=[initial_query.query] 63 | ) 64 | 65 | 66 | def analyze_results(aggregated_results: AggregatedSearchResults) -> LLMAnalysisResult: 67 | try: 68 | system_prompt = """ 69 | As an advanced AI analyst specializing in information evaluation, your task is to thoroughly analyze the provided search results and determine if they sufficiently answer the original query. Follow these guidelines: 70 | 71 | 1. Comprehension: Carefully read and understand the original query and all search results. 72 | 2. Relevance Assessment: Evaluate how directly each result addresses the query's main points. 73 | 3. Depth of Information: Assess the depth and breadth of information provided in the results. 74 | 4. Credibility: Consider the sources of information and their reliability. 75 | 5. Completeness: Determine if all aspects of the query are addressed in the collective results. 76 | 6. Contradictions: Identify any conflicting information across the results. 77 | 7. Currency: Evaluate if the information is up-to-date and relevant to the current context. 78 | 79 | Your output must be a JSON object with two fields: 80 | 1. 'is_question_answered': A boolean indicating whether the search results sufficiently answer the original query. 81 | 2. 'explanation': A detailed explanation of your analysis, including: 82 | - Why you believe the question is or is not sufficiently answered. 83 | - Key points from the search results that support your conclusion. 84 | - Any gaps in information or areas that require further investigation. 85 | - Suggestions for refining the search if the question is not fully answered. 86 | 87 | Example output format: 88 | { 89 | "is_question_answered": true, 90 | "explanation": "The search results provide a comprehensive answer to the query about recent developments in quantum computing. Key points include: 1) Breakthrough in error correction techniques by Company X, 2) Demonstration of quantum supremacy by Company Y, and 3) Application of quantum algorithms in drug discovery by Research Institute Z. The information comes from reputable sources and covers both theoretical advancements and practical applications. While the results are sufficient, further investigation into the scalability of these developments could provide additional context." 91 | } 92 | 93 | Analyze the following search results and provide your assessment: 94 | """ 95 | 96 | analysis = create_bedrock_structured_output( 97 | pydantic_model=LLMAnalysisResult, 98 | model_id=SMART_LLM, 99 | temperature=0.5, 100 | system_prompt=system_prompt, 101 | region_name=AWS_REGION, 102 | ) 103 | return analysis.invoke(aggregated_results.dict()) 104 | except ClientError as e: 105 | logger.error(f"Error analyzing results: {e}") 106 | return LLMAnalysisResult( 107 | is_question_answered=False, 108 | explanation="Error occurred during analysis. The system encountered an issue while processing the search results. Please try again or refine your query for a new search.", 109 | ) 110 | 111 | 112 | def formulate_final_answer(aggregated_results: AggregatedSearchResults) -> FinalAnswer: 113 | try: 114 | system_prompt = """ 115 | As an advanced AI specializing in information synthesis and summarization, your task is to formulate a comprehensive final answer based on the provided search results. Follow these guidelines: 116 | 117 | 1. Comprehension: Thoroughly understand the original query and all search results. 118 | 2. Synthesis: Combine information from multiple sources to create a coherent and comprehensive answer. 119 | 3. Relevance: Ensure the answer directly addresses the original query. 120 | 4. Accuracy: Cross-reference information across sources to ensure factual correctness. 121 | 5. Completeness: Cover all major aspects of the query in your answer. 122 | 6. Clarity: Present the information in a clear, logical, and easy-to-understand manner. 123 | 7. Objectivity: Maintain a neutral tone and present different viewpoints if applicable. 124 | 8. Currency: Emphasize the most recent developments or information when relevant. 125 | 9. References: Properly cite sources used in formulating the answer. 126 | 127 | Your output must be a JSON object with the following fields: 128 | 1. 'original_query': The exact original query string. 129 | 2. 'answer': A detailed, well-structured answer to the query, typically 2-3 paragraphs long. 130 | 3. 'references': A list of dictionaries, each containing 'title' and 'url' of the sources used. 131 | 132 | Example output format: 133 | { 134 | "original_query": "What are the latest developments in quantum computing?", 135 | "answer": "Recent developments in quantum computing have been marked by significant breakthroughs in both theoretical and practical domains. One of the most notable advancements is in error correction techniques, crucial for maintaining quantum coherence. Researchers at Company X have developed a new quantum error correction code that can protect quantum information for substantially longer periods, potentially bringing us closer to practical quantum computers. 136 | 137 | Another major milestone is the demonstration of quantum supremacy by Company Y. Their 54-qubit processor completed a specific task in 200 seconds that would take the most powerful classical supercomputer approximately 10,000 years. This achievement, while still for a narrow application, represents a significant step towards proving the practical advantages of quantum computing. 138 | 139 | In the realm of applications, quantum algorithms are showing promise in drug discovery. Research Institute Z has successfully used a quantum simulator to model complex molecular interactions, potentially accelerating the process of identifying new therapeutic compounds. While these developments are exciting, it's important to note that many challenges remain, particularly in scaling up quantum systems and making them more robust for practical, real-world applications.", 140 | "references": [ 141 | { 142 | "title": "Breakthrough in Quantum Error Correction", 143 | "url": "https://company-x.com/quantum-error-correction" 144 | }, 145 | { 146 | "title": "Quantum Supremacy Achieved", 147 | "url": "https://company-y.com/quantum-supremacy-paper" 148 | }, 149 | { 150 | "title": "Quantum Computing in Drug Discovery", 151 | "url": "https://institute-z.org/quantum-drug-discovery" 152 | } 153 | ] 154 | } 155 | 156 | Analyze the following aggregated search results and formulate a final answer: 157 | """ 158 | 159 | final_answer = create_bedrock_structured_output( 160 | pydantic_model=FinalAnswer, 161 | model_id=SMART_LLM, 162 | temperature=0.7, 163 | system_prompt=system_prompt, 164 | region_name=AWS_REGION, 165 | ) 166 | return final_answer.invoke(aggregated_results.dict()) 167 | except ClientError as e: 168 | logger.error(f"Error formulating final answer: {e}") 169 | return FinalAnswer( 170 | original_query=aggregated_results.original_query, 171 | answer="We apologize, but an error occurred while formulating the final answer. This could be due to a temporary system issue or complexity in processing the search results. Please try your query again or rephrase it for a new search.", 172 | references=[], 173 | ) 174 | -------------------------------------------------------------------------------- /v2/lambda/advanced_web_search/models.py: -------------------------------------------------------------------------------- 1 | from typing import List, Dict, Any, Optional 2 | from pydantic import BaseModel, Field 3 | 4 | 5 | class InitialQuery(BaseModel): 6 | query: str = Field(..., description="The initial query from the human user") 7 | 8 | class Config: 9 | json_schema_extra = { 10 | "example": { 11 | "query": "What are the latest developments in quantum computing?" 12 | } 13 | } 14 | 15 | 16 | class RewrittenQueries(BaseModel): 17 | original_query: str = Field(..., description="The original query") 18 | rewritten_queries: List[str] = Field( 19 | ..., 20 | description="List of rewritten queries generated to improve diversity and quality of the search results", 21 | ) 22 | 23 | 24 | class TavilySearchQuery(BaseModel): 25 | query: str = Field(..., description="The search query") 26 | search_depth: str = Field("advanced", description="The depth of the search") 27 | include_images: bool = Field( 28 | False, description="Whether to include images in the results" 29 | ) 30 | include_answer: bool = Field( 31 | False, description="Whether to include an answer in the results" 32 | ) 33 | include_raw_content: bool = Field( 34 | False, description="Whether to include raw content in the results" 35 | ) 36 | max_results: int = Field(3, description="The maximum number of results to return") 37 | include_domains: List[str] = Field( 38 | [], description="List of domains to include in the search" 39 | ) 40 | exclude_domains: List[str] = Field( 41 | [], description="List of domains to exclude from the search" 42 | ) 43 | 44 | class Config: 45 | json_schema_extra = { 46 | "example": { 47 | "query": "Recent breakthroughs in quantum computing technology", 48 | "search_depth": "advanced", 49 | "include_images": False, 50 | "include_answer": False, 51 | "include_raw_content": False, 52 | "max_results": 3, 53 | "include_domains": [], 54 | "exclude_domains": [], 55 | } 56 | } 57 | 58 | 59 | class TavilySearchResult(BaseModel): 60 | query: str = Field(..., description="The search query used") 61 | results: List[Dict[str, Any]] = Field(..., description="List of search results") 62 | count: int = Field(..., description="Number of results returned") 63 | 64 | class Config: 65 | json_schema_extra = { 66 | "example": { 67 | "query": "Recent breakthroughs in quantum computing technology", 68 | "results": [ 69 | { 70 | "title": "Latest Quantum Computing Breakthroughs", 71 | "url": "https://example.com/quantum-breakthroughs", 72 | "content": "Summary of recent quantum computing advancements...", 73 | } 74 | ], 75 | "count": 1, 76 | } 77 | } 78 | 79 | 80 | class AggregatedSearchResults(BaseModel): 81 | original_query: str = Field(..., description="The original query") 82 | rewritten_queries: List[str] = Field(..., description="List of rewritten queries") 83 | search_results: List[TavilySearchResult] = Field( 84 | ..., description="List of search results for each query" 85 | ) 86 | 87 | class Config: 88 | json_schema_extra = { 89 | "example": { 90 | "original_query": "What are the latest developments in quantum computing?", 91 | "rewritten_queries": [ 92 | "Recent breakthroughs in quantum computing technology", 93 | "Current state of quantum computing research and applications", 94 | ], 95 | "search_results": [ 96 | { 97 | "query": "Recent breakthroughs in quantum computing technology", 98 | "results": [ 99 | { 100 | "title": "Latest Quantum Computing Breakthroughs", 101 | "url": "https://example.com/quantum-breakthroughs", 102 | "content": "Summary of recent quantum computing advancements...", 103 | } 104 | ], 105 | "count": 1, 106 | } 107 | ], 108 | } 109 | } 110 | 111 | 112 | class LLMAnalysisResult(BaseModel): 113 | is_question_answered: bool = Field( 114 | ..., description="Whether the question is sufficiently answered" 115 | ) 116 | explanation: str = Field(..., description="Explanation of the analysis") 117 | 118 | class Config: 119 | json_schema_extra = { 120 | "example": { 121 | "is_question_answered": True, 122 | "explanation": "The search results provide comprehensive information on recent quantum computing developments, including breakthroughs in error correction and quantum supremacy demonstrations.", 123 | } 124 | } 125 | 126 | 127 | class FinalAnswer(BaseModel): 128 | original_query: str = Field(..., description="The original query") 129 | answer: str = Field(..., description="The final answer to the query") 130 | references: List[Dict[str, str]] = Field( 131 | ..., description="List of references used in the answer" 132 | ) 133 | 134 | class Config: 135 | json_schema_extra = { 136 | "example": { 137 | "original_query": "What are the latest developments in quantum computing?", 138 | "answer": "Recent developments in quantum computing include significant progress in error correction techniques and demonstrations of quantum supremacy by major tech companies...", 139 | "references": [ 140 | { 141 | "title": "Latest Quantum Computing Breakthroughs", 142 | "url": "https://example.com/quantum-breakthroughs", 143 | } 144 | ], 145 | } 146 | } 147 | 148 | 149 | class AdvancedWebSearchResult(BaseModel): 150 | original_query: str = Field(..., description="The original query") 151 | final_answer: Optional[FinalAnswer] = Field( 152 | None, description="The final answer, if available" 153 | ) 154 | search_iterations: int = Field( 155 | ..., description="Number of search iterations performed" 156 | ) 157 | total_queries: int = Field(..., description="Total number of queries made") 158 | total_results: int = Field(..., description="Total number of results obtained") 159 | 160 | class Config: 161 | json_schema_extra = { 162 | "example": { 163 | "original_query": "What are the latest developments in quantum computing?", 164 | "final_answer": { 165 | "original_query": "What are the latest developments in quantum computing?", 166 | "answer": "Recent developments in quantum computing include significant progress in error correction techniques and demonstrations of quantum supremacy by major tech companies...", 167 | "references": [ 168 | { 169 | "title": "Latest Quantum Computing Breakthroughs", 170 | "url": "https://example.com/quantum-breakthroughs", 171 | } 172 | ], 173 | }, 174 | "search_iterations": 2, 175 | "total_queries": 3, 176 | "total_results": 5, 177 | } 178 | } 179 | -------------------------------------------------------------------------------- /v2/lambda/advanced_web_search/search_operations_tavily.py: -------------------------------------------------------------------------------- 1 | import json 2 | import logging 3 | import os 4 | import boto3 5 | from botocore.exceptions import ClientError 6 | import asyncio 7 | from typing import List 8 | from tavily import AsyncTavilyClient 9 | from models import TavilySearchResult 10 | 11 | from utils import get_from_secretstore_or_env 12 | 13 | logger = logging.getLogger(__name__) 14 | 15 | AWS_REGION = os.environ.get("AWS_REGION", "us-west-2") 16 | 17 | 18 | def get_tavily_api_key(): 19 | return get_from_secretstore_or_env("TAVILY_API_KEY", AWS_REGION) 20 | 21 | 22 | async def perform_tavily_searches(queries: List[str]) -> List[TavilySearchResult]: 23 | """ 24 | Performs asynchronous searches using the Tavily API for multiple queries. 25 | 26 | Args: 27 | queries (List[str]): A list of search queries to process. 28 | 29 | Returns: 30 | List[TavilySearchResult]: A list of TavilySearchResult objects containing the search results 31 | for each query. Each result includes the original query, search results, and result count. 32 | The Pydantic object is defined in models.py 33 | 34 | Example: 35 | queries = ["tesla news", "AI advances"] 36 | results = await perform_tavily_searches(queries) 37 | """ 38 | tavily_api_key = get_tavily_api_key() 39 | client = AsyncTavilyClient(api_key=tavily_api_key) 40 | tasks = [search_tavily(client, query) for query in queries] 41 | return await asyncio.gather(*tasks) 42 | 43 | 44 | async def search_tavily(client: AsyncTavilyClient, query: str) -> TavilySearchResult: 45 | """ 46 | Performs a single search using the Tavily API client. 47 | Args: 48 | client (AsyncTavilyClient): The initialized Tavily API client instance 49 | query (str): The search query string to process 50 | 51 | Returns: 52 | TavilySearchResult: A TavilySearchResult object containing: 53 | - query: The original search query 54 | - results: List of search results (empty list if error) 55 | - count: Number of results found (0 if error) 56 | 57 | The search is configured to: 58 | - Use "advanced" search depth 59 | - Exclude images 60 | - Exclude answer summaries 61 | - Exclude raw content 62 | - Return maximum 3 results 63 | 64 | If an error occurs during the API request, it will be logged and an empty 65 | result will be returned rather than raising an exception. 66 | """ 67 | try: 68 | response = await client.search( 69 | query=query, 70 | search_depth="advanced", 71 | include_images=False, 72 | include_answer=False, 73 | include_raw_content=False, 74 | max_results=3, 75 | ) 76 | return TavilySearchResult( 77 | query=query, 78 | results=response.get("results", []), 79 | count=len(response.get("results", [])), 80 | ) 81 | except Exception as e: 82 | logger.error(f"Error during Tavily API request for query '{query}': {str(e)}") 83 | return TavilySearchResult(query=query, results=[], count=0) 84 | -------------------------------------------------------------------------------- /v2/lambda/advanced_web_search/strut_output_bedrock.py: -------------------------------------------------------------------------------- 1 | import json 2 | from typing import List, Dict, Any, Optional, Type 3 | from pydantic import BaseModel, ValidationError 4 | from langchain_core.runnables import Runnable 5 | from langchain_core.messages import AIMessage, HumanMessage 6 | import boto3 7 | import os 8 | import json 9 | 10 | 11 | class MaxTokensReachedException(Exception): 12 | pass 13 | 14 | 15 | class BedrockStructuredOutput(Runnable): 16 | 17 | def __init__( 18 | self, 19 | pydantic_model: Type[BaseModel], 20 | model_id: str = "anthropic.claude-3-haiku-20240307-v1:0", 21 | temperature: float = 0.0, 22 | max_tokens: Optional[int] = 4096, 23 | top_p: float = 1.0, 24 | region_name: str = "us-west-2", # os.getenv("AWS_REGION"), # TODO: CHange this back to normal region 25 | system_prompt: str = "", 26 | include_raw: bool = False, 27 | bedrock_client: Optional[boto3.client] = None, 28 | ): 29 | self.pydantic_model = pydantic_model 30 | self.model_id = model_id 31 | self.temperature = temperature 32 | self.max_tokens = 4096 if max_tokens is None else max_tokens # Set default here 33 | self.top_p = top_p 34 | self.region_name = region_name 35 | self.system_prompt = system_prompt 36 | self.include_raw = include_raw 37 | self.tool_config = self._create_tool_config() 38 | self.bedrock_client = ( 39 | bedrock_client 40 | if bedrock_client is not None 41 | else boto3.client("bedrock-runtime", region_name="us-west-2") 42 | ) 43 | 44 | def _create_tool_config(self) -> Dict: 45 | json_schema = self.pydantic_model.model_json_schema() 46 | return { 47 | "tools": [ 48 | { 49 | "toolSpec": { 50 | "name": "structured_output", 51 | "description": "Generate structured output", 52 | "inputSchema": {"json": json_schema}, 53 | } 54 | } 55 | ], 56 | "toolChoice": {"tool": {"name": "structured_output"}}, 57 | } 58 | 59 | def _prepare_message(self, inputs: Any) -> List[Dict[str, Any]]: 60 | return [{"role": "user", "content": [{"text": str(inputs)}]}] 61 | 62 | def _prepare_system_message(self) -> List[Dict[str, Any]]: 63 | if self.system_prompt: 64 | return [{"text": self.system_prompt}] 65 | return [] 66 | 67 | def invoke(self, inputs: Any, config: Optional[Dict[str, Any]] = None) -> Any: 68 | messages = self._prepare_message(inputs) 69 | system = self._prepare_system_message() 70 | 71 | response = self.bedrock_client.converse( 72 | modelId=self.model_id, 73 | messages=messages, 74 | system=system, 75 | inferenceConfig={ 76 | "temperature": self.temperature, 77 | "maxTokens": self.max_tokens, 78 | "topP": self.top_p, 79 | }, 80 | toolConfig=self.tool_config, 81 | ) 82 | 83 | if self.include_raw: 84 | return self._structure_output(response) 85 | else: 86 | return self._parse_response(response) 87 | 88 | def _structure_output(self, response: Dict) -> Dict[str, Any]: 89 | """ 90 | Structures the model response into a dictionary containing both raw and parsed outputs. 91 | 92 | Args: 93 | response (Dict): The raw response dictionary from the Bedrock model 94 | 95 | Returns: 96 | Dict[str, Any]: A dictionary containing: 97 | - raw (AIMessage): The raw model response structured as an AIMessage 98 | - parsed (Optional[BaseModel]): The response parsed into the configured Pydantic model, 99 | or None if parsing failed 100 | - parsing_error (Optional[str]): Error message if parsing failed, None otherwise 101 | 102 | Notes: 103 | - Attempts to create both raw and parsed versions of the model output 104 | - If parsing fails, the parsed output will be None and the error will be captured 105 | - Raw output is always included regardless of parsing success 106 | """ 107 | raw_output = self._create_raw_output(response) 108 | try: 109 | parsed_output = self._parse_response(response) 110 | parsing_error = None 111 | except Exception as e: 112 | parsed_output = None 113 | parsing_error = str(e) 114 | 115 | return { 116 | "raw": raw_output, 117 | "parsed": parsed_output, 118 | "parsing_error": parsing_error, 119 | } 120 | 121 | def _create_raw_output(self, response: Dict) -> AIMessage: 122 | """ 123 | Creates a structured AIMessage from the raw Bedrock model response. 124 | 125 | Args: 126 | response (Dict): The raw response dictionary from the Bedrock model containing: 127 | - output.message.content: List of content items 128 | - usage: Token usage statistics 129 | - stopReason: Reason for response completion 130 | - modelId: ID of the model used 131 | - id: Response identifier 132 | 133 | Returns: 134 | AIMessage: A message object containing: 135 | - content: Empty string 136 | - additional_kwargs: Dictionary with stop_reason and model_id 137 | - id: Response identifier 138 | - tool_calls: List containing extracted tool call information 139 | - usage_metadata: Dictionary with token usage statistics: 140 | - input_tokens: Number of prompt tokens 141 | - output_tokens: Number of completion tokens 142 | - total_tokens: Total tokens used 143 | """ 144 | output_message = response.get("output", {}).get("message", {}) 145 | content = output_message.get("content", []) 146 | usage = response.get("usage", {}) 147 | stop_reason = response.get("stopReason") 148 | model_id = response.get("modelId") 149 | 150 | return AIMessage( 151 | content="", 152 | additional_kwargs={ 153 | "stop_reason": stop_reason, 154 | "model_id": model_id, 155 | }, 156 | id=response.get("id", ""), 157 | tool_calls=[self._extract_tool_call(content)], 158 | usage_metadata={ 159 | "input_tokens": usage.get("promptTokens", 0), 160 | "output_tokens": usage.get("completionTokens", 0), 161 | "total_tokens": usage.get("totalTokens", 0), 162 | }, 163 | ) 164 | 165 | def _extract_tool_call(self, content: List[Dict]) -> Dict: 166 | """ 167 | Extracts tool call information from the model response content. 168 | 169 | Args: 170 | content (List[Dict]): The content section of the model response containing possible tool use data. 171 | Returns: 172 | Dict: A dictionary containing the tool call information with the following structure: 173 | { 174 | 'name': The name of the tool that was called (str), 175 | 'args': The input arguments passed to the tool (Dict), 176 | 'id': The unique identifier of the tool call (str), 177 | 'type': Always set to 'tool_call' (str) 178 | } 179 | Returns an empty dictionary if no tool use is found in the content. 180 | """ 181 | tool_use = next((c["toolUse"] for c in content if "toolUse" in c), None) 182 | if tool_use: 183 | return { 184 | "name": tool_use.get("name", ""), 185 | "args": tool_use.get("input", {}), 186 | "id": tool_use.get("id", ""), 187 | "type": "tool_call", 188 | } 189 | return {} 190 | 191 | def _parse_response(self, response: Dict) -> BaseModel: 192 | """ 193 | Parses the response from the Bedrock model and converts it to a Pydantic model instance. 194 | 195 | Args: 196 | response (Dict): The raw response dictionary from the Bedrock model 197 | 198 | Returns: 199 | BaseModel: An instance of the configured Pydantic model containing the structured output 200 | 201 | Raises: 202 | MaxTokensReachedException: If the response was truncated due to reaching the max token limit 203 | and either no structured output was found or the output failed validation 204 | ValidationError: If the structured output fails Pydantic model validation (when not truncated) 205 | ValueError: If no structured output is found in a complete (non-truncated) response 206 | 207 | Notes: 208 | - Checks if response was truncated due to max tokens and prints warning if so 209 | - Extracts structured output from the tool use section of the response 210 | - Validates the output against the configured Pydantic model 211 | - Handles truncation cases by raising MaxTokensReachedException when appropriate 212 | """ 213 | stop_reason = response.get("stopReason") 214 | is_truncated = stop_reason == "max_tokens" 215 | 216 | if is_truncated: 217 | print( 218 | "WARNING: The response was truncated due to reaching the max token limit." 219 | ) 220 | 221 | output_message = response.get("output", {}).get("message", {}) 222 | content = output_message.get("content", []) 223 | 224 | tool_use = next((c["toolUse"] for c in content if "toolUse" in c), None) 225 | if tool_use: 226 | output_dict = tool_use["input"] 227 | try: 228 | parsed_output = self.pydantic_model(**output_dict) 229 | if is_truncated: 230 | print("Note: Parsed output may be incomplete due to truncation.") 231 | return parsed_output 232 | except ValidationError as e: 233 | print(f"Validation error: {e}") 234 | if is_truncated: 235 | raise MaxTokensReachedException( 236 | "Unable to parse the complete structure due to max tokens limit being reached." 237 | ) 238 | else: 239 | raise # Re-raise the original ValidationError if not due to truncation 240 | else: 241 | if is_truncated: 242 | raise MaxTokensReachedException( 243 | "Structured output not found in the response due to max tokens limit being reached." 244 | ) 245 | else: 246 | raise ValueError("Structured output not found in the response") 247 | 248 | 249 | def create_bedrock_structured_output( 250 | pydantic_model: Type[BaseModel], 251 | model_id: str, 252 | temperature: float = 0, 253 | max_tokens: Optional[int] = None, 254 | top_p: float = 1.0, 255 | region_name: str = "us-east-1", 256 | system_prompt: str = "", 257 | include_raw: bool = False, 258 | bedrock_client: Optional[boto3.client] = None, 259 | ) -> BedrockStructuredOutput: 260 | """ 261 | Creates a BedrockStructuredOutput instance for generating structured responses using Amazon Bedrock. 262 | 263 | Args: 264 | pydantic_model (Type[BaseModel]): The Pydantic model class that defines the structure of the output. 265 | model_id (str): The identifier of the Amazon Bedrock model to use. 266 | temperature (float, optional): Controls randomness in the model's output. Defaults to 0. 267 | max_tokens (Optional[int], optional): Maximum number of tokens in the response. Defaults to None. 268 | top_p (float, optional): Controls diversity of model output via nucleus sampling. Defaults to 1.0. 269 | region_name (str, optional): AWS region for the Bedrock service. Defaults to "us-east-1". 270 | system_prompt (str, optional): System prompt to guide the model's behavior. Defaults to "". 271 | include_raw (bool, optional): Whether to include raw response data. Defaults to False. 272 | bedrock_client (Optional[boto3.client], optional): Pre-configured Bedrock client. Defaults to None. 273 | 274 | Returns: 275 | BedrockStructuredOutput: An instance of BedrockStructuredOutput configured with the provided parameters. 276 | """ 277 | if bedrock_client is None: 278 | bedrock_client = boto3.client("bedrock-runtime", region_name=region_name) 279 | 280 | return BedrockStructuredOutput( 281 | pydantic_model, 282 | model_id, 283 | temperature, 284 | max_tokens, 285 | top_p, 286 | region_name, 287 | system_prompt, 288 | include_raw, 289 | bedrock_client, 290 | ) 291 | -------------------------------------------------------------------------------- /v2/lambda/advanced_web_search/test_advanced_web_search.py: -------------------------------------------------------------------------------- 1 | import json 2 | from advanced_web_search_lambda import lambda_handler 3 | 4 | # Mock event based on the provided test event 5 | mock_event = { 6 | "messageVersion": "1.0", 7 | "function": "advanced-web-search", 8 | "parameters": [ 9 | { 10 | "name": "search_query", 11 | "type": "string", 12 | "value": "ASML share price dip reason recent", 13 | } 14 | ], 15 | "sessionId": "510646607739287", 16 | "agent": { 17 | "name": "websearch_agent", 18 | "version": "DRAFT", 19 | "id": "Y3JFGHUJ6H", 20 | "alias": "TSTALIASID", 21 | }, 22 | "actionGroup": "action-group-advanced-web-search", 23 | "sessionAttributes": {}, 24 | "promptSessionAttributes": {}, 25 | "inputText": "Advanced web search. Why did ASML share price dip?", 26 | } 27 | 28 | 29 | # Mock context object 30 | class MockContext: 31 | def __init__(self): 32 | self.function_name = "test_advanced_web_search" 33 | self.memory_limit_in_mb = 128 34 | self.invoked_function_arn = ( 35 | "arn:aws:lambda:us-west-2:123456789012:function:test_advanced_web_search" 36 | ) 37 | 38 | 39 | # Create an instance of the mock context 40 | mock_context = MockContext() 41 | 42 | 43 | def test_advanced_web_search(): 44 | # Call the lambda_handler function with mock event and context 45 | response = lambda_handler(mock_event, mock_context) 46 | 47 | # Print the response 48 | print(json.dumps(response, indent=2)) 49 | 50 | # Add assertions here to verify the response structure and content 51 | assert "response" in response 52 | assert "actionGroup" in response["response"] 53 | assert "function" in response["response"] 54 | assert "functionResponse" in response["response"] 55 | assert "responseBody" in response["response"]["functionResponse"] 56 | assert "TEXT" in response["response"]["functionResponse"]["responseBody"] 57 | assert "body" in response["response"]["functionResponse"]["responseBody"]["TEXT"] 58 | 59 | 60 | if __name__ == "__main__": 61 | test_advanced_web_search() 62 | -------------------------------------------------------------------------------- /v2/lambda/advanced_web_search/utils.py: -------------------------------------------------------------------------------- 1 | import json 2 | import logging 3 | import os 4 | import boto3 5 | from botocore.exceptions import ClientError 6 | from typing import Optional 7 | 8 | 9 | logger = logging.getLogger(__name__) 10 | logger.setLevel(logging.INFO) 11 | 12 | 13 | def is_env_var_set(env_var: str) -> bool: 14 | """ 15 | Check if an environment variable exists and has a non-empty/non-false value. 16 | 17 | Args: 18 | env_var (str): The name of the environment variable to check 19 | 20 | Returns: 21 | bool: True if the environment variable exists and has a value other than 22 | empty string, "0", "false", or "False". False otherwise. 23 | """ 24 | return env_var in os.environ and os.environ[env_var] not in ( 25 | "", 26 | "0", 27 | "false", 28 | "False", 29 | ) 30 | 31 | 32 | def get_from_secretstore_or_env(key: str, region: Optional[str] = None) -> str: 33 | """ 34 | Retrieve a secret value either from environment variables or AWS Secrets Manager. 35 | 36 | First checks if the value exists as an environment variable. If it does, returns 37 | that value with a warning. If not, attempts to retrieve the value from AWS 38 | Secrets Manager in the specified region. 39 | 40 | Args: 41 | key (str): The key/name of the secret to retrieve 42 | region (Optional[str]): AWS region for Secrets Manager. If None, uses AWS_REGION 43 | environment variable 44 | 45 | Returns: 46 | str: The secret value 47 | 48 | Raises: 49 | Exception: If the secret cannot be retrieved from AWS Secrets Manager 50 | """ 51 | if is_env_var_set(key): 52 | logger.warning( 53 | f"getting value for {key} from environment var; recommended to use AWS Secrets Manager instead" 54 | ) 55 | return os.environ[key] 56 | 57 | session = boto3.session.Session() 58 | secrets_manager = session.client( 59 | service_name="secretsmanager", 60 | region_name=region if region else os.getenv("AWS_REGION"), 61 | ) 62 | logger.info(f"getting secret {key} from AWS Secrets Manager in region {region}.") 63 | try: 64 | secret_value = secrets_manager.get_secret_value(SecretId=key) 65 | except Exception as e: 66 | logger.error(f"could not get secret {key} from secrets manager: {e}") 67 | raise e 68 | 69 | secret: str = secret_value["SecretString"] 70 | 71 | return secret 72 | -------------------------------------------------------------------------------- /v2/lambda/websearch/config.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | AWS_REGION = os.environ.get("AWS_REGION", "us-west-2") 4 | ACTION_GROUP_NAME = os.environ.get("ACTION_GROUP", "action-group-web-search-d213q") 5 | FUNCTION_NAMES = ["tavily-ai-search", "google-search"] 6 | -------------------------------------------------------------------------------- /v2/lambda/websearch/exceptions.py: -------------------------------------------------------------------------------- 1 | class SearchError(Exception): 2 | """Base exception for search operations""" 3 | pass 4 | 5 | class ConfigurationError(SearchError): 6 | """Raised when there's a configuration issue""" 7 | pass 8 | 9 | class APIError(SearchError): 10 | """Raised when an API call fails""" 11 | pass 12 | -------------------------------------------------------------------------------- /v2/lambda/websearch/google_search.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aws-samples/websearch_agent/fbd3bc0b08af39e6fc184419e0ac0d1785f1a106/v2/lambda/websearch/google_search.py -------------------------------------------------------------------------------- /v2/lambda/websearch/models.py: -------------------------------------------------------------------------------- 1 | from typing import List, Dict, Any, Optional 2 | from pydantic import BaseModel 3 | 4 | class SearchParameters(BaseModel): 5 | name: str 6 | value: str 7 | 8 | class SearchEvent(BaseModel): 9 | actionGroup: str 10 | function: str 11 | parameters: List[SearchParameters] 12 | messageVersion: str 13 | 14 | class SearchResponse(BaseModel): 15 | response: Dict[str, Any] 16 | messageVersion: str 17 | -------------------------------------------------------------------------------- /v2/lambda/websearch/response_formatter.py: -------------------------------------------------------------------------------- 1 | import json 2 | from typing import Dict, Any 3 | 4 | def format_lambda_response( 5 | action_group: str, 6 | function: str, 7 | search_query: str, 8 | search_results: Dict[str, Any], 9 | message_version: str 10 | ) -> Dict[str, Any]: 11 | function_response_body = { 12 | "TEXT": { 13 | "body": f"Here are the top search results for the query '{search_query}': {json.dumps(search_results)} " 14 | } 15 | } 16 | 17 | action_response = { 18 | "actionGroup": action_group, 19 | "function": function, 20 | "functionResponse": {"responseBody": function_response_body}, 21 | } 22 | 23 | return { 24 | "response": action_response, 25 | "messageVersion": message_version 26 | } 27 | -------------------------------------------------------------------------------- /v2/lambda/websearch/search_providers.py: -------------------------------------------------------------------------------- 1 | import json 2 | import http.client 3 | import urllib.request 4 | import urllib.parse 5 | from abc import ABC, abstractmethod 6 | from typing import Dict, Any, Optional 7 | from utils import logger 8 | from exceptions import APIError 9 | 10 | class SearchProvider(ABC): 11 | @abstractmethod 12 | def __init__(self, api_key: str): 13 | self.api_key = api_key 14 | 15 | @abstractmethod 16 | def search( 17 | self, 18 | query: str, 19 | target_website: Optional[str] = None, 20 | search_type: str = "search", 21 | **kwargs 22 | ) -> Dict[str, Any]: 23 | pass 24 | 25 | class TavilySearchProvider(SearchProvider): 26 | def __init__(self, api_key: str): 27 | self.api_key = api_key 28 | 29 | def search( 30 | self, 31 | query: str, 32 | target_website: Optional[str] = None, 33 | search_depth: str = "advanced", 34 | max_results: int = 3, 35 | topic: str = "general", 36 | days: int = 3, 37 | include_answer: bool = False, 38 | include_raw_content: bool = False, 39 | include_images: bool = False, 40 | **kwargs 41 | ) -> Dict[str, Any]: 42 | logger.info(f"Executing Tavily AI search with query={query}") 43 | 44 | base_url = "https://api.tavily.com/search" 45 | headers = {"Content-Type": "application/json", "Accept": "application/json"} 46 | payload = { 47 | "api_key": self.api_key, 48 | "query": query, 49 | "search_depth": search_depth, 50 | "topic": topic, 51 | "max_results": max_results, 52 | "include_images": include_images, 53 | "include_answer": include_answer, 54 | "include_raw_content": include_raw_content, 55 | "include_domains": [target_website] if target_website else [], 56 | "exclude_domains": [], 57 | } 58 | 59 | # Only add days parameter if topic is news 60 | if topic == "news" and days: 61 | payload["days"] = days 62 | 63 | try: 64 | data = json.dumps(payload).encode("utf-8") 65 | request = urllib.request.Request(base_url, data=data, headers=headers) 66 | response = urllib.request.urlopen(request) 67 | response_data = json.loads(response.read().decode("utf-8")) 68 | logger.debug(f"Response from Tavily AI search: {response_data}") 69 | return response_data 70 | except Exception as e: 71 | logger.error(f"Failed to retrieve search results from Tavily AI Search: {str(e)}") 72 | raise APIError(f"Tavily AI Search error: {str(e)}") 73 | 74 | class GoogleSearchProvider(SearchProvider): 75 | def __init__(self, api_key: str): 76 | self.api_key = api_key 77 | 78 | def search( 79 | self, 80 | query: str, 81 | target_website: Optional[str] = None, 82 | search_type: str = "search", 83 | time_period: Optional[str] = None, 84 | country_code: str = "us", 85 | **kwargs 86 | ) -> Dict[str, Any]: 87 | logger.info(f"Executing Google search with query={query}") 88 | 89 | if target_website: 90 | query = f"site:{target_website} {query}" 91 | 92 | conn = http.client.HTTPSConnection("google.serper.dev") 93 | payload = {"q": query, "gl": country_code} 94 | 95 | if time_period: 96 | payload["tbs"] = time_period 97 | 98 | headers = {"X-API-KEY": self.api_key, "Content-Type": "application/json"} 99 | 100 | try: 101 | conn.request("POST", f"/{search_type}", json.dumps(payload), headers) 102 | res = conn.getresponse() 103 | data = res.read() 104 | response = json.loads(data.decode("utf-8")) 105 | 106 | if res.status != 200: 107 | error_msg = f"API request failed with status code {res.status}: {response.get('message', 'Unknown error')}" 108 | logger.error(error_msg) 109 | raise APIError(error_msg) 110 | 111 | return self._process_response(response, search_type) 112 | except Exception as e: 113 | logger.error(f"Failed to retrieve search results from Google Search: {str(e)}") 114 | raise APIError(f"Google Search error: {str(e)}") 115 | finally: 116 | conn.close() 117 | 118 | def _process_response(self, response: Dict[str, Any], search_type: str) -> Dict[str, Any]: 119 | processed_results = [] 120 | 121 | if search_type == "search": 122 | organic = response.get("organic", []) 123 | for result in organic: 124 | processed_results.append({ 125 | "title": result.get("title"), 126 | "link": result.get("link"), 127 | "snippet": result.get("snippet"), 128 | "position": result.get("position"), 129 | }) 130 | elif search_type == "news": 131 | news = response.get("news", []) 132 | for result in news: 133 | processed_results.append({ 134 | "title": result.get("title"), 135 | "link": result.get("link"), 136 | "snippet": result.get("snippet"), 137 | "date": result.get("date"), 138 | "source": result.get("source"), 139 | }) 140 | 141 | return { 142 | "searchMetadata": { 143 | "query": response.get("searchParameters", {}).get("q"), 144 | "totalResults": response.get("searchInformation", {}).get("totalResults"), 145 | }, 146 | "results": processed_results, 147 | } 148 | -------------------------------------------------------------------------------- /v2/lambda/websearch/test_websearch.py: -------------------------------------------------------------------------------- 1 | import json 2 | from websearch_lambda import lambda_handler 3 | 4 | # Mock event based on the provided test event 5 | mock_event = { 6 | "messageVersion": "1.0", 7 | "function": "tavily-ai-search", 8 | "parameters": [ 9 | { 10 | "name": "search_query", 11 | "type": "string", 12 | "value": "Latest developments in quantum computing", 13 | }, 14 | { 15 | "name": "search_depth", 16 | "type": "string", 17 | "value": "advanced", 18 | }, 19 | { 20 | "name": "max_results", 21 | "type": "number", 22 | "value": "3", 23 | }, 24 | { 25 | "name": "topic", 26 | "type": "string", 27 | "value": "general", 28 | } 29 | ], 30 | "sessionId": "510646607739287", 31 | "agent": { 32 | "name": "websearch_agent", 33 | "version": "DRAFT", 34 | "id": "Y3JFGHUJ6H", 35 | "alias": "TSTALIASID", 36 | }, 37 | "actionGroup": "action-group-web-search-d213q", 38 | "sessionAttributes": {}, 39 | "promptSessionAttributes": {}, 40 | "inputText": "Search for latest developments in quantum computing", 41 | } 42 | 43 | # Mock context object 44 | class MockContext: 45 | def __init__(self): 46 | self.function_name = "test_websearch" 47 | self.memory_limit_in_mb = 128 48 | self.invoked_function_arn = ( 49 | "arn:aws:lambda:us-west-2:123456789012:function:test_websearch" 50 | ) 51 | 52 | # Create an instance of the mock context 53 | mock_context = MockContext() 54 | 55 | def test_websearch(): 56 | # Call the lambda_handler function with mock event and context 57 | response = lambda_handler(mock_event, mock_context) 58 | 59 | # Print the response 60 | print(json.dumps(response, indent=2)) 61 | 62 | # Add assertions here to verify the response structure and content 63 | assert "response" in response 64 | assert "actionGroup" in response["response"] 65 | assert "function" in response["response"] 66 | assert "functionResponse" in response["response"] 67 | assert "responseBody" in response["response"]["functionResponse"] 68 | assert "TEXT" in response["response"]["functionResponse"]["responseBody"] 69 | assert "body" in response["response"]["functionResponse"]["responseBody"]["TEXT"] 70 | 71 | if __name__ == "__main__": 72 | test_websearch() 73 | -------------------------------------------------------------------------------- /v2/lambda/websearch/utils.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import boto3 4 | from typing import Optional 5 | 6 | def setup_logging() -> logging.Logger: 7 | """Configure and return a logger with proper formatting""" 8 | log_level = os.environ.get("LOG_LEVEL", "INFO").strip().upper() 9 | logging.basicConfig( 10 | format="[%(asctime)s] p%(process)s {%(filename)s:%(lineno)d} %(levelname)s - %(message)s" 11 | ) 12 | logger = logging.getLogger(__name__) 13 | logger.setLevel(log_level) 14 | return logger 15 | 16 | def is_env_var_set(env_var: str) -> bool: 17 | """Check if environment variable is set and truthy""" 18 | return env_var in os.environ and os.environ[env_var] not in ( 19 | "", 20 | "0", 21 | "false", 22 | "False", 23 | ) 24 | 25 | def get_from_secretstore_or_env(key: str, region: str) -> str: 26 | """Get value from Secrets Manager or environment variable""" 27 | if is_env_var_set(key): 28 | logger.warning( 29 | f"Getting value for {key} from environment var; recommended to use AWS Secrets Manager instead" 30 | ) 31 | return os.environ[key] 32 | 33 | session = boto3.session.Session() 34 | secrets_manager = session.client( 35 | service_name="secretsmanager", region_name=region 36 | ) 37 | try: 38 | secret_value = secrets_manager.get_secret_value(SecretId=key) 39 | except Exception as e: 40 | logger.error(f"Could not get secret {key} from secrets manager: {e}") 41 | raise e 42 | 43 | return secret_value["SecretString"] 44 | 45 | logger = setup_logging() 46 | -------------------------------------------------------------------------------- /v2/lambda/websearch/validators.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple, Optional 2 | from exceptions import ConfigurationError 3 | from utils import logger 4 | 5 | def validate_search_params( 6 | action_group: str, 7 | function: str, 8 | parameters: list, 9 | valid_action_group: str, 10 | valid_functions: list 11 | ) -> Tuple[str, Optional[str]]: 12 | if action_group != valid_action_group: 13 | raise ConfigurationError(f"Invalid action group: {action_group}") 14 | 15 | if function not in valid_functions: 16 | raise ConfigurationError(f"Invalid function: {function}") 17 | 18 | search_query = next( 19 | (param["value"] for param in parameters if param["name"] == "search_query"), 20 | None, 21 | ) 22 | if not search_query: 23 | raise ConfigurationError("Missing required parameter: search_query") 24 | 25 | target_website = next( 26 | (param["value"] for param in parameters if param["name"] == "target_website"), 27 | None, 28 | ) 29 | 30 | logger.debug(f"validate_search_params: {search_query=} {target_website=}") 31 | return search_query, target_website 32 | -------------------------------------------------------------------------------- /v2/lambda/websearch/websearch_lambda.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # SPDX-License-Identifier: MIT-0 3 | import json 4 | from typing import Dict, Any, Optional 5 | 6 | from utils import logger, get_from_secretstore_or_env 7 | from models import SearchEvent, SearchResponse 8 | from search_providers import TavilySearchProvider, GoogleSearchProvider 9 | from config import AWS_REGION, ACTION_GROUP_NAME, FUNCTION_NAMES 10 | from validators import validate_search_params 11 | from response_formatter import format_lambda_response 12 | from exceptions import SearchError, ConfigurationError, APIError 13 | 14 | 15 | # Initialize API keys 16 | SERPER_API_KEY = get_from_secretstore_or_env("SERPER_API_KEY", AWS_REGION) 17 | TAVILY_API_KEY = get_from_secretstore_or_env("TAVILY_API_KEY", AWS_REGION) 18 | 19 | # Initialize search providers 20 | tavily_provider = TavilySearchProvider(TAVILY_API_KEY) 21 | google_provider = GoogleSearchProvider(SERPER_API_KEY) 22 | 23 | 24 | def extract_search_params(action_group, function, parameters): 25 | if action_group != ACTION_GROUP_NAME: 26 | logger.error( 27 | f"unexpected name '{action_group}'; expected valid action group name '{ACTION_GROUP_NAME}'" 28 | ) 29 | return None, None, None, None, None 30 | 31 | if function not in FUNCTION_NAMES: 32 | logger.error( 33 | f"unexpected function name '{function}'; valid function names are'{FUNCTION_NAMES}'" 34 | ) 35 | return None, None, None, None, None 36 | 37 | # Extract all parameters with defaults 38 | search_query = next( 39 | (param["value"] for param in parameters if param["name"] == "search_query"), 40 | None, 41 | ) 42 | target_website = next( 43 | (param["value"] for param in parameters if param["name"] == "target_website"), 44 | None, 45 | ) 46 | 47 | if function == "tavily-ai-search": 48 | search_depth = next( 49 | (param["value"] for param in parameters if param["name"] == "search_depth"), 50 | "advanced", 51 | ) 52 | max_results = int( 53 | next( 54 | ( 55 | param["value"] 56 | for param in parameters 57 | if param["name"] == "max_results" 58 | ), 59 | "3", 60 | ) 61 | ) 62 | topic = next( 63 | (param["value"] for param in parameters if param["name"] == "topic"), 64 | "general", 65 | ) 66 | logger.debug( 67 | f"extract_search_params Tavily: {search_query=} {target_website=} {search_depth=} {max_results=} {topic=}" 68 | ) 69 | return search_query, target_website, search_depth, max_results, topic 70 | else: # google-search 71 | search_type = next( 72 | (param["value"] for param in parameters if param["name"] == "search_type"), 73 | "search", 74 | ) 75 | time_period = next( 76 | (param["value"] for param in parameters if param["name"] == "time_period"), 77 | None, 78 | ) 79 | country_code = next( 80 | (param["value"] for param in parameters if param["name"] == "country_code"), 81 | "us", 82 | ) 83 | logger.debug( 84 | f"extract_search_params Google: {search_query=} {target_website=} {search_type=} {time_period=} {country_code=}" 85 | ) 86 | return search_query, target_website, search_type, time_period, country_code 87 | 88 | 89 | def tavily_ai_search( 90 | search_query: str, 91 | target_website: Optional[str] = None, 92 | search_depth: str = "advanced", 93 | max_results: int = 3, 94 | topic: str = "general", 95 | days: int = 3, 96 | ) -> Dict[str, Any]: 97 | logger.info(f"executing Tavily AI search with {search_query=}") 98 | try: 99 | return tavily_provider.search( 100 | search_query, target_website, search_depth, max_results, topic, days 101 | ) 102 | except Exception as e: 103 | logger.error( 104 | f"failed to retrieve search results from Tavily AI Search: {str(e)}" 105 | ) 106 | return {"error": str(e)} 107 | 108 | 109 | def lambda_handler(event, _): # type: ignore 110 | logger.debug(f"lambda_handler {event=}") 111 | logger.debug(f"full event {event}") 112 | 113 | action_group = event["actionGroup"] 114 | function = event["function"] 115 | parameters = event.get("parameters", []) 116 | 117 | logger.info(f"lambda_handler: {action_group=} {function=}") 118 | 119 | search_query, target_website, search_depth, max_results, topic = ( 120 | extract_search_params(action_group, function, parameters) 121 | ) 122 | 123 | search_results = {} 124 | if function == "tavily-ai-search": 125 | search_results = tavily_ai_search( 126 | search_query, target_website, search_depth, max_results, topic 127 | ) 128 | elif function == "google-search": 129 | search_query, target_website, search_type, time_period, country_code = ( 130 | extract_search_params(action_group, function, parameters) 131 | ) 132 | try: 133 | search_results = google_provider.search( 134 | search_query, target_website, search_type, time_period, country_code 135 | ) 136 | except SearchError as e: 137 | logger.error(f"Google search failed: {str(e)}") 138 | search_results = {"error": str(e)} 139 | 140 | logger.debug(f"query results {search_results=}") 141 | 142 | # Prepare the response 143 | function_response_body = { 144 | "TEXT": { 145 | "body": f"Here are the top search results for the query '{search_query}': {json.dumps(search_results)} " 146 | } 147 | } 148 | 149 | action_response = { 150 | "actionGroup": action_group, 151 | "function": function, 152 | "functionResponse": {"responseBody": function_response_body}, 153 | } 154 | 155 | response = {"response": action_response, "messageVersion": event["messageVersion"]} 156 | 157 | logger.debug(f"lambda_handler: {response=}") 158 | 159 | return response 160 | -------------------------------------------------------------------------------- /v2/lambda_layer/advanced_websearch_libraries/requirements.txt: -------------------------------------------------------------------------------- 1 | langchain>=0.1.0 2 | langchain-community>=0.0.13 3 | pydantic>=2.5.2 4 | tavily-python>=0.2.8 5 | --------------------------------------------------------------------------------