├── .editorconfig ├── .github └── PULL_REQUEST_TEMPLATE.md ├── .gitignore ├── .infrastructure └── attp-bootstrap.cfn.yaml ├── .vscode └── settings.json ├── CODE_OF_CONDUCT.md ├── CONTRIBUTING.md ├── CUSTOMIZATION_GUIDE.md ├── LICENSE ├── LICENSE-SUMMARY ├── README.md ├── annotation ├── __init__.py ├── fn-SMGT-Post │ ├── data_model.py │ ├── main.py │ └── smgt.py └── fn-SMGT-Pre │ └── main.py ├── cdk.json ├── cdk_app.py ├── cdk_demo_stack.py ├── img ├── annotation-example-trimmed.png ├── architecture-overview.png ├── human-review-sample.png └── sfn-execution-screenshot.png ├── notebooks ├── 1. Data Preparation.ipynb ├── 2. Model Training.ipynb ├── 3. Human Review.ipynb ├── Optional Extras.ipynb ├── Workshop.ipynb ├── annotation │ └── ocr-bbox-and-validation.liquid.tpl.html ├── custom-containers │ ├── preproc │ │ └── Dockerfile │ └── train-inf │ │ └── Dockerfile ├── data │ └── annotations │ │ ├── LICENSE │ │ ├── augmentation-1 │ │ └── manifests │ │ │ └── output │ │ │ └── output.manifest │ │ └── augmentation-2 │ │ └── manifests │ │ └── output │ │ └── output.manifest ├── img │ ├── a2i-custom-template-demo.png │ ├── cfn-stack-outputs-a2i.png │ ├── sfn-execution-status-screenshot.png │ ├── sfn-statemachine-screenshot.png │ ├── sfn-statemachine-success.png │ ├── smgt-custom-template-demo.png │ ├── smgt-find-workforce-url.png │ ├── smgt-private-workforce.png │ ├── smgt-task-pending.png │ ├── ssm-a2i-param-detail.png │ └── ssm-param-detail-screenshot.png ├── preproc │ ├── __init__.py │ ├── ocr.py │ ├── preproc.py │ └── textract_transformers │ │ ├── __init__.py │ │ ├── file_utils.py │ │ ├── image_utils.py │ │ ├── ocr.py │ │ ├── ocr_engines │ │ ├── __init__.py │ │ ├── base.py │ │ └── eng_tesseract.py │ │ └── preproc.py ├── review │ ├── .eslintrc.cjs │ ├── .gitignore │ ├── .prettierrc.yml │ ├── README.md │ ├── env.d.ts │ ├── fields-validation-legacy.liquid.html │ ├── index-noliquid.html │ ├── index.html │ ├── package-lock.json │ ├── package.json │ ├── public │ │ └── .gitkeep │ ├── src │ │ ├── App.vue │ │ ├── assets │ │ │ └── base.scss │ │ ├── components │ │ │ ├── FieldMultiValue.ce.vue │ │ │ ├── FieldSingleValue.ce │ │ │ │ ├── FieldSingleValue.ce.vue │ │ │ │ └── index.ts │ │ │ ├── HelloWorld.vue │ │ │ ├── MultiFieldValue.ce.vue │ │ │ ├── ObjectValueInput.ts │ │ │ ├── PdfPageAnnotationLayer.ce.vue │ │ │ └── Viewer.ce.vue │ │ ├── main.ts │ │ └── util │ │ │ ├── colors.ts │ │ │ ├── model.d.ts │ │ │ └── store.ts │ ├── task-input.example.json │ ├── tsconfig.json │ └── vite.config.js ├── src │ ├── __init__.py │ ├── code │ │ ├── __init__.py │ │ ├── config.py │ │ ├── data │ │ │ ├── __init__.py │ │ │ ├── base.py │ │ │ ├── geometry.py │ │ │ ├── mlm.py │ │ │ ├── ner.py │ │ │ ├── seq2seq │ │ │ │ ├── __init__.py │ │ │ │ ├── date_normalization.py │ │ │ │ ├── metrics.py │ │ │ │ └── task_builder.py │ │ │ ├── smgt.py │ │ │ └── splitting.py │ │ ├── inference.py │ │ ├── inference_seq2seq.py │ │ ├── logging_utils.py │ │ ├── models │ │ │ ├── __init__.py │ │ │ └── layoutlmv2.py │ │ ├── smddpfix.py │ │ └── train.py │ ├── ddp_launcher.py │ ├── inference.py │ ├── inference_seq2seq.py │ ├── requirements.txt │ ├── smtc_launcher.py │ └── train.py └── util │ ├── __init__.py │ ├── deployment.py │ ├── ocr.py │ ├── postproc │ ├── preproc.py │ ├── project.py │ ├── s3.py │ ├── smgt.py │ ├── training.py │ ├── uid.py │ └── viz.py ├── package-lock.json ├── package.json ├── pipeline ├── __init__.py ├── config_utils.py ├── enrichment │ └── __init__.py ├── fn-trigger │ ├── main.py │ └── requirements.txt ├── iam_utils.py ├── ocr │ ├── __init__.py │ ├── fn-call-textract │ │ └── main.py │ ├── sagemaker_ocr.py │ ├── sfn_semaphore │ │ ├── __init__.py │ │ └── fn-acquire-lock │ │ │ └── main.py │ └── textract_ocr.py ├── postprocessing │ ├── __init__.py │ └── fn-postprocess │ │ ├── main.py │ │ ├── requirements.txt │ │ └── util │ │ ├── __init__.py │ │ ├── boxes.py │ │ ├── config.py │ │ ├── deser.py │ │ ├── extract.py │ │ └── normalize.py ├── review │ ├── __init__.py │ ├── fn-review-callback │ │ └── main.py │ └── fn-start-review │ │ └── main.py ├── shared │ ├── __init__.py │ └── sagemaker │ │ ├── __init__.py │ │ ├── fn-call-sagemaker │ │ ├── main.py │ │ └── requirements.txt │ │ ├── model_deployment.py │ │ └── sagemaker_sfn.py └── thumbnails │ └── __init__.py ├── poetry.lock ├── pyproject.toml ├── requirements.txt ├── setup.py └── source.bat /.editorconfig: -------------------------------------------------------------------------------- 1 | # Top-most EditorConfig file: 2 | root = true 3 | 4 | [*] 5 | charset = utf-8 6 | end_of_line = lf 7 | insert_final_newline = true 8 | trim_trailing_whitespace = true 9 | -------------------------------------------------------------------------------- /.github/PULL_REQUEST_TEMPLATE.md: -------------------------------------------------------------------------------- 1 | **Issue #, if available:** 2 | 3 | **Description of changes:** 4 | 5 | **Testing done:** 6 | 7 | 8 |
9 | 10 | By submitting this pull request, I confirm that you can use, modify, copy, and redistribute this contribution, under the terms of your choice. 11 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # AWS SAM build artifacts 2 | .aws-sam/ 3 | *.tmp.* 4 | 5 | # AWS CDK stuff 6 | *.swp 7 | .cdk.staging 8 | cdk.out 9 | 10 | # SageMaker Experiments/Debugger (if try running locally): 11 | tmp_trainer/ 12 | 13 | # Working data folders and notebook-built assets: 14 | # (With some specific exclusions) 15 | notebooks/data/* 16 | !notebooks/data/annotations/ 17 | notebooks/data/annotations/* 18 | !notebooks/data/annotations/augmentation-* 19 | !notebooks/data/annotations/LICENSE 20 | notebooks/annotation/*.html 21 | !notebooks/annotation/*.tpl.html 22 | 23 | # Byte-compiled / optimized / DLL files 24 | __pycache__/ 25 | *.py[cod] 26 | *$py.class 27 | 28 | # C extensions 29 | *.so 30 | 31 | # Distribution / packaging 32 | .Python 33 | build/ 34 | develop-eggs/ 35 | dist/ 36 | downloads/ 37 | eggs/ 38 | .eggs/ 39 | lib/ 40 | lib64/ 41 | parts/ 42 | sdist/ 43 | var/ 44 | wheels/ 45 | pip-wheel-metadata/ 46 | share/python-wheels/ 47 | *.egg-info/ 48 | .installed.cfg 49 | *.egg 50 | MANIFEST 51 | 52 | # Installer logs 53 | pip-log.txt 54 | pip-delete-this-directory.txt 55 | 56 | # Unit test / coverage reports 57 | htmlcov/ 58 | .tox/ 59 | .nox/ 60 | .coverage 61 | .coverage.* 62 | .cache 63 | nosetests.xml 64 | coverage.xml 65 | *.cover 66 | *.py,cover 67 | .hypothesis/ 68 | .pytest_cache/ 69 | 70 | # Translations 71 | *.mo 72 | *.pot 73 | 74 | # Flask stuff: 75 | instance/ 76 | .webassets-cache 77 | 78 | # Sphinx documentation 79 | docs/_build/ 80 | 81 | # PyBuilder 82 | target/ 83 | 84 | # Jupyter Notebook 85 | .ipynb_checkpoints 86 | 87 | # IPython 88 | profile_default/ 89 | ipython_config.py 90 | 91 | # pyenv 92 | .python-version 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # SageMath parsed files 98 | *.sage.py 99 | 100 | # Environments 101 | .env 102 | .venv 103 | env/ 104 | venv/ 105 | ENV/ 106 | env.bak/ 107 | venv.bak/ 108 | 109 | # Spyder project settings 110 | .spyderproject 111 | .spyproject 112 | 113 | # Rope project settings 114 | .ropeproject 115 | 116 | # mkdocs documentation 117 | /site 118 | 119 | # mypy 120 | .mypy_cache/ 121 | .dmypy.json 122 | dmypy.json 123 | 124 | # Pyre type checker 125 | .pyre/ 126 | 127 | # Checks and reporting 128 | /*report* 129 | 130 | # Operating Systems 131 | .DS_Store 132 | 133 | # NodeJS 134 | **/node_modules/** 135 | .nvmrc 136 | -------------------------------------------------------------------------------- /.vscode/settings.json: -------------------------------------------------------------------------------- 1 | { 2 | "python.pythonPath": ".venv/bin/python3", 3 | "yaml.customTags": [ 4 | "!And", 5 | "!And sequence", 6 | "!If", 7 | "!If sequence", 8 | "!Not", 9 | "!Not sequence", 10 | "!Equals", 11 | "!Equals sequence", 12 | "!Or", 13 | "!Or sequence", 14 | "!FindInMap", 15 | "!FindInMap sequence", 16 | "!Base64", 17 | "!Join", 18 | "!Join sequence", 19 | "!Cidr", 20 | "!Ref", 21 | "!Sub", 22 | "!Sub sequence", 23 | "!GetAtt", 24 | "!GetAZs", 25 | "!ImportValue", 26 | "!ImportValue sequence", 27 | "!Select", 28 | "!Select sequence", 29 | "!Split", 30 | "!Split sequence" 31 | ] 32 | } -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | ## Code of Conduct 2 | 3 | This project has adopted the [Amazon Open Source Code of Conduct](https://aws.github.io/code-of-conduct). 4 | For more information see the [Code of Conduct FAQ](https://aws.github.io/code-of-conduct-faq) or contact 5 | opensource-codeofconduct@amazon.com with any additional questions or comments. 6 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing Guidelines 2 | 3 | Thank you for your interest in contributing to our project. Whether it's a bug report, new feature, correction, or additional 4 | documentation, we greatly value feedback and contributions from our community. 5 | 6 | Please read through this document before submitting any issues or pull requests to ensure we have all the necessary 7 | information to effectively respond to your bug report or contribution. 8 | 9 | 10 | ## Reporting Bugs/Feature Requests 11 | 12 | We welcome you to use the GitHub issue tracker to report bugs or suggest features. 13 | 14 | When filing an issue, please check existing open, or recently closed, issues to make sure somebody else hasn't already 15 | reported the issue. Please try to include as much information as you can. Details like these are incredibly useful: 16 | 17 | * A reproducible test case or series of steps 18 | * The version of our code being used 19 | * Any modifications you've made relevant to the bug 20 | * Anything unusual about your environment or deployment 21 | 22 | 23 | ## Contributing via Pull Requests 24 | Contributions via pull requests are much appreciated. Before sending us a pull request, please ensure that: 25 | 26 | 1. You are working against the latest source on the *main* branch. 27 | 2. You check existing open, and recently merged, pull requests to make sure someone else hasn't addressed the problem already. 28 | 3. You open an issue to discuss any significant work - we would hate for your time to be wasted. 29 | 30 | To send us a pull request, please: 31 | 32 | 1. Fork the repository. 33 | 2. Modify the source; please focus on the specific change you are contributing. If you also reformat all the code, it will be hard for us to focus on your change. 34 | 3. Ensure local tests pass. 35 | 4. Commit to your fork using clear commit messages. 36 | 5. Send us a pull request, answering any default questions in the pull request interface. 37 | 6. Pay attention to any automated CI failures reported in the pull request, and stay involved in the conversation. 38 | 39 | GitHub provides additional document on [forking a repository](https://help.github.com/articles/fork-a-repo/) and 40 | [creating a pull request](https://help.github.com/articles/creating-a-pull-request/). 41 | 42 | 43 | ## Finding contributions to work on 44 | Looking at the existing issues is a great way to find something to contribute on. As our projects, by default, use the default GitHub issue labels (enhancement/bug/duplicate/help wanted/invalid/question/wontfix), looking at any 'help wanted' issues is a great place to start. 45 | 46 | 47 | ## Code of Conduct 48 | This project has adopted the [Amazon Open Source Code of Conduct](https://aws.github.io/code-of-conduct). 49 | For more information see the [Code of Conduct FAQ](https://aws.github.io/code-of-conduct-faq) or contact 50 | opensource-codeofconduct@amazon.com with any additional questions or comments. 51 | 52 | 53 | ## Security issue notifications 54 | If you discover a potential security issue in this project we ask that you notify AWS/Amazon Security via our [vulnerability reporting page](http://aws.amazon.com/security/vulnerability-reporting/). Please do **not** create a public github issue. 55 | 56 | 57 | ## Licensing 58 | 59 | See the [LICENSE](LICENSE) file for our project's licensing. We will ask you to confirm the licensing of your contribution. 60 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | 3 | Permission is hereby granted, free of charge, to any person obtaining a copy of 4 | this software and associated documentation files (the "Software"), to deal in 5 | the Software without restriction, including without limitation the rights to 6 | use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of 7 | the Software, and to permit persons to whom the Software is furnished to do so. 8 | 9 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 10 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS 11 | FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR 12 | COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER 13 | IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN 14 | CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 15 | 16 | -------------------------------------------------------------------------------- /LICENSE-SUMMARY: -------------------------------------------------------------------------------- 1 | Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | 3 | This library is licensed under the MIT-0 License. See `LICENSE`. 4 | 5 | Included datasets are licensed under the Creative Commons Attribution 4.0 6 | International License. See `LICENSE` in the `notebooks/data/annotations` 7 | subfolder. 8 | -------------------------------------------------------------------------------- /annotation/fn-SMGT-Post/main.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # SPDX-License-Identifier: MIT-0 3 | """Annotation consolidation Lambda for BBoxes+transcriptions in SageMaker Ground Truth 4 | """ 5 | # Python Built-Ins: 6 | import json 7 | import logging 8 | from typing import List, Optional 9 | 10 | # External Dependencies: 11 | import boto3 # AWS SDK for Python 12 | 13 | # Set up logger before local imports: 14 | logger = logging.getLogger() 15 | logger.setLevel(logging.INFO) 16 | 17 | # Local Dependencies: 18 | from data_model import SMGTWorkerAnnotation # Custom task data model (edit if needed!) 19 | from smgt import ( # Generic SageMaker Ground Truth parsers/utilities 20 | ConsolidationRequest, 21 | ObjectAnnotationResult, 22 | PostConsolidationDatum, 23 | ) 24 | 25 | 26 | s3 = boto3.client("s3") 27 | 28 | 29 | def consolidate_object_annotations( 30 | object_data: ObjectAnnotationResult, 31 | label_attribute_name: str, 32 | label_categories: Optional[List[str]] = None, 33 | ) -> PostConsolidationDatum: 34 | """Consolidate the (potentially multiple) raw worker annotations for a dataset object 35 | 36 | TODO: Actual consolidation/reconciliation of multiple labels is not yet supported! 37 | 38 | This function just takes the "first" (not necessarily clock-first) worker's result and outputs 39 | a warning if others were found. 40 | 41 | Parameters 42 | ---------- 43 | object_data : 44 | Object describing the raw annotations and metadata for a particular task in the SMGT job 45 | label_attribute_name : 46 | Target attribute on the output object to store consolidated label results (note this may 47 | not be the *only* attribute set/updated on the output object, hence provided as a param 48 | rather than abstracted away). 49 | label_categories : 50 | Label categories specified when creating the labelling job. If provided, this is used to 51 | translate from class names to numeric class_id similarly to SMGT's built-in bounding box 52 | task result. 53 | """ 54 | warn_msgs: List[str] = [] 55 | worker_anns: List[SMGTWorkerAnnotation] = [] 56 | for worker_ann in object_data.annotations: 57 | ann_raw = worker_ann.fetch_data() 58 | worker_anns.append(SMGTWorkerAnnotation.parse(ann_raw, class_list=label_categories)) 59 | 60 | if len(worker_anns) > 1: 61 | warn_msg = ( 62 | "Reconciliation of multiple worker annotations is not currently implemented for this " 63 | "post-processor. Outputting annotation from worker %s and ignoring labels from %s" 64 | % ( 65 | object_data.annotations[0].worker_id, 66 | [a.worker_id for a in object_data.annotations[1:]], 67 | ) 68 | ) 69 | logger.warning(warn_msg) 70 | warn_msgs.append(warn_msg) 71 | 72 | consolidated_label = worker_anns[0].to_jsonable() 73 | if len(warn_msgs): 74 | consolidated_label["consolidationWarnings"] = warn_msgs 75 | 76 | return PostConsolidationDatum( 77 | dataset_object_id=object_data.dataset_object_id, 78 | consolidated_content={ 79 | label_attribute_name: consolidated_label, 80 | # Note: In our tests it's not possible to add a f"{label_attribute_name}-meta" field 81 | # here - it gets replaced by whatever post-processing happens, instead of merged. 82 | }, 83 | ) 84 | 85 | 86 | def handler(event: dict, context) -> List[dict]: 87 | """Main Lambda handler for consolidation of SMGT worker annotations 88 | 89 | This function receives a batched request to consolidate (multiple?) workers' annotations for 90 | multiple objects, and outputs the consolidated results per object. For more docs see: 91 | 92 | https://docs.aws.amazon.com/sagemaker/latest/dg/sms-custom-templates-step3-lambda-requirements.html 93 | """ 94 | logger.info("Received event: %s", json.dumps(event)) 95 | req = ConsolidationRequest.parse(event) 96 | if req.label_categories and len(req.label_categories) > 0: 97 | label_cats = req.label_categories 98 | else: 99 | logger.warning( 100 | "Label categories list (see CreateLabelingJob.LabelCategoryConfigS3Uri) was not " 101 | "provided when creating this job. Post-consolidation outputs will be incompatible with " 102 | "built-in Bounding Box task, because we're unable to map class names to numeric IDs." 103 | ) 104 | label_cats = None 105 | 106 | # Loop through the objects in this batch, consolidating annotations for each: 107 | return [ 108 | consolidate_object_annotations( 109 | object_data, 110 | label_attribute_name=req.label_attribute_name, 111 | label_categories=label_cats, 112 | ).to_jsonable() 113 | for object_data in req.fetch_object_annotations() 114 | ] 115 | -------------------------------------------------------------------------------- /annotation/fn-SMGT-Pre/main.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # SPDX-License-Identifier: MIT-0 3 | """A minimal Lambda function for pre-processing SageMaker Ground Truth custom annotation tasks 4 | 5 | Just passes through the event's `dataObject` unchanged. 6 | """ 7 | import logging 8 | 9 | logger = logging.getLogger() 10 | logger.setLevel(logging.INFO) 11 | 12 | 13 | def handler(event, context): 14 | logger.debug("Got event: %s", event) 15 | result = {"taskInput": event["dataObject"]} 16 | logger.debug("Returning result: %s", result) 17 | return result 18 | -------------------------------------------------------------------------------- /cdk.json: -------------------------------------------------------------------------------- 1 | { 2 | "app": "python3 cdk_app.py", 3 | "context": { 4 | } 5 | } 6 | -------------------------------------------------------------------------------- /cdk_app.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 3 | # SPDX-License-Identifier: MIT-0 4 | """AWS CDK app entry point for OCR pipeline sample 5 | """ 6 | # Python Built-Ins: 7 | import json 8 | import os 9 | 10 | # External Dependencies: 11 | from aws_cdk import App 12 | 13 | # Local Dependencies: 14 | from cdk_demo_stack import PipelineDemoStack 15 | from pipeline.config_utils import bool_env_var, list_env_var 16 | 17 | 18 | # Top-level configurations are loaded from environment variables at the point `cdk synth` or 19 | # `cdk deploy` is run (or you can override here): 20 | config = { 21 | # Used as a prefix for some cloud resources e.g. SSM parameters: 22 | "default_project_id": os.environ.get("DEFAULT_PROJECT_ID", default="ocr-transformers-demo"), 23 | 24 | # Set False to skip deploying the page thumbnail image generator, if you're only using models 25 | # (like LayoutLMv1) that don't take page image as input features: 26 | "use_thumbnails": bool_env_var("USE_THUMBNAILS", default=True), 27 | 28 | # Set True to enable auto-scale-to-zero on auto-deployed SageMaker endpoints (including the 29 | # thumbnail generator and any custom OCR engines). This saves costs for low-volume workloads, 30 | # but introduces a few minutes' extra cold start for requests when all instances are released: 31 | "enable_sagemaker_autoscaling": bool_env_var("ENABLE_SM_AUTOSCALING", default=False), 32 | 33 | # To use alternative Tesseract OCR instead of Amazon Textract, before running `cdk deploy` run: 34 | # export BUILD_SM_OCRS=tesseract 35 | # export DEPLOY_SM_OCRS=tesseract 36 | # export USE_SM_OCR=tesseract 37 | # ...Or edit the defaults below to `["tesseract"]` and `"tesseract"` 38 | "build_sagemaker_ocrs": list_env_var("BUILD_SM_OCRS", default=[]), 39 | "deploy_sagemaker_ocrs": list_env_var("DEPLOY_SM_OCRS", default=[]), 40 | "use_sagemaker_ocr": os.environ.get("USE_SM_OCR", default=None), 41 | } 42 | 43 | app = App() 44 | print(f"Deploying stack with configuration:\n{json.dumps(config, indent=2)}") 45 | demo_stack = PipelineDemoStack( 46 | app, 47 | "OCRPipelineDemo", 48 | **config, 49 | ) 50 | app.synth() 51 | -------------------------------------------------------------------------------- /img/annotation-example-trimmed.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aws-samples/amazon-textract-transformer-pipeline/06b39d69023e584da8daaf2b0ef44d31465b05e8/img/annotation-example-trimmed.png -------------------------------------------------------------------------------- /img/architecture-overview.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aws-samples/amazon-textract-transformer-pipeline/06b39d69023e584da8daaf2b0ef44d31465b05e8/img/architecture-overview.png -------------------------------------------------------------------------------- /img/human-review-sample.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aws-samples/amazon-textract-transformer-pipeline/06b39d69023e584da8daaf2b0ef44d31465b05e8/img/human-review-sample.png -------------------------------------------------------------------------------- /img/sfn-execution-screenshot.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aws-samples/amazon-textract-transformer-pipeline/06b39d69023e584da8daaf2b0ef44d31465b05e8/img/sfn-execution-screenshot.png -------------------------------------------------------------------------------- /notebooks/custom-containers/preproc/Dockerfile: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # SPDX-License-Identifier: MIT-0 3 | 4 | #### Container image with document/image processing (and optionally OCR) tools added. 5 | 6 | ARG BASE_IMAGE 7 | FROM ${BASE_IMAGE} 8 | 9 | # Common/base doc & image processing tools: 10 | RUN conda install -c conda-forge poppler -y \ 11 | && pip install amazon-textract-response-parser pdf2image "Pillow>=8,<9" 12 | 13 | # Optional OCR engine: Tesseract+PyTesseract 14 | # conda tesseract already includes Leptonica dependency and multi-language tessdata files by default 15 | # (but didn't set the required TESSDATA_PREFIX variable at time of writing) 16 | ARG INCLUDE_OCR_TESSERACT 17 | RUN if test -z "$INCLUDE_OCR_TESSERACT" ; \ 18 | then \ 19 | echo Skipping OCR engine Tesseract \ 20 | ; else \ 21 | conda install -y -c conda-forge tesseract && \ 22 | pip install pytesseract && \ 23 | export TESSDATA_PREFIX='/opt/conda/share/tessdata' \ 24 | ; fi 25 | -------------------------------------------------------------------------------- /notebooks/custom-containers/train-inf/Dockerfile: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # SPDX-License-Identifier: MIT-0 3 | 4 | # Container definition for Layout+language model training & inference on SageMaker 5 | 6 | ARG BASE_IMAGE 7 | FROM ${BASE_IMAGE} 8 | 9 | # Core dependencies: 10 | # - Pin PyTorch to prevent pip accidentally re-installing/upgrading it via detectron 11 | # - Pin setuptools per https://github.com/pytorch/pytorch/issues/69894#issuecomment-1080635462 12 | # - Pin protobuf < 3.21 due to an error like https://stackoverflow.com/q/72441758 as of 2023-02 13 | # (which seems to originate from somewhere in SM DDP package when unconstrained install results 14 | # in downloading protobuf@4.x) 15 | RUN PT_VER=`pip show torch | grep 'Version:' | sed 's/Version: //'` \ 16 | && pip install git+https://github.com/facebookresearch/detectron2.git setuptools==59.5.0 \ 17 | "amazon-textract-response-parser>=0.1,<0.2" "datasets[vision]>=2.14,<3" "Pillow>=9.4" \ 18 | "protobuf<3.21" torch==$PT_VER "torchvision>=0.15,<0.17" "transformers>=4.28,<4.29" 19 | 20 | # Could also consider installing detectron2 via pre-built Linux wheel, depending on the PyTorch and 21 | # CUDA versions of your base container: 22 | # https://github.com/aws/deep-learning-containers/tree/master/huggingface/pytorch 23 | # https://detectron2.readthedocs.io/en/latest/tutorials/install.html 24 | # 25 | # For example: 26 | # && pip install detectron2 -f https://dl.fbaipublicfiles.com/detectron2/wheels/cu113/torch1.10/index.html 27 | 28 | 29 | # Additional dependencies: 30 | # - pytesseract shouldn't be necessary after Transformers v4.18 (because we don't use Tesseract 31 | # OCR), but older versions have a bug: https://github.com/huggingface/transformers/issues/16845 32 | # - datasets 1.18 and torchvision 0.11 are installed in the HF training container but missing from 33 | # the inference container, and we need them for inference. Upgraded datasets to use some new 34 | # logging controls and debug multi-worker .map() pre-processing: 35 | RUN PT_VER=`pip show torch | grep 'Version:' | sed 's/Version: //'` \ 36 | && pip install pytesseract torch==$PT_VER 37 | 38 | 39 | # If you'd like to enable this container as a Custom Image for notebook kernels, for debugging in 40 | # SageMaker Studio, build it with INCLUDE_NOTEBOOK_KERNEL=1 arg to include IPython kernel and also 41 | # some other PDF processing + OCR utilities: 42 | ARG INCLUDE_NOTEBOOK_KERNEL 43 | RUN if test -z "$INCLUDE_NOTEBOOK_KERNEL" ; \ 44 | then \ 45 | echo Skipping notebook kernel dependencies \ 46 | ; else \ 47 | conda install -y -c conda-forge poppler tesseract && \ 48 | PT_VER=`pip show torch | grep 'Version:' | sed 's/Version: //'` && \ 49 | pip install easyocr ipykernel "ipywidgets>=8.1,<9" pdf2image pytesseract sagemaker \ 50 | torch==$PT_VER && \ 51 | export TESSDATA_PREFIX='/opt/conda/share/tessdata' && \ 52 | python -m ipykernel install --sys-prefix \ 53 | ; fi 54 | 55 | # We would like to disable SMDEBUG when running as a notebook kernel, because it can cause some 56 | # unwanted side-effects... But at the time of writing Dockerfile doesn't have full support for a 57 | # conditional env statement - so: 58 | # if --build-arg INCLUDE_NOTEBOOK_KERNEL=1, set USE_SMDEBUG to 'false', else set null. 59 | ENV USE_SMDEBUG=${INCLUDE_NOTEBOOK_KERNEL:+false} 60 | # ...But '' will cause problems in SM Training, default empty value to 'true' instead (which should 61 | # be the default per: 62 | # https://github.com/awslabs/sagemaker-debugger/blob/56fabe531692403e77ce9b5879d55211adec238e/smdebug/core/config_validator.py#L21 63 | ENV USE_SMDEBUG=${USE_SMDEBUG:-true} 64 | 65 | # See below guidance for adding an image built with INCLUDE_NOTEBOOK_KERNEL to SMStudio: 66 | # https://docs.aws.amazon.com/sagemaker/latest/dg/studio-byoi.html 67 | # https://github.com/aws-samples/sagemaker-studio-custom-image-samples 68 | # 69 | # An image config something like the following should work: 70 | # { 71 | # "KernelSpecs": [ 72 | # { 73 | # "Name": "python3", 74 | # "DisplayName": "Textract Transformers" 75 | # }, 76 | # ], 77 | # "FileSystemConfig": { 78 | # "MountPath": "/root/data", 79 | # "DefaultUid": 0, 80 | # "DefaultGid": 0 81 | # } 82 | # } 83 | -------------------------------------------------------------------------------- /notebooks/img/a2i-custom-template-demo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aws-samples/amazon-textract-transformer-pipeline/06b39d69023e584da8daaf2b0ef44d31465b05e8/notebooks/img/a2i-custom-template-demo.png -------------------------------------------------------------------------------- /notebooks/img/cfn-stack-outputs-a2i.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aws-samples/amazon-textract-transformer-pipeline/06b39d69023e584da8daaf2b0ef44d31465b05e8/notebooks/img/cfn-stack-outputs-a2i.png -------------------------------------------------------------------------------- /notebooks/img/sfn-execution-status-screenshot.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aws-samples/amazon-textract-transformer-pipeline/06b39d69023e584da8daaf2b0ef44d31465b05e8/notebooks/img/sfn-execution-status-screenshot.png -------------------------------------------------------------------------------- /notebooks/img/sfn-statemachine-screenshot.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aws-samples/amazon-textract-transformer-pipeline/06b39d69023e584da8daaf2b0ef44d31465b05e8/notebooks/img/sfn-statemachine-screenshot.png -------------------------------------------------------------------------------- /notebooks/img/sfn-statemachine-success.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aws-samples/amazon-textract-transformer-pipeline/06b39d69023e584da8daaf2b0ef44d31465b05e8/notebooks/img/sfn-statemachine-success.png -------------------------------------------------------------------------------- /notebooks/img/smgt-custom-template-demo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aws-samples/amazon-textract-transformer-pipeline/06b39d69023e584da8daaf2b0ef44d31465b05e8/notebooks/img/smgt-custom-template-demo.png -------------------------------------------------------------------------------- /notebooks/img/smgt-find-workforce-url.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aws-samples/amazon-textract-transformer-pipeline/06b39d69023e584da8daaf2b0ef44d31465b05e8/notebooks/img/smgt-find-workforce-url.png -------------------------------------------------------------------------------- /notebooks/img/smgt-private-workforce.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aws-samples/amazon-textract-transformer-pipeline/06b39d69023e584da8daaf2b0ef44d31465b05e8/notebooks/img/smgt-private-workforce.png -------------------------------------------------------------------------------- /notebooks/img/smgt-task-pending.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aws-samples/amazon-textract-transformer-pipeline/06b39d69023e584da8daaf2b0ef44d31465b05e8/notebooks/img/smgt-task-pending.png -------------------------------------------------------------------------------- /notebooks/img/ssm-a2i-param-detail.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aws-samples/amazon-textract-transformer-pipeline/06b39d69023e584da8daaf2b0ef44d31465b05e8/notebooks/img/ssm-a2i-param-detail.png -------------------------------------------------------------------------------- /notebooks/img/ssm-param-detail-screenshot.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aws-samples/amazon-textract-transformer-pipeline/06b39d69023e584da8daaf2b0ef44d31465b05e8/notebooks/img/ssm-param-detail-screenshot.png -------------------------------------------------------------------------------- /notebooks/preproc/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # SPDX-License-Identifier: MIT-0 3 | """OCR and image pre-processing utilities for document understanding projects with SageMaker 4 | 5 | This top-level __init__.py is not necessary for SageMaker processing/training/etc jobs, but provided 6 | in case you want to `from preproc import textract_transformers` from the notebooks to experiment. 7 | """ 8 | -------------------------------------------------------------------------------- /notebooks/preproc/ocr.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # SPDX-License-Identifier: MIT-0 3 | """Top-level entrypoint script for alternative OCR engines 4 | """ 5 | # Python Built-Ins: 6 | import logging 7 | import os 8 | import sys 9 | 10 | 11 | def run_main(): 12 | """Configure logging, import local modules and run the job""" 13 | consolehandler = logging.StreamHandler(sys.stdout) 14 | consolehandler.setFormatter( 15 | logging.Formatter("%(asctime)s [%(name)s] %(levelname)s %(message)s") 16 | ) 17 | logging.basicConfig(handlers=[consolehandler], level=os.environ.get("LOG_LEVEL", logging.INFO)) 18 | 19 | from textract_transformers.ocr import main 20 | 21 | return main() 22 | 23 | 24 | if __name__ == "__main__": 25 | # If the file is running as a script, we're in batch processing mode and should run the batch 26 | # routine (with a little logging setup before any imports, to make sure output shows up ok): 27 | run_main() 28 | else: 29 | # If the file is imported as a module, we're in inference mode and should pass through the 30 | # override functions defined in the inference module. 31 | from textract_transformers.ocr import * 32 | -------------------------------------------------------------------------------- /notebooks/preproc/preproc.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # SPDX-License-Identifier: MIT-0 3 | """Top-level entrypoint script for image/document pre-processing 4 | """ 5 | # Python Built-Ins: 6 | import logging 7 | import os 8 | import sys 9 | 10 | 11 | def run_main(): 12 | """Configure logging, import local modules and run the job""" 13 | consolehandler = logging.StreamHandler(sys.stdout) 14 | consolehandler.setFormatter( 15 | logging.Formatter("%(asctime)s [%(name)s] %(levelname)s %(message)s") 16 | ) 17 | logging.basicConfig(handlers=[consolehandler], level=os.environ.get("LOG_LEVEL", logging.INFO)) 18 | 19 | from textract_transformers.preproc import main 20 | 21 | return main() 22 | 23 | 24 | if __name__ == "__main__": 25 | # If the file is running as a script, we're in batch processing mode and should run the batch 26 | # routine (with a little logging setup before any imports, to make sure output shows up ok): 27 | run_main() 28 | else: 29 | # If the file is imported as a module, we're in inference mode and should pass through the 30 | # override functions defined in the inference module. 31 | from textract_transformers.preproc import * 32 | -------------------------------------------------------------------------------- /notebooks/preproc/textract_transformers/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # SPDX-License-Identifier: MIT-0 3 | """OCR and image pre-processing utilities for document understanding projects with SageMaker 4 | """ 5 | -------------------------------------------------------------------------------- /notebooks/preproc/textract_transformers/file_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # SPDX-License-Identifier: MIT-0 3 | """Utilities for working with filesystem names and paths 4 | """ 5 | # Python Built-Ins: 6 | import os 7 | from typing import List, Tuple 8 | 9 | 10 | def split_filename(filename: str) -> Tuple[str, str]: 11 | """Split a filename into base name and extension 12 | 13 | This basic method does NOT currently account for 2-part extensions e.g. '.tar.gz' 14 | """ 15 | basename, _, ext = filename.rpartition(".") 16 | return basename, ext 17 | 18 | 19 | def ls_relpaths(path: str, exclude_hidden: bool = True, sort: bool = True) -> List[str]: 20 | """Recursively list folder contents, sorting and excluding hidden files by default 21 | 22 | Parameters 23 | ---------- 24 | path : 25 | Folder to be walked 26 | exclude_hidden : 27 | By default (True), exclude any files beginning with '.' or folders beginning with '.' 28 | sort : 29 | By default (True), sort result paths in alphabetical order. If False, results will be 30 | randomly ordered as per os.walk(). 31 | 32 | Returns 33 | ------- 34 | results : 35 | *Relative* file paths under the provided folder 36 | """ 37 | if path.endswith("/"): 38 | path = path[:-1] 39 | result = [ 40 | os.path.join(currpath, name)[len(path) + 1 :] # +1 for trailing '/' 41 | for currpath, dirs, files in os.walk(path) 42 | for name in files 43 | ] 44 | if exclude_hidden: 45 | result = filter( 46 | lambda f: not (f.startswith(".") or "/." in f), result # (Exclude hidden dot-files) 47 | ) 48 | if sort: 49 | return sorted(result) 50 | else: 51 | return list(result) 52 | -------------------------------------------------------------------------------- /notebooks/preproc/textract_transformers/ocr.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # SPDX-License-Identifier: MIT-0 3 | """Script to run open-source OCR engines in Amazon SageMaker 4 | """ 5 | 6 | # Python Built-Ins: 7 | import argparse 8 | from base64 import b64decode 9 | import json 10 | from logging import getLogger 11 | from multiprocessing import cpu_count, Pool 12 | import os 13 | import time 14 | from typing import Iterable, Optional, Tuple 15 | 16 | # External Dependencies: 17 | import boto3 18 | 19 | # Local Dependencies 20 | from . import ocr_engines 21 | from .file_utils import ls_relpaths 22 | from .image_utils import Document 23 | 24 | 25 | logger = getLogger("ocr") 26 | s3client = boto3.client("s3") 27 | 28 | # Environment variable configurations: 29 | # When running in SM Endpoint, we can't use the usual processing job command line argument pattern 30 | # to configure these extra parameters - so instead configure via environment variables for both. 31 | OCR_ENGINE = os.environ.get("OCR_ENGINE", "tesseract").lower() 32 | OCR_DEFAULT_LANGUAGES = os.environ.get("OCR_DEFAULT_LANGUAGES", "eng").lower().split(",") 33 | OCR_DEFAULT_DPI = int(os.environ.get("OCR_DEFAULT_DPI", "300")) 34 | 35 | 36 | def model_fn(model_dir: str): 37 | """OCR Engine loader: Load the configured engine into memory ready to use""" 38 | return ocr_engines.get(OCR_ENGINE, OCR_DEFAULT_LANGUAGES) 39 | 40 | 41 | def input_fn(input_bytes: bytes, content_type: str) -> Tuple[Document, Optional[Iterable[str]]]: 42 | """Deserialize real-time processing requests 43 | 44 | For binary data requests (image or document bytes), default settings will be used e.g. 45 | `OCR_DEFAULT_LANGUAGES`. 46 | 47 | For JSON requests, supported fields are: 48 | 49 | ``` 50 | { 51 | "Document": { 52 | "Bytes": Base64-encoded inline document/image, OR: 53 | "S3Object": { 54 | "Bucket": S3 bucket name for raw document/image 55 | "Name": S3 object key 56 | "VersionId": Optional S3 object version ID 57 | } 58 | }, 59 | "Languages": Optional List[str] override for OCR_DEFAULT_LANGUAGES language codes 60 | } 61 | ``` 62 | 63 | Returns 64 | ------- 65 | doc : 66 | Loaded `Document` from which page images may be accessed 67 | languages : 68 | Optional override list of language codes for OCR, otherwise None. 69 | """ 70 | logger.debug("Deserializing request of content_type %s", content_type) 71 | if content_type == "application/json": 72 | # Indirected request with metadata (e.g. language codes and S3 pointer): 73 | req = json.loads(input_bytes) 74 | doc_spec = req.get("Document", {}) 75 | if "Bytes" in doc_spec: 76 | doc_bytes = b64decode(doc_spec["Bytes"]) 77 | elif "S3Object" in doc_spec: 78 | s3_spec = doc_spec["S3Object"] 79 | if not ("Bucket" in s3_spec and "Name" in s3_spec): 80 | raise ValueError( 81 | "Document.S3Object must be an object with keys 'Bucket' and 'Name'. Got: %s" 82 | % s3_spec 83 | ) 84 | logger.info("Fetching s3://%s/%s ...", s3_spec["Bucket"], s3_spec["Name"]) 85 | version_id = s3_spec.get("Version") 86 | resp = s3client.get_object( 87 | Bucket=s3_spec["Bucket"], 88 | Key=s3_spec["Name"], 89 | **({} if version_id is None else {"VersionId": version_id}), 90 | ) 91 | doc_bytes = resp["Body"].read() 92 | content_type = resp["ContentType"] 93 | else: 94 | raise ValueError( 95 | "JSON requests must include 'Document' object containing either 'Bytes' or " 96 | "'S3Object'. Got %s" % req 97 | ) 98 | 99 | languages = req.get("Languages") 100 | else: 101 | # Direct image/document request: 102 | doc_bytes = input_bytes 103 | languages = None 104 | 105 | return ( 106 | Document(doc_bytes, ext_or_media_type=content_type, default_doc_dpi=OCR_DEFAULT_DPI), 107 | languages, 108 | ) 109 | 110 | 111 | def predict_fn( 112 | inputs: Tuple[Document, Optional[Iterable[str]]], engine: ocr_engines.BaseOCREngine 113 | ) -> dict: 114 | """Get OCR results for a single input document/image, for the requested language codes 115 | 116 | Returns 117 | ------- 118 | result : 119 | JSON-serializable OCR result dictionary, of format roughly compatible with Amazon Textract 120 | DetectDocumentText result payload. 121 | """ 122 | return engine.process(inputs[0], languages=inputs[1]) 123 | 124 | 125 | # No output_fn required as we will always use JSON which the default serializer supports 126 | # def output_fn(prediction_output: Dict, accept: str) -> bytes: 127 | 128 | 129 | def parse_args() -> argparse.Namespace: 130 | """Parse SageMaker OCR Processing Job (batch) CLI arguments to job parameters""" 131 | parser = argparse.ArgumentParser( 132 | description="OCR documents in batch using an alternative (non-Amazon-Textract) engine" 133 | ) 134 | parser.add_argument( 135 | "--input", 136 | type=str, 137 | default="/opt/ml/processing/input/raw", 138 | help="Folder where raw input images/documents are stored", 139 | ) 140 | parser.add_argument( 141 | "--output", 142 | type=str, 143 | default="/opt/ml/processing/output/ocr", 144 | help="Folder where Amazon Textract-compatible OCR results should be saved", 145 | ) 146 | parser.add_argument( 147 | "--n-workers", 148 | type=int, 149 | default=cpu_count(), 150 | help="Number of worker processes to use for extraction (default #CPU cores)", 151 | ) 152 | args = parser.parse_args() 153 | return args 154 | 155 | 156 | def process_doc_in_worker(inputs: dict) -> None: 157 | """Batch job worker function to extract a document (used in a multiprocessing pool) 158 | 159 | File paths are mapped similar to this sample's Amazon Textract pipeline. For example: 160 | `{in_folder}/some/fld/filename.pdf` to `{out_folder}/some/fld/filename.pdf/consolidated.json` 161 | 162 | Parameters 163 | ---------- 164 | inputs : 165 | Dictionary containing fields: 166 | - in_folder (str): Mandatory path to input documents folder 167 | - rel_filepath (str): Mandatory path relative to `in_folder`, to input document 168 | - out_folder (str): Mandatory path to OCR results output folder 169 | - wait (float): Optional number of seconds to wait before starting processing, to ensure 170 | system resources are not *fully* exhausted when running as many threads as CPU cores. 171 | (Which could cause health check problems) - Default 0.5. 172 | """ 173 | time.sleep(inputs.get("wait", 0.5)) 174 | in_path = os.path.join(inputs["in_folder"], inputs["rel_filepath"]) 175 | doc = Document( 176 | in_path, 177 | # ext_or_media_type to be inferred from file path 178 | default_doc_dpi=OCR_DEFAULT_DPI, 179 | base_file_path=inputs["in_folder"], 180 | ) 181 | engine = ocr_engines.get(OCR_ENGINE, OCR_DEFAULT_LANGUAGES) 182 | try: 183 | result = engine.process(doc, OCR_DEFAULT_LANGUAGES) 184 | except Exception as e: 185 | logger.error("Failed to process document %s", in_path) 186 | raise e 187 | out_path = os.path.join(inputs["out_folder"], inputs["rel_filepath"], "consolidated.json") 188 | os.makedirs(os.path.dirname(out_path), exist_ok=True) 189 | with open(out_path, "w") as f: 190 | f.write(json.dumps(result, indent=2)) 191 | logger.info("Processed doc %s", in_path) 192 | 193 | 194 | def main() -> None: 195 | """Main batch processing job entrypoint: Parse CLI+envvars and process docs in multiple workers""" 196 | args = parse_args() 197 | logger.info("Parsed job args: %s", args) 198 | 199 | logger.info("Reading raw files from %s", args.input) 200 | rel_filepaths_all = ls_relpaths(args.input) 201 | 202 | n_docs = len(rel_filepaths_all) 203 | logger.info("Processing %s files across %s processes", n_docs, args.n_workers) 204 | with Pool(args.n_workers) as pool: 205 | for ix, _ in enumerate( 206 | pool.imap_unordered( 207 | process_doc_in_worker, 208 | [ 209 | { 210 | "in_folder": args.input, 211 | "out_folder": args.output, 212 | "rel_filepath": path, 213 | } 214 | for path in rel_filepaths_all 215 | ], 216 | ) 217 | ): 218 | logger.info("Processed doc %s of %s", ix + 1, n_docs) 219 | logger.info("All done!") 220 | -------------------------------------------------------------------------------- /notebooks/preproc/textract_transformers/ocr_engines/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # SPDX-License-Identifier: MIT-0 3 | """Custom/open-source OCR engine integrations 4 | 5 | Use .get() defined in this __init__.py file, to dynamically load your custom engine(s). 6 | 7 | To be discoverable by get(), your module (script or folder): 8 | 9 | - Should be placed in this folder, with a name beginning 'eng_'. 10 | - Should expose a class inheriting from `base.BaseOCREngine`, preferably with a name 11 | ending with 'Engine', and should not expose multiple such classes. 12 | """ 13 | # Python Built-Ins: 14 | from importlib import import_module 15 | import inspect 16 | from logging import getLogger 17 | import os 18 | from types import ModuleType 19 | from typing import Dict, Iterable, Type 20 | 21 | # Local Dependencies: 22 | from .base import BaseOCREngine 23 | 24 | 25 | logger = getLogger("ocr_engines") 26 | 27 | 28 | # Auto-discover all eng_*** modules in this folder as [EngineName->ModuleName]: 29 | ENGINES: Dict[str, str] = {} 30 | for item in os.listdir(os.path.dirname(__file__)): 31 | if not item.startswith("eng_"): 32 | continue 33 | if item.endswith(".py"): 34 | # (Assuming everything starting with eng_ is a folder or a .py file), strip ext if present: 35 | item = item[: -len(".py")] 36 | name = item[len("eng_") :] # ID/Name of the engine strips leading 'eng_' 37 | ENGINES[name] = "." + item # Relative importable module name 38 | 39 | 40 | def _find_ocr_engine_class(module: ModuleType) -> Type[BaseOCREngine]: 41 | """Find the OCREngine class from an imported module""" 42 | class_names = [name for name in dir(module) if inspect.isclass(module.__dict__[name])] 43 | names_ending_engine = [n for n in dir(module) if n.endswith("Engine")] 44 | engine_child_classes = [ 45 | name 46 | for name in class_names 47 | if issubclass(module.__dict__[name], BaseOCREngine) 48 | and module.__dict__[name] is not BaseOCREngine 49 | ] 50 | preferred_names = [n for n in engine_child_classes if n in names_ending_engine] 51 | 52 | if len(preferred_names) == 1: 53 | name = preferred_names[0] 54 | elif len(engine_child_classes) == 1: 55 | name = engine_child_classes[0] 56 | elif len(names_ending_engine) == 1: 57 | name = names_ending_engine[0] 58 | elif len(class_names) == 1: 59 | name = class_names[0] 60 | else: 61 | raise ImportError( 62 | "Failed to find unique BaseOCREngine child class from OCR engine module '%s'. Classes " 63 | "inheriting from BaseOCREngine: %s. Class names defined by module: %s" 64 | % (module.__name__, engine_child_classes, class_names) 65 | ) 66 | return module.__dict__[name] 67 | 68 | 69 | def get(engine_name: str, default_languages: Iterable[str]) -> BaseOCREngine: 70 | """Initialize a supported custom OCR engine by name 71 | 72 | Engines are dynamically imported, so that ImportErrors aren't raised for missing dependencies 73 | unless there's an actual attempt to create/use the engine. 74 | 75 | Parameters 76 | ---------- 77 | engine_name : 78 | Name of a supported custom OCR engine to fetch. 79 | default_languages : 80 | Language codes to configure the engine to detect by default. 81 | """ 82 | if engine_name in ENGINES: 83 | # Load the module: 84 | logger.info("Loading OCR engine '%s' from module '%s'", engine_name, ENGINES[engine_name]) 85 | module = import_module(ENGINES[engine_name], package=__name__) 86 | # Locate the target class in the module: 87 | cls = _find_ocr_engine_class(module) 88 | logger.info("Loading engine class: %s", cls) 89 | # Load the engine: 90 | return cls(default_languages) 91 | else: 92 | raise ValueError( 93 | "Couldn't find engine '%s' in ocr_engines module. Not in set: %s" 94 | % (engine_name, ENGINES) 95 | ) 96 | -------------------------------------------------------------------------------- /notebooks/preproc/textract_transformers/ocr_engines/eng_tesseract.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # SPDX-License-Identifier: MIT-0 3 | """Example integration for (Py)Tesseract as a custom OCR engine 4 | """ 5 | # Python Built-Ins: 6 | from logging import getLogger 7 | import os 8 | from statistics import mean 9 | from tempfile import TemporaryDirectory 10 | from typing import Iterable, List, Optional 11 | 12 | # External Dependencies: 13 | import pandas as pd 14 | import pytesseract 15 | 16 | # Local Dependencies: 17 | from .base import BaseOCREngine, generate_response_json, OCRGeometry, OCRLine, OCRPage, OCRWord 18 | from ..image_utils import Document 19 | 20 | 21 | logger = getLogger("eng_tesseract") 22 | 23 | 24 | if os.environ.get("TESSDATA_PREFIX") is None: 25 | os.environ["TESSDATA_PREFIX"] = "/opt/conda/share/tessdata" 26 | 27 | 28 | class TesseractEngine(BaseOCREngine): 29 | """Tesseract-based engine for custom SageMaker OCR endpoint option""" 30 | 31 | engine_name = "tesseract" 32 | 33 | def process(self, raw_doc: Document, languages: Optional[Iterable[str]] = None) -> dict: 34 | ocr_pages = [] 35 | 36 | with TemporaryDirectory() as tmpdir: 37 | raw_doc.set_workspace(tmpdir) 38 | for ixpage, page in enumerate(raw_doc.get_pages()): 39 | logger.debug(f"Serializing page {ixpage + 1}") 40 | page_ocr = pytesseract.image_to_data( 41 | page.file_path, 42 | output_type=pytesseract.Output.DATAFRAME, 43 | lang="+".join(self.default_languages if languages is None else languages), 44 | pandas_config={ 45 | # Need this explicit override or else pages containing only a single number 46 | # can sometimes have text column interpreted as numeric type: 47 | "dtype": {"text": str}, 48 | }, 49 | ) 50 | ocr_pages += self.dataframe_to_ocrpages(page_ocr) 51 | return generate_response_json(ocr_pages, self.engine_name) 52 | 53 | @classmethod 54 | def dataframe_to_ocrpages(cls, ocr_df: pd.DataFrame) -> List[OCRPage]: 55 | """Convert a Tesseract DataFrame to a list of OCRPage ready for Textract-like serialization 56 | 57 | Tesseract TSVs / PyTesseract DataFrames group detections by multiple levels: Page, block, 58 | paragraph, line, word. Columns are: level, page_num, block_num, par_num, line_num, word_num, 59 | left, top, width, height, conf, text. 60 | 61 | Each level is introduced by a record, so for example there will be an initial record with 62 | (level=1, page_num=1, block_num=0, par_num=0, line_num=0, word_num=0)... And then several 63 | others before finally getting down to the first WORD record (level=5, page_num=1, 64 | block_num=1, par_num=1, line_num=1, word_num=1). Records are assumed to be sorted in order, 65 | as indeed they are direct from Tesseract. 66 | """ 67 | # First construct an indexable list of page geometries, as we'll need these to normalize 68 | # other entity coordinates from absolute pixel values to 0-1 range: 69 | # (Note: In fact this function will often be called with only one page_num at a time) 70 | page_dims = ( 71 | ocr_df[ocr_df["level"] == 1] 72 | .groupby("page_num") 73 | .agg( 74 | { 75 | "left": "min", 76 | "top": "min", 77 | "width": "max", 78 | "height": "max", 79 | "page_num": "count", 80 | } 81 | ) 82 | ) 83 | # There should be exactly one level=1 record per page in the dataframe. After checking 84 | # this, we can dispose the "page_num" count column. 85 | if (page_dims["page_num"] > 1).sum() > 0: 86 | raise ValueError( 87 | "Tesseract DataFrame had duplicate entries for these page_nums at level 1: %s" 88 | % page_dims.index[page_dims["page_num"] > 0].values[:20] 89 | ) 90 | page_dims.drop(columns="page_num", inplace=True) 91 | 92 | # We need to collapse the {block, paragraph} levels of Tesseract hierarchy to preserve only 93 | # PAGE, LINE and WORD for consistency with Textract. Here we'll assume the DataFrame is in 94 | # its original Tesseract sort order, allowing iteration through the records to correctly 95 | # roll the entities up. Although iterating through large DataFrames isn't generally a 96 | # performant practice, this could always be balanced with specific parallelism if wanted: 97 | # E.g. processing multiple pages at once. 98 | pages = { 99 | num: OCRPage([]) # Initialise all pages first with no text 100 | for num in sorted(ocr_df[ocr_df["level"] == 1]["page_num"].unique()) 101 | } 102 | cur_page_num = None 103 | page_lines = [] 104 | cur_line_id = None 105 | line_words = [] 106 | 107 | # Tesseract LINE records (level 4) don't have a confidence (equals -1), so we'll use the 108 | # average over the included WORDs as a heuristic. They *do* have T/L/H/W geometry info, but 109 | # we'll ignore that for the sake of code simplicity and let OCRLine infer it from the union 110 | # of all WORD bounding boxes. 111 | add_line = lambda words: ( 112 | page_lines.append(OCRLine(mean(w.confidence for w in words), words)) 113 | ) 114 | 115 | # Loop through all WORD records, ignoring whitespace-only ones that Tesseract likes to yield 116 | words_df = ocr_df[ocr_df["level"] == 5].copy() 117 | words_df["text"] = words_df["text"].str.strip() 118 | words_df = words_df[words_df["text"].str.len() > 0] 119 | for _, rec in words_df.iterrows(): 120 | line_id = (rec.block_num, rec.par_num, rec.line_num) 121 | page_num = rec.page_num 122 | if cur_line_id != line_id: 123 | # Start of new LINE - add previous one to result: 124 | if cur_line_id is not None: 125 | add_line(line_words) 126 | cur_line_id = line_id 127 | line_words = [] 128 | if cur_page_num != page_num: 129 | # Start of new PAGE - add previous one to result: 130 | if cur_page_num is not None: 131 | pages[cur_page_num].add_lines(page_lines) 132 | cur_page_num = page_num 133 | page_lines = [] 134 | # Parse this record into a WORD: 135 | page_dim_rec = page_dims.loc[page_num] 136 | line_words.append( 137 | OCRWord( 138 | rec.text, 139 | rec.conf, 140 | OCRGeometry.from_bbox( 141 | # Word geometries, too, need normalizing by page dimensions. 142 | (rec.top - page_dim_rec.top) / page_dim_rec.height, 143 | (rec.left - page_dim_rec.left) / page_dim_rec.width, 144 | rec.height / page_dim_rec.height, 145 | rec.width / page_dim_rec.width, 146 | ), 147 | ) 148 | ) 149 | # End of last line and last page: Add any remaining content. 150 | if len(line_words): 151 | add_line(line_words) 152 | if len(page_lines): 153 | pages[cur_page_num].add_lines(page_lines) 154 | return [page for page in pages.values()] 155 | -------------------------------------------------------------------------------- /notebooks/review/.eslintrc.cjs: -------------------------------------------------------------------------------- 1 | /* eslint-env node */ 2 | 3 | module.exports = { 4 | root: true, 5 | extends: [ 6 | "plugin:vue/vue3-essential", 7 | "eslint:recommended", 8 | "@vue/eslint-config-typescript/recommended", 9 | "@vue/eslint-config-prettier", 10 | ], 11 | env: { 12 | "vue/setup-compiler-macros": true, 13 | }, 14 | rules: { 15 | "vue/multi-word-component-names": [ 16 | "error", 17 | { 18 | ignores: ["Viewer.ce"], 19 | }, 20 | ], 21 | }, 22 | }; 23 | -------------------------------------------------------------------------------- /notebooks/review/.gitignore: -------------------------------------------------------------------------------- 1 | # Local testing files 2 | public/*.pdf 3 | NOTES.md 4 | **.tmp 5 | **.tmp.* 6 | 7 | # Logs 8 | logs 9 | *.log 10 | npm-debug.log* 11 | yarn-debug.log* 12 | yarn-error.log* 13 | pnpm-debug.log* 14 | lerna-debug.log* 15 | 16 | node_modules 17 | .DS_Store 18 | dist 19 | dist-ssr 20 | coverage 21 | *.local 22 | 23 | # Editor directories and files 24 | .vscode/* 25 | !.vscode/extensions.json 26 | .idea 27 | *.suo 28 | *.ntvs* 29 | *.njsproj 30 | *.sln 31 | *.sw? 32 | -------------------------------------------------------------------------------- /notebooks/review/.prettierrc.yml: -------------------------------------------------------------------------------- 1 | printWidth: 100 2 | -------------------------------------------------------------------------------- /notebooks/review/env.d.ts: -------------------------------------------------------------------------------- 1 | /// 2 | -------------------------------------------------------------------------------- /notebooks/review/package.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "amazon-a2i-pdf-field-review-ui", 3 | "private": true, 4 | "version": "0.0.0", 5 | "description": "Document field detection review UI for Amazon Augmented AI (A2I)", 6 | "author": "Amazon Web Services", 7 | "license": "MIT-0", 8 | "repository": { 9 | "type": "git", 10 | "url": "git+https://github.com/aws-samples/amazon-textract-transformer-pipeline.git" 11 | }, 12 | "bugs": { 13 | "url": "https://github.com/aws-samples/amazon-textract-transformer-pipeline/issues" 14 | }, 15 | "homepage": "https://github.com/aws-samples/amazon-textract-transformer-pipeline#readme", 16 | "scripts": { 17 | "dev": "vite", 18 | "build": "vue-tsc --noEmit && vite build", 19 | "preview": "vite preview --port 5050", 20 | "typecheck": "vue-tsc --noEmit", 21 | "lint": "eslint . --ext .vue,.js,.jsx,.cjs,.mjs,.ts,.tsx,.cts,.mts --fix --ignore-path .gitignore" 22 | }, 23 | "dependencies": { 24 | "amazon-textract-response-parser": "0.4.2", 25 | "element-internals-polyfill": "1.2.6", 26 | "pdfjs-dist": "4.5.136", 27 | "vue": "^3.2.47" 28 | }, 29 | "devDependencies": { 30 | "@types/node": "^18.15.0", 31 | "@vitejs/plugin-vue": "^4.1.0", 32 | "@vue/eslint-config-prettier": "^9.0.0", 33 | "@vue/eslint-config-typescript": "^13.0.0", 34 | "@vue/tsconfig": "^0.5.1", 35 | "eslint": "^8.57.0", 36 | "eslint-plugin-vue": "^9.24.1", 37 | "prettier": "^3.3.2", 38 | "sass": "^1.60.0", 39 | "typescript": "^5.5.0", 40 | "vite": "^4.5.2", 41 | "vite-plugin-singlefile": "^0.13.5", 42 | "vue-tsc": "^2.0.29" 43 | } 44 | } 45 | -------------------------------------------------------------------------------- /notebooks/review/public/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aws-samples/amazon-textract-transformer-pipeline/06b39d69023e584da8daaf2b0ef44d31465b05e8/notebooks/review/public/.gitkeep -------------------------------------------------------------------------------- /notebooks/review/src/App.vue: -------------------------------------------------------------------------------- 1 | 3 | 13 | 23 | 24 | 27 | 28 | 44 | -------------------------------------------------------------------------------- /notebooks/review/src/assets/base.scss: -------------------------------------------------------------------------------- 1 | // Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | // SPDX-License-Identifier: MIT-0 3 | /** 4 | * Base/common styling affecting multiple components 5 | */ 6 | 7 | // Semantic color variables: 8 | :root { 9 | --color-accent: rgb(63, 81, 181); 10 | --color-error: rgb(221, 44, 0); 11 | } 12 | 13 | // Some basic cross-document layout tweaks: 14 | *, 15 | *::before, 16 | *::after { 17 | box-sizing: border-box; 18 | margin: 0; 19 | position: relative; 20 | font-weight: normal; 21 | } 22 | 23 | body { 24 | text-rendering: optimizeLegibility; 25 | -webkit-font-smoothing: antialiased; 26 | -moz-osx-font-smoothing: grayscale; 27 | } 28 | 29 | .col-taskobject { 30 | max-height: 100%; 31 | min-height: 25%; 32 | text-align: center; 33 | } 34 | 35 | .col-fields { 36 | height: 100%; 37 | overflow: auto; 38 | } 39 | 40 | // The following styles are shared between a couple of components: 41 | .field-detections { 42 | color: var(--color-accent); 43 | background-color: #ddd; 44 | font-size: 0.8rem; 45 | overflow: hidden; // Force child margins inside container 46 | padding-left: 26px; 47 | padding-top: 4px; 48 | 49 | p,ul,ol { 50 | margin-bottom: 0.2rem; 51 | } 52 | } 53 | 54 | @mixin confidence-bar { 55 | color: white; 56 | padding-left: 5px; 57 | font-size: 7px; 58 | line-height: 9px; 59 | } 60 | -------------------------------------------------------------------------------- /notebooks/review/src/components/FieldSingleValue.ce/index.ts: -------------------------------------------------------------------------------- 1 | // Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | // SPDX-License-Identifier: MIT-0 3 | /** 4 | * Wrapper around FieldSingleValue Vue CE component to add ElementInternals DOM functionality 5 | */ 6 | // External Dependencies: 7 | import { defineCustomElement } from "vue"; 8 | 9 | // Local Dependencies: 10 | import FieldSingleValueBase from "./FieldSingleValue.ce.vue"; 11 | 12 | // Start with the vanilla Vue CE element: 13 | const FieldSingleValueBaseElement = defineCustomElement(FieldSingleValueBase); 14 | 15 | /** 16 | * Extend the vanilla Vue CE to implement a form-associated Custom Element 17 | * 18 | * The ElementInternals standard allows custom elements to register their 19 | * participation in
s as discussed at: 20 | * 21 | * https://html.spec.whatwg.org/multipage/custom-elements.html#the-elementinternals-interface 22 | * https://developer.mozilla.org/en-US/docs/Web/API/ElementInternals 23 | * 24 | * By defining the required static formAssociated property (here) and attaching ElementInternals 25 | * and setting Form Value (in the Vue component onMounted), we can demonstrate how this approach 26 | * might be used to send data to the SageMaker without the central state store and 27 | * pattern. However, at the time of writing this appeared to work end-to-end 28 | * in Firefox (v91 ESR, using ElementInternals polyfill) but not Chrome (v98, native EInternals). 29 | */ 30 | export class FieldSingleValue extends FieldSingleValueBaseElement { 31 | _internals?: ElementInternals; 32 | static get formAssociated() { 33 | return true; 34 | } 35 | constructor(initialProps?: Record | undefined) { 36 | super(initialProps); 37 | this._internals = this.attachInternals() as unknown as ElementInternals; 38 | } 39 | } 40 | 41 | export default FieldSingleValue; 42 | -------------------------------------------------------------------------------- /notebooks/review/src/components/HelloWorld.vue: -------------------------------------------------------------------------------- 1 | 3 | 8 | 13 | 14 | 20 | 21 | 39 | -------------------------------------------------------------------------------- /notebooks/review/src/components/MultiFieldValue.ce.vue: -------------------------------------------------------------------------------- 1 | 3 | 9 | 30 | 31 | 184 | 185 | 198 | 199 | 213 | -------------------------------------------------------------------------------- /notebooks/review/src/components/ObjectValueInput.ts: -------------------------------------------------------------------------------- 1 | // Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | // SPDX-License-Identifier: MIT-0 3 | /** 4 | * Vue Custom Element to proxy all data from global state store through to SageMaker 5 | * 6 | * This CE must be registered with the special name to work correctly with the 7 | * SageMaker element. 8 | * 9 | * Since this element doesn't have any UI template or styling, there's no need to use a single-file 10 | * component '.vue': Plain TypeScript is fine. 11 | */ 12 | 13 | // External Dependencies: 14 | import { defineCustomElement } from "vue"; 15 | 16 | // Local Dependencies: 17 | import { emitValidate, store } from "../util/store"; 18 | 19 | // Base Vue component configuration (just binds the 'name' attribute) 20 | const ObjectValueInputVueComponentBase = { 21 | props: { 22 | name: String, 23 | }, 24 | /** 25 | * No UI/DOM on this component 26 | */ 27 | render() { 28 | return; 29 | }, 30 | }; 31 | 32 | // The base component is not sufficient because it doesn't register itself as being form-associated 33 | // or implement the validate() API for . 34 | const ObjectValueInputElementBase = defineCustomElement(ObjectValueInputVueComponentBase); 35 | 36 | /** 37 | * Form-associated Custom Element class for the data proxy 38 | */ 39 | export class ObjectValueInputElement extends ObjectValueInputElementBase { 40 | _internals: ElementInternals; 41 | value: Record; // Form data value property 42 | 43 | /** 44 | * Required property to mark the element as associated to forms 45 | * 46 | * https://html.spec.whatwg.org/multipage/custom-elements.html#custom-elements-face-example 47 | */ 48 | static get formAssociated() { 49 | return true; 50 | } 51 | 52 | constructor(initialProps?: Record | undefined) { 53 | super(initialProps); 54 | this._internals = this.attachInternals() as unknown as ElementInternals; 55 | // No need to continuously watch() for changes here: Setting once is sufficient 56 | this.value = store; 57 | } 58 | 59 | /** 60 | * Validate data before submission (return true for go, false for stop) 61 | * 62 | * Only allows submission to proceed if *all* registered listeners in the state store return true 63 | */ 64 | validate(): boolean { 65 | return emitValidate().every((r) => r); 66 | } 67 | } 68 | 69 | export default ObjectValueInputElement; 70 | -------------------------------------------------------------------------------- /notebooks/review/src/components/PdfPageAnnotationLayer.ce.vue: -------------------------------------------------------------------------------- 1 | 3 | 4 | 22 | 23 | 109 | 110 | 134 | 135 | 153 | -------------------------------------------------------------------------------- /notebooks/review/src/main.ts: -------------------------------------------------------------------------------- 1 | // Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | // SPDX-License-Identifier: MIT-0 3 | /** 4 | * Main script entrypoint for Vue.js example A2I UI template. 5 | */ 6 | // External Dependencies: 7 | import { createApp, defineCustomElement } from "vue"; 8 | import type { App as VueApp } from "vue"; 9 | // If needed, you could also use TRP.js here or in any of the components - as: 10 | // import { TextractDocument } from "amazon-textract-response-parser"; 11 | 12 | // Local Dependencies: 13 | import App from "./App.vue"; 14 | import FieldMultiValue from "./components/FieldMultiValue.ce.vue"; 15 | import FieldSingleValue from "./components/FieldSingleValue.ce"; 16 | import MultiFieldValue from "./components/MultiFieldValue.ce.vue"; 17 | import ObjectValueInputElement from "./components/ObjectValueInput"; 18 | import PdfPageAnnotationLayer from "./components/PdfPageAnnotationLayer.ce.vue"; 19 | import Viewer from "./components/Viewer.ce.vue"; 20 | import type { ModelResult } from "./util/model"; 21 | 22 | declare global { 23 | interface Window { 24 | app: VueApp; 25 | // As per the setup script in index.html / index-noliquid.html: 26 | taskData: { 27 | taskObject: string; 28 | taskInput: { ModelResult: ModelResult }; 29 | }; 30 | } 31 | } 32 | 33 | // Register our Custom Element components (not needed for normal Vue components): 34 | customElements.define("object-value-input", ObjectValueInputElement); 35 | customElements.define("custom-field", FieldSingleValue); 36 | const FieldMultiValueElement = defineCustomElement(FieldMultiValue); 37 | customElements.define("custom-field-multivalue", FieldMultiValueElement); 38 | const MultiFieldValueElement = defineCustomElement(MultiFieldValue); 39 | customElements.define("custom-multifield-value", MultiFieldValueElement); 40 | const PdfPageAnnotationLayerElement = defineCustomElement(PdfPageAnnotationLayer); 41 | customElements.define("custom-page-annotation-layer", PdfPageAnnotationLayerElement); 42 | const ViewerElement = defineCustomElement(Viewer); 43 | customElements.define("custom-viewer", ViewerElement); 44 | 45 | // Mount the Vue app (not that the app itself does very much): 46 | createApp(App).mount("#app"); 47 | -------------------------------------------------------------------------------- /notebooks/review/src/util/colors.ts: -------------------------------------------------------------------------------- 1 | // Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | // SPDX-License-Identifier: MIT-0 3 | 4 | /** 5 | * Colours by class ID to match those used by SM Ground Truth bounding box task UI. 6 | */ 7 | export const LABEL_CLASS_COLORS = [ 8 | "#2ca02c", 9 | "#1f77b4", 10 | "#ff7f0e", 11 | "#d62728", 12 | "#9467bd", 13 | "#8c564b", 14 | "#e377c2", 15 | "#7f7f7f", 16 | "#bcbd22", 17 | "#ff9896", 18 | "#17becf", 19 | "#aec7e8", 20 | "#ffbb78", 21 | "#98df8a", 22 | "#c5b0d5", 23 | "#c49c94", 24 | "#f7b6d2", 25 | "#c7c7c7", 26 | "#dbdb8d", 27 | "#9edae5", 28 | "#393b79", 29 | "#5254a3", 30 | "#6b6ecf", 31 | "#9c9ede", 32 | "#637939", 33 | "#8ca252", 34 | "#b5cf6b", 35 | "#cedb9c", 36 | "#8c6d31", 37 | "#bd9e39", 38 | "#e7ba52", 39 | "#e7cb94", 40 | "#843c39", 41 | "#ad494a", 42 | "#d6616b", 43 | "#e7969c", 44 | "#7b4173", 45 | "#a55194", 46 | "#ce6dbd", 47 | "#de9ed6", 48 | "#3182bd", 49 | "#6baed6", 50 | "#9ecae1", 51 | "#c6dbef", 52 | "#e6550d", 53 | "#fd8d3c", 54 | "#fdae6b", 55 | "#fdd0a2", 56 | "#31a354", 57 | "#74c476", 58 | "#a1d99b", 59 | "#c7e9c0", 60 | "#756bb1", 61 | "#9e9ac8", 62 | "#bcbddc", 63 | "#dadaeb", 64 | "#636363", 65 | "#969696", 66 | "#bdbdbd", 67 | "#d9d9d9", 68 | ]; 69 | -------------------------------------------------------------------------------- /notebooks/review/src/util/model.d.ts: -------------------------------------------------------------------------------- 1 | // Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | // SPDX-License-Identifier: MIT-0 3 | /** 4 | * Type declarations for expected A2I task input data for this task 5 | */ 6 | 7 | /** 8 | * Detected instance of a given entity/field type in document text 9 | */ 10 | export interface Detection { 11 | Blocks: string[]; 12 | BoundingBox: { Height: number; Left: number; Top: number; Width: number }; 13 | ClassId: number; 14 | ClassName: string; 15 | Confidence: number; 16 | PageNum: number; 17 | Text: string; 18 | } 19 | 20 | /** 21 | * Common interface for single- or multi-value entity/field results. 22 | */ 23 | interface ModelResultFieldBase { 24 | ClassId: number; 25 | Confidence: number; 26 | NumDetectedValues: number; 27 | NumDetections: number; 28 | Optional?: boolean; 29 | SortOrder: number; 30 | } 31 | 32 | /** 33 | * Overall detection result for a given entity/field type. 34 | */ 35 | export interface ModelResultSingleField extends ModelResultFieldBase { 36 | Detections: Detection[]; 37 | Value: string; 38 | } 39 | 40 | /** 41 | * Overall detection result for a given entity/field type. 42 | */ 43 | export interface ModelResultMultiField extends ModelResultFieldBase { 44 | Values: Array<{ 45 | Confidence: number; 46 | Detections: Detection[]; 47 | Value: string; 48 | }>; 49 | } 50 | 51 | /** 52 | * Overall model result across entity/field types. 53 | */ 54 | export interface ModelResult { 55 | Confidence: number; 56 | Fields: { 57 | [FieldName: string]: ModelResultSingleField | ModelResultMultiField; 58 | }; 59 | } 60 | -------------------------------------------------------------------------------- /notebooks/review/src/util/store.ts: -------------------------------------------------------------------------------- 1 | // Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | // SPDX-License-Identifier: MIT-0 3 | /** 4 | * A reactive state store singleton and validation event bus for . 5 | * 6 | * This state store, used with the custom element, allows input components 7 | * under shadow DOM to still register output data to be picked up by the SageMaker . 8 | * Pub/sub validation event handling is also provided, so components can handle and respond to 9 | * validation requests when the user clicks 'Submit'. 10 | */ 11 | import { reactive } from "vue"; 12 | 13 | /** 14 | * A simple event bus on which publish()ing collects the results from all registered listeners. 15 | */ 16 | class ResultCollectingEventBus { 17 | private nextId = 0; 18 | private subscriptions: Record TResponse>> = {}; 19 | 20 | generateId() { 21 | return (this.nextId++).toString(); 22 | } 23 | 24 | /** 25 | * Subscribe a new event listener/handler 26 | * @param eventType Name of event type to listen for 27 | * @param listener Function to be called when an event is published 28 | * @returns A function to call to de-register/remove the listener 29 | */ 30 | subscribe(eventType: string, listener: (event: TEvent) => TResponse): () => void { 31 | if (!this.subscriptions[eventType]) this.subscriptions[eventType] = {}; 32 | const id = this.generateId(); 33 | this.subscriptions[eventType][id] = listener; 34 | return () => { 35 | delete this.subscriptions[eventType][id]; 36 | if (!Object.keys(this.subscriptions[eventType]).length) { 37 | delete this.subscriptions[eventType]; 38 | } 39 | }; 40 | } 41 | 42 | /** 43 | * Publish an event and synchronously collect and return the results from all listeners. 44 | * @param eventType Name of event type to publish 45 | * @param event Event data to publish 46 | * @returns Array of listener responses (in the order they were called) 47 | */ 48 | publish(eventType: string, event: TEvent): TResponse[] { 49 | const subs = this.subscriptions[eventType]; 50 | if (!subs) return []; 51 | return Object.keys(subs).map((k) => subs[k](event)); 52 | } 53 | } 54 | 55 | const storeBus = new ResultCollectingEventBus(); 56 | 57 | /** 58 | * Register a listener/handler for form validation events when user clicks submit. 59 | * @param validateHandler A callback returning true to allow form submission, false to prevent it 60 | * @returns A callback to remove the event listener (e.g. when your component is deleted) 61 | */ 62 | export const addValidateHandler = (validateHandler: () => boolean) => { 63 | return storeBus.subscribe("validate", validateHandler); 64 | }; 65 | 66 | /** 67 | * Publish a form validation event and collect the response from all registered components 68 | * @returns Boolean array in which any 'false' entry should prevent accepting the form. 69 | */ 70 | export const emitValidate = () => storeBus.publish("validate"); 71 | 72 | /** 73 | * Reactive state store to which components can save data inputs. 74 | * 75 | * The contents of this store will be continuously watched and synchronised to the 76 | * element, to be included in output data when the SageMaker task is submitted 77 | */ 78 | export const store = reactive>({}); 79 | export default store; 80 | -------------------------------------------------------------------------------- /notebooks/review/tsconfig.json: -------------------------------------------------------------------------------- 1 | { 2 | "extends": "@vue/tsconfig/tsconfig.json", 3 | "include": ["env.d.ts", "src/**/*", "src/**/*.vue"], 4 | "compilerOptions": { 5 | "allowJs": true, 6 | "baseUrl": ".", 7 | "ignoreDeprecations": "5.0", // TODO: https://github.com/vuejs/tsconfig/issues/6 8 | "paths": { 9 | "@/*": ["./src/*"] 10 | } 11 | } 12 | } 13 | -------------------------------------------------------------------------------- /notebooks/review/vite.config.js: -------------------------------------------------------------------------------- 1 | // Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | // SPDX-License-Identifier: MIT-0 3 | /* global __dirname */ 4 | 5 | // Node Built-Ins: 6 | import { resolve } from "path"; 7 | import { fileURLToPath, URL } from "url"; 8 | 9 | // External Dependencies: 10 | import { defineConfig } from "vite"; 11 | import { viteSingleFile } from "vite-plugin-singlefile"; 12 | import vue from "@vitejs/plugin-vue"; 13 | 14 | // See reference at https://vitejs.dev/config/ 15 | export default defineConfig({ 16 | build: { 17 | // Single file adjustments as per https://www.npmjs.com/package/vite-plugin-singlefile 18 | assetsInlineLimit: 100000000, // for vite-plugin-singlefile 19 | chunkSizeWarningLimit: 100000000, // for vite-plugin-singlefile 20 | cssCodeSplit: false, // for vite-plugin-singlefile 21 | reportCompressedSize: false, // Not really relevant for single-file outputs 22 | rollupOptions: { 23 | external: [ 24 | // ---- Dependencies to exclude from build (will be fetched from CDN): 25 | // TRP.js is run as IIFE by a script tag in the HTML which produces a `trp` global (below): 26 | "amazon-textract-response-parser", 27 | // PDF.js entrypoints will be treated as external `paths` (below): 28 | "pdfjs-dist/legacy/build/pdf.mjs", 29 | "pdfjs-dist/legacy/web/pdf_viewer.mjs", 30 | ], 31 | // You could point the build to a different input HTML template if needed: 32 | input: resolve(__dirname, "index.html"), 33 | output: { 34 | format: "es", // Need to use ESM for pdf.js 35 | paths: { 36 | "pdfjs-dist/legacy/build/pdf.mjs": 37 | "https://cdn.jsdelivr.net/npm/pdfjs-dist@4.5.136/legacy/build/pdf.mjs", 38 | "pdfjs-dist/legacy/web/pdf_viewer.mjs": 39 | "https://cdn.jsdelivr.net/npm/pdfjs-dist@4.5.136/legacy/web/pdf_viewer.mjs", 40 | }, 41 | globals: { 42 | "amazon-textract-response-parser": "trp", 43 | }, 44 | inlineDynamicImports: true, // for vite-plugin-singlefile 45 | }, 46 | }, 47 | }, 48 | plugins: [ 49 | vue({ 50 | template: { 51 | compilerOptions: { 52 | // Avoid default {{ }} delimiters because these will conflict with Liquid template lang. 53 | delimiters: ["${", "}"], 54 | // Declare the SageMaker Crowd HTML Elements so Vue doesn't fuss about missing them: 55 | isCustomElement: (tag) => tag.startsWith("crowd-") || tag.startsWith("iron-"), 56 | }, 57 | }, 58 | }), 59 | // Package all outputs together so we don't have to find a way to host many JS/CSS/etc assets: 60 | viteSingleFile(), 61 | ], 62 | resolve: { 63 | alias: { 64 | "@": fileURLToPath(new URL("./src", import.meta.url)), 65 | }, 66 | }, 67 | }); 68 | -------------------------------------------------------------------------------- /notebooks/src/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # SPDX-License-Identifier: MIT-0 3 | """Amazon Textract + LayoutLM model training and inference code for SageMaker 4 | 5 | This __init__ file is not necessary for training or inference in sagemaker (which runs your 6 | nominated entry point file, rather than importing this whole folder as a module) - but it can help 7 | with local debugging, by enabling you to `from src import XYZ` from the notebooks. 8 | """ 9 | -------------------------------------------------------------------------------- /notebooks/src/code/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # SPDX-License-Identifier: MIT-0 3 | """Amazon Textract + LayoutLM model training and inference code package for SageMaker 4 | 5 | Why the extra level of nesting? Because the src folder (even if __init__ is present) is not loaded 6 | as a Python module during training, but rather as the working directory. This requires a different 7 | import syntax for top-level files/folders (`import config`, not `from . import config`) than you 8 | would see if your working directory was different (for example when you `from src import code` to 9 | use it from one of the notebooks). 10 | 11 | Wrapping this code in an extra package folder ensures that - regardless of whether you use it from 12 | notebook, in SM training job, or in some other app - relative imports *within* this code/ folder 13 | work correctly. 14 | """ 15 | -------------------------------------------------------------------------------- /notebooks/src/code/data/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # SPDX-License-Identifier: MIT-0 3 | """Data loading utilities for Amazon Textract with Hugging Face Transformers 4 | 5 | Call get_datasets() from the training script to load datasets/collators for the current task. 6 | """ 7 | # Python Built-Ins: 8 | from typing import Iterable, Optional 9 | 10 | # External Dependencies: 11 | from transformers.processing_utils import ProcessorMixin 12 | from transformers.tokenization_utils_base import PreTrainedTokenizerBase 13 | 14 | # Local Dependencies: 15 | from ..config import DataTrainingArguments 16 | from .base import TaskData 17 | from .mlm import get_task as get_mlm_task 18 | from .ner import get_task as get_ner_task 19 | from .seq2seq import get_task as get_seq2seq_task 20 | 21 | 22 | def get_datasets( 23 | data_args: DataTrainingArguments, 24 | tokenizer: PreTrainedTokenizerBase, 25 | processor: Optional[ProcessorMixin] = None, 26 | model_param_names: Optional[Iterable[str]] = None, 27 | n_workers: Optional[int] = None, 28 | cache_dir: Optional[str] = None, 29 | ) -> TaskData: 30 | """Load datasets and data collators for model pre/training""" 31 | if data_args.task_name == "mlm": 32 | return get_mlm_task( 33 | data_args, 34 | tokenizer, 35 | processor, 36 | model_param_names=model_param_names, 37 | n_workers=n_workers, 38 | cache_dir=cache_dir, 39 | ) 40 | elif data_args.task_name == "ner": 41 | return get_ner_task( 42 | data_args, tokenizer, processor, n_workers=n_workers, cache_dir=cache_dir 43 | ) 44 | elif data_args.task_name == "seq2seq": 45 | return get_seq2seq_task( 46 | data_args, tokenizer, processor, n_workers=n_workers, cache_dir=cache_dir 47 | ) 48 | else: 49 | raise ValueError( 50 | "Unknown task '%s' is not in 'mlm', 'ner', 'seq2seq'" % data_args.task_name 51 | ) 52 | -------------------------------------------------------------------------------- /notebooks/src/code/data/geometry.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # SPDX-License-Identifier: MIT-0 3 | """Geometry utilities for working with LayoutLM, Amazon Textract, and SageMaker Ground Truth 4 | """ 5 | # Python Built-Ins: 6 | from typing import Iterable, Optional, Union 7 | 8 | # External Dependencies: 9 | import numpy as np 10 | import torch 11 | import trp 12 | 13 | 14 | def layoutlm_boxes_from_trp_blocks( 15 | textract_blocks: Iterable[ 16 | Union[ 17 | trp.Word, 18 | trp.Line, 19 | trp.SelectionElement, 20 | trp.FieldKey, 21 | trp.FieldValue, 22 | trp.Cell, 23 | trp.Table, 24 | trp.Page, 25 | ] 26 | ], 27 | return_tensors: Optional[str] = None, 28 | ): 29 | """List of TRP 'blocks' to array of 0-1000 normalized x0,y0,x1,y1 for LayoutLM 30 | 31 | Per https://docs.aws.amazon.com/textract/latest/dg/API_BoundingBox.html, Textract bounding box 32 | coords are 0-1 relative to page size already: So we just need to multiply by 1000. Note this 33 | means there's no information encoded about the overall aspect ratio of the page. 34 | 35 | Parameters 36 | ---------- 37 | textract_blocks : 38 | Iterable of any Textract TRP objects including a 'geometry' property e.g. Word, Line, Cell, 39 | etc. 40 | return_tensors : 41 | None (default) to return plain nested lists of ints. "np" to return a numpy array or "pt" 42 | to return a torch.LongTensor. 43 | 44 | Returns 45 | ------- 46 | boxes : 47 | Array or tensor shape (n_examples, 4) of bounding boxes: left, top, right, bottom scaled 48 | 0-1000. 49 | """ 50 | raw_zero_to_one_list = [ 51 | [ 52 | block.geometry.boundingBox.left, 53 | block.geometry.boundingBox.top, 54 | block.geometry.boundingBox.left + block.geometry.boundingBox.width, 55 | block.geometry.boundingBox.top + block.geometry.boundingBox.height, 56 | ] 57 | for block in textract_blocks 58 | ] 59 | if return_tensors == "np" or not return_tensors: 60 | if len(raw_zero_to_one_list) == 0: 61 | npresult = np.zeros((0, 4), dtype="long") 62 | else: 63 | npresult = (np.array(raw_zero_to_one_list) * 1000).astype("long") 64 | return npresult if return_tensors else npresult.tolist() 65 | elif return_tensors == "pt": 66 | if len(raw_zero_to_one_list) == 0: 67 | return torch.zeros((0, 4), dtype=torch.long) 68 | else: 69 | return (torch.FloatTensor(raw_zero_to_one_list) * 1000).long() 70 | else: 71 | raise ValueError( 72 | "return_tensors must be 'np' or 'pt' for layoutlm_boxes_from_trp_blocks(). Got: %s" 73 | % return_tensors 74 | ) 75 | -------------------------------------------------------------------------------- /notebooks/src/code/data/seq2seq/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # SPDX-License-Identifier: MIT-0 3 | """Data utilities for generative, sequence-to-sequence tasks 4 | 5 | This task is experimental, and does not currently support layout-aware models. As shown in the 6 | 'Optional Extras' notebook, you can use it to train separate post-processing models to normalize 7 | extracted fields: For example converting the format of dates. 8 | """ 9 | from .task_builder import get_task 10 | -------------------------------------------------------------------------------- /notebooks/src/code/data/seq2seq/metrics.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # SPDX-License-Identifier: MIT-0 3 | """Validation/accuracy metric callbacks for seq2seq modelling tasks""" 4 | # Python Built-Ins: 5 | from numbers import Real 6 | from typing import Callable, Dict 7 | 8 | # External Dependencies: 9 | import numpy as np 10 | from transformers import EvalPrediction, PreTrainedTokenizerBase 11 | 12 | 13 | def get_metric_computer( 14 | tokenizer: PreTrainedTokenizerBase, 15 | ) -> Callable[[EvalPrediction], Dict[str, Real]]: 16 | """An 'accuracy' computer for seq2seq tasks that ignores outer whitespace and case. 17 | 18 | For our example task, it's reasonable to measure exact-match accuracy (since we're normalising 19 | small text spans - not e.g. summarizing long texts to shorter paragraphs). Therefore this metric 20 | computer checks exact accuracy, while allowing for variations in case and leading/trailing 21 | whitespace. 22 | """ 23 | 24 | def compute_metrics(p: EvalPrediction) -> Dict[str, Real]: 25 | # Convert model output probs/logits to predicted token IDs: 26 | predicted_token_ids = np.argmax(p.predictions[0], axis=2) 27 | # Replace everything from the first token onward with padding (as eos 28 | # would terminate generation in a normal generate() call) 29 | for ix_batch, seq in enumerate(predicted_token_ids): 30 | eos_token_matches = np.where(seq == tokenizer.eos_token_id) 31 | if len(eos_token_matches) and len(eos_token_matches[0]): 32 | first_eos_posn = eos_token_matches[0][0] 33 | predicted_token_ids[ix_batch, first_eos_posn:] = tokenizer.pad_token_id 34 | 35 | gen_texts = [ 36 | s.strip().lower() 37 | for s in tokenizer.batch_decode(predicted_token_ids, skip_special_tokens=True) 38 | ] 39 | 40 | target_texts = [ 41 | s.strip().lower() 42 | for s in tokenizer.batch_decode( 43 | # Replace label '-100' tokens (ignore index for BinaryCrossEntropy) with '0' ( 44 | # token), to avoid an OverflowError when trying to decode the target text: 45 | np.maximum(0, p.label_ids), 46 | skip_special_tokens=True, 47 | ) 48 | ] 49 | 50 | n_examples = len(gen_texts) 51 | n_correct = sum(1 for gen, target in zip(gen_texts, target_texts) if gen == target) 52 | return { 53 | "n_examples": len(gen_texts), 54 | "acc": n_correct / n_examples, 55 | } 56 | 57 | return compute_metrics 58 | -------------------------------------------------------------------------------- /notebooks/src/code/inference_seq2seq.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # SPDX-License-Identifier: MIT-0 3 | """Alternative SageMaker inference wrapper for text-only (non-multimodal) seq2seq models 4 | 5 | These models are optionally deployed alongside the core layout-aware NER model, to normalize 6 | detected entity mentions. 7 | 8 | API Usage 9 | --------- 10 | 11 | All requests and responses in 'application/json'. The model takes a dict with key `inputs` which 12 | may be a text string or a list of strings. It will return a dict with key `generated_text` 13 | containing either a text string or a list of strings (as per the input). 14 | """ 15 | 16 | # Python Built-Ins: 17 | import json 18 | import os 19 | from typing import Dict, List, Union 20 | 21 | # External Dependencies: 22 | from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline 23 | import torch 24 | 25 | # Local Dependencies: 26 | from . import logging_utils 27 | 28 | logger = logging_utils.getLogger("infcustom") 29 | logger.info("Loading custom inference handlers") 30 | # If you need to debug this script and aren't seeing any logging in CloudWatch, try setting the 31 | # following on the Model to force flushing log calls through: env={ "PYTHONUNBUFFERED": "1" } 32 | 33 | # Configurations: 34 | INFERENCE_BATCH_SIZE = int(os.environ.get("INFERENCE_BATCH_SIZE", "4")) 35 | PAD_TO_MULTIPLE_OF = os.environ.get("PAD_TO_MULTIPLE_OF", "8") 36 | PAD_TO_MULTIPLE_OF = None if PAD_TO_MULTIPLE_OF in ("None", "") else int(PAD_TO_MULTIPLE_OF) 37 | 38 | 39 | def input_fn(input_bytes, content_type: str): 40 | """Deserialize and pre-process model request JSON 41 | 42 | Requests must be of type application/json. See module-level docstring for API details. 43 | """ 44 | logger.info(f"Received request of type:{content_type}") 45 | if content_type != "application/json": 46 | raise ValueError("Content type must be application/json") 47 | 48 | req_json = json.loads(input_bytes) 49 | if "inputs" not in req_json: 50 | raise ValueError( 51 | "Request JSON must contain field 'inputs' with either a text string or a list of text " 52 | "strings" 53 | ) 54 | return req_json["inputs"] 55 | 56 | 57 | # No custom output_fn needed as result is plain JSON fully prepared by predict_fn 58 | 59 | 60 | def model_fn(model_dir) -> dict: 61 | """Load model artifacts from model_dir into a dict 62 | 63 | Returns 64 | ------- 65 | pipeline : transformers.pipeline 66 | HF Pipeline for text generation inference 67 | """ 68 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 69 | tokenizer = AutoTokenizer.from_pretrained( 70 | model_dir, 71 | pad_to_multiple_of=PAD_TO_MULTIPLE_OF, 72 | # TODO: Is it helpful to use_fast=True? 73 | ) 74 | model = AutoModelForSeq2SeqLM.from_pretrained(model_dir) 75 | model.eval() 76 | model.to(device) 77 | 78 | pl = pipeline( 79 | "text2text-generation", 80 | model=model, 81 | tokenizer=tokenizer, 82 | batch_size=INFERENCE_BATCH_SIZE, 83 | # num_workers as per default 84 | device=model.device, 85 | ) 86 | 87 | logger.info("Model loaded") 88 | return { 89 | # Could return other objects e.g. `model` and `tokenizer`` for debugging 90 | "pipeline": pl, 91 | } 92 | 93 | 94 | def predict_fn( 95 | input_data: Union[str, List[str]], 96 | model_data: dict, 97 | ) -> Dict[str, Union[str, List[str]]]: 98 | """Generate text outputs from an input or list of inputs 99 | 100 | Parameters 101 | ---------- 102 | input_data : 103 | Input text string or list of input text strings (including prompts if needed) 104 | model_data : { pipeline } 105 | Trained model data loaded by model_fn, including a `pipeline`. 106 | 107 | Returns 108 | ------- 109 | result : 110 | Dict including key `generated_text`, which will either be a text string (if `input_data` was 111 | a single string) or a list of strings (if `input_data` was a list). 112 | """ 113 | pl = model_data["pipeline"] 114 | 115 | # Use transformers Pipelines to simplify the inference process and handle e.g. batching and 116 | # tokenization for us: 117 | result = pl(input_data, clean_up_tokenization_spaces=True) 118 | 119 | # Convert output from list of dicts to dict of lists: 120 | result = {k: [r[k] for r in result] for k in result[0].keys()} 121 | # Strip any leading/trailing whitespace from results: 122 | result["generated_text"] = [t.strip() for t in result["generated_text"]] 123 | 124 | # If input was a plain string (instead of a list of strings), remove the batch dimension from 125 | # outputs too: 126 | if isinstance(input_data, str): 127 | for k in result: 128 | result[k] = result[k][0] 129 | 130 | return result 131 | -------------------------------------------------------------------------------- /notebooks/src/code/logging_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # SPDX-License-Identifier: MIT-0 3 | """Logging utilities for the SageMaker code package 4 | 5 | Provides a centralized place for re-configuring existing loggers after config load - which allows 6 | our files to still getLogger() on import, rather than having to pass dynamic Logger objects around 7 | everywhere between functions at call time. 8 | """ 9 | # Python Built-Ins: 10 | import logging 11 | from typing import Union 12 | 13 | # External Dependencies: 14 | from transformers.utils import logging as transformers_logging 15 | 16 | 17 | transformers_logging.enable_default_handler() 18 | transformers_logging.enable_explicit_format() 19 | 20 | LEVEL = logging.root.level 21 | 22 | 23 | def _create_logger(name: str): 24 | logger = logging.getLogger(name) 25 | logger.setLevel(LEVEL) 26 | return logger 27 | 28 | 29 | LOGGER_MAP = {} 30 | 31 | 32 | def getLogger(name: str) -> logging.Logger: 33 | """Retrieve or create a Logger by name""" 34 | if name not in LOGGER_MAP: 35 | LOGGER_MAP[name] = _create_logger(name) 36 | 37 | return LOGGER_MAP[name] 38 | 39 | 40 | def setLevel(level: Union[int, str]): 41 | """Set the level of all active loggers created through this util (and HF's loggerrs)""" 42 | LEVEL = level 43 | transformers_logging.set_verbosity(LEVEL) 44 | for name in LOGGER_MAP: 45 | LOGGER_MAP[name].setLevel(LEVEL) 46 | -------------------------------------------------------------------------------- /notebooks/src/code/models/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # SPDX-License-Identifier: MIT-0 3 | """Some model implementation customizations""" 4 | -------------------------------------------------------------------------------- /notebooks/src/code/smddpfix.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # SPDX-License-Identifier: MIT-0 3 | """Patched HF Trainer to enable using ddp_find_unused_parameters with SageMaker Distributed 4 | """ 5 | # Python Built-Ins: 6 | from logging import getLogger 7 | from unittest.mock import patch, MagicMock 8 | 9 | # External Dependencies: 10 | from transformers.trainer import Trainer as TrainerBase 11 | 12 | try: 13 | # v4.18+ 14 | from transformers.utils import is_sagemaker_dp_enabled 15 | except ImportError: 16 | # v4.17 17 | from transformers.file_utils import is_sagemaker_dp_enabled 18 | from torch.nn.parallel import DistributedDataParallel as PTDDP 19 | from smdistributed.dataparallel.torch.parallel.distributed import DistributedDataParallel as SMDDP 20 | 21 | 22 | logger = getLogger("smddpfix") 23 | 24 | 25 | class Trainer(TrainerBase): 26 | """transformers.Trainer with a fix to enable ddp_find_unused_parameters on SageMaker DDP 27 | 28 | In at least versions 4.17.0 to 4.19.2 (probably others), HF transformers.Trainer ignores the 29 | ddp_find_unused_parameters argument when training with SageMaker Distributed Data Parallel. 30 | 31 | This customized class tries to detect and correct that behavior. 32 | """ 33 | 34 | def _wrap_model(self, model, **kwargs): 35 | """Modified _wrap_model implementation with SageMaker ddp_find_unused_parameters fix""" 36 | # If the conditions for the problem don't apply, just call the original method: 37 | if not (is_sagemaker_dp_enabled() and self.args.ddp_find_unused_parameters): 38 | return super()._wrap_model(model, **kwargs) 39 | 40 | # In v4.18+, Trainer uses nn.parallel.DistributedDataParallel() (SM DDP is configured as a 41 | # backend for the vanilla PyTorch class): 42 | with patch( 43 | "transformers.trainer.nn.parallel.DistributedDataParallel", 44 | create=True, 45 | ) as ptmock: 46 | # In v4.17, Trainer instantiates "DDP" (as per our SMDDP above) 47 | with patch("transformers.trainer.DDP", create=True) as smmock: 48 | # (The mock/patching approach assumes that nothing in the parent function actually 49 | # uses the model after creating it, but that's true in the checked HF versions) 50 | model = super()._wrap_model(model, **kwargs) 51 | 52 | if len(ptmock.call_args_list): 53 | # Native PyTorch DDP mock was called: 54 | if len(ptmock.call_args_list) > 1: 55 | raise ValueError( 56 | "Error in custom fix for SageMaker Distributed Data Parallel: Native " 57 | f"PyTorch DDP mock called multiple times. {ptmock.call_args_list}" 58 | ) 59 | params = ptmock.call_args_list[0] 60 | logger.info( 61 | "Intercepting PyTorch DistributedDataParallel call to add " 62 | "find_unused_parameters=%s", 63 | self.args.ddp_find_unused_parameters, 64 | ) 65 | params.kwargs["find_unused_parameters"] = self.args.ddp_find_unused_parameters 66 | model = PTDDP(*params.args, **params.kwargs) 67 | 68 | elif len(smmock.call_args_list): 69 | # SageMaker DDP mock was called: 70 | if len(smmock.call_args_list) > 1: 71 | raise ValueError( 72 | "Error in custom fix for SageMaker Distributed Data Parallel: " 73 | f"SageMaker DDP mock called multiple times. {smmock.call_args_list}" 74 | ) 75 | params = smmock.call_args_list[0] 76 | logger.info( 77 | "Intercepting SageMaker DistributedDataParallel call to add " 78 | "find_unused_parameters=%s", 79 | self.args.ddp_find_unused_parameters, 80 | ) 81 | params.kwargs["find_unused_parameters"] = self.args.ddp_find_unused_parameters 82 | model = SMDDP(*params.args, **params.kwargs) 83 | 84 | # If model is still a mock after the updates, something's gone wrong: 85 | if isinstance(model, MagicMock): 86 | raise ValueError( 87 | "Error in custom fix for SageMaker Distributed Data Parallel: " 88 | "DDP model is still mocked after checking expected cases." 89 | ) 90 | return model 91 | -------------------------------------------------------------------------------- /notebooks/src/ddp_launcher.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # SPDX-License-Identifier: MIT-0 3 | """(Native PyTorch) DistributedDataParallel launcher 4 | 5 | Use this entrypoint script to launch training with native PyTorch DDP on SageMaker. You don't need 6 | it if using SageMaker DDP - in which case directly set 'train.py' as your entrypoint. 7 | """ 8 | # Python Built-Ins: 9 | import json 10 | import os 11 | import socket 12 | import subprocess 13 | import sys 14 | 15 | 16 | # Path to resource config file IF running on SageMaker: 17 | SM_CONFIG_PATH = "/opt/ml/input/config/resourceconfig.json" 18 | 19 | if __name__ != "__main__": 20 | # If the file is imported as a module, we're in inference mode and should pass through the 21 | # override functions defined in the inference module. This is to support directly deploying the 22 | # model via SageMaker SDK's Estimator.deploy(), which will carry over the environment variable 23 | # SAGEMAKER_PROGRAM=ddp_launcher.py from training - causing the server to try and load handlers 24 | # from here rather than inference.py. 25 | from code.inference import * 26 | else: 27 | if os.path.exists(SM_CONFIG_PATH): 28 | # Running on SageMaker: Load distribution configs from the resourceconfig file 29 | with open(SM_CONFIG_PATH) as file: 30 | cluster_config = json.load(file) 31 | 32 | host_names = cluster_config["hosts"] 33 | default_n_nodes = len(host_names) 34 | default_node_rank = host_names.index(os.environ.get("SM_CURRENT_HOST")) 35 | 36 | # Elect first listed host as the leader for PyTorch DDP 37 | print("CLUSTER HOSTS:") 38 | host_ips = [socket.gethostbyname(host) for host in host_names] 39 | for ix, host in enumerate(host_names): 40 | print( 41 | " - {}host: {}, IP: {}".format( 42 | "(leader) " if ix == 0 else "", 43 | host, 44 | host_ips[ix], 45 | ) 46 | ) 47 | leader = host_ips[0] 48 | 49 | # Set the network interface for inter node communication 50 | os.environ["NCCL_SOCKET_IFNAME"] = cluster_config["network_interface_name"] 51 | 52 | else: 53 | # Seems not to be a SageMaker training job (could be e.g. testing on notebook, local). 54 | # Default to single-machine setup: 55 | default_n_nodes = 1 56 | default_node_rank = 0 57 | leader = "127.0.0.1" 58 | 59 | # Set up DDP & NCCL environment variables: 60 | # https://docs.nvidia.com/deeplearning/sdk/nccl-developer-guide/index.html#ncclknobs 61 | # https://github.com/aws/sagemaker-pytorch-training-toolkit/blob/88ca48a831bf4f099d4c57f3c18e0ff92fa2b48c/src/sagemaker_pytorch_container/training.py#L103 62 | # 63 | # Disable IB transport and force to use IP sockets by default: 64 | os.environ["NCCL_IB_DISABLE"] = "1" 65 | # Set NCCL log level (could be INFO for more debugging information): 66 | if not os.environ.get("NCCL_DEBUG"): 67 | os.environ["NCCL_DEBUG"] = "WARN" 68 | 69 | # Launch PyTorch DDP: 70 | ddp_cmd = ( 71 | [ 72 | "python", 73 | "-m", 74 | "torch.distributed.launch", 75 | "--nproc_per_node", 76 | os.environ["SM_NUM_GPUS"], 77 | "--nnodes", 78 | str(default_n_nodes), 79 | "--node_rank", 80 | str(default_node_rank), 81 | "--master_addr", 82 | leader, 83 | "--master_port", 84 | "7777", 85 | ] 86 | # ...And pass through arguments for the actual train script: 87 | + ["train.py"] 88 | + [arg for arg in sys.argv[1:]] 89 | ) 90 | print("LAUNCHING: " + " ".join(ddp_cmd)) 91 | subprocess.check_call(ddp_cmd) 92 | -------------------------------------------------------------------------------- /notebooks/src/inference.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # SPDX-License-Identifier: MIT-0 3 | """Load custom inference handlers for model deployment 4 | """ 5 | from code.inference import * 6 | -------------------------------------------------------------------------------- /notebooks/src/inference_seq2seq.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # SPDX-License-Identifier: MIT-0 3 | """Load alternative inference handlers for seq2seq model deployment 4 | """ 5 | from code.inference_seq2seq import * 6 | -------------------------------------------------------------------------------- /notebooks/src/requirements.txt: -------------------------------------------------------------------------------- 1 | # Dummy requirements.txt file not needed with customized containers 2 | # To use the vanilla Hugging Face DLCs for training & inference, uncomment the below: 3 | 4 | ## Common: 5 | # amazon-textract-response-parser>=0.1,<0.2 6 | # Pillow>=8,<9 7 | 8 | ## For LayoutLMv2+: 9 | # git+https://github.com/facebookresearch/detectron2.git 10 | # pytesseract 11 | 12 | ## Libraries present by default in training container but missing from inference: 13 | # datasets>=1.18.4,<2 14 | # torchvision==0.11.3 15 | -------------------------------------------------------------------------------- /notebooks/src/smtc_launcher.py: -------------------------------------------------------------------------------- 1 | #!/bin/python 2 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 3 | # SPDX-License-Identifier: MIT-0 4 | """Alternative train launcher script for SageMaker Training Compiler 5 | 6 | More info at: https://docs.aws.amazon.com/sagemaker/latest/dg/training-compiler-enable.html 7 | 8 | To use SMTC, you'll need to specify a `compiler_config` and set the "GPU_NUM_DEVICES" environment 9 | variable on your Estimator to the number of GPUs per instance for the type you have selected. For 10 | example: 11 | 12 | ```python 13 | from sagemaker.huggingface import TrainingCompilerConfig 14 | 15 | pre_estimator = HuggingFaceEstimator( 16 | ..., 17 | compiler_config=TrainingCompilerConfig(), 18 | env={ 19 | ..., 20 | "GPU_NUM_DEVICES": "4", # for ml.p3.8xlarge 21 | }, 22 | ) 23 | ``` 24 | 25 | For single-GPU training, you can use the train.py entry_point as usual. However for multi-GPU 26 | training, you'll need to instead set this `entry_point="smtc_launcher.py"` and add an additional 27 | hyperparameter `"training_script": "train.py"`. 28 | 29 | This training script has been tested to *functionally* work with SMTC (on Hugging Face v4.11 DLC), 30 | but whether you'll see a useful speed-up may be quite hyperparameter- and use-case-dependent. Note 31 | that a substantial portion of the optimization opportunity with SMTC comes from memory efficiency 32 | allowing larger batch sizes. 33 | 34 | Remember also that on p3.16xl and larger where it's supported, enabling SageMaker Distributed Data 35 | Parallel can provide a useful speed boost. When *neither* SMTC nor SMDistributed are enabled, the 36 | HF Trainer API will use PyTorch DataParallel by default (rather than DistributedDataParallel) which 37 | can limit scaling to many GPUs - partly because memory consumption is higher on the "lead" GPU and 38 | so CUDAOutOfMemory will be encountered at lower maximum batch sizes. 39 | 40 | Notes from pre-training experiments 41 | ----------------------------------- 42 | 43 | 2,500 document training set (set N_DOCS_KEPT = 2500 in notebook 1) on `ml.p3.8xlarge`, pre-training 44 | with: 45 | 46 | - num_train_epochs = 25 47 | - early_stopping_patience = 10 48 | - per_device_eval_batch_size = per_device_train_batch_size 49 | - seed = 42 50 | - warmup_steps = 200 51 | 52 | | SMTC | per_device_train_batch_size | learning_rate | Execution Time | min val loss | 53 | |:----:|----------------------------:|--------------:|---------------------:|-------------:| 54 | | No | 4 | 5e-05 | 5h28m16s (25 epochs) | 0.149301 | 55 | | No | 8 | 2e-05 | 4h13m46s (25 epochs) | 0.154481 | 56 | | Yes | 20 | 2e-05 | N/A (GPU OOM) | N/A (OOM) | 57 | | Yes | 16 | 1e-04 | 5h03m03s (25 epochs) | 0.147910 | 58 | | Yes | 16 | 5e-05 | 5h05m03s (25 epochs) | 0.141911 | 59 | | Yes | 16 | 2e-05 | 5h02m52s (25 epochs) | 0.159771 | 60 | | Yes | 16 | 1e-05 | 5h01m09s (25 epochs) | 0.191195 | 61 | | Yes | 16 | 5e-06 | 5h01m48s (25 epochs) | 0.249820 | 62 | | Yes | 12 | 1e-05 | 5h10m35s (25 epochs) | 0.165622 | 63 | | Yes | 8 | 2e-05 | 2h50m02s (12 epochs) | * 0.301963 | 64 | | Yes | 8 | 1e-05 | 2h37m52s (11 epochs) | * 0.627447 | 65 | 66 | (*): Training unstable and stopped early after reaching `nan` loss. Best epoch reported. 67 | """ 68 | # Python Built-Ins: 69 | import subprocess 70 | import sys 71 | 72 | if __name__ == "__main__": 73 | arguments_command = " ".join([arg for arg in sys.argv[1:]]) 74 | subprocess.check_call( 75 | "python -m torch_xla.distributed.sm_dist " + arguments_command, shell=True 76 | ) 77 | -------------------------------------------------------------------------------- /notebooks/src/train.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # SPDX-License-Identifier: MIT-0 3 | """Top-level entrypoint script for model training (also supports loading for inference) 4 | """ 5 | # Python Built-Ins: 6 | import logging 7 | import os 8 | import sys 9 | 10 | 11 | def run_training(): 12 | """Configure logging, import local modules and run the training job""" 13 | consolehandler = logging.StreamHandler(sys.stdout) 14 | consolehandler.setFormatter( 15 | logging.Formatter("%(asctime)s [%(name)s] %(levelname)s %(message)s") 16 | ) 17 | logging.basicConfig(handlers=[consolehandler], level=os.environ.get("LOG_LEVEL", logging.INFO)) 18 | 19 | from code.train import main 20 | 21 | return main() 22 | 23 | 24 | if __name__ == "__main__": 25 | # If the file is running as a script, we're in training mode and should run the actual training 26 | # routine (with a little logging setup before any imports, to make sure output shows up ok): 27 | run_training() 28 | else: 29 | # If the file is imported as a module, we're in inference mode and should pass through the 30 | # override functions defined in the inference module. This is to support directly deploying the 31 | # model via SageMaker SDK's Estimator.deploy(), which will carry over the environment variable 32 | # SAGEMAKER_PROGRAM=train.py from training - causing the server to try and load handlers from 33 | # here rather than inference.py. 34 | from code.inference import * 35 | 36 | 37 | def _mp_fn(index): 38 | """For torch_xla / SageMaker Training Compiler 39 | 40 | (See smtc_launcher.py in this folder for configuration tips) 41 | """ 42 | return run_training() 43 | -------------------------------------------------------------------------------- /notebooks/util/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # SPDX-License-Identifier: MIT-0 3 | """Utility functions to help keep the notebooks tidy - Amazon Textract + Transformers sample 4 | """ 5 | 6 | # Before importing any submodules, we'll configure Python `logging` nicely for notebooks. 7 | # 8 | # By "nicely", we mean: 9 | # - Setting up the formatting to display timestamp, logger name and level 10 | # - Sending messages >= WARN to stderr (so Jupyter renders them with pink/red background) 11 | # - Sending messages < WARN to stdout (so Jupyter renders them plain, like a print()) 12 | import logging 13 | from logging.config import dictConfig 14 | 15 | 16 | class MaxLevelFilter(logging.Filter): 17 | """A custom Python logging Filter to reject messages *above* a certain `max_level`""" 18 | 19 | def __init__(self, max_level): 20 | self._max_level = max_level 21 | super().__init__() 22 | 23 | def filter(self, record): 24 | return record.levelno <= self._max_level 25 | 26 | @classmethod 27 | def qualname(cls): 28 | return ".".join([cls.__module__, cls.__qualname__]) 29 | 30 | 31 | dictConfig( 32 | { 33 | "formatters": { 34 | "basefmt": {"format": "%(asctime)s %(name)s [%(levelname)s] %(message)s"}, 35 | }, 36 | "filters": { 37 | "maxinfo": {"()": MaxLevelFilter.qualname(), "max_level": logging.INFO}, 38 | }, 39 | "handlers": { 40 | "stdout": { 41 | "class": "logging.StreamHandler", 42 | "filters": ["maxinfo"], 43 | "formatter": "basefmt", 44 | "level": logging.DEBUG, 45 | "stream": "ext://sys.stdout", 46 | }, 47 | "stderr": { 48 | "class": "logging.StreamHandler", 49 | "formatter": "basefmt", 50 | "level": logging.WARN, 51 | "stream": "ext://sys.stderr", 52 | }, 53 | }, 54 | "loggers": { 55 | "": {"handlers": ["stderr", "stdout"], "level": logging.INFO}, 56 | }, 57 | "version": 1, 58 | } 59 | ) 60 | 61 | # Now import the submodules: 62 | from . import deployment 63 | from . import ocr 64 | from . import preproc 65 | from . import project 66 | from . import s3 67 | from . import smgt 68 | from . import training 69 | from . import uid 70 | from . import viz 71 | -------------------------------------------------------------------------------- /notebooks/util/deployment.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # SPDX-License-Identifier: MIT-0 3 | """Utilities to simplify model/endpoint deployment 4 | """ 5 | 6 | # Python Built-Ins: 7 | import errno 8 | import io 9 | from logging import getLogger 10 | import os 11 | import tarfile 12 | 13 | # External Dependencies: 14 | import numpy as np 15 | import sagemaker 16 | 17 | 18 | logger = getLogger("deploy") 19 | 20 | 21 | def tar_as_inference_code(folder: str, outfile: str = "model.tar.gz") -> str: 22 | """Package a folder of code (without model artifacts) to run in SageMaker endpoint 23 | 24 | SageMaker framework endpoints expect a .tar.gz archive, and PyTorch/HuggingFace frameworks in 25 | particular look for a 'code/' folder within this archive with an 'inference.py' entrypoint 26 | script. 27 | 28 | Given a local folder, this function will produce a .tar.gz file with the folder's contents 29 | archived under 'code/'. It will warn if the folder does not contain an 'inference.py'. 30 | 31 | Parameters 32 | ---------- 33 | folder : 34 | Local folder of code to package 35 | outfile : 36 | Local path to write output archive to (default "model.tar.gz") 37 | 38 | Returns 39 | ------- 40 | outfile : 41 | (Unchanged) local path to the saved tarball. 42 | """ 43 | 44 | if "inference.py" not in os.listdir(folder): 45 | logger.warning( 46 | "Folder '%s' does not contain an 'inference.py' and so won't work as a SM endpoint " 47 | "bundle unless you make extra configurations on your Model", 48 | folder, 49 | ) 50 | os.makedirs(os.path.dirname(outfile), exist_ok=True) 51 | try: # Remove existing file if present 52 | os.remove(outfile) 53 | except OSError as e: 54 | if e.errno != errno.ENOENT: 55 | raise e 56 | with tarfile.open(outfile, mode="w:gz") as archive: 57 | archive.add( 58 | folder, 59 | # Name folder explicitly as 'code', as required for modern PyTorchModel versions: 60 | arcname="code", 61 | # Exclude hidden files like .ipynb_checkpoints: 62 | filter=lambda info: None if "/." in info.name else info, 63 | ) 64 | return outfile 65 | 66 | 67 | class FileSerializer(sagemaker.serializers.SimpleBaseSerializer): 68 | """Serializer to simply send contents of a file: predictor.predict(filepath) 69 | 70 | You should set content_type to match your intended files when constructing this serializer. For 71 | example 'application/pdf', 'image/png', etc. 72 | """ 73 | 74 | EXTENSION_TO_MIME_TYPE = { 75 | "jpg": "image/jpg", 76 | "jpeg": "image/jpeg", 77 | "pdf": "application/pdf", 78 | "png": "image/png", 79 | } 80 | 81 | def serialize(self, data: str): 82 | with open(data, "rb") as f: 83 | return f.read() 84 | 85 | @classmethod 86 | def content_type_from_filename(cls, filename: str): 87 | ext = filename.rpartition(".")[2] 88 | try: 89 | return cls.EXTENSION_TO_MIME_TYPE[ext] 90 | except KeyError as ke: 91 | pass 92 | raise ValueError(f"Unknown content type for filename extension '.{ext}'") 93 | 94 | @classmethod 95 | def from_filename(cls, filename: str, **kwargs): 96 | return cls(content_type=cls.content_type_from_filename(filename), **kwargs) 97 | 98 | 99 | class CompressedNumpyDeserializer(sagemaker.deserializers.NumpyDeserializer): 100 | """Like SageMaker's NumpyDeserializer, but also supports (and defaults to) .npz archive 101 | 102 | While .npy files save an individual array, .npz archives store multiple named variables and 103 | can be saved with compression to further reduce payload size. 104 | """ 105 | 106 | def __init__(self, dtype=None, accept="application/x-npz", allow_pickle=True): 107 | super(CompressedNumpyDeserializer, self).__init__( 108 | dtype=dtype, accept=accept, allow_pickle=allow_pickle 109 | ) 110 | 111 | def deserialize(self, stream, content_type): 112 | if content_type == "application/x-npz": 113 | try: 114 | return np.load(io.BytesIO(stream.read()), allow_pickle=self.allow_pickle) 115 | finally: 116 | stream.close() 117 | else: 118 | super(CompressedNumpyDeserializer, self).deserialize(stream, content_type) 119 | -------------------------------------------------------------------------------- /notebooks/util/postproc: -------------------------------------------------------------------------------- 1 | ../../pipeline/postprocessing/fn-postprocess/util -------------------------------------------------------------------------------- /notebooks/util/project.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # SPDX-License-Identifier: MIT-0 3 | """ML project infrastructure utilities (reading stack params from AWS SSM) 4 | 5 | init() with a valid *project ID* (or provide the PROJECT_ID environment variable), and this module 6 | will read the project's configuration from AWS SSM (stored there by the CloudFormation stack): 7 | Allowing us to reference from the SageMaker notebook to resources created by the stack. 8 | 9 | A "session" in the context of this module is a project config loaded from SSM. This way we can 10 | choose either to init-and-forget (standard data science project sandbox use case) or to call 11 | individual functions on separate sessions (interact with multiple projects). 12 | """ 13 | 14 | # Python Built-Ins: 15 | import logging 16 | import os 17 | from types import SimpleNamespace 18 | from typing import Dict, Union 19 | 20 | # External Dependencies: 21 | import boto3 22 | 23 | logger = logging.getLogger("project") 24 | ssm = boto3.client("ssm") 25 | 26 | 27 | defaults = SimpleNamespace() 28 | defaults.project_id = None 29 | defaults.session = None 30 | 31 | 32 | if "PROJECT_ID" not in os.environ: 33 | logger.info( 34 | "No PROJECT_ID variable found in environment: You'll need to call init('myprojectid')" 35 | ) 36 | else: 37 | defaults.project_id = os.environ["PROJECT_ID"] 38 | 39 | 40 | class ProjectSession: 41 | """Class defining the parameters for a project and how they get loaded (from AWS SSSM)""" 42 | 43 | SSM_PREFIX: str = "" 44 | STATIC_PARAMS: Dict[str, str] = { 45 | "static/A2IExecutionRoleArn": "a2i_execution_role_arn", 46 | "static/InputBucket": "pipeline_input_bucket_name", 47 | "static/ReviewsBucket": "pipeline_reviews_bucket_name", 48 | "static/PipelineStateMachine": "pipeline_sfn_arn", 49 | "static/PlainTextractStateMachine": "plain_textract_sfn_arn", 50 | "static/PreprocImageURI": "preproc_image_uri", 51 | "static/ThumbnailsCallbackTopicArn": "thumbnails_callback_topic_arn", 52 | "static/ModelCallbackTopicArn": "model_callback_topic_arn", 53 | "static/ModelResultsBucket": "model_results_bucket", 54 | "static/SMDockerBuildRole": "sm_image_build_role", 55 | } 56 | DYNAMIC_PARAM_NAMES: Dict[str, str] = { 57 | "config/HumanReviewFlowArn": "a2i_review_flow_arn_param", 58 | "config/EntityConfiguration": "entity_config_param", 59 | "config/SageMakerEndpointName": "sagemaker_endpoint_name_param", 60 | "config/ThumbnailEndpointName": "thumbnail_endpoint_name_param", 61 | } 62 | 63 | # Static values (from SSM): 64 | a2i_execution_role_arn: str 65 | pipeline_input_bucket_name: str 66 | pipeline_reviews_bucket_name: str 67 | pipeline_sfn_arn: str 68 | plain_textract_sfn_arn: str 69 | preproc_image_uri: str 70 | model_callback_topic_arn: str 71 | model_results_bucket: str 72 | sm_image_build_role: str 73 | thumbnails_callback_topic_arn: str 74 | # Configurable parameters (names in SSM): 75 | a2i_review_flow_arn_param: str 76 | entity_config_param: str 77 | sagemaker_endpoint_name_param: str 78 | thumbnail_endpoint_name_param: str 79 | 80 | def __init__(self, project_id: str): 81 | """Create a ProjectSession 82 | 83 | Parameters 84 | ---------- 85 | project_id : 86 | ProjectId from the provisioned OCR pipeline stack 87 | """ 88 | self.project_id = project_id 89 | 90 | # Load SSM names for dynamic/configuration parameters: 91 | for param_suffix, session_attr in self.DYNAMIC_PARAM_NAMES.items(): 92 | setattr(self, session_attr, f"{self.SSM_PREFIX}/{project_id}/{param_suffix}") 93 | 94 | # Load SSM *values* for static project attributes: 95 | param_names_to_config_attrs = { 96 | f"{self.SSM_PREFIX}/{project_id}/{s}": self.STATIC_PARAMS[s] for s in self.STATIC_PARAMS 97 | } 98 | response = ssm.get_parameters(Names=[s for s in param_names_to_config_attrs]) 99 | n_invalid = len(response.get("InvalidParameters", [])) 100 | if n_invalid == len(param_names_to_config_attrs): 101 | raise ValueError(f"Found no valid SSM parameters for /{project_id}: Invalid project ID") 102 | elif n_invalid > 0: 103 | logger.warning( 104 | " ".join( 105 | [ 106 | f"{n_invalid} Project parameters missing from SSM: Some functionality ", 107 | f"may not work as expected. Missing: {response['InvalidParameters']}", 108 | ] 109 | ) 110 | ) 111 | 112 | for param in response["Parameters"]: 113 | param_name = param["Name"] 114 | setattr(self, param_names_to_config_attrs[param_name], param["Value"]) 115 | 116 | def __repr__(self): 117 | """Produce a meaningful representation when this class is print()ed""" 118 | typ = type(self) 119 | mod = typ.__module__ 120 | qualname = typ.__qualname__ 121 | propdict = self.__dict__ 122 | proprepr = ",\n ".join([f"{k}={propdict[k]}" for k in propdict]) 123 | return f"<{mod}.{qualname}(\n {proprepr}\n) at {hex(id(self))}>" 124 | 125 | 126 | def init(project_id: str) -> ProjectSession: 127 | """Initialise the project util library (and the default session) to project_id""" 128 | # Check that we can create the session straight away, for nice error behaviour: 129 | session = ProjectSession(project_id) 130 | if defaults.project_id and defaults.project_id != project_id and defaults.session: 131 | logger.info(f"Clearing previous default session for project '{defaults.project_id}'") 132 | defaults.project_id = project_id 133 | defaults.session = session 134 | logger.info(f"Working in project '{project_id}'") 135 | return session 136 | 137 | 138 | def session_or_default(sess: Union[ProjectSession, None] = None): 139 | """Mostly-internal utility to return either the provided session or else a default""" 140 | if sess: 141 | return sess 142 | elif defaults.session: 143 | return defaults.session 144 | elif defaults.project_id: 145 | defaults.session = ProjectSession(defaults.project_id) 146 | return defaults.session 147 | else: 148 | raise ValueError( 149 | "Must provide a project session or init() the project library with a valid project ID" 150 | ) 151 | -------------------------------------------------------------------------------- /notebooks/util/s3.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # SPDX-License-Identifier: MIT-0 3 | """Notebook-simplifying utilities for working with Amazon S3 4 | """ 5 | # Python Built-Ins: 6 | from typing import Optional, Tuple 7 | 8 | # External Dependencies: 9 | import boto3 10 | 11 | 12 | s3client = boto3.client("s3") 13 | 14 | 15 | def s3uri_to_bucket_and_key(s3uri: str) -> Tuple[str, str]: 16 | """Convert an s3://... URI string to a (bucket, key) tuple""" 17 | if not s3uri.lower().startswith("s3://"): 18 | raise ValueError(f"Expected S3 object URI to start with s3://. Got: {s3uri}") 19 | bucket, _, key = s3uri[len("s3://") :].partition("/") 20 | return bucket, key 21 | 22 | 23 | def s3uri_to_relative_path(s3uri: str, key_base: str) -> str: 24 | """Extract e.g. 'subfolders/file' from 's3://bucket/.../{key_base}subfolders/file' 25 | 26 | If `key_base` is a folder, it should typically include a trailing slash. 27 | """ 28 | return s3uri[len("s3://") :].partition("/")[2].partition(key_base)[2] 29 | 30 | 31 | def s3_object_exists(bucket_name_or_s3uri: str, key: Optional[str] = None) -> bool: 32 | """Check if an object exists in Amazon S3 33 | 34 | Parameters 35 | ---------- 36 | bucket_name_or_s3uri : 37 | Either an 's3://.../...' object URI, or an S3 bucket name. 38 | key : 39 | Ignored if `bucket_name_or_s3uri` is a full URI, otherwise mandatory: Key of the object to 40 | check. 41 | """ 42 | if bucket_name_or_s3uri.lower().startswith("s3://"): 43 | bucket_name, key = s3uri_to_bucket_and_key(bucket_name_or_s3uri) 44 | elif not key: 45 | raise ValueError( 46 | "key is mandatory when bucket_name_or_s3uri is not an s3:// URI. Got: %s" 47 | % bucket_name_or_s3uri 48 | ) 49 | else: 50 | bucket_name = bucket_name_or_s3uri 51 | try: 52 | s3client.head_object(Bucket=bucket_name, Key=key) 53 | return True 54 | except s3client.exceptions.ClientError as e: 55 | if e.response["Error"]["Code"] == "404": 56 | return False 57 | else: 58 | raise e 59 | -------------------------------------------------------------------------------- /notebooks/util/training.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # SPDX-License-Identifier: MIT-0 3 | """Utilities to support the model training on SageMaker""" 4 | 5 | 6 | def get_hf_metric_regex(metric_name: str) -> str: 7 | """Build RegEx string to extract a numeric HuggingFace Transformers metric from SageMaker logs 8 | 9 | HF metric log lines look like a Python dict print e.g: 10 | {'eval_loss': 0.3940396010875702, ..., 'epoch': 1.0} 11 | """ 12 | scientific_number_exp = r"(-?[0-9]+(\.[0-9]+)?(e[+\-][0-9]+)?)" 13 | return "".join( 14 | ( 15 | "'", 16 | metric_name, 17 | "': ", 18 | scientific_number_exp, 19 | "[,}]", 20 | ) 21 | ) 22 | -------------------------------------------------------------------------------- /notebooks/util/uid.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # SPDX-License-Identifier: MIT-0 3 | """Unique ID utilities for SageMaker""" 4 | 5 | from datetime import datetime 6 | 7 | 8 | def append_timestamp(s: str, sep: str = "-", include_millis=True) -> str: 9 | """Append current datetime to `s` in a format suitable for SageMaker job names""" 10 | now = datetime.now() 11 | if include_millis: 12 | # strftime only supports microseconds, so we trim by 3: 13 | datetime_str = now.strftime(f"%Y{sep}%m{sep}%d{sep}%H{sep}%M{sep}%S{sep}%f")[:-3] 14 | else: 15 | datetime_str = now.strftime(f"%Y{sep}%m{sep}%d{sep}%H{sep}%M{sep}%S") 16 | return sep.join((s, datetime_str)) 17 | -------------------------------------------------------------------------------- /package-lock.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "amazon-textract-transformer-pipeline", 3 | "version": "0.2.1", 4 | "lockfileVersion": 3, 5 | "requires": true, 6 | "packages": { 7 | "": { 8 | "name": "amazon-textract-transformer-pipeline", 9 | "version": "0.2.1", 10 | "license": "MIT-0", 11 | "dependencies": { 12 | "aws-cdk": "2.126.0" 13 | } 14 | }, 15 | "node_modules/aws-cdk": { 16 | "version": "2.126.0", 17 | "resolved": "https://registry.npmjs.org/aws-cdk/-/aws-cdk-2.126.0.tgz", 18 | "integrity": "sha512-hEyy8UCEEUnkieH6JbJBN8XAbvuVZNdBmVQ8wHCqo8RSNqmpwM1qvLiyXV/2JvCqJJ0bl9uBiZ98Ytd5i3wW7g==", 19 | "bin": { 20 | "cdk": "bin/cdk" 21 | }, 22 | "engines": { 23 | "node": ">= 14.15.0" 24 | }, 25 | "optionalDependencies": { 26 | "fsevents": "2.3.2" 27 | } 28 | }, 29 | "node_modules/fsevents": { 30 | "version": "2.3.2", 31 | "resolved": "https://registry.npmjs.org/fsevents/-/fsevents-2.3.2.tgz", 32 | "integrity": "sha512-xiqMQR4xAeHTuB9uWm+fFRcIOgKBMiOBP+eXiyT7jsgVCq1bkVygt00oASowB7EdtpOHaaPgKt812P9ab+DDKA==", 33 | "hasInstallScript": true, 34 | "optional": true, 35 | "os": [ 36 | "darwin" 37 | ], 38 | "engines": { 39 | "node": "^8.16.0 || ^10.6.0 || >=11.0.0" 40 | } 41 | } 42 | } 43 | } 44 | -------------------------------------------------------------------------------- /package.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "amazon-textract-transformer-pipeline", 3 | "version": "0.2.1", 4 | "description": "Post-processing Amazon Textract with Transformer-Based Models on Amazon SageMaker", 5 | "scripts": { 6 | "lint": "black ./cdk", 7 | "login:ecrpublic": "aws ecr-public get-login-password --region us-east-1 | docker login --username AWS --password-stdin public.ecr.aws" 8 | }, 9 | "keywords": [ 10 | "Amazon-Textract", 11 | "AWS", 12 | "Intelligent-Document-Processing", 13 | "Machine-Learning", 14 | "Transformers" 15 | ], 16 | "author": "Amazon Web Services", 17 | "license": "MIT-0", 18 | "private": true, 19 | "dependencies": { 20 | "aws-cdk": "2.126.0" 21 | } 22 | } 23 | -------------------------------------------------------------------------------- /pipeline/config_utils.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 3 | # SPDX-License-Identifier: MIT-0 4 | """Utilities for configuring the stack (e.g. environment variable parsing) 5 | """ 6 | # Python Built-Ins: 7 | import os 8 | from typing import List, Optional 9 | 10 | 11 | def bool_env_var(env_var_name: str, default: Optional[bool] = None) -> bool: 12 | """Parse a boolean environment variable 13 | 14 | Raises 15 | ------ 16 | ValueError : 17 | If environment variable `env_var_name` is not found and no `default` is specified, or if the 18 | raw value string could not be interpreted as a boolean. 19 | 20 | Returns 21 | ------- 22 | parsed : 23 | True if the env var has values such as `1`, `true`, `y`, `yes` (case-insensitive). False if 24 | opposite values `0`, `false`, `n`, `no` or empty string. 25 | """ 26 | raw = os.environ.get(env_var_name) 27 | if raw is None: 28 | if default is None: 29 | raise ValueError(f"Mandatory boolean env var '{env_var_name}' not found") 30 | return default 31 | raw = raw.lower() 32 | if raw in ("1", "true", "y", "yes"): 33 | return True 34 | elif raw in ("", "0", "false", "n", "no"): 35 | return False 36 | else: 37 | raise ValueError( 38 | "Couldn't interpret env var '%s' as boolean. Got: '%s'" % (env_var_name, raw) 39 | ) 40 | 41 | 42 | def list_env_var(env_var_name: str, default: Optional[List[str]] = None) -> List[str]: 43 | """Parse a comma-separated string list from an environment variable 44 | 45 | Raises 46 | ------ 47 | ValueError : 48 | If environment variable `env_var_name` is not found and no `default` is specified. 49 | 50 | Returns 51 | ------- 52 | parsed : 53 | List of strings: Split by commas in the raw input, each stripped of any leading/trailing 54 | whitespace, and filtered to remove any empty values. For example: `dog, , cat` returns 55 | `["dog", "cat"]`. Empty environment variable returns `[]`. Whitespace stripping and 56 | filtering is not applied to the `default` value, if used. 57 | """ 58 | raw = os.environ.get(env_var_name) 59 | if raw is None: 60 | if default is None: 61 | raise ValueError(f"Mandatory string-list env var {env_var_name} not found") 62 | return default[:] 63 | whitespace_stripped_values = [s.strip() for s in raw.split(",")] 64 | return [s for s in whitespace_stripped_values if s] 65 | -------------------------------------------------------------------------------- /pipeline/enrichment/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # SPDX-License-Identifier: MIT-0 3 | """CDK for NLP/ML model enrichment stage of the OCR pipeline 4 | """ 5 | # Python Built-Ins: 6 | from typing import Dict, List, Optional, Union 7 | 8 | # External Dependencies: 9 | from aws_cdk import Duration, Token 10 | from aws_cdk.aws_iam import Effect, PolicyStatement, Role 11 | from aws_cdk.aws_s3 import Bucket 12 | import aws_cdk.aws_ssm as ssm 13 | import aws_cdk.aws_stepfunctions as sfn 14 | from constructs import Construct 15 | 16 | # Local Dependencies: 17 | from ..shared.sagemaker import SageMakerCallerFunction, SageMakerSSMStep 18 | 19 | 20 | class SageMakerEnrichmentStep(Construct): 21 | """CDK construct for an OCR pipeline step to enrich Textract JSON on S3 using SageMaker 22 | 23 | This construct's `.sfn_task` takes input from JSONPath locations as specified by init params 24 | `textracted_input_jsonpath` (mandatory) and `thumbnail_input_jsonpath` (optional). The first 25 | links to a consolidated Textract JSON result in S3 as {Bucket, Key}. The second (if present), 26 | links to a consolidated page thumbnails file for the document: Again as S3 {Bucket, Key}. The 27 | task will set $.Textract on the output, with a similar { Bucket, Key } structure pointing to 28 | the enriched output JSON file. 29 | 30 | This step is implemented via AWS Lambda (rather than direct Step Function service call) to 31 | support looking up the configured SageMaker endpoint name from SSM within the same SFn step. 32 | 33 | When `support_async_endpoints` is enabled, the construct uses an asynchronous/TaskToken Lambda 34 | integration and checks at run-time whether the configured endpoint is sync or async. For async 35 | invocations, the same Lambda processes SageMaker callback events via SNS to notify SFn. 36 | """ 37 | 38 | def __init__( 39 | self, 40 | scope: Construct, 41 | id: str, 42 | lambda_role: Role, 43 | output_bucket: Bucket, 44 | ssm_param_prefix: Union[Token, str], 45 | textracted_input_jsonpath: Dict[str, str], 46 | thumbnail_input_jsonpath: Optional[Dict[str, str]] = None, 47 | support_async_endpoints: bool = True, 48 | shared_sagemaker_caller_lambda: Optional[SageMakerCallerFunction] = None, 49 | **kwargs, 50 | ): 51 | """Create a SageMakerEnrichmentStep 52 | 53 | Parameters 54 | ---------- 55 | lambda_role : 56 | IAM Execution Role for AWS Lambda, which will be used for the function to invoke the 57 | SageMaker endpoint. 58 | output_bucket : 59 | S3 Bucket where inference results should be stored. 60 | ssm_param_prefix : 61 | Name prefix under which the SSM SageMakerEndpointName parameter will be generated. 62 | textracted_input_jsonpath : 63 | Dict of `{ Bucket, Key }` locating the input document Textract result (should typically 64 | each be an `aws_stepfunctions.JsonPath` pointing to strings in the SFn state). 65 | thumbnail_input_jsonpath : 66 | Optional Dict of `{ Bucket, Key }` locating the thumbnail images archive for the input 67 | document (if thumbnailing is enabled). Structure as `textracted_input_jsonpath`. 68 | support_async_endpoints : 69 | As per `..shared.sagemaker.SageMakerSSMStep` 70 | shared_sagemaker_caller_lambda : 71 | Optional SageMakerCallerFunction Lambda for calling the SageMaker endpoint, if an 72 | already-created one is to be used (for sharing with other constructs). 73 | **kwargs : 74 | As per Construct parent 75 | """ 76 | super().__init__(scope, id, **kwargs) 77 | 78 | self.endpoint_param = ssm.StringParameter( 79 | self, 80 | "EnrichmentSageMakerEndpointParam", 81 | description="Name of the SageMaker Endpoint to call for OCR result enrichment", 82 | parameter_name=f"{ssm_param_prefix}SageMakerEndpointName", 83 | simple_name=False, 84 | string_value="undefined", 85 | ) 86 | lambda_role.add_to_policy( 87 | PolicyStatement( 88 | sid="ReadSageMakerEndpointParam", 89 | actions=["ssm:GetParameter"], 90 | effect=Effect.ALLOW, 91 | resources=[self.endpoint_param.parameter_arn], 92 | ) 93 | ) 94 | 95 | output_bucket.grant_read_write(lambda_role, "enriched/*") 96 | 97 | # Prepare the "Body" param for the Lambda function: 98 | inf_req_body = { 99 | "S3Input": textracted_input_jsonpath, 100 | "S3Output": { 101 | "Bucket": output_bucket.bucket_name, 102 | "Key": sfn.JsonPath.format( 103 | "enriched/{}", 104 | textracted_input_jsonpath["Key"], 105 | ), 106 | }, 107 | } 108 | if thumbnail_input_jsonpath is None: 109 | # No need to upload this payload to S3: Lambda can directly invoke SageMaker on S3Input 110 | # if async, or pass the object if sync. 111 | body_upload = None 112 | else: 113 | inf_req_body["S3Thumbnails"] = thumbnail_input_jsonpath 114 | # Since our Body.S3Input doesn't contain the *entire* endpoint input (as the raw 115 | # Textract JSON is missing the S3Thumbnails link), to call an async SM endpoint Lambda 116 | # will need to upload the above `inf_req_body` JSON to S3 first: 117 | body_upload = { 118 | "Bucket": output_bucket.bucket_name, 119 | "Key": sfn.JsonPath.format( 120 | "requests/{}", 121 | textracted_input_jsonpath["Key"], 122 | ), 123 | } 124 | 125 | self.sfn_task = SageMakerSSMStep( 126 | self, 127 | "NLPEnrichmentModel", 128 | comment="Post-Process the Textract result with Amazon SageMaker", 129 | lambda_function=shared_sagemaker_caller_lambda, 130 | lambda_role=lambda_role, 131 | support_async_endpoints=support_async_endpoints, 132 | payload=sfn.TaskInput.from_object( 133 | { 134 | # Because the caller lambda can be shared, we need to specify the param on req: 135 | "EndpointNameParam": self.endpoint_param.parameter_name, 136 | "Body": inf_req_body, 137 | **({"BodyUpload": body_upload} if body_upload else {}), 138 | "ContentType": "application/json", 139 | **({"TaskToken": sfn.JsonPath.task_token} if support_async_endpoints else {}), 140 | } 141 | ), 142 | # We call the output variable 'Textract' here because it's an augmented Textract JSON - 143 | # so downstream components can treat it broadly as a Textract result: 144 | result_path="$.Textract", 145 | timeout=Duration.minutes(30), 146 | ) 147 | 148 | def sagemaker_sns_statements(self, sid_prefix: Union[str, None] = "") -> List[PolicyStatement]: 149 | """Create PolicyStatements to grant SageMaker permission to use the SNS callback topic 150 | 151 | Arguments 152 | --------- 153 | sid_prefix : str | None 154 | Prefix to add to generated statement IDs for uniqueness, or "", or None to suppress 155 | SIDs. 156 | """ 157 | return self.sfn_task.sagemaker_sns_statements(sid_prefix=sid_prefix) 158 | -------------------------------------------------------------------------------- /pipeline/fn-trigger/main.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # SPDX-License-Identifier: MIT-0 3 | """Lambda function to trigger the OCR state machine from an S3 object upload notification 4 | """ 5 | # Python Built-Ins: 6 | from datetime import datetime 7 | import json 8 | import logging 9 | import os 10 | from packaging import version 11 | import re 12 | from typing import List 13 | from urllib.parse import unquote_plus 14 | 15 | # External Dependencies: 16 | import boto3 # AWS SDK for Python 17 | 18 | logger = logging.getLogger() 19 | logger.setLevel(logging.INFO) 20 | sfn = boto3.client("stepfunctions") 21 | 22 | STATE_MACHINE_ARN = os.environ.get("STATE_MACHINE_ARN") 23 | S3_EVENT_STRUCTURE_MAJOR = 2 24 | 25 | 26 | class MalformedRequest(ValueError): 27 | pass 28 | 29 | 30 | class S3Event: 31 | """Model for an individual record within an S3 notification""" 32 | 33 | bucket: str 34 | id: str 35 | key: str 36 | 37 | def __init__(self, record: dict): 38 | if not record: 39 | raise MalformedRequest(f"Empty record in S3 notification: {record}") 40 | 41 | self.event_version = record.get("eventVersion") 42 | if version.parse(self.event_version).major != S3_EVENT_STRUCTURE_MAJOR: 43 | raise NotImplementedError( 44 | f"S3 event version {self.event_version} is not supported by this solution." 45 | ) 46 | 47 | self.bucket = record.get("s3", {}).get("bucket", {}).get("name") 48 | if not self.bucket: 49 | raise MalformedRequest(f"s3.bucket.name not found in notification: {record}") 50 | 51 | # S3 object notifications quote object key spaces wih '+'. Can undo as follows: 52 | self.key = unquote_plus(record.get("s3", {}).get("object", {}).get("key")) 53 | if not self.key: 54 | raise MalformedRequest(f"s3.object.key not found in notification: {record}") 55 | 56 | # A Step-Functions-compatible event ID with timestamp and filename: 57 | self.id = re.sub( 58 | r"[\-]{2,}", 59 | "-", # Reduce consecutive hyphens 60 | re.sub( 61 | r"[\s<>\{\}\[\]\?\*\"#%\\\^\|\~`\$&,;:/\u0000-\u001F\u007F-\u009F]+", 62 | "-", # Replace special chars in filename with hyphens 63 | "-".join( 64 | ( 65 | # ISO timestamp (millisecond precision) of event: 66 | record.get("eventTime", datetime.now().isoformat()), 67 | # Filename of document: 68 | self.key.rpartition("/")[2], 69 | ) 70 | ), 71 | ), 72 | )[:80] 73 | 74 | 75 | class S3Notification: 76 | """Model for an S3 event notification comprising multiple records""" 77 | 78 | events: List[S3Event] 79 | parse_errors: List[Exception] 80 | 81 | def __init__(self, event: dict): 82 | records = event.get("Records") 83 | if not records: 84 | raise MalformedRequest("Couldn't find 'Records' array in input event") 85 | elif not len(records): 86 | raise MalformedRequest("No Records to process in input event") 87 | self.events = [] 88 | self.parse_errors = [] 89 | for record in records: 90 | try: 91 | self.events.append(S3Event(record)) 92 | except Exception as err: 93 | logger.exception("Failed to parse S3 notification record") 94 | self.parse_errors.append(err) 95 | 96 | 97 | def handler(event: dict, context): 98 | """Trigger the OCR pipeline state machine in response to an S3 event notification""" 99 | s3notification = S3Notification(event) 100 | 101 | for record in s3notification.events: 102 | sfn_input = { 103 | "Input": { 104 | "Bucket": record.bucket, 105 | "Key": record.key, 106 | }, 107 | } 108 | 109 | sfn.start_execution( 110 | stateMachineArn=STATE_MACHINE_ARN, 111 | name=record.id, 112 | input=json.dumps(sfn_input), 113 | ) 114 | logger.info(f"Started SFn execution {record.id} from s3://${record.bucket}/{record.key}") 115 | -------------------------------------------------------------------------------- /pipeline/fn-trigger/requirements.txt: -------------------------------------------------------------------------------- 1 | packaging>=21 2 | -------------------------------------------------------------------------------- /pipeline/iam_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # SPDX-License-Identifier: MIT-0 3 | """CDK IAM convenience utilities for the OCR pipeline 4 | 5 | Since we want to output a ManagedPolicy users can attach to their existing SageMaker execution 6 | roles, but CDK ManagedPolicy objects do not implement IGrantable (see discussion at 7 | https://github.com/aws/aws-cdk/issues/7448): The typical high-level access grant functions like 8 | Bucket.grant_read_write() won't work for this use case and we'll instead define these utility 9 | classes to simplify directly setting up useful IAM policies. 10 | """ 11 | # Python Built-Ins: 12 | from itertools import zip_longest 13 | from typing import Iterable, Union 14 | 15 | # External Dependencies: 16 | from aws_cdk import Stack 17 | from aws_cdk.aws_iam import PolicyStatement 18 | from aws_cdk.aws_s3 import Bucket 19 | from aws_cdk.aws_ssm import IParameter 20 | from aws_cdk.aws_stepfunctions import StateMachine 21 | 22 | 23 | class S3Statement(PolicyStatement): 24 | """Utility class for creating PolicyStatement granting S3 read/write permissions""" 25 | 26 | def __init__( 27 | self, 28 | actions: Union[str, None] = None, 29 | grant_read: bool = True, 30 | grant_write: bool = False, 31 | resources: Iterable[Bucket] = [], 32 | resource_key_patterns: Iterable[str] = [], 33 | **kwargs, 34 | ): 35 | """Create a SsmParameterReadStatement 36 | 37 | Arguments 38 | --------- 39 | actions : Sequence[str] or None 40 | Appended to built-in list if provided 41 | grant_read : bool 42 | Whether to include built-in IAM actions for read permissions 43 | grant_write : bool 44 | Whether to include built-in IAM actions for write permissions 45 | resources : Iterable[s3.Bucket] 46 | S3 Buckets to grant read/write access to 47 | resource_key_patterns : Iterable[str] 48 | Key patterns to restrict access to, corresponding to the list in `resources`. 49 | **kwargs : Any 50 | Passed through to PolicyStatement 51 | """ 52 | super().__init__( 53 | actions=( 54 | (["s3:GetBucket*", "s3:GetObject*", "s3:List*"] if grant_read else []) 55 | + (["s3:Abort*", "s3:DeleteObject*", "s3:PutObject*"] if grant_write else []) 56 | + (actions if actions else []) 57 | ), 58 | resources=[b.bucket_arn for b in resources] 59 | + [ 60 | b.arn_for_objects(k or "*") 61 | for b, k in zip_longest(resources, resource_key_patterns) 62 | ], 63 | **kwargs, 64 | ) 65 | 66 | 67 | class SsmParameterReadStatement(PolicyStatement): 68 | """Utility class for creating PolicyStatement granting SSM parameter read permissions""" 69 | 70 | def __init__( 71 | self, 72 | actions: Union[str, None] = None, 73 | resources: Iterable[IParameter] = [], 74 | **kwargs, 75 | ): 76 | """Create a SsmParameterReadStatement 77 | 78 | Arguments 79 | --------- 80 | actions : Sequence[str] or None 81 | Appended to built-in list if provided 82 | resources : Iterable[ssm.IParameter] 83 | SSM parameters to grant read access to 84 | **kwargs : Any 85 | Passed through to PolicyStatement 86 | """ 87 | super().__init__( 88 | actions=[ 89 | "ssm:DescribeParameters", 90 | "ssm:GetParameter", 91 | "ssm:GetParameterHistory", 92 | "ssm:GetParameters", 93 | ] 94 | + (actions if actions else []), 95 | resources=[p.parameter_arn for p in resources], 96 | **kwargs, 97 | ) 98 | 99 | 100 | class SsmParameterWriteStatement(PolicyStatement): 101 | """Utility class for creating PolicyStatement granting SSM parameter read permissions""" 102 | 103 | def __init__( 104 | self, 105 | actions: Union[str, None] = None, 106 | resources: Iterable[IParameter] = [], 107 | **kwargs, 108 | ): 109 | """Create a SsmParameterWriteStatement 110 | 111 | Arguments 112 | --------- 113 | actions : Sequence[str] or None 114 | Appended to built-in list if provided 115 | resources : Iterable[ssm.IParameter] 116 | SSM parameters to grant read access to 117 | **kwargs : Any 118 | Passed through to PolicyStatement 119 | """ 120 | super().__init__( 121 | actions=["ssm:PutParameter"] + (actions if actions else []), 122 | resources=[p.parameter_arn for p in resources], 123 | **kwargs, 124 | ) 125 | 126 | 127 | class StateMachineExecuteStatement(PolicyStatement): 128 | """Utility class for creating PolicyStatement granting execution of a SFn state machine""" 129 | 130 | def __init__( 131 | self, 132 | actions: Union[str, None] = None, 133 | resources: Iterable[StateMachine] = [], 134 | **kwargs, 135 | ): 136 | """Create a SsmParameterReadStatement 137 | 138 | Arguments 139 | --------- 140 | actions : Sequence[str] or None 141 | Appended to built-in list if provided 142 | resources : Iterable[sfn.StateMachine] 143 | SFn state machines to grant permissions on 144 | **kwargs : Any 145 | Passed through to PolicyStatement 146 | """ 147 | super(StateMachineExecuteStatement, self).__init__( 148 | actions=[ 149 | "states:DescribeExecution", 150 | "states:DescribeStateMachine", 151 | "states:DescribeStateMachineForExecution", 152 | "states:GetExecutionHistory", 153 | "states:ListExecutions", 154 | "states:StartExecution", 155 | "states:StartSyncExecution", 156 | "states:StopExecution", 157 | ] 158 | + (actions if actions else []), 159 | resources=[m.state_machine_arn for m in resources] 160 | + [ 161 | "arn:{}:states:{}:{}:execution:{}:*".format( 162 | Stack.of(m).partition, 163 | Stack.of(m).region, 164 | Stack.of(m).account, 165 | m.state_machine_name, 166 | ) 167 | for m in resources 168 | ], 169 | **kwargs, 170 | ) 171 | -------------------------------------------------------------------------------- /pipeline/ocr/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # SPDX-License-Identifier: MIT-0 3 | """CDK for OCR stage of the document processing pipeline 4 | """ 5 | # Python Built-Ins: 6 | from typing import List, Optional, Union 7 | 8 | # External Dependencies: 9 | from aws_cdk import Token 10 | import aws_cdk.aws_iam as iam 11 | from aws_cdk.aws_s3 import Bucket 12 | from constructs import Construct 13 | 14 | # Local Dependencies: 15 | from .sagemaker_ocr import SageMakerOCRStep 16 | from .textract_ocr import TextractOCRStep 17 | from ..shared.sagemaker import SageMakerCallerFunction 18 | 19 | 20 | class OCRStep(Construct): 21 | """CDK construct for a document pipeline step to OCR incoming documents/images 22 | 23 | This construct's `.sfn_task` expects inputs with $.Input.Bucket and $.Input.Key properties 24 | specifying the location of the raw input document, and will return an object with Bucket and 25 | Key pointing to a consolidated JSON OCR output in Amazon Textract-compatible format. 26 | 27 | In addition to the standard (Amazon Textract-based) option, this construct supports building 28 | and deploying alternative, custom OCR options. Multiple engines may be built and/or deployed (to 29 | support experimentation), but the pipeline must be pointed to exactly one custom SageMaker or 30 | Amazon Textract OCR provider. 31 | """ 32 | 33 | def __init__( 34 | self, 35 | scope: Construct, 36 | id: str, 37 | lambda_role: iam.Role, 38 | ssm_param_prefix: Union[Token, str], 39 | input_bucket: Bucket, 40 | output_bucket: Bucket, 41 | output_prefix: str, 42 | input_prefix: Optional[str] = None, 43 | build_sagemaker_ocrs: List[str] = [], 44 | deploy_sagemaker_ocrs: List[str] = [], 45 | use_sagemaker_ocr: Optional[str] = None, 46 | enable_sagemaker_autoscaling: bool = False, 47 | shared_sagemaker_caller_lambda: Optional[SageMakerCallerFunction] = None, 48 | ): 49 | """Create an OCRStep 50 | 51 | Parameters 52 | ---------- 53 | scope : 54 | CDK construct scope 55 | id : 56 | CDK construct ID 57 | lambda_role : 58 | IAM Role that the Amazon Textract-invoking Lambda function will run with 59 | ssm_param_prefix : 60 | Prefix to be applied to generated SSM pipeline configuration parameter names (including 61 | the parameter to configure SageMaker endpoint name for thumbnail generation). 62 | input_bucket : 63 | Bucket from which input documents will be fetched. If auto-deployment of a thumbnailer 64 | endpoint is enabled, the model execution role will be granted access to this bucket 65 | (limited to `input_prefix`). 66 | output_bucket : 67 | (Pre-existing) S3 bucket where Textract result files should be stored 68 | output_prefix : 69 | Prefix under which Textract result files should be stored in S3 (under this prefix, 70 | the original input document keys will be mapped). 71 | input_prefix : 72 | Prefix under `input_bucket` from which input documents will be fetched. Used to 73 | configure SageMaker model execution role permissions when auto-deployment of thumbnailer 74 | endpoint is enabled. 75 | build_sagemaker_ocrs : 76 | List of alternative (SageMaker-based) OCR engine names to build container images and 77 | SageMaker Models for in the deployed stack. By default ([]), none will be included. See 78 | `CUSTOM_OCR_ENGINES` in pipeline/ocr/sagemaker_ocr.py for supported engines. 79 | deploy_sagemaker_ocrs : 80 | List of alternative OCR engine names to deploy SageMaker endpoints for in the stack. Any 81 | names in here must also be included in `build_sagemaker_ocrs`. Default []: Support 82 | Amazon Textract OCR only. 83 | use_sagemaker_ocr : 84 | Optional alternative OCR engine name to use in the deployed document pipeline. If set 85 | and not empty, this must also be present in `build_sagemaker_ocrs` and 86 | `deploy_sagemaker_ocrs`. Default None: Use Amazon Textract for initial document OCR. 87 | enable_sagemaker_autoscaling : 88 | Set True to enable auto-scaling on SageMaker OCR endpoints (if any are deployed), to 89 | optimize resource usage (recommended for production use). Set False to disable it and 90 | avoid cold-starts (good for development). 91 | shared_sagemaker_caller_lambda : 92 | Optional pre-existing SageMaker caller Lambda function, to share this between multiple 93 | SageMakerSSMSteps in the app if required. 94 | """ 95 | super().__init__(scope, id) 96 | 97 | if len(build_sagemaker_ocrs) > 0: 98 | self.sagemaker_step = SageMakerOCRStep( 99 | self, 100 | "SageMakerStep", 101 | lambda_role=lambda_role, 102 | ssm_param_prefix=ssm_param_prefix, 103 | input_bucket=input_bucket, 104 | ocr_results_bucket=output_bucket, 105 | input_prefix=input_prefix, 106 | ocr_results_prefix=output_prefix, 107 | build_engine_names=build_sagemaker_ocrs, 108 | deploy_engine_names=deploy_sagemaker_ocrs, 109 | use_engine_name=use_sagemaker_ocr, 110 | enable_autoscaling=enable_sagemaker_autoscaling, 111 | shared_sagemaker_caller_lambda=shared_sagemaker_caller_lambda, 112 | ) 113 | else: 114 | self.sagemaker_step = None 115 | 116 | self.textract_step = TextractOCRStep( 117 | self, 118 | "TextractStep", 119 | lambda_role=lambda_role, 120 | output_bucket=output_bucket, 121 | output_prefix=output_prefix, 122 | ) 123 | 124 | if use_sagemaker_ocr and self.sagemaker_step: 125 | self.sfn_task = self.sagemaker_step.sfn_task 126 | else: 127 | self.sfn_task = self.textract_step.sfn_task 128 | 129 | @property 130 | def textract_state_machine(self): 131 | return self.textract_step.textract_state_machine 132 | -------------------------------------------------------------------------------- /pipeline/postprocessing/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # SPDX-License-Identifier: MIT-0 3 | """CDK for rule-based post-processing stage of the OCR pipeline 4 | """ 5 | # Python Built-Ins: 6 | import json 7 | from typing import Union 8 | 9 | # External Dependencies: 10 | from aws_cdk import Duration, Token 11 | from aws_cdk.aws_iam import Effect, PolicyStatement, Role 12 | from aws_cdk.aws_lambda import Runtime as LambdaRuntime 13 | from aws_cdk.aws_lambda_python_alpha import PythonFunction 14 | import aws_cdk.aws_ssm as ssm 15 | import aws_cdk.aws_stepfunctions as sfn 16 | import aws_cdk.aws_stepfunctions_tasks as sfn_tasks 17 | from constructs import Construct 18 | 19 | # Local Dependencies: 20 | from ..shared import abs_path 21 | 22 | 23 | POSTPROC_LAMBDA_PATH = abs_path("fn-postprocess", __file__) 24 | 25 | # Not technically necessary as the notebook guides users to configure this through AWS SSM, but 26 | # useful to set the defaults per the notebook for speedy setup: 27 | DEFAULT_ENTITY_CONFIG = [ 28 | { 29 | "ClassId": 0, 30 | "Name": "Agreement Effective Date", 31 | "Optional": True, 32 | "Select": "first", 33 | }, 34 | { 35 | "ClassId": 1, 36 | "Name": "APR - Introductory", 37 | "Optional": True, 38 | "Select": "confidence", 39 | }, 40 | { 41 | "ClassId": 2, 42 | "Name": "APR - Balance Transfers", 43 | "Optional": True, 44 | "Select": "confidence", 45 | }, 46 | { 47 | "ClassId": 3, 48 | "Name": "APR - Cash Advances", 49 | "Optional": True, 50 | "Select": "confidence", 51 | }, 52 | { 53 | "ClassId": 4, 54 | "Name": "APR - Purchases", 55 | "Optional": True, 56 | "Select": "confidence", 57 | }, 58 | { 59 | "ClassId": 5, 60 | "Name": "APR - Penalty", 61 | "Optional": True, 62 | "Select": "confidence", 63 | }, 64 | { 65 | "ClassId": 6, 66 | "Name": "APR - General", 67 | "Optional": True, 68 | "Select": "confidence", 69 | }, 70 | { 71 | "ClassId": 7, 72 | "Name": "APR - Other", 73 | "Optional": True, 74 | "Select": "confidence", 75 | }, 76 | { 77 | "ClassId": 8, 78 | "Name": "Fee - Annual", 79 | "Optional": True, 80 | "Select": "confidence", 81 | }, 82 | { 83 | "ClassId": 9, 84 | "Name": "Fee - Balance Transfer", 85 | "Optional": True, 86 | "Select": "confidence", 87 | }, 88 | { 89 | "ClassId": 10, 90 | "Name": "Fee - Late Payment", 91 | "Optional": True, 92 | "Select": "confidence", 93 | }, 94 | { 95 | "ClassId": 11, 96 | "Name": "Fee - Returned Payment", 97 | "Optional": True, 98 | "Select": "confidence", 99 | }, 100 | { 101 | "ClassId": 12, 102 | "Name": "Fee - Foreign Transaction", 103 | "Optional": True, 104 | "Select": "shortest", 105 | }, 106 | { 107 | "ClassId": 13, 108 | "Name": "Fee - Other", 109 | "Ignore": True, 110 | }, 111 | { 112 | "ClassId": 14, 113 | "Name": "Card Name", 114 | }, 115 | { 116 | "ClassId": 15, 117 | "Name": "Provider Address", 118 | "Optional": True, 119 | "Select": "confidence", 120 | }, 121 | { 122 | "ClassId": 16, 123 | "Name": "Provider Name", 124 | "Select": "longest", 125 | }, 126 | { 127 | "ClassId": 17, 128 | "Name": "Min Payment Calculation", 129 | "Ignore": True, 130 | }, 131 | { 132 | "ClassId": 18, 133 | "Name": "Local Terms", 134 | "Ignore": True, 135 | }, 136 | ] 137 | 138 | 139 | class LambdaPostprocStep(Construct): 140 | """CDK construct for an OCR pipeline step consolidate document fields from enriched OCR JSON 141 | 142 | This construct's `.sfn_task` expects inputs with $.Textract.Bucket and $.Textract.Key 143 | properties, and will process this object with a Lambda function to add a $.ModelResult object 144 | to the output state: Consolidating detections of the different fields as defined by the 145 | field/entity configuration JSON in AWS SSM. 146 | """ 147 | 148 | def __init__( 149 | self, 150 | scope: Construct, 151 | id: str, 152 | lambda_role: Role, 153 | ssm_param_prefix: Union[Token, str], 154 | **kwargs, 155 | ): 156 | super().__init__(scope, id, **kwargs) 157 | self.entity_config_param = ssm.StringParameter( 158 | self, 159 | "EntityConfigParam", 160 | description=( 161 | "JSON configuration describing the field types to be extracted by the pipeline" 162 | ), 163 | parameter_name=f"{ssm_param_prefix}EntityConfiguration", 164 | simple_name=False, 165 | string_value=json.dumps(DEFAULT_ENTITY_CONFIG, indent=2), 166 | ) 167 | lambda_role.add_to_policy( 168 | PolicyStatement( 169 | sid="ReadSSMEntityConfigParam", 170 | actions=["ssm:GetParameter"], 171 | effect=Effect.ALLOW, 172 | resources=[self.entity_config_param.parameter_arn], 173 | ) 174 | ) 175 | self.caller_lambda = PythonFunction( 176 | self, 177 | "PostProcessFn", 178 | description="Post-process SageMaker-enriched Textract JSON to extract business fields", 179 | entry=POSTPROC_LAMBDA_PATH, 180 | environment={ 181 | "DEFAULT_ENTITY_CONFIG_PARAM": self.entity_config_param.parameter_name, 182 | }, 183 | index="main.py", 184 | handler="handler", 185 | memory_size=1024, 186 | role=lambda_role, 187 | runtime=LambdaRuntime.PYTHON_3_9, 188 | timeout=Duration.seconds(120), 189 | ) 190 | 191 | self.sfn_task = sfn_tasks.LambdaInvoke( 192 | self, 193 | "PostProcess", 194 | comment="Post-Process the enriched Textract data to your business-level fields", 195 | lambda_function=self.caller_lambda, 196 | payload=sfn.TaskInput.from_object( 197 | { 198 | "Input": { 199 | "Bucket": sfn.JsonPath.string_at("$.Textract.Bucket"), 200 | "Key": sfn.JsonPath.string_at("$.Textract.Key"), 201 | }, 202 | } 203 | ), 204 | payload_response_only=True, 205 | result_path="$.ModelResult", 206 | ) 207 | -------------------------------------------------------------------------------- /pipeline/postprocessing/fn-postprocess/main.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # SPDX-License-Identifier: MIT-0 3 | """Lambda to extract business fields from SageMaker-enriched Textract result 4 | 5 | Env var DEFAULT_ENTITY_CONFIG (or the contents of SSM parameter whose name is given by 6 | DEFAULT_ENTITY_CONFIG_PARAM) should be a JSON list of objects roughly as: 7 | 8 | ```python 9 | [ 10 | { 11 | "ClassId": 0, # (int, required) ID of the class per SageMaker model 12 | "Name": "...", # (str, required) Human-readable name of the entity/field 13 | "Ignore": True, # (bool, optional) Set true to ignore detections of this field 14 | "Optional": True, # (bool, optional) Set true to indicate param is optional 15 | "Select": "..." # (str, optional) name util.config.FieldSelectors 16 | } 17 | ] 18 | ``` 19 | 20 | If "Select" is not specified, the field is determined to be multi-value. 21 | 22 | For full entity configuration details, refer to util.config.FieldConfiguration 23 | """ 24 | 25 | # Python Built-Ins: 26 | from functools import reduce 27 | import json 28 | import logging 29 | import os 30 | 31 | # External Dependencies: 32 | import boto3 # General-purpose AWS SDK for Python 33 | import trp # Amazon Textract Response Parser 34 | 35 | # Set up logging before local imports: 36 | logger = logging.getLogger() 37 | logger.setLevel(logging.INFO) 38 | 39 | # Local Dependencies 40 | from util.config import FieldConfiguration 41 | from util.extract import extract_entities 42 | from util.normalize import normalize_detections 43 | 44 | 45 | s3 = boto3.resource("s3") 46 | ssm = boto3.client("ssm") 47 | 48 | DEFAULT_ENTITY_CONFIG = os.environ.get("DEFAULT_ENTITY_CONFIG") 49 | if DEFAULT_ENTITY_CONFIG is not None: 50 | DEFAULT_ENTITY_CONFIG = json.loads(DEFAULT_ENTITY_CONFIG) 51 | DEFAULT_ENTITY_CONFIG_PARAM = os.environ.get("DEFAULT_ENTITY_CONFIG_PARAM") 52 | 53 | 54 | class MalformedRequest(ValueError): 55 | pass 56 | 57 | 58 | def handler(event, context): 59 | try: 60 | srcbucket = event["Input"]["Bucket"] 61 | srckey = event["Input"]["Key"] 62 | entity_config = event.get("EntityConfig", DEFAULT_ENTITY_CONFIG) 63 | if entity_config is None and DEFAULT_ENTITY_CONFIG_PARAM: 64 | entity_config = json.loads( 65 | ssm.get_parameter(Name=DEFAULT_ENTITY_CONFIG_PARAM)["Parameter"]["Value"] 66 | ) 67 | except KeyError as ke: 68 | raise MalformedRequest(f"Missing field {ke}, please check your input payload") from ke 69 | if entity_config is None: 70 | raise MalformedRequest( 71 | "Request did not specify EntityConfig, and neither env var DEFAULT_ENTITY_CONFIG (for " 72 | "inline json) nor DEFAULT_ENTITY_CONFIG_PARAM (for SSM parameter) are set" 73 | ) 74 | entity_config = [FieldConfiguration.from_dict(cfg) for cfg in entity_config] 75 | 76 | doc = json.loads(s3.Bucket(srcbucket).Object(srckey).get()["Body"].read()) 77 | doc = trp.Document(doc) 78 | 79 | # Pull out the entities from the Amazon Textract-format doc: 80 | entities = extract_entities(doc, entity_config) 81 | # Normalize entity values, if any per-type normalizations are configured: 82 | normalize_detections(entities, entity_config) 83 | 84 | result_fields = {} 85 | for ixtype, cfg in enumerate(cfg for cfg in entity_config if not cfg.ignore): 86 | # Filter the list of detected entity mentions for this class only: 87 | field_entities = list(filter(lambda e: e.cls_id == cfg.class_id, entities)) 88 | 89 | # Consolidate multiple detections of exactly the same value (text): 90 | field_values = {} 91 | for ixe, e in enumerate(field_entities): 92 | if e.text in field_values: 93 | field_values[e.text]["Detections"].append(e) 94 | field_values[e.text]["IxLastDetection"] = ixe 95 | else: 96 | field_values[e.text] = { 97 | "Text": e.text, 98 | "Detections": [e], 99 | "IxFirstDetection": ixe, 100 | "IxLastDetection": ixe, 101 | } 102 | field_values_list = [v for v in field_values.values()] 103 | # To approximate confidence for values detected multiple times, model each detection as an 104 | # uncorrelated observation of that value (naive, probably biased to over-estimate): 105 | for v in field_values_list: 106 | # e.g. {0.84, 0.86, 0.90} -> 1 - (0.16 * 0.14 * 0.1) = 0.998 107 | v["Confidence"] = 1 - reduce( 108 | lambda acc, next: acc * (1 - next.confidence), 109 | v["Detections"], 110 | 1.0, 111 | ) 112 | # TODO: Adjust for other (disagreeing) confidences better 113 | value_conf_norm = reduce(lambda acc, next: acc + next["Confidence"], field_values_list, 0.0) 114 | for v in field_values_list: 115 | v["Confidence"] = v["Confidence"] / max(1.0, value_conf_norm) 116 | 117 | field_result = { 118 | "ClassId": cfg.class_id, 119 | "Confidence": 0.0, 120 | "NumDetections": len(field_entities), 121 | "NumDetectedValues": len(field_values), 122 | "SortOrder": ixtype, 123 | } 124 | result_fields[cfg.name] = field_result 125 | if cfg.optional is not None: 126 | field_result["Optional"] = cfg.optional 127 | 128 | if cfg.select is not None: 129 | # Single-valued field: Select 'best' matched values: 130 | selector = cfg.select 131 | field_values_sorted = sorted( 132 | field_values_list, 133 | key=selector.sort, 134 | reverse=selector.desc, 135 | ) 136 | if len(field_values_sorted): 137 | field_result["Value"] = field_values_sorted[0]["Text"] 138 | field_result["Confidence"] = field_values_sorted[0]["Confidence"] 139 | field_result["Detections"] = list( 140 | map( 141 | lambda e: e.to_dict(), 142 | field_values_sorted[0]["Detections"], 143 | ) 144 | ) 145 | else: 146 | field_result["Value"] = "" 147 | field_result["Detections"] = [] 148 | else: 149 | # Multi-valued field: Pass through all matched values 150 | field_result["Values"] = list( 151 | map( 152 | lambda v: { 153 | "Confidence": v["Confidence"], 154 | "Value": v["Text"], 155 | "Detections": list( 156 | map( 157 | lambda e: e.to_dict(), 158 | v["Detections"], 159 | ) 160 | ), 161 | }, 162 | sorted(field_values_list, key=lambda v: v["Confidence"], reverse=True), 163 | ) 164 | ) 165 | if len(field_result["Values"]): 166 | # For multi value, take field confidence = average value confidence 167 | field_result["Confidence"] = reduce( 168 | lambda acc, next: acc + next["Confidence"], 169 | field_result["Values"], 170 | 0.0, 171 | ) / len(field_result["Values"]) 172 | 173 | return { 174 | "Confidence": min( 175 | r["Confidence"] 176 | for r in result_fields.values() 177 | if not (r["Confidence"] == 0 and r.get("Optional")) 178 | ), 179 | "Fields": result_fields, 180 | } 181 | -------------------------------------------------------------------------------- /pipeline/postprocessing/fn-postprocess/requirements.txt: -------------------------------------------------------------------------------- 1 | amazon-textract-response-parser~=0.1 2 | -------------------------------------------------------------------------------- /pipeline/postprocessing/fn-postprocess/util/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aws-samples/amazon-textract-transformer-pipeline/06b39d69023e584da8daaf2b0ef44d31465b05e8/pipeline/postprocessing/fn-postprocess/util/__init__.py -------------------------------------------------------------------------------- /pipeline/postprocessing/fn-postprocess/util/config.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # SPDX-License-Identifier: MIT-0 3 | """Document field/entity configuration definition utilities 4 | """ 5 | # Python Built-Ins: 6 | from enum import Enum 7 | from typing import Callable, Optional 8 | 9 | # Local Dependencies: 10 | from .deser import PascalJsonableDataClass 11 | 12 | 13 | class FieldSelectionMethod: 14 | def __init__(self, name: str, sort: Callable, desc: bool = False): 15 | self.name = name 16 | self.sort = sort 17 | self.desc = desc 18 | 19 | def to_dict(self): 20 | return self.name 21 | 22 | 23 | class FieldSelectionMethods(Enum): 24 | CONFIDENCE = FieldSelectionMethod("confidence", lambda v: v["Confidence"], desc=True) 25 | FIRST = FieldSelectionMethod("first", lambda v: v["IxFirstDetection"]) 26 | LAST = FieldSelectionMethod("last", lambda v: v["IxLastDetection"], desc=True) 27 | LONGEST = FieldSelectionMethod("longest", lambda v: len(v["Text"]), desc=True) 28 | SHORTEST = FieldSelectionMethod("shortest", lambda v: len(v["Text"])) 29 | 30 | 31 | class FieldConfiguration(PascalJsonableDataClass): 32 | """A JSON-serializable configuration for a field/entity type""" 33 | 34 | def __init__( 35 | self, 36 | class_id: int, 37 | name: str, 38 | ignore: Optional[bool] = None, 39 | optional: Optional[bool] = None, 40 | select: Optional[str] = None, 41 | annotation_guidance: Optional[str] = None, 42 | normalizer_endpoint: Optional[str] = None, 43 | normalizer_prompt: Optional[str] = None, 44 | ): 45 | """Create a FieldConfiguration 46 | 47 | Parameters 48 | ---------- 49 | class_id : int 50 | The ID number (ordinal) of the class per the machine learning model 51 | name : str 52 | The human-readable name of the class / entity type 53 | ignore : Optional[bool] 54 | Set True to exclude this field from post-processing in the OCR pipeline (the ML model 55 | will still be trained on it). Useful if for e.g. testing a new field type with unknown 56 | detection quality. 57 | optional : Optional[bool] 58 | Set True to explicitly indicate the field is optional (default None) 59 | select : Optional[str] 60 | A (case insensitive) name from the FieldSelectionMethods enum (e.g. 'confidence') to 61 | indicate how the "winning" detected value of a field should be selected. If omitted, 62 | the field is treated as multi-value and all detected values passed through. 63 | annotation_guidance : Optional[str] 64 | HTML-tagged guidance detailing the specific scope for this entity: I.e. what should 65 | and should not be included for consistent labelling. 66 | normalizer_endpoint : Optional[str] 67 | An optional deployed SageMaker seq2seq endpoint for field value normalization, if one 68 | should be used (You'll have to train and deploy this endpoint separately). 69 | normalizer_prompt : Optional[str] 70 | The prompting prefix for the seq2seq field value normalization requests on this field, 71 | if enabled. For example, "Convert dates to YYYY-MM-DD: " 72 | """ 73 | self.class_id = class_id 74 | self.name = name 75 | self.ignore = ignore 76 | self.optional = optional 77 | self.annotation_guidance = annotation_guidance 78 | self.normalizer_endpoint = normalizer_endpoint 79 | self.normalizer_prompt = normalizer_prompt 80 | try: 81 | self.select = FieldSelectionMethods[select.upper()].value if select else None 82 | except KeyError as e: 83 | raise ValueError( 84 | "Selection method '{}' configured for field '{}' not in the known list {}".format( 85 | select, 86 | name, 87 | [fsm.name for fsm in FieldSelectionMethods], 88 | ) 89 | ) from e 90 | if bool(self.normalizer_endpoint) ^ bool(self.normalizer_prompt): 91 | raise ValueError( 92 | "Cannot provide only one of `normalizer_endpoint` and `normalizer_prompt` without " 93 | "setting both. Got: '%s' and '%s'" 94 | % (self.normalizer_endpoint, self.normalizer_prompt) 95 | ) 96 | -------------------------------------------------------------------------------- /pipeline/postprocessing/fn-postprocess/util/deser.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # SPDX-License-Identifier: MIT-0 3 | """Utilities for de/serializing typed Python objects to JSON-able dicts and actual JSON strings. 4 | """ 5 | # Python Built-Ins: 6 | import json 7 | import re 8 | from typing import Iterable, Optional 9 | 10 | 11 | def pascal_to_snake_case(s: str) -> str: 12 | """Convert a string from PascalCase to snake_case""" 13 | if not s: 14 | return s 15 | # Interpret sequences of 2+ uppercase chars as acronyms to be wordified. e.g. HTML -> Html 16 | result = re.sub(r"([A-Z])([A-Z]+)", lambda m: m.group(1) + m.group(2).lower(), s) 17 | # Find any uppercase character(s) following a lowercase character, insert an underscore and 18 | # convert. E.g. aA -> a_a, MyHtmlThing -> my_html_thing 19 | # Replace any aA combo with a_a: 20 | result = re.sub( 21 | r"([^A-Z])([A-Z]+)", 22 | lambda m: "_".join((m.group(1), m.group(2).lower())), 23 | result, 24 | ) 25 | # Force lowercase first char: 26 | return result[0].lower() + result[1:] 27 | 28 | 29 | def snake_to_pascal_case(s: str) -> str: 30 | """Convert a string from snake_case to PascalCase""" 31 | if not s: 32 | return s 33 | return "".join( 34 | map( 35 | lambda segment: (segment[0].upper() + segment[1:]) if segment else segment, 36 | s.split("_"), 37 | ), 38 | ) 39 | 40 | 41 | class PascalJsonableDataClass: 42 | """Mixin to make a class with snake_case attrs interop with JSON/dicts with PascalCase attrs 43 | 44 | from_dict maps dict keys to constructor args { "MyProp": 1 } -> __init__(my_prop=1) 45 | 46 | to_dict maps data properties (as enumerated by __dict__) to dict keys 47 | 48 | from_json/to_json methods simply wrap the above with json.loads() / json.dumps() 49 | """ 50 | 51 | @classmethod 52 | def from_dict(cls, d: dict): 53 | kwargs = {pascal_to_snake_case(k): v for k, v in d.items()} 54 | return cls(**kwargs) 55 | 56 | @classmethod 57 | def from_json(cls, s: str): 58 | return cls.from_dict(json.loads(s)) 59 | 60 | def to_dict(self, omit: Optional[Iterable[str]] = None): 61 | if not omit: 62 | omit = [] 63 | return { 64 | snake_to_pascal_case(attr): value.to_dict() if hasattr(value, "to_dict") else value 65 | for attr, value in filter( 66 | lambda kv: not (kv[1] is None or kv[0].startswith("_") or kv[0] in omit), 67 | self.__dict__.items(), 68 | ) 69 | } 70 | 71 | def to_json(self, omit: Optional[Iterable[str]] = None): 72 | return json.dumps(self.to_dict(omit=omit)) 73 | -------------------------------------------------------------------------------- /pipeline/postprocessing/fn-postprocess/util/extract.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # SPDX-License-Identifier: MIT-0 3 | """Utils to extract entity mentions from SageMaker Textract WORD-tagging model results 4 | 5 | As a simple heuristic, consecutive WORD blocks of the same tagged entity class are tagged as 6 | belonging to the same mention. This means that in cases where the normal human reading order 7 | diverges from the Amazon Textract block output order, mentions may get split up. 8 | """ 9 | # Python Built-Ins: 10 | import json 11 | from typing import List, Optional, Sequence 12 | 13 | # External Dependencies: 14 | import trp # Amazon Textract Response Parser 15 | 16 | # Local Dependencies: 17 | from .boxes import UniversalBox 18 | from .config import FieldConfiguration 19 | 20 | 21 | class EntityDetection: 22 | """Object describing an entity mention in a document 23 | 24 | If property `raw_text` (or 'RawText' in the JSON-ified equivalent) is set, this mention has 25 | been normalized. Otherwise, `text` is as per the original document. 26 | """ 27 | 28 | raw_text: Optional[str] 29 | 30 | def __init__(self, trp_words: Sequence[trp.Word], cls_id: int, cls_name: str, page_num: int): 31 | self.cls_id = cls_id 32 | self.cls_name = cls_name 33 | self.page_num = page_num 34 | 35 | if len(trp_words) and not hasattr(trp_words[0], "id"): 36 | trp_words_by_line = trp_words 37 | trp_words_flat = [w for ws in trp_words for w in ws] 38 | 39 | else: 40 | trp_words_by_line = [trp_words] 41 | trp_words_flat = trp_words 42 | self.bbox = UniversalBox.aggregate( 43 | boxes=[UniversalBox(box=w.geometry.boundingBox) for w in trp_words_flat], 44 | ) 45 | self.blocks = list(map(lambda w: w.id, trp_words_flat)) 46 | self.confidence = min( 47 | map( 48 | lambda w: min( 49 | w._block.get("PredictedClassConfidence", 1.0), 50 | w.confidence, 51 | ), 52 | trp_words_flat, 53 | ) 54 | ) 55 | self.text = "\n".join( 56 | map( 57 | lambda words: " ".join([w.text for w in words]), 58 | trp_words_by_line, 59 | ) 60 | ) 61 | self.raw_text = None 62 | 63 | def normalize(self, normalized_text: str) -> None: 64 | """Update the detection with a new normalized text value 65 | 66 | Only the original raw_text value will be preserved, so if you normalize() multiple times no 67 | record of the intermediate normalized_text values will be kept. 68 | """ 69 | if self.raw_text is None: 70 | self.raw_text = self.text 71 | # Otherwise keep original 'raw' text (normalize called multiple times) 72 | self.text = normalized_text 73 | 74 | def to_dict(self) -> dict: 75 | """Represent this mention as a PascalCase JSON-able object""" 76 | result = { 77 | "ClassId": self.cls_id, 78 | "ClassName": self.cls_name, 79 | "Confidence": self.confidence, 80 | "Blocks": self.blocks, 81 | "BoundingBox": self.bbox.to_dict(), 82 | "PageNum": self.page_num, 83 | "Text": self.text, 84 | } 85 | if self.raw_text is not None: 86 | result["RawText"] = self.raw_text 87 | return result 88 | 89 | def __repr__(self) -> str: 90 | return json.dumps(self.to_dict()) 91 | 92 | 93 | def extract_entities( 94 | doc: trp.Document, 95 | entity_config: List[FieldConfiguration], 96 | ) -> List[EntityDetection]: 97 | """Collect EntityDetections from an NER-enriched Textract JSON doc into a flat list""" 98 | entity_classes = {c.class_id: c.name for c in entity_config if not c.ignore} 99 | detections = [] 100 | 101 | current_cls = None 102 | current_entity = [] 103 | for ixpage, page in enumerate(doc.pages): 104 | for line in page.lines: # TODO: Lines InReadingOrder? 105 | current_entity.append([]) 106 | for word in line.words: 107 | pred_cls = word._block.get("PredictedClass") 108 | if pred_cls not in entity_classes: 109 | pred_cls = None # Treat all non-config'd entities as "other" 110 | 111 | if pred_cls != current_cls: 112 | if current_cls is not None: 113 | detections.append( 114 | EntityDetection( 115 | trp_words=list( 116 | filter( 117 | lambda ws: len(ws), 118 | current_entity, 119 | ) 120 | ), 121 | cls_id=current_cls, 122 | cls_name=entity_classes[current_cls], 123 | page_num=ixpage + 1, 124 | ) 125 | ) 126 | current_cls = pred_cls 127 | current_entity = [[]] if pred_cls is None else [[word]] 128 | elif pred_cls is not None: 129 | current_entity[-1].append(word) 130 | 131 | return detections 132 | -------------------------------------------------------------------------------- /pipeline/postprocessing/fn-postprocess/util/normalize.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # SPDX-License-Identifier: MIT-0 3 | """Utils to normalize detected entity text by calling SageMaker sequence-to-sequence model endpoints 4 | 5 | `normalizer_endpoint` on a FieldConfiguration is assumed to be a deployed real-time SageMaker 6 | endpoint that accepts batched 'application/json' requests of structure: 7 | `{"inputs": ["list", "of", "strings"]}`, and returns 'application/json' responses of structure: 8 | `{"generated_text": ["corresponding", "result", "strings"]}` 9 | """ 10 | # Python Built-Ins: 11 | import json 12 | from logging import getLogger 13 | from typing import Dict, List, Sequence 14 | 15 | # External Dependencies: 16 | import boto3 # General-purpose AWS SDK for Python 17 | 18 | # Local Dependencies: 19 | from .config import FieldConfiguration 20 | from .extract import EntityDetection 21 | 22 | logger = getLogger("postproc") 23 | smruntime = boto3.client("sagemaker-runtime") 24 | 25 | 26 | def normalize_detections( 27 | detections: Sequence[EntityDetection], 28 | entity_config: Sequence[FieldConfiguration], 29 | ) -> None: 30 | """Normalize detected entities in-place via batched requests to SageMaker normalizer endpoint(s) 31 | 32 | Due to the high likelihood of one document featuring multiple matches of the same text for the 33 | same entity class, we de-duplicate requests by target endpoint and input text - and duplicate 34 | the result across all linked detections. 35 | """ 36 | entity_config_by_clsid = {c.class_id: c for c in entity_config if not c.ignore} 37 | 38 | # Batched normalization requests: 39 | # - By target endpoint name 40 | # - By input text (after adding prompt prefix) 41 | # - List of which detections (indexes) correspond to the request 42 | norm_requests: Dict[str, Dict[str, List[int]]] = {} 43 | 44 | # Collect required normalization requests from the detections: 45 | for ixdet, detection in enumerate(detections): 46 | config = entity_config_by_clsid.get(detection.cls_id) 47 | if not config: 48 | continue # Ignore any detections in non-configured classes 49 | if not config.normalizer_endpoint: 50 | continue # This entity class configuration has no normalizer 51 | if config.normalizer_endpoint not in norm_requests: 52 | norm_requests[config.normalizer_endpoint] = {} 53 | 54 | norm_input_text = config.normalizer_prompt + detection.text 55 | if norm_input_text in norm_requests[config.normalizer_endpoint]: 56 | norm_requests[config.normalizer_endpoint][norm_input_text].append(ixdet) 57 | else: 58 | norm_requests[config.normalizer_endpoint][norm_input_text] = [ixdet] 59 | 60 | # Call out to the SageMaker endpoints and update the detections with the results: 61 | for endpoint_name in norm_requests: 62 | req_dict = norm_requests[endpoint_name] 63 | input_texts = [k for k in req_dict] 64 | try: 65 | norm_resp = smruntime.invoke_endpoint( 66 | EndpointName=endpoint_name, 67 | Body=json.dumps( 68 | { 69 | "inputs": input_texts, 70 | } 71 | ), 72 | ContentType="application/json", 73 | Accept="application/json", 74 | ) 75 | # Response should be JSON dict containing list 'generated_text' of outputs: 76 | output_texts = json.loads(norm_resp["Body"].read())["generated_text"] 77 | except Exception: 78 | # Log the failure, but continue on: 79 | logger.exception( 80 | "Entity normalization call failed: %s texts to endpoint '%s'", 81 | len(input_texts), 82 | endpoint_name, 83 | ) 84 | continue 85 | 86 | for ixtext, output in enumerate(output_texts): 87 | for ixdetection in req_dict[input_texts[ixtext]]: 88 | detections[ixdetection].normalize(output) 89 | 90 | # Return nothing to explicitly indicate that detections are modified in-place 91 | return 92 | -------------------------------------------------------------------------------- /pipeline/review/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # SPDX-License-Identifier: MIT-0 3 | """CDK for human review stage of the OCR pipeline 4 | """ 5 | # Python Built-Ins: 6 | from typing import Union 7 | 8 | # External Dependencies: 9 | from aws_cdk import Duration, Token 10 | from aws_cdk.aws_iam import Effect, PolicyStatement, Role, ServicePrincipal 11 | from aws_cdk.aws_lambda import Runtime as LambdaRuntime 12 | from aws_cdk.aws_lambda_python_alpha import PythonFunction 13 | from aws_cdk.aws_s3 import Bucket, EventType 14 | import aws_cdk.aws_s3_notifications as s3n 15 | import aws_cdk.aws_ssm as ssm 16 | import aws_cdk.aws_stepfunctions as sfn 17 | import aws_cdk.aws_stepfunctions_tasks as sfn_tasks 18 | from constructs import Construct 19 | 20 | # Local Dependencies: 21 | from ..shared import abs_path 22 | from ..shared.sagemaker import get_sagemaker_default_bucket 23 | 24 | 25 | START_REVIEW_LAMBDA_PATH = abs_path("fn-start-review", __file__) 26 | REVIEW_CALLBACK_LAMBDA_PATH = abs_path("fn-review-callback", __file__) 27 | 28 | 29 | class A2IReviewStep(Construct): 30 | """CDK construct for an OCR pipeline step to have humans review extracted document fields 31 | 32 | This construct's `.sfn_task` expects inputs with a $.ModelResult object and an input document 33 | specified by $.Input.Bucket and $.Input.Key. An Amazon A2I Human Review Loop will be triggered 34 | to manually review the model's prediction, and the Step Function execution will be resumed when 35 | the review is complete, with an updated $.ModelResult in the output state. 36 | """ 37 | 38 | def __init__( 39 | self, 40 | scope: Construct, 41 | id: str, 42 | lambda_role: Role, 43 | input_bucket: Bucket, 44 | reviews_bucket: Bucket, 45 | ssm_param_prefix: Union[Token, str], 46 | **kwargs, 47 | ): 48 | super().__init__(scope, id, **kwargs) 49 | 50 | self.workflow_param = ssm.StringParameter( 51 | self, 52 | "A2IHumanReviewFlowParam", 53 | description="ARN of the Amazon A2I workflow definition to call for human reviews", 54 | parameter_name=f"{ssm_param_prefix}HumanReviewFlowArn", 55 | simple_name=False, 56 | string_value="undefined", 57 | ) 58 | lambda_role.add_to_policy( 59 | PolicyStatement( 60 | sid="ReadSSMWorkflowParam", 61 | actions=["ssm:GetParameter"], 62 | effect=Effect.ALLOW, 63 | resources=[self.workflow_param.parameter_arn], 64 | ) 65 | ) 66 | lambda_role.add_to_policy( 67 | PolicyStatement( 68 | sid="StartAnyA2IHumanLoop", 69 | actions=["sagemaker:StartHumanLoop"], 70 | effect=Effect.ALLOW, 71 | resources=["*"], 72 | ) 73 | ) 74 | 75 | self.start_lambda = PythonFunction( 76 | self, 77 | "StartHumanReview", 78 | description="Kick off A2I human review for OCR processing pipeline", 79 | entry=START_REVIEW_LAMBDA_PATH, 80 | environment={ 81 | "DEFAULT_FLOW_DEFINITION_ARN_PARAM": self.workflow_param.parameter_name, 82 | }, 83 | index="main.py", 84 | handler="handler", 85 | memory_size=128, 86 | role=lambda_role, 87 | runtime=LambdaRuntime.PYTHON_3_9, 88 | timeout=Duration.seconds(10), 89 | ) 90 | self.callback_lambda = PythonFunction( 91 | self, 92 | "HumanReviewCallback", 93 | description="Return A2I human review result to OCR processing pipeline Step Function", 94 | entry=REVIEW_CALLBACK_LAMBDA_PATH, 95 | index="main.py", 96 | handler="handler", 97 | memory_size=128, 98 | role=lambda_role, 99 | runtime=LambdaRuntime.PYTHON_3_9, 100 | timeout=Duration.seconds(60), 101 | ) 102 | self.a2i_role = Role( 103 | self, 104 | "ProcessingPipelineA2IRole", 105 | assumed_by=ServicePrincipal("sagemaker.amazonaws.com"), 106 | description="Execution Role for Amazon A2I human review workflows", 107 | ) 108 | input_bucket.grant_read(self.a2i_role) 109 | reviews_bucket.grant_read_write(self.a2i_role) 110 | get_sagemaker_default_bucket(self).grant_read_write(self.a2i_role) 111 | 112 | reviews_bucket.add_event_notification( 113 | dest=s3n.LambdaDestination(self.callback_lambda), 114 | event=EventType.OBJECT_CREATED, 115 | ) 116 | 117 | self.sfn_task = sfn_tasks.LambdaInvoke( 118 | self, 119 | "HumanReview", 120 | comment="Run an Amazon A2I human loop to review the annotations manually", 121 | lambda_function=self.start_lambda, 122 | integration_pattern=sfn.IntegrationPattern.WAIT_FOR_TASK_TOKEN, 123 | payload=sfn.TaskInput.from_object( 124 | { 125 | "ModelResult.$": "$.ModelResult", 126 | # TODO: Can we add a pass state with Parameters to filter the inputs? 127 | "TaskObject": { 128 | "Bucket": sfn.JsonPath.string_at("$.Input.Bucket"), 129 | "Key": sfn.JsonPath.string_at("$.Input.Key"), 130 | }, 131 | "TaskToken": sfn.JsonPath.task_token, 132 | } 133 | ), 134 | result_path="$.ModelResult", 135 | timeout=Duration.minutes(20), 136 | ) 137 | -------------------------------------------------------------------------------- /pipeline/review/fn-start-review/main.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # SPDX-License-Identifier: MIT-0 3 | """Lambda to start an A2I human loop to review a non-confident model output 4 | 5 | Should be called as an *asynchronous* task from Step Functions (using lambda:invoke.waitForToken) 6 | 7 | By passing the Step Functions task token to the A2I task as input, we ensure it gets included in 8 | the output JSON generated by the task and therefore enable our S3-triggered callback function to 9 | retrieve the task token and signal to Step Functions that the review is complete. 10 | """ 11 | 12 | # Python Built-Ins: 13 | from datetime import datetime 14 | import json 15 | import logging 16 | import os 17 | import re 18 | import uuid 19 | 20 | # External Dependencies: 21 | import boto3 22 | 23 | 24 | logger = logging.getLogger() 25 | logger.setLevel(logging.INFO) 26 | a2i = boto3.client("sagemaker-a2i-runtime") 27 | ssm = boto3.client("ssm") 28 | 29 | default_flow_definition_arn_param = os.environ.get("DEFAULT_FLOW_DEFINITION_ARN_PARAM") 30 | 31 | 32 | class MalformedRequest(ValueError): 33 | """Returned to SFN when input event structure is invalid""" 34 | 35 | pass 36 | 37 | 38 | def generate_human_loop_name(s3_object_key: str, max_len: int = 63) -> str: 39 | """Create a random-but-a-bit-meaningful unique name for human loop job 40 | 41 | Generated names combine timestamp, object filename, and a random element. 42 | """ 43 | filename = s3_object_key.rpartition("/")[2] 44 | filename_component = re.sub( 45 | # Condense double-hyphens: 46 | r"--", 47 | "-", 48 | re.sub( 49 | # Cut out any remaining disallowed characters: 50 | r"[^a-zA-Z0-9\-]", 51 | "", 52 | re.sub( 53 | # Turn significant punctuation to hyphens: 54 | r"[ _.,!?]", 55 | "-", 56 | filename, 57 | ), 58 | ), 59 | ) 60 | 61 | # Millis is enough, no need for microseconds: 62 | datetime_component = datetime.now().strftime("%Y-%m-%d-%H-%M-%S-%f")[:-3] 63 | # Most significant bits section of a GUID: 64 | random_component = str(uuid.uuid4()).partition("-")[0] 65 | 66 | clipped_filename_component = filename_component[ 67 | : max_len - len(datetime_component) - len(random_component) 68 | ] 69 | 70 | return f"{datetime_component}-{clipped_filename_component}-{random_component}"[:max_len] 71 | 72 | 73 | def handler(event, context): 74 | try: 75 | task_token = event["TaskToken"] 76 | model_result = event["ModelResult"] 77 | task_object = event["TaskObject"] 78 | if isinstance(task_object, dict): 79 | if "S3Uri" in task_object and task_object["S3Uri"]: 80 | task_object = task_object["S3Uri"] 81 | elif "Bucket" in task_object and "Key" in task_object: 82 | task_object = f"s3://{task_object['Bucket']}/{task_object['Key']}" 83 | else: 84 | raise MalformedRequest( 85 | "TaskObject must be an s3://... URI string OR an object with 'S3Uri' key or " 86 | f"both 'Bucket' and 'Key' keys. Got {task_object}" 87 | ) 88 | 89 | task_input = { 90 | "TaskObject": task_object, 91 | "TaskToken": task_token, # Not used within A2I, but for feed-through to callback fn 92 | "ModelResult": model_result, 93 | } 94 | 95 | if "FlowDefinitionArn" in event: 96 | flow_definition_arn = event["FlowDefinitionArn"] 97 | elif default_flow_definition_arn_param: 98 | flow_definition_arn = ssm.get_parameter( 99 | Name=default_flow_definition_arn_param, 100 | )[ 101 | "Parameter" 102 | ]["Value"] 103 | if (not flow_definition_arn) or flow_definition_arn.lower() in ("undefined", "null"): 104 | raise MalformedRequest( 105 | "Neither request FlowDefinitionArn nor expected SSM parameter are set. Got: " 106 | f"{default_flow_definition_arn_param} = '{flow_definition_arn}'" 107 | ) 108 | else: 109 | raise MalformedRequest( 110 | "FlowDefinitionArn not specified in request and DEFAULT_FLOW_DEFINITION_ARN_PARAM " 111 | "env var not set" 112 | ) 113 | except KeyError as ke: 114 | raise MalformedRequest(f"Missing field {ke}, please check your input payload") 115 | 116 | logger.info(f"Starting A2I human loop with input {task_input}") 117 | a2i_response = a2i.start_human_loop( 118 | HumanLoopName=generate_human_loop_name(task_input["TaskObject"]), 119 | FlowDefinitionArn=flow_definition_arn, 120 | HumanLoopInput={"InputContent": json.dumps(task_input)}, 121 | # If adapting this code for use with A2I public workforce, you may need to add additional 122 | # content classifiers as described here: 123 | # https://docs.aws.amazon.com/sagemaker/latest/dg/sms-workforce-management-public.html 124 | # https://docs.aws.amazon.com/augmented-ai/2019-11-07/APIReference/API_HumanLoopDataAttributes.html 125 | # DataAttributes={ 126 | # "ContentClassifiers": ["FreeOfPersonallyIdentifiableInformation"] 127 | # } 128 | ) 129 | logger.info(f"Human loop started: {a2i_response}") 130 | 131 | # Doesn't really matter what we return because Step Functions will wait for the callback with 132 | # the token! 133 | return a2i_response["HumanLoopArn"] 134 | -------------------------------------------------------------------------------- /pipeline/shared/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # SPDX-License-Identifier: MIT-0 3 | """Shared utilities for the pipeline CDK app""" 4 | # Python Built-Ins: 5 | import os 6 | from typing import Union 7 | 8 | 9 | def abs_path(rel_path: Union[str, os.PathLike], from__file__: str) -> str: 10 | """Construct an absolute path from a relative path and current __file__ location""" 11 | return os.path.normpath( 12 | os.path.join( 13 | os.path.dirname(os.path.realpath(from__file__)), 14 | rel_path, 15 | ) 16 | ) 17 | -------------------------------------------------------------------------------- /pipeline/shared/sagemaker/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # SPDX-License-Identifier: MIT-0 3 | """CDK constructs and utilities for working with Amazon SageMaker 4 | """ 5 | # External Dependencies: 6 | from aws_cdk import Stack 7 | from aws_cdk.aws_s3 import Bucket 8 | from constructs import Construct 9 | 10 | # Local Dependencies: 11 | from .model_deployment import ( 12 | EndpointAutoscaler, 13 | SageMakerAsyncInferenceConfig, 14 | SageMakerAutoscalingRole, 15 | SageMakerCustomizedDLCModel, 16 | SageMakerDLCBasedImage, 17 | SageMakerDLCSpec, 18 | SageMakerEndpointExecutionRole, 19 | SageMakerModelDeployment, 20 | ) 21 | from .sagemaker_sfn import SageMakerCallerFunction, SageMakerSSMStep 22 | 23 | 24 | def get_sagemaker_default_bucket(scope: Construct) -> Bucket: 25 | """Generate a CDK S3.Bucket construct for the (assumed pre-existing) SageMaker Default Bucket""" 26 | stack = Stack.of(scope) 27 | return Bucket.from_bucket_arn( 28 | scope, 29 | "SageMakerDefaultBucket", 30 | f"arn:{stack.partition}:s3:::sagemaker-{stack.region}-{stack.account}", 31 | ) 32 | -------------------------------------------------------------------------------- /pipeline/shared/sagemaker/fn-call-sagemaker/requirements.txt: -------------------------------------------------------------------------------- 1 | cachetools>=5.0.0,<6.0.0 2 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.black] 2 | extend-exclude = "^/(cdk\\.out|setup\\.py)" 3 | line-length = 100 4 | 5 | [tool.poetry] 6 | name = "amazon-textract-transformer-pipeline" 7 | version = "0.2.1" 8 | description = "Post-processing Amazon Textract with Transformer-Based Models on Amazon SageMaker" 9 | authors = ["Amazon Web Services"] 10 | license = "MIT-0" 11 | 12 | [tool.poetry.dependencies] 13 | # numpy 1.25 requires Python <3.12 as per https://stackoverflow.com/a/77935901/13352657 14 | python = "^3.9,<3.12" 15 | aws-cdk-lib = "^2.126.0" 16 | "aws-cdk.aws-lambda-python-alpha" = "^2.126.0-alpha.0" 17 | boto3 = "^1.34.33" 18 | botocore = "^1.34.33" 19 | cdk-ecr-deployment = "^3.0.13" 20 | constructs = "^10.3.0" 21 | sagemaker = ">=2.214.3,<3" 22 | semver = "^3.0.0" 23 | 24 | [tool.poetry.group.dev.dependencies] 25 | black = "^24.3.0" 26 | black-nb = "^0.7.0" 27 | 28 | [build-system] 29 | requires = ["poetry-core>=1.0.0"] 30 | build-backend = "poetry.core.masonry.api" 31 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | -e . 2 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # SPDX-License-Identifier: MIT-0 3 | 4 | import setuptools 5 | 6 | with open("README.md") as fp: 7 | long_description = fp.read() 8 | 9 | setuptools.setup( 10 | name="amazon-textract-transformer-pipeline", 11 | version="0.2.1", 12 | 13 | description="Post-processing Amazon Textract with Transformer-Based Models on Amazon SageMaker", 14 | long_description=long_description, 15 | long_description_content_type="text/markdown", 16 | 17 | author="Amazon Web Services", 18 | 19 | packages=["annotation", "pipeline"], 20 | 21 | install_requires=[ 22 | "aws-cdk-lib==^2.126.0", 23 | "aws-cdk.aws-lambda-python-alpha==^2.126.0-alpha.0", 24 | "boto3==^1.34.33", 25 | "cdk-ecr-deployment==^3.0.13", 26 | "constructs==^10.3.0", 27 | "sagemaker>=2.205,<3", 28 | ], 29 | extras_require={ 30 | "dev": [ 31 | "black==^22.3.0", 32 | "black-nb==^0.7.0", 33 | ] 34 | }, 35 | 36 | python_requires=">=3.9,<3.12", 37 | 38 | classifiers=[ 39 | "Development Status :: 4 - Beta", 40 | 41 | "Intended Audience :: Developers", 42 | 43 | "License :: OSI Approved :: MIT No Attribution License (MIT-0)", 44 | 45 | "Programming Language :: JavaScript", 46 | "Programming Language :: Python :: 3 :: Only", 47 | "Programming Language :: Python :: 3.6", 48 | "Programming Language :: Python :: 3.7", 49 | "Programming Language :: Python :: 3.8", 50 | 51 | "Topic :: Scientific/Engineering :: Artificial Intelligence", 52 | "Topic :: Software Development :: Code Generators", 53 | "Topic :: Utilities", 54 | 55 | "Typing :: Typed", 56 | ], 57 | ) 58 | -------------------------------------------------------------------------------- /source.bat: -------------------------------------------------------------------------------- 1 | @echo off 2 | 3 | rem The sole purpose of this script is to make the command 4 | rem 5 | rem source .venv/bin/activate 6 | rem 7 | rem (which activates a Python virtualenv on Linux or Mac OS X) work on Windows. 8 | rem On Windows, this command just runs this batch file (the argument is ignored). 9 | rem 10 | rem Now we don't need to document a Windows command for activating a virtualenv. 11 | 12 | echo Executing .venv\Scripts\activate.bat for you 13 | .venv\Scripts\activate.bat 14 | --------------------------------------------------------------------------------