├── .gitignore ├── CODE_OF_CONDUCT.md ├── CONTRIBUTING.md ├── LICENSE ├── README.md ├── README_CN.md ├── README_JP.md ├── application ├── .dockerignore ├── .env.cntemplate ├── .env.template ├── .streamlit │ └── config.toml ├── Dockerfile ├── Dockerfile-api ├── Index.py ├── api │ ├── __init__.py │ ├── enum.py │ ├── exception_handler.py │ ├── main.py │ ├── schemas.py │ └── service.py ├── config_files │ └── stauth_config.yaml ├── docker-compose-build.sh ├── docker-compose.yml ├── generate_streamlit_password.py ├── initial_data │ └── README.md ├── main.py ├── nlq │ ├── __init__.py │ ├── business │ │ ├── __init__.py │ │ ├── connection.py │ │ ├── datasource │ │ │ ├── __init__.py │ │ │ ├── base.py │ │ │ ├── clickhouse.py │ │ │ ├── default.py │ │ │ ├── factory.py │ │ │ └── mysql.py │ │ ├── log_feedback.py │ │ ├── log_store.py │ │ ├── login_user.py │ │ ├── model.py │ │ ├── nlq_chain.py │ │ ├── profile.py │ │ ├── suggested_question.py │ │ ├── user_profile.py │ │ └── vector_store.py │ ├── core │ │ ├── __init__.py │ │ ├── chat_context.py │ │ ├── state.py │ │ └── state_machine.py │ └── data_access │ │ ├── __init__.py │ │ ├── database.py │ │ ├── dynamo_connection.py │ │ ├── dynamo_model.py │ │ ├── dynamo_profile.py │ │ ├── dynamo_query_log.py │ │ ├── dynamo_suggested_question.py │ │ ├── dynamo_user_profile.py │ │ ├── opensearch.py │ │ └── opensearch_query_log.py ├── opensearch_deploy.py ├── pages │ ├── 10_📚_User_Authorization.py │ ├── 1_🌍_Generative_BI_Playground.py │ ├── 2_🪙_Data_Connection_Management.py │ ├── 3_🪙_Data_Profile_Management.py │ ├── 4_🪙_Schema_Description_Management.py │ ├── 5_🪙_Prompt_Management.py │ ├── 6_📚_Index_Management.py │ ├── 7_📚_Entity_Management.py │ ├── 8_📚_Agent_Cot_Management.py │ ├── 9_🪙_SageMaker_Model_Management.py │ └── mainpage.py ├── requirements-api.txt ├── requirements.txt ├── static │ ├── RESTful.html │ ├── WebSocket.html │ └── components │ │ ├── JsonViewer │ │ ├── json-viewer.css │ │ └── json-viewer.js │ │ ├── bootstrap │ │ ├── bootstrap.min.css │ │ ├── bootstrap.min.css.map │ │ ├── bootstrap.min.js │ │ └── bootstrap.min.js.map │ │ ├── jquery-3.7.1.min.js │ │ └── marked.min.js ├── tests │ └── unit_tests │ │ └── test_row_level_security.py └── utils │ ├── __init__.py │ ├── apis.py │ ├── auth.py │ ├── constant.py │ ├── database.py │ ├── domain.py │ ├── env_var.py │ ├── llm.py │ ├── logging.py │ ├── navigation.py │ ├── opensearch.py │ ├── prompt.py │ ├── prompts │ ├── __init__.py │ ├── check_prompt.py │ ├── generate_prompt.py │ ├── guidance_prompt.py │ └── table_prompt.py │ ├── text_search.py │ └── tool.py ├── assets ├── add_database_connect.png ├── add_index_sample.png ├── add_schema_management.png ├── architecture.png ├── aws_architecture.png ├── bedrock_model_access.png ├── create_data_profile.png ├── interface.png ├── logic.png ├── react_deploy.png ├── screenshot-genbi.png ├── streamlit_deploy.png ├── streamlit_front.png ├── update_data_profile.png ├── update_schema_management.png └── user_front_end_cn.png ├── report-front-end ├── .env ├── .env.template ├── .eslintrc.cjs ├── .gitignore ├── .prettierrc ├── Dockerfile ├── docker-entry.sh ├── index.html ├── package.json ├── postcss.config.js ├── public │ ├── favicon.ico │ ├── manifest.json │ └── smile-logo.png ├── src │ ├── app.scss │ ├── app.tsx │ ├── components │ │ ├── BaseAppLayout.tsx │ │ ├── Login │ │ │ ├── CognitoLogin │ │ │ │ ├── aws-config.ts │ │ │ │ ├── index.tsx │ │ │ │ └── layout-with-cognito.css │ │ │ ├── CustomLogin │ │ │ │ ├── index.tsx │ │ │ │ └── style.scss │ │ │ └── index.tsx │ │ ├── PanelConfigs │ │ │ ├── index.tsx │ │ │ └── style.scss │ │ ├── PanelSideNav │ │ │ ├── index.tsx │ │ │ ├── style.scss │ │ │ └── types.ts │ │ ├── SectionChat │ │ │ ├── ChartRenderer.tsx │ │ │ ├── ChatInput.tsx │ │ │ ├── CustomQuestions.tsx │ │ │ ├── ExpandableSectionWithDivider.tsx │ │ │ ├── MessageRenderer │ │ │ │ ├── AiMessage.tsx │ │ │ │ ├── EntitySelect.tsx │ │ │ │ └── index.tsx │ │ │ ├── ResultRenderer.tsx │ │ │ ├── chat.module.scss │ │ │ ├── index.tsx │ │ │ └── types.ts │ │ └── TopNav │ │ │ ├── index.tsx │ │ │ └── style.scss │ ├── hooks │ │ └── useGlobalContext.ts │ ├── main.tsx │ ├── utils │ │ ├── api │ │ │ ├── API.ts │ │ │ └── WebSocket.ts │ │ ├── constants.ts │ │ └── helpers │ │ │ ├── storage.ts │ │ │ ├── store.ts │ │ │ ├── tools.ts │ │ │ └── types.ts │ └── vite-env.d.ts ├── tsconfig.json ├── tsconfig.node.json ├── vite.config.ts └── yarn.lock └── source └── resources ├── .npmignore ├── bin └── main.ts ├── cdk-config.json ├── cdk.json ├── jest.config.js ├── lib ├── aos │ └── aos-stack.ts ├── cognito │ └── cognito-stack.ts ├── ecs │ └── ecs-stack.ts ├── main-stack.ts ├── model │ └── llm-stack.ts ├── rds │ └── rds-stack.ts ├── redshift │ └── redshfit-stack.ts └── vpc │ └── vpc-stack.ts ├── package.json └── tsconfig.json /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/#use-with-ide 110 | .pdm.toml 111 | 112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 113 | __pypackages__/ 114 | 115 | # Celery stuff 116 | celerybeat-schedule 117 | celerybeat.pid 118 | 119 | # SageMath parsed files 120 | *.sage.py 121 | 122 | # Environments 123 | .venv 124 | env/ 125 | venv/ 126 | ENV/ 127 | env.bak/ 128 | venv.bak/ 129 | 130 | # Spyder project settings 131 | .spyderproject 132 | .spyproject 133 | 134 | # Rope project settings 135 | .ropeproject 136 | 137 | # mkdocs documentation 138 | /site 139 | 140 | # mypy 141 | .mypy_cache/ 142 | .dmypy.json 143 | dmypy.json 144 | 145 | # Pyre type checker 146 | .pyre/ 147 | 148 | # pytype static type analyzer 149 | .pytype/ 150 | 151 | # Cython debug symbols 152 | cython_debug/ 153 | 154 | # PyCharm 155 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 156 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 157 | # and can be added to the global gitignore or merged into this file. For a more nuclear 158 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 159 | #.idea/ 160 | 161 | *_test.py 162 | # CDK 163 | node_modules/ 164 | .env.local 165 | .env.local.* 166 | 167 | # CDK asset staging directory 168 | cdk.out 169 | cdk.context.json 170 | **/cdk.out 171 | package-lock.json 172 | !source/resources/lib/ 173 | 174 | # Model Artifact 175 | internlm2-chat-7b/ 176 | sqlcoder-7b-2/ 177 | bge-m3/ 178 | 179 | .DS_Store 180 | -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | ## Code of Conduct 2 | This project has adopted the [Amazon Open Source Code of Conduct](https://aws.github.io/code-of-conduct). 3 | For more information see the [Code of Conduct FAQ](https://aws.github.io/code-of-conduct-faq) or contact 4 | opensource-codeofconduct@amazon.com with any additional questions or comments. 5 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing Guidelines 2 | 3 | Thank you for your interest in contributing to our project. Whether it's a bug report, new feature, correction, or additional 4 | documentation, we greatly value feedback and contributions from our community. 5 | 6 | Please read through this document before submitting any issues or pull requests to ensure we have all the necessary 7 | information to effectively respond to your bug report or contribution. 8 | 9 | 10 | ## Reporting Bugs/Feature Requests 11 | 12 | We welcome you to use the GitHub issue tracker to report bugs or suggest features. 13 | 14 | When filing an issue, please check existing open, or recently closed, issues to make sure somebody else hasn't already 15 | reported the issue. Please try to include as much information as you can. Details like these are incredibly useful: 16 | 17 | * A reproducible test case or series of steps 18 | * The version of our code being used 19 | * Any modifications you've made relevant to the bug 20 | * Anything unusual about your environment or deployment 21 | 22 | 23 | ## Contributing via Pull Requests 24 | Contributions via pull requests are much appreciated. Before sending us a pull request, please ensure that: 25 | 26 | 1. You are working against the latest source on the *main* branch. 27 | 2. You check existing open, and recently merged, pull requests to make sure someone else hasn't addressed the problem already. 28 | 3. You open an issue to discuss any significant work - we would hate for your time to be wasted. 29 | 30 | To send us a pull request, please: 31 | 32 | 1. Fork the repository. 33 | 2. Modify the source; please focus on the specific change you are contributing. If you also reformat all the code, it will be hard for us to focus on your change. 34 | 3. Ensure local tests pass. 35 | 4. Commit to your fork using clear commit messages. 36 | 5. Send us a pull request, answering any default questions in the pull request interface. 37 | 6. Pay attention to any automated CI failures reported in the pull request, and stay involved in the conversation. 38 | 39 | GitHub provides additional document on [forking a repository](https://help.github.com/articles/fork-a-repo/) and 40 | [creating a pull request](https://help.github.com/articles/creating-a-pull-request/). 41 | 42 | 43 | ## Finding contributions to work on 44 | Looking at the existing issues is a great way to find something to contribute on. As our projects, by default, use the default GitHub issue labels (enhancement/bug/duplicate/help wanted/invalid/question/wontfix), looking at any 'help wanted' issues is a great place to start. 45 | 46 | 47 | ## Code of Conduct 48 | This project has adopted the [Amazon Open Source Code of Conduct](https://aws.github.io/code-of-conduct). 49 | For more information see the [Code of Conduct FAQ](https://aws.github.io/code-of-conduct-faq) or contact 50 | opensource-codeofconduct@amazon.com with any additional questions or comments. 51 | 52 | 53 | ## Security issue notifications 54 | If you discover a potential security issue in this project we ask that you notify AWS/Amazon Security via our [vulnerability reporting page](http://aws.amazon.com/security/vulnerability-reporting/). Please do **not** create a public github issue. 55 | 56 | 57 | ## Licensing 58 | 59 | See the [LICENSE](LICENSE) file for our project's licensing. We will ask you to confirm the licensing of your contribution. 60 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT No Attribution 2 | 3 | Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy of 6 | this software and associated documentation files (the "Software"), to deal in 7 | the Software without restriction, including without limitation the rights to 8 | use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of 9 | the Software, and to permit persons to whom the Software is furnished to do so. 10 | 11 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 12 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS 13 | FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR 14 | COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER 15 | IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN 16 | CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 17 | 18 | -------------------------------------------------------------------------------- /README_CN.md: -------------------------------------------------------------------------------- 1 | # AWS上的生成式BI应用 2 | 3 | ## 1、介绍 4 | 5 | 6 | 这是一个在AWS上使用 Amazon Bedrock、Amazon OpenSearch 和 RAG 技术的生成式BI应用。 7 | 8 | 9 | 10 | - 系统架构图 11 | 12 | 13 | ![img.png](./assets/aws_architecture.png) 14 | 15 | - 数据流程图 16 | 17 | ![Screenshot](./assets/logic.png) 18 | 19 | 20 | [用户操作手册](https://github.com/aws-samples/generative-bi-using-rag/wiki/%E7%B3%BB%E7%BB%9F%E7%AE%A1%E7%90%86%E5%91%98%E6%93%8D%E4%BD%9C) 21 | 22 | [项目数据流程图](https://github.com/aws-samples/generative-bi-using-rag/wiki/%E6%9E%B6%E6%9E%84%E5%9B%BE) 23 | 24 | 25 | ## 目录 26 | 27 | 1. [Overview](#overview) 28 | - [Cost](#cost) 29 | 2. [Prerequisites](#prerequisites) 30 | - [Operating System](#operating-system) 31 | 3. [Workshop](#workshop) 32 | 4. [Deployment Steps](#deployment-steps) 33 | 5. [Deployment Validation](#deployment-validation) 34 | 6. [Running the Guidance](#running-the-guidance) 35 | 7. [Next Steps](#next-steps) 36 | 8. [Cleanup](#cleanup) 37 | 38 | ## 概述 39 | 40 | 41 | 这是一个在AWS上针对自定义数据源(RDS/Redshift)启用生成式BI功能的框架。它提供以下关键特性: 42 | 43 | - 通过自然语言查询自定义数据源的Text-to-SQL功能。 44 | - 用户友好的界面,可添加、编辑和管理数据源、表和列描述。 45 | - 通过集成历史问题答案排名和实体识别来提高性能。 46 | - 自定义业务信息,包括实体信息,公式,SQL样本,复杂业务问题分析思路等。 47 | - 增加agent任务拆分功能,能够处理复杂的归因分析问题。 48 | - 直观的问答界面,可深入了解底层的Text-to-SQL机制。 49 | - 简单的代理设计界面,可通过对话方式处理复杂查询。 50 | 51 | 52 | 53 | ### 费用 54 | 55 | 截至2024年5月,在 us-west-2 区域使用默认设置运行这个 Guidance 的成本大约为每月$1337.8,处理2000个请求。 56 | 57 | 58 | 59 | ### 费用示例 60 | 61 | 下表提供了在美国东部(弗吉尼亚北部)地区部署此 Guidance 时,使用默认参数一个月的样本成本明细。 62 | 63 | 64 | | AWS service | Dimensions | Cost [USD] per Month | 65 | | ----------- | ------------ | ------------ | 66 | | Amazon ECS | v0.75 CPU 5GB | $11.51 | 67 | | Amazon DynamoDB | 25 provisioned write & read capacity units per month | $ 14.04 | 68 | | Amazon Bedrock | 2000 requests per month, with each request consuming 10000 input tokens and 1000 output tokens | $ 90.00 | 69 | | Amazon OpenSearch Service | 1 domain with m5.large.search | $ 103.66 | 70 | 71 | 72 | 73 | ### 前提条件 74 | 75 | ### 操作系统 76 | 77 | CDK 经过优化,最适合在 **Amazon Linux 2023 AMI** 上启动。在其他操作系统上部署可能需要额外的步骤。 78 | 79 | ### AWS 账户要求 80 | 81 | - VPC 82 | - IAM role with specific permissions 83 | - Amazon Bedrock 84 | - Amazon ECS 85 | - Amazon DynamoDB 86 | - Amazon Cognito 87 | - Amazon OpenSearch Service 88 | - Amazon Elastic Load Balancing 89 | - Amazon SageMaker (Optional, if you need customized models to be deployed) 90 | - Amazon Secrets Manager 91 | 92 | ### 支持的区域 93 | 94 | us-west-2, us-east-2, us-east-1, ap-south-1, ap-southeast-1, ap-southeast-2, ap-northeast-1, eu-central-1, eu-west-1, eu-west-3, 以及其他支持bedrock的区域 95 | 96 | ## Workshop 97 | 98 | 更多更详细的使用说明,请查看下方的Workshop 99 | 100 | 🔥🔥🔥 [The Workshop Content](https://catalog.us-east-1.prod.workshops.aws/workshops/37b20322-fc96-4716-8e51-4568b0641448) 101 | 102 | 103 | ## 部署步骤 104 | 105 | ### 1. 准备 CDK 先决条件 106 | 107 | 请按照 [CDK Workshop](https://cdkworkshop.com/15-prerequisites.html) 中的说明安装 CDK 工具包。确保您的环境有权限创建资源。 108 | 109 | ### 2. Set a password for the GenBI Admin Web UI 110 | 111 | 对于 GenBI 管理员 Web UI,默认密码为[empty],需要为 GenBI 管理员 Web UI 设置密码,您可以修改如下文件 112 | 113 | ```application/config_files/stauth_config.yaml``` 114 | 115 | 下面是一个示例 116 | 117 | ```yaml 118 | credentials: 119 | usernames: 120 | jsmith: 121 | email: jsmith@gmail.com 122 | name: John Smith 123 | password: XXXXXX # To be replaced with hashed password 124 | rbriggs: 125 | email: rbriggs@gmail.com 126 | name: Rebecca Briggs 127 | password: XXXXXX # To be replaced with hashed password 128 | cookie: 129 | expiry_days: 30 130 | key: random_signature_key # Must be string 131 | name: random_cookie_name 132 | preauthorized: 133 | emails: 134 | - melsby@gmail.com 135 | ``` 136 | 137 | 将密码'XXXXXX'改为哈希密码 138 | 139 | 使用以下 Python 代码生成 XXXXXX。我们需要 Python 3.8 及以上版本来运行以下代码: 140 | 141 | ```python 142 | from streamlit_authenticator.utilities.hasher import Hasher 143 | hashed_passwords = Hasher(['password123']).generate() 144 | ``` 145 | 146 | ### 3. 部署CDK 147 | 148 | 对于global区别,执行如下命令: 149 | 150 | ``` 151 | cd generative-bi-using-rag/source/resources 152 | 153 | npm install aws-cdk-lib 154 | ``` 155 | 156 | 部署 CDK 堆栈,如果需要,请将区域更改为您自己的区域,例如 us-west-2、us-east-1 等: 157 | 158 | ``` 159 | export AWS_ACCOUNT_ID=XXXXXXXXXXXX 160 | export AWS_REGION=us-west-2 161 | 162 | cdk bootstrap aws://$AWS_ACCOUNT_ID/$AWS_REGION 163 | cdk deploy GenBiMainStack --require-approval never 164 | 165 | ``` 166 | 167 | 当部署成功时,您可以看到如下信息 168 | ``` 169 | GenBiMainStack.AOSDomainEndpoint = XXXXX.us-west-2.es.amazonaws.com 170 | GenBiMainStack.APIEndpoint = XXXXX.us-west-2.elb.amazonaws.com 171 | GenBiMainStack.FrontendEndpoint = XXXXX.us-west-2.elb.amazonaws.com 172 | GenBiMainStack.StreamlitEndpoint = XXXXX.us-west-2.elb.amazonaws.com 173 | ``` 174 | 175 | 176 | ## 运行Guidance 177 | 178 | 在部署 CDK 堆栈后,等待大约 40 分钟完成初始化。然后在浏览器中打开 Web UI: https://your-public-dns 179 | 180 | ## 清除 181 | - 删除CDK堆栈: 182 | ``` 183 | cdk destroy GenBiMainStack 184 | ``` 185 | -------------------------------------------------------------------------------- /README_JP.md: -------------------------------------------------------------------------------- 1 | # AWS上でのRAGを使用した生成ビジネスインテリジェンス 2 | [中文文檔](README_CN.md) | [日本語ドキュメント](README_JP.md) 3 | 4 | ここに記載されているのはCDKのみの導入ガイドです。手動導入や詳細ガイドについては、[中国語の手動導入ガイド](https://github.com/aws-samples/generative-bi-using-rag/wiki/%E8%B0%83%E8%AF%95%E7%95%8C%E9%9D%A2%E4%BB%A5%E5%8F%8AAPI%E9%83%A8%E7%BD%B2)を参照してください。 5 | 6 | ![Screenshot](./assets/interface.png) 7 | 8 | ## 紹介 9 | 10 | Amazon Bedrock、Amazon OpenSearchとRAG技術を使用した生成ビジネスインテリジェンスのデモ。 11 | 12 | ![Screenshot](./assets/aws_architecture.png) 13 | *AWS上のリファレンスアーキテクチャ* 14 | 15 | ![Screenshot](./assets/logic.png) 16 | *設計論理* 17 | 18 | [ユーザー操作マニュアル](https://github.com/aws-samples/generative-bi-using-rag/wiki/%E7%B3%BB%E7%BB%9F%E7%AE%A1%E7%90%86%E5%91%98%E6%93%8D%E4%BD%9C) 19 | 20 | [プロジェクトデータフローチャート](https://github.com/aws-samples/generative-bi-using-rag/wiki/%E6%9E%B6%E6%9E%84%E5%9B%B3) 21 | 22 | ## 目次 23 | 1. [概要](#overview) 24 | - [コスト](#cost) 25 | 2. [前提条件](#prerequisites) 26 | - [オペレーティングシステム](#operating-system) 27 | 3. [ワークショップ](#workshop) 28 | 4. [デプロイ手順](#deployment-steps) 29 | 5. [デプロイの検証](#deployment-validation) 30 | 6. [ガイダンスの実行](#running-the-guidance) 31 | 7. [次のステップ](#next-steps) 32 | 8. [クリーンアップ](#cleanup) 33 | 34 | ## 概要 35 | このフレームワークは、AWSでホストされているカスタムデータソース(RDS/Redshift)に対してGenerative BIの機能を可能にするように設計されています。主な機能は以下の通りです。 36 | 37 | - 自然言語を使ってカスタムデータソースを問い合わせるためのText-to-SQLの機能 38 | - データソース、テーブル、列の説明を追加、編集、管理するためのユーザーフレンドリーなインターフェース 39 | - 過去の質問と回答のランキングとエンティティ認識の統合による性能向上 40 | - エンティティ情報、数式、SQLサンプル、複雑なビジネス問題の分析アイデアなどのビジネス情報をカスタマイズ可能 41 | - 複雑な帰属分析問題を処理するためのエージェントタスク分割機能の追加 42 | - 基礎となるText-to-SQLメカニズムの洞察を提供する直感的な質問応答UI 43 | - 対話型アプローチで複雑な質問に対処するためのシンプルなエージェント設計インターフェース 44 | 45 | ### コスト 46 | 47 | 2024年5月現在、デフォルト設定でこのガイダンスを_us-west-2_リージョンで実行する場合、2000リクエストを処理するのに約1,337.8ドル/月のコストがかかります。 48 | 49 | ### サンプルコストテーブル 50 | 51 | 以下の表は、デフォルトのパラメーターでこのガイダンスをUSイースト(バージニア北部)リージョンに1か月間デプロイした場合のサンプルコスト内訳を示しています。 52 | 53 | | AWSサービス | 内訳 | コスト[USD]/月 | 54 | | ----------- | ------------ | ------------ | 55 | | Amazon ECS | vCPU 0.75、5GB | $11.51 | 56 | | Amazon DynamoDB | プロビジョンドライト&リードキャパシティユニット25個/月 | $14.04 | 57 | | Amazon Bedrock | 2000リクエスト/月、リクエストあたり10000入力トークン、1000出力トークン | $90.00 | 58 | | Amazon OpenSearch Service | m5.large.searchインスタンス×1ドメイン | $103.66 | 59 | 60 | ## 前提条件 61 | 62 | ### オペレーティングシステム 63 | "CDKは **** 上で最適に動作するよう最適化されています。他のOSでのデプロイには追加の手順が必要になる可能性があります。" 64 | 65 | ### AWS アカウントの要件 66 | 67 | - VPC 68 | - 特定の権限を持つIAMロール 69 | - Amazon Bedrock 70 | - Amazon ECS 71 | - Amazon DynamoDB 72 | - Amazon Cognito 73 | - Amazon OpenSearch Service 74 | - Amazon Elastic Load Balancing 75 | - Amazon SageMaker (オプション、カスタムモデルをデプロイする場合) 76 | - Amazon Secrets Manager 77 | 78 | ### サポートされているリージョン 79 | 80 | us-west-2、us-east-2、us-east-1、ap-south-1、ap-southeast-1、ap-southeast-2、ap-northeast-1、eu-central-1、eu-west-1、eu-west-3、またはガイダンスで使用されているサービス(bedrock)がサポートされている他のリージョン。 81 | 82 | ## ワークショップ 83 | 84 | より詳細な使用手順については、以下のワークショップを参照してください。 85 | 86 | 🔥🔥🔥 [ワークショップコンテンツ](https://catalog.us-east-1.prod.workshops.aws/workshops/37b20322-fc96-4716-8e51-4568b0641448) 87 | 88 | 89 | ## デプロイの手順 90 | 91 | ### 1. CDK の前提条件を準備する 92 | [CDK ワークショップ](https://cdkworkshop.com/15-prerequisites.html)の手順に従って、CDK ツールキットをインストールしてください。環境にリソースを作成する権限があることを確認してください。 93 | 94 | ### 2. GenBI 管理ウェブ UI のパスワードを設定する 95 | 96 | GenBI 管理ウェブ UI のデフォルトのパスワードは[空白]です。GenBI 管理ウェブ UI のパスワードを設定する必要がある場合は、以下のファイルでパスワードを更新できます。 97 | ```application/config_files/stauth_config.yaml``` 98 | 99 | 例: 100 | 101 | ```yaml 102 | credentials: 103 | usernames: 104 | jsmith: 105 | email: jsmith@gmail.com 106 | name: John Smith 107 | password: XXXXXX # ハッシュ化されたパスワードに置き換える 108 | rbriggs: 109 | email: rbriggs@gmail.com 110 | name: Rebecca Briggs 111 | password: XXXXXX # ハッシュ化されたパスワードに置き換える 112 | cookie: 113 | expiry_days: 30 114 | key: random_signature_key # 文字列でなければならない 115 | name: random_cookie_name 116 | preauthorized: 117 | emails: 118 | - melsby@gmail.com 119 | ``` 120 | 121 | パスワード 'XXXXXX' をハッシュ化されたパスワードに変更します。 122 | 123 | 以下の Python コードを使用して XXXXXX を生成します。Python 3.8 以上が必要です。 124 | ```python 125 | from streamlit_authenticator.utilities.hasher import Hasher 126 | hashed_passwords = Hasher(['password123']).generate() 127 | ``` 128 | 129 | ### 3. CDK スタックをデプロイする 130 | グローバルリージョンの場合、以下のコマンドを実行します。 131 | 132 | CDK プロジェクトのディレクトリに移動: 133 | ``` 134 | cd generative-bi-using-rag/source/resources 135 | 136 | npm install aws-cdk-lib 137 | ``` 138 | CDK スタックをデプロイします。必要に応じてリージョンを変更してください(例: us-west-2、us-east-1 など)。 139 | ``` 140 | export AWS_ACCOUNT_ID=XXXXXXXXXXXX 141 | export AWS_REGION=us-west-2 142 | 143 | cdk bootstrap aws://$AWS_ACCOUNT_ID/$AWS_REGION 144 | cdk deploy GenBiMainStack --require-approval never 145 | 146 | ``` 147 | デプロイが成功すると、以下のように表示されます。 148 | ``` 149 | GenBiMainStack.AOSDomainEndpoint = XXXXX.us-west-2.es.amazonaws.com 150 | GenBiMainStack.APIEndpoint = XXXXX.us-west-2.elb.amazonaws.com 151 | GenBiMainStack.FrontendEndpoint = XXXXX.us-west-2.elb.amazonaws.com 152 | GenBiMainStack.StreamlitEndpoint = XXXXX.us-west-2.elb.amazonaws.com 153 | ``` 154 | 155 | ## Guidance の実行 156 | 157 | CDK スタックがデプロイされた後、初期化が完了するのを約40分待ってから、ブラウザで Web UI を開きます: https://your-public-dns 158 | 159 | ## クリーンアップ 160 | - CDK スタックを削除する: 161 | ``` 162 | cdk destroy GenBiMainStack 163 | ``` -------------------------------------------------------------------------------- /application/.dockerignore: -------------------------------------------------------------------------------- 1 | initial_data/ 2 | .git/ -------------------------------------------------------------------------------- /application/.env.cntemplate: -------------------------------------------------------------------------------- 1 | RDS_MYSQL_USERNAME=llmdata 2 | RDS_MYSQL_PASSWORD=llmdata 3 | RDS_MYSQL_HOST=mysql-db 4 | RDS_MYSQL_PORT=3306 5 | RDS_MYSQL_DBNAME=llm 6 | 7 | 8 | OPENSEARCH_TYPE=service 9 | AOS_AWS_REGION=cn-north-1 10 | AOS_INDEX=uba 11 | AOS_INDEX_NER=uba_ner 12 | AOS_INDEX_AGENT=uba_agent 13 | 14 | 15 | BEDROCK_REGION=cn-north-1 16 | RDS_REGION_NAME=cn-north-1 17 | AWS_DEFAULT_REGION=cn-north-1 18 | DYNAMODB_AWS_REGION=cn-north-1 19 | 20 | SAGEMAKER_ENDPOINT_EMBEDDING=embedding-bge-m3-3ab71 21 | SAGEMAKER_ENDPOINT_INTENT=llm-internlm2-chat-7b-3ab71 22 | SAGEMAKER_ENDPOINT_SQL=sql-sqlcoder-7b-2-7e5b6 23 | SAGEMAKER_ENDPOINT_EXPLAIN=llm-internlm2-chat-7b-3ab71 24 | 25 | EMBEDDING_DIMENSION=1024 26 | 27 | # If you need to use ak/sk to access bedrock, please configure bedrock's ak/sk to Secrets Manager, Examples are as follows 28 | # BEDROCK_SECRETS_AK_SK=bedrock-ak-sk 29 | 30 | BEDROCK_SECRETS_AK_SK= 31 | 32 | OPENSEARCH_SECRETS_URL_HOST=opensearch-host-url 33 | OPENSEARCH_SECRETS_USERNAME_PASSWORD=opensearch-master-user 34 | 35 | ENABLE_USER_PROFILE_MAP=False -------------------------------------------------------------------------------- /application/.env.template: -------------------------------------------------------------------------------- 1 | RDS_MYSQL_USERNAME=llmdata 2 | RDS_MYSQL_PASSWORD=llmdata 3 | RDS_MYSQL_HOST=mysql-db 4 | RDS_MYSQL_PORT=3306 5 | RDS_MYSQL_DBNAME=llm 6 | 7 | # possible value: 'service', 'docker'. Will route to Secrets Manager if set to 'service'. Will use env vars below if set to 'docker' 8 | OPENSEARCH_TYPE=service 9 | AOS_HOST=opensearch-node1 10 | AOS_PORT=9200 11 | AOS_AWS_REGION=us-west-2 12 | AOS_DOMAIN=llm-data-analytics 13 | AOS_INDEX=uba 14 | AOS_INDEX_NER=uba_ner 15 | AOS_INDEX_AGENT=uba_agent 16 | AOS_USER=admin 17 | AOS_PASSWORD=admin 18 | 19 | BEDROCK_REGION=us-west-2 20 | RDS_REGION_NAME=us-west-2 21 | AWS_DEFAULT_REGION=us-west-2 22 | DYNAMODB_AWS_REGION=us-west-2 23 | 24 | EMBEDDING_DIMENSION=1536 25 | BEDROCK_EMBEDDING_MODEL=amazon.titan-embed-text-v1 26 | 27 | # If you need to use ak/sk to access bedrock, please configure bedrock's ak/sk to Secrets Manager, Examples are as follows 28 | # BEDROCK_SECRETS_AK_SK=bedrock-ak-sk 29 | 30 | BEDROCK_SECRETS_AK_SK= 31 | 32 | OPENSEARCH_SECRETS_URL_HOST=opensearch-host-url 33 | OPENSEARCH_SECRETS_USERNAME_PASSWORD=opensearch-master-user 34 | 35 | # SAGEMAKER_ENDPOINT_EMBEDDING= 36 | 37 | 38 | VITE_COGNITO_REGION= 39 | VITE_COGNITO_USER_POOL_ID= 40 | VITE_COGNITO_USER_POOL_WEB_CLIENT_ID= 41 | 42 | ENABLE_USER_PROFILE_MAP=False -------------------------------------------------------------------------------- /application/.streamlit/config.toml: -------------------------------------------------------------------------------- 1 | [client] 2 | showSidebarNavigation = false 3 | -------------------------------------------------------------------------------- /application/Dockerfile: -------------------------------------------------------------------------------- 1 | FROM public.ecr.aws/docker/library/python:3.10-slim 2 | 3 | WORKDIR /app 4 | 5 | RUN adduser --disabled-password --gecos '' appuser 6 | 7 | WORKDIR /app 8 | 9 | COPY requirements.txt /app/ 10 | 11 | ARG AWS_REGION 12 | ENV AWS_REGION=${AWS_REGION} 13 | 14 | # Print the AWS_REGION for verification 15 | RUN echo "Current AWS Region: $AWS_REGION" 16 | 17 | # Install dependencies using the appropriate PyPI source based on AWS region 18 | RUN if [ "$AWS_REGION" = "cn-north-1" ] || [ "$AWS_REGION" = "cn-northwest-1" ]; then \ 19 | pip3 install -r requirements.txt -i https://pypi.tuna.tsinghua.edu.cn/simple; \ 20 | else \ 21 | pip3 install -r requirements.txt; \ 22 | fi 23 | 24 | COPY . /app/ 25 | 26 | # set streamlit config via env vars 27 | ENV STREAMLIT_SERVER_ENABLE_STATIC_SERVING=false 28 | ENV STREAMLIT_LOGGER_LEVEL="info" 29 | ENV STREAMLIT_CLIENT_TOOLBAR_MODE="viewer" 30 | ENV STREAMLIT_CLIENT_SHOW_ERROR_DETAILS=false 31 | ENV STREAMLIT_BROWSER_GATHER_USAGE_STATS=false 32 | ENV STREAMLIT_THEME_BASE="light" 33 | 34 | EXPOSE 8501 35 | 36 | USER appuser 37 | 38 | HEALTHCHECK CMD curl --fail http://localhost:8501/_stcore/health 39 | 40 | ENTRYPOINT ["streamlit", "run", "Index.py", "--server.port=8501", "--server.address=0.0.0.0"] 41 | -------------------------------------------------------------------------------- /application/Dockerfile-api: -------------------------------------------------------------------------------- 1 | FROM public.ecr.aws/docker/library/python:3.10-slim 2 | 3 | WORKDIR /app 4 | 5 | COPY . /app/ 6 | 7 | ARG AWS_REGION 8 | ENV AWS_REGION=${AWS_REGION} 9 | 10 | # Print the AWS_REGION for verification 11 | RUN echo "Current AWS Region: $AWS_REGION" 12 | 13 | # Install dependencies using the appropriate PyPI source based on AWS region 14 | RUN if [ "$AWS_REGION" = "cn-north-1" ] || [ "$AWS_REGION" = "cn-northwest-1" ]; then \ 15 | pip3 install -r requirements-api.txt -i https://pypi.tuna.tsinghua.edu.cn/simple; \ 16 | else \ 17 | pip3 install -r requirements-api.txt; \ 18 | fi 19 | 20 | EXPOSE 8000 21 | 22 | ENTRYPOINT ["uvicorn", "main:app", "--host", "0.0.0.0"] 23 | -------------------------------------------------------------------------------- /application/Index.py: -------------------------------------------------------------------------------- 1 | import time 2 | 3 | import streamlit as st 4 | from utils.navigation import get_authenticator 5 | 6 | st.set_page_config( 7 | page_title="Intelligent BI", 8 | page_icon="👋", 9 | ) 10 | 11 | authenticator = get_authenticator() 12 | name, authentication_status, username = authenticator.login('main') 13 | 14 | if st.session_state['authentication_status']: 15 | time.sleep(0.5) 16 | st.session_state['auth_name'] = name 17 | st.session_state['auth_username'] = username 18 | st.switch_page("pages/mainpage.py") 19 | elif st.session_state['authentication_status'] is False: 20 | st.error('Username/password is incorrect') 21 | elif st.session_state['authentication_status'] is None: 22 | st.warning('Please enter your username and password') 23 | -------------------------------------------------------------------------------- /application/api/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aws-samples/generative-bi-using-rag/19db1a855e19167008a0adde451099b0c5726559/application/api/__init__.py -------------------------------------------------------------------------------- /application/api/enum.py: -------------------------------------------------------------------------------- 1 | from enum import Enum, unique 2 | from utils.constant import BEDROCK_MODEL_IDS 3 | 4 | 5 | @unique 6 | class ErrorEnum(Enum): 7 | SUCCEEDED = {1: "Operation succeeded"} 8 | NOT_SUPPORTED = {1001: "Your query statement is currently not supported by the system"} 9 | INVAILD_BEDROCK_MODEL_ID = {1002: f"Invalid bedrock model id.Vaild ids:{BEDROCK_MODEL_IDS}"} 10 | INVAILD_SESSION_ID = {1003: f"Invalid session id."} 11 | PROFILE_NOT_FOUND = {1004: "Profile name not found."} 12 | UNKNOWN_ERROR = {9999: "Unknown error."} 13 | 14 | def get_code(self): 15 | return list(self.value.keys())[0] 16 | 17 | def get_message(self): 18 | return list(self.value.values())[0] 19 | 20 | 21 | @unique 22 | class ContentEnum(Enum): 23 | EXCEPTION = "exception" 24 | COMMON = "common" 25 | STATE = "state" 26 | END = "end" -------------------------------------------------------------------------------- /application/api/exception_handler.py: -------------------------------------------------------------------------------- 1 | import os 2 | from fastapi.responses import JSONResponse 3 | from fastapi import status, FastAPI, Request, Response 4 | from fastapi.exceptions import RequestValidationError 5 | from .enum import ErrorEnum 6 | import traceback 7 | import logging 8 | logger = logging.getLogger(__name__) 9 | 10 | def response_error(code: int, message: str, status_code: int = status.HTTP_400_BAD_REQUEST) -> Response: 11 | headers = {} 12 | return JSONResponse( 13 | content={ 14 | 'code': code, 15 | 'message': message, 16 | }, 17 | headers=headers, 18 | status_code=status_code, 19 | ) 20 | 21 | 22 | def biz_exception(app: FastAPI): 23 | # customize request validation error 24 | @app.exception_handler(RequestValidationError) 25 | async def val_exception_handler(req: Request, rve: RequestValidationError, code: int = status.HTTP_422_UNPROCESSABLE_ENTITY): 26 | lst = [] 27 | for error in rve.errors(): 28 | lst.append('{}=>{}'.format('.'.join(error['loc']), error['msg'])) 29 | return response_error(code, ' , '.join(lst)) 30 | 31 | # customize business error 32 | @app.exception_handler(BizException) 33 | async def biz_exception_handler(req: Request, exc: BizException): 34 | return response_error(exc.code, exc.message) 35 | 36 | # system error 37 | @app.exception_handler(Exception) 38 | async def exception_handler(req: Request, exc: Exception): 39 | if isinstance(exc, BizException): 40 | return 41 | error_msg = traceback.format_exc() 42 | logger.error(error_msg) 43 | return response_error(ErrorEnum.UNKNOWN_ERROR.get_code(), error_msg, status.HTTP_500_INTERNAL_SERVER_ERROR) 44 | 45 | 46 | class BizException(Exception): 47 | def __init__(self, error_message: ErrorEnum): 48 | self.code = error_message.get_code() 49 | self.message = error_message.get_message() 50 | 51 | 52 | def __msg__(self): 53 | return self.message 54 | -------------------------------------------------------------------------------- /application/api/schemas.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Union 2 | from pydantic import BaseModel 3 | 4 | 5 | class Question(BaseModel): 6 | query: str 7 | bedrock_model_id: str = "anthropic.claude-3-sonnet-20240229-v1:0" 8 | use_rag_flag: bool = True 9 | visualize_results_flag: bool = True 10 | intent_ner_recognition_flag: bool = True 11 | agent_cot_flag: bool = True 12 | profile_name: str 13 | explain_gen_process_flag: bool = True 14 | gen_suggested_question_flag: bool = False 15 | answer_with_insights: bool = False 16 | top_k: float = 250 17 | top_p: float = 0.9 18 | max_tokens: int = 2048 19 | temperature: float = 0.01 20 | context_window: int = 5 21 | session_id: str = "-1" 22 | user_id: str = "admin" 23 | username: str = '' 24 | query_rewrite: str = "" 25 | previous_intent: str = "" 26 | entity_user_select: dict = {} 27 | entity_retrieval: list = [] 28 | 29 | 30 | class Example(BaseModel): 31 | score: float 32 | question: str 33 | answer: str 34 | 35 | 36 | class HistoryRequest(BaseModel): 37 | user_id: str 38 | profile_name: str 39 | log_type: str = "chat_history" 40 | 41 | 42 | class HistorySessionRequest(BaseModel): 43 | session_id: str 44 | user_id: str 45 | profile_name: str 46 | log_type: str = "chat_history" 47 | 48 | 49 | class QueryEntity(BaseModel): 50 | query: str 51 | sql: str 52 | 53 | 54 | class FeedBackInput(BaseModel): 55 | feedback_type: str 56 | data_profiles: str 57 | query: str 58 | query_intent: str 59 | query_answer: str 60 | session_id: str = "-1" 61 | user_id: str = "admin" 62 | error_description: str = "" 63 | error_categories: str = "" 64 | correct_sql_reference: str = "" 65 | 66 | 67 | class Option(BaseModel): 68 | data_profiles: list[str] 69 | bedrock_model_ids: list[str] 70 | 71 | 72 | class CustomQuestion(BaseModel): 73 | custom_question: list[str] 74 | 75 | 76 | class ChartEntity(BaseModel): 77 | chart_type: str 78 | chart_data: list[Any] 79 | 80 | 81 | class SQLSearchResult(BaseModel): 82 | sql: str 83 | sql_data: list[Any] 84 | data_show_type: str 85 | sql_gen_process: str 86 | data_analyse: str 87 | sql_data_chart: list[ChartEntity] 88 | 89 | 90 | class TaskSQLSearchResult(BaseModel): 91 | sub_task_query: str 92 | sql_search_result: SQLSearchResult 93 | 94 | 95 | class KnowledgeSearchResult(BaseModel): 96 | knowledge_response: str 97 | 98 | 99 | class AgentSearchResult(BaseModel): 100 | agent_sql_search_result: list[TaskSQLSearchResult] 101 | agent_summary: str 102 | 103 | 104 | class AskReplayResult(BaseModel): 105 | query_rewrite: str 106 | 107 | 108 | class AskEntitySelect(BaseModel): 109 | entity_select_info: dict[str, Any] 110 | entity_retrieval: list[Any] 111 | 112 | 113 | class Answer(BaseModel): 114 | query: str 115 | query_rewrite: str = "" 116 | query_intent: str 117 | knowledge_search_result: KnowledgeSearchResult 118 | sql_search_result: SQLSearchResult 119 | agent_search_result: AgentSearchResult 120 | ask_rewrite_result: AskReplayResult 121 | suggested_question: list[str] 122 | ask_entity_select: AskEntitySelect 123 | error_log: dict[str, Any] 124 | 125 | 126 | class Message(BaseModel): 127 | type: str 128 | content: Union[str, Answer] 129 | 130 | 131 | class HistoryMessage(BaseModel): 132 | session_id: str 133 | messages: list[Message] 134 | 135 | 136 | class ChatHistory(BaseModel): 137 | messages: list[HistoryMessage] 138 | -------------------------------------------------------------------------------- /application/config_files/stauth_config.yaml: -------------------------------------------------------------------------------- 1 | credentials: 2 | usernames: 3 | admin: 4 | email: amazon@amazon.com 5 | failed_login_attempts: 0 # Will be managed automatically 6 | logged_in: False # Will be managed automatically 7 | name: AWS 8 | password: $2b$12$NDQv5NLaWiVlNuzQYHwAo.tv.f.TuX1nbdoUZi44/Y3xv4I4QAfjy # Set the password following instructions in README 9 | cookie: 10 | expiry_days: 2 11 | key: some_signature_key # Must be string 12 | name: some_cookie_name 13 | pre-authorized: 14 | emails: 15 | - amazon@amazon.com 16 | -------------------------------------------------------------------------------- /application/docker-compose-build.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # 提示用户输入容器名称 4 | # read -p "请输入要停止和删除的容器名称: " container_name 5 | 6 | # 查找与输入名称匹配的容器 7 | container_id=$(docker ps -aq --filter="name=nlq-webserver") 8 | 9 | # 如果找到匹配的容器 10 | if [ -n "$container_id" ]; then 11 | # 停止容器 12 | echo "正在停止容器 $container_name..." 13 | docker stop $container_id 14 | 15 | # 删除容器 16 | echo "正在删除容器 $container_name..." 17 | docker rm $container_id 18 | 19 | echo "容器 $container_name 已成功停止和删除." 20 | 21 | echo "容器 $container_name 已重新启动." 22 | else 23 | echo "没有找到名称为 $container_name 的容器." 24 | fi 25 | 26 | container_id=$(docker ps -aq --filter="name=nlq-api") 27 | 28 | # 如果找到匹配的容器 29 | if [ -n "$container_id" ]; then 30 | # 停止容器 31 | echo "正在停止容器 $container_name..." 32 | docker stop $container_id 33 | 34 | # 删除容器 35 | echo "正在删除容器 $container_name..." 36 | docker rm $container_id 37 | 38 | echo "容器 $container_name 已成功停止和删除." 39 | 40 | echo "容器 $container_name 已重新启动." 41 | else 42 | echo "没有找到名称为 $container_name 的容器." 43 | fi 44 | 45 | container_id=$(docker ps -aq --filter="name=react-front-end") 46 | 47 | # 如果找到匹配的容器 48 | if [ -n "$container_id" ]; then 49 | # 停止容器 50 | echo "正在停止容器 $container_name..." 51 | docker stop $container_id 52 | 53 | # 删除容器 54 | echo "正在删除容器 $container_name..." 55 | docker rm $container_id 56 | 57 | echo "容器 $container_name 已成功停止和删除." 58 | 59 | echo "容器 $container_name 已重新启动." 60 | else 61 | echo "没有找到名称为 $container_name 的容器." 62 | fi 63 | 64 | 65 | docker-compose build 66 | 67 | docker-compose up -d 68 | 69 | docker images -q --filter "dangling=true" | xargs -r docker rmi -------------------------------------------------------------------------------- /application/docker-compose.yml: -------------------------------------------------------------------------------- 1 | version: '3' 2 | services: 3 | opensearch-node1: 4 | image: public.ecr.aws/opensearchproject/opensearch:2.11.1 5 | container_name: opensearch-node1 6 | environment: 7 | - cluster.name=opensearch-cluster 8 | - node.name=opensearch-node1 9 | - discovery.seed_hosts=opensearch-node1,opensearch-node2 10 | - cluster.initial_cluster_manager_nodes=opensearch-node1,opensearch-node2 11 | - bootstrap.memory_lock=true # along with the memlock settings below, disables swapping 12 | - "OPENSEARCH_JAVA_OPTS=-Xms512m -Xmx512m" # minimum and maximum Java heap size, recommend setting both to 50% of system RAM 13 | ulimits: 14 | memlock: 15 | soft: -1 16 | hard: -1 17 | nofile: 18 | soft: 65536 # maximum number of open files for the OpenSearch user, set to at least 65536 on modern systems 19 | hard: 65536 20 | volumes: 21 | - opensearch-data1:/usr/share/opensearch/data 22 | ports: 23 | - 9200:9200 24 | - 9600:9600 # required for Performance Analyzer 25 | networks: 26 | - opensearch-net 27 | opensearch-node2: 28 | image: public.ecr.aws/opensearchproject/opensearch:2.11.1 29 | container_name: opensearch-node2 30 | environment: 31 | - cluster.name=opensearch-cluster 32 | - node.name=opensearch-node2 33 | - discovery.seed_hosts=opensearch-node1,opensearch-node2 34 | - cluster.initial_cluster_manager_nodes=opensearch-node1,opensearch-node2 35 | - bootstrap.memory_lock=true 36 | - "OPENSEARCH_JAVA_OPTS=-Xms512m -Xmx512m" 37 | ulimits: 38 | memlock: 39 | soft: -1 40 | hard: -1 41 | nofile: 42 | soft: 65536 43 | hard: 65536 44 | volumes: 45 | - opensearch-data2:/usr/share/opensearch/data 46 | networks: 47 | - opensearch-net 48 | opensearch-dashboards: 49 | image: public.ecr.aws/opensearchproject/opensearch-dashboards:2.11.1 50 | read_only: true 51 | container_name: opensearch-dashboards 52 | ports: 53 | - 5601:5601 54 | expose: 55 | - "5601" 56 | environment: 57 | OPENSEARCH_HOSTS: '["https://opensearch-node1:9200","https://opensearch-node2:9200"]' 58 | networks: 59 | - opensearch-net 60 | mysql-db: 61 | # 指定容器的名称 62 | container_name: nlq-mysql 63 | # 指定镜像和版本 64 | image: public.ecr.aws/docker/library/mysql:8.0 65 | ports: 66 | - "3306:3306" 67 | restart: always 68 | environment: 69 | # 配置root密码 70 | MYSQL_ROOT_PASSWORD: password 71 | MYSQL_DATABASE: llm 72 | MYSQL_USER: llmdata 73 | MYSQL_PASSWORD: llmdata 74 | volumes: 75 | # 挂载数据目录 76 | - mysql-data:/var/lib/mysql 77 | # 挂载配置文件目录 78 | #- "./mysql/config:/etc/mysql/conf.d" 79 | - ./initial_data:/opt/data 80 | networks: 81 | - opensearch-net 82 | streamlit-demo: 83 | container_name: nlq-webserver 84 | build: . 85 | env_file: 86 | - .env 87 | ports: 88 | - "80:8501" 89 | - "8765:8765" 90 | expose: 91 | - "8501" 92 | volumes: 93 | - ./config_files:/app/config_files 94 | - ./deployment:/app/deployment 95 | networks: 96 | - opensearch-net 97 | front-end: 98 | container_name: react-front-end 99 | build: 100 | context: ../report-front-end 101 | dockerfile: Dockerfile 102 | restart: always 103 | ports: 104 | - "3000:80" 105 | expose: 106 | - "80" 107 | networks: 108 | - opensearch-net 109 | api: 110 | container_name: nlq-api 111 | build: 112 | context: . 113 | dockerfile: Dockerfile-api 114 | env_file: 115 | - .env 116 | ports: 117 | - "8000:8000" 118 | expose: 119 | - "8000" 120 | volumes: 121 | - ./config_files:/app/config_files 122 | networks: 123 | - opensearch-net 124 | volumes: 125 | opensearch-data1: 126 | opensearch-data2: 127 | mysql-data: 128 | networks: 129 | opensearch-net: -------------------------------------------------------------------------------- /application/generate_streamlit_password.py: -------------------------------------------------------------------------------- 1 | 2 | from streamlit_authenticator.utilities.hasher import Hasher 3 | 4 | if __name__ == "__main__": 5 | password = input("please enter the password: ") 6 | hashed_passwords = Hasher([password]).generate() 7 | print("hashed_passwords: ", hashed_passwords[0]) 8 | -------------------------------------------------------------------------------- /application/initial_data/README.md: -------------------------------------------------------------------------------- 1 | Please download demo data to this folder. For example, run the following command: 2 | -------------------------------------------------------------------------------- /application/main.py: -------------------------------------------------------------------------------- 1 | 2 | from fastapi import FastAPI, status, Request 3 | from fastapi.staticfiles import StaticFiles 4 | from fastapi.responses import RedirectResponse, Response 5 | from api.exception_handler import biz_exception 6 | from api.main import router 7 | from fastapi.middleware.cors import CORSMiddleware 8 | from api import service 9 | from api.schemas import Option 10 | from utils.auth import authenticate, skipAuthentication 11 | 12 | MAX_CHAT_WINDOW_SIZE = 10 * 2 13 | app = FastAPI(title='GenBI') 14 | 15 | app.add_middleware( 16 | CORSMiddleware, 17 | allow_origins=['*'], 18 | allow_credentials=True, 19 | allow_methods=['*'], 20 | allow_headers=['*'], 21 | ) 22 | 23 | @app.middleware("http") 24 | async def http_authenticate(request: Request, call_next): 25 | # print('---HTTP REQUEST---', vars(request), request.headers) 26 | 27 | if request.url.path == "/" or request.url.path == "/ping": 28 | return await call_next(request) 29 | 30 | if request.method == "OPTIONS": 31 | response = Response(status_code=status.HTTP_200_OK) 32 | response.headers["Access-Control-Allow-Origin"] = "*" 33 | response.headers["Access-Control-Allow-Methods"] = "GET, POST, PUT, DELETE, OPTIONS" 34 | response.headers["Access-Control-Allow-Headers"] = "*" 35 | return response 36 | 37 | if not skipAuthentication: 38 | access_token = request.headers.get("X-Access-Token") 39 | id_token = request.headers.get("X-Id-Token") 40 | refresh_token = request.headers.get("X-Refresh-Token") 41 | 42 | response = authenticate(access_token, id_token, refresh_token) 43 | else: 44 | response = {'X-Status-Code': status.HTTP_200_OK} 45 | 46 | if not skipAuthentication and response["X-Status-Code"] != status.HTTP_200_OK: 47 | 48 | response_error = Response(status_code=response["X-Status-Code"]) 49 | response_error.headers["Access-Control-Allow-Origin"] = "*" 50 | response_error.headers["Access-Control-Allow-Methods"] = "GET, POST, PUT, DELETE, OPTIONS" 51 | response_error.headers["Access-Control-Allow-Headers"] = "*" 52 | return response_error 53 | else: 54 | if not skipAuthentication: 55 | username = response["X-User-Name"] 56 | else: 57 | username = "admin" 58 | response = await call_next(request) 59 | if not skipAuthentication: 60 | response.headers["X-User-Name"] = username 61 | response.headers["Access-Control-Allow-Origin"] = "*" 62 | response.headers["Access-Control-Allow-Methods"] = "GET, POST, PUT, DELETE, OPTIONS" 63 | response.headers["Access-Control-Allow-Headers"] = "*" 64 | return response 65 | 66 | # Global exception capture 67 | biz_exception(app) 68 | app.mount("/static", StaticFiles(directory="static"), name="static") 69 | app.include_router(router) 70 | 71 | 72 | # changed from "/" to "/test" to avoid health check fails in ECS 73 | @app.get("/test", status_code=status.HTTP_302_FOUND) 74 | def index(): 75 | return RedirectResponse("static/WebSocket.html") 76 | 77 | 78 | # health check 79 | @app.get("/") 80 | def health(): 81 | return {"status": "ok"} 82 | 83 | @app.get("/ping") 84 | def ping(): 85 | return {"status": "ok"} 86 | 87 | 88 | @app.get("/option", response_model=Option) 89 | def option(): 90 | return service.get_option() 91 | -------------------------------------------------------------------------------- /application/nlq/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aws-samples/generative-bi-using-rag/19db1a855e19167008a0adde451099b0c5726559/application/nlq/__init__.py -------------------------------------------------------------------------------- /application/nlq/business/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aws-samples/generative-bi-using-rag/19db1a855e19167008a0adde451099b0c5726559/application/nlq/business/__init__.py -------------------------------------------------------------------------------- /application/nlq/business/connection.py: -------------------------------------------------------------------------------- 1 | 2 | from nlq.data_access.dynamo_connection import ConnectConfigDao, ConnectConfigEntity 3 | from nlq.data_access.database import RelationDatabase 4 | from utils.logging import getLogger 5 | 6 | logger = getLogger() 7 | 8 | 9 | class ConnectionManagement: 10 | connection_config_dao = ConnectConfigDao() 11 | 12 | @classmethod 13 | def get_all_connections(cls): 14 | logger.info('get all connections...') 15 | return [conn.conn_name for conn in cls.connection_config_dao.get_db_list()] 16 | 17 | @classmethod 18 | def add_connection(cls, conn_name, db_type, db_host, db_port, db_user, db_pwd, db_name, comment): 19 | cls.connection_config_dao.add_url_db(conn_name, db_type, db_host, db_port, db_user, db_pwd, db_name, comment) 20 | logger.info(f"Connection {conn_name} added") 21 | 22 | @classmethod 23 | def get_conn_config_by_name(cls, conn_name): 24 | return cls.connection_config_dao.get_by_name(conn_name) 25 | 26 | @classmethod 27 | def update_connection(cls, conn_name, db_type, db_host, db_port, db_user, db_pwd, db_name, comment): 28 | cls.connection_config_dao.update_db_info(conn_name, db_type, db_host, db_port, db_user, db_pwd, db_name, 29 | comment) 30 | logger.info(f"Connection {conn_name} updated") 31 | 32 | @classmethod 33 | def delete_connection(cls, conn_name): 34 | if cls.connection_config_dao.delete(conn_name): 35 | logger.info(f"Connection {conn_name} deleted") 36 | else: 37 | logger.warning(f"Failed to delete Connection {conn_name}") 38 | 39 | @classmethod 40 | def get_table_name_by_config(cls, conn_config: ConnectConfigEntity, schema_names): 41 | return RelationDatabase.get_all_tables_by_connection(conn_config, schema_names) 42 | 43 | @classmethod 44 | def get_all_schemas_by_config(cls, conn_config: ConnectConfigEntity): 45 | return RelationDatabase.get_all_schema_names_by_connection(conn_config) 46 | 47 | @classmethod 48 | def get_table_definition_by_config(cls, conn_config: ConnectConfigEntity, schema_names, table_names): 49 | return RelationDatabase.get_table_definition_by_connection(conn_config, schema_names, table_names) 50 | 51 | @classmethod 52 | def get_db_url_by_name(cls, conn_name): 53 | conn_config = cls.get_conn_config_by_name(conn_name) 54 | return RelationDatabase.get_db_url_by_connection(conn_config) 55 | 56 | @classmethod 57 | def get_db_password_host_by_name(cls, conn_name): 58 | conn_config = cls.get_conn_config_by_name(conn_name) 59 | return RelationDatabase.get_password_host_by_connection(conn_config) 60 | 61 | @classmethod 62 | def get_db_type_by_name(cls, conn_name): 63 | conn_config = cls.get_conn_config_by_name(conn_name) 64 | return conn_config.db_type 65 | -------------------------------------------------------------------------------- /application/nlq/business/datasource/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aws-samples/generative-bi-using-rag/19db1a855e19167008a0adde451099b0c5726559/application/nlq/business/datasource/__init__.py -------------------------------------------------------------------------------- /application/nlq/business/datasource/base.py: -------------------------------------------------------------------------------- 1 | import yaml 2 | import re 3 | from abc import ABC, abstractmethod 4 | 5 | from nlq.business.login_user import LoginUser 6 | from utils.logging import getLogger 7 | 8 | logger = getLogger() 9 | 10 | 11 | class RowLevelSecurityMode: 12 | NONE = None 13 | TABLE_REPLACE = 'TABLE_REPLACE' 14 | 15 | 16 | class DataSourceBase(ABC): 17 | 18 | def __init__(self): 19 | pass 20 | 21 | @abstractmethod 22 | def support_row_level_security(self) -> bool: 23 | """Abstract method to check if row-level security is supported""" 24 | pass 25 | 26 | @abstractmethod 27 | def row_level_security_mode(self) -> RowLevelSecurityMode: 28 | """Abstract method to get the row-level security mode""" 29 | pass 30 | 31 | @staticmethod 32 | def validate_row_level_security_config(rls_config: str) -> bool: 33 | """method to validate row-level security config""" 34 | try: 35 | DataSourceBase.convert_rls_yaml_to_table_subquery(LoginUser('validate_user'), yaml.safe_load(rls_config)) 36 | return True 37 | except Exception as e: 38 | logger.error(f'Failed to validate RLS config:\n{rls_config}') 39 | return False 40 | 41 | def row_level_security_control(self, sql: str, rls_config: str, login_user: LoginUser) -> str: 42 | """Method to apply row-level security control""" 43 | replaced_sql = sql 44 | rls_config_obj = {'tables': []} 45 | try: 46 | if rls_config: 47 | rls_config_obj = yaml.safe_load(rls_config) 48 | # YAML format: 49 | # {'tables': [{'table_name': 'table_a', 'columns': [ 50 | # {'column_name': 'username', 'column_value': '$login_user.username'}]}]} 51 | table_statements = self.convert_rls_yaml_to_table_subquery(login_user, rls_config_obj) 52 | 53 | logger.info(f'original SQL: {sql}') 54 | replaced_sql = self.replace_table_with_cte(sql, table_statements) 55 | logger.info(f'RLS applied SQL: {replaced_sql}') 56 | except Exception as e: 57 | logger.exception('Failed to apply RLS config') 58 | logger.info(f'{sql=}') 59 | logger.info(f'{rls_config=}') 60 | logger.info(f'{login_user=}') 61 | 62 | return replaced_sql 63 | 64 | @staticmethod 65 | def convert_rls_yaml_to_table_subquery(login_user, rls_config_obj): 66 | """method to convert RLS YAML to table subqueries""" 67 | table_statements = {} 68 | for table in rls_config_obj['tables']: 69 | table_name = table['table_name'] 70 | 71 | condition = '' 72 | columns = table['columns'] 73 | for column in columns: 74 | column_name = column['column_name'] 75 | column_value = column['column_value'] 76 | 77 | if column_value == '$login_user.username' and login_user is not None: 78 | column_value = login_user.get_username() 79 | 80 | if not condition: 81 | condition += condition + f"{column_name} = '{column_value}'" 82 | else: 83 | condition += condition + f" AND {column_name} = '{column_value}'" 84 | 85 | statement = f'(SELECT * FROM {table_name} WHERE {condition})' 86 | table_statements[table_name] = statement 87 | return table_statements 88 | 89 | @staticmethod 90 | def replace_table_with_cte(sql, table_config: dict): 91 | """method to replace tables with CTEs""" 92 | cte_sql = '' 93 | sql_splits = [''] 94 | origin_sql_has_cte = False 95 | if 'with' in sql: 96 | sql_splits = sql.split('with') 97 | origin_sql_has_cte = True 98 | elif 'WITH' in sql: 99 | sql_splits = sql.split('WITH') 100 | origin_sql_has_cte = True 101 | else: 102 | # cte_sql = "WITH\n" 103 | sql_splits.append(sql) 104 | 105 | for table_name, sub_query in table_config.items(): 106 | if '.' in table_name: 107 | # 如果表名包含schema name(格式: schema.table), 则将.替换成__ 108 | schema_name, table_name_alone = table_name.split('.') 109 | table_name_replaced = table_name.replace('.', '__') 110 | if table_name in sql_splits[1]: 111 | # 替换带schema的表名 112 | sql_splits[1] = re.sub(r'\b{}\b'.format(table_name), table_name_replaced, sql_splits[1]) 113 | elif table_name_alone in sql_splits[1]: 114 | # 替换不带schema的表名 115 | sql_splits[1] = re.sub(r'\b{}\b'.format(table_name_alone), table_name_replaced, sql_splits[1]) 116 | 117 | table_name = table_name_replaced 118 | cte_sql += f"/* rls applied */ {table_name} AS {sub_query},\n" 119 | if origin_sql_has_cte: 120 | cte_sql = cte_sql[:-1] 121 | else: 122 | cte_sql = cte_sql[:-2] 123 | 124 | return f'''WITH 125 | {cte_sql} 126 | {sql_splits[1]}''' 127 | 128 | def post_sql_generation(self, sql: str, rls_config: str = None, login_user: LoginUser = None) -> str: 129 | """Method to post-process SQL after generation""" 130 | if self.row_level_security_mode() != RowLevelSecurityMode.NONE and rls_config is not None: 131 | return self.row_level_security_control(sql, rls_config, login_user) 132 | # 默认直接返回输入的SQL 133 | return sql 134 | -------------------------------------------------------------------------------- /application/nlq/business/datasource/clickhouse.py: -------------------------------------------------------------------------------- 1 | from nlq.business.datasource.base import DataSourceBase, RowLevelSecurityMode 2 | 3 | 4 | class ClickHouseDataSource(DataSourceBase): 5 | 6 | def row_level_security_mode(self) -> RowLevelSecurityMode: 7 | return RowLevelSecurityMode.TABLE_REPLACE 8 | 9 | def support_row_level_security(self) -> bool: 10 | return True 11 | 12 | def __init__(self): 13 | super().__init__() -------------------------------------------------------------------------------- /application/nlq/business/datasource/default.py: -------------------------------------------------------------------------------- 1 | from nlq.business.datasource.base import DataSourceBase, RowLevelSecurityMode 2 | 3 | 4 | class DefaultDataSoruce(DataSourceBase): 5 | 6 | def row_level_security_mode(self) -> RowLevelSecurityMode: 7 | return RowLevelSecurityMode.NONE 8 | 9 | def support_row_level_security(self) -> bool: 10 | return False 11 | 12 | def __init__(self): 13 | super().__init__() 14 | -------------------------------------------------------------------------------- /application/nlq/business/datasource/factory.py: -------------------------------------------------------------------------------- 1 | from nlq.business.datasource.clickhouse import ClickHouseDataSource 2 | from nlq.business.datasource.base import DataSourceBase 3 | from nlq.business.datasource.default import DefaultDataSoruce 4 | from nlq.business.datasource.mysql import MySQLDataSource 5 | from nlq.business.login_user import LoginUser 6 | 7 | 8 | class DataSourceFactory: 9 | 10 | @staticmethod 11 | def get_data_source(data_source_type) -> DataSourceBase: 12 | if data_source_type == "mysql": 13 | return MySQLDataSource() 14 | elif data_source_type == "clickhouse": 15 | return ClickHouseDataSource() 16 | else: 17 | return DefaultDataSoruce() 18 | 19 | @staticmethod 20 | def apply_row_level_security_for_sql(db_type: str, sql: str, rls_config: str, username: str): 21 | data_source = DataSourceFactory.get_data_source(db_type) 22 | post_sql = data_source.post_sql_generation( 23 | sql, 24 | rls_config=rls_config, 25 | login_user=LoginUser(username)) 26 | return post_sql 27 | -------------------------------------------------------------------------------- /application/nlq/business/datasource/mysql.py: -------------------------------------------------------------------------------- 1 | from nlq.business.datasource.base import DataSourceBase, RowLevelSecurityMode 2 | 3 | 4 | class MySQLDataSource(DataSourceBase): 5 | 6 | def row_level_security_mode(self) -> RowLevelSecurityMode: 7 | return RowLevelSecurityMode.TABLE_REPLACE 8 | 9 | def support_row_level_security(self) -> bool: 10 | return True 11 | 12 | def __init__(self): 13 | super().__init__() 14 | 15 | 16 | -------------------------------------------------------------------------------- /application/nlq/business/log_feedback.py: -------------------------------------------------------------------------------- 1 | import json 2 | import logging 3 | 4 | from nlq.data_access.dynamo_query_log import DynamoQueryLogDao 5 | from utils.logging import getLogger 6 | 7 | logger = getLogger() 8 | 9 | 10 | class FeedBackManagement: 11 | dynammo_log_dao = DynamoQueryLogDao() 12 | 13 | @classmethod 14 | def add_log_to_database(cls, log_id, user_id, session_id, profile_name, sql, query, intent, log_info, time_str, 15 | log_type='chat_history'): 16 | cls.dynammo_log_dao.add_log(log_id=log_id, profile_name=profile_name, user_id=user_id, session_id=session_id, 17 | sql=sql, query=query, intent=intent, log_info=log_info, time_str=time_str, 18 | log_type=log_type) -------------------------------------------------------------------------------- /application/nlq/business/log_store.py: -------------------------------------------------------------------------------- 1 | import json 2 | 3 | from nlq.data_access.opensearch_query_log import OpenSearchQueryLogDao 4 | from utils.logging import getLogger 5 | 6 | logger = getLogger() 7 | 8 | 9 | class LogManagement: 10 | # query_log_dao = DynamoQueryLogDao() 11 | query_log_dao = OpenSearchQueryLogDao() 12 | 13 | @classmethod 14 | def add_log_to_database(cls, log_id, user_id, session_id, profile_name, sql, query, intent, log_info, time_str, 15 | log_type='chat_history'): 16 | cls.query_log_dao.add_log(log_id=log_id, profile_name=profile_name, user_id=user_id, session_id=session_id, 17 | sql=sql, query=query, intent=intent, log_info=log_info, time_str=time_str, 18 | log_type=log_type) 19 | 20 | @classmethod 21 | def get_history(cls, user_id, profile_name, log_type="chat_history"): 22 | history_list = cls.query_log_dao.get_history_by_user_profile(user_id, profile_name, log_type) 23 | return history_list 24 | 25 | @classmethod 26 | def get_history_by_session(cls, session_id, user_id, profile_name, size, log_type): 27 | user_query_history = [] 28 | history_list = cls.query_log_dao.get_logs_by_session(profile_name=profile_name, 29 | session_id=session_id, 30 | user_id=user_id, 31 | size=size, 32 | log_type=log_type) 33 | for log in history_list: 34 | logger.info("the opensearch log is : {log}".format(log=log)) 35 | answer = json.loads(log['log_info']) 36 | user_query_history.append("user:" + log['query']) 37 | user_query_history.append("assistant:" + answer['query_rewrite']) 38 | return user_query_history 39 | 40 | @classmethod 41 | def get_all_sessions(cls, user_id, profile_name, log_type): 42 | session_list = cls.query_log_dao.get_all_history(profile_name=profile_name, 43 | user_id=user_id, 44 | log_type=log_type) 45 | return session_list 46 | 47 | @classmethod 48 | def get_all_history_by_session(cls, session_id, user_id, profile_name, size, log_type): 49 | history_list = cls.query_log_dao.get_logs_by_session(profile_name=profile_name, 50 | session_id=session_id, 51 | user_id=user_id, 52 | size=size, 53 | log_type=log_type) 54 | return history_list 55 | 56 | @classmethod 57 | def delete_history_by_session(cls, user_id, profile_name, session_id, log_type="chat_history"): 58 | return cls.query_log_dao.delete_history_by_session(user_id, profile_name, session_id) -------------------------------------------------------------------------------- /application/nlq/business/login_user.py: -------------------------------------------------------------------------------- 1 | 2 | class LoginUser: 3 | 4 | def __init__(self, username): 5 | self.username = username 6 | 7 | def get_username(self): 8 | return self.username 9 | -------------------------------------------------------------------------------- /application/nlq/business/model.py: -------------------------------------------------------------------------------- 1 | 2 | from nlq.data_access.dynamo_model import ModelConfigDao, ModelConfigEntity 3 | from utils.logging import getLogger 4 | 5 | logger = getLogger() 6 | 7 | 8 | class ModelManagement: 9 | model_config_dao = ModelConfigDao() 10 | 11 | @classmethod 12 | def get_all_models(cls): 13 | logger.info('get all models...') 14 | return [conn.model_id for conn in cls.model_config_dao.get_model_list()] 15 | 16 | @classmethod 17 | def get_all_models_with_info(cls): 18 | logger.info('get all models with info...') 19 | model_list = cls.model_config_dao.get_model_list() 20 | model_map = {} 21 | for model in model_list: 22 | model_map[model.model_id] = { 23 | 'model_region': model.model_region, 24 | 'prompt_template': model.prompt_template, 25 | 'input_payload': model.input_payload, 26 | 'output_format': model.output_format 27 | } 28 | 29 | return model_map 30 | 31 | @classmethod 32 | def add_model(cls, model_id, model_region, prompt_template, input_payload, output_format): 33 | entity = ModelConfigEntity(model_id, model_region, prompt_template, input_payload, output_format) 34 | cls.model_config_dao.add(entity) 35 | logger.info(f"Model {model_id} added") 36 | 37 | @classmethod 38 | def get_model_by_id(cls, model_id): 39 | return cls.model_config_dao.get_by_id(model_id) 40 | 41 | @classmethod 42 | def update_model(cls, model_id, model_region, prompt_template, input_payload, output_format): 43 | entity = ModelConfigEntity(model_id, model_region, prompt_template, input_payload, output_format) 44 | cls.model_config_dao.update(entity) 45 | logger.info(f"Model {model_id} updated") 46 | 47 | @classmethod 48 | def delete_model(cls, model_id): 49 | cls.model_config_dao.delete(model_id) 50 | logger.info(f"Model {model_id} updated") 51 | -------------------------------------------------------------------------------- /application/nlq/business/nlq_chain.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import re 3 | import logging 4 | from nlq.business.connection import ConnectionManagement 5 | from utils.apis import query_from_sql_pd 6 | logger = logging.getLogger(__name__) 7 | 8 | class NLQChain: 9 | 10 | def __init__(self, profile): 11 | self.question = '' 12 | self.profile = profile 13 | self.retrieve_samples = [] 14 | self.generated_sql_response = '' 15 | self.executed_result_df: pd.DataFrame | None = None 16 | self.visualization_config_change: bool = False 17 | self.sql = '' 18 | 19 | def set_question(self, question): 20 | if self.question != question: 21 | self.retrieve_samples = [] 22 | self.generated_sql_response = '' 23 | self.executed_result_df = None 24 | self.question = question 25 | 26 | def get_question(self): 27 | return self.question 28 | 29 | def get_profile(self): 30 | return self.profile 31 | 32 | def get_retrieve_samples(self): 33 | return self.retrieve_samples 34 | 35 | def set_retrieve_samples(self, retrieve_samples): 36 | self.retrieve_samples = retrieve_samples 37 | 38 | def set_generated_sql_response(self, sql_response): 39 | self.generated_sql_response = sql_response 40 | 41 | def get_generated_sql_response(self): 42 | return self.generated_sql_response 43 | 44 | def set_generated_sql(self, sql): 45 | self.sql = sql 46 | 47 | def get_generated_sql(self): 48 | if self.sql != "": 49 | return self.sql 50 | sql = "" 51 | try: 52 | return self.generated_sql_response.split("")[1].split("")[0] 53 | except IndexError: 54 | logger.error("No SQL found in the LLM's response") 55 | return sql 56 | 57 | def get_generated_sql_explain(self): 58 | index = self.generated_sql_response.find("") 59 | if index != -1: 60 | return self.generated_sql_response[index + len(""):] 61 | else: 62 | return self.generated_sql_response 63 | 64 | def set_executed_result_df(self, df): 65 | self.executed_result_df = df 66 | 67 | def get_executed_result_df(self, profile, force_execute_query=True): 68 | if self.executed_result_df is None and force_execute_query: 69 | db_url = profile['db_url'] 70 | if not db_url: 71 | conn_name = profile['conn_name'] 72 | db_url = ConnectionManagement.get_db_url_by_name(conn_name) 73 | sql = self.get_generated_sql() 74 | if sql == "": 75 | return pd.DataFrame() 76 | self.executed_result_df = query_from_sql_pd( 77 | p_db_url=db_url, 78 | query=self.get_generated_sql()) 79 | 80 | return self.executed_result_df 81 | 82 | def set_visualization_config_change(self, change_value=True): 83 | self.visualization_config_change = change_value 84 | 85 | def is_visualization_config_changed(self): 86 | return self.visualization_config_change 87 | 88 | -------------------------------------------------------------------------------- /application/nlq/business/profile.py: -------------------------------------------------------------------------------- 1 | 2 | from nlq.data_access.dynamo_profile import ProfileConfigDao, ProfileConfigEntity 3 | from utils.logging import getLogger 4 | 5 | logger = getLogger() 6 | 7 | class ProfileManagement: 8 | profile_config_dao = ProfileConfigDao() 9 | 10 | @classmethod 11 | def get_all_profiles(cls): 12 | logger.info('get all profiles...') 13 | return [conn.profile_name for conn in cls.profile_config_dao.get_profile_list()] 14 | 15 | @classmethod 16 | def get_all_profiles_with_info(cls): 17 | logger.info('get all profiles with info...') 18 | profile_list = cls.profile_config_dao.get_profile_list() 19 | profile_map = {} 20 | for profile in profile_list: 21 | profile_map[profile.profile_name] = { 22 | 'db_url': '', 23 | 'db_type': profile.db_type, 24 | 'conn_name': profile.conn_name, 25 | 'tables_info': profile.tables_info, 26 | 'hints': '', 27 | 'search_samples': [], 28 | 'comments': profile.comments, 29 | 'prompt_map': profile.prompt_map, 30 | 'row_level_security_config': profile.row_level_security_config if profile.enable_row_level_security else None 31 | } 32 | 33 | return profile_map 34 | 35 | @classmethod 36 | def add_profile(cls, profile_name, conn_name, schemas, tables, comment, db_type: str): 37 | entity = ProfileConfigEntity(profile_name, conn_name, schemas, tables, comment, db_type=db_type) 38 | cls.profile_config_dao.add(entity) 39 | logger.info(f"Profile {profile_name} added") 40 | 41 | @classmethod 42 | def get_profile_by_name(cls, profile_name): 43 | return cls.profile_config_dao.get_by_name(profile_name) 44 | 45 | @classmethod 46 | def update_profile(cls, profile_name, conn_name, schemas, tables, comment, tables_info, db_type, rls_enable, rls_config): 47 | all_profiles = ProfileManagement.get_all_profiles_with_info() 48 | prompt_map = all_profiles[profile_name]["prompt_map"] 49 | entity = ProfileConfigEntity(profile_name, conn_name, schemas, tables, comment, tables_info, prompt_map, 50 | db_type=db_type, 51 | enable_row_level_security=rls_enable, row_level_security_config=rls_config) 52 | cls.profile_config_dao.update(entity) 53 | logger.info(f"Profile {profile_name} updated") 54 | 55 | @classmethod 56 | def update_prompt_map(cls, profile_name, prompt_map): 57 | profile_info = ProfileManagement.get_profile_by_name(profile_name) 58 | entity = ProfileConfigEntity(profile_name, profile_info.conn_name, profile_info.schemas, profile_info.tables, profile_info.comments, 59 | tables_info=profile_info.tables_info, prompt_map=prompt_map, 60 | db_type=profile_info.db_type, 61 | enable_row_level_security=profile_info.enable_row_level_security, 62 | row_level_security_config=profile_info.row_level_security_config) 63 | cls.profile_config_dao.update(entity) 64 | logger.info(f"Profile {profile_name} updated") 65 | 66 | @classmethod 67 | def delete_profile(cls, profile_name): 68 | cls.profile_config_dao.delete(profile_name) 69 | logger.info(f"Profile {profile_name} updated") 70 | 71 | @classmethod 72 | def update_table_def(cls, profile_name, tables_info, merge_before_update=False): 73 | if merge_before_update: 74 | old_profile = cls.get_profile_by_name(profile_name) 75 | old_tables_info = old_profile.tables_info 76 | if old_tables_info is not None: 77 | # print(old_tables_info) 78 | for table_name, table_info in tables_info.items(): 79 | # copy annotation to new table info if old table has annotation 80 | if table_name in old_tables_info and 'tbl_a' in old_tables_info[table_name]: 81 | table_info['tbl_a'] = old_tables_info[table_name]['tbl_a'] 82 | table_info['col_a'] = old_tables_info[table_name]['col_a'] 83 | 84 | logger.info('tables info merged', tables_info) 85 | 86 | cls.profile_config_dao.update_table_def(profile_name, tables_info) 87 | logger.info(f"Table definition updated") 88 | 89 | @classmethod 90 | def update_table_prompt_map(cls, profile_name, prompt_map): 91 | cls.profile_config_dao.update_table_prompt_map(profile_name, prompt_map) 92 | logger.info(f"System and user prompt updated") 93 | -------------------------------------------------------------------------------- /application/nlq/business/suggested_question.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from nlq.data_access.dynamo_suggested_question import SuggestedQuestionDao, SuggestedQuestionEntity 3 | from datetime import datetime, timezone 4 | from utils.constant import PROFILE_QUESTION_TABLE_NAME, ACTIVE_PROMPT_NAME, DEFAULT_PROMPT_NAME 5 | 6 | logger = logging.getLogger(__name__) 7 | 8 | class SuggestedQuestionManagement: 9 | sq_dao = SuggestedQuestionDao() 10 | 11 | @classmethod 12 | def get_prompt_by_name(cls, prompt_name: str): 13 | return cls.sq_dao.get_by_name(prompt_name) 14 | 15 | @classmethod 16 | def update_prompt(cls, prompt: str): 17 | current_time = datetime.now(timezone.utc) 18 | formatted_time = current_time.strftime("%Y-%m-%dT%H:%M:%SZ") 19 | logger.info(f"Creation time: %s", formatted_time) 20 | entity = SuggestedQuestionEntity(prompt, formatted_time) 21 | cls.sq_dao.update(entity) 22 | logger.info("Prompt updated") 23 | 24 | @classmethod 25 | def reset_to_default(cls): 26 | response = cls.sq_dao.get_by_name(DEFAULT_PROMPT_NAME) 27 | logger.info(response.prompt) 28 | return response.prompt 29 | -------------------------------------------------------------------------------- /application/nlq/business/user_profile.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from nlq.data_access.dynamo_user_profile import UserProfileConfigDao, UserProfileConfigEntity 3 | 4 | logger = logging.getLogger(__name__) 5 | 6 | class UserProfileManagement: 7 | user_profile_config_dao = UserProfileConfigDao() 8 | 9 | @classmethod 10 | def get_all_user_profiles(cls): 11 | logger.info('get all profiles...') 12 | logger.info(cls.user_profile_config_dao.get_user_profile_list()) 13 | logger.info('=' * 20) 14 | return [item.user_id for item in cls.user_profile_config_dao.get_user_profile_list()] 15 | 16 | @classmethod 17 | def get_all_user_profiles_with_info(cls): 18 | logger.info('get all user_profiles with info...') 19 | get_user_profile_list = cls.user_profile_config_dao.get_user_profile_list() 20 | user_profile_map = {} 21 | for user_profile in get_user_profile_list: 22 | user_profile_map[user_profile.user_id] = { 23 | 'profile_name_list': user_profile.profile_name_list 24 | } 25 | return user_profile_map 26 | 27 | @classmethod 28 | def add_user_profile(cls, user_id, profile_name): 29 | print(user_id) 30 | print(profile_name) 31 | entity = UserProfileConfigEntity(user_id, profile_name) 32 | print(entity) 33 | cls.user_profile_config_dao.add(entity) 34 | logger.info(f"User Profile {user_id} added") 35 | 36 | @classmethod 37 | def get_user_profile_by_id(cls, user_id): 38 | return cls.user_profile_config_dao.get_by_name(user_id) 39 | 40 | @classmethod 41 | def update_user_profile(cls, user_id, profile_name_list): 42 | entity = UserProfileConfigEntity(user_id,profile_name_list) 43 | old_user_profile_list = cls.get_user_profile_by_id(user_id) 44 | if old_user_profile_list: 45 | cls.user_profile_config_dao.update(entity) 46 | else: 47 | cls.user_profile_config_dao.add(entity) 48 | 49 | logger.info(f"User Profile {user_id} updated") 50 | 51 | @classmethod 52 | def delete_user_profile(cls, user_id): 53 | cls.user_profile_config_dao.delete(user_id) 54 | logger.info(f"User Profile {user_id} updated") 55 | -------------------------------------------------------------------------------- /application/nlq/core/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aws-samples/generative-bi-using-rag/19db1a855e19167008a0adde451099b0c5726559/application/nlq/core/__init__.py -------------------------------------------------------------------------------- /application/nlq/core/chat_context.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass, field 2 | from typing import List, Dict, Any 3 | 4 | @dataclass 5 | class ProcessingContext: 6 | search_box: str 7 | query_rewrite: str 8 | session_id: str 9 | user_id: str 10 | username: str 11 | selected_profile: str 12 | database_profile: Dict[str, Any] 13 | model_type: str 14 | use_rag_flag: bool 15 | intent_ner_recognition_flag: bool 16 | agent_cot_flag: bool 17 | explain_gen_process_flag: bool 18 | visualize_results_flag: bool 19 | data_with_analyse: bool 20 | gen_suggested_question_flag: bool 21 | auto_correction_flag: bool 22 | context_window: int 23 | entity_same_name_select: Dict[str, Any] 24 | user_query_history: List[str] 25 | opensearch_info: Dict[str, Any] 26 | previous_state: str = "INITIAL" 27 | entity_retrieval: List[str] = field(default_factory=list) 28 | entity_user_select: List[str] = field(default_factory=list) 29 | -------------------------------------------------------------------------------- /application/nlq/core/state.py: -------------------------------------------------------------------------------- 1 | from enum import Enum, auto 2 | 3 | 4 | class QueryState(Enum): 5 | INITIAL = auto() 6 | ENTITY_RETRIEVAL = auto() 7 | QA_RETRIEVAL = auto() 8 | SQL_GENERATION = auto() 9 | INTENT_RECOGNITION = auto() 10 | SEARCH_INTENT = auto() 11 | AGENT_SEARCH = auto() 12 | REJECT_INTENT = auto() 13 | KNOWLEDGE_SEARCH = auto() 14 | EXECUTE_QUERY = auto() 15 | ANALYZE_DATA = auto() 16 | AGENT_TASK = auto() 17 | AGENT_DATA_SUMMARY = auto() 18 | ASK_ENTITY_SELECT = auto() 19 | ASK_QUERY_REWRITE = auto() 20 | QUERY_REWRITE = auto() 21 | USER_SELECT_ENTITY = auto() 22 | DATA_VISUALIZATION = auto() 23 | ERROR = auto() 24 | COMPLETE = auto() -------------------------------------------------------------------------------- /application/nlq/data_access/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aws-samples/generative-bi-using-rag/19db1a855e19167008a0adde451099b0c5726559/application/nlq/data_access/__init__.py -------------------------------------------------------------------------------- /application/nlq/data_access/dynamo_model.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import boto3 4 | from botocore.exceptions import ClientError 5 | 6 | from utils.logging import getLogger 7 | 8 | 9 | logger = getLogger() 10 | 11 | # DynamoDB table name 12 | MODEL_CONFIG_TABLE_NAME = 'NlqModelConfig' 13 | DYNAMODB_AWS_REGION = os.environ.get('DYNAMODB_AWS_REGION') 14 | 15 | class ModelConfigEntity: 16 | 17 | def __init__(self, model_id: str, model_region: str, prompt_template: str, 18 | input_payload: str, output_format: str): 19 | self.model_id = model_id 20 | self.model_region = model_region 21 | self.prompt_template = prompt_template 22 | self.input_payload = input_payload 23 | self.output_format = output_format 24 | 25 | 26 | def to_dict(self): 27 | """Convert to DynamoDB item format""" 28 | base_props = { 29 | 'model_id': self.model_id, 30 | 'model_region': self.model_region, 31 | 'prompt_template': self.prompt_template, 32 | 'input_payload': self.input_payload, 33 | 'output_format': self.output_format 34 | } 35 | return base_props 36 | 37 | 38 | class ModelConfigDao: 39 | 40 | def __init__(self, table_name_prefix=''): 41 | self.dynamodb = boto3.resource('dynamodb', region_name=DYNAMODB_AWS_REGION) 42 | self.table_name = table_name_prefix + MODEL_CONFIG_TABLE_NAME 43 | if not self.exists(): 44 | self.create_table() 45 | self.table = self.dynamodb.Table(self.table_name) 46 | 47 | def exists(self): 48 | """ 49 | Determines whether a table exists. As a side effect, stores the table in 50 | a member variable. 51 | 52 | :param table_name: The name of the table to check. 53 | :return: True when the table exists; otherwise, False. 54 | """ 55 | try: 56 | table = self.dynamodb.Table(self.table_name) 57 | table.load() 58 | exists = True 59 | except ClientError as err: 60 | if err.response["Error"]["Code"] == "ResourceNotFoundException": 61 | exists = False 62 | logger.info("Table does not exist") 63 | else: 64 | logger.error( 65 | "Couldn't check for existence of %s. Here's why: %s: %s", 66 | self.table_name, 67 | err.response["Error"]["Code"], 68 | err.response["Error"]["Message"], 69 | ) 70 | raise 71 | # else: 72 | # self.table = table 73 | return exists 74 | 75 | def create_table(self): 76 | try: 77 | self.table = self.dynamodb.create_table( 78 | TableName=self.table_name, 79 | KeySchema=[ 80 | {"AttributeName": "model_id", "KeyType": "HASH"}, # Partition key 81 | # {"AttributeName": "title", "KeyType": "RANGE"}, # Sort key 82 | ], 83 | AttributeDefinitions=[ 84 | {"AttributeName": "model_id", "AttributeType": "S"}, 85 | # {"AttributeName": "conn_name", "AttributeType": "S"}, 86 | ], 87 | BillingMode='PAY_PER_REQUEST', 88 | ) 89 | self.table.wait_until_exists() 90 | logger.info(f"DynamoDB Table {self.table_name} created") 91 | except ClientError as err: 92 | print(type(err)) 93 | logger.error( 94 | "Couldn't create table %s. Here's why: %s: %s", 95 | self.table_name, 96 | err.response["Error"]["Code"], 97 | err.response["Error"]["Message"], 98 | ) 99 | raise 100 | 101 | def get_by_id(self, model_id): 102 | response = self.table.get_item(Key={'model_id': model_id}) 103 | if 'Item' in response: 104 | return ModelConfigEntity(**response['Item']) 105 | 106 | def add(self, entity): 107 | self.table.put_item(Item=entity.to_dict()) 108 | 109 | def update(self, entity): 110 | self.table.put_item(Item=entity.to_dict()) 111 | 112 | def delete(self, model_id): 113 | self.table.delete_item(Key={'model_id': model_id}) 114 | return True 115 | 116 | def get_model_list(self): 117 | response = self.table.scan() 118 | return [ModelConfigEntity(**item) for item in response['Items']] 119 | 120 | -------------------------------------------------------------------------------- /application/nlq/data_access/dynamo_suggested_question.py: -------------------------------------------------------------------------------- 1 | from datetime import datetime, timezone 2 | 3 | from utils.logging import getLogger 4 | from utils.prompt import SUGGESTED_QUESTION_PROMPT_CLAUDE3 5 | import boto3 6 | import os 7 | from botocore.exceptions import ClientError 8 | from utils.constant import PROFILE_QUESTION_TABLE_NAME, ACTIVE_PROMPT_NAME, DEFAULT_PROMPT_NAME 9 | 10 | logger = getLogger() 11 | 12 | class SuggestedQuestionEntity: 13 | 14 | def __init__(self, prompt: str, create_time: str, prompt_name: str = ACTIVE_PROMPT_NAME): 15 | self.prompt_name = prompt_name 16 | self.prompt = prompt 17 | self.create_time = create_time 18 | 19 | def to_dict(self): 20 | """Convert to DynamoDB item format""" 21 | base_props = { 22 | 'prompt_name': self.prompt_name, 23 | 'prompt': self.prompt, 24 | 'create_time': self.create_time 25 | } 26 | return base_props 27 | 28 | class SuggestedQuestionDao: 29 | 30 | def __init__(self, table_name_prefix=''): 31 | self.dynamodb = boto3.resource('dynamodb', region_name=os.getenv("DYNAMODB_AWS_REGION")) 32 | self.table_name = table_name_prefix + PROFILE_QUESTION_TABLE_NAME 33 | if not self.exists(): 34 | self.create_table() 35 | self.table = self.dynamodb.Table(self.table_name) 36 | 37 | def exists(self): 38 | """ 39 | Determines whether a table exists. As a side effect, stores the table in 40 | a member variable. 41 | 42 | :param table_name: The name of the table to check. 43 | :return: True when the table exists; otherwise, False. 44 | """ 45 | try: 46 | table = self.dynamodb.Table(self.table_name) 47 | table.load() 48 | exists = True 49 | except ClientError as err: 50 | if err.response["Error"]["Code"] == "ResourceNotFoundException": 51 | exists = False 52 | logger.info("Table does not exist") 53 | else: 54 | logger.error( 55 | "Couldn't check for existence of %s. Here's why: %s: %s", 56 | self.table_name, 57 | err.response["Error"]["Code"], 58 | err.response["Error"]["Message"], 59 | ) 60 | raise 61 | 62 | return exists 63 | 64 | def create_table(self): 65 | try: 66 | self.table = self.dynamodb.create_table( 67 | TableName=self.table_name, 68 | KeySchema=[ 69 | {"AttributeName": "prompt_name", "KeyType": "HASH"}, 70 | ], 71 | AttributeDefinitions=[ 72 | {"AttributeName": "prompt_name", "AttributeType": "S"}, 73 | ], 74 | ProvisionedThroughput={ 75 | "ReadCapacityUnits": 2, 76 | "WriteCapacityUnits": 1, 77 | }, 78 | ) 79 | self.table.wait_until_exists() 80 | 81 | # Add default prompt 82 | current_time = datetime.now(timezone.utc) 83 | formatted_time = current_time.strftime("%Y-%m-%dT%H:%M:%SZ") 84 | item = { 85 | "prompt_name": DEFAULT_PROMPT_NAME, 86 | "prompt": SUGGESTED_QUESTION_PROMPT_CLAUDE3, 87 | "create_time": formatted_time, 88 | } 89 | self.table.put_item(Item=item) 90 | 91 | # Add active prompt 92 | current_time = datetime.now(timezone.utc) 93 | formatted_time = current_time.strftime("%Y-%m-%dT%H:%M:%SZ") 94 | item = { 95 | "prompt_name": ACTIVE_PROMPT_NAME, 96 | "prompt": SUGGESTED_QUESTION_PROMPT_CLAUDE3, 97 | "create_time": formatted_time, 98 | } 99 | self.table.put_item(Item=item) 100 | logger.info("Item added successfully to table %s.", self.table_name) 101 | except ClientError as err: 102 | logger.error(type(err)) 103 | logger.error( 104 | "Couldn't create table %s. Here's why: %s: %s", 105 | self.table_name, 106 | err.response["Error"]["Code"], 107 | err.response["Error"]["Message"], 108 | ) 109 | raise 110 | 111 | def get_by_name(self, prompt_name): 112 | response = self.table.get_item(Key={'prompt_name': prompt_name}) 113 | if 'Item' in response: 114 | return SuggestedQuestionEntity(**response['Item']) 115 | 116 | def update(self, entity): 117 | self.table.put_item(Item=entity.to_dict()) 118 | -------------------------------------------------------------------------------- /application/opensearch_deploy.py: -------------------------------------------------------------------------------- 1 | import json 2 | from dotenv import load_dotenv 3 | import os 4 | import boto3 5 | import logging 6 | 7 | from nlq.business.vector_store import VectorStore 8 | from utils.opensearch import get_opensearch_cluster_client, opensearch_index_init 9 | from utils.env_var import opensearch_info 10 | 11 | logger = logging.getLogger(__name__) 12 | 13 | load_dotenv() 14 | 15 | SAGEMAKER_ENDPOINT_EMBEDDING = os.getenv('SAGEMAKER_ENDPOINT_EMBEDDING', '') 16 | 17 | BEDROCK_REGION = os.getenv('BEDROCK_REGION') 18 | 19 | 20 | 21 | def index_to_opensearch(): 22 | 23 | opensearch_client = get_opensearch_cluster_client(opensearch_info["domain"], opensearch_info["host"], opensearch_info["port"], 24 | opensearch_info["username"], opensearch_info["password"], opensearch_info["region"]) 25 | 26 | def create_vector_embedding_with_bedrock(text, index_name, bedrock_client): 27 | payload = {"inputText": f"{text}"} 28 | body = json.dumps(payload) 29 | modelId = "amazon.titan-embed-text-v1" 30 | accept = "application/json" 31 | contentType = "application/json" 32 | 33 | response = bedrock_client.invoke_model( 34 | body=body, modelId=modelId, accept=accept, contentType=contentType 35 | ) 36 | response_body = json.loads(response.get("body").read()) 37 | 38 | embedding = response_body.get("embedding") 39 | return {"_index": index_name, "text": text, "vector_field": embedding} 40 | 41 | def get_bedrock_client(region): 42 | bedrock_client = boto3.client("bedrock-runtime", region_name=region) 43 | return bedrock_client 44 | 45 | def create_vector_embedding_with_sagemaker(text, index_name, sagemaker_client): 46 | model_kwargs = {} 47 | model_kwargs["batch_size"] = 12 48 | model_kwargs["max_length"] = 512 49 | model_kwargs["return_type"] = "dense" 50 | 51 | response_model = sagemaker_client.invoke_endpoint( 52 | EndpointName=SAGEMAKER_ENDPOINT_EMBEDDING, 53 | Body=json.dumps({"inputs": [text], **model_kwargs}), 54 | ContentType="application/json", 55 | ) 56 | # 中文instruction => 为这个句子生成表示以用于检索相关文章: 57 | json_str = response_model["Body"].read().decode("utf8") 58 | json_obj = json.loads(json_str) 59 | embeddings = json_obj["sentence_embeddings"] 60 | return {"_index": index_name, "text": text, "vector_field": embeddings["dense_vecs"][0]} 61 | 62 | def get_sagemaker_client(): 63 | sagemaker_client = boto3.client("sagemaker-runtime") 64 | return sagemaker_client 65 | 66 | # Initialize 67 | if SAGEMAKER_ENDPOINT_EMBEDDING: 68 | dimension = 1024 69 | sagemaker_client = get_sagemaker_client() 70 | else: 71 | dimension = 1536 72 | bedrock_client = get_bedrock_client(BEDROCK_REGION) 73 | 74 | opensearch_index_flag = opensearch_index_init() 75 | if not opensearch_index_flag: 76 | logger.info("OpenSearch Index Create Fail") 77 | else: 78 | current_profile = "entity_insert_test" 79 | entity = "环比" 80 | comment = "环比增长率是指本期和上期相比较的增长率,计算公式为:环比增长率 =(本期数-上期数)/ 上期数 ×100%" 81 | VectorStore.add_entity_sample(current_profile, entity, comment) 82 | 83 | 84 | if __name__ == "__main__": 85 | index_to_opensearch() -------------------------------------------------------------------------------- /application/pages/4_🪙_Schema_Description_Management.py: -------------------------------------------------------------------------------- 1 | import streamlit as st 2 | from dotenv import load_dotenv 3 | from nlq.business.profile import ProfileManagement 4 | from utils.logging import getLogger 5 | from utils.navigation import make_sidebar 6 | 7 | logger = getLogger() 8 | 9 | def main(): 10 | load_dotenv() 11 | logger.info('start schema management') 12 | st.set_page_config(page_title="Schema Management", ) 13 | make_sidebar() 14 | 15 | if 'current_profile' not in st.session_state: 16 | st.session_state['current_profile'] = '' 17 | 18 | if "update_profile" not in st.session_state: 19 | st.session_state.update_profile = False 20 | 21 | if "profiles_list" not in st.session_state: 22 | st.session_state["profiles_list"] = [] 23 | 24 | if 'profiles' not in st.session_state: 25 | all_profiles = ProfileManagement.get_all_profiles_with_info() 26 | st.session_state['profiles'] = all_profiles 27 | st.session_state["profiles_list"] = list(all_profiles.keys()) 28 | 29 | if st.session_state.update_profile: 30 | logger.info("session_state update_profile get_all_profiles_with_info") 31 | all_profiles = ProfileManagement.get_all_profiles_with_info() 32 | st.session_state["profiles_list"] = list(all_profiles.keys()) 33 | st.session_state['profiles'] = all_profiles 34 | st.session_state.update_profile = False 35 | 36 | with st.sidebar: 37 | st.title("Schema Management") 38 | all_profiles_list = st.session_state["profiles_list"] 39 | if st.session_state.current_profile != "" and st.session_state.current_profile in all_profiles_list: 40 | profile_index = all_profiles_list.index(st.session_state.current_profile) 41 | current_profile = st.selectbox("My Data Profiles", all_profiles_list, index=profile_index) 42 | else: 43 | current_profile = st.selectbox("My Data Profiles", all_profiles_list, 44 | index=None, 45 | placeholder="Please select data profile...", key='current_profile_name') 46 | 47 | if current_profile is not None: 48 | st.session_state['current_profile'] = current_profile 49 | profile_detail = ProfileManagement.get_profile_by_name(current_profile) 50 | 51 | selected_table = st.selectbox("Tables", profile_detail.tables, index=None, placeholder="Please select a table") 52 | if selected_table is not None: 53 | table_info = profile_detail.tables_info[selected_table] 54 | if table_info is not None: 55 | table_ddl = table_info['ddl'] 56 | table_desc = table_info['description'] 57 | table_anno = table_info.get('tbl_a') 58 | column_anno = table_info.get('col_a') 59 | 60 | st.caption(f'Table description: {table_desc}') 61 | tbl_annotation = st.text_input('Table annotation', table_anno) 62 | 63 | if column_anno is not None: 64 | col_annotation_text = column_anno 65 | col_annotation = st.text_area('Column annotation', col_annotation_text, height=500) 66 | else: 67 | col_annotation = st.text_area('Column annotation', table_ddl, height=400, help='''e.g. CREATE TABLE employees ( 68 | id INT AUTO_INCREMENT PRIMARY KEY COMMENT 'Unique identifier for each employee', 69 | name VARCHAR(100) NOT NULL COMMENT 'Employee name', 70 | position VARCHAR(50) NOT NULL COMMENT 'Job position, 2 possible values: 'Engineer', 'Manager', 71 | salary DECIMAL(10, 2) COMMENT 'Salary in USD, e.g., 1000.00', 72 | date DATE NOT NULL COMMENT 'Date of joining the company' 73 | ... 74 | ); 75 | ''') 76 | if st.button('Save', type='primary'): 77 | st.session_state.update_profile = True 78 | origin_tables_info = profile_detail.tables_info 79 | origin_table_info = origin_tables_info[selected_table] 80 | origin_table_info['tbl_a'] = tbl_annotation 81 | origin_table_info['col_a'] = col_annotation 82 | ProfileManagement.update_table_def(current_profile, origin_tables_info) 83 | st.success('saved.') 84 | else: 85 | st.info('Please select data profile in the left sidebar.') 86 | 87 | if __name__ == '__main__': 88 | main() 89 | -------------------------------------------------------------------------------- /application/pages/mainpage.py: -------------------------------------------------------------------------------- 1 | import streamlit as st 2 | 3 | from nlq.business.vector_store import VectorStore 4 | from utils.navigation import make_sidebar 5 | from utils.opensearch import opensearch_index_init 6 | from utils.prompts.check_prompt import check_model_id_prompt 7 | 8 | 9 | st.set_page_config( 10 | page_title="Generative BI", 11 | page_icon="👋", 12 | ) 13 | 14 | make_sidebar() 15 | 16 | st.write("## Welcome to Generative BI using RAG on Amazon Web Services!👋") 17 | 18 | st.sidebar.success("Select a demo above.") 19 | 20 | st.markdown( 21 | """ 22 | In the data analysis scenario, analysts often need to write multi-round, complex query statements to obtain business insights. 23 | 24 | Amazon Web Services has built an intelligent data analysis assistant solution to address this scenario. Leveraging the powerful natural language understanding capabilities of large language models, non-technical users can query and analyze data through natural language, without needing to master SQL or other professional skills, helping business users obtain data insights and improve decision-making efficiency. 25 | 26 | This guide is based on services such as Amazon Bedrock, Amazon OpenSearch, and Amazon DynamoDB. 27 | """ 28 | ) 29 | 30 | # Check OpenSearch Index Init and Test Embedding Insert 31 | opensearch_index_init = opensearch_index_init() 32 | if not opensearch_index_init: 33 | st.info("The OpenSearch Index is Error, Please Create OpenSearch Index First!!!") 34 | else: 35 | current_profile = "entity_insert_test" 36 | entity = "Month on month ratio" 37 | comment = "The month on month growth rate refers to the growth rate compared to the previous period, and the calculation formula is: month on month growth rate=(current period number - previous period number)/previous period number x 100%" 38 | VectorStore.add_entity_sample(current_profile, entity, comment) 39 | 40 | check_model_id_prompt() -------------------------------------------------------------------------------- /application/requirements-api.txt: -------------------------------------------------------------------------------- 1 | fastapi~=0.110.1 2 | uvicorn~=0.21.1 3 | websockets==12.0 4 | tabulate~=0.9.0 5 | boto3~=1.33.4 6 | psycopg2-binary==2.9.9 7 | SQLAlchemy==1.4.52 8 | opensearch-py==2.4.2 9 | PyMySQL==1.1.1 10 | python-dotenv~=1.0.0 11 | plotly~=5.18.0 12 | cryptography==42.0.4 13 | langchain~=0.1.11 14 | langchain-core~=0.1.30 15 | sqlparse~=0.4.2 16 | pandas==2.0.3 17 | openpyxl 18 | starrocks==1.0.6 19 | clickhouse-sqlalchemy==0.2.6 20 | sagemaker 21 | python-jose 22 | sqlalchemy-redshift~=0.8.14 23 | numpy==1.26.4 24 | pyhive==0.7.0 25 | thrift==0.20.0 26 | thrift-sasl==0.4.3 27 | sqlalchemy-bigquery==1.11.0 -------------------------------------------------------------------------------- /application/requirements.txt: -------------------------------------------------------------------------------- 1 | streamlit~=1.38.0 2 | streamlit-authenticator~=0.3.3 3 | boto3~=1.33.4 4 | psycopg2-binary==2.9.9 5 | SQLAlchemy==1.4.52 6 | opensearch-py==2.4.2 7 | PyMySQL==1.1.1 8 | python-dotenv~=1.0.0 9 | plotly~=5.18.0 10 | cryptography==42.0.4 11 | langchain~=0.1.11 12 | langchain-core~=0.1.30 13 | sqlparse~=0.4.2 14 | debugpy 15 | pandas==2.0.3 16 | openpyxl 17 | starrocks==1.0.6 18 | clickhouse-sqlalchemy==0.2.6 19 | sagemaker 20 | fastapi~=0.110.1 21 | sqlalchemy-redshift~=0.8.14 22 | numpy==1.26.4 23 | pyhive==0.7.0 24 | thrift==0.20.0 25 | thrift-sasl==0.4.3 26 | sqlalchemy-bigquery==1.11.0 -------------------------------------------------------------------------------- /application/static/components/JsonViewer/json-viewer.css: -------------------------------------------------------------------------------- 1 | .json-viewer { 2 | color: #000; 3 | padding-left: 20px; 4 | } 5 | 6 | .json-viewer ul { 7 | list-style-type: none; 8 | margin: 0; 9 | margin: 0 0 0 1px; 10 | border-left: 1px dotted #ccc; 11 | padding-left: 2em; 12 | } 13 | 14 | .json-viewer .hide { 15 | display: none; 16 | } 17 | 18 | .json-viewer .type-string { 19 | color: #0B7500; 20 | } 21 | 22 | .json-viewer .type-date { 23 | color: #CB7500; 24 | } 25 | 26 | .json-viewer .type-boolean { 27 | color: #1A01CC; 28 | font-weight: bold; 29 | } 30 | 31 | .json-viewer .type-number { 32 | color: #1A01CC; 33 | } 34 | 35 | .json-viewer .type-null, .json-viewer .type-undefined { 36 | color: #90a; 37 | } 38 | 39 | .json-viewer a.list-link { 40 | color: #000; 41 | text-decoration: none; 42 | position: relative; 43 | } 44 | 45 | .json-viewer a.list-link:before { 46 | color: #aaa; 47 | content: "\25BC"; 48 | position: absolute; 49 | display: inline-block; 50 | width: 1em; 51 | left: -1em; 52 | } 53 | 54 | .json-viewer a.list-link.collapsed:before { 55 | content: "\25B6"; 56 | } 57 | 58 | .json-viewer a.list-link.empty:before { 59 | content: ""; 60 | } 61 | 62 | .json-viewer .items-ph { 63 | color: #aaa; 64 | padding: 0 1em; 65 | } 66 | 67 | .json-viewer .items-ph:hover { 68 | text-decoration: underline; 69 | } 70 | -------------------------------------------------------------------------------- /application/tests/unit_tests/test_row_level_security.py: -------------------------------------------------------------------------------- 1 | 2 | import unittest 3 | 4 | from nlq.business.datasource.mysql import MySQLDataSource 5 | from nlq.business.login_user import LoginUser 6 | 7 | 8 | class TestRLS(unittest.TestCase): 9 | def setUp(self): 10 | self.two_table_join_sql = '''SELECT c.`name`, o.`product`, o.`quantity`, o.`territory` 11 | FROM customer c 12 | JOIN orders o ON c.`id` = o.`customer_id` 13 | LIMIT 100''' 14 | 15 | self.two_table_join_sql_with_schema = '''SELECT c.`name`, o.`product`, o.`quantity`, o.`territory` 16 | FROM someschema.customer c 17 | JOIN someschema.orders o ON c.`id` = o.`customer_id` 18 | LIMIT 100''' 19 | 20 | self.two_table_join_sql_with_schema_output = '''SELECT c.`name`, o.`product`, o.`quantity`, o.`territory` 21 | FROM someschema__customer c 22 | JOIN someschema__orders o ON c.`id` = o.`customer_id` 23 | LIMIT 100''' 24 | 25 | self.expected_rls_enabled_sql = ( 26 | "WITH\n" 27 | "/* rls applied */ customer AS (SELECT * FROM customer WHERE created_by = 'admin'),\n" 28 | "/* rls applied */ orders AS (SELECT * FROM orders WHERE territory = 'Asia')\n" 29 | f"{self.two_table_join_sql}") 30 | 31 | self.expected_rls_enabled_sql_with_schema = ( 32 | "WITH\n" 33 | "/* rls applied */ someschema__customer AS (SELECT * FROM someschema.customer WHERE created_by = 'admin'),\n" 34 | "/* rls applied */ someschema__orders AS (SELECT * FROM someschema.orders WHERE territory = 'Asia')\n" 35 | f"{self.two_table_join_sql_with_schema_output}") 36 | 37 | self.base = MySQLDataSource() 38 | 39 | def test_row_level_security_control(self): 40 | test_yaml = '''tables: 41 | - table_name: customer 42 | columns: 43 | - column_name: created_by 44 | column_value: $login_user.username 45 | - table_name: orders 46 | columns: 47 | - column_name: territory 48 | column_value: Asia''' 49 | rls_modified_sql = self.base.row_level_security_control(self.two_table_join_sql, test_yaml, LoginUser('admin')) 50 | 51 | self.assertEqual(self.expected_rls_enabled_sql, rls_modified_sql) 52 | 53 | def test_row_level_security_control_with_schema(self): 54 | test_yaml = '''tables: 55 | - table_name: someschema.customer 56 | columns: 57 | - column_name: created_by 58 | column_value: $login_user.username 59 | - table_name: someschema.orders 60 | columns: 61 | - column_name: territory 62 | column_value: Asia''' 63 | rls_modified_sql = self.base.row_level_security_control(self.two_table_join_sql_with_schema, test_yaml, LoginUser('admin')) 64 | 65 | self.assertEqual(self.expected_rls_enabled_sql_with_schema, rls_modified_sql) 66 | 67 | # 测试不带schema的表名的兼容性 68 | rls_modified_sql2 = self.base.row_level_security_control(self.two_table_join_sql, test_yaml, LoginUser('admin')) 69 | self.assertEqual(self.expected_rls_enabled_sql_with_schema, rls_modified_sql2) 70 | 71 | def test_cte_replace1(self): 72 | original_sql = """SELECT 73 | offer_id, 74 | slot, 75 | total_revenue 76 | FROM 77 | ( 78 | SELECT 79 | offer_id, 80 | slot, 81 | SUM(revenue_valid) AS total_revenue, 82 | ROW_NUMBER() OVER (PARTITION BY offer_id ORDER BY SUM(revenue_valid) DESC) AS rn 83 | FROM 84 | buzz_base_report 85 | GROUP BY 86 | offer_id, slot 87 | ) AS subquery 88 | WHERE 89 | rn <= 5 90 | ORDER BY 91 | offer_id, total_revenue DESC""" 92 | 93 | modified_sql = self.base.replace_table_with_cte(original_sql, { 94 | 'buzz_base_report': "(select * from buzz_base_report where username = 'admin')" 95 | }) 96 | 97 | self.assertEqual("WITH\n/* rls applied */ buzz_base_report AS " 98 | f"(select * from buzz_base_report where username = 'admin')\n{original_sql}", modified_sql) 99 | 100 | def test_cte_replace2(self): 101 | original_sql = self.two_table_join_sql 102 | 103 | modified_sql = self.base.replace_table_with_cte(original_sql, { 104 | 'customer': "(SELECT * FROM customer WHERE created_by = 'admin')", 105 | 'orders': "(SELECT * FROM orders WHERE territory = 'Asia')" 106 | }) 107 | 108 | self.assertEqual(self.expected_rls_enabled_sql, modified_sql) 109 | 110 | def test_cte_replace3(self): 111 | original_sql = """WITH mycte as ( 112 | SELECT c.`name`, o.`product`, o.`quantity`, o.`territory` FROM customer c JOIN orders o ON c.`id` = o.`customer_id` 113 | ) 114 | select * from mycte LIMIT 100""" 115 | 116 | modified_sql = self.base.replace_table_with_cte(original_sql, { 117 | 'customer': "(select * from customer where created_by = 'admin')", 118 | 'orders': "(select * from orders where territory = 'Asia')" 119 | }) 120 | 121 | self.assertEqual("WITH\n" 122 | "/* rls applied */ customer AS (select * from customer where created_by = 'admin'),\n" 123 | "/* rls applied */ orders AS (select * from orders where territory = 'Asia'),\n" 124 | " mycte as (\n" 125 | "SELECT c.`name`, o.`product`, o.`quantity`, o.`territory`" 126 | " FROM customer c" 127 | " JOIN orders o ON c.`id` = o.`customer_id`\n" 128 | ")\n" 129 | "select * from mycte LIMIT 100", modified_sql) 130 | -------------------------------------------------------------------------------- /application/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aws-samples/generative-bi-using-rag/19db1a855e19167008a0adde451099b0c5726559/application/utils/__init__.py -------------------------------------------------------------------------------- /application/utils/apis.py: -------------------------------------------------------------------------------- 1 | import json 2 | 3 | import sqlalchemy as db 4 | from sqlalchemy import text 5 | from utils.env_var import RDS_MYSQL_HOST, RDS_MYSQL_PORT, RDS_MYSQL_USERNAME, RDS_MYSQL_PASSWORD, RDS_MYSQL_DBNAME, RDS_PQ_SCHEMA 6 | import pandas as pd 7 | import sqlparse 8 | from nlq.business.connection import ConnectionManagement 9 | from utils.logging import getLogger 10 | 11 | logger = getLogger() 12 | 13 | ALLOWED_QUERY_TYPES = ['SELECT'] 14 | def query_from_database(p_db_url: str, query, schema=None): 15 | """ 16 | Query the database 17 | """ 18 | try: 19 | if '{RDS_MYSQL_USERNAME}' in p_db_url: 20 | engine = db.create_engine(p_db_url.format( 21 | RDS_MYSQL_HOST=RDS_MYSQL_HOST, 22 | RDS_MYSQL_PORT=RDS_MYSQL_PORT, 23 | RDS_MYSQL_USERNAME=RDS_MYSQL_USERNAME, 24 | RDS_MYSQL_PASSWORD=RDS_MYSQL_PASSWORD, 25 | RDS_MYSQL_DBNAME=RDS_MYSQL_DBNAME, 26 | )) 27 | else: 28 | engine = db.create_engine(p_db_url) 29 | with engine.connect() as connection: 30 | logger.info(f'{query=}') 31 | sanitized_query = sqlparse.format(query, strip_comments=True) 32 | query_type = sqlparse.parse(sanitized_query)[0].get_type() 33 | if query_type not in ALLOWED_QUERY_TYPES: 34 | return {"status": "error", "message": f"Query type '{query_type}' is not allowed."} 35 | # if schema and 'postgres' in p_db_url: 36 | # query = f'SET search_path TO {schema}; {query}' 37 | cursor = connection.execute(text(sanitized_query)) 38 | results = cursor.fetchall() 39 | columns = list(cursor.keys()) 40 | except ValueError as e: 41 | logger.exception(e) 42 | return {"status": "error", "message": str(e)} 43 | return { 44 | "status": "ok", 45 | "data": str(results), 46 | "query": sanitized_query, 47 | "columns": columns 48 | } 49 | 50 | 51 | def query_from_sql_pd(p_db_url: str, query, schema=None): 52 | """ 53 | Query the database 54 | """ 55 | if '{RDS_MYSQL_USERNAME}' in p_db_url: 56 | engine = db.create_engine(p_db_url.format( 57 | RDS_MYSQL_HOST=RDS_MYSQL_HOST, 58 | RDS_MYSQL_PORT=RDS_MYSQL_PORT, 59 | RDS_MYSQL_USERNAME=RDS_MYSQL_USERNAME, 60 | RDS_MYSQL_PASSWORD=RDS_MYSQL_PASSWORD, 61 | RDS_MYSQL_DBNAME=RDS_MYSQL_DBNAME, 62 | )) 63 | else: 64 | engine = db.create_engine(p_db_url) 65 | 66 | with engine.connect() as connection: 67 | logger.info(f'{query=}') 68 | res = pd.DataFrame() 69 | try: 70 | res = pd.read_sql_query(text(query), connection) 71 | except Exception as e: 72 | logger.error("query_from_sql_pd is error") 73 | logger.error(e) 74 | return res 75 | 76 | def get_sql_result_tool(profile, sql): 77 | result_dict = {"data": pd.DataFrame(), "sql": sql, "status_code": 200, "error_info": ""} 78 | try: 79 | p_db_url = profile['db_url'] 80 | if not p_db_url: 81 | conn_name = profile['conn_name'] 82 | p_db_url = ConnectionManagement.get_db_url_by_name(conn_name) 83 | 84 | if '{RDS_MYSQL_USERNAME}' in p_db_url: 85 | engine = db.create_engine(p_db_url.format( 86 | RDS_MYSQL_HOST=RDS_MYSQL_HOST, 87 | RDS_MYSQL_PORT=RDS_MYSQL_PORT, 88 | RDS_MYSQL_USERNAME=RDS_MYSQL_USERNAME, 89 | RDS_MYSQL_PASSWORD=RDS_MYSQL_PASSWORD, 90 | RDS_MYSQL_DBNAME=RDS_MYSQL_DBNAME, 91 | )) 92 | else: 93 | if profile['db_type'] == "bigquery": 94 | password, host = ConnectionManagement.get_db_password_host_by_name(profile['conn_name']) 95 | password = json.loads(password) 96 | engine = db.create_engine(url=host, credentials_info=password) 97 | else: 98 | engine = db.create_engine(p_db_url) 99 | with engine.connect() as connection: 100 | logger.info(f'{sql=}') 101 | executed_result_df = pd.read_sql_query(text(sql), connection) 102 | result_dict["data"] = executed_result_df.fillna("") 103 | except Exception as e: 104 | logger.error("get_sql_result is error: {}".format(e)) 105 | result_dict["error_info"] = str(e) 106 | result_dict["status_code"] = 500 107 | result_dict["data"] = [] 108 | return result_dict 109 | -------------------------------------------------------------------------------- /application/utils/auth.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | from fastapi import status, Request 4 | from fastapi.responses import Response 5 | from jose import jwt 6 | import requests 7 | import os 8 | 9 | from utils.logging import getLogger 10 | 11 | VITE_COGNITO_REGION = os.getenv("VITE_COGNITO_REGION") 12 | USER_POOL_ID = os.getenv("VITE_COGNITO_USER_POOL_ID") 13 | CLIENT_ID = os.getenv("VITE_COGNITO_USER_POOL_WEB_CLIENT_ID") 14 | AUTH_PATH = os.getenv("COGNITO_AUTH_PATH") 15 | USER_ROLES_CLAIM = os.getenv("USER_ROLES_CLAIM", "cognito:groups") 16 | AWS_DEFAULT_REGION = os.getenv("AWS_DEFAULT_REGION") 17 | skipAuthentication = AWS_DEFAULT_REGION.startswith("cn") 18 | 19 | JWKS_URL = os.getenv("JWKS_URL", 20 | f"https://cognito-idp.{VITE_COGNITO_REGION}.amazonaws.com/{USER_POOL_ID}/" ".well-known/jwks.json") 21 | 22 | TOKEN_URL = f"{AUTH_PATH}/oauth2/token" 23 | 24 | logger = getLogger() 25 | 26 | def jwt_decode(token, audience=None, access_token=None): 27 | return jwt.decode( 28 | token, requests.get(JWKS_URL).json(), audience=audience, access_token=access_token, algorithms=["RS256"] 29 | ) 30 | 31 | class RefreshTokenError(Exception): 32 | ERROR_FMT = 'Refresh token error: {description}' 33 | description = 'Refresh token flow failed' 34 | 35 | def __init__(self, description=None): 36 | if description: 37 | self.description = self.ERROR_FMT.format(description=description) 38 | 39 | def __str__(self): 40 | return self.description 41 | 42 | def refresh_tokens(refresh_token): 43 | 44 | resp = requests.post( 45 | TOKEN_URL, 46 | data={"grant_type": 'refresh_token', "refresh_token": refresh_token, "client_id": CLIENT_ID}, 47 | headers={"Content-Type": "application/x-www-form-urlencoded"}, 48 | ) 49 | 50 | if resp.status_code != 200: 51 | raise RefreshTokenError(resp.json().get('error')) 52 | 53 | values = resp.json() 54 | access_token = values.get("access_token") 55 | id_token = values.get("id_token") 56 | 57 | return {'accessToken': access_token, 'idToken': id_token} 58 | 59 | def get_cognito_identity_from_token(decoded, claims): 60 | identity = {"attributes": {}} 61 | 62 | if USER_ROLES_CLAIM in decoded: 63 | identity["user_roles"] = decoded[USER_ROLES_CLAIM] 64 | if "username" in decoded: 65 | identity["username"] = decoded["username"] 66 | 67 | for claim in claims: 68 | if claim in decoded: 69 | identity["attributes"][claim] = decoded[claim] 70 | 71 | return identity 72 | 73 | def authenticate(access_token, id_token, refresh_token): 74 | if access_token and access_token.startswith("Bearer "): 75 | access_token = access_token[len("Bearer "):] 76 | 77 | if id_token and id_token.startswith("Bearer "): 78 | id_token = id_token[len("Bearer "):] 79 | 80 | if refresh_token and refresh_token.startswith("Bearer "): 81 | refresh_token = refresh_token[len("Bearer "):] 82 | 83 | if access_token is None or id_token is None or refresh_token is None: 84 | response = {} 85 | response['X-Status-Code'] = status.HTTP_401_UNAUTHORIZED 86 | return response 87 | 88 | if len(access_token.strip()) < 2 or len(id_token.strip()) < 2 or len(refresh_token.strip()) < 2: 89 | response = {} 90 | response['X-Status-Code'] = status.HTTP_401_UNAUTHORIZED 91 | return response 92 | 93 | # print('---ACCESS TOKEN---', access_token) 94 | # print('---ID TOKEN---', id_token) 95 | # print('---REFRESH TOKEN---', refresh_token) 96 | 97 | if not access_token or not id_token or not refresh_token: 98 | print('Token: one of token is none') 99 | response = {} 100 | response['X-Status-Code'] = status.HTTP_401_UNAUTHORIZED 101 | return response 102 | try: 103 | decoded = jwt_decode(access_token) 104 | # print('Token decoded:', decoded) 105 | 106 | except Exception as e: 107 | logger.error('Token decode exception: ', str(e)) 108 | response = {} 109 | response['X-Status-Code'] = status.HTTP_401_UNAUTHORIZED 110 | return response 111 | 112 | response = {} 113 | response['X-Status-Code'] = status.HTTP_200_OK 114 | 115 | claims = ["email"] 116 | identity = get_cognito_identity_from_token(decoded=decoded, claims=claims) 117 | 118 | print("Identity:", identity) 119 | 120 | if id_token: 121 | decoded_id = jwt_decode(id_token, audience=CLIENT_ID, access_token=access_token) 122 | identity_from_id_token = get_cognito_identity_from_token(decoded=decoded_id, claims=claims) 123 | identity.update(identity_from_id_token) 124 | 125 | response["X-User-Name"] = identity["username"] 126 | #response["X-Email"] = identity["attributes"]["email"] 127 | return response 128 | -------------------------------------------------------------------------------- /application/utils/constant.py: -------------------------------------------------------------------------------- 1 | # Suggested Question 2 | PROFILE_QUESTION_TABLE_NAME = 'NlqSuggestedQuestion' 3 | DEFAULT_PROMPT_NAME = 'suggested_question_prompt_default' 4 | ACTIVE_PROMPT_NAME = 'suggested_question_prompt_active' 5 | BEDROCK_MODEL_IDS = ['anthropic.claude-3-sonnet-20240229-v1:0', 'anthropic.claude-3-5-sonnet-20240620-v1:0', 6 | 'anthropic.claude-3-haiku-20240307-v1:0', 7 | 'mistral.mixtral-8x7b-instruct-v0:1', 'meta.llama3-70b-instruct-v1:0'] -------------------------------------------------------------------------------- /application/utils/database.py: -------------------------------------------------------------------------------- 1 | import sqlalchemy as db 2 | from utils.env_var import RDS_MYSQL_HOST, RDS_MYSQL_PORT, RDS_MYSQL_USERNAME, RDS_MYSQL_PASSWORD, RDS_MYSQL_DBNAME, \ 3 | RDS_PQ_SCHEMA 4 | 5 | 6 | def get_all_table_names(db_url: str, is_sample_db: bool, schema: str = None): 7 | if is_sample_db: 8 | print('checking connection...') 9 | db_url = db_url.format( 10 | RDS_MYSQL_HOST=RDS_MYSQL_HOST, 11 | RDS_MYSQL_PORT=RDS_MYSQL_PORT, 12 | RDS_MYSQL_USERNAME=RDS_MYSQL_USERNAME, 13 | RDS_MYSQL_PASSWORD=RDS_MYSQL_PASSWORD, 14 | RDS_MYSQL_DBNAME=RDS_MYSQL_DBNAME, 15 | ) 16 | engine = db.create_engine(db_url) 17 | with engine.connect() as connection: 18 | print('connected to database') 19 | 20 | metadata = db.MetaData() 21 | if schema: 22 | metadata.reflect(bind=connection, schema=schema) 23 | else: 24 | metadata.reflect(bind=connection) 25 | tables = metadata.tables 26 | table_name_list = [] 27 | 28 | for table_name, _ in tables.items(): 29 | table_name_list.append(table_name) 30 | 31 | return table_name_list 32 | 33 | 34 | def get_db_url_dialect(db_url: str) -> str: 35 | return db_url.split("://")[0].split('+')[0] 36 | 37 | 38 | def get_dll_for_tables(db_url: str, is_sample_db: bool, schema: str = None, selected_tables: list = []): 39 | pass -------------------------------------------------------------------------------- /application/utils/domain.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import Any 3 | 4 | 5 | @dataclass 6 | class SearchTextSqlResult: 7 | search_query: str 8 | entity_slot_retrieve: list 9 | retrieve_result: list 10 | response: str 11 | sql: str 12 | '''Origin sql before post processing''' 13 | original_sql: str = '' 14 | 15 | 16 | @dataclass 17 | class ModelResponse: 18 | response: str = '' 19 | text: str = '' 20 | token_info: dict[str, Any] = None 21 | -------------------------------------------------------------------------------- /application/utils/env_var.py: -------------------------------------------------------------------------------- 1 | import json 2 | import logging 3 | import os 4 | import boto3 5 | from botocore.exceptions import ClientError 6 | from dotenv import load_dotenv 7 | 8 | load_dotenv() 9 | 10 | RDS_MYSQL_USERNAME = os.getenv('RDS_MYSQL_USERNAME') 11 | RDS_MYSQL_PASSWORD = os.getenv('RDS_MYSQL_PASSWORD') 12 | RDS_MYSQL_HOST = os.getenv('RDS_MYSQL_HOST') 13 | RDS_MYSQL_PORT = os.getenv('RDS_MYSQL_PORT') 14 | RDS_MYSQL_DBNAME = os.getenv('RDS_MYSQL_DBNAME') 15 | 16 | RDS_PQ_SCHEMA = os.getenv('RDS_PQ_SCHEMA') 17 | 18 | BEDROCK_REGION = os.getenv('BEDROCK_REGION') 19 | 20 | DYNAMODB_AWS_REGION = os.getenv('DYNAMODB_AWS_REGION') 21 | OPENSEARCH_REGION = os.getenv('AOS_AWS_REGION') 22 | 23 | AOS_HOST = os.getenv('AOS_HOST') 24 | AOS_PORT = os.getenv('AOS_PORT') 25 | AOS_USER = os.getenv('AOS_USER') 26 | AOS_PASSWORD = os.getenv('AOS_PASSWORD') 27 | AOS_DOMAIN = os.getenv('AOS_DOMAIN') 28 | 29 | AOS_INDEX = os.getenv('AOS_INDEX') 30 | AOS_INDEX_NER = os.getenv('AOS_INDEX_NER') 31 | AOS_INDEX_AGENT = os.getenv('AOS_INDEX_AGENT') 32 | QUERY_LOG_INDEX = os.getenv('QUERY_LOG_INDEX') 33 | 34 | EMBEDDING_DIMENSION = os.getenv('EMBEDDING_DIMENSION') 35 | 36 | AWS_DEFAULT_REGION = os.getenv('AWS_DEFAULT_REGION') 37 | 38 | OPENSEARCH_TYPE = os.getenv('OPENSEARCH_TYPE') 39 | 40 | OPENSEARCH_SECRETS_URL_HOST = os.getenv('OPENSEARCH_SECRETS_URL_HOST', 'opensearch-host-url') 41 | 42 | OPENSEARCH_SECRETS_USERNAME_PASSWORD = os.getenv('OPENSEARCH_SECRETS_USERNAME_PASSWORD', 'opensearch-master-user') 43 | 44 | BEDROCK_SECRETS_AK_SK = os.getenv('BEDROCK_SECRETS_AK_SK', '') 45 | 46 | SAGEMAKER_SQL_REGION = os.getenv('SAGEMAKER_SQL_REGION', '') 47 | 48 | 49 | def get_opensearch_parameter(): 50 | try: 51 | session = boto3.session.Session() 52 | sm_client = session.client(service_name='secretsmanager', region_name=AWS_DEFAULT_REGION) 53 | master_user = sm_client.get_secret_value(SecretId=OPENSEARCH_SECRETS_URL_HOST)['SecretString'] 54 | data = json.loads(master_user) 55 | es_host_name = data.get('host') 56 | # cluster endpoint, for example: my-test-domain.us-east-1.es.amazonaws.com/ 57 | # host = es_host_name + '/' if es_host_name[-1] != '/' else es_host_name 58 | host = es_host_name 59 | 60 | sm_client = session.client(service_name='secretsmanager', region_name=AWS_DEFAULT_REGION) 61 | master_user = sm_client.get_secret_value(SecretId=OPENSEARCH_SECRETS_USERNAME_PASSWORD)['SecretString'] 62 | data = json.loads(master_user) 63 | username = data.get('username') 64 | password = data.get('password') 65 | port = 443 66 | return host, port, username, password 67 | except ClientError as e: 68 | # For a list of exceptions thrown, see 69 | # https://docs.aws.amazon.com/secretsmanager/latest/apireference/API_GetSecretValue.html 70 | raise e 71 | 72 | 73 | def get_bedrock_parameter(): 74 | bedrock_ak_sk_info = {} 75 | try: 76 | session = boto3.session.Session() 77 | sm_client = session.client(service_name='secretsmanager', region_name=AWS_DEFAULT_REGION) 78 | if BEDROCK_SECRETS_AK_SK is not None and BEDROCK_SECRETS_AK_SK != "": 79 | bedrock_info = sm_client.get_secret_value(SecretId=BEDROCK_SECRETS_AK_SK)['SecretString'] 80 | data = json.loads(bedrock_info) 81 | access_key = data.get('access_key_id') 82 | secret_key = data.get('secret_access_key') 83 | bedrock_ak_sk_info['access_key_id'] = access_key 84 | bedrock_ak_sk_info['secret_access_key'] = secret_key 85 | else: 86 | return bedrock_ak_sk_info 87 | except ClientError as e: 88 | logging.error(e) 89 | return bedrock_ak_sk_info 90 | 91 | 92 | if OPENSEARCH_TYPE == "service": 93 | opensearch_host, opensearch_port, opensearch_username, opensearch_password = get_opensearch_parameter() 94 | AOS_HOST = opensearch_host 95 | AOS_PORT = opensearch_port 96 | AOS_USER = opensearch_username 97 | AOS_PASSWORD = opensearch_password 98 | 99 | opensearch_info = { 100 | 'host': AOS_HOST, 101 | 'port': AOS_PORT, 102 | 'username': AOS_USER, 103 | 'password': AOS_PASSWORD, 104 | 'domain': AOS_DOMAIN, 105 | 'region': OPENSEARCH_REGION, 106 | 'sql_index': AOS_INDEX, 107 | 'ner_index': AOS_INDEX_NER, 108 | 'agent_index': AOS_INDEX_AGENT, 109 | 'embedding_dimension': EMBEDDING_DIMENSION 110 | } 111 | 112 | query_log_name = os.getenv("QUERY_LOG_INDEX", "genbi_query_logging") 113 | 114 | bedrock_ak_sk_info = get_bedrock_parameter() 115 | 116 | embedding_info = { 117 | "embedding_platform": os.getenv('EMBEDDING_PLATFORM', "bedrock"), 118 | "embedding_name": os.getenv('EMBEDDING_NAME', "amazon.titan-embed-text-v1"), 119 | "embedding_dimension": int(os.getenv('EMBEDDING_DIMENSION', 1536)), 120 | "embedding_region": os.getenv('EMBEDDING_REGION', AWS_DEFAULT_REGION) 121 | } 122 | 123 | if embedding_info["embedding_platform"] == "bedrock": 124 | SAGEMAKER_EMBEDDING_REGION = "" 125 | SAGEMAKER_ENDPOINT_EMBEDDING = "" 126 | else: 127 | SAGEMAKER_EMBEDDING_REGION = embedding_info["embedding_region"] 128 | SAGEMAKER_ENDPOINT_EMBEDDING = embedding_info["embedding_name"] 129 | -------------------------------------------------------------------------------- /application/utils/logging.py: -------------------------------------------------------------------------------- 1 | import os 2 | import logging 3 | 4 | # 设置日志级别 5 | LOG_LEVEL = os.environ.get('LOG_LEVEL', 'INFO').upper() 6 | 7 | logger = None 8 | 9 | 10 | def getLogger(): 11 | global logger 12 | if logger is not None: 13 | return logger 14 | 15 | # 创建日志记录器 16 | logger = logging.getLogger('application') 17 | logger.propagate = False 18 | logger.setLevel(LOG_LEVEL) 19 | # 创建控制台处理器 20 | console_handler = logging.StreamHandler() 21 | console_handler.setLevel(LOG_LEVEL) 22 | 23 | # 设置日志格式 24 | log_format = '%(asctime)s [%(module)s L%(lineno)d] [%(levelname)s] %(message)s' 25 | formatter = logging.Formatter(log_format) 26 | 27 | # 设置日志处理器格式 28 | console_handler.setFormatter(formatter) 29 | 30 | # 清理旧的日志处理器 31 | logger.handlers.clear() 32 | # 添加日志处理器 33 | logger.addHandler(console_handler) 34 | 35 | return logger 36 | -------------------------------------------------------------------------------- /application/utils/navigation.py: -------------------------------------------------------------------------------- 1 | import streamlit as st 2 | import streamlit_authenticator as stauth 3 | from time import sleep 4 | from streamlit.runtime.scriptrunner import get_script_run_ctx 5 | from streamlit.source_util import get_pages 6 | import yaml 7 | from yaml.loader import SafeLoader 8 | 9 | 10 | def get_authenticator(): 11 | with open('config_files/stauth_config.yaml') as file: 12 | config = yaml.load(file, Loader=SafeLoader) 13 | 14 | return stauth.Authenticate( 15 | config['credentials'], 16 | config['cookie']['name'], 17 | config['cookie']['key'], 18 | float(config['cookie']['expiry_days']), 19 | config['pre-authorized'] 20 | ) 21 | 22 | 23 | def get_current_page_name(): 24 | ctx = get_script_run_ctx() 25 | if ctx is None: 26 | raise RuntimeError("Couldn't get script context") 27 | 28 | pages = get_pages("") 29 | 30 | return pages[ctx.page_script_hash]["page_name"] 31 | 32 | 33 | def make_sidebar(): 34 | with st.sidebar: 35 | if st.session_state.get('authentication_status'): 36 | st.page_link("pages/mainpage.py", label="Index") 37 | st.page_link("pages/1_🌍_Generative_BI_Playground.py", label="Generative BI Playground", icon="🌍") 38 | st.markdown(":gray[Data Customization Management]", 39 | help='Add your own datasources and customize description for LLM to better understand them') 40 | st.page_link("pages/2_🪙_Data_Connection_Management.py", label="Data Connection Management", icon="🪙") 41 | st.page_link("pages/3_🪙_Data_Profile_Management.py", label="Data Profile Management", icon="🪙") 42 | st.page_link("pages/4_🪙_Schema_Description_Management.py", label="Schema Description Management", icon="🪙") 43 | st.page_link("pages/5_🪙_Prompt_Management.py", label="Prompt Management", icon="🪙") 44 | st.markdown(":gray[Performance Enhancement]", 45 | help='Optimize your LLM for better performance by adding RAG or agent') 46 | st.page_link("pages/6_📚_Index_Management.py", label="Index Management", icon="📚") 47 | st.page_link("pages/7_📚_Entity_Management.py", label="Entity Management", icon="📚") 48 | st.page_link("pages/8_📚_Agent_Cot_Management.py", label="Agent Cot Management", icon="📚") 49 | st.page_link("pages/9_🪙_SageMaker_Model_Management.py", label="SageMaker Model Management", icon="🪙") 50 | st.page_link("pages/10_📚_User_Authorization.py", label="User Authorization Management", icon="📚") 51 | 52 | if st.button("Log out"): 53 | logout() 54 | 55 | elif get_current_page_name() != "Index": 56 | # If anyone tries to access a secret page without being logged in, 57 | # redirect them to the login page 58 | st.switch_page("Index.py") 59 | 60 | 61 | def logout(): 62 | authenticator = get_authenticator() 63 | authenticator.logout('Logout', 'unrendered') 64 | st.info("Logged out successfully!") 65 | sleep(0.5) 66 | st.switch_page("Index.py") 67 | -------------------------------------------------------------------------------- /application/utils/prompts/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aws-samples/generative-bi-using-rag/19db1a855e19167008a0adde451099b0c5726559/application/utils/prompts/__init__.py -------------------------------------------------------------------------------- /application/utils/prompts/guidance_prompt.py: -------------------------------------------------------------------------------- 1 | guidance_prompt_dict = {} 2 | guidance_prompt_dict['haiku-20240307v1-0'] = """ 3 | you should always keep the words from question unchanges when writing SQL. \n\n 4 | """ 5 | 6 | guidance_prompt_dict['sonnet-20240229v1-0'] = """ 7 | 8 | """ 9 | 10 | class GuidancePromptMapper: 11 | def __init__(self): 12 | self.variable_map = guidance_prompt_dict 13 | 14 | def get_variable(self, name): 15 | return self.variable_map.get(name) -------------------------------------------------------------------------------- /application/utils/prompts/table_prompt.py: -------------------------------------------------------------------------------- 1 | table_prompt_dict = {} 2 | table_prompt_dict['haiku-20240307v1-0']=""" 3 | 4 | """ 5 | 6 | table_prompt_dict['sonnet-20240229v1-0']=""" 7 | 8 | """ 9 | 10 | class TablePromptMapper: 11 | def __init__(self): 12 | self.variable_map = table_prompt_dict 13 | 14 | def get_variable(self, name): 15 | return self.variable_map.get(name) -------------------------------------------------------------------------------- /application/utils/tool.py: -------------------------------------------------------------------------------- 1 | import json 2 | import time 3 | import random 4 | import datetime 5 | 6 | import pandas as pd 7 | 8 | from utils.logging import getLogger 9 | 10 | logger = getLogger() 11 | 12 | def get_generated_sql(generated_sql_response): 13 | sql = "" 14 | try: 15 | if "" in generated_sql_response: 16 | return generated_sql_response.split("")[1].split("")[0] 17 | elif "```sql" in generated_sql_response: 18 | return generated_sql_response.split("```sql")[1].split("```")[0] 19 | except IndexError: 20 | logger.error("No SQL found in the LLM's response") 21 | logger.error(generated_sql_response) 22 | return sql 23 | 24 | 25 | def generate_log_id(): 26 | timestamp = int(time.time() * 1000000) 27 | random_part = random.randint(0, 9999) 28 | log_id = f"{timestamp}{random_part:04d}" 29 | return log_id 30 | 31 | 32 | def get_current_time(): 33 | now = datetime.datetime.now() 34 | formatted_time = now.strftime('%Y-%m-%d %H:%M:%S') 35 | return formatted_time 36 | 37 | 38 | def get_generated_sql_explain(generated_sql_response): 39 | try: 40 | if "" in generated_sql_response: 41 | return generated_sql_response.split("")[1].split("")[1] 42 | elif "```sql" in generated_sql_response: 43 | return generated_sql_response.split("```sql")[1].split("```")[1] 44 | else: 45 | return generated_sql_response 46 | except IndexError: 47 | logger.error("No generated found in the LLM's response") 48 | logger.error(generated_sql_response) 49 | return generated_sql_response 50 | 51 | def change_class_to_str(result): 52 | try: 53 | log_info = json.dumps(result.dict(), default=serialize_timestamp) 54 | return log_info 55 | except Exception as e: 56 | logger.error(f"Error in changing class to string: {e}") 57 | return "" 58 | 59 | 60 | def serialize_timestamp(obj): 61 | """ 62 | Custom serialization function for handling objects of types Timestamp and Datetime.date 63 | :param obj: 64 | :return: 65 | """ 66 | if isinstance(obj, pd.Timestamp): 67 | return obj.strftime('%Y-%m-%d %H:%M:%S') 68 | elif isinstance(obj, datetime.date): 69 | return obj.strftime('%Y-%m-%d %H:%M:%S') 70 | elif isinstance(obj, list): 71 | return [serialize_timestamp(item) for item in obj] 72 | elif isinstance(obj, dict): 73 | return {k: serialize_timestamp(v) for k, v in obj.items()} 74 | raise TypeError(f'Object of type {obj.__class__.__name__} is not JSON serializable') 75 | 76 | 77 | def convert_timestamps_to_str(data): 78 | # Convert all Timestamp objects in the data to strings 79 | try: 80 | converted_data = [] 81 | for row in data: 82 | new_row = [] 83 | for item in row: 84 | if isinstance(item, pd.Timestamp): 85 | # Convert Timestamp to string 86 | new_row.append(item.strftime('%Y-%m-%d %H:%M:%S')) 87 | elif isinstance(item, datetime.date): 88 | # Convert datetime.date to string 89 | new_row.append(item.strftime('%Y-%m-%d %H:%M:%S')) 90 | else: 91 | new_row.append(item) 92 | converted_data.append(new_row) 93 | return converted_data 94 | except Exception as e: 95 | logger.error(f"Error in converting timestamps to strings: {e}") 96 | return data 97 | -------------------------------------------------------------------------------- /assets/add_database_connect.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aws-samples/generative-bi-using-rag/19db1a855e19167008a0adde451099b0c5726559/assets/add_database_connect.png -------------------------------------------------------------------------------- /assets/add_index_sample.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aws-samples/generative-bi-using-rag/19db1a855e19167008a0adde451099b0c5726559/assets/add_index_sample.png -------------------------------------------------------------------------------- /assets/add_schema_management.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aws-samples/generative-bi-using-rag/19db1a855e19167008a0adde451099b0c5726559/assets/add_schema_management.png -------------------------------------------------------------------------------- /assets/architecture.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aws-samples/generative-bi-using-rag/19db1a855e19167008a0adde451099b0c5726559/assets/architecture.png -------------------------------------------------------------------------------- /assets/aws_architecture.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aws-samples/generative-bi-using-rag/19db1a855e19167008a0adde451099b0c5726559/assets/aws_architecture.png -------------------------------------------------------------------------------- /assets/bedrock_model_access.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aws-samples/generative-bi-using-rag/19db1a855e19167008a0adde451099b0c5726559/assets/bedrock_model_access.png -------------------------------------------------------------------------------- /assets/create_data_profile.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aws-samples/generative-bi-using-rag/19db1a855e19167008a0adde451099b0c5726559/assets/create_data_profile.png -------------------------------------------------------------------------------- /assets/interface.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aws-samples/generative-bi-using-rag/19db1a855e19167008a0adde451099b0c5726559/assets/interface.png -------------------------------------------------------------------------------- /assets/logic.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aws-samples/generative-bi-using-rag/19db1a855e19167008a0adde451099b0c5726559/assets/logic.png -------------------------------------------------------------------------------- /assets/react_deploy.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aws-samples/generative-bi-using-rag/19db1a855e19167008a0adde451099b0c5726559/assets/react_deploy.png -------------------------------------------------------------------------------- /assets/screenshot-genbi.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aws-samples/generative-bi-using-rag/19db1a855e19167008a0adde451099b0c5726559/assets/screenshot-genbi.png -------------------------------------------------------------------------------- /assets/streamlit_deploy.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aws-samples/generative-bi-using-rag/19db1a855e19167008a0adde451099b0c5726559/assets/streamlit_deploy.png -------------------------------------------------------------------------------- /assets/streamlit_front.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aws-samples/generative-bi-using-rag/19db1a855e19167008a0adde451099b0c5726559/assets/streamlit_front.png -------------------------------------------------------------------------------- /assets/update_data_profile.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aws-samples/generative-bi-using-rag/19db1a855e19167008a0adde451099b0c5726559/assets/update_data_profile.png -------------------------------------------------------------------------------- /assets/update_schema_management.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aws-samples/generative-bi-using-rag/19db1a855e19167008a0adde451099b0c5726559/assets/update_schema_management.png -------------------------------------------------------------------------------- /assets/user_front_end_cn.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aws-samples/generative-bi-using-rag/19db1a855e19167008a0adde451099b0c5726559/assets/user_front_end_cn.png -------------------------------------------------------------------------------- /report-front-end/.env: -------------------------------------------------------------------------------- 1 | # General configuration 2 | VITE_APP_VERSION=1.8.0 3 | VITE_TITLE=Guidance for Generative BI on Amazon Web Services 4 | VITE_LOGO= 5 | VITE_RIGHT_LOGO= 6 | VITE_LOGO_DISPLAY_ON_LOGIN_PAGE=true 7 | # VITE_TITLE=Guidance for Generative BI on Amazon Web Services 8 | # VITE_LOGO=/log.png 9 | # VITE_RIGHT_LOGO=/logo.png 10 | 11 | 12 | # Login configuration, e.g. Cognito | None 13 | VITE_LOGIN_TYPE=PLACEHOLDER_VITE_LOGIN_TYPE 14 | 15 | # KEEP the placeholder values if using CDK to deploy the backend! 16 | 17 | # Cognito configuration 18 | VITE_COGNITO_REGION=PLACEHOLDER_VITE_COGNITO_REGION 19 | VITE_COGNITO_USER_POOL_ID=PLACEHOLDER_VITE_COGNITO_USER_POOL_ID 20 | VITE_COGNITO_USER_POOL_WEB_CLIENT_ID=PLACEHOLDER_VITE_COGNITO_USER_POOL_WEB_CLIENT_ID 21 | 22 | # Chat bot configuration 23 | # VITE_SQL_DISPLAY= 24 | VITE_SQL_DISPLAY=yes 25 | 26 | # FastAPI configuration, e.g. http://xxxxxxxx:8000/ 27 | VITE_BACKEND_URL=PLACEHOLDER_VITE_BACKEND_URL 28 | 29 | # Websocket configuration, e.g. ws://34.208.51.119:8000/qa/ws 30 | VITE_WEBSOCKET_URL=PLACEHOLDER_VITE_WEBSOCKET_URL 31 | 32 | # SSO CONFIGS: info for logging in with Single-Sign-On 33 | VITE_USE_SSO_LOGIN=false 34 | VITE_SSO_FED_AUTH_PROVIDER=vite.auth.provider 35 | VITE_SSO_OAUTH_DOMAIN=vite-domain.auth.region.amazoncognito.com 36 | -------------------------------------------------------------------------------- /report-front-end/.env.template: -------------------------------------------------------------------------------- 1 | # General configuration 2 | VITE_TITLE=Guidance for Generative BI on Amazon Web Services 3 | VITE_LOGO= 4 | VITE_RIGHT_LOGO= 5 | VITE_LOGO_DISPLAY_ON_LOGIN_PAGE=true 6 | # VITE_TITLE=Guidance for Generative BI on Amazon Web Services 7 | # VITE_LOGO=/log.png 8 | # VITE_RIGHT_LOGO= 9 | 10 | 11 | 12 | # Login configuration, e.g. Cognito | None 13 | VITE_LOGIN_TYPE=PLACEHOLDER_VITE_LOGIN_TYPE 14 | 15 | # KEEP the placeholder values if using CDK to deploy the backend! 16 | 17 | # Cognito configuration 18 | VITE_COGNITO_REGION=PLACEHOLDER_VITE_COGNITO_REGION 19 | VITE_COGNITO_USER_POOL_ID=PLACEHOLDER_VITE_COGNITO_USER_POOL_ID 20 | VITE_COGNITO_USER_POOL_WEB_CLIENT_ID=PLACEHOLDER_VITE_COGNITO_USER_POOL_WEB_CLIENT_ID 21 | 22 | # Chat bot configuration 23 | # VITE_SQL_DISPLAY= 24 | VITE_SQL_DISPLAY=yes 25 | 26 | # FastAPI configuration, e.g. http://xxxxxxxx:8000/ 27 | VITE_BACKEND_URL=PLACEHOLDER_VITE_BACKEND_URL 28 | 29 | # Websocket configuration, e.g. ws://34.208.51.119:8000/qa/ws 30 | VITE_WEBSOCKET_URL=PLACEHOLDER_VITE_WEBSOCKET_URL 31 | -------------------------------------------------------------------------------- /report-front-end/.eslintrc.cjs: -------------------------------------------------------------------------------- 1 | module.exports = { 2 | root: true, 3 | env: { browser: true, es2020: true }, 4 | extends: [ 5 | "eslint:recommended", 6 | "plugin:@typescript-eslint/recommended", 7 | "plugin:react-hooks/recommended", 8 | ], 9 | ignorePatterns: ["dist", ".eslintrc.cjs"], 10 | parser: "@typescript-eslint/parser", 11 | plugins: ["react-refresh"], 12 | rules: { 13 | "react-refresh/only-export-components": [ 14 | "warn", 15 | { allowConstantExport: true }, 16 | ], 17 | "@typescript-eslint/no-explicit-any": "off", 18 | "no-extra-semi": "off", 19 | }, 20 | }; 21 | -------------------------------------------------------------------------------- /report-front-end/.gitignore: -------------------------------------------------------------------------------- 1 | # See https://help.github.com/articles/ignoring-files/ for more about ignoring files. 2 | 3 | # dependencies 4 | /node_modules 5 | /.pnp 6 | .pnp.js 7 | 8 | # testing 9 | /coverage 10 | 11 | # production 12 | /build 13 | /dist 14 | 15 | # misc 16 | .DS_Store 17 | .env.local 18 | .env.development.local 19 | .env.test.local 20 | .env.production.local 21 | 22 | npm-debug.log* 23 | yarn-debug.log* 24 | yarn-error.log* 25 | -------------------------------------------------------------------------------- /report-front-end/.prettierrc: -------------------------------------------------------------------------------- 1 | { 2 | "singleQuote": false, 3 | "trailingComma": "es5", 4 | "semi": true 5 | } 6 | -------------------------------------------------------------------------------- /report-front-end/Dockerfile: -------------------------------------------------------------------------------- 1 | FROM public.ecr.aws/docker/library/node:18.17.0 AS builder 2 | WORKDIR /frontend 3 | COPY package*.json ./ 4 | COPY . . 5 | COPY .env /frontend/.env 6 | 7 | 8 | ARG AWS_REGION 9 | 10 | RUN echo "Current AWS Region: $AWS_REGION" 11 | 12 | RUN if [ "$AWS_REGION" = "cn-north-1" ] || [ "$AWS_REGION" = "cn-northwest-1" ]; then \ 13 | sed -i "s/PLACEHOLDER_VITE_LOGIN_TYPE/None/g" .env && \ 14 | npm config set registry https://registry.npmmirror.com && \ 15 | npm install; \ 16 | else \ 17 | sed -i "s/PLACEHOLDER_VITE_LOGIN_TYPE/Cognito/g" .env && \ 18 | npm install; \ 19 | fi 20 | 21 | RUN npm run build 22 | 23 | FROM public.ecr.aws/docker/library/nginx:1.23-alpine 24 | COPY --from=builder /frontend/dist/ /usr/share/nginx/html/ 25 | COPY --from=builder /frontend/.env /.env 26 | 27 | COPY docker-entry.sh /docker-entry.sh 28 | RUN chmod +x /docker-entry.sh 29 | 30 | EXPOSE 80 31 | ENTRYPOINT ["/docker-entry.sh"] 32 | CMD ["nginx", "-g", "daemon off;"] 33 | 34 | # run on linux ec2/lambda/ecs(amd64) 35 | # docker buildx build --platform linux/amd64 --network host -t my_tag . 36 | 37 | # run on lambda/macOS(arm64) 38 | # docker buildx build --platform linux/arm64 --network host -t my_tag . -------------------------------------------------------------------------------- /report-front-end/docker-entry.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | # Read variable names from the .env file 4 | env_file="/.env" 5 | vars="VITE_COGNITO_REGION 6 | VITE_LOGIN_TYPE 7 | VITE_COGNITO_USER_POOL_WEB_CLIENT_ID 8 | VITE_COGNITO_USER_POOL_ID 9 | VITE_BACKEND_URL 10 | VITE_WEBSOCKET_URL" 11 | 12 | # Iterate through .js files in /usr/share/nginx/html and replace variables 13 | find "/usr/share/nginx/html" -type f -name "*.js" | while read -r file; do 14 | for var in $vars; do 15 | placeholder="PLACEHOLDER_$var" 16 | value=$(eval "echo \$$var") 17 | 18 | # Escape special characters in the value for use in sed 19 | escaped_value=$(printf '%s\n' "$value" | sed -e 's/[\/&]/\\&/g') 20 | 21 | echo "Replacing $placeholder with $escaped_value in $file" 22 | sed -i "s/$placeholder/$escaped_value/g" "$file" 23 | done 24 | done 25 | 26 | exec "$@" 27 | -------------------------------------------------------------------------------- /report-front-end/index.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 12 | 17 | 18 | 19 | GenBI Chatbot 20 | 23 | 24 | 25 |
26 | 27 | 28 | 29 | -------------------------------------------------------------------------------- /report-front-end/package.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "aws-genai-llm-chatbot", 3 | "private": true, 4 | "version": "1.8.0", 5 | "type": "module", 6 | "scripts": { 7 | "start": "vite --host 0.0.0.0", 8 | "build": "vite build" 9 | }, 10 | "dependencies": { 11 | "@aws-amplify/ui-react": "^5.3.2", 12 | "@cloudscape-design/chat-components": "^1.0.10", 13 | "@cloudscape-design/components": "^3.0.733", 14 | "@cloudscape-design/design-tokens": "^3.0.43", 15 | "@cloudscape-design/global-styles": "^1.0.32", 16 | "aws-amplify": "^5.3.12", 17 | "dotenv": "^16.4.5", 18 | "react": "^18.2.0", 19 | "react-dom": "^18.2.0", 20 | "react-hot-toast": "^2.4.1", 21 | "react-markdown": "^9.0.0", 22 | "react-redux": "^9.1.2", 23 | "react-router-dom": "^6.15.0", 24 | "react-syntax-highlighter": "^15.5.0", 25 | "react-textarea-autosize": "^8.5.3", 26 | "react-use-websocket": "^4.8.1", 27 | "redux": "^5.0.1", 28 | "regenerator-runtime": "^0.14.0", 29 | "umi-request": "^1.4.0", 30 | "uuid": "^9.0.0" 31 | }, 32 | "devDependencies": { 33 | "@types/react": "^18.2.15", 34 | "@types/react-dom": "^18.2.7", 35 | "@types/react-syntax-highlighter": "^15.5.13", 36 | "@types/uuid": "^9.0.3", 37 | "@typescript-eslint/eslint-plugin": "^6.0.0", 38 | "@typescript-eslint/parser": "^6.0.0", 39 | "@vitejs/plugin-react": "^4.0.3", 40 | "autoprefixer": "^10.4.14", 41 | "eslint": "^8.45.0", 42 | "eslint-plugin-react-hooks": "^4.6.0", 43 | "eslint-plugin-react-refresh": "^0.4.3", 44 | "postcss": "^8.4.27", 45 | "sass": "^1.65.1", 46 | "typescript": "^5.0.2", 47 | "vite": "^4.5.2" 48 | } 49 | } 50 | -------------------------------------------------------------------------------- /report-front-end/postcss.config.js: -------------------------------------------------------------------------------- 1 | export default { 2 | plugins: { 3 | autoprefixer: {}, 4 | }, 5 | }; 6 | -------------------------------------------------------------------------------- /report-front-end/public/favicon.ico: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aws-samples/generative-bi-using-rag/19db1a855e19167008a0adde451099b0c5726559/report-front-end/public/favicon.ico -------------------------------------------------------------------------------- /report-front-end/public/manifest.json: -------------------------------------------------------------------------------- 1 | { 2 | "id": "/", 3 | "start_url": "/", 4 | "short_name": "React App", 5 | "name": "React App Sample", 6 | "description": "Sample", 7 | "theme_color": "#000000", 8 | "background_color": "#ffffff", 9 | "display": "standalone", 10 | "icons": [ 11 | { 12 | "src": "favicon.ico", 13 | "sizes": "64x64 32x32 24x24 16x16", 14 | "type": "image/x-icon" 15 | } 16 | ] 17 | } 18 | -------------------------------------------------------------------------------- /report-front-end/public/smile-logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aws-samples/generative-bi-using-rag/19db1a855e19167008a0adde451099b0c5726559/report-front-end/public/smile-logo.png -------------------------------------------------------------------------------- /report-front-end/src/app.scss: -------------------------------------------------------------------------------- 1 | * { 2 | box-sizing: border-box; 3 | } 4 | 5 | :root { 6 | --app-color-scheme: light; 7 | color-scheme: var(--app-color-scheme); 8 | } 9 | 10 | html, 11 | body, 12 | #root, 13 | div[data-amplify-theme] { 14 | height: 100%; 15 | } 16 | body { 17 | background-color: #ffffff; 18 | overflow-y: scroll; 19 | } 20 | 21 | body.awsui-dark-mode { 22 | background-color: #0e1b2a; 23 | } 24 | 25 | .matrix-table { 26 | border: 1px solid #d6d6d6; 27 | border-radius: 2px; 28 | border-collapse: collapse; 29 | font-size: 1.1rem; 30 | 31 | th { 32 | border: 1px solid #d6d6d6; 33 | } 34 | 35 | td { 36 | border: 1px solid #d6d6d6; 37 | padding: 10px; 38 | } 39 | } 40 | 41 | .awsui-dark-mode { 42 | .matrix-table { 43 | border: 1px solid rgb(95, 107, 122); 44 | 45 | th { 46 | border: 1px solid rgb(95, 107, 122); 47 | } 48 | 49 | td { 50 | border: 1px solid rgb(95, 107, 122); 51 | padding: 12px; 52 | } 53 | } 54 | } 55 | 56 | .jsonContainer { 57 | font-family: "Open Sans", sans-serif; 58 | font-size: 1em; 59 | background-color: #0e1b2ac3; 60 | } 61 | 62 | .jsonStrings { 63 | color: rgb(74, 234, 167); 64 | } 65 | 66 | .jsonNumbers { 67 | color: rgb(255, 223, 60); 68 | } 69 | 70 | .jsonBool { 71 | color: rgb(252, 178, 250); 72 | font-weight: 600; 73 | } 74 | 75 | .jsonNull { 76 | color: rgb(74, 205, 234); 77 | font-weight: 600; 78 | } 79 | -------------------------------------------------------------------------------- /report-front-end/src/app.tsx: -------------------------------------------------------------------------------- 1 | import { AmplifyUser } from "@aws-amplify/ui"; 2 | import { UseAuthenticator } from "@aws-amplify/ui-react-core"; 3 | import { useEffect, useState } from "react"; 4 | import toast, { Toaster } from "react-hot-toast"; 5 | import { useDispatch } from "react-redux"; 6 | import { BrowserRouter, Route, Routes } from "react-router-dom"; 7 | import { v4 as uuid } from "uuid"; 8 | import "./app.scss"; 9 | import BaseAppLayout from "./components/BaseAppLayout"; 10 | import PanelConfigs from "./components/PanelConfigs"; 11 | import { PanelSideNav } from "./components/PanelSideNav"; 12 | import { Session } from "./components/PanelSideNav/types"; 13 | import SectionChat from "./components/SectionChat"; 14 | import TopNav from "./components/TopNav"; 15 | import { GlobalContext } from "./hooks/useGlobalContext"; 16 | import { isLoginWithCognito, LOCAL_STORAGE_KEYS } from "./utils/constants"; 17 | import { ActionType, UserInfo } from "./utils/helpers/types"; 18 | 19 | export type SignOut = UseAuthenticator["signOut"]; 20 | 21 | const App: React.FC<{ 22 | signOut?: SignOut; 23 | user?: AmplifyUser & { signInUserSession: any }; 24 | }> = ({ user }) => { 25 | const dispatch = useDispatch(); 26 | 27 | useEffect(() => { 28 | console.log({ user, signInUserSession: user?.signInUserSession }); 29 | if (isLoginWithCognito) { 30 | if (!user?.signInUserSession) { 31 | toast.error("User session not found"); 32 | return; 33 | } 34 | try { 35 | const { 36 | signInUserSession: { 37 | accessToken: { jwtToken: accessToken }, 38 | idToken: { jwtToken: idToken }, 39 | refreshToken: { token: refreshToken }, 40 | }, 41 | } = user; 42 | const loginUser: UserInfo = { 43 | userId: user?.attributes?.sub || "", 44 | displayName: 45 | user?.attributes?.displayName || user?.attributes?.email || "", 46 | loginExpiration: 0, 47 | isLogin: true, 48 | username: user?.username || "", 49 | }; 50 | dispatch({ type: ActionType.UpdateUserInfo, state: loginUser }); 51 | localStorage.setItem(LOCAL_STORAGE_KEYS.accessToken, accessToken); 52 | localStorage.setItem(LOCAL_STORAGE_KEYS.idToken, idToken); 53 | localStorage.setItem(LOCAL_STORAGE_KEYS.refreshToken, refreshToken); 54 | } catch (error) { 55 | console.error("Initiating cognito user state error: ", error); 56 | } 57 | } else { 58 | const loginUser: UserInfo = { 59 | userId: "none", 60 | displayName: "", 61 | loginExpiration: 0, 62 | isLogin: true, 63 | username: "anonymous", 64 | }; 65 | dispatch({ type: ActionType.UpdateUserInfo, state: loginUser }); 66 | } 67 | }, [dispatch, user]); 68 | 69 | return ( 70 |
71 | 72 | 73 | 74 |
 
75 |
76 | 77 | } /> 78 | 79 |
80 |
81 |
82 | ); 83 | }; 84 | 85 | export default App; 86 | 87 | const initSession = () => ({ 88 | session_id: uuid(), 89 | title: "New Chat", 90 | messages: [], 91 | }); 92 | 93 | function Playground() { 94 | const [toolsHide, setToolsHide] = useState(true); 95 | const [isSearching, setIsSearching] = useState(false); 96 | const [sessions, setSessions] = useState([initSession()]); 97 | const [currentSessionId, setCurrentSessionId] = useState( 98 | sessions[0].session_id 99 | ); 100 | return ( 101 | 111 | } 113 | content={} 114 | tools={} 115 | toolsHide={toolsHide} 116 | setToolsHide={setToolsHide} 117 | /> 118 | 119 | ); 120 | } 121 | -------------------------------------------------------------------------------- /report-front-end/src/components/BaseAppLayout.tsx: -------------------------------------------------------------------------------- 1 | import { AppLayout, AppLayoutProps } from "@cloudscape-design/components"; 2 | import { Dispatch, SetStateAction, useState } from "react"; 3 | import { Storage } from "../utils/helpers/storage"; 4 | 5 | export default function BaseAppLayout(props: { 6 | content: AppLayoutProps["content"]; 7 | tools: AppLayoutProps["tools"]; 8 | navigation: AppLayoutProps["navigation"]; 9 | toolsHide: boolean; 10 | setToolsHide: Dispatch>; 11 | }) { 12 | const [currentState, setCurrentState] = useState( 13 | Storage.getNavigationPanelState() 14 | ); 15 | 16 | return ( 17 | { 25 | props.setToolsHide(!detail.open); 26 | }} 27 | toolsWidth={450} 28 | navigationWidth={300} 29 | navigationHide={false} 30 | navigationOpen={!currentState.collapsed} 31 | onNavigationChange={({ detail }) => 32 | setCurrentState( 33 | Storage.setNavigationPanelState({ collapsed: !detail.open }) 34 | ) 35 | } 36 | /> 37 | ); 38 | } 39 | -------------------------------------------------------------------------------- /report-front-end/src/components/Login/CognitoLogin/aws-config.ts: -------------------------------------------------------------------------------- 1 | import { useSSOLogin } from "../../../utils/constants"; 2 | 3 | const extraConfigUseSSOLogin = useSSOLogin 4 | ? { 5 | mandatorySignIn: false, 6 | authenticationFlowType: "USER_SRP_AUTH", 7 | oauth: { 8 | domain: import.meta.env.VITE_SSO_OAUTH_DOMAIN, 9 | scope: ["email", "openid", "aws.cognito.signin.user.admin", "profile"], 10 | redirectSignIn: window.location.origin, 11 | redirectSignOut: window.location.origin, 12 | responseType: "code", 13 | }, 14 | } 15 | : {}; 16 | 17 | export const awsConfig = { 18 | Auth: { 19 | region: process.env.VITE_COGNITO_REGION, 20 | userPoolId: process.env.VITE_COGNITO_USER_POOL_ID, 21 | userPoolWebClientId: process.env.VITE_COGNITO_USER_POOL_WEB_CLIENT_ID, 22 | ...extraConfigUseSSOLogin, 23 | }, 24 | }; 25 | -------------------------------------------------------------------------------- /report-front-end/src/components/Login/CognitoLogin/index.tsx: -------------------------------------------------------------------------------- 1 | import { 2 | Authenticator, 3 | Button, 4 | defaultDarkModeOverride, 5 | Divider, 6 | Heading, 7 | Image, 8 | ThemeProvider, 9 | useTheme, 10 | View, 11 | } from "@aws-amplify/ui-react"; 12 | import "@aws-amplify/ui-react/styles.css"; 13 | import { Mode } from "@cloudscape-design/global-styles"; 14 | import { Amplify, Auth } from "aws-amplify"; 15 | import { useEffect, useState } from "react"; 16 | import App from "../../../app"; 17 | import { 18 | APP_LOGO, 19 | APP_LOGO_DISPLAY_ON_LOGIN_PAGE, 20 | APP_TITLE, 21 | APP_VERSION, 22 | SSO_FED_AUTH_PROVIDER, 23 | useSSOLogin, 24 | } from "../../../utils/constants"; 25 | import { Storage } from "../../../utils/helpers/storage"; 26 | import { awsConfig } from "./aws-config"; 27 | import "./layout-with-cognito.css"; 28 | 29 | export default function CognitoLogin() { 30 | const [theme, setTheme] = useState(Storage.getTheme()); 31 | 32 | useEffect(() => { 33 | console.log("Cognito configured"); 34 | try { 35 | Amplify.configure(awsConfig); 36 | } catch (e) { 37 | console.error(e); 38 | } 39 | }, []); 40 | 41 | useEffect(() => { 42 | const observer = new MutationObserver((mutations) => { 43 | mutations.forEach((mutation) => { 44 | if ( 45 | mutation.type === "attributes" && 46 | mutation.attributeName === "style" 47 | ) { 48 | const newValue = 49 | document.documentElement.style.getPropertyValue( 50 | "--app-color-scheme" 51 | ); 52 | 53 | const mode = newValue === "dark" ? Mode.Dark : Mode.Light; 54 | if (mode !== theme) { 55 | setTheme(mode); 56 | } 57 | } 58 | }); 59 | }); 60 | 61 | observer.observe(document.documentElement, { 62 | attributes: true, 63 | attributeFilter: ["style"], 64 | }); 65 | 66 | return () => { 67 | observer.disconnect(); 68 | }; 69 | }, [theme]); 70 | 71 | const [isLoading, setIsLoading] = useState(false); 72 | 73 | return ( 74 | 81 | 90 | ) : ( 91 | 92 | {APP_LOGO ? ( 93 | App logo 98 | ) : ( 99 | 100 | {APP_TITLE} 101 | 102 | )} 103 | 104 | ); 105 | }, 106 | } 107 | : { 108 | Header: Title, 109 | SignIn: { 110 | Header() { 111 | return ( 112 | 113 | 131 | 136 | 137 | ); 138 | }, 139 | }, 140 | } 141 | } 142 | > 143 | {({ signOut, user }) => } 144 | 145 | 146 | ); 147 | } 148 | 149 | function Title() { 150 | const { tokens } = useTheme(); 151 | return ( 152 | 157 | 158 | 159 | Generative Business Intelligence 160 | 161 | {APP_VERSION && {APP_VERSION}} 162 | 163 | 164 | Guidance on Amazon Web Services 165 | Amazon Web Services Logo 170 | 171 | ); 172 | } 173 | -------------------------------------------------------------------------------- /report-front-end/src/components/Login/CognitoLogin/layout-with-cognito.css: -------------------------------------------------------------------------------- 1 | [data-amplify-authenticator] { 2 | --amplify-components-authenticator-router-box-shadow: 0 0 16px 3 | var(--amplify-colors-overlay-10); 4 | --amplify-components-authenticator-router-border-width: 0; 5 | --amplify-components-authenticator-form-padding: var(--amplify-space-medium) 6 | var(--amplify-space-xl) var(--amplify-space-xl); 7 | --amplify-components-button-primary-background-color: var( 8 | --amplify-colors-neutral-100 9 | ); 10 | --amplify-components-fieldcontrol-focus-box-shadow: 0 0 0 2px 11 | var(--amplify-colors-purple-60); 12 | --amplify-components-tabs-item-active-border-color: var( 13 | --amplify-colors-neutral-100 14 | ); 15 | --amplify-components-tabs-item-color: var(--amplify-colors-neutral-80); 16 | --amplify-components-tabs-item-active-color: var(--amplify-colors-purple-100); 17 | --amplify-components-button-link-color: var(--amplify-colors-purple-80); 18 | } 19 | -------------------------------------------------------------------------------- /report-front-end/src/components/Login/CustomLogin/index.tsx: -------------------------------------------------------------------------------- 1 | import App from "../../../app"; 2 | import "./style.scss"; 3 | 4 | export default function CustomLogin() { 5 | 6 | return ( 7 |
8 | 9 |
10 | ); 11 | }; -------------------------------------------------------------------------------- /report-front-end/src/components/Login/CustomLogin/style.scss: -------------------------------------------------------------------------------- 1 | .login-page { 2 | width: 500px; 3 | height: 400px; 4 | position: relative; 5 | top: 300px; 6 | margin: 0 auto; 7 | border-radius: 16px; 8 | box-shadow: 0 1px 14px rgba(0, 7, 22, 0.14), 0 0 4px rgba(65, 77, 92, 0.2); 9 | align-content: center; 10 | } 11 | 12 | .login-container { 13 | width: 400px; 14 | margin: auto; 15 | } -------------------------------------------------------------------------------- /report-front-end/src/components/Login/index.tsx: -------------------------------------------------------------------------------- 1 | import CognitoLogin from "./CognitoLogin"; 2 | import CustomLogin from "./CustomLogin"; 3 | 4 | const Login = { Cognito: CognitoLogin, Custom: CustomLogin }; 5 | 6 | export default Login; 7 | -------------------------------------------------------------------------------- /report-front-end/src/components/PanelConfigs/style.scss: -------------------------------------------------------------------------------- 1 | .input-wrapper { 2 | position: absolute; 3 | top: -18px; 4 | right: 0; 5 | width: 80px; 6 | } 7 | -------------------------------------------------------------------------------- /report-front-end/src/components/PanelSideNav/index.tsx: -------------------------------------------------------------------------------- 1 | import { 2 | Box, 3 | Button, 4 | ContentLayout, 5 | Header, 6 | Spinner, 7 | } from "@cloudscape-design/components"; 8 | import { useEffect, useState } from "react"; 9 | import { useSelector } from "react-redux"; 10 | import { v4 as uuid } from "uuid"; 11 | import { getSessions } from "../../utils/api/API"; 12 | import { UserState } from "../../utils/helpers/types"; 13 | import "./style.scss"; 14 | import useGlobalContext from "../../hooks/useGlobalContext"; 15 | 16 | export const PanelSideNav = () => { 17 | const userInfo = useSelector((state: UserState) => state.userInfo); 18 | const queryConfig = useSelector((state: UserState) => state.queryConfig); 19 | const { setCurrentSessionId, setSessions, sessions, currentSessionId } = 20 | useGlobalContext(); 21 | const [loadingSessions, setLoadingSessions] = useState(false); 22 | 23 | useEffect(() => { 24 | setLoadingSessions(true); 25 | getSessions({ 26 | user_id: userInfo.userId, 27 | profile_name: queryConfig.selectedDataPro, 28 | }) 29 | .then((sessions) => { 30 | if (sessions?.length) { 31 | setCurrentSessionId(sessions[0].session_id); 32 | return setSessions(sessions); 33 | } 34 | const newSessionId = uuid(); 35 | setSessions([ 36 | { 37 | session_id: newSessionId, 38 | title: "New Chat", 39 | messages: [], 40 | }, 41 | ]); 42 | setCurrentSessionId(newSessionId); 43 | }) 44 | .finally(() => { 45 | setLoadingSessions(false); 46 | }); 47 | }, [ 48 | userInfo.userId, 49 | queryConfig.selectedDataPro, 50 | setCurrentSessionId, 51 | setSessions, 52 | ]); 53 | 54 | return ( 55 | 61 | {queryConfig.selectedDataPro || "Sessions of a profile"} 62 | 63 | } 64 | > 65 | {loadingSessions ? ( 66 | 67 | Loading sessions... 68 | 69 | ) : ( 70 | 71 | 90 |
91 | {sessions?.map((ses, idx) => ( 92 |
100 | 110 |
111 | ))} 112 |
113 |
114 | )} 115 |
116 | ); 117 | }; 118 | -------------------------------------------------------------------------------- /report-front-end/src/components/PanelSideNav/style.scss: -------------------------------------------------------------------------------- 1 | .session_container { 2 | display: flex; 3 | flex-direction: row; 4 | justify-content: space-between; 5 | border-radius: 10px !important; 6 | } 7 | 8 | .new_session_btn { 9 | border: 1px solid #ccc !important; 10 | border-radius: 10px !important; 11 | padding: 12px 12px !important; 12 | color: black !important; 13 | } 14 | 15 | .session { 16 | border: none !important; 17 | border-radius: 10px !important; 18 | font-weight: 500 !important; 19 | padding: 12px 12px !important; 20 | color: black !important; 21 | width: 100%; 22 | height: 100%; 23 | white-space: nowrap !important; 24 | text-overflow: ellipsis !important; 25 | overflow: hidden !important; 26 | background: transparent !important; 27 | &:hover { 28 | box-shadow: 0 0 4px 1px #cdcdcdb3 inset; 29 | transform: scale(0.99); 30 | } 31 | &:focus { 32 | outline: none; 33 | border-color: transparent; 34 | } 35 | } 36 | 37 | .menu { 38 | padding: 12px 12px !important; 39 | } -------------------------------------------------------------------------------- /report-front-end/src/components/PanelSideNav/types.ts: -------------------------------------------------------------------------------- 1 | import { ChatBotHistoryItem } from "../SectionChat/types"; 2 | 3 | export interface Session { 4 | session_id: string; 5 | title: string; 6 | messages: ChatBotHistoryItem[]; 7 | } 8 | -------------------------------------------------------------------------------- /report-front-end/src/components/SectionChat/ChartRenderer.tsx: -------------------------------------------------------------------------------- 1 | import { BarChart, LineChart, PieChart } from "@cloudscape-design/components"; 2 | 3 | interface ChartTypeProps { 4 | data_show_type: string; 5 | sql_data: any[][]; 6 | } 7 | 8 | export default function ChartPanel(props: ChartTypeProps) { 9 | const sql_data = props.sql_data; 10 | if (props.data_show_type === "bar") { 11 | // convert data to bar chart data 12 | const header = sql_data[0]; 13 | const items = sql_data.slice(1, sql_data.length); 14 | const key = ["x", "y"]; 15 | const content = items.map((item) => { 16 | const map: any = new Map( 17 | item.map((value, index) => { 18 | return [key[index], value]; 19 | }) 20 | ); 21 | return Object.fromEntries(map); 22 | }); 23 | const seriesValue: any = [ 24 | { 25 | title: header[1], 26 | type: "bar", 27 | data: content, 28 | }, 29 | ]; 30 | return ( 31 | 38 | ); 39 | } else if (props.data_show_type === "line") { 40 | // convert data to line chart data 41 | const lineHeader = sql_data[0]; 42 | const lineItems = sql_data.slice(1, sql_data.length); 43 | const lineKey = ["x", "y"]; 44 | const lineContent = lineItems.map((item) => { 45 | const map: any = new Map( 46 | item.map((value, index) => { 47 | return [lineKey[index], value]; 48 | }) 49 | ); 50 | return Object.fromEntries(map); 51 | }); 52 | const lineSeriesValue: any = [ 53 | { 54 | title: lineHeader[1], 55 | type: "line", 56 | data: lineContent, 57 | }, 58 | ]; 59 | return ( 60 | 68 | ); 69 | } else if (props.data_show_type === "pie") { 70 | // convert data to pie data 71 | const pieHeader = sql_data[0]; 72 | const pieItems = sql_data.slice(1, sql_data.length); 73 | const pieKeys = ["title", "value"]; 74 | const pieContent: any = pieItems.map((item) => { 75 | const map: any = new Map( 76 | item.map((value, index) => { 77 | return [pieKeys[index], value]; 78 | }) 79 | ); 80 | return Object.fromEntries(map); 81 | }); 82 | return ( 83 | [ 86 | { key: pieHeader[1], value: datum.value }, 87 | ]} 88 | fitHeight={true} 89 | hideFilter 90 | hideLegend 91 | /> 92 | ); 93 | } else { 94 | return null; 95 | } 96 | } 97 | -------------------------------------------------------------------------------- /report-front-end/src/components/SectionChat/CustomQuestions.tsx: -------------------------------------------------------------------------------- 1 | import { Button } from "@aws-amplify/ui-react"; 2 | import { Link, SpaceBetween } from "@cloudscape-design/components"; 3 | import { useEffect, useState } from "react"; 4 | import { SendJsonMessage } from "react-use-websocket/src/lib/types"; 5 | import { getRecommendQuestions } from "../../utils/api/API"; 6 | import { useQueryWithTokens } from "../../utils/api/WebSocket"; 7 | import styles from "./chat.module.scss"; 8 | 9 | export interface RecommendQuestionsProps { 10 | sendJsonMessage: SendJsonMessage; 11 | } 12 | 13 | export default function CustomQuestions({ 14 | sendJsonMessage, 15 | }: RecommendQuestionsProps) { 16 | const [showMoreQuestions, setShowMoreQuestions] = useState(false); 17 | const [questions, setQuestions] = useState([]); 18 | const { queryWithWS, queryConfig } = useQueryWithTokens(); 19 | 20 | useEffect(() => { 21 | const data_profile = queryConfig?.selectedDataPro; 22 | if (data_profile) { 23 | getRecommendQuestions(data_profile).then((data) => { 24 | setQuestions(data); 25 | }); 26 | } 27 | }, [queryConfig?.selectedDataPro]); 28 | 29 | const queries = showMoreQuestions 30 | ? questions 31 | : questions?.slice(0, Math.min(3, questions.length)); 32 | return ( 33 |
34 | {!queries?.length ? null : ( 35 | 36 |
37 | {queries?.map((query, idx) => ( 38 | 48 | ))} 49 |
50 |
51 | setShowMoreQuestions((prev) => !prev)}> 52 |

53 | {showMoreQuestions ? "less" : "more"} sample suggestions 54 |

55 | 56 |
57 |
58 | )} 59 |
60 | ); 61 | } 62 | -------------------------------------------------------------------------------- /report-front-end/src/components/SectionChat/ExpandableSectionWithDivider.tsx: -------------------------------------------------------------------------------- 1 | import { Divider } from "@aws-amplify/ui-react"; 2 | import { 3 | ExpandableSection, 4 | ExpandableSectionProps, 5 | } from "@cloudscape-design/components"; 6 | import React from "react"; 7 | 8 | const ExpandableSectionWithDivider: React.FC< 9 | ExpandableSectionProps & { withDivider?: boolean; label?: string } 10 | > = ({ withDivider = true, label, children, ...props }) => { 11 | return ( 12 |
13 | 14 |
15 | {children} 16 | {!withDivider ? null : ( 17 |
18 | 19 |
20 | )} 21 |
22 |
23 |
24 | ); 25 | }; 26 | 27 | export default ExpandableSectionWithDivider; 28 | -------------------------------------------------------------------------------- /report-front-end/src/components/SectionChat/MessageRenderer/EntitySelect.tsx: -------------------------------------------------------------------------------- 1 | import { 2 | Box, 3 | Button, 4 | RadioGroup, 5 | SpaceBetween, 6 | } from "@cloudscape-design/components"; 7 | import { Dispatch, SetStateAction, useEffect, useState } from "react"; 8 | import { useQueryWithTokens } from "../../../utils/api/WebSocket"; 9 | import { IEntityItem, QUERY_INTENT } from "../types"; 10 | import { IPropsAiMessageRenderer } from "./AiMessage"; 11 | 12 | export default function EntitySelect({ 13 | content, 14 | sendJsonMessage, 15 | }: IPropsAiMessageRenderer) { 16 | const entities = content.ask_entity_select?.entity_select_info; 17 | 18 | const { queryWithWS } = useQueryWithTokens(); 19 | const [userSelected, setUserSelected] = useState< 20 | Record> 21 | >({}); 22 | const [selectedIdRecord, setSelectedIdRecord] = useState< 23 | Record 24 | >({}); 25 | useEffect(() => { 26 | if (!entities) return; 27 | Object.entries(selectedIdRecord).forEach(([tp, selectedId]) => { 28 | const arr = entities[tp]; 29 | setUserSelected((prev) => ({ 30 | ...prev, 31 | [tp]: arr.find(({ id }) => id === selectedId)!, 32 | })); 33 | }); 34 | }, [entities, selectedIdRecord]); 35 | 36 | if (!entities) return null; 37 | return ( 38 | 39 | 40 | 41 | Please select the correct entity you would like to query - 42 | 43 | {Object.entries(entities).map(([type, vObj], idx) => { 44 | return ( 45 | 51 | ); 52 | })} 53 | 54 | 75 | 76 | 77 | 78 | ); 79 | } 80 | 81 | function EntityRadioSelect({ 82 | type, 83 | vObj, 84 | setSelectedIdRecord, 85 | }: { 86 | type: string; 87 | vObj: Array; 88 | setSelectedIdRecord: Dispatch>>; 89 | }) { 90 | const [id, setId] = useState(""); 91 | 92 | return ( 93 | 94 | {type} 95 | { 97 | setId(detail.value); 98 | setSelectedIdRecord((prev) => { 99 | return { ...prev, [type]: detail.value }; 100 | }); 101 | }} 102 | value={id} 103 | items={vObj.map(({ text, id }) => ({ 104 | value: id, 105 | label: text, 106 | }))} 107 | /> 108 | 109 | ); 110 | } 111 | -------------------------------------------------------------------------------- /report-front-end/src/components/SectionChat/MessageRenderer/index.tsx: -------------------------------------------------------------------------------- 1 | import { Icon, SpaceBetween } from "@cloudscape-design/components"; 2 | import { Dispatch, SetStateAction } from "react"; 3 | import { SendJsonMessage } from "react-use-websocket/src/lib/types"; 4 | import { ChatBotHistoryItem, ChatBotMessageType } from "../types"; 5 | import styles from "../chat.module.scss"; 6 | import AiMessage from "./AiMessage"; 7 | 8 | export interface ChatMessageProps { 9 | message: T; 10 | setMessageHistory: Dispatch>; 11 | sendJsonMessage: SendJsonMessage; 12 | } 13 | 14 | export default function MessageRenderer({ 15 | message, 16 | sendJsonMessage, 17 | }: ChatMessageProps) { 18 | return ( 19 | 20 | {message.type === ChatBotMessageType.Human && ( 21 |
22 | {message?.content?.toString()} 23 |
24 | )} 25 | {message.type === ChatBotMessageType.AI && ( 26 | 27 | )} 28 |
29 | ); 30 | } 31 | -------------------------------------------------------------------------------- /report-front-end/src/components/SectionChat/types.ts: -------------------------------------------------------------------------------- 1 | export interface ChatInputState { 2 | value: string; 3 | } 4 | 5 | export enum ChatBotMessageType { 6 | AI = "AI", 7 | Human = "human", 8 | } 9 | 10 | export type ChatBotHistoryItem = 11 | | { 12 | type: ChatBotMessageType.Human; 13 | content: string; 14 | } 15 | | { 16 | type: ChatBotMessageType.AI; 17 | content: ChatBotAnswerItem; 18 | }; 19 | 20 | export interface ChatBotMessageItem { 21 | session_id: string; 22 | user_id: string; 23 | content_type: string; 24 | content: StatusMessageItem; 25 | } 26 | 27 | export interface StatusMessageItem { 28 | status: string; 29 | text: string; 30 | } 31 | 32 | export enum QUERY_INTENT { 33 | ask_in_reply = "ask_in_reply", 34 | reject_search = "reject_search", 35 | normal_search = "normal_search", 36 | agent_search = "agent_search", 37 | knowledge_search = "knowledge_search", 38 | entity_select = "entity_select", 39 | } 40 | 41 | export interface ChatBotAnswerItem { 42 | query: string; 43 | // LLM rewrites the query 44 | query_rewrite: string; 45 | query_intent: QUERY_INTENT; 46 | knowledge_search_result: KnowledgeSearchResult; 47 | ask_rewrite_result: AskRewriteResult; 48 | sql_search_result: SQLSearchResult; 49 | agent_search_result: AgentSearchResult; 50 | suggested_question: string[]; 51 | error_log: Record; 52 | ask_entity_select: { 53 | entity_retrieval: unknown[]; 54 | entity_select_info: Record>; 55 | }; 56 | } 57 | 58 | export type IEntityItem = { text: string; id: string; [key: string]: string }; 59 | 60 | export enum FeedBackType { 61 | UPVOTE = "upvote", 62 | DOWNVOTE = "downvote", 63 | } 64 | 65 | export interface FeedBackItem { 66 | feedback_type: FeedBackType; 67 | data_profiles: string; 68 | query: string; 69 | query_intent: string; 70 | query_answer: string; 71 | // downvote feedback only ⬇️ 72 | session_id?: string; 73 | user_id?: string; 74 | error_description?: string; 75 | error_categories?: string; 76 | correct_sql_reference?: string; 77 | } 78 | 79 | export interface SessionItem { 80 | user_id: string; 81 | profile_name: string; 82 | } 83 | 84 | export interface HistoryItem { 85 | user_id: string; 86 | session_id: string; 87 | profile_name: string; 88 | } 89 | 90 | export interface AskRewriteResult { 91 | query_rewrite: string; 92 | } 93 | 94 | export interface KnowledgeSearchResult { 95 | knowledge_response: string; 96 | } 97 | 98 | export interface SQLSearchResult { 99 | // SQL string 100 | sql: string; 101 | // table data 102 | sql_data: any[][]; 103 | // chart data 104 | sql_data_chart: SQLDataChart[]; 105 | // chart type: default - "Table" 106 | data_show_type: "bar" | "line" | "table" | "pie"; 107 | // Desc of SQL ⬇️ 108 | sql_gen_process: string; 109 | // Answer with insights ⬇️ 110 | data_analyse: string; 111 | } 112 | 113 | export interface SQLDataChart { 114 | chart_type: string; 115 | chart_data: any[][]; 116 | } 117 | 118 | export interface AgentSQLSearchResult { 119 | sub_task_query: string; 120 | sql_search_result: SQLSearchResult; 121 | 122 | // 'sub_search_task': any[], 123 | // 'agent_sql_search_result': any[], 124 | // 'agent_summary': string 125 | } 126 | 127 | export interface AgentSearchResult { 128 | agent_sql_search_result: AgentSQLSearchResult[]; 129 | agent_summary: string; 130 | } 131 | -------------------------------------------------------------------------------- /report-front-end/src/components/TopNav/index.tsx: -------------------------------------------------------------------------------- 1 | import { TopNavigation } from "@cloudscape-design/components"; 2 | // import { Mode } from '@cloudscape-design/global-styles' 3 | import { Density } from "@cloudscape-design/global-styles"; 4 | import { Auth } from "aws-amplify"; 5 | import { useState } from "react"; 6 | import { useSelector } from "react-redux"; 7 | import { 8 | APP_LOGO, 9 | APP_RIGHT_LOGO, 10 | APP_TITLE, 11 | APP_VERSION, 12 | CHATBOT_NAME, 13 | isLoginWithCognito, 14 | } from "../../utils/constants"; 15 | import { Storage } from "../../utils/helpers/storage"; 16 | import { UserState } from "../../utils/helpers/types"; 17 | import "./style.scss"; 18 | 19 | export default function TopNav() { 20 | // const [theme, setTheme] = useState(Storage.getTheme()) 21 | const userInfo = useSelector((state: UserState) => state.userInfo); 22 | 23 | const [isCompact, setIsCompact] = useState( 24 | Storage.getDensity() === Density.Compact 25 | ); 26 | 27 | // const onChangeThemeClick = () => { 28 | // if (theme === Mode.Dark) { 29 | // setTheme(Storage.applyTheme(Mode.Light)) 30 | // } else { 31 | // setTheme(Storage.applyTheme(Mode.Dark)) 32 | // } 33 | // } 34 | 35 | return ( 36 |
40 | {APP_RIGHT_LOGO && ( 41 | logo 42 | )} 43 | { 66 | setIsCompact((prev) => { 67 | Storage.applyDensity( 68 | !prev ? Density.Compact : Density.Comfortable 69 | ); 70 | return !prev; 71 | }); 72 | }, 73 | }, 74 | { 75 | type: "menu-dropdown", 76 | text: userInfo?.displayName || "Authenticating", 77 | // description: `username: ${userInfo?.username}`, 78 | iconName: "user-profile", 79 | onItemClick: ({ detail }) => { 80 | if (detail.id === "signout") { 81 | if (isLoginWithCognito) { 82 | Auth.signOut(); 83 | } 84 | } 85 | }, 86 | items: [ 87 | { 88 | itemType: "group", 89 | id: "user-info", 90 | text: "User Information", 91 | items: [ 92 | { 93 | id: "0", 94 | text: `username: ${userInfo?.username}`, 95 | }, 96 | { 97 | id: "1", 98 | text: `userId: ${userInfo?.userId}`, 99 | }, 100 | { 101 | id: "2", 102 | text: `loginExpiration: ${userInfo?.loginExpiration}`, 103 | disabled: true, 104 | }, 105 | ], 106 | }, 107 | { 108 | id: "signout", 109 | text: "Sign out", 110 | }, 111 | ], 112 | }, 113 | ]} 114 | /> 115 |
116 | ); 117 | } 118 | -------------------------------------------------------------------------------- /report-front-end/src/components/TopNav/style.scss: -------------------------------------------------------------------------------- 1 | .logo { 2 | height: 40px; 3 | width: 50px; 4 | right: 0; 5 | position:fixed; 6 | margin-top: 8px; 7 | margin-right: 8px; 8 | } -------------------------------------------------------------------------------- /report-front-end/src/hooks/useGlobalContext.ts: -------------------------------------------------------------------------------- 1 | import { createContext, Dispatch, SetStateAction, useContext } from "react"; 2 | import { Session } from "../components/PanelSideNav/types"; 3 | 4 | export interface IGlobalContext { 5 | sessions: Session[]; 6 | setSessions: Dispatch>; 7 | currentSessionId: string; 8 | setCurrentSessionId: Dispatch>; 9 | isSearching: boolean; 10 | setIsSearching: Dispatch>; 11 | } 12 | export const GlobalContext = createContext(null); 13 | 14 | export const useGlobalContext = () => { 15 | const context = useContext(GlobalContext); 16 | if (!context) { 17 | throw new Error("useGlobalContext must be used within a GlobalContext"); 18 | } 19 | return context; 20 | }; 21 | 22 | export default useGlobalContext; 23 | -------------------------------------------------------------------------------- /report-front-end/src/main.tsx: -------------------------------------------------------------------------------- 1 | import "@cloudscape-design/global-styles/index.css"; 2 | import React from "react"; 3 | import ReactDOM from "react-dom/client"; 4 | import { Provider } from "react-redux"; 5 | import "regenerator-runtime/runtime"; 6 | import Login from "./components/Login"; 7 | import { isLoginWithCognito, LOGIN_TYPE } from "./utils/constants"; 8 | import { Storage } from "./utils/helpers/storage"; 9 | import userReduxStore from "./utils/helpers/store"; 10 | 11 | const root = ReactDOM.createRoot( 12 | document.getElementById("root") as HTMLElement 13 | ); 14 | 15 | const theme = Storage.getTheme(); 16 | Storage.applyTheme(theme); 17 | const density = Storage.getDensity(); 18 | Storage.applyDensity(density); 19 | console.log("Login type: ", LOGIN_TYPE); 20 | 21 | root.render( 22 | 23 | 24 | {isLoginWithCognito ? : } 25 | 26 | 27 | ); 28 | -------------------------------------------------------------------------------- /report-front-end/src/utils/api/API.ts: -------------------------------------------------------------------------------- 1 | import toast from "react-hot-toast"; 2 | import { extend } from "umi-request"; 3 | import { 4 | FeedBackItem, 5 | HistoryItem, 6 | SessionItem, 7 | } from "../../components/SectionChat/types"; 8 | import { 9 | BACKEND_URL, 10 | isLoginWithCognito, 11 | LOCAL_STORAGE_KEYS, 12 | } from "../constants"; 13 | import { logout } from "../helpers/tools"; 14 | import { Session } from "../../components/PanelSideNav/types"; 15 | 16 | export const getLSTokens = () => { 17 | const accessToken = 18 | localStorage.getItem(LOCAL_STORAGE_KEYS.accessToken) || ""; 19 | const idToken = localStorage.getItem(LOCAL_STORAGE_KEYS.idToken) || ""; 20 | const refreshToken = 21 | localStorage.getItem(LOCAL_STORAGE_KEYS.refreshToken) || ""; 22 | 23 | return { 24 | accessToken: `Bearer ${accessToken}`, 25 | idToken: `Bearer ${idToken}`, 26 | refreshToken: `Bearer ${refreshToken}`, 27 | noToken: !accessToken || !idToken || !refreshToken, 28 | }; 29 | }; 30 | export const getBearerTokenObj = () => { 31 | const { accessToken, idToken, refreshToken } = getLSTokens(); 32 | return { 33 | // Authorization: accessToken, 34 | "X-Access-Token": accessToken, 35 | "X-Id-Token": idToken, 36 | "X-Refresh-Token": refreshToken, 37 | }; 38 | }; 39 | 40 | export const request = extend({ 41 | prefix: BACKEND_URL, 42 | timeout: 30 * 1000, 43 | headers: { "Content-Type": "application/json" }, 44 | }); 45 | 46 | request.interceptors.request.use((url, options) => { 47 | if (!isLoginWithCognito) return { url, options }; 48 | const headers = { ...getBearerTokenObj(), ...options.headers }; 49 | return { url, options: { ...options, headers } }; 50 | }); 51 | 52 | request.interceptors.response.use((response) => { 53 | if (response.status === 500) toast.error("Internal Server Error"); 54 | if (!isLoginWithCognito) return response; 55 | if (response.status === 401) logout(); 56 | return response; 57 | }); 58 | 59 | export async function getSelectData() { 60 | try { 61 | const data = await request.get(`qa/option`, { 62 | errorHandler: (error) => { 63 | toast.error("LLM Option Error"); 64 | console.error("LLM Option Error: ", error); 65 | }, 66 | }); 67 | if (!data || !data.data_profiles || !data.bedrock_model_ids) { 68 | toast.error("LLM Option Error: data missing"); 69 | return; 70 | } 71 | return data; 72 | } catch (error) { 73 | console.error("getSelectData Error", error); 74 | } 75 | } 76 | 77 | export const getRecommendQuestions = async (data_profile: string) => { 78 | try { 79 | const data = await request.get("qa/get_custom_question", { 80 | params: { data_profile }, 81 | errorHandler: (error) => { 82 | toast.error("getCustomQuestions response error"); 83 | console.error("getCustomQuestions response error, ", error); 84 | }, 85 | }); 86 | return data.custom_question; 87 | } catch (error) { 88 | console.error("getCustomQuestions Error", error); 89 | } 90 | }; 91 | export async function postUserFeedback(feedbackData: FeedBackItem) { 92 | try { 93 | const data = await request.post("qa/user_feedback", { 94 | data: feedbackData, 95 | errorHandler: (error) => { 96 | toast.error("AddUserFeedback"); 97 | console.error("AddUserFeedback error, ", error); 98 | }, 99 | }); 100 | toast.success("Thanks for your feedback!"); 101 | console.log("AddUserFeedback: ", data); 102 | return data; 103 | } catch (err) { 104 | console.error("Query error, ", err); 105 | } 106 | } 107 | 108 | export async function getSessions(sessionItem: SessionItem) { 109 | try { 110 | const data = await request.post(`qa/get_sessions`, { 111 | data: sessionItem, 112 | errorHandler: (error) => { 113 | toast.error("getSessions error"); 114 | console.error("getSessions error: ", error); 115 | }, 116 | }); 117 | return data as Session[]; 118 | } catch (error) { 119 | console.error("getSessions, error: ", error); 120 | return []; 121 | } 122 | } 123 | 124 | export async function deleteHistoryBySession(historyItem: HistoryItem) { 125 | try { 126 | const data = await request.post(`qa/delete_history_by_session`, { 127 | data: historyItem, 128 | errorHandler: (error) => { 129 | toast.error("deleteHistoryBySession error"); 130 | console.error("deleteHistoryBySession error: ", error); 131 | }, 132 | }); 133 | return data; 134 | } catch (error) { 135 | console.error("deleteHistoryBySession, error: ", error); 136 | } 137 | } 138 | 139 | export async function getHistoryBySession(historyItem: HistoryItem) { 140 | // call api 141 | try { 142 | const data = await request.post(`qa/get_history_by_session`, { 143 | data: historyItem, 144 | errorHandler: (error) => { 145 | toast.error("getHistoryBySession error"); 146 | console.error("getHistoryBySession error: ", error); 147 | }, 148 | }); 149 | return data; 150 | } catch (error) { 151 | console.error("getHistoryBySession, error: ", error); 152 | } 153 | } 154 | -------------------------------------------------------------------------------- /report-front-end/src/utils/constants.ts: -------------------------------------------------------------------------------- 1 | export const CHATBOT_NAME = "GenBI Chatbot"; 2 | export const COGNITO = "Cognito"; 3 | 4 | export const DEFAULT_USER_INFO = { 5 | userId: "", 6 | displayName: "", 7 | loginExpiration: +new Date() + 6000, 8 | isLogin: false, 9 | username: "anonymous", 10 | }; 11 | 12 | export const DEFAULT_QUERY_CONFIG = { 13 | selectedLLM: "", 14 | selectedDataPro: "", 15 | intentChecked: true, 16 | complexChecked: true, 17 | answerInsightChecked: false, 18 | contextWindow: 0, 19 | modelSuggestChecked: false, 20 | temperature: 0.1, 21 | topP: 1, 22 | topK: 250, 23 | maxLength: 2048, 24 | }; 25 | 26 | export const LOCALSTORAGE_KEY = "__GEN_BI_STORE_INFO__"; 27 | 28 | export const LOGIN_TYPE = process.env.VITE_LOGIN_TYPE; 29 | export const isLoginWithCognito = LOGIN_TYPE === COGNITO; 30 | 31 | export const useSSOLogin = 32 | import.meta.env.VITE_USE_SSO_LOGIN === "true" ? true : false; 33 | export const SSO_FED_AUTH_PROVIDER = import.meta.env.VITE_SSO_FED_AUTH_PROVIDER; 34 | 35 | export const BACKEND_URL = process.env.VITE_BACKEND_URL?.endsWith("/") 36 | ? process.env.VITE_BACKEND_URL 37 | : process.env.VITE_BACKEND_URL + "/"; 38 | 39 | export const APP_TITLE = process.env.VITE_TITLE; 40 | export const APP_VERSION = process.env.VITE_APP_VERSION; 41 | export const APP_LOGO = process.env.VITE_LOGO || ""; 42 | 43 | export const APP_RIGHT_LOGO = process.env.VITE_RIGHT_LOGO || ""; 44 | 45 | export const APP_LOGO_DISPLAY_ON_LOGIN_PAGE = 46 | process.env.VITE_LOGO_DISPLAY_ON_LOGIN_PAGE || true; 47 | 48 | export const SQL_DISPLAY = process.env.VITE_SQL_DISPLAY; 49 | 50 | // https://cloudscape.design/patterns/general/density-settings/ 51 | export const APP_STYLE_DEFAULT_COMPACT = true; 52 | 53 | export const LOCAL_STORAGE_KEYS = { 54 | accessToken: "accessToken", 55 | idToken: "idToken", 56 | refreshToken: "refreshToken", 57 | } as const; 58 | -------------------------------------------------------------------------------- /report-front-end/src/utils/helpers/storage.ts: -------------------------------------------------------------------------------- 1 | import { 2 | applyDensity, 3 | applyMode, 4 | Density, 5 | Mode, 6 | } from "@cloudscape-design/global-styles"; 7 | import { APP_STYLE_DEFAULT_COMPACT } from "../constants"; 8 | 9 | const PREFIX = "genai-chatbot"; 10 | const THEME_STORAGE_NAME = `${PREFIX}-themes`; 11 | const DENSITY_STORAGE_NAME = `${PREFIX}-density`; 12 | const NAVIGATION_PANEL_STATE_STORAGE_NAME = `${PREFIX}-navigation-panel-state`; 13 | 14 | interface NavigationPanelState { 15 | collapsed?: boolean; 16 | collapsedSections?: Record; 17 | } 18 | 19 | export abstract class Storage { 20 | static getTheme() { 21 | const value = localStorage.getItem(THEME_STORAGE_NAME) ?? Mode.Light; 22 | return value === Mode.Dark ? Mode.Dark : Mode.Light; 23 | } 24 | 25 | static applyTheme(theme: Mode) { 26 | localStorage.setItem(THEME_STORAGE_NAME, theme); 27 | applyMode(theme); 28 | 29 | document.documentElement.style.setProperty( 30 | "--app-color-scheme", 31 | theme === Mode.Dark ? "dark" : "light" 32 | ); 33 | 34 | return theme; 35 | } 36 | 37 | static getDensity(): Density { 38 | let density = localStorage.getItem(DENSITY_STORAGE_NAME) as Density | null; 39 | if (!density) { 40 | density = APP_STYLE_DEFAULT_COMPACT 41 | ? Density.Compact 42 | : Density.Comfortable; 43 | } 44 | return density; 45 | } 46 | static applyDensity(density: Density) { 47 | localStorage.setItem(DENSITY_STORAGE_NAME, density); 48 | applyDensity(density); 49 | } 50 | 51 | static getNavigationPanelState(): NavigationPanelState { 52 | const value = 53 | localStorage.getItem(NAVIGATION_PANEL_STATE_STORAGE_NAME) ?? 54 | JSON.stringify({ 55 | collapsed: true, 56 | }); 57 | 58 | let state: NavigationPanelState | null = null; 59 | try { 60 | state = JSON.parse(value); 61 | } catch { 62 | state = {}; 63 | } 64 | 65 | return state ?? {}; 66 | } 67 | 68 | static setNavigationPanelState(state: Partial) { 69 | const currentState = this.getNavigationPanelState(); 70 | const newState = { ...currentState, ...state }; 71 | const stateStr = JSON.stringify(newState); 72 | localStorage.setItem(NAVIGATION_PANEL_STATE_STORAGE_NAME, stateStr); 73 | 74 | return newState; 75 | } 76 | } 77 | -------------------------------------------------------------------------------- /report-front-end/src/utils/helpers/store.ts: -------------------------------------------------------------------------------- 1 | import { createStore } from "redux"; 2 | import { DEFAULT_QUERY_CONFIG, DEFAULT_USER_INFO, LOCALSTORAGE_KEY } from "../constants"; 3 | import { ActionType, UserAction, UserState } from "./types"; 4 | 5 | const defaultUserState: UserState = { 6 | userInfo : DEFAULT_USER_INFO, 7 | queryConfig: DEFAULT_QUERY_CONFIG, 8 | }; 9 | 10 | const localStorageData = localStorage.getItem(LOCALSTORAGE_KEY) 11 | ? JSON.parse(localStorage.getItem(LOCALSTORAGE_KEY) || "{}") 12 | : null; 13 | 14 | const initialState = localStorageData || defaultUserState; 15 | 16 | const userReducer = (state = initialState, action: UserAction) => { 17 | switch (action.type) { 18 | case ActionType.Delete: 19 | localStorage.setItem(LOCALSTORAGE_KEY, ""); 20 | return null; 21 | case ActionType.UpdateUserInfo: 22 | localStorage.setItem(LOCALSTORAGE_KEY, JSON.stringify({ ...state, userInfo: action.state })); 23 | return { ...state, userInfo: action.state }; 24 | case ActionType.UpdateConfig: 25 | localStorage.setItem(LOCALSTORAGE_KEY, JSON.stringify({ ...state, queryConfig: action.state })); 26 | return { ...state, queryConfig: action.state }; 27 | default: 28 | localStorage.setItem(LOCALSTORAGE_KEY, JSON.stringify({ ...state })); 29 | return { ...state }; 30 | } 31 | }; 32 | 33 | const store = createStore(userReducer as any); 34 | 35 | export default store; 36 | -------------------------------------------------------------------------------- /report-front-end/src/utils/helpers/tools.ts: -------------------------------------------------------------------------------- 1 | import { Auth } from "aws-amplify"; 2 | import toast from "react-hot-toast"; 3 | import { LOCAL_STORAGE_KEYS } from "../constants"; 4 | 5 | export const logout = () => { 6 | console.warn("Not authorized! Logging out"); 7 | toast.error("Please login first!"); 8 | Object.keys(LOCAL_STORAGE_KEYS).forEach((key) => 9 | localStorage.removeItem(key) 10 | ); 11 | Auth.signOut(); 12 | }; 13 | 14 | /** 15 | * @deprecated please use logout() function directly 16 | */ 17 | export const dispatchUnauthorizedEvent = () => { 18 | window.dispatchEvent(new CustomEvent("unauthorized")); 19 | }; 20 | -------------------------------------------------------------------------------- /report-front-end/src/utils/helpers/types.ts: -------------------------------------------------------------------------------- 1 | export const COMMON_ALERT_TYPE = { 2 | Success: "success", 3 | Error: "error", 4 | Warning: "warning", 5 | Info: "info", 6 | }; 7 | 8 | export enum ActionType { 9 | Delete = "Delete", 10 | UpdateUserInfo = "UpdateUserInfo", 11 | UpdateConfig = "UpdateConfig", 12 | } 13 | 14 | export type UserState = { 15 | userInfo: UserInfo; 16 | queryConfig: LLMConfigState; 17 | }; 18 | 19 | export type UserInfo = { 20 | userId: string; 21 | displayName: string; 22 | loginExpiration: number; 23 | isLogin: boolean; 24 | username: string; 25 | }; 26 | 27 | export type LLMConfigState = { 28 | selectedLLM: string; 29 | selectedDataPro: string; 30 | intentChecked: boolean; 31 | complexChecked: boolean; 32 | answerInsightChecked: boolean; 33 | contextWindow: number; 34 | modelSuggestChecked: boolean; 35 | temperature: number; 36 | topP: number; 37 | topK: number; 38 | maxLength: number; 39 | }; 40 | 41 | export type UserAction = { type: ActionType; state?: any }; 42 | -------------------------------------------------------------------------------- /report-front-end/src/vite-env.d.ts: -------------------------------------------------------------------------------- 1 | /// 2 | -------------------------------------------------------------------------------- /report-front-end/tsconfig.json: -------------------------------------------------------------------------------- 1 | { 2 | "compilerOptions": { 3 | "target": "ES2020", 4 | "useDefineForClassFields": true, 5 | "lib": ["ES2020", "DOM", "DOM.Iterable"], 6 | "module": "ESNext", 7 | "skipLibCheck": true, 8 | 9 | /* Bundler mode */ 10 | "moduleResolution": "bundler", 11 | "allowImportingTsExtensions": true, 12 | "resolveJsonModule": true, 13 | "isolatedModules": true, 14 | "noEmit": true, 15 | "jsx": "react-jsx", 16 | 17 | /* Linting */ 18 | "strict": true, 19 | "noUnusedLocals": true, 20 | "noUnusedParameters": true, 21 | "noFallthroughCasesInSwitch": true 22 | }, 23 | "include": ["src"], 24 | "references": [{ "path": "./tsconfig.node.json" }] 25 | } 26 | -------------------------------------------------------------------------------- /report-front-end/tsconfig.node.json: -------------------------------------------------------------------------------- 1 | { 2 | "compilerOptions": { 3 | "composite": true, 4 | "skipLibCheck": true, 5 | "module": "ESNext", 6 | "moduleResolution": "bundler", 7 | "allowSyntheticDefaultImports": true 8 | }, 9 | "include": ["vite.config.ts"] 10 | } 11 | -------------------------------------------------------------------------------- /report-front-end/vite.config.ts: -------------------------------------------------------------------------------- 1 | import { defineConfig } from "vite"; 2 | import react from "@vitejs/plugin-react"; 3 | import dotenv from "dotenv"; 4 | import fs from "fs"; 5 | import path from "path"; 6 | 7 | // Load .env file 8 | const envPath = path.resolve(__dirname, ".env"); 9 | if (fs.existsSync(envPath)) { 10 | dotenv.config({ path: `${envPath}.local` }); 11 | dotenv.config({ path: envPath }); 12 | } 13 | 14 | // https://vitejs.dev/config/ 15 | export default defineConfig({ 16 | define: { 17 | "process.env": { 18 | VITE_TITLE: process.env.VITE_TITLE, 19 | VITE_LOGO: process.env.VITE_LOGO, 20 | VITE_RIGHT_LOGO: process.env.VITE_RIGHT_LOGO, 21 | VITE_LOGO_DISPLAY_ON_LOGIN_PAGE: 22 | process.env.VITE_LOGO_DISPLAY_ON_LOGIN_PAGE, 23 | VITE_COGNITO_REGION: process.env.VITE_COGNITO_REGION, 24 | VITE_COGNITO_USER_POOL_ID: process.env.VITE_COGNITO_USER_POOL_ID, 25 | VITE_COGNITO_USER_POOL_WEB_CLIENT_ID: 26 | process.env.VITE_COGNITO_USER_POOL_WEB_CLIENT_ID, 27 | VITE_COGNITO_IDENTITY_POOL_ID: process.env.VITE_COGNITO_IDENTITY_POOL_ID, 28 | VITE_SQL_DISPLAY: process.env.VITE_SQL_DISPLAY, 29 | VITE_BACKEND_URL: process.env.VITE_BACKEND_URL, 30 | VITE_WEBSOCKET_URL: process.env.VITE_WEBSOCKET_URL, 31 | VITE_LOGIN_TYPE: process.env.VITE_LOGIN_TYPE, 32 | VITE_APP_VERSION: process.env.VITE_APP_VERSION, 33 | VITE_USE_SSO_LOGIN: process.env.VITE_USE_SSO_LOGIN, 34 | VITE_SSO_FED_AUTH_PROVIDER: process.env.VITE_SSO_FED_AUTH_PROVIDER, 35 | VITE_SSO_OAUTH_DOMAIN: process.env.VITE_SSO_OAUTH_DOMAIN, 36 | }, 37 | }, 38 | plugins: [react()], 39 | server: { 40 | port: 3000, 41 | }, 42 | }); 43 | -------------------------------------------------------------------------------- /source/resources/.npmignore: -------------------------------------------------------------------------------- 1 | *.ts 2 | !*.d.ts 3 | 4 | # CDK asset staging directory 5 | .cdk.staging 6 | cdk.out 7 | -------------------------------------------------------------------------------- /source/resources/bin/main.ts: -------------------------------------------------------------------------------- 1 | import * as cdk from 'aws-cdk-lib'; 2 | import { MainStack } from '../lib/main-stack'; 3 | import * as fs from 'fs'; 4 | import * as path from 'path'; 5 | 6 | const devEnv = { 7 | account: process.env.CDK_DEFAULT_ACCOUNT, 8 | region: process.env.CDK_DEFAULT_REGION, 9 | }; 10 | 11 | declare const __dirname: string; 12 | 13 | const configPath = path.join(__dirname, '..', 'cdk-config.json'); 14 | const config = JSON.parse(fs.readFileSync(configPath, 'utf8')); 15 | 16 | const app = new cdk.App(); 17 | 18 | const rds = config.rds 19 | 20 | const embedding = config.embedding 21 | 22 | const opensearch = config.opensearch 23 | 24 | const vpc = config.vpc 25 | 26 | const cdkConfig = { 27 | env: devEnv, 28 | deployRds: rds.deploy, 29 | embedding_platform: embedding.embedding_platform, 30 | embedding_region: embedding.embedding_region, 31 | embedding_name: embedding.embedding_name, 32 | embedding_dimension: embedding.embedding_dimension, 33 | sql_index : opensearch.sql_index, 34 | ner_index : opensearch.ner_index, 35 | cot_index : opensearch.cot_index, 36 | log_index : opensearch.log_index, 37 | existing_vpc_id : vpc.existing_vpc_id, 38 | bedrock_ak_sk : config.ecs.bedrock_ak_sk, 39 | bedrock_region: config.ecs.bedrock_region, 40 | cognito_sign_in_aliases_username: config.cognito.sign_in_aliases_username 41 | }; 42 | 43 | new MainStack(app, 'GenBiMainStack', cdkConfig); // Pass deployRDS flag to MainStack constructor 44 | app.synth(); 45 | -------------------------------------------------------------------------------- /source/resources/cdk-config.json: -------------------------------------------------------------------------------- 1 | { 2 | "vpc" : { 3 | "existing_vpc_id" : "" 4 | }, 5 | "rds": { 6 | "deploy": false 7 | }, 8 | "embedding": { 9 | "embedding_platform": "bedrock", 10 | "embedding_name": "amazon.titan-embed-text-v1", 11 | "embedding_dimension": 1536, 12 | "embedding_region": "" 13 | }, 14 | "opensearch": { 15 | "sql_index" : "uba", 16 | "ner_index" : "uba_ner", 17 | "cot_index" : "uba_agent", 18 | "log_index" : "genbi_query_logging" 19 | }, 20 | "ecs": { 21 | "bedrock_region": "", 22 | "bedrock_ak_sk": "" 23 | }, 24 | "cognito": { 25 | "sign_in_aliases_username": false 26 | } 27 | } -------------------------------------------------------------------------------- /source/resources/cdk.json: -------------------------------------------------------------------------------- 1 | { 2 | "app": "npx ts-node --prefer-ts-exts bin/main.ts", 3 | "watch": { 4 | "include": [ 5 | "**" 6 | ], 7 | "exclude": [ 8 | "README.md", 9 | "cdk*.json", 10 | "**/*.d.ts", 11 | "**/*.js", 12 | "tsconfig.json", 13 | "package*.json", 14 | "yarn.lock", 15 | "node_modules", 16 | "test" 17 | ] 18 | }, 19 | "context": { 20 | "@aws-cdk/aws-lambda:recognizeLayerVersion": true, 21 | "@aws-cdk/core:checkSecretUsage": true, 22 | "@aws-cdk/core:target-partitions": [ 23 | "aws", 24 | "aws-cn" 25 | ], 26 | "@aws-cdk-containers/ecs-service-extensions:enableDefaultLogDriver": true, 27 | "@aws-cdk/aws-ec2:uniqueImdsv2TemplateName": true, 28 | "@aws-cdk/aws-ecs:arnFormatIncludesClusterName": true, 29 | "@aws-cdk/aws-iam:minimizePolicies": true, 30 | "@aws-cdk/core:validateSnapshotRemovalPolicy": true, 31 | "@aws-cdk/aws-codepipeline:crossAccountKeyAliasStackSafeResourceName": true, 32 | "@aws-cdk/aws-s3:createDefaultLoggingPolicy": true, 33 | "@aws-cdk/aws-sns-subscriptions:restrictSqsDescryption": true, 34 | "@aws-cdk/aws-apigateway:disableCloudWatchRole": true, 35 | "@aws-cdk/core:enablePartitionLiterals": true, 36 | "@aws-cdk/aws-events:eventsTargetQueueSameAccount": true, 37 | "@aws-cdk/aws-iam:standardizedServicePrincipals": true, 38 | "@aws-cdk/aws-ecs:disableExplicitDeploymentControllerForCircuitBreaker": true, 39 | "@aws-cdk/aws-iam:importedRoleStackSafeDefaultPolicyName": true, 40 | "@aws-cdk/aws-s3:serverAccessLogsUseBucketPolicy": true, 41 | "@aws-cdk/aws-route53-patters:useCertificate": true, 42 | "@aws-cdk/customresources:installLatestAwsSdkDefault": false, 43 | "@aws-cdk/aws-rds:databaseProxyUniqueResourceName": true, 44 | "@aws-cdk/aws-codedeploy:removeAlarmsFromDeploymentGroup": true, 45 | "@aws-cdk/aws-apigateway:authorizerChangeDeploymentLogicalId": true, 46 | "@aws-cdk/aws-ec2:launchTemplateDefaultUserData": true, 47 | "@aws-cdk/aws-secretsmanager:useAttachedSecretResourcePolicyForSecretTargetAttachments": true, 48 | "@aws-cdk/aws-redshift:columnId": true, 49 | "@aws-cdk/aws-stepfunctions-tasks:enableEmrServicePolicyV2": true, 50 | "@aws-cdk/aws-ec2:restrictDefaultSecurityGroup": true, 51 | "@aws-cdk/aws-apigateway:requestValidatorUniqueId": true, 52 | "@aws-cdk/aws-kms:aliasNameRef": true, 53 | "@aws-cdk/aws-autoscaling:generateLaunchTemplateInsteadOfLaunchConfig": true, 54 | "@aws-cdk/core:includePrefixInUniqueNameGeneration": true, 55 | "@aws-cdk/aws-efs:denyAnonymousAccess": true, 56 | "@aws-cdk/aws-opensearchservice:enableOpensearchMultiAzWithStandby": true, 57 | "@aws-cdk/aws-lambda-nodejs:useLatestRuntimeVersion": true, 58 | "@aws-cdk/aws-efs:mountTargetOrderInsensitiveLogicalId": true, 59 | "@aws-cdk/aws-rds:auroraClusterChangeScopeOfInstanceParameterGroupWithEachParameters": true, 60 | "@aws-cdk/aws-appsync:useArnForSourceApiAssociationIdentifier": true, 61 | "@aws-cdk/aws-rds:preventRenderingDeprecatedCredentials": true, 62 | "@aws-cdk/aws-codepipeline-actions:useNewDefaultBranchForCodeCommitSource": true 63 | } 64 | } 65 | -------------------------------------------------------------------------------- /source/resources/jest.config.js: -------------------------------------------------------------------------------- 1 | module.exports = { 2 | testEnvironment: 'node', 3 | roots: ['/test'], 4 | testMatch: ['**/*.test.ts'], 5 | transform: { 6 | '^.+\\.tsx?$': 'ts-jest' 7 | } 8 | }; 9 | -------------------------------------------------------------------------------- /source/resources/lib/aos/aos-stack.ts: -------------------------------------------------------------------------------- 1 | import * as cdk from 'aws-cdk-lib'; 2 | import { Construct } from 'constructs'; 3 | import * as ec2 from 'aws-cdk-lib/aws-ec2'; 4 | import * as opensearch from 'aws-cdk-lib/aws-opensearchservice'; 5 | import * as secretsmanager from 'aws-cdk-lib/aws-secretsmanager'; 6 | import { AnyPrincipal, Effect, PolicyStatement } from "aws-cdk-lib/aws-iam"; 7 | import * as crypto from 'crypto'; 8 | 9 | 10 | export class AOSStack extends cdk.Stack { 11 | _securityGroup; 12 | public readonly endpoint: string; 13 | public readonly OSMasterUserSecretName: string; 14 | public readonly OSHostSecretName: string; 15 | 16 | constructor(scope: Construct, id: string, props: cdk.StackProps & {vpc: ec2.IVpc} & { subnets: cdk.aws_ec2.ISubnet[] }) { 17 | super(scope, id, props); 18 | 19 | // Create a Security Group for OpenSearch 20 | this._securityGroup = new ec2.SecurityGroup(this, 'GenBIOpenSearchSG', { 21 | vpc: props.vpc, 22 | description: 'Allow access to OpenSearch', 23 | allowAllOutbound: true 24 | }); 25 | this._securityGroup.applyRemovalPolicy(cdk.RemovalPolicy.DESTROY); 26 | 27 | const OSMasterUserSecretNamePrefix = 'opensearch-master-user'; // Add the secret name here 28 | // const guid = crypto.randomBytes(3).toString('hex'); 29 | const vpcIdSuffix = props.vpc.vpcId 30 | console.log(`VPC ID Suffix: ${vpcIdSuffix}`); 31 | this.OSMasterUserSecretName = `${OSMasterUserSecretNamePrefix}-${vpcIdSuffix}`; 32 | console.log(`OSMasterUserSecretName: ${this.OSMasterUserSecretName}`); 33 | const templatedSecret = new secretsmanager.Secret(this, 'TemplatedSecret', { 34 | secretName: this.OSMasterUserSecretName, 35 | description: 'Templated secret used for OpenSearch master user password', 36 | generateSecretString: { 37 | excludePunctuation: false, 38 | includeSpace: false, 39 | generateStringKey: 'password', 40 | passwordLength: 12, 41 | requireEachIncludedType: true, 42 | secretStringTemplate: JSON.stringify({ username: 'master-user' }) 43 | }, 44 | removalPolicy: cdk.RemovalPolicy.DESTROY 45 | }); 46 | 47 | // Allow inbound HTTP and HTTPS traffic 48 | this._securityGroup.addIngressRule(ec2.Peer.anyIpv4(), ec2.Port.tcp(80), 'Allow HTTP access'); 49 | this._securityGroup.addIngressRule(ec2.Peer.anyIpv4(), ec2.Port.tcp(443), 'Allow HTTPS access'); 50 | 51 | // Find subnets in different availability zones 52 | const subnets = props.vpc.selectSubnets({ 53 | subnetType: ec2.SubnetType.PRIVATE_WITH_EGRESS, 54 | }).subnets; 55 | // const subnets = this._vpc.selectSubnets().subnets; 56 | 57 | // Create the OpenSearch domain 58 | const domain = new opensearch.Domain(this, 'GenBiOpenSearchDomain', { 59 | version: opensearch.EngineVersion.OPENSEARCH_2_9, 60 | vpc: props.vpc, 61 | securityGroups: [this._securityGroup], 62 | accessPolicies: [new PolicyStatement({ 63 | effect: Effect.ALLOW, 64 | principals: [new AnyPrincipal()], 65 | actions: ["es:*"], 66 | resources: [`arn:${this.partition}:es:${cdk.Aws.REGION}:${cdk.Aws.ACCOUNT_ID}:domain/*`] 67 | })] 68 | , 69 | vpcSubnets: [ {subnets: subnets.slice(0, 2)}], 70 | // vpcSubnets: [ 71 | // { subnets: [subnets[0]] }, 72 | // ], 73 | capacity: { 74 | dataNodes: 2, 75 | dataNodeInstanceType: 'm5.large.search', 76 | multiAzWithStandbyEnabled: false 77 | }, 78 | // capacity: { 79 | // dataNodes: 1, 80 | // dataNodeInstanceType: 'm5.large.search', 81 | // multiAzWithStandbyEnabled: false 82 | // }, 83 | ebs: { 84 | volumeType: ec2.EbsDeviceVolumeType.GP3, 85 | volumeSize: 20, 86 | }, 87 | zoneAwareness: { 88 | availabilityZoneCount: 2 89 | }, 90 | nodeToNodeEncryption: true, 91 | encryptionAtRest: { 92 | enabled: true 93 | }, 94 | enforceHttps: true, 95 | fineGrainedAccessControl: { 96 | masterUserName: 'master-user', 97 | masterUserPassword: cdk.SecretValue.secretsManager(templatedSecret.secretArn, { 98 | jsonField: 'password' 99 | }), 100 | }, 101 | }); 102 | domain.applyRemovalPolicy(cdk.RemovalPolicy.DESTROY); 103 | this.endpoint = domain.domainEndpoint.toString(); 104 | 105 | const OSHostSecretNamePrefix = 'opensearch-host-url'; // Add the secret name here 106 | this.OSHostSecretName = `${OSHostSecretNamePrefix}-${vpcIdSuffix}`; 107 | console.log(`OSHostSecretName: ${this.OSHostSecretName}`); 108 | const hostSecret = new secretsmanager.Secret(this, 'HostSecret', { 109 | secretName: this.OSHostSecretName, 110 | generateSecretString: { 111 | secretStringTemplate: JSON.stringify({host: this.endpoint}), 112 | generateStringKey: 'password', // Specify the key under which the secret will be stored 113 | }, 114 | }); 115 | 116 | new cdk.CfnOutput(this, 'AOSDomainEndpoint', { 117 | value: this.endpoint, 118 | description: 'The endpoint of the OpenSearch domain' 119 | }); 120 | } 121 | } 122 | 123 | // const app = new cdk.App(); 124 | // new AOSStack(app, 'AOSStack', { 125 | // env: { 126 | // account: process.env.CDK_DEFAULT_ACCOUNT, 127 | // region: process.env.CDK_DEFAULT_REGION 128 | // } 129 | // }); 130 | -------------------------------------------------------------------------------- /source/resources/lib/cognito/cognito-stack.ts: -------------------------------------------------------------------------------- 1 | import * as cdk from 'aws-cdk-lib/core'; 2 | import * as cognito from 'aws-cdk-lib/aws-cognito'; 3 | import { Construct } from 'constructs'; 4 | 5 | interface CognitoStackProps extends cdk.StackProps { 6 | sign_in_aliases_username?: boolean; 7 | } 8 | export class CognitoStack extends cdk.Stack { 9 | public readonly userPoolId: string; 10 | public readonly userPoolClientId: string; 11 | constructor(scope: Construct, id: string, props?: CognitoStackProps) { 12 | super(scope, id, props); 13 | 14 | let userPoolProps = undefined 15 | if (props?.sign_in_aliases_username) { 16 | userPoolProps = { 17 | userPoolName: 'GenBiUserPool', 18 | selfSignUpEnabled: true, 19 | signInAliases: { email: true, username: true, preferredUsername: true }, 20 | signInCaseSensitive: false, 21 | autoVerify: { email: true }, 22 | passwordPolicy: { 23 | minLength: 8, 24 | requireUppercase: false, 25 | requireLowercase: true, 26 | requireDigits: false, 27 | requireSymbols: false 28 | } 29 | } 30 | } else { 31 | userPoolProps = { 32 | userPoolName: 'GenBiUserPool', 33 | selfSignUpEnabled: true, 34 | signInAliases: { email: true }, 35 | autoVerify: { email: true }, 36 | passwordPolicy: { 37 | minLength: 8, 38 | requireUppercase: false, 39 | requireLowercase: true, 40 | requireDigits: false, 41 | requireSymbols: false 42 | } 43 | } 44 | } 45 | 46 | // Create a Cognito User Pool 47 | const userPool = new cognito.UserPool(this, 'GenBiUserPool', userPoolProps); 48 | 49 | 50 | // Create a User Pool Client associated with the User Pool 51 | const userPoolClient = new cognito.UserPoolClient(this, 'GenBiUserPoolClient', { 52 | userPool: userPool, 53 | userPoolClientName: 'GenBiUserPoolClient' 54 | }); 55 | 56 | this.userPoolId = userPool.userPoolId; 57 | this.userPoolClientId = userPoolClient.userPoolClientId; 58 | 59 | // Output the User Pool Id and User Pool Client Id 60 | new cdk.CfnOutput(this, 'UserPoolId', { 61 | value: userPool.userPoolId 62 | }); 63 | 64 | new cdk.CfnOutput(this, 'UserPoolClientId', { 65 | value: userPoolClient.userPoolClientId 66 | }); 67 | } 68 | } -------------------------------------------------------------------------------- /source/resources/lib/rds/rds-stack.ts: -------------------------------------------------------------------------------- 1 | import * as cdk from 'aws-cdk-lib'; 2 | import * as rds from 'aws-cdk-lib/aws-rds'; 3 | import * as ec2 from 'aws-cdk-lib/aws-ec2'; 4 | import { Construct } from 'constructs'; 5 | import { InstanceClass, InstanceSize, InstanceType, Port, SubnetType, Vpc } from 'aws-cdk-lib/aws-ec2' 6 | import * as secretsmanager from 'aws-cdk-lib/aws-secretsmanager'; 7 | 8 | interface RDSStackProps extends cdk.StackProps { 9 | subnets?: ec2.SubnetSelection; 10 | vpc:ec2.IVpc; 11 | } 12 | // add rds stack 13 | export class RDSStack extends cdk.Stack { 14 | public readonly endpoint: string; 15 | public readonly rdsSecurityGroup: ec2.SecurityGroup; 16 | constructor(scope: Construct, id: string, props: RDSStackProps) { 17 | super(scope, id, props); 18 | 19 | const templatedSecret = new secretsmanager.Secret(this, 'GenBIRDSTemplatedSecret', { 20 | description: 'Templated secret used for RDS password', 21 | generateSecretString: { 22 | excludePunctuation: true, 23 | includeSpace: false, 24 | generateStringKey: 'password', 25 | passwordLength: 12, 26 | secretStringTemplate: JSON.stringify({ username: 'user' }) 27 | }, 28 | removalPolicy: cdk.RemovalPolicy.DESTROY 29 | }); 30 | 31 | // Create an RDS instance 32 | const database = new rds.DatabaseInstance(this, 'Database', { 33 | engine: rds.DatabaseInstanceEngine.mysql({ version: rds.MysqlEngineVersion.VER_8_0 }), 34 | instanceType: ec2.InstanceType.of(InstanceClass.T3, InstanceSize.MICRO), 35 | vpc: props.vpc, 36 | vpcSubnets: props.subnets || { subnetType: SubnetType.PRIVATE_WITH_EGRESS }, 37 | publiclyAccessible: false, 38 | databaseName: 'GenBIDB', 39 | credentials: rds.Credentials.fromSecret(templatedSecret), 40 | }); 41 | this.endpoint = database.instanceEndpoint.hostname; 42 | // Output the database endpoint 43 | new cdk.CfnOutput(this, 'RDSEndpoint', { 44 | value: database.instanceEndpoint.hostname, 45 | description: 'The endpoint of the RDS instance', 46 | }); 47 | } 48 | } -------------------------------------------------------------------------------- /source/resources/lib/redshift/redshfit-stack.ts: -------------------------------------------------------------------------------- 1 | import * as cdk from 'aws-cdk-lib'; 2 | import * as redshift from 'aws-cdk-lib/aws-redshift'; 3 | import * as ec2 from 'aws-cdk-lib/aws-ec2'; 4 | import { Construct } from 'constructs'; 5 | import * as secretsmanager from 'aws-cdk-lib/aws-secretsmanager'; 6 | 7 | interface RedshiftStackProps extends cdk.StackProps { 8 | vpc: ec2.IVpc; 9 | subnets?: ec2.SubnetSelection; 10 | } 11 | 12 | 13 | export class RedshiftStack extends cdk.Stack { 14 | public readonly endpoint: string; 15 | 16 | constructor(scope: Construct, id: string, props: RedshiftStackProps) { 17 | super(scope, id, props); 18 | 19 | // Create a secret for Redshift credentials 20 | const redshiftSecret = new secretsmanager.Secret(this, 'RedshiftSecret', { 21 | description: 'Secret for Redshift cluster credentials', 22 | generateSecretString: { 23 | secretStringTemplate: JSON.stringify({ username: 'admin' }), 24 | generateStringKey: 'password', 25 | excludePunctuation: true, 26 | includeSpace: false, 27 | passwordLength: 16, 28 | }, 29 | removalPolicy: cdk.RemovalPolicy.DESTROY, 30 | }); 31 | 32 | // Create a security group for Redshift 33 | const redshiftSecurityGroup = new ec2.SecurityGroup(this, 'RedshiftSecurityGroup', { 34 | vpc: props.vpc, 35 | description: 'Security group for Redshift cluster', 36 | allowAllOutbound: true, 37 | }); 38 | 39 | // Allow inbound traffic on port 5439 (default Redshift port) 40 | redshiftSecurityGroup.addIngressRule(ec2.Peer.anyIpv4(), ec2.Port.tcp(5439), 'Allow Redshift access'); 41 | 42 | // Create the Redshift cluster 43 | const redshiftCluster = new redshift.Cluster(this, 'RedshiftCluster', { 44 | masterUser: { 45 | masterUsername: 'admin', 46 | masterPassword: redshiftSecret.secretValueFromJson('password'), 47 | }, 48 | vpc: props.vpc, 49 | vpcSubnets: props.subnets || { subnetType: ec2.SubnetType.PRIVATE_WITH_EGRESS }, 50 | securityGroups: [redshiftSecurityGroup], 51 | clusterType: redshift.ClusterType.SINGLE_NODE, 52 | nodeType: redshift.NodeType.DC2_LARGE, 53 | defaultDatabaseName: 'default_db', 54 | removalPolicy: cdk.RemovalPolicy.DESTROY, 55 | }); 56 | 57 | this.endpoint = redshiftCluster.clusterEndpoint.hostname; 58 | 59 | // Output the Redshift cluster endpoint 60 | new cdk.CfnOutput(this, 'RedshiftEndpoint', { 61 | value: redshiftCluster.clusterEndpoint.hostname, 62 | description: 'The endpoint of the Redshift cluster', 63 | }); 64 | } 65 | } -------------------------------------------------------------------------------- /source/resources/lib/vpc/vpc-stack.ts: -------------------------------------------------------------------------------- 1 | import * as cdk from 'aws-cdk-lib'; 2 | import {Construct} from 'constructs'; 3 | import * as ec2 from 'aws-cdk-lib/aws-ec2'; 4 | interface VPCStackProps extends cdk.StackProps { 5 | existing_vpc_id?: string; 6 | } 7 | 8 | export class VPCStack extends cdk.Stack { 9 | public readonly vpc: ec2.IVpc; 10 | public readonly publicSubnets: ec2.ISubnet[]; 11 | 12 | constructor(scope: Construct, id: string, props: VPCStackProps) { 13 | super(scope, id, props); 14 | // Create a VPC 15 | if (props.existing_vpc_id) { 16 | this.vpc = ec2.Vpc.fromLookup(this, 'GenBIVpc', { 17 | vpcId: props.existing_vpc_id, 18 | }); 19 | } else{ 20 | this.vpc = new ec2.Vpc(this, 'GenBIVpc', { 21 | maxAzs: 3, // Default is all AZs in the region 22 | natGateways: 1, 23 | subnetConfiguration: [ 24 | { 25 | cidrMask: 24, 26 | name: 'public-subnet', 27 | subnetType: ec2.SubnetType.PUBLIC, 28 | }, 29 | { 30 | cidrMask: 24, 31 | name: 'private-subnet', 32 | subnetType: ec2.SubnetType.PRIVATE_WITH_EGRESS, 33 | }, 34 | ], 35 | }); 36 | } 37 | 38 | // Output the VPC ID 39 | new cdk.CfnOutput(this, 'VpcId', { 40 | value: this.vpc.vpcId, 41 | }); 42 | 43 | // Output the Subnet IDs 44 | this.vpc.publicSubnets.forEach((subnet, index) => { 45 | new cdk.CfnOutput(this, `PublicSubnet${index}Id`, { 46 | value: subnet.subnetId, 47 | }); 48 | }); 49 | 50 | this.vpc.privateSubnets.forEach((subnet, index) => { 51 | new cdk.CfnOutput(this, `PrivateSubnet${index}Id`, { 52 | value: subnet.subnetId, 53 | }); 54 | }); 55 | } 56 | } -------------------------------------------------------------------------------- /source/resources/package.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "resources", 3 | "version": "0.1.0", 4 | "bin": { 5 | "resources": "bin/resources.js" 6 | }, 7 | "scripts": { 8 | "build": "tsc", 9 | "watch": "tsc -w", 10 | "test": "jest", 11 | "cdk": "cdk" 12 | }, 13 | "devDependencies": { 14 | "@types/jest": "^29.5.8", 15 | "@types/node": "20.9.0", 16 | "aws-cdk": "2.108.0", 17 | "jest": "^29.7.0", 18 | "ts-jest": "^29.1.1", 19 | "ts-node": "^10.9.1", 20 | "typescript": "~5.2.2" 21 | }, 22 | "dependencies": { 23 | "aws-cdk-lib": "2.108.0", 24 | "constructs": "^10.0.0" 25 | } 26 | } 27 | -------------------------------------------------------------------------------- /source/resources/tsconfig.json: -------------------------------------------------------------------------------- 1 | { 2 | "compilerOptions": { 3 | "target": "ES2020", 4 | "module": "commonjs", 5 | "lib": [ 6 | "es2020", 7 | "dom" 8 | ], 9 | "declaration": true, 10 | "strict": true, 11 | "noImplicitAny": true, 12 | "strictNullChecks": true, 13 | "noImplicitThis": true, 14 | "alwaysStrict": true, 15 | "noUnusedLocals": false, 16 | "noUnusedParameters": false, 17 | "noImplicitReturns": true, 18 | "noFallthroughCasesInSwitch": false, 19 | "inlineSourceMap": true, 20 | "inlineSources": true, 21 | "experimentalDecorators": true, 22 | "strictPropertyInitialization": false, 23 | "esModuleInterop": true, 24 | "typeRoots": [ 25 | "./node_modules/@types" 26 | ], 27 | "types": ["node"] 28 | }, 29 | "exclude": [ 30 | "node_modules", 31 | "cdk.out" 32 | ] 33 | } 34 | --------------------------------------------------------------------------------