├── BedrockChatUI.ipynb ├── CODE_OF_CONDUCT.md ├── CONTRIBUTING.md ├── Docker ├── Dockerfile ├── app.py └── requirements.txt ├── LICENSE ├── README.md ├── bedrock-chat.py ├── config.json ├── images ├── JP-lab.PNG ├── chat-flow.png ├── chat-preview.JPG ├── chatbot-snip.PNG ├── chatbot4.png ├── demo.mp4 ├── sg-rules.PNG └── studio-new-launcher.png ├── install_package.sh ├── model_id.json ├── pricing.json ├── prompt ├── chat.txt ├── doc_chat.txt ├── pyspark_debug_prompt.txt ├── pyspark_tool_prompt.txt ├── pyspark_tool_system.txt ├── pyspark_tool_template.json ├── python_debug_prompt.txt ├── python_tool_prompt.txt ├── python_tool_system.txt └── python_tool_template.json ├── req.txt └── utils ├── athena_handler_.py └── function_calling_utils.py /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | ## Code of Conduct 2 | This project has adopted the [Amazon Open Source Code of Conduct](https://aws.github.io/code-of-conduct). 3 | For more information see the [Code of Conduct FAQ](https://aws.github.io/code-of-conduct-faq) or contact 4 | opensource-codeofconduct@amazon.com with any additional questions or comments. 5 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing Guidelines 2 | 3 | Thank you for your interest in contributing to our project. Whether it's a bug report, new feature, correction, or additional 4 | documentation, we greatly value feedback and contributions from our community. 5 | 6 | Please read through this document before submitting any issues or pull requests to ensure we have all the necessary 7 | information to effectively respond to your bug report or contribution. 8 | 9 | 10 | ## Reporting Bugs/Feature Requests 11 | 12 | We welcome you to use the GitHub issue tracker to report bugs or suggest features. 13 | 14 | When filing an issue, please check existing open, or recently closed, issues to make sure somebody else hasn't already 15 | reported the issue. Please try to include as much information as you can. Details like these are incredibly useful: 16 | 17 | * A reproducible test case or series of steps 18 | * The version of our code being used 19 | * Any modifications you've made relevant to the bug 20 | * Anything unusual about your environment or deployment 21 | 22 | 23 | ## Contributing via Pull Requests 24 | Contributions via pull requests are much appreciated. Before sending us a pull request, please ensure that: 25 | 26 | 1. You are working against the latest source on the *main* branch. 27 | 2. You check existing open, and recently merged, pull requests to make sure someone else hasn't addressed the problem already. 28 | 3. You open an issue to discuss any significant work - we would hate for your time to be wasted. 29 | 30 | To send us a pull request, please: 31 | 32 | 1. Fork the repository. 33 | 2. Modify the source; please focus on the specific change you are contributing. If you also reformat all the code, it will be hard for us to focus on your change. 34 | 3. Ensure local tests pass. 35 | 4. Commit to your fork using clear commit messages. 36 | 5. Send us a pull request, answering any default questions in the pull request interface. 37 | 6. Pay attention to any automated CI failures reported in the pull request, and stay involved in the conversation. 38 | 39 | GitHub provides additional document on [forking a repository](https://help.github.com/articles/fork-a-repo/) and 40 | [creating a pull request](https://help.github.com/articles/creating-a-pull-request/). 41 | 42 | 43 | ## Finding contributions to work on 44 | Looking at the existing issues is a great way to find something to contribute on. As our projects, by default, use the default GitHub issue labels (enhancement/bug/duplicate/help wanted/invalid/question/wontfix), looking at any 'help wanted' issues is a great place to start. 45 | 46 | 47 | ## Code of Conduct 48 | This project has adopted the [Amazon Open Source Code of Conduct](https://aws.github.io/code-of-conduct). 49 | For more information see the [Code of Conduct FAQ](https://aws.github.io/code-of-conduct-faq) or contact 50 | opensource-codeofconduct@amazon.com with any additional questions or comments. 51 | 52 | 53 | ## Security issue notifications 54 | If you discover a potential security issue in this project we ask that you notify AWS/Amazon Security via our [vulnerability reporting page](http://aws.amazon.com/security/vulnerability-reporting/). Please do **not** create a public github issue. 55 | 56 | 57 | ## Licensing 58 | 59 | See the [LICENSE](LICENSE) file for our project's licensing. We will ask you to confirm the licensing of your contribution. 60 | -------------------------------------------------------------------------------- /Docker/Dockerfile: -------------------------------------------------------------------------------- 1 | FROM public.ecr.aws/lambda/python:3.12 2 | 3 | # Copy function code 4 | COPY app.py ${LAMBDA_TASK_ROOT} 5 | 6 | # Install the function's dependencies 7 | COPY requirements.txt . 8 | RUN pip install -r requirements.txt 9 | 10 | # Set the CMD to your handler (could also be done as a parameter override outside of the Dockerfile) 11 | CMD [ "app.lambda_handler" ] -------------------------------------------------------------------------------- /Docker/app.py: -------------------------------------------------------------------------------- 1 | import json 2 | import subprocess 3 | import os 4 | import boto3 5 | import sys 6 | import re 7 | 8 | class CodeExecutionError(Exception): 9 | pass 10 | 11 | def local_code_executy(code_string): 12 | """ 13 | Execute a given Python code string in a secure, isolated environment and capture its output. 14 | 15 | Parameters: 16 | code_string (str): The Python code to be executed. 17 | 18 | Returns: 19 | dict: The output of the executed code, expected to be in JSON format. 20 | 21 | Raises: 22 | CodeExecutionError: Custom exception with detailed error information if code execution fails. 23 | 24 | 25 | Functionality: 26 | 1. Creates a temporary Python file with a unique name in the /tmp directory. 27 | 2. Writes the provided code to this temporary file. 28 | 3. Executes the temporary file using the Python interpreter. 29 | 4. Captures the output from a predefined output file ('/tmp/output.json'). 30 | 5. Cleans up temporary files after execution. 31 | 6. In case of execution errors, provides detailed error information including: 32 | - The error message and traceback. 33 | - The line number where the error occurred. 34 | - Context of the code around the error line. 35 | 36 | Note: 37 | - This function assumes that the executed code writes its output to '/tmp/output.json'. This (saving the output to local) is appended to the generated code in the main application. 38 | """ 39 | # Create a unique filename in /tmp 40 | temp_file_path = f"/tmp/code_{os.urandom(16).hex()}.py" 41 | output_file_path = '/tmp/output.json' 42 | try: 43 | # Write the code to the temporary file 44 | with open(temp_file_path, 'w', encoding="utf-8") as temp_file: 45 | temp_file.write(code_string) 46 | 47 | # Execute the temporary file 48 | result = subprocess.run([sys.executable, temp_file_path], 49 | capture_output=True, text=True, check=True) 50 | 51 | with open(output_file_path, 'r', encoding="utf-8") as f: 52 | output = json.load(f) 53 | 54 | # Clean up temporary files 55 | os.remove(output_file_path) 56 | 57 | return output 58 | 59 | except subprocess.CalledProcessError as e: 60 | # An error occurred during execution 61 | full_error_message = e.stderr.strip() 62 | 63 | # Extract the traceback part of the error message 64 | traceback_match = re.search(r'(Traceback[\s\S]*)', full_error_message) 65 | if traceback_match: 66 | error_message = traceback_match.group(1) 67 | else: 68 | error_message = full_error_message # Fallback to full message if no traceback found 69 | 70 | # Parse the traceback to get the line number 71 | tb_lines = error_message.split('\n') 72 | line_no = None 73 | for line in reversed(tb_lines): 74 | if temp_file_path in line: 75 | match = re.search(r'line (\d+)', line) 76 | if match: 77 | line_no = int(match.group(1)) 78 | break 79 | 80 | # Construct error message with context 81 | error = f"Error: {error_message}\n" 82 | if line_no is not None: 83 | code_lines = code_string.split('\n') 84 | context_lines = 2 85 | start = max(0, line_no - 1 - context_lines) 86 | end = min(len(code_lines), line_no + context_lines) 87 | error += f"Error on line {line_no}:\n" 88 | for i, line in enumerate(code_lines[start:end], start=start+1): 89 | prefix = "-> " if i == line_no else " " 90 | error += f"{prefix}{i}: {line}\n" 91 | else: 92 | error += "Could not determine the exact line of the error.\n" 93 | error += "Full code:\n" 94 | for i, line in enumerate(code_string.split('\n'), start=1): 95 | error += f"{i}: {line}\n" 96 | 97 | raise CodeExecutionError(error) 98 | 99 | finally: 100 | # Clean up the temporary file 101 | if os.path.exists(temp_file_path): 102 | os.remove(temp_file_path) 103 | 104 | 105 | 106 | def execute_function_string(input_code): 107 | """ 108 | Execute a given Python code string 109 | Parameters: 110 | input_code (dict): A dictionary containing the following keys: 111 | - 'code' (str): The Python code to be executed. 112 | - 'dataset_name' (str or list, optional): Name(s) of the dataset(s) used in the code. 113 | Returns: 114 | The result of executing the code using the local_code_executy function. 115 | 116 | """ 117 | code_string = input_code['code'] 118 | return local_code_executy(code_string) 119 | 120 | 121 | def put_obj_in_s3_bucket_(docs, bucket, key_prefix): 122 | """Uploads a file to an S3 bucket and returns the S3 URI of the uploaded object. 123 | Args: 124 | docs (str): The local file path of the file to upload to S3. 125 | bucket (str): S3 bucket name, 126 | key_prefix (str): S3 key prefix. 127 | Returns: 128 | str: The S3 URI of the uploaded object, in the format "s3://{bucket_name}/{file_path}". 129 | """ 130 | S3 = boto3.client('s3') 131 | if isinstance(docs, str): 132 | file_name = os.path.basename(docs) 133 | file_path = f"{key_prefix}/{docs}" 134 | S3.upload_file(f"/tmp/{docs}", bucket, file_path) 135 | else: 136 | file_name = os.path.basename(docs.name) 137 | file_path = f"{key_prefix}/{file_name}" 138 | S3.put_object(Body=docs.read(), Bucket=bucket, Key=file_path) 139 | return f"s3://{bucket}/{file_path}" 140 | 141 | 142 | 143 | def lambda_handler(event, context): 144 | try: 145 | input_data = json.loads(event) if isinstance(event, str) else event 146 | iterate = input_data.get('iterate', 0) 147 | bucket = input_data.get('bucket', '') 148 | s3_file_path = input_data.get('file_path', '') 149 | print(input_data, bucket, s3_file_path, iterate) 150 | result = execute_function_string(input_data) 151 | image_holder = [] 152 | plotly_holder = [] 153 | 154 | if isinstance(result, dict): 155 | for item, value in result.items(): 156 | if "image" in item and value is not None: 157 | if isinstance(value, list): 158 | for img in value: 159 | image_path_s3 = put_obj_in_s3_bucket_(img, bucket, s3_file_path) 160 | image_holder.append(image_path_s3) 161 | else: 162 | image_path_s3 = put_obj_in_s3_bucket_(value, bucket, s3_file_path) 163 | image_holder.append(image_path_s3) 164 | if "plotly-files" in item and value is not None: # Upload plotly objects to s3 165 | if isinstance(value, list): 166 | for img in value: 167 | image_path_s3 = put_obj_in_s3_bucket_(img, bucket, s3_file_path) 168 | plotly_holder.append(image_path_s3) 169 | else: 170 | image_path_s3 = put_obj_in_s3_bucket_(value, bucket, s3_file_path) 171 | plotly_holder.append(image_path_s3) 172 | 173 | tool_result = { 174 | "result": result, 175 | "image_dict": image_holder, 176 | "plotly": plotly_holder 177 | } 178 | print(tool_result) 179 | return { 180 | 'statusCode': 200, 181 | 'body': json.dumps(tool_result) 182 | } 183 | except Exception as e: 184 | print(e) 185 | return { 186 | 'statusCode': 500, 187 | 'body': json.dumps({'error': str(e)}) 188 | } -------------------------------------------------------------------------------- /Docker/requirements.txt: -------------------------------------------------------------------------------- 1 | boto3 2 | scikit-learn 3 | pandas 4 | numpy 5 | seaborn 6 | matplotlib 7 | s3fs 8 | openpyxl 9 | plotly>=5.0.0,<6.0.0 10 | kaleido -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT No Attribution 2 | 3 | Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy of 6 | this software and associated documentation files (the "Software"), to deal in 7 | the Software without restriction, including without limitation the rights to 8 | use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of 9 | the Software, and to permit persons to whom the Software is furnished to do so. 10 | 11 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 12 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS 13 | FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR 14 | COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER 15 | IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN 16 | CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 17 | 18 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Bedrock Claude ChatBot 2 | Bedrock Chat App is a Streamlit application that allows users to interact with various LLMs on Amazon Bedrock. It provides a conversational interface where users can ask questions, upload documents, and receive responses from the AI assistant. 3 | 4 | 5 | 6 | 7 | READ THE FOLLOWING **PREREQUISITES** CAREFULLY. 8 | 9 | ## Features 10 | 11 | - **Conversational UI**: The app provides a chat-like interface for seamless interaction with the AI assistant. 12 | - **Document Upload**: Users can upload various types of documents (PDF, CSV, TXT, PNG, JPG, XLSX, JSON, DOCX, Python scripts etc) to provide context for the AI assistant. 13 | - **Caching**: Uploaded documents and extracted text are cached in an S3 bucket for improved performance. This serves as the object storage unit for the application as documents are retrieved and loaded into the model to keep conversation context. 14 | - **Chat History**: The app stores stores and retrieves chat history (including document metadata) to/from a DynamoDB table, allowing users to continue conversations across sessions. 15 | - **Session Store**: The application utilizes DynamoDB to store and manage user and session information, enabling isolated conversations and state tracking for each user interaction. 16 | - **Model Selection**: Users can select from a broad list of LLMs on Amazon Bedrock including latest models from Anthropic Claude, Amazon Nova, Meta Llama, Deepseek etc for their queries and can include additional models on Bedrock by modifying teh `model-id.json` file. It incorporates the Bedrock Converse API providing a standardized model interface. 17 | - **Cost Tracking**: The application calculates and displays the cost associated with each chat session based on the input and output token counts and the pricing model defined in the `pricing.json` file. 18 | - **Logging**: The items logged in the DynamoDB table include the user ID, session ID, messages, timestamps,uploaded documents s3 path, input and output token counts. This helps to isolate user engagement statistics and track the various items being logged, as well as attribute the cost per user. 19 | - **Tool Usage**: **`Advanced Data Analytics tool`** for processing and analyzing structured data (CSV, XLX and XLSX format) in an isolated and serverless enviroment. 20 | - **Extensible Tool Integration**: This app can be modified to leverage the extensive Domain Specific Language (DSL) knowledge inherent in Large Language Models (LLMs) to implement a wide range of specialized tools. This capability is enhanced by the versatile execution environments provided by Docker containers and AWS Lambda, allowing for dynamic and adaptable implementation of various DSL-based functionalities. This approach enables the system to handle diverse domain-specific tasks efficiently, without the need for hardcoded, specialized modules for each domain. 21 | 22 | There are two files of interest. 23 | 1. A Jupyter Notebook that walks you through the ChatBot Implementation cell by cell (Advanced Data Analytics only available in the streamlit chatbot). 24 | 2. A Streamlit app that can be deployed to create a UI Chatbot. 25 | 26 | ## Pre-Requisites 27 | 1. [Amazon Bedrock Anthropic Claude Model Access](https://docs.aws.amazon.com/bedrock/latest/userguide/model-access.html) 28 | 2. [S3 bucket](https://docs.aws.amazon.com/AmazonS3/latest/userguide/create-bucket-overview.html) to store uploaded documents and Textract output. 29 | 3. Optional: 30 | - Create an Amazon DynamoDB table to store chat history (Run the notebook **BedrockChatUI** to create a DynamoDB Table). This is optional as there is a local disk storage option, however, I would recommend using Amazon DynamoDB. 31 | - Amazon Textract. This is optional as there is an option to use python libraries [`pypdf2`](https://pypi.org/project/PyPDF2/) and [`pytessesract`](https://pypi.org/project/pytesseract/) for PDF and image processing. However, I would recommend using Amazon Textract for higher quality PDF and image processing. You will experience latency when using `pytesseract`. 32 | - [Amazon Elastic Container Registry](https://docs.aws.amazon.com/AmazonECR/latest/userguide/repository-create.html) to store custom docker images if using the `Advanced Data Analytics` feature with the AWS Lambda setup. 33 | 34 | To use the **Advanced Analytics Feature**, this additional step is required (ChatBot can still be used without enabling `Advanced Analytics Feature`): 35 | 36 | This feature can be powered by a **python** runtime on AWS Lambda and/or a **pyspark** runtime on Amazon Athena. Expand the appropiate section below to view the set-up instructions. 37 | 38 |
39 | AWS Lambda Python Runtime Setup 40 | 41 | ## AWS Lambda Function with Custom Python Image 42 | 43 | 5. [Amazon Lambda](https://docs.aws.amazon.com/lambda/latest/dg/python-image.html#python-image-clients) function with custom python image to execute python code for analytics. 44 | - Create an private ECR repository by following the link in step 3. 45 | - On your local machine or any related AWS services including [AWS CloudShell](https://docs.aws.amazon.com/cloudshell/latest/userguide/welcome.html), [Amazon Elastic Compute Cloud](https://aws.amazon.com/ec2/getting-started/), [Amazon Sageamker Studio](https://aws.amazon.com/blogs/machine-learning/accelerate-ml-workflows-with-amazon-sagemaker-studio-local-mode-and-docker-support/) etc. run the following CLI commands: 46 | - install git and clone this git repo `git clone [github_link]` 47 | - navigate into the Docker directory `cd Docker` 48 | - if using local machine, authenticate with your [AWS credentials](https://docs.aws.amazon.com/cli/v1/userguide/cli-chap-authentication.html) 49 | - install [AWS Command Line Interface (AWS CLI) version 2](https://docs.aws.amazon.com/cli/latest/userguide/getting-started-install.html) if not already installed. 50 | - Follow the steps in the **Deploying the image** section under **Using an AWS base image for Python** in this [documentation guide](https://docs.aws.amazon.com/lambda/latest/dg/python-image.html#python-image-instructions). Replace the placeholders with the appropiate values. You can skip step `2` if you already created an ECR repository. 51 | - In step 6, in addition to `AWSLambdaBasicExecutionRole` policy, **ONLY** grant [least priveledged read and write Amazon S3 policies](https://docs.aws.amazon.com/IAM/latest/UserGuide/reference_policies_examples_s3_rw-bucket.html) to the execution role. Scope down the policy to only include the necessary S3 bucket and S3 directory prefix where uploaded files will be stored and read from as configured in the `config.json` file below. 52 | - In step 7, I recommend creating the Lambda function in a [Amazon Virtual Private Cloud (VPC)](https://docs.aws.amazon.com/lambda/latest/dg/configuration-vpc.html) without [internet access](https://docs.aws.amazon.com/vpc/latest/userguide/vpc-example-private-subnets-nat.html) and attach Amazon S3 and Amazon CloudWatch [gateway](https://docs.aws.amazon.com/vpc/latest/privatelink/vpc-endpoints-s3.html) and [interface endpoints](https://docs.aws.amazon.com/vpc/latest/privatelink/create-interface-endpoint.html#create-interface-endpoint.html) accordingly. The following step 7 command can be modified to include VPC paramters: 53 | ``` 54 | aws lambda create-function \ 55 | --function-name YourFunctionName \ 56 | --package-type Image \ 57 | --code ImageUri=your-account-id.dkr.ecr.your-region.amazonaws.com/your-repo:tag \ 58 | --role arn:aws:iam::your-account-id:role/YourLambdaExecutionRole \ 59 | --vpc-config SubnetIds=subnet-xxxxxxxx,subnet-yyyyyyyy,SecurityGroupIds=sg-zzzzzzzz \ 60 | --memory-size 512 \ 61 | --timeout 300 \ 62 | --region your-region 63 | ``` 64 | 65 | Modify the placeholders as appropiate. I recommend to keep `timeout` and `memory-size` params conservative as that will affect cost. A good staring point for memory is `512` MB. 66 | - Ignore step 8. 67 |
68 | 69 |
70 | Amazon Athena Spark Runtime Setup 71 | 72 | ## Create Amazon Athena Spark WorkGroup 73 | 74 | 5. Follow the instructions [Get started with Apache Spark on Amazon Athena](https://docs.aws.amazon.com/athena/latest/ug/notebooks-spark-getting-started.html) to create an Amazon Athena workgroup with Apache Spark. You `DO NOT` need to select `Turn on example notebook`. 75 | - Provide S3 permissions to the workgroup execution role for the S3 buckets configured with this application. 76 | - Note that the Amazon Athena Spark environment comes preinstalled with a select [few python libraries](https://docs.aws.amazon.com/athena/latest/ug/notebooks-spark-preinstalled-python-libraries.html). 77 | 78 |
79 | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- 80 | 81 | 82 | 83 | **⚠ IMPORTANT SECURITY NOTE:** 84 | 85 | Enabling the **Advanced Analytics Feature** allows the LLM to generate and execute Python code to analyze your dataset that will automatically be executed in a Lambda function environment. To mitigate potential risks: 86 | 87 | 1. **VPC Configuration**: 88 | - It is recommended to place the Lambda function in an internet-free VPC. 89 | - Use Amazon S3 and CloudWatch gateway/interface endpoints for necessary access. 90 | 91 | 2. **IAM Permissions**: 92 | - Scope down the AWS Lambda and/or Amazon Athena workgroup execution role to only Amazon S3 and the required S3 resources. This is in addition to `AWSLambdaBasicExecutionRole` policy if using AWS Lambda. 93 | 94 | 3. **Library Restrictions**: 95 | - Only libraries specified in `Docker/requirements.txt` will be available at runtime. 96 | - Modify this list carefully based on your needs. 97 | 98 | 4. **Resource Allocation**: 99 | - Adjust AWS Lambda function `timeout` and `memory-size` based on data size and analysis complexity. 100 | 101 | 5. **Production Considerations**: 102 | - This application is designed for POC use. 103 | - Implement additional security measures before deploying to production. 104 | 105 | The goal is to limit the potential impact of generated code execution. 106 | 107 | ## Configuration 108 | The application's behavior can be customized by modifying the `config.json` file. Here are the available options: 109 | 110 | - `DynamodbTable`: The name of the DynamoDB table to use for storing chat history. Leave this field empty if you decide to use local storage for chat history. 111 | - `UserId`: The DynamoDB user ID for the application. Leave this field empty if you decide to use local storage for chat history. 112 | - `Bucket_Name`: The name of the S3 bucket used for caching documents and extracted text. This is required. 113 | - `max-output-token`: The maximum number of output tokens allowed for the AI assistant's response. 114 | - `chat-history-loaded-length`: The number of recent chat messages to load from the DynamoDB table or Local storage. 115 | - `bedrock-region`: The AWS region where Amazon Bedrock is enabled. 116 | - `load-doc-in-chat-history`: A boolean flag indicating whether to load documents in the chat history. If `true` all documents would be loaded in chat history as context (provides more context of previous chat history to the AI at the cost of additional price and latency). If `false` only the user query and response would be loaded in the chat history, the AI would have no recollection of any document context from those chat conversations. When setting boolean in JSON use all lower caps. 117 | - `AmazonTextract`: A boolean indicating whether to use Amazon Textract or python libraries for PDF and image processing. Set to `false` if you do not have access to Amazon Textract. When setting boolean in JSON use all lower caps. 118 | - `csv-delimiter`: The delimiter to use when parsing structured content to string. Supported formats are "|", "\t", and ",". 119 | - `document-upload-cache-s3-path`: S3 bucket path to cache uploaded files. Do not include the bucket name, just the prefix without a trailing slash. For example "path/to/files". 120 | - `AmazonTextract-result-cache`: S3 bucket path to cache Amazon Textract result. Do not include the bucket name, just the prefix without a trailing slash. For example "path/to/files". 121 | - `lambda-function`: Name of the Lambda function deploy in the steps above. This is required if using the `Advanced Analytics Tool` with AWS Lambda. 122 | - `input_s3_path`: S3 directory prefix, without the foward and trailing slash, to render the S3 objects in the Chat UI. 123 | - `input_bucket`: S3 bucket name where the files to be rendered on the screen are stored. 124 | - `input_file_ext`: comma-seperated file extension names (without ".") for files in the S3 buckets to be rendered on the screen. By default `xlsx` and `csv` are included. 125 | - `athena-work-group-name`: Spark Amazon Athena workkgroup name created above. This is required if using the `Advanced Analytics Tool` with Amazon Athena. 126 | 127 | **⚠ IMPORTANT ADVISORY FOR ADVANCED ANALYTICS FEATURE** 128 | 129 | When using the **Advanced Analytics Feature**, take the following precautions: 130 | 1. **Sandbox Environment**: 131 | - Set `Bucket_Name` and `document-upload-cache-s3-path` to point to a separate, isolated "sandbox" S3 location. 132 | - Grant read and write access as [documented](https://docs.aws.amazon.com/IAM/latest/UserGuide/reference_policies_examples_s3_rw-bucket.html) to this bucket/prefix resource to the lambda execution role. 133 | - Do NOT use your main storage path for these parameters. This isolation is crucial, to avoid potential file overwrite, as the app will execute LLM-generated code. 134 | 2. **Input Data Safety**: 135 | - `input_s3_path` and `input_bucket` are used for read-only operations and can safely point to your main data storage. The LLM is not aware of this parameters unless explicitly provided by user during chat. 136 | - Only grant read access to this bucket/prefix resource in the execution role attached to the Lambda function. 137 | - **IMPORTANT**: Ensure `input_bucket` is different from `Bucket_Name`. 138 | 139 | By following these guidelines, you mitigate the potential risk of unintended data modification or loss in your primary storage areas. 140 | 141 | ## To run this Streamlit App on Sagemaker Studio follow the steps below: 142 | 143 | 144 | 145 | 146 | ![Demo](https://github.com/aws-samples/bedrock-claude-chatbot/raw/main/images/demo.mp4) 147 | 148 | If You have a Sagemaker AI Studio Domain already set up, ignore the first item, however, item 2 is required. 149 | * [Set Up SageMaker Studio](https://docs.aws.amazon.com/sagemaker/latest/dg/onboard-quick-start.html) 150 | * SageMaker execution role should have access to interact with [Bedrock](https://docs.aws.amazon.com/bedrock/latest/userguide/api-setup.html), [S3](https://docs.aws.amazon.com/AmazonS3/latest/userguide/access-policy-language-overview.html) and optionally [Textract](https://docs.aws.amazon.com/aws-managed-policy/latest/reference/AmazonTextractFullAccess.html) and [DynamoDB](https://docs.aws.amazon.com/amazondynamodb/latest/developerguide/iam-policy-specific-table-indexes.html), [AWS Lambda](https://docs.aws.amazon.com/lambda/latest/dg/access-control-identity-based.html) and [Amazon Athena](https://docs.aws.amazon.com/athena/latest/ug/managed-policies.html)if these services are used. 151 | 152 | ### On SageMaker AI Studio JupyterLab: 153 | * [Create a JupyterLab space](https://docs.aws.amazon.com/sagemaker/latest/dg/studio-updated-jl.html) 154 | * 155 | * Open a terminal by clicking **File** -> **New** -> **Terminal** 156 | * Navigate into the cloned repository directory using the `cd bedrock-claude-chatbot` command and run the following commands to install the application python libraries: 157 | - sudo apt update 158 | - sudo apt upgrade -y 159 | - chmod +x install_package.sh 160 | - ./install_package.sh 161 | - **NOTE**: If you run into this error `ERROR: Could not install packages due to an OSError: [Errno 2] No such file or directory: /opt/conda/lib/python3.10/site-packages/fsspec-2023.6.0.dist-info/METADATA`, I solved it by deleting the `fsspec` package by running the following command (this is due to have two versions of `fsspec` install 2023* and 2024*): 162 | - `rm /opt/conda/lib/python3.10/site-packages/fsspec-2023.6.0.dist-info -rdf` 163 | - pip install -U fsspec # fsspec 2024.9.0 should already be installed. 164 | * If you decide to use Python Libs for PDF and image processing, this requires tesserect-ocr. Run the following command: 165 | - sudo apt update -y 166 | - sudo apt-get install tesseract-ocr-all -y 167 | * Run command `python3 -m streamlit run bedrock-chat.py --server.enableXsrfProtection false` to start the Streamlit server. Do not use the links generated by the command as they won't work in studio. 168 | * Copy the URL of the SageMaker JupyterLab. It should look something like this https://qukigdtczjsdk.studio.us-east-1.sagemaker.aws/jupyterlab/default/lab/tree/healthlake/app_fhir.py. Replace everything after .../default/ with proxy/8501/, something like https://qukigdtczjsdk.studio.us-east-1.sagemaker.aws/jupyterlab/default/proxy/8501/. Make sure the port number (8501 in this case) matches with the port number printed out when you run the `python3 -m streamlit run bedrock-chat.py --server.enableXsrfProtection false` command; port number is the last 4 digits after the colon in the generated URL. 169 | 170 | 171 | ## To run this Streamlit App on AWS EC2 (I tested this on the Ubuntu Image) 172 | * [Create a new ec2 instance](https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/EC2_GetStarted.html) 173 | * Expose TCP port range 8500-8510 on Inbound connections of the attached Security group to the ec2 instance. TCP port 8501 is needed for Streamlit to work. See image below 174 | * 175 | * EC2 [instance profile role](https://docs.aws.amazon.com/IAM/latest/UserGuide/id_roles_use_switch-role-ec2_instance-profiles.html) has the required permissions to access the services used by this application mentioned above. 176 | * [Connect to your ec2 instance](https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/AccessingInstances.html) 177 | * Run the appropiate commands to update the ec2 instance (`sudo apt update` and `sudo apt upgrade` -for Ubuntu) 178 | * Clone this git repo `git clone [github_link]` and `cd bedrock-claude-chatbot` 179 | * Install python3 and pip if not already installed, `sudo apt install python3` and `sudo apt install python3-pip`. 180 | * If you decide to use Python Libs for PDF and image processing, this requires tesserect-ocr. Run the following command: 181 | - If using Centos-OS or Amazon-Linux: 182 | - sudo rpm -Uvh https://dl.fedoraproject.org/pub/epel/epel-release-latest-7.noarch.rpm 183 | - sudo yum -y update 184 | - sudo yum install -y tesseract 185 | - For Ubuntu or Debian: 186 | - sudo apt-get install tesseract-ocr-all -y 187 | * Install the dependencies by running the following commands (use `yum` for Centos-OS or Amazon-Linux): 188 | - sudo apt update 189 | - sudo apt upgrade -y 190 | - chmod +x install_package.sh 191 | - ./install_package.sh 192 | * Run command `tmux new -s mysession` to create a new session. Then in the new session created `cd bedrock-claude-chatbot` into the **ChatBot** dir and run `python3 -m streamlit run bedrock-chat.py` to start the streamlit app. This allows you to run the Streamlit application in the background and keep it running even if you disconnect from the terminal session. 193 | * Copy the **External URL** link generated and paste in a new browser tab. 194 | * **⚠ NOTE:** The generated link is not secure! For [additional guidance](https://github.com/aws-samples/deploy-streamlit-app). 195 | To stop the `tmux` session, in your ec2 terminal Press `Ctrl+b`, then `d` to detach. to kill the session, run `tmux kill-session -t mysession` 196 | 197 | ## Limitations and Future Updates 198 | 1. **Pricing**: Pricing is only calculated for the Bedrock models not including cost of any other AWS service used. In addition, the pricing information of the models are stored in a static `pricing.json` file. Do manually update the file to refelct current [Bedrock pricing details](https://aws.amazon.com/bedrock/pricing/). Use this cost implementation in this app as a rough estimate of actual cost of interacting with the Bedrock models as actual cost reported in your account may differ. 199 | 200 | 3. **Storage Encryption**: This application does not implement storing and reading files to and from S3 and/or DynamoDB using KMS keys for data at rest encryption. 201 | 202 | 4. **Production-Ready**: For an enterprise and production-ready chatbot application architecture pattern, check out [Generative AI Application Builder on AWS](https://aws.amazon.com/solutions/implementations/generative-ai-application-builder-on-aws/) and [Bedrock-Claude-Chat](https://github.com/aws-samples/bedrock-claude-chat) for best practices and recommendations. 203 | 204 | 5. **Tools Suite**: This application only includes a single tool. However, with the many niche applications of LLM's, a library of tools will make this application robust. 205 | 206 | ## Application Workflow Diagram 207 | 208 | 209 | 210 | **Guidelines** 211 | - When a document is uploaded (and for everytime it stays uploaded), its content is attached to the user's query, and the chatbot's responses are grounded in the document ( a sperate prompt template is used). That chat conversation is tagged with the document name as metadata to be used in the chat history. 212 | - If the document is detached, the chat history will only contain the user's queries and the chatbot's responses, unless the `load-doc-in-chat-history` configuration parameter is enabled, in which case the document content will be retained in the chat history. 213 | - You can refer to documents by their names of format (PDF, WORD, IMAGE etc) when having a conversation with the AI. 214 | - The `chat-history-loaded-length` setting determines how many previous conversations the LLM will be aware of, including any attached documents (if the `load-doc-in-chat-history` option is enabled). A higher value for this setting means that the LLM will have access to more historical context, but it may also increase the cost and potentially introduce latency, as more tokens will be inputted into the LLM. For optimal performance and cost-effectiveness, it's recommended to set the 'chat-history-loaded-length' to a value between 5 and 10. This range strikes a balance between providing the LLM with sufficient historical context while minimizing the input payload size and associated costs. 215 | - ⚠️ When using the Streamlit app, any uploaded document will be persisted for the current chat conversation. This means that subsequent questions, as well as chat histories (if the 'load-doc-in-chat-history' option is enabled), will have the document(s) as context, and the responses will be grounded in that document(s). However, this can increase the cost and latency, as the input payload will be larger due to the loaded document(s) in every chat turn. Therefore if you have the `load-doc-in-chat-history` option enabled, after your first question response with the uploaded document(s), it is recommended to remove the document(s) by clicking the **X** sign next to the uploaded file(s). The document(s) will be saved in the chat history, and you can ask follow-up questions about it, as the LLM will have knowledge of the document(s) from the chat history. On the other hand, if the `load-doc-in-chat-history` option is disabled, and you want to keep asking follow-up questions about the document(s), leave the document(s) uploaded until you are done. This way, only the current chat turn will have the document(s) loaded, and not the entire chat history. The choice between enabling `load-doc-in-chat-history` or not is dependent on cost and latency. I would recommend enabling for a smoother experience following the aforementioned guidelines. 216 | -------------------------------------------------------------------------------- /bedrock-chat.py: -------------------------------------------------------------------------------- 1 | import streamlit as st 2 | import boto3 3 | from botocore.config import Config 4 | import os 5 | import pandas as pd 6 | import time 7 | import json 8 | import io 9 | import re 10 | import openpyxl 11 | from python_calamine import CalamineWorkbook 12 | from openpyxl.cell import Cell 13 | import plotly.io as pio 14 | from openpyxl.worksheet.cell_range import CellRange 15 | from docx.table import _Cell 16 | from boto3.dynamodb.conditions import Key 17 | from pptx import Presentation 18 | from botocore.exceptions import ClientError 19 | from textractor import Textractor 20 | from textractor.data.constants import TextractFeatures 21 | from textractor.data.text_linearization_config import TextLinearizationConfig 22 | import pytesseract 23 | from PIL import Image 24 | import PyPDF2 25 | import chardet 26 | from docx import Document as DocxDocument 27 | from docx.oxml.text.paragraph import CT_P 28 | from docx.oxml.table import CT_Tbl 29 | from docx.document import Document 30 | from docx.text.paragraph import Paragraph 31 | from docx.table import Table as DocxTable 32 | import concurrent.futures 33 | from functools import partial 34 | import random 35 | from utils import function_calling_utils 36 | from urllib.parse import urlparse 37 | import plotly.graph_objects as go 38 | import numpy as np 39 | import base64 40 | config = Config( 41 | read_timeout=600, # Read timeout parameter 42 | retries=dict( 43 | max_attempts=10 # Handle retries 44 | ) 45 | ) 46 | 47 | st.set_page_config(initial_sidebar_state="auto") 48 | 49 | # Read app configurations 50 | with open('config.json','r',encoding='utf-8') as f: 51 | config_file = json.load(f) 52 | # pricing info 53 | with open('pricing.json','r',encoding='utf-8') as f: 54 | pricing_file = json.load(f) 55 | # Bedrock Model info 56 | with open('model_id.json','r',encoding='utf-8') as f: 57 | model_info = json.load(f) 58 | 59 | DYNAMODB = boto3.resource('dynamodb') 60 | COGNITO = boto3.client('cognito-idp') 61 | S3 = boto3.client('s3') 62 | LOCAL_CHAT_FILE_NAME = "chat-history.json" 63 | DYNAMODB_TABLE = config_file["DynamodbTable"] 64 | BUCKET = config_file["Bucket_Name"] 65 | OUTPUT_TOKEN = config_file["max-output-token"] 66 | S3_DOC_CACHE_PATH = config_file["document-upload-cache-s3-path"] 67 | TEXTRACT_RESULT_CACHE_PATH = config_file["AmazonTextract-result-cache"] 68 | LOAD_DOC_IN_ALL_CHAT_CONVO = config_file["load-doc-in-chat-history"] 69 | CHAT_HISTORY_LENGTH = config_file["chat-history-loaded-length"] 70 | DYNAMODB_USER = config_file["UserId"] 71 | REGION = config_file["region"] 72 | USE_TEXTRACT = config_file["AmazonTextract"] 73 | CSV_SEPERATOR = config_file["csv-delimiter"] 74 | INPUT_BUCKET = config_file["input_bucket"] 75 | INPUT_S3_PATH = config_file["input_s3_path"] 76 | INPUT_EXT = tuple(f".{x}" for x in config_file["input_file_ext"].split(',')) 77 | MODEL_DISPLAY_NAME = list(model_info.keys()) 78 | HYBRID_MODELS = ["sonnet-3.7", "sonnet-4", "opus-4"] # populate with list of hybrid reasoning models on Bedrock 79 | NON_VISION_MODELS = ["deepseek", "haiku-3.5", "nova-micro"] # populate with list of models not supporting image input on Bedrock 80 | NON_TOOL_SUPPORTING_MODELS = ["deepseek", "meta-scout", "meta-maverick"] # populate with list of models not supporting tool calling on Bedrock 81 | 82 | bedrock_runtime = boto3.client(service_name='bedrock-runtime', region_name=REGION, config=config) 83 | 84 | if 'messages' not in st.session_state: 85 | st.session_state['messages'] = [] 86 | if 'input_token' not in st.session_state: 87 | st.session_state['input_token'] = 0 88 | if 'output_token' not in st.session_state: 89 | st.session_state['output_token'] = 0 90 | if 'chat_hist' not in st.session_state: 91 | st.session_state['chat_hist'] = [] 92 | if 'user_sess' not in st.session_state: 93 | st.session_state['user_sess'] =str(time.time()) 94 | if 'chat_session_list' not in st.session_state: 95 | st.session_state['chat_session_list'] = [] 96 | if 'count' not in st.session_state: 97 | st.session_state['count'] = 0 98 | if 'userid' not in st.session_state: 99 | st.session_state['userid']= config_file["UserId"] 100 | if 'cost' not in st.session_state: 101 | st.session_state['cost'] = 0 102 | if 'reasoning_mode' not in st.session_state: 103 | st.session_state['reasoning_mode'] = False # Only activated when user selects anthropic 3.7 and toggles on thinking 104 | if 'athena-session' not in st.session_state: 105 | st.session_state['athena-session'] = "" 106 | 107 | def get_object_with_retry(bucket, key): 108 | max_retries = 5 109 | retries = 0 110 | backoff_base = 2 111 | max_backoff = 3 # Maximum backoff time in seconds 112 | s3 = boto3.client('s3') 113 | while retries < max_retries: 114 | try: 115 | response = s3.get_object(Bucket=bucket, Key=key) 116 | return response 117 | except ClientError as e: 118 | error_code = e.response['Error']['Code'] 119 | if error_code == 'DecryptionFailureException': 120 | sleep_time = min(max_backoff, backoff_base ** retries + random.uniform(0, 1)) 121 | print(f"Decryption failed, retrying in {sleep_time} seconds...") 122 | time.sleep(sleep_time) 123 | retries += 1 124 | elif e.response['Error']['Code'] == 'ModelStreamErrorException': 125 | if retries < max_retries: 126 | # Throttling, exponential backoff 127 | sleep_time = min(max_backoff, backoff_base ** retries + random.uniform(0, 1)) 128 | time.sleep(sleep_time) 129 | retries += 1 130 | else: 131 | raise e 132 | 133 | # If we reach this point, it means the maximum number of retries has been exceeded 134 | raise Exception(f"Failed to get object {key} from bucket {bucket} after {max_retries} retries.") 135 | 136 | def decode_numpy_array(obj): 137 | if isinstance(obj, dict) and 'dtype' in obj and 'bdata' in obj: 138 | dtype = np.dtype(obj['dtype']) 139 | return np.frombuffer(base64.b64decode(obj['bdata']), dtype=dtype) 140 | return obj 141 | 142 | # Decode the numpy arrays in the JSON data 143 | def decode_json(obj): 144 | if isinstance(obj, dict): 145 | return {k: decode_json(decode_numpy_array(v)) for k, v in obj.items()} 146 | elif isinstance(obj, list): 147 | return [decode_json(item) for item in obj] 148 | return obj 149 | 150 | def save_chat_local(file_path, new_data, session_id): 151 | """Store long term chat history Local Disk""" 152 | try: 153 | # Read the existing JSON data from the file 154 | with open(file_path, "r",encoding='utf-8') as file: 155 | existing_data = json.load(file) 156 | if session_id not in existing_data: 157 | existing_data[session_id]=[] 158 | except FileNotFoundError: 159 | # If the file doesn't exist, initialize an empty list 160 | existing_data = {session_id:[]} 161 | # Append the new data to the existing list 162 | from decimal import Decimal 163 | data = [{k: float(v) if isinstance(v, Decimal) else v for k, v in item.items()} for item in new_data] 164 | existing_data[session_id].extend(data) 165 | 166 | # Write the updated list back to the JSON file 167 | with open(file_path, "w",encoding="utf-8") as file: 168 | json.dump(existing_data, file) 169 | 170 | def load_chat_local(file_path,session_id): 171 | """Load long term chat history from Local""" 172 | try: 173 | # Read the existing JSON data from the file 174 | with open(file_path, "r",encoding='utf-8') as file: 175 | existing_data = json.load(file) 176 | if session_id in existing_data: 177 | existing_data=existing_data[session_id] 178 | else: 179 | existing_data=[] 180 | except FileNotFoundError: 181 | # If the file doesn't exist, initialize an empty list 182 | existing_data = [] 183 | return existing_data 184 | 185 | 186 | def process_files(files): 187 | """process uploaded files in parallel""" 188 | result_string="" 189 | errors = [] 190 | future_proxy_mapping = {} 191 | futures = [] 192 | 193 | with concurrent.futures.ProcessPoolExecutor() as executor: 194 | # Partial function to pass the handle_doc_upload_or_s3 function 195 | func = partial(handle_doc_upload_or_s3) 196 | for file in files: 197 | future = executor.submit(func, file) 198 | future_proxy_mapping[future] = file 199 | futures.append(future) 200 | 201 | # Collect the results and handle exceptions 202 | for future in concurrent.futures.as_completed(futures): 203 | file_url= future_proxy_mapping[future] 204 | try: 205 | result = future.result() 206 | doc_name=os.path.basename(file_url) 207 | result_string+=f"<{doc_name}>\n{result}\n\n" # tag documnets with names to enhance prompts 208 | except Exception as e: 209 | # Get the original function arguments from the Future object 210 | error = {'file': file_url, 'error': str(e)} 211 | errors.append(error) 212 | 213 | return errors, result_string 214 | 215 | def handle_doc_upload_or_s3(file, cutoff=None): 216 | """Handle various document format""" 217 | dir_name, ext = os.path.splitext(file) 218 | if ext.lower() in [".pdf", ".png", ".jpg",".tif",".jpeg"]: 219 | content=exract_pdf_text_aws(file) 220 | elif ".csv" == ext.lower(): 221 | content=parse_csv_from_s3(file,cutoff) 222 | elif ext.lower() in [".xlsx", ".xls"]: 223 | content=table_parser_utills(file,cutoff) 224 | elif ".json"==ext.lower(): 225 | obj=get_s3_obj_from_bucket_(file) 226 | content = json.loads(obj['Body'].read()) 227 | elif ext.lower() in [".txt",".py"]: 228 | obj=get_s3_obj_from_bucket_(file) 229 | content = obj['Body'].read() 230 | elif ".docx" == ext.lower(): 231 | obj=get_s3_obj_from_bucket_(file) 232 | content = obj['Body'].read() 233 | docx_buffer = io.BytesIO(content) 234 | content = extract_text_and_tables(docx_buffer) 235 | elif ".pptx" == ext.lower(): 236 | obj=get_s3_obj_from_bucket_(file) 237 | content = obj['Body'].read() 238 | docx_buffer = io.BytesIO(content) 239 | content = extract_text_from_pptx_s3(docx_buffer) 240 | # Implement any other file extension logic 241 | return content 242 | 243 | class InvalidContentError(Exception): 244 | pass 245 | 246 | def detect_encoding(s3_uri): 247 | """detect csv encoding""" 248 | s3 = boto3.client('s3') 249 | match = re.match("s3://(.+?)/(.+)", s3_uri) 250 | if match: 251 | bucket_name = match.group(1) 252 | key = match.group(2) 253 | response = s3.get_object(Bucket=bucket_name, Key=key) 254 | content = response['Body'].read() 255 | result = chardet.detect(content) 256 | return result['encoding'] 257 | 258 | def parse_csv_from_s3(s3_uri, cutoff): 259 | """read csv files""" 260 | try: 261 | # Detect the file encoding using chardet 262 | encoding = detect_encoding(s3_uri) 263 | # Sniff the delimiter and read the CSV file 264 | df = pd.read_csv(s3_uri, delimiter=None, engine='python', encoding=encoding) 265 | if cutoff: 266 | df=df.iloc[:20] 267 | return df.to_csv(index=False, sep=CSV_SEPERATOR) 268 | except Exception as e: 269 | raise InvalidContentError(f"Error: {e}") 270 | 271 | def iter_block_items(parent): 272 | if isinstance(parent, Document): 273 | parent_elm = parent.element.body 274 | elif isinstance(parent, _Cell): 275 | parent_elm = parent._tc 276 | else: 277 | raise ValueError("something's not right") 278 | 279 | for child in parent_elm.iterchildren(): 280 | if isinstance(child, CT_P): 281 | yield Paragraph(child, parent) 282 | elif isinstance(child, CT_Tbl): 283 | yield DocxTable(child, parent) 284 | 285 | def extract_text_and_tables(docx_path): 286 | """ Extract text from docx files""" 287 | document = DocxDocument(docx_path) 288 | content = "" 289 | current_section = "" 290 | section_type = None 291 | for block in iter_block_items(document): 292 | if isinstance(block, Paragraph): 293 | if block.text: 294 | if block.style.name == 'Heading 1': 295 | # Close the current section if it exists 296 | if current_section: 297 | content += f"{current_section}\n" 298 | current_section = "" 299 | section_type = None 300 | section_type ="h1" 301 | content += f"<{section_type}>{block.text}\n" 302 | elif block.style.name== 'Heading 3': 303 | # Close the current section if it exists 304 | if current_section: 305 | content += f"{current_section}\n" 306 | current_section = "" 307 | section_type = "h3" 308 | content += f"<{section_type}>{block.text}\n" 309 | elif block.style.name == 'List Paragraph': 310 | # Add to the current list section 311 | if section_type != "list": 312 | # Close the current section if it exists 313 | if current_section: 314 | content += f"{current_section}\n" 315 | section_type = "list" 316 | current_section = "" 317 | current_section += f"{block.text}\n" 318 | elif block.style.name.startswith('toc'): 319 | # Add to the current toc section 320 | if section_type != "toc": 321 | # Close the current section if it exists 322 | if current_section: 323 | content += f"{current_section}\n" 324 | section_type = "toc" 325 | current_section = "" 326 | current_section += f"{block.text}\n" 327 | else: 328 | # Close the current section if it exists 329 | if current_section: 330 | content += f"{current_section}\n" 331 | current_section = "" 332 | section_type = None 333 | 334 | # Append the passage text without tagging 335 | content += f"{block.text}\n" 336 | 337 | elif isinstance(block, DocxTable): 338 | # Add the current section before the table 339 | if current_section: 340 | content += f"{current_section}\n" 341 | current_section = "" 342 | section_type = None 343 | 344 | content += "\n" 345 | for row in block.rows: 346 | row_content = [] 347 | for cell in row.cells: 348 | cell_content = [] 349 | for nested_block in iter_block_items(cell): 350 | if isinstance(nested_block, Paragraph): 351 | cell_content.append(nested_block.text) 352 | elif isinstance(nested_block, DocxTable): 353 | nested_table_content = parse_nested_table(nested_block) 354 | cell_content.append(nested_table_content) 355 | row_content.append(CSV_SEPERATOR.join(cell_content)) 356 | content += CSV_SEPERATOR.join(row_content) + "\n" 357 | content += "
\n" 358 | 359 | # Add the final section 360 | if current_section: 361 | content += f"{current_section}\n" 362 | return content 363 | 364 | def parse_nested_table(table): 365 | nested_table_content = "\n" 366 | for row in table.rows: 367 | row_content = [] 368 | for cell in row.cells: 369 | cell_content = [] 370 | for nested_block in iter_block_items(cell): 371 | if isinstance(nested_block, Paragraph): 372 | cell_content.append(nested_block.text) 373 | elif isinstance(nested_block, DocxTable): 374 | nested_table_content += parse_nested_table(nested_block) 375 | row_content.append(CSV_SEPERATOR.join(cell_content)) 376 | nested_table_content += CSV_SEPERATOR.join(row_content) + "\n" 377 | nested_table_content += "
" 378 | return nested_table_content 379 | 380 | 381 | 382 | def extract_text_from_pptx_s3(pptx_buffer): 383 | """ Extract Text from pptx files""" 384 | presentation = Presentation(pptx_buffer) 385 | text_content = [] 386 | for slide in presentation.slides: 387 | slide_text = [] 388 | for shape in slide.shapes: 389 | if hasattr(shape, 'text'): 390 | slide_text.append(shape.text) 391 | text_content.append('\n'.join(slide_text)) 392 | return '\n\n'.join(text_content) 393 | 394 | def exract_pdf_text_aws(file): 395 | """extract text from PDFs using Amazon Textract or PyPDF2""" 396 | file_base_name = os.path.basename(file) 397 | dir_name, ext = os.path.splitext(file) 398 | # Checking if extracted doc content is in S3 399 | if USE_TEXTRACT: 400 | if [x for x in get_s3_keys(f"{TEXTRACT_RESULT_CACHE_PATH}/") if file_base_name in x]: 401 | response = get_object_with_retry(BUCKET, f"{TEXTRACT_RESULT_CACHE_PATH}/{file_base_name}.txt") 402 | text = response['Body'].read().decode() 403 | return text 404 | else: 405 | 406 | extractor = Textractor() 407 | # Asynchronous call, you will experience some wait time. Try caching results for better experience 408 | if "pdf" in ext: 409 | print("Asynchronous call, you may experience some wait time.") 410 | document = extractor.start_document_analysis( 411 | file_source=file, 412 | features=[TextractFeatures.LAYOUT, TextractFeatures.TABLES], 413 | save_image=False, 414 | s3_output_path=f"s3://{BUCKET}/textract_output/" 415 | ) 416 | # Synchronous call 417 | else: 418 | document = extractor.analyze_document( 419 | file_source=file, 420 | features=[TextractFeatures.LAYOUT,TextractFeatures.TABLES], 421 | save_image=False, 422 | ) 423 | config = TextLinearizationConfig( 424 | hide_figure_layout=False, 425 | hide_header_layout=False, 426 | table_prefix="", 427 | table_suffix="
", 428 | ) 429 | # Upload extracted content to s3 430 | S3.put_object(Body=document.get_text(config=config), Bucket=BUCKET, Key=f"{TEXTRACT_RESULT_CACHE_PATH}/{file_base_name}.txt") 431 | return document.get_text(config=config) 432 | else: 433 | s3 = boto3.resource("s3") 434 | match = re.match("s3://(.+?)/(.+)", file) 435 | if match: 436 | bucket_name = match.group(1) 437 | key = match.group(2) 438 | if "pdf" in ext: 439 | pdf_bytes = io.BytesIO() 440 | s3.Bucket(bucket_name).download_fileobj(key, pdf_bytes) 441 | # Read the PDF from the BytesIO object 442 | pdf_bytes.seek(0) 443 | # Create a PDF reader object 444 | pdf_reader = PyPDF2.PdfReader(pdf_bytes) 445 | # Get the number of pages in the PDF 446 | num_pages = len(pdf_reader.pages) 447 | # Extract text from each page 448 | text = '' 449 | for page_num in range(num_pages): 450 | page = pdf_reader.pages[page_num] 451 | text += page.extract_text() 452 | else: 453 | img_bytes = io.BytesIO() 454 | s3.Bucket(bucket_name).download_fileobj(key, img_bytes) 455 | img_bytes.seek(0) 456 | image_stream = io.BytesIO(img_bytes) 457 | image = Image.open(image_stream) 458 | text = pytesseract.image_to_string(image) 459 | return text 460 | 461 | def strip_newline(cell): 462 | return str(cell).strip() 463 | 464 | def table_parser_openpyxl(file, cutoff): 465 | """convert xlsx files to python pandas handling merged cells""" 466 | # Read from S3 467 | s3 = boto3.client('s3') 468 | match = re.match("s3://(.+?)/(.+)", file) 469 | if match: 470 | bucket_name = match.group(1) 471 | key = match.group(2) 472 | obj = s3.get_object(Bucket=bucket_name, Key=key) 473 | # Read Excel file from S3 into a buffer 474 | xlsx_buffer = io.BytesIO(obj['Body'].read()) 475 | xlsx_buffer.seek(0) 476 | # Load workbook 477 | wb = openpyxl.load_workbook(xlsx_buffer) 478 | all_sheets_string = "" 479 | # Iterate over each sheet in the workbook 480 | for sheet_name in wb.sheetnames: 481 | # all_sheets_name.append(sheet_name) 482 | worksheet = wb[sheet_name] 483 | 484 | all_merged_cell_ranges: list[CellRange] = list( 485 | worksheet.merged_cells.ranges 486 | ) 487 | for merged_cell_range in all_merged_cell_ranges: 488 | merged_cell: Cell = merged_cell_range.start_cell 489 | worksheet.unmerge_cells(range_string=merged_cell_range.coord) 490 | for row_index, col_index in merged_cell_range.cells: 491 | cell: Cell = worksheet.cell(row=row_index, column=col_index) 492 | cell.value = merged_cell.value 493 | # Convert sheet data to a DataFrame 494 | df = pd.DataFrame(worksheet.values) 495 | df = df.map(strip_newline) 496 | if cutoff: 497 | df = df.iloc[:20] 498 | 499 | # Convert to string and tag by sheet name 500 | tabb=df.to_csv(sep=CSV_SEPERATOR, index=False, header=0) 501 | all_sheets_string+=f'<{sheet_name}>\n{tabb}\n\n' 502 | return all_sheets_string 503 | else: 504 | raise Exception(f"{file} not formatted as an S3 path") 505 | 506 | def calamaine_excel_engine(file,cutoff): 507 | # # Read from S3 508 | s3 = boto3.client('s3') 509 | match = re.match("s3://(.+?)/(.+)", file) 510 | if match: 511 | bucket_name = match.group(1) 512 | key = match.group(2) 513 | obj = s3.get_object(Bucket=bucket_name, Key=key) 514 | # Read Excel file from S3 into a buffer 515 | xlsx_buffer = io.BytesIO(obj['Body'].read()) 516 | xlsx_buffer.seek(0) 517 | all_sheets_string = "" 518 | # Load the Excel file 519 | workbook = CalamineWorkbook.from_filelike(xlsx_buffer) 520 | # Iterate over each sheet in the workbook 521 | for sheet_name in workbook.sheet_names: 522 | # Get the sheet by name 523 | sheet = workbook.get_sheet_by_name(sheet_name) 524 | df = pd.DataFrame(sheet.to_python(skip_empty_area=False)) 525 | df = df.map(strip_newline) 526 | if cutoff: 527 | df = df.iloc[:20] 528 | # print(df) 529 | tabb = df.to_csv(sep=CSV_SEPERATOR, index=False, header=0) 530 | all_sheets_string += f'<{sheet_name}>\n{tabb}\n\n' 531 | return all_sheets_string 532 | else: 533 | raise Exception(f"{file} not formatted as an S3 path") 534 | 535 | def table_parser_utills(file,cutoff): 536 | try: 537 | response = table_parser_openpyxl(file, cutoff) 538 | if response: 539 | return response 540 | else: 541 | return calamaine_excel_engine(file, cutoff) 542 | except Exception as e: 543 | try: 544 | return calamaine_excel_engine(file, cutoff) 545 | except Exception as e: 546 | raise Exception(str(e)) 547 | 548 | 549 | def put_db(params,messages): 550 | """Store long term chat history in DynamoDB""" 551 | chat_item = { 552 | "UserId": st.session_state['userid'], # user id 553 | "SessionId": params["session_id"], # User session id 554 | "messages": [messages], # 'messages' is a list of dictionaries 555 | "time":messages['time'] 556 | } 557 | 558 | existing_item = DYNAMODB.Table(DYNAMODB_TABLE).get_item(Key={"UserId": st.session_state['userid'], "SessionId":params["session_id"]}) 559 | if "Item" in existing_item: 560 | existing_messages = existing_item["Item"]["messages"] 561 | chat_item["messages"] = existing_messages + [messages] 562 | response = DYNAMODB.Table(DYNAMODB_TABLE).put_item( 563 | Item=chat_item 564 | ) 565 | 566 | 567 | def get_chat_history_db(params,cutoff,vision_model): 568 | """process chat histories from local or DynamoDb to format expected by model input""" 569 | current_chat, chat_hist = [], [] 570 | if params['chat_histories'] and cutoff != 0: 571 | chat_hist = params['chat_histories'][-cutoff:] 572 | for d in chat_hist: 573 | if d['image'] and vision_model and LOAD_DOC_IN_ALL_CHAT_CONVO: 574 | content = [] 575 | for img in d['image']: 576 | s3 = boto3.client('s3') 577 | match = re.match("s3://(.+?)/(.+)", img) 578 | image_name = os.path.basename(img) 579 | _, ext = os.path.splitext(image_name) 580 | if "jpg" in ext: 581 | ext = ".jpeg" 582 | # if match: 583 | bucket_name = match.group(1) 584 | key = match.group(2) 585 | obj = s3.get_object(Bucket=bucket_name, Key=key) 586 | bytes_image = obj['Body'].read() 587 | content.extend([{"text": image_name}, { 588 | "image": { 589 | "format": f"{ext.lower().replace('.','')}", 590 | "source": {"bytes": bytes_image} 591 | } 592 | }]) 593 | content.extend([{"text":d['user']}]) 594 | current_chat.append({'role': 'user', 'content': content}) 595 | elif d['document'] and LOAD_DOC_IN_ALL_CHAT_CONVO: 596 | # Handle scenario where tool is used for dataset that is out of context for the model context length 597 | if 'tool_use_id' in d and d['tool_use_id']: 598 | doc = 'Here are the documents:\n' 599 | for docs in d['document']: 600 | uploads = handle_doc_upload_or_s3(docs,20) 601 | doc_name = os.path.basename(docs) 602 | doc += f"<{doc_name}>\n{uploads}\n\n" 603 | else: 604 | doc = 'Here are the documents:\n' 605 | for docs in d['document']: 606 | uploads = handle_doc_upload_or_s3(docs) 607 | doc_name = os.path.basename(docs) 608 | doc += f"<{doc_name}>\n{uploads}\n\n" 609 | if not vision_model and d["image"]: 610 | for docs in d['image']: 611 | uploads = handle_doc_upload_or_s3(docs) 612 | doc_name = os.path.basename(docs) 613 | doc += f"<{doc_name}>\n{uploads}\n\n" 614 | current_chat.append({'role': 'user', 'content': [{"text": doc+d['user']}]}) 615 | else: 616 | current_chat.append({'role': 'user', 'content': [{"text": d['user']}]}) 617 | current_chat.append({'role': 'assistant', 'content': [{"text": d['assistant']}]}) 618 | else: 619 | chat_hist = [] 620 | return current_chat, chat_hist 621 | 622 | 623 | def get_s3_keys(prefix): 624 | """list all keys in an s3 path""" 625 | s3 = boto3.client('s3') 626 | keys = [] 627 | next_token = None 628 | while True: 629 | if next_token: 630 | response = s3.list_objects_v2(Bucket=BUCKET, Prefix=prefix, ContinuationToken=next_token) 631 | else: 632 | response = s3.list_objects_v2(Bucket=BUCKET, Prefix=prefix) 633 | if "Contents" in response: 634 | for obj in response['Contents']: 635 | key = obj['Key'] 636 | name = key[len(prefix):] 637 | keys.append(name) 638 | if "NextContinuationToken" in response: 639 | next_token = response["NextContinuationToken"] 640 | else: 641 | break 642 | return keys 643 | 644 | def parse_s3_uri(uri): 645 | """ 646 | Parse an S3 URI and extract the bucket name and key. 647 | 648 | :param uri: S3 URI (e.g., 's3://bucket-name/path/to/file.txt') 649 | :return: Tuple of (bucket_name, key) if valid, (None, None) if invalid 650 | """ 651 | pattern = r'^s3://([^/]+)/(.*)$' 652 | match = re.match(pattern, uri) 653 | if match: 654 | return match.groups() 655 | return (None, None) 656 | 657 | def copy_s3_object(source_uri, dest_bucket, dest_key): 658 | """ 659 | Copy an object from one S3 location to another. 660 | 661 | :param source_uri: S3 URI of the source object 662 | :param dest_bucket: Name of the destination bucket 663 | :param dest_key: Key to be used for the destination object 664 | :return: True if successful, False otherwise 665 | """ 666 | s3 = boto3.client('s3') 667 | 668 | # Parse the source URI 669 | source_bucket, source_key = parse_s3_uri(source_uri) 670 | if not source_bucket or not source_key: 671 | print(f"Invalid source URI: {source_uri}") 672 | return False 673 | 674 | try: 675 | # Create a copy source dictionary 676 | copy_source = { 677 | 'Bucket': source_bucket, 678 | 'Key': source_key 679 | } 680 | 681 | # Copy the object 682 | s3.copy_object(CopySource=copy_source, Bucket=dest_bucket, Key=f"{dest_key}/{source_key}") 683 | 684 | print(f"File copied from {source_uri} to s3://{dest_bucket}/{dest_key}/{source_key}") 685 | return f"s3://{dest_bucket}/{dest_key}/{source_key}" 686 | 687 | except ClientError as e: 688 | print(f"An error occurred: {e}") 689 | raise(e) 690 | # return False 691 | 692 | def plotly_to_png_bytes(s3_uri): 693 | """ 694 | Read a .plotly file from S3 given an S3 URI, convert it to a PNG image, and return the image as bytes. 695 | 696 | :param s3_uri: S3 URI of the .plotly file (e.g., 's3://bucket-name/path/to/file.plotly') 697 | :return: PNG image as bytes 698 | """ 699 | # Parse S3 URI 700 | parsed_uri = urlparse(s3_uri) 701 | bucket_name = parsed_uri.netloc 702 | file_key = parsed_uri.path.lstrip('/') 703 | 704 | # Initialize S3 client 705 | s3_client = boto3.client('s3') 706 | 707 | try: 708 | # Read the .plotly file from S3 709 | response = s3_client.get_object(Bucket=bucket_name, Key=file_key) 710 | plotly_data = json.loads(response['Body'].read().decode('utf-8')) 711 | 712 | # Create a Figure object from the plotly data 713 | fig = go.Figure(data=plotly_data['data'], layout=plotly_data.get('layout', {})) 714 | 715 | # Convert the figure to PNG bytes 716 | img_bytes = fig.to_image(format="png") 717 | 718 | return img_bytes 719 | 720 | except Exception as e: 721 | print(f"An error occurred: {str(e)}") 722 | return None 723 | 724 | 725 | def get_s3_obj_from_bucket_(file): 726 | s3 = boto3.client('s3') 727 | match = re.match("s3://(.+?)/(.+)", file) 728 | if match: 729 | bucket_name = match.group(1) 730 | key = match.group(2) 731 | obj = s3.get_object(Bucket=bucket_name, Key=key) 732 | return obj 733 | 734 | def put_obj_in_s3_bucket_(docs): 735 | if isinstance(docs,str): 736 | s3_uri_pattern = r'^s3://([^/]+)/(.*?([^/]+)/?)$' 737 | if bool(re.match(s3_uri_pattern, docs)): 738 | file_uri=copy_s3_object(docs, BUCKET, S3_DOC_CACHE_PATH) 739 | return file_uri 740 | else: 741 | file_name = os.path.basename(docs.name) 742 | file_path = f"{S3_DOC_CACHE_PATH}/{file_name}" 743 | S3.put_object(Body=docs.read(), Bucket= BUCKET, Key=file_path) 744 | return f"s3://{BUCKET}/{file_path}" 745 | 746 | 747 | def bedrock_streemer(params,response, handler): 748 | """ stream response from bedrock Runtime""" 749 | text = '' 750 | think = "" 751 | signature = "" 752 | for chunk in response['stream']: 753 | if 'contentBlockDelta' in chunk: 754 | delta = chunk['contentBlockDelta']['delta'] 755 | # print(chunk) 756 | if 'text' in delta: 757 | text += delta['text'] 758 | handler.markdown(text.replace("$", "\\$"), unsafe_allow_html=True) 759 | if 'reasoningContent' in delta: 760 | if "text" in delta['reasoningContent']: 761 | think += delta['reasoningContent']['text'] 762 | handler.markdown('**MODEL REASONING**\n\n' + think.replace("$", "\\$"), unsafe_allow_html=True) 763 | elif "signature" in delta['reasoningContent']: 764 | signature = delta['reasoningContent']['signature'] 765 | 766 | elif "metadata" in chunk: 767 | if 'cacheReadInputTokens' in chunk['metadata']['usage']: 768 | print(f"\nCache Read Tokens: {chunk['metadata']['usage']['cacheReadInputTokens']}") 769 | print(f"Cache Write Tokens: {chunk['metadata']['usage']['cacheWriteInputTokens']}") 770 | st.session_state['input_token'] = chunk['metadata']['usage']["inputTokens"] 771 | st.session_state['output_token'] = chunk['metadata']['usage']["outputTokens"] 772 | latency = chunk['metadata']['metrics']["latencyMs"] 773 | pricing = st.session_state['input_token']*pricing_file[f"{params['model']}"]["input"] + st.session_state['output_token'] * pricing_file[f"{params['model']}"]["output"] 774 | st.session_state['cost']+=pricing 775 | print(f"\nInput Tokens: {st.session_state['input_token']}\nOutput Tokens: {st.session_state['output_token']}\nLatency: {latency}ms") 776 | return text, think 777 | 778 | def bedrock_claude_(params, chat_history, system_message, prompt, 779 | model_id, image_path=None, handler=None): 780 | """ format user request and chat history and make a call to Bedrock Runtime""" 781 | chat_history_copy = chat_history[:] 782 | content = [] 783 | if image_path: 784 | if not isinstance(image_path, list): 785 | image_path = [image_path] 786 | for img in image_path: 787 | s3 = boto3.client('s3', region_name=REGION) 788 | match = re.match("s3://(.+?)/(.+)", img) 789 | image_name = os.path.basename(img) 790 | _, ext = os.path.splitext(image_name) 791 | if "jpg" in ext: 792 | ext = ".jpeg" 793 | bucket_name = match.group(1) 794 | key = match.group(2) 795 | if ".plotly" in key: 796 | print(key) 797 | bytes_image = plotly_to_png_bytes(img) 798 | ext = ".png" 799 | else: 800 | obj = s3.get_object(Bucket=bucket_name, Key=key) 801 | bytes_image = obj['Body'].read() 802 | content.extend([{"text":image_name},{ 803 | "image": { 804 | "format": f"{ext.lower().replace('.', '')}", 805 | "source": {"bytes": bytes_image} 806 | } 807 | }]) 808 | 809 | content.append({ 810 | "text": prompt 811 | }) 812 | chat_history_copy.append({"role": "user", 813 | "content": content}) 814 | system_message = [{"text": system_message}] 815 | 816 | if st.session_state['reasoning_mode']: 817 | response = bedrock_runtime.converse_stream(messages=chat_history_copy, modelId=model_id, 818 | inferenceConfig={"maxTokens": 18000, "temperature": 1}, 819 | system=system_message, 820 | additionalModelRequestFields={"thinking": {"type": "enabled", "budget_tokens": 10000}} 821 | ) 822 | else: 823 | response = bedrock_runtime.converse_stream(messages=chat_history_copy, modelId=model_id, 824 | inferenceConfig={"maxTokens": 4000 if "deepseek" not in model_id else 20000, 825 | "temperature": 0.5 if "deepseek" not in model_id else 0.6, 826 | }, 827 | system=system_message, 828 | ) 829 | answer, think=bedrock_streemer(params, response, handler) 830 | return answer, think 831 | 832 | def _invoke_bedrock_with_retries(params, current_chat, chat_template, question, model_id, image_path, handler): 833 | max_retries = 10 834 | backoff_base = 2 835 | max_backoff = 3 # Maximum backoff time in seconds 836 | retries = 0 837 | 838 | while True: 839 | try: 840 | response, think = bedrock_claude_(params,current_chat, chat_template, question, model_id, image_path, handler) 841 | return response, think 842 | except ClientError as e: 843 | if e.response['Error']['Code'] == 'ThrottlingException': 844 | if retries < max_retries: 845 | # Throttling, exponential backoff 846 | sleep_time = min(max_backoff, backoff_base ** retries + random.uniform(0, 1)) 847 | time.sleep(sleep_time) 848 | retries += 1 849 | else: 850 | raise e 851 | elif e.response['Error']['Code'] == 'ModelStreamErrorException': 852 | if retries < max_retries: 853 | # Throttling, exponential backoff 854 | sleep_time = min(max_backoff, backoff_base ** retries + random.uniform(0, 1)) 855 | time.sleep(sleep_time) 856 | retries += 1 857 | else: 858 | raise e 859 | elif e.response['Error']['Code'] == 'EventStreamError': 860 | if retries < max_retries: 861 | # Throttling, exponential backoff 862 | sleep_time = min(max_backoff, backoff_base ** retries + random.uniform(0, 1)) 863 | time.sleep(sleep_time) 864 | retries += 1 865 | else: 866 | raise e 867 | else: 868 | # Some other API error, rethrow 869 | raise 870 | 871 | def get_session_ids_by_user(table_name, user_id): 872 | """ 873 | Get Session Ids and corresponding top message for a user to populate the chat history drop down on the front end 874 | """ 875 | if DYNAMODB_TABLE: 876 | table = DYNAMODB.Table(table_name) 877 | message_list = {} 878 | session_ids = [] 879 | args = { 880 | 'KeyConditionExpression': Key('UserId').eq(user_id) 881 | } 882 | while True: 883 | response = table.query(**args) 884 | session_ids.extend([item['SessionId'] for item in response['Items']]) 885 | if 'LastEvaluatedKey' not in response: 886 | break 887 | args['ExclusiveStartKey'] = response['LastEvaluatedKey'] 888 | 889 | for session_id in session_ids: 890 | try: 891 | message_list[session_id] = DYNAMODB.Table(table_name).get_item(Key={"UserId": user_id, "SessionId": session_id})['Item']['messages'][0]['user'] 892 | except Exception as e: 893 | print(e) 894 | pass 895 | else: 896 | try: 897 | message_list={} 898 | # Read the existing JSON data from the file 899 | with open(LOCAL_CHAT_FILE_NAME, "r", encoding='utf-8') as file: 900 | existing_data = json.load(file) 901 | for session_id in existing_data: 902 | message_list[session_id]=existing_data[session_id][0]['user'] 903 | 904 | except FileNotFoundError: 905 | # If the file doesn't exist, initialize an empty list 906 | message_list = {} 907 | return message_list 908 | 909 | def list_csv_xlsx_in_s3_folder(bucket_name, folder_path): 910 | """ 911 | List all CSV and XLSX files in a specified S3 folder. 912 | 913 | :param bucket_name: Name of the S3 bucket 914 | :param folder_path: Path to the folder in the S3 bucket 915 | :return: List of CSV and XLSX file names in the folder 916 | """ 917 | s3 = boto3.client('s3') 918 | csv_xlsx_files = [] 919 | 920 | try: 921 | # Ensure the folder path ends with a '/' 922 | if not folder_path.endswith('/'): 923 | folder_path += '/' 924 | 925 | # List objects in the specified folder 926 | paginator = s3.get_paginator('list_objects_v2') 927 | page_iterator = paginator.paginate(Bucket=bucket_name, Prefix=folder_path) 928 | 929 | for page in page_iterator: 930 | if 'Contents' in page: 931 | for obj in page['Contents']: 932 | # Get the file name 933 | file_name = obj['Key'] 934 | 935 | # Check if the file is a CSV or XLSX 936 | if file_name.lower().endswith(INPUT_EXT): 937 | csv_xlsx_files.append(os.path.basename(file_name)) 938 | # csv_xlsx_files.append(file_name) 939 | 940 | return csv_xlsx_files 941 | 942 | except ClientError as e: 943 | print(f"An error occurred: {e}") 944 | return [] 945 | 946 | def query_llm(params, handler): 947 | """ 948 | Handles users requests and routes to a native call or tool use, then stores sonversation to local or DynamoDB 949 | """ 950 | 951 | if not isinstance(params['upload_doc'], list): 952 | raise TypeError("documents must be in a list format") 953 | 954 | vision_model = True 955 | model = 'us.' + model_info[params['model']] 956 | if any(keyword in [params['model']] for keyword in NON_VISION_MODELS): 957 | vision_model = False 958 | 959 | # prompt template for when a user uploads a doc 960 | doc_path = [] 961 | image_path = [] 962 | full_doc_path = [] 963 | doc = "" 964 | if params['tools']: 965 | messages, tool, results, image_holder, doc_list, stop_reason, plotly_fig = function_calling_utils.function_caller_claude_(params, handler) 966 | if stop_reason != "tool_use": 967 | return messages 968 | elif stop_reason == "tool_use": 969 | prompt = f"""You are a conversational AI Assitant. 970 | I will provide you with an question on a dataset, a python code that implements the solution to the question and the result of that code solution. 971 | Here is the question: 972 | 973 | {params['question']} 974 | 975 | 976 | Here is the python code: 977 | 978 | {tool['input']['code']} 979 | 980 | 981 | Here the result of the code: 982 | 983 | {results} 984 | 985 | 986 | After reading the user question, respond with a detailed analytical answer based entirely on the result from the code. Do NOT make up answers. 987 | When providing your respons: 988 | - Do not include any preamble, go straight to the answer. 989 | - It should not be obvious you are referencing the result.""" 990 | 991 | system_message="You always provide your response in a well presented format using markdown. Make use of tables, list etc. where necessary in your response, so information is well preseneted and easily read." 992 | answer, think = _invoke_bedrock_with_retries(params, [], system_message, prompt, model, image_holder, handler) 993 | 994 | chat_history = { 995 | "user": results["text"] if "text" in results else "", 996 | "assistant": answer, 997 | "image": image_holder, 998 | "document": [], # data_file,#doc_list, 999 | "plotly": plotly_fig, 1000 | "modelID": model, 1001 | "thinking": think, 1002 | "code": tool['input']['code'], 1003 | "time": str(time.time()), 1004 | "input_token": round(st.session_state['input_token']) , 1005 | "output_token": round(st.session_state['output_token']), 1006 | "tool_result_id": tool['toolUseId'], 1007 | "tool_name": '', 1008 | "tool_params": '' 1009 | } 1010 | # store convsation memory in DynamoDB table 1011 | if DYNAMODB_TABLE: 1012 | put_db(params, chat_history) 1013 | # use local disk for storage 1014 | else: 1015 | save_chat_local(LOCAL_CHAT_FILE_NAME, [chat_history], params["session_id"]) 1016 | return answer 1017 | else: 1018 | current_chat, chat_hist = get_chat_history_db(params, CHAT_HISTORY_LENGTH, vision_model) 1019 | if params['upload_doc'] or params['s3_objects']: 1020 | if params['upload_doc']: 1021 | doc = 'I have provided documents and/or images.\n' 1022 | for ids, docs in enumerate(params['upload_doc']): 1023 | file_name = docs.name 1024 | _, extensions = os.path.splitext(file_name) 1025 | docs = put_obj_in_s3_bucket_(docs) 1026 | full_doc_path.append(docs) 1027 | if extensions.lower() in [".jpg", ".jpeg", ".png", ".gif", ".webp"] and vision_model: 1028 | image_path.append(docs) 1029 | continue 1030 | 1031 | if params['s3_objects']: 1032 | doc = 'I have provided documents and/or images.\n' 1033 | for ids, docs in enumerate(params['s3_objects']): 1034 | file_name = docs 1035 | _, extensions = os.path.splitext(file_name) 1036 | docs = put_obj_in_s3_bucket_(f"s3://{INPUT_BUCKET}/{INPUT_S3_PATH}/{docs}") 1037 | full_doc_path.append(docs) 1038 | if extensions.lower() in [".jpg", ".jpeg", ".png", ".gif", ".webp"] and vision_model: 1039 | image_path.append(docs) 1040 | continue 1041 | 1042 | doc_path = [item for item in full_doc_path if item not in image_path] 1043 | errors, result_string = process_files(doc_path) 1044 | if errors: 1045 | st.error(errors) 1046 | doc += result_string 1047 | with open("prompt/doc_chat.txt", "r", encoding="utf-8") as f: 1048 | chat_template = f.read() 1049 | else: 1050 | # Chat template for open ended query 1051 | with open("prompt/chat.txt", "r", encoding="utf-8") as f: 1052 | chat_template = f.read() 1053 | 1054 | response, think = _invoke_bedrock_with_retries(params, current_chat, chat_template, 1055 | doc+params['question'], model, 1056 | image_path, handler) 1057 | # log the following items to dynamodb 1058 | chat_history = { 1059 | "user": params['question'], 1060 | "assistant": response, 1061 | "image": image_path, 1062 | "document": doc_path, 1063 | "modelID": model, 1064 | "thinking": think, 1065 | "time": str(time.time()), 1066 | "input_token": round(st.session_state['input_token']), 1067 | "output_token": round(st.session_state['output_token']) 1068 | } 1069 | # store convsation memory and user other items in DynamoDB table 1070 | if DYNAMODB_TABLE: 1071 | put_db(params, chat_history) 1072 | # use local memory for storage 1073 | else: 1074 | save_chat_local(LOCAL_CHAT_FILE_NAME, [chat_history], params["session_id"]) 1075 | return response 1076 | 1077 | 1078 | def get_chat_historie_for_streamlit(params): 1079 | """ 1080 | This function retrieves chat history stored in a dynamoDB table partitioned by a userID and sorted by a SessionID 1081 | """ 1082 | if DYNAMODB_TABLE: 1083 | chat_histories = DYNAMODB.Table(DYNAMODB_TABLE).get_item(Key={"UserId": st.session_state['userid'], "SessionId":params["session_id"]}) 1084 | # st.write(chat_histories) 1085 | if "Item" in chat_histories: 1086 | chat_histories = chat_histories['Item']['messages'] 1087 | else: 1088 | chat_histories = [] 1089 | else: 1090 | chat_histories = load_chat_local(LOCAL_CHAT_FILE_NAME, params["session_id"]) 1091 | 1092 | # Constructing the desired list of dictionaries 1093 | formatted_data = [] 1094 | if chat_histories: 1095 | for entry in chat_histories: 1096 | image_files = [os.path.basename(x) for x in entry.get('image', [])] 1097 | doc_files = [os.path.basename(x) for x in entry.get('document', [])] 1098 | code_script = entry.get('code', "") 1099 | assistant_attachment = '\n\n'.join(image_files+doc_files) 1100 | # Get entries but dont show the Function calling unecessary parts in the chat dialogue on streamlit 1101 | if "tool_result_id" in entry and not entry["tool_result_id"]: 1102 | formatted_data.append({ 1103 | "role": "user", 1104 | "content": entry["user"], 1105 | "thinking": entry.get('thinking', "") 1106 | }) 1107 | elif not "tool_result_id" in entry : 1108 | formatted_data.append({ 1109 | "role": "user", 1110 | "content": entry["user"], 1111 | "thinking": entry.get('thinking', "") 1112 | }) 1113 | if "tool_use_id" in entry and not entry["tool_use_id"]: 1114 | formatted_data.append({ 1115 | "role": "assistant", 1116 | "content": entry["assistant"], 1117 | "attachment": assistant_attachment, 1118 | "code": code_script, 1119 | "thinking": entry.get('thinking', "") 1120 | # "image_output": entry.get('image', []) if entry["tool_result_id"] else [] 1121 | }) 1122 | elif "tool_use_id" not in entry: 1123 | formatted_data.append({ 1124 | "role": "assistant", 1125 | "content": entry["assistant"], 1126 | "attachment": assistant_attachment, 1127 | "code": code_script, 1128 | "code-result": entry["user"], 1129 | "image_output": entry.get('image', []) if "tool_result_id" in entry else [], 1130 | "plotly": entry.get('plotly', []) if "tool_result_id" in entry else [], 1131 | "thinking": entry.get('thinking', "") 1132 | }) 1133 | else: 1134 | chat_histories=[] 1135 | return formatted_data,chat_histories 1136 | 1137 | 1138 | 1139 | def get_key_from_value(dictionary, value): 1140 | return next((key for key, val in dictionary.items() if val == value), None) 1141 | 1142 | def chat_bedrock_(params): 1143 | st.title('Chatty AI Assitant 🙂') 1144 | params['chat_histories'] = [] 1145 | if params["session_id"].strip(): 1146 | st.session_state.messages, params['chat_histories'] = get_chat_historie_for_streamlit(params) 1147 | for message in st.session_state.messages: 1148 | 1149 | with st.chat_message(message["role"]): 1150 | if "```" in message["content"]: 1151 | st.markdown(message["content"], unsafe_allow_html=True) 1152 | else: 1153 | st.markdown(message["content"].replace("$", "\\$"), unsafe_allow_html=True) 1154 | if message["role"] == "assistant": 1155 | if message["plotly"]: 1156 | for item in message["plotly"]: 1157 | bucket_name, key = item.replace('s3://', '').split('/', 1) 1158 | image_bytes = get_object_with_retry(bucket_name, key) 1159 | content = image_bytes['Body'].read() 1160 | json_data = json.loads(content.decode('utf-8')) 1161 | try: 1162 | fig = pio.from_json(json.dumps(json_data)) 1163 | except Exception: 1164 | decoded_data = decode_json(json_data) 1165 | fig = go.Figure(data=decoded_data['data'], layout=decoded_data['layout']) 1166 | st.plotly_chart(fig) 1167 | 1168 | elif message["image_output"]: 1169 | for item in message["image_output"]: 1170 | bucket_name, key = item.replace('s3://', '').split('/', 1) 1171 | image_bytes = get_object_with_retry(bucket_name, key) 1172 | # image_bytes=base64.b64decode(message["image"][image_idx]) 1173 | image = Image.open(io.BytesIO(image_bytes['Body'].read())) 1174 | st.image(image) 1175 | if message["attachment"]: 1176 | with st.expander(label="**attachments**"): 1177 | st.markdown(message["attachment"]) 1178 | # st.markdown(message["image_output"]) 1179 | if message['code']: 1180 | with st.expander(label="**code snippet**"): 1181 | st.markdown(f'```python\n{message["code"]}', unsafe_allow_html=True) 1182 | with st.expander(label="**code result**"): 1183 | st.markdown(f'```python\n{message["code-result"]}', unsafe_allow_html=True) 1184 | if message['thinking']: 1185 | with st.expander(label="**MODEL REASONING**"): 1186 | st.markdown(message["thinking"].replace("$", "\\$"), unsafe_allow_html=True) 1187 | 1188 | if prompt := st.chat_input("Whats up?"): 1189 | st.session_state.messages.append({"role": "user", "content": prompt}) 1190 | with st.chat_message("user"): 1191 | st.markdown(prompt.replace("$", "\\$"), unsafe_allow_html=True ) 1192 | with st.chat_message("assistant"): 1193 | message_placeholder = st.empty() 1194 | params["question"] = prompt 1195 | answer = query_llm(params, message_placeholder) 1196 | message_placeholder.markdown(answer.replace("$", "\\$"), unsafe_allow_html=True ) 1197 | st.session_state.messages.append({"role": "assistant", "content": answer}) 1198 | st.rerun() 1199 | 1200 | def app_sidebar(): 1201 | with st.sidebar: 1202 | st.metric(label="Bedrock Session Cost", value=f"${round(st.session_state['cost'], 2)}") 1203 | st.write("-----") 1204 | button = st.button("New Chat", type="primary") 1205 | models = MODEL_DISPLAY_NAME 1206 | model = st.selectbox('**Model**', models) 1207 | if any(keyword in [model] for keyword in HYBRID_MODELS): 1208 | st.session_state['reasoning_mode'] = st.toggle("Reasoning Mode", value=False, key="thinking") 1209 | st.write(st.session_state['reasoning_mode']) 1210 | else: 1211 | st.session_state['reasoning_mode'] = False 1212 | runtime = "" 1213 | tools = "" 1214 | user_sess_id = get_session_ids_by_user(DYNAMODB_TABLE, st.session_state['userid']) 1215 | float_keys = {float(key): value for key, value in user_sess_id.items()} 1216 | sorted_messages = sorted(float_keys.items(), reverse=True) 1217 | sorted_messages.insert(0, (float(st.session_state['user_sess']), "New Chat")) 1218 | if button: 1219 | st.session_state['user_sess'] = str(time.time()) 1220 | sorted_messages.insert(0, (float(st.session_state['user_sess']), "New Chat")) 1221 | st.session_state['chat_session_list'] = dict(sorted_messages) 1222 | chat_items = st.selectbox("**Chat Sessions**", st.session_state['chat_session_list'].values(), key="chat_sessions") 1223 | session_id = get_key_from_value(st.session_state['chat_session_list'], chat_items) 1224 | if model not in NON_TOOL_SUPPORTING_MODELS: 1225 | tools = st.multiselect("**Tools**", ["Advanced Data Analytics"], 1226 | key="function_collen", default=None) 1227 | if "Advanced Data Analytics" in tools: 1228 | engines = ["pyspark", "python"] 1229 | runtime = st.select_slider( 1230 | "Runtime", engines, key="enginees" 1231 | ) 1232 | bucket_items = list_csv_xlsx_in_s3_folder(INPUT_BUCKET, INPUT_S3_PATH) 1233 | bucket_objects = st.multiselect("**Files**", bucket_items, key="objector", default=None) 1234 | file = st.file_uploader('Upload a document', accept_multiple_files=True, 1235 | help="pdf,csv,txt,png,jpg,xlsx,json,py doc format supported") 1236 | if file and LOAD_DOC_IN_ALL_CHAT_CONVO: 1237 | st.warning('You have set **load-doc-in-chat-history** to true. For better performance, remove uploaded file(s) (by clicking **X**) **AFTER** first query on uploaded files. See the README for more info', icon="⚠️") 1238 | params = {"model": model, "session_id": str(session_id), 1239 | "chat_item": chat_items, 1240 | "upload_doc": file, 1241 | "tools": tools, 1242 | 's3_objects': bucket_objects, 1243 | "engine": runtime 1244 | } 1245 | st.session_state['count'] = 1 1246 | return params 1247 | 1248 | 1249 | def main(): 1250 | params = app_sidebar() 1251 | chat_bedrock_(params) 1252 | if __name__ == '__main__': 1253 | main() 1254 | -------------------------------------------------------------------------------- /config.json: -------------------------------------------------------------------------------- 1 | {"DynamodbTable": "", "UserId": "user id", "Bucket_Name": "S3 bucket name (without s3:// prefix) to save uploaded files", "max-output-token":4000, "chat-history-loaded-length":10, "region":"aws region", "load-doc-in-chat-history":true, "AmazonTextract":false, "csv-delimiter":"|", "document-upload-cache-s3-path":"S3 prefix without trailing or foward slash at either ends", "AmazonTextract-result-cache":"S3 prefix without trailing or foward slash at either ends", "lambda-function":"lambda function name", "input_s3_path":"S3 prefix without trailing or foward slash at either ends","input_bucket":"S3 Bucket name (without s3:// prefix) to render files on app","input_file_ext":"csv,xlsx,parquet", "athena-work-group-name":"name of athena spark workgroup"} 2 | -------------------------------------------------------------------------------- /images/JP-lab.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aws-samples/bedrock-claude-chatbot/d1ea1143554e92f6c58f84a52a3d68b77c5f106c/images/JP-lab.PNG -------------------------------------------------------------------------------- /images/chat-flow.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aws-samples/bedrock-claude-chatbot/d1ea1143554e92f6c58f84a52a3d68b77c5f106c/images/chat-flow.png -------------------------------------------------------------------------------- /images/chat-preview.JPG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aws-samples/bedrock-claude-chatbot/d1ea1143554e92f6c58f84a52a3d68b77c5f106c/images/chat-preview.JPG -------------------------------------------------------------------------------- /images/chatbot-snip.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aws-samples/bedrock-claude-chatbot/d1ea1143554e92f6c58f84a52a3d68b77c5f106c/images/chatbot-snip.PNG -------------------------------------------------------------------------------- /images/chatbot4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aws-samples/bedrock-claude-chatbot/d1ea1143554e92f6c58f84a52a3d68b77c5f106c/images/chatbot4.png -------------------------------------------------------------------------------- /images/demo.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aws-samples/bedrock-claude-chatbot/d1ea1143554e92f6c58f84a52a3d68b77c5f106c/images/demo.mp4 -------------------------------------------------------------------------------- /images/sg-rules.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aws-samples/bedrock-claude-chatbot/d1ea1143554e92f6c58f84a52a3d68b77c5f106c/images/sg-rules.PNG -------------------------------------------------------------------------------- /images/studio-new-launcher.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aws-samples/bedrock-claude-chatbot/d1ea1143554e92f6c58f84a52a3d68b77c5f106c/images/studio-new-launcher.png -------------------------------------------------------------------------------- /install_package.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Path to the requirements.txt file 4 | REQUIREMENTS_FILE="req.txt" 5 | 6 | # Check if the requirements file exists 7 | if [ ! -f "$REQUIREMENTS_FILE" ]; then 8 | echo "requirements.txt file not found!" 9 | exit 1 10 | fi 11 | 12 | # Loop through each line in the requirements.txt file 13 | while IFS= read -r package || [ -n "$package" ]; do 14 | # Install the package using pip 15 | echo "Installing $package" 16 | pip install "$package" 17 | 18 | # Check if the installation was successful 19 | if [ $? -eq 0 ]; then 20 | echo "$package installed successfully" 21 | else 22 | echo "Failed to install $package" 23 | exit 1 24 | fi 25 | done < "$REQUIREMENTS_FILE" 26 | 27 | echo "All packages installed successfully." 28 | -------------------------------------------------------------------------------- /model_id.json: -------------------------------------------------------------------------------- 1 | { 2 | "sonnet-4": "anthropic.claude-sonnet-4-20250514-v1:0", 3 | "sonnet-3.7": "anthropic.claude-3-7-sonnet-20250219-v1:0", 4 | "opus-4": "anthropic.claude-opus-4-20250514-v1:0", 5 | "sonnet-3.5-v2": "anthropic.claude-3-5-sonnet-20241022-v2:0", 6 | "haiku-3.5": "anthropic.claude-3-5-haiku-20241022-v1:0", 7 | "deepseek": "deepseek.r1-v1:0", 8 | "nova-pro": "amazon.nova-pro-v1:0", 9 | "nova-lite": "amazon.nova-lite-v1:0", 10 | "nova-micro": "amazon.nova-micro-v1:0", 11 | "nova-premier": "amazon.nova-premier-v1:0", 12 | "meta-maverick": "meta.llama4-maverick-17b-instruct-v1:0", 13 | "meta-scout": "meta.llama4-scout-17b-instruct-v1:0" 14 | } -------------------------------------------------------------------------------- /pricing.json: -------------------------------------------------------------------------------- 1 | {"sonnet-4": {"input": 0.000003, "output": 0.000015},"opus-4": {"input": 0.000015, "output": 0.000075},"sonnet-3.5-v1": {"input": 0.000003, "output": 0.000015},"sonnet-3.5-v2": {"input": 0.000003, "output": 0.000015},"haiku-3.5": {"input": 0.0000008, "output": 0.000004},"haiku": {"input": 0.00000025, "output": 0.00000125},"sonnet-3.7": {"input": 0.000003, "output": 0.000015},"meta-maverick": {"input": 0.00000024, "output": 0.00000097}, "meta-scout": {"input": 0.00000017, "output": 0.00000066},"nova-lite": {"input": 0.00000006, "output": 0.000000024},"nova-premier": {"input": 0.00000025, "output": 0.00000125},"nova-pro": {"input": 0.00000032, "output": 0.00000006},"deepseek": {"input": 0.000000135, "output": 0.00000054}, "nova-micro": {"input": 0.0000000035, "output": 0.000000014}} -------------------------------------------------------------------------------- /prompt/chat.txt: -------------------------------------------------------------------------------- 1 | You are a conversational AI assistant, proficient in delivering high-quality responses and resolving tasks effectively. You are very attentive and respond in markdown format. -------------------------------------------------------------------------------- /prompt/doc_chat.txt: -------------------------------------------------------------------------------- 1 | You are a conversational assistant, expert at providing quality and accurate answers based on a document(s) and/or image(s) provided. You are very attentive and respond in markdown format. Take your time to read through the document(s) and/or image(s) carefully and pay attention to relevant areas pertaining to the question(s). Once done reading, provide an answer to the user question(s). -------------------------------------------------------------------------------- /prompt/pyspark_debug_prompt.txt: -------------------------------------------------------------------------------- 1 | I will provide you a pyspark code that analyzes a tabular data and an error relating to the code. 2 | Here is a subset (first 3 rows) of each dataset: 3 | 4 | {dataset} 5 | 6 | 7 | Here is the pyspark code to analyze the data: 8 | 9 | {code} 10 | 11 | 12 | Here is the thrown error: 13 | 14 | {error} 15 | 16 | 17 | Debug and fix the code. Think through where the potential bug is and what solution is needed, put all this thinking process in XML tags. 18 | The data files are stored in Amazon S3 (the XML tags point to the S3 URI) and must be read from S3. 19 | 20 | Important considerations: 21 | - Always use S3A file system when reading files from S3. 22 | - Always generate the full code. 23 | - Each plots must be saved as PLOTLY (.plotly) files in '/tmp' directory. 24 | - Use proper namespace management for python and pyspark libraries. 25 | - DO NOT use a `try-exception` block to handle exceptions when writing the correct code, this prevents my front end code from handling exceptions properly. 26 | - Use pio.write_json() to save images as ".plotly" files 27 | 28 | Additional info: The code must output a JSON object variable name "output" with following keys: 29 | - 'text': Any text output generated by the Python code 30 | - 'plotly-files': Plotly objects saved as ".plotly" 31 | 32 | Reflect on your approach and considerations provided in reflection XML tags. 33 | 34 | Provide the fixed code within XML tags and all python top-level package names (seperated by comma, no extra formatting) needed within XML tags -------------------------------------------------------------------------------- /prompt/pyspark_tool_prompt.txt: -------------------------------------------------------------------------------- 1 | Purpose: Analyze structured data and generate text and graphical outputs using pyspark code without interpreting results. 2 | Input: Structured data file(s) (CSV, PARQUET) 3 | Processing: 4 | - Read input files from Amazon S3: 5 | CSV files: load using CSV-specific methods, e.g., spark.read.csv("s3a://path/to/your/file.csv", header=True, inferSchema=True) 6 | PARQUET files: Load using Parquet-specific methods, e.g., spark.read.parquet("s3a://path/to/your/file.parquet") 7 | Process files according to their true type, not their sample representation. Each file must be read from Amazon S3. 8 | - Perform statistical analysis 9 | - When working with columns that have special characters (".", " ") etc, wrap columns names in backtick "`". 10 | - Generate plots when possible. 11 | - If multiple data files are provided, always load all dataset for analysis. 12 | 13 | Visualization: 14 | - Ensure plots are clear and legible with a figure size of 10 by 12 inches. 15 | - Use Plotly for plots when possible and save plotly objects as ".plotly" also in /tmp directory 16 | - When genrating plots, use contarsting colours for legibitlity. 17 | - Remember, you should save .plotly version for each generated plot. 18 | 19 | Output: JSON object named "output" with: 20 | - 'text': All text-based results and printed information 21 | - 'plotly-files': Plotly objects saved as ".plotly" 22 | 23 | Important: 24 | - Generate code for analysis only, without interpreting results 25 | - Avoid making conclusive statements or recommendations 26 | - Present calculated statistics and generated visualizations without drawing conclusions 27 | - Save plots .plotly files accordingly using pio.write_json() to '/tmp' directory 28 | - Use proper namespace management for Python and PySpark libraries: 29 | - Import only the necessary modules or functions to avoid cluttering the namespace. 30 | - Use aliases for long module names (e.g., 'import pandas as pd'). 31 | - Avoid using 'from module import *' as it can lead to naming conflicts. Instead of 'from pyspark.sql import *' do 'from pyspark.sql import functions as F' 32 | - Group imports logically: standard library imports first, then third-party libraries, followed by local application imports. 33 | - Use efficient, well-documented, PEP 8 compliant code 34 | - Follow data analysis and visualization best practices 35 | - Include plots whenever possible 36 | - Store all results in the 'output' JSON object 37 | - Ensure 'output' is the final variable assigned in the code, whether inside or outside a function 38 | 39 | Example: 40 | import plotly.io as pio 41 | from pyspark.sql import functions as F 42 | from pyspark.sql.window import Window 43 | ... REST of IMPORT 44 | 45 | # In Amazon Athena, Spark context is already initialized as "spark" variable, no need to initialize 46 | 47 | # Read data 48 | df = spark.read.csv("s3a://path/to/file/file.csv") # for parquet use spark.read.parquet(..) 49 | 50 | ...REST OF CODE 51 | 52 | #Save plotly figures 53 | pio.write_json(fig, 'tmp/plot.plotly') 54 | 55 | # Prepare output 56 | output = { 57 | 'text': 'Statistical analysis results...\nOther printed output...', 58 | 'plotly-files': 'plot.plotly' # or ['plot1.plotly', 'plot2.plotly'] for multiple plotly figures 59 | } 60 | 61 | # No Need to stop Spark context -------------------------------------------------------------------------------- /prompt/pyspark_tool_system.txt: -------------------------------------------------------------------------------- 1 | You are a conversational AI assistant, proficient in delivering high-quality responses and resolving tasks effectively. 2 | You will have access to a set of "tools" for handling specific request, use your judgement to figure out if you need to use a tool and what tool to use. I will provide the tool description below that guides you on if and when to use a tool: 3 | 1. pyspark_function_tool: This tool is used to handle structured data files to perform any data analysis query and task on such files (CSV, PQ). Structure data will usually be tagged by the file name and will be in a CSV string. 4 | If a user query does not need a tool, go ahead an answer the question directly without using any tool. Do not include any preamble -------------------------------------------------------------------------------- /prompt/pyspark_tool_template.json: -------------------------------------------------------------------------------- 1 | { 2 | "tools": [ 3 | { 4 | "toolSpec": { 5 | "name": "pyspark_function_tool", 6 | "description": "This tool allows you to analyze structured data files (CSV, PARQUET, XLSX, etc.) using PySpark programming language. It can be used to answer questions or perform analyses on the data contained in these files. Use this tool when the user asks questions or requests analysis related to structured data files. Do not use this tool for any other query not related to analyzing tabular data files. When using this tool, first think through the user ask and understand the data types within the dataset, put your thoughts in XML tags as your scratch pad.", 7 | "inputSchema": { 8 | "json": { 9 | "type": "object", 10 | "properties": { 11 | "code": { 12 | "type": "string", 13 | "description": "" 14 | 15 | }, 16 | "dataset_name": { 17 | "type": "string", 18 | "description": "The file name of the structured dataset including its extension (CSV, PQ ..etc)" 19 | }, 20 | "python_packages": { 21 | "type": "string", 22 | "description": "Comma-separated list of Python libraries required to run the function" 23 | } 24 | }, 25 | "required": ["code","dataset_name","python_packages"] 26 | } 27 | } 28 | } 29 | } 30 | ] 31 | } -------------------------------------------------------------------------------- /prompt/python_debug_prompt.txt: -------------------------------------------------------------------------------- 1 | I will provide you a python code that analyzes a tabular data and an error relating to the code. 2 | Here is a subset (first few rows) of each dataset: 3 | 4 | {dataset} 5 | 6 | 7 | Here is the python code to analyze the data: 8 | 9 | {code} 10 | 11 | 12 | Here is the thrown error: 13 | 14 | {error} 15 | 16 | 17 | Debug and fix the code. Think through where the potential bug is and what solution is needed, put all this thinking process in XML tags. 18 | The data files are stored in Amazon S3 (the XML tags point to the S3 URI) and must be read from S3. 19 | Images, if available in the code, must be saved in the '/tmp' directory. 20 | DO NOT use a `try-exception` block to handle exceptions when writing the correct code, this prevents my front end code from handling exceptions properly. 21 | Additional info: The code must output a JSON object variable name "output" with following keys: 22 | - 'text': Any text output generated by the Python code. 23 | - 'image': If the Python code generates any image outputs, image filenames (without the '/tmp' parent directory) will be mapped to this key. If no image is generated, no need for this key. (Must be in list format) 24 | - 'plotly-files': 'plot.plotly' # or ['plot1.plotly', 'plot2.plotly'] for multiple plotly figures 25 | Finally "output" variable must be saved as "output.json" to "/tmp" dir. 26 | Provide the fixed code within XML tags and all python top-level package names (seperated by comma, no extra formatting) needed within XML tags -------------------------------------------------------------------------------- /prompt/python_tool_prompt.txt: -------------------------------------------------------------------------------- 1 | Purpose: Analyze structured data and generate text and graphical outputs using python code without interpreting results. 2 | Input: Structured data file(s) (CSV, XLS, XLSX) 3 | Processing: 4 | - Read input files from Amazon S3: 5 | CSV files: pd.read_csv(s3://path/to/your/file.csv) 6 | XLS/XLSX files: Load using Excel-specific methods (e.g., pd.read_excel(s3://path/to/your/file.xlsx)) 7 | - Perform statistical analysis 8 | - Generate plots when possible 9 | 10 | Visualization: 11 | - Ensure plots are clear and legible with a figure size of 10 by 12 inches. 12 | - Save generated plots as PNG files in /tmp directory. Use appropiate filenames based on title of plots. 13 | - Always use fig.write_image() to save plotly figures as PNG. 14 | - Use Plotly for plots when possible and save plotly objects as ".plotly" also in /tmp directory 15 | - When genrating plots, use contarsting colours for legibitlity. 16 | - Remember, you should save a PNG and .plotly version for each generated plot. 17 | - When using Matplotlib, create a temporary directory and set it as MPLCONFIGDIR before importing any libraries to avoid permission issues in restricted environments. 18 | 19 | Output: JSON object named "output" with: 20 | - 'text': All text-based results and printed information 21 | - 'image': Filename(s) of PNG plot(s) 22 | - 'plotly-files': Plotly objects saved as ".plotly" 23 | - save 'output.json' in '/tmp' directory 24 | 25 | Important: 26 | - Generate code for analysis only, without interpreting results 27 | - Avoid making conclusive statements or recommendations 28 | - Present calculated statistics and generated visualizations without drawing conclusions 29 | - Save plots as PNG and .plotly files accordingly 30 | 31 | Notes: 32 | - Take time to think about the code to be generated for the user query 33 | - Save plots as PNG and .plotly files in '/tmp' directory 34 | - Use efficient, well-documented, PEP 8 compliant code 35 | - Follow data analysis and visualization best practices 36 | - Include plots whenever possible 37 | - Store all results in the 'output' JSON object 38 | - Ensure 'output' is the final variable assigned in the code, whether inside or outside a function 39 | 40 | Example: 41 | import plotly.io as pio 42 | import pandas as pd 43 | ... REST of IMPORT 44 | 45 | # Read the data 46 | df = pd.read_csv("s3://path/to/file/file.csv") 47 | 48 | ...REST OF CODE 49 | 50 | #Save plot as PNG 51 | fig.write_image("/tmp/plot.png") 52 | 53 | #Save plots as PLOTLY files 54 | pio.write_json(fig, '/tmp/plot.plotly') 55 | 56 | # Prepare output 57 | output = { 58 | 'text': '''Statistical analysis results...\nOther printed output...''', 59 | 'image': 'plot.png' # or ['plot1.png', 'plot2.png'] for multiple images 60 | 'plotly-files': 'plot.plotly' # or ['plot1.plotly', 'plot2.plotly'] for multiple plotly figures 61 | } 62 | 63 | # Save output as JSON file in 'tmp' dir 64 | with open('/tmp/output.json', 'w') as f: 65 | json.dump(output, f) -------------------------------------------------------------------------------- /prompt/python_tool_system.txt: -------------------------------------------------------------------------------- 1 | You are a conversational AI assistant, proficient in delivering high-quality responses and resolving tasks effectively. 2 | You will have access to a set of "tools" for handling specific request, use your judgement to figure out if you need to use a tool and what tool to use. I will provide the tool description below that guides you on if and when to use a tool: 3 | 1. python_function_tool: This tool is used to handle structured data files to perform any data analysis query and task on such files (CSV, XLSX, etc.). Structure data will usually be tagged by the file name and will be in a CSV string. 4 | 5 | If a user query does not need a tool, go ahead an answer the question directly without using any tool. Do not include any preamble -------------------------------------------------------------------------------- /prompt/python_tool_template.json: -------------------------------------------------------------------------------- 1 | { 2 | "tools": [ 3 | { 4 | "toolSpec": { 5 | "name": "python_function_tool", 6 | "description": "This tool allows you to analyze structured data files (CSV, XLSX, etc.) using Python programming language. It can be used to answer questions or perform analyses on the data contained in these files. Use this tool when the user asks questions or requests analyses related to structured data files. Do not use this tool for any other query not related to analyzing tabular data files.", 7 | "inputSchema": { 8 | "json": { 9 | "type": "object", 10 | "properties": { 11 | "code": { 12 | "type": "string", 13 | "description": "" 14 | 15 | }, 16 | "dataset_name": { 17 | "type": "string", 18 | "description": "The file name of the structured dataset including its extension (CSV, XLSX ..etc)" 19 | }, 20 | "python_packages": { 21 | "type": "string", 22 | "description": "Comma-separated list of Python libraries required to run the function" 23 | } 24 | }, 25 | "required": ["code","dataset_name","python_packages"] 26 | } 27 | } 28 | } 29 | } 30 | ] 31 | } -------------------------------------------------------------------------------- /req.txt: -------------------------------------------------------------------------------- 1 | packaging 2 | setuptools 3 | streamlit 4 | boto3 5 | pymupdf 6 | pandas 7 | numpy 8 | amazon-textract-textractor 9 | openpyxl 10 | inflect 11 | pypdf2 12 | pytesseract 13 | python-pptx 14 | python-docx 15 | pillow 16 | openpyxl 17 | pydantic 18 | python-calamine 19 | s3fs 20 | plotly>=5.0.0,<6.0.0 21 | kaleido -------------------------------------------------------------------------------- /utils/athena_handler_.py: -------------------------------------------------------------------------------- 1 | import boto3 2 | from botocore.exceptions import ClientError 3 | import time 4 | import streamlit as st 5 | import json 6 | 7 | def start_athena_session_( 8 | workgroup_name, 9 | description="Starting Athena session", 10 | coordinator_dpu_size=1, 11 | max_concurrent_dpus=60, 12 | default_executor_dpu_size=1, 13 | additional_configs=None, 14 | spark_properties=None, 15 | # notebook_version="Athena notebook version 1", 16 | session_idle_timeout_in_minutes=None, 17 | client_request_token=None 18 | ): 19 | """ 20 | Start an Athena session using boto3. 21 | 22 | Args: 23 | workgroup_name (str): The name of the workgroup. 24 | description (str): A description of the session. Default is "Starting Athena session". 25 | coordinator_dpu_size (int): The size of the coordinator DPU. Default is 1. 26 | max_concurrent_dpus (int): The maximum number of concurrent DPUs. Default is 20. 27 | default_executor_dpu_size (int): The default size of executor DPUs. Default is 1. 28 | additional_configs (dict): Additional configurations. Default is None. 29 | spark_properties (dict): Spark properties. Default is None. 30 | notebook_version (str): The version of the Athena notebook. Default is "Athena notebook version 1". 31 | session_idle_timeout_in_minutes (int): The idle timeout for the session in minutes. Default is None. 32 | client_request_token (str): A unique, case-sensitive identifier that you provide to ensure the idempotency of the request. Default is None. 33 | 34 | Returns: 35 | dict: A dictionary containing the SessionId and State of the started session. 36 | """ 37 | 38 | # Create an Athena client 39 | athena_client = boto3.client('athena') 40 | 41 | # Define the engine configuration 42 | engine_configuration = { 43 | 'CoordinatorDpuSize': coordinator_dpu_size, 44 | 'MaxConcurrentDpus': max_concurrent_dpus, 45 | 'DefaultExecutorDpuSize': default_executor_dpu_size, 46 | } 47 | 48 | if additional_configs: 49 | engine_configuration['AdditionalConfigs'] = additional_configs 50 | 51 | if spark_properties: 52 | engine_configuration['SparkProperties'] = spark_properties 53 | 54 | # Prepare the request parameters 55 | request_params = { 56 | 'Description': description, 57 | 'WorkGroup': workgroup_name, 58 | 'EngineConfiguration': engine_configuration, 59 | # 'NotebookVersion': notebook_version 60 | } 61 | 62 | if session_idle_timeout_in_minutes is not None: 63 | request_params['SessionIdleTimeoutInMinutes'] = session_idle_timeout_in_minutes 64 | 65 | if client_request_token: 66 | request_params['ClientRequestToken'] = client_request_token 67 | 68 | try: 69 | # Start the Athena session 70 | response = athena_client.start_session(**request_params) 71 | 72 | # Extract relevant information 73 | session_info = { 74 | 'SessionId': response['SessionId'], 75 | 'State': response['State'] 76 | } 77 | 78 | print(f"Athena session started successfully.") 79 | print(f"Session ID: {session_info['SessionId']}") 80 | print(f"State: {session_info['State']}") 81 | 82 | return session_info 83 | 84 | except ClientError as e: 85 | print(f"An error occurred while starting the Athena session: {e.response['Error']['Message']}") 86 | return None 87 | 88 | 89 | class SessionFailedException(Exception): 90 | """Custom exception for when the session is in a FAILED state.""" 91 | pass 92 | 93 | class SessionTimeoutException(Exception): 94 | """Custom exception for when the session check times out.""" 95 | pass 96 | 97 | def wait_for_session_status(session_id, max_wait_seconds=300, check_interval_seconds=3): 98 | """ 99 | Wait for an Athena session to reach either IDLE or FAILED state. 100 | 101 | Args: 102 | session_id (str): The ID of the session to check. 103 | max_wait_seconds (int): Maximum time to wait in seconds. Default is 300 seconds (5 minutes). 104 | check_interval_seconds (int): Time to wait between status checks in seconds. Default is 10 seconds. 105 | 106 | Returns: 107 | bool: True if the session state is IDLE. 108 | 109 | Raises: 110 | SessionFailedException: If the session state is FAILED. 111 | SessionTimeoutException: If the maximum wait time is exceeded. 112 | ClientError: If there's an error in the AWS API call. 113 | """ 114 | 115 | athena_client = boto3.client('athena') 116 | start_time = time.time() 117 | 118 | while True: 119 | try: 120 | response = athena_client.get_session_status(SessionId=session_id) 121 | state = response['Status']['State'] 122 | 123 | print(f"Session {session_id} is in state: {state}") 124 | 125 | if state == 'IDLE': 126 | return True 127 | elif state == 'FAILED': 128 | reason = response['Status'].get('StateChangeReason', 'No reason provided') 129 | raise SessionFailedException(f"Session {session_id} has FAILED. Reason: {reason}") 130 | elif state == 'TERMINATED': 131 | # return f"Session {session_id} is in state: {state}" 132 | return False 133 | 134 | # Check if we've exceeded the maximum wait time 135 | if time.time() - start_time > max_wait_seconds: 136 | raise SessionTimeoutException(f"Timeout waiting for session {session_id} to become IDLE or FAILED") 137 | 138 | # Wait for the specified interval before checking again 139 | time.sleep(check_interval_seconds) 140 | 141 | except ClientError as e: 142 | print(f"An error occurred while checking the session status: {e.response['Error']['Message']}") 143 | raise 144 | 145 | def execute_athena_calculation(session_id, code_block, workgroup, max_wait_seconds=600, check_interval_seconds=5): 146 | """ 147 | Execute a calculation in Athena, wait for completion, and retrieve results. 148 | 149 | Args: 150 | session_id (str): The Athena session ID. 151 | code_block (str): The code to execute. 152 | max_wait_seconds (int): Maximum time to wait for execution in seconds. Default is 600 seconds (10 minutes). 153 | check_interval_seconds (int): Time to wait between status checks in seconds. Default is 10 seconds. 154 | 155 | Returns: 156 | dict: A dictionary containing execution results or error information. 157 | """ 158 | athena_client = boto3.client('athena') 159 | s3_client = boto3.client('s3') 160 | 161 | def start_calculation(): 162 | try: 163 | response = athena_client.start_calculation_execution( 164 | SessionId=session_id, 165 | CodeBlock=code_block, 166 | # ClientRequestToken=f"token-{time.time()}" # Unique token for idempotency 167 | ) 168 | return response['CalculationExecutionId'] 169 | except ClientError as e: 170 | print(f"Failed to start calculation: {e}") 171 | raise 172 | 173 | def check_calculation_status(calculation_id): 174 | try: 175 | response = athena_client.get_calculation_execution(CalculationExecutionId=calculation_id) 176 | return response['Status']['State'] 177 | except ClientError as e: 178 | print(f"Failed to get calculation status: {e}") 179 | raise 180 | 181 | def get_calculation_result(calculation_id): 182 | try: 183 | response = athena_client.get_calculation_execution(CalculationExecutionId=calculation_id) 184 | return response['Result'] 185 | except ClientError as e: 186 | print(f"Failed to get calculation result: {e}") 187 | raise 188 | 189 | def download_s3_file(s3_uri): 190 | try: 191 | bucket, key = s3_uri.replace("s3://", "").split("/", 1) 192 | response = s3_client.get_object(Bucket=bucket, Key=key) 193 | return response['Body'].read().decode('utf-8') 194 | except ClientError as e: 195 | print(f"Failed to download S3 file: {e}") 196 | return None 197 | 198 | if session_id: 199 | if not wait_for_session_status(session_id): 200 | # Start Session 201 | session = start_athena_session_( 202 | workgroup_name=workgroup, 203 | session_idle_timeout_in_minutes=60 204 | ) 205 | if wait_for_session_status(session['SessionId']): 206 | session_id = session['SessionId'] 207 | st.session_state['athena-session'] = session['SessionId'] 208 | 209 | else: 210 | session = start_athena_session_( 211 | workgroup_name=workgroup, 212 | session_idle_timeout_in_minutes=60 213 | ) 214 | if wait_for_session_status(session['SessionId']): 215 | session_id = session['SessionId'] 216 | st.session_state['athena-session'] = session['SessionId'] 217 | 218 | # Start the calculation 219 | calculation_id = start_calculation() 220 | print(f"Started calculation with ID: {calculation_id}") 221 | 222 | # Wait for the calculation to complete 223 | start_time = time.time() 224 | while True: 225 | status = check_calculation_status(calculation_id) 226 | print(f"Calculation status: {status}") 227 | 228 | if status in ['COMPLETED', 'FAILED', 'CANCELED']: 229 | break 230 | 231 | if time.time() - start_time > max_wait_seconds: 232 | print("Calculation timed out") 233 | return {"error": "Calculation timed out"} 234 | 235 | time.sleep(check_interval_seconds) 236 | 237 | # Get the calculation result 238 | result = get_calculation_result(calculation_id) 239 | 240 | if status == 'COMPLETED': 241 | # Download and return the result 242 | result_content = download_s3_file(result['ResultS3Uri']) 243 | print(result) 244 | return { 245 | "status": "COMPLETED", 246 | "result": result_content, 247 | "stdout": download_s3_file(result['StdOutS3Uri']), 248 | "stderr": download_s3_file(result['StdErrorS3Uri']) 249 | } 250 | elif status == 'FAILED': 251 | # Get the error file 252 | error_content = download_s3_file(result['StdErrorS3Uri']) 253 | return { 254 | "status": "FAILED", 255 | "error": error_content, 256 | "stdout": download_s3_file(result['StdOutS3Uri']) 257 | } 258 | else: 259 | return {"status": status} 260 | 261 | def send_athena_job(payload, workgroup): 262 | """ 263 | Send a Spark job payload to the specified URL. 264 | 265 | Args: 266 | payload (dict): The payload containing the Spark job details. 267 | 268 | Returns: 269 | dict: The response from the server. 270 | """ 271 | 272 | code_block = """ 273 | import boto3 274 | import os 275 | def put_obj_in_s3_bucket_(docs, bucket, key_prefix): 276 | S3 = boto3.client('s3') 277 | if isinstance(docs,str): 278 | file_name=os.path.basename(docs) 279 | file_path=f"{key_prefix}/{docs}" 280 | S3.upload_file(f"/tmp/{docs}", bucket, file_path) 281 | else: 282 | file_name=os.path.basename(docs.name) 283 | file_path=f"{key_prefix}/{file_name}" 284 | S3.put_object(Body=docs.read(),Bucket= BUCKET, Key=file_path) 285 | return f"s3://{bucket}/{file_path}" 286 | 287 | def handle_results_(result, bucket, s3_file_path): 288 | image_holder = [] 289 | if isinstance(result, dict): 290 | for item, value in result.items(): 291 | if "plotly-files" in item and value is not None: 292 | if isinstance(value, list): 293 | for img in value: 294 | image_path_s3 = put_obj_in_s3_bucket_(img, bucket, s3_file_path) 295 | image_holder.append(image_path_s3) 296 | else: 297 | image_path_s3 = put_obj_in_s3_bucket_(value,bucket,s3_file_path) 298 | image_holder.append(image_path_s3) 299 | 300 | tool_result = { 301 | "result": result, 302 | "plotly": image_holder 303 | } 304 | return tool_result 305 | 306 | # iterate = input_data.get('iterate', 0) 307 | bucket="BUCKET-NAME" 308 | s3_file_path="BUCKET-PATH" 309 | result_final = handle_results_(output, bucket, s3_file_path) 310 | print(f"{result_final}") 311 | """.replace("BUCKET-NAME",payload["bucket"]).replace("BUCKET-PATH", payload["file_path"]) 312 | 313 | try: 314 | result = execute_athena_calculation(st.session_state['athena-session'], payload["code"]+ "\n" +code_block, workgroup) 315 | print(json.dumps(result, indent=2)) 316 | return result 317 | except Exception as e: 318 | print(f"An error occurred: {str(e)}") 319 | return e -------------------------------------------------------------------------------- /utils/function_calling_utils.py: -------------------------------------------------------------------------------- 1 | import streamlit as st 2 | import boto3 3 | from botocore.config import Config 4 | import os 5 | import pandas as pd 6 | import time 7 | import json 8 | from botocore.exceptions import ClientError 9 | import io 10 | import re 11 | from pptx import Presentation 12 | import random 13 | from python_calamine import CalamineWorkbook 14 | import chardet 15 | from docx.table import _Cell 16 | import concurrent.futures 17 | from functools import partial 18 | from textractor import Textractor 19 | from textractor.data.constants import TextractFeatures 20 | from textractor.data.text_linearization_config import TextLinearizationConfig 21 | import pytesseract 22 | from PIL import Image 23 | import PyPDF2 24 | from docx import Document as DocxDocument 25 | from docx.oxml.text.paragraph import CT_P 26 | from docx.oxml.table import CT_Tbl 27 | from docx.document import Document 28 | from docx.text.paragraph import Paragraph 29 | from docx.table import Table as DocxTable 30 | from utils.athena_handler_ import send_athena_job 31 | import ast 32 | from urllib.parse import urlparse 33 | import plotly.graph_objects as go 34 | 35 | config = Config( 36 | read_timeout=600, # Read timeout parameter 37 | connect_timeout=600, # Connection timeout parameter in seconds 38 | retries=dict( 39 | max_attempts=10 # Handle retries 40 | ) 41 | ) 42 | 43 | with open('config.json', 'r', encoding='utf-8') as f: 44 | config_file = json.load(f) 45 | 46 | # Bedrock Model info 47 | with open('model_id.json', 'r', encoding='utf-8') as f: 48 | model_info = json.load(f) 49 | 50 | S3 = boto3.client('s3') 51 | DYNAMODB = boto3.resource('dynamodb') 52 | LOCAL_CHAT_FILE_NAME = "chat-history.json" 53 | DYNAMODB_TABLE = config_file["DynamodbTable"] 54 | BUCKET = config_file["Bucket_Name"] 55 | OUTPUT_TOKEN = config_file["max-output-token"] 56 | S3_DOC_CACHE_PATH = config_file["document-upload-cache-s3-path"] 57 | TEXTRACT_RESULT_CACHE_PATH = config_file["AmazonTextract-result-cache"] 58 | LOAD_DOC_IN_ALL_CHAT_CONVO = config_file["load-doc-in-chat-history"] 59 | CHAT_HISTORY_LENGTH = config_file["chat-history-loaded-length"] 60 | DYNAMODB_USER = config_file["UserId"] 61 | REGION = config_file["region"] 62 | USE_TEXTRACT = config_file["AmazonTextract"] 63 | CSV_SEPERATOR = config_file["csv-delimiter"] 64 | LAMBDA_FUNC = config_file["lambda-function"] 65 | INPUT_S3_PATH = config_file["input_s3_path"] 66 | INPUT_BUCKET = config_file["input_bucket"] 67 | ATHENA_WORKGROUP_NAME = config_file["athena-work-group-name"] 68 | TEXT_ONLY_MODELS = ["deepseek", "haiku-3.5", "micro"] 69 | 70 | with open('pricing.json', 'r', encoding='utf-8') as f: 71 | pricing_file = json.load(f) 72 | 73 | 74 | 75 | def put_db(params, messages): 76 | """Store long term chat history in DynamoDB""" 77 | chat_item = { 78 | "UserId": st.session_state['userid'], # user id 79 | "SessionId": params["session_id"], # User session id 80 | "messages": [messages], # 'messages' is a list of dictionaries 81 | "time": messages['time'] 82 | } 83 | 84 | existing_item = DYNAMODB.Table(DYNAMODB_TABLE).get_item(Key={"UserId": st.session_state['userid'], "SessionId": params["session_id"]}) 85 | 86 | if "Item" in existing_item: 87 | existing_messages = existing_item["Item"]["messages"] 88 | chat_item["messages"] = existing_messages + [messages] 89 | 90 | response = DYNAMODB.Table(DYNAMODB_TABLE).put_item( 91 | Item=chat_item 92 | ) 93 | def save_chat_local(file_path, new_data,params): 94 | """Store long term chat history Local Disk""" 95 | try: 96 | # Read the existing JSON data from the file 97 | with open(file_path, "r",encoding='utf-8') as file: 98 | existing_data = json.load(file) 99 | if params["session_id"] not in existing_data: 100 | existing_data[params["session_id"]]=[] 101 | except FileNotFoundError: 102 | # If the file doesn't exist, initialize an empty list 103 | existing_data = {params["session_id"]:[]} 104 | # Append the new data to the existing list 105 | from decimal import Decimal 106 | data = [{k: float(v) if isinstance(v, Decimal) else v for k, v in item.items()} for item in new_data] 107 | existing_data[params["session_id"]].extend(data) 108 | # Write the updated list back to the JSON file 109 | with open(file_path, "w", encoding="utf-8") as file: 110 | json.dump(existing_data, file) 111 | 112 | def load_chat_local(file_path,params): 113 | """Load long term chat history from Local""" 114 | try: 115 | # Read the existing JSON data from the file 116 | with open(file_path, "r",encoding='utf-8') as file: 117 | existing_data = json.load(file) 118 | if params["session_id"] in existing_data: 119 | existing_data=existing_data[params["session_id"]] 120 | else: 121 | existing_data=[] 122 | except FileNotFoundError: 123 | # If the file doesn't exist, initialize an empty list 124 | existing_data = [] 125 | return existing_data 126 | 127 | 128 | def bedrock_streemer(params,response, handler): 129 | text='' 130 | think = "" 131 | signature = "" 132 | for chunk in response['stream']: 133 | if 'contentBlockDelta' in chunk: 134 | delta = chunk['contentBlockDelta']['delta'] 135 | # print(chunk) 136 | if 'text' in delta: 137 | text += delta['text'] 138 | handler.markdown(text.replace("$", "\\$"), unsafe_allow_html=True) 139 | if 'reasoningContent' in delta: 140 | if "text" in delta['reasoningContent']: 141 | think += delta['reasoningContent']['text'] 142 | handler.markdown('**MODEL REASONING**\n\n'+ think.replace("$", "\\$"), unsafe_allow_html=True) 143 | elif "signature" in delta['reasoningContent']: 144 | signature = delta['reasoningContent']['signature'] 145 | 146 | elif "metadata" in chunk: 147 | 148 | if 'cacheReadInputTokens' in chunk['metadata']['usage']: 149 | print(f"\nCache Read Tokens: {chunk['metadata']['usage']['cacheReadInputTokens']}") 150 | print(f"Cache Write Tokens: {chunk['metadata']['usage']['cacheWriteInputTokens']}") 151 | st.session_state['input_token'] = chunk['metadata']['usage']["inputTokens"] 152 | st.session_state['output_token'] = chunk['metadata']['usage']["outputTokens"] 153 | latency = chunk['metadata']['metrics']["latencyMs"] 154 | pricing = st.session_state['input_token'] * pricing_file[f"{params['model']}"]["input"] + st.session_state['output_token'] * pricing_file[f"{params['model']}"]["output"] 155 | st.session_state['cost']+=pricing 156 | print(f"\nInput Tokens: {st.session_state['input_token']}\nOutput Tokens: {st.session_state['output_token']}\nLatency: {latency}ms") 157 | return text, think 158 | 159 | def bedrock_claude_(params, chat_history, system_message, prompt, model_id, image_path=None, handler=None): 160 | # st.write(chat_history) 161 | chat_history_copy = chat_history[:] 162 | content = [] 163 | if image_path: 164 | if not isinstance(image_path, list): 165 | image_path = [image_path] 166 | for img in image_path: 167 | s3 = boto3.client('s3') 168 | match = re.match("s3://(.+?)/(.+)", img) 169 | image_name = os.path.basename(img) 170 | _, ext = os.path.splitext(image_name) 171 | if "jpg" in ext: 172 | ext = ".jpeg" 173 | bucket_name = match.group(1) 174 | key = match.group(2) 175 | if ".plotly" in key: 176 | bytes_image = plotly_to_png_bytes(img) 177 | ext = ".png" 178 | else: 179 | obj = s3.get_object(Bucket=bucket_name, Key=key) 180 | bytes_image = obj['Body'].read() 181 | content.extend([{"text": image_name}, { 182 | "image": { 183 | "format": f"{ext.lower().replace('.', '')}", 184 | "source": {"bytes": bytes_image} 185 | } 186 | }]) 187 | 188 | content.append({ 189 | "text": prompt 190 | }) 191 | chat_history_copy.append({"role": "user", 192 | "content": content}) 193 | system_message = [{"text": system_message}] 194 | 195 | config = Config( 196 | read_timeout=600, # Read timeout parameter 197 | retries=dict( 198 | max_attempts=10 ## Handle retries 199 | ) 200 | ) 201 | bedrock_runtime = boto3.client(service_name='bedrock-runtime', region_name=REGION, config=config) 202 | 203 | if st.session_state['reasoning_mode'] and "3-7" in model_id : 204 | response = bedrock_runtime.converse_stream(messages=chat_history_copy, modelId=model_id, 205 | inferenceConfig={"maxTokens": 18000, "temperature": 1,}, 206 | system=system_message, 207 | additionalModelRequestFields={"thinking": {"type": "enabled", "budget_tokens": 10000}} 208 | ) 209 | else: 210 | max_tokens = 8000 211 | if any(keyword in [params['model']] for keyword in ["haiku-3.5"]): 212 | max_tokens = 4000 213 | response = bedrock_runtime.converse_stream(messages=chat_history_copy, modelId=model_id, 214 | inferenceConfig={"maxTokens": max_tokens, "temperature": 0.5,}, 215 | system=system_message, 216 | ) 217 | 218 | answer, think = bedrock_streemer(params, response, handler) 219 | return answer, think 220 | 221 | def _invoke_bedrock_with_retries(params, current_chat, chat_template, question, model_id, image_path, handler): 222 | max_retries = 10 223 | backoff_base = 2 224 | max_backoff = 3 # Maximum backoff time in seconds 225 | retries = 0 226 | 227 | while True: 228 | try: 229 | response, think = bedrock_claude_(params, current_chat, chat_template, question, model_id, image_path, handler) 230 | return response, think 231 | except ClientError as e: 232 | if e.response['Error']['Code'] == 'ThrottlingException': 233 | if retries < max_retries: 234 | # Throttling, exponential backoff 235 | sleep_time = min(max_backoff, backoff_base ** retries + random.uniform(0, 1)) 236 | time.sleep(sleep_time) 237 | retries += 1 238 | else: 239 | raise e 240 | elif e.response['Error']['Code'] == 'ModelStreamErrorException': 241 | if retries < max_retries: 242 | # Throttling, exponential backoff 243 | sleep_time = min(max_backoff, backoff_base ** retries + random.uniform(0, 1)) 244 | time.sleep(sleep_time) 245 | retries += 1 246 | else: 247 | raise e 248 | elif e.response['Error']['Code'] == 'EventStreamError': 249 | if retries < max_retries: 250 | # Throttling, exponential backoff 251 | sleep_time = min(max_backoff, backoff_base ** retries + random.uniform(0, 1)) 252 | time.sleep(sleep_time) 253 | retries += 1 254 | else: 255 | raise e 256 | else: 257 | # Some other API error, rethrow 258 | raise 259 | 260 | def parse_s3_uri(uri): 261 | """ 262 | Parse an S3 URI and extract the bucket name and key. 263 | 264 | :param uri: S3 URI (e.g., 's3://bucket-name/path/to/file.txt') 265 | :return: Tuple of (bucket_name, key) if valid, (None, None) if invalid 266 | """ 267 | pattern = r'^s3://([^/]+)/(.*)$' 268 | match = re.match(pattern, uri) 269 | if match: 270 | return match.groups() 271 | return (None, None) 272 | 273 | def copy_s3_object(source_uri, dest_bucket, dest_key): 274 | """ 275 | Copy an object from one S3 location to another. 276 | 277 | :param source_uri: S3 URI of the source object 278 | :param dest_bucket: Name of the destination bucket 279 | :param dest_key: Key to be used for the destination object 280 | :return: True if successful, False otherwise 281 | """ 282 | s3 = boto3.client('s3') 283 | 284 | # Parse the source URI 285 | source_bucket, source_key = parse_s3_uri(source_uri) 286 | if not source_bucket or not source_key: 287 | print(f"Invalid source URI: {source_uri}") 288 | return False 289 | 290 | try: 291 | # Create a copy source dictionary 292 | copy_source = { 293 | 'Bucket': source_bucket, 294 | 'Key': source_key 295 | } 296 | 297 | # Copy the object 298 | s3.copy_object(CopySource=copy_source, Bucket=dest_bucket, Key=f"{dest_key}/{source_key}") 299 | return f"s3://{dest_bucket}/{dest_key}/{source_key}" 300 | 301 | except ClientError as e: 302 | print(f"An error occurred: {e}") 303 | raise(e) 304 | # return False 305 | 306 | class LibraryInstallationDetected(Exception): 307 | """Exception raised when potential library installation is detected.""" 308 | pass 309 | 310 | 311 | def check_for_library_installs(code_string): 312 | # Check for pip install commands using subprocess 313 | if re.search(r'subprocess\.(?:check_call|run|Popen)\s*\(\s*\[.*pip.*install', code_string): 314 | raise LibraryInstallationDetected(f"Potential library installation detected: '{keyword}' found in code.") 315 | 316 | # Check for pip as a module 317 | if re.search(r'pip\._internal\.main\(\[.*install', code_string) or re.search(r'pip\.main\(\[.*install', code_string): 318 | raise LibraryInstallationDetected(f"Potential library installation detected: '{keyword}' found in code.") 319 | 320 | keywords = ["subprocess","pip","conda","install","easy_install","setup.py","pipenv", 321 | "git+","svn+","hg+","bzr+","requirements.txt","environment.yml","apt-get","yum","brew", 322 | "ensurepip","get-pip","pkg_resources","importlib","setuptools","distutils","venv","virtualenv", 323 | "pyenv"] 324 | 325 | # Convert the code string to lowercase for case-insensitive matching 326 | code_lower = code_string.lower() 327 | 328 | # Check for each keyword 329 | for keyword in keywords: 330 | if keyword in code_lower: 331 | return True 332 | 333 | # Check for each keyword 334 | for keyword in keywords: 335 | if keyword in code_lower: 336 | raise LibraryInstallationDetected(f"Potential library installation detected: '{keyword}' found in code.") 337 | 338 | 339 | def put_obj_in_s3_bucket_(docs): 340 | """Uploads a file to an S3 bucket and returns the S3 URI of the uploaded object. 341 | Args: 342 | docs (str): The local file path of the file to upload to S3. 343 | Returns: 344 | str: The S3 URI of the uploaded object, in the format "s3://{bucket_name}/{file_path}". 345 | """ 346 | if isinstance(docs,str): 347 | s3_uri_pattern = r'^s3://([^/]+)/(.*?([^/]+)/?)$' 348 | if bool(re.match(s3_uri_pattern, docs)): 349 | file_uri=copy_s3_object(docs, BUCKET, S3_DOC_CACHE_PATH) 350 | return file_uri 351 | else: 352 | file_name=os.path.basename(docs) 353 | file_path=f"{S3_DOC_CACHE_PATH}/{docs}" 354 | S3.upload_file(docs, BUCKET, file_path) 355 | return f"s3://{BUCKET}/{file_path}" 356 | else: 357 | file_name=os.path.basename(docs.name) 358 | file_path=f"{S3_DOC_CACHE_PATH}/{file_name}" 359 | S3.put_object(Body=docs.read(),Bucket= BUCKET, Key=file_path) 360 | return f"s3://{BUCKET}/{file_path}" 361 | 362 | 363 | def get_large_s3_obj_from_bucket_(file, max_bytes=1000000): 364 | """Retrieves a portion of an object from an S3 bucket given its S3 URI. 365 | Args: 366 | file (str): The S3 URI of the object to retrieve, in the format "s3://{bucket_name}/{key}". 367 | max_bytes (int, optional): Maximum number of bytes to read from the beginning of the file. 368 | Returns: 369 | botocore.response.StreamingBody: The retrieved S3 object. 370 | """ 371 | s3 = boto3.client('s3') 372 | match = re.match("s3://(.+?)/(.+)", file) 373 | bucket_name = match.group(1) 374 | key = match.group(2) 375 | 376 | if max_bytes: 377 | # Read specific number of bytes 378 | obj = s3.get_object(Bucket=bucket_name, Key=key, Range=f'bytes=0-{max_bytes-1}') 379 | else: 380 | # Read the whole object if max_bytes is not specified 381 | obj = s3.get_object(Bucket=bucket_name, Key=key) 382 | 383 | return obj 384 | 385 | 386 | def get_s3_obj_from_bucket_(file): 387 | """Retrieves an object from an S3 bucket given its S3 URI. 388 | Args: 389 | file (str): The S3 URI of the object to retrieve, in the format "s3://{bucket_name}/{key}". 390 | Returns: 391 | botocore.response.StreamingBody: The retrieved S3 object. 392 | """ 393 | s3 = boto3.client('s3') 394 | match = re.match("s3://(.+?)/(.+)", file) 395 | bucket_name = match.group(1) 396 | key = match.group(2) 397 | obj = s3.get_object(Bucket=bucket_name, Key=key) 398 | return obj 399 | 400 | 401 | def iter_block_items(parent): 402 | if isinstance(parent, Document): 403 | parent_elm = parent.element.body 404 | elif isinstance(parent, _Cell): 405 | parent_elm = parent._tc 406 | else: 407 | raise ValueError("something's not right") 408 | 409 | for child in parent_elm.iterchildren(): 410 | if isinstance(child, CT_P): 411 | yield Paragraph(child, parent) 412 | elif isinstance(child, CT_Tbl): 413 | yield DocxTable(child, parent) 414 | 415 | def extract_text_and_tables(docx_path): 416 | """ Extract text from docx files""" 417 | document = DocxDocument(docx_path) 418 | content = "" 419 | current_section = "" 420 | section_type = None 421 | for block in iter_block_items(document): 422 | if isinstance(block, Paragraph): 423 | if block.text: 424 | if block.style.name == 'Heading 1': 425 | # Close the current section if it exists 426 | if current_section: 427 | content += f"{current_section}\n" 428 | current_section = "" 429 | section_type = None 430 | section_type ="h1" 431 | content += f"<{section_type}>{block.text}\n" 432 | elif block.style.name== 'Heading 3': 433 | # Close the current section if it exists 434 | if current_section: 435 | content += f"{current_section}\n" 436 | current_section = "" 437 | section_type = "h3" 438 | content += f"<{section_type}>{block.text}\n" 439 | 440 | elif block.style.name == 'List Paragraph': 441 | # Add to the current list section 442 | if section_type != "list": 443 | # Close the current section if it exists 444 | if current_section: 445 | content += f"{current_section}\n" 446 | section_type = "list" 447 | current_section = "" 448 | current_section += f"{block.text}\n" 449 | elif block.style.name.startswith('toc'): 450 | # Add to the current toc section 451 | if section_type != "toc": 452 | # Close the current section if it exists 453 | if current_section: 454 | content += f"{current_section}\n" 455 | section_type = "toc" 456 | current_section = "" 457 | current_section += f"{block.text}\n" 458 | else: 459 | # Close the current section if it exists 460 | if current_section: 461 | content += f"{current_section}\n" 462 | current_section = "" 463 | section_type = None 464 | 465 | # Append the passage text without tagging 466 | content += f"{block.text}\n" 467 | 468 | elif isinstance(block, DocxTable): 469 | # Add the current section before the table 470 | if current_section: 471 | content += f"{current_section}\n" 472 | current_section = "" 473 | section_type = None 474 | 475 | content += "\n" 476 | for row in block.rows: 477 | row_content = [] 478 | for cell in row.cells: 479 | cell_content = [] 480 | for nested_block in iter_block_items(cell): 481 | if isinstance(nested_block, Paragraph): 482 | cell_content.append(nested_block.text) 483 | elif isinstance(nested_block, DocxTable): 484 | nested_table_content = parse_nested_table(nested_block) 485 | cell_content.append(nested_table_content) 486 | row_content.append(CSV_SEPERATOR.join(cell_content)) 487 | content += CSV_SEPERATOR.join(row_content) + "\n" 488 | content += "
\n" 489 | 490 | # Add the final section 491 | if current_section: 492 | content += f"{current_section}\n" 493 | 494 | return content 495 | 496 | def parse_nested_table(table): 497 | nested_table_content = "\n" 498 | for row in table.rows: 499 | row_content = [] 500 | for cell in row.cells: 501 | cell_content = [] 502 | for nested_block in iter_block_items(cell): 503 | if isinstance(nested_block, Paragraph): 504 | cell_content.append(nested_block.text) 505 | elif isinstance(nested_block, DocxTable): 506 | nested_table_content += parse_nested_table(nested_block) 507 | row_content.append(CSV_SEPERATOR.join(cell_content)) 508 | nested_table_content += CSV_SEPERATOR.join(row_content) + "\n" 509 | nested_table_content += "
" 510 | return nested_table_content 511 | 512 | 513 | 514 | def extract_text_from_pptx_s3(pptx_buffer): 515 | """ Extract Text from pptx files""" 516 | presentation = Presentation(pptx_buffer) 517 | text_content = [] 518 | for slide in presentation.slides: 519 | slide_text = [] 520 | for shape in slide.shapes: 521 | if hasattr(shape, 'text'): 522 | slide_text.append(shape.text) 523 | text_content.append('\n'.join(slide_text)) 524 | return '\n\n'.join(text_content) 525 | 526 | def get_s3_keys(prefix): 527 | """list all keys in an s3 path""" 528 | s3 = boto3.client('s3') 529 | keys = [] 530 | next_token = None 531 | while True: 532 | if next_token: 533 | response = s3.list_objects_v2(Bucket=BUCKET, Prefix=prefix, ContinuationToken=next_token) 534 | else: 535 | response = s3.list_objects_v2(Bucket=BUCKET, Prefix=prefix) 536 | if "Contents" in response: 537 | for obj in response['Contents']: 538 | key = obj['Key'] 539 | name = key[len(prefix):] 540 | keys.append(name) 541 | if "NextContinuationToken" in response: 542 | next_token = response["NextContinuationToken"] 543 | else: 544 | break 545 | return keys 546 | 547 | def get_object_with_retry(bucket, key): 548 | max_retries=5 549 | retries = 0 550 | backoff_base = 2 551 | max_backoff = 3 # Maximum backoff time in seconds 552 | s3 = boto3.client('s3') 553 | while retries < max_retries: 554 | try: 555 | response = s3.get_object(Bucket=bucket, Key=key) 556 | return response 557 | except ClientError as e: 558 | error_code = e.response['Error']['Code'] 559 | if error_code == 'DecryptionFailureException': 560 | sleep_time = min(max_backoff, backoff_base ** retries + random.uniform(0, 1)) 561 | print(f"Decryption failed, retrying in {sleep_time} seconds...") 562 | time.sleep(sleep_time) 563 | retries += 1 564 | elif e.response['Error']['Code'] == 'ModelStreamErrorException': 565 | if retries < max_retries: 566 | # Throttling, exponential backoff 567 | sleep_time = min(max_backoff, backoff_base ** retries + random.uniform(0, 1)) 568 | time.sleep(sleep_time) 569 | retries += 1 570 | else: 571 | raise e 572 | 573 | def exract_pdf_text_aws(file): 574 | file_base_name=os.path.basename(file) 575 | dir_name, ext = os.path.splitext(file) 576 | # Checking if extracted doc content is in S3 577 | if USE_TEXTRACT: 578 | if [x for x in get_s3_keys(f"{TEXTRACT_RESULT_CACHE_PATH}/") if file_base_name in x]: 579 | response = get_object_with_retry(BUCKET, f"{TEXTRACT_RESULT_CACHE_PATH}/{file_base_name}.txt") 580 | text = response['Body'].read().decode() 581 | return text 582 | else: 583 | extractor = Textractor() 584 | # Asynchronous call, you will experience some wait time. Try caching results for better experience 585 | if "pdf" in ext: 586 | print("Asynchronous call, you may experience some wait time.") 587 | document = extractor.start_document_analysis( 588 | file_source=file, 589 | features=[TextractFeatures.LAYOUT,TextractFeatures.TABLES], 590 | save_image=False, 591 | s3_output_path=f"s3://{BUCKET}/textract_output/" 592 | ) 593 | #Synchronous call 594 | else: 595 | document = extractor.analyze_document( 596 | file_source=file, 597 | features=[TextractFeatures.LAYOUT,TextractFeatures.TABLES], 598 | save_image=False, 599 | ) 600 | config = TextLinearizationConfig( 601 | hide_figure_layout=False, 602 | hide_header_layout=False, 603 | table_prefix="", 604 | table_suffix="
", 605 | ) 606 | # Upload extracted content to s3 607 | S3.put_object(Body=document.get_text(config=config), Bucket=BUCKET, Key=f"{TEXTRACT_RESULT_CACHE_PATH}/{file_base_name}.txt") 608 | return document.get_text(config=config) 609 | else: 610 | s3=boto3.resource("s3") 611 | match = re.match("s3://(.+?)/(.+)", file) 612 | if match: 613 | bucket_name = match.group(1) 614 | key = match.group(2) 615 | if "pdf" in ext: 616 | pdf_bytes = io.BytesIO() 617 | s3.Bucket(bucket_name).download_fileobj(key, pdf_bytes) 618 | # Read the PDF from the BytesIO object 619 | pdf_bytes.seek(0) 620 | # Create a PDF reader object 621 | pdf_reader = PyPDF2.PdfReader(pdf_bytes) 622 | # Get the number of pages in the PDF 623 | num_pages = len(pdf_reader.pages) 624 | # Extract text from each page 625 | text = '' 626 | for page_num in range(num_pages): 627 | page = pdf_reader.pages[page_num] 628 | text += page.extract_text() 629 | else: 630 | img_bytes = io.BytesIO() 631 | s3.Bucket(bucket_name).download_fileobj(key, img_bytes) 632 | img_bytes.seek(0) 633 | image_stream = io.BytesIO(image_bytes) 634 | image = Image.open(image_stream) 635 | text = pytesseract.image_to_string(image) 636 | return text 637 | 638 | def detect_encoding(s3_uri): 639 | """detect csv encoding""" 640 | s3 = boto3.client('s3') 641 | match = re.match("s3://(.+?)/(.+)", s3_uri) 642 | if match: 643 | bucket_name = match.group(1) 644 | key = match.group(2) 645 | response = get_large_s3_obj_from_bucket_(s3_uri) 646 | content = response['Body'].read() 647 | result = chardet.detect(content) 648 | df = content.decode(result['encoding']) 649 | return result['encoding'], df 650 | 651 | class InvalidContentError(Exception): 652 | pass 653 | 654 | def parse_csv_from_s3(s3_uri): 655 | """Here we are only loading the first 3 rows to the model. 3 rows is sufficient for the model to figure out the schema""" 656 | 657 | try: 658 | # Detect the file encoding using chardet 659 | encoding, content = detect_encoding(s3_uri) 660 | # Use StringIO to create a file-like object 661 | csv_file = io.StringIO(content) 662 | # Read the CSV file using pandas 663 | df = pd.read_csv(csv_file, delimiter=None, engine='python').iloc[:3] 664 | data_types = df.dtypes 665 | return f"{df.to_csv(index=False)}\n\nHere are the data types for each columns:\n{data_types}" 666 | 667 | except Exception as e: 668 | raise InvalidContentError(f"Error: {e}") 669 | 670 | 671 | def strip_newline(cell): 672 | return str(cell).strip() 673 | 674 | def table_parser_openpyxl(file): 675 | """ 676 | Here we are only loading the first 20 rows to the model and we are not massaging the dataset by merging empty cells. 5 rows is sufficient for the model to figure out the schema 677 | """ 678 | # Read from S3 679 | s3 = boto3.client('s3', region_name=REGION) 680 | match = re.match("s3://(.+?)/(.+)", file) 681 | if match: 682 | bucket_name = match.group(1) 683 | key = match.group(2) 684 | obj = s3.get_object(Bucket=bucket_name, Key=key) 685 | # Read Excel file from S3 into a buffer 686 | xlsx_buffer = io.BytesIO(obj['Body'].read()) 687 | # Load workbook 688 | wb=pd.read_excel(xlsx_buffer,sheet_name=None, header=None) 689 | all_sheets_string="" 690 | # Iterate over each sheet in the workbook 691 | for sheet_name, sheet_data in wb.items(): 692 | df = pd.DataFrame(sheet_data) 693 | # Convert to string and tag by sheet name 694 | all_sheets_string+=f'Here is a data preview of this sheet (first 5 rows):<{sheet_name}>\n{df.iloc[:5].to_csv(index=False, header=False)}\n\n' 695 | return all_sheets_string 696 | else: 697 | raise Exception(f"{file} not formatted as an S3 path") 698 | 699 | def calamaine_excel_engine(file): 700 | """ 701 | Here we are only loading the first 20 rows to the model and we are not massaging the dataset by merging empty cells. 20 rows is sufficient for the model to figure out the schema 702 | """ 703 | # # Read from S3 704 | s3 = boto3.client('s3',region_name=REGION) 705 | match = re.match("s3://(.+?)/(.+)", file) 706 | if match: 707 | bucket_name = match.group(1) 708 | key = match.group(2) 709 | obj = s3.get_object(Bucket=bucket_name, Key=key) 710 | # Read Excel file from S3 into a buffer 711 | xlsx_buffer = io.BytesIO(obj['Body'].read()) 712 | xlsx_buffer.seek(0) 713 | all_sheets_string = "" 714 | # Load the Excel file 715 | workbook = CalamineWorkbook.from_filelike(xlsx_buffer) 716 | # Iterate over each sheet in the workbook 717 | for sheet_name in workbook.sheet_names: 718 | # Get the sheet by name 719 | sheet = workbook.get_sheet_by_name(sheet_name) 720 | df = pd.DataFrame(sheet.to_python(skip_empty_area=False)) 721 | df = df.map(strip_newline) 722 | all_sheets_string += f'Here is a data preview of this sheet (first 5 rows):\n\n<{sheet_name}>\n{df.iloc[:5].to_csv(index=False, header=0)}\n\n' 723 | return all_sheets_string 724 | else: 725 | raise Exception(f"{file} not formatted as an S3 path") 726 | 727 | def table_parser_utills(file): 728 | try: 729 | response = table_parser_openpyxl(file) 730 | if response: 731 | return response 732 | else: 733 | return calamaine_excel_engine(file) 734 | except Exception: 735 | try: 736 | return calamaine_excel_engine(file) 737 | except Exception as e: 738 | raise Exception(str(e)) 739 | 740 | def process_document_types(file): 741 | """Handle various document format""" 742 | dir_name, ext = os.path.splitext(file) 743 | if ".csv" == ext.lower(): 744 | content = parse_csv_from_s3(file) 745 | elif ext.lower() in [".txt", ".py"]: 746 | obj = get_s3_obj_from_bucket_(file) 747 | content = obj['Body'].read() 748 | elif ext.lower() in [".xlsx", ".xls"]: 749 | content = table_parser_utills(file) 750 | elif ext.lower() in [".pdf", ".png", ".jpg", ".tif", ".jpeg"]: 751 | content = exract_pdf_text_aws(file) 752 | elif ".json" == ext.lower(): 753 | obj = get_s3_obj_from_bucket_(file) 754 | content = json.loads(obj['Body'].read()) 755 | elif ".docx" == ext.lower(): 756 | obj = get_s3_obj_from_bucket_(file) 757 | content = obj['Body'].read() 758 | docx_buffer = io.BytesIO(content) 759 | content = extract_text_and_tables(docx_buffer) 760 | elif ".pptx" == ext.lower(): 761 | obj = get_s3_obj_from_bucket_(file) 762 | content = obj['Body'].read() 763 | docx_buffer = io.BytesIO(content) 764 | content = extract_text_from_pptx_s3(docx_buffer) 765 | 766 | # Implement any other file extension logic 767 | return content 768 | 769 | def plotly_to_png_bytes(s3_uri): 770 | """ 771 | Read a .plotly file from S3 given an S3 URI, convert it to a PNG image, and return the image as bytes. 772 | 773 | :param s3_uri: S3 URI of the .plotly file (e.g., 's3://bucket-name/path/to/file.plotly') 774 | :return: PNG image as bytes 775 | """ 776 | # Parse S3 URI 777 | parsed_uri = urlparse(s3_uri) 778 | bucket_name = parsed_uri.netloc 779 | file_key = parsed_uri.path.lstrip('/') 780 | 781 | # Initialize S3 client 782 | s3_client = boto3.client('s3') 783 | 784 | try: 785 | # Read the .plotly file from S3 786 | response = s3_client.get_object(Bucket=bucket_name, Key=file_key) 787 | plotly_data = json.loads(response['Body'].read().decode('utf-8')) 788 | 789 | # Create a Figure object from the plotly data 790 | fig = go.Figure(data=plotly_data['data'], layout=plotly_data.get('layout', {})) 791 | 792 | # Convert the figure to PNG bytes 793 | img_bytes = fig.to_image(format="png") 794 | 795 | return img_bytes 796 | 797 | except Exception as e: 798 | print(f"An error occurred: {str(e)}") 799 | return None 800 | 801 | 802 | def get_chat_history_db(params, cutoff,vision_model): 803 | """ 804 | Load chat history and attachments from DynamoDB and S3 accordingly 805 | 806 | parameters: 807 | params (dict): Application parameters 808 | cutoff (int): Custoff of Chat history to be loaded 809 | vision_model (bool): Boolean if Claude 3 model is used 810 | """ 811 | current_chat, chat_hist = [], [] 812 | if params['chat_histories'] and cutoff != 0: 813 | chat_hist = params['chat_histories'][-cutoff:] 814 | for ids, d in enumerate(chat_hist): 815 | if d['image'] and vision_model and LOAD_DOC_IN_ALL_CHAT_CONVO: 816 | content = [] 817 | for img in d['image']: 818 | s3 = boto3.client('s3') 819 | match = re.match("s3://(.+?)/(.+)", img) 820 | image_name = os.path.basename(img) 821 | _, ext = os.path.splitext(image_name) 822 | if "jpg" in ext: 823 | ext = ".jpeg" 824 | if match: 825 | bucket_name = match.group(1) 826 | key = match.group(2) 827 | if ".plotly" in key: 828 | bytes_image = plotly_to_png_bytes(img) 829 | ext = ".png" 830 | else: 831 | obj = s3.get_object(Bucket=bucket_name, Key=key) 832 | bytes_image = obj['Body'].read() 833 | content.extend( 834 | [ 835 | {"text": image_name}, 836 | {'image': 837 | { 838 | 'format': ext.lower().replace('.', ''), 839 | 'source': {'bytes': bytes_image} 840 | } 841 | } 842 | ] 843 | ) 844 | content.extend([{"text": d['user']}]) 845 | if 'tool_result_id' in d and d['tool_result_id']: 846 | user = [{'toolResult': {'toolUseId': d['tool_result_id'], 847 | 'content': content}}] 848 | current_chat.append({'role': 'user', 'content': user}) 849 | else: 850 | current_chat.append({'role': 'user', 'content': content}) 851 | elif d['document'] and LOAD_DOC_IN_ALL_CHAT_CONVO: 852 | doc = 'Here is a document showing sample rows:\n' 853 | for docs in d['document']: 854 | uploads = process_document_types(docs) 855 | doc_name = os.path.basename(docs) 856 | doc += f"<{doc_name}>\n{uploads}\n\n" 857 | if not vision_model and d["image"]: 858 | for docs in d['image']: 859 | uploads = process_document_types(docs) 860 | doc_name = os.path.basename(docs) 861 | doc += f"<{doc_name}>\n{uploads}\n\n" 862 | current_chat.append({'role': 'user', 'content': [{"text": doc + d['user']}]}) 863 | # do not have a tool return for the document section because the tool does not provide documents only images 864 | else: 865 | if 'tool_result_id' in d and d['tool_result_id']: 866 | user = [{'toolResult': {'toolUseId': d['tool_result_id'], 867 | 'content': [{'text': d['user']}]}}] 868 | current_chat.append({'role': 'user', 'content': user}) 869 | else: 870 | current_chat.append({'role': 'user', 'content': [{"text": d['user']}]}) 871 | 872 | if 'tool_use_id' in d and d['tool_use_id']: 873 | assistant = [{'toolUse': {'toolUseId': d['tool_use_id'], 874 | 'name': d['tool_name'], 875 | 'input': {'code': d['assistant'], 876 | "dataset_name": d['tool_params']['ds'], 877 | "python_packages": d['tool_params']['pp']}}} 878 | ] 879 | 880 | current_chat.append({'role': 'assistant', 'content': assistant}) 881 | else: 882 | current_chat.append({'role': 'assistant', 'content': [{"text": d['assistant']}]}) 883 | return current_chat, chat_hist 884 | 885 | def stream_messages(params, 886 | bedrock_client, 887 | model_id, 888 | messages, 889 | tool_config, 890 | system, 891 | temperature, 892 | handler): 893 | """ 894 | Sends a message to a model and streams the response. 895 | Args: 896 | bedrock_client: The Boto3 Bedrock runtime client. 897 | model_id (str): The model ID to use. 898 | messages (JSON) : The messages to send to the model. 899 | tool_config : Tool Information to send to the model. 900 | 901 | Returns: 902 | stop_reason (str): The reason why the model stopped generating text. 903 | message (JSON): The message that the model generated. 904 | 905 | """ 906 | 907 | if st.session_state['reasoning_mode']: 908 | response = bedrock_client.converse_stream( 909 | modelId=model_id, 910 | messages=messages, 911 | inferenceConfig={"maxTokens": 18000, "temperature": 1}, 912 | toolConfig=tool_config, 913 | system=system, 914 | additionalModelRequestFields={"thinking": {"type": "enabled", "budget_tokens": 10000}} 915 | ) 916 | else: 917 | response = bedrock_client.converse_stream( 918 | modelId=model_id, 919 | messages=messages, 920 | inferenceConfig={"maxTokens": 4000, "temperature": temperature}, 921 | toolConfig=tool_config, 922 | system=system 923 | ) 924 | 925 | stop_reason = "" 926 | message = {} 927 | content = [] 928 | message['content'] = content 929 | text = '' 930 | tool_use = {} 931 | think = '' 932 | signature = '' 933 | 934 | for chunk in response['stream']: 935 | if 'messageStart' in chunk: 936 | message['role'] = chunk['messageStart']['role'] 937 | elif 'contentBlockStart' in chunk: 938 | tool = chunk['contentBlockStart']['start']['toolUse'] 939 | tool_use['toolUseId'] = tool['toolUseId'] 940 | tool_use['name'] = tool['name'] 941 | elif 'contentBlockDelta' in chunk: 942 | delta = chunk['contentBlockDelta']['delta'] 943 | if 'toolUse' in delta: 944 | if 'input' not in tool_use: 945 | tool_use['input'] = '' 946 | tool_use['input'] += delta['toolUse']['input'] 947 | elif 'text' in delta: 948 | text += delta['text'] 949 | if handler: 950 | handler.markdown(text.replace("$", "\\$"), unsafe_allow_html=True) 951 | elif 'reasoningContent' in delta: 952 | if "text" in delta['reasoningContent']: 953 | think += delta['reasoningContent']['text'] 954 | handler.markdown('**MODEL REASONING**\n\n' + think.replace("$", "\\$"), unsafe_allow_html=True) 955 | if "signature" in delta['reasoningContent']: 956 | signature = delta['reasoningContent']['signature'] 957 | 958 | elif 'contentBlockStop' in chunk: 959 | if 'input' in tool_use: 960 | tool_use['input'] = json.loads(tool_use['input']) 961 | content.append({'toolUse': tool_use}) 962 | else: 963 | content.append({'text': text}) 964 | text = '' 965 | 966 | elif 'messageStop' in chunk: 967 | stop_reason = chunk['messageStop']['stopReason'] 968 | elif "metadata" in chunk: 969 | st.session_state['input_token'] = chunk['metadata']['usage']["inputTokens"] 970 | st.session_state['output_token'] = chunk['metadata']['usage']["outputTokens"] 971 | latency = chunk['metadata']['metrics']["latencyMs"] 972 | pricing = st.session_state['input_token'] * pricing_file[f"{params['model']}"]["input"] + st.session_state['output_token'] * pricing_file[f"{params['model']}"]["output"] 973 | st.session_state['cost']+=pricing 974 | 975 | if tool_use: 976 | try: 977 | handler.markdown(f"{text}\n```python\n{message['content'][1]['toolUse']['input']['code']}", unsafe_allow_html=True ) 978 | except: 979 | if len(message['content']) == 3: 980 | handler.markdown(f"{text}\n```python\n{message['content'][2]['toolUse']['input']['code']}", unsafe_allow_html=True ) 981 | else: 982 | handler.markdown(f"{text}\n```python\n{message['content'][0]['toolUse']['input']['code']}", unsafe_allow_html=True ) 983 | return stop_reason, message, input_tokens, output_tokens, think 984 | 985 | 986 | def self_crtique(params, code, error, dataset, handler=None): 987 | 988 | import re 989 | if params["engine"] == "pyspark": 990 | with open("prompt/pyspark_debug_prompt.txt", "r") as fo: 991 | prompt_template = fo.read() 992 | values = {"dataset": dataset, "code": code, "error": error} 993 | prompt = prompt_template.format(**values) 994 | system = "You are an expert pyspark debugger for Amazon Athena PySpark Runtime" 995 | else: 996 | with open("prompt/python_debug_prompt.txt", "r") as fo: 997 | prompt_template = fo.read() 998 | values = {"dataset": dataset, "code": code, "error": error} 999 | prompt = prompt_template.format(**values) 1000 | system = "You are an expert python debugger." 1001 | 1002 | model_id = 'us.' + model_info[params['model']] 1003 | fixed_code, think = _invoke_bedrock_with_retries(params, [], system, prompt, model_id, [], handler) 1004 | code_pattern = r'(.*?)' 1005 | match = re.search(code_pattern, fixed_code, re.DOTALL) 1006 | code = match.group(1) 1007 | if handler: 1008 | handler.markdown(f"```python\n{code}", unsafe_allow_html=True) 1009 | lib_pattern = r'(.*?)' 1010 | match = re.search(lib_pattern, fixed_code, re.DOTALL) 1011 | if match: 1012 | libs = match.group(1) 1013 | else: 1014 | libs = '' 1015 | return code, libs 1016 | 1017 | def process_files(files): 1018 | result_string = "" 1019 | errors = [] 1020 | future_proxy_mapping = {} 1021 | futures = [] 1022 | 1023 | with concurrent.futures.ProcessPoolExecutor() as executor: 1024 | # Partial function to pass the process_document_types function 1025 | func = partial(process_document_types) 1026 | for file in files: 1027 | future = executor.submit(func, file) 1028 | future_proxy_mapping[future] = file 1029 | futures.append(future) 1030 | 1031 | # Collect the results and handle exceptions 1032 | for future in concurrent.futures.as_completed(futures): 1033 | file_url = future_proxy_mapping[future] 1034 | try: 1035 | result = future.result() 1036 | doc_name = os.path.basename(file_url) 1037 | result_string += f"\n{result}\n\n" 1038 | except Exception as e: 1039 | # Get the original function arguments from the Future object 1040 | error = {'file': file_url, 'error': str(e)} 1041 | errors.append(error) 1042 | 1043 | return errors, result_string 1044 | 1045 | def invoke_lambda(function_name, payload): 1046 | config = Config( 1047 | connect_timeout=600, 1048 | read_timeout=600, # Read timeout parameter 1049 | retries=dict( 1050 | max_attempts=0, # Handle retries 1051 | total_max_attempts=1 1052 | ) 1053 | ) 1054 | lambda_client = boto3.client('lambda', config=config) 1055 | response = lambda_client.invoke( 1056 | FunctionName=function_name, 1057 | InvocationType='RequestResponse', 1058 | Payload=json.dumps(payload) 1059 | ) 1060 | return json.loads(response['Payload'].read().decode('utf-8')) 1061 | 1062 | class CodeExecutionError(Exception): 1063 | pass 1064 | 1065 | def load_json_data(output_content): 1066 | try: 1067 | json_data = json.loads(output_content.replace("'", '"')) 1068 | return json_data 1069 | except json.JSONDecodeError: 1070 | try: 1071 | parsed_data = ast.literal_eval(output_content) 1072 | json_data = json.loads(json.dumps(parsed_data)) 1073 | return json_data 1074 | except (SyntaxError, ValueError): 1075 | try: 1076 | # Replace escaped single quotes with double quotes, but handle nested quotes carefully 1077 | modified_content = output_content 1078 | # First, ensure the outer structure is properly quoted 1079 | if modified_content.startswith("'") and modified_content.endswith("'"): 1080 | modified_content = modified_content[1:-1] # Remove outer quotes 1081 | # Replace escaped single quotes with double quotes 1082 | modified_content = modified_content.replace("\'", '"') 1083 | json_data = json.loads(modified_content) 1084 | return json_data 1085 | except json.JSONDecodeError as e: 1086 | print(f"Error parsing JSON using all methods: {e}") 1087 | raise e 1088 | 1089 | def function_caller_claude_(params, handler=None): 1090 | """ 1091 | Entrypoint for streaming tool use example. 1092 | """ 1093 | current_chat, chat_hist = get_chat_history_db(params, CHAT_HISTORY_LENGTH, True) 1094 | if current_chat and 'toolResult' in current_chat[0]['content'][0]: 1095 | if 'toolUseId' in current_chat[0]['content'][0]['toolResult']: 1096 | del current_chat[0:2] 1097 | 1098 | vision_model = True 1099 | model_id = 'us.' + model_info[params['model']] 1100 | if any(keyword in [params['model']] for keyword in TEXT_ONLY_MODELS): 1101 | vision_model = False 1102 | 1103 | full_doc_path = [] 1104 | image_path = [] 1105 | for ids, docs in enumerate(params['upload_doc']): 1106 | file_name = docs.name 1107 | _, extensions = os.path.splitext(file_name) 1108 | s3_file_name = put_obj_in_s3_bucket_(docs) 1109 | if extensions.lower() in [".jpg", ".jpeg", ".png", ".gif", ".webp"] and vision_model: 1110 | image_path.append(s3_file_name) 1111 | continue 1112 | full_doc_path.append(s3_file_name) 1113 | 1114 | if params['s3_objects']: 1115 | for ids, docs in enumerate(params['s3_objects']): 1116 | file_name = docs 1117 | _, extensions = os.path.splitext(file_name) 1118 | docs = f"s3://{INPUT_BUCKET}/{INPUT_S3_PATH}/{docs}" 1119 | full_doc_path.append(docs) 1120 | if extensions.lower() in [".jpg", ".jpeg", ".png", ".gif", ".webp"] and vision_model: 1121 | image_path.append(docs) 1122 | continue 1123 | 1124 | errors, result_string = process_files(full_doc_path) 1125 | if errors: 1126 | st.error(errors) 1127 | question = params['question'] 1128 | if result_string and ('.csv' in result_string or '.parquet' in result_string or '.xlsx' in result_string): 1129 | input_text = f"Here is a subset (first few rows) of the data from each dataset tagged by each file name:\n{result_string}\n{question}" 1130 | elif result_string and not ('.csv' in result_string or '.xlsx' in result_string or '.parquet' in result_string): 1131 | doc = 'I have provided documents and/or images tagged by their file names:\n' 1132 | input_text = f"{doc}{result_string}\n{question}" 1133 | else: 1134 | input_text = question 1135 | bedrock_client = boto3.client(service_name='bedrock-runtime', region_name=REGION, config=config) 1136 | # Create the initial message from the user input. 1137 | content = [] 1138 | if image_path: 1139 | for img in image_path: 1140 | s3 = boto3.client('s3') 1141 | match = re.match("s3://(.+?)/(.+)", img) 1142 | image_name = os.path.basename(img) 1143 | _, ext = os.path.splitext(image_name) 1144 | if "jpg" in ext: 1145 | ext = ".jpeg" 1146 | bucket_name = match.group(1) 1147 | key = match.group(2) 1148 | if ".plotly" in key: 1149 | bytes_image = plotly_to_png_bytes(img) 1150 | ext = ".png" 1151 | else: 1152 | obj = s3.get_object(Bucket=bucket_name, Key=key) 1153 | bytes_image = obj['Body'].read() 1154 | content.extend([{"text": image_name}, { 1155 | "image": { 1156 | "format": f"{ext.lower().replace('.', '')}", 1157 | "source": {"bytes": bytes_image} 1158 | } 1159 | }]) 1160 | content.append({"text": input_text}) 1161 | messages = [{ 1162 | "role": "user", 1163 | "content": content 1164 | }] 1165 | # Define the tool and prompt template to send to the model. 1166 | if params["engine"] == "pyspark": 1167 | with open("prompt/pyspark_tool_template.json", "r") as f: 1168 | tool_config = json.load(f) 1169 | with open("prompt/pyspark_tool_prompt.txt", "r") as fo: 1170 | description = fo.read() 1171 | with open("prompt/pyspark_tool_system.txt", "r") as fod: 1172 | system_prompt = fod.read() 1173 | tool_config['tools'][0]['toolSpec']['inputSchema']['json']['properties']['code']['description'] = description 1174 | else: 1175 | with open("prompt/python_tool_template.json", "r") as f: 1176 | tool_config = json.load(f) 1177 | with open("prompt/python_tool_prompt.txt", "r") as fo: 1178 | description = fo.read() 1179 | with open("prompt/python_tool_system.txt", "r") as fod: 1180 | system_prompt = fod.read() 1181 | tool_config['tools'][0]['toolSpec']['inputSchema']['json']['properties']['code']['description'] = description 1182 | 1183 | system = [ 1184 | { 1185 | 'text': system_prompt 1186 | } 1187 | ] 1188 | current_chat.extend(messages) 1189 | # Send the message and get the tool use request from response. 1190 | stop_reason, message, input_tokens, output_tokens, think = stream_messages( 1191 | params, bedrock_client, model_id, current_chat, tool_config, system, 0.1, handler) 1192 | messages.append(message) 1193 | if stop_reason != "tool_use": 1194 | chat_history = { 1195 | "user": question, 1196 | "assistant": message['content'][0]['text'] if message['content'][0]['text'] else message['content'][1]['text'], 1197 | "image": image_path, 1198 | "document": full_doc_path, 1199 | "thinking": think, 1200 | "modelID": model_id, 1201 | "time": str(time.time()), 1202 | "input_token": round(input_tokens), 1203 | "output_token": round(output_tokens) 1204 | } 1205 | if DYNAMODB_TABLE: 1206 | put_db(params, chat_history) 1207 | # use local disk for storage 1208 | else: 1209 | save_chat_local(LOCAL_CHAT_FILE_NAME, [chat_history], params) 1210 | 1211 | return message['content'][0]['text'], "", "", "", full_doc_path, stop_reason, "" 1212 | elif stop_reason == "tool_use": 1213 | 1214 | self_correction_retry = 5 1215 | for content in message['content']: 1216 | if 'toolUse' in content: 1217 | tool = content['toolUse'] 1218 | if tool['name'] == tool_config['tools'][0]['toolSpec']['name']: 1219 | # Preliminary Guardrail to check that the code does not have any install commands 1220 | check_for_library_installs(tool['input']['code']) 1221 | i = 0 1222 | while i < self_correction_retry: 1223 | try: 1224 | payload = { 1225 | "python_packages": tool['input']['python_packages'], 1226 | "code": tool['input']['code'], 1227 | "dataset_name": tool['input']['dataset_name'], 1228 | "iterate": i, 1229 | "bucket": BUCKET, 1230 | "file_path": S3_DOC_CACHE_PATH 1231 | } 1232 | if params["engine"] == "pyspark": 1233 | tool_execution_response = send_athena_job(payload, ATHENA_WORKGROUP_NAME) # Execute generated code in Amazon Athena 1234 | if "error" not in tool_execution_response: 1235 | pattern = r'(.*?)' 1236 | match = re.search(pattern, tool_execution_response['stdout'], re.DOTALL) 1237 | if match: 1238 | output_content = match.group(1) 1239 | # Parse the extracted content to JSON 1240 | try: 1241 | json_data = load_json_data(output_content) 1242 | image_holder = json_data.get('plotly', []) 1243 | results = json_data['result'] 1244 | plotly_obj = json_data.get('plotly', []) 1245 | except json.JSONDecodeError as e: 1246 | print(f"Error parsing JSON: {e}") 1247 | raise e 1248 | else: 1249 | print("No tags found in the input string.") 1250 | break 1251 | else: 1252 | raise Exception(tool_execution_response) 1253 | else: 1254 | # payload['code'] = tool['input']['code'] 1255 | tool_execution_response = invoke_lambda(LAMBDA_FUNC, payload) # Execute generated code in AWS Lambda 1256 | if tool_execution_response.get('statusCode') == 200: 1257 | json_data = json.loads(tool_execution_response['body']) 1258 | image_holder = json_data.get('image_dict', []) 1259 | results = json_data['result'] 1260 | plotly_obj = json_data.get('plotly', []) 1261 | break 1262 | else: 1263 | raise Exception(tool_execution_response.get('body')) 1264 | except Exception as err: 1265 | print(f"ERROR: {err}") 1266 | with st.spinner(f'**Self Correction {i+1}**'): 1267 | tool['input']['code'], tool['input']['python_packages'] = self_crtique(params, tool['input']['code'], err, result_string, handler) 1268 | i += 1 1269 | if i == self_correction_retry: 1270 | raise CodeExecutionError("Request Failed due to exceed on self-correction trials") 1271 | 1272 | if 'text' in message['content'][0] and message['content'][0]['text']: 1273 | code = tool['input']['code'] 1274 | ds = message['content'][1]['toolUse']['input']['dataset_name'] 1275 | pp = message['content'][1]['toolUse']['input']['python_packages'] 1276 | tool_ids = message['content'][1]['toolUse']['toolUseId'] 1277 | tool_name = message['content'][1]['toolUse']['name'] 1278 | else: 1279 | try: 1280 | code = tool['input']['code'] 1281 | ds = message['content'][0]['toolUse']['input']['dataset_name'] 1282 | pp = message['content'][0]['toolUse']['input']['python_packages'] 1283 | tool_ids = message['content'][0]['toolUse']['toolUseId'] 1284 | tool_name = message['content'][0]['toolUse']['name'] 1285 | except Exception: 1286 | code = tool['input']['code'] 1287 | ds = message['content'][-1]['toolUse']['input']['dataset_name'] 1288 | pp = message['content'][-1]['toolUse']['input']['python_packages'] 1289 | tool_ids = message['content'][-1]['toolUse']['toolUseId'] 1290 | tool_name = message['content'][-1]['toolUse']['name'] 1291 | 1292 | chat_history = {"user": question, 1293 | "assistant": code, 1294 | "image": [], 1295 | "document": full_doc_path, 1296 | "modelID": model_id, 1297 | "time": str(time.time()), 1298 | "input_token": round(input_tokens), 1299 | "output_token": round(output_tokens), 1300 | "tool_use_id": tool_ids, 1301 | "tool_name": tool_name, 1302 | "tool_params": {"ds": ds, "pp": pp} 1303 | } 1304 | 1305 | if DYNAMODB_TABLE: 1306 | put_db(params, chat_history) 1307 | # use local disk for storage 1308 | else: 1309 | save_chat_local(LOCAL_CHAT_FILE_NAME, [chat_history], params) 1310 | return "", tool, results, image_holder, full_doc_path, stop_reason, plotly_obj 1311 | --------------------------------------------------------------------------------