├── requirements.txt ├── network_agent ├── tools │ ├── __init__.py │ ├── network_tools.py │ ├── general_tools.py │ ├── ec2_tools.py │ └── vpc_tools.py ├── _cli_example.py ├── main.py ├── tool_handler.py ├── chat_engine.py └── bedrock_utils.py ├── LICENSE ├── .gitignore ├── README.md └── infra └── infrastructure.yaml /requirements.txt: -------------------------------------------------------------------------------- 1 | altair==5.4.1 2 | attrs==24.2.0 3 | blinker==1.9.0 4 | boto3==1.35.66 5 | botocore==1.35.66 6 | cachetools==5.5.0 7 | certifi==2024.8.30 8 | charset-normalizer==3.4.0 9 | click==8.1.7 10 | gitdb==4.0.11 11 | GitPython==3.1.43 12 | idna==3.10 13 | Jinja2==3.1.4 14 | jmespath==1.0.1 15 | jsonschema==4.23.0 16 | jsonschema-specifications==2024.10.1 17 | markdown-it-py==3.0.0 18 | MarkupSafe==3.0.2 19 | mdurl==0.1.2 20 | narwhals==1.14.1 21 | numpy==2.1.3 22 | packaging==24.2 23 | pandas==2.2.3 24 | pillow==11.0.0 25 | protobuf==5.28.3 26 | pyarrow==18.0.0 27 | pydeck==0.9.1 28 | Pygments==2.18.0 29 | python-dateutil==2.9.0.post0 30 | pytz==2024.2 31 | referencing==0.35.1 32 | requests==2.32.3 33 | rich==13.9.4 34 | rpds-py==0.21.0 35 | s3transfer==0.10.4 36 | six==1.16.0 37 | smmap==5.0.1 38 | streamlit==1.40.1 39 | tenacity==9.0.0 40 | toml==0.10.2 41 | tornado==6.4.1 42 | typing_extensions==4.12.2 43 | tzdata==2024.2 44 | urllib3==2.2.3 45 | -------------------------------------------------------------------------------- /network_agent/tools/__init__.py: -------------------------------------------------------------------------------- 1 | from .vpc_tools import vpc_tools, handle_vpc_tool 2 | from .network_tools import network_tools, handle_network_tool 3 | from .ec2_tools import ec2_tools, handle_ec2_tool 4 | from .general_tools import general_tools, handle_general_tool 5 | 6 | 7 | def get_all_tools(): 8 | return vpc_tools + network_tools + ec2_tools + general_tools 9 | 10 | 11 | def handle_tool(tool_use): 12 | tool_name = tool_use['name'] 13 | if tool_name in [tool['toolSpec']['name'] for tool in vpc_tools]: 14 | return handle_vpc_tool(tool_use) 15 | elif tool_name in [tool['toolSpec']['name'] for tool in network_tools]: 16 | return handle_network_tool(tool_use) 17 | elif tool_name in [tool['toolSpec']['name'] for tool in ec2_tools]: 18 | return handle_ec2_tool(tool_use) 19 | elif tool_name in [tool['toolSpec']['name'] for tool in general_tools]: 20 | return handle_general_tool(tool_use) 21 | else: 22 | return {"error": f"Unknown tool: {tool_name}"} 23 | -------------------------------------------------------------------------------- /network_agent/_cli_example.py: -------------------------------------------------------------------------------- 1 | from chat_engine import chat, print_conversation 2 | from bedrock_utils import initialize_bedrock_client 3 | from tools import get_all_tools 4 | 5 | 6 | 7 | def main(): 8 | bedrock_client = initialize_bedrock_client() 9 | tools = get_all_tools() 10 | messages = [] 11 | 12 | print("Welcome to the AWS Network Assistant. You can ask about VPCs, Internet Gateways, NAT Gateways, Route Tables, and other network components.") 13 | print("Type 'exit', 'quit', or 'bye' to end the conversation.") 14 | 15 | while True: 16 | user_input = input("You: ") 17 | messages.append({"role": "user", "content": [{"text": user_input}]}) 18 | if user_input.lower() in ['exit', 'quit', 'bye']: 19 | break 20 | response = chat(user_input, messages, bedrock_client, tools) 21 | print("Assistant:", response[-1]["content"][0]["text"]) 22 | 23 | print("\nFinal Conversation History:") 24 | print_conversation(messages) 25 | 26 | 27 | if __name__ == "__main__": 28 | main() -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Du'An Lightfoot 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /network_agent/main.py: -------------------------------------------------------------------------------- 1 | import streamlit as st 2 | from chat_engine import chat 3 | from bedrock_utils import initialize_bedrock_client 4 | from tools import get_all_tools 5 | 6 | 7 | def main(): 8 | st.title("AWS Network Assistant") 9 | st.write("Ask about VPCs, Internet Gateways, NAT Gateways, Route Tables, and other network components.") 10 | 11 | # Initialize the bedrock client and tools 12 | if 'bedrock_client' not in st.session_state: 13 | st.session_state.bedrock_client = initialize_bedrock_client() 14 | if 'tools' not in st.session_state: 15 | st.session_state.tools = get_all_tools() 16 | if 'messages' not in st.session_state: 17 | st.session_state.messages = [] 18 | 19 | # Create a chat input 20 | user_input = st.chat_input("Type your question here...") 21 | 22 | # Display chat history 23 | for message in st.session_state.messages: 24 | if message["role"] == "user": 25 | with st.chat_message("user"): 26 | st.write(message["content"][0]["text"]) 27 | else: 28 | with st.chat_message("assistant"): 29 | st.write(message["content"][0]["text"]) 30 | 31 | # Handle new user input 32 | if user_input: 33 | # Add user message to chat history 34 | st.session_state.messages.append({ 35 | "role": "user", 36 | "content": [{"text": user_input}] 37 | }) 38 | 39 | # Display user message 40 | with st.chat_message("user"): 41 | st.write(user_input) 42 | 43 | # Get and display assistant response 44 | with st.chat_message("assistant"): 45 | response = chat( 46 | user_input, 47 | st.session_state.messages, 48 | st.session_state.bedrock_client, 49 | st.session_state.tools 50 | ) 51 | st.write(response[-1]["content"][0]["text"]) 52 | 53 | # Add assistant response to chat history 54 | st.session_state.messages.extend(response) 55 | 56 | if __name__ == "__main__": 57 | main() 58 | -------------------------------------------------------------------------------- /network_agent/tool_handler.py: -------------------------------------------------------------------------------- 1 | # tool_handler.py 2 | from tools.vpc_tools import list_vpcs, check_internet_gateway, check_nat_gateway, get_route_tables 3 | from tools.network_tools import list_subnets, describe_network_acls 4 | from tools.ec2_tools import describe_instances, describe_security_groups 5 | from tools.general_tools import get_current_datetime, calculate_cidr_range 6 | 7 | 8 | def handle_tool_use(tool_use): 9 | """ 10 | Handle tool use requests from Claude. 11 | 12 | :param tool_use: Dictionary containing tool use details 13 | :return: Dictionary with the tool result in the format expected by Claude 14 | """ 15 | tool_name = tool_use['name'] 16 | input_data = tool_use['input'] 17 | region = input_data.get('region', 'us-west-2') 18 | 19 | if tool_name == "list_vpcs": 20 | result = list_vpcs(region=region) 21 | elif tool_name == "check_internet_gateway": 22 | result = check_internet_gateway(input_data['vpc_id'], region=region) 23 | elif tool_name == "check_nat_gateway": 24 | result = check_nat_gateway(input_data['vpc_id'], region=region) 25 | elif tool_name == "get_route_tables": 26 | result = get_route_tables(input_data['vpc_id'], region=region) 27 | elif tool_name == "list_subnets": 28 | result = list_subnets(input_data['vpc_id'], region=region) 29 | elif tool_name == "describe_network_acls": 30 | result = describe_network_acls(input_data['vpc_id'], region=region) 31 | elif tool_name == "describe_instances": 32 | result = describe_instances(region=region) 33 | elif tool_name == "describe_security_groups": 34 | result = describe_security_groups(region=region) 35 | elif tool_name == "get_current_datetime": 36 | result = get_current_datetime(input_data.get('timezone', 'UTC')) 37 | elif tool_name == "calculate_cidr_range": 38 | result = calculate_cidr_range(input_data['cidr_block']) 39 | else: 40 | result = {"error": f"Unknown tool: {tool_name}"} 41 | 42 | return { 43 | "role": "user", 44 | "content": [ 45 | { 46 | "toolResult": { 47 | "toolUseId": tool_use['toolUseId'], 48 | "content": [{"json": result}], 49 | "status": "success" 50 | } 51 | } 52 | ] 53 | } 54 | -------------------------------------------------------------------------------- /network_agent/tools/network_tools.py: -------------------------------------------------------------------------------- 1 | import boto3 2 | 3 | 4 | def list_subnets(vpc_id, region="us-west-2"): 5 | ec2 = boto3.client('ec2', region_name=region) 6 | response = ec2.describe_subnets(Filters=[{'Name': 'vpc-id', 'Values': [vpc_id]}]) 7 | subnets = [{'SubnetId': subnet['SubnetId'], 'CidrBlock': subnet['CidrBlock'], 'AvailabilityZone': subnet['AvailabilityZone']} for subnet in response['Subnets']] 8 | return { 9 | 'vpc_id': vpc_id, 10 | 'subnets': subnets, 11 | "region": region 12 | } 13 | 14 | 15 | def describe_network_acls(vpc_id, region="us-west-2"): 16 | ec2 = boto3.client('ec2', region_name=region) 17 | response = ec2.describe_network_acls(Filters=[{'Name': 'vpc-id', 'Values': [vpc_id]}]) 18 | nacls = [{'NetworkAclId': nacl['NetworkAclId'], 'IsDefault': nacl['IsDefault']} for nacl in response['NetworkAcls']] 19 | return { 20 | 'vpc_id': vpc_id, 21 | 'network_acls': nacls, 22 | "region": region 23 | } 24 | 25 | 26 | network_tools = [ 27 | { 28 | "toolSpec": { 29 | "name": "list_subnets", 30 | "description": "List subnets in a specified VPC", 31 | "inputSchema": { 32 | "json": { 33 | "type": "object", 34 | "properties": { 35 | "vpc_id": {"type": "string", "description": "VPC ID"}, 36 | "region": {"type": "string", "description": "AWS region (e.g., us-west-2)"} 37 | }, 38 | "required": ["vpc_id"] 39 | } 40 | } 41 | } 42 | }, 43 | { 44 | "toolSpec": { 45 | "name": "describe_network_acls", 46 | "description": "Describe Network ACLs for a specified VPC", 47 | "inputSchema": { 48 | "json": { 49 | "type": "object", 50 | "properties": { 51 | "vpc_id": {"type": "string", "description": "VPC ID"}, 52 | "region": {"type": "string", "description": "AWS region (e.g., us-west-2)"} 53 | }, 54 | "required": ["vpc_id"] 55 | } 56 | } 57 | } 58 | } 59 | ] 60 | 61 | 62 | def handle_network_tool(tool_use): 63 | tool_name = tool_use['name'] 64 | input_data = tool_use['input'] 65 | 66 | if tool_name == "list_subnets": 67 | result = list_subnets(input_data['vpc_id'], region=input_data.get('region', 'us-west-2')) 68 | elif tool_name == "describe_network_acls": 69 | result = describe_network_acls(input_data['vpc_id'], region=input_data.get('region', 'us-west-2')) 70 | else: 71 | result = {"error": f"Unknown network tool: {tool_name}"} 72 | 73 | return { 74 | "role": "user", 75 | "content": [ 76 | { 77 | "toolResult": { 78 | "toolUseId": tool_use['toolUseId'], 79 | "content": [{"json": result}], 80 | "status": "success" 81 | } 82 | } 83 | ] 84 | } -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control 110 | .pdm.toml 111 | .pdm-python 112 | .pdm-build/ 113 | 114 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 115 | __pypackages__/ 116 | 117 | # Celery stuff 118 | celerybeat-schedule 119 | celerybeat.pid 120 | 121 | # SageMath parsed files 122 | *.sage.py 123 | 124 | # Environments 125 | .env 126 | .venv 127 | env/ 128 | venv/ 129 | ENV/ 130 | env.bak/ 131 | venv.bak/ 132 | 133 | # Spyder project settings 134 | .spyderproject 135 | .spyproject 136 | 137 | # Rope project settings 138 | .ropeproject 139 | 140 | # mkdocs documentation 141 | /site 142 | 143 | # mypy 144 | .mypy_cache/ 145 | .dmypy.json 146 | dmypy.json 147 | 148 | # Pyre type checker 149 | .pyre/ 150 | 151 | # pytype static type analyzer 152 | .pytype/ 153 | 154 | # Cython debug symbols 155 | cython_debug/ 156 | 157 | # PyCharm 158 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 159 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 160 | # and can be added to the global gitignore or merged into this file. For a more nuclear 161 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 162 | #.idea/ 163 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Network Whisperer Agent 2 | 3 | A Python-based network analysis agent for AWS infrastructure that provides comprehensive insights into your VPC configurations, ec2 and network resources, and other functions. 4 | 5 | ## Capabilities 6 | 7 | The agent provides the following network analysis tools: 8 | 9 | ### VPC Tools 10 | - List all VPCs in a specified region 11 | - Check Internet Gateway configurations 12 | - Check NAT Gateway configurations 13 | - Retrieve and analyze Route Tables 14 | 15 | ### Network Tools 16 | - List and analyze subnets within a VPC 17 | - Describe and analyze Network ACLs 18 | 19 | ### EC2 Tools 20 | - Describe EC2 instances and their configurations 21 | - Analyze Security Groups 22 | 23 | ### General Tools 24 | - Get current datetime (with timezone support) 25 | - Calculate CIDR range information 26 | 27 | >NOTE: All tools can be found in `./network_agent/tools/`clear 28 | 29 | ## Prerequisites 30 | 31 | - Python 3.x 32 | - AWS credentials configured 33 | - Required Python packages (add requirements.txt to your project) 34 | 35 | ## Installation 36 | 37 | 1. Clone the repository: 38 | 39 | ```bash 40 | git clone 41 | cd network_whisperer 42 | ``` 43 | 44 | 2. Set up a virtual environment: 45 | 46 | ```bash 47 | python -m venv venv 48 | source venv/bin/activate # On Windows use: venv\Scripts\activate 49 | ``` 50 | 51 | 3. Install required dependencies: 52 | 53 | ```bash 54 | pip install -r requirements.txt 55 | ``` 56 | 57 | 4. Configure AWS credentials: Ensure you have AWS credentials configured either through: 58 | 59 | AWS CLI configuration 60 | 61 | Environment variables: 62 | 63 | ```bash 64 | export AWS_ACCESS_KEY_ID="your_access_key" 65 | export AWS_SECRET_ACCESS_KEY="your_secret_key" 66 | export AWS_DEFAULT_REGION="us-west-2" 67 | ``` 68 | 69 | ## Running the Application 70 | 71 | ### Using Streamlit 72 | 73 | 1. Cd to the `./network_agent` directory 74 | 75 | ```bash 76 | cd network_agent/ 77 | ``` 78 | 79 | 2. Run the application: 80 | 81 | ```bash 82 | streamlit run main.py 83 | ``` 84 | 85 | This will: 86 | 87 | - Start the Streamlit server 88 | - Automatically open your default web browser to the application 89 | - Display the Network Whisperer interface 90 | 91 | >If the browser doesn't open automatically, you can access the application at http://localhost:8501 92 | 93 | ### Usage 94 | 95 | 1. Import the necessary modules in your main.py: 96 | 97 | ```python 98 | from network_agent.tool_handler import handle_tool_use 99 | 100 | # Example usage 101 | tool_request = { 102 | "name": "list_vpcs", 103 | "input": { 104 | "region": "us-west-2" 105 | }, 106 | "toolUseId": "unique_id" 107 | } 108 | 109 | result = handle_tool_use(tool_request) 110 | ``` 111 | 112 | 2. Available Tool Commands: 113 | 114 | ```python 115 | # VPC Operations 116 | list_vpcs(region="us-west-2") 117 | check_internet_gateway(vpc_id, region="us-west-2") 118 | check_nat_gateway(vpc_id, region="us-west-2") 119 | get_route_tables(vpc_id, region="us-west-2") 120 | 121 | # Network Operations 122 | list_subnets(vpc_id, region="us-west-2") 123 | describe_network_acls(vpc_id, region="us-west-2") 124 | 125 | # EC2 Operations 126 | describe_ec2_instances(region="us-west-2") 127 | describe_security_groups(region="us-west-2") 128 | 129 | # Utility Operations 130 | get_current_datetime(timezone="UTC") 131 | calculate_cidr_range(cidr_block) 132 | ``` 133 | 134 | **Default Configuration** 135 | - Default region: us-west-2 136 | - Default timezone: UTC 137 | 138 | **Error Handling** 139 | The tool handler returns error messages in the following format: 140 | 141 | ```json 142 | { 143 | "error": "Error message description" 144 | } 145 | ``` 146 | 147 | ### Best Practices 148 | 149 | 1. Always specify the region when making API calls 150 | 2. Use environment variables for sensitive information 151 | 3. Implement proper error handling in your main application 152 | 4. Monitor API usage to stay within AWS service limits 153 | 154 | ## About me 155 | 156 | My passions lie in building cool stuff and impacting people's lives. I'm fortunate to weave all these elements together in my role as a Developer Advocate. On GitHub, I share my ongoing learning journey and the projects I'm building. Don't hesitate to reach out for a friendly hello or to ask any questions! 157 | 158 | My hangouts: 159 | - [LinkedIn](https://www.linkedin.com/in/duanlightfoot/) 160 | - [YouTube](https://www.youtube.com/@LabEveryday) -------------------------------------------------------------------------------- /infra/infrastructure.yaml: -------------------------------------------------------------------------------- 1 | AWSTemplateFormatVersion: '2010-09-09' 2 | Description: 'Network Infrastructure for AI Agent Demo' 3 | Parameters: 4 | LatestAmiId: 5 | Description: AMI for Bastion Host (default is latest AmaLinux2) 6 | Type: 'AWS::SSM::Parameter::Value' 7 | Default: '/aws/service/ami-amazon-linux-latest/amzn2-ami-hvm-x86_64-gp2' 8 | 9 | Resources: 10 | # S3 Bucket for flow logs 11 | FlowLogsBucket: 12 | Type: AWS::S3::Bucket 13 | Properties: 14 | VersioningConfiguration: 15 | Status: Enabled 16 | 17 | # VPC 1 in us-east-1 18 | VPC1: 19 | Type: AWS::EC2::VPC 20 | Properties: 21 | CidrBlock: 10.0.0.0/16 22 | EnableDnsHostnames: true 23 | EnableDnsSupport: true 24 | Tags: 25 | - Key: Name 26 | Value: Demo-VPC-1 27 | 28 | VPC1PublicSubnet: 29 | Type: AWS::EC2::Subnet 30 | Properties: 31 | VpcId: !Ref VPC1 32 | CidrBlock: 10.0.1.0/24 33 | AvailabilityZone: !Select [0, !GetAZs ""] 34 | Tags: 35 | - Key: Name 36 | Value: VPC1-Public-1 37 | 38 | # Internet Gateways 39 | VPC1InternetGateway: 40 | Type: AWS::EC2::InternetGateway 41 | Properties: 42 | Tags: 43 | - Key: Name 44 | Value: VPC1-IGW 45 | 46 | VPC1InternetGatewayAttachment: 47 | Type: AWS::EC2::VPCGatewayAttachment 48 | Properties: 49 | VpcId: !Ref VPC1 50 | InternetGatewayId: !Ref VPC1InternetGateway 51 | 52 | # Route Tables 53 | VPC1PublicRouteTable: 54 | Type: AWS::EC2::RouteTable 55 | Properties: 56 | VpcId: !Ref VPC1 57 | Tags: 58 | - Key: Name 59 | Value: VPC1-Public-RT 60 | 61 | VPC1PublicRoute: 62 | Type: AWS::EC2::Route 63 | DependsOn: VPC1InternetGatewayAttachment 64 | Properties: 65 | RouteTableId: !Ref VPC1PublicRouteTable 66 | DestinationCidrBlock: 0.0.0.0/0 67 | GatewayId: !Ref VPC1InternetGateway 68 | 69 | VPC1PublicSubnetRouteTableAssociation: 70 | Type: AWS::EC2::SubnetRouteTableAssociation 71 | Properties: 72 | SubnetId: !Ref VPC1PublicSubnet 73 | RouteTableId: !Ref VPC1PublicRouteTable 74 | 75 | WebServerSG1: 76 | Type: 'AWS::EC2::SecurityGroup' 77 | Properties: 78 | VpcId: !Ref VPC1 79 | GroupDescription: Enable SSH access via ports 22 IPv4 80 | SecurityGroupIngress: 81 | - Description: 'Allow HTTP IPv4 IN' 82 | IpProtocol: tcp 83 | FromPort: 80 84 | ToPort: 80 85 | CidrIp: 0.0.0.0/0 86 | 87 | # Flow log configuration 88 | FlowLog1: 89 | Type: AWS::EC2::FlowLog 90 | Properties: 91 | ResourceId: !Ref VPC1 92 | ResourceType: VPC 93 | TrafficType: ALL 94 | LogDestinationType: s3 95 | LogDestination: !Sub ${FlowLogsBucket.Arn}/vpc-flow-logs/ 96 | MaxAggregationInterval: 60 97 | LogFormat: '${version} ${account-id} ${interface-id} ${srcaddr} ${dstaddr} ${srcport} ${dstport} ${protocol} ${packets} ${bytes} ${start} ${end} ${action} ${log-status} ${vpc-id} ${subnet-id} ${instance-id} ${tcp-flags} ${type} ${pkt-srcaddr} ${pkt-dstaddr} ${region} ${az-id} ${sublocation-type} ${sublocation-id}' 98 | 99 | # CloudWatch Alarms 100 | BandwidthAlarm: 101 | Type: AWS::CloudWatch::Alarm 102 | Properties: 103 | AlarmName: HighBandwidthUsage 104 | MetricName: NetworkIn 105 | Namespace: AWS/EC2 106 | Statistic: Sum 107 | Period: 300 108 | EvaluationPeriods: 2 109 | Threshold: 5000000 110 | ComparisonOperator: GreaterThanThreshold 111 | 112 | 113 | # EC2 Instances 114 | WebServer1: 115 | Type: AWS::EC2::Instance 116 | Properties: 117 | InstanceType: t3.micro 118 | ImageId: !Ref LatestAmiId 119 | SubnetId: !Ref VPC1PublicSubnet 120 | SecurityGroupIds: 121 | - !Ref WebServerSG1 122 | UserData: 123 | Fn::Base64: !Sub | 124 | #!/bin/bash 125 | yum update -y 126 | yum install -y httpd 127 | systemctl start httpd 128 | systemctl enable httpd 129 | echo "

Hello from VPC 1

" > /var/www/html/index.html 130 | Tags: 131 | - Key: Name 132 | Value: WebServer-VPC1 133 | 134 | Outputs: 135 | VPC1Id: 136 | Description: VPC 1 ID 137 | Value: !Ref VPC1 138 | 139 | S3BucketName: 140 | Description: S3 Bucket for flow logs 141 | Value: !Ref FlowLogsBucket 142 | 143 | WebServer1PublicIP: 144 | Description: WebServer1 Public IP 145 | Value: !GetAtt WebServer1.PublicIpAddress 146 | -------------------------------------------------------------------------------- /network_agent/tools/general_tools.py: -------------------------------------------------------------------------------- 1 | # tools/general_tools.py 2 | from datetime import datetime 3 | import pytz 4 | from typing import Dict, Any 5 | 6 | 7 | def get_current_datetime(timezone_str: str = "UTC") -> Dict[str, str]: 8 | """ 9 | Get the current date and time in the specified timezone. 10 | 11 | This function returns the current date and time for a given timezone. 12 | If no timezone is specified, it defaults to UTC. 13 | 14 | Args: 15 | timezone_str (str): The timezone to use (e.g., "America/New_York", "Europe/London"). 16 | 17 | Returns: 18 | Dict[str, str]: A dictionary containing the formatted datetime and the timezone used. 19 | If an error occurs, it returns a dictionary with an error message. 20 | 21 | Raises: 22 | pytz.exceptions.UnknownTimeZoneError: If an invalid timezone is provided. 23 | """ 24 | try: 25 | tz = pytz.timezone(timezone_str) 26 | current_time = datetime.now(tz) 27 | return { 28 | "datetime": current_time.strftime("%Y-%m-%d %H:%M:%S %Z"), 29 | "timezone": timezone_str 30 | } 31 | except pytz.exceptions.UnknownTimeZoneError: 32 | return {"error": f"Unknown timezone: {timezone_str}"} 33 | 34 | 35 | def calculate_cidr_range(cidr: str) -> Dict[str, Any]: 36 | """ 37 | Calculate the range of IP addresses for a given CIDR notation. 38 | 39 | This function takes a CIDR notation and returns details about the IP range, 40 | including the network address, broadcast address, number of addresses, and netmask. 41 | 42 | Args: 43 | cidr (str): The CIDR notation (e.g., "192.168.1.0/24"). 44 | 45 | Returns: 46 | Dict[str, Any]: A dictionary containing CIDR range details or an error message 47 | if the CIDR notation is invalid. 48 | 49 | Raises: 50 | ValueError: If the CIDR notation is invalid. 51 | """ 52 | try: 53 | from ipaddress import ip_network 54 | network = ip_network(cidr) 55 | return { 56 | "network_address": str(network.network_address), 57 | "broadcast_address": str(network.broadcast_address), 58 | "num_addresses": network.num_addresses, 59 | "netmask": str(network.netmask) 60 | } 61 | except ValueError as e: 62 | return {"error": f"Invalid CIDR notation: {str(e)}"} 63 | 64 | 65 | general_tools = [ 66 | { 67 | "toolSpec": { 68 | "name": "get_current_datetime", 69 | "description": "Get the current date and time in a specified timezone", 70 | "inputSchema": { 71 | "json": { 72 | "type": "object", 73 | "properties": { 74 | "utc": {"type": "string", "description": "The timezone to use (e.g., UTC)"} 75 | }, 76 | "required": [] 77 | } 78 | } 79 | } 80 | }, 81 | { 82 | "toolSpec": { 83 | "name": "calculate_cidr_range", 84 | "description": "Calculate the range of IP addresses for a given CIDR notation", 85 | "inputSchema": { 86 | "json": { 87 | "type": "object", 88 | "properties": { 89 | "cidr": {"type": "string", "description": "The CIDR notation (e.g., 192.168.1.0/24)"} 90 | }, 91 | "required": ["cidr"] 92 | } 93 | } 94 | } 95 | } 96 | ] 97 | 98 | 99 | def handle_general_tool(tool_use: Dict[str, Any]) -> Dict[str, Any]: 100 | """ 101 | Handle general tool use requests. 102 | 103 | Args: 104 | tool_use (Dict[str, Any]): A dictionary containing the tool use request details. 105 | 106 | Returns: 107 | Dict[str, Any]: The result of the tool operation or an error message. 108 | """ 109 | tool_name = tool_use['name'] 110 | input_data = tool_use['input'] 111 | 112 | try: 113 | if tool_name == "get_current_datetime": 114 | result = get_current_datetime(input_data.get('timezone', 'UTC')) 115 | elif tool_name == "calculate_cidr_range": 116 | result = calculate_cidr_range(input_data['cidr']) 117 | else: 118 | return {"error": f"Unknown general tool: {tool_name}"} 119 | 120 | return { 121 | "toolResult": { 122 | "toolUseId": tool_use['toolUseId'], 123 | "content": [{"json": result}], 124 | "status": "success" 125 | } 126 | } 127 | except Exception as e: 128 | return { 129 | "toolResult": { 130 | "toolUseId": tool_use['toolUseId'], 131 | "content": [{"json": {"error": str(e)}}], 132 | "status": "error" 133 | } 134 | } -------------------------------------------------------------------------------- /network_agent/tools/ec2_tools.py: -------------------------------------------------------------------------------- 1 | import boto3 2 | 3 | 4 | def describe_instances(region="us-west-2"): 5 | ec2 = boto3.client('ec2', region_name=region) 6 | response = ec2.describe_instances() 7 | instance_ids = [] 8 | for reservation in response["Reservations"]: 9 | for instance in reservation["Instances"]: 10 | instance_ids.append(instance["InstanceId"]) 11 | return { 12 | "InstanceIds": instance_ids, 13 | "region": region 14 | } 15 | 16 | 17 | # Describe security groups in a region 18 | def describe_security_groups(vpc_id, region="us-west-2"): 19 | ec2 = boto3.client('ec2', region_name=region) 20 | response = ec2.describe_security_groups( 21 | Filters=[ 22 | { 23 | 'Name': 'vpc-id', 24 | 'Values': [vpc_id] 25 | } 26 | ] 27 | ) 28 | 29 | security_groups = [] 30 | for sg in response['SecurityGroups']: 31 | inbound_rules = [] 32 | outbound_rules = [] 33 | 34 | # Format inbound rules 35 | for rule in sg['IpPermissions']: 36 | port_range = f"{rule.get('FromPort', 'All')} - {rule.get('ToPort', 'All')}" 37 | for ip_range in rule.get('IpRanges', []): 38 | inbound_rules.append({ 39 | 'Protocol': rule.get('IpProtocol', '-1'), 40 | 'Ports': port_range, 41 | 'Source': ip_range.get('CidrIp'), 42 | 'Description': ip_range.get('Description', '') 43 | }) 44 | 45 | # Include security group references 46 | for group in rule.get('UserIdGroupPairs', []): 47 | inbound_rules.append({ 48 | 'Protocol': rule.get('IpProtocol', '-1'), 49 | 'Ports': port_range, 50 | 'Source': f"sg-{group.get('GroupId')}", 51 | 'Description': group.get('Description', '') 52 | }) 53 | 54 | # Format outbound rules 55 | for rule in sg['IpPermissionsEgress']: 56 | port_range = f"{rule.get('FromPort', 'All')} - {rule.get('ToPort', 'All')}" 57 | for ip_range in rule.get('IpRanges', []): 58 | outbound_rules.append({ 59 | 'Protocol': rule.get('IpProtocol', '-1'), 60 | 'Ports': port_range, 61 | 'Destination': ip_range.get('CidrIp'), 62 | 'Description': ip_range.get('Description', '') 63 | }) 64 | 65 | # Include security group references 66 | for group in rule.get('UserIdGroupPairs', []): 67 | outbound_rules.append({ 68 | 'Protocol': rule.get('IpProtocol', '-1'), 69 | 'Ports': port_range, 70 | 'Destination': f"sg-{group.get('GroupId')}", 71 | 'Description': group.get('Description', '') 72 | }) 73 | 74 | security_groups.append({ 75 | 'GroupId': sg['GroupId'], 76 | 'GroupName': sg['GroupName'], 77 | 'Description': sg['Description'], 78 | 'InboundRules': inbound_rules, 79 | 'OutboundRules': outbound_rules 80 | }) 81 | 82 | return security_groups 83 | 84 | 85 | 86 | ec2_tools = [ 87 | { 88 | "toolSpec": { 89 | "name": "describe_instances", 90 | "description": "List ec2 instances in all a region", 91 | "inputSchema": { 92 | "json": { 93 | "type": "object", 94 | "properties": { 95 | "region": {"type": "string", "description": "AWS region (e.g., us-west-2)"} 96 | }, 97 | "required": [] 98 | } 99 | } 100 | } 101 | }, 102 | { 103 | "toolSpec": { 104 | "name": "describe_security_groups", 105 | "description": "Describe security groups in a specified AWS region", 106 | "inputSchema": { 107 | "json": { 108 | "type": "object", 109 | "properties": { 110 | "vpc_id": {"type": "string", "description": "VPC ID"}, 111 | "region": {"type": "string", "description": "AWS region (e.g., us-west-2)"}, 112 | }, 113 | "required": ["vpc_id"] 114 | } 115 | } 116 | } 117 | } 118 | ] 119 | 120 | 121 | def handle_ec2_tool(tool_use): 122 | tool_name = tool_use['name'] 123 | input_data = tool_use['input'] 124 | 125 | if tool_name == "describe_instances": 126 | result = describe_instances(region=input_data.get('region', 'us-west-2')) 127 | elif tool_name == "describe_security_groups": 128 | result = describe_security_groups(region=input_data.get('region', 'us-west-2')) 129 | else: 130 | result = {"error": f"Unknown network tool: {tool_name}"} 131 | 132 | return { 133 | "role": "user", 134 | "content": [ 135 | { 136 | "toolResult": { 137 | "toolUseId": tool_use['toolUseId'], 138 | "content": [{"json": result}], 139 | "status": "success" 140 | } 141 | } 142 | ] 143 | } 144 | -------------------------------------------------------------------------------- /network_agent/chat_engine.py: -------------------------------------------------------------------------------- 1 | # chat_engine.py 2 | import json 3 | import logging 4 | from typing import List, Dict, Any 5 | from tool_handler import handle_tool_use 6 | from bedrock_utils import converse_with_claude, create_converse_request 7 | 8 | # Set up logging 9 | logging.basicConfig( 10 | level=logging.INFO, 11 | format='%(asctime)s | %(levelname)s | %(message)s', 12 | datefmt='%H:%M' 13 | ) 14 | logger = logging.getLogger(__name__) 15 | 16 | def chat(user_input: str, messages: List[Dict[str, Any]], bedrock_client: Any, tools: List[Dict[str, Any]]) -> List[Dict[str, Any]]: 17 | """ 18 | Main chat function to interact with Claude, handling tool use and maintaining conversation flow. 19 | 20 | Args: 21 | user_input (str): The user's input text. 22 | messages (List[Dict[str, Any]]): The conversation history. 23 | bedrock_client (Any): The Bedrock client for making API calls. 24 | tools (List[Dict[str, Any]]): List of available tools for Claude to use. 25 | 26 | Returns: 27 | List[Dict[str, Any]]: Updated conversation history including Claude's responses and tool uses. 28 | 29 | This function manages the conversation flow, ensuring proper alternation between user and assistant roles, 30 | and handles any tool use requests from Claude. 31 | """ 32 | try: 33 | # Add user input to messages 34 | #messages.append({"role": "user", "content": [{"text": user_input}]}) 35 | 36 | while True: 37 | # Get Claude's response 38 | request = create_converse_request(messages, tools) 39 | response = converse_with_claude(bedrock_client, request) 40 | 41 | # Example Use Claude 3 Sonnet 42 | # request = create_converse_request(messages, tools, model_key="claude_3_sonnet") 43 | # response = converse_with_claude(bedrock_client, request, model_key="claude_3_sonnet") 44 | 45 | 46 | if not response or 'content' not in response: 47 | logger.error("Unexpected response format from Claude.") 48 | raise ValueError("Invalid response from Claude") 49 | 50 | # Process Claude's response 51 | assistant_message = {"role": "assistant", "content": []} 52 | for content in response['content']: 53 | if 'text' in content: 54 | logger.info(f"Claude: {content['text']}") 55 | assistant_message['content'].append({"text": content['text']}) 56 | elif 'toolUse' in content: 57 | tool_use = content['toolUse'] 58 | logger.info(f"Claude is using the {tool_use['name']} tool.") 59 | assistant_message['content'].append({"toolUse": tool_use}) 60 | 61 | # Add Claude's response to messages 62 | messages.append(assistant_message) 63 | 64 | # Check if Claude used a tool 65 | if any('toolUse' in item for item in assistant_message['content']): 66 | # Handle all tool uses 67 | user_message = {"role": "user", "content": []} 68 | for item in assistant_message['content']: 69 | if 'toolUse' in item: 70 | try: 71 | tool_result = handle_tool_use(item['toolUse']) 72 | user_message['content'].append(tool_result['content'][0]) 73 | except Exception as e: 74 | logger.error(f"Error in tool use: {str(e)}") 75 | user_message['content'].append({"toolResult": {"status": "error", "message": str(e)}}) 76 | 77 | # Add tool results as a user message 78 | messages.append(user_message) 79 | else: 80 | # If no tool was used, we're done 81 | break 82 | 83 | return messages 84 | except Exception as e: 85 | logger.error(f"An error occurred in the chat function: {str(e)}") 86 | raise 87 | 88 | def print_conversation(messages: List[Dict[str, Any]]) -> None: 89 | """ 90 | Print the entire conversation history in a readable format. 91 | 92 | Args: 93 | messages (List[Dict[str, Any]]): The conversation history to print. 94 | 95 | This function iterates through the conversation history and prints each message, 96 | including text content, tool uses, and tool results, in a formatted manner. 97 | """ 98 | print("\n=== Conversation History ===") 99 | for message in messages: 100 | role = message['role'] 101 | content = message['content'] 102 | 103 | print(f"\n{role.capitalize()}:") 104 | for item in content: 105 | if 'text' in item: 106 | print(f" {item['text']}") 107 | elif 'toolUse' in item: 108 | tool_use = item['toolUse'] 109 | print(item) 110 | print(f" [Tool Use] {tool_use['name']}") 111 | print(f" Input: {json.dumps(tool_use['input'], indent=2)}") 112 | elif 'toolResult' in item: 113 | tool_result = item['toolResult'] 114 | print(f" [Tool Result] Status: {tool_result.get('status', 'N/A')}") 115 | if 'toolUseId' in tool_result: 116 | print(f" ID: {tool_result['toolUseId']}") 117 | if 'content' in tool_result: 118 | print(f" Output: {json.dumps(tool_result['content'][0]['json'], indent=2)}") 119 | if 'message' in tool_result: 120 | print(f" Message: {tool_result['message']}") 121 | print("-" * 50) -------------------------------------------------------------------------------- /network_agent/tools/vpc_tools.py: -------------------------------------------------------------------------------- 1 | import boto3 2 | 3 | 4 | def list_vpcs(region="us-west-2"): 5 | ec2 = boto3.client('ec2', region_name=region) 6 | response = ec2.describe_vpcs() 7 | vpcs = [{'VpcId': vpc['VpcId'], 'CidrBlock': vpc['CidrBlock'], 8 | 'IsDefault': vpc['IsDefault']} for vpc in response['Vpcs']] 9 | return { 10 | 'vpcs': vpcs, 11 | "region": region 12 | } 13 | 14 | def check_internet_gateway(vpc_id, region="us-west-2"): 15 | ec2 = boto3.client('ec2', region_name=region) 16 | response = ec2.describe_internet_gateways( 17 | Filters=[ 18 | { 19 | 'Name': 'attachment.vpc-id', 20 | 'Values': [vpc_id] 21 | } 22 | ] 23 | ) 24 | internet_gateways = [ 25 | { 26 | 'InternetGatewayId': ig['InternetGatewayId'], 27 | 'AttachedToVpc': vpc_id in [att['VpcId'] for att in ig['Attachments']] 28 | } for ig in response['InternetGateways'] 29 | ] 30 | return { 31 | 'vpc_id': vpc_id, 32 | 'internetGateways': internet_gateways 33 | } 34 | 35 | def check_nat_gateway(vpc_id, region="us-west-2"): 36 | ec2 = boto3.client('ec2', region_name=region) 37 | response = ec2.describe_nat_gateways( 38 | Filters=[ 39 | { 40 | 'Name': 'vpc-id', 41 | 'Values': [vpc_id] 42 | } 43 | ] 44 | ) 45 | nat_gateways = [ 46 | { 47 | 'NatGatewayId': natgw['NatGatewayId'], 48 | 'SubnetId': natgw['SubnetId'], 49 | 'State': natgw['State'], 50 | 'PublicIp': natgw['NatGatewayAddresses'][0]['PublicIp'] if natgw['NatGatewayAddresses'] else None 51 | } for natgw in response['NatGateways'] 52 | ] 53 | return { 54 | 'vpc_id': vpc_id, 55 | 'NatGateways': nat_gateways 56 | } 57 | 58 | def get_route_tables(vpc_id, region="us-west-2"): 59 | ec2 = boto3.client('ec2', region_name=region) 60 | response = ec2.describe_route_tables( 61 | Filters=[ 62 | { 63 | 'Name': 'vpc-id', 64 | 'Values': [vpc_id] 65 | } 66 | ] 67 | ) 68 | route_tables = [] 69 | for rt in response['RouteTables']: 70 | routes = [] 71 | for route in rt['Routes']: 72 | route_data = { 73 | 'DestinationCidrBlock': route.get('DestinationCidrBlock'), 74 | 'GatewayId': route.get('GatewayId'), 75 | 'NatGatewayId': route.get('NatGatewayId'), 76 | 'InstanceId': route.get('InstanceId'), 77 | 'VpcPeeringConnectionId': route.get('VpcPeeringConnectionId'), 78 | 'NetworkInterfaceId': route.get('NetworkInterfaceId') 79 | } 80 | routes.append({k: v for k, v in route_data.items() if v is not None}) 81 | 82 | route_tables.append({ 83 | 'RouteTableId': rt['RouteTableId'], 84 | 'IsMain': any(assoc['Main'] for assoc in rt.get('Associations', [])), 85 | 'Routes': routes 86 | }) 87 | 88 | return { 89 | 'vpc_id': vpc_id, 90 | 'routeTables': route_tables 91 | } 92 | 93 | 94 | vpc_tools = [ 95 | { 96 | "toolSpec": { 97 | "name": "list_vpcs", 98 | "description": "List VPCs in a specified AWS region", 99 | "inputSchema": { 100 | "json": { 101 | "type": "object", 102 | "properties": { 103 | "region": {"type": "string", "description": "AWS region (e.g., us-west-2)"} 104 | } 105 | } 106 | } 107 | } 108 | }, 109 | { 110 | "toolSpec": { 111 | "name": "check_internet_gateway", 112 | "description": "Check Internet Gateway for a specified VPC", 113 | "inputSchema": { 114 | "json": { 115 | "type": "object", 116 | "properties": { 117 | "vpc_id": {"type": "string", "description": "VPC ID"}, 118 | "region": {"type": "string", "description": "AWS region (e.g., us-west-2)"} 119 | }, 120 | "required": ["vpc_id"] 121 | } 122 | } 123 | } 124 | }, 125 | { 126 | "toolSpec": { 127 | "name": "check_nat_gateway", 128 | "description": "Check NAT Gateway for a specified VPC", 129 | "inputSchema": { 130 | "json": { 131 | "type": "object", 132 | "properties": { 133 | "vpc_id": {"type": "string", "description": "VPC ID"}, 134 | "region": {"type": "string", "description": "AWS region (e.g., us-west-2)"} 135 | }, 136 | "required": ["vpc_id"] 137 | } 138 | } 139 | } 140 | }, 141 | { 142 | "toolSpec": { 143 | "name": "get_route_tables", 144 | "description": "Get route tables for a specified VPC", 145 | "inputSchema": { 146 | "json": { 147 | "type": "object", 148 | "properties": { 149 | "vpc_id": {"type": "string", "description": "VPC ID"}, 150 | "region": {"type": "string", "description": "AWS region (e.g., us-west-2)"} 151 | }, 152 | "required": ["vpc_id"] 153 | } 154 | } 155 | } 156 | } 157 | ] 158 | 159 | 160 | def handle_vpc_tool(tool_use): 161 | tool_name = tool_use['name'] 162 | input_data = tool_use['input'] 163 | 164 | if tool_name == "list_vpcs": 165 | result = list_vpcs(region=input_data.get('region', 'us-west-2')) 166 | elif tool_name == "check_internet_gateway": 167 | result = check_internet_gateway(input_data['vpc_id'], region=input_data.get('region', 'us-west-2')) 168 | elif tool_name == "check_nat_gateway": 169 | result = check_nat_gateway(input_data['vpc_id'], region=input_data.get('region', 'us-west-2')) 170 | elif tool_name == "get_route_tables": 171 | result = get_route_tables(input_data['vpc_id'], region=input_data.get('region', 'us-west-2')) 172 | else: 173 | result = {"error": f"Unknown VPC tool: {tool_name}"} 174 | 175 | return { 176 | "role": "user", 177 | "content": [ 178 | { 179 | "toolResult": { 180 | "toolUseId": tool_use['toolUseId'], 181 | "content": [{"json": result}], 182 | "status": "success" 183 | } 184 | } 185 | ] 186 | } 187 | -------------------------------------------------------------------------------- /network_agent/bedrock_utils.py: -------------------------------------------------------------------------------- 1 | # bedrock_utils.py 2 | import boto3 3 | import json 4 | import logging 5 | from botocore.exceptions import BotoCoreError, ClientError 6 | from typing import Dict, List, Any, Optional 7 | 8 | # Set up logging 9 | logging.basicConfig( 10 | level=logging.INFO, 11 | format='%(asctime)s | %(levelname)s | %(message)s', 12 | datefmt='%H:%M' 13 | ) 14 | logger = logging.getLogger(__name__) 15 | 16 | 17 | # Define available models 18 | AVAILABLE_MODELS = { 19 | "claude_3_sonnet": "anthropic.claude-3-sonnet-20240229-v1:0", 20 | "claude_3_haiku": "anthropic.claude-3-haiku-20240307-v1:0", 21 | "claude_2": "anthropic.claude-v2:1", 22 | # Add more models as needed 23 | } 24 | 25 | # Default model 26 | DEFAULT_MODEL = "claude_3_haiku" 27 | 28 | 29 | SYSTEM_MESSAGE = """ 30 | You are an AWS Network Assistant, designed to help with AWS networking tasks and queries. \ 31 | Your role is to provide accurate information and assist with AWS networking operations using the tools provided. Follow these guidelines strictly: 32 | 33 | 1. Tool Usage: 34 | - Always use the provided tools to fetch real data. Never invent or assume information. 35 | - If a tool requires information that hasn't been provided (e.g., region or VPC ID), ask the user for this information before using the tool. 36 | - Explain which tool you're using and why before you use it. 37 | 38 | 2. Responses: 39 | - Be honest and transparent. If you don't have enough information or if the tools don't provide the necessary data, say so clearly. 40 | - Never make up examples or hypothetical data. Stick to facts from tool outputs or general AWS knowledge. 41 | - If you're unsure about something, express your uncertainty and suggest using a tool to verify. 42 | 43 | 3. AWS Resources: 44 | - Remember that AWS resource IDs (like VPC IDs, subnet IDs, etc.) are specific and should never be invented. 45 | - If asked about specific resources without context, always ask for clarification or suggest using a tool to list the resources first. 46 | 47 | 4. Clarity and Education: 48 | - Explain AWS concepts clearly when relevant to the user's query. 49 | - If a user's request is unclear or could be interpreted in multiple ways, ask for clarification. 50 | 51 | 5. Tool Limitations: 52 | - Be aware of the limitations of your tools. If a user asks for something beyond your tools' capabilities, explain this limitation clearly. 53 | 54 | 6. Privacy and Security: 55 | - Do not ask for or store AWS credentials or sensitive information. 56 | - Remind users not to share sensitive information if they attempt to do so. 57 | 58 | Remember, your primary goal is to assist with AWS networking tasks accurately and safely. \ 59 | Always prioritize providing correct information over guessing or making assumptions. 60 | """ 61 | 62 | 63 | def initialize_bedrock_client(region_name: str = "us-west-2") -> boto3.client: 64 | """ 65 | Initialize and return a Bedrock runtime client. 66 | 67 | This function creates a boto3 client for the Bedrock runtime service. 68 | It's used to interact with the Bedrock API for model invocations. 69 | 70 | Args: 71 | region_name (str): The AWS region to connect to. Defaults to "us-east-1". 72 | 73 | Returns: 74 | boto3.client: A configured Bedrock runtime client. 75 | 76 | Raises: 77 | BotoCoreError: If there's an issue creating the Bedrock client. 78 | """ 79 | try: 80 | client = boto3.client("bedrock-runtime", region_name=region_name) 81 | logger.info(f"Bedrock client initialized for region: {region_name}") 82 | return client 83 | except BotoCoreError as e: 84 | logger.error(f"Failed to initialize Bedrock client: {str(e)}") 85 | raise 86 | 87 | 88 | def create_converse_request(messages: List[Dict[str, Any]], tools: List[Dict[str, Any]], 89 | max_tokens: int = 500, temperature: float = 0.7, top_p: float = 1, model_key: str = DEFAULT_MODEL 90 | ) -> Dict[str, Any]: 91 | """ 92 | Create a request object for the Bedrock converse API. 93 | """ 94 | try: 95 | model_id = AVAILABLE_MODELS.get(model_key) 96 | if not model_id: 97 | raise ValueError(f"Invalid model key: {model_key}. Available models are: {', '.join(AVAILABLE_MODELS.keys())}") 98 | 99 | request = { 100 | "modelId": model_id, 101 | "messages": messages, 102 | "system": [{"text": SYSTEM_MESSAGE}], 103 | "inferenceConfig": { 104 | "maxTokens": max_tokens, 105 | "temperature": temperature, 106 | "topP": top_p 107 | }, 108 | "toolConfig": { 109 | "tools": tools, 110 | "toolChoice": {"auto":{}} 111 | } 112 | } 113 | logger.debug(f"Created converse request for model: {model_id}") 114 | return request 115 | except Exception as e: 116 | logger.error(f"Error creating converse request: {str(e)}") 117 | raise 118 | 119 | 120 | 121 | 122 | def converse_with_claude(bedrock_client: boto3.client, request: Dict[str, Any], model_key: str = DEFAULT_MODEL) -> Optional[Dict[str, Any]]: 123 | """ 124 | Send a request to Claude via the Bedrock converse API. 125 | 126 | This function sends the prepared request to the Bedrock API and handles the response. 127 | 128 | Args: 129 | bedrock_client (boto3.client): The Bedrock runtime client. 130 | request (Dict[str, Any]): The prepared request payload. 131 | model_key (str): Key for the model to use. Defaults to DEFAULT_MODEL. 132 | 133 | Returns: 134 | Optional[Dict[str, Any]]: The model's response message, or None if an error occurred. 135 | 136 | Raises: 137 | ClientError: If there's an API-specific error from Bedrock. 138 | ValueError: If an invalid model key is provided. 139 | Exception: For any other unexpected errors. 140 | """ 141 | try: 142 | model_id = AVAILABLE_MODELS.get(model_key) 143 | if not model_id: 144 | raise ValueError(f"Invalid model key: {model_key}. Available models are: {', '.join(AVAILABLE_MODELS.keys())}") 145 | 146 | request["modelId"] = model_id # Ensure the correct model ID is used 147 | response = bedrock_client.converse(**request) 148 | logger.info(f"Successfully received response from Bedrock using model: {DEFAULT_MODEL}") 149 | return response['output']['message'] 150 | except ClientError as e: 151 | error_code = e.response['Error']['Code'] 152 | error_message = e.response['Error']['Message'] 153 | logger.error(f"Bedrock API error: {error_code} - {error_message}") 154 | if error_code == "ValidationException": 155 | logger.warning("Check if the conversation alternates correctly between user and assistant roles") 156 | raise 157 | except Exception as e: 158 | logger.error(f"Unexpected error in converse_with_claude: {str(e)}") 159 | raise --------------------------------------------------------------------------------