├── data └── .gitkeep ├── tests ├── __init__.py ├── test_utils.py ├── test_chat_agent.py ├── test_gis_work_tool.py ├── test_geocode_tool.py └── test_action_summarizer.py ├── geospatial_agent ├── __init__.py ├── cli │ ├── __init__.py │ └── main.py ├── agent │ ├── __init__.py │ ├── geo_chat │ │ ├── __init__.py │ │ ├── tools │ │ │ ├── __init__.py │ │ │ ├── geocode_tool.py │ │ │ └── gis_work_tool.py │ │ └── chat_agent.py │ ├── geospatial │ │ ├── __init__.py │ │ ├── planner │ │ │ ├── __init__.py │ │ │ ├── prompts.py │ │ │ └── planner.py │ │ ├── solver │ │ │ ├── __init__.py │ │ │ ├── constants.py │ │ │ ├── prompts.py │ │ │ ├── op_graph.py │ │ │ └── solver.py │ │ └── agent.py │ ├── action_summarizer │ │ ├── __init__.py │ │ ├── prompts.py │ │ └── action_summarizer.py │ └── shared.py └── shared │ ├── __init__.py │ ├── prompts.py │ ├── bedrock.py │ ├── utils.py │ ├── location.py │ └── shim.py ├── pytest.ini ├── .env ├── CODE_OF_CONDUCT.md ├── Makefile ├── .gitignore ├── Dockerfile ├── LICENSE ├── pyproject.toml ├── CONTRIBUTING.md └── README.md /data/.gitkeep: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /geospatial_agent/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /geospatial_agent/cli/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /geospatial_agent/agent/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /geospatial_agent/shared/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /geospatial_agent/agent/geo_chat/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /geospatial_agent/agent/geospatial/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /geospatial_agent/agent/geo_chat/tools/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /geospatial_agent/agent/geospatial/planner/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /geospatial_agent/agent/geospatial/solver/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /geospatial_agent/agent/action_summarizer/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /pytest.ini: -------------------------------------------------------------------------------- 1 | [pytest] 2 | addopts = --ignore=manual_tests 3 | -------------------------------------------------------------------------------- /.env: -------------------------------------------------------------------------------- 1 | API_KEY_NAME=AgentAPIKey 2 | MAP_NAME=AgentMap 3 | PLACE_INDEX_NAME=AgentPlaceIndex 4 | 5 | -------------------------------------------------------------------------------- /geospatial_agent/shared/prompts.py: -------------------------------------------------------------------------------- 1 | GIS_AGENT_ROLE_INTRO = r'You are a geospatial data scientist and an expert python developer.' 2 | HUMAN_STOP_SEQUENCE = '\n\nHuman' 3 | HUMAN_ROLE = HUMAN_STOP_SEQUENCE 4 | ASSISTANT_ROLE = "\n\nAssistant" 5 | -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | ## Code of Conduct 2 | This project has adopted the [Amazon Open Source Code of Conduct](https://aws.github.io/code-of-conduct). 3 | For more information see the [Code of Conduct FAQ](https://aws.github.io/code-of-conduct-faq) or contact 4 | opensource-codeofconduct@amazon.com with any additional questions or comments. 5 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | build: install test 2 | 3 | install: 4 | poetry install 5 | 6 | test: 7 | poetry run coverage run -m pytest 8 | 9 | create-session: 10 | mkdir -p geospatial-agent-session-storage/$(SESSION_ID)/data 11 | mkdir -p geospatial-agent-session-storage/$(SESSION_ID)/generated 12 | cp data/$(FILE_NAME) geospatial-agent-session-storage/$(SESSION_ID)/data 13 | -------------------------------------------------------------------------------- /geospatial_agent/agent/geospatial/solver/constants.py: -------------------------------------------------------------------------------- 1 | NODE_TYPE_ATTRIBUTE = "node_type" 2 | NODE_NAME_ATTRIBUTE = "node_name" 3 | NODE_OUTPUT_ATTRIBUTE = "output" 4 | NODE_INPUT_ATTRIBUTE = "input" 5 | NODE_DESCRIPTION_ATTRIBUTE = "description" 6 | NODE_DATA_PATH_ATTRIBUTE = "data_path" 7 | NODE_TYPE_OPERATION = "operation" 8 | NODE_TYPE_DATA = "data" 9 | NODE_TYPE_OPERATION_TYPE = "operation_type" 10 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build 12 | 13 | # Environments 14 | .venv 15 | env/ 16 | venv/ 17 | ENV/ 18 | env.bak/ 19 | venv.bak/ 20 | 21 | # PyCharm 22 | .idea/ 23 | generated/ 24 | .DS_Store 25 | geospatial-agent-session-storage 26 | .run 27 | 28 | .coverage 29 | 30 | # rtx configurations 31 | .rtx.toml 32 | -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | FROM public.ecr.aws/docker/library/python:3.11.5-slim 2 | 3 | RUN apt-get update && apt-get install make 4 | RUN apt-get install -y python3-setuptools groff 5 | 6 | RUN pip3 install "poetry" 7 | RUN pip3 --no-cache-dir install --upgrade awscli 8 | 9 | ENV WORK_DIR="/var/task" 10 | WORKDIR ${WORK_DIR} 11 | 12 | COPY poetry.lock pyproject.toml README.md .env Makefile ${WORK_DIR}/ 13 | 14 | COPY geospatial_agent ${WORK_DIR}/geospatial_agent 15 | COPY tests ${WORK_DIR}/tests 16 | 17 | RUN poetry install 18 | ADD data ${WORK_DIR}/data 19 | -------------------------------------------------------------------------------- /tests/test_utils.py: -------------------------------------------------------------------------------- 1 | from assertpy import assert_that 2 | 3 | from geospatial_agent.shared.utils import extract_code, ExtractionException 4 | 5 | 6 | def test_extract_code_works_with_python_markdown_blocks(): 7 | # Get code of method_for_code using inspect.getsource 8 | method_code = """ 9 | def method_for_code(): 10 | pass 11 | """ 12 | python_block_code = f'```python\n{method_code}\n```' 13 | extracted_code = extract_code(python_block_code) 14 | assert_that(extracted_code.strip()).is_equal_to(method_code.strip()) 15 | 16 | 17 | def test_extract_code_raise_exception_when_code_is_not_in_markdown_block(): 18 | python_block_code = 'some code' 19 | assert_that(extract_code).raises(ExtractionException).when_called_with(python_block_code) 20 | -------------------------------------------------------------------------------- /geospatial_agent/shared/bedrock.py: -------------------------------------------------------------------------------- 1 | import boto3 2 | from botocore.client import BaseClient 3 | from botocore.config import Config 4 | from langchain.llms import Bedrock 5 | 6 | import os 7 | 8 | 9 | def get_claude_v2(max_tokens_to_sample=8100, temperature=0.001): 10 | client = get_bedrock_client() 11 | llm = Bedrock(model_id="anthropic.claude-v2", 12 | client=client, 13 | model_kwargs={ 14 | "max_tokens_to_sample": max_tokens_to_sample, 15 | "temperature": temperature 16 | }) 17 | return llm 18 | 19 | 20 | def get_bedrock_client() -> BaseClient: 21 | profile = os.environ.get("AWS_PROFILE", None) 22 | session = boto3.Session(profile_name=profile) 23 | cfg = Config(retries={'max_attempts': 10, 'mode': 'adaptive'}) 24 | client: BaseClient = session.client("bedrock-runtime", region_name="us-east-1", config=cfg) 25 | return client 26 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT No Attribution 2 | 3 | Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy of 6 | this software and associated documentation files (the "Software"), to deal in 7 | the Software without restriction, including without limitation the rights to 8 | use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of 9 | the Software, and to permit persons to whom the Software is furnished to do so. 10 | 11 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 12 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS 13 | FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR 14 | COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER 15 | IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN 16 | CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 17 | 18 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.poetry] 2 | name = "geospatial-agent" 3 | version = "0.1.0" 4 | description = "" 5 | authors = ["Amazon Location Team"] 6 | readme = "README.md" 7 | packages = [{ include = "geospatial_agent" }] 8 | 9 | [tool.poetry.dependencies] 10 | python = ">= 3.10.13, <= 3.11.5" 11 | langchain = ">= 0.1.0" 12 | langchain-community = ">= 0.2.9" 13 | python-json-logger = "^2.0.7" 14 | click = "^8.1.6" 15 | networkx = "^3.1" 16 | pydot = "^1.4.2" 17 | matplotlib = "^3.7.2" 18 | pydispatcher = "^2.0.7" 19 | contextily = "^1.3.0" 20 | mapclassify = "^2.6.0" 21 | pydeck = { extras = ["jupyter"], version = "^0.8.0" } 22 | idna = "^3.4" 23 | anyio = "^3.7.1" 24 | importlib-metadata = "^6.8.0" 25 | sniffio = "^1.3.0" 26 | typing-extensions = "^4.7.1" 27 | zipp = "^3.16.2" 28 | geopandas = "^0.13.2" 29 | pygments = "^2.16.1" 30 | boto3 = "^1.28.63" 31 | botocore = "^1.31.63" 32 | python-dotenv = "^1.0.0" 33 | fiona = ">= 1.10b2" 34 | 35 | [tool.poetry.group.dev.dependencies] 36 | pytest = "^7.4.0" 37 | assertpy = "^1.1" 38 | coverage = "^7.3.0" 39 | pytest-mock = "^3.11.1" 40 | 41 | [build-system] 42 | requires = ["poetry-core"] 43 | build-backend = "poetry.core.masonry.api" 44 | 45 | [tool.poetry.scripts] 46 | agent = "geospatial_agent.cli.main:main" 47 | -------------------------------------------------------------------------------- /geospatial_agent/shared/utils.py: -------------------------------------------------------------------------------- 1 | import re 2 | 3 | import networkx 4 | from pydot import Dot 5 | 6 | 7 | class ExtractionException(Exception): 8 | def __init__(self, message: str): 9 | self.message = message 10 | super().__init__(self.message) 11 | 12 | 13 | def extract_content_xml(tag: str, response: str) -> str: 14 | pattern = f"<{tag}>(.*?)<\/{tag}>" 15 | match = re.search(pattern, response, re.DOTALL) 16 | if match: 17 | return match.group(1).strip() 18 | else: 19 | raise ExtractionException(f"Failed to extract {tag} from response") 20 | 21 | 22 | def extract_code(response): 23 | """Extract python code from LLM response.""" 24 | 25 | python_code_match = re.search(r"```(?:python)?(.*?)```", response, re.DOTALL) 26 | if python_code_match: 27 | python_code = python_code_match.group(1).strip() 28 | return python_code 29 | else: 30 | raise ExtractionException("Failed to extract python code from response") 31 | 32 | 33 | def get_dot_graph(graph: networkx.DiGraph) -> Dot: 34 | """Returns a dot graph using pydot from a networkx graph""" 35 | graph_dot: Dot = networkx.drawing.nx_pydot.to_pydot(graph) 36 | return graph_dot 37 | 38 | 39 | def get_exception_messages(ex: Exception) -> str: 40 | msg = "" 41 | while ex: 42 | msg += ex.__str__() + "\n" 43 | ex = ex.__cause__ 44 | return msg 45 | -------------------------------------------------------------------------------- /tests/test_chat_agent.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from assertpy import assert_that 4 | from langchain.tools import Tool 5 | 6 | from geospatial_agent.agent.geo_chat.chat_agent import GeoChatAgent 7 | from geospatial_agent.agent.geo_chat.tools.geocode_tool import GEOCODE_TOOL 8 | from geospatial_agent.agent.geo_chat.tools.gis_work_tool import GIS_WORK_TOOL 9 | from geospatial_agent.shared.location import ENV_PLACE_INDEX_NAME 10 | 11 | test_place_index = 'test_place_index' 12 | 13 | 14 | def test_initializing_geo_chat_agent_does_not_raise_exception(): 15 | os.environ[ENV_PLACE_INDEX_NAME] = test_place_index 16 | 17 | geo_chat_agent = GeoChatAgent() 18 | assert_that(geo_chat_agent).is_not_none() 19 | 20 | 21 | def test_invoking_geo_chat_agent_does_not_raise_exception(mocker): 22 | mocker.patch(f'{GeoChatAgent.__module__}.geocode_tool', 23 | return_value=Tool.from_function(func=lambda q: "geocoded response", name=GEOCODE_TOOL, 24 | description="test description")) 25 | mocker.patch(f'{GeoChatAgent.__module__}.gis_work_tool', 26 | return_value=Tool.from_function(func=lambda q: "gis work complete", name=GIS_WORK_TOOL, 27 | description="test description")) 28 | 29 | mocker.patch(f'{GeoChatAgent.__module__}.AgentExecutor.run', return_value="The agent has finished running!") 30 | 31 | geo_chat_agent = GeoChatAgent() 32 | output = geo_chat_agent.invoke( 33 | agent_input="test input", session_id="test_session_id", storage_mode="test_storage_mode") 34 | 35 | assert_that(output).is_equal_to("The agent has finished running!") 36 | -------------------------------------------------------------------------------- /geospatial_agent/agent/geo_chat/tools/geocode_tool.py: -------------------------------------------------------------------------------- 1 | from langchain.tools import tool 2 | from pydispatch import dispatcher 3 | 4 | from geospatial_agent.shared.location import get_location_client, get_place_index_name 5 | 6 | GEOCODE_TOOL = "geocode_tool" 7 | GEOCODE_TOOL_FAILED = "geocode_tool_failed" 8 | 9 | 10 | def geocode_tool(location_client=None, place_index_name: str = ""): 11 | if not location_client: 12 | location_client = get_location_client() 13 | 14 | if not place_index_name: 15 | place_index_name = get_place_index_name() 16 | 17 | @tool(GEOCODE_TOOL) 18 | def geocode_tool_func(query: str) -> str: 19 | """\ 20 | A tool that geocodes a given address using the AWS Location service. 21 | The input is a string that could be an address, area, neighborhood, city, or country. 22 | The output is list of places that match the input. Each place comes with a label and a 23 | pair of coordinates following the format [longitude, latitude] that represents the physical location of the input. 24 | """ 25 | 26 | try: 27 | response = location_client.search_place_index_for_text( 28 | IndexName=place_index_name, 29 | MaxResults=10, 30 | Text=query 31 | ) 32 | 33 | response_string = "" 34 | for place in response['Results']: 35 | label_with_geom = f"{place['Place']['Label']}: {place['Place']['Geometry']['Point']}" 36 | response_string += label_with_geom + "\n" 37 | 38 | return response_string 39 | except Exception as e: 40 | dispatcher.send(signal=GEOCODE_TOOL_FAILED, sender=GEOCODE_TOOL, event_data=e) 41 | return "Observation: The tool did not find any results." 42 | 43 | return geocode_tool_func 44 | -------------------------------------------------------------------------------- /tests/test_gis_work_tool.py: -------------------------------------------------------------------------------- 1 | from unittest.mock import Mock 2 | 3 | from assertpy import assert_that 4 | from langchain.tools import Tool 5 | 6 | from geospatial_agent.agent.action_summarizer.action_summarizer import ActionSummarizer, ActionSummary 7 | from geospatial_agent.agent.geo_chat.tools.gis_work_tool import gis_work_tool 8 | from geospatial_agent.agent.geospatial.agent import GeospatialAgent 9 | 10 | 11 | def test_initializing_gis_work_tool_does_not_raise_error(): 12 | # Create a mock ActionSummarizer object 13 | mock_action_summarizer = Mock(spec=ActionSummarizer) 14 | # Create a mock GeospatialAgent object 15 | mock_geospatial_agent = Mock(spec=GeospatialAgent) 16 | 17 | tool = gis_work_tool( 18 | session_id='test-session-id', 19 | action_summarizer=mock_action_summarizer, 20 | gis_agent=mock_geospatial_agent, 21 | storage_mode='test-storage-mode' 22 | ) 23 | 24 | assert_that(tool).is_not_none() 25 | assert_that(tool).is_instance_of(Tool) 26 | 27 | 28 | def test_using_gis_work_tool_does_not_raise_error(): 29 | mock_action_summarizer = ActionSummarizer 30 | mock_action_summarizer.invoke = Mock( 31 | return_value=ActionSummary( 32 | action="The user wants to draw a heatmap", 33 | file_summaries=[] 34 | )) 35 | 36 | mock_gis_agent = GeospatialAgent 37 | mock_gis_agent.invoke = Mock(return_value=None) 38 | 39 | tool = gis_work_tool(session_id='test_session_id', 40 | action_summarizer=mock_action_summarizer, 41 | gis_agent=mock_gis_agent, 42 | storage_mode='test-storage-mode') 43 | 44 | output = tool.run(tool_input={ 45 | 'user_input': 'Draw me a heatmap!' 46 | }) 47 | 48 | assert_that(output).is_equal_to("Observation: GIS Agent has completed it's work. This is the final answer.") 49 | -------------------------------------------------------------------------------- /geospatial_agent/shared/location.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import boto3 4 | 5 | ENV_MAP_NAME = "MAP_NAME" 6 | ENV_PLACE_INDEX_NAME = "PLACE_INDEX_NAME" 7 | 8 | 9 | class LocationConfigurationError(Exception): 10 | def __init__(self, message: str): 11 | self.message = message 12 | super().__init__(self.message) 13 | 14 | 15 | def get_location_client(): 16 | return boto3.client("location") 17 | 18 | 19 | def get_place_index_name(): 20 | """Gets Place Index name from environment variable PLACE_INDEX_NAME""" 21 | place_index_name = os.environ.get(ENV_PLACE_INDEX_NAME) 22 | if not place_index_name: 23 | raise LocationConfigurationError(f"{ENV_PLACE_INDEX_NAME} environment variable is not set") 24 | 25 | return place_index_name 26 | 27 | 28 | def get_api_key() -> str: 29 | """Gets API Key referenced by API Key Name from environment variables.""" 30 | 31 | api_key_arn = os.environ.get("API_KEY_NAME") 32 | if not api_key_arn: 33 | raise LocationConfigurationError("API_KEY_NAME environment variable is not set") 34 | 35 | try: 36 | location = get_location_client() 37 | api_key = location.describe_key(KeyName=api_key_arn) 38 | except Exception as e: 39 | raise LocationConfigurationError(f"Error getting API Key") from e 40 | 41 | return api_key["Key"] 42 | 43 | 44 | def get_map_style_uri(): 45 | """Returns map style URI inside the style JSON returned by GetMapStyleDescriptor API of Amazon Location Service""" 46 | map_name = os.environ.get(ENV_MAP_NAME) 47 | if not map_name: 48 | raise LocationConfigurationError("MAP_NAME environment variable is not set") 49 | 50 | try: 51 | client = get_location_client() 52 | op_path = f'/maps/v0/maps/{map_name}/style-descriptor' 53 | style_uri = f"https://maps.{client.meta.service_model.signing_name}.{client.meta.region_name}.amazonaws.com{op_path}?key={get_api_key()}" 54 | return style_uri 55 | except Exception as e: 56 | raise LocationConfigurationError(f"Error getting map style URI") from e 57 | -------------------------------------------------------------------------------- /geospatial_agent/agent/geo_chat/tools/gis_work_tool.py: -------------------------------------------------------------------------------- 1 | from langchain.tools import Tool 2 | 3 | from geospatial_agent.agent.action_summarizer.action_summarizer import ActionSummarizer 4 | from geospatial_agent.agent.geospatial.agent import GeospatialAgent 5 | 6 | GIS_WORK_TOOL = "gis_work_tool" 7 | 8 | 9 | def gis_work_tool(session_id: str, storage_mode: str, action_summarizer=None, gis_agent=None): 10 | desc = f"""\ 11 | A tool that invokes a {GeospatialAgent.__name__} if the user action is requires geospatial analysis to be done on user provided data. 12 | {GeospatialAgent.__name__} description: {GeospatialAgent.__doc__} 13 | 14 | It accepts two inputs: user_input and session_id. 15 | 16 | An example query might look like the following: 17 | Draw time series choropleth map of weather temperature change over major cities of the world. 18 | 19 | Data Locations: 20 | 1. Climate Change: Earth Surface Temperature Data location since 1940s data location: GlobalLandTemperaturesByCity.csv 21 | 22 | A qualified action for the tool have the following requirements: 23 | 1. A geospatial analysis action such as heatmap, choropleth, or time series. 24 | 2. A data location such as a scheme://URI or just a file name such as data.csv. 25 | 26 | DO NOT invoke this tool unless both of these requirements are met. 27 | 28 | This tool will invoke the GIS agent to perform the geospatial analysis on the data. 29 | The return is freeform string or a URL to the result of the analysis.""" 30 | 31 | if action_summarizer is None: 32 | action_summarizer = ActionSummarizer() 33 | 34 | if gis_agent is None: 35 | gis_agent = GeospatialAgent(storage_mode=storage_mode) 36 | 37 | def gis_work_tool_func(user_input: str): 38 | action_summary = action_summarizer.invoke( 39 | user_input=user_input, session_id=session_id, storage_mode=storage_mode) 40 | output = gis_agent.invoke(action_summary=action_summary, session_id=session_id) 41 | 42 | return (f"Observation: GIS Agent has completed it's work. I should list the generated code file path, and " 43 | f"generated visualization file path from the code output, if applicable." 44 | f"Generated code path = {output.assembled_code_file_path}. " 45 | f"Generated code output = {output.assembled_code_output}.") 46 | 47 | return Tool.from_function(func=gis_work_tool_func, name=GIS_WORK_TOOL, description=desc) 48 | -------------------------------------------------------------------------------- /geospatial_agent/agent/shared.py: -------------------------------------------------------------------------------- 1 | import sys 2 | from enum import Enum, auto 3 | from io import StringIO 4 | from typing import TypeVar 5 | from datetime import datetime 6 | 7 | from pydantic import BaseModel, Field 8 | 9 | from uuid import uuid4 10 | 11 | SENDER_ACTION_SUMMARIZER = "action_summarizer" 12 | SENDER_GEOSPATIAL_AGENT = "geospatial_agent" 13 | SENDER_GEO_CHAT_AGENT = "geo_chat_agent" 14 | 15 | SIGNAL_ASSEMBLED_CODE_EXECUTING = "assembled_code_executing" 16 | SIGNAL_ASSEMBLED_CODE_EXECUTED = "assembled_code_executed" 17 | SIGNAL_GRAPH_CODE_GENERATED = "plan_graph_code_generated" 18 | SIGNAL_TASK_NAME_GENERATED = "task_name_generated" 19 | SIGNAL_OPERATION_CODE_GENERATED = "operation_code_generated" 20 | SIGNAL_CODE_REVIEW_GENERATED = "operation_code_review_generated" 21 | SIGNAL_ACTION_CONTEXT_GENERATED = "action_context_generated" 22 | SIGNAL_FILE_READ_CODE_GENERATED = "file_read_code_generated" 23 | SIGNAL_FILE_READ_CODE_EXECUTED = "file_read_code_executed" 24 | SIGNAL_TAIL_CODE_GENERATED = "tail_code_generated" 25 | 26 | SIGNAL_GEO_CHAT_INITIATED = "geo_chat_initiated" 27 | SIGNAL_GEO_CHAT_RESPONSE_COMPLETE = "geo_chat_response_complete" 28 | 29 | ALL_SIGNALS = [ 30 | SIGNAL_ASSEMBLED_CODE_EXECUTING, 31 | SIGNAL_ASSEMBLED_CODE_EXECUTED, 32 | SIGNAL_GRAPH_CODE_GENERATED, 33 | SIGNAL_TASK_NAME_GENERATED, 34 | SIGNAL_OPERATION_CODE_GENERATED, 35 | SIGNAL_CODE_REVIEW_GENERATED, 36 | SIGNAL_ACTION_CONTEXT_GENERATED, 37 | SIGNAL_FILE_READ_CODE_GENERATED, 38 | SIGNAL_FILE_READ_CODE_EXECUTED, 39 | SIGNAL_TAIL_CODE_GENERATED, 40 | SIGNAL_GEO_CHAT_INITIATED, 41 | SIGNAL_GEO_CHAT_RESPONSE_COMPLETE 42 | ] 43 | 44 | 45 | # enum for event types - CodePython, Message, Error 46 | class EventType(Enum): 47 | PythonCode = auto() 48 | Message = auto() 49 | Error = auto() 50 | 51 | 52 | T = TypeVar('T') 53 | 54 | 55 | class AgentSignal(BaseModel): 56 | id: str = Field(default_factory=lambda: uuid4().__str__()) 57 | timestamp: str = Field(default_factory=lambda: datetime.now().isoformat()) 58 | event_source: str = Field() 59 | event_message: str = Field() 60 | event_data: T = Field(default=None) 61 | event_type: EventType = Field(default=EventType.Message) 62 | is_final: bool = Field(default=False) 63 | 64 | 65 | def execute_assembled_code(assembled_code): 66 | """Executes the assembled code and returns the output.""" 67 | old_stdout = sys.stdout 68 | redirected_output = sys.stdout = StringIO() 69 | try: 70 | exec(assembled_code, globals(), globals()) 71 | except Exception as e: 72 | raise e 73 | finally: 74 | sys.stdout = old_stdout 75 | 76 | output = redirected_output.getvalue() 77 | return output, globals() 78 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing Guidelines 2 | 3 | Thank you for your interest in contributing to our project. Whether it's a bug report, new feature, correction, or additional 4 | documentation, we greatly value feedback and contributions from our community. 5 | 6 | Please read through this document before submitting any issues or pull requests to ensure we have all the necessary 7 | information to effectively respond to your bug report or contribution. 8 | 9 | 10 | ## Reporting Bugs/Feature Requests 11 | 12 | We welcome you to use the GitHub issue tracker to report bugs or suggest features. 13 | 14 | When filing an issue, please check existing open, or recently closed, issues to make sure somebody else hasn't already 15 | reported the issue. Please try to include as much information as you can. Details like these are incredibly useful: 16 | 17 | * A reproducible test case or series of steps 18 | * The version of our code being used 19 | * Any modifications you've made relevant to the bug 20 | * Anything unusual about your environment or deployment 21 | 22 | 23 | ## Contributing via Pull Requests 24 | Contributions via pull requests are much appreciated. Before sending us a pull request, please ensure that: 25 | 26 | 1. You are working against the latest source on the *main* branch. 27 | 2. You check existing open, and recently merged, pull requests to make sure someone else hasn't addressed the problem already. 28 | 3. You open an issue to discuss any significant work - we would hate for your time to be wasted. 29 | 30 | To send us a pull request, please: 31 | 32 | 1. Fork the repository. 33 | 2. Modify the source; please focus on the specific change you are contributing. If you also reformat all the code, it will be hard for us to focus on your change. 34 | 3. Ensure local tests pass. 35 | 4. Commit to your fork using clear commit messages. 36 | 5. Send us a pull request, answering any default questions in the pull request interface. 37 | 6. Pay attention to any automated CI failures reported in the pull request, and stay involved in the conversation. 38 | 39 | GitHub provides additional document on [forking a repository](https://help.github.com/articles/fork-a-repo/) and 40 | [creating a pull request](https://help.github.com/articles/creating-a-pull-request/). 41 | 42 | 43 | ## Finding contributions to work on 44 | Looking at the existing issues is a great way to find something to contribute on. As our projects, by default, use the default GitHub issue labels (enhancement/bug/duplicate/help wanted/invalid/question/wontfix), looking at any 'help wanted' issues is a great place to start. 45 | 46 | 47 | ## Code of Conduct 48 | This project has adopted the [Amazon Open Source Code of Conduct](https://aws.github.io/code-of-conduct). 49 | For more information see the [Code of Conduct FAQ](https://aws.github.io/code-of-conduct-faq) or contact 50 | opensource-codeofconduct@amazon.com with any additional questions or comments. 51 | 52 | 53 | ## Security issue notifications 54 | If you discover a potential security issue in this project we ask that you notify AWS/Amazon Security via our [vulnerability reporting page](http://aws.amazon.com/security/vulnerability-reporting/). Please do **not** create a public github issue. 55 | 56 | 57 | ## Licensing 58 | 59 | See the [LICENSE](LICENSE) file for our project's licensing. We will ask you to confirm the licensing of your contribution. 60 | -------------------------------------------------------------------------------- /geospatial_agent/agent/geospatial/planner/prompts.py: -------------------------------------------------------------------------------- 1 | _graph_generation_instructions = """ 2 | Generate a graph, the data structure only, whose nodes are 3 | 1. A series of consecutive steps and 4 | 2. Framework to achieve the following goal 5 | """ 6 | 7 | _graph_reply_example = r""" 8 | ```python 9 | import networkx as nx 10 | G = nx.DiGraph() 11 | # Add nodes and edges for the graph 12 | # Load covid 19 shapefile from remote source 13 | G.add_node("covid_19_shp_url", node_type="data", data_path="agent://covid_19_shapefile.zip", description="Covid 19 shapefile URI") 14 | G.add_node("load_covid_19_shp", node_type="operation", description="Load Covid 19 shapefile") 15 | G.add_edge("covid_19_shp_url", "load_covid_19_shp") 16 | G.add_node("covid_19_gdf", node_type="data", description="Covid 19 shapefile GeoDataFrame") 17 | G.add_edge("load_covid_19_shp", "covid_19_gdf") 18 | 19 | ... 20 | 21 | ``` 22 | """ 23 | 24 | _graph_requirement_list = [ 25 | "Create a single NetworkX graph instance. Graph in NetworkX will represent steps and data.", 26 | "No disconnected components are allowed.", 27 | "There are two types of nodes: operation node and data node.", 28 | "A data node is always followed by an operation node. An operation node is always followed by a data node." 29 | "Operation node accepts data nodes as parameters and writes data nodes as outputs to next operation", 30 | "Input of an Operation node is the data node output of previous operations, except for data loading or collection.", 31 | "First operations are data loading or collection, and the last operation output is the final answer.", 32 | "Use goepandas for spatial data if the goal is to make a map or visualization.", 33 | "Succinctly name all nodes.", 34 | "Produce a concise graph with minimum amount of steps.", 35 | "Nodes should have these attributes: node_type (data or operation), data_path (only for data node), operation_type (only for operation node), and description.", 36 | "operation_type is a single word tag to categorize the operations. For example, visualization, map, plot, load, transform, and spatial_join.", 37 | "Do not generate code to implement the steps.", 38 | "Only use the provided data. Use external, only from Github if needed.", 39 | "Only use columns or attributes noted in Data locations section. Do NOT assume attributes or columns." 40 | "Put your reply into a Python code block enclosed by ```python and ```." 41 | ] 42 | 43 | _planning_graph_task_prompt_template = r""" 44 | {human_role}: 45 | Your Role: {planner_role_intro} 46 | Your task: 47 | {graph_generation_instructions} {task_definition} 48 | 49 | 50 | Your reply needs to meet the following requirements: 51 | {graph_requirements} 52 | 53 | 54 | Your reply example: 55 | {graph_reply_example} 56 | 57 | 58 | Data locations (each data is a node): 59 | {data_locations_instructions} 60 | 61 | {assistant_role}: 62 | """ 63 | 64 | _task_name_generation_prompt = r""" 65 | {human_role}: Create an appropriate unix folder name from the following task definition below: 66 | 67 | 1. Do not use spaces. 68 | 2. Do not use slashes or any escape characters. 69 | 3. Use underscore (_) to connect multiple words if necessary. 70 | 4. Produce a lowercase, concise, and meaningful name. 71 | 5. Only return the folder name. 72 | 73 | The task definition is: 74 | {task_definition} 75 | 76 | {assistant_role}: 77 | """ 78 | -------------------------------------------------------------------------------- /tests/test_geocode_tool.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import boto3 4 | from assertpy import assert_that 5 | from botocore.stub import Stubber 6 | 7 | from geospatial_agent.agent.geo_chat.tools.geocode_tool import geocode_tool 8 | from geospatial_agent.shared.location import ENV_PLACE_INDEX_NAME 9 | 10 | test_place_index = 'test_place_index' 11 | test_query = 'test query' 12 | 13 | 14 | def test_initializing_geocode_tool_does_not_throw_error(): 15 | # Set env PLACE_INDEX_NAME to 'test_place_index' 16 | os.environ[ENV_PLACE_INDEX_NAME] = test_place_index 17 | 18 | tool = geocode_tool() 19 | assert_that(tool).is_not_none() 20 | 21 | 22 | def test_initializing_geocode_tool_throws_error_if_place_index_env_not_set(): 23 | # Set env PLACE_INDEX_NAME to None 24 | os.environ[ENV_PLACE_INDEX_NAME] = "" 25 | 26 | assert_that(geocode_tool).raises(Exception).when_called_with() 27 | 28 | 29 | def test_invoking_geocode_tool_throws_no_error_if_results_returned(): 30 | # Set env PLACE_INDEX_NAME to 'test_place_index' 31 | os.environ[ENV_PLACE_INDEX_NAME] = test_place_index 32 | 33 | # Mock boto3 client for location 34 | location = boto3.client('location') 35 | stubber = Stubber(location) 36 | search_place_index_for_text_response = { 37 | 'Results': [ 38 | { 39 | 'Distance': 123.0, 40 | 'Place': { 41 | 'Label': 'test-address-label', 42 | 'Geometry': { 43 | 'Point': [-37.71133, 144.86304] 44 | }, 45 | }, 46 | 'PlaceId': 'test-id', 47 | 'Relevance': 123.0 48 | }, 49 | ], 50 | 'Summary': { 51 | 'Text': 'test-summary', 52 | 'DataSource': 'test-datasource', 53 | } 54 | } 55 | expected_params = {'IndexName': test_place_index, 'MaxResults': 10, 'Text': test_query} 56 | stubber.add_response('search_place_index_for_text', search_place_index_for_text_response, expected_params) 57 | stubber.activate() 58 | 59 | tool = geocode_tool(location_client=location, place_index_name=test_place_index) 60 | response = tool(test_query).strip() 61 | 62 | place = search_place_index_for_text_response['Results'][0] 63 | response_string = f"{place['Place']['Label']}: {place['Place']['Geometry']['Point']}" 64 | 65 | stubber.deactivate() 66 | assert_that(response).is_equal_to(response_string) 67 | 68 | 69 | def test_invoking_geocode_tool_returns_no_results_observation_if_location_client_errors(): 70 | # Set env PLACE_INDEX_NAME to 'test_place_index' 71 | os.environ[ENV_PLACE_INDEX_NAME] = test_place_index 72 | 73 | # Mock boto3 client for location 74 | location = boto3.client('location') 75 | stubber = Stubber(location) 76 | expected_params = {'IndexName': test_place_index, 'MaxResults': 10, 'Text': test_query} 77 | stubber.add_client_error( 78 | 'search_place_index_for_text', 79 | service_error_code='TestServiceErrorCode', 80 | service_message='Test error message', 81 | http_status_code=500 82 | ) 83 | stubber.activate() 84 | 85 | tool = geocode_tool(location_client=location, place_index_name=test_place_index) 86 | response = tool(test_query).strip() 87 | 88 | stubber.deactivate() 89 | assert_that(response).is_equal_to("Observation: The tool did not find any results.") 90 | -------------------------------------------------------------------------------- /geospatial_agent/agent/geospatial/planner/planner.py: -------------------------------------------------------------------------------- 1 | import time 2 | 3 | from langchain import PromptTemplate, LLMChain 4 | from langchain.llms.base import LLM 5 | 6 | from geospatial_agent.agent.geospatial.planner.prompts import _graph_generation_instructions, \ 7 | _graph_reply_example, _task_name_generation_prompt, _graph_requirement_list, \ 8 | _planning_graph_task_prompt_template 9 | from geospatial_agent.shared.prompts import GIS_AGENT_ROLE_INTRO, HUMAN_STOP_SEQUENCE 10 | from geospatial_agent.shared.utils import extract_code 11 | 12 | 13 | class PlannerException(Exception): 14 | def __init__(self, message: str): 15 | self.message = message 16 | super().__init__(self.message) 17 | 18 | 19 | def gen_task_name(llm: LLM, task: str) -> str: 20 | """Returns a task name for creating unix folders from task description using LLM""" 21 | task_name_gen_prompt_template: PromptTemplate = PromptTemplate.from_template(_task_name_generation_prompt) 22 | task_name_gen_prompt = task_name_gen_prompt_template.format(human_role="Human", 23 | assistant_role="Assistant", 24 | task_definition=task) 25 | task_name = llm.predict(text=task_name_gen_prompt, stop=[HUMAN_STOP_SEQUENCE]).strip() 26 | task_name = f'{int(time.time())}_{task_name}' 27 | return task_name 28 | 29 | 30 | def gen_plan_graph(llm: LLM, task_definition: str, data_locations_instructions: str) -> str: 31 | """Returns a plan graph in the form of python code from a task definition.""" 32 | try: 33 | graph_plan_code = _gen_plan_graph_code(llm, task_definition, data_locations_instructions) 34 | return graph_plan_code 35 | except Exception as e: 36 | raise PlannerException(f"Failed to generate graph plan code for task") from e 37 | 38 | 39 | def _gen_plan_graph_code(llm: LLM, task_definition: str, data_locations_instructions: str): 40 | # Generating a graph plan python code using the LLM. 41 | graph_requirements = _get_graph_requirements() 42 | graph_gen_prompt_template: PromptTemplate = PromptTemplate.from_template(_planning_graph_task_prompt_template) 43 | chain = LLMChain(llm=llm, prompt=graph_gen_prompt_template) 44 | graph_plan_response = chain.run(human_role="Human", 45 | planner_role_intro=GIS_AGENT_ROLE_INTRO, 46 | graph_generation_instructions=_graph_generation_instructions, 47 | task_definition=task_definition.strip("\n").strip(), 48 | graph_requirements=graph_requirements, 49 | graph_reply_example=_graph_reply_example, 50 | data_locations_instructions=data_locations_instructions, 51 | assistant_role="Assistant", 52 | stop=[HUMAN_STOP_SEQUENCE]) 53 | # Use the LLM to generate a plan graph code 54 | graph_plan_code = extract_code(graph_plan_response) 55 | return graph_plan_code 56 | 57 | 58 | def _get_graph_requirements() -> str: 59 | """Returns planning graph requirements list""" 60 | requirements = _graph_requirement_list.copy() 61 | graph_requirement_str = '\n'.join([f"{idx + 1}. {line}" for idx, line in enumerate(requirements)]) 62 | return graph_requirement_str 63 | -------------------------------------------------------------------------------- /geospatial_agent/agent/action_summarizer/prompts.py: -------------------------------------------------------------------------------- 1 | _ROLE_INTRO = "You are a geospatial data analyzer designed to analyze data schema from arbitrary geospatial data sets." 2 | 3 | # Action Summary # 4 | 5 | _ACTION_SUMMARY_REQUIREMENTS = [ 6 | "Return a JSON object with keys: action, file_paths. The action is the intended user action. The file_paths are the file paths that are extracted from the message.", 7 | "Rephrase user action as a complete sentence with desired user action and include it in the action key.", 8 | "Only return the JSON object as output. Do not add any extra text.", 9 | "If no file paths are found file_paths will be an empty string list.", 10 | "If the file path is a HTTP(S) link, use the full link as output.", 11 | "If the file path is not a URI, add agent:// to the beginning of the filepath.", 12 | "If there are multiple file paths, add all file paths in the output. Follow the rules above for each filepath.", 13 | "File paths are case sensitive. It can have spaces, hyphens, underscores, and periods." 14 | ] 15 | 16 | _ACTION_SUMMARY_PROMPT = """\ 17 | {role_intro} 18 | {human_role}: A message is provided below. 19 | Your task is to extract the intended user action and all file paths from the message. Meet the requirements written below: 20 | 21 | Requirements: 22 | {requirements} 23 | 24 | 25 | Message: {message} 26 | 27 | {assistant_role}: 28 | """ 29 | 30 | # Read File # 31 | 32 | DATA_FRAMES_VARIABLE_NAME = "dataframes" 33 | 34 | _READ_FILE_REQUIREMENTS = [ 35 | "Read each file using geopandas. Each file could be csv, shapefile, or GeoJSON. Otherwise, throw a ValueError.", 36 | "Return a list of python dictionaries with keys: file_url, resolved_file_url, data_frame, column_names.", 37 | "Use built-in function resolved_file_url = get_data_file_url(file_url, session_id) to get downloadable URLs. Do not add import statement for this function.", 38 | "Take 3 random rows with no missing values to each data_frame.", 39 | f"After writing the function, call the function in the end and store the list of data_frame in a global variable named {DATA_FRAMES_VARIABLE_NAME}.", 40 | "Do not use any try except block.", 41 | "Put your reply into a Python code block(enclosed by ```python and ```) without any extra surrounding text.", 42 | "Use pandas, geopandas, numpy, and builtins to solve the problem. Do not use any external data sources or libraries." 43 | ] 44 | 45 | _READ_FILE_PROMPT = """\ 46 | {role_intro} 47 | {human_role}: You are provided a set of file URLs. You need to generate a Python function that meets the following requirements: 48 | 49 | Requirements: 50 | {requirements} 51 | 52 | Session Id: {session_id} 53 | Storage Mode: {storage_mode} 54 | 55 | File Urls: 56 | {file_urls} 57 | 58 | As 59 | {assistant_role}: 60 | """ 61 | 62 | # Generate Data Summary # 63 | 64 | _DATA_SUMMARY_REQUIREMENTS = [ 65 | "The summary should be at maximum two sentences.", 66 | "The first sentence should be summary of the data in the table from the aspect of the user action.", 67 | "If there is no geometry column in the table, the second sentence should note column names that can be used to generate a geometry column in geopandas.", 68 | "Write summary without any extra surrounding text." 69 | ] 70 | 71 | _DATA_SUMMARY_PROMPT = """\ 72 | {role_intro} 73 | {human_role}: You are provided with a table with some rows data. Your task is to generate a summary that describes the data in the table following the requirements below: 74 | 75 | Requirements: 76 | {requirements} 77 | 78 | Intended user action: {action} 79 | 80 | The table has following columns: 81 | {columns} 82 | 83 | Table: 84 | {table} 85 | 86 | 87 | {assistant_role}: 88 | """ 89 | -------------------------------------------------------------------------------- /geospatial_agent/cli/main.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import click 4 | import langchain 5 | from pydispatch import dispatcher 6 | from pygments import highlight 7 | from pygments.formatters.terminal import TerminalFormatter 8 | from pygments.lexers.python import PythonLexer 9 | 10 | from geospatial_agent.agent.geo_chat.chat_agent import GeoChatAgent 11 | from geospatial_agent.agent.shared import ALL_SIGNALS, AgentSignal, EventType 12 | from geospatial_agent.shared.shim import LOCAL_STORAGE_MODE, LocalStorage 13 | 14 | import uuid 15 | 16 | from geospatial_agent.shared.utils import get_exception_messages 17 | 18 | from dotenv import load_dotenv 19 | 20 | langchain.verbose = False 21 | 22 | load_dotenv() 23 | 24 | 25 | def set_langchain_verbose(is_verbose: bool): 26 | langchain.verbose = is_verbose 27 | 28 | 29 | def get_chatbot_response(user_input: str, session_id: str, verbose: bool, storage_mode: str): 30 | try: 31 | set_langchain_verbose(verbose) 32 | 33 | geo_chat_agent = GeoChatAgent() 34 | response = geo_chat_agent.invoke(agent_input=user_input, storage_mode=storage_mode, session_id=session_id) 35 | return response 36 | 37 | except Exception as e: 38 | click.echo(click.style(f"An unhandled error occurred during chatbot conversation: {str(e)}", fg="red")) 39 | 40 | 41 | @click.command() 42 | @click.option('--verbose', is_flag=True, help='Enable verbose mode') 43 | @click.option('--session-id', help='Session id of the conversation') 44 | @click.option('--profile', help='AWS Profile to use') 45 | def chat(session_id: str, verbose: bool, profile: str): 46 | """ 47 | Simple conversational chatbot running in the terminal. 48 | Type your message, press Enter, and the chatbot will respond. 49 | Type 'exit' to quit the chatbot. 50 | """ 51 | 52 | click.echo("Agent: Hi! I'm Agent Smith! Your conversational geospatial agent. How can I help you today?") 53 | 54 | if session_id == "" or session_id is None: 55 | session_id = uuid.uuid4().hex 56 | 57 | click.echo(f"Agent: Creating a new session {session_id}") 58 | storage = LocalStorage() 59 | storage.create_session_storage(session_id=session_id) 60 | 61 | # If profile is not none or empty, set _PROFILE to profile 62 | if profile is not None and profile != "": 63 | os.environ["AWS_PROFILE"] = profile 64 | click.echo(f"Agent: Using AWS profile: {profile}") 65 | 66 | def print_signal(sender, event_data): 67 | # Check if event_data is instance of Exception 68 | if isinstance(event_data, Exception): 69 | exception_message = get_exception_messages(event_data) 70 | click.echo( 71 | click.style( 72 | f"\nAn error occurred during chatbot conversation: {str(exception_message)}\n===============\n", 73 | fg="red")) 74 | 75 | elif isinstance(event_data, AgentSignal): 76 | click.echo(click.style(f"\n{sender}: \n{event_data.event_message}", fg="cyan")) 77 | if event_data.event_type == EventType.PythonCode: 78 | highlighted_code = highlight(event_data.event_data, PythonLexer(), TerminalFormatter()) 79 | print(highlighted_code) 80 | click.echo(click.style("\n===============\n", fg="cyan")) 81 | 82 | for signal in ALL_SIGNALS: 83 | dispatcher.connect(receiver=print_signal, signal=signal) 84 | 85 | while True: 86 | user_input = click.prompt("You", type=str) 87 | if user_input.lower() == 'exit': 88 | click.echo("Agent: Goodbye! Have a great day!") 89 | break 90 | 91 | response = get_chatbot_response( 92 | user_input=user_input, session_id=session_id, verbose=verbose, storage_mode=LOCAL_STORAGE_MODE) 93 | if response: 94 | click.echo(f"Agent: {response}") 95 | 96 | 97 | def main(): 98 | chat() 99 | 100 | 101 | if __name__ == "__main__": 102 | main() 103 | -------------------------------------------------------------------------------- /geospatial_agent/shared/shim.py: -------------------------------------------------------------------------------- 1 | import os 2 | from abc import ABC, abstractmethod 3 | 4 | from geospatial_agent.shared.location import get_map_style_uri 5 | 6 | 7 | def get_shim_imports() -> str: 8 | shim_map_style_import = f'from {location_map_style.__module__} import {location_map_style.__name__} \n' \ 9 | f'from {get_data_file_url.__module__} import {get_data_file_url.__name__}\n' \ 10 | f'from {get_local_file_path.__module__} import {get_local_file_path.__name__}\n' 11 | return shim_map_style_import 12 | 13 | 14 | def location_map_style(): 15 | return get_map_style_uri() 16 | 17 | 18 | LOCAL_STORAGE_MODE = 'local' 19 | 20 | 21 | def get_data_file_url(file_path: str, session_id: str) -> str: 22 | if not file_path.startswith("agent://"): 23 | return file_path 24 | 25 | storage = LocalStorage() 26 | return storage.get_data_file_url(file_path=file_path, session_id=session_id) 27 | 28 | 29 | def get_local_file_path(file_path: str, session_id: str, task_name: str = "") -> str: 30 | storage = LocalStorage() 31 | file_url = storage.get_generated_file_url(file_path=file_path, session_id=session_id, task_name=task_name) 32 | print(f"Resolved local file file_url = {file_url}") 33 | return file_url 34 | 35 | 36 | class Storage(ABC): 37 | @abstractmethod 38 | def create_session_storage(self, session_id: str): 39 | pass 40 | 41 | @abstractmethod 42 | def get_data_file_url(self, file_path_or_url: str, session_id: str) -> str: 43 | pass 44 | 45 | @abstractmethod 46 | def get_generated_file_url(self, file_path_or_url: str, session_id: str, task_name: str = "") -> str: 47 | pass 48 | 49 | @abstractmethod 50 | def write_file(self, file_path_or_name: str, session_id: str, task_name="") -> str: 51 | pass 52 | 53 | 54 | class LocalStorage(Storage): 55 | def get_data_file_url(self, file_path: str, session_id: str) -> str: 56 | if file_path.startswith("agent://"): 57 | filename = file_path.removeprefix("agent://") 58 | root = self._get_root_folder() 59 | return os.path.join(root, session_id, "data", filename) 60 | else: 61 | # Check if the file path exists 62 | if not os.path.exists(file_path): 63 | raise FileNotFoundError(f"File {file_path} not found") 64 | return file_path 65 | 66 | def get_generated_file_url(self, file_path: str, session_id: str, task_name: str = "") -> str: 67 | if file_path.startswith("agent://"): 68 | file_path = file_path.removeprefix("agent://") 69 | 70 | if not os.path.abspath(file_path): 71 | raise ValueError("Generated file path must be absolute") 72 | 73 | root = self._get_root_folder() 74 | 75 | return os.path.join(root, session_id, "generated", task_name, file_path) 76 | 77 | def write_file(self, file_name: str, session_id: str, task_name="", content="") -> str: 78 | if content == "": 79 | raise ValueError("To write a local file, content parameter must be provided") 80 | 81 | file_path = self.get_generated_file_url( 82 | file_path=file_name, session_id=session_id, task_name=task_name 83 | ) 84 | 85 | parent_dir = os.path.dirname(file_path) 86 | if not os.path.exists(parent_dir): 87 | os.makedirs(parent_dir) 88 | 89 | with open(file_path, 'w') as file: 90 | file.write(content) 91 | 92 | return file_path 93 | 94 | def create_session_storage(self, session_id: str): 95 | root = self._get_root_folder() 96 | 97 | data_folder = os.path.join(root, session_id, "data") 98 | generated_folder = os.path.join(root, session_id, "generated") 99 | 100 | if not os.path.exists(data_folder): 101 | os.makedirs(data_folder) 102 | 103 | if not os.path.exists(generated_folder): 104 | os.makedirs(generated_folder) 105 | 106 | @staticmethod 107 | def _get_root_folder(): 108 | root = os.path.join(os.getcwd(), "geospatial-agent-session-storage") 109 | return root 110 | -------------------------------------------------------------------------------- /geospatial_agent/agent/geo_chat/chat_agent.py: -------------------------------------------------------------------------------- 1 | from typing import Sequence 2 | 3 | from langchain.agents import AgentExecutor, ZeroShotAgent 4 | from langchain.tools import BaseTool 5 | from pydispatch import dispatcher 6 | 7 | from geospatial_agent.agent.geo_chat.tools.geocode_tool import geocode_tool 8 | from geospatial_agent.agent.geo_chat.tools.gis_work_tool import gis_work_tool 9 | from geospatial_agent.agent.shared import AgentSignal, EventType, SIGNAL_GEO_CHAT_INITIATED, \ 10 | SENDER_GEO_CHAT_AGENT, SIGNAL_GEO_CHAT_RESPONSE_COMPLETE 11 | from geospatial_agent.shared.bedrock import get_claude_v2 12 | from geospatial_agent.shared.prompts import HUMAN_ROLE, ASSISTANT_ROLE 13 | 14 | _PREFIX = f"""\ 15 | {HUMAN_ROLE}: 16 | You are an conversational agent named Agent Smith. You are created by Amazon Location Service to assist an user with geospatial information and queries. 17 | Answer the following questions as best you can. 18 | 1. To assist you, you have access to some tools. Each tool has a description that explains its functionality, the inputs it takes, and the outputs it provides. 19 | 2. You MUST answer all intermediate steps with the following prefixes: Thought, Action, Action Input, and Observation. 20 | 3. If you do not find a tool to use, you MUST add "Observation:" prefix to the output. 21 | """ 22 | 23 | _SUFFIX = f"""\ 24 | {HUMAN_ROLE}: 25 | Question: {{input}} 26 | 27 | 28 | {ASSISTANT_ROLE}: 29 | Thought:{{agent_scratchpad}} 30 | """ 31 | 32 | _FORMAT_INSTRUCTIONS = f"""Use the following format: 33 | {HUMAN_ROLE}: 34 | Question: The input question or query you MUST answer. 35 | Thought: You MUST always think about what to do. 36 | Action: You SHOULD select an action to take. Actions can be one of [{{tool_names}}]. If no tool is selected, keep conversing with the user. 37 | Action Input: The input to the action. 38 | Observation: The output or result of the action. (this Thought/Action/Action Input/Observation can repeat N times). 39 | Thought: I now know the final answer. 40 | Final Answer: the final answer to the original input question. 41 | 42 | Example of using a tool: 43 | {HUMAN_ROLE}: 44 | Question: I want to know the latitude and longitude of English Bay, Vancouver. 45 | {ASSISTANT_ROLE}: 46 | Thought: I should use find a tool from [{{tool_names}}]. The selected tool is geocode. 47 | Action: Geocode. 48 | Action Input: English Bay, Vancouver. 49 | Observation: The longitude is 49.28696 and latitude is -123.1432. 50 | Thought: I now know the latitude and longitude of English Bay, Vancouver. 51 | Final Answer: The longitude and latitude of English Bay, Vancouver are 49.28696 and -123.1432. 52 | 53 | Example of not finding a tool: 54 | {HUMAN_ROLE}: 55 | Question: Hello, how are you? 56 | {ASSISTANT_ROLE}: 57 | Thought: I should greet the user in response. 58 | Action: I do not have any tool to use here. Keep conversing with the user. 59 | Action Input: Hello, how are you? 60 | Observation: Hello, I am doing fine! If you have a geospatial query or action, I can help you with that. 61 | Thought: I now have the final response for the user. 62 | Final Answer: Hello, I am doing fine! If you have a geospatial query or action, I can help you with that. 63 | """ 64 | 65 | 66 | class GeoChatAgent: 67 | """ 68 | GeoChatAgent class is the gateway to Amazon Location Geospatial Agent 69 | This class decides whether the GeospatialAgent should be invoked or not. 70 | If no, then it converses with the customer with the help of some tools. 71 | """ 72 | 73 | def __init__(self, memory=None): 74 | self.memory = memory 75 | self.claude_v2 = get_claude_v2() 76 | 77 | def invoke(self, agent_input: str, storage_mode: str, session_id: str) -> str: 78 | tools: Sequence[BaseTool] = [geocode_tool(), gis_work_tool(session_id=session_id, storage_mode=storage_mode)] 79 | agent = ZeroShotAgent.from_llm_and_tools( 80 | llm=self.claude_v2, tools=tools, 81 | prefix=_PREFIX, suffix=_SUFFIX, input_variables=["input", "agent_scratchpad"], 82 | format_instructions=_FORMAT_INSTRUCTIONS, memory=self.memory) 83 | 84 | dispatcher.send(signal=SIGNAL_GEO_CHAT_INITIATED, 85 | sender=SENDER_GEO_CHAT_AGENT, 86 | event_data=AgentSignal( 87 | event_source=SENDER_GEO_CHAT_AGENT, 88 | event_message="Initiating Agent Smith, your conversational geospatial agent", 89 | )) 90 | 91 | agent_executor = AgentExecutor.from_agent_and_tools(agent=agent, tools=tools, memory=self.memory) 92 | output = agent_executor.run(input=agent_input, chat_history=[]) 93 | dispatcher.send(signal=SIGNAL_GEO_CHAT_RESPONSE_COMPLETE, 94 | sender=SENDER_GEO_CHAT_AGENT, 95 | event_data=AgentSignal( 96 | event_source=SENDER_GEO_CHAT_AGENT, 97 | event_type=EventType.Message, 98 | event_message=output, 99 | is_final=True, 100 | )) 101 | return output 102 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## Amazon Location Geospatial Agent 2 | 3 | This repository provides an example agent that can generate a heatmap of Airbnb listings in New York City when prompted. 4 | The agent utilizes plan-and-solve prompting with Anthropic's Claude 2 model from Amazon Bedrock. 5 | 6 | ## Getting Started 7 | First, we need install all dependencies. This agent uses `poetry` to manage dependencies. 8 | `poetry` is a dependency management and packaging tool for Python. 9 | 10 | First, we need to install poetry. If you are on OSX 11 | ```bash 12 | brew install poetry 13 | poetry --version 14 | ``` 15 | 16 | If you are on any other OS, [please follow poetry official documentation for installation](https://python-poetry.org/docs/). 17 | 18 | Once poetry is installed, we will install all the dependencies using poetry. To install dependencies: 19 | ```bash 20 | poetry install 21 | ``` 22 | 23 | Next, lets give the unit tests a run: 24 | ```bash 25 | poetry run pytest 26 | ``` 27 | 28 | You can also run unit tests with coverage: 29 | 30 | ```bash 31 | poetry run coverage run -m pytest 32 | ``` 33 | 34 | If the tests pass, we are ready to go! 35 | 36 | Just in case, there is a Makefile to make our life better. You can also run ```make``` for all of these 37 | ```bash 38 | make build 39 | make install 40 | make test 41 | ``` 42 | 43 | ## Testing the agent 44 | ### Setting up the infra for Amazon Location 45 | Before testing the agent, we will need to create the following the backing AWS account: 46 | 1. An Amazon Location Place Index Resource created from Amazon Location Console. 47 | 2. An Amazon Location Map Resource created from Amazon Location Console. 48 | 3. An Amazon Location API Key created from Amazon Location Console. This API Key should have access permissions to 1 and 2. 49 | 50 | After they are created, open `.env` file and set the following env variables: 51 | ```env 52 | API_KEY_NAME= 53 | MAP_NAME= 54 | PLACE_INDEX_NAME= 55 | ``` 56 | 57 | As a default, a set of placeholders are used. 58 | 59 | ### Setting up the infra for AWS Bedrock 60 | Finally, the AWS account we are going to use with the agent MUST have Claude V2 foundational model 61 | access from Bedrock. At the time of writing, this requires going to Bedrock console and explicitly 62 | clicking a button to ask for access to Claude 2. The request is automatically accepted in a couple of 63 | moments. 64 | 65 | Now we have all the resources we need! 66 | 67 | ### Using the right credential 68 | The agent runs locally in your machine. Use local AWS credentials that has access to Amazon Bedrock InvokeModel API. 69 | Additionally, it should have access to Amazon Location SearchPlaceIndexForText API. 70 | 71 | 72 | ### Downloading the data 73 | In this sample, we will generate a heatmap from Airbnb database. Download 74 | [the Airbnb 2023 Open Dataset for New York from here](http://data.insideairbnb.com/united-states/ny/new-york-city/2023-10-01/visualisations/listings.csv). 75 | Store the file inside the `data` folder. 76 | 77 | ```bash 78 | wget http://data.insideairbnb.com/united-states/ny/new-york-city/2023-10-01/visualisations/listings.csv 79 | cp listings.csv data/listings.csv 80 | ``` 81 | 82 | Then run the following to crete a session. We are using a guid named `3c18d48c-9c9b-488f-8229-e2e8016fa851` 83 | as example session id. This will create a session with `listings.csv` stored inside `data` folder. 84 | 85 | ```bash 86 | SESSION_ID="3c18d48c-9c9b-488f-8229-e2e8016fa851" FILE_NAME="listings.csv" make create-session 87 | ``` 88 | 89 | ### Starting the agent 90 | The agent can run inside a docker container or outside docker container. We recommend running the agent inside a docker container. 91 | This way, the generated code can not create any unexpected side effect. 92 | 93 | We can build the docker image for the agent by: 94 | 95 | ```bash 96 | docker build -t agent . 97 | ``` 98 | 99 | Then, we can shell into docker by: 100 | 101 | ```bash 102 | docker run --rm -it -v ~/.aws:/root/.aws --entrypoint bash agent 103 | ``` 104 | 105 | Then, start the agent using inside docker: 106 | ```bash 107 | poetry run agent --session-id 3c18d48c-9c9b-488f-8229-e2e8016fa851 108 | ``` 109 | 110 | If you want to use an AWS profile, there is a `--profile` flag available 111 | ```bash 112 | poetry run agent --session-id 3c18d48c-9c9b-488f-8229-e2e8016fa851 --profile some-aws-profile 113 | ``` 114 | 115 | The agent will write all generated content under `geospatial-agent-session-storage` folder. 116 | 117 | Now, when prompted by the agent, we can use the following input to generate the heatmap. 118 | ``` 119 | I've uploaded the file listings.csv. Draw a heatmap of Airbnb listing price. 120 | ``` 121 | 122 | And then, let the agent do its thing! 123 | 124 | ### Limitations 125 | This agent can work on datasets other than the one used for training, but its performance is not guaranteed to be perfect 126 | on all datasets and tasks. 127 | 128 | The agent can generate high-level plans to solve problems. However, to translate those plans into executable code, 129 | it relies on the Claude 2 model which is good at writing Python code using built-ins but has limitations. 130 | Claude 2 scored 74% on the HumanEval Python test. 131 | 132 | To reliably write code dealing with geospatial data, the agent would need additional training focused on libraries 133 | like geopandas, matplotlib, and pydeck. In particular, knowledge of geopandas is crucial for spatial joins. 134 | 135 | While the agent can plan spatial joins, it often fails to write functioning geopandas code to join columns from two 136 | dataframes. Common issues include data type mismatches and coordinate system incompatibilities. We tried addressing 137 | these with prompting, but without further tuning or a model specialized for geopandas, success rates across diverse 138 | datasets will be limited. 139 | 140 | This agent relies on generative AI, which has inherent limitations. Retrying once or twice may yield better results 141 | if an interaction is unsatisfactory. -------------------------------------------------------------------------------- /tests/test_action_summarizer.py: -------------------------------------------------------------------------------- 1 | import pandas 2 | import pytest 3 | from assertpy import assert_that 4 | from langchain.llms import FakeListLLM 5 | from pydantic import ValidationError 6 | 7 | from geospatial_agent.agent.action_summarizer.action_summarizer \ 8 | import ActionContext, ActionSummarizer, FileSummary 9 | from geospatial_agent.agent.action_summarizer.prompts import DATA_FRAMES_VARIABLE_NAME 10 | 11 | 12 | def test_initializing_action_summarizer_does_not_raise_exception(): 13 | action_summarizer = ActionSummarizer() 14 | assert_that(action_summarizer).is_not_none() 15 | 16 | 17 | def test_action_context_extraction_from_llm_response_does_not_raise_exception(): 18 | user_input = "Build me a heatmap. I have uploaded data.csv" 19 | expected_action_context = ActionContext(action="Build me a heatmap", file_paths=["agent://data.csv"]) 20 | responses = [expected_action_context.json()] 21 | 22 | # Creating a Fake LLM for mocking responses. 23 | fake_llm = FakeListLLM(responses=responses) 24 | action_summarizer = ActionSummarizer(llm=fake_llm) 25 | 26 | context = action_summarizer._extract_action_context(user_input=user_input) 27 | assert_that(context).is_equal_to(expected_action_context) 28 | 29 | 30 | def test_action_context_extraction_from_llm_response_raise_exception_if_action_is_missing(): 31 | user_input = "Build me a heatmap. I have uploaded data.csv" 32 | expected_action_context = ActionContext.construct( 33 | ActionContext.__fields_set__, action=None, file_paths=["agent://data.csv"]) 34 | responses = [expected_action_context.json()] 35 | 36 | # Creating a Fake LLM for mocking responses. 37 | fake_llm = FakeListLLM(responses=responses) 38 | action_summarizer = ActionSummarizer(llm=fake_llm) 39 | 40 | with pytest.raises(ValidationError) as exec_info: 41 | action_summarizer._extract_action_context(user_input=user_input) 42 | 43 | assert_that(exec_info.value.raw_errors).is_length(1) 44 | assert_that(exec_info.value.raw_errors[0]._loc).is_equal_to('action') 45 | assert_that(exec_info.value.__str__()).contains('none is not an allowed value') 46 | 47 | 48 | def test_action_context_extraction_from_llm_response_raise_exception_if_file_paths_is_missing(): 49 | user_input = "Build me a heatmap. I have uploaded data.csv" 50 | expected_action_context = ActionContext.construct( 51 | ActionContext.__fields_set__, action="Build me a heatmap", file_paths=None) 52 | responses = [expected_action_context.json()] 53 | 54 | # Creating a Fake LLM for mocking responses. 55 | fake_llm = FakeListLLM(responses=responses) 56 | action_summarizer = ActionSummarizer(llm=fake_llm) 57 | 58 | with pytest.raises(ValidationError) as exec_info: 59 | action_summarizer._extract_action_context(user_input=user_input) 60 | 61 | assert_that(exec_info.value.raw_errors).is_length(1) 62 | assert_that(exec_info.value.raw_errors[0]._loc).is_equal_to('file_paths') 63 | assert_that(exec_info.value.__str__()).contains('none is not an allowed value') 64 | 65 | 66 | def test_generating_file_reading_code_does_not_raise_exception(): 67 | expected_action_context = ActionContext(action="Build me a heatmap", file_paths=["agent://data.csv"]) 68 | responses = ["```python var something = 'something'\n```"] 69 | fake_llm = FakeListLLM(responses=responses) 70 | action_summarizer = ActionSummarizer(llm=fake_llm) 71 | 72 | action_summarizer._gen_file_read_code( 73 | action_context=expected_action_context, session_id="session_id", storage_mode='test_storage_mode') 74 | 75 | 76 | def test_generating_file_summary_for_action_does_not_raise_exception(): 77 | action = "Build me a heatmap" 78 | file_summaries = [FileSummary( 79 | file_url="agent://data.csv", 80 | data_frame=pandas.DataFrame(data=[[1, 2, 3], [4, 5, 6]]), 81 | column_names=["a", "b", "c"] 82 | )] 83 | 84 | test_file_summary = "test file summary" 85 | responses = [test_file_summary] 86 | fake_llm = FakeListLLM(responses=responses) 87 | 88 | action_summarizer = ActionSummarizer(llm=fake_llm) 89 | file_summaries = action_summarizer._gen_file_summaries_for_action(action=action, file_summaries=file_summaries) 90 | 91 | assert_that(file_summaries).is_length(1) 92 | assert_that(file_summaries[0].file_summary).is_equal_to(test_file_summary) 93 | 94 | 95 | def test_generating_file_summaries_from_executing_code_does_not_raise_exception(): 96 | expected_file_summaries = [FileSummary( 97 | file_url="agent://data.csv", 98 | data_frame=pandas.DataFrame(data=[[1, 2, 3], [4, 5, 6]]), 99 | column_names=["a", "b", "c"] 100 | )] 101 | 102 | code = f""" 103 | import pandas 104 | from geospatial_agent.agent.action_summarizer.action_summarizer import FileSummary 105 | 106 | def test_code(): 107 | file_summaries = [FileSummary( 108 | file_url="agent://data.csv", 109 | data_frame=pandas.DataFrame(data=[[1, 2, 3], [4, 5, 6]]), 110 | column_names=["a", "b", "c"] 111 | ).dict()] 112 | return file_summaries 113 | 114 | {DATA_FRAMES_VARIABLE_NAME} = test_code() 115 | """ 116 | 117 | file_summaries = ActionSummarizer._gen_file_summaries_from_executing_code(code=code) 118 | assert_that(file_summaries).is_length(len(expected_file_summaries)) 119 | assert_that(file_summaries[0].file_url).is_equal_to(expected_file_summaries[0].file_url) 120 | assert_that(file_summaries[0].data_frame.equals(expected_file_summaries[0].data_frame)).is_true() 121 | assert_that(file_summaries[0].column_names).is_equal_to(expected_file_summaries[0].column_names) 122 | 123 | 124 | def test_invoking_action_summarizer_does_not_raise_exception(): 125 | user_input = "Build me a heatmap. I have uploaded data.csv" 126 | session_id = "session_id" 127 | 128 | code = f""" 129 | ```python 130 | import pandas 131 | from geospatial_agent.agent.action_summarizer.action_summarizer import FileSummary 132 | 133 | def test_code(): 134 | file_summaries = [FileSummary( 135 | file_url="agent://data.csv", 136 | data_frame=pandas.DataFrame(data=[[1, 2, 3], [4, 5, 6]]), 137 | column_names=["a", "b", "c"] 138 | ).dict()] 139 | return file_summaries 140 | 141 | {DATA_FRAMES_VARIABLE_NAME} = test_code() 142 | ```""" 143 | test_file_summary = "test file summary" 144 | 145 | expected_action_context = ActionContext(action="Build me a heatmap", file_paths=["agent://data.csv"]) 146 | responses = [expected_action_context.json(), code, test_file_summary] 147 | fake_llm = FakeListLLM(responses=responses) 148 | 149 | action_summarizer = ActionSummarizer(llm=fake_llm) 150 | action_summary = action_summarizer.invoke(user_input=user_input, session_id=session_id, 151 | storage_mode='test_storage_mode') 152 | -------------------------------------------------------------------------------- /geospatial_agent/agent/geospatial/solver/prompts.py: -------------------------------------------------------------------------------- 1 | from geospatial_agent.shared.prompts import GIS_AGENT_ROLE_INTRO 2 | 3 | # =============Operation Requirement Gen Section================== 4 | operation_requirement_gen_intro = GIS_AGENT_ROLE_INTRO 5 | 6 | shim_instructions = [ 7 | "Always use built-in function resolved_file_url = get_data_file_url(file_url, session_id) to get downloadable URLs. Do not add import statement for this function.", 8 | "For any type of task, when saving any file locally, use built-in function get_local_file_path(filepath, session_id, task_name). Do not add import statement for this function.", 9 | "When visualizing something with pydeck,use a built-in function named location_map_style() to return map_style. Do not add import statement for this function.", 10 | ] 11 | 12 | predefined_operation_requirements = [ 13 | "Do not change the given variable names and paths from the graph.", 14 | "Write code with necessary import statements of necessary libraries. Do not add imports for built-in functions. Follow Pep8 styling.", 15 | 'Write code into a Python code block(enclosed by ```python and ```).', 16 | "For visualization that requires a map, always use pydeck library unless otherwise specified.", 17 | "For visualization, if pydeck is used, save the pydeck deck to a HTML file.", 18 | "For visualization that does not require a map, use other plotting libraries such as matplotlib, and pyplot.", 19 | "When using GeoPandas to load a zipped shapefile from a URL, use gpd.read_file(URL). Do not download and unzip the file.", 20 | "When doing spatial analysis, if necessary, and if there are multiple layers ONLY, convert all involved spatial layers into the same map projection.", 21 | "Conduct map projection conversion only for spatial data layers defined with geopandas GeoDataFrame. Do not do map projection with pandas DataFrame", 22 | "While joining DataFrame and GeoDataFrame, use common columns. Do not to convert DataFrame to GeoDataFrame.", 23 | "When joining tables, convert the involved columns to string type without leading zeros. Convert integers to floats if necessary.", 24 | "Show units for graphs or maps.", 25 | "When doing spatial joins, retain at least 1 geometry column", 26 | "While using GeoPandas for spatial joining, remind the arguements. The arguments are: geopandas.sjoin(left_df, right_df, how='inner', predicate='intersects', lsuffix='left', rsuffix='right', **kwargs)", 27 | "Do not to write a main function with 'if __name__ == '__main__:'", 28 | "Only a single python function is to be written in this task. Do not write tests or a main function.", 29 | "Use the built-in functions or attribute. Do not to make up fake built-in functions.", 30 | "Do not to use any library or package that is not necessary for the task.", 31 | "Do not to use try except block without re-throwing the exception.", 32 | "Point function requires importing shapely library." 33 | ] 34 | 35 | operation_requirement_gen_task_prefix = r""" 36 | {human_role}: 37 | Your role: {operation_req_gen_intro} 38 | 39 | The function to write requirements for: {operation_name}. 40 | 41 | The function that we need requirements for has the following properties: 42 | {operation_properties} 43 | 44 | These are the pre-written requirements: 45 | {pre_requirements} 46 | 47 | Your task is to respond with a JSON string array of requirements. 48 | 1. Pick requirements from the pre-written requirements that are relevant for the current function. 49 | 2. Never re-phrase the requirements from the pre-written requirements. 50 | 3. Do not add your own specific requirements unless it is a corner case. 51 | 4. Maximum number of requirements is 20. 52 | 5. Write the python array into a xml block, enclosed by and . 53 | 54 | 55 | {assistant_role}: 56 | """ 57 | 58 | # =============Operation Code Gen Section================== 59 | operation_code_gen_intro = GIS_AGENT_ROLE_INTRO 60 | operation_task_prefix = r'You need to generate a Python function to do: ' 61 | 62 | operation_reply_example = """ 63 | ```python' 64 | import pandas as pd 65 | 66 | def Load_csv(csv_url="agent://data.csv"): 67 | # Description: Load a CSV file from a given URL 68 | # csv_url: CSV file URL 69 | # Get downloadable url from csv_url 70 | file_url = get_data_file_url(file_url, session_id) 71 | tract_population_df = pd.read_csv(tract_population_csv_url) 72 | return tract_population_df 73 | ``` 74 | """ 75 | 76 | operation_pydeck_example = """ 77 | import pydeck as pdk 78 | 79 | def generate_heatmap(airbnb_gdf): 80 | # Generate heatmap using pydeck 81 | airbnb_heatmap = pdk.Deck( 82 | map_style=location_map_style(), 83 | initial_view_state=pdk.ViewState( 84 | latitude=airbnb_gdf['latitude'].mean(), 85 | longitude=airbnb_gdf['longitude'].mean(), 86 | zoom=11, 87 | pitch=50, 88 | ), 89 | layers=[ 90 | pdk.Layer( 91 | 'HexagonLayer', 92 | data=airbnb_gdf, 93 | get_position=['longitude', 'latitude'], 94 | radius=100, 95 | elevation_scale=4, 96 | elevation_range=[0, 1000], 97 | pickable=True, 98 | extruded=True, 99 | ), 100 | pdk.Layer( 101 | 'ScatterplotLayer', 102 | data=airbnb_gdf, 103 | get_position=['longitude', 'latitude'], 104 | get_color='[200, 30, 0, 160]', 105 | get_radius=200, 106 | ), 107 | ], 108 | ) 109 | 110 | # Save heatmap HTML 111 | airbnb_heatmap.to_html(get_local_file_path('airbnb_heatmap.html', session_id, task_name)) 112 | return airbnb_heatmap 113 | """ 114 | 115 | operation_code_gen_prompt_template = """ 116 | {human_role}: 117 | Your role: {operation_code_gen_intro} 118 | 119 | Operation_task: {operation_task_prefix} {operation_description} 120 | 121 | This function is one step to solve the question/task: {task_definition} 122 | 123 | Your reply needs to meet these requirements: {operation_requirements} 124 | 125 | Data locations: {data_locations_instructions} 126 | Session Id: {session_id} 127 | Task Name: {task_name} 128 | Storage Mode: {storage_mode} 129 | 130 | Your reply example: {operation_reply_example} 131 | 132 | Pydeck usage example: 133 | {operation_pydeck_example} 134 | 135 | This function is a operation node in a solution graph for the question/task, the Python code to build the graph is: 136 | {graph_code} 137 | 138 | The ancestor function code is below. Follow the generated file names and attribute names: 139 | {ancestor_operation_code} 140 | 141 | The descendant function (if any) definitions for the question are (node_name is function name): 142 | {descendant_operations_definition} 143 | 144 | 145 | {assistant_role}: 146 | """ 147 | -------------------------------------------------------------------------------- /geospatial_agent/agent/geospatial/solver/op_graph.py: -------------------------------------------------------------------------------- 1 | from typing import Sequence 2 | 3 | import networkx 4 | 5 | from geospatial_agent.agent.geospatial.solver.constants import NODE_TYPE_ATTRIBUTE, NODE_TYPE_OPERATION, \ 6 | NODE_DESCRIPTION_ATTRIBUTE, NODE_TYPE_DATA, NODE_DATA_PATH_ATTRIBUTE, NODE_TYPE_OPERATION_TYPE 7 | 8 | 9 | class OperationNode: 10 | def __init__(self, 11 | function_definition: str, 12 | return_line: str, 13 | description: str, 14 | node_name: str, 15 | param_names: set, 16 | return_names: set, 17 | operation_type: str = "", 18 | code_gen_response: str = "", 19 | operation_code: str = "", 20 | reviewed_code: str = "", 21 | operation_prompt: str = ""): 22 | self.function_definition = function_definition 23 | self.return_line = return_line 24 | self.description = description 25 | self.operation_type = operation_type 26 | self.node_name = node_name 27 | self.param_names = param_names 28 | self.return_names = return_names 29 | self.code_gen_response = code_gen_response 30 | self.operation_code = operation_code 31 | self.reviewed_code = reviewed_code 32 | self.operation_prompt = operation_prompt 33 | 34 | 35 | # An exception class for OperationsParser with a message 36 | class OperationsParserException(Exception): 37 | def __init__(self, message: str): 38 | self.message = message 39 | super().__init__(self.message) 40 | 41 | 42 | class OperationsParser: 43 | def __init__(self, graph: networkx.DiGraph): 44 | self.graph = graph 45 | 46 | self.op_node_names = self._get_operation_node_names() 47 | self.operation_nodes = self._get_operation_nodes(self.op_node_names) 48 | self.output_node_names = self._get_output_node_names() 49 | self.input_node_names = self._get_input_node_names() 50 | 51 | def get_ancestors(self, node_name) -> Sequence[OperationNode]: 52 | ancestor_node_names = networkx.ancestors(self.graph, node_name) 53 | 54 | ancestor_operation_names = [] 55 | for node_name in ancestor_node_names: 56 | if node_name in self.op_node_names: 57 | ancestor_operation_names.append(node_name) 58 | 59 | ancestor_operation_functions = [] 60 | for op_node in self.operation_nodes: 61 | op_node_name = op_node.node_name 62 | if op_node_name in ancestor_operation_names: 63 | ancestor_operation_functions.append(op_node) 64 | 65 | return ancestor_operation_functions 66 | 67 | def get_descendants(self, node_name) -> Sequence[OperationNode]: 68 | descendant_operation_names = [] 69 | descendant_node_names = networkx.descendants(self.graph, node_name) 70 | 71 | for descendant in descendant_node_names: 72 | if descendant in self.op_node_names: 73 | descendant_operation_names.append(descendant) 74 | 75 | descendant_operation_nodes = [] 76 | for op_node in self.operation_nodes: 77 | op_name = op_node.node_name 78 | if op_name in descendant_operation_names: 79 | descendant_operation_nodes.append(op_node) 80 | 81 | return descendant_operation_nodes 82 | 83 | def stringify_nodes(self, nodes: Sequence[OperationNode]) -> str: 84 | """Returns all operation nodes attributes stringified as a new line delimited string""" 85 | op_def_list = [] 86 | for op_node in nodes: 87 | op_node_dict = op_node.__dict__ 88 | op_def_list.append(str(op_node_dict)) 89 | 90 | defs = '\n'.join(op_def_list) 91 | return defs 92 | 93 | def _get_operation_nodes(self, op_node_names) -> Sequence[OperationNode]: 94 | op_nodes = [] 95 | for op in op_node_names: 96 | node_dict = self.graph.nodes[op] 97 | 98 | node_type = node_dict[NODE_TYPE_ATTRIBUTE] 99 | if node_type != NODE_TYPE_OPERATION: 100 | raise OperationsParserException(f"Node {op} is not an operation node") 101 | 102 | function_def, param_names = self._get_func_def_str(op) 103 | 104 | successors = list(self.graph.successors(op)) 105 | return_str = 'return ' + ', '.join(successors) 106 | 107 | op_node = OperationNode( 108 | function_definition=function_def, 109 | return_line=return_str, 110 | description=node_dict[NODE_DESCRIPTION_ATTRIBUTE], 111 | operation_type=node_dict.get(NODE_TYPE_OPERATION_TYPE, ""), 112 | node_name=op, 113 | param_names=param_names, 114 | return_names=set(successors) 115 | ) 116 | 117 | op_nodes.append(op_node) 118 | return op_nodes 119 | 120 | def _get_operation_node_names(self): 121 | op_nodes = [] 122 | for node_name in self.graph.nodes(): 123 | node = self.graph.nodes[node_name] 124 | if node[NODE_TYPE_ATTRIBUTE] == NODE_TYPE_OPERATION: 125 | op_nodes.append(node_name) 126 | return op_nodes 127 | 128 | def _get_output_node_names(self): 129 | """Returns output nodes from the graph. Output nodes have 'output' attribute set to True""" 130 | output_nodes = [] 131 | for node_name in self.graph.nodes(): 132 | node = self.graph.nodes[node_name] 133 | if len(list(self.graph.successors(node_name))) == 0: 134 | if node[NODE_TYPE_ATTRIBUTE] != NODE_TYPE_DATA: 135 | raise OperationsParserException(f"Node {node_name} is not an {NODE_TYPE_DATA} node") 136 | output_nodes.append(node_name) 137 | return output_nodes 138 | 139 | def _get_input_node_names(self): 140 | """Returns input nodes from the graph. Input nodes have 'input' attribute set to True""" 141 | input_nodes = [] 142 | for node_name in self.graph.nodes(): 143 | node = self.graph.nodes[node_name] 144 | if len(list(self.graph.predecessors(node_name))) == 0: 145 | if node[NODE_TYPE_ATTRIBUTE] != NODE_TYPE_DATA: 146 | raise OperationsParserException(f"Node {node_name} is not an {NODE_TYPE_DATA} node") 147 | input_nodes.append(node_name) 148 | return input_nodes 149 | 150 | def _get_func_def_str(self, node): 151 | """ 152 | Returns function definition string with function name, parameters and default values of parameters. 153 | """ 154 | 155 | # INFO: To generate a function definition from the solution graph, we need to find the parameters of the 156 | # function, and the return value. We start with looking for the predecessors of the node. 157 | # Because the parameters are the predecessors. 158 | 159 | predecessors = self.graph.predecessors(node) 160 | 161 | param_default_str = '' 162 | param_str = '' 163 | param_names = set() 164 | 165 | for data_node in predecessors: 166 | param_node = self.graph.nodes[data_node] 167 | 168 | # INFO: The parameter node may have a data_path attribute specifying the location of its data, 169 | # like a URL or filepath, which should be used if present; otherwise the node name can be 170 | # used as the default parameter value. 171 | 172 | data_path = param_node.get(NODE_DATA_PATH_ATTRIBUTE, '') 173 | param_names.add(data_node) 174 | 175 | if data_path != "": 176 | param_default_str = param_default_str + f"{data_node}='{data_path}', " 177 | else: 178 | param_str = param_str + f"{data_node}, " 179 | 180 | all_parameters_str = param_str + param_default_str 181 | 182 | func_def = f'{node}({all_parameters_str})' 183 | func_def = func_def.replace(', )', ')') 184 | 185 | return func_def, param_names 186 | -------------------------------------------------------------------------------- /geospatial_agent/agent/geospatial/agent.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import networkx 4 | from pydispatch import dispatcher 5 | 6 | from geospatial_agent.agent.action_summarizer.action_summarizer import ActionSummary 7 | from geospatial_agent.agent.geospatial.planner.planner import gen_plan_graph, gen_task_name 8 | from geospatial_agent.agent.geospatial.solver.solver import Solver 9 | from geospatial_agent.agent.shared import AgentSignal, EventType, SIGNAL_ASSEMBLED_CODE_EXECUTED, \ 10 | SENDER_GEOSPATIAL_AGENT, SIGNAL_GRAPH_CODE_GENERATED, SIGNAL_TASK_NAME_GENERATED, SIGNAL_ASSEMBLED_CODE_EXECUTING, \ 11 | execute_assembled_code 12 | from geospatial_agent.shared.bedrock import get_claude_v2 13 | from geospatial_agent.shared.shim import LocalStorage 14 | 15 | 16 | class GISAgentException(Exception): 17 | def __init__(self, message: str): 18 | self.message = message 19 | super().__init__(self.message) 20 | 21 | 22 | class GISAgentResponse: 23 | def __init__(self, graph_plan_code, graph, repl_output, op_defs, 24 | assembled_code, assembled_code_output, assembled_code_file_path): 25 | self.graph_plan_code = graph_plan_code 26 | self.graph = graph 27 | self.graph_plan_code_output = repl_output 28 | self.assembled_code = assembled_code 29 | self.op_defs = op_defs 30 | self.assembled_code_output = assembled_code_output 31 | self.assembled_code_file_path = assembled_code_file_path 32 | 33 | 34 | class GeospatialAgent: 35 | """A geospatial data scientist and a python developer agent written by Amazon Location Service.""" 36 | 37 | _assembled_code_file_name = "assembled_code.py" 38 | 39 | def __init__(self, storage_mode: str): 40 | claude_v2 = get_claude_v2() 41 | self.llm = claude_v2 42 | self.local_storage = LocalStorage() 43 | self.storage_mode = storage_mode 44 | 45 | def invoke(self, action_summary: ActionSummary, session_id: str) -> GISAgentResponse: 46 | try: 47 | # INFO: Generating a task name from the action summary action 48 | task_name = gen_task_name(self.llm, action_summary.action) 49 | dispatcher.send(signal=SIGNAL_TASK_NAME_GENERATED, 50 | sender=SENDER_GEOSPATIAL_AGENT, 51 | event_data=AgentSignal( 52 | event_source=SENDER_GEOSPATIAL_AGENT, 53 | event_message=f"I will use task name {task_name} to gather all generated artifacts.", 54 | )) 55 | 56 | data_locations_instructions = self._get_data_locations_instructions(action_summary) 57 | 58 | # INFO: Generating the graph plan to write code 59 | graph_plan_code = gen_plan_graph(self.llm, 60 | task_definition=action_summary.action, 61 | data_locations_instructions=data_locations_instructions) 62 | dispatcher.send( 63 | signal=SIGNAL_GRAPH_CODE_GENERATED, 64 | sender=SENDER_GEOSPATIAL_AGENT, 65 | event_data=AgentSignal( 66 | event_source=SENDER_GEOSPATIAL_AGENT, 67 | event_message=f'Generated plan graph code.', 68 | event_type=EventType.PythonCode, 69 | event_data=graph_plan_code 70 | )) 71 | 72 | # INFO: Executing the graph plan code and get the graph object and the repl output 73 | graph, repl_output = self._execute_plan_graph_code(graph_plan_code) 74 | graph_file_abs_path = self._write_local_graph_file(graph, session_id=session_id, task_name=task_name) 75 | 76 | solver = Solver( 77 | llm=self.llm, 78 | graph=graph, 79 | graph_code=graph_plan_code, 80 | session_id=session_id, 81 | storage_mode=self.storage_mode, 82 | task_definition=action_summary.action, 83 | task_name=task_name, 84 | data_locations_instructions=data_locations_instructions) 85 | 86 | op_defs = solver.solve() 87 | assembled_code = solver.assemble() 88 | 89 | dispatcher.send(signal=SIGNAL_ASSEMBLED_CODE_EXECUTING, 90 | sender=SENDER_GEOSPATIAL_AGENT, 91 | event_data=AgentSignal( 92 | event_source=SENDER_GEOSPATIAL_AGENT, 93 | event_message="Saving and executing assembled code", 94 | )) 95 | 96 | code_file_abs_path = self._write_local_code_file(assembled_code=assembled_code, session_id=session_id, 97 | task_name=task_name) 98 | 99 | code_output, _ = execute_assembled_code(assembled_code) 100 | if code_output is not None: 101 | dispatcher.send(signal=SIGNAL_ASSEMBLED_CODE_EXECUTED, 102 | sender=SENDER_GEOSPATIAL_AGENT, 103 | event_data=AgentSignal( 104 | event_source=SENDER_GEOSPATIAL_AGENT, 105 | event_message=code_output 106 | )) 107 | 108 | return GISAgentResponse( 109 | graph_plan_code=graph_plan_code, 110 | graph=graph, 111 | repl_output=repl_output, 112 | op_defs=op_defs, 113 | assembled_code=assembled_code, 114 | assembled_code_output=code_output, 115 | assembled_code_file_path=code_file_abs_path, 116 | ) 117 | except Exception as e: 118 | raise GISAgentException(message="Error occurred while executing the graph plan code") from e 119 | 120 | @staticmethod 121 | def _get_data_locations_instructions(action_summary): 122 | # Generating a string for all the data locations from action_summary 123 | # For each file in action_summary.file_summaries, we will generate a string of: 124 | # "File Location: ", 125 | # "Column Names: ", 126 | # "Summary: " 127 | # We will then join these strings with a new line character and return it. 128 | # We will also add a new line character at the end of the string. 129 | data_locations_instructions = "" 130 | for file_summary in action_summary.file_summaries: 131 | instr = "" 132 | instr += f"File Location: {file_summary.file_url}\n" 133 | instr += f"Column Names: {file_summary.column_names}\n" 134 | instr += f"Summary: {file_summary.file_summary}\n" 135 | data_locations_instructions += instr 136 | return data_locations_instructions 137 | 138 | def _write_local_graph_file(self, graph, session_id: str, task_name: str) -> str: 139 | graph_file_path = self.local_storage.get_generated_file_url( 140 | file_path="plan_graph.graphml", session_id=session_id, task_name=task_name) 141 | 142 | parent_dir = os.path.dirname(graph_file_path) 143 | if not os.path.exists(parent_dir): 144 | os.makedirs(parent_dir) 145 | 146 | networkx.write_graphml(graph, graph_file_path, named_key_ids=False) 147 | return os.path.abspath(graph_file_path) 148 | 149 | def _write_local_code_file(self, session_id: str, assembled_code: str, task_name: str): 150 | return self.local_storage.write_file( 151 | file_name=self._assembled_code_file_name, 152 | session_id=session_id, 153 | task_name=task_name, 154 | content=assembled_code 155 | ) 156 | 157 | @staticmethod 158 | def _execute_plan_graph_code(graph_plan_code) -> tuple[networkx.DiGraph, str]: 159 | """Returns the plan graph object by executing the graph plan code.""" 160 | output, _globals = execute_assembled_code(graph_plan_code) 161 | graph: networkx.DiGraph = _globals['G'] 162 | return graph, output 163 | -------------------------------------------------------------------------------- /geospatial_agent/agent/action_summarizer/action_summarizer.py: -------------------------------------------------------------------------------- 1 | import json 2 | from typing import List, Any, Optional 3 | 4 | from langchain import PromptTemplate, LLMChain 5 | from pydantic import BaseModel, ConfigDict 6 | from pydispatch import dispatcher 7 | 8 | from geospatial_agent.agent.action_summarizer.prompts import _ACTION_SUMMARY_PROMPT, _ROLE_INTRO, \ 9 | _READ_FILE_PROMPT, _READ_FILE_REQUIREMENTS, _ACTION_SUMMARY_REQUIREMENTS, DATA_FRAMES_VARIABLE_NAME, \ 10 | _DATA_SUMMARY_REQUIREMENTS, _DATA_SUMMARY_PROMPT 11 | from geospatial_agent.agent.shared import AgentSignal, EventType, SIGNAL_ACTION_CONTEXT_GENERATED, \ 12 | SENDER_ACTION_SUMMARIZER, SIGNAL_FILE_READ_CODE_GENERATED, SIGNAL_FILE_READ_CODE_EXECUTED, execute_assembled_code 13 | from geospatial_agent.shared.bedrock import get_claude_v2 14 | from geospatial_agent.shared.prompts import HUMAN_ROLE, ASSISTANT_ROLE, HUMAN_STOP_SEQUENCE 15 | from geospatial_agent.shared.shim import get_shim_imports 16 | from geospatial_agent.shared.utils import extract_code 17 | 18 | 19 | class ActionSummarizerException(Exception): 20 | def __init__(self, message: str): 21 | self.message = message 22 | super().__init__(self.message) 23 | 24 | 25 | class ActionContext(BaseModel): 26 | action: str 27 | file_paths: List[str] 28 | 29 | 30 | class FileSummary(BaseModel): 31 | model_config = ConfigDict(arbitrary_types_allowed=True) 32 | file_url: str 33 | data_frame: Any 34 | column_names: List[str] 35 | file_summary: Optional[str] = None 36 | 37 | 38 | class ActionSummary(BaseModel): 39 | action: str 40 | file_summaries: List[FileSummary] 41 | 42 | 43 | class ActionSummarizer: 44 | """Action summarizer acts on raw user messages with the following traits 45 | 1. It is a geospatial query or analysis such as "Draw me a heatmap". 46 | 2. Has URLS of data to be used for the analysis. 47 | 48 | ActionSummarizer generates a list of ActionSummary. 49 | """ 50 | 51 | def __init__(self, llm=None): 52 | if llm is None: 53 | claude_v2 = get_claude_v2() 54 | self.llm = claude_v2 55 | else: 56 | self.llm = llm 57 | 58 | def invoke(self, user_input: str, session_id: str, storage_mode: str) -> ActionSummary: 59 | try: 60 | action_context = self._extract_action_context(user_input) 61 | dispatcher.send(signal=SIGNAL_ACTION_CONTEXT_GENERATED, 62 | sender=SENDER_ACTION_SUMMARIZER, 63 | event_data=AgentSignal( 64 | event_type=EventType.Message, 65 | event_source=SENDER_ACTION_SUMMARIZER, 66 | event_message=f'Detected desired action {action_context.action}. And file paths: {action_context.file_paths}.' 67 | )) 68 | 69 | read_file_code = self._gen_file_read_code(action_context, session_id, storage_mode) 70 | dispatcher.send(signal=SIGNAL_FILE_READ_CODE_GENERATED, 71 | sender=SENDER_ACTION_SUMMARIZER, 72 | event_data=AgentSignal( 73 | event_type=EventType.PythonCode, 74 | event_source=SENDER_ACTION_SUMMARIZER, 75 | event_message=f'Generated code to read and understand data schema.', 76 | event_data=read_file_code 77 | )) 78 | 79 | data_files_summary = self._gen_file_summaries_from_executing_code(read_file_code) 80 | dispatcher.send(signal=SIGNAL_FILE_READ_CODE_EXECUTED, 81 | sender=SENDER_ACTION_SUMMARIZER, 82 | event_data=AgentSignal( 83 | event_type=EventType.Message, 84 | event_source=SENDER_ACTION_SUMMARIZER, 85 | event_message=f'Successfully executed code to read and understand data schema.', 86 | )) 87 | 88 | file_summaries = self._gen_file_summaries_for_action(action_context.action, data_files_summary) 89 | return ActionSummary(action=action_context.action, file_summaries=file_summaries) 90 | 91 | except Exception as e: 92 | if isinstance(e, ActionSummarizerException): 93 | raise e 94 | else: 95 | raise ActionSummarizerException( 96 | message=f"Failed to extract dataframes from data reading code. Original exception: {e}") from e 97 | 98 | def _gen_file_summaries_for_action(self, action: str, file_summaries: List[FileSummary]) -> List[FileSummary]: 99 | for item in file_summaries: 100 | requirements_str = "\n".join( 101 | [f"{index + 1}. {requirement}" for index, requirement in enumerate(_DATA_SUMMARY_REQUIREMENTS)]) 102 | file_summary_template: PromptTemplate = PromptTemplate.from_template(_DATA_SUMMARY_PROMPT) 103 | gdf_str = item.data_frame.to_json() 104 | 105 | if len(gdf_str) > 4000: 106 | gdf_str = gdf_str[:4000] 107 | 108 | chain = LLMChain(llm=self.llm, prompt=file_summary_template) 109 | file_summary = chain.run( 110 | role_intro=_ROLE_INTRO, 111 | human_role=HUMAN_ROLE, 112 | requirements=requirements_str, 113 | action=action, 114 | columns=item.column_names, 115 | table=gdf_str, 116 | assistant_role=ASSISTANT_ROLE, 117 | stop=[HUMAN_STOP_SEQUENCE] 118 | ).strip() 119 | item.file_summary = file_summary 120 | 121 | return file_summaries 122 | 123 | def _gen_file_read_code(self, action_context: ActionContext, session_id: str, storage_mode: str) -> str: 124 | file_paths = action_context.file_paths 125 | file_urls_str = "\n".join( 126 | [f"{index + 1}. {file_url}" for index, file_url in enumerate(file_paths)]) 127 | 128 | requirements_str = "\n".join( 129 | [f"{index + 1}. {requirement}" for index, requirement in enumerate(_READ_FILE_REQUIREMENTS)]) 130 | read_file_template: PromptTemplate = PromptTemplate.from_template(_READ_FILE_PROMPT) 131 | 132 | chain = LLMChain(llm=self.llm, prompt=read_file_template) 133 | read_file_code_response = chain.run( 134 | role_intro=_ROLE_INTRO, 135 | human_role=HUMAN_ROLE, 136 | requirements=requirements_str, 137 | session_id=session_id, 138 | storage_mode=storage_mode, 139 | assistant_role=ASSISTANT_ROLE, 140 | file_urls=file_urls_str, 141 | stop=[HUMAN_STOP_SEQUENCE] 142 | ).strip() 143 | 144 | read_file_code = extract_code(read_file_code_response) 145 | return read_file_code 146 | 147 | @staticmethod 148 | def _gen_file_summaries_from_executing_code(code: str) -> List[FileSummary]: 149 | assembled_code = f'{get_shim_imports()}\n{code}' 150 | output, _globals = execute_assembled_code(assembled_code) 151 | 152 | dataframes = _globals[DATA_FRAMES_VARIABLE_NAME] 153 | file_summaries = [FileSummary(**data) for data in dataframes] 154 | 155 | if len(file_summaries) == 0: 156 | raise ActionSummarizerException( 157 | message=f"Failed to generate file summaries from executing code. " 158 | f"No dataframes found in globals") 159 | 160 | for item in file_summaries: 161 | if not isinstance(item.file_url, str): 162 | raise ActionSummarizerException( 163 | message=f"Failed to generate file summaries from executing code. " 164 | f"Found {type(item.file_url)} instead of str") 165 | if not isinstance(item.column_names, list): 166 | raise ActionSummarizerException( 167 | message=f"Failed to generate file summaries from executing code. " 168 | f"Found {type(item.column_names)} instead of list") 169 | 170 | return file_summaries 171 | 172 | def _extract_action_context(self, user_input: str) -> ActionContext: 173 | filepaths_extract_template: PromptTemplate = PromptTemplate.from_template(_ACTION_SUMMARY_PROMPT) 174 | requirements_str = "\n".join( 175 | [f"{index + 1}. {requirement}" for index, requirement in enumerate(_ACTION_SUMMARY_REQUIREMENTS)]) 176 | 177 | chain = LLMChain(llm=self.llm, prompt=filepaths_extract_template) 178 | action_summary = chain.run( 179 | role_intro=_ROLE_INTRO, 180 | human_role=HUMAN_ROLE, 181 | requirements=requirements_str, 182 | assistant_role=ASSISTANT_ROLE, 183 | message=user_input, 184 | stop=[HUMAN_STOP_SEQUENCE] 185 | ).strip() 186 | 187 | try: 188 | action_summary_obj = ActionContext.parse_raw(action_summary) 189 | return action_summary_obj 190 | except json.JSONDecodeError as e: 191 | raise ValueError("Invalid JSON format.") from e 192 | -------------------------------------------------------------------------------- /geospatial_agent/agent/geospatial/solver/solver.py: -------------------------------------------------------------------------------- 1 | import json 2 | 3 | import networkx 4 | from langchain import PromptTemplate, LLMChain 5 | from langchain.llms.base import LLM 6 | from pydispatch import dispatcher 7 | 8 | from geospatial_agent.agent.geospatial.solver.constants import NODE_TYPE_ATTRIBUTE, NODE_TYPE_OPERATION 9 | from geospatial_agent.agent.geospatial.solver.op_graph import OperationsParser, OperationNode 10 | from geospatial_agent.agent.geospatial.solver.prompts import operation_code_gen_intro, \ 11 | operation_task_prefix, operation_reply_example, operation_code_gen_prompt_template, \ 12 | operation_pydeck_example, operation_requirement_gen_task_prefix, predefined_operation_requirements, \ 13 | shim_instructions 14 | from geospatial_agent.agent.shared import SIGNAL_OPERATION_CODE_GENERATED, SENDER_GEOSPATIAL_AGENT, AgentSignal, \ 15 | EventType, SIGNAL_TAIL_CODE_GENERATED 16 | from geospatial_agent.shared.prompts import HUMAN_ROLE, ASSISTANT_ROLE, HUMAN_STOP_SEQUENCE 17 | from geospatial_agent.shared.shim import get_shim_imports 18 | from geospatial_agent.shared.utils import extract_code, extract_content_xml 19 | 20 | from typing import List 21 | 22 | 23 | class OperationCodeGenOutput: 24 | def __init__(self, 25 | operation_prompt: str, 26 | operation_code_gen_response: str, 27 | operation_code: str): 28 | self.operation_prompt = operation_prompt 29 | self.operation_code_gen_response = operation_code_gen_response 30 | self.operation_code = operation_code 31 | 32 | 33 | class InvalidStateError(Exception): 34 | def __init__(self, message: str): 35 | self.message = message 36 | super().__init__(self.message) 37 | 38 | 39 | class Solver: 40 | def __init__(self, 41 | llm: LLM, 42 | graph: networkx.DiGraph, 43 | graph_code: str, 44 | session_id: str, 45 | storage_mode: str, 46 | task_definition: str, 47 | task_name: str, 48 | data_locations_instructions: str): 49 | self.llm = llm 50 | self.graph = graph 51 | self.graph_code = graph_code 52 | self.session_id = session_id 53 | self.storage_mode = storage_mode 54 | self.task_def = task_definition 55 | self.task_name = task_name 56 | self.data_locations_instructions = data_locations_instructions 57 | self.operation_parser = OperationsParser(graph) 58 | 59 | def solve(self): 60 | op_nodes = self.operation_parser.operation_nodes 61 | 62 | for idx, op_node in enumerate(op_nodes): 63 | operation_code_gen_output = self.gen_operation_code(op_node) 64 | dispatcher.send(signal=SIGNAL_OPERATION_CODE_GENERATED, 65 | sender=SENDER_GEOSPATIAL_AGENT, 66 | event_data=AgentSignal( 67 | event_source=SENDER_GEOSPATIAL_AGENT, 68 | event_message=f"{idx + 1} / {len(op_nodes)}: Generated code for operation {op_node.node_name}", 69 | event_data=operation_code_gen_output.operation_code, 70 | event_type=EventType.PythonCode 71 | )) 72 | 73 | # INFO: Updating Operation Nodes with generated code 74 | op_node.operation_prompt = operation_code_gen_output.operation_prompt 75 | op_node.code_gen_response = operation_code_gen_output.operation_code_gen_response 76 | op_node.operation_code = operation_code_gen_output.operation_code 77 | 78 | return op_nodes 79 | 80 | def assemble(self): 81 | output_node_names = self.operation_parser.output_node_names 82 | operation_nodes = self.operation_parser.operation_nodes 83 | 84 | # The head end of the code 85 | head = "" 86 | 87 | # The tail end of the code 88 | tail = "" 89 | 90 | reverse_graph = self.graph.reverse(copy=True) 91 | 92 | for idx, output_node in enumerate(output_node_names): 93 | bfs_edges = networkx.bfs_edges(reverse_graph, source=output_node) 94 | for bfs_edge in bfs_edges: 95 | from_node_name, _ = bfs_edge 96 | current_nx_node = self.graph.nodes[from_node_name] 97 | 98 | if current_nx_node.get(NODE_TYPE_ATTRIBUTE, None) == NODE_TYPE_OPERATION: 99 | op_node: OperationNode = next( 100 | (op_node for op_node in operation_nodes if op_node.node_name == from_node_name), None) 101 | 102 | head = "\n" + op_node.operation_code + "\n" + head 103 | tail = f'{", ".join(op_node.return_names)}={op_node.function_definition}\n' + tail 104 | 105 | # Adding the session id and task name to the code 106 | tail = f'\nsession_id = "{self.session_id}"\n' + \ 107 | f'task_name = "{self.task_name}"\n' + \ 108 | f'storage_mode = "{self.storage_mode}"\n' + \ 109 | tail 110 | 111 | dispatcher.send(signal=SIGNAL_TAIL_CODE_GENERATED, 112 | sender=SENDER_GEOSPATIAL_AGENT, 113 | event_data=AgentSignal( 114 | event_source=SENDER_GEOSPATIAL_AGENT, 115 | event_message=f"Generated final code block.", 116 | event_data=tail, 117 | event_type=EventType.PythonCode 118 | )) 119 | 120 | assembled_code = head + "\n" + tail 121 | assembled_code = f'{get_shim_imports()}\n{assembled_code}' 122 | return assembled_code 123 | 124 | def get_operation_requirement(self, op_node: OperationNode) -> list[str]: 125 | node_name = op_node.node_name 126 | 127 | task_def = self.task_def.strip("\n").strip() 128 | op_properties = [ 129 | f'The function description is: {op_node.description}', 130 | f'The type of work done in this function is: {op_node.operation_type}', 131 | f'This function is one step to solve the question/task: {task_def}' 132 | ] 133 | 134 | op_properties_str = '\n'.join( 135 | [f"{idx + 1}. {line}" for idx, line in enumerate(op_properties)]) 136 | 137 | operation_requirement_str = '\n'.join( 138 | [f"{idx + 1}. {line}" for idx, line in enumerate(predefined_operation_requirements)]) 139 | 140 | op_req_gen_prompt_template: PromptTemplate = PromptTemplate.from_template(operation_requirement_gen_task_prefix) 141 | 142 | chain = LLMChain(llm=self.llm, prompt=op_req_gen_prompt_template) 143 | req_gen_response = chain.run( 144 | human_role=HUMAN_ROLE, 145 | operation_req_gen_intro=operation_code_gen_intro, 146 | operation_name=node_name, 147 | pre_requirements=operation_requirement_str, 148 | operation_properties=op_properties_str, 149 | assistant_role=ASSISTANT_ROLE, 150 | stop=[HUMAN_STOP_SEQUENCE] 151 | ).strip() 152 | 153 | operation_requirement_json = extract_content_xml("json", req_gen_response) 154 | operation_requirement_list: List[str] = json.loads(operation_requirement_json) 155 | operation_requirement_list = shim_instructions + operation_requirement_list 156 | return operation_requirement_list 157 | 158 | def gen_operation_code(self, op_node: OperationNode) -> OperationCodeGenOutput: 159 | operation_requirement_list = self.get_operation_requirement(op_node) 160 | 161 | node_name = op_node.node_name 162 | 163 | # Get ancestors operations functions. For operations that has ancestors, this will also come with LLM 164 | # generated code for the operations. 165 | ancestor_op_nodes = self.operation_parser.get_ancestors(node_name) 166 | ancestor_op_nodes_code = '\n'.join([op_node.operation_code for op_node in ancestor_op_nodes]) 167 | 168 | descendant_op_node = self.operation_parser.get_descendants(node_name) 169 | descendant_op_node_defs = self.operation_parser.stringify_nodes(descendant_op_node) 170 | 171 | pre_requirements = [ 172 | f'The function description is: {op_node.description}', 173 | f'The function definition is: {op_node.function_definition}', 174 | f'The function return line is: {op_node.return_line}' 175 | ] 176 | 177 | operation_requirements_str = '\n'.join( 178 | [f"{idx + 1}. {line}" for idx, line in enumerate(pre_requirements + operation_requirement_list)]) 179 | 180 | op_code_gen_prompt_template: PromptTemplate = PromptTemplate.from_template(operation_code_gen_prompt_template) 181 | op_code_gen_prompt = op_code_gen_prompt_template.format( 182 | human_role=HUMAN_ROLE, 183 | operation_code_gen_intro=operation_code_gen_intro, 184 | operation_task_prefix=operation_task_prefix, 185 | operation_description=op_node.description, 186 | task_definition=self.task_def.strip("\n").strip(), 187 | graph_code=self.graph_code, 188 | data_locations_instructions=self.data_locations_instructions, 189 | session_id=self.session_id, 190 | task_name=self.task_name, 191 | storage_mode=self.storage_mode, 192 | operation_reply_example=operation_reply_example, 193 | operation_pydeck_example=operation_pydeck_example, 194 | operation_requirements=operation_requirements_str, 195 | ancestor_operation_code=ancestor_op_nodes_code, 196 | descendant_operations_definition=str(descendant_op_node_defs), 197 | assistant_role=ASSISTANT_ROLE 198 | ) 199 | 200 | chain = LLMChain(llm=self.llm, prompt=op_code_gen_prompt_template) 201 | code_gen_response = chain.run( 202 | human_role=HUMAN_ROLE, 203 | operation_code_gen_intro=operation_code_gen_intro, 204 | operation_task_prefix=operation_task_prefix, 205 | operation_description=op_node.description, 206 | task_definition=self.task_def.strip("\n").strip(), 207 | graph_code=self.graph_code, 208 | data_locations_instructions=self.data_locations_instructions, 209 | session_id=self.session_id, 210 | task_name=self.task_name, 211 | storage_mode=self.storage_mode, 212 | operation_reply_example=operation_reply_example, 213 | operation_pydeck_example=operation_pydeck_example, 214 | operation_requirements=operation_requirements_str, 215 | ancestor_operation_code=ancestor_op_nodes_code, 216 | descendant_operations_definition=str(descendant_op_node_defs), 217 | assistant_role=ASSISTANT_ROLE, 218 | stop=[HUMAN_STOP_SEQUENCE] 219 | ).strip() 220 | 221 | operation_code = extract_code(code_gen_response) 222 | 223 | return OperationCodeGenOutput( 224 | operation_prompt=op_code_gen_prompt, 225 | operation_code_gen_response=code_gen_response, 226 | operation_code=operation_code 227 | ) 228 | --------------------------------------------------------------------------------