├── BedrockChatUI.ipynb
├── CODE_OF_CONDUCT.md
├── CONTRIBUTING.md
├── Docker
├── Dockerfile
├── app.py
└── requirements.txt
├── LICENSE
├── README.md
├── bedrock-chat.py
├── config.json
├── images
├── JP-lab.PNG
├── chat-flow.png
├── chat-preview.JPG
├── chatbot-snip.PNG
├── chatbot4.png
├── demo.mp4
├── sg-rules.PNG
└── studio-new-launcher.png
├── install_package.sh
├── model_id.json
├── pricing.json
├── prompt
├── chat.txt
├── doc_chat.txt
├── pyspark_debug_prompt.txt
├── pyspark_tool_prompt.txt
├── pyspark_tool_system.txt
├── pyspark_tool_template.json
├── python_debug_prompt.txt
├── python_tool_prompt.txt
├── python_tool_system.txt
└── python_tool_template.json
├── req.txt
└── utils
├── athena_handler_.py
└── function_calling_utils.py
/CODE_OF_CONDUCT.md:
--------------------------------------------------------------------------------
1 | ## Code of Conduct
2 | This project has adopted the [Amazon Open Source Code of Conduct](https://aws.github.io/code-of-conduct).
3 | For more information see the [Code of Conduct FAQ](https://aws.github.io/code-of-conduct-faq) or contact
4 | opensource-codeofconduct@amazon.com with any additional questions or comments.
5 |
--------------------------------------------------------------------------------
/CONTRIBUTING.md:
--------------------------------------------------------------------------------
1 | # Contributing Guidelines
2 |
3 | Thank you for your interest in contributing to our project. Whether it's a bug report, new feature, correction, or additional
4 | documentation, we greatly value feedback and contributions from our community.
5 |
6 | Please read through this document before submitting any issues or pull requests to ensure we have all the necessary
7 | information to effectively respond to your bug report or contribution.
8 |
9 |
10 | ## Reporting Bugs/Feature Requests
11 |
12 | We welcome you to use the GitHub issue tracker to report bugs or suggest features.
13 |
14 | When filing an issue, please check existing open, or recently closed, issues to make sure somebody else hasn't already
15 | reported the issue. Please try to include as much information as you can. Details like these are incredibly useful:
16 |
17 | * A reproducible test case or series of steps
18 | * The version of our code being used
19 | * Any modifications you've made relevant to the bug
20 | * Anything unusual about your environment or deployment
21 |
22 |
23 | ## Contributing via Pull Requests
24 | Contributions via pull requests are much appreciated. Before sending us a pull request, please ensure that:
25 |
26 | 1. You are working against the latest source on the *main* branch.
27 | 2. You check existing open, and recently merged, pull requests to make sure someone else hasn't addressed the problem already.
28 | 3. You open an issue to discuss any significant work - we would hate for your time to be wasted.
29 |
30 | To send us a pull request, please:
31 |
32 | 1. Fork the repository.
33 | 2. Modify the source; please focus on the specific change you are contributing. If you also reformat all the code, it will be hard for us to focus on your change.
34 | 3. Ensure local tests pass.
35 | 4. Commit to your fork using clear commit messages.
36 | 5. Send us a pull request, answering any default questions in the pull request interface.
37 | 6. Pay attention to any automated CI failures reported in the pull request, and stay involved in the conversation.
38 |
39 | GitHub provides additional document on [forking a repository](https://help.github.com/articles/fork-a-repo/) and
40 | [creating a pull request](https://help.github.com/articles/creating-a-pull-request/).
41 |
42 |
43 | ## Finding contributions to work on
44 | Looking at the existing issues is a great way to find something to contribute on. As our projects, by default, use the default GitHub issue labels (enhancement/bug/duplicate/help wanted/invalid/question/wontfix), looking at any 'help wanted' issues is a great place to start.
45 |
46 |
47 | ## Code of Conduct
48 | This project has adopted the [Amazon Open Source Code of Conduct](https://aws.github.io/code-of-conduct).
49 | For more information see the [Code of Conduct FAQ](https://aws.github.io/code-of-conduct-faq) or contact
50 | opensource-codeofconduct@amazon.com with any additional questions or comments.
51 |
52 |
53 | ## Security issue notifications
54 | If you discover a potential security issue in this project we ask that you notify AWS/Amazon Security via our [vulnerability reporting page](http://aws.amazon.com/security/vulnerability-reporting/). Please do **not** create a public github issue.
55 |
56 |
57 | ## Licensing
58 |
59 | See the [LICENSE](LICENSE) file for our project's licensing. We will ask you to confirm the licensing of your contribution.
60 |
--------------------------------------------------------------------------------
/Docker/Dockerfile:
--------------------------------------------------------------------------------
1 | FROM public.ecr.aws/lambda/python:3.12
2 |
3 | # Copy function code
4 | COPY app.py ${LAMBDA_TASK_ROOT}
5 |
6 | # Install the function's dependencies
7 | COPY requirements.txt .
8 | RUN pip install -r requirements.txt
9 |
10 | # Set the CMD to your handler (could also be done as a parameter override outside of the Dockerfile)
11 | CMD [ "app.lambda_handler" ]
--------------------------------------------------------------------------------
/Docker/app.py:
--------------------------------------------------------------------------------
1 | import json
2 | import subprocess
3 | import os
4 | import boto3
5 | import sys
6 | import re
7 |
8 | class CodeExecutionError(Exception):
9 | pass
10 |
11 | def local_code_executy(code_string):
12 | """
13 | Execute a given Python code string in a secure, isolated environment and capture its output.
14 |
15 | Parameters:
16 | code_string (str): The Python code to be executed.
17 |
18 | Returns:
19 | dict: The output of the executed code, expected to be in JSON format.
20 |
21 | Raises:
22 | CodeExecutionError: Custom exception with detailed error information if code execution fails.
23 |
24 |
25 | Functionality:
26 | 1. Creates a temporary Python file with a unique name in the /tmp directory.
27 | 2. Writes the provided code to this temporary file.
28 | 3. Executes the temporary file using the Python interpreter.
29 | 4. Captures the output from a predefined output file ('/tmp/output.json').
30 | 5. Cleans up temporary files after execution.
31 | 6. In case of execution errors, provides detailed error information including:
32 | - The error message and traceback.
33 | - The line number where the error occurred.
34 | - Context of the code around the error line.
35 |
36 | Note:
37 | - This function assumes that the executed code writes its output to '/tmp/output.json'. This (saving the output to local) is appended to the generated code in the main application.
38 | """
39 | # Create a unique filename in /tmp
40 | temp_file_path = f"/tmp/code_{os.urandom(16).hex()}.py"
41 | output_file_path = '/tmp/output.json'
42 | try:
43 | # Write the code to the temporary file
44 | with open(temp_file_path, 'w', encoding="utf-8") as temp_file:
45 | temp_file.write(code_string)
46 |
47 | # Execute the temporary file
48 | result = subprocess.run([sys.executable, temp_file_path],
49 | capture_output=True, text=True, check=True)
50 |
51 | with open(output_file_path, 'r', encoding="utf-8") as f:
52 | output = json.load(f)
53 |
54 | # Clean up temporary files
55 | os.remove(output_file_path)
56 |
57 | return output
58 |
59 | except subprocess.CalledProcessError as e:
60 | # An error occurred during execution
61 | full_error_message = e.stderr.strip()
62 |
63 | # Extract the traceback part of the error message
64 | traceback_match = re.search(r'(Traceback[\s\S]*)', full_error_message)
65 | if traceback_match:
66 | error_message = traceback_match.group(1)
67 | else:
68 | error_message = full_error_message # Fallback to full message if no traceback found
69 |
70 | # Parse the traceback to get the line number
71 | tb_lines = error_message.split('\n')
72 | line_no = None
73 | for line in reversed(tb_lines):
74 | if temp_file_path in line:
75 | match = re.search(r'line (\d+)', line)
76 | if match:
77 | line_no = int(match.group(1))
78 | break
79 |
80 | # Construct error message with context
81 | error = f"Error: {error_message}\n"
82 | if line_no is not None:
83 | code_lines = code_string.split('\n')
84 | context_lines = 2
85 | start = max(0, line_no - 1 - context_lines)
86 | end = min(len(code_lines), line_no + context_lines)
87 | error += f"Error on line {line_no}:\n"
88 | for i, line in enumerate(code_lines[start:end], start=start+1):
89 | prefix = "-> " if i == line_no else " "
90 | error += f"{prefix}{i}: {line}\n"
91 | else:
92 | error += "Could not determine the exact line of the error.\n"
93 | error += "Full code:\n"
94 | for i, line in enumerate(code_string.split('\n'), start=1):
95 | error += f"{i}: {line}\n"
96 |
97 | raise CodeExecutionError(error)
98 |
99 | finally:
100 | # Clean up the temporary file
101 | if os.path.exists(temp_file_path):
102 | os.remove(temp_file_path)
103 |
104 |
105 |
106 | def execute_function_string(input_code):
107 | """
108 | Execute a given Python code string
109 | Parameters:
110 | input_code (dict): A dictionary containing the following keys:
111 | - 'code' (str): The Python code to be executed.
112 | - 'dataset_name' (str or list, optional): Name(s) of the dataset(s) used in the code.
113 | Returns:
114 | The result of executing the code using the local_code_executy function.
115 |
116 | """
117 | code_string = input_code['code']
118 | return local_code_executy(code_string)
119 |
120 |
121 | def put_obj_in_s3_bucket_(docs, bucket, key_prefix):
122 | """Uploads a file to an S3 bucket and returns the S3 URI of the uploaded object.
123 | Args:
124 | docs (str): The local file path of the file to upload to S3.
125 | bucket (str): S3 bucket name,
126 | key_prefix (str): S3 key prefix.
127 | Returns:
128 | str: The S3 URI of the uploaded object, in the format "s3://{bucket_name}/{file_path}".
129 | """
130 | S3 = boto3.client('s3')
131 | if isinstance(docs, str):
132 | file_name = os.path.basename(docs)
133 | file_path = f"{key_prefix}/{docs}"
134 | S3.upload_file(f"/tmp/{docs}", bucket, file_path)
135 | else:
136 | file_name = os.path.basename(docs.name)
137 | file_path = f"{key_prefix}/{file_name}"
138 | S3.put_object(Body=docs.read(), Bucket=bucket, Key=file_path)
139 | return f"s3://{bucket}/{file_path}"
140 |
141 |
142 |
143 | def lambda_handler(event, context):
144 | try:
145 | input_data = json.loads(event) if isinstance(event, str) else event
146 | iterate = input_data.get('iterate', 0)
147 | bucket = input_data.get('bucket', '')
148 | s3_file_path = input_data.get('file_path', '')
149 | print(input_data, bucket, s3_file_path, iterate)
150 | result = execute_function_string(input_data)
151 | image_holder = []
152 | plotly_holder = []
153 |
154 | if isinstance(result, dict):
155 | for item, value in result.items():
156 | if "image" in item and value is not None:
157 | if isinstance(value, list):
158 | for img in value:
159 | image_path_s3 = put_obj_in_s3_bucket_(img, bucket, s3_file_path)
160 | image_holder.append(image_path_s3)
161 | else:
162 | image_path_s3 = put_obj_in_s3_bucket_(value, bucket, s3_file_path)
163 | image_holder.append(image_path_s3)
164 | if "plotly-files" in item and value is not None: # Upload plotly objects to s3
165 | if isinstance(value, list):
166 | for img in value:
167 | image_path_s3 = put_obj_in_s3_bucket_(img, bucket, s3_file_path)
168 | plotly_holder.append(image_path_s3)
169 | else:
170 | image_path_s3 = put_obj_in_s3_bucket_(value, bucket, s3_file_path)
171 | plotly_holder.append(image_path_s3)
172 |
173 | tool_result = {
174 | "result": result,
175 | "image_dict": image_holder,
176 | "plotly": plotly_holder
177 | }
178 | print(tool_result)
179 | return {
180 | 'statusCode': 200,
181 | 'body': json.dumps(tool_result)
182 | }
183 | except Exception as e:
184 | print(e)
185 | return {
186 | 'statusCode': 500,
187 | 'body': json.dumps({'error': str(e)})
188 | }
--------------------------------------------------------------------------------
/Docker/requirements.txt:
--------------------------------------------------------------------------------
1 | boto3
2 | scikit-learn
3 | pandas
4 | numpy
5 | seaborn
6 | matplotlib
7 | s3fs
8 | openpyxl
9 | plotly>=5.0.0,<6.0.0
10 | kaleido
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT No Attribution
2 |
3 | Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy of
6 | this software and associated documentation files (the "Software"), to deal in
7 | the Software without restriction, including without limitation the rights to
8 | use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
9 | the Software, and to permit persons to whom the Software is furnished to do so.
10 |
11 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
12 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
13 | FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
14 | COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
15 | IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
16 | CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
17 |
18 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Bedrock Claude ChatBot
2 | Bedrock Chat App is a Streamlit application that allows users to interact with various LLMs on Amazon Bedrock. It provides a conversational interface where users can ask questions, upload documents, and receive responses from the AI assistant.
3 |
4 |
5 |
6 |
7 | READ THE FOLLOWING **PREREQUISITES** CAREFULLY.
8 |
9 | ## Features
10 |
11 | - **Conversational UI**: The app provides a chat-like interface for seamless interaction with the AI assistant.
12 | - **Document Upload**: Users can upload various types of documents (PDF, CSV, TXT, PNG, JPG, XLSX, JSON, DOCX, Python scripts etc) to provide context for the AI assistant.
13 | - **Caching**: Uploaded documents and extracted text are cached in an S3 bucket for improved performance. This serves as the object storage unit for the application as documents are retrieved and loaded into the model to keep conversation context.
14 | - **Chat History**: The app stores stores and retrieves chat history (including document metadata) to/from a DynamoDB table, allowing users to continue conversations across sessions.
15 | - **Session Store**: The application utilizes DynamoDB to store and manage user and session information, enabling isolated conversations and state tracking for each user interaction.
16 | - **Model Selection**: Users can select from a broad list of LLMs on Amazon Bedrock including latest models from Anthropic Claude, Amazon Nova, Meta Llama, Deepseek etc for their queries and can include additional models on Bedrock by modifying teh `model-id.json` file. It incorporates the Bedrock Converse API providing a standardized model interface.
17 | - **Cost Tracking**: The application calculates and displays the cost associated with each chat session based on the input and output token counts and the pricing model defined in the `pricing.json` file.
18 | - **Logging**: The items logged in the DynamoDB table include the user ID, session ID, messages, timestamps,uploaded documents s3 path, input and output token counts. This helps to isolate user engagement statistics and track the various items being logged, as well as attribute the cost per user.
19 | - **Tool Usage**: **`Advanced Data Analytics tool`** for processing and analyzing structured data (CSV, XLX and XLSX format) in an isolated and serverless enviroment.
20 | - **Extensible Tool Integration**: This app can be modified to leverage the extensive Domain Specific Language (DSL) knowledge inherent in Large Language Models (LLMs) to implement a wide range of specialized tools. This capability is enhanced by the versatile execution environments provided by Docker containers and AWS Lambda, allowing for dynamic and adaptable implementation of various DSL-based functionalities. This approach enables the system to handle diverse domain-specific tasks efficiently, without the need for hardcoded, specialized modules for each domain.
21 |
22 | There are two files of interest.
23 | 1. A Jupyter Notebook that walks you through the ChatBot Implementation cell by cell (Advanced Data Analytics only available in the streamlit chatbot).
24 | 2. A Streamlit app that can be deployed to create a UI Chatbot.
25 |
26 | ## Pre-Requisites
27 | 1. [Amazon Bedrock Anthropic Claude Model Access](https://docs.aws.amazon.com/bedrock/latest/userguide/model-access.html)
28 | 2. [S3 bucket](https://docs.aws.amazon.com/AmazonS3/latest/userguide/create-bucket-overview.html) to store uploaded documents and Textract output.
29 | 3. Optional:
30 | - Create an Amazon DynamoDB table to store chat history (Run the notebook **BedrockChatUI** to create a DynamoDB Table). This is optional as there is a local disk storage option, however, I would recommend using Amazon DynamoDB.
31 | - Amazon Textract. This is optional as there is an option to use python libraries [`pypdf2`](https://pypi.org/project/PyPDF2/) and [`pytessesract`](https://pypi.org/project/pytesseract/) for PDF and image processing. However, I would recommend using Amazon Textract for higher quality PDF and image processing. You will experience latency when using `pytesseract`.
32 | - [Amazon Elastic Container Registry](https://docs.aws.amazon.com/AmazonECR/latest/userguide/repository-create.html) to store custom docker images if using the `Advanced Data Analytics` feature with the AWS Lambda setup.
33 |
34 | To use the **Advanced Analytics Feature**, this additional step is required (ChatBot can still be used without enabling `Advanced Analytics Feature`):
35 |
36 | This feature can be powered by a **python** runtime on AWS Lambda and/or a **pyspark** runtime on Amazon Athena. Expand the appropiate section below to view the set-up instructions.
37 |
38 |
39 | AWS Lambda Python Runtime Setup
40 |
41 | ## AWS Lambda Function with Custom Python Image
42 |
43 | 5. [Amazon Lambda](https://docs.aws.amazon.com/lambda/latest/dg/python-image.html#python-image-clients) function with custom python image to execute python code for analytics.
44 | - Create an private ECR repository by following the link in step 3.
45 | - On your local machine or any related AWS services including [AWS CloudShell](https://docs.aws.amazon.com/cloudshell/latest/userguide/welcome.html), [Amazon Elastic Compute Cloud](https://aws.amazon.com/ec2/getting-started/), [Amazon Sageamker Studio](https://aws.amazon.com/blogs/machine-learning/accelerate-ml-workflows-with-amazon-sagemaker-studio-local-mode-and-docker-support/) etc. run the following CLI commands:
46 | - install git and clone this git repo `git clone [github_link]`
47 | - navigate into the Docker directory `cd Docker`
48 | - if using local machine, authenticate with your [AWS credentials](https://docs.aws.amazon.com/cli/v1/userguide/cli-chap-authentication.html)
49 | - install [AWS Command Line Interface (AWS CLI) version 2](https://docs.aws.amazon.com/cli/latest/userguide/getting-started-install.html) if not already installed.
50 | - Follow the steps in the **Deploying the image** section under **Using an AWS base image for Python** in this [documentation guide](https://docs.aws.amazon.com/lambda/latest/dg/python-image.html#python-image-instructions). Replace the placeholders with the appropiate values. You can skip step `2` if you already created an ECR repository.
51 | - In step 6, in addition to `AWSLambdaBasicExecutionRole` policy, **ONLY** grant [least priveledged read and write Amazon S3 policies](https://docs.aws.amazon.com/IAM/latest/UserGuide/reference_policies_examples_s3_rw-bucket.html) to the execution role. Scope down the policy to only include the necessary S3 bucket and S3 directory prefix where uploaded files will be stored and read from as configured in the `config.json` file below.
52 | - In step 7, I recommend creating the Lambda function in a [Amazon Virtual Private Cloud (VPC)](https://docs.aws.amazon.com/lambda/latest/dg/configuration-vpc.html) without [internet access](https://docs.aws.amazon.com/vpc/latest/userguide/vpc-example-private-subnets-nat.html) and attach Amazon S3 and Amazon CloudWatch [gateway](https://docs.aws.amazon.com/vpc/latest/privatelink/vpc-endpoints-s3.html) and [interface endpoints](https://docs.aws.amazon.com/vpc/latest/privatelink/create-interface-endpoint.html#create-interface-endpoint.html) accordingly. The following step 7 command can be modified to include VPC paramters:
53 | ```
54 | aws lambda create-function \
55 | --function-name YourFunctionName \
56 | --package-type Image \
57 | --code ImageUri=your-account-id.dkr.ecr.your-region.amazonaws.com/your-repo:tag \
58 | --role arn:aws:iam::your-account-id:role/YourLambdaExecutionRole \
59 | --vpc-config SubnetIds=subnet-xxxxxxxx,subnet-yyyyyyyy,SecurityGroupIds=sg-zzzzzzzz \
60 | --memory-size 512 \
61 | --timeout 300 \
62 | --region your-region
63 | ```
64 |
65 | Modify the placeholders as appropiate. I recommend to keep `timeout` and `memory-size` params conservative as that will affect cost. A good staring point for memory is `512` MB.
66 | - Ignore step 8.
67 |
68 |
69 |
70 | Amazon Athena Spark Runtime Setup
71 |
72 | ## Create Amazon Athena Spark WorkGroup
73 |
74 | 5. Follow the instructions [Get started with Apache Spark on Amazon Athena](https://docs.aws.amazon.com/athena/latest/ug/notebooks-spark-getting-started.html) to create an Amazon Athena workgroup with Apache Spark. You `DO NOT` need to select `Turn on example notebook`.
75 | - Provide S3 permissions to the workgroup execution role for the S3 buckets configured with this application.
76 | - Note that the Amazon Athena Spark environment comes preinstalled with a select [few python libraries](https://docs.aws.amazon.com/athena/latest/ug/notebooks-spark-preinstalled-python-libraries.html).
77 |
78 |
79 | -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
80 |
81 |
82 |
83 | **⚠ IMPORTANT SECURITY NOTE:**
84 |
85 | Enabling the **Advanced Analytics Feature** allows the LLM to generate and execute Python code to analyze your dataset that will automatically be executed in a Lambda function environment. To mitigate potential risks:
86 |
87 | 1. **VPC Configuration**:
88 | - It is recommended to place the Lambda function in an internet-free VPC.
89 | - Use Amazon S3 and CloudWatch gateway/interface endpoints for necessary access.
90 |
91 | 2. **IAM Permissions**:
92 | - Scope down the AWS Lambda and/or Amazon Athena workgroup execution role to only Amazon S3 and the required S3 resources. This is in addition to `AWSLambdaBasicExecutionRole` policy if using AWS Lambda.
93 |
94 | 3. **Library Restrictions**:
95 | - Only libraries specified in `Docker/requirements.txt` will be available at runtime.
96 | - Modify this list carefully based on your needs.
97 |
98 | 4. **Resource Allocation**:
99 | - Adjust AWS Lambda function `timeout` and `memory-size` based on data size and analysis complexity.
100 |
101 | 5. **Production Considerations**:
102 | - This application is designed for POC use.
103 | - Implement additional security measures before deploying to production.
104 |
105 | The goal is to limit the potential impact of generated code execution.
106 |
107 | ## Configuration
108 | The application's behavior can be customized by modifying the `config.json` file. Here are the available options:
109 |
110 | - `DynamodbTable`: The name of the DynamoDB table to use for storing chat history. Leave this field empty if you decide to use local storage for chat history.
111 | - `UserId`: The DynamoDB user ID for the application. Leave this field empty if you decide to use local storage for chat history.
112 | - `Bucket_Name`: The name of the S3 bucket used for caching documents and extracted text. This is required.
113 | - `max-output-token`: The maximum number of output tokens allowed for the AI assistant's response.
114 | - `chat-history-loaded-length`: The number of recent chat messages to load from the DynamoDB table or Local storage.
115 | - `bedrock-region`: The AWS region where Amazon Bedrock is enabled.
116 | - `load-doc-in-chat-history`: A boolean flag indicating whether to load documents in the chat history. If `true` all documents would be loaded in chat history as context (provides more context of previous chat history to the AI at the cost of additional price and latency). If `false` only the user query and response would be loaded in the chat history, the AI would have no recollection of any document context from those chat conversations. When setting boolean in JSON use all lower caps.
117 | - `AmazonTextract`: A boolean indicating whether to use Amazon Textract or python libraries for PDF and image processing. Set to `false` if you do not have access to Amazon Textract. When setting boolean in JSON use all lower caps.
118 | - `csv-delimiter`: The delimiter to use when parsing structured content to string. Supported formats are "|", "\t", and ",".
119 | - `document-upload-cache-s3-path`: S3 bucket path to cache uploaded files. Do not include the bucket name, just the prefix without a trailing slash. For example "path/to/files".
120 | - `AmazonTextract-result-cache`: S3 bucket path to cache Amazon Textract result. Do not include the bucket name, just the prefix without a trailing slash. For example "path/to/files".
121 | - `lambda-function`: Name of the Lambda function deploy in the steps above. This is required if using the `Advanced Analytics Tool` with AWS Lambda.
122 | - `input_s3_path`: S3 directory prefix, without the foward and trailing slash, to render the S3 objects in the Chat UI.
123 | - `input_bucket`: S3 bucket name where the files to be rendered on the screen are stored.
124 | - `input_file_ext`: comma-seperated file extension names (without ".") for files in the S3 buckets to be rendered on the screen. By default `xlsx` and `csv` are included.
125 | - `athena-work-group-name`: Spark Amazon Athena workkgroup name created above. This is required if using the `Advanced Analytics Tool` with Amazon Athena.
126 |
127 | **⚠ IMPORTANT ADVISORY FOR ADVANCED ANALYTICS FEATURE**
128 |
129 | When using the **Advanced Analytics Feature**, take the following precautions:
130 | 1. **Sandbox Environment**:
131 | - Set `Bucket_Name` and `document-upload-cache-s3-path` to point to a separate, isolated "sandbox" S3 location.
132 | - Grant read and write access as [documented](https://docs.aws.amazon.com/IAM/latest/UserGuide/reference_policies_examples_s3_rw-bucket.html) to this bucket/prefix resource to the lambda execution role.
133 | - Do NOT use your main storage path for these parameters. This isolation is crucial, to avoid potential file overwrite, as the app will execute LLM-generated code.
134 | 2. **Input Data Safety**:
135 | - `input_s3_path` and `input_bucket` are used for read-only operations and can safely point to your main data storage. The LLM is not aware of this parameters unless explicitly provided by user during chat.
136 | - Only grant read access to this bucket/prefix resource in the execution role attached to the Lambda function.
137 | - **IMPORTANT**: Ensure `input_bucket` is different from `Bucket_Name`.
138 |
139 | By following these guidelines, you mitigate the potential risk of unintended data modification or loss in your primary storage areas.
140 |
141 | ## To run this Streamlit App on Sagemaker Studio follow the steps below:
142 |
143 |
144 |
145 |
146 | 
147 |
148 | If You have a Sagemaker AI Studio Domain already set up, ignore the first item, however, item 2 is required.
149 | * [Set Up SageMaker Studio](https://docs.aws.amazon.com/sagemaker/latest/dg/onboard-quick-start.html)
150 | * SageMaker execution role should have access to interact with [Bedrock](https://docs.aws.amazon.com/bedrock/latest/userguide/api-setup.html), [S3](https://docs.aws.amazon.com/AmazonS3/latest/userguide/access-policy-language-overview.html) and optionally [Textract](https://docs.aws.amazon.com/aws-managed-policy/latest/reference/AmazonTextractFullAccess.html) and [DynamoDB](https://docs.aws.amazon.com/amazondynamodb/latest/developerguide/iam-policy-specific-table-indexes.html), [AWS Lambda](https://docs.aws.amazon.com/lambda/latest/dg/access-control-identity-based.html) and [Amazon Athena](https://docs.aws.amazon.com/athena/latest/ug/managed-policies.html)if these services are used.
151 |
152 | ### On SageMaker AI Studio JupyterLab:
153 | * [Create a JupyterLab space](https://docs.aws.amazon.com/sagemaker/latest/dg/studio-updated-jl.html)
154 | *
155 | * Open a terminal by clicking **File** -> **New** -> **Terminal**
156 | * Navigate into the cloned repository directory using the `cd bedrock-claude-chatbot` command and run the following commands to install the application python libraries:
157 | - sudo apt update
158 | - sudo apt upgrade -y
159 | - chmod +x install_package.sh
160 | - ./install_package.sh
161 | - **NOTE**: If you run into this error `ERROR: Could not install packages due to an OSError: [Errno 2] No such file or directory: /opt/conda/lib/python3.10/site-packages/fsspec-2023.6.0.dist-info/METADATA`, I solved it by deleting the `fsspec` package by running the following command (this is due to have two versions of `fsspec` install 2023* and 2024*):
162 | - `rm /opt/conda/lib/python3.10/site-packages/fsspec-2023.6.0.dist-info -rdf`
163 | - pip install -U fsspec # fsspec 2024.9.0 should already be installed.
164 | * If you decide to use Python Libs for PDF and image processing, this requires tesserect-ocr. Run the following command:
165 | - sudo apt update -y
166 | - sudo apt-get install tesseract-ocr-all -y
167 | * Run command `python3 -m streamlit run bedrock-chat.py --server.enableXsrfProtection false` to start the Streamlit server. Do not use the links generated by the command as they won't work in studio.
168 | * Copy the URL of the SageMaker JupyterLab. It should look something like this https://qukigdtczjsdk.studio.us-east-1.sagemaker.aws/jupyterlab/default/lab/tree/healthlake/app_fhir.py. Replace everything after .../default/ with proxy/8501/, something like https://qukigdtczjsdk.studio.us-east-1.sagemaker.aws/jupyterlab/default/proxy/8501/. Make sure the port number (8501 in this case) matches with the port number printed out when you run the `python3 -m streamlit run bedrock-chat.py --server.enableXsrfProtection false` command; port number is the last 4 digits after the colon in the generated URL.
169 |
170 |
171 | ## To run this Streamlit App on AWS EC2 (I tested this on the Ubuntu Image)
172 | * [Create a new ec2 instance](https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/EC2_GetStarted.html)
173 | * Expose TCP port range 8500-8510 on Inbound connections of the attached Security group to the ec2 instance. TCP port 8501 is needed for Streamlit to work. See image below
174 | *
175 | * EC2 [instance profile role](https://docs.aws.amazon.com/IAM/latest/UserGuide/id_roles_use_switch-role-ec2_instance-profiles.html) has the required permissions to access the services used by this application mentioned above.
176 | * [Connect to your ec2 instance](https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/AccessingInstances.html)
177 | * Run the appropiate commands to update the ec2 instance (`sudo apt update` and `sudo apt upgrade` -for Ubuntu)
178 | * Clone this git repo `git clone [github_link]` and `cd bedrock-claude-chatbot`
179 | * Install python3 and pip if not already installed, `sudo apt install python3` and `sudo apt install python3-pip`.
180 | * If you decide to use Python Libs for PDF and image processing, this requires tesserect-ocr. Run the following command:
181 | - If using Centos-OS or Amazon-Linux:
182 | - sudo rpm -Uvh https://dl.fedoraproject.org/pub/epel/epel-release-latest-7.noarch.rpm
183 | - sudo yum -y update
184 | - sudo yum install -y tesseract
185 | - For Ubuntu or Debian:
186 | - sudo apt-get install tesseract-ocr-all -y
187 | * Install the dependencies by running the following commands (use `yum` for Centos-OS or Amazon-Linux):
188 | - sudo apt update
189 | - sudo apt upgrade -y
190 | - chmod +x install_package.sh
191 | - ./install_package.sh
192 | * Run command `tmux new -s mysession` to create a new session. Then in the new session created `cd bedrock-claude-chatbot` into the **ChatBot** dir and run `python3 -m streamlit run bedrock-chat.py` to start the streamlit app. This allows you to run the Streamlit application in the background and keep it running even if you disconnect from the terminal session.
193 | * Copy the **External URL** link generated and paste in a new browser tab.
194 | * **⚠ NOTE:** The generated link is not secure! For [additional guidance](https://github.com/aws-samples/deploy-streamlit-app).
195 | To stop the `tmux` session, in your ec2 terminal Press `Ctrl+b`, then `d` to detach. to kill the session, run `tmux kill-session -t mysession`
196 |
197 | ## Limitations and Future Updates
198 | 1. **Pricing**: Pricing is only calculated for the Bedrock models not including cost of any other AWS service used. In addition, the pricing information of the models are stored in a static `pricing.json` file. Do manually update the file to refelct current [Bedrock pricing details](https://aws.amazon.com/bedrock/pricing/). Use this cost implementation in this app as a rough estimate of actual cost of interacting with the Bedrock models as actual cost reported in your account may differ.
199 |
200 | 3. **Storage Encryption**: This application does not implement storing and reading files to and from S3 and/or DynamoDB using KMS keys for data at rest encryption.
201 |
202 | 4. **Production-Ready**: For an enterprise and production-ready chatbot application architecture pattern, check out [Generative AI Application Builder on AWS](https://aws.amazon.com/solutions/implementations/generative-ai-application-builder-on-aws/) and [Bedrock-Claude-Chat](https://github.com/aws-samples/bedrock-claude-chat) for best practices and recommendations.
203 |
204 | 5. **Tools Suite**: This application only includes a single tool. However, with the many niche applications of LLM's, a library of tools will make this application robust.
205 |
206 | ## Application Workflow Diagram
207 |
208 |
209 |
210 | **Guidelines**
211 | - When a document is uploaded (and for everytime it stays uploaded), its content is attached to the user's query, and the chatbot's responses are grounded in the document ( a sperate prompt template is used). That chat conversation is tagged with the document name as metadata to be used in the chat history.
212 | - If the document is detached, the chat history will only contain the user's queries and the chatbot's responses, unless the `load-doc-in-chat-history` configuration parameter is enabled, in which case the document content will be retained in the chat history.
213 | - You can refer to documents by their names of format (PDF, WORD, IMAGE etc) when having a conversation with the AI.
214 | - The `chat-history-loaded-length` setting determines how many previous conversations the LLM will be aware of, including any attached documents (if the `load-doc-in-chat-history` option is enabled). A higher value for this setting means that the LLM will have access to more historical context, but it may also increase the cost and potentially introduce latency, as more tokens will be inputted into the LLM. For optimal performance and cost-effectiveness, it's recommended to set the 'chat-history-loaded-length' to a value between 5 and 10. This range strikes a balance between providing the LLM with sufficient historical context while minimizing the input payload size and associated costs.
215 | - ⚠️ When using the Streamlit app, any uploaded document will be persisted for the current chat conversation. This means that subsequent questions, as well as chat histories (if the 'load-doc-in-chat-history' option is enabled), will have the document(s) as context, and the responses will be grounded in that document(s). However, this can increase the cost and latency, as the input payload will be larger due to the loaded document(s) in every chat turn. Therefore if you have the `load-doc-in-chat-history` option enabled, after your first question response with the uploaded document(s), it is recommended to remove the document(s) by clicking the **X** sign next to the uploaded file(s). The document(s) will be saved in the chat history, and you can ask follow-up questions about it, as the LLM will have knowledge of the document(s) from the chat history. On the other hand, if the `load-doc-in-chat-history` option is disabled, and you want to keep asking follow-up questions about the document(s), leave the document(s) uploaded until you are done. This way, only the current chat turn will have the document(s) loaded, and not the entire chat history. The choice between enabling `load-doc-in-chat-history` or not is dependent on cost and latency. I would recommend enabling for a smoother experience following the aforementioned guidelines.
216 |
--------------------------------------------------------------------------------
/bedrock-chat.py:
--------------------------------------------------------------------------------
1 | import streamlit as st
2 | import boto3
3 | from botocore.config import Config
4 | import os
5 | import pandas as pd
6 | import time
7 | import json
8 | import io
9 | import re
10 | import openpyxl
11 | from python_calamine import CalamineWorkbook
12 | from openpyxl.cell import Cell
13 | import plotly.io as pio
14 | from openpyxl.worksheet.cell_range import CellRange
15 | from docx.table import _Cell
16 | from boto3.dynamodb.conditions import Key
17 | from pptx import Presentation
18 | from botocore.exceptions import ClientError
19 | from textractor import Textractor
20 | from textractor.data.constants import TextractFeatures
21 | from textractor.data.text_linearization_config import TextLinearizationConfig
22 | import pytesseract
23 | from PIL import Image
24 | import PyPDF2
25 | import chardet
26 | from docx import Document as DocxDocument
27 | from docx.oxml.text.paragraph import CT_P
28 | from docx.oxml.table import CT_Tbl
29 | from docx.document import Document
30 | from docx.text.paragraph import Paragraph
31 | from docx.table import Table as DocxTable
32 | import concurrent.futures
33 | from functools import partial
34 | import random
35 | from utils import function_calling_utils
36 | from urllib.parse import urlparse
37 | import plotly.graph_objects as go
38 | import numpy as np
39 | import base64
40 | config = Config(
41 | read_timeout=600, # Read timeout parameter
42 | retries=dict(
43 | max_attempts=10 # Handle retries
44 | )
45 | )
46 |
47 | st.set_page_config(initial_sidebar_state="auto")
48 |
49 | # Read app configurations
50 | with open('config.json','r',encoding='utf-8') as f:
51 | config_file = json.load(f)
52 | # pricing info
53 | with open('pricing.json','r',encoding='utf-8') as f:
54 | pricing_file = json.load(f)
55 | # Bedrock Model info
56 | with open('model_id.json','r',encoding='utf-8') as f:
57 | model_info = json.load(f)
58 |
59 | DYNAMODB = boto3.resource('dynamodb')
60 | COGNITO = boto3.client('cognito-idp')
61 | S3 = boto3.client('s3')
62 | LOCAL_CHAT_FILE_NAME = "chat-history.json"
63 | DYNAMODB_TABLE = config_file["DynamodbTable"]
64 | BUCKET = config_file["Bucket_Name"]
65 | OUTPUT_TOKEN = config_file["max-output-token"]
66 | S3_DOC_CACHE_PATH = config_file["document-upload-cache-s3-path"]
67 | TEXTRACT_RESULT_CACHE_PATH = config_file["AmazonTextract-result-cache"]
68 | LOAD_DOC_IN_ALL_CHAT_CONVO = config_file["load-doc-in-chat-history"]
69 | CHAT_HISTORY_LENGTH = config_file["chat-history-loaded-length"]
70 | DYNAMODB_USER = config_file["UserId"]
71 | REGION = config_file["region"]
72 | USE_TEXTRACT = config_file["AmazonTextract"]
73 | CSV_SEPERATOR = config_file["csv-delimiter"]
74 | INPUT_BUCKET = config_file["input_bucket"]
75 | INPUT_S3_PATH = config_file["input_s3_path"]
76 | INPUT_EXT = tuple(f".{x}" for x in config_file["input_file_ext"].split(','))
77 | MODEL_DISPLAY_NAME = list(model_info.keys())
78 | HYBRID_MODELS = ["sonnet-3.7", "sonnet-4", "opus-4"] # populate with list of hybrid reasoning models on Bedrock
79 | NON_VISION_MODELS = ["deepseek", "haiku-3.5", "nova-micro"] # populate with list of models not supporting image input on Bedrock
80 | NON_TOOL_SUPPORTING_MODELS = ["deepseek", "meta-scout", "meta-maverick"] # populate with list of models not supporting tool calling on Bedrock
81 |
82 | bedrock_runtime = boto3.client(service_name='bedrock-runtime', region_name=REGION, config=config)
83 |
84 | if 'messages' not in st.session_state:
85 | st.session_state['messages'] = []
86 | if 'input_token' not in st.session_state:
87 | st.session_state['input_token'] = 0
88 | if 'output_token' not in st.session_state:
89 | st.session_state['output_token'] = 0
90 | if 'chat_hist' not in st.session_state:
91 | st.session_state['chat_hist'] = []
92 | if 'user_sess' not in st.session_state:
93 | st.session_state['user_sess'] =str(time.time())
94 | if 'chat_session_list' not in st.session_state:
95 | st.session_state['chat_session_list'] = []
96 | if 'count' not in st.session_state:
97 | st.session_state['count'] = 0
98 | if 'userid' not in st.session_state:
99 | st.session_state['userid']= config_file["UserId"]
100 | if 'cost' not in st.session_state:
101 | st.session_state['cost'] = 0
102 | if 'reasoning_mode' not in st.session_state:
103 | st.session_state['reasoning_mode'] = False # Only activated when user selects anthropic 3.7 and toggles on thinking
104 | if 'athena-session' not in st.session_state:
105 | st.session_state['athena-session'] = ""
106 |
107 | def get_object_with_retry(bucket, key):
108 | max_retries = 5
109 | retries = 0
110 | backoff_base = 2
111 | max_backoff = 3 # Maximum backoff time in seconds
112 | s3 = boto3.client('s3')
113 | while retries < max_retries:
114 | try:
115 | response = s3.get_object(Bucket=bucket, Key=key)
116 | return response
117 | except ClientError as e:
118 | error_code = e.response['Error']['Code']
119 | if error_code == 'DecryptionFailureException':
120 | sleep_time = min(max_backoff, backoff_base ** retries + random.uniform(0, 1))
121 | print(f"Decryption failed, retrying in {sleep_time} seconds...")
122 | time.sleep(sleep_time)
123 | retries += 1
124 | elif e.response['Error']['Code'] == 'ModelStreamErrorException':
125 | if retries < max_retries:
126 | # Throttling, exponential backoff
127 | sleep_time = min(max_backoff, backoff_base ** retries + random.uniform(0, 1))
128 | time.sleep(sleep_time)
129 | retries += 1
130 | else:
131 | raise e
132 |
133 | # If we reach this point, it means the maximum number of retries has been exceeded
134 | raise Exception(f"Failed to get object {key} from bucket {bucket} after {max_retries} retries.")
135 |
136 | def decode_numpy_array(obj):
137 | if isinstance(obj, dict) and 'dtype' in obj and 'bdata' in obj:
138 | dtype = np.dtype(obj['dtype'])
139 | return np.frombuffer(base64.b64decode(obj['bdata']), dtype=dtype)
140 | return obj
141 |
142 | # Decode the numpy arrays in the JSON data
143 | def decode_json(obj):
144 | if isinstance(obj, dict):
145 | return {k: decode_json(decode_numpy_array(v)) for k, v in obj.items()}
146 | elif isinstance(obj, list):
147 | return [decode_json(item) for item in obj]
148 | return obj
149 |
150 | def save_chat_local(file_path, new_data, session_id):
151 | """Store long term chat history Local Disk"""
152 | try:
153 | # Read the existing JSON data from the file
154 | with open(file_path, "r",encoding='utf-8') as file:
155 | existing_data = json.load(file)
156 | if session_id not in existing_data:
157 | existing_data[session_id]=[]
158 | except FileNotFoundError:
159 | # If the file doesn't exist, initialize an empty list
160 | existing_data = {session_id:[]}
161 | # Append the new data to the existing list
162 | from decimal import Decimal
163 | data = [{k: float(v) if isinstance(v, Decimal) else v for k, v in item.items()} for item in new_data]
164 | existing_data[session_id].extend(data)
165 |
166 | # Write the updated list back to the JSON file
167 | with open(file_path, "w",encoding="utf-8") as file:
168 | json.dump(existing_data, file)
169 |
170 | def load_chat_local(file_path,session_id):
171 | """Load long term chat history from Local"""
172 | try:
173 | # Read the existing JSON data from the file
174 | with open(file_path, "r",encoding='utf-8') as file:
175 | existing_data = json.load(file)
176 | if session_id in existing_data:
177 | existing_data=existing_data[session_id]
178 | else:
179 | existing_data=[]
180 | except FileNotFoundError:
181 | # If the file doesn't exist, initialize an empty list
182 | existing_data = []
183 | return existing_data
184 |
185 |
186 | def process_files(files):
187 | """process uploaded files in parallel"""
188 | result_string=""
189 | errors = []
190 | future_proxy_mapping = {}
191 | futures = []
192 |
193 | with concurrent.futures.ProcessPoolExecutor() as executor:
194 | # Partial function to pass the handle_doc_upload_or_s3 function
195 | func = partial(handle_doc_upload_or_s3)
196 | for file in files:
197 | future = executor.submit(func, file)
198 | future_proxy_mapping[future] = file
199 | futures.append(future)
200 |
201 | # Collect the results and handle exceptions
202 | for future in concurrent.futures.as_completed(futures):
203 | file_url= future_proxy_mapping[future]
204 | try:
205 | result = future.result()
206 | doc_name=os.path.basename(file_url)
207 | result_string+=f"<{doc_name}>\n{result}\n{doc_name}>\n" # tag documnets with names to enhance prompts
208 | except Exception as e:
209 | # Get the original function arguments from the Future object
210 | error = {'file': file_url, 'error': str(e)}
211 | errors.append(error)
212 |
213 | return errors, result_string
214 |
215 | def handle_doc_upload_or_s3(file, cutoff=None):
216 | """Handle various document format"""
217 | dir_name, ext = os.path.splitext(file)
218 | if ext.lower() in [".pdf", ".png", ".jpg",".tif",".jpeg"]:
219 | content=exract_pdf_text_aws(file)
220 | elif ".csv" == ext.lower():
221 | content=parse_csv_from_s3(file,cutoff)
222 | elif ext.lower() in [".xlsx", ".xls"]:
223 | content=table_parser_utills(file,cutoff)
224 | elif ".json"==ext.lower():
225 | obj=get_s3_obj_from_bucket_(file)
226 | content = json.loads(obj['Body'].read())
227 | elif ext.lower() in [".txt",".py"]:
228 | obj=get_s3_obj_from_bucket_(file)
229 | content = obj['Body'].read()
230 | elif ".docx" == ext.lower():
231 | obj=get_s3_obj_from_bucket_(file)
232 | content = obj['Body'].read()
233 | docx_buffer = io.BytesIO(content)
234 | content = extract_text_and_tables(docx_buffer)
235 | elif ".pptx" == ext.lower():
236 | obj=get_s3_obj_from_bucket_(file)
237 | content = obj['Body'].read()
238 | docx_buffer = io.BytesIO(content)
239 | content = extract_text_from_pptx_s3(docx_buffer)
240 | # Implement any other file extension logic
241 | return content
242 |
243 | class InvalidContentError(Exception):
244 | pass
245 |
246 | def detect_encoding(s3_uri):
247 | """detect csv encoding"""
248 | s3 = boto3.client('s3')
249 | match = re.match("s3://(.+?)/(.+)", s3_uri)
250 | if match:
251 | bucket_name = match.group(1)
252 | key = match.group(2)
253 | response = s3.get_object(Bucket=bucket_name, Key=key)
254 | content = response['Body'].read()
255 | result = chardet.detect(content)
256 | return result['encoding']
257 |
258 | def parse_csv_from_s3(s3_uri, cutoff):
259 | """read csv files"""
260 | try:
261 | # Detect the file encoding using chardet
262 | encoding = detect_encoding(s3_uri)
263 | # Sniff the delimiter and read the CSV file
264 | df = pd.read_csv(s3_uri, delimiter=None, engine='python', encoding=encoding)
265 | if cutoff:
266 | df=df.iloc[:20]
267 | return df.to_csv(index=False, sep=CSV_SEPERATOR)
268 | except Exception as e:
269 | raise InvalidContentError(f"Error: {e}")
270 |
271 | def iter_block_items(parent):
272 | if isinstance(parent, Document):
273 | parent_elm = parent.element.body
274 | elif isinstance(parent, _Cell):
275 | parent_elm = parent._tc
276 | else:
277 | raise ValueError("something's not right")
278 |
279 | for child in parent_elm.iterchildren():
280 | if isinstance(child, CT_P):
281 | yield Paragraph(child, parent)
282 | elif isinstance(child, CT_Tbl):
283 | yield DocxTable(child, parent)
284 |
285 | def extract_text_and_tables(docx_path):
286 | """ Extract text from docx files"""
287 | document = DocxDocument(docx_path)
288 | content = ""
289 | current_section = ""
290 | section_type = None
291 | for block in iter_block_items(document):
292 | if isinstance(block, Paragraph):
293 | if block.text:
294 | if block.style.name == 'Heading 1':
295 | # Close the current section if it exists
296 | if current_section:
297 | content += f"{current_section}{section_type}>\n"
298 | current_section = ""
299 | section_type = None
300 | section_type ="h1"
301 | content += f"<{section_type}>{block.text}{section_type}>\n"
302 | elif block.style.name== 'Heading 3':
303 | # Close the current section if it exists
304 | if current_section:
305 | content += f"{current_section}{section_type}>\n"
306 | current_section = ""
307 | section_type = "h3"
308 | content += f"<{section_type}>{block.text}{section_type}>\n"
309 | elif block.style.name == 'List Paragraph':
310 | # Add to the current list section
311 | if section_type != "list":
312 | # Close the current section if it exists
313 | if current_section:
314 | content += f"{current_section}{section_type}>\n"
315 | section_type = "list"
316 | current_section = ""
317 | current_section += f"{block.text}\n"
318 | elif block.style.name.startswith('toc'):
319 | # Add to the current toc section
320 | if section_type != "toc":
321 | # Close the current section if it exists
322 | if current_section:
323 | content += f"{current_section}{section_type}>\n"
324 | section_type = "toc"
325 | current_section = ""
326 | current_section += f"{block.text}\n"
327 | else:
328 | # Close the current section if it exists
329 | if current_section:
330 | content += f"{current_section}{section_type}>\n"
331 | current_section = ""
332 | section_type = None
333 |
334 | # Append the passage text without tagging
335 | content += f"{block.text}\n"
336 |
337 | elif isinstance(block, DocxTable):
338 | # Add the current section before the table
339 | if current_section:
340 | content += f"{current_section}{section_type}>\n"
341 | current_section = ""
342 | section_type = None
343 |
344 | content += "\n"
345 | for row in block.rows:
346 | row_content = []
347 | for cell in row.cells:
348 | cell_content = []
349 | for nested_block in iter_block_items(cell):
350 | if isinstance(nested_block, Paragraph):
351 | cell_content.append(nested_block.text)
352 | elif isinstance(nested_block, DocxTable):
353 | nested_table_content = parse_nested_table(nested_block)
354 | cell_content.append(nested_table_content)
355 | row_content.append(CSV_SEPERATOR.join(cell_content))
356 | content += CSV_SEPERATOR.join(row_content) + "\n"
357 | content += "
\n"
358 |
359 | # Add the final section
360 | if current_section:
361 | content += f"{current_section}{section_type}>\n"
362 | return content
363 |
364 | def parse_nested_table(table):
365 | nested_table_content = "\n"
366 | for row in table.rows:
367 | row_content = []
368 | for cell in row.cells:
369 | cell_content = []
370 | for nested_block in iter_block_items(cell):
371 | if isinstance(nested_block, Paragraph):
372 | cell_content.append(nested_block.text)
373 | elif isinstance(nested_block, DocxTable):
374 | nested_table_content += parse_nested_table(nested_block)
375 | row_content.append(CSV_SEPERATOR.join(cell_content))
376 | nested_table_content += CSV_SEPERATOR.join(row_content) + "\n"
377 | nested_table_content += "
"
378 | return nested_table_content
379 |
380 |
381 |
382 | def extract_text_from_pptx_s3(pptx_buffer):
383 | """ Extract Text from pptx files"""
384 | presentation = Presentation(pptx_buffer)
385 | text_content = []
386 | for slide in presentation.slides:
387 | slide_text = []
388 | for shape in slide.shapes:
389 | if hasattr(shape, 'text'):
390 | slide_text.append(shape.text)
391 | text_content.append('\n'.join(slide_text))
392 | return '\n\n'.join(text_content)
393 |
394 | def exract_pdf_text_aws(file):
395 | """extract text from PDFs using Amazon Textract or PyPDF2"""
396 | file_base_name = os.path.basename(file)
397 | dir_name, ext = os.path.splitext(file)
398 | # Checking if extracted doc content is in S3
399 | if USE_TEXTRACT:
400 | if [x for x in get_s3_keys(f"{TEXTRACT_RESULT_CACHE_PATH}/") if file_base_name in x]:
401 | response = get_object_with_retry(BUCKET, f"{TEXTRACT_RESULT_CACHE_PATH}/{file_base_name}.txt")
402 | text = response['Body'].read().decode()
403 | return text
404 | else:
405 |
406 | extractor = Textractor()
407 | # Asynchronous call, you will experience some wait time. Try caching results for better experience
408 | if "pdf" in ext:
409 | print("Asynchronous call, you may experience some wait time.")
410 | document = extractor.start_document_analysis(
411 | file_source=file,
412 | features=[TextractFeatures.LAYOUT, TextractFeatures.TABLES],
413 | save_image=False,
414 | s3_output_path=f"s3://{BUCKET}/textract_output/"
415 | )
416 | # Synchronous call
417 | else:
418 | document = extractor.analyze_document(
419 | file_source=file,
420 | features=[TextractFeatures.LAYOUT,TextractFeatures.TABLES],
421 | save_image=False,
422 | )
423 | config = TextLinearizationConfig(
424 | hide_figure_layout=False,
425 | hide_header_layout=False,
426 | table_prefix="",
428 | )
429 | # Upload extracted content to s3
430 | S3.put_object(Body=document.get_text(config=config), Bucket=BUCKET, Key=f"{TEXTRACT_RESULT_CACHE_PATH}/{file_base_name}.txt")
431 | return document.get_text(config=config)
432 | else:
433 | s3 = boto3.resource("s3")
434 | match = re.match("s3://(.+?)/(.+)", file)
435 | if match:
436 | bucket_name = match.group(1)
437 | key = match.group(2)
438 | if "pdf" in ext:
439 | pdf_bytes = io.BytesIO()
440 | s3.Bucket(bucket_name).download_fileobj(key, pdf_bytes)
441 | # Read the PDF from the BytesIO object
442 | pdf_bytes.seek(0)
443 | # Create a PDF reader object
444 | pdf_reader = PyPDF2.PdfReader(pdf_bytes)
445 | # Get the number of pages in the PDF
446 | num_pages = len(pdf_reader.pages)
447 | # Extract text from each page
448 | text = ''
449 | for page_num in range(num_pages):
450 | page = pdf_reader.pages[page_num]
451 | text += page.extract_text()
452 | else:
453 | img_bytes = io.BytesIO()
454 | s3.Bucket(bucket_name).download_fileobj(key, img_bytes)
455 | img_bytes.seek(0)
456 | image_stream = io.BytesIO(img_bytes)
457 | image = Image.open(image_stream)
458 | text = pytesseract.image_to_string(image)
459 | return text
460 |
461 | def strip_newline(cell):
462 | return str(cell).strip()
463 |
464 | def table_parser_openpyxl(file, cutoff):
465 | """convert xlsx files to python pandas handling merged cells"""
466 | # Read from S3
467 | s3 = boto3.client('s3')
468 | match = re.match("s3://(.+?)/(.+)", file)
469 | if match:
470 | bucket_name = match.group(1)
471 | key = match.group(2)
472 | obj = s3.get_object(Bucket=bucket_name, Key=key)
473 | # Read Excel file from S3 into a buffer
474 | xlsx_buffer = io.BytesIO(obj['Body'].read())
475 | xlsx_buffer.seek(0)
476 | # Load workbook
477 | wb = openpyxl.load_workbook(xlsx_buffer)
478 | all_sheets_string = ""
479 | # Iterate over each sheet in the workbook
480 | for sheet_name in wb.sheetnames:
481 | # all_sheets_name.append(sheet_name)
482 | worksheet = wb[sheet_name]
483 |
484 | all_merged_cell_ranges: list[CellRange] = list(
485 | worksheet.merged_cells.ranges
486 | )
487 | for merged_cell_range in all_merged_cell_ranges:
488 | merged_cell: Cell = merged_cell_range.start_cell
489 | worksheet.unmerge_cells(range_string=merged_cell_range.coord)
490 | for row_index, col_index in merged_cell_range.cells:
491 | cell: Cell = worksheet.cell(row=row_index, column=col_index)
492 | cell.value = merged_cell.value
493 | # Convert sheet data to a DataFrame
494 | df = pd.DataFrame(worksheet.values)
495 | df = df.map(strip_newline)
496 | if cutoff:
497 | df = df.iloc[:20]
498 |
499 | # Convert to string and tag by sheet name
500 | tabb=df.to_csv(sep=CSV_SEPERATOR, index=False, header=0)
501 | all_sheets_string+=f'<{sheet_name}>\n{tabb}\n{sheet_name}>\n'
502 | return all_sheets_string
503 | else:
504 | raise Exception(f"{file} not formatted as an S3 path")
505 |
506 | def calamaine_excel_engine(file,cutoff):
507 | # # Read from S3
508 | s3 = boto3.client('s3')
509 | match = re.match("s3://(.+?)/(.+)", file)
510 | if match:
511 | bucket_name = match.group(1)
512 | key = match.group(2)
513 | obj = s3.get_object(Bucket=bucket_name, Key=key)
514 | # Read Excel file from S3 into a buffer
515 | xlsx_buffer = io.BytesIO(obj['Body'].read())
516 | xlsx_buffer.seek(0)
517 | all_sheets_string = ""
518 | # Load the Excel file
519 | workbook = CalamineWorkbook.from_filelike(xlsx_buffer)
520 | # Iterate over each sheet in the workbook
521 | for sheet_name in workbook.sheet_names:
522 | # Get the sheet by name
523 | sheet = workbook.get_sheet_by_name(sheet_name)
524 | df = pd.DataFrame(sheet.to_python(skip_empty_area=False))
525 | df = df.map(strip_newline)
526 | if cutoff:
527 | df = df.iloc[:20]
528 | # print(df)
529 | tabb = df.to_csv(sep=CSV_SEPERATOR, index=False, header=0)
530 | all_sheets_string += f'<{sheet_name}>\n{tabb}\n{sheet_name}>\n'
531 | return all_sheets_string
532 | else:
533 | raise Exception(f"{file} not formatted as an S3 path")
534 |
535 | def table_parser_utills(file,cutoff):
536 | try:
537 | response = table_parser_openpyxl(file, cutoff)
538 | if response:
539 | return response
540 | else:
541 | return calamaine_excel_engine(file, cutoff)
542 | except Exception as e:
543 | try:
544 | return calamaine_excel_engine(file, cutoff)
545 | except Exception as e:
546 | raise Exception(str(e))
547 |
548 |
549 | def put_db(params,messages):
550 | """Store long term chat history in DynamoDB"""
551 | chat_item = {
552 | "UserId": st.session_state['userid'], # user id
553 | "SessionId": params["session_id"], # User session id
554 | "messages": [messages], # 'messages' is a list of dictionaries
555 | "time":messages['time']
556 | }
557 |
558 | existing_item = DYNAMODB.Table(DYNAMODB_TABLE).get_item(Key={"UserId": st.session_state['userid'], "SessionId":params["session_id"]})
559 | if "Item" in existing_item:
560 | existing_messages = existing_item["Item"]["messages"]
561 | chat_item["messages"] = existing_messages + [messages]
562 | response = DYNAMODB.Table(DYNAMODB_TABLE).put_item(
563 | Item=chat_item
564 | )
565 |
566 |
567 | def get_chat_history_db(params,cutoff,vision_model):
568 | """process chat histories from local or DynamoDb to format expected by model input"""
569 | current_chat, chat_hist = [], []
570 | if params['chat_histories'] and cutoff != 0:
571 | chat_hist = params['chat_histories'][-cutoff:]
572 | for d in chat_hist:
573 | if d['image'] and vision_model and LOAD_DOC_IN_ALL_CHAT_CONVO:
574 | content = []
575 | for img in d['image']:
576 | s3 = boto3.client('s3')
577 | match = re.match("s3://(.+?)/(.+)", img)
578 | image_name = os.path.basename(img)
579 | _, ext = os.path.splitext(image_name)
580 | if "jpg" in ext:
581 | ext = ".jpeg"
582 | # if match:
583 | bucket_name = match.group(1)
584 | key = match.group(2)
585 | obj = s3.get_object(Bucket=bucket_name, Key=key)
586 | bytes_image = obj['Body'].read()
587 | content.extend([{"text": image_name}, {
588 | "image": {
589 | "format": f"{ext.lower().replace('.','')}",
590 | "source": {"bytes": bytes_image}
591 | }
592 | }])
593 | content.extend([{"text":d['user']}])
594 | current_chat.append({'role': 'user', 'content': content})
595 | elif d['document'] and LOAD_DOC_IN_ALL_CHAT_CONVO:
596 | # Handle scenario where tool is used for dataset that is out of context for the model context length
597 | if 'tool_use_id' in d and d['tool_use_id']:
598 | doc = 'Here are the documents:\n'
599 | for docs in d['document']:
600 | uploads = handle_doc_upload_or_s3(docs,20)
601 | doc_name = os.path.basename(docs)
602 | doc += f"<{doc_name}>\n{uploads}\n{doc_name}>\n"
603 | else:
604 | doc = 'Here are the documents:\n'
605 | for docs in d['document']:
606 | uploads = handle_doc_upload_or_s3(docs)
607 | doc_name = os.path.basename(docs)
608 | doc += f"<{doc_name}>\n{uploads}\n{doc_name}>\n"
609 | if not vision_model and d["image"]:
610 | for docs in d['image']:
611 | uploads = handle_doc_upload_or_s3(docs)
612 | doc_name = os.path.basename(docs)
613 | doc += f"<{doc_name}>\n{uploads}\n{doc_name}>\n"
614 | current_chat.append({'role': 'user', 'content': [{"text": doc+d['user']}]})
615 | else:
616 | current_chat.append({'role': 'user', 'content': [{"text": d['user']}]})
617 | current_chat.append({'role': 'assistant', 'content': [{"text": d['assistant']}]})
618 | else:
619 | chat_hist = []
620 | return current_chat, chat_hist
621 |
622 |
623 | def get_s3_keys(prefix):
624 | """list all keys in an s3 path"""
625 | s3 = boto3.client('s3')
626 | keys = []
627 | next_token = None
628 | while True:
629 | if next_token:
630 | response = s3.list_objects_v2(Bucket=BUCKET, Prefix=prefix, ContinuationToken=next_token)
631 | else:
632 | response = s3.list_objects_v2(Bucket=BUCKET, Prefix=prefix)
633 | if "Contents" in response:
634 | for obj in response['Contents']:
635 | key = obj['Key']
636 | name = key[len(prefix):]
637 | keys.append(name)
638 | if "NextContinuationToken" in response:
639 | next_token = response["NextContinuationToken"]
640 | else:
641 | break
642 | return keys
643 |
644 | def parse_s3_uri(uri):
645 | """
646 | Parse an S3 URI and extract the bucket name and key.
647 |
648 | :param uri: S3 URI (e.g., 's3://bucket-name/path/to/file.txt')
649 | :return: Tuple of (bucket_name, key) if valid, (None, None) if invalid
650 | """
651 | pattern = r'^s3://([^/]+)/(.*)$'
652 | match = re.match(pattern, uri)
653 | if match:
654 | return match.groups()
655 | return (None, None)
656 |
657 | def copy_s3_object(source_uri, dest_bucket, dest_key):
658 | """
659 | Copy an object from one S3 location to another.
660 |
661 | :param source_uri: S3 URI of the source object
662 | :param dest_bucket: Name of the destination bucket
663 | :param dest_key: Key to be used for the destination object
664 | :return: True if successful, False otherwise
665 | """
666 | s3 = boto3.client('s3')
667 |
668 | # Parse the source URI
669 | source_bucket, source_key = parse_s3_uri(source_uri)
670 | if not source_bucket or not source_key:
671 | print(f"Invalid source URI: {source_uri}")
672 | return False
673 |
674 | try:
675 | # Create a copy source dictionary
676 | copy_source = {
677 | 'Bucket': source_bucket,
678 | 'Key': source_key
679 | }
680 |
681 | # Copy the object
682 | s3.copy_object(CopySource=copy_source, Bucket=dest_bucket, Key=f"{dest_key}/{source_key}")
683 |
684 | print(f"File copied from {source_uri} to s3://{dest_bucket}/{dest_key}/{source_key}")
685 | return f"s3://{dest_bucket}/{dest_key}/{source_key}"
686 |
687 | except ClientError as e:
688 | print(f"An error occurred: {e}")
689 | raise(e)
690 | # return False
691 |
692 | def plotly_to_png_bytes(s3_uri):
693 | """
694 | Read a .plotly file from S3 given an S3 URI, convert it to a PNG image, and return the image as bytes.
695 |
696 | :param s3_uri: S3 URI of the .plotly file (e.g., 's3://bucket-name/path/to/file.plotly')
697 | :return: PNG image as bytes
698 | """
699 | # Parse S3 URI
700 | parsed_uri = urlparse(s3_uri)
701 | bucket_name = parsed_uri.netloc
702 | file_key = parsed_uri.path.lstrip('/')
703 |
704 | # Initialize S3 client
705 | s3_client = boto3.client('s3')
706 |
707 | try:
708 | # Read the .plotly file from S3
709 | response = s3_client.get_object(Bucket=bucket_name, Key=file_key)
710 | plotly_data = json.loads(response['Body'].read().decode('utf-8'))
711 |
712 | # Create a Figure object from the plotly data
713 | fig = go.Figure(data=plotly_data['data'], layout=plotly_data.get('layout', {}))
714 |
715 | # Convert the figure to PNG bytes
716 | img_bytes = fig.to_image(format="png")
717 |
718 | return img_bytes
719 |
720 | except Exception as e:
721 | print(f"An error occurred: {str(e)}")
722 | return None
723 |
724 |
725 | def get_s3_obj_from_bucket_(file):
726 | s3 = boto3.client('s3')
727 | match = re.match("s3://(.+?)/(.+)", file)
728 | if match:
729 | bucket_name = match.group(1)
730 | key = match.group(2)
731 | obj = s3.get_object(Bucket=bucket_name, Key=key)
732 | return obj
733 |
734 | def put_obj_in_s3_bucket_(docs):
735 | if isinstance(docs,str):
736 | s3_uri_pattern = r'^s3://([^/]+)/(.*?([^/]+)/?)$'
737 | if bool(re.match(s3_uri_pattern, docs)):
738 | file_uri=copy_s3_object(docs, BUCKET, S3_DOC_CACHE_PATH)
739 | return file_uri
740 | else:
741 | file_name = os.path.basename(docs.name)
742 | file_path = f"{S3_DOC_CACHE_PATH}/{file_name}"
743 | S3.put_object(Body=docs.read(), Bucket= BUCKET, Key=file_path)
744 | return f"s3://{BUCKET}/{file_path}"
745 |
746 |
747 | def bedrock_streemer(params,response, handler):
748 | """ stream response from bedrock Runtime"""
749 | text = ''
750 | think = ""
751 | signature = ""
752 | for chunk in response['stream']:
753 | if 'contentBlockDelta' in chunk:
754 | delta = chunk['contentBlockDelta']['delta']
755 | # print(chunk)
756 | if 'text' in delta:
757 | text += delta['text']
758 | handler.markdown(text.replace("$", "\\$"), unsafe_allow_html=True)
759 | if 'reasoningContent' in delta:
760 | if "text" in delta['reasoningContent']:
761 | think += delta['reasoningContent']['text']
762 | handler.markdown('**MODEL REASONING**\n\n' + think.replace("$", "\\$"), unsafe_allow_html=True)
763 | elif "signature" in delta['reasoningContent']:
764 | signature = delta['reasoningContent']['signature']
765 |
766 | elif "metadata" in chunk:
767 | if 'cacheReadInputTokens' in chunk['metadata']['usage']:
768 | print(f"\nCache Read Tokens: {chunk['metadata']['usage']['cacheReadInputTokens']}")
769 | print(f"Cache Write Tokens: {chunk['metadata']['usage']['cacheWriteInputTokens']}")
770 | st.session_state['input_token'] = chunk['metadata']['usage']["inputTokens"]
771 | st.session_state['output_token'] = chunk['metadata']['usage']["outputTokens"]
772 | latency = chunk['metadata']['metrics']["latencyMs"]
773 | pricing = st.session_state['input_token']*pricing_file[f"{params['model']}"]["input"] + st.session_state['output_token'] * pricing_file[f"{params['model']}"]["output"]
774 | st.session_state['cost']+=pricing
775 | print(f"\nInput Tokens: {st.session_state['input_token']}\nOutput Tokens: {st.session_state['output_token']}\nLatency: {latency}ms")
776 | return text, think
777 |
778 | def bedrock_claude_(params, chat_history, system_message, prompt,
779 | model_id, image_path=None, handler=None):
780 | """ format user request and chat history and make a call to Bedrock Runtime"""
781 | chat_history_copy = chat_history[:]
782 | content = []
783 | if image_path:
784 | if not isinstance(image_path, list):
785 | image_path = [image_path]
786 | for img in image_path:
787 | s3 = boto3.client('s3', region_name=REGION)
788 | match = re.match("s3://(.+?)/(.+)", img)
789 | image_name = os.path.basename(img)
790 | _, ext = os.path.splitext(image_name)
791 | if "jpg" in ext:
792 | ext = ".jpeg"
793 | bucket_name = match.group(1)
794 | key = match.group(2)
795 | if ".plotly" in key:
796 | print(key)
797 | bytes_image = plotly_to_png_bytes(img)
798 | ext = ".png"
799 | else:
800 | obj = s3.get_object(Bucket=bucket_name, Key=key)
801 | bytes_image = obj['Body'].read()
802 | content.extend([{"text":image_name},{
803 | "image": {
804 | "format": f"{ext.lower().replace('.', '')}",
805 | "source": {"bytes": bytes_image}
806 | }
807 | }])
808 |
809 | content.append({
810 | "text": prompt
811 | })
812 | chat_history_copy.append({"role": "user",
813 | "content": content})
814 | system_message = [{"text": system_message}]
815 |
816 | if st.session_state['reasoning_mode']:
817 | response = bedrock_runtime.converse_stream(messages=chat_history_copy, modelId=model_id,
818 | inferenceConfig={"maxTokens": 18000, "temperature": 1},
819 | system=system_message,
820 | additionalModelRequestFields={"thinking": {"type": "enabled", "budget_tokens": 10000}}
821 | )
822 | else:
823 | response = bedrock_runtime.converse_stream(messages=chat_history_copy, modelId=model_id,
824 | inferenceConfig={"maxTokens": 4000 if "deepseek" not in model_id else 20000,
825 | "temperature": 0.5 if "deepseek" not in model_id else 0.6,
826 | },
827 | system=system_message,
828 | )
829 | answer, think=bedrock_streemer(params, response, handler)
830 | return answer, think
831 |
832 | def _invoke_bedrock_with_retries(params, current_chat, chat_template, question, model_id, image_path, handler):
833 | max_retries = 10
834 | backoff_base = 2
835 | max_backoff = 3 # Maximum backoff time in seconds
836 | retries = 0
837 |
838 | while True:
839 | try:
840 | response, think = bedrock_claude_(params,current_chat, chat_template, question, model_id, image_path, handler)
841 | return response, think
842 | except ClientError as e:
843 | if e.response['Error']['Code'] == 'ThrottlingException':
844 | if retries < max_retries:
845 | # Throttling, exponential backoff
846 | sleep_time = min(max_backoff, backoff_base ** retries + random.uniform(0, 1))
847 | time.sleep(sleep_time)
848 | retries += 1
849 | else:
850 | raise e
851 | elif e.response['Error']['Code'] == 'ModelStreamErrorException':
852 | if retries < max_retries:
853 | # Throttling, exponential backoff
854 | sleep_time = min(max_backoff, backoff_base ** retries + random.uniform(0, 1))
855 | time.sleep(sleep_time)
856 | retries += 1
857 | else:
858 | raise e
859 | elif e.response['Error']['Code'] == 'EventStreamError':
860 | if retries < max_retries:
861 | # Throttling, exponential backoff
862 | sleep_time = min(max_backoff, backoff_base ** retries + random.uniform(0, 1))
863 | time.sleep(sleep_time)
864 | retries += 1
865 | else:
866 | raise e
867 | else:
868 | # Some other API error, rethrow
869 | raise
870 |
871 | def get_session_ids_by_user(table_name, user_id):
872 | """
873 | Get Session Ids and corresponding top message for a user to populate the chat history drop down on the front end
874 | """
875 | if DYNAMODB_TABLE:
876 | table = DYNAMODB.Table(table_name)
877 | message_list = {}
878 | session_ids = []
879 | args = {
880 | 'KeyConditionExpression': Key('UserId').eq(user_id)
881 | }
882 | while True:
883 | response = table.query(**args)
884 | session_ids.extend([item['SessionId'] for item in response['Items']])
885 | if 'LastEvaluatedKey' not in response:
886 | break
887 | args['ExclusiveStartKey'] = response['LastEvaluatedKey']
888 |
889 | for session_id in session_ids:
890 | try:
891 | message_list[session_id] = DYNAMODB.Table(table_name).get_item(Key={"UserId": user_id, "SessionId": session_id})['Item']['messages'][0]['user']
892 | except Exception as e:
893 | print(e)
894 | pass
895 | else:
896 | try:
897 | message_list={}
898 | # Read the existing JSON data from the file
899 | with open(LOCAL_CHAT_FILE_NAME, "r", encoding='utf-8') as file:
900 | existing_data = json.load(file)
901 | for session_id in existing_data:
902 | message_list[session_id]=existing_data[session_id][0]['user']
903 |
904 | except FileNotFoundError:
905 | # If the file doesn't exist, initialize an empty list
906 | message_list = {}
907 | return message_list
908 |
909 | def list_csv_xlsx_in_s3_folder(bucket_name, folder_path):
910 | """
911 | List all CSV and XLSX files in a specified S3 folder.
912 |
913 | :param bucket_name: Name of the S3 bucket
914 | :param folder_path: Path to the folder in the S3 bucket
915 | :return: List of CSV and XLSX file names in the folder
916 | """
917 | s3 = boto3.client('s3')
918 | csv_xlsx_files = []
919 |
920 | try:
921 | # Ensure the folder path ends with a '/'
922 | if not folder_path.endswith('/'):
923 | folder_path += '/'
924 |
925 | # List objects in the specified folder
926 | paginator = s3.get_paginator('list_objects_v2')
927 | page_iterator = paginator.paginate(Bucket=bucket_name, Prefix=folder_path)
928 |
929 | for page in page_iterator:
930 | if 'Contents' in page:
931 | for obj in page['Contents']:
932 | # Get the file name
933 | file_name = obj['Key']
934 |
935 | # Check if the file is a CSV or XLSX
936 | if file_name.lower().endswith(INPUT_EXT):
937 | csv_xlsx_files.append(os.path.basename(file_name))
938 | # csv_xlsx_files.append(file_name)
939 |
940 | return csv_xlsx_files
941 |
942 | except ClientError as e:
943 | print(f"An error occurred: {e}")
944 | return []
945 |
946 | def query_llm(params, handler):
947 | """
948 | Handles users requests and routes to a native call or tool use, then stores sonversation to local or DynamoDB
949 | """
950 |
951 | if not isinstance(params['upload_doc'], list):
952 | raise TypeError("documents must be in a list format")
953 |
954 | vision_model = True
955 | model = 'us.' + model_info[params['model']]
956 | if any(keyword in [params['model']] for keyword in NON_VISION_MODELS):
957 | vision_model = False
958 |
959 | # prompt template for when a user uploads a doc
960 | doc_path = []
961 | image_path = []
962 | full_doc_path = []
963 | doc = ""
964 | if params['tools']:
965 | messages, tool, results, image_holder, doc_list, stop_reason, plotly_fig = function_calling_utils.function_caller_claude_(params, handler)
966 | if stop_reason != "tool_use":
967 | return messages
968 | elif stop_reason == "tool_use":
969 | prompt = f"""You are a conversational AI Assitant.
970 | I will provide you with an question on a dataset, a python code that implements the solution to the question and the result of that code solution.
971 | Here is the question:
972 |
973 | {params['question']}
974 |
975 |
976 | Here is the python code:
977 |
978 | {tool['input']['code']}
979 |
980 |
981 | Here the result of the code:
982 |
983 | {results}
984 |
985 |
986 | After reading the user question, respond with a detailed analytical answer based entirely on the result from the code. Do NOT make up answers.
987 | When providing your respons:
988 | - Do not include any preamble, go straight to the answer.
989 | - It should not be obvious you are referencing the result."""
990 |
991 | system_message="You always provide your response in a well presented format using markdown. Make use of tables, list etc. where necessary in your response, so information is well preseneted and easily read."
992 | answer, think = _invoke_bedrock_with_retries(params, [], system_message, prompt, model, image_holder, handler)
993 |
994 | chat_history = {
995 | "user": results["text"] if "text" in results else "",
996 | "assistant": answer,
997 | "image": image_holder,
998 | "document": [], # data_file,#doc_list,
999 | "plotly": plotly_fig,
1000 | "modelID": model,
1001 | "thinking": think,
1002 | "code": tool['input']['code'],
1003 | "time": str(time.time()),
1004 | "input_token": round(st.session_state['input_token']) ,
1005 | "output_token": round(st.session_state['output_token']),
1006 | "tool_result_id": tool['toolUseId'],
1007 | "tool_name": '',
1008 | "tool_params": ''
1009 | }
1010 | # store convsation memory in DynamoDB table
1011 | if DYNAMODB_TABLE:
1012 | put_db(params, chat_history)
1013 | # use local disk for storage
1014 | else:
1015 | save_chat_local(LOCAL_CHAT_FILE_NAME, [chat_history], params["session_id"])
1016 | return answer
1017 | else:
1018 | current_chat, chat_hist = get_chat_history_db(params, CHAT_HISTORY_LENGTH, vision_model)
1019 | if params['upload_doc'] or params['s3_objects']:
1020 | if params['upload_doc']:
1021 | doc = 'I have provided documents and/or images.\n'
1022 | for ids, docs in enumerate(params['upload_doc']):
1023 | file_name = docs.name
1024 | _, extensions = os.path.splitext(file_name)
1025 | docs = put_obj_in_s3_bucket_(docs)
1026 | full_doc_path.append(docs)
1027 | if extensions.lower() in [".jpg", ".jpeg", ".png", ".gif", ".webp"] and vision_model:
1028 | image_path.append(docs)
1029 | continue
1030 |
1031 | if params['s3_objects']:
1032 | doc = 'I have provided documents and/or images.\n'
1033 | for ids, docs in enumerate(params['s3_objects']):
1034 | file_name = docs
1035 | _, extensions = os.path.splitext(file_name)
1036 | docs = put_obj_in_s3_bucket_(f"s3://{INPUT_BUCKET}/{INPUT_S3_PATH}/{docs}")
1037 | full_doc_path.append(docs)
1038 | if extensions.lower() in [".jpg", ".jpeg", ".png", ".gif", ".webp"] and vision_model:
1039 | image_path.append(docs)
1040 | continue
1041 |
1042 | doc_path = [item for item in full_doc_path if item not in image_path]
1043 | errors, result_string = process_files(doc_path)
1044 | if errors:
1045 | st.error(errors)
1046 | doc += result_string
1047 | with open("prompt/doc_chat.txt", "r", encoding="utf-8") as f:
1048 | chat_template = f.read()
1049 | else:
1050 | # Chat template for open ended query
1051 | with open("prompt/chat.txt", "r", encoding="utf-8") as f:
1052 | chat_template = f.read()
1053 |
1054 | response, think = _invoke_bedrock_with_retries(params, current_chat, chat_template,
1055 | doc+params['question'], model,
1056 | image_path, handler)
1057 | # log the following items to dynamodb
1058 | chat_history = {
1059 | "user": params['question'],
1060 | "assistant": response,
1061 | "image": image_path,
1062 | "document": doc_path,
1063 | "modelID": model,
1064 | "thinking": think,
1065 | "time": str(time.time()),
1066 | "input_token": round(st.session_state['input_token']),
1067 | "output_token": round(st.session_state['output_token'])
1068 | }
1069 | # store convsation memory and user other items in DynamoDB table
1070 | if DYNAMODB_TABLE:
1071 | put_db(params, chat_history)
1072 | # use local memory for storage
1073 | else:
1074 | save_chat_local(LOCAL_CHAT_FILE_NAME, [chat_history], params["session_id"])
1075 | return response
1076 |
1077 |
1078 | def get_chat_historie_for_streamlit(params):
1079 | """
1080 | This function retrieves chat history stored in a dynamoDB table partitioned by a userID and sorted by a SessionID
1081 | """
1082 | if DYNAMODB_TABLE:
1083 | chat_histories = DYNAMODB.Table(DYNAMODB_TABLE).get_item(Key={"UserId": st.session_state['userid'], "SessionId":params["session_id"]})
1084 | # st.write(chat_histories)
1085 | if "Item" in chat_histories:
1086 | chat_histories = chat_histories['Item']['messages']
1087 | else:
1088 | chat_histories = []
1089 | else:
1090 | chat_histories = load_chat_local(LOCAL_CHAT_FILE_NAME, params["session_id"])
1091 |
1092 | # Constructing the desired list of dictionaries
1093 | formatted_data = []
1094 | if chat_histories:
1095 | for entry in chat_histories:
1096 | image_files = [os.path.basename(x) for x in entry.get('image', [])]
1097 | doc_files = [os.path.basename(x) for x in entry.get('document', [])]
1098 | code_script = entry.get('code', "")
1099 | assistant_attachment = '\n\n'.join(image_files+doc_files)
1100 | # Get entries but dont show the Function calling unecessary parts in the chat dialogue on streamlit
1101 | if "tool_result_id" in entry and not entry["tool_result_id"]:
1102 | formatted_data.append({
1103 | "role": "user",
1104 | "content": entry["user"],
1105 | "thinking": entry.get('thinking', "")
1106 | })
1107 | elif not "tool_result_id" in entry :
1108 | formatted_data.append({
1109 | "role": "user",
1110 | "content": entry["user"],
1111 | "thinking": entry.get('thinking', "")
1112 | })
1113 | if "tool_use_id" in entry and not entry["tool_use_id"]:
1114 | formatted_data.append({
1115 | "role": "assistant",
1116 | "content": entry["assistant"],
1117 | "attachment": assistant_attachment,
1118 | "code": code_script,
1119 | "thinking": entry.get('thinking', "")
1120 | # "image_output": entry.get('image', []) if entry["tool_result_id"] else []
1121 | })
1122 | elif "tool_use_id" not in entry:
1123 | formatted_data.append({
1124 | "role": "assistant",
1125 | "content": entry["assistant"],
1126 | "attachment": assistant_attachment,
1127 | "code": code_script,
1128 | "code-result": entry["user"],
1129 | "image_output": entry.get('image', []) if "tool_result_id" in entry else [],
1130 | "plotly": entry.get('plotly', []) if "tool_result_id" in entry else [],
1131 | "thinking": entry.get('thinking', "")
1132 | })
1133 | else:
1134 | chat_histories=[]
1135 | return formatted_data,chat_histories
1136 |
1137 |
1138 |
1139 | def get_key_from_value(dictionary, value):
1140 | return next((key for key, val in dictionary.items() if val == value), None)
1141 |
1142 | def chat_bedrock_(params):
1143 | st.title('Chatty AI Assitant 🙂')
1144 | params['chat_histories'] = []
1145 | if params["session_id"].strip():
1146 | st.session_state.messages, params['chat_histories'] = get_chat_historie_for_streamlit(params)
1147 | for message in st.session_state.messages:
1148 |
1149 | with st.chat_message(message["role"]):
1150 | if "```" in message["content"]:
1151 | st.markdown(message["content"], unsafe_allow_html=True)
1152 | else:
1153 | st.markdown(message["content"].replace("$", "\\$"), unsafe_allow_html=True)
1154 | if message["role"] == "assistant":
1155 | if message["plotly"]:
1156 | for item in message["plotly"]:
1157 | bucket_name, key = item.replace('s3://', '').split('/', 1)
1158 | image_bytes = get_object_with_retry(bucket_name, key)
1159 | content = image_bytes['Body'].read()
1160 | json_data = json.loads(content.decode('utf-8'))
1161 | try:
1162 | fig = pio.from_json(json.dumps(json_data))
1163 | except Exception:
1164 | decoded_data = decode_json(json_data)
1165 | fig = go.Figure(data=decoded_data['data'], layout=decoded_data['layout'])
1166 | st.plotly_chart(fig)
1167 |
1168 | elif message["image_output"]:
1169 | for item in message["image_output"]:
1170 | bucket_name, key = item.replace('s3://', '').split('/', 1)
1171 | image_bytes = get_object_with_retry(bucket_name, key)
1172 | # image_bytes=base64.b64decode(message["image"][image_idx])
1173 | image = Image.open(io.BytesIO(image_bytes['Body'].read()))
1174 | st.image(image)
1175 | if message["attachment"]:
1176 | with st.expander(label="**attachments**"):
1177 | st.markdown(message["attachment"])
1178 | # st.markdown(message["image_output"])
1179 | if message['code']:
1180 | with st.expander(label="**code snippet**"):
1181 | st.markdown(f'```python\n{message["code"]}', unsafe_allow_html=True)
1182 | with st.expander(label="**code result**"):
1183 | st.markdown(f'```python\n{message["code-result"]}', unsafe_allow_html=True)
1184 | if message['thinking']:
1185 | with st.expander(label="**MODEL REASONING**"):
1186 | st.markdown(message["thinking"].replace("$", "\\$"), unsafe_allow_html=True)
1187 |
1188 | if prompt := st.chat_input("Whats up?"):
1189 | st.session_state.messages.append({"role": "user", "content": prompt})
1190 | with st.chat_message("user"):
1191 | st.markdown(prompt.replace("$", "\\$"), unsafe_allow_html=True )
1192 | with st.chat_message("assistant"):
1193 | message_placeholder = st.empty()
1194 | params["question"] = prompt
1195 | answer = query_llm(params, message_placeholder)
1196 | message_placeholder.markdown(answer.replace("$", "\\$"), unsafe_allow_html=True )
1197 | st.session_state.messages.append({"role": "assistant", "content": answer})
1198 | st.rerun()
1199 |
1200 | def app_sidebar():
1201 | with st.sidebar:
1202 | st.metric(label="Bedrock Session Cost", value=f"${round(st.session_state['cost'], 2)}")
1203 | st.write("-----")
1204 | button = st.button("New Chat", type="primary")
1205 | models = MODEL_DISPLAY_NAME
1206 | model = st.selectbox('**Model**', models)
1207 | if any(keyword in [model] for keyword in HYBRID_MODELS):
1208 | st.session_state['reasoning_mode'] = st.toggle("Reasoning Mode", value=False, key="thinking")
1209 | st.write(st.session_state['reasoning_mode'])
1210 | else:
1211 | st.session_state['reasoning_mode'] = False
1212 | runtime = ""
1213 | tools = ""
1214 | user_sess_id = get_session_ids_by_user(DYNAMODB_TABLE, st.session_state['userid'])
1215 | float_keys = {float(key): value for key, value in user_sess_id.items()}
1216 | sorted_messages = sorted(float_keys.items(), reverse=True)
1217 | sorted_messages.insert(0, (float(st.session_state['user_sess']), "New Chat"))
1218 | if button:
1219 | st.session_state['user_sess'] = str(time.time())
1220 | sorted_messages.insert(0, (float(st.session_state['user_sess']), "New Chat"))
1221 | st.session_state['chat_session_list'] = dict(sorted_messages)
1222 | chat_items = st.selectbox("**Chat Sessions**", st.session_state['chat_session_list'].values(), key="chat_sessions")
1223 | session_id = get_key_from_value(st.session_state['chat_session_list'], chat_items)
1224 | if model not in NON_TOOL_SUPPORTING_MODELS:
1225 | tools = st.multiselect("**Tools**", ["Advanced Data Analytics"],
1226 | key="function_collen", default=None)
1227 | if "Advanced Data Analytics" in tools:
1228 | engines = ["pyspark", "python"]
1229 | runtime = st.select_slider(
1230 | "Runtime", engines, key="enginees"
1231 | )
1232 | bucket_items = list_csv_xlsx_in_s3_folder(INPUT_BUCKET, INPUT_S3_PATH)
1233 | bucket_objects = st.multiselect("**Files**", bucket_items, key="objector", default=None)
1234 | file = st.file_uploader('Upload a document', accept_multiple_files=True,
1235 | help="pdf,csv,txt,png,jpg,xlsx,json,py doc format supported")
1236 | if file and LOAD_DOC_IN_ALL_CHAT_CONVO:
1237 | st.warning('You have set **load-doc-in-chat-history** to true. For better performance, remove uploaded file(s) (by clicking **X**) **AFTER** first query on uploaded files. See the README for more info', icon="⚠️")
1238 | params = {"model": model, "session_id": str(session_id),
1239 | "chat_item": chat_items,
1240 | "upload_doc": file,
1241 | "tools": tools,
1242 | 's3_objects': bucket_objects,
1243 | "engine": runtime
1244 | }
1245 | st.session_state['count'] = 1
1246 | return params
1247 |
1248 |
1249 | def main():
1250 | params = app_sidebar()
1251 | chat_bedrock_(params)
1252 | if __name__ == '__main__':
1253 | main()
1254 |
--------------------------------------------------------------------------------
/config.json:
--------------------------------------------------------------------------------
1 | {"DynamodbTable": "", "UserId": "user id", "Bucket_Name": "S3 bucket name (without s3:// prefix) to save uploaded files", "max-output-token":4000, "chat-history-loaded-length":10, "region":"aws region", "load-doc-in-chat-history":true, "AmazonTextract":false, "csv-delimiter":"|", "document-upload-cache-s3-path":"S3 prefix without trailing or foward slash at either ends", "AmazonTextract-result-cache":"S3 prefix without trailing or foward slash at either ends", "lambda-function":"lambda function name", "input_s3_path":"S3 prefix without trailing or foward slash at either ends","input_bucket":"S3 Bucket name (without s3:// prefix) to render files on app","input_file_ext":"csv,xlsx,parquet", "athena-work-group-name":"name of athena spark workgroup"}
2 |
--------------------------------------------------------------------------------
/images/JP-lab.PNG:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aws-samples/bedrock-claude-chatbot/d1ea1143554e92f6c58f84a52a3d68b77c5f106c/images/JP-lab.PNG
--------------------------------------------------------------------------------
/images/chat-flow.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aws-samples/bedrock-claude-chatbot/d1ea1143554e92f6c58f84a52a3d68b77c5f106c/images/chat-flow.png
--------------------------------------------------------------------------------
/images/chat-preview.JPG:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aws-samples/bedrock-claude-chatbot/d1ea1143554e92f6c58f84a52a3d68b77c5f106c/images/chat-preview.JPG
--------------------------------------------------------------------------------
/images/chatbot-snip.PNG:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aws-samples/bedrock-claude-chatbot/d1ea1143554e92f6c58f84a52a3d68b77c5f106c/images/chatbot-snip.PNG
--------------------------------------------------------------------------------
/images/chatbot4.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aws-samples/bedrock-claude-chatbot/d1ea1143554e92f6c58f84a52a3d68b77c5f106c/images/chatbot4.png
--------------------------------------------------------------------------------
/images/demo.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aws-samples/bedrock-claude-chatbot/d1ea1143554e92f6c58f84a52a3d68b77c5f106c/images/demo.mp4
--------------------------------------------------------------------------------
/images/sg-rules.PNG:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aws-samples/bedrock-claude-chatbot/d1ea1143554e92f6c58f84a52a3d68b77c5f106c/images/sg-rules.PNG
--------------------------------------------------------------------------------
/images/studio-new-launcher.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aws-samples/bedrock-claude-chatbot/d1ea1143554e92f6c58f84a52a3d68b77c5f106c/images/studio-new-launcher.png
--------------------------------------------------------------------------------
/install_package.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | # Path to the requirements.txt file
4 | REQUIREMENTS_FILE="req.txt"
5 |
6 | # Check if the requirements file exists
7 | if [ ! -f "$REQUIREMENTS_FILE" ]; then
8 | echo "requirements.txt file not found!"
9 | exit 1
10 | fi
11 |
12 | # Loop through each line in the requirements.txt file
13 | while IFS= read -r package || [ -n "$package" ]; do
14 | # Install the package using pip
15 | echo "Installing $package"
16 | pip install "$package"
17 |
18 | # Check if the installation was successful
19 | if [ $? -eq 0 ]; then
20 | echo "$package installed successfully"
21 | else
22 | echo "Failed to install $package"
23 | exit 1
24 | fi
25 | done < "$REQUIREMENTS_FILE"
26 |
27 | echo "All packages installed successfully."
28 |
--------------------------------------------------------------------------------
/model_id.json:
--------------------------------------------------------------------------------
1 | {
2 | "sonnet-4": "anthropic.claude-sonnet-4-20250514-v1:0",
3 | "sonnet-3.7": "anthropic.claude-3-7-sonnet-20250219-v1:0",
4 | "opus-4": "anthropic.claude-opus-4-20250514-v1:0",
5 | "sonnet-3.5-v2": "anthropic.claude-3-5-sonnet-20241022-v2:0",
6 | "haiku-3.5": "anthropic.claude-3-5-haiku-20241022-v1:0",
7 | "deepseek": "deepseek.r1-v1:0",
8 | "nova-pro": "amazon.nova-pro-v1:0",
9 | "nova-lite": "amazon.nova-lite-v1:0",
10 | "nova-micro": "amazon.nova-micro-v1:0",
11 | "nova-premier": "amazon.nova-premier-v1:0",
12 | "meta-maverick": "meta.llama4-maverick-17b-instruct-v1:0",
13 | "meta-scout": "meta.llama4-scout-17b-instruct-v1:0"
14 | }
--------------------------------------------------------------------------------
/pricing.json:
--------------------------------------------------------------------------------
1 | {"sonnet-4": {"input": 0.000003, "output": 0.000015},"opus-4": {"input": 0.000015, "output": 0.000075},"sonnet-3.5-v1": {"input": 0.000003, "output": 0.000015},"sonnet-3.5-v2": {"input": 0.000003, "output": 0.000015},"haiku-3.5": {"input": 0.0000008, "output": 0.000004},"haiku": {"input": 0.00000025, "output": 0.00000125},"sonnet-3.7": {"input": 0.000003, "output": 0.000015},"meta-maverick": {"input": 0.00000024, "output": 0.00000097}, "meta-scout": {"input": 0.00000017, "output": 0.00000066},"nova-lite": {"input": 0.00000006, "output": 0.000000024},"nova-premier": {"input": 0.00000025, "output": 0.00000125},"nova-pro": {"input": 0.00000032, "output": 0.00000006},"deepseek": {"input": 0.000000135, "output": 0.00000054}, "nova-micro": {"input": 0.0000000035, "output": 0.000000014}}
--------------------------------------------------------------------------------
/prompt/chat.txt:
--------------------------------------------------------------------------------
1 | You are a conversational AI assistant, proficient in delivering high-quality responses and resolving tasks effectively. You are very attentive and respond in markdown format.
--------------------------------------------------------------------------------
/prompt/doc_chat.txt:
--------------------------------------------------------------------------------
1 | You are a conversational assistant, expert at providing quality and accurate answers based on a document(s) and/or image(s) provided. You are very attentive and respond in markdown format. Take your time to read through the document(s) and/or image(s) carefully and pay attention to relevant areas pertaining to the question(s). Once done reading, provide an answer to the user question(s).
--------------------------------------------------------------------------------
/prompt/pyspark_debug_prompt.txt:
--------------------------------------------------------------------------------
1 | I will provide you a pyspark code that analyzes a tabular data and an error relating to the code.
2 | Here is a subset (first 3 rows) of each dataset:
3 |
4 | {dataset}
5 |
6 |
7 | Here is the pyspark code to analyze the data:
8 |
9 | {code}
10 |
11 |
12 | Here is the thrown error:
13 |
14 | {error}
15 |
16 |
17 | Debug and fix the code. Think through where the potential bug is and what solution is needed, put all this thinking process in XML tags.
18 | The data files are stored in Amazon S3 (the XML tags point to the S3 URI) and must be read from S3.
19 |
20 | Important considerations:
21 | - Always use S3A file system when reading files from S3.
22 | - Always generate the full code.
23 | - Each plots must be saved as PLOTLY (.plotly) files in '/tmp' directory.
24 | - Use proper namespace management for python and pyspark libraries.
25 | - DO NOT use a `try-exception` block to handle exceptions when writing the correct code, this prevents my front end code from handling exceptions properly.
26 | - Use pio.write_json() to save images as ".plotly" files
27 |
28 | Additional info: The code must output a JSON object variable name "output" with following keys:
29 | - 'text': Any text output generated by the Python code
30 | - 'plotly-files': Plotly objects saved as ".plotly"
31 |
32 | Reflect on your approach and considerations provided in reflection XML tags.
33 |
34 | Provide the fixed code within XML tags and all python top-level package names (seperated by comma, no extra formatting) needed within XML tags
--------------------------------------------------------------------------------
/prompt/pyspark_tool_prompt.txt:
--------------------------------------------------------------------------------
1 | Purpose: Analyze structured data and generate text and graphical outputs using pyspark code without interpreting results.
2 | Input: Structured data file(s) (CSV, PARQUET)
3 | Processing:
4 | - Read input files from Amazon S3:
5 | CSV files: load using CSV-specific methods, e.g., spark.read.csv("s3a://path/to/your/file.csv", header=True, inferSchema=True)
6 | PARQUET files: Load using Parquet-specific methods, e.g., spark.read.parquet("s3a://path/to/your/file.parquet")
7 | Process files according to their true type, not their sample representation. Each file must be read from Amazon S3.
8 | - Perform statistical analysis
9 | - When working with columns that have special characters (".", " ") etc, wrap columns names in backtick "`".
10 | - Generate plots when possible.
11 | - If multiple data files are provided, always load all dataset for analysis.
12 |
13 | Visualization:
14 | - Ensure plots are clear and legible with a figure size of 10 by 12 inches.
15 | - Use Plotly for plots when possible and save plotly objects as ".plotly" also in /tmp directory
16 | - When genrating plots, use contarsting colours for legibitlity.
17 | - Remember, you should save .plotly version for each generated plot.
18 |
19 | Output: JSON object named "output" with:
20 | - 'text': All text-based results and printed information
21 | - 'plotly-files': Plotly objects saved as ".plotly"
22 |
23 | Important:
24 | - Generate code for analysis only, without interpreting results
25 | - Avoid making conclusive statements or recommendations
26 | - Present calculated statistics and generated visualizations without drawing conclusions
27 | - Save plots .plotly files accordingly using pio.write_json() to '/tmp' directory
28 | - Use proper namespace management for Python and PySpark libraries:
29 | - Import only the necessary modules or functions to avoid cluttering the namespace.
30 | - Use aliases for long module names (e.g., 'import pandas as pd').
31 | - Avoid using 'from module import *' as it can lead to naming conflicts. Instead of 'from pyspark.sql import *' do 'from pyspark.sql import functions as F'
32 | - Group imports logically: standard library imports first, then third-party libraries, followed by local application imports.
33 | - Use efficient, well-documented, PEP 8 compliant code
34 | - Follow data analysis and visualization best practices
35 | - Include plots whenever possible
36 | - Store all results in the 'output' JSON object
37 | - Ensure 'output' is the final variable assigned in the code, whether inside or outside a function
38 |
39 | Example:
40 | import plotly.io as pio
41 | from pyspark.sql import functions as F
42 | from pyspark.sql.window import Window
43 | ... REST of IMPORT
44 |
45 | # In Amazon Athena, Spark context is already initialized as "spark" variable, no need to initialize
46 |
47 | # Read data
48 | df = spark.read.csv("s3a://path/to/file/file.csv") # for parquet use spark.read.parquet(..)
49 |
50 | ...REST OF CODE
51 |
52 | #Save plotly figures
53 | pio.write_json(fig, 'tmp/plot.plotly')
54 |
55 | # Prepare output
56 | output = {
57 | 'text': 'Statistical analysis results...\nOther printed output...',
58 | 'plotly-files': 'plot.plotly' # or ['plot1.plotly', 'plot2.plotly'] for multiple plotly figures
59 | }
60 |
61 | # No Need to stop Spark context
--------------------------------------------------------------------------------
/prompt/pyspark_tool_system.txt:
--------------------------------------------------------------------------------
1 | You are a conversational AI assistant, proficient in delivering high-quality responses and resolving tasks effectively.
2 | You will have access to a set of "tools" for handling specific request, use your judgement to figure out if you need to use a tool and what tool to use. I will provide the tool description below that guides you on if and when to use a tool:
3 | 1. pyspark_function_tool: This tool is used to handle structured data files to perform any data analysis query and task on such files (CSV, PQ). Structure data will usually be tagged by the file name and will be in a CSV string.
4 | If a user query does not need a tool, go ahead an answer the question directly without using any tool. Do not include any preamble
--------------------------------------------------------------------------------
/prompt/pyspark_tool_template.json:
--------------------------------------------------------------------------------
1 | {
2 | "tools": [
3 | {
4 | "toolSpec": {
5 | "name": "pyspark_function_tool",
6 | "description": "This tool allows you to analyze structured data files (CSV, PARQUET, XLSX, etc.) using PySpark programming language. It can be used to answer questions or perform analyses on the data contained in these files. Use this tool when the user asks questions or requests analysis related to structured data files. Do not use this tool for any other query not related to analyzing tabular data files. When using this tool, first think through the user ask and understand the data types within the dataset, put your thoughts in XML tags as your scratch pad.",
7 | "inputSchema": {
8 | "json": {
9 | "type": "object",
10 | "properties": {
11 | "code": {
12 | "type": "string",
13 | "description": ""
14 |
15 | },
16 | "dataset_name": {
17 | "type": "string",
18 | "description": "The file name of the structured dataset including its extension (CSV, PQ ..etc)"
19 | },
20 | "python_packages": {
21 | "type": "string",
22 | "description": "Comma-separated list of Python libraries required to run the function"
23 | }
24 | },
25 | "required": ["code","dataset_name","python_packages"]
26 | }
27 | }
28 | }
29 | }
30 | ]
31 | }
--------------------------------------------------------------------------------
/prompt/python_debug_prompt.txt:
--------------------------------------------------------------------------------
1 | I will provide you a python code that analyzes a tabular data and an error relating to the code.
2 | Here is a subset (first few rows) of each dataset:
3 |
4 | {dataset}
5 |
6 |
7 | Here is the python code to analyze the data:
8 |
9 | {code}
10 |
11 |
12 | Here is the thrown error:
13 |
14 | {error}
15 |
16 |
17 | Debug and fix the code. Think through where the potential bug is and what solution is needed, put all this thinking process in XML tags.
18 | The data files are stored in Amazon S3 (the XML tags point to the S3 URI) and must be read from S3.
19 | Images, if available in the code, must be saved in the '/tmp' directory.
20 | DO NOT use a `try-exception` block to handle exceptions when writing the correct code, this prevents my front end code from handling exceptions properly.
21 | Additional info: The code must output a JSON object variable name "output" with following keys:
22 | - 'text': Any text output generated by the Python code.
23 | - 'image': If the Python code generates any image outputs, image filenames (without the '/tmp' parent directory) will be mapped to this key. If no image is generated, no need for this key. (Must be in list format)
24 | - 'plotly-files': 'plot.plotly' # or ['plot1.plotly', 'plot2.plotly'] for multiple plotly figures
25 | Finally "output" variable must be saved as "output.json" to "/tmp" dir.
26 | Provide the fixed code within XML tags and all python top-level package names (seperated by comma, no extra formatting) needed within XML tags
--------------------------------------------------------------------------------
/prompt/python_tool_prompt.txt:
--------------------------------------------------------------------------------
1 | Purpose: Analyze structured data and generate text and graphical outputs using python code without interpreting results.
2 | Input: Structured data file(s) (CSV, XLS, XLSX)
3 | Processing:
4 | - Read input files from Amazon S3:
5 | CSV files: pd.read_csv(s3://path/to/your/file.csv)
6 | XLS/XLSX files: Load using Excel-specific methods (e.g., pd.read_excel(s3://path/to/your/file.xlsx))
7 | - Perform statistical analysis
8 | - Generate plots when possible
9 |
10 | Visualization:
11 | - Ensure plots are clear and legible with a figure size of 10 by 12 inches.
12 | - Save generated plots as PNG files in /tmp directory. Use appropiate filenames based on title of plots.
13 | - Always use fig.write_image() to save plotly figures as PNG.
14 | - Use Plotly for plots when possible and save plotly objects as ".plotly" also in /tmp directory
15 | - When genrating plots, use contarsting colours for legibitlity.
16 | - Remember, you should save a PNG and .plotly version for each generated plot.
17 | - When using Matplotlib, create a temporary directory and set it as MPLCONFIGDIR before importing any libraries to avoid permission issues in restricted environments.
18 |
19 | Output: JSON object named "output" with:
20 | - 'text': All text-based results and printed information
21 | - 'image': Filename(s) of PNG plot(s)
22 | - 'plotly-files': Plotly objects saved as ".plotly"
23 | - save 'output.json' in '/tmp' directory
24 |
25 | Important:
26 | - Generate code for analysis only, without interpreting results
27 | - Avoid making conclusive statements or recommendations
28 | - Present calculated statistics and generated visualizations without drawing conclusions
29 | - Save plots as PNG and .plotly files accordingly
30 |
31 | Notes:
32 | - Take time to think about the code to be generated for the user query
33 | - Save plots as PNG and .plotly files in '/tmp' directory
34 | - Use efficient, well-documented, PEP 8 compliant code
35 | - Follow data analysis and visualization best practices
36 | - Include plots whenever possible
37 | - Store all results in the 'output' JSON object
38 | - Ensure 'output' is the final variable assigned in the code, whether inside or outside a function
39 |
40 | Example:
41 | import plotly.io as pio
42 | import pandas as pd
43 | ... REST of IMPORT
44 |
45 | # Read the data
46 | df = pd.read_csv("s3://path/to/file/file.csv")
47 |
48 | ...REST OF CODE
49 |
50 | #Save plot as PNG
51 | fig.write_image("/tmp/plot.png")
52 |
53 | #Save plots as PLOTLY files
54 | pio.write_json(fig, '/tmp/plot.plotly')
55 |
56 | # Prepare output
57 | output = {
58 | 'text': '''Statistical analysis results...\nOther printed output...''',
59 | 'image': 'plot.png' # or ['plot1.png', 'plot2.png'] for multiple images
60 | 'plotly-files': 'plot.plotly' # or ['plot1.plotly', 'plot2.plotly'] for multiple plotly figures
61 | }
62 |
63 | # Save output as JSON file in 'tmp' dir
64 | with open('/tmp/output.json', 'w') as f:
65 | json.dump(output, f)
--------------------------------------------------------------------------------
/prompt/python_tool_system.txt:
--------------------------------------------------------------------------------
1 | You are a conversational AI assistant, proficient in delivering high-quality responses and resolving tasks effectively.
2 | You will have access to a set of "tools" for handling specific request, use your judgement to figure out if you need to use a tool and what tool to use. I will provide the tool description below that guides you on if and when to use a tool:
3 | 1. python_function_tool: This tool is used to handle structured data files to perform any data analysis query and task on such files (CSV, XLSX, etc.). Structure data will usually be tagged by the file name and will be in a CSV string.
4 |
5 | If a user query does not need a tool, go ahead an answer the question directly without using any tool. Do not include any preamble
--------------------------------------------------------------------------------
/prompt/python_tool_template.json:
--------------------------------------------------------------------------------
1 | {
2 | "tools": [
3 | {
4 | "toolSpec": {
5 | "name": "python_function_tool",
6 | "description": "This tool allows you to analyze structured data files (CSV, XLSX, etc.) using Python programming language. It can be used to answer questions or perform analyses on the data contained in these files. Use this tool when the user asks questions or requests analyses related to structured data files. Do not use this tool for any other query not related to analyzing tabular data files.",
7 | "inputSchema": {
8 | "json": {
9 | "type": "object",
10 | "properties": {
11 | "code": {
12 | "type": "string",
13 | "description": ""
14 |
15 | },
16 | "dataset_name": {
17 | "type": "string",
18 | "description": "The file name of the structured dataset including its extension (CSV, XLSX ..etc)"
19 | },
20 | "python_packages": {
21 | "type": "string",
22 | "description": "Comma-separated list of Python libraries required to run the function"
23 | }
24 | },
25 | "required": ["code","dataset_name","python_packages"]
26 | }
27 | }
28 | }
29 | }
30 | ]
31 | }
--------------------------------------------------------------------------------
/req.txt:
--------------------------------------------------------------------------------
1 | packaging
2 | setuptools
3 | streamlit
4 | boto3
5 | pymupdf
6 | pandas
7 | numpy
8 | amazon-textract-textractor
9 | openpyxl
10 | inflect
11 | pypdf2
12 | pytesseract
13 | python-pptx
14 | python-docx
15 | pillow
16 | openpyxl
17 | pydantic
18 | python-calamine
19 | s3fs
20 | plotly>=5.0.0,<6.0.0
21 | kaleido
--------------------------------------------------------------------------------
/utils/athena_handler_.py:
--------------------------------------------------------------------------------
1 | import boto3
2 | from botocore.exceptions import ClientError
3 | import time
4 | import streamlit as st
5 | import json
6 |
7 | def start_athena_session_(
8 | workgroup_name,
9 | description="Starting Athena session",
10 | coordinator_dpu_size=1,
11 | max_concurrent_dpus=60,
12 | default_executor_dpu_size=1,
13 | additional_configs=None,
14 | spark_properties=None,
15 | # notebook_version="Athena notebook version 1",
16 | session_idle_timeout_in_minutes=None,
17 | client_request_token=None
18 | ):
19 | """
20 | Start an Athena session using boto3.
21 |
22 | Args:
23 | workgroup_name (str): The name of the workgroup.
24 | description (str): A description of the session. Default is "Starting Athena session".
25 | coordinator_dpu_size (int): The size of the coordinator DPU. Default is 1.
26 | max_concurrent_dpus (int): The maximum number of concurrent DPUs. Default is 20.
27 | default_executor_dpu_size (int): The default size of executor DPUs. Default is 1.
28 | additional_configs (dict): Additional configurations. Default is None.
29 | spark_properties (dict): Spark properties. Default is None.
30 | notebook_version (str): The version of the Athena notebook. Default is "Athena notebook version 1".
31 | session_idle_timeout_in_minutes (int): The idle timeout for the session in minutes. Default is None.
32 | client_request_token (str): A unique, case-sensitive identifier that you provide to ensure the idempotency of the request. Default is None.
33 |
34 | Returns:
35 | dict: A dictionary containing the SessionId and State of the started session.
36 | """
37 |
38 | # Create an Athena client
39 | athena_client = boto3.client('athena')
40 |
41 | # Define the engine configuration
42 | engine_configuration = {
43 | 'CoordinatorDpuSize': coordinator_dpu_size,
44 | 'MaxConcurrentDpus': max_concurrent_dpus,
45 | 'DefaultExecutorDpuSize': default_executor_dpu_size,
46 | }
47 |
48 | if additional_configs:
49 | engine_configuration['AdditionalConfigs'] = additional_configs
50 |
51 | if spark_properties:
52 | engine_configuration['SparkProperties'] = spark_properties
53 |
54 | # Prepare the request parameters
55 | request_params = {
56 | 'Description': description,
57 | 'WorkGroup': workgroup_name,
58 | 'EngineConfiguration': engine_configuration,
59 | # 'NotebookVersion': notebook_version
60 | }
61 |
62 | if session_idle_timeout_in_minutes is not None:
63 | request_params['SessionIdleTimeoutInMinutes'] = session_idle_timeout_in_minutes
64 |
65 | if client_request_token:
66 | request_params['ClientRequestToken'] = client_request_token
67 |
68 | try:
69 | # Start the Athena session
70 | response = athena_client.start_session(**request_params)
71 |
72 | # Extract relevant information
73 | session_info = {
74 | 'SessionId': response['SessionId'],
75 | 'State': response['State']
76 | }
77 |
78 | print(f"Athena session started successfully.")
79 | print(f"Session ID: {session_info['SessionId']}")
80 | print(f"State: {session_info['State']}")
81 |
82 | return session_info
83 |
84 | except ClientError as e:
85 | print(f"An error occurred while starting the Athena session: {e.response['Error']['Message']}")
86 | return None
87 |
88 |
89 | class SessionFailedException(Exception):
90 | """Custom exception for when the session is in a FAILED state."""
91 | pass
92 |
93 | class SessionTimeoutException(Exception):
94 | """Custom exception for when the session check times out."""
95 | pass
96 |
97 | def wait_for_session_status(session_id, max_wait_seconds=300, check_interval_seconds=3):
98 | """
99 | Wait for an Athena session to reach either IDLE or FAILED state.
100 |
101 | Args:
102 | session_id (str): The ID of the session to check.
103 | max_wait_seconds (int): Maximum time to wait in seconds. Default is 300 seconds (5 minutes).
104 | check_interval_seconds (int): Time to wait between status checks in seconds. Default is 10 seconds.
105 |
106 | Returns:
107 | bool: True if the session state is IDLE.
108 |
109 | Raises:
110 | SessionFailedException: If the session state is FAILED.
111 | SessionTimeoutException: If the maximum wait time is exceeded.
112 | ClientError: If there's an error in the AWS API call.
113 | """
114 |
115 | athena_client = boto3.client('athena')
116 | start_time = time.time()
117 |
118 | while True:
119 | try:
120 | response = athena_client.get_session_status(SessionId=session_id)
121 | state = response['Status']['State']
122 |
123 | print(f"Session {session_id} is in state: {state}")
124 |
125 | if state == 'IDLE':
126 | return True
127 | elif state == 'FAILED':
128 | reason = response['Status'].get('StateChangeReason', 'No reason provided')
129 | raise SessionFailedException(f"Session {session_id} has FAILED. Reason: {reason}")
130 | elif state == 'TERMINATED':
131 | # return f"Session {session_id} is in state: {state}"
132 | return False
133 |
134 | # Check if we've exceeded the maximum wait time
135 | if time.time() - start_time > max_wait_seconds:
136 | raise SessionTimeoutException(f"Timeout waiting for session {session_id} to become IDLE or FAILED")
137 |
138 | # Wait for the specified interval before checking again
139 | time.sleep(check_interval_seconds)
140 |
141 | except ClientError as e:
142 | print(f"An error occurred while checking the session status: {e.response['Error']['Message']}")
143 | raise
144 |
145 | def execute_athena_calculation(session_id, code_block, workgroup, max_wait_seconds=600, check_interval_seconds=5):
146 | """
147 | Execute a calculation in Athena, wait for completion, and retrieve results.
148 |
149 | Args:
150 | session_id (str): The Athena session ID.
151 | code_block (str): The code to execute.
152 | max_wait_seconds (int): Maximum time to wait for execution in seconds. Default is 600 seconds (10 minutes).
153 | check_interval_seconds (int): Time to wait between status checks in seconds. Default is 10 seconds.
154 |
155 | Returns:
156 | dict: A dictionary containing execution results or error information.
157 | """
158 | athena_client = boto3.client('athena')
159 | s3_client = boto3.client('s3')
160 |
161 | def start_calculation():
162 | try:
163 | response = athena_client.start_calculation_execution(
164 | SessionId=session_id,
165 | CodeBlock=code_block,
166 | # ClientRequestToken=f"token-{time.time()}" # Unique token for idempotency
167 | )
168 | return response['CalculationExecutionId']
169 | except ClientError as e:
170 | print(f"Failed to start calculation: {e}")
171 | raise
172 |
173 | def check_calculation_status(calculation_id):
174 | try:
175 | response = athena_client.get_calculation_execution(CalculationExecutionId=calculation_id)
176 | return response['Status']['State']
177 | except ClientError as e:
178 | print(f"Failed to get calculation status: {e}")
179 | raise
180 |
181 | def get_calculation_result(calculation_id):
182 | try:
183 | response = athena_client.get_calculation_execution(CalculationExecutionId=calculation_id)
184 | return response['Result']
185 | except ClientError as e:
186 | print(f"Failed to get calculation result: {e}")
187 | raise
188 |
189 | def download_s3_file(s3_uri):
190 | try:
191 | bucket, key = s3_uri.replace("s3://", "").split("/", 1)
192 | response = s3_client.get_object(Bucket=bucket, Key=key)
193 | return response['Body'].read().decode('utf-8')
194 | except ClientError as e:
195 | print(f"Failed to download S3 file: {e}")
196 | return None
197 |
198 | if session_id:
199 | if not wait_for_session_status(session_id):
200 | # Start Session
201 | session = start_athena_session_(
202 | workgroup_name=workgroup,
203 | session_idle_timeout_in_minutes=60
204 | )
205 | if wait_for_session_status(session['SessionId']):
206 | session_id = session['SessionId']
207 | st.session_state['athena-session'] = session['SessionId']
208 |
209 | else:
210 | session = start_athena_session_(
211 | workgroup_name=workgroup,
212 | session_idle_timeout_in_minutes=60
213 | )
214 | if wait_for_session_status(session['SessionId']):
215 | session_id = session['SessionId']
216 | st.session_state['athena-session'] = session['SessionId']
217 |
218 | # Start the calculation
219 | calculation_id = start_calculation()
220 | print(f"Started calculation with ID: {calculation_id}")
221 |
222 | # Wait for the calculation to complete
223 | start_time = time.time()
224 | while True:
225 | status = check_calculation_status(calculation_id)
226 | print(f"Calculation status: {status}")
227 |
228 | if status in ['COMPLETED', 'FAILED', 'CANCELED']:
229 | break
230 |
231 | if time.time() - start_time > max_wait_seconds:
232 | print("Calculation timed out")
233 | return {"error": "Calculation timed out"}
234 |
235 | time.sleep(check_interval_seconds)
236 |
237 | # Get the calculation result
238 | result = get_calculation_result(calculation_id)
239 |
240 | if status == 'COMPLETED':
241 | # Download and return the result
242 | result_content = download_s3_file(result['ResultS3Uri'])
243 | print(result)
244 | return {
245 | "status": "COMPLETED",
246 | "result": result_content,
247 | "stdout": download_s3_file(result['StdOutS3Uri']),
248 | "stderr": download_s3_file(result['StdErrorS3Uri'])
249 | }
250 | elif status == 'FAILED':
251 | # Get the error file
252 | error_content = download_s3_file(result['StdErrorS3Uri'])
253 | return {
254 | "status": "FAILED",
255 | "error": error_content,
256 | "stdout": download_s3_file(result['StdOutS3Uri'])
257 | }
258 | else:
259 | return {"status": status}
260 |
261 | def send_athena_job(payload, workgroup):
262 | """
263 | Send a Spark job payload to the specified URL.
264 |
265 | Args:
266 | payload (dict): The payload containing the Spark job details.
267 |
268 | Returns:
269 | dict: The response from the server.
270 | """
271 |
272 | code_block = """
273 | import boto3
274 | import os
275 | def put_obj_in_s3_bucket_(docs, bucket, key_prefix):
276 | S3 = boto3.client('s3')
277 | if isinstance(docs,str):
278 | file_name=os.path.basename(docs)
279 | file_path=f"{key_prefix}/{docs}"
280 | S3.upload_file(f"/tmp/{docs}", bucket, file_path)
281 | else:
282 | file_name=os.path.basename(docs.name)
283 | file_path=f"{key_prefix}/{file_name}"
284 | S3.put_object(Body=docs.read(),Bucket= BUCKET, Key=file_path)
285 | return f"s3://{bucket}/{file_path}"
286 |
287 | def handle_results_(result, bucket, s3_file_path):
288 | image_holder = []
289 | if isinstance(result, dict):
290 | for item, value in result.items():
291 | if "plotly-files" in item and value is not None:
292 | if isinstance(value, list):
293 | for img in value:
294 | image_path_s3 = put_obj_in_s3_bucket_(img, bucket, s3_file_path)
295 | image_holder.append(image_path_s3)
296 | else:
297 | image_path_s3 = put_obj_in_s3_bucket_(value,bucket,s3_file_path)
298 | image_holder.append(image_path_s3)
299 |
300 | tool_result = {
301 | "result": result,
302 | "plotly": image_holder
303 | }
304 | return tool_result
305 |
306 | # iterate = input_data.get('iterate', 0)
307 | bucket="BUCKET-NAME"
308 | s3_file_path="BUCKET-PATH"
309 | result_final = handle_results_(output, bucket, s3_file_path)
310 | print(f"")
311 | """.replace("BUCKET-NAME",payload["bucket"]).replace("BUCKET-PATH", payload["file_path"])
312 |
313 | try:
314 | result = execute_athena_calculation(st.session_state['athena-session'], payload["code"]+ "\n" +code_block, workgroup)
315 | print(json.dumps(result, indent=2))
316 | return result
317 | except Exception as e:
318 | print(f"An error occurred: {str(e)}")
319 | return e
--------------------------------------------------------------------------------
/utils/function_calling_utils.py:
--------------------------------------------------------------------------------
1 | import streamlit as st
2 | import boto3
3 | from botocore.config import Config
4 | import os
5 | import pandas as pd
6 | import time
7 | import json
8 | from botocore.exceptions import ClientError
9 | import io
10 | import re
11 | from pptx import Presentation
12 | import random
13 | from python_calamine import CalamineWorkbook
14 | import chardet
15 | from docx.table import _Cell
16 | import concurrent.futures
17 | from functools import partial
18 | from textractor import Textractor
19 | from textractor.data.constants import TextractFeatures
20 | from textractor.data.text_linearization_config import TextLinearizationConfig
21 | import pytesseract
22 | from PIL import Image
23 | import PyPDF2
24 | from docx import Document as DocxDocument
25 | from docx.oxml.text.paragraph import CT_P
26 | from docx.oxml.table import CT_Tbl
27 | from docx.document import Document
28 | from docx.text.paragraph import Paragraph
29 | from docx.table import Table as DocxTable
30 | from utils.athena_handler_ import send_athena_job
31 | import ast
32 | from urllib.parse import urlparse
33 | import plotly.graph_objects as go
34 |
35 | config = Config(
36 | read_timeout=600, # Read timeout parameter
37 | connect_timeout=600, # Connection timeout parameter in seconds
38 | retries=dict(
39 | max_attempts=10 # Handle retries
40 | )
41 | )
42 |
43 | with open('config.json', 'r', encoding='utf-8') as f:
44 | config_file = json.load(f)
45 |
46 | # Bedrock Model info
47 | with open('model_id.json', 'r', encoding='utf-8') as f:
48 | model_info = json.load(f)
49 |
50 | S3 = boto3.client('s3')
51 | DYNAMODB = boto3.resource('dynamodb')
52 | LOCAL_CHAT_FILE_NAME = "chat-history.json"
53 | DYNAMODB_TABLE = config_file["DynamodbTable"]
54 | BUCKET = config_file["Bucket_Name"]
55 | OUTPUT_TOKEN = config_file["max-output-token"]
56 | S3_DOC_CACHE_PATH = config_file["document-upload-cache-s3-path"]
57 | TEXTRACT_RESULT_CACHE_PATH = config_file["AmazonTextract-result-cache"]
58 | LOAD_DOC_IN_ALL_CHAT_CONVO = config_file["load-doc-in-chat-history"]
59 | CHAT_HISTORY_LENGTH = config_file["chat-history-loaded-length"]
60 | DYNAMODB_USER = config_file["UserId"]
61 | REGION = config_file["region"]
62 | USE_TEXTRACT = config_file["AmazonTextract"]
63 | CSV_SEPERATOR = config_file["csv-delimiter"]
64 | LAMBDA_FUNC = config_file["lambda-function"]
65 | INPUT_S3_PATH = config_file["input_s3_path"]
66 | INPUT_BUCKET = config_file["input_bucket"]
67 | ATHENA_WORKGROUP_NAME = config_file["athena-work-group-name"]
68 | TEXT_ONLY_MODELS = ["deepseek", "haiku-3.5", "micro"]
69 |
70 | with open('pricing.json', 'r', encoding='utf-8') as f:
71 | pricing_file = json.load(f)
72 |
73 |
74 |
75 | def put_db(params, messages):
76 | """Store long term chat history in DynamoDB"""
77 | chat_item = {
78 | "UserId": st.session_state['userid'], # user id
79 | "SessionId": params["session_id"], # User session id
80 | "messages": [messages], # 'messages' is a list of dictionaries
81 | "time": messages['time']
82 | }
83 |
84 | existing_item = DYNAMODB.Table(DYNAMODB_TABLE).get_item(Key={"UserId": st.session_state['userid'], "SessionId": params["session_id"]})
85 |
86 | if "Item" in existing_item:
87 | existing_messages = existing_item["Item"]["messages"]
88 | chat_item["messages"] = existing_messages + [messages]
89 |
90 | response = DYNAMODB.Table(DYNAMODB_TABLE).put_item(
91 | Item=chat_item
92 | )
93 | def save_chat_local(file_path, new_data,params):
94 | """Store long term chat history Local Disk"""
95 | try:
96 | # Read the existing JSON data from the file
97 | with open(file_path, "r",encoding='utf-8') as file:
98 | existing_data = json.load(file)
99 | if params["session_id"] not in existing_data:
100 | existing_data[params["session_id"]]=[]
101 | except FileNotFoundError:
102 | # If the file doesn't exist, initialize an empty list
103 | existing_data = {params["session_id"]:[]}
104 | # Append the new data to the existing list
105 | from decimal import Decimal
106 | data = [{k: float(v) if isinstance(v, Decimal) else v for k, v in item.items()} for item in new_data]
107 | existing_data[params["session_id"]].extend(data)
108 | # Write the updated list back to the JSON file
109 | with open(file_path, "w", encoding="utf-8") as file:
110 | json.dump(existing_data, file)
111 |
112 | def load_chat_local(file_path,params):
113 | """Load long term chat history from Local"""
114 | try:
115 | # Read the existing JSON data from the file
116 | with open(file_path, "r",encoding='utf-8') as file:
117 | existing_data = json.load(file)
118 | if params["session_id"] in existing_data:
119 | existing_data=existing_data[params["session_id"]]
120 | else:
121 | existing_data=[]
122 | except FileNotFoundError:
123 | # If the file doesn't exist, initialize an empty list
124 | existing_data = []
125 | return existing_data
126 |
127 |
128 | def bedrock_streemer(params,response, handler):
129 | text=''
130 | think = ""
131 | signature = ""
132 | for chunk in response['stream']:
133 | if 'contentBlockDelta' in chunk:
134 | delta = chunk['contentBlockDelta']['delta']
135 | # print(chunk)
136 | if 'text' in delta:
137 | text += delta['text']
138 | handler.markdown(text.replace("$", "\\$"), unsafe_allow_html=True)
139 | if 'reasoningContent' in delta:
140 | if "text" in delta['reasoningContent']:
141 | think += delta['reasoningContent']['text']
142 | handler.markdown('**MODEL REASONING**\n\n'+ think.replace("$", "\\$"), unsafe_allow_html=True)
143 | elif "signature" in delta['reasoningContent']:
144 | signature = delta['reasoningContent']['signature']
145 |
146 | elif "metadata" in chunk:
147 |
148 | if 'cacheReadInputTokens' in chunk['metadata']['usage']:
149 | print(f"\nCache Read Tokens: {chunk['metadata']['usage']['cacheReadInputTokens']}")
150 | print(f"Cache Write Tokens: {chunk['metadata']['usage']['cacheWriteInputTokens']}")
151 | st.session_state['input_token'] = chunk['metadata']['usage']["inputTokens"]
152 | st.session_state['output_token'] = chunk['metadata']['usage']["outputTokens"]
153 | latency = chunk['metadata']['metrics']["latencyMs"]
154 | pricing = st.session_state['input_token'] * pricing_file[f"{params['model']}"]["input"] + st.session_state['output_token'] * pricing_file[f"{params['model']}"]["output"]
155 | st.session_state['cost']+=pricing
156 | print(f"\nInput Tokens: {st.session_state['input_token']}\nOutput Tokens: {st.session_state['output_token']}\nLatency: {latency}ms")
157 | return text, think
158 |
159 | def bedrock_claude_(params, chat_history, system_message, prompt, model_id, image_path=None, handler=None):
160 | # st.write(chat_history)
161 | chat_history_copy = chat_history[:]
162 | content = []
163 | if image_path:
164 | if not isinstance(image_path, list):
165 | image_path = [image_path]
166 | for img in image_path:
167 | s3 = boto3.client('s3')
168 | match = re.match("s3://(.+?)/(.+)", img)
169 | image_name = os.path.basename(img)
170 | _, ext = os.path.splitext(image_name)
171 | if "jpg" in ext:
172 | ext = ".jpeg"
173 | bucket_name = match.group(1)
174 | key = match.group(2)
175 | if ".plotly" in key:
176 | bytes_image = plotly_to_png_bytes(img)
177 | ext = ".png"
178 | else:
179 | obj = s3.get_object(Bucket=bucket_name, Key=key)
180 | bytes_image = obj['Body'].read()
181 | content.extend([{"text": image_name}, {
182 | "image": {
183 | "format": f"{ext.lower().replace('.', '')}",
184 | "source": {"bytes": bytes_image}
185 | }
186 | }])
187 |
188 | content.append({
189 | "text": prompt
190 | })
191 | chat_history_copy.append({"role": "user",
192 | "content": content})
193 | system_message = [{"text": system_message}]
194 |
195 | config = Config(
196 | read_timeout=600, # Read timeout parameter
197 | retries=dict(
198 | max_attempts=10 ## Handle retries
199 | )
200 | )
201 | bedrock_runtime = boto3.client(service_name='bedrock-runtime', region_name=REGION, config=config)
202 |
203 | if st.session_state['reasoning_mode'] and "3-7" in model_id :
204 | response = bedrock_runtime.converse_stream(messages=chat_history_copy, modelId=model_id,
205 | inferenceConfig={"maxTokens": 18000, "temperature": 1,},
206 | system=system_message,
207 | additionalModelRequestFields={"thinking": {"type": "enabled", "budget_tokens": 10000}}
208 | )
209 | else:
210 | max_tokens = 8000
211 | if any(keyword in [params['model']] for keyword in ["haiku-3.5"]):
212 | max_tokens = 4000
213 | response = bedrock_runtime.converse_stream(messages=chat_history_copy, modelId=model_id,
214 | inferenceConfig={"maxTokens": max_tokens, "temperature": 0.5,},
215 | system=system_message,
216 | )
217 |
218 | answer, think = bedrock_streemer(params, response, handler)
219 | return answer, think
220 |
221 | def _invoke_bedrock_with_retries(params, current_chat, chat_template, question, model_id, image_path, handler):
222 | max_retries = 10
223 | backoff_base = 2
224 | max_backoff = 3 # Maximum backoff time in seconds
225 | retries = 0
226 |
227 | while True:
228 | try:
229 | response, think = bedrock_claude_(params, current_chat, chat_template, question, model_id, image_path, handler)
230 | return response, think
231 | except ClientError as e:
232 | if e.response['Error']['Code'] == 'ThrottlingException':
233 | if retries < max_retries:
234 | # Throttling, exponential backoff
235 | sleep_time = min(max_backoff, backoff_base ** retries + random.uniform(0, 1))
236 | time.sleep(sleep_time)
237 | retries += 1
238 | else:
239 | raise e
240 | elif e.response['Error']['Code'] == 'ModelStreamErrorException':
241 | if retries < max_retries:
242 | # Throttling, exponential backoff
243 | sleep_time = min(max_backoff, backoff_base ** retries + random.uniform(0, 1))
244 | time.sleep(sleep_time)
245 | retries += 1
246 | else:
247 | raise e
248 | elif e.response['Error']['Code'] == 'EventStreamError':
249 | if retries < max_retries:
250 | # Throttling, exponential backoff
251 | sleep_time = min(max_backoff, backoff_base ** retries + random.uniform(0, 1))
252 | time.sleep(sleep_time)
253 | retries += 1
254 | else:
255 | raise e
256 | else:
257 | # Some other API error, rethrow
258 | raise
259 |
260 | def parse_s3_uri(uri):
261 | """
262 | Parse an S3 URI and extract the bucket name and key.
263 |
264 | :param uri: S3 URI (e.g., 's3://bucket-name/path/to/file.txt')
265 | :return: Tuple of (bucket_name, key) if valid, (None, None) if invalid
266 | """
267 | pattern = r'^s3://([^/]+)/(.*)$'
268 | match = re.match(pattern, uri)
269 | if match:
270 | return match.groups()
271 | return (None, None)
272 |
273 | def copy_s3_object(source_uri, dest_bucket, dest_key):
274 | """
275 | Copy an object from one S3 location to another.
276 |
277 | :param source_uri: S3 URI of the source object
278 | :param dest_bucket: Name of the destination bucket
279 | :param dest_key: Key to be used for the destination object
280 | :return: True if successful, False otherwise
281 | """
282 | s3 = boto3.client('s3')
283 |
284 | # Parse the source URI
285 | source_bucket, source_key = parse_s3_uri(source_uri)
286 | if not source_bucket or not source_key:
287 | print(f"Invalid source URI: {source_uri}")
288 | return False
289 |
290 | try:
291 | # Create a copy source dictionary
292 | copy_source = {
293 | 'Bucket': source_bucket,
294 | 'Key': source_key
295 | }
296 |
297 | # Copy the object
298 | s3.copy_object(CopySource=copy_source, Bucket=dest_bucket, Key=f"{dest_key}/{source_key}")
299 | return f"s3://{dest_bucket}/{dest_key}/{source_key}"
300 |
301 | except ClientError as e:
302 | print(f"An error occurred: {e}")
303 | raise(e)
304 | # return False
305 |
306 | class LibraryInstallationDetected(Exception):
307 | """Exception raised when potential library installation is detected."""
308 | pass
309 |
310 |
311 | def check_for_library_installs(code_string):
312 | # Check for pip install commands using subprocess
313 | if re.search(r'subprocess\.(?:check_call|run|Popen)\s*\(\s*\[.*pip.*install', code_string):
314 | raise LibraryInstallationDetected(f"Potential library installation detected: '{keyword}' found in code.")
315 |
316 | # Check for pip as a module
317 | if re.search(r'pip\._internal\.main\(\[.*install', code_string) or re.search(r'pip\.main\(\[.*install', code_string):
318 | raise LibraryInstallationDetected(f"Potential library installation detected: '{keyword}' found in code.")
319 |
320 | keywords = ["subprocess","pip","conda","install","easy_install","setup.py","pipenv",
321 | "git+","svn+","hg+","bzr+","requirements.txt","environment.yml","apt-get","yum","brew",
322 | "ensurepip","get-pip","pkg_resources","importlib","setuptools","distutils","venv","virtualenv",
323 | "pyenv"]
324 |
325 | # Convert the code string to lowercase for case-insensitive matching
326 | code_lower = code_string.lower()
327 |
328 | # Check for each keyword
329 | for keyword in keywords:
330 | if keyword in code_lower:
331 | return True
332 |
333 | # Check for each keyword
334 | for keyword in keywords:
335 | if keyword in code_lower:
336 | raise LibraryInstallationDetected(f"Potential library installation detected: '{keyword}' found in code.")
337 |
338 |
339 | def put_obj_in_s3_bucket_(docs):
340 | """Uploads a file to an S3 bucket and returns the S3 URI of the uploaded object.
341 | Args:
342 | docs (str): The local file path of the file to upload to S3.
343 | Returns:
344 | str: The S3 URI of the uploaded object, in the format "s3://{bucket_name}/{file_path}".
345 | """
346 | if isinstance(docs,str):
347 | s3_uri_pattern = r'^s3://([^/]+)/(.*?([^/]+)/?)$'
348 | if bool(re.match(s3_uri_pattern, docs)):
349 | file_uri=copy_s3_object(docs, BUCKET, S3_DOC_CACHE_PATH)
350 | return file_uri
351 | else:
352 | file_name=os.path.basename(docs)
353 | file_path=f"{S3_DOC_CACHE_PATH}/{docs}"
354 | S3.upload_file(docs, BUCKET, file_path)
355 | return f"s3://{BUCKET}/{file_path}"
356 | else:
357 | file_name=os.path.basename(docs.name)
358 | file_path=f"{S3_DOC_CACHE_PATH}/{file_name}"
359 | S3.put_object(Body=docs.read(),Bucket= BUCKET, Key=file_path)
360 | return f"s3://{BUCKET}/{file_path}"
361 |
362 |
363 | def get_large_s3_obj_from_bucket_(file, max_bytes=1000000):
364 | """Retrieves a portion of an object from an S3 bucket given its S3 URI.
365 | Args:
366 | file (str): The S3 URI of the object to retrieve, in the format "s3://{bucket_name}/{key}".
367 | max_bytes (int, optional): Maximum number of bytes to read from the beginning of the file.
368 | Returns:
369 | botocore.response.StreamingBody: The retrieved S3 object.
370 | """
371 | s3 = boto3.client('s3')
372 | match = re.match("s3://(.+?)/(.+)", file)
373 | bucket_name = match.group(1)
374 | key = match.group(2)
375 |
376 | if max_bytes:
377 | # Read specific number of bytes
378 | obj = s3.get_object(Bucket=bucket_name, Key=key, Range=f'bytes=0-{max_bytes-1}')
379 | else:
380 | # Read the whole object if max_bytes is not specified
381 | obj = s3.get_object(Bucket=bucket_name, Key=key)
382 |
383 | return obj
384 |
385 |
386 | def get_s3_obj_from_bucket_(file):
387 | """Retrieves an object from an S3 bucket given its S3 URI.
388 | Args:
389 | file (str): The S3 URI of the object to retrieve, in the format "s3://{bucket_name}/{key}".
390 | Returns:
391 | botocore.response.StreamingBody: The retrieved S3 object.
392 | """
393 | s3 = boto3.client('s3')
394 | match = re.match("s3://(.+?)/(.+)", file)
395 | bucket_name = match.group(1)
396 | key = match.group(2)
397 | obj = s3.get_object(Bucket=bucket_name, Key=key)
398 | return obj
399 |
400 |
401 | def iter_block_items(parent):
402 | if isinstance(parent, Document):
403 | parent_elm = parent.element.body
404 | elif isinstance(parent, _Cell):
405 | parent_elm = parent._tc
406 | else:
407 | raise ValueError("something's not right")
408 |
409 | for child in parent_elm.iterchildren():
410 | if isinstance(child, CT_P):
411 | yield Paragraph(child, parent)
412 | elif isinstance(child, CT_Tbl):
413 | yield DocxTable(child, parent)
414 |
415 | def extract_text_and_tables(docx_path):
416 | """ Extract text from docx files"""
417 | document = DocxDocument(docx_path)
418 | content = ""
419 | current_section = ""
420 | section_type = None
421 | for block in iter_block_items(document):
422 | if isinstance(block, Paragraph):
423 | if block.text:
424 | if block.style.name == 'Heading 1':
425 | # Close the current section if it exists
426 | if current_section:
427 | content += f"{current_section}{section_type}>\n"
428 | current_section = ""
429 | section_type = None
430 | section_type ="h1"
431 | content += f"<{section_type}>{block.text}{section_type}>\n"
432 | elif block.style.name== 'Heading 3':
433 | # Close the current section if it exists
434 | if current_section:
435 | content += f"{current_section}{section_type}>\n"
436 | current_section = ""
437 | section_type = "h3"
438 | content += f"<{section_type}>{block.text}{section_type}>\n"
439 |
440 | elif block.style.name == 'List Paragraph':
441 | # Add to the current list section
442 | if section_type != "list":
443 | # Close the current section if it exists
444 | if current_section:
445 | content += f"{current_section}{section_type}>\n"
446 | section_type = "list"
447 | current_section = ""
448 | current_section += f"{block.text}\n"
449 | elif block.style.name.startswith('toc'):
450 | # Add to the current toc section
451 | if section_type != "toc":
452 | # Close the current section if it exists
453 | if current_section:
454 | content += f"{current_section}{section_type}>\n"
455 | section_type = "toc"
456 | current_section = ""
457 | current_section += f"{block.text}\n"
458 | else:
459 | # Close the current section if it exists
460 | if current_section:
461 | content += f"{current_section}{section_type}>\n"
462 | current_section = ""
463 | section_type = None
464 |
465 | # Append the passage text without tagging
466 | content += f"{block.text}\n"
467 |
468 | elif isinstance(block, DocxTable):
469 | # Add the current section before the table
470 | if current_section:
471 | content += f"{current_section}{section_type}>\n"
472 | current_section = ""
473 | section_type = None
474 |
475 | content += "\n"
476 | for row in block.rows:
477 | row_content = []
478 | for cell in row.cells:
479 | cell_content = []
480 | for nested_block in iter_block_items(cell):
481 | if isinstance(nested_block, Paragraph):
482 | cell_content.append(nested_block.text)
483 | elif isinstance(nested_block, DocxTable):
484 | nested_table_content = parse_nested_table(nested_block)
485 | cell_content.append(nested_table_content)
486 | row_content.append(CSV_SEPERATOR.join(cell_content))
487 | content += CSV_SEPERATOR.join(row_content) + "\n"
488 | content += "
\n"
489 |
490 | # Add the final section
491 | if current_section:
492 | content += f"{current_section}{section_type}>\n"
493 |
494 | return content
495 |
496 | def parse_nested_table(table):
497 | nested_table_content = "\n"
498 | for row in table.rows:
499 | row_content = []
500 | for cell in row.cells:
501 | cell_content = []
502 | for nested_block in iter_block_items(cell):
503 | if isinstance(nested_block, Paragraph):
504 | cell_content.append(nested_block.text)
505 | elif isinstance(nested_block, DocxTable):
506 | nested_table_content += parse_nested_table(nested_block)
507 | row_content.append(CSV_SEPERATOR.join(cell_content))
508 | nested_table_content += CSV_SEPERATOR.join(row_content) + "\n"
509 | nested_table_content += "
"
510 | return nested_table_content
511 |
512 |
513 |
514 | def extract_text_from_pptx_s3(pptx_buffer):
515 | """ Extract Text from pptx files"""
516 | presentation = Presentation(pptx_buffer)
517 | text_content = []
518 | for slide in presentation.slides:
519 | slide_text = []
520 | for shape in slide.shapes:
521 | if hasattr(shape, 'text'):
522 | slide_text.append(shape.text)
523 | text_content.append('\n'.join(slide_text))
524 | return '\n\n'.join(text_content)
525 |
526 | def get_s3_keys(prefix):
527 | """list all keys in an s3 path"""
528 | s3 = boto3.client('s3')
529 | keys = []
530 | next_token = None
531 | while True:
532 | if next_token:
533 | response = s3.list_objects_v2(Bucket=BUCKET, Prefix=prefix, ContinuationToken=next_token)
534 | else:
535 | response = s3.list_objects_v2(Bucket=BUCKET, Prefix=prefix)
536 | if "Contents" in response:
537 | for obj in response['Contents']:
538 | key = obj['Key']
539 | name = key[len(prefix):]
540 | keys.append(name)
541 | if "NextContinuationToken" in response:
542 | next_token = response["NextContinuationToken"]
543 | else:
544 | break
545 | return keys
546 |
547 | def get_object_with_retry(bucket, key):
548 | max_retries=5
549 | retries = 0
550 | backoff_base = 2
551 | max_backoff = 3 # Maximum backoff time in seconds
552 | s3 = boto3.client('s3')
553 | while retries < max_retries:
554 | try:
555 | response = s3.get_object(Bucket=bucket, Key=key)
556 | return response
557 | except ClientError as e:
558 | error_code = e.response['Error']['Code']
559 | if error_code == 'DecryptionFailureException':
560 | sleep_time = min(max_backoff, backoff_base ** retries + random.uniform(0, 1))
561 | print(f"Decryption failed, retrying in {sleep_time} seconds...")
562 | time.sleep(sleep_time)
563 | retries += 1
564 | elif e.response['Error']['Code'] == 'ModelStreamErrorException':
565 | if retries < max_retries:
566 | # Throttling, exponential backoff
567 | sleep_time = min(max_backoff, backoff_base ** retries + random.uniform(0, 1))
568 | time.sleep(sleep_time)
569 | retries += 1
570 | else:
571 | raise e
572 |
573 | def exract_pdf_text_aws(file):
574 | file_base_name=os.path.basename(file)
575 | dir_name, ext = os.path.splitext(file)
576 | # Checking if extracted doc content is in S3
577 | if USE_TEXTRACT:
578 | if [x for x in get_s3_keys(f"{TEXTRACT_RESULT_CACHE_PATH}/") if file_base_name in x]:
579 | response = get_object_with_retry(BUCKET, f"{TEXTRACT_RESULT_CACHE_PATH}/{file_base_name}.txt")
580 | text = response['Body'].read().decode()
581 | return text
582 | else:
583 | extractor = Textractor()
584 | # Asynchronous call, you will experience some wait time. Try caching results for better experience
585 | if "pdf" in ext:
586 | print("Asynchronous call, you may experience some wait time.")
587 | document = extractor.start_document_analysis(
588 | file_source=file,
589 | features=[TextractFeatures.LAYOUT,TextractFeatures.TABLES],
590 | save_image=False,
591 | s3_output_path=f"s3://{BUCKET}/textract_output/"
592 | )
593 | #Synchronous call
594 | else:
595 | document = extractor.analyze_document(
596 | file_source=file,
597 | features=[TextractFeatures.LAYOUT,TextractFeatures.TABLES],
598 | save_image=False,
599 | )
600 | config = TextLinearizationConfig(
601 | hide_figure_layout=False,
602 | hide_header_layout=False,
603 | table_prefix="",
605 | )
606 | # Upload extracted content to s3
607 | S3.put_object(Body=document.get_text(config=config), Bucket=BUCKET, Key=f"{TEXTRACT_RESULT_CACHE_PATH}/{file_base_name}.txt")
608 | return document.get_text(config=config)
609 | else:
610 | s3=boto3.resource("s3")
611 | match = re.match("s3://(.+?)/(.+)", file)
612 | if match:
613 | bucket_name = match.group(1)
614 | key = match.group(2)
615 | if "pdf" in ext:
616 | pdf_bytes = io.BytesIO()
617 | s3.Bucket(bucket_name).download_fileobj(key, pdf_bytes)
618 | # Read the PDF from the BytesIO object
619 | pdf_bytes.seek(0)
620 | # Create a PDF reader object
621 | pdf_reader = PyPDF2.PdfReader(pdf_bytes)
622 | # Get the number of pages in the PDF
623 | num_pages = len(pdf_reader.pages)
624 | # Extract text from each page
625 | text = ''
626 | for page_num in range(num_pages):
627 | page = pdf_reader.pages[page_num]
628 | text += page.extract_text()
629 | else:
630 | img_bytes = io.BytesIO()
631 | s3.Bucket(bucket_name).download_fileobj(key, img_bytes)
632 | img_bytes.seek(0)
633 | image_stream = io.BytesIO(image_bytes)
634 | image = Image.open(image_stream)
635 | text = pytesseract.image_to_string(image)
636 | return text
637 |
638 | def detect_encoding(s3_uri):
639 | """detect csv encoding"""
640 | s3 = boto3.client('s3')
641 | match = re.match("s3://(.+?)/(.+)", s3_uri)
642 | if match:
643 | bucket_name = match.group(1)
644 | key = match.group(2)
645 | response = get_large_s3_obj_from_bucket_(s3_uri)
646 | content = response['Body'].read()
647 | result = chardet.detect(content)
648 | df = content.decode(result['encoding'])
649 | return result['encoding'], df
650 |
651 | class InvalidContentError(Exception):
652 | pass
653 |
654 | def parse_csv_from_s3(s3_uri):
655 | """Here we are only loading the first 3 rows to the model. 3 rows is sufficient for the model to figure out the schema"""
656 |
657 | try:
658 | # Detect the file encoding using chardet
659 | encoding, content = detect_encoding(s3_uri)
660 | # Use StringIO to create a file-like object
661 | csv_file = io.StringIO(content)
662 | # Read the CSV file using pandas
663 | df = pd.read_csv(csv_file, delimiter=None, engine='python').iloc[:3]
664 | data_types = df.dtypes
665 | return f"{df.to_csv(index=False)}\n\nHere are the data types for each columns:\n{data_types}"
666 |
667 | except Exception as e:
668 | raise InvalidContentError(f"Error: {e}")
669 |
670 |
671 | def strip_newline(cell):
672 | return str(cell).strip()
673 |
674 | def table_parser_openpyxl(file):
675 | """
676 | Here we are only loading the first 20 rows to the model and we are not massaging the dataset by merging empty cells. 5 rows is sufficient for the model to figure out the schema
677 | """
678 | # Read from S3
679 | s3 = boto3.client('s3', region_name=REGION)
680 | match = re.match("s3://(.+?)/(.+)", file)
681 | if match:
682 | bucket_name = match.group(1)
683 | key = match.group(2)
684 | obj = s3.get_object(Bucket=bucket_name, Key=key)
685 | # Read Excel file from S3 into a buffer
686 | xlsx_buffer = io.BytesIO(obj['Body'].read())
687 | # Load workbook
688 | wb=pd.read_excel(xlsx_buffer,sheet_name=None, header=None)
689 | all_sheets_string=""
690 | # Iterate over each sheet in the workbook
691 | for sheet_name, sheet_data in wb.items():
692 | df = pd.DataFrame(sheet_data)
693 | # Convert to string and tag by sheet name
694 | all_sheets_string+=f'Here is a data preview of this sheet (first 5 rows):<{sheet_name}>\n{df.iloc[:5].to_csv(index=False, header=False)}\n{sheet_name}>\n'
695 | return all_sheets_string
696 | else:
697 | raise Exception(f"{file} not formatted as an S3 path")
698 |
699 | def calamaine_excel_engine(file):
700 | """
701 | Here we are only loading the first 20 rows to the model and we are not massaging the dataset by merging empty cells. 20 rows is sufficient for the model to figure out the schema
702 | """
703 | # # Read from S3
704 | s3 = boto3.client('s3',region_name=REGION)
705 | match = re.match("s3://(.+?)/(.+)", file)
706 | if match:
707 | bucket_name = match.group(1)
708 | key = match.group(2)
709 | obj = s3.get_object(Bucket=bucket_name, Key=key)
710 | # Read Excel file from S3 into a buffer
711 | xlsx_buffer = io.BytesIO(obj['Body'].read())
712 | xlsx_buffer.seek(0)
713 | all_sheets_string = ""
714 | # Load the Excel file
715 | workbook = CalamineWorkbook.from_filelike(xlsx_buffer)
716 | # Iterate over each sheet in the workbook
717 | for sheet_name in workbook.sheet_names:
718 | # Get the sheet by name
719 | sheet = workbook.get_sheet_by_name(sheet_name)
720 | df = pd.DataFrame(sheet.to_python(skip_empty_area=False))
721 | df = df.map(strip_newline)
722 | all_sheets_string += f'Here is a data preview of this sheet (first 5 rows):\n\n<{sheet_name}>\n{df.iloc[:5].to_csv(index=False, header=0)}\n{sheet_name}>\n'
723 | return all_sheets_string
724 | else:
725 | raise Exception(f"{file} not formatted as an S3 path")
726 |
727 | def table_parser_utills(file):
728 | try:
729 | response = table_parser_openpyxl(file)
730 | if response:
731 | return response
732 | else:
733 | return calamaine_excel_engine(file)
734 | except Exception:
735 | try:
736 | return calamaine_excel_engine(file)
737 | except Exception as e:
738 | raise Exception(str(e))
739 |
740 | def process_document_types(file):
741 | """Handle various document format"""
742 | dir_name, ext = os.path.splitext(file)
743 | if ".csv" == ext.lower():
744 | content = parse_csv_from_s3(file)
745 | elif ext.lower() in [".txt", ".py"]:
746 | obj = get_s3_obj_from_bucket_(file)
747 | content = obj['Body'].read()
748 | elif ext.lower() in [".xlsx", ".xls"]:
749 | content = table_parser_utills(file)
750 | elif ext.lower() in [".pdf", ".png", ".jpg", ".tif", ".jpeg"]:
751 | content = exract_pdf_text_aws(file)
752 | elif ".json" == ext.lower():
753 | obj = get_s3_obj_from_bucket_(file)
754 | content = json.loads(obj['Body'].read())
755 | elif ".docx" == ext.lower():
756 | obj = get_s3_obj_from_bucket_(file)
757 | content = obj['Body'].read()
758 | docx_buffer = io.BytesIO(content)
759 | content = extract_text_and_tables(docx_buffer)
760 | elif ".pptx" == ext.lower():
761 | obj = get_s3_obj_from_bucket_(file)
762 | content = obj['Body'].read()
763 | docx_buffer = io.BytesIO(content)
764 | content = extract_text_from_pptx_s3(docx_buffer)
765 |
766 | # Implement any other file extension logic
767 | return content
768 |
769 | def plotly_to_png_bytes(s3_uri):
770 | """
771 | Read a .plotly file from S3 given an S3 URI, convert it to a PNG image, and return the image as bytes.
772 |
773 | :param s3_uri: S3 URI of the .plotly file (e.g., 's3://bucket-name/path/to/file.plotly')
774 | :return: PNG image as bytes
775 | """
776 | # Parse S3 URI
777 | parsed_uri = urlparse(s3_uri)
778 | bucket_name = parsed_uri.netloc
779 | file_key = parsed_uri.path.lstrip('/')
780 |
781 | # Initialize S3 client
782 | s3_client = boto3.client('s3')
783 |
784 | try:
785 | # Read the .plotly file from S3
786 | response = s3_client.get_object(Bucket=bucket_name, Key=file_key)
787 | plotly_data = json.loads(response['Body'].read().decode('utf-8'))
788 |
789 | # Create a Figure object from the plotly data
790 | fig = go.Figure(data=plotly_data['data'], layout=plotly_data.get('layout', {}))
791 |
792 | # Convert the figure to PNG bytes
793 | img_bytes = fig.to_image(format="png")
794 |
795 | return img_bytes
796 |
797 | except Exception as e:
798 | print(f"An error occurred: {str(e)}")
799 | return None
800 |
801 |
802 | def get_chat_history_db(params, cutoff,vision_model):
803 | """
804 | Load chat history and attachments from DynamoDB and S3 accordingly
805 |
806 | parameters:
807 | params (dict): Application parameters
808 | cutoff (int): Custoff of Chat history to be loaded
809 | vision_model (bool): Boolean if Claude 3 model is used
810 | """
811 | current_chat, chat_hist = [], []
812 | if params['chat_histories'] and cutoff != 0:
813 | chat_hist = params['chat_histories'][-cutoff:]
814 | for ids, d in enumerate(chat_hist):
815 | if d['image'] and vision_model and LOAD_DOC_IN_ALL_CHAT_CONVO:
816 | content = []
817 | for img in d['image']:
818 | s3 = boto3.client('s3')
819 | match = re.match("s3://(.+?)/(.+)", img)
820 | image_name = os.path.basename(img)
821 | _, ext = os.path.splitext(image_name)
822 | if "jpg" in ext:
823 | ext = ".jpeg"
824 | if match:
825 | bucket_name = match.group(1)
826 | key = match.group(2)
827 | if ".plotly" in key:
828 | bytes_image = plotly_to_png_bytes(img)
829 | ext = ".png"
830 | else:
831 | obj = s3.get_object(Bucket=bucket_name, Key=key)
832 | bytes_image = obj['Body'].read()
833 | content.extend(
834 | [
835 | {"text": image_name},
836 | {'image':
837 | {
838 | 'format': ext.lower().replace('.', ''),
839 | 'source': {'bytes': bytes_image}
840 | }
841 | }
842 | ]
843 | )
844 | content.extend([{"text": d['user']}])
845 | if 'tool_result_id' in d and d['tool_result_id']:
846 | user = [{'toolResult': {'toolUseId': d['tool_result_id'],
847 | 'content': content}}]
848 | current_chat.append({'role': 'user', 'content': user})
849 | else:
850 | current_chat.append({'role': 'user', 'content': content})
851 | elif d['document'] and LOAD_DOC_IN_ALL_CHAT_CONVO:
852 | doc = 'Here is a document showing sample rows:\n'
853 | for docs in d['document']:
854 | uploads = process_document_types(docs)
855 | doc_name = os.path.basename(docs)
856 | doc += f"<{doc_name}>\n{uploads}\n{doc_name}>\n"
857 | if not vision_model and d["image"]:
858 | for docs in d['image']:
859 | uploads = process_document_types(docs)
860 | doc_name = os.path.basename(docs)
861 | doc += f"<{doc_name}>\n{uploads}\n{doc_name}>\n"
862 | current_chat.append({'role': 'user', 'content': [{"text": doc + d['user']}]})
863 | # do not have a tool return for the document section because the tool does not provide documents only images
864 | else:
865 | if 'tool_result_id' in d and d['tool_result_id']:
866 | user = [{'toolResult': {'toolUseId': d['tool_result_id'],
867 | 'content': [{'text': d['user']}]}}]
868 | current_chat.append({'role': 'user', 'content': user})
869 | else:
870 | current_chat.append({'role': 'user', 'content': [{"text": d['user']}]})
871 |
872 | if 'tool_use_id' in d and d['tool_use_id']:
873 | assistant = [{'toolUse': {'toolUseId': d['tool_use_id'],
874 | 'name': d['tool_name'],
875 | 'input': {'code': d['assistant'],
876 | "dataset_name": d['tool_params']['ds'],
877 | "python_packages": d['tool_params']['pp']}}}
878 | ]
879 |
880 | current_chat.append({'role': 'assistant', 'content': assistant})
881 | else:
882 | current_chat.append({'role': 'assistant', 'content': [{"text": d['assistant']}]})
883 | return current_chat, chat_hist
884 |
885 | def stream_messages(params,
886 | bedrock_client,
887 | model_id,
888 | messages,
889 | tool_config,
890 | system,
891 | temperature,
892 | handler):
893 | """
894 | Sends a message to a model and streams the response.
895 | Args:
896 | bedrock_client: The Boto3 Bedrock runtime client.
897 | model_id (str): The model ID to use.
898 | messages (JSON) : The messages to send to the model.
899 | tool_config : Tool Information to send to the model.
900 |
901 | Returns:
902 | stop_reason (str): The reason why the model stopped generating text.
903 | message (JSON): The message that the model generated.
904 |
905 | """
906 |
907 | if st.session_state['reasoning_mode']:
908 | response = bedrock_client.converse_stream(
909 | modelId=model_id,
910 | messages=messages,
911 | inferenceConfig={"maxTokens": 18000, "temperature": 1},
912 | toolConfig=tool_config,
913 | system=system,
914 | additionalModelRequestFields={"thinking": {"type": "enabled", "budget_tokens": 10000}}
915 | )
916 | else:
917 | response = bedrock_client.converse_stream(
918 | modelId=model_id,
919 | messages=messages,
920 | inferenceConfig={"maxTokens": 4000, "temperature": temperature},
921 | toolConfig=tool_config,
922 | system=system
923 | )
924 |
925 | stop_reason = ""
926 | message = {}
927 | content = []
928 | message['content'] = content
929 | text = ''
930 | tool_use = {}
931 | think = ''
932 | signature = ''
933 |
934 | for chunk in response['stream']:
935 | if 'messageStart' in chunk:
936 | message['role'] = chunk['messageStart']['role']
937 | elif 'contentBlockStart' in chunk:
938 | tool = chunk['contentBlockStart']['start']['toolUse']
939 | tool_use['toolUseId'] = tool['toolUseId']
940 | tool_use['name'] = tool['name']
941 | elif 'contentBlockDelta' in chunk:
942 | delta = chunk['contentBlockDelta']['delta']
943 | if 'toolUse' in delta:
944 | if 'input' not in tool_use:
945 | tool_use['input'] = ''
946 | tool_use['input'] += delta['toolUse']['input']
947 | elif 'text' in delta:
948 | text += delta['text']
949 | if handler:
950 | handler.markdown(text.replace("$", "\\$"), unsafe_allow_html=True)
951 | elif 'reasoningContent' in delta:
952 | if "text" in delta['reasoningContent']:
953 | think += delta['reasoningContent']['text']
954 | handler.markdown('**MODEL REASONING**\n\n' + think.replace("$", "\\$"), unsafe_allow_html=True)
955 | if "signature" in delta['reasoningContent']:
956 | signature = delta['reasoningContent']['signature']
957 |
958 | elif 'contentBlockStop' in chunk:
959 | if 'input' in tool_use:
960 | tool_use['input'] = json.loads(tool_use['input'])
961 | content.append({'toolUse': tool_use})
962 | else:
963 | content.append({'text': text})
964 | text = ''
965 |
966 | elif 'messageStop' in chunk:
967 | stop_reason = chunk['messageStop']['stopReason']
968 | elif "metadata" in chunk:
969 | st.session_state['input_token'] = chunk['metadata']['usage']["inputTokens"]
970 | st.session_state['output_token'] = chunk['metadata']['usage']["outputTokens"]
971 | latency = chunk['metadata']['metrics']["latencyMs"]
972 | pricing = st.session_state['input_token'] * pricing_file[f"{params['model']}"]["input"] + st.session_state['output_token'] * pricing_file[f"{params['model']}"]["output"]
973 | st.session_state['cost']+=pricing
974 |
975 | if tool_use:
976 | try:
977 | handler.markdown(f"{text}\n```python\n{message['content'][1]['toolUse']['input']['code']}", unsafe_allow_html=True )
978 | except:
979 | if len(message['content']) == 3:
980 | handler.markdown(f"{text}\n```python\n{message['content'][2]['toolUse']['input']['code']}", unsafe_allow_html=True )
981 | else:
982 | handler.markdown(f"{text}\n```python\n{message['content'][0]['toolUse']['input']['code']}", unsafe_allow_html=True )
983 | return stop_reason, message, input_tokens, output_tokens, think
984 |
985 |
986 | def self_crtique(params, code, error, dataset, handler=None):
987 |
988 | import re
989 | if params["engine"] == "pyspark":
990 | with open("prompt/pyspark_debug_prompt.txt", "r") as fo:
991 | prompt_template = fo.read()
992 | values = {"dataset": dataset, "code": code, "error": error}
993 | prompt = prompt_template.format(**values)
994 | system = "You are an expert pyspark debugger for Amazon Athena PySpark Runtime"
995 | else:
996 | with open("prompt/python_debug_prompt.txt", "r") as fo:
997 | prompt_template = fo.read()
998 | values = {"dataset": dataset, "code": code, "error": error}
999 | prompt = prompt_template.format(**values)
1000 | system = "You are an expert python debugger."
1001 |
1002 | model_id = 'us.' + model_info[params['model']]
1003 | fixed_code, think = _invoke_bedrock_with_retries(params, [], system, prompt, model_id, [], handler)
1004 | code_pattern = r'(.*?)
'
1005 | match = re.search(code_pattern, fixed_code, re.DOTALL)
1006 | code = match.group(1)
1007 | if handler:
1008 | handler.markdown(f"```python\n{code}", unsafe_allow_html=True)
1009 | lib_pattern = r'(.*?)'
1010 | match = re.search(lib_pattern, fixed_code, re.DOTALL)
1011 | if match:
1012 | libs = match.group(1)
1013 | else:
1014 | libs = ''
1015 | return code, libs
1016 |
1017 | def process_files(files):
1018 | result_string = ""
1019 | errors = []
1020 | future_proxy_mapping = {}
1021 | futures = []
1022 |
1023 | with concurrent.futures.ProcessPoolExecutor() as executor:
1024 | # Partial function to pass the process_document_types function
1025 | func = partial(process_document_types)
1026 | for file in files:
1027 | future = executor.submit(func, file)
1028 | future_proxy_mapping[future] = file
1029 | futures.append(future)
1030 |
1031 | # Collect the results and handle exceptions
1032 | for future in concurrent.futures.as_completed(futures):
1033 | file_url = future_proxy_mapping[future]
1034 | try:
1035 | result = future.result()
1036 | doc_name = os.path.basename(file_url)
1037 | result_string += f"\n{result}\n\n"
1038 | except Exception as e:
1039 | # Get the original function arguments from the Future object
1040 | error = {'file': file_url, 'error': str(e)}
1041 | errors.append(error)
1042 |
1043 | return errors, result_string
1044 |
1045 | def invoke_lambda(function_name, payload):
1046 | config = Config(
1047 | connect_timeout=600,
1048 | read_timeout=600, # Read timeout parameter
1049 | retries=dict(
1050 | max_attempts=0, # Handle retries
1051 | total_max_attempts=1
1052 | )
1053 | )
1054 | lambda_client = boto3.client('lambda', config=config)
1055 | response = lambda_client.invoke(
1056 | FunctionName=function_name,
1057 | InvocationType='RequestResponse',
1058 | Payload=json.dumps(payload)
1059 | )
1060 | return json.loads(response['Payload'].read().decode('utf-8'))
1061 |
1062 | class CodeExecutionError(Exception):
1063 | pass
1064 |
1065 | def load_json_data(output_content):
1066 | try:
1067 | json_data = json.loads(output_content.replace("'", '"'))
1068 | return json_data
1069 | except json.JSONDecodeError:
1070 | try:
1071 | parsed_data = ast.literal_eval(output_content)
1072 | json_data = json.loads(json.dumps(parsed_data))
1073 | return json_data
1074 | except (SyntaxError, ValueError):
1075 | try:
1076 | # Replace escaped single quotes with double quotes, but handle nested quotes carefully
1077 | modified_content = output_content
1078 | # First, ensure the outer structure is properly quoted
1079 | if modified_content.startswith("'") and modified_content.endswith("'"):
1080 | modified_content = modified_content[1:-1] # Remove outer quotes
1081 | # Replace escaped single quotes with double quotes
1082 | modified_content = modified_content.replace("\'", '"')
1083 | json_data = json.loads(modified_content)
1084 | return json_data
1085 | except json.JSONDecodeError as e:
1086 | print(f"Error parsing JSON using all methods: {e}")
1087 | raise e
1088 |
1089 | def function_caller_claude_(params, handler=None):
1090 | """
1091 | Entrypoint for streaming tool use example.
1092 | """
1093 | current_chat, chat_hist = get_chat_history_db(params, CHAT_HISTORY_LENGTH, True)
1094 | if current_chat and 'toolResult' in current_chat[0]['content'][0]:
1095 | if 'toolUseId' in current_chat[0]['content'][0]['toolResult']:
1096 | del current_chat[0:2]
1097 |
1098 | vision_model = True
1099 | model_id = 'us.' + model_info[params['model']]
1100 | if any(keyword in [params['model']] for keyword in TEXT_ONLY_MODELS):
1101 | vision_model = False
1102 |
1103 | full_doc_path = []
1104 | image_path = []
1105 | for ids, docs in enumerate(params['upload_doc']):
1106 | file_name = docs.name
1107 | _, extensions = os.path.splitext(file_name)
1108 | s3_file_name = put_obj_in_s3_bucket_(docs)
1109 | if extensions.lower() in [".jpg", ".jpeg", ".png", ".gif", ".webp"] and vision_model:
1110 | image_path.append(s3_file_name)
1111 | continue
1112 | full_doc_path.append(s3_file_name)
1113 |
1114 | if params['s3_objects']:
1115 | for ids, docs in enumerate(params['s3_objects']):
1116 | file_name = docs
1117 | _, extensions = os.path.splitext(file_name)
1118 | docs = f"s3://{INPUT_BUCKET}/{INPUT_S3_PATH}/{docs}"
1119 | full_doc_path.append(docs)
1120 | if extensions.lower() in [".jpg", ".jpeg", ".png", ".gif", ".webp"] and vision_model:
1121 | image_path.append(docs)
1122 | continue
1123 |
1124 | errors, result_string = process_files(full_doc_path)
1125 | if errors:
1126 | st.error(errors)
1127 | question = params['question']
1128 | if result_string and ('.csv' in result_string or '.parquet' in result_string or '.xlsx' in result_string):
1129 | input_text = f"Here is a subset (first few rows) of the data from each dataset tagged by each file name:\n{result_string}\n{question}"
1130 | elif result_string and not ('.csv' in result_string or '.xlsx' in result_string or '.parquet' in result_string):
1131 | doc = 'I have provided documents and/or images tagged by their file names:\n'
1132 | input_text = f"{doc}{result_string}\n{question}"
1133 | else:
1134 | input_text = question
1135 | bedrock_client = boto3.client(service_name='bedrock-runtime', region_name=REGION, config=config)
1136 | # Create the initial message from the user input.
1137 | content = []
1138 | if image_path:
1139 | for img in image_path:
1140 | s3 = boto3.client('s3')
1141 | match = re.match("s3://(.+?)/(.+)", img)
1142 | image_name = os.path.basename(img)
1143 | _, ext = os.path.splitext(image_name)
1144 | if "jpg" in ext:
1145 | ext = ".jpeg"
1146 | bucket_name = match.group(1)
1147 | key = match.group(2)
1148 | if ".plotly" in key:
1149 | bytes_image = plotly_to_png_bytes(img)
1150 | ext = ".png"
1151 | else:
1152 | obj = s3.get_object(Bucket=bucket_name, Key=key)
1153 | bytes_image = obj['Body'].read()
1154 | content.extend([{"text": image_name}, {
1155 | "image": {
1156 | "format": f"{ext.lower().replace('.', '')}",
1157 | "source": {"bytes": bytes_image}
1158 | }
1159 | }])
1160 | content.append({"text": input_text})
1161 | messages = [{
1162 | "role": "user",
1163 | "content": content
1164 | }]
1165 | # Define the tool and prompt template to send to the model.
1166 | if params["engine"] == "pyspark":
1167 | with open("prompt/pyspark_tool_template.json", "r") as f:
1168 | tool_config = json.load(f)
1169 | with open("prompt/pyspark_tool_prompt.txt", "r") as fo:
1170 | description = fo.read()
1171 | with open("prompt/pyspark_tool_system.txt", "r") as fod:
1172 | system_prompt = fod.read()
1173 | tool_config['tools'][0]['toolSpec']['inputSchema']['json']['properties']['code']['description'] = description
1174 | else:
1175 | with open("prompt/python_tool_template.json", "r") as f:
1176 | tool_config = json.load(f)
1177 | with open("prompt/python_tool_prompt.txt", "r") as fo:
1178 | description = fo.read()
1179 | with open("prompt/python_tool_system.txt", "r") as fod:
1180 | system_prompt = fod.read()
1181 | tool_config['tools'][0]['toolSpec']['inputSchema']['json']['properties']['code']['description'] = description
1182 |
1183 | system = [
1184 | {
1185 | 'text': system_prompt
1186 | }
1187 | ]
1188 | current_chat.extend(messages)
1189 | # Send the message and get the tool use request from response.
1190 | stop_reason, message, input_tokens, output_tokens, think = stream_messages(
1191 | params, bedrock_client, model_id, current_chat, tool_config, system, 0.1, handler)
1192 | messages.append(message)
1193 | if stop_reason != "tool_use":
1194 | chat_history = {
1195 | "user": question,
1196 | "assistant": message['content'][0]['text'] if message['content'][0]['text'] else message['content'][1]['text'],
1197 | "image": image_path,
1198 | "document": full_doc_path,
1199 | "thinking": think,
1200 | "modelID": model_id,
1201 | "time": str(time.time()),
1202 | "input_token": round(input_tokens),
1203 | "output_token": round(output_tokens)
1204 | }
1205 | if DYNAMODB_TABLE:
1206 | put_db(params, chat_history)
1207 | # use local disk for storage
1208 | else:
1209 | save_chat_local(LOCAL_CHAT_FILE_NAME, [chat_history], params)
1210 |
1211 | return message['content'][0]['text'], "", "", "", full_doc_path, stop_reason, ""
1212 | elif stop_reason == "tool_use":
1213 |
1214 | self_correction_retry = 5
1215 | for content in message['content']:
1216 | if 'toolUse' in content:
1217 | tool = content['toolUse']
1218 | if tool['name'] == tool_config['tools'][0]['toolSpec']['name']:
1219 | # Preliminary Guardrail to check that the code does not have any install commands
1220 | check_for_library_installs(tool['input']['code'])
1221 | i = 0
1222 | while i < self_correction_retry:
1223 | try:
1224 | payload = {
1225 | "python_packages": tool['input']['python_packages'],
1226 | "code": tool['input']['code'],
1227 | "dataset_name": tool['input']['dataset_name'],
1228 | "iterate": i,
1229 | "bucket": BUCKET,
1230 | "file_path": S3_DOC_CACHE_PATH
1231 | }
1232 | if params["engine"] == "pyspark":
1233 | tool_execution_response = send_athena_job(payload, ATHENA_WORKGROUP_NAME) # Execute generated code in Amazon Athena
1234 | if "error" not in tool_execution_response:
1235 | pattern = r''
1236 | match = re.search(pattern, tool_execution_response['stdout'], re.DOTALL)
1237 | if match:
1238 | output_content = match.group(1)
1239 | # Parse the extracted content to JSON
1240 | try:
1241 | json_data = load_json_data(output_content)
1242 | image_holder = json_data.get('plotly', [])
1243 | results = json_data['result']
1244 | plotly_obj = json_data.get('plotly', [])
1245 | except json.JSONDecodeError as e:
1246 | print(f"Error parsing JSON: {e}")
1247 | raise e
1248 | else:
1249 | print("No