├── CODE_OF_CONDUCT.md ├── CONTRIBUTING.md ├── LICENSE ├── README.md ├── cdk ├── .gitignore ├── .npmignore ├── bin │ └── cdk.ts ├── cdk.json ├── fargate │ ├── embeddingWorker │ │ ├── Dockerfile │ │ └── app.py │ ├── qaWorker │ │ ├── Dockerfile │ │ ├── app.py │ │ └── requirements.txt │ └── summarizationWorker │ │ ├── Dockerfile │ │ └── app.py ├── jest.config.js ├── lambda │ ├── asyncprocessor │ │ └── lambda_function.py │ ├── embeddingprocessor │ │ └── lambda_function.py │ ├── embeddingworker │ │ └── lambda_function.py │ ├── helper │ │ └── python │ │ │ ├── datastore.py │ │ │ └── helper.py │ ├── jobresultprocessor │ │ └── lambda_function.py │ ├── summarizationprocessor │ │ └── lambda_function.py │ ├── taskprocessor │ │ └── lambda_function.py │ └── textractor │ │ └── python │ │ ├── og.py │ │ └── trp.py ├── lib │ └── cdk-stack.ts ├── package-lock.json ├── package.json ├── test │ └── cdk.test.ts └── tsconfig.json ├── diagrams └── fsi-qa.png ├── frontend ├── .gitignore ├── package-lock.json ├── package.json ├── public │ ├── favicon.ico │ ├── index.html │ ├── logo192.png │ ├── logo512.png │ ├── manifest.json │ └── robots.txt └── src │ ├── App.css │ ├── App.js │ ├── App.test.js │ ├── config.js │ ├── index.css │ ├── index.js │ ├── logo.png │ └── setupTests.js ├── screenshots └── summarization.png └── scripts └── create-user.sh /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | ## Code of Conduct 2 | This project has adopted the [Amazon Open Source Code of Conduct](https://aws.github.io/code-of-conduct). 3 | For more information see the [Code of Conduct FAQ](https://aws.github.io/code-of-conduct-faq) or contact 4 | opensource-codeofconduct@amazon.com with any additional questions or comments. 5 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing Guidelines 2 | 3 | Thank you for your interest in contributing to our project. Whether it's a bug report, new feature, correction, or additional 4 | documentation, we greatly value feedback and contributions from our community. 5 | 6 | Please read through this document before submitting any issues or pull requests to ensure we have all the necessary 7 | information to effectively respond to your bug report or contribution. 8 | 9 | 10 | ## Reporting Bugs/Feature Requests 11 | 12 | We welcome you to use the GitHub issue tracker to report bugs or suggest features. 13 | 14 | When filing an issue, please check existing open, or recently closed, issues to make sure somebody else hasn't already 15 | reported the issue. Please try to include as much information as you can. Details like these are incredibly useful: 16 | 17 | * A reproducible test case or series of steps 18 | * The version of our code being used 19 | * Any modifications you've made relevant to the bug 20 | * Anything unusual about your environment or deployment 21 | 22 | 23 | ## Contributing via Pull Requests 24 | Contributions via pull requests are much appreciated. Before sending us a pull request, please ensure that: 25 | 26 | 1. You are working against the latest source on the *main* branch. 27 | 2. You check existing open, and recently merged, pull requests to make sure someone else hasn't addressed the problem already. 28 | 3. You open an issue to discuss any significant work - we would hate for your time to be wasted. 29 | 30 | To send us a pull request, please: 31 | 32 | 1. Fork the repository. 33 | 2. Modify the source; please focus on the specific change you are contributing. If you also reformat all the code, it will be hard for us to focus on your change. 34 | 3. Ensure local tests pass. 35 | 4. Commit to your fork using clear commit messages. 36 | 5. Send us a pull request, answering any default questions in the pull request interface. 37 | 6. Pay attention to any automated CI failures reported in the pull request, and stay involved in the conversation. 38 | 39 | GitHub provides additional document on [forking a repository](https://help.github.com/articles/fork-a-repo/) and 40 | [creating a pull request](https://help.github.com/articles/creating-a-pull-request/). 41 | 42 | 43 | ## Finding contributions to work on 44 | Looking at the existing issues is a great way to find something to contribute on. As our projects, by default, use the default GitHub issue labels (enhancement/bug/duplicate/help wanted/invalid/question/wontfix), looking at any 'help wanted' issues is a great place to start. 45 | 46 | 47 | ## Code of Conduct 48 | This project has adopted the [Amazon Open Source Code of Conduct](https://aws.github.io/code-of-conduct). 49 | For more information see the [Code of Conduct FAQ](https://aws.github.io/code-of-conduct-faq) or contact 50 | opensource-codeofconduct@amazon.com with any additional questions or comments. 51 | 52 | 53 | ## Security issue notifications 54 | If you discover a potential security issue in this project we ask that you notify AWS/Amazon Security via our [vulnerability reporting page](http://aws.amazon.com/security/vulnerability-reporting/). Please do **not** create a public github issue. 55 | 56 | 57 | ## Licensing 58 | 59 | See the [LICENSE](LICENSE) file for our project's licensing. We will ask you to confirm the licensing of your contribution. 60 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT No Attribution 2 | 3 | Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy of 6 | this software and associated documentation files (the "Software"), to deal in 7 | the Software without restriction, including without limitation the rights to 8 | use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of 9 | the Software, and to permit persons to whom the Software is furnished to do so. 10 | 11 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 12 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS 13 | FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR 14 | COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER 15 | IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN 16 | CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 17 | 18 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Summarization and question answering for financial documents 2 | 3 | This example shows you how to perform summarization and question answering for lengthy financial documents like annual reports to shareholders. 4 | 5 | For summarization, we split the document into smaller segments (five-pages by default) and summarize each segment. We use AI21's summarization model, which can handle input sequences up to about 10,000 words. 6 | 7 | For question answering we use a technique called retrieval augmented generation, where we provide new information (the contents of the financial document) to a large language model. 8 | 9 | ## Attribution 10 | 11 | Parts of the CDK code are adopted from [this repository](https://github.com/aws-samples/amazon-textract-serverless-large-scale-document-processing). 12 | 13 | Parts of this solution are inspired by: 14 | 15 | * https://medium.com/@shankar.arunp/augmenting-large-language-models-with-verified-information-sources-leveraging-aws-sagemaker-and-f6be17fb10a8 16 | * https://github.com/arunprsh/knowledge-augmented-LLMs 17 | 18 | 19 | ## Architecture 20 | 21 | The solution starts with a React Javascript application hosted in an S3 bucket fronted by CloudFront. 22 | 23 | ![Architecture](diagrams/fsi-qa.png) 24 | 25 | When users upload a PDF to S3, they can then start a Textract job to extract text information. When that job completes, the user can then start a summarization job. 26 | 27 | The front-end application calls methods on an API Gateway, which invokes Lambda functions for processing. The Lambda functions use SQS queues for asynchronous handling. The summarization job delegates to an ECS Fargate task as it may take several minutes to run. Job state is captured in DynamoDB tables. 28 | 29 | ## Clone repository 30 | 31 | Clone this GitHub repository into a working directory. 32 | 33 | ## Deploy SageMaker endpoints 34 | 35 | You must have access to SageMaker Jumpstart Foundation Models in this step. 36 | 37 | We will use the AI21 Summarize model available through [SageMaker Jumpstart Foundation Models](https://aws.amazon.com/sagemaker/jumpstart/?sagemaker-data-wrangler-whats-new.sort-by=item.additionalFields.postDateTime&sagemaker-data-wrangler-whats-new.sort-order=desc). To begin, deploy the AI21 Summarize model in SageMaker Jumpstart. You will need to subscribe to the model and then follow the example notebook to deploy a SageMaker inference endpoint. See this [previous blog](https://medium.com/@shankar.arunp/augmenting-large-language-models-with-verified-information-sources-leveraging-aws-sagemaker-and-f6be17fb10a8) for more detailed instructions. 38 | 39 | Once you have deployed the endpoint, create a file called `cdk/cdk.context.json` and add the endpoint names here. For example, if you deployed the AI21 summarization endpoint with the name `summarize`, the contents of `cdk.context.json` would be: 40 | 41 | { 42 | "sumEndpoint": "summarize" 43 | } 44 | 45 | Next, follow this [notebook](https://github.com/arunprsh/knowledge-augmented-LLMs/blob/main/01-deploy-text-embedding-model.ipynb) to deploy a text embedding model. Add the endpoint name to `cdk.context.json` as `embedEndpoint`. 46 | 47 | Finally, deploy a Cohere Medium model from SageMaker Jumpstart Foundation models. Add the endpoint name to `cdk.context.json` as `qaEndpoint`. 48 | 49 | _Note_: if you alter the endpoint names after deployment, you may need to recycle the containers in ECS to retrieve the latest values. 50 | 51 | ## CDK 52 | 53 | The application relies on a CDK stack for required infrastructure. 54 | 55 | First, see the [CDK getting started guide](https://docs.aws.amazon.com/cdk/v2/guide/getting_started.html) to deploy and configure the CDK on your workstation. 56 | 57 | Then go into the `cdk` directory and install required packages. 58 | 59 | npm i @aws-cdk/aws-cognito-identitypool-alpha 60 | npm i cdk-nag 61 | 62 | Next, set the region on line 10 of `bin/cdk.ts`. Normally we would not hard-code the region, but it's a [necessary step](https://docs.aws.amazon.com/cdk/api/v2/docs/aws-cdk-lib.aws_elasticloadbalancingv2.NetworkLoadBalancer.html#logwbraccesswbrlogsbucket-prefix) to enable ELB access logging. 63 | 64 | If you are running on ARM CPU Architecture, comment out the 3 sections in the lib/cdk/cdk-stack.ts file that look like the snippet below. If you aren't sure what architecture you are running on, you can run the following commands on either platform: 65 | 1. Unix 66 | ```uname -m``` 67 | 2. Windows 68 | ```wmic OS get OSArchitecture``` 69 | 70 | ``` 71 | // Uncomment this section if running on ARM 72 | // runtimePlatform: { 73 | // cpuArchitecture: ecs.CpuArchitecture.ARM64, 74 | // } 75 | ``` 76 | 77 | Now deploy the stack: 78 | 79 | cdk synth 80 | cdk deploy 81 | 82 | ## Create cognito user 83 | 84 | Run this script to create a Cognito user. 85 | 86 | ./scripts/create-user.sh 87 | 88 | ## Deploy front end 89 | 90 | Finally, build and load the React app. Adjust any necessary values in `frontend/src/config.js`. 91 | 92 | cd frontend 93 | npm install # only needed once 94 | yarn build 95 | aws s3 sync build/ s3:// 96 | 97 | Now you can access the applicaation at: 98 | 99 | https://`CdkStack.AppUrl`/index.html 100 | 101 | ## Security notes 102 | 103 | ### CloudFront certificate 104 | 105 | For the purposes of this example, we use the CloudFront default viewer certificate. Distributions that use the default CloudFront viewer certificate have a security policy set to TLSv1 regardless of the specified 'MinimumProtocolVersion'. Vulnerabilities have been and continue to be discovered in the deprecated SSL and TLS protocols. For production deployments, we recommend specifying a viewer certificate that enforces a minimum of TLSv1.1 or TLSv1.2 in the security policy. 106 | 107 | ## Security 108 | 109 | See [CONTRIBUTING](CONTRIBUTING.md#security-issue-notifications) for more information. 110 | 111 | ## License 112 | 113 | This library is licensed under the MIT-0 License. See the LICENSE file. 114 | 115 | -------------------------------------------------------------------------------- /cdk/.gitignore: -------------------------------------------------------------------------------- 1 | *.js 2 | !jest.config.js 3 | *.d.ts 4 | node_modules 5 | 6 | # CDK asset staging directory 7 | .cdk.staging 8 | cdk.out 9 | 10 | cdk.context.json 11 | -------------------------------------------------------------------------------- /cdk/.npmignore: -------------------------------------------------------------------------------- 1 | *.ts 2 | !*.d.ts 3 | 4 | # CDK asset staging directory 5 | .cdk.staging 6 | cdk.out 7 | -------------------------------------------------------------------------------- /cdk/bin/cdk.ts: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env node 2 | 3 | // Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 4 | // SPDX-License-Identifier: MIT-0 5 | 6 | 7 | import 'source-map-support/register'; 8 | import * as cdk from 'aws-cdk-lib'; 9 | import { CdkStack } from '../lib/cdk-stack'; 10 | import { AwsSolutionsChecks, NagSuppressions } from 'cdk-nag' 11 | import { Aspects } from 'aws-cdk-lib'; 12 | 13 | const app = new cdk.App(); 14 | Aspects.of(app).add(new AwsSolutionsChecks({ verbose: true })) 15 | const stack = new CdkStack(app, 'CdkStack', { env: {region: 'us-east-1'} }); 16 | NagSuppressions.addStackSuppressions(stack, [ 17 | { id: 'AwsSolutions-COG4', reason: 'All API Gateway methods are protected by IAM authentication' }, 18 | { id: 'AwsSolutions-CFR4', reason: 'I will document that a production deployment should use a non-default certificate' }, 19 | { id: 'AwsSolutions-IAM4', reason: 'Lambda default role only grants normal access to create Cloudwatch log groups' }, 20 | { id: 'AwsSolutions-IAM5', reason: 'Lambda default role access to S3 is properly scoped' }, 21 | { id: 'AwsSolutions-APIG2', reason: 'Request validation is handled by the Lambda functions' }, 22 | { id: 'AwsSolutions-APIG4', reason: 'All API Gateway methods are protected by IAM authentication' }, 23 | ]); 24 | -------------------------------------------------------------------------------- /cdk/cdk.json: -------------------------------------------------------------------------------- 1 | { 2 | "app": "npx ts-node --prefer-ts-exts bin/cdk.ts", 3 | "watch": { 4 | "include": [ 5 | "**" 6 | ], 7 | "exclude": [ 8 | "README.md", 9 | "cdk*.json", 10 | "**/*.d.ts", 11 | "**/*.js", 12 | "tsconfig.json", 13 | "package*.json", 14 | "yarn.lock", 15 | "node_modules", 16 | "test" 17 | ] 18 | }, 19 | "context": { 20 | "@aws-cdk/aws-lambda:recognizeLayerVersion": true, 21 | "@aws-cdk/core:checkSecretUsage": true, 22 | "@aws-cdk/core:target-partitions": [ 23 | "aws", 24 | "aws-cn" 25 | ], 26 | "@aws-cdk-containers/ecs-service-extensions:enableDefaultLogDriver": true, 27 | "@aws-cdk/aws-ec2:uniqueImdsv2TemplateName": true, 28 | "@aws-cdk/aws-ecs:arnFormatIncludesClusterName": true, 29 | "@aws-cdk/aws-iam:minimizePolicies": true, 30 | "@aws-cdk/core:validateSnapshotRemovalPolicy": true, 31 | "@aws-cdk/aws-codepipeline:crossAccountKeyAliasStackSafeResourceName": true, 32 | "@aws-cdk/aws-s3:createDefaultLoggingPolicy": true, 33 | "@aws-cdk/aws-sns-subscriptions:restrictSqsDescryption": true, 34 | "@aws-cdk/aws-apigateway:disableCloudWatchRole": true, 35 | "@aws-cdk/core:enablePartitionLiterals": true, 36 | "@aws-cdk/aws-events:eventsTargetQueueSameAccount": true, 37 | "@aws-cdk/aws-iam:standardizedServicePrincipals": true, 38 | "@aws-cdk/aws-ecs:disableExplicitDeploymentControllerForCircuitBreaker": true, 39 | "@aws-cdk/aws-iam:importedRoleStackSafeDefaultPolicyName": true, 40 | "@aws-cdk/aws-s3:serverAccessLogsUseBucketPolicy": true, 41 | "@aws-cdk/aws-route53-patters:useCertificate": true, 42 | "@aws-cdk/customresources:installLatestAwsSdkDefault": false 43 | } 44 | } 45 | -------------------------------------------------------------------------------- /cdk/fargate/embeddingWorker/Dockerfile: -------------------------------------------------------------------------------- 1 | FROM public.ecr.aws/lts/ubuntu:20.04 2 | 3 | RUN apt-get update 4 | RUN apt-get -y install python3-pip 5 | RUN pip3 install boto3 langchain transformers chromadb 6 | 7 | COPY app.py /opt/app.py 8 | 9 | CMD ["/usr/bin/python3", "/opt/app.py"] -------------------------------------------------------------------------------- /cdk/fargate/embeddingWorker/app.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # SPDX-License-Identifier: MIT-0 3 | 4 | import os 5 | import traceback 6 | from typing import Optional, List 7 | import json 8 | import boto3 9 | from langchain.document_loaders import TextLoader 10 | from langchain.text_splitter import RecursiveCharacterTextSplitter 11 | from langchain.vectorstores import Chroma 12 | from langchain.embeddings.base import Embeddings 13 | from pydantic import BaseModel 14 | 15 | class SMEndpointEmbeddings(BaseModel, Embeddings): 16 | endpoint_name: str 17 | 18 | def embed_documents( 19 | self, texts: List[str], chunk_size: int = 64 20 | ) -> List[List[float]]: 21 | results = [] 22 | for t in texts: 23 | response = self.embed_query(t) 24 | results.append(response) 25 | return results 26 | 27 | def embed_query(self, text: str) -> List[float]: 28 | payload = {'text_inputs': [text]} 29 | payload = json.dumps(payload).encode('utf-8') 30 | client = boto3.client("runtime.sagemaker") 31 | response = client.invoke_endpoint(EndpointName=self.endpoint_name, 32 | ContentType='application/json', 33 | Body=payload) 34 | 35 | model_predictions = json.loads(response['Body'].read()) 36 | embedding = model_predictions['embedding'][0] 37 | return embedding 38 | 39 | # Inputs: document id and s3 location of summary 40 | def main(): 41 | 42 | print("Task starting") 43 | endpoint_name = os.environ['endpoint'] 44 | print(f"Endpoint: {endpoint_name}") 45 | table_name = os.environ['table'] 46 | print(f"Table: {table_name}") 47 | region = os.environ['region'] 48 | print(f"Region: {region}") 49 | docId = os.environ['docId'] 50 | print(f"docId: {docId}") 51 | jobId = os.environ['jobId'] 52 | print(f"jobId: {jobId}") 53 | bucket = os.environ['bucket'] 54 | print(f"bucket: {bucket}") 55 | name = os.environ['name'] 56 | print(f"name: {name}") 57 | mntpnt = os.environ['mountpoint'] 58 | print(f"name: {mntpnt}") 59 | 60 | try: 61 | s3 = boto3.client('s3') 62 | doc_dir = os.path.join(mntpnt, docId) 63 | if not os.path.exists(doc_dir): 64 | os.mkdir(doc_dir) 65 | sum_dir = os.path.join(doc_dir, 'summary') 66 | if not os.path.exists(sum_dir): 67 | os.mkdir(sum_dir) 68 | sum_path = os.path.join(sum_dir, 'summary.txt') 69 | print(f"Downloading s3://{bucket}/{name} to {sum_path}") 70 | s3.download_file(bucket, name, sum_path) 71 | 72 | persist_directory = os.path.join(doc_dir, 'db') 73 | if not os.path.exists(persist_directory): 74 | os.mkdir(persist_directory) 75 | loader = TextLoader(sum_path) 76 | documents = loader.load() 77 | text_splitter = RecursiveCharacterTextSplitter(chunk_size = 500, 78 | chunk_overlap = 0) 79 | texts = text_splitter.split_documents(documents) 80 | print(f"Number of splits: {len(texts)}") 81 | 82 | embeddings = SMEndpointEmbeddings( 83 | endpoint_name=endpoint_name, 84 | ) 85 | vectordb = Chroma.from_documents(texts, embeddings, persist_directory=persist_directory) 86 | vectordb.persist() 87 | 88 | ddb = boto3.resource('dynamodb', region_name=region) 89 | table = ddb.Table(table_name) 90 | table.update_item( 91 | Key = { "documentId": docId, "jobId": jobId }, 92 | UpdateExpression = 'SET jobStatus = :jobstatusValue', 93 | ExpressionAttributeValues = { 94 | ':jobstatusValue': "Complete", 95 | } 96 | ) 97 | 98 | except Exception as e: 99 | trc = traceback.format_exc() 100 | print(trc) 101 | print(str(e)) 102 | 103 | if __name__ == "__main__": 104 | main() 105 | -------------------------------------------------------------------------------- /cdk/fargate/qaWorker/Dockerfile: -------------------------------------------------------------------------------- 1 | FROM public.ecr.aws/lts/ubuntu:20.04 2 | 3 | RUN apt-get update 4 | RUN apt-get -y install python3-pip 5 | RUN pip3 install Flask Flask-Cors boto3 langchain transformers chromadb cohere-sagemaker numpy 6 | 7 | COPY ./* ./app/ 8 | WORKDIR /app/ 9 | 10 | EXPOSE 5000 11 | 12 | CMD ["python3", "app.py"] -------------------------------------------------------------------------------- /cdk/fargate/qaWorker/app.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # SPDX-License-Identifier: MIT-0 3 | 4 | import os 5 | import traceback 6 | from typing import List 7 | import json 8 | import boto3 9 | from flask import Flask, jsonify, request 10 | from flask_cors import CORS 11 | from langchain.vectorstores import Chroma 12 | from langchain.embeddings.base import Embeddings 13 | from pydantic import BaseModel 14 | from langchain.embeddings.base import Embeddings 15 | from pydantic import BaseModel 16 | from cohere_sagemaker import Client 17 | import numpy as np 18 | 19 | app = Flask(__name__) 20 | CORS(app) 21 | 22 | def query_endpoint_with_json_payload(encoded_json, endpoint_name): 23 | client = boto3.client("runtime.sagemaker") 24 | response = client.invoke_endpoint( 25 | EndpointName=endpoint_name, ContentType="application/json", Body=encoded_json 26 | ) 27 | return response 28 | 29 | def parse_response_multiple_texts(query_response): 30 | generated_text = [] 31 | model_predictions = json.loads(query_response["Body"].read()) 32 | return model_predictions[0] 33 | 34 | class SMEndpointEmbeddings(BaseModel, Embeddings): 35 | endpoint_name: str 36 | 37 | def embed_documents( 38 | self, texts: List[str], chunk_size: int = 64 39 | ) -> List[List[float]]: 40 | results = [] 41 | for t in texts: 42 | response = self.embed_query(t) 43 | results.append(response) 44 | return results 45 | 46 | def embed_query(self, text: str) -> List[float]: 47 | payload = {'text_inputs': [text]} 48 | payload = json.dumps(payload).encode('utf-8') 49 | client = boto3.client("runtime.sagemaker") 50 | response = client.invoke_endpoint(EndpointName=self.endpoint_name, 51 | ContentType='application/json', 52 | Body=payload) 53 | 54 | model_predictions = json.loads(response['Body'].read()) 55 | embedding = model_predictions['embedding'][0] 56 | return embedding 57 | 58 | @app.route("/health") 59 | def health(): 60 | resp = jsonify(health="healthy") 61 | resp.status_code = 200 62 | return resp 63 | 64 | @app.route("/", methods=['POST']) 65 | def answerquestion(): 66 | content_type = request.headers.get('Content-Type') 67 | if (content_type == 'application/json'): 68 | body_data = request.json 69 | else: 70 | return { 71 | 'error': "Content type not supported", 72 | 'code': 400 73 | } 74 | 75 | print("Task starting") 76 | endpoint_embed = os.environ['endpoint_embed'] 77 | print(f"Endpoint: {endpoint_embed}") 78 | endpoint_qa = os.environ['endpoint_qa'] 79 | print(f"Endpoint: {endpoint_qa}") 80 | mntpnt = os.environ['mountpoint'] 81 | print(f"name: {mntpnt}") 82 | docId = body_data['docId'] 83 | print(f"docId: {docId}") 84 | question = body_data['question'] 85 | print(f"question: {question}") 86 | 87 | try: 88 | # Create LLM chain 89 | doc_dir = os.path.join(mntpnt, docId) 90 | persist_directory = os.path.join(doc_dir, 'db') 91 | if not os.path.exists(persist_directory): 92 | return { 93 | 'error': f"Could not find Chroma database for {docId}", 94 | 'code': 400 95 | } 96 | 97 | embeddings = SMEndpointEmbeddings( 98 | endpoint_name=endpoint_embed 99 | ) 100 | vectordb = Chroma(persist_directory=persist_directory, embedding_function=embeddings) 101 | 102 | cohere_client = Client(endpoint_name=endpoint_qa) 103 | docs = vectordb.similarity_search_with_score(question) 104 | 105 | scores = [] 106 | for t in docs: 107 | scores.append(t[1]) 108 | 109 | score_array = np.asarray(scores) 110 | high_score_idx = score_array.argmax() 111 | print(f"High score {score_array[high_score_idx]}") 112 | context = docs[high_score_idx][0].page_content.replace("\n", "") 113 | qa_prompt = f'Context={context}\nQuestion={question}\nAnswer=' 114 | response = cohere_client.generate(prompt=qa_prompt, 115 | max_tokens=512, 116 | temperature=0.25, 117 | return_likelihoods='GENERATION') 118 | answer = response.generations[0].text.strip().replace('\n', '') 119 | 120 | return { 121 | 'answer': answer, 122 | 'code': 200 123 | } 124 | 125 | except Exception as e: 126 | trc = traceback.format_exc() 127 | print(trc) 128 | print(str(e)) 129 | return { 130 | 'error': str(e), 131 | 'code': 400 132 | } 133 | 134 | if __name__ == "__main__": 135 | port = int(os.environ.get('PORT', 5000)) 136 | app.run(host='0.0.0.0', port=port) 137 | -------------------------------------------------------------------------------- /cdk/fargate/qaWorker/requirements.txt: -------------------------------------------------------------------------------- 1 | Flask 2 | Flask-Cors 3 | boto3 4 | langchain 5 | transformers 6 | chromadb 7 | cohere-sagemaker 8 | numpy -------------------------------------------------------------------------------- /cdk/fargate/summarizationWorker/Dockerfile: -------------------------------------------------------------------------------- 1 | FROM public.ecr.aws/lts/ubuntu:20.04 2 | 3 | RUN apt-get update 4 | RUN apt-get -y install python3-pip 5 | RUN pip3 install boto3 langchain transformers ai21[SM] 6 | 7 | COPY app.py /opt/app.py 8 | 9 | CMD ["/usr/bin/python3", "/opt/app.py"] 10 | 11 | -------------------------------------------------------------------------------- /cdk/fargate/summarizationWorker/app.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # SPDX-License-Identifier: MIT-0 3 | 4 | import os 5 | import traceback 6 | from typing import Optional, List 7 | import json 8 | import boto3 9 | from langchain.docstore.document import Document 10 | from langchain.llms.base import LLM 11 | from langchain.chains.summarize import load_summarize_chain 12 | from langchain.text_splitter import RecursiveCharacterTextSplitter 13 | import ai21 14 | 15 | def query_endpoint_with_json_payload(encoded_json, endpoint_name): 16 | client = boto3.client("runtime.sagemaker") 17 | response = client.invoke_endpoint( 18 | EndpointName=endpoint_name, ContentType="application/json", Body=encoded_json 19 | ) 20 | return response 21 | 22 | def parse_response_multiple_texts(query_response): 23 | model_predictions = json.loads(query_response["Body"].read()) 24 | generated_text = model_predictions["generated_texts"] 25 | return generated_text 26 | 27 | def query_endpoint(encoded_text, endpoint_name): 28 | client = boto3.client("runtime.sagemaker") 29 | response = client.invoke_endpoint( 30 | EndpointName=endpoint_name, ContentType="application/x-text", Body=encoded_text 31 | ) 32 | return response 33 | 34 | 35 | def parse_response(query_response): 36 | model_predictions = json.loads(query_response["Body"].read()) 37 | generated_text = model_predictions["generated_text"] 38 | return generated_text 39 | 40 | class SageMakerLLMAI21(LLM): 41 | 42 | endpoint_name: str 43 | 44 | @property 45 | def _llm_type(self) -> str: 46 | return "summarize" 47 | 48 | def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str: 49 | response = ai21.Summarize.execute( 50 | source=prompt, 51 | sourceType="TEXT", 52 | sm_endpoint=self.endpoint_name 53 | ) 54 | return response.summary 55 | 56 | class SageMakerLLMFlanT5(LLM): 57 | 58 | endpoint_name: str 59 | max_length: int 60 | num_beams: int 61 | top_k: int 62 | top_p: float 63 | temperature: float 64 | 65 | @property 66 | def _llm_type(self) -> str: 67 | return "summarize" 68 | 69 | def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str: 70 | parameters = { 71 | "max_length": self.max_length, 72 | "num_return_sequences": 1, 73 | #"num_beams": self.num_beams, 74 | "top_k": self.top_k, 75 | "top_p": self.top_p, 76 | "temperature": self.temperature, 77 | "do_sample": True, 78 | } 79 | payload = {"text_inputs": f"Summarize this article:\n\n{prompt}", **parameters} 80 | query_response = query_endpoint_with_json_payload( 81 | json.dumps(payload).encode("utf-8"), endpoint_name=self.endpoint_name 82 | ) 83 | generated_texts = parse_response_multiple_texts(query_response) 84 | 85 | return generated_texts[0] 86 | 87 | # Inputs: document id and s3 location of output 88 | def main(): 89 | 90 | print("Task starting") 91 | endpoint_name = os.environ['endpoint'] 92 | print(f"Endpoint: {endpoint_name}") 93 | table_name = os.environ['table'] 94 | print(f"Table: {table_name}") 95 | region = os.environ['region'] 96 | print(f"Region: {region}") 97 | docId = os.environ['docId'] 98 | print(f"docId: {docId}") 99 | jobId = os.environ['jobId'] 100 | print(f"jobId: {jobId}") 101 | bucket = os.environ['bucket'] 102 | print(f"bucket: {bucket}") 103 | name = os.environ['name'] 104 | print(f"name: {name}") 105 | 106 | if "chunk_size" in os.environ: 107 | chunk_size = os.environ['chunk_size'] 108 | else: 109 | chunk_size = 2000 110 | if "chunk_overlap" in os.environ: 111 | chunk_overlap = os.environ['chunk_overlap'] 112 | else: 113 | chunk_overlap = 500 114 | if "max_length" in os.environ: 115 | max_length = os.environ['chunmax_lengthk_overlap'] 116 | else: 117 | max_length = 10000 118 | if "num_beams" in os.environ: 119 | num_beams = os.environ['num_beams'] 120 | else: 121 | num_beams = 2 122 | if "top_k" in os.environ: 123 | top_k = os.environ['top_k'] 124 | else: 125 | top_k = 100 126 | if "top_p" in os.environ: 127 | top_p = os.environ['chunktop_p_overlap'] 128 | else: 129 | top_p = 0.9 130 | if "temperature" in os.environ: 131 | temperature = os.environ['temperature'] 132 | else: 133 | temperature = 0.5 134 | 135 | name_parts = name.split('/') 136 | local_path = os.path.join ('/tmp', name_parts[-1]) 137 | 138 | try: 139 | s3 = boto3.client('s3') 140 | print(f"Downloading s3://{bucket}/{name} to {local_path}") 141 | s3.download_file(bucket, name, local_path) 142 | 143 | text_splitter = RecursiveCharacterTextSplitter(separators = ["", "", "\n"], 144 | chunk_size = int(chunk_size), 145 | chunk_overlap = int(chunk_overlap)) 146 | 147 | with open(local_path) as f: 148 | doc = f.read() 149 | texts = text_splitter.split_text(doc) 150 | print(f"Number of splits: {len(texts)}") 151 | 152 | #docs = [Document(page_content=t) for t in texts] 153 | 154 | """ llm = SageMakerLLMFlanT5(endpoint_name = endpoint_name, 155 | top_k = int(top_k), 156 | top_p = float(top_p), 157 | max_length = int(max_length), 158 | num_beams = int(num_beams), 159 | temperature = float(temperature)) """ 160 | llm = SageMakerLLMAI21(endpoint_name = endpoint_name) 161 | 162 | #chain = load_summarize_chain(llm, chain_type="map_reduce", verbose=False) 163 | #summary = chain({"input_documents": docs}, return_only_outputs=True) 164 | responses = [] 165 | for t in texts: 166 | r = llm(t) 167 | responses.append(r) 168 | summary = "\n".join(responses) 169 | 170 | ddb = boto3.resource('dynamodb', region_name=region) 171 | table = ddb.Table(table_name) 172 | table.update_item( 173 | Key = { "documentId": docId, "jobId": jobId }, 174 | UpdateExpression = 'SET jobStatus = :jobstatusValue, summaryText = :outputValue', 175 | ExpressionAttributeValues = { 176 | ':jobstatusValue': "Complete", 177 | ':outputValue': summary 178 | } 179 | ) 180 | 181 | except Exception as e: 182 | trc = traceback.format_exc() 183 | print(trc) 184 | print(str(e)) 185 | 186 | if __name__ == "__main__": 187 | main() 188 | -------------------------------------------------------------------------------- /cdk/jest.config.js: -------------------------------------------------------------------------------- 1 | module.exports = { 2 | testEnvironment: 'node', 3 | roots: ['/test'], 4 | testMatch: ['**/*.test.ts'], 5 | transform: { 6 | '^.+\\.tsx?$': 'ts-jest' 7 | } 8 | }; 9 | -------------------------------------------------------------------------------- /cdk/lambda/asyncprocessor/lambda_function.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # SPDX-License-Identifier: MIT-0 3 | 4 | import json 5 | import os 6 | from helper import AwsHelper 7 | import traceback 8 | import datastore 9 | 10 | def processRequest(documentId, bucketName, objectName, snsRole, snsTopic): 11 | 12 | print("Starting job with documentId: {}, bucketName: {}, objectName: {}".format(documentId, bucketName, objectName)) 13 | 14 | response = None 15 | client = AwsHelper().getClient('textract') 16 | response = client.start_document_text_detection( 17 | ClientRequestToken = documentId, 18 | DocumentLocation={ 19 | 'S3Object': { 20 | 'Bucket': bucketName, 21 | 'Name': objectName 22 | } 23 | }, 24 | NotificationChannel= { 25 | "RoleArn": snsRole, 26 | "SNSTopicArn": snsTopic 27 | }, 28 | JobTag = documentId 29 | ) 30 | 31 | return response["JobId"] 32 | 33 | 34 | 35 | def respond(err, res=None): 36 | return { 37 | 'statusCode': '400' if err else '200', 38 | 'body': str(err) if err else json.dumps(res), 39 | 'headers': { 40 | 'Content-Type': 'application/json', 41 | "Access-Control-Allow-Origin": "*", 42 | "Access-Control-Allow-Credentials": 'true' 43 | }, 44 | } 45 | 46 | # Inputs: document id and s3 location 47 | def lambda_handler(event, context): 48 | 49 | print("Received event: " + json.dumps(event, indent=2)) 50 | operation = event['httpMethod'] 51 | 52 | if operation != "POST": 53 | return respond(ValueError('Unsupported method "{}"'.format(operation))) 54 | else: 55 | payload = event['queryStringParameters'] if operation == 'GET' else json.loads(event['body']) 56 | docId = payload['docId'] 57 | bucket = payload['bucket'] 58 | name = payload['name'] 59 | snsTopic = os.environ['SNS_TOPIC_ARN'] 60 | snsRole = os.environ['SNS_ROLE_ARN'] 61 | outputTable = os.environ['OUTPUT_TABLE'] 62 | documentsTable = os.environ['DOCUMENTS_TABLE'] 63 | 64 | try: 65 | jobId = processRequest(docId, bucket, name, snsRole, snsTopic) 66 | print(f"Started textract job {jobId}") 67 | ds = datastore.DocumentStore(documentsTable, outputTable) 68 | ds.createDocument(docId, bucket, name, "Started", jobId) 69 | return respond(None, {'msg': "Job started", 'jobId': jobId}) 70 | except Exception as e: 71 | trc = traceback.format_exc() 72 | print(f"Error starting textract job: {str(e)} - {trc}") 73 | return respond(ValueError(f"Could not start job: {str(e)}")); 74 | -------------------------------------------------------------------------------- /cdk/lambda/embeddingprocessor/lambda_function.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # SPDX-License-Identifier: MIT-0 3 | 4 | import os 5 | import traceback 6 | import json 7 | from helper import AwsHelper 8 | import datastore 9 | 10 | def postMessage(client, qUrl, jsonMessage): 11 | 12 | message = json.dumps(jsonMessage) 13 | 14 | client.send_message( 15 | QueueUrl=qUrl, 16 | MessageBody=message 17 | ) 18 | 19 | print("Submitted message to queue: {}".format(message)) 20 | 21 | def respond(err, res=None): 22 | return { 23 | 'statusCode': '400' if err else '200', 24 | 'body': str(err) if err else json.dumps(res), 25 | 'headers': { 26 | 'Content-Type': 'application/json', 27 | "Access-Control-Allow-Origin": "*", 28 | "Access-Control-Allow-Credentials": 'true' 29 | }, 30 | } 31 | 32 | # Inputs: document id and s3 location of text extract 33 | def lambda_handler(event, context): 34 | 35 | print("Received event: " + json.dumps(event, indent=2)) 36 | operation = event['httpMethod'] 37 | 38 | if operation != "POST": 39 | return respond(ValueError('Unsupported method "{}"'.format(operation))) 40 | else: 41 | payload = event['queryStringParameters'] if operation == 'GET' else json.loads(event['body']) 42 | docId = payload['docId'] 43 | bucket = payload['bucket'] 44 | name = payload['name'] 45 | 46 | queueUrl = os.environ['QUEUE_URL'] 47 | jobTable = os.environ['JOB_TABLE'] 48 | 49 | jsonMessage = { 'documentId' : docId, 50 | 'bucketName': bucket, 51 | 'objectName' : name, 52 | 'jobId': docId, 53 | 'jobTable': jobTable} 54 | 55 | try: 56 | client = AwsHelper().getClient('sqs') 57 | postMessage(client, queueUrl, jsonMessage) 58 | ds = datastore.DocumentStore("", "", embeddingTableName = jobTable) 59 | ds.createEmbeddingJob(docId, "Started", docId) 60 | 61 | return respond(None, {'msg': "Embedding started", 'job': docId}) 62 | except Exception as e: 63 | trc = traceback.format_exc() 64 | print(trc) 65 | return respond(ValueError(f"Could not start embeddings job for doc: {str(e)}")); 66 | -------------------------------------------------------------------------------- /cdk/lambda/embeddingworker/lambda_function.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # SPDX-License-Identifier: MIT-0 3 | 4 | import json 5 | import os 6 | import boto3 7 | import traceback 8 | 9 | def lambda_handler(event, context): 10 | 11 | print("event: {}".format(event)) 12 | 13 | message = json.loads(event['Records'][0]['body']) 14 | 15 | print("Message: {}".format(message)) 16 | docId = message['documentId'] 17 | bucket = message['bucketName'] 18 | name = message['objectName'] 19 | jobId = message['jobId'] 20 | 21 | clusterArn = os.environ['target'] 22 | taskDefinitionArn = os.environ['taskDefinitionArn'] 23 | subnets = os.environ['subnets'] 24 | subnet_list = subnets.split(',') 25 | 26 | try: 27 | ecs = boto3.client('ecs') 28 | response = ecs.run_task( 29 | cluster=clusterArn, 30 | count=1, 31 | launchType='FARGATE', 32 | networkConfiguration={ 33 | 'awsvpcConfiguration': { 34 | 'subnets': subnet_list 35 | } 36 | }, 37 | overrides={ 38 | 'containerOverrides': [ 39 | { 40 | 'name': 'worker', 41 | 'environment': [ 42 | { 43 | 'name': 'docId', 44 | 'value': docId 45 | }, 46 | { 47 | 'name': 'jobId', 48 | 'value': jobId 49 | }, 50 | { 51 | 'name': 'bucket', 52 | 'value': bucket 53 | }, 54 | { 55 | 'name': 'name', 56 | 'value': name 57 | }, 58 | ], 59 | } 60 | ] 61 | }, 62 | taskDefinition=taskDefinitionArn 63 | ) 64 | output = f"Launched task" 65 | except Exception as e: 66 | trc = traceback.format_exc() 67 | print(trc) 68 | output = str(e) 69 | 70 | return { 71 | 'statusCode': 200, 72 | 'body': output 73 | } 74 | -------------------------------------------------------------------------------- /cdk/lambda/helper/python/datastore.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # SPDX-License-Identifier: MIT-0 3 | 4 | import boto3 5 | from botocore.exceptions import ClientError 6 | from helper import AwsHelper 7 | import datetime 8 | 9 | class DocumentStore: 10 | 11 | def __init__(self, documentsTableName, outputTableName, jobTableName = None, embeddingTableName = None): 12 | self._documentsTableName = documentsTableName 13 | self._outputTableName = outputTableName 14 | self._jobTableName = jobTableName 15 | self._embedTableName = embeddingTableName 16 | 17 | def createSummaryJob(self, documentId, jobStatus, jobId): 18 | 19 | err = None 20 | 21 | dynamodb = AwsHelper().getResource("dynamodb") 22 | table = dynamodb.Table(self._jobTableName) 23 | 24 | try: 25 | table.update_item( 26 | Key = { "documentId": documentId, "jobId": jobId }, 27 | UpdateExpression = 'SET jobStatus = :jobstatusValue', 28 | ConditionExpression = 'attribute_not_exists(jobId)', 29 | ExpressionAttributeValues = { 30 | ':jobstatusValue': jobStatus 31 | } 32 | ) 33 | except ClientError as e: 34 | print(e) 35 | if e.response['Error']['Code'] == "ConditionalCheckFailedException": 36 | print(e.response['Error']['Message']) 37 | err = {'Error' : 'Document job already exist.'} 38 | else: 39 | raise 40 | 41 | return err 42 | 43 | def createEmbeddingJob(self, documentId, jobStatus, jobId): 44 | 45 | err = None 46 | 47 | dynamodb = AwsHelper().getResource("dynamodb") 48 | table = dynamodb.Table(self._embedTableName) 49 | 50 | try: 51 | table.update_item( 52 | Key = { "documentId": documentId, "jobId": jobId }, 53 | UpdateExpression = 'SET jobStatus = :jobstatusValue', 54 | ConditionExpression = 'attribute_not_exists(jobId)', 55 | ExpressionAttributeValues = { 56 | ':jobstatusValue': jobStatus 57 | } 58 | ) 59 | except ClientError as e: 60 | print(e) 61 | if e.response['Error']['Code'] == "ConditionalCheckFailedException": 62 | print(e.response['Error']['Message']) 63 | err = {'Error' : 'Document job already exist.'} 64 | else: 65 | raise 66 | 67 | return err 68 | 69 | def createDocument(self, documentId, bucketName, objectName, jobStatus, jobId): 70 | 71 | err = None 72 | 73 | dynamodb = AwsHelper().getResource("dynamodb") 74 | table = dynamodb.Table(self._documentsTableName) 75 | 76 | try: 77 | table.update_item( 78 | Key = { "documentId": documentId }, 79 | UpdateExpression = 'SET bucketName = :bucketNameValue, objectName = :objectNameValue, jobStatus = :jobstatusValue, jobId = :jobIdValue', 80 | ConditionExpression = 'attribute_not_exists(documentId)', 81 | ExpressionAttributeValues = { 82 | ':bucketNameValue': bucketName, 83 | ':objectNameValue': objectName, 84 | ':jobstatusValue': jobStatus, 85 | ':jobIdValue': jobId 86 | } 87 | ) 88 | except ClientError as e: 89 | print(e) 90 | if e.response['Error']['Code'] == "ConditionalCheckFailedException": 91 | print(e.response['Error']['Message']) 92 | err = {'Error' : 'Document already exist.'} 93 | else: 94 | raise 95 | 96 | return err 97 | 98 | def updateDocumentStatus(self, documentId, documentStatus): 99 | 100 | err = None 101 | 102 | dynamodb = AwsHelper().getResource("dynamodb") 103 | table = dynamodb.Table(self._documentsTableName) 104 | 105 | try: 106 | table.update_item( 107 | Key = { 'documentId': documentId }, 108 | UpdateExpression = 'SET jobStatus = :jobstatusValue', 109 | ConditionExpression = 'attribute_exists(documentId)', 110 | ExpressionAttributeValues = { 111 | ':jobstatusValue': documentStatus 112 | } 113 | ) 114 | except ClientError as e: 115 | if e.response['Error']['Code'] == "ConditionalCheckFailedException": 116 | print(e.response['Error']['Message']) 117 | err = {'Error' : 'Document does not exist.'} 118 | else: 119 | raise 120 | 121 | return err 122 | 123 | def getDocument(self, documentId): 124 | 125 | dynamodb = AwsHelper().getClient("dynamodb") 126 | 127 | ddbGetItemResponse = dynamodb.get_item( 128 | Key={'documentId': {'S': documentId} }, 129 | TableName=self._documentsTableName 130 | ) 131 | 132 | itemToReturn = None 133 | 134 | if('Item' in ddbGetItemResponse): 135 | itemToReturn = { 'documentId' : ddbGetItemResponse['Item']['documentId']['S'], 136 | 'bucketName' : ddbGetItemResponse['Item']['bucketName']['S'], 137 | 'objectName' : ddbGetItemResponse['Item']['objectName']['S'], 138 | 'jobId' : ddbGetItemResponse['Item']['jobId']['S'], 139 | 'jobStatus' : ddbGetItemResponse['Item']['jobStatus']['S'] } 140 | 141 | return itemToReturn 142 | 143 | def deleteDocument(self, documentId): 144 | 145 | dynamodb = AwsHelper().getResource("dynamodb") 146 | table = dynamodb.Table(self._documentsTableName) 147 | 148 | table.delete_item( 149 | Key={ 150 | 'documentId': documentId 151 | } 152 | ) 153 | 154 | def getDocuments(self, nextToken=None): 155 | 156 | dynamodb = AwsHelper().getResource("dynamodb") 157 | table = dynamodb.Table(self._documentsTableName) 158 | 159 | pageSize = 25 160 | 161 | if(nextToken): 162 | response = table.scan(ExclusiveStartKey={ "documentId" : nextToken}, Limit=pageSize) 163 | else: 164 | response = table.scan(Limit=pageSize) 165 | 166 | print("response: {}".format(response)) 167 | 168 | data = [] 169 | 170 | if('Items' in response): 171 | data = response['Items'] 172 | 173 | documents = { 174 | "documents" : data 175 | } 176 | 177 | if 'LastEvaluatedKey' in response: 178 | nextToken = response['LastEvaluatedKey']['documentId'] 179 | print("nexToken: {}".format(nextToken)) 180 | documents["nextToken"] = nextToken 181 | 182 | return documents 183 | -------------------------------------------------------------------------------- /cdk/lambda/helper/python/helper.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # SPDX-License-Identifier: MIT-0 3 | 4 | import boto3 5 | from botocore.client import Config 6 | import os 7 | import csv 8 | import io 9 | from boto3.dynamodb.conditions import Key 10 | 11 | class DynamoDBHelper: 12 | 13 | @staticmethod 14 | def getItems(tableName, key, value): 15 | items = None 16 | 17 | ddb = AwsHelper().getResource("dynamodb") 18 | table = ddb.Table(tableName) 19 | 20 | if key is not None and value is not None: 21 | filter = Key(key).eq(value) 22 | queryResult = table.query(KeyConditionExpression=filter) 23 | if(queryResult and "Items" in queryResult): 24 | items = queryResult["Items"] 25 | 26 | return items 27 | 28 | @staticmethod 29 | def insertItem(tableName, itemData): 30 | 31 | ddb = AwsHelper().getResource("dynamodb") 32 | table = ddb.Table(tableName) 33 | 34 | ddbResponse = table.put_item(Item=itemData) 35 | 36 | return ddbResponse 37 | 38 | @staticmethod 39 | def deleteItems(tableName, key, value, sk): 40 | items = DynamoDBHelper.getItems(tableName, key, value) 41 | if(items): 42 | ddb = AwsHelper().getResource("dynamodb") 43 | table = ddb.Table(tableName) 44 | for item in items: 45 | print("Deleting...") 46 | print("{} : {}".format(key, item[key])) 47 | print("{} : {}".format(sk, item[sk])) 48 | table.delete_item( 49 | Key={ 50 | key: value, 51 | sk : item[sk] 52 | }) 53 | print("Deleted...") 54 | 55 | class AwsHelper: 56 | def getClient(self, name, awsRegion=None): 57 | config = Config( 58 | retries = dict( 59 | max_attempts = 30 60 | ) 61 | ) 62 | if(awsRegion): 63 | return boto3.client(name, region_name=awsRegion, config=config) 64 | else: 65 | return boto3.client(name, config=config) 66 | 67 | def getResource(self, name, awsRegion=None): 68 | config = Config( 69 | retries = dict( 70 | max_attempts = 30 71 | ) 72 | ) 73 | 74 | if(awsRegion): 75 | return boto3.resource(name, region_name=awsRegion, config=config) 76 | else: 77 | return boto3.resource(name, config=config) 78 | 79 | class S3Helper: 80 | @staticmethod 81 | def getS3BucketRegion(bucketName): 82 | client = boto3.client('s3') 83 | response = client.get_bucket_location(Bucket=bucketName) 84 | awsRegion = response['LocationConstraint'] 85 | return awsRegion 86 | 87 | @staticmethod 88 | def getFileNames(bucketName, prefix, maxPages, allowedFileTypes, awsRegion=None): 89 | 90 | files = [] 91 | 92 | currentPage = 1 93 | hasMoreContent = True 94 | continuationToken = None 95 | 96 | s3client = AwsHelper().getClient('s3', awsRegion) 97 | 98 | while(hasMoreContent and currentPage <= maxPages): 99 | if(continuationToken): 100 | listObjectsResponse = s3client.list_objects_v2( 101 | Bucket=bucketName, 102 | Prefix=prefix, 103 | ContinuationToken=continuationToken) 104 | else: 105 | listObjectsResponse = s3client.list_objects_v2( 106 | Bucket=bucketName, 107 | Prefix=prefix) 108 | 109 | if(listObjectsResponse['IsTruncated']): 110 | continuationToken = listObjectsResponse['NextContinuationToken'] 111 | else: 112 | hasMoreContent = False 113 | 114 | for doc in listObjectsResponse['Contents']: 115 | docName = doc['Key'] 116 | docExt = FileHelper.getFileExtenstion(docName) 117 | docExtLower = docExt.lower() 118 | if(docExtLower in allowedFileTypes): 119 | files.append(docName) 120 | 121 | return files 122 | 123 | @staticmethod 124 | def writeToS3(content, bucketName, s3FileName, awsRegion=None): 125 | s3 = AwsHelper().getResource('s3', awsRegion) 126 | object = s3.Object(bucketName, s3FileName) 127 | object.put(Body=content) 128 | 129 | @staticmethod 130 | def readFromS3(bucketName, s3FileName, awsRegion=None): 131 | s3 = AwsHelper().getResource('s3', awsRegion) 132 | obj = s3.Object(bucketName, s3FileName) 133 | return obj.get()['Body'].read().decode('utf-8') 134 | 135 | @staticmethod 136 | def writeCSV(fieldNames, csvData, bucketName, s3FileName, awsRegion=None): 137 | csv_file = io.StringIO() 138 | #with open(fileName, 'w') as csv_file: 139 | writer = csv.DictWriter(csv_file, fieldnames=fieldNames) 140 | writer.writeheader() 141 | 142 | for item in csvData: 143 | i = 0 144 | row = {} 145 | for value in item: 146 | row[fieldNames[i]] = value 147 | i = i + 1 148 | writer.writerow(row) 149 | S3Helper.writeToS3(csv_file.getvalue(), bucketName, s3FileName) 150 | 151 | @staticmethod 152 | def writeCSVRaw(csvData, bucketName, s3FileName): 153 | csv_file = io.StringIO() 154 | #with open(fileName, 'w') as csv_file: 155 | writer = csv.writer(csv_file) 156 | for item in csvData: 157 | writer.writerow(item) 158 | S3Helper.writeToS3(csv_file.getvalue(), bucketName, s3FileName) 159 | 160 | 161 | class FileHelper: 162 | @staticmethod 163 | def getFileNameAndExtension(filePath): 164 | basename = os.path.basename(filePath) 165 | dn, dext = os.path.splitext(basename) 166 | return (dn, dext[1:]) 167 | 168 | @staticmethod 169 | def getFileName(fileName): 170 | basename = os.path.basename(fileName) 171 | dn, dext = os.path.splitext(basename) 172 | return dn 173 | 174 | @staticmethod 175 | def getFileExtenstion(fileName): 176 | basename = os.path.basename(fileName) 177 | dn, dext = os.path.splitext(basename) 178 | return dext[1:] 179 | 180 | 181 | @staticmethod 182 | def readFile(fileName): 183 | with open(fileName, 'r') as document: 184 | return document.read() 185 | 186 | @staticmethod 187 | def writeToFile(fileName, content): 188 | with open(fileName, 'w') as document: 189 | document.write(content) 190 | 191 | @staticmethod 192 | def writeToFileWithMode(fileName, content, mode): 193 | with open(fileName, mode) as document: 194 | document.write(content) 195 | @staticmethod 196 | def getFilesInFolder(path, fileTypes): 197 | for file in os.listdir(path): 198 | if os.path.isfile(os.path.join(path, file)): 199 | ext = FileHelper.getFileExtenstion(file) 200 | if(ext.lower() in fileTypes): 201 | yield file 202 | 203 | @staticmethod 204 | def getFileNames(path, allowedLocalFileTypes): 205 | files = [] 206 | 207 | for file in FileHelper.getFilesInFolder(path, allowedLocalFileTypes): 208 | files.append(path + file) 209 | 210 | return files 211 | 212 | @staticmethod 213 | def writeCSV(fileName, fieldNames, csvData): 214 | with open(fileName, 'w') as csv_file: 215 | writer = csv.DictWriter(csv_file, fieldnames=fieldNames) 216 | writer.writeheader() 217 | 218 | for item in csvData: 219 | i = 0 220 | row = {} 221 | for value in item: 222 | row[fieldNames[i]] = value 223 | i = i + 1 224 | writer.writerow(row) 225 | 226 | @staticmethod 227 | def writeCSVRaw(fileName, csvData): 228 | with open(fileName, 'w') as csv_file: 229 | writer = csv.writer(csv_file) 230 | for item in csvData: 231 | writer.writerow(item) 232 | -------------------------------------------------------------------------------- /cdk/lambda/jobresultprocessor/lambda_function.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # SPDX-License-Identifier: MIT-0 3 | 4 | import json 5 | import os 6 | import boto3 7 | import time 8 | from helper import AwsHelper 9 | from og import OutputGenerator 10 | import datastore 11 | 12 | def getJobResults(api, jobId): 13 | 14 | pages = [] 15 | 16 | time.sleep(5) 17 | 18 | client = AwsHelper().getClient('textract') 19 | if(api == "StartDocumentTextDetection"): 20 | response = client.get_document_text_detection(JobId=jobId) 21 | else: 22 | response = client.get_document_analysis(JobId=jobId) 23 | pages.append(response) 24 | print("Resultset page recieved: {}".format(len(pages))) 25 | nextToken = None 26 | if('NextToken' in response): 27 | nextToken = response['NextToken'] 28 | print("Next token: {}".format(nextToken)) 29 | 30 | while(nextToken): 31 | time.sleep(2) 32 | 33 | if(api == "StartDocumentTextDetection"): 34 | response = client.get_document_text_detection(JobId=jobId, NextToken=nextToken) 35 | else: 36 | response = client.get_document_analysis(JobId=jobId, NextToken=nextToken) 37 | 38 | pages.append(response) 39 | print("Resultset page recieved: {}".format(len(pages))) 40 | nextToken = None 41 | if('NextToken' in response): 42 | nextToken = response['NextToken'] 43 | print("Next token: {}".format(nextToken)) 44 | 45 | return pages 46 | 47 | def processRequest(request): 48 | 49 | output = "" 50 | 51 | print(request) 52 | 53 | jobId = request['jobId'] 54 | jobTag = request['jobTag'] 55 | jobStatus = request['jobStatus'] 56 | jobAPI = request['jobAPI'] 57 | bucketName = request['bucketName'] 58 | objectName = request['objectName'] 59 | outputTable = request["outputTable"] 60 | documentsTable = request["documentsTable"] 61 | 62 | pages = getJobResults(jobAPI, jobId) 63 | 64 | print("Result pages recieved: {}".format(len(pages))) 65 | 66 | dynamodb = AwsHelper().getResource("dynamodb") 67 | ddb = dynamodb.Table(outputTable) 68 | 69 | detectForms = False 70 | detectTables = False 71 | if(jobAPI == "StartDocumentAnalysis"): 72 | detectForms = True 73 | detectTables = True 74 | 75 | dynamodb = AwsHelper().getResource('dynamodb') 76 | ddb = dynamodb.Table(outputTable) 77 | 78 | opg = OutputGenerator(jobTag, pages, bucketName, objectName, detectForms, detectTables, ddb) 79 | opg.run() 80 | 81 | print("DocumentId: {}".format(jobTag)) 82 | 83 | ds = datastore.DocumentStore(documentsTable, outputTable) 84 | ds.updateDocumentStatus(jobTag, jobStatus) 85 | 86 | output = "Processed -> Document: {}, Object: {}/{} processed.".format(jobTag, bucketName, objectName) 87 | 88 | print(output) 89 | 90 | return { 91 | 'statusCode': 200, 92 | 'body': output 93 | } 94 | 95 | def lambda_handler(event, context): 96 | 97 | print("event: {}".format(event)) 98 | 99 | body = json.loads(event['Records'][0]['body']) 100 | message = json.loads(body['Message']) 101 | 102 | print("Message: {}".format(message)) 103 | 104 | request = {} 105 | 106 | request["jobId"] = message['JobId'] 107 | request["jobTag"] = message['JobTag'] 108 | request["jobStatus"] = message['Status'] 109 | request["jobAPI"] = message['API'] 110 | request["bucketName"] = message['DocumentLocation']['S3Bucket'] 111 | request["objectName"] = message['DocumentLocation']['S3ObjectName'] 112 | 113 | request["outputTable"] = os.environ['OUTPUT_TABLE'] 114 | request["documentsTable"] = os.environ['DOCUMENTS_TABLE'] 115 | 116 | return processRequest(request) 117 | -------------------------------------------------------------------------------- /cdk/lambda/summarizationprocessor/lambda_function.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # SPDX-License-Identifier: MIT-0 3 | 4 | import os 5 | import traceback 6 | import json 7 | from helper import AwsHelper 8 | import datastore 9 | 10 | def postMessage(client, qUrl, jsonMessage): 11 | 12 | message = json.dumps(jsonMessage) 13 | 14 | client.send_message( 15 | QueueUrl=qUrl, 16 | MessageBody=message 17 | ) 18 | 19 | print("Submitted message to queue: {}".format(message)) 20 | 21 | def respond(err, res=None): 22 | return { 23 | 'statusCode': '400' if err else '200', 24 | 'body': str(err) if err else json.dumps(res), 25 | 'headers': { 26 | 'Content-Type': 'application/json', 27 | "Access-Control-Allow-Origin": "*", 28 | "Access-Control-Allow-Credentials": 'true' 29 | }, 30 | } 31 | 32 | # Inputs: document id and s3 location of output 33 | def lambda_handler(event, context): 34 | 35 | print("Received event: " + json.dumps(event, indent=2)) 36 | operation = event['httpMethod'] 37 | 38 | if operation != "POST": 39 | return respond(ValueError('Unsupported method "{}"'.format(operation))) 40 | else: 41 | payload = event['queryStringParameters'] if operation == 'GET' else json.loads(event['body']) 42 | docId = payload['docId'] 43 | bucket = payload['bucket'] 44 | name = payload['name'] 45 | chunkSize = payload['chunkSize'] 46 | chunkOverlap = payload['chunkOverlap'] 47 | max_length = payload['max_length'] 48 | top_p = payload['top_p'] 49 | top_k = payload['top_k'] 50 | num_beams = payload['num_beams'] 51 | temperature = payload['temperature'] 52 | 53 | queueUrl = os.environ['QUEUE_URL'] 54 | jobTable = os.environ['JOB_TABLE'] 55 | 56 | jsonMessage = { 'documentId' : docId, 57 | 'bucketName': bucket, 58 | 'objectName' : name, 59 | 'jobId': docId, 60 | 'chunkSize': chunkSize, 61 | 'chunkOverlap': chunkOverlap, 62 | 'max_length': max_length, 63 | 'top_p': top_p, 64 | 'top_k': top_k, 65 | 'num_beams': num_beams, 66 | 'temperature': temperature} 67 | 68 | try: 69 | client = AwsHelper().getClient('sqs') 70 | postMessage(client, queueUrl, jsonMessage) 71 | ds = datastore.DocumentStore("", "", jobTable) 72 | ds.createSummaryJob(docId, "Started", docId) 73 | 74 | return respond(None, {'msg': "Summarization started", 'job': docId}) 75 | except Exception as e: 76 | trc = traceback.format_exc() 77 | print(trc) 78 | return respond(ValueError(f"Could not summarize doc: {str(e)}")); 79 | -------------------------------------------------------------------------------- /cdk/lambda/taskprocessor/lambda_function.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # SPDX-License-Identifier: MIT-0 3 | 4 | import json 5 | import os 6 | import boto3 7 | import traceback 8 | 9 | def lambda_handler(event, context): 10 | 11 | print("event: {}".format(event)) 12 | 13 | message = json.loads(event['Records'][0]['body']) 14 | 15 | print("Message: {}".format(message)) 16 | docId = message['documentId'] 17 | bucket = message['bucketName'] 18 | name = message['objectName'] 19 | jobId = message['jobId'] 20 | chunk_size = message['chunkSize'] 21 | chunk_overlap = message['chunkOverlap'] 22 | max_length = message['max_length'] 23 | top_p = message['top_p'] 24 | top_k = message['top_k'] 25 | num_beams = message['num_beams'] 26 | temperature = message['temperature'] 27 | 28 | clusterArn = os.environ['target'] 29 | taskDefinitionArn = os.environ['taskDefinitionArn'] 30 | subnets = os.environ['subnets'] 31 | subnet_list = subnets.split(',') 32 | 33 | try: 34 | ecs = boto3.client('ecs') 35 | response = ecs.run_task( 36 | cluster=clusterArn, 37 | count=1, 38 | launchType='FARGATE', 39 | networkConfiguration={ 40 | 'awsvpcConfiguration': { 41 | 'subnets': subnet_list 42 | } 43 | }, 44 | overrides={ 45 | 'containerOverrides': [ 46 | { 47 | 'name': 'worker', 48 | 'environment': [ 49 | { 50 | 'name': 'docId', 51 | 'value': docId 52 | }, 53 | { 54 | 'name': 'jobId', 55 | 'value': jobId 56 | }, 57 | { 58 | 'name': 'bucket', 59 | 'value': bucket 60 | }, 61 | { 62 | 'name': 'name', 63 | 'value': name 64 | }, 65 | { 66 | 'name': 'chunk_size', 67 | 'value': str(chunk_size) 68 | }, 69 | { 70 | 'name': 'chunk_overlap', 71 | 'value': str(chunk_overlap) 72 | }, 73 | ], 74 | } 75 | ] 76 | }, 77 | taskDefinition=taskDefinitionArn 78 | ) 79 | output = f"Launched task" 80 | except Exception as e: 81 | trc = traceback.format_exc() 82 | print(trc) 83 | output = str(e) 84 | 85 | return { 86 | 'statusCode': 200, 87 | 'body': output 88 | } 89 | -------------------------------------------------------------------------------- /cdk/lambda/textractor/python/og.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # SPDX-License-Identifier: MIT-0 3 | 4 | import json 5 | from helper import FileHelper, S3Helper 6 | from trp import Document 7 | import boto3 8 | 9 | class OutputGenerator: 10 | def __init__(self, documentId, response, bucketName, objectName, forms, tables, ddb): 11 | self.documentId = documentId 12 | self.response = response 13 | self.bucketName = bucketName 14 | self.objectName = objectName 15 | self.forms = forms 16 | self.tables = tables 17 | self.ddb = ddb 18 | 19 | self.outputPath = "{}-analysis/{}/".format(objectName, documentId) 20 | 21 | self.document = Document(self.response) 22 | 23 | def saveItem(self, pk, sk, output): 24 | 25 | jsonItem = {} 26 | jsonItem['documentId'] = pk 27 | jsonItem['outputType'] = sk 28 | jsonItem['outputPath'] = output 29 | 30 | self.ddb.put_item(Item=jsonItem) 31 | 32 | def _outputText(self, page, p): 33 | text = page.text 34 | opath = "{}page-{}-text.txt".format(self.outputPath, p) 35 | S3Helper.writeToS3(text, self.bucketName, opath) 36 | self.saveItem(self.documentId, "page-{}-Text".format(p), opath) 37 | 38 | textInReadingOrder = page.getTextInReadingOrder() 39 | opath = "{}page-{}-text-inreadingorder.txt".format(self.outputPath, p) 40 | S3Helper.writeToS3(textInReadingOrder, self.bucketName, opath) 41 | self.saveItem(self.documentId, "page-{}-TextInReadingOrder".format(p), opath) 42 | 43 | def _outputForm(self, page, p): 44 | csvData = [] 45 | for field in page.form.fields: 46 | csvItem = [] 47 | if(field.key): 48 | csvItem.append(field.key.text) 49 | else: 50 | csvItem.append("") 51 | if(field.value): 52 | csvItem.append(field.value.text) 53 | else: 54 | csvItem.append("") 55 | csvData.append(csvItem) 56 | csvFieldNames = ['Key', 'Value'] 57 | opath = "{}page-{}-forms.csv".format(self.outputPath, p) 58 | S3Helper.writeCSV(csvFieldNames, csvData, self.bucketName, opath) 59 | self.saveItem(self.documentId, "page-{}-Forms".format(p), opath) 60 | 61 | def _outputTable(self, page, p): 62 | 63 | csvData = [] 64 | for table in page.tables: 65 | csvRow = [] 66 | csvRow.append("Table") 67 | csvData.append(csvRow) 68 | for row in table.rows: 69 | csvRow = [] 70 | for cell in row.cells: 71 | csvRow.append(cell.text) 72 | csvData.append(csvRow) 73 | csvData.append([]) 74 | csvData.append([]) 75 | 76 | opath = "{}page-{}-tables.csv".format(self.outputPath, p) 77 | S3Helper.writeCSVRaw(csvData, self.bucketName, opath) 78 | self.saveItem(self.documentId, "page-{}-Tables".format(p), opath) 79 | 80 | def run(self): 81 | 82 | if(not self.document.pages): 83 | return 84 | 85 | opath = "{}response.json".format(self.outputPath) 86 | S3Helper.writeToS3(json.dumps(self.response), self.bucketName, opath) 87 | self.saveItem(self.documentId, 'Response', opath) 88 | 89 | print("Total Pages in Document: {}".format(len(self.document.pages))) 90 | 91 | docText = "" 92 | 93 | p = 1 94 | for page in self.document.pages: 95 | 96 | opath = "{}page-{}-response.json".format(self.outputPath, p) 97 | S3Helper.writeToS3(json.dumps(page.blocks), self.bucketName, opath) 98 | self.saveItem(self.documentId, "page-{}-Response".format(p), opath) 99 | 100 | self._outputText(page, p) 101 | 102 | docText = docText + page.text + "\n" 103 | 104 | if(self.forms): 105 | self._outputForm(page, p) 106 | 107 | if(self.tables): 108 | self._outputTable(page, p) 109 | 110 | p = p + 1 111 | 112 | orderedDocText = "" 113 | cnt = 0 114 | chunkSize = 5 115 | for page in self.document.getPagesInReadingOrder(): 116 | orderedDocText = orderedDocText + page.getTextInReadingOrder() + "\n\n" 117 | cnt = cnt + 1 118 | if cnt % chunkSize == 0: 119 | orderedDocText = orderedDocText + "\n\n" 120 | opath = "{}response.txt".format(self.outputPath) 121 | S3Helper.writeToS3(orderedDocText, self.bucketName, opath) 122 | self.saveItem(self.documentId, 'ResponseOrderedText', opath) 123 | -------------------------------------------------------------------------------- /cdk/lambda/textractor/python/trp.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # SPDX-License-Identifier: MIT-0 3 | 4 | import json 5 | from collections import OrderedDict 6 | 7 | class BoundingBox: 8 | def __init__(self, width, height, left, top): 9 | self._width = width 10 | self._height = height 11 | self._left = left 12 | self._top = top 13 | 14 | def __str__(self): 15 | return "width: {}, height: {}, left: {}, top: {}".format(self._width, self._height, self._left, self._top) 16 | 17 | @property 18 | def width(self): 19 | return self._width 20 | 21 | @property 22 | def height(self): 23 | return self._height 24 | 25 | @property 26 | def left(self): 27 | return self._left 28 | 29 | @property 30 | def top(self): 31 | return self._top 32 | 33 | class Polygon: 34 | def __init__(self, x, y): 35 | self._x = x 36 | self._y = y 37 | 38 | def __str__(self): 39 | return "x: {}, y: {}".format(self._x, self._y) 40 | 41 | @property 42 | def x(self): 43 | return self._x 44 | 45 | @property 46 | def y(self): 47 | return self._y 48 | 49 | class Geometry: 50 | def __init__(self, geometry): 51 | boundingBox = geometry["BoundingBox"] 52 | polygon = geometry["Polygon"] 53 | bb = BoundingBox(boundingBox["Width"], boundingBox["Height"], boundingBox["Left"], boundingBox["Top"]) 54 | pgs = [] 55 | for pg in polygon: 56 | pgs.append(Polygon(pg["X"], pg["Y"])) 57 | 58 | self._boundingBox = bb 59 | self._polygon = pgs 60 | 61 | def __str__(self): 62 | s = "BoundingBox: {}\n".format(str(self._boundingBox)) 63 | return s 64 | 65 | @property 66 | def boundingBox(self): 67 | return self._boundingBox 68 | 69 | @property 70 | def polygon(self): 71 | return self._polygon 72 | 73 | class Word: 74 | def __init__(self, block, blockMap): 75 | self._block = block 76 | self._confidence = block['Confidence'] 77 | self._geometry = Geometry(block['Geometry']) 78 | self._id = block['Id'] 79 | self._text = "" 80 | if(block['Text']): 81 | self._text = block['Text'] 82 | 83 | def __str__(self): 84 | return self._text 85 | 86 | @property 87 | def confidence(self): 88 | return self._confidence 89 | 90 | @property 91 | def geometry(self): 92 | return self._geometry 93 | 94 | @property 95 | def id(self): 96 | return self._id 97 | 98 | @property 99 | def text(self): 100 | return self._text 101 | 102 | @property 103 | def block(self): 104 | return self._block 105 | 106 | class Line: 107 | def __init__(self, block, blockMap): 108 | 109 | self._block = block 110 | self._confidence = block['Confidence'] 111 | self._geometry = Geometry(block['Geometry']) 112 | self._id = block['Id'] 113 | 114 | self._text = "" 115 | if(block['Text']): 116 | self._text = block['Text'] 117 | 118 | self._words = [] 119 | if('Relationships' in block and block['Relationships']): 120 | for rs in block['Relationships']: 121 | if(rs['Type'] == 'CHILD'): 122 | for cid in rs['Ids']: 123 | if(blockMap[cid]["BlockType"] == "WORD"): 124 | self._words.append(Word(blockMap[cid], blockMap)) 125 | def __str__(self): 126 | s = "Line\n==========\n" 127 | s = s + self._text + "\n" 128 | s = s + "Words\n----------\n" 129 | for word in self._words: 130 | s = s + "[{}]".format(str(word)) 131 | return s 132 | 133 | @property 134 | def confidence(self): 135 | return self._confidence 136 | 137 | @property 138 | def geometry(self): 139 | return self._geometry 140 | 141 | @property 142 | def id(self): 143 | return self._id 144 | 145 | @property 146 | def words(self): 147 | return self._words 148 | 149 | @property 150 | def text(self): 151 | return self._text 152 | 153 | @property 154 | def block(self): 155 | return self._block 156 | 157 | class SelectionElement: 158 | def __init__(self, block, blockMap): 159 | self._confidence = block['Confidence'] 160 | self._geometry = Geometry(block['Geometry']) 161 | self._id = block['Id'] 162 | self._selectionStatus = block['SelectionStatus'] 163 | 164 | @property 165 | def confidence(self): 166 | return self._confidence 167 | 168 | @property 169 | def geometry(self): 170 | return self._geometry 171 | 172 | @property 173 | def id(self): 174 | return self._id 175 | 176 | @property 177 | def selectionStatus(self): 178 | return self._selectionStatus 179 | 180 | class FieldKey: 181 | def __init__(self, block, children, blockMap): 182 | self._block = block 183 | self._confidence = block['Confidence'] 184 | self._geometry = Geometry(block['Geometry']) 185 | self._id = block['Id'] 186 | self._text = "" 187 | self._content = [] 188 | 189 | t = [] 190 | 191 | for eid in children: 192 | wb = blockMap[eid] 193 | if(wb['BlockType'] == "WORD"): 194 | w = Word(wb, blockMap) 195 | self._content.append(w) 196 | t.append(w.text) 197 | 198 | if(t): 199 | self._text = ' '.join(t) 200 | 201 | def __str__(self): 202 | return self._text 203 | 204 | @property 205 | def confidence(self): 206 | return self._confidence 207 | 208 | @property 209 | def geometry(self): 210 | return self._geometry 211 | 212 | @property 213 | def id(self): 214 | return self._id 215 | 216 | @property 217 | def content(self): 218 | return self._content 219 | 220 | @property 221 | def text(self): 222 | return self._text 223 | 224 | @property 225 | def block(self): 226 | return self._block 227 | 228 | class FieldValue: 229 | def __init__(self, block, children, blockMap): 230 | self._block = block 231 | self._confidence = block['Confidence'] 232 | self._geometry = Geometry(block['Geometry']) 233 | self._id = block['Id'] 234 | self._text = "" 235 | self._content = [] 236 | 237 | t = [] 238 | 239 | for eid in children: 240 | wb = blockMap[eid] 241 | if(wb['BlockType'] == "WORD"): 242 | w = Word(wb, blockMap) 243 | self._content.append(w) 244 | t.append(w.text) 245 | elif(wb['BlockType'] == "SELECTION_ELEMENT"): 246 | se = SelectionElement(wb, blockMap) 247 | self._content.append(se) 248 | self._text = se.selectionStatus 249 | 250 | if(t): 251 | self._text = ' '.join(t) 252 | 253 | def __str__(self): 254 | return self._text 255 | 256 | @property 257 | def confidence(self): 258 | return self._confidence 259 | 260 | @property 261 | def geometry(self): 262 | return self._geometry 263 | 264 | @property 265 | def id(self): 266 | return self._id 267 | 268 | @property 269 | def content(self): 270 | return self._content 271 | 272 | @property 273 | def text(self): 274 | return self._text 275 | 276 | @property 277 | def block(self): 278 | return self._block 279 | 280 | class Field: 281 | def __init__(self, block, blockMap): 282 | self._key = None 283 | self._value = None 284 | 285 | for item in block['Relationships']: 286 | if(item["Type"] == "CHILD"): 287 | self._key = FieldKey(block, item['Ids'], blockMap) 288 | elif(item["Type"] == "VALUE"): 289 | for eid in item['Ids']: 290 | vkvs = blockMap[eid] 291 | if 'VALUE' in vkvs['EntityTypes']: 292 | if('Relationships' in vkvs): 293 | for vitem in vkvs['Relationships']: 294 | if(vitem["Type"] == "CHILD"): 295 | self._value = FieldValue(vkvs, vitem['Ids'], blockMap) 296 | def __str__(self): 297 | s = "\nField\n==========\n" 298 | k = "" 299 | v = "" 300 | if(self._key): 301 | k = str(self._key) 302 | if(self._value): 303 | v = str(self._value) 304 | s = s + "Key: {}\nValue: {}".format(k, v) 305 | return s 306 | 307 | @property 308 | def key(self): 309 | return self._key 310 | 311 | @property 312 | def value(self): 313 | return self._value 314 | 315 | class Form: 316 | def __init__(self): 317 | self._fields = [] 318 | self._fieldsMap = {} 319 | 320 | def addField(self, field): 321 | self._fields.append(field) 322 | self._fieldsMap[field.key.text] = field 323 | 324 | def __str__(self): 325 | s = "" 326 | for field in self._fields: 327 | s = s + str(field) + "\n" 328 | return s 329 | 330 | @property 331 | def fields(self): 332 | return self._fields 333 | 334 | def getFieldByKey(self, key): 335 | field = None 336 | if(key in self._fieldsMap): 337 | field = self._fieldsMap[key] 338 | return field 339 | 340 | def searchFieldsByKey(self, key): 341 | searchKey = key.lower() 342 | results = [] 343 | for field in self._fields: 344 | if(field.key and searchKey in field.key.text.lower()): 345 | results.append(field) 346 | return results 347 | 348 | class Cell: 349 | 350 | def __init__(self, block, blockMap): 351 | self._block = block 352 | self._confidence = block['Confidence'] 353 | self._rowIndex = block['RowIndex'] 354 | self._columnIndex = block['ColumnIndex'] 355 | self._rowSpan = block['RowSpan'] 356 | self._columnSpan = block['ColumnSpan'] 357 | self._geometry = Geometry(block['Geometry']) 358 | self._id = block['Id'] 359 | self._content = [] 360 | self._text = "" 361 | if('Relationships' in block and block['Relationships']): 362 | for rs in block['Relationships']: 363 | if(rs['Type'] == 'CHILD'): 364 | for cid in rs['Ids']: 365 | blockType = blockMap[cid]["BlockType"] 366 | if(blockType == "WORD"): 367 | w = Word(blockMap[cid], blockMap) 368 | self._content.append(w) 369 | self._text = self._text + w.text + ' ' 370 | elif(blockType == "SELECTION_ELEMENT"): 371 | se = SelectionElement(blockMap[cid], blockMap) 372 | self._content.append(se) 373 | self._text = self._text + se.selectionStatus + ', ' 374 | 375 | def __str__(self): 376 | return self._text 377 | 378 | @property 379 | def confidence(self): 380 | return self._confidence 381 | 382 | @property 383 | def rowIndex(self): 384 | return self._rowIndex 385 | 386 | @property 387 | def columnIndex(self): 388 | return self._columnIndex 389 | 390 | @property 391 | def rowSpan(self): 392 | return self._rowSpan 393 | 394 | @property 395 | def columnSpan(self): 396 | return self._columnSpan 397 | 398 | @property 399 | def geometry(self): 400 | return self._geometry 401 | 402 | @property 403 | def id(self): 404 | return self._id 405 | 406 | @property 407 | def content(self): 408 | return self._content 409 | 410 | @property 411 | def text(self): 412 | return self._text 413 | 414 | @property 415 | def block(self): 416 | return self._block 417 | 418 | class Row: 419 | def __init__(self): 420 | self._cells = [] 421 | 422 | def __str__(self): 423 | s = "" 424 | for cell in self._cells: 425 | s = s + "[{}]".format(str(cell)) 426 | return s 427 | 428 | @property 429 | def cells(self): 430 | return self._cells 431 | 432 | class Table: 433 | 434 | def __init__(self, block, blockMap): 435 | 436 | self._block = block 437 | 438 | self._confidence = block['Confidence'] 439 | self._geometry = Geometry(block['Geometry']) 440 | 441 | self._id = block['Id'] 442 | self._rows = [] 443 | 444 | ri = 1 445 | row = Row() 446 | cell = None 447 | if('Relationships' in block and block['Relationships']): 448 | for rs in block['Relationships']: 449 | if(rs['Type'] == 'CHILD'): 450 | for cid in rs['Ids']: 451 | cell = Cell(blockMap[cid], blockMap) 452 | if(cell.rowIndex > ri): 453 | self._rows.append(row) 454 | row = Row() 455 | ri = cell.rowIndex 456 | row.cells.append(cell) 457 | if(row and row.cells): 458 | self._rows.append(row) 459 | 460 | def __str__(self): 461 | s = "Table\n==========\n" 462 | for row in self._rows: 463 | s = s + "Row\n==========\n" 464 | s = s + str(row) + "\n" 465 | return s 466 | 467 | @property 468 | def confidence(self): 469 | return self._confidence 470 | 471 | @property 472 | def geometry(self): 473 | return self._geometry 474 | 475 | @property 476 | def id(self): 477 | return self._id 478 | 479 | @property 480 | def rows(self): 481 | return self._rows 482 | 483 | @property 484 | def block(self): 485 | return self._block 486 | 487 | class Page: 488 | 489 | def __init__(self, blocks, blockMap): 490 | self._blocks = blocks 491 | self._text = "" 492 | self._lines = [] 493 | self._form = Form() 494 | self._tables = [] 495 | self._content = [] 496 | self._pageNumber = -1 497 | 498 | self._parse(blockMap) 499 | 500 | def __str__(self): 501 | s = "Page\n==========\n" 502 | for item in self._content: 503 | s = s + str(item) + "\n" 504 | return s 505 | 506 | def _parse(self, blockMap): 507 | for item in self._blocks: 508 | if item["BlockType"] == "PAGE": 509 | self._geometry = Geometry(item['Geometry']) 510 | self._id = item['Id'] 511 | self._pageNumber = item['Page'] 512 | elif item["BlockType"] == "LINE": 513 | l = Line(item, blockMap) 514 | self._lines.append(l) 515 | self._content.append(l) 516 | self._text = self._text + l.text + '\n' 517 | elif item["BlockType"] == "TABLE": 518 | t = Table(item, blockMap) 519 | self._tables.append(t) 520 | self._content.append(t) 521 | elif item["BlockType"] == "KEY_VALUE_SET": 522 | if 'KEY' in item['EntityTypes']: 523 | f = Field(item, blockMap) 524 | if(f.key): 525 | self._form.addField(f) 526 | self._content.append(f) 527 | else: 528 | print("WARNING: Detected K/V where key does not have content. Excluding key from output.") 529 | print(f) 530 | print(item) 531 | 532 | def getLinesInReadingOrder(self ): 533 | columns = [] 534 | lines = [] 535 | for item in self._lines: 536 | column_found=False 537 | for index, column in enumerate(columns): 538 | bbox_left = item.geometry.boundingBox.left 539 | bbox_right = item.geometry.boundingBox.left + item.geometry.boundingBox.width 540 | bbox_centre = item.geometry.boundingBox.left + item.geometry.boundingBox.width/2 541 | column_centre = column['left'] + column['right']/2 542 | if (bbox_centre > column['left'] and bbox_centre < column['right']) or (column_centre > bbox_left and column_centre < bbox_right): 543 | #Bbox appears inside the column 544 | lines.append([index, item.text]) 545 | column_found=True 546 | break 547 | if not column_found: 548 | columns.append({'left':item.geometry.boundingBox.left, 'right':item.geometry.boundingBox.left + item.geometry.boundingBox.width}) 549 | lines.append([len(columns)-1, item.text]) 550 | 551 | lines.sort(key=lambda x: x[0]) 552 | return lines 553 | 554 | def getTextInReadingOrder(self ): 555 | lines = self.getLinesInReadingOrder() 556 | text = "" 557 | for line in lines: 558 | text = text + line[1] + '\n' 559 | return text 560 | 561 | def getLineHeights(self): 562 | heights = [] 563 | for item in self._lines: 564 | heights.append(item.geometry.boundingBox.height) 565 | return list(set(heights)) 566 | 567 | @property 568 | def blocks(self): 569 | return self._blocks 570 | 571 | @property 572 | def text(self): 573 | return self._text 574 | 575 | @property 576 | def lines(self): 577 | return self._lines 578 | 579 | @property 580 | def form(self): 581 | return self._form 582 | 583 | @property 584 | def tables(self): 585 | return self._tables 586 | 587 | @property 588 | def content(self): 589 | return self._content 590 | 591 | @property 592 | def geometry(self): 593 | return self._geometry 594 | 595 | @property 596 | def id(self): 597 | return self._id 598 | 599 | class Document: 600 | 601 | def __init__(self, responsePages): 602 | 603 | if(not isinstance(responsePages, list)): 604 | rps = [] 605 | rps.append(responsePages) 606 | responsePages = rps 607 | 608 | self._responsePages = responsePages 609 | self._pages = [] 610 | 611 | self._parse() 612 | 613 | def __str__(self): 614 | s = "\nDocument\n==========\n" 615 | for p in self._pages: 616 | s = s + str(p) + "\n\n" 617 | return s 618 | 619 | def _parseDocumentPagesAndBlockMap(self): 620 | 621 | blockMap = {} 622 | 623 | documentPages = [] 624 | documentPage = None 625 | for page in self._responsePages: 626 | for block in page['Blocks']: 627 | if('BlockType' in block and 'Id' in block): 628 | blockMap[block['Id']] = block 629 | 630 | if(block['BlockType'] == 'PAGE'): 631 | if(documentPage): 632 | documentPages.append({"Blocks" : documentPage}) 633 | documentPage = [] 634 | documentPage.append(block) 635 | else: 636 | documentPage.append(block) 637 | if(documentPage): 638 | documentPages.append({"Blocks" : documentPage}) 639 | return documentPages, blockMap 640 | 641 | def _parse(self): 642 | 643 | self._responseDocumentPages, self._blockMap = self._parseDocumentPagesAndBlockMap() 644 | for documentPage in self._responseDocumentPages: 645 | page = Page(documentPage["Blocks"], self._blockMap) 646 | self._pages.append(page) 647 | 648 | @property 649 | def blocks(self): 650 | return self._responsePages 651 | 652 | @property 653 | def pageBlocks(self): 654 | return self._responseDocumentPages 655 | 656 | @property 657 | def pages(self): 658 | return self._pages 659 | 660 | def getBlockById(self, blockId): 661 | block = None 662 | if(self._blockMap and blockId in self._blockMap): 663 | block = self._blockMap[blockId] 664 | return block 665 | 666 | def getPagesInReadingOrder(self): 667 | pageDict = {} 668 | for p in self._pages: 669 | pageDict[p._pageNumber] = p 670 | sortedDict = OrderedDict(sorted(pageDict.items())) 671 | return list(sortedDict.values()) 672 | 673 | def getLineHeights(self): 674 | heights = [] 675 | for p in self._pages: 676 | heights = heights + p.getLineHeights() 677 | heights = list(set(heights)) 678 | return heights 679 | 680 | 681 | 682 | -------------------------------------------------------------------------------- /cdk/lib/cdk-stack.ts: -------------------------------------------------------------------------------- 1 | // Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | // SPDX-License-Identifier: MIT-0 3 | 4 | import * as cdk from 'aws-cdk-lib'; 5 | import { Construct } from 'constructs'; 6 | import iam = require('aws-cdk-lib/aws-iam'); 7 | import {ObjectOwnership} from "aws-cdk-lib/aws-s3"; 8 | import { SqsEventSource } from 'aws-cdk-lib/aws-lambda-event-sources'; 9 | import sns = require('aws-cdk-lib/aws-sns'); 10 | import snsSubscriptions = require("aws-cdk-lib/aws-sns-subscriptions"); 11 | import sqs = require('aws-cdk-lib/aws-sqs'); 12 | import dynamodb = require('aws-cdk-lib/aws-dynamodb'); 13 | import lambda = require('aws-cdk-lib/aws-lambda'); 14 | import s3 = require('aws-cdk-lib/aws-s3'); 15 | import apigw = require('aws-cdk-lib/aws-apigateway'); 16 | import ecs = require('aws-cdk-lib/aws-ecs'); 17 | import ec2 = require('aws-cdk-lib/aws-ec2'); 18 | import cognito = require('aws-cdk-lib/aws-cognito'); 19 | import cloudfront = require('aws-cdk-lib/aws-cloudfront'); 20 | import origins = require('aws-cdk-lib/aws-cloudfront-origins'); 21 | import cognitoIdp = require('@aws-cdk/aws-cognito-identitypool-alpha'); 22 | import logs = require('aws-cdk-lib/aws-logs'); 23 | import efs = require('aws-cdk-lib/aws-efs'); 24 | import elb = require('aws-cdk-lib/aws-elasticloadbalancingv2'); 25 | import kms = require('aws-cdk-lib/aws-kms'); 26 | import ssm = require('aws-cdk-lib/aws-ssm'); 27 | import wafv2 = require ('aws-cdk-lib/aws-wafv2'); 28 | 29 | export class CdkStack extends cdk.Stack { 30 | constructor(scope: Construct, id: string, props?: cdk.StackProps) { 31 | super(scope, id, props); 32 | 33 | const key = new kms.Key(this, 'KmsKey', { 34 | enableKeyRotation: true, 35 | }); 36 | 37 | //**********SNS Topics****************************** 38 | const jobCompletionTopic = new sns.Topic(this, 'JobCompletion', { 39 | masterKey: key 40 | }); 41 | 42 | //**********IAM Roles****************************** 43 | const textractServiceRole = new iam.Role(this, 'TextractServiceRole', { 44 | assumedBy: new iam.ServicePrincipal('textract.amazonaws.com') 45 | }); 46 | textractServiceRole.addToPolicy( 47 | new iam.PolicyStatement({ 48 | effect: iam.Effect.ALLOW, 49 | resources: [jobCompletionTopic.topicArn], 50 | actions: ["sns:Publish"] 51 | }) 52 | ); 53 | textractServiceRole.addToPolicy( 54 | new iam.PolicyStatement({ 55 | effect: iam.Effect.ALLOW, 56 | resources: [key.keyArn], 57 | actions: [ 58 | "kms:GenerateDataKey", 59 | "kms:Decrypt" 60 | ] 61 | }) 62 | ); 63 | 64 | //**********S3 Bucket****************************** 65 | const corsRule: s3.CorsRule = { 66 | allowedMethods: [ 67 | s3.HttpMethods.GET, 68 | s3.HttpMethods.HEAD, 69 | s3.HttpMethods.PUT, 70 | s3.HttpMethods.POST, 71 | s3.HttpMethods.DELETE, 72 | ], 73 | allowedOrigins: ['*'], 74 | 75 | // the properties below are optional 76 | allowedHeaders: ['*'], 77 | exposedHeaders: [ 78 | "x-amz-server-side-encryption", 79 | "x-amz-request-id", 80 | "x-amz-id-2", 81 | "ETag" 82 | ], 83 | maxAge: 3000, 84 | }; 85 | const contentBucket = new s3.Bucket(this, 'DocumentsBucket', { 86 | blockPublicAccess: s3.BlockPublicAccess.BLOCK_ALL, 87 | versioned: false, 88 | encryption: s3.BucketEncryption.S3_MANAGED, 89 | cors: [corsRule], 90 | serverAccessLogsPrefix: 'accesslogs', 91 | enforceSSL: true, 92 | objectOwnership: ObjectOwnership.BUCKET_OWNER_PREFERRED, 93 | 94 | }); 95 | const appBucket = new s3.Bucket(this, 'AppBucket', { 96 | blockPublicAccess: s3.BlockPublicAccess.BLOCK_ALL, 97 | encryption: s3.BucketEncryption.S3_MANAGED, 98 | serverAccessLogsPrefix: 'accesslogs', 99 | publicReadAccess: false, 100 | enforceSSL: true, 101 | }); 102 | 103 | //**********DynamoDB Table************************* 104 | //DynamoDB table with textract output link 105 | // Fields = document, output type, s3 location 106 | const outputTable = new dynamodb.Table(this, 'OutputTable', { 107 | partitionKey: { name: 'documentId', type: dynamodb.AttributeType.STRING }, 108 | sortKey: { name: 'outputType', type: dynamodb.AttributeType.STRING }, 109 | billingMode: dynamodb.BillingMode.PAY_PER_REQUEST, 110 | pointInTimeRecovery: true 111 | }); 112 | 113 | //DynamoDB table with job status info. 114 | // Fields = document id, job id, status, s3 location 115 | const documentsTable = new dynamodb.Table(this, 'JobTable', { 116 | partitionKey: { name: 'documentId', type: dynamodb.AttributeType.STRING }, 117 | pointInTimeRecovery: true, 118 | billingMode: dynamodb.BillingMode.PAY_PER_REQUEST 119 | }); 120 | 121 | // DynamoDB table with summarization job info. 122 | // Fields = document id, job id, status, summary text 123 | const summarizationTable = new dynamodb.Table(this, 'SummarizationTable', { 124 | partitionKey: { name: 'documentId', type: dynamodb.AttributeType.STRING }, 125 | sortKey: { name: 'jobId', type: dynamodb.AttributeType.STRING }, 126 | pointInTimeRecovery: true, 127 | billingMode: dynamodb.BillingMode.PAY_PER_REQUEST 128 | }); 129 | 130 | // Table with embedding job info 131 | // Fields = document id, job id, status 132 | const embeddingTable = new dynamodb.Table(this, 'EmbeddingTable', { 133 | partitionKey: { name: 'documentId', type: dynamodb.AttributeType.STRING }, 134 | sortKey: { name: 'jobId', type: dynamodb.AttributeType.STRING }, 135 | pointInTimeRecovery: true, 136 | billingMode: dynamodb.BillingMode.PAY_PER_REQUEST 137 | }); 138 | 139 | //**********SQS Queues***************************** 140 | //DLQ 141 | const dlq = new sqs.Queue(this, 'DLQ', { 142 | visibilityTimeout: cdk.Duration.seconds(30), retentionPeriod: cdk.Duration.seconds(1209600), 143 | enforceSSL: true, 144 | }) 145 | 146 | //Queues 147 | const jobResultsQueue = new sqs.Queue(this, 'JobResults', { 148 | visibilityTimeout: cdk.Duration.seconds(900), retentionPeriod: cdk.Duration.seconds(1209600), deadLetterQueue: { queue: dlq, maxReceiveCount: 50 }, 149 | enforceSSL: true, 150 | }); 151 | //Trigger 152 | jobCompletionTopic.addSubscription( 153 | new snsSubscriptions.SqsSubscription(jobResultsQueue) 154 | ); 155 | const summarizationResultsQueue = new sqs.Queue(this, 'SummarizationResults', { 156 | visibilityTimeout: cdk.Duration.seconds(900), enforceSSL: true, 157 | retentionPeriod: cdk.Duration.seconds(1209600), deadLetterQueue: { queue: dlq, maxReceiveCount: 50 } 158 | }); 159 | const embeddingQueue = new sqs.Queue(this, 'EmbeddingQueue', { 160 | visibilityTimeout: cdk.Duration.seconds(900), enforceSSL: true, 161 | retentionPeriod: cdk.Duration.seconds(1209600), deadLetterQueue: { queue: dlq, maxReceiveCount: 50 } 162 | }); 163 | 164 | //**********Lambda Functions****************************** 165 | 166 | // Helper Layer with helper functions 167 | const helperLayer = new lambda.LayerVersion(this, 'HelperLayer', { 168 | code: lambda.Code.fromAsset('lambda/helper'), 169 | compatibleRuntimes: [lambda.Runtime.PYTHON_3_9], 170 | license: 'Apache-2.0', 171 | description: 'Helper layer.', 172 | }); 173 | 174 | // Textractor helper layer 175 | const textractorLayer = new lambda.LayerVersion(this, 'Textractor', { 176 | code: lambda.Code.fromAsset('lambda/textractor'), 177 | compatibleRuntimes: [lambda.Runtime.PYTHON_3_9], 178 | license: 'Apache-2.0', 179 | description: 'Textractor layer.', 180 | }); 181 | 182 | //------------------------------------------------------------ 183 | // Async Job Processor (Start jobs using Async APIs) 184 | const asyncProcessor = new lambda.Function(this, 'ASyncProcessor', { 185 | runtime: lambda.Runtime.PYTHON_3_9, 186 | code: lambda.Code.fromAsset('lambda/asyncprocessor'), 187 | handler: 'lambda_function.lambda_handler', 188 | reservedConcurrentExecutions: 1, 189 | tracing: lambda.Tracing.ACTIVE, 190 | timeout: cdk.Duration.seconds(60), 191 | environment: { 192 | SNS_TOPIC_ARN: jobCompletionTopic.topicArn, 193 | SNS_ROLE_ARN: textractServiceRole.roleArn, 194 | OUTPUT_TABLE: outputTable.tableName, 195 | DOCUMENTS_TABLE: documentsTable.tableName, 196 | } 197 | }); 198 | 199 | //Layer 200 | asyncProcessor.addLayers(helperLayer) 201 | 202 | //Permissions 203 | contentBucket.grantRead(asyncProcessor) 204 | outputTable.grantReadWriteData(asyncProcessor) 205 | documentsTable.grantReadWriteData(asyncProcessor) 206 | asyncProcessor.addToRolePolicy( 207 | new iam.PolicyStatement({ 208 | actions: ["iam:PassRole"], 209 | resources: [textractServiceRole.roleArn] 210 | }) 211 | ); 212 | asyncProcessor.addToRolePolicy( 213 | new iam.PolicyStatement({ 214 | actions: ["textract:StartDocumentTextDetection"], 215 | resources: ["*"] 216 | }) 217 | ); 218 | //------------------------------------------------------------ 219 | 220 | // Async Jobs Results Processor 221 | const jobResultProcessor = new lambda.Function(this, 'JobResultProcessor', { 222 | runtime: lambda.Runtime.PYTHON_3_9, 223 | code: lambda.Code.fromAsset('lambda/jobresultprocessor'), 224 | handler: 'lambda_function.lambda_handler', 225 | memorySize: 2000, 226 | tracing: lambda.Tracing.ACTIVE, 227 | reservedConcurrentExecutions: 50, 228 | timeout: cdk.Duration.seconds(900), 229 | environment: { 230 | OUTPUT_TABLE: outputTable.tableName, 231 | DOCUMENTS_TABLE: documentsTable.tableName, 232 | } 233 | }); 234 | //Layer 235 | jobResultProcessor.addLayers(helperLayer) 236 | jobResultProcessor.addLayers(textractorLayer) 237 | //Triggers 238 | jobResultProcessor.addEventSource(new SqsEventSource(jobResultsQueue, { 239 | batchSize: 1 240 | })); 241 | //Permissions 242 | outputTable.grantReadWriteData(jobResultProcessor) 243 | documentsTable.grantReadWriteData(jobResultProcessor) 244 | contentBucket.grantReadWrite(jobResultProcessor) 245 | jobResultProcessor.addToRolePolicy( 246 | new iam.PolicyStatement({ 247 | actions: ["textract:GetDocumentTextDetection", "textract:GetDocumentAnalysis"], 248 | resources: ["*"] 249 | }) 250 | ); 251 | 252 | //------------------------------------------------------------ 253 | 254 | // Summarization handler 255 | const summarizationProcessor = new lambda.Function(this, 'SummarizationProcessor', { 256 | runtime: lambda.Runtime.PYTHON_3_9, 257 | code: lambda.Code.fromAsset('lambda/summarizationprocessor'), 258 | handler: 'lambda_function.lambda_handler', 259 | tracing: lambda.Tracing.ACTIVE, 260 | timeout: cdk.Duration.seconds(60), 261 | environment: { 262 | QUEUE_URL: summarizationResultsQueue.queueUrl, 263 | JOB_TABLE: summarizationTable.tableName, 264 | } 265 | }); 266 | summarizationProcessor.addLayers(helperLayer) 267 | summarizationResultsQueue.grantSendMessages(summarizationProcessor) 268 | summarizationTable.grantReadWriteData(summarizationProcessor) 269 | 270 | // Embedding handler 271 | const embeddingProcessor = new lambda.Function(this, 'EmbeddingProcessor', { 272 | runtime: lambda.Runtime.PYTHON_3_9, 273 | code: lambda.Code.fromAsset('lambda/embeddingprocessor'), 274 | handler: 'lambda_function.lambda_handler', 275 | tracing: lambda.Tracing.ACTIVE, 276 | timeout: cdk.Duration.seconds(60), 277 | environment: { 278 | QUEUE_URL: embeddingQueue.queueUrl, 279 | JOB_TABLE: embeddingTable.tableName, 280 | } 281 | }); 282 | embeddingProcessor.addLayers(helperLayer) 283 | embeddingQueue.grantSendMessages(embeddingProcessor) 284 | embeddingTable.grantReadWriteData(embeddingProcessor) 285 | 286 | //**********API Gateway****************************** 287 | const prdLogGroup = new logs.LogGroup(this, "PrdLogs"); 288 | const cfnWebACLApi = new wafv2.CfnWebACL(this, 'WebAclApi', { 289 | defaultAction: { 290 | allow: {} 291 | }, 292 | scope: 'REGIONAL', 293 | visibilityConfig: { 294 | cloudWatchMetricsEnabled: true, 295 | metricName:'MetricForWebACLCDKApi', 296 | sampledRequestsEnabled: true, 297 | }, 298 | name:'CdkWebAclApi', 299 | rules: [{ 300 | name: 'CRSRule', 301 | priority: 0, 302 | statement: { 303 | managedRuleGroupStatement: { 304 | name:'AWSManagedRulesCommonRuleSet', 305 | vendorName:'AWS' 306 | } 307 | }, 308 | visibilityConfig: { 309 | cloudWatchMetricsEnabled: true, 310 | metricName:'MetricForWebACLCDK-CRS-Api', 311 | sampledRequestsEnabled: true, 312 | }, 313 | overrideAction: { 314 | none: {} 315 | }, 316 | }] 317 | }); 318 | const api = new apigw.RestApi(this, 'sum-qa-api', { 319 | defaultCorsPreflightOptions: { 320 | allowOrigins: apigw.Cors.ALL_ORIGINS, 321 | allowMethods: apigw.Cors.ALL_METHODS, 322 | allowHeaders: apigw.Cors.DEFAULT_HEADERS, 323 | allowCredentials: true, 324 | statusCode: 200 325 | }, 326 | deployOptions: { 327 | accessLogDestination: new apigw.LogGroupLogDestination(prdLogGroup), 328 | accessLogFormat: apigw.AccessLogFormat.jsonWithStandardFields(), 329 | loggingLevel: apigw.MethodLoggingLevel.INFO, 330 | dataTraceEnabled: true 331 | }, 332 | cloudWatchRole: true, 333 | }); 334 | const cfnWebACLAssociation = new wafv2.CfnWebACLAssociation(this,'ApiCDKWebACLAssociation', { 335 | resourceArn: api.deploymentStage.stageArn, 336 | webAclArn: cfnWebACLApi.attrArn, 337 | }); 338 | const documentResource = api.root.addResource('doctopdf'); 339 | const pdfToText = new apigw.LambdaIntegration(asyncProcessor); 340 | documentResource.addMethod('POST', pdfToText, { 341 | authorizationType: apigw.AuthorizationType.IAM 342 | }) 343 | const summarizeResource = api.root.addResource('summarize'); 344 | const summarizeIntegration = new apigw.LambdaIntegration(summarizationProcessor); 345 | summarizeResource.addMethod('POST', summarizeIntegration, { 346 | authorizationType: apigw.AuthorizationType.IAM 347 | }) 348 | const embeddingResource = api.root.addResource('embed'); 349 | const embeddingIntegration = new apigw.LambdaIntegration(embeddingProcessor); 350 | embeddingResource.addMethod('POST', embeddingIntegration, { 351 | authorizationType: apigw.AuthorizationType.IAM 352 | }) 353 | const qaResource = api.root.addResource('qa'); 354 | 355 | //**********Fargate tasks****************************** 356 | 357 | const vpc = new ec2.Vpc(this, 'VPC', { 358 | gatewayEndpoints: { 359 | S3: { 360 | service: ec2.GatewayVpcEndpointAwsService.S3, 361 | }, 362 | }, 363 | }); 364 | vpc.addFlowLog('FlowLogS3', { 365 | destination: ec2.FlowLogDestination.toS3(contentBucket, 'flowlogs/') 366 | }); 367 | vpc.addInterfaceEndpoint('EcrDockerEndpoint', { 368 | service: ec2.InterfaceVpcEndpointAwsService.ECR_DOCKER, 369 | }); 370 | vpc.addInterfaceEndpoint('KmsEndpoint', { 371 | service: ec2.InterfaceVpcEndpointAwsService.KMS, 372 | }); 373 | const endpointSum = this.node.tryGetContext('sumEndpoint'); 374 | const endpointEmbed = this.node.tryGetContext('embedEndpoint'); 375 | const cluster = new ecs.Cluster(this, 'Cluster', { 376 | vpc, 377 | enableFargateCapacityProviders: true, 378 | containerInsights: true 379 | }); 380 | const fargateTaskDefinition = new ecs.FargateTaskDefinition(this, 'SummarizationWorkerTask', { 381 | memoryLimitMiB: 61440, 382 | cpu: 8192, 383 | ephemeralStorageGiB: 200, 384 | // Uncomment this section if running on ARM 385 | // runtimePlatform: { 386 | // cpuArchitecture: ecs.CpuArchitecture.ARM64, 387 | // } 388 | }); 389 | fargateTaskDefinition.grantRun(summarizationProcessor) 390 | contentBucket.grantRead(fargateTaskDefinition.taskRole) 391 | summarizationTable.grantReadWriteData(fargateTaskDefinition.taskRole) 392 | fargateTaskDefinition.taskRole.addToPrincipalPolicy( 393 | new iam.PolicyStatement({ 394 | actions: ["sagemaker:InvokeEndpoint"], 395 | resources: [ 396 | "arn:aws:sagemaker:" + this.region + ":" + this.account + ":endpoint/" + endpointSum 397 | ] 398 | }) 399 | ); 400 | const regionParam = new ssm.StringParameter(this, 'RegionParameter', { 401 | parameterName: 'RegionParameter', 402 | stringValue: this.region, 403 | tier: ssm.ParameterTier.ADVANCED, 404 | }); 405 | const endpointSumParam = new ssm.StringParameter(this, 'EndpointSumParameter', { 406 | parameterName: 'EndpointSumParameter', 407 | stringValue: endpointSum, 408 | tier: ssm.ParameterTier.ADVANCED, 409 | }); 410 | const sumTableParam = new ssm.StringParameter(this, 'SumTableParameter', { 411 | parameterName: 'SumTableParameter', 412 | stringValue: summarizationTable.tableName, 413 | tier: ssm.ParameterTier.ADVANCED, 414 | }); 415 | fargateTaskDefinition.addContainer('worker', { 416 | image: ecs.ContainerImage.fromAsset('fargate/summarizationWorker'), 417 | logging: ecs.LogDrivers.awsLogs({ streamPrefix: 'summarization-log-group', logRetention: 30 }), 418 | secrets: { 419 | endpoint: ecs.Secret.fromSsmParameter(endpointSumParam), 420 | table: ecs.Secret.fromSsmParameter(sumTableParam), 421 | region: ecs.Secret.fromSsmParameter(regionParam) 422 | } 423 | }); 424 | 425 | //**********ECS task launcher****************************** 426 | const subnetIds: string[] = []; 427 | vpc.privateSubnets.forEach(subnet => { 428 | subnetIds.push(subnet.subnetId); 429 | }); 430 | 431 | // Summarization worker - fires ECS task 432 | const taskProcessor = new lambda.Function(this, 'TaskProcessor', { 433 | runtime: lambda.Runtime.PYTHON_3_9, 434 | code: lambda.Code.fromAsset('lambda/taskprocessor'), 435 | handler: 'lambda_function.lambda_handler', 436 | tracing: lambda.Tracing.ACTIVE, 437 | reservedConcurrentExecutions: 50, 438 | timeout: cdk.Duration.seconds(30), 439 | environment: { 440 | target: cluster.clusterArn, 441 | taskDefinitionArn: fargateTaskDefinition.taskDefinitionArn, 442 | subnets: subnetIds.join(",") 443 | } 444 | }); 445 | //Triggers 446 | taskProcessor.addEventSource(new SqsEventSource(summarizationResultsQueue, { 447 | batchSize: 1 448 | })); 449 | //Permissions 450 | taskProcessor.addToRolePolicy( 451 | new iam.PolicyStatement({ 452 | actions: ["ecs:RunTask"], 453 | resources: [fargateTaskDefinition.taskDefinitionArn] 454 | }) 455 | ); 456 | taskProcessor.addToRolePolicy( 457 | new iam.PolicyStatement({ 458 | actions: ["iam:PassRole"], 459 | resources: ["*"] 460 | }) 461 | ); 462 | 463 | //**********EFS************************* 464 | const fileSystem = new efs.FileSystem(this, 'ChromaFileSystem', { 465 | vpc: vpc, 466 | encrypted: true, 467 | enableAutomaticBackups: true, 468 | performanceMode: efs.PerformanceMode.GENERAL_PURPOSE, // default 469 | vpcSubnets: { 470 | subnetType: ec2.SubnetType.PRIVATE_WITH_EGRESS, 471 | }, 472 | }); 473 | const accessPoint = fileSystem.addAccessPoint('LambdaAccessPoint',{ 474 | createAcl: { 475 | ownerGid: '1001', 476 | ownerUid: '1001', 477 | permissions: '750' 478 | }, 479 | path:'/export/lambda', 480 | posixUser: { 481 | gid: '1001', 482 | uid: '1001' 483 | } 484 | }); 485 | 486 | //**********Function that uses EFS************************* 487 | const endpointQa = this.node.tryGetContext('qaEndpoint'); 488 | const fargateTaskDefinitionEmbed = new ecs.FargateTaskDefinition(this, 'EmbedWorkerTask', { 489 | memoryLimitMiB: 8192, 490 | cpu: 4096, 491 | ephemeralStorageGiB: 100, 492 | // Uncomment this section if running on ARM 493 | // runtimePlatform: { 494 | // cpuArchitecture: ecs.CpuArchitecture.ARM64, 495 | // } 496 | }); 497 | const embedVolume = { 498 | name: "datavolume", 499 | efsVolumeConfiguration: { 500 | fileSystemId: fileSystem.fileSystemId, 501 | transitEncryption: 'ENABLED', 502 | authorizationConfig:{ 503 | accessPointId: accessPoint.accessPointId, 504 | iam: 'ENABLED' 505 | } 506 | }, 507 | }; 508 | fargateTaskDefinitionEmbed.addVolume(embedVolume); 509 | contentBucket.grantRead(fargateTaskDefinitionEmbed.taskRole) 510 | embeddingTable.grantReadWriteData(fargateTaskDefinitionEmbed.taskRole) 511 | fargateTaskDefinitionEmbed.taskRole.addToPrincipalPolicy( 512 | new iam.PolicyStatement({ 513 | actions: ["sagemaker:InvokeEndpoint"], 514 | resources: [ 515 | "arn:aws:sagemaker:" + this.region + ":" + this.account + ":endpoint/" + endpointEmbed 516 | ] 517 | }) 518 | ); 519 | const endpointEmbedParam = new ssm.StringParameter(this, 'EndpointEmbedParameter', { 520 | parameterName: 'EndpointEmbedParameter', 521 | stringValue: endpointEmbed, 522 | tier: ssm.ParameterTier.ADVANCED, 523 | }); 524 | const embedTableParam = new ssm.StringParameter(this, 'EmbedTableParameter', { 525 | parameterName: 'EmbedTableParameter', 526 | stringValue: embeddingTable.tableName, 527 | tier: ssm.ParameterTier.ADVANCED, 528 | }); 529 | const mountParam = new ssm.StringParameter(this, 'MountParameter', { 530 | parameterName: 'MountParameter', 531 | stringValue: '/efs/data', 532 | tier: ssm.ParameterTier.ADVANCED, 533 | }); 534 | const embedContainer = fargateTaskDefinitionEmbed.addContainer('worker', { 535 | image: ecs.ContainerImage.fromAsset('fargate/embeddingWorker'), 536 | logging: ecs.LogDrivers.awsLogs({ streamPrefix: 'embed-log-group', logRetention: 30 }), 537 | secrets: { 538 | endpoint: ecs.Secret.fromSsmParameter(endpointEmbedParam), 539 | table: ecs.Secret.fromSsmParameter(embedTableParam), 540 | region: ecs.Secret.fromSsmParameter(regionParam), 541 | mountpoint: ecs.Secret.fromSsmParameter(mountParam) 542 | } 543 | }); 544 | embedContainer.addMountPoints( 545 | { 546 | containerPath: '/efs/data', 547 | readOnly: false, 548 | sourceVolume: 'datavolume', 549 | } 550 | ); 551 | fargateTaskDefinitionEmbed.addToTaskRolePolicy( 552 | new iam.PolicyStatement({ 553 | actions: [ 554 | 'elasticfilesystem:ClientRootAccess', 555 | 'elasticfilesystem:ClientWrite', 556 | 'elasticfilesystem:ClientMount', 557 | 'elasticfilesystem:DescribeMountTargets' 558 | ], 559 | resources: [fileSystem.fileSystemArn] 560 | }) 561 | ); 562 | fargateTaskDefinitionEmbed.addToTaskRolePolicy( 563 | new iam.PolicyStatement({ 564 | actions: ['ec2:DescribeAvailabilityZones'], 565 | resources: ['*'] 566 | }) 567 | ); 568 | fileSystem.connections.allowDefaultPortFrom(ec2.Peer.ipv4(vpc.vpcCidrBlock)); 569 | const embeddingWorker = new lambda.Function(this, 'EmbeddingWorker', { 570 | runtime: lambda.Runtime.PYTHON_3_9, 571 | code: lambda.Code.fromAsset('lambda/embeddingworker'), 572 | handler: 'lambda_function.lambda_handler', 573 | tracing: lambda.Tracing.ACTIVE, 574 | memorySize: 1024, 575 | reservedConcurrentExecutions: 50, 576 | timeout: cdk.Duration.seconds(30), 577 | environment: { 578 | target: cluster.clusterArn, 579 | taskDefinitionArn: fargateTaskDefinitionEmbed.taskDefinitionArn, 580 | subnets: subnetIds.join(",") 581 | } 582 | }); 583 | //Triggers 584 | embeddingWorker.addEventSource(new SqsEventSource(embeddingQueue, { 585 | batchSize: 1 586 | })); 587 | embeddingWorker.addToRolePolicy( 588 | new iam.PolicyStatement({ 589 | actions: ["ecs:RunTask"], 590 | resources: [fargateTaskDefinitionEmbed.taskDefinitionArn] 591 | }) 592 | ); 593 | embeddingWorker.addToRolePolicy( 594 | new iam.PolicyStatement({ 595 | actions: ["iam:PassRole"], 596 | resources: ["*"] 597 | }) 598 | ); 599 | 600 | //***********Cognito ************************/ 601 | const userPool = new cognito.UserPool(this, 'userpool', { 602 | userPoolName: 'fsiqasumuserpool', 603 | selfSignUpEnabled: false, 604 | signInCaseSensitive: false, // case insensitive is preferred in most situations 605 | signInAliases: { 606 | username: true, 607 | email: true, 608 | }, 609 | passwordPolicy: { 610 | minLength: 8, 611 | requireLowercase: true, 612 | requireUppercase: true, 613 | requireDigits: true, 614 | requireSymbols: true, 615 | }, 616 | advancedSecurityMode: cognito.AdvancedSecurityMode.ENFORCED 617 | }); 618 | 619 | const idPool = new cognitoIdp.IdentityPool(this, 'fsiIdentityPool', 620 | { 621 | allowUnauthenticatedIdentities: false, 622 | authenticationProviders: { 623 | userPools: [new cognitoIdp.UserPoolAuthenticationProvider({ userPool })], 624 | }, 625 | }, 626 | ); 627 | documentsTable.grantReadData(idPool.authenticatedRole) 628 | summarizationTable.grantReadData(idPool.authenticatedRole) 629 | outputTable.grantReadData(idPool.authenticatedRole) 630 | embeddingTable.grantReadData(idPool.authenticatedRole) 631 | contentBucket.grantReadWrite(idPool.authenticatedRole) 632 | idPool.authenticatedRole.addToPrincipalPolicy(new iam.PolicyStatement({ 633 | effect: iam.Effect.ALLOW, 634 | actions: ['execute-api:Invoke'], 635 | resources: [api.arnForExecuteApi('*')], 636 | })); 637 | const cfnUserPoolGroup = new cognito.CfnUserPoolGroup(this, 'MyCfnUserPoolGroup', { 638 | userPoolId: userPool.userPoolId, 639 | groupName: 'fsigroup', 640 | precedence: 1, 641 | roleArn: idPool.authenticatedRole.roleArn 642 | }); 643 | 644 | //**********ECS Service for QA****************************** 645 | const fargateTaskDefinitionQa = new ecs.FargateTaskDefinition(this, 'QaWorkerTask', { 646 | memoryLimitMiB: 8192, 647 | cpu: 4096, 648 | ephemeralStorageGiB: 100, 649 | // Uncomment this section if running on ARM 650 | // runtimePlatform: { 651 | // cpuArchitecture: ecs.CpuArchitecture.ARM64, 652 | // } 653 | }); 654 | const qaVolume = { 655 | name: "datavolume", 656 | efsVolumeConfiguration: { 657 | fileSystemId: fileSystem.fileSystemId, 658 | transitEncryption: 'ENABLED', 659 | authorizationConfig:{ 660 | accessPointId: accessPoint.accessPointId, 661 | iam: 'ENABLED' 662 | } 663 | }, 664 | }; 665 | fargateTaskDefinitionQa.addVolume(qaVolume); 666 | fargateTaskDefinitionQa.taskRole.addToPrincipalPolicy( 667 | new iam.PolicyStatement({ 668 | actions: ["sagemaker:InvokeEndpoint"], 669 | resources: [ 670 | "arn:aws:sagemaker:" + this.region + ":" + this.account + ":endpoint/" + endpointEmbed, 671 | "arn:aws:sagemaker:" + this.region + ":" + this.account + ":endpoint/" + endpointQa 672 | ] 673 | }) 674 | ); 675 | const endpointQaParam = new ssm.StringParameter(this, 'EndpointQaParameter', { 676 | parameterName: 'EndpointQaParameter', 677 | stringValue: endpointQa, 678 | tier: ssm.ParameterTier.ADVANCED, 679 | }); 680 | const qaContainer = fargateTaskDefinitionQa.addContainer('qaworker', { 681 | image: ecs.ContainerImage.fromAsset('fargate/qaWorker'), 682 | containerName: 'qaworker', 683 | logging: ecs.LogDrivers.awsLogs({ streamPrefix: 'qa-log-group', logRetention: 30 }), 684 | portMappings: [ 685 | { 686 | containerPort: 5000, 687 | hostPort: 5000 688 | } 689 | ], 690 | secrets: { 691 | endpoint_embed: ecs.Secret.fromSsmParameter(endpointEmbedParam), 692 | endpoint_qa: ecs.Secret.fromSsmParameter(endpointQaParam), 693 | mountpoint: ecs.Secret.fromSsmParameter(mountParam) 694 | }, 695 | essential: true 696 | }); 697 | qaContainer.addMountPoints( 698 | { 699 | containerPath: '/efs/data', 700 | readOnly: false, 701 | sourceVolume: 'datavolume', 702 | } 703 | ); 704 | fargateTaskDefinitionQa.addToTaskRolePolicy( 705 | new iam.PolicyStatement({ 706 | actions: [ 707 | 'elasticfilesystem:ClientRootAccess', 708 | 'elasticfilesystem:ClientWrite', 709 | 'elasticfilesystem:ClientMount', 710 | 'elasticfilesystem:DescribeMountTargets' 711 | ], 712 | resources: [fileSystem.fileSystemArn] 713 | }) 714 | ); 715 | fargateTaskDefinitionQa.addToTaskRolePolicy( 716 | new iam.PolicyStatement({ 717 | actions: ['ec2:DescribeAvailabilityZones'], 718 | resources: ['*'] 719 | }) 720 | ); 721 | const serviceSecurityGroup = new ec2.SecurityGroup(this, 'svcSecurityGroup', { 722 | vpc: vpc, 723 | securityGroupName: 'svcSecurityGroup' 724 | }) 725 | serviceSecurityGroup.addIngressRule( 726 | ec2.Peer.ipv4(vpc.vpcCidrBlock), 727 | ec2.Port.tcp(5000), 728 | 'Allow inbound traffic from resources in vpc' 729 | ) 730 | const qaService = new ecs.FargateService(this, 'qaService', { 731 | serviceName: 'qaService', 732 | cluster: cluster, 733 | desiredCount: 1, 734 | securityGroups: [serviceSecurityGroup], 735 | taskDefinition: fargateTaskDefinitionQa, 736 | healthCheckGracePeriod: cdk.Duration.seconds(300) 737 | }) 738 | const qaNLB = new elb.NetworkLoadBalancer(this, 'qaNLB', { 739 | loadBalancerName: 'qaNLB', 740 | vpc: vpc, 741 | crossZoneEnabled: true, 742 | internetFacing: false, 743 | }) 744 | qaNLB.logAccessLogs(contentBucket, "nlblog") 745 | const qaTargetGroup = new elb.NetworkTargetGroup(this, 'qaTargetGroup', { 746 | targetGroupName: 'qaTargetGroup', 747 | vpc: vpc, 748 | port: 5000, 749 | targets: [qaService] 750 | }) 751 | qaTargetGroup.configureHealthCheck({ 752 | path: "/health", 753 | protocol: elb.Protocol.HTTP, 754 | port: "5000", 755 | }); 756 | qaNLB.addListener('qaTargetGroupListener', { 757 | port: 80, 758 | defaultTargetGroups: [qaTargetGroup] 759 | }) 760 | 761 | const link = new apigw.VpcLink(this, 'link', { 762 | targets: [qaNLB], 763 | }); 764 | const qaIntegration = new apigw.Integration({ 765 | type: apigw.IntegrationType.HTTP_PROXY, 766 | integrationHttpMethod: "POST", 767 | options: { 768 | connectionType: apigw.ConnectionType.VPC_LINK, 769 | vpcLink: link, 770 | }, 771 | }); 772 | qaResource.addMethod('POST', qaIntegration, { 773 | authorizationType: apigw.AuthorizationType.IAM 774 | }) 775 | 776 | //**********CloudFront****************************** 777 | const cfnWebACL = new wafv2.CfnWebACL(this, 'WebAcl', { 778 | defaultAction: { 779 | allow: {} 780 | }, 781 | scope: 'CLOUDFRONT', 782 | visibilityConfig: { 783 | cloudWatchMetricsEnabled: true, 784 | metricName:'MetricForWebACLCDK', 785 | sampledRequestsEnabled: true, 786 | }, 787 | name:'CdkWebAcl', 788 | rules: [{ 789 | name: 'CRSRule', 790 | priority: 0, 791 | statement: { 792 | managedRuleGroupStatement: { 793 | name:'AWSManagedRulesCommonRuleSet', 794 | vendorName:'AWS' 795 | } 796 | }, 797 | visibilityConfig: { 798 | cloudWatchMetricsEnabled: true, 799 | metricName:'MetricForWebACLCDK-CRS', 800 | sampledRequestsEnabled: true, 801 | }, 802 | overrideAction: { 803 | none: {} 804 | }, 805 | }] 806 | }); 807 | const distribution = new cloudfront.Distribution(this, 'appdist', { 808 | defaultBehavior: { origin: new origins.S3Origin(appBucket) }, 809 | enableLogging: true, 810 | logBucket: contentBucket, 811 | logFilePrefix: 'distribution-access-logs/', 812 | logIncludesCookies: true, 813 | geoRestriction: cloudfront.GeoRestriction.allowlist('US'), 814 | minimumProtocolVersion: cloudfront.SecurityPolicyProtocol.TLS_V1_2_2021, 815 | webAclId: cfnWebACL.attrArn 816 | }); 817 | 818 | //**********Outputs****************************** 819 | new cdk.CfnOutput(this, 'DocToPdfApiUrl', { 820 | value: `${api.url}doctopdf`, 821 | }); 822 | new cdk.CfnOutput(this, 'UserPoolId', { 823 | value: `${userPool.userPoolId}`, 824 | }); 825 | new cdk.CfnOutput(this, 'IdentityPoolId', { 826 | value: `${idPool.identityPoolId}`, 827 | }); 828 | new cdk.CfnOutput(this, 'UserPoolGroupName', { 829 | value: `${cfnUserPoolGroup.groupName}`, 830 | }); 831 | new cdk.CfnOutput(this, 'SummarizeUrl', { 832 | value: `${api.url}summarize`, 833 | }); 834 | new cdk.CfnOutput(this, 'BucketName', { 835 | value: `${contentBucket.bucketName}`, 836 | }); 837 | new cdk.CfnOutput(this, 'OutputTableName', { 838 | value: `${outputTable.tableName}`, 839 | }); 840 | new cdk.CfnOutput(this, 'DocumentTableName', { 841 | value: `${documentsTable.tableName}`, 842 | }); 843 | new cdk.CfnOutput(this, 'AppBucketName', { 844 | value: `${appBucket.bucketName}`, 845 | }); 846 | new cdk.CfnOutput(this, 'AppUrl', { 847 | value: `${distribution.domainName}`, 848 | }); 849 | } 850 | } 851 | -------------------------------------------------------------------------------- /cdk/package.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "cdk", 3 | "version": "0.1.0", 4 | "bin": { 5 | "cdk": "bin/cdk.js" 6 | }, 7 | "scripts": { 8 | "build": "tsc", 9 | "watch": "tsc -w", 10 | "test": "jest", 11 | "cdk": "cdk" 12 | }, 13 | "devDependencies": { 14 | "@types/jest": "^29.4.0", 15 | "@types/node": "18.11.18", 16 | "aws-cdk": "2.74.0", 17 | "jest": "^29.4.1", 18 | "ts-jest": "^29.0.5", 19 | "ts-node": "^10.9.1", 20 | "typescript": "~4.9.5" 21 | }, 22 | "dependencies": { 23 | "@aws-cdk/aws-cognito-identitypool-alpha": "^2.74.0-alpha.0", 24 | "@aws-cdk/lambda-layer-kubectl-v25": "^2.0.3", 25 | "aws-cdk-lib": "2.80.0", 26 | "cdk-nag": "^2.26.13", 27 | "constructs": "^10.0.0", 28 | "source-map-support": "^0.5.21" 29 | } 30 | } 31 | -------------------------------------------------------------------------------- /cdk/test/cdk.test.ts: -------------------------------------------------------------------------------- 1 | // Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | // SPDX-License-Identifier: MIT-0 3 | 4 | // import * as cdk from 'aws-cdk-lib'; 5 | // import { Template } from 'aws-cdk-lib/assertions'; 6 | // import * as Cdk from '../lib/cdk-stack'; 7 | 8 | // example test. To run these tests, uncomment this file along with the 9 | // example resource in lib/cdk-stack.ts 10 | test('SQS Queue Created', () => { 11 | // const app = new cdk.App(); 12 | // // WHEN 13 | // const stack = new Cdk.CdkStack(app, 'MyTestStack'); 14 | // // THEN 15 | // const template = Template.fromStack(stack); 16 | 17 | // template.hasResourceProperties('AWS::SQS::Queue', { 18 | // VisibilityTimeout: 300 19 | // }); 20 | }); 21 | -------------------------------------------------------------------------------- /cdk/tsconfig.json: -------------------------------------------------------------------------------- 1 | { 2 | "compilerOptions": { 3 | "target": "ES2020", 4 | "module": "commonjs", 5 | "lib": [ 6 | "es2020" 7 | ], 8 | "declaration": true, 9 | "strict": true, 10 | "noImplicitAny": true, 11 | "strictNullChecks": true, 12 | "noImplicitThis": true, 13 | "alwaysStrict": true, 14 | "noUnusedLocals": false, 15 | "noUnusedParameters": false, 16 | "noImplicitReturns": true, 17 | "noFallthroughCasesInSwitch": false, 18 | "inlineSourceMap": true, 19 | "inlineSources": true, 20 | "experimentalDecorators": true, 21 | "strictPropertyInitialization": false, 22 | "typeRoots": [ 23 | "./node_modules/@types" 24 | ] 25 | }, 26 | "exclude": [ 27 | "node_modules", 28 | "cdk.out" 29 | ] 30 | } 31 | -------------------------------------------------------------------------------- /diagrams/fsi-qa.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aws-samples/question-answering-large-documents/e683efb9186682fcf8e4c8e76465c2792efb834d/diagrams/fsi-qa.png -------------------------------------------------------------------------------- /frontend/.gitignore: -------------------------------------------------------------------------------- 1 | # Node modules 2 | node_modules/ 3 | 4 | # Logs 5 | logs 6 | *.log 7 | npm-debug.log* 8 | yarn-debug.log* 9 | yarn-error.log* 10 | lerna-debug.log* 11 | .pnpm-debug.log* 12 | 13 | # Optional npm cache directory 14 | .npm 15 | 16 | # Yarn Integrity file 17 | .yarn-integrity 18 | .yarn/cache 19 | .yarn/unplugged 20 | .yarn/build-state.yml 21 | .yarn/install-state.gz 22 | 23 | # Build folders 24 | build/ 25 | 26 | # Local History for Visual Studio Code 27 | .history/ 28 | 29 | # MacOS hidden folders 30 | .DS_Store/ -------------------------------------------------------------------------------- /frontend/package.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "fsi-summarize-qa", 3 | "version": "1.0.0", 4 | "private": false, 5 | "email": "rddefauw@amazon.com", 6 | "license": "MIT", 7 | "devDependencies": { 8 | "@testing-library/jest-dom": "^5.11.4", 9 | "@testing-library/react": "^11.1.0", 10 | "@testing-library/user-event": "^12.1.10", 11 | "react-scripts": "^5.0.1" 12 | }, 13 | "dependencies": { 14 | "@aws-amplify/ui-react": "^4.6.0", 15 | "aws-amplify": "^6.0.21", 16 | "aws-sdk": "^2.1360.0", 17 | "bootstrap": "^5.2.3", 18 | "react": "^17.0.2", 19 | "react-async": "^10.0.1", 20 | "react-bootstrap": "^2.7.4", 21 | "react-dom": "^17.0.2", 22 | "react-router-dom": "^5.3.0", 23 | "react-toastify": "^9.1.2", 24 | "uuid": "^9.0.0" 25 | }, 26 | "scripts": { 27 | "start": "react-scripts start", 28 | "build": "react-scripts build", 29 | "test": "react-scripts test", 30 | "eject": "react-scripts eject" 31 | }, 32 | "eslintConfig": { 33 | "extends": [ 34 | "react-app", 35 | "react-app/jest" 36 | ] 37 | }, 38 | "browserslist": { 39 | "production": [ 40 | ">0.2%", 41 | "not dead", 42 | "not op_mini all" 43 | ], 44 | "development": [ 45 | "last 1 chrome version", 46 | "last 1 firefox version", 47 | "last 1 safari version" 48 | ] 49 | } 50 | } 51 | -------------------------------------------------------------------------------- /frontend/public/favicon.ico: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aws-samples/question-answering-large-documents/e683efb9186682fcf8e4c8e76465c2792efb834d/frontend/public/favicon.ico -------------------------------------------------------------------------------- /frontend/public/index.html: -------------------------------------------------------------------------------- 1 | 2 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 16 | 17 | 21 | 22 | 31 | Summarization and Question Answering 32 | 33 | 34 | 35 |
36 | 37 | 38 | -------------------------------------------------------------------------------- /frontend/public/logo192.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aws-samples/question-answering-large-documents/e683efb9186682fcf8e4c8e76465c2792efb834d/frontend/public/logo192.png -------------------------------------------------------------------------------- /frontend/public/logo512.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aws-samples/question-answering-large-documents/e683efb9186682fcf8e4c8e76465c2792efb834d/frontend/public/logo512.png -------------------------------------------------------------------------------- /frontend/public/manifest.json: -------------------------------------------------------------------------------- 1 | { 2 | "short_name": "React App", 3 | "name": "Create React App Sample", 4 | "icons": [ 5 | { 6 | "src": "favicon.ico", 7 | "sizes": "64x64 32x32 24x24 16x16", 8 | "type": "image/x-icon" 9 | }, 10 | { 11 | "src": "logo192.png", 12 | "type": "image/png", 13 | "sizes": "192x192" 14 | }, 15 | { 16 | "src": "logo512.png", 17 | "type": "image/png", 18 | "sizes": "512x512" 19 | } 20 | ], 21 | "start_url": ".", 22 | "display": "standalone", 23 | "theme_color": "#000000", 24 | "background_color": "#ffffff" 25 | } 26 | -------------------------------------------------------------------------------- /frontend/public/robots.txt: -------------------------------------------------------------------------------- 1 | # https://www.robotstxt.org/robotstxt.html 2 | User-agent: * 3 | Disallow: 4 | -------------------------------------------------------------------------------- /frontend/src/App.css: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 3 | * SPDX-License-Identifier: MIT-0 4 | */ 5 | 6 | .App { 7 | text-align: center; 8 | font-size: calc(10px + 2vmin); 9 | color: white; 10 | } 11 | 12 | .App-MainLogo { 13 | height: 15vmin; 14 | } 15 | 16 | .sidebar { 17 | color: white; 18 | background-color: darkblue; 19 | } 20 | .mainpanel { 21 | color:darkblue; 22 | background-color:antiquewhite; 23 | } 24 | .mainbtn { 25 | color:white; 26 | background-color: red; 27 | } 28 | .mainbtn[disabled] { 29 | color:grey; 30 | background-color:darkgrey; 31 | } 32 | .summaryOutput { 33 | overflow-y: scroll; 34 | height: 500px; 35 | width: 100%; 36 | } 37 | .qaOutput { 38 | overflow-y: scroll; 39 | height: 200px; 40 | width: 100%; 41 | } 42 | 43 | .App-logoL2R { 44 | height: 20vmin; 45 | pointer-events: none; 46 | } 47 | 48 | .App-logoR2L { 49 | height: 20vmin; 50 | pointer-events: none; 51 | } 52 | 53 | @media (prefers-reduced-motion: no-preference) { 54 | .App-MainLogo { 55 | } 56 | } 57 | 58 | @media (prefers-reduced-motion: no-preference) { 59 | .App-logoL2R { 60 | animation: Logo-spinL2R infinite 5s ease-in-out; 61 | } 62 | } 63 | 64 | @media (prefers-reduced-motion: no-preference) { 65 | .App-logoR2L { 66 | animation: Logo-spinR2L infinite 5s ease-in-out; 67 | } 68 | } 69 | .App-header { 70 | min-height: 21vmin; 71 | display: flex; 72 | flex-direction: column; 73 | align-items: center; 74 | justify-content: center; 75 | } 76 | 77 | .logos { 78 | display: flex; 79 | justify-content: space-evenly; 80 | padding-top: 50px; 81 | } 82 | 83 | .App-link { 84 | color: #f1a20d; 85 | } 86 | 87 | @keyframes Logo-spinL2R { 88 | from { 89 | transform: rotate(0deg); 90 | } 91 | to { 92 | transform: rotate(360deg); 93 | } 94 | } 95 | 96 | @keyframes Logo-spinR2L { 97 | from { 98 | transform: rotate(360deg); 99 | } 100 | to { 101 | transform: rotate(0deg); 102 | } 103 | } 104 | 105 | @keyframes zoom-in-zoom-out { 106 | 0% { 107 | transform: scale(1, 1); 108 | } 109 | 50% { 110 | transform: scale(1.5, 1.5); 111 | } 112 | 100% { 113 | transform: scale(1, 1); 114 | } 115 | } 116 | -------------------------------------------------------------------------------- /frontend/src/App.js: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 3 | * SPDX-License-Identifier: MIT-0 4 | */ 5 | 6 | import logo from './logo.png'; 7 | import 'bootstrap/dist/css/bootstrap.min.css'; 8 | import './App.css'; 9 | import { useFetch } from "react-async" 10 | import { withAuthenticator } from '@aws-amplify/ui-react'; 11 | import '@aws-amplify/ui-react/styles.css'; 12 | import { Amplify, Auth, Storage, API } from "aws-amplify"; 13 | import 'react-toastify/dist/ReactToastify.css'; 14 | import { ToastContainer, toast } from 'react-toastify'; 15 | import { useState } from 'react'; 16 | import Container from 'react-bootstrap/Container'; 17 | import Row from 'react-bootstrap/Row'; 18 | import Col from 'react-bootstrap/Col'; 19 | import Stack from 'react-bootstrap/Stack'; 20 | import config from "./config"; 21 | import { v4 as uuidv4 } from 'uuid'; 22 | import DynamoDB from 'aws-sdk/clients/dynamodb'; 23 | import Collapse from 'react-bootstrap/Collapse'; 24 | import Button from 'react-bootstrap/Button'; 25 | 26 | function App({ signOut, user }) { 27 | // Document ID - UUID created after file upload 28 | const [docid, setDocid] = useState(''); 29 | 30 | // PDF extraction job id - returned from Textract 31 | const [jobid, setJobid] = useState(''); 32 | 33 | // Embedding job id - 34 | const [ejobid, setEjobid] = useState(''); 35 | 36 | // File name - will be stored in S3 under uploads prefix 37 | const [fname, setFname] = useState(''); 38 | 39 | // Extract file name - will be stored under uploads prefix 40 | const [ename, setEname] = useState(''); 41 | 42 | // Download link signed URL for extracted results 43 | const [dname, setDname] = useState(''); 44 | 45 | // Summarization job id - returned by api 46 | const [sjobid, setSjobid] = useState(''); 47 | 48 | // Indicates if extraction job is done 49 | const [isJobdone, setIsJobdone] = useState(false); 50 | 51 | // Indicates if summarization job is done 52 | const [isSjobdone, setIsSjobdone] = useState(false); 53 | 54 | // Indicates if embeddingjob is done 55 | const [isEjobdone, setIsEjobdone] = useState(false); 56 | 57 | // Summarization text 58 | const [summaryText, setSummaryText] = useState(''); 59 | 60 | // Question 61 | const [qaQ, setQaQ] = useState(''); 62 | 63 | // Answer 64 | const [qaA, setQaA] = useState(''); 65 | 66 | // Chunk size 67 | const [chunkSize, setChunkSize] = useState(10000); 68 | 69 | // Summarization text 70 | const [chunkOverlap, setChunkOverlap] = useState(1000); 71 | 72 | // Temperature 73 | const [temperature, setTemperature] = useState(0.5); 74 | 75 | // Top-k 76 | const [topK, setTopK] = useState(100); 77 | 78 | // Top-p 79 | const [topP, setTopP] = useState(0.9); 80 | 81 | // Max sequence length 82 | const [maxLength, setMaxLength] = useState(10000); 83 | 84 | // Summarization text 85 | const [numBeams, setNumBeams] = useState(2); 86 | 87 | // S3 bucket name 88 | const bucket = config.content.bucket; 89 | 90 | // collapse toggles 91 | const [open, setOpen] = useState(false); 92 | const [opensum, setOpensum] = useState(false); 93 | const [openqa, setOpenqa] = useState(false); 94 | 95 | async function uploadFile(e) { 96 | const file = e.target.files[0]; 97 | try { 98 | const result = await Storage.put(file.name, file, { 99 | progressCallback(progress) { 100 | console.log(`Uploaded: ${progress.loaded}/${progress.total}`); 101 | }, 102 | }); 103 | setDocid(uuidv4()); 104 | setFname(config.content.prefix + file.name); 105 | toast.success(`Uploaded file ${result.key}`) 106 | } catch (error) { 107 | console.log("Error uploading file: ", error); 108 | toast.warning("Failed to upload file"); 109 | } 110 | } 111 | 112 | async function startOver() { 113 | console.log("Clearing state"); 114 | setDname(''); 115 | setEname(''); 116 | setFname(''); 117 | setDocid(''); 118 | setJobid(''); 119 | setSjobid(''); 120 | setSummaryText(''); 121 | setIsJobdone(false); 122 | setIsSjobdone(false); 123 | document.getElementById('docpicker').value = '' 124 | setChunkOverlap(1000); 125 | setChunkSize(10000); 126 | setTemperature(0.5); 127 | setTopK(100); 128 | setTopP(0.9); 129 | setMaxLength(10000); 130 | setNumBeams(2); 131 | } 132 | 133 | async function pdf2txt() { 134 | console.log("Starting PDF extraction: " + docid); 135 | try { 136 | const result = await API.post("docs", "/doctopdf", { 137 | body: { 138 | 'docId': docid, 139 | 'bucket': bucket, 140 | 'name': fname 141 | } 142 | }); 143 | setJobid(result.jobId) 144 | toast.success("PDF extraction started"); 145 | setTimeout(() => { checkJobStatus(); }, 30000); 146 | } 147 | catch(error) { 148 | console.log("Error starting PDF extraction: ", error); 149 | toast.warning("Failed to start PDF extraction"); 150 | } 151 | } 152 | 153 | async function genembed() { 154 | console.log("Starting embedding generation: " + docid); 155 | try { 156 | const result = await API.post("docs", "/embed", { 157 | body: { 158 | 'docId': docid, 159 | 'bucket': bucket, 160 | 'name': ename 161 | } 162 | }); 163 | setEjobid(result.job) 164 | toast.success("Embedding generation started"); 165 | setTimeout(() => { checkEJobStatus(); }, 30000); 166 | } 167 | catch(error) { 168 | console.log("Error starting embeddings: ", error); 169 | toast.warning("Failed to start embedding generation"); 170 | } 171 | } 172 | 173 | async function summarize() { 174 | console.log("Starting summarization: " + docid); 175 | try { 176 | const result = await API.post("docs", "/summarize", { 177 | body: { 178 | 'docId': docid, 179 | 'bucket': bucket, 180 | 'name': ename, 181 | 'chunkSize': chunkSize, 182 | 'chunkOverlap': chunkOverlap, 183 | 'max_length': maxLength, 184 | 'top_p': topP, 185 | 'top_k': topK, 186 | 'num_beams': numBeams, 187 | 'temperature': temperature, 188 | } 189 | }); 190 | console.log("Summarization job ID: " + result.job) 191 | setSjobid(result.job) 192 | toast.success("Summarization started"); 193 | setTimeout(() => { checkSummarizationStatus(result.job); }, 30000); 194 | } 195 | catch(error) { 196 | console.log("Error starting summarization: ", error); 197 | toast.warning("Failed to start summarization"); 198 | } 199 | } 200 | 201 | async function getanswer() { 202 | console.log("Starting answer: " + docid); 203 | try { 204 | const result = await API.post("docs", "/qa", { 205 | body: { 206 | 'docId': docid, 207 | 'question': qaQ 208 | }, 209 | headers: { 210 | 'Content-Type': "application/json" 211 | } 212 | }); 213 | if (result.code == 200) { 214 | setQaA(result.answer) 215 | console.log("Got answer: ", result.answer) 216 | } 217 | else { 218 | console.log("Error getting answer: ", result.error); 219 | toast.warning("Failed to get answer"); 220 | } 221 | } 222 | catch(error) { 223 | console.log("Error getting answer: ", error); 224 | toast.warning("Failed to get answer"); 225 | } 226 | } 227 | 228 | async function downloadExtract(opath) { 229 | var ekey = opath.replace(config.content.prefix, '') 230 | console.log("Getting signed url for key " + ekey); 231 | const signedURL = await Storage.get(ekey); 232 | console.log("Got signed URL: " + signedURL) 233 | setDname(signedURL); 234 | } 235 | 236 | function checkJobStatus() { 237 | Auth.currentCredentials() 238 | .then(credentials => { 239 | const db= new DynamoDB({ 240 | region: config.content.REGION, 241 | credentials: Auth.essentialCredentials(credentials) 242 | }); 243 | var params = { 244 | TableName: config.tables.jobtable, 245 | KeyConditionExpression: '#documentid = :docid', 246 | ExpressionAttributeNames: { 247 | "#documentid": "documentId" 248 | }, 249 | ExpressionAttributeValues: { 250 | ":docid": { "S" : docid}, 251 | } 252 | }; 253 | db.query(params, function(err, data) { 254 | if (err) { 255 | console.log(err); 256 | return null; 257 | } else { 258 | 259 | console.log('Got data'); 260 | console.log(data); 261 | 262 | var jobStatus = ''; 263 | for (var i in data['Items']) { 264 | // read the values from the dynamodb JSON packet 265 | jobStatus = data['Items'][i]['jobStatus']['S']; 266 | console.log(jobStatus); 267 | if(jobStatus.includes("SUCCEEDED")) { 268 | setIsJobdone(true); 269 | } 270 | } 271 | if(jobStatus.includes("SUCCEEDED")) { 272 | console.log("PDF extraction done") 273 | toast.success("PDF extraction done") 274 | getOutputPath(); 275 | } 276 | else { 277 | toast.info("Checking job status every 30 seconds...") 278 | setTimeout(() => { checkJobStatus(); }, 30000); 279 | } 280 | } 281 | }) 282 | }); 283 | } 284 | 285 | function checkEJobStatus() { 286 | Auth.currentCredentials() 287 | .then(credentials => { 288 | const db= new DynamoDB({ 289 | region: config.content.REGION, 290 | credentials: Auth.essentialCredentials(credentials) 291 | }); 292 | var params = { 293 | TableName: config.tables.ejobtable, 294 | KeyConditionExpression: '#documentid = :docid', 295 | ExpressionAttributeNames: { 296 | "#documentid": "documentId" 297 | }, 298 | ExpressionAttributeValues: { 299 | ":docid": { "S" : docid}, 300 | } 301 | }; 302 | db.query(params, function(err, data) { 303 | if (err) { 304 | console.log(err); 305 | return null; 306 | } else { 307 | 308 | console.log('Got data'); 309 | console.log(data); 310 | 311 | var jobStatus = ''; 312 | for (var i in data['Items']) { 313 | // read the values from the dynamodb JSON packet 314 | jobStatus = data['Items'][i]['jobStatus']['S']; 315 | console.log(jobStatus); 316 | if(jobStatus.includes("Complete")) { 317 | setIsEjobdone(true); 318 | } 319 | } 320 | if(jobStatus.includes("Complete")) { 321 | console.log("Embeddings done") 322 | toast.success("Embedding generation done") 323 | } 324 | else { 325 | toast.info("Checking job status every 30 seconds...") 326 | setTimeout(() => { checkEJobStatus(); }, 30000); 327 | } 328 | } 329 | }) 330 | }); 331 | } 332 | 333 | function getOutputPath() { 334 | Auth.currentCredentials() 335 | .then(credentials => { 336 | const db= new DynamoDB({ 337 | region: config.content.REGION, 338 | credentials: Auth.essentialCredentials(credentials) 339 | }); 340 | var params = { 341 | TableName: config.tables.outputtable, 342 | KeyConditionExpression: '#documentid = :docid AND #outputtype = :otype', 343 | ExpressionAttributeNames: { 344 | "#documentid": "documentId", 345 | "#outputtype": "outputType" 346 | }, 347 | ExpressionAttributeValues: { 348 | ":docid": { "S" : docid}, 349 | ":otype": { "S" : "ResponseOrderedText"} 350 | } 351 | }; 352 | db.query(params, function(err, data) { 353 | if (err) { 354 | console.log(err); 355 | return null; 356 | } else { 357 | 358 | console.log('Got data'); 359 | console.log(data); 360 | 361 | for (var i in data['Items']) { 362 | // read the values from the dynamodb JSON packet 363 | var opath = data['Items'][i]['outputPath']['S']; 364 | console.log("Output path: " + opath); 365 | setEname(opath); 366 | downloadExtract(opath); 367 | } 368 | } 369 | }) 370 | }); 371 | } 372 | 373 | function checkSummarizationStatus(sumjobid) { 374 | Auth.currentCredentials() 375 | .then(credentials => { 376 | const db= new DynamoDB({ 377 | region: config.content.REGION, 378 | credentials: Auth.essentialCredentials(credentials) 379 | }); 380 | var params = { 381 | TableName: config.tables.sumtable, 382 | KeyConditionExpression: '#documentid = :docid AND #jobid = :jobidvalue', 383 | ExpressionAttributeNames: { 384 | "#documentid": "documentId", 385 | "#jobid": "jobId" 386 | }, 387 | ExpressionAttributeValues: { 388 | ":docid": { "S" : docid}, 389 | ":jobidvalue": { "S" : sumjobid}, 390 | } 391 | }; 392 | db.query(params, function(err, data) { 393 | if (err) { 394 | console.log(err); 395 | return null; 396 | } else { 397 | 398 | console.log('Got data'); 399 | console.log(data); 400 | 401 | var jobStatus = ''; 402 | for (var i in data['Items']) { 403 | // read the values from the dynamodb JSON packet 404 | jobStatus = data['Items'][i]['jobStatus']['S']; 405 | console.log(jobStatus); 406 | if (jobStatus.includes("Complete")) { 407 | var stext = data['Items'][i]['summaryText']['S']; 408 | console.log("Summary: " + stext) 409 | setSummaryText(stext); 410 | setIsSjobdone(true); 411 | } 412 | } 413 | if (jobStatus.includes("Complete")) { 414 | console.log("Summarization done") 415 | toast.success("Summarization done") 416 | } 417 | else { 418 | toast.info("Checking job status every 30 seconds...") 419 | setTimeout(() => { checkSummarizationStatus(sumjobid); }, 30000); 420 | } 421 | } 422 | }) 423 | }); 424 | } 425 | 426 | function changeQaq(e) { 427 | setQaQ(e.target.value); 428 | } 429 | function changeChunkSize(e) { 430 | setChunkSize(e.target.value); 431 | } 432 | function changeChunkOverlap(e) { 433 | setChunkOverlap(e.target.value); 434 | } 435 | function changeTopP(e) { 436 | setTopP(e.target.value); 437 | } 438 | function changeTopK(e) { 439 | setTopK(e.target.value); 440 | } 441 | function changeNumBeams(e) { 442 | setNumBeams(e.target.value); 443 | } 444 | function changeTemperature(e) { 445 | setTemperature(e.target.value); 446 | } 447 | function changeMaxLength(e) { 448 | setMaxLength(e.target.value); 449 | } 450 | 451 | return ( 452 | 453 | 454 | 455 | 456 |
457 | logo 458 |
459 |
460 |

This application lets you upload a PDF, convert it to text, summarize it, and ask questions about it.

461 |
462 |
463 | Upload file: 464 |
465 |
466 | Document id: {docid} 467 |
468 |
469 | 470 | 471 | 472 |
473 |
474 | 475 | 476 | 477 |
478 | 479 |

Extraction job id: {jobid}

480 | {dname !== '' && 481 | Download summary 482 | } 483 |

484 | 485 |

Embedding job id: {ejobid}

486 |
487 |
488 | 495 |

496 | 497 |
498 | 501 |

502 | 505 |

506 | 509 |

510 | 513 |

514 | 517 |

518 | 521 |

522 | 525 |
526 |
527 |
528 |
529 |

Question answering

530 | 537 |

538 | 539 |
540 | 541 |

542 |

543 | 544 |
545 |
546 |
547 |
548 |

Summary

549 | 556 |

557 | 558 |
559 |

560 |

Summarization job id: {sjobid}

561 | 562 |
563 |
564 |
565 |
566 | 567 |
568 |
569 | ); 570 | } 571 | 572 | export default withAuthenticator(App); 573 | -------------------------------------------------------------------------------- /frontend/src/App.test.js: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 3 | * SPDX-License-Identifier: MIT-0 4 | */ 5 | 6 | import { render, screen } from '@testing-library/react'; 7 | import App from './App'; 8 | 9 | test('renders learn react link', () => { 10 | render(); 11 | const linkElement = screen.getByText(/react-based application/i); 12 | expect(linkElement).toBeInTheDocument(); 13 | }); 14 | -------------------------------------------------------------------------------- /frontend/src/config.js: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 3 | * SPDX-License-Identifier: MIT-0 4 | */ 5 | export default { 6 | apiGateway: { 7 | REGION: "", 8 | URL: "" 9 | }, 10 | cognito: { 11 | REGION: "", 12 | USER_POOL_ID: "", 13 | APP_CLIENT_ID: "", 14 | IDENTITY_POOL_ID: ", 15 | }, 16 | content: { 17 | bucket: "cdkstack-documentsbucket9ec9deb9-t7lgk3l18gaa", 18 | REGION: "", 19 | prefix: "uploads/" 20 | }, 21 | tables: { 22 | jobtable: "", 23 | ejobtable: "", 24 | outputtable: "", 25 | sumtable: "" 26 | } 27 | }; 28 | 29 | -------------------------------------------------------------------------------- /frontend/src/index.css: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 3 | * SPDX-License-Identifier: MIT-0 4 | */ 5 | body { 6 | margin: 0; 7 | font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', 'Roboto', 'Oxygen', 8 | 'Ubuntu', 'Cantarell', 'Fira Sans', 'Droid Sans', 'Helvetica Neue', 9 | sans-serif; 10 | -webkit-font-smoothing: antialiased; 11 | -moz-osx-font-smoothing: grayscale; 12 | } 13 | 14 | code { 15 | font-family: source-code-pro, Menlo, Monaco, Consolas, 'Courier New', 16 | monospace; 17 | } 18 | -------------------------------------------------------------------------------- /frontend/src/index.js: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 3 | * SPDX-License-Identifier: MIT-0 4 | */ 5 | import React from 'react'; 6 | import ReactDOM from 'react-dom'; 7 | import './index.css'; 8 | import App from './App'; 9 | import { Amplify, Auth, Storage, API } from "aws-amplify"; 10 | import config from "./config"; 11 | import { ToastContainer, toast } from 'react-toastify'; 12 | 13 | //Amplify.Logger.LOG_LEVEL = 'DEBUG'; 14 | Amplify.configure({ 15 | Auth: { 16 | mandatorySignIn: true, 17 | region: config.cognito.REGION, 18 | userPoolId: config.cognito.USER_POOL_ID, 19 | identityPoolId: config.cognito.IDENTITY_POOL_ID, 20 | userPoolWebClientId: config.cognito.APP_CLIENT_ID 21 | }, 22 | Storage: { 23 | AWSS3: { 24 | bucket: config.content.bucket, 25 | region: config.content.REGION, 26 | }, 27 | customPrefix: { 28 | public: config.content.prefix, 29 | protected: config.content.prefix, 30 | private: config.content.prefix, 31 | }, 32 | }, 33 | API: { 34 | endpoints: [ 35 | { 36 | name: "docs", 37 | endpoint: config.apiGateway.URL, 38 | region: config.apiGateway.REGION 39 | }, 40 | ] 41 | } 42 | }); 43 | 44 | ReactDOM.render( 45 | 46 | 47 | , 48 | document.getElementById('root') 49 | ); 50 | 51 | -------------------------------------------------------------------------------- /frontend/src/logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aws-samples/question-answering-large-documents/e683efb9186682fcf8e4c8e76465c2792efb834d/frontend/src/logo.png -------------------------------------------------------------------------------- /frontend/src/setupTests.js: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 3 | * SPDX-License-Identifier: MIT-0 4 | */ 5 | // jest-dom adds custom jest matchers for asserting on DOM nodes. 6 | // allows you to do things like: 7 | // expect(element).toHaveTextContent(/react/i) 8 | // learn more: https://github.com/testing-library/jest-dom 9 | import '@testing-library/jest-dom'; 10 | -------------------------------------------------------------------------------- /screenshots/summarization.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aws-samples/question-answering-large-documents/e683efb9186682fcf8e4c8e76465c2792efb834d/screenshots/summarization.png -------------------------------------------------------------------------------- /scripts/create-user.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 3 | # SPDX-License-Identifier: MIT-0 4 | 5 | USERNAME=$1 6 | EMAIL=$2 7 | PASSWORD=$3 8 | USERPOOLID=$4 9 | CLIENTID=$5 10 | GROUPNAME=$6 11 | 12 | 13 | if [ "$USERNAME" == "" ] 14 | then 15 | echo "Usage: $0 " 16 | exit 1 17 | fi 18 | if [ "$EMAIL" == "" ] 19 | then 20 | echo "Usage: $0 " 21 | exit 1 22 | fi 23 | if [ "$PASSWORD" == "" ] 24 | then 25 | echo "Usage: $0 " 26 | exit 1 27 | fi 28 | if [ "$USERPOOLID" == "" ] 29 | then 30 | echo "Usage: $0 " 31 | exit 1 32 | fi 33 | if [ "$CLIENTID" == "" ] 34 | then 35 | echo "Usage: $0 " 36 | exit 1 37 | fi 38 | if [ "$GROUPNAME" == "" ] 39 | then 40 | echo "Usage: $0 " 41 | exit 1 42 | fi 43 | 44 | aws cognito-idp update-user-pool-client --user-pool-id ${USERPOOLID} --client-id ${CLIENTID} --explicit-auth-flows ADMIN_NO_SRP_AUTH 45 | 46 | aws cognito-idp sign-up --client-id ${CLIENTID} --username ${USERNAME} --password ${PASSWORD} --user-attributes "[ { \"Name\": \"email\", \"Value\": \"$EMAIL\" }, { \"Name\": \"phone_number\", \"Value\": \"+12485551212\" }]" 47 | 48 | aws cognito-idp admin-confirm-sign-up --user-pool-id ${USERPOOLID} --username ${USERNAME} 49 | 50 | cat << EOF > /tmp/authflow.json 51 | { "AuthFlow": "ADMIN_NO_SRP_AUTH", "AuthParameters": { "USERNAME": "${USERNAME}", "PASSWORD": "${PASSWORD}" } } 52 | EOF 53 | 54 | JWT_ID_TOKEN=$(aws cognito-idp admin-initiate-auth --user-pool-id ${USERPOOLID} --client-id ${CLIENTID} --cli-input-json file:///tmp/authflow.json --query AuthenticationResult.IdToken --output text) 55 | 56 | aws cognito-idp admin-add-user-to-group --user-pool-id ${USERPOOLID} --username $USERNAME --group-name $GROUPNAME 57 | --------------------------------------------------------------------------------