├── .gitignore ├── tutorial └── img │ ├── screenshot-dark.png │ ├── top_5_customers.jpg │ ├── data_agent_design.jpg │ └── top_5_customers_screenshot.png ├── metadata └── sfdc_metadata_loader │ ├── requirements.txt │ ├── __init__.py │ ├── sfdc_metadata.py │ └── sfdc_metadata_loader.py ├── src ├── agents │ ├── data_agent │ │ ├── tools │ │ │ ├── __init__.py │ │ │ ├── utils.py │ │ │ ├── chart_evaluator.py │ │ │ ├── crm_business_analyst.py │ │ │ ├── data_engineer.py │ │ │ └── bi_engineer.py │ │ ├── prompts │ │ │ ├── __init__.py │ │ │ ├── sql_correction.py │ │ │ ├── chart_evaluator.py │ │ │ ├── root_agent.py │ │ │ ├── crm_business_analyst.py │ │ │ ├── bi_engineer.py │ │ │ └── data_engineer.py │ │ ├── __init__.py │ │ └── agent.py │ └── __init__.py ├── Dockerfile ├── requirements.txt ├── shared │ ├── __init__.py │ ├── config_env.py │ └── firestore_session_service.py ├── __init__.py ├── web │ ├── __init__.py │ ├── fast_api_runner.py │ ├── main.py │ ├── images │ │ ├── logo.svg │ │ └── logo-dark.svg │ ├── agent_runtime_client.py │ ├── fast_api_app.py │ └── web.py └── .env-template ├── utils ├── __init__.py ├── deploy_to_cloud_run.py └── deploy_demo_data.py ├── deploy_to_cloud_run.sh ├── run_local.sh ├── app.json ├── README.md └── LICENSE /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__/ 2 | .local 3 | .venv 4 | data 5 | .env 6 | .vscode/ 7 | .DS_Store 8 | -------------------------------------------------------------------------------- /tutorial/img/screenshot-dark.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vladkol/crm-data-agent/HEAD/tutorial/img/screenshot-dark.png -------------------------------------------------------------------------------- /tutorial/img/top_5_customers.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vladkol/crm-data-agent/HEAD/tutorial/img/top_5_customers.jpg -------------------------------------------------------------------------------- /metadata/sfdc_metadata_loader/requirements.txt: -------------------------------------------------------------------------------- 1 | google-cloud-bigquery 2 | # google-cloud-secret-manager 3 | simple-salesforce 4 | -------------------------------------------------------------------------------- /tutorial/img/data_agent_design.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vladkol/crm-data-agent/HEAD/tutorial/img/data_agent_design.jpg -------------------------------------------------------------------------------- /tutorial/img/top_5_customers_screenshot.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vladkol/crm-data-agent/HEAD/tutorial/img/top_5_customers_screenshot.png -------------------------------------------------------------------------------- /src/agents/data_agent/tools/__init__.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | sys.path.append(".") 5 | sys.path.append(os.path.dirname(os.path.abspath(__file__))) 6 | -------------------------------------------------------------------------------- /metadata/sfdc_metadata_loader/__init__.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | import sys 3 | 4 | sys.path.append(".") 5 | sys.path.append(str(Path(__file__).parent)) 6 | sys.path.append(str(Path(__file__).parent.parent)) -------------------------------------------------------------------------------- /src/Dockerfile: -------------------------------------------------------------------------------- 1 | FROM python:3.11-slim 2 | 3 | WORKDIR /app 4 | COPY . /app 5 | 6 | RUN pip install --no-cache-dir -r requirements.txt 7 | 8 | EXPOSE 8080 9 | 10 | CMD uvicorn --app-dir web fast_api_runner:api_app --port 8000 --workers 16 & python3 web/main.py agents/data_agent local & wait 11 | -------------------------------------------------------------------------------- /src/requirements.txt: -------------------------------------------------------------------------------- 1 | a2a-sdk==0.2.* 2 | google-adk==1.3.* 3 | streamlit 4 | tabulate 5 | platformdirs 6 | Flask 7 | gunicorn 8 | scikit-learn 9 | google-genai==1.20.* 10 | pillow 11 | pandas 12 | google-cloud-aiplatform[adk]==1.97.* 13 | google-cloud-bigquery[pandas] 14 | google-cloud-bigquery-storage 15 | google-cloud-firestore==2.21.* 16 | altair==5.5.* 17 | vl-convert-python 18 | python-dotenv 19 | cloudpickle==3.1.1 20 | pydantic==2.11.* 21 | yfinance 22 | matplotlib 23 | -------------------------------------------------------------------------------- /src/shared/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """shared module init""" 15 | -------------------------------------------------------------------------------- /src/agents/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """agents init module""" 15 | 16 | -------------------------------------------------------------------------------- /src/agents/data_agent/prompts/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Prompt templates module.""" 15 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """agents module init""" 15 | 16 | import os 17 | import sys 18 | 19 | sys.path.append(".") 20 | sys.path.append(os.path.dirname(os.path.abspath(__file__))) 21 | -------------------------------------------------------------------------------- /src/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """src module init""" 15 | 16 | from pathlib import Path 17 | import sys 18 | 19 | sys.path.append(".") 20 | sys.path.append(str(Path(__file__).parent)) 21 | sys.path.append(str(Path(__file__).parent.parent)) -------------------------------------------------------------------------------- /src/web/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """agents module init""" 15 | 16 | from pathlib import Path 17 | import sys 18 | 19 | sys.path.append(".") 20 | sys.path.append(str(Path(__file__).parent)) 21 | sys.path.append(str(Path(__file__).parent.parent)) 22 | -------------------------------------------------------------------------------- /src/agents/data_agent/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """QnA Agent init module""" 15 | 16 | import os 17 | import sys 18 | 19 | sys.path.append(".") 20 | sys.path.append(os.path.dirname(os.path.abspath(__file__))) 21 | 22 | from .agent import root_agent 23 | 24 | __version__ = "0.1.0" 25 | -------------------------------------------------------------------------------- /deploy_to_cloud_run.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Copyright 2025 Google LLC 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # https://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd ) 17 | pushd "${SCRIPT_DIR}" &> /dev/null || exit 18 | 19 | SERVICE_NAME="crm-data-agent" 20 | python3 utils/deploy_to_cloud_run.py "${SERVICE_NAME}" 21 | 22 | 23 | popd &> /dev/null || exit 24 | -------------------------------------------------------------------------------- /run_local.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Copyright 2025 Google LLC 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # https://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd ) 17 | pushd "${SCRIPT_DIR}/src" &> /dev/null || exit 18 | uvicorn --app-dir web fast_api_runner:api_app --port 8000 & python3 web/main.py "agents/data_agent" "local" & wait 19 | popd &> /dev/null || exit 20 | -------------------------------------------------------------------------------- /src/.env-template: -------------------------------------------------------------------------------- 1 | AGENT_NAME="crm_data_agent" # Agent Name. [REQUIRED] 2 | GOOGLE_GENAI_USE_VERTEXAI=1 # Only tested with Gemini in Vertex AI. [REQUIRED] 3 | 4 | GOOGLE_CLOUD_PROJECT="" # Project Id for Vertex AI and Cloud Run. [REQUIRED] 5 | GOOGLE_CLOUD_LOCATION="us-central1" # Cloud Region for Vertex AI and Cloud Run. [REQUIRED] 6 | AI_STORAGE_BUCKET="" # Storage bucket for Artifacts Service and Vertex AI deployment operations. [REQUIRED] 7 | 8 | BQ_LOCATION="US" # BigQuery location [REQUIRED] 9 | SFDC_BQ_DATASET="sfdc_data" # BigQuery Dataset with Salesforce data in SFDC_DATA_PROJECT_ID project. [REQUIRED] 10 | # BQ_PROJECT_ID= # Project Id for executing BigQuery queries. (if empty, defaults to GOOGLE_CLOUD_PROJECT) 11 | # SFDC_DATA_PROJECT_ID= # Project Id Salesforce Data in BigQuery. (if empty, defaults to BQ_PROJECT_ID/GOOGLE_CLOUD_PROJECT) 12 | 13 | FIRESTORE_SESSION_DATABASE="" # Firestore database name for storing session data [REQUIRED] 14 | 15 | SFDC_METADATA_FILE="sfdc_metadata.json" # Salesforce Metadata file path. Do not change it if not using custom metadata. 16 | -------------------------------------------------------------------------------- /utils/deploy_to_cloud_run.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # Copyright 2025 Google LLC 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # https://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | from pathlib import Path 17 | import os 18 | import subprocess 19 | import sys 20 | 21 | sys.path.append(str(Path(__file__).parent.parent)) 22 | from src.shared.config_env import prepare_environment 23 | 24 | prepare_environment() 25 | 26 | cmd_line = f""" 27 | gcloud run deploy "{sys.argv[1]}" 28 | --project="{os.environ['GOOGLE_CLOUD_PROJECT']}" 29 | --region="{os.environ['GOOGLE_CLOUD_LOCATION']}" 30 | --port=8080 31 | --cpu=8 32 | --memory=32Gi 33 | --cpu-boost 34 | --no-allow-unauthenticated 35 | --no-cpu-throttling 36 | --source "./src" 37 | --timeout 1h 38 | -q 39 | """.replace("\n", " ").strip() 40 | 41 | subprocess.run(cmd_line, shell=True) 42 | -------------------------------------------------------------------------------- /utils/deploy_demo_data.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # Copyright 2025 Google LLC 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # https://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | from pathlib import Path 17 | import os 18 | import subprocess 19 | import sys 20 | 21 | sys.path.append(str(Path(__file__).parent.parent)) 22 | from src.shared.config_env import prepare_environment 23 | 24 | prepare_environment() 25 | 26 | git_cmd_line = "git clone --depth 1 --no-tags https://github.com/vladkol/sfdc-kittycorn ./data && rm -rf ./data/.git" 27 | subprocess.run(git_cmd_line, shell=True) 28 | 29 | _python_cmd_line = (f"\"{sys.executable}\" ./data/deploy_to_my_project.py " 30 | f"--project {os.environ['SFDC_DATA_PROJECT_ID']} " 31 | f"--dataset {os.environ['SFDC_BQ_DATASET']} " 32 | f"--location {os.environ['BQ_LOCATION']}") 33 | subprocess.run(_python_cmd_line, shell=True) 34 | -------------------------------------------------------------------------------- /src/agents/data_agent/prompts/sql_correction.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """SQL Correction prompt template.""" 15 | # flake8: noqa 16 | # pylint: disable=all 17 | 18 | instruction = """ 19 | You are a BigQuery SQL Correction Tool. Your task is to analyze incoming BigQuery SQL queries, identify errors based on syntax and the provided schema, and output a corrected, fully executable query. 20 | 21 | **Context:** 22 | * **Platform:** Google BigQuery 23 | * **Project ID:** `{data_project_id}` 24 | * **Dataset Name:** `{dataset}` 25 | 26 | **Schema:** 27 | You MUST operate exclusively within the following database schema for the `{data_project_id}.{dataset}` dataset. All table and field references must conform to this structure: 28 | 29 | ```json 30 | {sfdc_metadata} 31 | ``` 32 | """ 33 | 34 | prompt = """ 35 | ```sql 36 | {validating_query} 37 | ``` 38 | 39 | Fix the error below. Do not simply exclude entities if it affects the algorithm. 40 | !!! Do not repeat yourself !!! 41 | 42 | ERROR: {validator_result} 43 | """ -------------------------------------------------------------------------------- /src/web/fast_api_runner.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Agent Runner App - FastAPI service""" 15 | 16 | import logging 17 | import os 18 | from pathlib import Path 19 | import sys 20 | 21 | from google.adk.artifacts import GcsArtifactService 22 | from fast_api_app import get_fast_api_app 23 | 24 | sys.path.append(str(Path(__file__).parent.parent)) 25 | from shared.config_env import prepare_environment 26 | 27 | #################### Initialization #################### 28 | logging.getLogger().setLevel(logging.INFO) 29 | os.environ["AGENT_DIR"] = str(Path(__file__).parent.parent / 30 | "agents" / 31 | "data_agent") 32 | prepare_environment() 33 | ######################################################## 34 | 35 | api_app = get_fast_api_app( 36 | agent_dir=os.environ["AGENT_DIR"], 37 | trace_to_cloud=False, 38 | artifact_service=GcsArtifactService( 39 | bucket_name=os.environ["AI_STORAGE_BUCKET"] 40 | ) 41 | ) 42 | 43 | 44 | -------------------------------------------------------------------------------- /src/agents/data_agent/prompts/chart_evaluator.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Chart Evaluator prompt template.""" 15 | # flake8: noqa 16 | # pylint: disable=all 17 | 18 | prompt = """ 19 | **Instructions**: 20 | 21 | The image is a BI chart or a dashboard that shows data supporting an answer to a question below. 22 | 23 | Number of rows in the data source is: {data_row_count}. 24 | Make sure labels and values are readable. 25 | 26 | After looking at a chart, decide if it's good or not good (nothing in between). 27 | If not good, provide a reason with a longer explanation of what needs to be worked on. 28 | 29 | The chart must be comfortable to read on a 2K screen of 16 inch size with ability to zoom in and out. 30 | Do not make comments about choice of dimensions, metrics, grouping or data cardinality. 31 | You can only criticize readability, composition, color choices, font size, etc. 32 | 33 | **Exceptions:** 34 | * The chart may require interaction (selection parameters). Default selection may make rendered chart lack data. Be tolerant to that. 35 | * The chart may be hard to read due to the density of elements. If the density problem can be solved by selecting a parameter value, then then assume it is ok, and let it slide. 36 | 37 | **QUESTION:** 38 | ``` 39 | {question} 40 | ``` 41 | 42 | This is chart json code in Vega-Lite (data removed): 43 | 44 | ```json 45 | {chart_json} 46 | ``` 47 | 48 | """ 49 | -------------------------------------------------------------------------------- /src/web/main.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # Copyright 2025 Google LLC 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # https://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Agent web bootstrap script""" 16 | 17 | import os 18 | from pathlib import Path 19 | import sys 20 | 21 | import streamlit.web.bootstrap as bootstrap 22 | 23 | sys.path.append(str(Path(__file__).parent.parent)) 24 | from shared.config_env import prepare_environment 25 | 26 | flag_options = { 27 | "server.headless": True, 28 | "server.enableCORS": False, 29 | "server.enableXsrfProtection": False, 30 | "server.fileWatcherType": None, 31 | "server.port": int(os.getenv("PORT", 8080)), 32 | "server.enableWebsocketCompression": True, 33 | "browser.gatherUsageStats": False, 34 | "client.toolbarMode": "minimal", 35 | "global.developmentMode": False, 36 | "theme.font": "Inter, Verdana", 37 | "theme.base": "dark", 38 | "logger.level": "info", 39 | } 40 | 41 | if __name__ == "__main__": 42 | if len(sys.argv) >= 2: 43 | agent_dir = sys.argv[1] 44 | else: 45 | agent_dir = "." 46 | if len(sys.argv) >= 3: 47 | target_env = sys.argv[2] 48 | else: 49 | target_env = "local" 50 | print(f"Target runtime environment: {target_env}") 51 | 52 | agent_dir = os.path.abspath(agent_dir) 53 | print(f"Agent directory: {agent_dir}") 54 | os.environ["AGENT_DIR"] = agent_dir 55 | os.environ["RUNTIME_ENVIRONMENT"] = target_env 56 | prepare_environment() 57 | app_script_path = os.path.join(os.path.dirname(__file__), "web.py") 58 | bootstrap.load_config_options(flag_options) 59 | bootstrap.run( 60 | app_script_path, 61 | False, 62 | [], 63 | flag_options, 64 | ) 65 | -------------------------------------------------------------------------------- /app.json: -------------------------------------------------------------------------------- 1 | { 2 | "options": { 3 | "allow-unauthenticated": true, 4 | "memory": "32Gi", 5 | "cpu": "8", 6 | "port": 8080, 7 | "http2": false 8 | }, 9 | "build": { 10 | "skip": true 11 | }, 12 | "hooks": { 13 | "prebuild": { 14 | "commands": [ 15 | "gcloud config set project ${GOOGLE_CLOUD_PROJECT} -q", 16 | "gcloud services enable aiplatform.googleapis.com cloudbuild.googleapis.com run.googleapis.com firestore.googleapis.com bigquery.googleapis.com --project=${GOOGLE_CLOUD_PROJECT} || true", 17 | "pip install -q -U google-cloud-bigquery google-cloud-bigquery-storage", 18 | "git clone --depth 1 --no-tags https://github.com/vladkol/sfdc-kittycorn ./data; rm -rf ./data/.git/" 19 | ] 20 | }, 21 | "postbuild": { 22 | "commands": [ 23 | "echo GOOGLE_CLOUD_PROJECT=\"${GOOGLE_CLOUD_PROJECT}\" >> src/.env", 24 | "echo GOOGLE_CLOUD_LOCATION=\"${GOOGLE_CLOUD_REGION}\" >> src/.env", 25 | "echo GOOGLE_GENAI_USE_VERTEXAI=1 >> src/.env", 26 | "echo AGENT_NAME=\"crm-data-agent\" >> src/.env", 27 | "echo FIRESTORE_SESSION_DATABASE=\"agent-sessions\" >> src/.env", 28 | "echo AI_STORAGE_BUCKET=\"crm-data-agent-artifacts--${GOOGLE_CLOUD_PROJECT}\" >> src/.env", 29 | "echo BQ_LOCATION=\"US\" >> src/.env", 30 | "echo SFDC_BQ_DATASET=\"sfdc_data\" >> src/.env", 31 | "source src/.env; gcloud firestore databases describe --database=${FIRESTORE_SESSION_DATABASE} --project=${GOOGLE_CLOUD_PROJECT} -q &> /dev/null || gcloud firestore databases create --database=${FIRESTORE_SESSION_DATABASE} --project=${GOOGLE_CLOUD_PROJECT} --location=${GOOGLE_CLOUD_REGION}", 32 | "source src/.env; gcloud storage buckets describe gs://${AI_STORAGE_BUCKET} -q &> /dev/null || gcloud storage buckets create gs://${AI_STORAGE_BUCKET} --project=${GOOGLE_CLOUD_PROJECT} --location=${GOOGLE_CLOUD_REGION}", 33 | "source src/.env; python3 ./data/deploy_to_my_project.py --project ${GOOGLE_CLOUD_PROJECT} --dataset ${SFDC_BQ_DATASET} --location ${BQ_LOCATION}", 34 | "docker build -t ${IMAGE_URL} ./src" 35 | ] 36 | } 37 | } 38 | } 39 | -------------------------------------------------------------------------------- /metadata/sfdc_metadata_loader/sfdc_metadata.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Salesforce CRM metadata extractor""" 15 | 16 | import json 17 | import pathlib 18 | import threading 19 | import typing 20 | 21 | 22 | class SFDCMetadata: 23 | """Salesforce CRM metadata client""" 24 | 25 | def __init__( 26 | self, 27 | project_id: str, 28 | dataset_name: str, 29 | metadata_file: typing.Optional[str] = None 30 | ) -> None: 31 | """ 32 | Args: 33 | project_id (str): GCP project id of BigQuery data. 34 | dataset_name (str): BigQuery dataset name. 35 | metadata_file (str, Optional): path to metadata file. 36 | If not provided, 37 | it will be "{.project_id}__{dataset_name}.json" 38 | in the current directory. 39 | """ 40 | self.project_id = project_id 41 | self.dataset_name = dataset_name 42 | if metadata_file: 43 | self._metadata_file_name = metadata_file 44 | else: 45 | self._metadata_file_name = (f"{self.project_id}__" 46 | "{self.dataset_name}.json") 47 | self._metadata = {} 48 | self._lock = threading.Lock() 49 | 50 | def get_metadata(self) -> typing.Dict[str, typing.Any]: 51 | """Extract metadata from Salesforce CRM""" 52 | if len(self._metadata) > 0: 53 | return self._metadata 54 | 55 | with self._lock: 56 | if len(self._metadata) == 0: 57 | metadata_path = pathlib.Path(self._metadata_file_name) 58 | if metadata_path.exists(): 59 | self._metadata = json.loads( 60 | metadata_path.read_text(encoding="utf-8")) 61 | else: 62 | raise FileNotFoundError(self._metadata_file_name) 63 | 64 | return self._metadata 65 | 66 | -------------------------------------------------------------------------------- /src/agents/data_agent/tools/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Utils for agents""" 15 | 16 | from functools import cached_property 17 | import os 18 | from typing_extensions import override 19 | from threading import Lock 20 | from typing import AsyncGenerator 21 | 22 | from google import genai 23 | from google.adk.models import Gemini 24 | from google.adk.models.llm_request import LlmRequest 25 | from google.adk.models.llm_response import LlmResponse 26 | 27 | _lock = Lock() 28 | 29 | _gemini = None 30 | _llm_client = None 31 | 32 | class _GlobalGemini(Gemini): 33 | @override 34 | async def generate_content_async( 35 | self, llm_request: LlmRequest, stream: bool = False 36 | ) -> AsyncGenerator[LlmResponse, None]: 37 | if not llm_request.model: 38 | llm_request.model = "gemini-flash-2.0" 39 | if ( 40 | llm_request.model.startswith("gemini-2") 41 | and "/" not in llm_request.model 42 | ): 43 | project = os.environ["GOOGLE_CLOUD_PROJECT"] 44 | llm_request.model = (f"projects/{project}/locations/global/" 45 | "publishers/google/" 46 | f"models/{llm_request.model}") 47 | async for response in super().generate_content_async( 48 | llm_request, stream 49 | ): 50 | yield response 51 | 52 | @cached_property 53 | def api_client(self) -> genai.Client: 54 | """Provides the api client. 55 | 56 | Returns: 57 | The api client. 58 | """ 59 | original_client = super().api_client 60 | return genai.Client( 61 | vertexai=original_client.vertexai, 62 | location="global", 63 | ) 64 | 65 | 66 | def get_genai_client(model_id: str = "gemini-flash-2.0") -> genai.Client: 67 | global _gemini 68 | global _llm_client 69 | if _llm_client: 70 | return _llm_client 71 | with _lock: 72 | if _llm_client: 73 | return _llm_client 74 | _gemini = _GlobalGemini(model=model_id) 75 | _gemini.api_client._api_client.location = "global" 76 | _llm_client = _gemini.api_client 77 | return _llm_client 78 | 79 | def get_gemini_model(model_id: str) -> Gemini: 80 | global _gemini 81 | get_genai_client() 82 | res = _gemini.model_copy() # type: ignore 83 | res.model = model_id 84 | return res 85 | 86 | def get_shared_lock() -> Lock: 87 | return _lock -------------------------------------------------------------------------------- /src/agents/data_agent/tools/chart_evaluator.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Chart Evaluator Sub-tool""" 15 | 16 | from google.adk.tools import ToolContext 17 | from google.genai.types import Content, GenerateContentConfig, Part, SafetySetting 18 | 19 | from pydantic import BaseModel 20 | 21 | from .utils import get_genai_client 22 | from prompts.chart_evaluator import prompt as chart_evaluator_prompt 23 | 24 | 25 | CHART_EVALUATOR_MODEL_ID = "gemini-2.0-flash-001" 26 | 27 | class EvaluationResult(BaseModel): 28 | is_good: bool 29 | reason: str 30 | 31 | 32 | def evaluate_chart(png_image: bytes, 33 | chart_json_text: str, 34 | question: str, 35 | data_row_count: int, 36 | tool_context: ToolContext) -> EvaluationResult: 37 | """ 38 | This is an experienced Business Intelligence UX designer. 39 | They look at a chart or a dashboard, and can tell if it the right one for the question. 40 | 41 | Parameters: 42 | * png_image (str) - png image of the chart or a dashboard 43 | * question (str) - question this chart is supposed to answer 44 | 45 | """ 46 | 47 | prompt = chart_evaluator_prompt.format(data_row_count=data_row_count, 48 | chart_json=chart_json_text, 49 | question=question) 50 | 51 | image_part = Part.from_bytes(mime_type="image/png", data=png_image) 52 | eval_result = get_genai_client().models.generate_content( 53 | model=CHART_EVALUATOR_MODEL_ID, 54 | contents=Content( 55 | role="user", 56 | parts=[ 57 | image_part, # type: ignore 58 | Part.from_text(text=prompt) 59 | ] 60 | ), 61 | config=GenerateContentConfig( 62 | response_schema=EvaluationResult, 63 | response_mime_type="application/json", 64 | system_instruction=f""" 65 | You are an experienced Business Intelligence UX designer. 66 | You can look at a chart or a dashboard, and tell if it the right one for the question. 67 | """.strip(), 68 | temperature=0.1, 69 | top_p=0.0, 70 | seed=1, 71 | safety_settings=[ 72 | SafetySetting( 73 | category="HARM_CATEGORY_DANGEROUS_CONTENT", # type: ignore 74 | threshold="BLOCK_ONLY_HIGH", # type: ignore 75 | ), 76 | ]) 77 | ) 78 | 79 | return eval_result.parsed # type: ignore 80 | -------------------------------------------------------------------------------- /src/agents/data_agent/tools/crm_business_analyst.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Business Analyst Agent""" 15 | 16 | import uuid 17 | 18 | from google.adk.agents import LlmAgent 19 | from google.adk.agents.callback_context import CallbackContext 20 | from google.adk.models import LlmResponse 21 | from google.adk.planners import BuiltInPlanner 22 | 23 | from google.genai.types import ( 24 | GenerateContentConfig, 25 | Part, 26 | SafetySetting, 27 | ThinkingConfig 28 | ) 29 | 30 | from prompts.crm_business_analyst import (system_instruction 31 | as crm_business_analyst_instruction) 32 | from tools.utils import get_gemini_model 33 | 34 | 35 | BUSINESS_ANALYST_AGENT_MODEL_ID = "gemini-2.5-pro" # "gemini-2.5-pro-preview-05-06" 36 | 37 | 38 | async def after_model_callback(callback_context: CallbackContext, 39 | llm_response: LlmResponse) -> LlmResponse | None: 40 | if not llm_response.content or not llm_response.content.parts: 41 | return 42 | for p in llm_response.content.parts: 43 | if p.text and p.text.strip(): 44 | await callback_context.save_artifact( 45 | f"analysis_{uuid.uuid4().hex}.md", 46 | Part.from_bytes( 47 | mime_type="text/markdown", 48 | data=p.text.encode("utf-8") 49 | ) 50 | ) 51 | 52 | 53 | crm_business_analyst_agent = LlmAgent( 54 | model=get_gemini_model(BUSINESS_ANALYST_AGENT_MODEL_ID), 55 | name="crm_business_analyst", 56 | description=""" 57 | This is your Senior Business Analyst. 58 | 59 | They can analyze your questions about business 60 | no matter what form these questions are in. 61 | 62 | Questions may be different: 63 | - Directly linked to business data (e.g. "Revenue by country") 64 | - Open to interpretation (e.g. "Who are my best customers"). 65 | 66 | They figure out what metrics, dimensions and KPIs 67 | could be used to answer the question. 68 | 69 | They may offer a few options. 70 | """, 71 | instruction=crm_business_analyst_instruction, 72 | generate_content_config=GenerateContentConfig( 73 | temperature=0.0, 74 | top_p=0.0, 75 | seed=1, 76 | safety_settings=[ 77 | SafetySetting( 78 | category="HARM_CATEGORY_DANGEROUS_CONTENT", # type: ignore 79 | threshold="BLOCK_ONLY_HIGH", # type: ignore 80 | ), 81 | ] 82 | ), 83 | planner=BuiltInPlanner( 84 | thinking_config=ThinkingConfig(thinking_budget=32768) 85 | ), 86 | after_model_callback=after_model_callback 87 | ) 88 | -------------------------------------------------------------------------------- /src/shared/config_env.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Configuration Environment Variables Loader""" 15 | 16 | import os 17 | from dotenv import load_dotenv, dotenv_values 18 | from pathlib import Path 19 | import logging 20 | import sys 21 | 22 | 23 | _env_requirements = { 24 | "GOOGLE_CLOUD_PROJECT": None, # None value means it has to be non-empty 25 | "GOOGLE_CLOUD_LOCATION": None, 26 | 27 | # `$`` at the beginning refers to another variable 28 | "BQ_PROJECT_ID": "$GOOGLE_CLOUD_PROJECT", 29 | "SFDC_DATA_PROJECT_ID": "$BQ_PROJECT_ID", 30 | "SFDC_BQ_DATASET": None, 31 | "FIRESTORE_SESSION_DATABASE": None, 32 | "BQ_LOCATION": "US", 33 | "SFDC_METADATA_FILE": "sfdc_metadata.json", # default value 34 | "AI_STORAGE_BUCKET": None, 35 | } 36 | _prepared = False 37 | 38 | def _get_dotenv_file() -> str: 39 | dotenv_path = Path(__file__).parent.parent / ".env" 40 | if not dotenv_path.exists(): 41 | logging.warning(f"{dotenv_path} not found.") 42 | return "" 43 | return str(dotenv_path) 44 | 45 | def get_env_values() -> dict: 46 | env_file = _get_dotenv_file() 47 | if not env_file: 48 | logging.warning(".env file not found. Trying to re-construct from values.") 49 | env_dict = {} 50 | for name in _env_requirements: 51 | env_dict[name] = os.environ.get(name, None) 52 | return env_dict 53 | values = dotenv_values(_get_dotenv_file()).copy() 54 | for v in values: 55 | if v in os.environ: 56 | values[v] = os.environ[v] 57 | if ( 58 | "FIREBASE_SESSION_DATABASE" in values 59 | and "FIRESTORE_SESSION_DATABASE" not in values 60 | ): 61 | values["FIRESTORE_SESSION_DATABASE"] = values[ 62 | "FIREBASE_SESSION_DATABASE" 63 | ] 64 | return values 65 | 66 | 67 | def prepare_environment(): 68 | global _prepared 69 | if _prepared: 70 | return 71 | env_file = _get_dotenv_file() 72 | if not env_file: 73 | logging.warning(".env file not found.") 74 | else: 75 | load_dotenv(dotenv_path=_get_dotenv_file(), override=True) 76 | if ( 77 | "FIREBASE_SESSION_DATABASE" in os.environ 78 | and "FIRESTORE_SESSION_DATABASE" not in os.environ 79 | ): 80 | os.environ["FIRESTORE_SESSION_DATABASE"] = os.environ[ 81 | "FIREBASE_SESSION_DATABASE" 82 | ] 83 | for name, val in _env_requirements.items(): 84 | if name in os.environ and len(os.environ[name].strip()) > 0: 85 | continue 86 | if val is None or val.strip() == "": 87 | logging.error((f"{name} environment variable must be set" 88 | "(check .env file).")) 89 | sys.exit(1) 90 | elif val.startswith("$"): 91 | ref_name = val[1:] 92 | os.environ[name] = os.environ[ref_name] 93 | else: 94 | os.environ[name] = val 95 | 96 | from google.cloud.aiplatform import init 97 | init(location="global") 98 | 99 | _prepared = True 100 | -------------------------------------------------------------------------------- /src/agents/data_agent/agent.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Root agent""" 15 | 16 | from pathlib import Path 17 | import sys 18 | from typing import Optional 19 | 20 | from google.genai.types import ( 21 | Content, 22 | GenerateContentConfig, 23 | SafetySetting, 24 | ThinkingConfig 25 | ) 26 | from google.adk.agents import LlmAgent 27 | from google.adk.agents.callback_context import CallbackContext 28 | from google.adk.models import LlmResponse, LlmRequest 29 | from google.adk.planners import BuiltInPlanner 30 | from google.adk.tools.agent_tool import AgentTool 31 | 32 | sys.path.append(str(Path(__file__).parent.parent.parent)) 33 | from shared.config_env import prepare_environment 34 | 35 | from prompts.root_agent import system_instruction as root_agent_instruction 36 | from tools.bi_engineer import bi_engineer_tool 37 | from tools.crm_business_analyst import crm_business_analyst_agent 38 | from tools.data_engineer import data_engineer 39 | from tools.utils import get_gemini_model 40 | 41 | 42 | ROOT_AGENT_MODEL_ID = "gemini-2.5-pro" # "gemini-2.5-pro-preview-05-06" 43 | 44 | 45 | async def before_model_callback(callback_context: CallbackContext, 46 | llm_request: LlmRequest) -> LlmResponse | None: 47 | chart_image_name = callback_context.state.get("chart_image_name", None) 48 | if chart_image_name: 49 | callback_context.state["chart_image_name"] = "" 50 | llm_request.contents[0].parts.append( # type: ignore 51 | await callback_context.load_artifact( 52 | filename=chart_image_name)) # type: ignore 53 | return None 54 | 55 | 56 | async def before_agent_callback(callback_context: CallbackContext) -> Optional[Content]: 57 | pass 58 | 59 | 60 | async def after_model_callback(callback_context: CallbackContext, 61 | llm_response: LlmResponse) -> LlmResponse | None: 62 | pass 63 | 64 | 65 | ########################### AGENT ########################### 66 | prepare_environment() 67 | 68 | root_agent = LlmAgent( 69 | model=get_gemini_model(ROOT_AGENT_MODEL_ID), 70 | name="data_agent", 71 | output_key="output", 72 | description="CRM Data Analytics Consultant", 73 | instruction=root_agent_instruction, 74 | before_model_callback=before_model_callback, 75 | after_model_callback=after_model_callback, 76 | before_agent_callback=before_agent_callback, 77 | tools=[ 78 | AgentTool(crm_business_analyst_agent), 79 | data_engineer, 80 | bi_engineer_tool, 81 | ], 82 | planner=BuiltInPlanner( 83 | thinking_config=ThinkingConfig(thinking_budget=32768) 84 | ), 85 | generate_content_config=GenerateContentConfig( 86 | temperature = 0.0001, 87 | top_p = 0.0, 88 | seed=256, 89 | safety_settings=[ 90 | SafetySetting( 91 | category="HARM_CATEGORY_DANGEROUS_CONTENT", # type: ignore 92 | threshold="BLOCK_ONLY_HIGH", # type: ignore 93 | ), 94 | ], 95 | ) 96 | ) 97 | -------------------------------------------------------------------------------- /src/agents/data_agent/prompts/root_agent.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Root Agent prompt template.""" 15 | # flake8: noqa 16 | # pylint: disable=all 17 | 18 | system_instruction = """ 19 | **// Persona & Role //** 20 | 21 | You are a highly capable Executive Assistant and Business Consultant, acting as the central coordinator for answering data-driven questions. 22 | You possess an MBA and diverse business experience (small/large companies, local/international). 23 | Context: 24 | * **Mindset:** You approach problems with rigorous **first-principles thinking**. You are data-driven and results-oriented. 25 | * **Team:** You delegate tasks effectively to your specialized team: 26 | 1. **CRM Business Analyst (BA):** Defines metrics and data requirements. 27 | 2. **Data Engineer (DE):** Extracts data from the Data Warehouse via SQL. 28 | 3. **BI Engineer (BI):** Executes SQL, returns data, and creates visualizations. 29 | * **Data Source:** Your team works with CRM data replicated to a central Data Warehouse. 30 | * **Today's Objective:** To accurately answer the user's question by orchestrating your team and leveraging the CRM data in the Warehouse. 31 | 32 | **// Core Workflow & Instructions //** 33 | 34 | Follow these steps meticulously to answer the user's query. Remember, your teammates are stateless and require all necessary context for each interaction. 35 | 36 | 1. **Understand & Consult BA:** 37 | * Receive the user's question. 38 | * **Action:** Explain the user's question clearly to the **CRM Business Analyst**. Also pass the exact user question. 39 | * **Goal:** Request their expert suggestion on relevant data points, metrics, KPIs, dimensions, and potential filters needed to answer the question effectively. Ask for the *rationale* behind their suggestions. 40 | * **Constraint** The BA can answer the same question only once. 41 | 42 | 2. **Instruct Data Engineer:** 43 | * **Action:** Pass the finalized detailed plan to the **Data Engineer**. 44 | * **Rule:** "Conceptual Data Steps" part must be passed as is. Add more details and clarifications as necessary. 45 | * **Goal:** Ask the DE to write and execute the SQL query to retrieve this data. 46 | 47 | 3. **Oversee Data Extraction:** 48 | * Receive the SQL query and execution status/result summary from the **Data Engineer**. 49 | * **Action:** Confirm the DE successfully executed *a* query based on your plan. 50 | 51 | 4. **Engage BI Engineer:** 52 | * **Action:** Call the **BI Engineer**. Pass the SQL query file name received from the Data Engineer in Step 3. 53 | * **Rule:** `notes` must be empty for the very first request. 54 | * **Goal:** Instruct the BI Engineer to: 55 | * Execute the provided SQL query against the Data Warehouse. 56 | * Return the resulting data (e.g., summary table). 57 | * Generate an appropriate chart/visualization for the data. 58 | 59 | 5. **Interpret Results:** 60 | * Receive the data and chart from the **BI Engineer**. 61 | * **Action:** Analyze the results. Connect the findings back to the BA's initial suggestions (Step 1) and the original user question. Identify key insights, trends, or answers revealed by the data. 62 | 63 | 6. **Formulate Final Answer:** 64 | * **Action:** Synthesize your findings into the final response for the user. 65 | 66 | **// Context & Constraints //** 67 | 68 | * **Stateless Teammates:** Your BA, DE, and BI tools have NO memory of previous interactions. You MUST provide all necessary context (e.g., user question, refined plan, specific SQL query) in each call. 69 | * **Mandatory Tool Usage:** You must interact with each teammate (BA, DE, BI) at least once by following the workflow steps above. Do not ask any if the teammates the same question twice. 70 | * **Date Filters:** Avoid applying date filters unless explicitly part of the user's request and confirmed in Step 1. 71 | * **SQL Integrity:** Do not modify the DE's SQL. 72 | * **Insufficient Data Handling:** If the BA, DE, or BI Engineer indicates at any step that there isn't enough data, the required data doesn't exist, or the query fails irrecoverably, accept their assessment. Proceed directly to formulating the final answer, stating clearly that the question cannot be answered confidently due to data limitations, and explain why based on the teammate's feedback. 73 | 74 | > If you are confident that user's follow-up question can be answered using the same data, you can skip the BA and DE steps. 75 | 76 | **// Output Format //** 77 | 78 | * **Three-Part Answer:** Provide the answer in three sections: 79 | 1. **Confidence:** How confident you are in the answer and why. You confidence must be based on whether the answer provides data and insights to answer the detailed question as it was formulated by the BA. 80 | 2. **Business Summary & Next Steps:** Provide a concise summary of the findings in business terms. Suggest potential next steps, actions, or further questions based on the results (or lack thereof). 81 | 3. **Detailed Findings:** Explain the results, referencing the key metrics/KPIs suggested by the BA and the data/chart provided by the BI Engineer. Include your interpretation from Step 5. 82 | 83 | * **Insufficient Data Output:** If you determined the question couldn't be answered due to data limitations, state this clearly in both sections of your answer, explaining the reason provided by your teammate. 84 | """ -------------------------------------------------------------------------------- /src/web/images/logo.svg: -------------------------------------------------------------------------------- 1 | 2 | 5 | 8 | 11 | 14 | 17 | -------------------------------------------------------------------------------- /src/web/images/logo-dark.svg: -------------------------------------------------------------------------------- 1 | 2 | 5 | 8 | 11 | 14 | 17 | -------------------------------------------------------------------------------- /src/agents/data_agent/prompts/crm_business_analyst.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Business Analyst prompt template.""" 15 | # flake8: noqa 16 | # pylint: disable=all 17 | 18 | system_instruction = """ 19 | **Persona:** 20 | You ARE a Senior Business Analyst with deep, cross-functional experience spanning customer support, CRM consulting, and core business analysis. Your expertise allows you to bridge the gap between ambiguous business questions and actionable insights derived from CRM data. You think critically and focus on business value. 21 | 22 | **Core Task:** 23 | Analyze incoming business questions, regardless of their format (specific data requests or open-ended inquiries). Your goal is to translate these questions into concrete analysis plans using conceptual CRM data. 24 | 25 | **Input:** 26 | You will receive a business question. 27 | 28 | **Mandatory Process Steps:** 29 | 30 | 1. **Interpret the Question:** 31 | * Apply first-principles thinking to understand the underlying business need. 32 | * If the question is ambiguous, identify and list 2-3 plausible interpretations. 33 | * Assess if historical data is necessary or the snapshot tables are sufficient. 34 | * Choose an interpretation that makes the most sense in terms of the insights it would provide to the user based on the question as a whole as well as the choice of words. 35 | * State the interpretation you will proceed with for the subsequent steps. 36 | 37 | 2. **Identify Relevant Metrics & Dimensions:** 38 | * Based on your chosen interpretation, determine the most relevant KPIs, metrics, and dimensions needed to answer the question. 39 | * Offer a primary suggestion and 1-2 alternative options where applicable. 40 | * Clearly state *why* these are relevant to the business question. 41 | 42 | 3. **Define Calculation Approaches (Linked to CRM Data):** 43 | * For each key metric/KPI identified: 44 | * Propose 1-3 potential calculation methods. 45 | * **Crucially:** Explicitly link each calculation method to the available **CRM Objects** ([Customers, Contacts, Opportunities, Leads, Tasks, Events, Support Cases, Users]). Describe *how* data from these objects would conceptually contribute to the calculation (e.g., "Count of 'Opportunities' where 'Status' is 'Closed Won' associated with 'Customers' in 'X Industry'"). 46 | 47 | 4. **Outline Conceptual Data Retrieval Strategy:** 48 | * Describe a high-level, conceptual sequence of steps to gather the necessary data *conceptually* from the CRM objects. This is about the *logic*, not the technical execution. (e.g., "1. Identify relevant 'Customers'. 2. Find associated 'Opportunities'. 3. Filter 'Opportunities' by 'Status' and 'Close Date'. 4. Aggregate 'Opportunity Amount'."). 49 | 50 | **Output Format:** 51 | Structure your answer clearly: 52 | 53 | * **1. Interpretation(s):** State your understanding(s) of the question. If multiple, specify which one you are using. 54 | * **2. Key Metrics, KPIs & Dimensions:** List the identified items with brief rationale. 55 | * **3. Calculation Options & CRM Links:** Detail the calculation methods and their connection to specific CRM objects. 56 | * **4. Conceptual Data Steps:** Provide the logical sequence for data retrieval. 57 | 58 | **Critical Constraints & Guidelines:** 59 | 60 | * **CRM Data Scope:** Your *only* available data concepts are: **[Customers, Contacts, Opportunities, Leads, Tasks, Events, Support Cases, Users]**. Treat these as conceptual business objects, *not* specific database tables or schemas. 61 | * **NO Data Engineering:** **ABSOLUTELY DO NOT** refer to databases, tables, SQL, ETL, specific data modeling techniques, or any data engineering implementation details. Keep the language focused on business logic and generic CRM concepts. 62 | * **Focus on Business Value:** Prioritize metrics and dimensions that provide actionable insights and directly address the business question. 63 | * **Data Volume:** The results will be directly presented to the user. Aim for digestible number of result rows, unless the user's question assumes otherwise. 64 | * **Aggregation:** Always consider appropriate aggregation levels (e.g., by month, by region, by customer segment), unless the question explicitly states otherwise. 65 | * **Dimension Handling:** 66 | * Refer to dimension *types* (e.g., 'Country' associated with Customer, 'Industry' associated with Customer, 'Date Created' for Lead, 'Status' for Opportunity). 67 | * Do **NOT** filter on specific dimension *values* (e.g., "USA", "Technology", "Q1 2023") *unless* the original question explicitly requires it. 68 | * Only apply date/time filters if the question *explicitly* specifies a period (e.g., "last quarter's revenue", "support cases created this year"). Otherwise, assume analysis across all available time. 69 | * **Data Limitations:** If the question fundamentally cannot be answered using *only* the listed CRM objects, clearly state this and explain what conceptual information is missing (e.g., "Product Cost information is not available in the specified CRM objects"). 70 | 71 | **Your goal is to provide a clear, actionable business analysis plan based *only* on the conceptual CRM data available.** 72 | """ 73 | -------------------------------------------------------------------------------- /src/agents/data_agent/prompts/bi_engineer.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Chart Evaluator prompt template.""" 15 | # flake8: noqa 16 | # pylint: disable=all 17 | 18 | prompt = """ 19 | You are an experienced Business Intelligence engineer tasked with creating a data visualization. 20 | You have good imagination, strong UX design skills, and you decent data engineering background. 21 | 22 | **Context:** 23 | 1. **Original Business Question:** ```{original_business_question}``` 24 | 2. **Specific Question Answered by Data:** ```{question_that_sql_result_can_answer}``` 25 | 3. **SQL Query Used:** 26 | ```sql 27 | {sql_code} 28 | ``` 29 | 4. **Resulting schema (from dataframe):** 30 | ``` 31 | {columns_string} 32 | ``` 33 | 5. **Resulting Data Preview (first {dataframe_preview_len} rows):** 34 | ``` 35 | {dataframe_head} 36 | ``` 37 | 6. **Total Rows in Result:** `{dataframe_len}` 38 | 39 | **Your Task:** 40 | Generate a complete Vega-Lite chart (as json text) that effectively visualizes the provided data to answer the `{question_that_sql_result_can_answer}`. 41 | 42 | **Approach:** 43 | 44 | 1. Create a design of the chart and how every piece is connected to the data. 45 | 2. Generate correct Vega-Lite code that implements the plan. 46 | 47 | **Key Requirements & Rules:** 48 | 49 | 1. **Chart Type Selection:** 50 | * Choose the most appropriate chart types based on the data structure (column types, number of rows: `{dataframe_len}`) and the question being answered. 51 | * If number of rows is over 50 (it is {dataframe_len}), then consider grouping, aggregating or clustering by some other way. You may also consider a bubble chart in such cases. 52 | * Number of data rows is `{dataframe_len}`. If it's `1`, generate a "text" mark displaying the key metric(s) with clear descriptive label(s). 53 | 54 | 2. **Data Encoding:** 55 | * **Use Provided Data:** Map columns directly from the provided data (`Resulting Data Preview` above shows available column names). 56 | * **Prioritize Readability:** Use descriptive entity names (e.g., `CustomerName`) for axes, legends, and tooltips instead of identifiers (e.g., `CustomerID`) whenever available. Look for columns ending in `Name`, `Label`, `Category`, etc. 57 | * **Correct Data Types:** Accurately map data columns to Vega-Lite types (`quantitative`, `temporal`, `nominal`, `ordinal`). 58 | * **Parameters:** Add selection parameters for geographical or "categorical" dimensions in cases when respective dimensions are present in data: 59 | * Name these parameters as f"{{dimension_column_name}}__selection" (`dimension_column_name` part is case-sensitive). 60 | * Add filter transform as {{"filter": f"datum.{{dimension_column_name}} === {{dimension_column_name}}__selection || {{dimension_column_name}}__selection == null"}} 61 | * Only one filter node per dimension is allowed. Only "input": "select" is allowed. 62 | * Do not allow any other transforms on such dimensions. 63 | * Remember that chosen dimension may have more values that you see. 64 | * Prioritize geographical dimensions. 65 | * Avoid directly referring to `bind.options` as a whole. 66 | * **Axes & Legends:** 67 | * Use clear, concise axis titles and legend titles. 68 | * Ensure legends accurately represent the encoding (e.g., color, shape) and include units (e.g., "$", "%", "Count") if applicable and known. 69 | * Format axes appropriately (e.g., date formats, currency formats). 70 | 71 | 3. **Data Transformation & Refinement:** 72 | * **Data representation:** Do not use "dataset" element, only "data". 73 | * **Sorting:** Apply meaningful sorting (e.g., bars by value descending/ascending, time series chronologically) to enhance interpretation. 74 | * **Filtering:** Consider adding a `transform` filter *only if* it clarifies the visualization by removing irrelevant data (e.g., nulls, zeros if not meaningful) *without* compromising the answer to the question. 75 | * **Many rows:** If dealing with high cardinality dimensions, consider using a chart type and grouping that would make the chart easy to understand. 76 | 77 | 4. **Chart Aesthetics & Formatting:** 78 | * **Title:** Provide a clear, descriptive title for the chart that summarizes its main insight or content relevant to the question. 79 | * **Readability:** Ensure all labels (axes, data points, legends) are easily readable and do not overlap. Rotate or reposition labels if necessary. Use tooltips to show details on hover. 80 | * **Dashboard-Ready:** Design the chart to be clear and effective when viewed as part of a larger dashboard. Aim for simplicity and avoid clutter. 81 | * **Sizing and Scaling:** 82 | - Define reasonable `width` and `height` suitable for a typical dashboard component. *Minimal* width is 1152. *Minimal* height is 648. 83 | - The chart must be comfortable to view using 16 inch 4K screen with ability to zoom in and out. 84 | - Consider using `autosize: fit` properties if appropriate. 85 | - Avoid making the chart excessively large or small. 86 | - If using `vconcat` or `hconcat`, adjust width and height accordingly to accommodate all series. 87 | 88 | 5. **Strict Technical Constraints:** 89 | * **Vega-Lite Version:** MUST use Vega-Lite {vega_lite_schema_version} schema. 90 | * **Valid Syntax:** Ensure the generated JSON is syntactically correct and adheres strictly to the Vega-Lite specification. DO NOT use properties or features from other versions or invent new ones. 91 | * **Output:** `vega_lite_json` is json code based on the Vega-Lite schema below. 92 | {notes_text} 93 | 94 | 6. **Vega-Lite Schema and Library:** 95 | * Vega-Altair will be used for visualization. 96 | * You MUST strictly follow the Vega-Lite {vega_lite_schema_version} schema below. 97 | 98 | Vega-Lite {vega_lite_schema_version} schema: 99 | 100 | ```json 101 | {vega_lite_spec} 102 | ``` 103 | """ 104 | -------------------------------------------------------------------------------- /src/web/agent_runtime_client.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Agent Runtime Client""" 15 | 16 | from abc import ABC, abstractmethod 17 | import json 18 | import logging 19 | from typing import AsyncGenerator, Union, Optional 20 | from typing_extensions import override 21 | 22 | import requests 23 | 24 | from google.adk.events import Event 25 | from google.adk.sessions import Session 26 | from google.genai.types import Content, Part 27 | 28 | from pydantic import ValidationError 29 | 30 | 31 | MAX_RUN_RETRIES = 10 32 | 33 | logger = logging.getLogger(__name__) 34 | 35 | class AgentRuntime(ABC): 36 | def __init__(self, session: Session): 37 | self.session = session 38 | 39 | @abstractmethod 40 | async def stream_query(self, message: str) -> AsyncGenerator[Event, None]: 41 | pass 42 | 43 | @abstractmethod 44 | def is_streaming(self) -> bool: 45 | pass 46 | 47 | 48 | async def sse_client(url, request, headers): 49 | """ 50 | A very minimal SSE client using only the requests library. 51 | Yields the data content from SSE messages. 52 | Handles multi-line 'data:' fields for a single event. 53 | """ 54 | if not headers: 55 | headers = {} 56 | headers["Accept"] = "text/event-stream" 57 | headers["Cache-Control"] = "no-cache" 58 | try: 59 | # stream=True is essential for SSE 60 | # timeout=None can be used for very long-lived connections, 61 | # but be aware of potential indefinite blocking if server misbehaves. 62 | # A specific timeout (e.g., (3.05, 60)) for connect and read can be safer. 63 | with requests.post(url, json=request, stream=True, headers=headers, timeout=(60, 60*60*24*7)) as response: 64 | response.raise_for_status() # Raise an exception for HTTP error codes (4xx or 5xx) 65 | logger.info(f"Connected to SSE stream at {url}") 66 | 67 | current_event_data_lines = [] 68 | for line_bytes in response.iter_lines(): # iter_lines gives bytes 69 | if not line_bytes: # An empty line signifies the end of an event 70 | if current_event_data_lines: 71 | # Join accumulated data lines for the event 72 | full_data_string = "\n".join(current_event_data_lines) 73 | yield full_data_string 74 | current_event_data_lines = [] # Reset for the next event 75 | continue # Skip further processing for this empty line 76 | 77 | # Decode bytes to string (SSE is typically UTF-8) 78 | line = line_bytes.decode('utf-8') 79 | 80 | if line.startswith(':'): # Comment line, ignore 81 | continue 82 | 83 | # We are only interested in 'data:' lines for this minimal client 84 | if line.startswith('data:'): 85 | # Strip "data:" prefix and any leading/trailing whitespace from the value part 86 | data_value = line[len('data:'):].lstrip() 87 | current_event_data_lines.append(data_value) 88 | 89 | # Other SSE fields like 'event:', 'id:', 'retry:' are ignored here. 90 | 91 | # If the stream ends and there's pending data (no final empty line) 92 | if current_event_data_lines: 93 | full_data_string = "\n".join(current_event_data_lines) 94 | yield full_data_string 95 | 96 | except requests.exceptions.RequestException as e: 97 | logger.error(f"Error connecting or streaming SSE: {e}") 98 | except KeyboardInterrupt: 99 | logging.warning("SSE stream manually interrupted.") 100 | finally: 101 | logging.info("SSE client finished.") 102 | 103 | class FastAPIEngineRuntime(AgentRuntime): 104 | def __init__(self, 105 | session: Session, 106 | server_url: Optional[str] = None ): 107 | super().__init__(session) 108 | if not server_url: 109 | server_url = "http://127.0.0.1:8000" 110 | self.server_url = server_url 111 | self.streaming = False 112 | self.connection = None 113 | 114 | 115 | @override 116 | async def stream_query( 117 | self, 118 | message: Union[str, Content] 119 | ) -> AsyncGenerator[Event, None]: 120 | self.streaming = True 121 | try: 122 | if not message: 123 | content = None 124 | if message and isinstance(message, str): 125 | content = Content( 126 | parts=[ 127 | Part.from_text(text=message) 128 | ], 129 | role="user" 130 | ) 131 | else: 132 | content = message 133 | if content: 134 | content_dict = content.model_dump() 135 | else: 136 | content_dict = None 137 | request = { 138 | "app_name": self.session.app_name, 139 | "user_id": self.session.user_id, 140 | "session_id": self.session.id, 141 | "new_message": content_dict, 142 | "streaming": False 143 | } 144 | 145 | async for event_str in sse_client(f"{self.server_url}/run_sse", 146 | request=request, 147 | headers=None): 148 | try: 149 | yield Event.model_validate_json(event_str) 150 | except ValidationError as e: 151 | try: 152 | # trying to parse as if it was a json with "error" field. 153 | err_json = json.loads(event_str) 154 | if "error" in err_json: 155 | print(f"#### RUNTIME ERROR: {err_json['error']}") 156 | continue 157 | except json.JSONDecodeError: 158 | print(f"VALIDATION ERROR: {e}") 159 | print("### DATA ###:\n" + event_str) 160 | print("\n\n######################################\n\n") 161 | pass 162 | finally: 163 | self.streaming = False 164 | 165 | @override 166 | def is_streaming(self) -> bool: 167 | return self.streaming -------------------------------------------------------------------------------- /src/agents/data_agent/prompts/data_engineer.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Data Engineer prompt template.""" 15 | # flake8: noqa 16 | # pylint: disable=all 17 | 18 | system_instruction=""" 19 | **Persona:** Act as an expert Senior Data Engineer. 20 | 21 | **Core Expertise & Environment:** 22 | * **Domain:** Deep expertise in CRM data, specifically Salesforce objects, relationships, and common business processes (Sales, Service, Marketing). 23 | * **Technology Stack:** Google Cloud Platform (GCP), primarily Google BigQuery. 24 | * **Data Source:** Assume access to a BigQuery Data Warehouse containing replicated data from Salesforce CRM (e.g., tables mirroring standard objects like Account, Contact, Opportunity, Lead, Case, Task, Event, User, etc., and potentially custom objects). 25 | * **Language/Dialect:** Proficient in writing high-quality, performant SQL specifically for the Google BigQuery dialect (Standard SQL). 26 | 27 | **Key Responsibilities & Workflow:** 28 | 1. **Analyze Request:** Carefully interpret user questions, analytical tasks, or data manipulation steps. Understand the underlying business goal. 29 | 2. **Assess Feasibility:** 30 | * Critically evaluate if the request can likely be fulfilled using standard Salesforce data structures typically found in a data warehouse. 31 | * Identify potential data gaps or ambiguities based on the request and common Salesforce schemas. 32 | * **Crucially:** If feasibility is uncertain or requires specific assumptions (e.g., availability of a specific field, a particular data relationship), explicitly state these assumptions or ask clarifying questions *before* generating SQL. 33 | 3. **Create a Plan:** 34 | * Asses your choice of data: tables, dimensions, metrics. Also understand their real meaning for business. 35 | * Plan how you use them in your implementation 36 | * Remember that you also have some historical data for some objects. 37 | 4. **Generate SQL:** 38 | * Produce clean, well-formatted, and efficient BigQuery SQL code. 39 | * Prioritize readability (using CTEs, meaningful aliases, comments for complex logic). 40 | * Optimize for performance within BigQuery (e.g., consider join strategies, filtering early), assuming standard table structures unless otherwise specified. 41 | * Handle potential data nuances where appropriate (e.g., NULL values, data types). 42 | * Implement necessary aggregations, partitioning, and filtering. 43 | * Refer to fields with table aliases. 44 | 5. **Explain & Justify:** Briefly explain the logic of the generated SQL, especially for complex queries. Justify design choices or assumptions made. If a request is deemed infeasible, clearly explain why based on typical data limitations. 45 | 46 | **Output Expectations:** 47 | * Primary output should be accurate and runnable BigQuery SQL code. 48 | * Include necessary explanations, assumptions, or feasibility assessments alongside the code. 49 | * Maintain a professional, precise, and helpful tone. 50 | 51 | **Constraints:** 52 | * You do not have access to live data or specific table schemas beyond general knowledge of Salesforce and BigQuery best practices. Base feasibility on common patterns. 53 | * Focus on generating SQL and explanations, not on executing queries or performing data analysis yourself. 54 | """ 55 | 56 | ################################################################################ 57 | 58 | prompt = """ 59 | **User Request:** 60 | 61 | ``` 62 | {request} 63 | ``` 64 | 65 | **Task:** 66 | 67 | Analyze the request and generate the BigQuery SQL query required to fulfill it. 68 | Adhere strictly to the context, rules, and schema provided below. 69 | If the request seems infeasible with the given schema or requires significant assumptions, 70 | state them clearly before providing the SQL. 71 | 72 | **Context & Rules:** 73 | 74 | 0. **Style:** 75 | * Do not over-complicate SQL. Make it easy to read. 76 | * When using complex expressions, pay attention to how it actually works. 77 | 78 | 1. **Target Environment:** 79 | * BigQuery Project ID: `{data_project_id}` 80 | * BigQuery Dataset: `{dataset}` 81 | * **Constraint:** You MUST fully qualify all table names (e.g., `{data_project_id}.{dataset}.YourTable`). 82 | 83 | 2. **Currency Conversion (Mandatory if handling multi-currency monetary values):** 84 | * **Objective:** Convert amounts to US Dollars (USD). 85 | * **Table:** Use `{data_project_id}.{dataset}.DatedConversionRate`. 86 | * **Logic:** 87 | * Join using the currency identifier (`IsoCode` column in `DatedConversionRate`). 88 | * Filter rates based on the relevant date from your primary data, ensuring it falls between `StartDate` (inclusive) and `NextStartDate` (exclusive) in `DatedConversionRate`. 89 | * Calculate USD amount: `OriginalAmount / ConversionRate`. (Note: `ConversionRate` is defined as `USD / IsoCode`). 90 | 91 | 3. **Geographical Dimension Handling (Apply ONLY if filtering or grouping on these dimensions):** 92 | * **Principle:** Account for common variations in geographical names. 93 | * **Countries:** Use multiple forms including ISO codes (e.g., `Country IN ('US', 'USA', 'United States')`). 94 | * **States/Provinces:** Use multiple forms including abbreviations (e.g., `State IN ('FL', 'Florida')`, `State IN ('TX', 'Texas')`). 95 | * **Multiple Values:** Combine all forms when checking multiple locations (e.g., `State IN ('TX', 'Texas', 'FL', 'Florida')`). 96 | 97 | 4. **Filtering on dimensions:*** 98 | * **Value semantics:** Whenever filtering on a column with `possible_values` property, make sure you map your filter values to one or more values from `possible_values`. 99 | * **Arbitrary text values:** Avoid filtering on text columns using arbitrary values (not one of `possible_values` or when `possible_values` is missing) unless such value is given by the user. 100 | 101 | 5. **Data Schema:** 102 | * The authoritative source for available tables and columns is the JSON structure below. 103 | * **Constraint:** ONLY use tables and columns defined within this schema. 104 | 105 | 6. **Multi-statement SQL queries:** 106 | * If a multi-statement query is necessary, you must construct the query the way that all intended results in the last statement. 107 | 108 | **Output:** 109 | Provide the complete and runnable BigQuery SQL query. Include brief explanations for complex logic or any assumptions made. 110 | 111 | **Schema Definition:** 112 | 113 | Each item in the dictionary below represents a table in BigQuery. 114 | Keys - table names. 115 | In values, `salesforce_name` is Salesforce.com object name of the respective Salesforce object. 116 | `salesforce_label` - UI Label in Salesforce.com. 117 | `columns` - detailed columns definitions. 118 | 119 | ```json 120 | {sfdc_metadata} 121 | ``` 122 | """ -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # CRM Data Q&A Agent - Advanced RAG with NL2SQL over Salesforce Data 2 | 3 | | | | 4 | | ------------ | ------------- | 5 | | | This is a 📊 Data Analytics Agent that grounds its conversation in Salesforce data replicated to a Data Warehouse in BigQuery. | 6 | 7 | The agent demonstrates an advanced [Retrieval-Augmented Generation](https://cloud.google.com/use-cases/retrieval-augmented-generation) workflow 8 | in a multi-agentic system with contextualized Natural-Language-to-SQL 9 | components powered by Long Context and In-Context Learning capabilities of [Gemini 2.5 Pro](https://deepmind.google/technologies/gemini). 10 | 11 | 🚀 **Blog post**: [Forget vibe coding, vibe Business Intelligence is here!](https://medium.com/@vladkol_eqwu/business-intelligence-in-ai-era-how-agents-and-gemini-unlock-your-data-ce158081c678) 12 | 13 | The agent is built with [Agent Development Kit](https://google.github.io/adk-docs/). 14 | 15 | * The agent interprets questions about state of the business how it's reflected in CRM rather than directly referring to Salesforce data entities. 16 | * It generates SQL query to gather data necessary for answering the question 17 | * It creates interactive [Vega-Lite](https://vega.github.io/vega-lite/) diagrams. 18 | * It analyzes the results, provides key insights and recommended actions. 19 | 20 | 21 | What are our best lead source in every country? 22 | 23 | 24 | ## Agent Development Kit 25 | 26 | 27 | 28 | The agent is built using [Agent Development Kit](https://google.github.io/adk-docs/) (ADK) - a flexible 29 | and modular framework for developing and deploying AI agents. 30 | 31 | The sample also demonstrates: 32 | 33 | * How to build a Web UI for ADK-based data agents using [streamlit](https://streamlit.io/). 34 | * How to use [Artifact Services](https://google.github.io/adk-docs/artifacts/) with ADK. 35 | * How to stream and interpret session [events](https://google.github.io/adk-docs/events/). 36 | * How to create and use a custom [Session Service](https://google.github.io/adk-docs/sessions/session/). 37 | 38 | ## 🕵🏻‍♀️ Simple questions are complex 39 | 40 | Top 5 customers by impact in the US this year 41 | 42 | ### Examples of questions the agent can answer 43 | 44 | * "Top 5 customers in every country" 45 | * "What are our best lead sources?" 46 | * or more specific "What are our best lead sources by value?" 47 | * Lead conversion trends in the US. 48 | 49 | ### High-Level Design 50 | 51 | Top 5 customers in every country 52 | 53 | ## 🚀 Deploy and Run 54 | 55 | To deploy the sample with demo data to a publicly available Cloud Run service, 56 | use `Run on Google Cloud` button below. 57 | 58 | [![Run on Google Cloud](https://deploy.cloud.run/button.svg)](https://console.cloud.google.com/cloudshell/?cloudshell_git_repo=https://github.com/vladkol/crm-data-agent&cloudshell_image=gcr.io/cloudrun/button&show=terminal&utm_campaign=CDR_0xc245fc42_default_b417442301&utm_medium=external&utm_source=blog) 59 | 60 | You need a Google Cloud Project with a [Billing Account](https://console.cloud.google.com/billing?utm_campaign=CDR_0xc245fc42_default_b417442301&utm_medium=external&utm_source=blog). 61 | 62 | ### Manual deployment 63 | 64 | * Clone this repository: 65 | 66 | ```bash 67 | git clone https://github.com/vladkol/crm-data-agent && cd crm-data-agent 68 | ``` 69 | 70 | * Create a Python virtual Environment 71 | 72 | > [uv](https://docs.astral.sh/uv/) makes it easy: `uv venv .venv --python 3.11 && source .venv/bin/activate` 73 | 74 | * Install dependencies 75 | 76 | `pip install -r src/requirements.txt` 77 | 78 | or, with `uv`: 79 | 80 | `uv pip install -r src/requirements.txt` 81 | 82 | * Create `.env` file in `src` directory. Set configuration values as described below. 83 | 84 | > [src/.env-template](src/.env-template) is a template to use for your `.env` file. 85 | 86 | ### Configuration variables 87 | 88 | > `src/.env` must be created and variables specified before taking further steps in deployment, local or cloud. 89 | 90 | `GOOGLE_CLOUD_PROJECT` - [REQUIRED] Project Id of a Google Cloud Project that will be used with Vertex AI (and Cloud Run if deployed). 91 | 92 | `GOOGLE_CLOUD_LOCATION` - [REQUIRED] Google Cloud Region to use with Vertex AI (and Cloud Run if deployed). 93 | 94 | `AI_STORAGE_BUCKET` - [REQUIRED] Cloud Storage Bucket for ADK Asset Service and for staging Vertex AI assets. 95 | Please create one. 96 | 97 | `BQ_LOCATION` - [REQUIRED] BigQuery location of the Salesforce datasets. 98 | 99 | `SFDC_BQ_DATASET` - [REQUIRED] Name of the Salesforce dataset (in project *SFDC_DATA_PROJECT_ID*). 100 | 101 | `FIRESTORE_SESSION_DATABASE` - [REQUIRED] Name of a Firestore database. Please create one. ADK will store its session data here. 102 | 103 | `BQ_PROJECT_ID` - *[OPTIONAL]* Project Id of a Google Cloud Project that will be used for running BigQuery query jobs. If not defined, `GOOGLE_CLOUD_PROJECT` value will be used. 104 | 105 | `SFDC_DATA_PROJECT_ID` - *[OPTIONAL]* Project Id of a Google Cloud Project of the Salesforce dataset. 106 | If not defined, `BQ_LOCATION` value will be used. 107 | 108 | `SFDC_METADATA_FILE` - *[OPTIONAL]* Salesforce Metadata file (do not change this value if using the demo data). 109 | 110 | > If you are deploying a demo, do not set `BQ_PROJECT_ID` and `SFDC_DATA_PROJECT_ID`. 111 | > All resources will be created in GOOGLE_CLOUD_PROJECT project. 112 | 113 | **If you deploy the agent to Cloud Run**, its service account must have the following roles: 114 | 115 | * BigQuery Job User (`roles/bigquery.jobUser`) in BQ_PROJECT_ID project (or GOOGLE_CLOUD_PROJECT, if BQ_PROJECT_ID is not defined). 116 | * BigQuery Data Viewer (`roles/bigquery.dataViewer`) for SFDC_BQ_DATASET dataset. 117 | * Storage Object User (`roles/storage.objectUser`) for AI_STORAGE_BUCKET bucket. 118 | * Vertex AI User (`roles/aiplatform.user`) in GOOGLE_CLOUD_PROJECT project. 119 | 120 | ### Enable APIs in your project 121 | 122 | ```bash 123 | gcloud services enable \ 124 | aiplatform.googleapis.com \ 125 | cloudbuild.googleapis.com \ 126 | run.googleapis.com \ 127 | firestore.googleapis.com \ 128 | bigquery.googleapis.com \ 129 | 130 | --project=[GOOGLE_CLOUD_PROJECT] 131 | ``` 132 | 133 | > Replace `[GOOGLE_CLOUD_PROJECT]` with GOOGLE_CLOUD_PROJECT value you put in `src/.env` file. 134 | 135 | ### Deploy Salesforce Data 136 | 137 | #### Demo data 138 | 139 | Run `utils/deploy_demo_data.py` script. 140 | 141 | > **Note**: Demo data contains records dated 2020-2022. If you ask questions with "last year" or "6 months ago", they will likely return no data. 142 | 143 | #### Real Salesforce Data 144 | 145 | Create a [BigQuery Data Transfer for Salesforce](https://cloud.google.com/bigquery/docs/salesforce-transfer). 146 | 147 | Make sure you transfer the following objects: 148 | 149 | * Account 150 | * Case 151 | * CaseHistory 152 | * Contact 153 | * CurrencyType 154 | * DatedConversionRate 155 | * Event 156 | * Lead 157 | * Opportunity 158 | * OpportunityHistory 159 | * RecordType 160 | * Task 161 | * User 162 | 163 | #### Deployment with your custom Salesforce.com metadata 164 | 165 | *COMING SOON!* 166 | 167 | This will allow you to use your customized metadata in addition to analyzing your real data replicated to BigQuery. 168 | 169 | ### Run Locally 170 | 171 | * Run `.\run_local.sh` 172 | * Open `http://localhost:8080` in your browser. 173 | 174 | #### Deploy and Run in Cloud Run 175 | 176 | * Run `.\deploy_to_cloud_run.sh` 177 | 178 | > This deployment uses default Compute Service Account for Cloud Run. 179 | To make changes in how the deployment is done, adjust `gcloud` command in [deploy_to_cloud_run.py](utils/deploy_to_cloud_run.py) 180 | 181 | **Cloud Run Authentication Note**: 182 | 183 | By default, this script deploys a Cloud Run service that requires authentication. 184 | You can switch to unauthenticated mode in [Cloud Run](https://console.cloud.google.com/run) or configure a [Load Balancer and Identity Access Proxy](https://cloud.google.com/iap/docs/enabling-cloud-run) (recommended). 185 | 186 | ## 📃 License 187 | 188 | This repository is licensed under the Apache 2.0 License - see the [LICENSE](LICENSE) file for details. 189 | 190 | ## 🗒️ Disclaimers 191 | 192 | This is not an officially supported Google product. This project is not eligible for the [Google Open Source Software Vulnerability Rewards Program](https://bughunters.google.com/open-source-security). 193 | 194 | Code and data from this repository are intended for demonstration purposes only. It is not intended for use in a production environment. 195 | -------------------------------------------------------------------------------- /src/agents/data_agent/tools/data_engineer.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Data Engineer Agent""" 15 | 16 | from functools import cache 17 | import json 18 | import os 19 | from pathlib import Path 20 | import uuid 21 | from typing import Tuple 22 | 23 | from pydantic import BaseModel 24 | 25 | from google.cloud.exceptions import BadRequest, NotFound 26 | from google.cloud.bigquery import Client, QueryJobConfig 27 | from google.genai.types import (Content, 28 | GenerateContentConfig, 29 | Part, 30 | SafetySetting) 31 | from google.adk.tools import ToolContext 32 | 33 | from .utils import get_genai_client 34 | from prompts.data_engineer import (system_instruction 35 | as data_engineer_instruction, 36 | prompt as data_engineer_prompt) 37 | from prompts.sql_correction import (instruction as sql_correction_instruction, 38 | prompt as sql_correction_prompt) 39 | 40 | DATA_ENGINEER_AGENT_MODEL_ID = "gemini-2.5-pro" # "gemini-2.5-pro-preview-05-06" 41 | SQL_VALIDATOR_MODEL_ID = "gemini-2.5-pro" # "gemini-2.5-pro-preview-05-06" 42 | _DEFAULT_METADATA_FILE = "sfdc_metadata.json" 43 | 44 | @cache 45 | def _init_environment(): 46 | global _bq_project_id, _data_project_id, _location, _dataset 47 | global _sfdc_metadata, _sfdc_metadata_dict, _sfdc_metadata 48 | 49 | _bq_project_id = os.environ["BQ_PROJECT_ID"] 50 | _data_project_id = os.environ["SFDC_DATA_PROJECT_ID"] 51 | _location = os.environ["BQ_LOCATION"] 52 | _dataset = os.environ["SFDC_BQ_DATASET"] 53 | _sfdc_metadata_path = os.environ.get("SFDC_METADATA_FILE", 54 | _DEFAULT_METADATA_FILE) 55 | if not Path(_sfdc_metadata_path).exists(): 56 | if "/" not in _sfdc_metadata_path: 57 | _sfdc_metadata_path = str(Path(__file__).parent.parent / 58 | _sfdc_metadata_path) 59 | 60 | _sfdc_metadata = Path(_sfdc_metadata_path).read_text(encoding="utf-8") 61 | _sfdc_metadata_dict = json.loads(_sfdc_metadata) 62 | 63 | # Only keep metadata for tables that exist in the dataset. 64 | _final_dict = {} 65 | client = Client(_bq_project_id, location=_location) 66 | for table in client.list_tables(f"{_data_project_id}.{_dataset}"): 67 | if table.table_id in _sfdc_metadata_dict: 68 | table_dict = _sfdc_metadata_dict[table.table_id] 69 | _final_dict[table.table_id] = table_dict 70 | table_obj = client.get_table(f"{_data_project_id}.{_dataset}." 71 | f"{table.table_id}") 72 | for f in table_obj.schema: 73 | if f.name in table_dict["columns"]: 74 | table_dict["columns"][f.name]["field_type"] = f.field_type 75 | 76 | _sfdc_metadata = json.dumps(_final_dict, indent=2) 77 | _sfdc_metadata_dict = _final_dict 78 | 79 | 80 | def _sql_validator(sql_code: str) -> Tuple[str, str]: 81 | """SQL Validator. Validates BigQuery SQL query using BigQuery client. 82 | May also change the query to correct known errors in-place. 83 | 84 | Args: 85 | sql_code (str): BigQuery SQL code to validate. 86 | 87 | Returns: 88 | tuple(str,str): 89 | str: "SUCCESS" if SQL is valid, error text otherwise. 90 | str: modified SQL code (always update your original query with it). 91 | """ 92 | print("Running SQL validator.") 93 | sql_code_to_run = sql_code 94 | for k,v in _sfdc_metadata_dict.items(): 95 | sfdc_name = v["salesforce_name"] 96 | full_name = f"`{_data_project_id}.{_dataset}.{sfdc_name}`" 97 | sql_code_to_run = sql_code_to_run.replace( 98 | full_name, 99 | f"`{_data_project_id}.{_dataset}.{k}`" 100 | ) 101 | 102 | client = Client(project=_bq_project_id, location=_location) 103 | try: 104 | dataset_location = client.get_dataset( 105 | f"{_data_project_id}.{_dataset}").location 106 | job_config = QueryJobConfig(dry_run=True, use_query_cache=False) 107 | client.query(sql_code, 108 | job_config=job_config, 109 | location=dataset_location).result() 110 | except (BadRequest, NotFound) as ex: 111 | err_text = ex.args[0].strip() 112 | return f"ERROR: {err_text}", sql_code_to_run 113 | return "SUCCESS", sql_code_to_run 114 | 115 | 116 | class SQLResult(BaseModel): 117 | sql_code: str 118 | sql_code_file_name: str 119 | error: str = "" 120 | 121 | 122 | ######## AGENT ######## 123 | async def data_engineer(request: str, tool_context: ToolContext) -> SQLResult: 124 | """ 125 | This is your Senior Data Engineer. 126 | They have extensive experience in working with CRM data. 127 | They write clean and efficient SQL in its BigQuery dialect. 128 | When given a question or a set of steps, 129 | they can understand whether the problem can be solved with the data you have. 130 | The result is a BigQuery SQL Query. 131 | """ 132 | _init_environment() 133 | prompt = data_engineer_prompt.format( 134 | request=request, 135 | data_project_id=_data_project_id, 136 | dataset=_dataset, 137 | sfdc_metadata=_sfdc_metadata 138 | ) 139 | 140 | sql_code_result = get_genai_client().models.generate_content( 141 | model=DATA_ENGINEER_AGENT_MODEL_ID, 142 | contents=Content( 143 | role="user", 144 | parts=[ 145 | Part.from_text(text=prompt) 146 | ] 147 | ), 148 | config=GenerateContentConfig( 149 | response_schema=SQLResult, 150 | response_mime_type="application/json", 151 | system_instruction=data_engineer_instruction, 152 | temperature=0.0, 153 | top_p=0.0, 154 | seed=1, 155 | safety_settings=[ 156 | SafetySetting( 157 | category="HARM_CATEGORY_DANGEROUS_CONTENT", # type: ignore 158 | threshold="BLOCK_ONLY_HIGH", # type: ignore 159 | ), 160 | ] 161 | ) 162 | ) 163 | sql_result: SQLResult = sql_code_result.parsed # type: ignore 164 | sql = sql_result.sql_code 165 | 166 | print(f"SQL Query candidate: {sql}") 167 | 168 | MAX_FIX_ATTEMPTS = 32 169 | validating_query = sql 170 | is_good = False 171 | 172 | for __ in range(MAX_FIX_ATTEMPTS): 173 | chat_session = None 174 | validator_result, validating_query = _sql_validator(validating_query) 175 | print(f"SQL Query candidate: {validating_query}") 176 | if validator_result == "SUCCESS": 177 | is_good = True 178 | break 179 | print(f"ERROR: {validator_result}") 180 | if not chat_session: 181 | chat_session = get_genai_client().chats.create( 182 | model=SQL_VALIDATOR_MODEL_ID, 183 | config=GenerateContentConfig( 184 | response_schema=SQLResult, 185 | response_mime_type="application/json", 186 | system_instruction=sql_correction_instruction.format( 187 | data_project_id=_data_project_id, 188 | dataset=_dataset, 189 | sfdc_metadata=_sfdc_metadata 190 | ), 191 | temperature=0.0, 192 | top_p=0.000001, 193 | seed=0, 194 | safety_settings=[ 195 | SafetySetting( 196 | category="HARM_CATEGORY_DANGEROUS_CONTENT", # type: ignore 197 | threshold="BLOCK_ONLY_HIGH", # type: ignore 198 | ), 199 | ] 200 | ) 201 | ) 202 | correcting_prompt = sql_correction_prompt.format( 203 | validating_query=validating_query, 204 | validator_result=validator_result 205 | ) 206 | corr_result = chat_session.send_message(correcting_prompt).parsed 207 | validating_query = corr_result.sql_code # type: ignore 208 | if is_good: 209 | print(f"Final result: {validating_query}") 210 | # sql_markdown = f"```sql\n{validating_query}\n```" 211 | sql_file_prefix = f"query_{uuid.uuid4().hex}" 212 | # await tool_context.save_artifact( 213 | # f"{sql_file_prefix}.md", 214 | # Part.from_bytes( 215 | # mime_type="text/markdown", 216 | # data=sql_markdown.encode("utf-8") 217 | # ) 218 | # ) 219 | await tool_context.save_artifact( 220 | f"{sql_file_prefix}.sql", 221 | Part.from_bytes( 222 | mime_type="text/x-sql", 223 | data=validating_query.encode("utf-8") 224 | ) 225 | ) 226 | return SQLResult( 227 | sql_code=validating_query, 228 | sql_code_file_name=f"{sql_file_prefix}.sql", 229 | ) 230 | else: 231 | return SQLResult( 232 | sql_code="-- no query", 233 | sql_code_file_name="none.sql", 234 | error=f"## Could not create a valid query in {MAX_FIX_ATTEMPTS}" 235 | " attempts.") 236 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /src/shared/firestore_session_service.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Firestore Session Service implementation""" 15 | 16 | import logging 17 | from typing import Any 18 | from typing import Optional 19 | import uuid 20 | 21 | from google.adk.events.event import Event 22 | from google.adk.sessions.base_session_service import (BaseSessionService, 23 | GetSessionConfig, 24 | ListSessionsResponse, 25 | Session, 26 | State) 27 | from google.api_core import exceptions 28 | from google.cloud.firestore import (Client, 29 | CollectionReference, 30 | DocumentReference, 31 | Query, 32 | SERVER_TIMESTAMP) 33 | 34 | 35 | logger = logging.getLogger(__name__) 36 | 37 | class FirestoreSessionService(BaseSessionService): 38 | def __init__(self, 39 | database: str, 40 | sessions_collection: str = "/", 41 | project_id: Optional[str] = None): 42 | 43 | self.client = Client(project_id, database=database) 44 | self.sessions_collection = sessions_collection 45 | 46 | @staticmethod 47 | def _clean_app_name(name: str) -> str: 48 | return name.rsplit("/", 1)[-1] 49 | 50 | 51 | def _get_session_path(self, 52 | *, 53 | app_name: str, 54 | user_id: str, 55 | session_id: str) -> str: 56 | return (f"{self.sessions_collection}" 57 | f"/agents/{FirestoreSessionService._clean_app_name(app_name)}" 58 | f"/users/{user_id}" 59 | f"/sessions/{session_id}").strip("/") 60 | 61 | def _get_session_doc(self, 62 | *, 63 | app_name: str, 64 | user_id: str, 65 | session_id: str) -> DocumentReference: 66 | sessions_collection = self._get_sessions_collection( 67 | app_name=FirestoreSessionService._clean_app_name(app_name), 68 | user_id=user_id 69 | ) 70 | return sessions_collection.document(session_id) 71 | 72 | def _get_events_collection(self, 73 | *, 74 | app_name: str, 75 | user_id: str, 76 | session_id: str) -> CollectionReference: 77 | return self._get_session_doc( 78 | app_name=FirestoreSessionService._clean_app_name(app_name), 79 | user_id=user_id, 80 | session_id=session_id 81 | ).collection("events") 82 | 83 | def _get_sessions_collection(self, 84 | *, 85 | app_name: str, 86 | user_id: str) -> CollectionReference: 87 | session_parent_path = self._get_session_path( 88 | app_name=FirestoreSessionService._clean_app_name(app_name), 89 | user_id=user_id, 90 | session_id="" 91 | ).strip("/") 92 | return self.client.collection(session_parent_path) 93 | 94 | def _delete_collection( 95 | self, 96 | coll_ref: CollectionReference, 97 | batch_size: int = 100, 98 | ): 99 | if batch_size < 1: 100 | batch_size = 1 101 | 102 | docs = coll_ref.list_documents(page_size=batch_size) 103 | deleted = 0 104 | 105 | for doc in docs: 106 | print(f"Deleting doc {doc.id} => {doc.get().to_dict()}") 107 | doc.delete() 108 | deleted = deleted + 1 109 | 110 | if deleted >= batch_size: 111 | return self._delete_collection(coll_ref, batch_size) 112 | 113 | 114 | async def create_session( 115 | self, 116 | *, 117 | app_name: str, 118 | user_id: str, 119 | state: Optional[dict[str, Any]] = None, 120 | session_id: Optional[str] = None, 121 | ) -> Session: 122 | if not session_id: 123 | session_id = uuid.uuid4().hex 124 | app_name = FirestoreSessionService._clean_app_name(app_name) 125 | logger.info(f"Creating session {app_name}/{user_id}/{session_id}.") 126 | session = Session(id=session_id, 127 | app_name=app_name, 128 | user_id=user_id, 129 | state=state or {}, 130 | events=[]) 131 | doc = self._get_session_doc( 132 | app_name=app_name, 133 | user_id=user_id, 134 | session_id=session_id 135 | ) 136 | session_dict = session.model_dump() 137 | session_dict.pop("events", None) 138 | session_dict["last_update_time"] = SERVER_TIMESTAMP 139 | session.last_update_time = doc.create( 140 | session_dict 141 | ).update_time.timestamp() # type: ignore 142 | return session 143 | 144 | async def get_session( 145 | self, 146 | *, 147 | app_name: str, 148 | user_id: str, 149 | session_id: str, 150 | config: Optional[GetSessionConfig] = None, 151 | ) -> Optional[Session]: 152 | """Gets a session.""" 153 | app_name = FirestoreSessionService._clean_app_name(app_name) 154 | logger.info(f"Loading session {app_name}/{user_id}/{session_id}.") 155 | session_doc = self._get_session_doc( 156 | app_name=app_name, 157 | user_id=user_id, 158 | session_id=session_id 159 | ) 160 | doc_obj = session_doc.get() 161 | session_dict = doc_obj.to_dict() or {} 162 | if not session_dict or "id" not in session_dict: 163 | raise FileNotFoundError( 164 | f"Session {app_name}/{user_id}/{session_id} not found." 165 | ) 166 | if "state" not in session_dict: 167 | session_dict["state"] = {} 168 | if "last_update_time" in session_dict: 169 | session_dict["last_update_time"] = session_dict[ 170 | "last_update_time" 171 | ].timestamp() 172 | # Backwards compatibility 173 | if "__STATE_::RUNNING_QUERY" in session_dict["state"]: 174 | session_dict["state"]["RUNNING_QUERY"] = session_dict.pop( 175 | "__STATE_::RUNNING_QUERY" 176 | ) 177 | session_dict = { 178 | k: v for k, v in session_dict.items() 179 | if not k.startswith("__STATE_::") 180 | } 181 | session = Session.model_validate(session_dict, strict=False) 182 | session.events = [] 183 | query = None 184 | events_collection = self._get_events_collection( 185 | app_name=app_name, 186 | user_id=user_id, 187 | session_id=session_id 188 | ) 189 | if config and config.after_timestamp: 190 | query = events_collection.where( 191 | "timestamp", 192 | ">", 193 | config.after_timestamp 194 | ).order_by("timestamp") 195 | if config and config.num_recent_events: 196 | if not query: 197 | query = events_collection.order_by("timestamp") 198 | query = query.limit_to_last(config.num_recent_events) 199 | if not query: 200 | query = events_collection.order_by("timestamp") 201 | for doc in query.stream(): 202 | session.events.append( 203 | Event.model_validate( 204 | doc.to_dict(), 205 | strict=False 206 | ) 207 | ) 208 | return session 209 | 210 | async def list_sessions( 211 | self, *, app_name: str, user_id: str 212 | ) -> ListSessionsResponse: 213 | sessions_result = [] 214 | app_name = FirestoreSessionService._clean_app_name(app_name) 215 | sessions = self._get_sessions_collection( 216 | app_name=app_name, 217 | user_id=user_id, 218 | ).order_by("last_update_time", direction=Query.DESCENDING).stream() 219 | for doc in sessions: 220 | session = Session(id=doc.id, 221 | app_name=app_name, 222 | user_id=user_id, 223 | state={}, 224 | events=[], 225 | last_update_time=doc.update_time.timestamp() 226 | ) 227 | sessions_result.append(session) 228 | return ListSessionsResponse(sessions=sessions_result) 229 | 230 | async def delete_session( 231 | self, *, app_name: str, user_id: str, session_id: str 232 | ) -> None: 233 | app_name = FirestoreSessionService._clean_app_name(app_name) 234 | logger.info(f"Deleting session {app_name}/{user_id}/{session_id}.") 235 | self._get_sessions_collection( 236 | app_name=app_name, 237 | user_id=user_id, 238 | ).document(session_id).delete() 239 | self._delete_collection( 240 | self._get_events_collection( 241 | app_name=app_name, 242 | user_id=user_id, 243 | session_id=session_id 244 | ) 245 | ) 246 | 247 | async def close_session(self, *, session: Session): 248 | """Closes a session.""" 249 | # No closed sessions supported. 250 | pass 251 | 252 | async def append_event(self, session: Session, event: Event) -> Event: 253 | """Appends an event to a session object.""" 254 | if event.partial: 255 | return event 256 | await self.__update_session_state(session, event) 257 | session.events.append(event) 258 | return event 259 | 260 | async def __update_session_state(self, session: Session, event: Event): 261 | """Updates the session state based on the event.""" 262 | collection = self._get_events_collection( 263 | app_name=session.app_name, 264 | user_id=session.user_id, 265 | session_id=session.id 266 | ) 267 | collection.document(event.id).create(event.model_dump()) 268 | if not event.actions or not event.actions.state_delta: 269 | return 270 | if not session.state: 271 | session.state = {} 272 | updated = False 273 | state_change_dict = {} 274 | for key, value in event.actions.state_delta.items(): 275 | if key.startswith(State.TEMP_PREFIX): 276 | continue 277 | state_change_dict[f"state.{key}"] = value 278 | session.state[key] = value 279 | updated = True 280 | state_change_dict["last_update_time"] = SERVER_TIMESTAMP 281 | while updated: # Writing to Firestore only if updated 282 | try: 283 | session_doc = self._get_session_doc( 284 | app_name=session.app_name, 285 | user_id=session.user_id, 286 | session_id=session.id 287 | ) 288 | session.last_update_time = session_doc.update( 289 | field_updates=state_change_dict 290 | ).update_time.timestamp() # type: ignore 291 | break 292 | except exceptions.FailedPrecondition: 293 | pass 294 | -------------------------------------------------------------------------------- /src/agents/data_agent/tools/bi_engineer.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """BI Engineer Agent""" 15 | 16 | from datetime import date, datetime 17 | from functools import cache 18 | import io 19 | import json 20 | import os 21 | 22 | from pydantic import BaseModel 23 | 24 | from google.adk.tools import ToolContext 25 | from google.genai.types import ( 26 | GenerateContentConfig, 27 | Part, 28 | SafetySetting, 29 | ThinkingConfig 30 | ) 31 | from google.cloud.bigquery import Client, QueryJobConfig 32 | from google.cloud.exceptions import BadRequest, NotFound 33 | 34 | import altair as alt 35 | from altair.vegalite.schema import core as alt_core 36 | import pandas as pd 37 | 38 | from .utils import get_genai_client 39 | from prompts.bi_engineer import prompt as bi_engineer_prompt 40 | from tools.chart_evaluator import evaluate_chart 41 | 42 | 43 | MAX_RESULT_ROWS_DISPLAY = 50 44 | BI_ENGINEER_AGENT_MODEL_ID = "gemini-2.5-pro" # "gemini-2.5-pro-preview-05-06" 45 | BI_ENGINEER_FIX_AGENT_MODEL_ID = "gemini-2.5-pro" # "gemini-2.5-pro-preview-05-06" 46 | 47 | 48 | @cache 49 | def _init_environment(): 50 | global _bq_project_id, _data_project_id, _location, _dataset 51 | _bq_project_id = os.environ["BQ_PROJECT_ID"] 52 | _data_project_id = os.environ["SFDC_DATA_PROJECT_ID"] 53 | _location = os.environ["BQ_LOCATION"] 54 | _dataset = os.environ["SFDC_BQ_DATASET"] 55 | 56 | class VegaResult(BaseModel): 57 | vega_lite_json: str 58 | 59 | 60 | def _enhance_parameters(vega_chart: dict, df: pd.DataFrame) -> dict: 61 | """ 62 | Makes sure all chart parameters with "select" equal to "point" 63 | have the same option values as respective dimensions. 64 | 65 | Args: 66 | vega_chart_json (str): _description_ 67 | df (pd.DataFrame): _description_ 68 | 69 | Returns: 70 | str: _description_ 71 | """ 72 | if "params" not in vega_chart: 73 | return vega_chart 74 | if "params" not in vega_chart or "'transform':" not in str(vega_chart): 75 | print("Cannot enhance parameters because one or " 76 | "more of these are missing: " 77 | "[params, transform]") 78 | return vega_chart 79 | print("Enhancing parameters...") 80 | params_list = vega_chart["params"] 81 | params = { p["name"]: p for p in params_list } 82 | for p in params: 83 | if not p.endswith("__selection"): 84 | continue 85 | print(f"Parameter {p}") 86 | param_dict = params[p] 87 | column_name = p.split("__selection")[0] 88 | if column_name not in df.columns: 89 | print(f"Column {column_name} not found in dataframe.") 90 | continue 91 | field_values = df[column_name].unique().tolist() 92 | if None not in field_values: 93 | field_values.insert(0, None) 94 | none_index = 0 95 | else: 96 | none_index = field_values.index(None) 97 | param_dict["value"] = None 98 | param_dict["bind"] = {"input": "select"} 99 | param_dict["bind"]["options"] = field_values 100 | field_labels = field_values.copy() 101 | field_labels[none_index] = "[All]" 102 | param_dict["bind"]["labels"] = field_labels 103 | param_dict["bind"]["name"] = column_name 104 | print(f"Yay! We can filter by {column_name} now!") 105 | return vega_chart 106 | 107 | 108 | def _create_chat(model: str, history: list, max_thinking: bool = False): 109 | return get_genai_client().chats.create( 110 | model=model, 111 | config=GenerateContentConfig( 112 | temperature=0.1, 113 | top_p=0.0, 114 | seed=256, 115 | response_schema=VegaResult, 116 | response_mime_type="application/json", 117 | safety_settings=[ 118 | SafetySetting( 119 | category="HARM_CATEGORY_DANGEROUS_CONTENT", # type: ignore 120 | threshold="BLOCK_ONLY_HIGH", # type: ignore 121 | ), 122 | ], 123 | thinking_config=( 124 | ThinkingConfig(thinking_budget=32768) if max_thinking 125 | else None), 126 | max_output_tokens=65536 127 | ), 128 | history=history 129 | ) 130 | 131 | 132 | 133 | def _fix_df_dates(df: pd.DataFrame) -> pd.DataFrame: 134 | """ 135 | Converts all columns of type date, datetime, or datetimetz in a 136 | Pandas DataFrame to ISO 8601 string format. 137 | 138 | Args: 139 | df (pd.DataFrame): The input DataFrame. 140 | 141 | Returns: 142 | pd.DataFrame: A DataFrame with date/datetime columns converted to 143 | ISO formatted strings. 144 | """ 145 | df = df.copy() # Work on a copy to avoid side effects 146 | # --- Process native pandas datetime types --- 147 | datetime_cols = df.select_dtypes( 148 | include=["datetime", "datetimetz", "dbdate"] 149 | ).columns 150 | for col in datetime_cols: 151 | # 1. Convert each value to an ISO string 152 | iso_values = df[col].apply(lambda x: x.isoformat() if pd.notnull(x) else None) 153 | # 2. Explicitly cast the column to the modern 'string' dtype 154 | df[col] = iso_values.astype("string") 155 | 156 | # --- Process object columns that might contain date/datetime objects --- 157 | object_cols = df.select_dtypes(include=['object']).columns 158 | for col in object_cols: 159 | # Heuristic to find columns that contain date/datetime objects 160 | first_valid_index = df[col].first_valid_index() 161 | if first_valid_index is not None and isinstance(df[col].loc[first_valid_index], (date, datetime)): 162 | # 1. Convert each value to an ISO string 163 | iso_values = df[col].apply( 164 | lambda x: x.isoformat() 165 | if isinstance(x, (date, datetime)) 166 | else x 167 | ) 168 | # 2. Explicitly cast the column to the modern 'string' dtype 169 | df[col] = iso_values.astype("string") 170 | return df 171 | 172 | 173 | def _json_date_serial(obj): 174 | """JSON serializer for objects not serializable by default json code""" 175 | if isinstance(obj, (datetime, date)): 176 | return obj.isoformat() 177 | raise TypeError ("Type %s not serializable" % type(obj)) 178 | 179 | def _safe_json(json_str: str) -> str: 180 | json_str = "{" + json_str.strip().split("{", 1)[-1] 181 | json_str = json_str.rsplit("}", 1)[0] + "}" 182 | json_dict = json.loads(json_str) 183 | return json.dumps(json_dict, default=_json_date_serial) 184 | 185 | async def bi_engineer_tool(original_business_question: str, 186 | question_that_sql_result_can_answer: str, 187 | sql_file_name: str, 188 | notes: str, 189 | tool_context: ToolContext) -> str: 190 | """Senior BI Engineer. Executes SQL code. 191 | 192 | Args: 193 | original_business_question (str): Original business question. 194 | question_that_sql_result_can_answer (str): 195 | Specific question or sub-question that SQL result can answer. 196 | sql_file_name (str): File name of BigQuery SQL code execute. 197 | notes (str): Important notes about the chart. Not empty only if the user stated something directly related to the chart. 198 | 199 | Returns: 200 | str: Chart image id and the result of executing the SQL code 201 | in CSV format (first 50 rows). 202 | """ 203 | _init_environment() 204 | sql_code_part = await tool_context.load_artifact(sql_file_name) 205 | sql_code = sql_code_part.inline_data.data.decode("utf-8") # type: ignore 206 | client = Client(project=_bq_project_id, location=_location) 207 | try: 208 | dataset_location = client.get_dataset( 209 | f"{_data_project_id}.{_dataset}").location 210 | job_config = QueryJobConfig(use_query_cache=False) 211 | df: pd.DataFrame = client.query(sql_code, 212 | job_config=job_config, 213 | location=dataset_location).result().to_dataframe() 214 | df = _fix_df_dates(df) 215 | except (BadRequest, NotFound) as ex: 216 | err_text = ex.args[0].strip() 217 | return f"BIGQUERY ERROR: {err_text}" 218 | 219 | if notes: 220 | notes_text = f"\n\n**Important notes about the chart:** \n{notes}\n\n" 221 | else: 222 | notes_text = "" 223 | 224 | vega_lite_spec = json.dumps( 225 | alt_core.load_schema(), 226 | indent=1, 227 | sort_keys=False 228 | ) 229 | chart_prompt = bi_engineer_prompt.format( 230 | original_business_question=original_business_question, 231 | question_that_sql_result_can_answer=question_that_sql_result_can_answer, 232 | sql_code=sql_code, 233 | notes_text=notes_text, 234 | columns_string=df.dtypes.to_string(), 235 | dataframe_preview_len=min(10,len(df)), 236 | dataframe_len=len(df), 237 | dataframe_head=df.head(10).to_string(), 238 | vega_lite_spec=vega_lite_spec, 239 | vega_lite_schema_version=alt.SCHEMA_VERSION.split(".")[0] 240 | ) 241 | 242 | vega_chart_json = "" 243 | vega_fix_chat = None 244 | while True: 245 | vega_chat = _create_chat(BI_ENGINEER_AGENT_MODEL_ID, []) 246 | chart_results = vega_chat.send_message(chart_prompt) 247 | chart_model = chart_results.parsed # type: ignore 248 | if chart_model: 249 | break 250 | chart_json = chart_model.vega_lite_json # type: ignore 251 | 252 | for _ in range(5): # 5 tries to make a good chart 253 | for _ in range(10): 254 | try: 255 | vega_dict = json.loads(_safe_json(chart_json)) # type: ignore 256 | vega_dict["data"] = {"values": []} 257 | vega_dict.pop("datasets", None) 258 | vega_chart = alt.Chart.from_dict(vega_dict) 259 | with io.BytesIO() as tmp: 260 | vega_chart.save(tmp, "png") 261 | vega_dict = _enhance_parameters(vega_dict, df) 262 | vega_chart_json = json.dumps(vega_dict, indent=1) 263 | vega_chart = alt.Chart.from_dict(vega_dict) 264 | vega_chart.data = df 265 | with io.BytesIO() as file: 266 | vega_chart.save(file, "png") 267 | error_reason = "" 268 | break 269 | except Exception as ex: 270 | message = f""" 271 | {chart_json} 272 | 273 | {df.dtypes.to_string()} 274 | 275 | You made a mistake! 276 | Fix the issues. Redesign the chart if it promises a better result. 277 | 278 | ERROR {type(ex).__name__}: {str(ex)} 279 | """.strip() 280 | error_reason = message 281 | print(message) 282 | if not vega_fix_chat: 283 | vega_fix_chat = _create_chat(BI_ENGINEER_FIX_AGENT_MODEL_ID, 284 | vega_chat.get_history(), 285 | True) 286 | print("Fixing...") 287 | chart_json = vega_fix_chat.send_message( 288 | message 289 | ).parsed.vega_lite_json # type: ignore 290 | 291 | if not error_reason: 292 | with io.BytesIO() as file: 293 | vega_chart.data = df 294 | vega_chart.save(file, "png") 295 | file.seek(0) 296 | png_data = file.getvalue() 297 | evaluate_chart_result = evaluate_chart( 298 | png_data, 299 | vega_chart_json, 300 | question_that_sql_result_can_answer, 301 | len(df), 302 | tool_context) 303 | if not evaluate_chart_result or evaluate_chart_result.is_good: 304 | break 305 | error_reason = evaluate_chart_result.reason 306 | 307 | if not error_reason: 308 | break 309 | 310 | print(f"Feedback:\n{error_reason}.\n\nWorking on another version...") 311 | history = (vega_fix_chat.get_history() 312 | if vega_fix_chat 313 | else vega_chat.get_history()) 314 | vega_chat = _create_chat(BI_ENGINEER_AGENT_MODEL_ID, history) 315 | chart_json = vega_chat.send_message(f""" 316 | Fix the chart based on the feedback. 317 | Only output Vega-Lite json. 318 | 319 | ***Feedback on the chart below** 320 | {error_reason} 321 | 322 | 323 | ***CHART** 324 | 325 | ``json 326 | {vega_chart_json} 327 | ```` 328 | """).parsed.vega_lite_json # type: ignore 329 | 330 | print(f"Done working on a chart.") 331 | if error_reason: 332 | print(f"Chart is still not good: {error_reason}") 333 | else: 334 | print("And the chart seem good to me.") 335 | data_file_name = f"{tool_context.invocation_id}.parquet" 336 | parquet_bytes = df.to_parquet() 337 | await tool_context.save_artifact(filename=data_file_name, 338 | artifact=Part.from_bytes( 339 | data=parquet_bytes, 340 | mime_type="application/parquet")) 341 | file_name = f"{tool_context.invocation_id}.vg" 342 | await tool_context.save_artifact(filename=file_name, 343 | artifact=Part.from_bytes( 344 | mime_type="application/json", 345 | data=vega_chart_json.encode("utf-8"))) 346 | with io.BytesIO() as file: 347 | vega_chart.save(file, "png", ppi=72) 348 | file.seek(0) 349 | data = file.getvalue() 350 | new_image_name = f"{tool_context.invocation_id}.png" 351 | await tool_context.save_artifact(filename=new_image_name, 352 | artifact=Part.from_bytes( 353 | mime_type="image/png", 354 | data=data)) 355 | tool_context.state["chart_image_name"] = new_image_name 356 | 357 | csv = df.head(MAX_RESULT_ROWS_DISPLAY).to_csv(index=False) 358 | if len(df) > MAX_RESULT_ROWS_DISPLAY: 359 | csv_message = f"**FIRST {MAX_RESULT_ROWS_DISPLAY} OF {len(df)} ROWS OF DATA**:" 360 | else: 361 | csv_message = "**DATA**:" 362 | 363 | return f"chart_image_id: `{new_image_name}`\n\n{csv_message}\n\n```csv\n{csv}\n```\n" 364 | -------------------------------------------------------------------------------- /metadata/sfdc_metadata_loader/sfdc_metadata_loader.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Salesforce CRM metadata extractor""" 15 | 16 | import json 17 | import pathlib 18 | import threading 19 | import typing 20 | 21 | from urllib.parse import unquote, urlparse, parse_qs 22 | 23 | from google.cloud import bigquery 24 | 25 | from simple_salesforce import Salesforce # type: ignore 26 | 27 | _system_fields_description_formats = { 28 | "Id": "Id of %s", 29 | "OwnerId": "Id of the User who owns this %s", 30 | "IsArchived": "Is this %s archived", 31 | "IsDeleted": "Is this %s deleted", 32 | "Name": "Name of %s", 33 | "Description": "Description of %s", 34 | "CreatedById": "Id of the User who created this %s", 35 | "LastModifiedById": "Id of the last User who modified this %s", 36 | } 37 | 38 | _extra_descriptions_path = "sfdc_extra_descriptions.json" 39 | 40 | from .sfdc_metadata import SFDCMetadata 41 | 42 | class SFDCMetadataBuilder(SFDCMetadata): 43 | """Salesforce CRM metadata extractor""" 44 | 45 | def __init__( 46 | self, 47 | sfdc_auth_parameters: typing.Union[str, typing.Dict[str, str]], 48 | bq_client: bigquery.Client, 49 | project_id: str, 50 | dataset_name: str, 51 | metadata_file: typing.Optional[str] = None, 52 | table_to_object_mapping: typing.Optional[typing.Dict[str, str]] = None 53 | ) -> None: 54 | """ 55 | Args: 56 | sfdc_auth_parameters (typing.Union[str, typing.Dict[str, str]]): 57 | May be string or a string dictionary. 58 | - If a string, it should be a Secret Manager secret version name 59 | (projects/PROJECT_NUMBER/secrets/SECRET_NAME/versions/latest) 60 | The secret value may be an Airflow connection string for Salesforce (salesforce://) 61 | or a json text. Json text will be converted to a dictionary 62 | (see dictionary details below). 63 | - If a dictionary, it must contain a valid combination of these parameters, 64 | as described here https://github.com/simple-salesforce/simple-salesforce/blob/1d28fa18438d3840140900d4c00799798bad57b8/simple_salesforce/api.py#L65 65 | * domain -- The domain to using for connecting to Salesforce. Use 66 | common domains, such as 'login' or 'test', or 67 | Salesforce My domain. If not used, will default to 68 | 'login'. 69 | 70 | -- Password Authentication: 71 | * username -- the Salesforce username to use for authentication 72 | * password -- the password for the username 73 | * security_token -- the security token for the username 74 | 75 | -- OAuth 2.0 Connected App Token Authentication: 76 | * consumer_key -- the consumer key generated for the user 77 | * consumer_secret -- the consumer secret generated for the user 78 | 79 | -- OAuth 2.0 JWT Bearer Token Authentication: 80 | * consumer_key -- the consumer key generated for the user 81 | 82 | Then either 83 | * privatekey_file -- the path to the private key file used 84 | for signing the JWT token 85 | OR 86 | * privatekey -- the private key to use 87 | for signing the JWT token 88 | 89 | -- Direct Session and Instance Access: 90 | * session_id -- Access token for this session 91 | 92 | Then either 93 | * instance -- Domain of your Salesforce instance, i.e. 94 | `na1.salesforce.com` 95 | OR 96 | * instance_url -- Full URL of your instance i.e. 97 | `https://na1.salesforce.com 98 | 99 | bq_client (bigquery.Client): BigQuery client 100 | project_id (str): GCP project id of BigQuery data. 101 | dataset_name (str): BigQuery dataset name. 102 | object_to_table_mapping: optional dictionary for mapping BigQuery table 103 | names to SFDC object names. 104 | """ 105 | super().__init__(project_id, dataset_name, metadata_file) 106 | self.bq_client = bq_client 107 | self.table_to_object_mapping = table_to_object_mapping 108 | 109 | if isinstance(sfdc_auth_parameters, str): 110 | # sfdc_auth_parameters is a path to a Secret Manager secret 111 | # "projects/PROJECT_NUMBER/secrets/SECRET_NAME/versions/latest" 112 | from google.cloud import secretmanager # type: ignore 113 | sm_client = secretmanager.SecretManagerServiceClient() 114 | secret_response = sm_client.access_secret_version( 115 | name=sfdc_auth_parameters) 116 | secret_payload = secret_response.payload.data.decode("utf-8") 117 | if secret_payload.startswith("salesforce://"): 118 | # Airflow connections string 119 | secret_payload = unquote( 120 | secret_payload.replace("salesforce://", "")) 121 | username = None 122 | password = "" 123 | url_parts = secret_payload.rsplit("@", 1) 124 | if len(url_parts) > 1: 125 | parsed = urlparse(url_parts[1]) 126 | username, password = url_parts[0].split(":", 1) 127 | else: 128 | parsed = urlparse(secret_payload) 129 | url_query_dict = parse_qs(parsed.query) 130 | auth_dict = {k: v[0] for k, v in url_query_dict.items()} 131 | if username: 132 | auth_dict["username"] = username 133 | auth_dict["password"] = password 134 | auth_dict["instance_url"] = f"{parsed.scheme}://{parsed.netloc}{parsed.path}" 135 | else: 136 | # Just a json string 137 | auth_dict = json.loads(secret_payload) 138 | else: 139 | # This is already a dictionary 140 | auth_dict = sfdc_auth_parameters 141 | 142 | for k in list(auth_dict.keys()): 143 | if k != k.lower() and k != "organizationId": 144 | auth_dict[k.lower()] = auth_dict.pop(k) 145 | 146 | for k in ["consumer_key", "consumer_secret", "security_token", 147 | "session_id", "instance_url", "client_id", "privatekey_file"]: 148 | no_underscore = k.replace("_", "") 149 | if no_underscore in auth_dict: 150 | auth_dict[k] = auth_dict.pop(no_underscore) 151 | 152 | if "domain" in auth_dict: 153 | if "." not in auth_dict["domain"]: 154 | auth_dict["domain"] += ".my" 155 | elif auth_dict["domain"].endswith(".salesforce.com"): 156 | auth_dict["domain"] = auth_dict["domain"].replace( 157 | ".salesforce.com", "") 158 | 159 | # auth_dict["version"] = "61.0" 160 | self.sfdc_connection = Salesforce(**auth_dict) # type: ignore 161 | self._metadata = {} 162 | self._lock = threading.Lock() 163 | 164 | def get_metadata(self) -> typing.Dict[str, typing.Any]: 165 | """Extract metadata from Salesforce CRM""" 166 | if len(self._metadata) > 0: 167 | return self._metadata 168 | 169 | with self._lock: 170 | if len(self._metadata) == 0: 171 | metadata_path = pathlib.Path(self._metadata_file_name) 172 | if metadata_path.exists(): 173 | self._metadata = json.loads( 174 | metadata_path.read_text(encoding="utf-8")) 175 | else: 176 | self._extract_metadata() 177 | self._enhance_metadata() 178 | metadata_path.write_text(json.dumps( 179 | self._metadata, indent=2)) 180 | return self._metadata 181 | 182 | def _enhance_metadata(self) -> bool: 183 | file_path = pathlib.Path(__file__).parent / pathlib.Path(_extra_descriptions_path) 184 | if not file_path.exists(): 185 | return False 186 | extra_dict = json.loads(file_path.read_text(encoding="utf-8")) 187 | for k in self._metadata.keys(): 188 | if k not in extra_dict: 189 | continue 190 | extra_cols = extra_dict[k] 191 | columns = self._metadata[k]["columns"] 192 | for fk in columns.keys(): 193 | if fk in extra_cols: 194 | columns[fk]["sfdc_description"] = extra_cols[fk] 195 | return True 196 | 197 | def _extract_metadata(self) -> bool: 198 | dataset = self.bq_client.get_dataset(self.dataset_name) 199 | tables = [] 200 | tables_light = list(self.bq_client.list_tables(dataset)) 201 | tables_names = [table.table_id for table in tables_light] 202 | tables_names_lower = [table.lower() for table in tables_names] 203 | 204 | table_metadatas = {} 205 | 206 | results = self.sfdc_connection.restful(f"sobjects") 207 | if not results or "sobjects" not in results: 208 | raise Exception(f"Invalid response from Salesforce: {results}") 209 | sobjects = results["sobjects"] 210 | for sobject in sobjects: 211 | singular = sobject["name"].lower() 212 | plural = sobject["labelPlural"].lower() 213 | if singular in tables_names_lower: 214 | index = tables_names_lower.index(singular) 215 | elif plural in tables_names_lower: 216 | index = tables_names_lower.index(plural) 217 | else: 218 | continue 219 | table_name = tables_names[index] 220 | table = self.bq_client.get_table( 221 | f"{dataset.project}.{dataset.dataset_id}.{table_name}") 222 | tables.append(table) 223 | 224 | results = self.sfdc_connection.restful( 225 | f"sobjects/{sobject['name']}/describe") 226 | if not results or "fields" not in results: 227 | raise Exception( 228 | f"Invalid response from Salesforce: {results}") 229 | 230 | table_metadata = {} 231 | table_metadata["salesforce_name"] = results["name"] 232 | table_metadata["salesforce_label"] = results["label"] 233 | table_metadata["important_notes_and_rules"] = "" 234 | table_metadata["salesforce_fields"] = results["fields"] 235 | table_metadata["bq_table"] = table 236 | table_metadatas[table_name] = table_metadata 237 | 238 | for _, table_metadata in table_metadatas.items(): 239 | table = table_metadata["bq_table"] 240 | schema = [f.to_api_repr() for f in table.schema] 241 | sfdc_fields = table_metadata["salesforce_fields"] 242 | sfdc_field_names = [f["name"] for f in sfdc_fields] 243 | table_metadata["columns"] = {} 244 | for index, f in enumerate(schema): 245 | bq_field_name = f["name"] 246 | field_complex_description = "" 247 | possible_values = [] 248 | reference = {} 249 | if bq_field_name.endswith("_Type"): 250 | # Handling polymorphic type description 251 | reference_field_name = bq_field_name[:-len( 252 | "_Type")] 253 | id_reference_filed_name = f"{reference_field_name}Id" 254 | field_full_description = ( 255 | "Type of object " 256 | f"`{id_reference_filed_name}` column refers to.") 257 | sfdc_label = f"Type of {reference_field_name}" 258 | else: 259 | if bq_field_name not in sfdc_field_names: 260 | continue 261 | sfdc_field_index = sfdc_field_names.index( 262 | bq_field_name) 263 | sfdc_field = sfdc_fields[sfdc_field_index] 264 | ref_to = sfdc_field.get("referenceTo", []) 265 | if len(ref_to) > 0: 266 | reference["refers_to"] = ref_to 267 | if len(ref_to) > 1: 268 | if sfdc_field["relationshipName"]: 269 | type_field = (sfdc_field["relationshipName"] + 270 | "_Type") 271 | field_complex_description = ( 272 | "Id of an object of one of types: [" + 273 | ",".join(ref_to) + 274 | "]. Object type is stored in " + 275 | f"`{type_field}` column.") 276 | reference["reference_type_column"] = type_field 277 | else: 278 | ref_to_name = ref_to[0] 279 | if ref_to_name == table_metadata["salesforce_name"]: 280 | field_complex_description = ( 281 | f"Id of another {ref_to_name}.") 282 | else: 283 | field_complex_description = ( 284 | f"Id of {ref_to_name}." 285 | ) 286 | if "picklistValues" in sfdc_field and len(sfdc_field["picklistValues"]) > 0: 287 | for v in sfdc_field["picklistValues"]: 288 | possible_values.append({ 289 | "value": v['value'], 290 | "value_label": v['label'] or v['value'] 291 | }) 292 | if sfdc_field["name"] in _system_fields_description_formats: 293 | field_name_long = _system_fields_description_formats[ 294 | sfdc_field["name"]] % table_metadata["salesforce_name"] 295 | else: 296 | field_name_long = ( 297 | sfdc_field["inlineHelpText"] or sfdc_field["label"]) 298 | sfdc_label = sfdc_field["label"] 299 | field_full_description = f"{sfdc_field['name']} ({field_name_long})" 300 | if field_complex_description: 301 | field_full_description += f"\n{field_complex_description}" 302 | nullable = f.get("nillable", True) 303 | field_metadata = {} 304 | field_metadata = { 305 | "field_name": bq_field_name, 306 | "field_type": f["type"], 307 | "field_label": sfdc_label, 308 | "sfdc_description": field_full_description, 309 | "is_nullable": nullable, 310 | } 311 | if len(reference) > 0: 312 | field_metadata["reference"] = reference 313 | if len(possible_values) > 0: 314 | field_metadata["possible_values"] = possible_values 315 | 316 | table_metadata["columns"][bq_field_name] = field_metadata 317 | table_metadata.pop("salesforce_fields") 318 | table_metadata.pop("bq_table") 319 | self._metadata = table_metadatas 320 | return True 321 | -------------------------------------------------------------------------------- /src/web/fast_api_app.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import asyncio 16 | from contextlib import asynccontextmanager 17 | import importlib 18 | import inspect 19 | import logging 20 | import os 21 | from pathlib import Path 22 | import sys 23 | import traceback 24 | import typing 25 | from typing import Any 26 | from typing import List 27 | from typing import Literal 28 | from typing import Optional 29 | 30 | from fastapi import FastAPI 31 | from fastapi import HTTPException 32 | from fastapi import Query 33 | from fastapi.middleware.cors import CORSMiddleware 34 | from fastapi.responses import StreamingResponse 35 | from fastapi.websockets import WebSocket 36 | from fastapi.websockets import WebSocketDisconnect 37 | from google.genai import types 38 | from opentelemetry import trace 39 | from opentelemetry.exporter.cloud_trace import CloudTraceSpanExporter 40 | from opentelemetry.sdk.trace import export 41 | from opentelemetry.sdk.trace import ReadableSpan 42 | from opentelemetry.sdk.trace import TracerProvider 43 | from pydantic import BaseModel 44 | from pydantic import ValidationError 45 | from starlette.types import Lifespan 46 | 47 | from google.adk.agents import RunConfig 48 | from google.adk.agents.live_request_queue import LiveRequest 49 | from google.adk.agents.live_request_queue import LiveRequestQueue 50 | from google.adk.agents.llm_agent import Agent 51 | from google.adk.agents.run_config import StreamingMode 52 | from google.adk.artifacts import BaseArtifactService, InMemoryArtifactService 53 | from google.adk.events.event import Event 54 | from google.adk.memory import InMemoryMemoryService 55 | from google.adk.runners import Runner 56 | from google.adk.sessions import Session 57 | 58 | sys.path.append(str(Path(__file__).parent.parent)) 59 | from shared.firestore_session_service import ( 60 | FirestoreSessionService as SessionService 61 | ) 62 | 63 | logger = logging.getLogger(__name__) 64 | 65 | class ApiServerSpanExporter(export.SpanExporter): 66 | 67 | def __init__(self, trace_dict): 68 | self.trace_dict = trace_dict 69 | 70 | def export( 71 | self, spans: typing.Sequence[ReadableSpan] 72 | ) -> export.SpanExportResult: 73 | for span in spans: 74 | if ( 75 | span.name == "call_llm" 76 | or span.name == "send_data" 77 | or span.name.startswith("tool_response") 78 | ): 79 | attributes = dict(span.attributes) # type: ignore 80 | attributes["trace_id"] = span.get_span_context().trace_id # type: ignore 81 | attributes["span_id"] = span.get_span_context().span_id # type: ignore 82 | return export.SpanExportResult.SUCCESS 83 | 84 | def force_flush(self, timeout_millis: int = 30000) -> bool: 85 | return True 86 | 87 | 88 | class AgentRunRequest(BaseModel): 89 | app_name: str 90 | user_id: str 91 | session_id: str 92 | new_message: types.Content 93 | streaming: bool = False 94 | 95 | 96 | def get_fast_api_app( 97 | *, 98 | agent_dir: str, 99 | allow_origins: Optional[list[str]] = None, 100 | trace_to_cloud: bool = False, 101 | lifespan: Optional[Lifespan[FastAPI]] = None, 102 | artifact_service: Optional[BaseArtifactService] = None 103 | ) -> FastAPI: 104 | # InMemory tracing dict. 105 | trace_dict: dict[str, Any] = {} 106 | 107 | # Set up tracing in the FastAPI server. 108 | provider = TracerProvider() 109 | provider.add_span_processor( 110 | export.SimpleSpanProcessor(ApiServerSpanExporter(trace_dict)) 111 | ) 112 | if trace_to_cloud: 113 | if project_id := os.environ.get("GOOGLE_CLOUD_PROJECT", None): 114 | processor = export.BatchSpanProcessor( 115 | CloudTraceSpanExporter(project_id=project_id) 116 | ) 117 | provider.add_span_processor(processor) 118 | else: 119 | logging.warning( 120 | "GOOGLE_CLOUD_PROJECT environment variable is not set. Tracing will" 121 | " not be enabled." 122 | ) 123 | 124 | trace.set_tracer_provider(provider) 125 | 126 | exit_stacks = [] 127 | 128 | @asynccontextmanager 129 | async def internal_lifespan(app: FastAPI): 130 | if lifespan: 131 | async with lifespan(app) as lifespan_context: 132 | yield 133 | 134 | if exit_stacks: 135 | for stack in exit_stacks: 136 | await stack.aclose() 137 | else: 138 | yield 139 | 140 | # Run the FastAPI server. 141 | app = FastAPI(lifespan=internal_lifespan) 142 | 143 | if allow_origins: 144 | app.add_middleware( 145 | CORSMiddleware, 146 | allow_origins=allow_origins, 147 | allow_credentials=True, 148 | allow_methods=["*"], 149 | allow_headers=["*"], 150 | ) 151 | 152 | if agent_dir not in sys.path: 153 | sys.path.append(agent_dir) 154 | 155 | runner_dict = {} 156 | root_agent_dict = {} 157 | 158 | # Build the Artifact service 159 | artifact_service = artifact_service or InMemoryArtifactService() 160 | memory_service = InMemoryMemoryService() 161 | 162 | # Build the Session service 163 | session_service = SessionService( 164 | database=os.environ["FIRESTORE_SESSION_DATABASE"], 165 | sessions_collection=os.getenv("FIRESTORE_SESSION_COLLECTION", "/") 166 | ) 167 | 168 | 169 | @app.get("/debug/trace/{event_id}") 170 | def get_trace_dict(event_id: str) -> Any: 171 | event_dict = trace_dict.get(event_id, None) 172 | if event_dict is None: 173 | raise HTTPException(status_code=404, detail="Trace not found") 174 | return event_dict 175 | 176 | @app.get( 177 | "/apps/{app_name}/users/{user_id}/sessions/{session_id}", 178 | response_model_exclude_none=True, 179 | ) 180 | async def get_session(app_name: str, user_id: str, session_id: str) -> Session: 181 | session = await session_service.get_session( 182 | app_name=app_name, user_id=user_id, session_id=session_id 183 | ) 184 | if not session: 185 | raise HTTPException(status_code=404, detail="Session not found") 186 | return session 187 | 188 | @app.get( 189 | "/apps/{app_name}/users/{user_id}/sessions", 190 | response_model_exclude_none=True, 191 | ) 192 | async def list_sessions(app_name: str, user_id: str) -> list[Session]: 193 | return [ 194 | session 195 | for session in (await session_service.list_sessions( 196 | app_name=app_name, user_id=user_id 197 | )).sessions 198 | ] 199 | 200 | @app.post( 201 | "/apps/{app_name}/users/{user_id}/sessions/{session_id}", 202 | response_model_exclude_none=True, 203 | ) 204 | async def create_session_with_id( 205 | app_name: str, 206 | user_id: str, 207 | session_id: str, 208 | state: Optional[dict[str, Any]] = None, 209 | ) -> Session: 210 | app_name = app_name 211 | if ( 212 | await session_service.get_session( 213 | app_name=app_name, user_id=user_id, session_id=session_id 214 | ) 215 | is not None 216 | ): 217 | logger.warning("Session already exists: %s", session_id) 218 | raise HTTPException( 219 | status_code=400, detail=f"Session already exists: {session_id}" 220 | ) 221 | 222 | logger.info("New session created: %s", session_id) 223 | return await session_service.create_session( 224 | app_name=app_name, user_id=user_id, state=state, session_id=session_id 225 | ) 226 | 227 | @app.post( 228 | "/apps/{app_name}/users/{user_id}/sessions", 229 | response_model_exclude_none=True, 230 | ) 231 | async def create_session( 232 | app_name: str, 233 | user_id: str, 234 | state: Optional[dict[str, Any]] = None, 235 | ) -> Session: 236 | logger.info("New session created") 237 | return await session_service.create_session( 238 | app_name=app_name, user_id=user_id, state=state 239 | ) 240 | 241 | @app.delete("/apps/{app_name}/users/{user_id}/sessions/{session_id}") 242 | async def delete_session(app_name: str, user_id: str, session_id: str): 243 | await session_service.delete_session( 244 | app_name=app_name, user_id=user_id, session_id=session_id 245 | ) 246 | 247 | @app.get( 248 | "/apps/{app_name}/users/{user_id}/sessions/{session_id}/artifacts/{artifact_name}", 249 | response_model_exclude_none=True, 250 | ) 251 | async def load_artifact( 252 | app_name: str, 253 | user_id: str, 254 | session_id: str, 255 | artifact_name: str, 256 | version: Optional[int] = Query(None), 257 | ) -> Optional[types.Part]: 258 | artifact = await artifact_service.load_artifact( 259 | app_name=app_name, 260 | user_id=user_id, 261 | session_id=session_id, 262 | filename=artifact_name, 263 | version=version, 264 | ) 265 | if not artifact: 266 | raise HTTPException(status_code=404, detail="Artifact not found") 267 | return artifact 268 | 269 | @app.get( 270 | "/apps/{app_name}/users/{user_id}/sessions/{session_id}/artifacts/{artifact_name}/versions/{version_id}", 271 | response_model_exclude_none=True, 272 | ) 273 | async def load_artifact_version( 274 | app_name: str, 275 | user_id: str, 276 | session_id: str, 277 | artifact_name: str, 278 | version_id: int, 279 | ) -> Optional[types.Part]: 280 | artifact = await artifact_service.load_artifact( 281 | app_name=app_name, 282 | user_id=user_id, 283 | session_id=session_id, 284 | filename=artifact_name, 285 | version=version_id, 286 | ) 287 | if not artifact: 288 | raise HTTPException(status_code=404, detail="Artifact not found") 289 | return artifact 290 | 291 | @app.get( 292 | "/apps/{app_name}/users/{user_id}/sessions/{session_id}/artifacts", 293 | response_model_exclude_none=True, 294 | ) 295 | async def list_artifact_names( 296 | app_name: str, user_id: str, session_id: str 297 | ) -> list[str]: 298 | return await artifact_service.list_artifact_keys( 299 | app_name=app_name, user_id=user_id, session_id=session_id 300 | ) 301 | 302 | @app.get( 303 | "/apps/{app_name}/users/{user_id}/sessions/{session_id}/artifacts/{artifact_name}/versions", 304 | response_model_exclude_none=True, 305 | ) 306 | async def list_artifact_versions( 307 | app_name: str, user_id: str, session_id: str, artifact_name: str 308 | ) -> list[int]: 309 | return await artifact_service.list_versions( 310 | app_name=app_name, 311 | user_id=user_id, 312 | session_id=session_id, 313 | filename=artifact_name, 314 | ) 315 | 316 | @app.delete( 317 | "/apps/{app_name}/users/{user_id}/sessions/{session_id}/artifacts/{artifact_name}", 318 | ) 319 | async def delete_artifact( 320 | app_name: str, user_id: str, session_id: str, artifact_name: str 321 | ): 322 | await artifact_service.delete_artifact( 323 | app_name=app_name, 324 | user_id=user_id, 325 | session_id=session_id, 326 | filename=artifact_name, 327 | ) 328 | 329 | @app.post("/run", response_model_exclude_none=True) 330 | async def agent_run(req: AgentRunRequest) -> list[Event]: 331 | session = await session_service.get_session( 332 | app_name=req.app_name, user_id=req.user_id, session_id=req.session_id 333 | ) 334 | if not session: 335 | raise HTTPException(status_code=404, detail="Session not found") 336 | runner = await _get_runner_async(req.app_name) 337 | events = [ 338 | event 339 | async for event in runner.run_async( 340 | user_id=req.user_id, 341 | session_id=req.session_id, 342 | new_message=req.new_message, 343 | ) 344 | ] 345 | logger.info("Generated %s events in agent run: %s", len(events), events) 346 | return events 347 | 348 | @app.post("/run_sse") 349 | async def agent_run_sse(req: AgentRunRequest) -> StreamingResponse: 350 | # SSE endpoint 351 | session = await session_service.get_session( 352 | app_name=req.app_name, user_id=req.user_id, session_id=req.session_id 353 | ) 354 | if not session: 355 | raise HTTPException(status_code=404, detail="Session not found") 356 | 357 | # Convert the events to properly formatted SSE 358 | async def event_generator(): 359 | try: 360 | stream_mode = (StreamingMode.SSE if req.streaming 361 | else StreamingMode.NONE) 362 | runner = await _get_runner_async(req.app_name) 363 | async for event in runner.run_async( 364 | user_id=req.user_id, 365 | session_id=req.session_id, 366 | new_message=req.new_message, 367 | run_config=RunConfig(streaming_mode=stream_mode), 368 | ): 369 | # Format as SSE data 370 | sse_event = event.model_dump_json(exclude_none=True, by_alias=True) 371 | logger.info("Generated event in agent run streaming: %s", sse_event) 372 | yield f"data: {sse_event}\n\n" 373 | except Exception as e: 374 | logger.exception("Error in event_generator: %s", e) 375 | # You might want to yield an error event here 376 | yield f'data: {{"error": "{str(e)}"}}\n\n' 377 | 378 | # Returns a streaming response with the proper media type for SSE 379 | return StreamingResponse( 380 | event_generator(), 381 | media_type="text/event-stream", 382 | ) 383 | 384 | 385 | @app.websocket("/run_live") 386 | async def agent_live_run( 387 | websocket: WebSocket, 388 | app_name: str, 389 | user_id: str, 390 | session_id: str, 391 | modalities: List[Literal["TEXT", "AUDIO"]] = Query( 392 | default=["TEXT", "AUDIO"] 393 | ), # Only allows "TEXT" or "AUDIO" 394 | ) -> None: 395 | await websocket.accept() 396 | session = await session_service.get_session( 397 | app_name=app_name, user_id=user_id, session_id=session_id 398 | ) 399 | if not session: 400 | # Accept first so that the client is aware of connection establishment, 401 | # then close with a specific code. 402 | await websocket.close(code=1002, reason="Session not found") 403 | return 404 | 405 | live_request_queue = LiveRequestQueue() 406 | 407 | async def forward_events(): 408 | runner = await _get_runner_async(app_name) 409 | async for event in runner.run_live( 410 | session=session, live_request_queue=live_request_queue 411 | ): 412 | await websocket.send_text( 413 | event.model_dump_json(exclude_none=True, by_alias=True) 414 | ) 415 | 416 | async def process_messages(): 417 | try: 418 | while True: 419 | data = await websocket.receive_text() 420 | # Validate and send the received message to the live queue. 421 | live_request_queue.send(LiveRequest.model_validate_json(data)) 422 | except ValidationError as ve: 423 | logger.error("Validation error in process_messages: %s", ve) 424 | 425 | # Run both tasks concurrently and cancel all if one fails. 426 | tasks = [ 427 | asyncio.create_task(forward_events()), 428 | asyncio.create_task(process_messages()), 429 | ] 430 | done, pending = await asyncio.wait( 431 | tasks, return_when=asyncio.FIRST_EXCEPTION 432 | ) 433 | try: 434 | # This will re-raise any exception from the completed tasks. 435 | for task in done: 436 | task.result() 437 | except WebSocketDisconnect: 438 | logger.info("Client disconnected during process_messages.") 439 | except Exception as e: 440 | logger.exception("Error during live websocket communication: %s", e) 441 | traceback.print_exc() 442 | WEBSOCKET_INTERNAL_ERROR_CODE = 1011 443 | WEBSOCKET_MAX_BYTES_FOR_REASON = 123 444 | await websocket.close( 445 | code=WEBSOCKET_INTERNAL_ERROR_CODE, 446 | reason=str(e)[:WEBSOCKET_MAX_BYTES_FOR_REASON], 447 | ) 448 | finally: 449 | for task in pending: 450 | task.cancel() 451 | 452 | async def _get_root_agent_async(app_name: str) -> Agent: 453 | """Returns the root agent for the given app.""" 454 | if app_name in root_agent_dict: 455 | return root_agent_dict[app_name] 456 | agent_module_name = str( 457 | Path(agent_dir).relative_to(os.getcwd()) 458 | ).replace("/", ".") 459 | agent_module = importlib.import_module(agent_module_name) 460 | if getattr(agent_module.agent, "root_agent"): 461 | root_agent = agent_module.agent.root_agent 462 | else: 463 | raise ValueError(f'Unable to find "root_agent" from {app_name}.') 464 | 465 | # Handle an awaitable root agent and await for the actual agent. 466 | if inspect.isawaitable(root_agent): 467 | try: 468 | agent, exit_stack = await root_agent 469 | exit_stacks.append(exit_stack) 470 | root_agent = agent 471 | except Exception as e: 472 | raise RuntimeError(f"error getting root agent, {e}") from e 473 | 474 | root_agent_dict[app_name] = root_agent 475 | return root_agent 476 | 477 | async def _get_runner_async(app_name: str) -> Runner: 478 | """Returns the runner for the given app.""" 479 | if app_name in runner_dict: 480 | return runner_dict[app_name] 481 | root_agent = await _get_root_agent_async(app_name) 482 | runner = Runner( 483 | app_name=app_name, 484 | agent=root_agent, 485 | artifact_service=artifact_service, 486 | session_service=session_service, 487 | memory_service=memory_service, 488 | ) 489 | runner_dict[app_name] = runner 490 | return runner 491 | 492 | return app 493 | -------------------------------------------------------------------------------- /src/web/web.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Agent streamlit web app""" 15 | 16 | import asyncio 17 | import base64 18 | from io import BytesIO 19 | import json 20 | import logging 21 | import os 22 | import subprocess 23 | from time import time 24 | 25 | import pandas as pd 26 | import streamlit as st 27 | 28 | import yfinance as yf 29 | import matplotlib.pyplot as plt 30 | 31 | from google.genai.types import Content, Part 32 | 33 | from google.adk.events import Event, EventActions 34 | from google.adk.sessions import Session 35 | from shared.firestore_session_service import (FirestoreSessionService 36 | as SessionService) 37 | 38 | from PIL import Image 39 | 40 | from google.adk.artifacts import GcsArtifactService 41 | from agent_runtime_client import FastAPIEngineRuntime 42 | 43 | MAX_RUN_RETRIES = 10 44 | DEFAULT_USER_ID = "user@ai" 45 | DEFAULT_AGENT_NAME = "default-agent" 46 | DEFAULT_TICKERS = [ 47 | "GOOGL", 48 | "MSFT", 49 | "AMZN", 50 | "^GSPC", 51 | "^DJI", 52 | "^IXIC", 53 | ] 54 | 55 | 56 | logging.getLogger().setLevel(logging.INFO) 57 | 58 | user_agent = st.context.headers.get("User-Agent", "") 59 | if " Mobile" in user_agent: 60 | initial_sidebar_state = "collapsed" 61 | else: 62 | initial_sidebar_state = "expanded" 63 | 64 | st.set_page_config(layout="wide", 65 | page_icon=":material/bar_chart:", 66 | page_title="📊 Enterprise Data Agent", 67 | initial_sidebar_state=initial_sidebar_state) 68 | 69 | material_theme_style = """ 70 | 71 | 72 | 73 | 74 | 523 | """ 524 | st.markdown(material_theme_style, unsafe_allow_html=True) 525 | 526 | st.markdown(""" 527 | 528 |

leaderboard Chat with your Data

529 |
530 | """.strip(), unsafe_allow_html=True) 531 | st.subheader("This Agent can perform Data Analytics tasks " 532 | "over Salesforce data in BigQuery.") 533 | st.markdown("[github.com/vladkol/crm-data-agent]" 534 | "(https://goo.gle/cloud-crm-data-agent?utm_campaign=CDR_0xc245fc42_default_b417442301&utm_medium=external&utm_source=blog)") 535 | st.markdown("

Examples of questions:

", unsafe_allow_html=True) 536 | st.markdown(""" 537 | 543 | """.strip(), unsafe_allow_html=True) 544 | 545 | hide_streamlit_style = """ 546 | 559 | """ 560 | st.markdown(hide_streamlit_style, unsafe_allow_html=True) 561 | 562 | ######################### Tickers data ######################### 563 | 564 | # --- DATA FETCHING & HELPERS --- 565 | @st.cache_resource(ttl="5min", show_spinner=False) # Cache data for 5 minutes 566 | def get_ticker_data(symbols: list[str]) -> list[dict]: 567 | """Fetches historical and current data for given ticker symbols.""" 568 | results = [] 569 | try: 570 | with st.spinner(f"Getting tickers data..."): 571 | results = [] 572 | tickers = yf.Tickers(symbols) 573 | 574 | # Get historical data for the sparkline (last 7 days, 1-hour interval) 575 | _ = tickers.history( 576 | period="7d", 577 | interval="1h", 578 | ) 579 | _ = tickers.history( 580 | period="2d", 581 | ) 582 | 583 | for s in symbols: 584 | ticker = tickers.tickers[s] 585 | hist = ticker.history( 586 | period="7d", 587 | interval="1h", 588 | ) 589 | if hist.empty: 590 | continue 591 | daily_hist = ticker.history( 592 | period="2d", 593 | ) 594 | 595 | # Get data for price and change (last 2 days) 596 | if len(daily_hist) < 2: 597 | # Fallback if not enough data for change calculation 598 | price = hist["Close"].iloc[-1] 599 | change = 0 600 | percent_change = 0 601 | else: 602 | price = daily_hist["Close"].iloc[-1] 603 | prev_close = daily_hist["Close"].iloc[-2] 604 | change = price - prev_close 605 | percent_change = (change / prev_close) * 100 606 | 607 | # Use the provided friendly name or fallback to the ticker info 608 | info = ticker.info 609 | company_name = info.get('longName', s.upper()) 610 | symbol_display = info.get('symbol', s).replace('^', '') 611 | 612 | results.append({ 613 | "symbol_display": symbol_display, 614 | "name": company_name, 615 | "history": hist['Close'], 616 | "price": price, 617 | "change": change, 618 | "percent_change": percent_change 619 | }) 620 | except Exception as e: 621 | st.error(f"Error when fetching stock data", icon="⚠️") 622 | return results 623 | 624 | ######################### Event rendering ######################### 625 | # Add a callback function to handle feedback 626 | async def handle_feedback(feedback_key: str, feedback_type: str): 627 | """ 628 | This function is called when a feedback button is clicked. 629 | """ 630 | with st.spinner(text="Sending..."): 631 | if feedback_type == "like": 632 | st.toast( 633 | "Great! Thanks for your feedback!", 634 | icon=":material/thumb_up:" 635 | ) 636 | else: 637 | st.toast( 638 | "Thank you for the feedback!", 639 | icon=":material/thumb_down:" 640 | ) 641 | await st.session_state.session_service.append_event( 642 | session=st.session_state.adk_session, 643 | event=Event( 644 | author="InternalUpdater", 645 | actions= EventActions( 646 | state_delta={ 647 | f"{feedback_key}": feedback_type 648 | } 649 | ) 650 | ) 651 | ) 652 | 653 | @st.fragment 654 | def _process_function_calls(function_calls): 655 | title = f"⚡ {', '.join([fc.name for fc in function_calls])}" 656 | with st.expander(title): 657 | for fc in function_calls: 658 | title = f'**{fc.name}**' 659 | if fc.id: 660 | title += f' ({fc.id})' 661 | st.write(title) 662 | st.write(fc.args) 663 | 664 | 665 | @st.fragment 666 | def _process_function_responses(function_responses): 667 | title = f"✔️ {', '.join([fr.name for fr in function_responses])}" 668 | with st.expander(title): 669 | for fr in function_responses: 670 | title = f'**{fr.name}**' 671 | if fr.id: 672 | title += f' ({fr.id})' 673 | st.write(title) 674 | st.write(fr.response) 675 | 676 | 677 | async def _process_event(event: Event) -> bool: 678 | if not event: 679 | return False 680 | session = st.session_state.adk_session 681 | artifact_service = st.session_state.artifact_service 682 | 683 | function_calls = [] 684 | function_responses = [] 685 | 686 | if event.content and event.content.parts: 687 | content = event.content 688 | for part in content.parts: # type: ignore 689 | if part.text and not part.text.strip(): 690 | continue 691 | if part.thought and part.text: 692 | msg = '\n'.join('> %s' % f 693 | for f in part.text.strip().split()) # type: ignore 694 | else: 695 | msg = part.text or "" 696 | msg = msg.strip() 697 | if part.function_call: 698 | function_calls.append(part.function_call) 699 | elif part.function_response: 700 | function_responses.append(part.function_response) 701 | if msg: 702 | if content.role == "model": 703 | msg_role = "ai" 704 | elif content.role == "user": 705 | msg_role = "human" 706 | else: 707 | msg_role = "assistant" 708 | with st.chat_message( 709 | msg_role, 710 | avatar=(":material/person:" if msg_role == "human" 711 | else ":material/robot_2:") 712 | ): 713 | st.markdown(msg, unsafe_allow_html=True) 714 | # Add feedback buttons for AI responses 715 | if msg_role in ["ai", "assistant"]: 716 | feedback_key = f"feedback_{event.id}" 717 | 718 | feedback_status = None 719 | feedback_states = session.state 720 | if feedback_key in feedback_states: 721 | feedback_status = feedback_states[feedback_key] 722 | 723 | # Adjust column width for better spacing 724 | col1, col2, _ = st.columns([0.05, 0.05, 0.9]) 725 | like_button_type = "secondary" 726 | dislike_button_type = "secondary" 727 | 728 | with col1: 729 | if feedback_status == "like": 730 | like_button_type = "primary" 731 | if st.button( 732 | "", 733 | icon=":material/thumb_up:", 734 | type=like_button_type, 735 | key=f"like_{feedback_key}" 736 | ): 737 | feedback_status = "like" 738 | await handle_feedback(feedback_key, "like") 739 | st.rerun() 740 | 741 | with col2: 742 | if feedback_status == "dislike": 743 | dislike_button_type = "primary" 744 | if st.button( 745 | "", 746 | icon=":material/thumb_down:", 747 | type=dislike_button_type, 748 | key=f"dislike_{feedback_key}" 749 | ): 750 | feedback_status = "dislike" 751 | await handle_feedback(feedback_key, "dislike") 752 | st.rerun() 753 | 754 | 755 | 756 | 757 | 758 | if event.actions.artifact_delta: 759 | for filename, version in event.actions.artifact_delta.items(): 760 | artifact = await artifact_service.load_artifact( 761 | app_name=session.app_name, user_id=session.user_id, 762 | session_id=session.id, filename=filename, version=version 763 | ) 764 | if not artifact.inline_data or not artifact.inline_data.data: 765 | continue 766 | if (artifact.inline_data.mime_type.startswith('image/')): 767 | # skip images with the invocation id filename 768 | if filename.startswith(f"{event.invocation_id}."): 769 | continue 770 | with BytesIO(artifact.inline_data.data) as image_io: 771 | with Image.open(image_io) as img: 772 | st.image(img) 773 | elif (artifact.inline_data.mime_type == 774 | "application/vnd.vegalite.v5+json" 775 | or filename.endswith(".vg") 776 | and (artifact.inline_data.mime_type in 777 | ["application/json", "text/plain"]) 778 | ): 779 | # find a parquet file to supply the chart with data 780 | data_file_name = filename.rsplit(".", 1)[0] + ".parquet" 781 | parquet_file = await artifact_service.load_artifact( 782 | app_name=session.app_name, 783 | user_id=session.user_id, 784 | session_id=session.id, 785 | filename=data_file_name, 786 | version=version) 787 | if parquet_file and parquet_file.inline_data: 788 | pq_bytes = parquet_file.inline_data.data # type: ignore 789 | else: 790 | pq_bytes = None 791 | text = artifact.inline_data.data.decode("utf-8") 792 | chart_dict = json.loads(text) 793 | if pq_bytes: 794 | with BytesIO(pq_bytes) as pq_file: 795 | df = pd.read_parquet(pq_file) 796 | st.dataframe(df) 797 | chart_dict.pop("data", None) 798 | else: 799 | df = None 800 | st.vega_lite_chart(data=df, 801 | spec=chart_dict, 802 | use_container_width=False) 803 | elif artifact.inline_data.mime_type == "application/json": 804 | st.json(artifact.inline_data.data.decode("utf-8")) 805 | elif artifact.inline_data.mime_type == "text/markdown": 806 | st.markdown(artifact.inline_data.data.decode("utf-8"), 807 | unsafe_allow_html=True) 808 | elif artifact.inline_data.mime_type == "text/x-sql": 809 | st.markdown("```sql\n" + 810 | artifact.inline_data.data.decode("utf-8") + 811 | "\n```\n", 812 | unsafe_allow_html=True) 813 | elif artifact.inline_data.mime_type == "text/csv": 814 | st.markdown( 815 | "```csv\n" + 816 | artifact.inline_data.data.decode("utf-8") + "\n```", 817 | unsafe_allow_html=True) 818 | elif artifact.inline_data.mime_type.startswith("text/"): 819 | st.text(artifact.inline_data.data.decode("utf-8")) 820 | 821 | if function_calls: 822 | _process_function_calls(function_calls) 823 | if function_responses: 824 | _process_function_responses(function_responses) 825 | return True 826 | 827 | 828 | async def _render_chat(events): 829 | for event in events: 830 | await _process_event(event) 831 | 832 | 833 | ######################### Configuration management ######################### 834 | 835 | def _get_user_id() -> str: 836 | """Retrieves user id (email) from the environment 837 | for using with the session service 838 | 839 | Returns: 840 | str: user id for the session service 841 | """ 842 | if "agent_user_id" in st.session_state: 843 | return st.session_state["agent_user_id"] 844 | 845 | user_id = st.context.headers.get( 846 | "X-Goog-Authenticated-User-Email", "").split(":", 1)[-1] 847 | if not user_id: 848 | try: 849 | user_id = ( 850 | subprocess.check_output( 851 | ( 852 | "gcloud config list account " 853 | "--format \"value(core.account)\" " 854 | f"--project {os.environ['GOOGLE_CLOUD_PROJECT']} " 855 | "-q" 856 | ), 857 | shell=True, 858 | ) 859 | .decode() 860 | .strip() 861 | ) 862 | except subprocess.CalledProcessError: 863 | user_id = "" 864 | if not user_id: 865 | user_id = DEFAULT_USER_ID 866 | st.session_state["agent_user_id"] = user_id 867 | st.session_state["agent_user_name"] = user_id 868 | return user_id 869 | 870 | 871 | async def _initialize_configuration(): 872 | if "adk_configured" in st.session_state: 873 | return st.session_state.adk_configured 874 | agent_app_name = os.getenv("GOOGLE_CLOUD_AGENT_ENGINE_ID", 875 | os.getenv("AGENT_NAME", DEFAULT_AGENT_NAME)) 876 | vertex_ai_bucket = os.environ["AI_STORAGE_BUCKET"] 877 | session_service = SessionService( 878 | database=os.environ["FIRESTORE_SESSION_DATABASE"], 879 | sessions_collection=os.getenv("FIRESTORE_SESSION_COLLECTION", "/") 880 | ) 881 | artifact_service = GcsArtifactService( 882 | bucket_name=vertex_ai_bucket 883 | ) 884 | st.session_state.artifact_service = artifact_service 885 | st.session_state.session_service = session_service 886 | st.session_state.app_name = agent_app_name 887 | st.session_state.adk_configured = True 888 | st.session_state.last_prompt = "" 889 | 890 | 891 | 892 | ######################### Session management ######################### 893 | 894 | async def _create_session() -> Session: 895 | if "adk_session" not in st.session_state: 896 | session = await st.session_state.session_service.create_session( 897 | app_name=st.session_state.app_name, 898 | user_id=_get_user_id() 899 | ) 900 | st.session_state.adk_session = session 901 | st.session_state.all_adk_sessions = (st.session_state.all_adk_sessions 902 | or []) 903 | st.session_state.all_adk_sessions.insert(0, session) 904 | return st.session_state.adk_session 905 | 906 | 907 | async def _get_all_sessions() -> list[Session]: 908 | if "all_adk_sessions" in st.session_state: 909 | return st.session_state.all_adk_sessions 910 | sessions_response = await st.session_state.session_service.list_sessions( 911 | app_name=st.session_state.app_name, 912 | user_id=_get_user_id()) 913 | sessions = sessions_response.sessions 914 | st.session_state.all_adk_sessions = sessions or [] 915 | return sessions 916 | 917 | ### Watchlist ### 918 | 919 | def create_sparkline_svg(data, color): 920 | fig, ax = plt.subplots(figsize=(4, 1)) 921 | ax.plot(data.index, data.values, color=color, linewidth=2) 922 | ax.set_yticklabels([]); ax.set_xticklabels([]) 923 | ax.tick_params(axis='both', which='both', length=0) 924 | for spine in ax.spines.values(): spine.set_visible(False) 925 | fig.patch.set_alpha(0.0); ax.patch.set_alpha(0.0) 926 | svg_buffer = BytesIO() 927 | fig.savefig(svg_buffer, format='svg', bbox_inches='tight', pad_inches=0, transparent=True) 928 | plt.close(fig) 929 | # << ENCODING STEP >> Encode the SVG to Base64 930 | svg_base64 = base64.b64encode(svg_buffer.getvalue()).decode("utf-8") 931 | return f"data:image/svg+xml;base64,{svg_base64}" 932 | 933 | def load_watchlist(): 934 | with open(os.path.join(os.path.dirname(__file__), "images/logo.svg")) as f: 935 | svg = base64.b64encode(f.read().encode("utf-8")).decode("utf-8") 936 | st.markdown(f""" 937 | 938 |

Enterprise Data Agent

939 |
940 |
941 | """.strip(), unsafe_allow_html=True) 942 | st.markdown("### Watchlist") 943 | for data in get_ticker_data(DEFAULT_TICKERS): 944 | if data: 945 | is_positive = data['change'] >= 0 946 | color = "#26A69A" if is_positive else "#EF5350" 947 | arrow_class = "arrow-up" if is_positive else "arrow-down" 948 | change_class = "positive" if is_positive else "negative" 949 | arrow_char = "▲" if is_positive else "▼" 950 | sparkline_uri = create_sparkline_svg(data['history'], color=color) 951 | 952 | # Create the HTML structure for one ticker row 953 | html = f""" 954 |
955 |
956 |

{data['symbol_display']}

957 |

{data['name']}

958 |
959 |
960 | Sparkline chart 961 |
962 |
963 |

{data['price']:.2f}

964 |

965 | {data['percent_change']:+.2f}% 966 | {arrow_char} 967 |

968 |
969 |
970 | """ 971 | st.html(html) 972 | 973 | ######################### Agent Request Handler ######################### 974 | 975 | async def ask_agent(question: str): 976 | start = time() 977 | session = st.session_state.adk_session 978 | st.session_state.last_prompt = question 979 | content = Content(parts=[ 980 | Part.from_text(text=question) 981 | ],role="user") 982 | 983 | user_event = Event(author="user", content=content) 984 | await _render_chat([user_event]) 985 | 986 | runtime_name = os.environ["RUNTIME_ENVIRONMENT"].lower() 987 | if runtime_name == "local": 988 | runtime = FastAPIEngineRuntime(session) 989 | else: 990 | ValueError(f"`{runtime_name}` is not a valid runtime name.") 991 | 992 | model_events_cnt = 0 # Count valid model events in this run 993 | await st.session_state.session_service.append_event( 994 | session=st.session_state.adk_session, 995 | event=Event( 996 | author="user", 997 | actions=EventActions( 998 | state_delta={ 999 | "RUNNING_QUERY": True, 1000 | "user_name": st.session_state.get("agent_user_name", "") 1001 | } 1002 | ) 1003 | ) 1004 | ) 1005 | try: 1006 | st.session_state.thinking = True 1007 | for _ in range(MAX_RUN_RETRIES): 1008 | if model_events_cnt > 0: 1009 | break 1010 | async for event in runtime.stream_query(question): 1011 | # If no valid model events in this run, but got an func call error, 1012 | # retry the run 1013 | if (event.error_code 1014 | and event.error_code == "MALFORMED_FUNCTION_CALL" 1015 | and model_events_cnt == 0): 1016 | print("Retrying the run") 1017 | break 1018 | if event.content and event.content.role == "model": 1019 | model_events_cnt += 1 1020 | await _render_chat([event]) 1021 | await st.session_state.session_service.append_event( 1022 | session=st.session_state.adk_session, 1023 | event=Event( 1024 | author="user", 1025 | actions=EventActions( 1026 | state_delta={ 1027 | "RUNNING_QUERY": False 1028 | } 1029 | ) 1030 | ) 1031 | ) 1032 | # Re-retrieve the session 1033 | st.session_state.adk_session = await st.session_state.session_service.get_session( 1034 | app_name=session.app_name, 1035 | user_id=session.user_id, 1036 | session_id=session.id 1037 | ) 1038 | finally: 1039 | st.session_state.thinking = False 1040 | end = time() 1041 | st.text(f"Flow duration: {end - start:.2f}s") 1042 | 1043 | 1044 | ######################### Streamlit main flow ######################### 1045 | 1046 | async def app(): 1047 | top = st.container() 1048 | 1049 | if "adk_configured" not in st.session_state: 1050 | with st.spinner("Initializing...", show_time=False): 1051 | await _initialize_configuration() 1052 | sessions_list = await _get_all_sessions() 1053 | session_ids = [s.id for s in sessions_list] 1054 | session_service = st.session_state.session_service 1055 | current_session = None 1056 | current_index = 0 1057 | if "session" in st.query_params: 1058 | selected_session_id = st.query_params["session"] 1059 | if selected_session_id in session_ids: 1060 | selected_session_id = st.query_params["session"] 1061 | if ( 1062 | "adk_session" in st.session_state 1063 | and st.session_state.adk_session.id != selected_session_id 1064 | ): 1065 | st.session_state.pop("adk_session") 1066 | current_index = session_ids.index(selected_session_id) 1067 | else: 1068 | st.query_params.pop("session") 1069 | selected_session_id = -1 1070 | else: 1071 | selected_session_id = -1 1072 | if "adk_session" in st.session_state: 1073 | current_session = st.session_state.adk_session 1074 | elif selected_session_id != -1: 1075 | selected_session = sessions_list[current_index] 1076 | with st.spinner("Loading...", show_time=False): 1077 | current_session = await session_service.get_session( 1078 | app_name=selected_session.app_name, 1079 | user_id=selected_session.user_id, 1080 | session_id=selected_session.id 1081 | ) 1082 | st.session_state.adk_session = current_session 1083 | 1084 | if not current_session: 1085 | with st.spinner("Creating a new session...", show_time=False): 1086 | current_session = await _create_session() 1087 | st.session_state.adk_session = current_session 1088 | st.rerun() 1089 | else: 1090 | st.query_params["session"] = current_session.id 1091 | 1092 | with st.sidebar: 1093 | load_watchlist() 1094 | with st.popover("Sessions"): 1095 | if st.button("New Session"): 1096 | st.query_params["session"] = "none" 1097 | st.session_state.pop("adk_session", None) 1098 | st.rerun() 1099 | 1100 | sessions_list = st.session_state.all_adk_sessions 1101 | session_ids = [s.id for s in sessions_list] 1102 | selected_option = st.selectbox("Select a session:", 1103 | session_ids, 1104 | index=current_index) 1105 | if selected_option and selected_option != current_session.id: # type: ignore 1106 | with st.spinner("Loading...", show_time=False): 1107 | st.query_params["session"] = selected_option 1108 | st.rerun() 1109 | with top: 1110 | await _render_chat(st.session_state.adk_session.events) # type: ignore 1111 | with st.spinner("Thinking...", show_time=False): 1112 | question = st.chat_input( 1113 | "Ask a question about your data.", 1114 | disabled=st.session_state.get("thinking", False) 1115 | ) 1116 | if "question" not in current_session.state: 1117 | current_session.state["question"] = question 1118 | with top: 1119 | with st.spinner("Thinking...", show_time=True): 1120 | if question: 1121 | await ask_agent(question) 1122 | 1123 | 1124 | if __name__ == "__main__": 1125 | loop = asyncio.new_event_loop() 1126 | asyncio.set_event_loop(loop) 1127 | loop.run_until_complete(app()) 1128 | --------------------------------------------------------------------------------