├── .editorconfig
├── .github
└── PULL_REQUEST_TEMPLATE.md
├── .gitignore
├── .infrastructure
└── attp-bootstrap.cfn.yaml
├── .vscode
└── settings.json
├── CODE_OF_CONDUCT.md
├── CONTRIBUTING.md
├── CUSTOMIZATION_GUIDE.md
├── LICENSE
├── LICENSE-SUMMARY
├── README.md
├── annotation
├── __init__.py
├── fn-SMGT-Post
│ ├── data_model.py
│ ├── main.py
│ └── smgt.py
└── fn-SMGT-Pre
│ └── main.py
├── cdk.json
├── cdk_app.py
├── cdk_demo_stack.py
├── img
├── annotation-example-trimmed.png
├── architecture-overview.png
├── human-review-sample.png
└── sfn-execution-screenshot.png
├── notebooks
├── 1. Data Preparation.ipynb
├── 2. Model Training.ipynb
├── 3. Human Review.ipynb
├── Optional Extras.ipynb
├── Workshop.ipynb
├── annotation
│ └── ocr-bbox-and-validation.liquid.tpl.html
├── custom-containers
│ ├── preproc
│ │ └── Dockerfile
│ └── train-inf
│ │ └── Dockerfile
├── data
│ └── annotations
│ │ ├── LICENSE
│ │ ├── augmentation-1
│ │ └── manifests
│ │ │ └── output
│ │ │ └── output.manifest
│ │ └── augmentation-2
│ │ └── manifests
│ │ └── output
│ │ └── output.manifest
├── img
│ ├── a2i-custom-template-demo.png
│ ├── cfn-stack-outputs-a2i.png
│ ├── sfn-execution-status-screenshot.png
│ ├── sfn-statemachine-screenshot.png
│ ├── sfn-statemachine-success.png
│ ├── smgt-custom-template-demo.png
│ ├── smgt-find-workforce-url.png
│ ├── smgt-private-workforce.png
│ ├── smgt-task-pending.png
│ ├── ssm-a2i-param-detail.png
│ └── ssm-param-detail-screenshot.png
├── preproc
│ ├── __init__.py
│ ├── ocr.py
│ ├── preproc.py
│ └── textract_transformers
│ │ ├── __init__.py
│ │ ├── file_utils.py
│ │ ├── image_utils.py
│ │ ├── ocr.py
│ │ ├── ocr_engines
│ │ ├── __init__.py
│ │ ├── base.py
│ │ └── eng_tesseract.py
│ │ └── preproc.py
├── review
│ ├── .eslintrc.cjs
│ ├── .gitignore
│ ├── .prettierrc.yml
│ ├── README.md
│ ├── env.d.ts
│ ├── fields-validation-legacy.liquid.html
│ ├── index-noliquid.html
│ ├── index.html
│ ├── package-lock.json
│ ├── package.json
│ ├── public
│ │ └── .gitkeep
│ ├── src
│ │ ├── App.vue
│ │ ├── assets
│ │ │ └── base.scss
│ │ ├── components
│ │ │ ├── FieldMultiValue.ce.vue
│ │ │ ├── FieldSingleValue.ce
│ │ │ │ ├── FieldSingleValue.ce.vue
│ │ │ │ └── index.ts
│ │ │ ├── HelloWorld.vue
│ │ │ ├── MultiFieldValue.ce.vue
│ │ │ ├── ObjectValueInput.ts
│ │ │ ├── PdfPageAnnotationLayer.ce.vue
│ │ │ └── Viewer.ce.vue
│ │ ├── main.ts
│ │ └── util
│ │ │ ├── colors.ts
│ │ │ ├── model.d.ts
│ │ │ └── store.ts
│ ├── task-input.example.json
│ ├── tsconfig.json
│ └── vite.config.js
├── src
│ ├── __init__.py
│ ├── code
│ │ ├── __init__.py
│ │ ├── config.py
│ │ ├── data
│ │ │ ├── __init__.py
│ │ │ ├── base.py
│ │ │ ├── geometry.py
│ │ │ ├── mlm.py
│ │ │ ├── ner.py
│ │ │ ├── seq2seq
│ │ │ │ ├── __init__.py
│ │ │ │ ├── date_normalization.py
│ │ │ │ ├── metrics.py
│ │ │ │ └── task_builder.py
│ │ │ ├── smgt.py
│ │ │ └── splitting.py
│ │ ├── inference.py
│ │ ├── inference_seq2seq.py
│ │ ├── logging_utils.py
│ │ ├── models
│ │ │ ├── __init__.py
│ │ │ └── layoutlmv2.py
│ │ ├── smddpfix.py
│ │ └── train.py
│ ├── ddp_launcher.py
│ ├── inference.py
│ ├── inference_seq2seq.py
│ ├── requirements.txt
│ ├── smtc_launcher.py
│ └── train.py
└── util
│ ├── __init__.py
│ ├── deployment.py
│ ├── ocr.py
│ ├── postproc
│ ├── preproc.py
│ ├── project.py
│ ├── s3.py
│ ├── smgt.py
│ ├── training.py
│ ├── uid.py
│ └── viz.py
├── package-lock.json
├── package.json
├── pipeline
├── __init__.py
├── config_utils.py
├── enrichment
│ └── __init__.py
├── fn-trigger
│ ├── main.py
│ └── requirements.txt
├── iam_utils.py
├── ocr
│ ├── __init__.py
│ ├── fn-call-textract
│ │ └── main.py
│ ├── sagemaker_ocr.py
│ ├── sfn_semaphore
│ │ ├── __init__.py
│ │ └── fn-acquire-lock
│ │ │ └── main.py
│ └── textract_ocr.py
├── postprocessing
│ ├── __init__.py
│ └── fn-postprocess
│ │ ├── main.py
│ │ ├── requirements.txt
│ │ └── util
│ │ ├── __init__.py
│ │ ├── boxes.py
│ │ ├── config.py
│ │ ├── deser.py
│ │ ├── extract.py
│ │ └── normalize.py
├── review
│ ├── __init__.py
│ ├── fn-review-callback
│ │ └── main.py
│ └── fn-start-review
│ │ └── main.py
├── shared
│ ├── __init__.py
│ └── sagemaker
│ │ ├── __init__.py
│ │ ├── fn-call-sagemaker
│ │ ├── main.py
│ │ └── requirements.txt
│ │ ├── model_deployment.py
│ │ └── sagemaker_sfn.py
└── thumbnails
│ └── __init__.py
├── poetry.lock
├── pyproject.toml
├── requirements.txt
├── setup.py
└── source.bat
/.editorconfig:
--------------------------------------------------------------------------------
1 | # Top-most EditorConfig file:
2 | root = true
3 |
4 | [*]
5 | charset = utf-8
6 | end_of_line = lf
7 | insert_final_newline = true
8 | trim_trailing_whitespace = true
9 |
--------------------------------------------------------------------------------
/.github/PULL_REQUEST_TEMPLATE.md:
--------------------------------------------------------------------------------
1 | **Issue #, if available:**
2 |
3 | **Description of changes:**
4 |
5 | **Testing done:**
6 |
7 |
8 |
9 |
10 | By submitting this pull request, I confirm that you can use, modify, copy, and redistribute this contribution, under the terms of your choice.
11 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # AWS SAM build artifacts
2 | .aws-sam/
3 | *.tmp.*
4 |
5 | # AWS CDK stuff
6 | *.swp
7 | .cdk.staging
8 | cdk.out
9 |
10 | # SageMaker Experiments/Debugger (if try running locally):
11 | tmp_trainer/
12 |
13 | # Working data folders and notebook-built assets:
14 | # (With some specific exclusions)
15 | notebooks/data/*
16 | !notebooks/data/annotations/
17 | notebooks/data/annotations/*
18 | !notebooks/data/annotations/augmentation-*
19 | !notebooks/data/annotations/LICENSE
20 | notebooks/annotation/*.html
21 | !notebooks/annotation/*.tpl.html
22 |
23 | # Byte-compiled / optimized / DLL files
24 | __pycache__/
25 | *.py[cod]
26 | *$py.class
27 |
28 | # C extensions
29 | *.so
30 |
31 | # Distribution / packaging
32 | .Python
33 | build/
34 | develop-eggs/
35 | dist/
36 | downloads/
37 | eggs/
38 | .eggs/
39 | lib/
40 | lib64/
41 | parts/
42 | sdist/
43 | var/
44 | wheels/
45 | pip-wheel-metadata/
46 | share/python-wheels/
47 | *.egg-info/
48 | .installed.cfg
49 | *.egg
50 | MANIFEST
51 |
52 | # Installer logs
53 | pip-log.txt
54 | pip-delete-this-directory.txt
55 |
56 | # Unit test / coverage reports
57 | htmlcov/
58 | .tox/
59 | .nox/
60 | .coverage
61 | .coverage.*
62 | .cache
63 | nosetests.xml
64 | coverage.xml
65 | *.cover
66 | *.py,cover
67 | .hypothesis/
68 | .pytest_cache/
69 |
70 | # Translations
71 | *.mo
72 | *.pot
73 |
74 | # Flask stuff:
75 | instance/
76 | .webassets-cache
77 |
78 | # Sphinx documentation
79 | docs/_build/
80 |
81 | # PyBuilder
82 | target/
83 |
84 | # Jupyter Notebook
85 | .ipynb_checkpoints
86 |
87 | # IPython
88 | profile_default/
89 | ipython_config.py
90 |
91 | # pyenv
92 | .python-version
93 |
94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
95 | __pypackages__/
96 |
97 | # SageMath parsed files
98 | *.sage.py
99 |
100 | # Environments
101 | .env
102 | .venv
103 | env/
104 | venv/
105 | ENV/
106 | env.bak/
107 | venv.bak/
108 |
109 | # Spyder project settings
110 | .spyderproject
111 | .spyproject
112 |
113 | # Rope project settings
114 | .ropeproject
115 |
116 | # mkdocs documentation
117 | /site
118 |
119 | # mypy
120 | .mypy_cache/
121 | .dmypy.json
122 | dmypy.json
123 |
124 | # Pyre type checker
125 | .pyre/
126 |
127 | # Checks and reporting
128 | /*report*
129 |
130 | # Operating Systems
131 | .DS_Store
132 |
133 | # NodeJS
134 | **/node_modules/**
135 | .nvmrc
136 |
--------------------------------------------------------------------------------
/.vscode/settings.json:
--------------------------------------------------------------------------------
1 | {
2 | "python.pythonPath": ".venv/bin/python3",
3 | "yaml.customTags": [
4 | "!And",
5 | "!And sequence",
6 | "!If",
7 | "!If sequence",
8 | "!Not",
9 | "!Not sequence",
10 | "!Equals",
11 | "!Equals sequence",
12 | "!Or",
13 | "!Or sequence",
14 | "!FindInMap",
15 | "!FindInMap sequence",
16 | "!Base64",
17 | "!Join",
18 | "!Join sequence",
19 | "!Cidr",
20 | "!Ref",
21 | "!Sub",
22 | "!Sub sequence",
23 | "!GetAtt",
24 | "!GetAZs",
25 | "!ImportValue",
26 | "!ImportValue sequence",
27 | "!Select",
28 | "!Select sequence",
29 | "!Split",
30 | "!Split sequence"
31 | ]
32 | }
--------------------------------------------------------------------------------
/CODE_OF_CONDUCT.md:
--------------------------------------------------------------------------------
1 | ## Code of Conduct
2 |
3 | This project has adopted the [Amazon Open Source Code of Conduct](https://aws.github.io/code-of-conduct).
4 | For more information see the [Code of Conduct FAQ](https://aws.github.io/code-of-conduct-faq) or contact
5 | opensource-codeofconduct@amazon.com with any additional questions or comments.
6 |
--------------------------------------------------------------------------------
/CONTRIBUTING.md:
--------------------------------------------------------------------------------
1 | # Contributing Guidelines
2 |
3 | Thank you for your interest in contributing to our project. Whether it's a bug report, new feature, correction, or additional
4 | documentation, we greatly value feedback and contributions from our community.
5 |
6 | Please read through this document before submitting any issues or pull requests to ensure we have all the necessary
7 | information to effectively respond to your bug report or contribution.
8 |
9 |
10 | ## Reporting Bugs/Feature Requests
11 |
12 | We welcome you to use the GitHub issue tracker to report bugs or suggest features.
13 |
14 | When filing an issue, please check existing open, or recently closed, issues to make sure somebody else hasn't already
15 | reported the issue. Please try to include as much information as you can. Details like these are incredibly useful:
16 |
17 | * A reproducible test case or series of steps
18 | * The version of our code being used
19 | * Any modifications you've made relevant to the bug
20 | * Anything unusual about your environment or deployment
21 |
22 |
23 | ## Contributing via Pull Requests
24 | Contributions via pull requests are much appreciated. Before sending us a pull request, please ensure that:
25 |
26 | 1. You are working against the latest source on the *main* branch.
27 | 2. You check existing open, and recently merged, pull requests to make sure someone else hasn't addressed the problem already.
28 | 3. You open an issue to discuss any significant work - we would hate for your time to be wasted.
29 |
30 | To send us a pull request, please:
31 |
32 | 1. Fork the repository.
33 | 2. Modify the source; please focus on the specific change you are contributing. If you also reformat all the code, it will be hard for us to focus on your change.
34 | 3. Ensure local tests pass.
35 | 4. Commit to your fork using clear commit messages.
36 | 5. Send us a pull request, answering any default questions in the pull request interface.
37 | 6. Pay attention to any automated CI failures reported in the pull request, and stay involved in the conversation.
38 |
39 | GitHub provides additional document on [forking a repository](https://help.github.com/articles/fork-a-repo/) and
40 | [creating a pull request](https://help.github.com/articles/creating-a-pull-request/).
41 |
42 |
43 | ## Finding contributions to work on
44 | Looking at the existing issues is a great way to find something to contribute on. As our projects, by default, use the default GitHub issue labels (enhancement/bug/duplicate/help wanted/invalid/question/wontfix), looking at any 'help wanted' issues is a great place to start.
45 |
46 |
47 | ## Code of Conduct
48 | This project has adopted the [Amazon Open Source Code of Conduct](https://aws.github.io/code-of-conduct).
49 | For more information see the [Code of Conduct FAQ](https://aws.github.io/code-of-conduct-faq) or contact
50 | opensource-codeofconduct@amazon.com with any additional questions or comments.
51 |
52 |
53 | ## Security issue notifications
54 | If you discover a potential security issue in this project we ask that you notify AWS/Amazon Security via our [vulnerability reporting page](http://aws.amazon.com/security/vulnerability-reporting/). Please do **not** create a public github issue.
55 |
56 |
57 | ## Licensing
58 |
59 | See the [LICENSE](LICENSE) file for our project's licensing. We will ask you to confirm the licensing of your contribution.
60 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2 |
3 | Permission is hereby granted, free of charge, to any person obtaining a copy of
4 | this software and associated documentation files (the "Software"), to deal in
5 | the Software without restriction, including without limitation the rights to
6 | use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
7 | the Software, and to permit persons to whom the Software is furnished to do so.
8 |
9 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
10 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
11 | FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
12 | COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
13 | IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
14 | CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
15 |
16 |
--------------------------------------------------------------------------------
/LICENSE-SUMMARY:
--------------------------------------------------------------------------------
1 | Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2 |
3 | This library is licensed under the MIT-0 License. See `LICENSE`.
4 |
5 | Included datasets are licensed under the Creative Commons Attribution 4.0
6 | International License. See `LICENSE` in the `notebooks/data/annotations`
7 | subfolder.
8 |
--------------------------------------------------------------------------------
/annotation/fn-SMGT-Post/main.py:
--------------------------------------------------------------------------------
1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2 | # SPDX-License-Identifier: MIT-0
3 | """Annotation consolidation Lambda for BBoxes+transcriptions in SageMaker Ground Truth
4 | """
5 | # Python Built-Ins:
6 | import json
7 | import logging
8 | from typing import List, Optional
9 |
10 | # External Dependencies:
11 | import boto3 # AWS SDK for Python
12 |
13 | # Set up logger before local imports:
14 | logger = logging.getLogger()
15 | logger.setLevel(logging.INFO)
16 |
17 | # Local Dependencies:
18 | from data_model import SMGTWorkerAnnotation # Custom task data model (edit if needed!)
19 | from smgt import ( # Generic SageMaker Ground Truth parsers/utilities
20 | ConsolidationRequest,
21 | ObjectAnnotationResult,
22 | PostConsolidationDatum,
23 | )
24 |
25 |
26 | s3 = boto3.client("s3")
27 |
28 |
29 | def consolidate_object_annotations(
30 | object_data: ObjectAnnotationResult,
31 | label_attribute_name: str,
32 | label_categories: Optional[List[str]] = None,
33 | ) -> PostConsolidationDatum:
34 | """Consolidate the (potentially multiple) raw worker annotations for a dataset object
35 |
36 | TODO: Actual consolidation/reconciliation of multiple labels is not yet supported!
37 |
38 | This function just takes the "first" (not necessarily clock-first) worker's result and outputs
39 | a warning if others were found.
40 |
41 | Parameters
42 | ----------
43 | object_data :
44 | Object describing the raw annotations and metadata for a particular task in the SMGT job
45 | label_attribute_name :
46 | Target attribute on the output object to store consolidated label results (note this may
47 | not be the *only* attribute set/updated on the output object, hence provided as a param
48 | rather than abstracted away).
49 | label_categories :
50 | Label categories specified when creating the labelling job. If provided, this is used to
51 | translate from class names to numeric class_id similarly to SMGT's built-in bounding box
52 | task result.
53 | """
54 | warn_msgs: List[str] = []
55 | worker_anns: List[SMGTWorkerAnnotation] = []
56 | for worker_ann in object_data.annotations:
57 | ann_raw = worker_ann.fetch_data()
58 | worker_anns.append(SMGTWorkerAnnotation.parse(ann_raw, class_list=label_categories))
59 |
60 | if len(worker_anns) > 1:
61 | warn_msg = (
62 | "Reconciliation of multiple worker annotations is not currently implemented for this "
63 | "post-processor. Outputting annotation from worker %s and ignoring labels from %s"
64 | % (
65 | object_data.annotations[0].worker_id,
66 | [a.worker_id for a in object_data.annotations[1:]],
67 | )
68 | )
69 | logger.warning(warn_msg)
70 | warn_msgs.append(warn_msg)
71 |
72 | consolidated_label = worker_anns[0].to_jsonable()
73 | if len(warn_msgs):
74 | consolidated_label["consolidationWarnings"] = warn_msgs
75 |
76 | return PostConsolidationDatum(
77 | dataset_object_id=object_data.dataset_object_id,
78 | consolidated_content={
79 | label_attribute_name: consolidated_label,
80 | # Note: In our tests it's not possible to add a f"{label_attribute_name}-meta" field
81 | # here - it gets replaced by whatever post-processing happens, instead of merged.
82 | },
83 | )
84 |
85 |
86 | def handler(event: dict, context) -> List[dict]:
87 | """Main Lambda handler for consolidation of SMGT worker annotations
88 |
89 | This function receives a batched request to consolidate (multiple?) workers' annotations for
90 | multiple objects, and outputs the consolidated results per object. For more docs see:
91 |
92 | https://docs.aws.amazon.com/sagemaker/latest/dg/sms-custom-templates-step3-lambda-requirements.html
93 | """
94 | logger.info("Received event: %s", json.dumps(event))
95 | req = ConsolidationRequest.parse(event)
96 | if req.label_categories and len(req.label_categories) > 0:
97 | label_cats = req.label_categories
98 | else:
99 | logger.warning(
100 | "Label categories list (see CreateLabelingJob.LabelCategoryConfigS3Uri) was not "
101 | "provided when creating this job. Post-consolidation outputs will be incompatible with "
102 | "built-in Bounding Box task, because we're unable to map class names to numeric IDs."
103 | )
104 | label_cats = None
105 |
106 | # Loop through the objects in this batch, consolidating annotations for each:
107 | return [
108 | consolidate_object_annotations(
109 | object_data,
110 | label_attribute_name=req.label_attribute_name,
111 | label_categories=label_cats,
112 | ).to_jsonable()
113 | for object_data in req.fetch_object_annotations()
114 | ]
115 |
--------------------------------------------------------------------------------
/annotation/fn-SMGT-Pre/main.py:
--------------------------------------------------------------------------------
1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2 | # SPDX-License-Identifier: MIT-0
3 | """A minimal Lambda function for pre-processing SageMaker Ground Truth custom annotation tasks
4 |
5 | Just passes through the event's `dataObject` unchanged.
6 | """
7 | import logging
8 |
9 | logger = logging.getLogger()
10 | logger.setLevel(logging.INFO)
11 |
12 |
13 | def handler(event, context):
14 | logger.debug("Got event: %s", event)
15 | result = {"taskInput": event["dataObject"]}
16 | logger.debug("Returning result: %s", result)
17 | return result
18 |
--------------------------------------------------------------------------------
/cdk.json:
--------------------------------------------------------------------------------
1 | {
2 | "app": "python3 cdk_app.py",
3 | "context": {
4 | }
5 | }
6 |
--------------------------------------------------------------------------------
/cdk_app.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
3 | # SPDX-License-Identifier: MIT-0
4 | """AWS CDK app entry point for OCR pipeline sample
5 | """
6 | # Python Built-Ins:
7 | import json
8 | import os
9 |
10 | # External Dependencies:
11 | from aws_cdk import App
12 |
13 | # Local Dependencies:
14 | from cdk_demo_stack import PipelineDemoStack
15 | from pipeline.config_utils import bool_env_var, list_env_var
16 |
17 |
18 | # Top-level configurations are loaded from environment variables at the point `cdk synth` or
19 | # `cdk deploy` is run (or you can override here):
20 | config = {
21 | # Used as a prefix for some cloud resources e.g. SSM parameters:
22 | "default_project_id": os.environ.get("DEFAULT_PROJECT_ID", default="ocr-transformers-demo"),
23 |
24 | # Set False to skip deploying the page thumbnail image generator, if you're only using models
25 | # (like LayoutLMv1) that don't take page image as input features:
26 | "use_thumbnails": bool_env_var("USE_THUMBNAILS", default=True),
27 |
28 | # Set True to enable auto-scale-to-zero on auto-deployed SageMaker endpoints (including the
29 | # thumbnail generator and any custom OCR engines). This saves costs for low-volume workloads,
30 | # but introduces a few minutes' extra cold start for requests when all instances are released:
31 | "enable_sagemaker_autoscaling": bool_env_var("ENABLE_SM_AUTOSCALING", default=False),
32 |
33 | # To use alternative Tesseract OCR instead of Amazon Textract, before running `cdk deploy` run:
34 | # export BUILD_SM_OCRS=tesseract
35 | # export DEPLOY_SM_OCRS=tesseract
36 | # export USE_SM_OCR=tesseract
37 | # ...Or edit the defaults below to `["tesseract"]` and `"tesseract"`
38 | "build_sagemaker_ocrs": list_env_var("BUILD_SM_OCRS", default=[]),
39 | "deploy_sagemaker_ocrs": list_env_var("DEPLOY_SM_OCRS", default=[]),
40 | "use_sagemaker_ocr": os.environ.get("USE_SM_OCR", default=None),
41 | }
42 |
43 | app = App()
44 | print(f"Deploying stack with configuration:\n{json.dumps(config, indent=2)}")
45 | demo_stack = PipelineDemoStack(
46 | app,
47 | "OCRPipelineDemo",
48 | **config,
49 | )
50 | app.synth()
51 |
--------------------------------------------------------------------------------
/img/annotation-example-trimmed.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aws-samples/amazon-textract-transformer-pipeline/06b39d69023e584da8daaf2b0ef44d31465b05e8/img/annotation-example-trimmed.png
--------------------------------------------------------------------------------
/img/architecture-overview.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aws-samples/amazon-textract-transformer-pipeline/06b39d69023e584da8daaf2b0ef44d31465b05e8/img/architecture-overview.png
--------------------------------------------------------------------------------
/img/human-review-sample.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aws-samples/amazon-textract-transformer-pipeline/06b39d69023e584da8daaf2b0ef44d31465b05e8/img/human-review-sample.png
--------------------------------------------------------------------------------
/img/sfn-execution-screenshot.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aws-samples/amazon-textract-transformer-pipeline/06b39d69023e584da8daaf2b0ef44d31465b05e8/img/sfn-execution-screenshot.png
--------------------------------------------------------------------------------
/notebooks/custom-containers/preproc/Dockerfile:
--------------------------------------------------------------------------------
1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2 | # SPDX-License-Identifier: MIT-0
3 |
4 | #### Container image with document/image processing (and optionally OCR) tools added.
5 |
6 | ARG BASE_IMAGE
7 | FROM ${BASE_IMAGE}
8 |
9 | # Common/base doc & image processing tools:
10 | RUN conda install -c conda-forge poppler -y \
11 | && pip install amazon-textract-response-parser pdf2image "Pillow>=8,<9"
12 |
13 | # Optional OCR engine: Tesseract+PyTesseract
14 | # conda tesseract already includes Leptonica dependency and multi-language tessdata files by default
15 | # (but didn't set the required TESSDATA_PREFIX variable at time of writing)
16 | ARG INCLUDE_OCR_TESSERACT
17 | RUN if test -z "$INCLUDE_OCR_TESSERACT" ; \
18 | then \
19 | echo Skipping OCR engine Tesseract \
20 | ; else \
21 | conda install -y -c conda-forge tesseract && \
22 | pip install pytesseract && \
23 | export TESSDATA_PREFIX='/opt/conda/share/tessdata' \
24 | ; fi
25 |
--------------------------------------------------------------------------------
/notebooks/custom-containers/train-inf/Dockerfile:
--------------------------------------------------------------------------------
1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2 | # SPDX-License-Identifier: MIT-0
3 |
4 | # Container definition for Layout+language model training & inference on SageMaker
5 |
6 | ARG BASE_IMAGE
7 | FROM ${BASE_IMAGE}
8 |
9 | # Core dependencies:
10 | # - Pin PyTorch to prevent pip accidentally re-installing/upgrading it via detectron
11 | # - Pin setuptools per https://github.com/pytorch/pytorch/issues/69894#issuecomment-1080635462
12 | # - Pin protobuf < 3.21 due to an error like https://stackoverflow.com/q/72441758 as of 2023-02
13 | # (which seems to originate from somewhere in SM DDP package when unconstrained install results
14 | # in downloading protobuf@4.x)
15 | RUN PT_VER=`pip show torch | grep 'Version:' | sed 's/Version: //'` \
16 | && pip install git+https://github.com/facebookresearch/detectron2.git setuptools==59.5.0 \
17 | "amazon-textract-response-parser>=0.1,<0.2" "datasets[vision]>=2.14,<3" "Pillow>=9.4" \
18 | "protobuf<3.21" torch==$PT_VER "torchvision>=0.15,<0.17" "transformers>=4.28,<4.29"
19 |
20 | # Could also consider installing detectron2 via pre-built Linux wheel, depending on the PyTorch and
21 | # CUDA versions of your base container:
22 | # https://github.com/aws/deep-learning-containers/tree/master/huggingface/pytorch
23 | # https://detectron2.readthedocs.io/en/latest/tutorials/install.html
24 | #
25 | # For example:
26 | # && pip install detectron2 -f https://dl.fbaipublicfiles.com/detectron2/wheels/cu113/torch1.10/index.html
27 |
28 |
29 | # Additional dependencies:
30 | # - pytesseract shouldn't be necessary after Transformers v4.18 (because we don't use Tesseract
31 | # OCR), but older versions have a bug: https://github.com/huggingface/transformers/issues/16845
32 | # - datasets 1.18 and torchvision 0.11 are installed in the HF training container but missing from
33 | # the inference container, and we need them for inference. Upgraded datasets to use some new
34 | # logging controls and debug multi-worker .map() pre-processing:
35 | RUN PT_VER=`pip show torch | grep 'Version:' | sed 's/Version: //'` \
36 | && pip install pytesseract torch==$PT_VER
37 |
38 |
39 | # If you'd like to enable this container as a Custom Image for notebook kernels, for debugging in
40 | # SageMaker Studio, build it with INCLUDE_NOTEBOOK_KERNEL=1 arg to include IPython kernel and also
41 | # some other PDF processing + OCR utilities:
42 | ARG INCLUDE_NOTEBOOK_KERNEL
43 | RUN if test -z "$INCLUDE_NOTEBOOK_KERNEL" ; \
44 | then \
45 | echo Skipping notebook kernel dependencies \
46 | ; else \
47 | conda install -y -c conda-forge poppler tesseract && \
48 | PT_VER=`pip show torch | grep 'Version:' | sed 's/Version: //'` && \
49 | pip install easyocr ipykernel "ipywidgets>=8.1,<9" pdf2image pytesseract sagemaker \
50 | torch==$PT_VER && \
51 | export TESSDATA_PREFIX='/opt/conda/share/tessdata' && \
52 | python -m ipykernel install --sys-prefix \
53 | ; fi
54 |
55 | # We would like to disable SMDEBUG when running as a notebook kernel, because it can cause some
56 | # unwanted side-effects... But at the time of writing Dockerfile doesn't have full support for a
57 | # conditional env statement - so:
58 | # if --build-arg INCLUDE_NOTEBOOK_KERNEL=1, set USE_SMDEBUG to 'false', else set null.
59 | ENV USE_SMDEBUG=${INCLUDE_NOTEBOOK_KERNEL:+false}
60 | # ...But '' will cause problems in SM Training, default empty value to 'true' instead (which should
61 | # be the default per:
62 | # https://github.com/awslabs/sagemaker-debugger/blob/56fabe531692403e77ce9b5879d55211adec238e/smdebug/core/config_validator.py#L21
63 | ENV USE_SMDEBUG=${USE_SMDEBUG:-true}
64 |
65 | # See below guidance for adding an image built with INCLUDE_NOTEBOOK_KERNEL to SMStudio:
66 | # https://docs.aws.amazon.com/sagemaker/latest/dg/studio-byoi.html
67 | # https://github.com/aws-samples/sagemaker-studio-custom-image-samples
68 | #
69 | # An image config something like the following should work:
70 | # {
71 | # "KernelSpecs": [
72 | # {
73 | # "Name": "python3",
74 | # "DisplayName": "Textract Transformers"
75 | # },
76 | # ],
77 | # "FileSystemConfig": {
78 | # "MountPath": "/root/data",
79 | # "DefaultUid": 0,
80 | # "DefaultGid": 0
81 | # }
82 | # }
83 |
--------------------------------------------------------------------------------
/notebooks/img/a2i-custom-template-demo.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aws-samples/amazon-textract-transformer-pipeline/06b39d69023e584da8daaf2b0ef44d31465b05e8/notebooks/img/a2i-custom-template-demo.png
--------------------------------------------------------------------------------
/notebooks/img/cfn-stack-outputs-a2i.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aws-samples/amazon-textract-transformer-pipeline/06b39d69023e584da8daaf2b0ef44d31465b05e8/notebooks/img/cfn-stack-outputs-a2i.png
--------------------------------------------------------------------------------
/notebooks/img/sfn-execution-status-screenshot.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aws-samples/amazon-textract-transformer-pipeline/06b39d69023e584da8daaf2b0ef44d31465b05e8/notebooks/img/sfn-execution-status-screenshot.png
--------------------------------------------------------------------------------
/notebooks/img/sfn-statemachine-screenshot.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aws-samples/amazon-textract-transformer-pipeline/06b39d69023e584da8daaf2b0ef44d31465b05e8/notebooks/img/sfn-statemachine-screenshot.png
--------------------------------------------------------------------------------
/notebooks/img/sfn-statemachine-success.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aws-samples/amazon-textract-transformer-pipeline/06b39d69023e584da8daaf2b0ef44d31465b05e8/notebooks/img/sfn-statemachine-success.png
--------------------------------------------------------------------------------
/notebooks/img/smgt-custom-template-demo.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aws-samples/amazon-textract-transformer-pipeline/06b39d69023e584da8daaf2b0ef44d31465b05e8/notebooks/img/smgt-custom-template-demo.png
--------------------------------------------------------------------------------
/notebooks/img/smgt-find-workforce-url.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aws-samples/amazon-textract-transformer-pipeline/06b39d69023e584da8daaf2b0ef44d31465b05e8/notebooks/img/smgt-find-workforce-url.png
--------------------------------------------------------------------------------
/notebooks/img/smgt-private-workforce.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aws-samples/amazon-textract-transformer-pipeline/06b39d69023e584da8daaf2b0ef44d31465b05e8/notebooks/img/smgt-private-workforce.png
--------------------------------------------------------------------------------
/notebooks/img/smgt-task-pending.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aws-samples/amazon-textract-transformer-pipeline/06b39d69023e584da8daaf2b0ef44d31465b05e8/notebooks/img/smgt-task-pending.png
--------------------------------------------------------------------------------
/notebooks/img/ssm-a2i-param-detail.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aws-samples/amazon-textract-transformer-pipeline/06b39d69023e584da8daaf2b0ef44d31465b05e8/notebooks/img/ssm-a2i-param-detail.png
--------------------------------------------------------------------------------
/notebooks/img/ssm-param-detail-screenshot.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aws-samples/amazon-textract-transformer-pipeline/06b39d69023e584da8daaf2b0ef44d31465b05e8/notebooks/img/ssm-param-detail-screenshot.png
--------------------------------------------------------------------------------
/notebooks/preproc/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2 | # SPDX-License-Identifier: MIT-0
3 | """OCR and image pre-processing utilities for document understanding projects with SageMaker
4 |
5 | This top-level __init__.py is not necessary for SageMaker processing/training/etc jobs, but provided
6 | in case you want to `from preproc import textract_transformers` from the notebooks to experiment.
7 | """
8 |
--------------------------------------------------------------------------------
/notebooks/preproc/ocr.py:
--------------------------------------------------------------------------------
1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2 | # SPDX-License-Identifier: MIT-0
3 | """Top-level entrypoint script for alternative OCR engines
4 | """
5 | # Python Built-Ins:
6 | import logging
7 | import os
8 | import sys
9 |
10 |
11 | def run_main():
12 | """Configure logging, import local modules and run the job"""
13 | consolehandler = logging.StreamHandler(sys.stdout)
14 | consolehandler.setFormatter(
15 | logging.Formatter("%(asctime)s [%(name)s] %(levelname)s %(message)s")
16 | )
17 | logging.basicConfig(handlers=[consolehandler], level=os.environ.get("LOG_LEVEL", logging.INFO))
18 |
19 | from textract_transformers.ocr import main
20 |
21 | return main()
22 |
23 |
24 | if __name__ == "__main__":
25 | # If the file is running as a script, we're in batch processing mode and should run the batch
26 | # routine (with a little logging setup before any imports, to make sure output shows up ok):
27 | run_main()
28 | else:
29 | # If the file is imported as a module, we're in inference mode and should pass through the
30 | # override functions defined in the inference module.
31 | from textract_transformers.ocr import *
32 |
--------------------------------------------------------------------------------
/notebooks/preproc/preproc.py:
--------------------------------------------------------------------------------
1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2 | # SPDX-License-Identifier: MIT-0
3 | """Top-level entrypoint script for image/document pre-processing
4 | """
5 | # Python Built-Ins:
6 | import logging
7 | import os
8 | import sys
9 |
10 |
11 | def run_main():
12 | """Configure logging, import local modules and run the job"""
13 | consolehandler = logging.StreamHandler(sys.stdout)
14 | consolehandler.setFormatter(
15 | logging.Formatter("%(asctime)s [%(name)s] %(levelname)s %(message)s")
16 | )
17 | logging.basicConfig(handlers=[consolehandler], level=os.environ.get("LOG_LEVEL", logging.INFO))
18 |
19 | from textract_transformers.preproc import main
20 |
21 | return main()
22 |
23 |
24 | if __name__ == "__main__":
25 | # If the file is running as a script, we're in batch processing mode and should run the batch
26 | # routine (with a little logging setup before any imports, to make sure output shows up ok):
27 | run_main()
28 | else:
29 | # If the file is imported as a module, we're in inference mode and should pass through the
30 | # override functions defined in the inference module.
31 | from textract_transformers.preproc import *
32 |
--------------------------------------------------------------------------------
/notebooks/preproc/textract_transformers/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2 | # SPDX-License-Identifier: MIT-0
3 | """OCR and image pre-processing utilities for document understanding projects with SageMaker
4 | """
5 |
--------------------------------------------------------------------------------
/notebooks/preproc/textract_transformers/file_utils.py:
--------------------------------------------------------------------------------
1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2 | # SPDX-License-Identifier: MIT-0
3 | """Utilities for working with filesystem names and paths
4 | """
5 | # Python Built-Ins:
6 | import os
7 | from typing import List, Tuple
8 |
9 |
10 | def split_filename(filename: str) -> Tuple[str, str]:
11 | """Split a filename into base name and extension
12 |
13 | This basic method does NOT currently account for 2-part extensions e.g. '.tar.gz'
14 | """
15 | basename, _, ext = filename.rpartition(".")
16 | return basename, ext
17 |
18 |
19 | def ls_relpaths(path: str, exclude_hidden: bool = True, sort: bool = True) -> List[str]:
20 | """Recursively list folder contents, sorting and excluding hidden files by default
21 |
22 | Parameters
23 | ----------
24 | path :
25 | Folder to be walked
26 | exclude_hidden :
27 | By default (True), exclude any files beginning with '.' or folders beginning with '.'
28 | sort :
29 | By default (True), sort result paths in alphabetical order. If False, results will be
30 | randomly ordered as per os.walk().
31 |
32 | Returns
33 | -------
34 | results :
35 | *Relative* file paths under the provided folder
36 | """
37 | if path.endswith("/"):
38 | path = path[:-1]
39 | result = [
40 | os.path.join(currpath, name)[len(path) + 1 :] # +1 for trailing '/'
41 | for currpath, dirs, files in os.walk(path)
42 | for name in files
43 | ]
44 | if exclude_hidden:
45 | result = filter(
46 | lambda f: not (f.startswith(".") or "/." in f), result # (Exclude hidden dot-files)
47 | )
48 | if sort:
49 | return sorted(result)
50 | else:
51 | return list(result)
52 |
--------------------------------------------------------------------------------
/notebooks/preproc/textract_transformers/ocr.py:
--------------------------------------------------------------------------------
1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2 | # SPDX-License-Identifier: MIT-0
3 | """Script to run open-source OCR engines in Amazon SageMaker
4 | """
5 |
6 | # Python Built-Ins:
7 | import argparse
8 | from base64 import b64decode
9 | import json
10 | from logging import getLogger
11 | from multiprocessing import cpu_count, Pool
12 | import os
13 | import time
14 | from typing import Iterable, Optional, Tuple
15 |
16 | # External Dependencies:
17 | import boto3
18 |
19 | # Local Dependencies
20 | from . import ocr_engines
21 | from .file_utils import ls_relpaths
22 | from .image_utils import Document
23 |
24 |
25 | logger = getLogger("ocr")
26 | s3client = boto3.client("s3")
27 |
28 | # Environment variable configurations:
29 | # When running in SM Endpoint, we can't use the usual processing job command line argument pattern
30 | # to configure these extra parameters - so instead configure via environment variables for both.
31 | OCR_ENGINE = os.environ.get("OCR_ENGINE", "tesseract").lower()
32 | OCR_DEFAULT_LANGUAGES = os.environ.get("OCR_DEFAULT_LANGUAGES", "eng").lower().split(",")
33 | OCR_DEFAULT_DPI = int(os.environ.get("OCR_DEFAULT_DPI", "300"))
34 |
35 |
36 | def model_fn(model_dir: str):
37 | """OCR Engine loader: Load the configured engine into memory ready to use"""
38 | return ocr_engines.get(OCR_ENGINE, OCR_DEFAULT_LANGUAGES)
39 |
40 |
41 | def input_fn(input_bytes: bytes, content_type: str) -> Tuple[Document, Optional[Iterable[str]]]:
42 | """Deserialize real-time processing requests
43 |
44 | For binary data requests (image or document bytes), default settings will be used e.g.
45 | `OCR_DEFAULT_LANGUAGES`.
46 |
47 | For JSON requests, supported fields are:
48 |
49 | ```
50 | {
51 | "Document": {
52 | "Bytes": Base64-encoded inline document/image, OR:
53 | "S3Object": {
54 | "Bucket": S3 bucket name for raw document/image
55 | "Name": S3 object key
56 | "VersionId": Optional S3 object version ID
57 | }
58 | },
59 | "Languages": Optional List[str] override for OCR_DEFAULT_LANGUAGES language codes
60 | }
61 | ```
62 |
63 | Returns
64 | -------
65 | doc :
66 | Loaded `Document` from which page images may be accessed
67 | languages :
68 | Optional override list of language codes for OCR, otherwise None.
69 | """
70 | logger.debug("Deserializing request of content_type %s", content_type)
71 | if content_type == "application/json":
72 | # Indirected request with metadata (e.g. language codes and S3 pointer):
73 | req = json.loads(input_bytes)
74 | doc_spec = req.get("Document", {})
75 | if "Bytes" in doc_spec:
76 | doc_bytes = b64decode(doc_spec["Bytes"])
77 | elif "S3Object" in doc_spec:
78 | s3_spec = doc_spec["S3Object"]
79 | if not ("Bucket" in s3_spec and "Name" in s3_spec):
80 | raise ValueError(
81 | "Document.S3Object must be an object with keys 'Bucket' and 'Name'. Got: %s"
82 | % s3_spec
83 | )
84 | logger.info("Fetching s3://%s/%s ...", s3_spec["Bucket"], s3_spec["Name"])
85 | version_id = s3_spec.get("Version")
86 | resp = s3client.get_object(
87 | Bucket=s3_spec["Bucket"],
88 | Key=s3_spec["Name"],
89 | **({} if version_id is None else {"VersionId": version_id}),
90 | )
91 | doc_bytes = resp["Body"].read()
92 | content_type = resp["ContentType"]
93 | else:
94 | raise ValueError(
95 | "JSON requests must include 'Document' object containing either 'Bytes' or "
96 | "'S3Object'. Got %s" % req
97 | )
98 |
99 | languages = req.get("Languages")
100 | else:
101 | # Direct image/document request:
102 | doc_bytes = input_bytes
103 | languages = None
104 |
105 | return (
106 | Document(doc_bytes, ext_or_media_type=content_type, default_doc_dpi=OCR_DEFAULT_DPI),
107 | languages,
108 | )
109 |
110 |
111 | def predict_fn(
112 | inputs: Tuple[Document, Optional[Iterable[str]]], engine: ocr_engines.BaseOCREngine
113 | ) -> dict:
114 | """Get OCR results for a single input document/image, for the requested language codes
115 |
116 | Returns
117 | -------
118 | result :
119 | JSON-serializable OCR result dictionary, of format roughly compatible with Amazon Textract
120 | DetectDocumentText result payload.
121 | """
122 | return engine.process(inputs[0], languages=inputs[1])
123 |
124 |
125 | # No output_fn required as we will always use JSON which the default serializer supports
126 | # def output_fn(prediction_output: Dict, accept: str) -> bytes:
127 |
128 |
129 | def parse_args() -> argparse.Namespace:
130 | """Parse SageMaker OCR Processing Job (batch) CLI arguments to job parameters"""
131 | parser = argparse.ArgumentParser(
132 | description="OCR documents in batch using an alternative (non-Amazon-Textract) engine"
133 | )
134 | parser.add_argument(
135 | "--input",
136 | type=str,
137 | default="/opt/ml/processing/input/raw",
138 | help="Folder where raw input images/documents are stored",
139 | )
140 | parser.add_argument(
141 | "--output",
142 | type=str,
143 | default="/opt/ml/processing/output/ocr",
144 | help="Folder where Amazon Textract-compatible OCR results should be saved",
145 | )
146 | parser.add_argument(
147 | "--n-workers",
148 | type=int,
149 | default=cpu_count(),
150 | help="Number of worker processes to use for extraction (default #CPU cores)",
151 | )
152 | args = parser.parse_args()
153 | return args
154 |
155 |
156 | def process_doc_in_worker(inputs: dict) -> None:
157 | """Batch job worker function to extract a document (used in a multiprocessing pool)
158 |
159 | File paths are mapped similar to this sample's Amazon Textract pipeline. For example:
160 | `{in_folder}/some/fld/filename.pdf` to `{out_folder}/some/fld/filename.pdf/consolidated.json`
161 |
162 | Parameters
163 | ----------
164 | inputs :
165 | Dictionary containing fields:
166 | - in_folder (str): Mandatory path to input documents folder
167 | - rel_filepath (str): Mandatory path relative to `in_folder`, to input document
168 | - out_folder (str): Mandatory path to OCR results output folder
169 | - wait (float): Optional number of seconds to wait before starting processing, to ensure
170 | system resources are not *fully* exhausted when running as many threads as CPU cores.
171 | (Which could cause health check problems) - Default 0.5.
172 | """
173 | time.sleep(inputs.get("wait", 0.5))
174 | in_path = os.path.join(inputs["in_folder"], inputs["rel_filepath"])
175 | doc = Document(
176 | in_path,
177 | # ext_or_media_type to be inferred from file path
178 | default_doc_dpi=OCR_DEFAULT_DPI,
179 | base_file_path=inputs["in_folder"],
180 | )
181 | engine = ocr_engines.get(OCR_ENGINE, OCR_DEFAULT_LANGUAGES)
182 | try:
183 | result = engine.process(doc, OCR_DEFAULT_LANGUAGES)
184 | except Exception as e:
185 | logger.error("Failed to process document %s", in_path)
186 | raise e
187 | out_path = os.path.join(inputs["out_folder"], inputs["rel_filepath"], "consolidated.json")
188 | os.makedirs(os.path.dirname(out_path), exist_ok=True)
189 | with open(out_path, "w") as f:
190 | f.write(json.dumps(result, indent=2))
191 | logger.info("Processed doc %s", in_path)
192 |
193 |
194 | def main() -> None:
195 | """Main batch processing job entrypoint: Parse CLI+envvars and process docs in multiple workers"""
196 | args = parse_args()
197 | logger.info("Parsed job args: %s", args)
198 |
199 | logger.info("Reading raw files from %s", args.input)
200 | rel_filepaths_all = ls_relpaths(args.input)
201 |
202 | n_docs = len(rel_filepaths_all)
203 | logger.info("Processing %s files across %s processes", n_docs, args.n_workers)
204 | with Pool(args.n_workers) as pool:
205 | for ix, _ in enumerate(
206 | pool.imap_unordered(
207 | process_doc_in_worker,
208 | [
209 | {
210 | "in_folder": args.input,
211 | "out_folder": args.output,
212 | "rel_filepath": path,
213 | }
214 | for path in rel_filepaths_all
215 | ],
216 | )
217 | ):
218 | logger.info("Processed doc %s of %s", ix + 1, n_docs)
219 | logger.info("All done!")
220 |
--------------------------------------------------------------------------------
/notebooks/preproc/textract_transformers/ocr_engines/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2 | # SPDX-License-Identifier: MIT-0
3 | """Custom/open-source OCR engine integrations
4 |
5 | Use .get() defined in this __init__.py file, to dynamically load your custom engine(s).
6 |
7 | To be discoverable by get(), your module (script or folder):
8 |
9 | - Should be placed in this folder, with a name beginning 'eng_'.
10 | - Should expose a class inheriting from `base.BaseOCREngine`, preferably with a name
11 | ending with 'Engine', and should not expose multiple such classes.
12 | """
13 | # Python Built-Ins:
14 | from importlib import import_module
15 | import inspect
16 | from logging import getLogger
17 | import os
18 | from types import ModuleType
19 | from typing import Dict, Iterable, Type
20 |
21 | # Local Dependencies:
22 | from .base import BaseOCREngine
23 |
24 |
25 | logger = getLogger("ocr_engines")
26 |
27 |
28 | # Auto-discover all eng_*** modules in this folder as [EngineName->ModuleName]:
29 | ENGINES: Dict[str, str] = {}
30 | for item in os.listdir(os.path.dirname(__file__)):
31 | if not item.startswith("eng_"):
32 | continue
33 | if item.endswith(".py"):
34 | # (Assuming everything starting with eng_ is a folder or a .py file), strip ext if present:
35 | item = item[: -len(".py")]
36 | name = item[len("eng_") :] # ID/Name of the engine strips leading 'eng_'
37 | ENGINES[name] = "." + item # Relative importable module name
38 |
39 |
40 | def _find_ocr_engine_class(module: ModuleType) -> Type[BaseOCREngine]:
41 | """Find the OCREngine class from an imported module"""
42 | class_names = [name for name in dir(module) if inspect.isclass(module.__dict__[name])]
43 | names_ending_engine = [n for n in dir(module) if n.endswith("Engine")]
44 | engine_child_classes = [
45 | name
46 | for name in class_names
47 | if issubclass(module.__dict__[name], BaseOCREngine)
48 | and module.__dict__[name] is not BaseOCREngine
49 | ]
50 | preferred_names = [n for n in engine_child_classes if n in names_ending_engine]
51 |
52 | if len(preferred_names) == 1:
53 | name = preferred_names[0]
54 | elif len(engine_child_classes) == 1:
55 | name = engine_child_classes[0]
56 | elif len(names_ending_engine) == 1:
57 | name = names_ending_engine[0]
58 | elif len(class_names) == 1:
59 | name = class_names[0]
60 | else:
61 | raise ImportError(
62 | "Failed to find unique BaseOCREngine child class from OCR engine module '%s'. Classes "
63 | "inheriting from BaseOCREngine: %s. Class names defined by module: %s"
64 | % (module.__name__, engine_child_classes, class_names)
65 | )
66 | return module.__dict__[name]
67 |
68 |
69 | def get(engine_name: str, default_languages: Iterable[str]) -> BaseOCREngine:
70 | """Initialize a supported custom OCR engine by name
71 |
72 | Engines are dynamically imported, so that ImportErrors aren't raised for missing dependencies
73 | unless there's an actual attempt to create/use the engine.
74 |
75 | Parameters
76 | ----------
77 | engine_name :
78 | Name of a supported custom OCR engine to fetch.
79 | default_languages :
80 | Language codes to configure the engine to detect by default.
81 | """
82 | if engine_name in ENGINES:
83 | # Load the module:
84 | logger.info("Loading OCR engine '%s' from module '%s'", engine_name, ENGINES[engine_name])
85 | module = import_module(ENGINES[engine_name], package=__name__)
86 | # Locate the target class in the module:
87 | cls = _find_ocr_engine_class(module)
88 | logger.info("Loading engine class: %s", cls)
89 | # Load the engine:
90 | return cls(default_languages)
91 | else:
92 | raise ValueError(
93 | "Couldn't find engine '%s' in ocr_engines module. Not in set: %s"
94 | % (engine_name, ENGINES)
95 | )
96 |
--------------------------------------------------------------------------------
/notebooks/preproc/textract_transformers/ocr_engines/eng_tesseract.py:
--------------------------------------------------------------------------------
1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2 | # SPDX-License-Identifier: MIT-0
3 | """Example integration for (Py)Tesseract as a custom OCR engine
4 | """
5 | # Python Built-Ins:
6 | from logging import getLogger
7 | import os
8 | from statistics import mean
9 | from tempfile import TemporaryDirectory
10 | from typing import Iterable, List, Optional
11 |
12 | # External Dependencies:
13 | import pandas as pd
14 | import pytesseract
15 |
16 | # Local Dependencies:
17 | from .base import BaseOCREngine, generate_response_json, OCRGeometry, OCRLine, OCRPage, OCRWord
18 | from ..image_utils import Document
19 |
20 |
21 | logger = getLogger("eng_tesseract")
22 |
23 |
24 | if os.environ.get("TESSDATA_PREFIX") is None:
25 | os.environ["TESSDATA_PREFIX"] = "/opt/conda/share/tessdata"
26 |
27 |
28 | class TesseractEngine(BaseOCREngine):
29 | """Tesseract-based engine for custom SageMaker OCR endpoint option"""
30 |
31 | engine_name = "tesseract"
32 |
33 | def process(self, raw_doc: Document, languages: Optional[Iterable[str]] = None) -> dict:
34 | ocr_pages = []
35 |
36 | with TemporaryDirectory() as tmpdir:
37 | raw_doc.set_workspace(tmpdir)
38 | for ixpage, page in enumerate(raw_doc.get_pages()):
39 | logger.debug(f"Serializing page {ixpage + 1}")
40 | page_ocr = pytesseract.image_to_data(
41 | page.file_path,
42 | output_type=pytesseract.Output.DATAFRAME,
43 | lang="+".join(self.default_languages if languages is None else languages),
44 | pandas_config={
45 | # Need this explicit override or else pages containing only a single number
46 | # can sometimes have text column interpreted as numeric type:
47 | "dtype": {"text": str},
48 | },
49 | )
50 | ocr_pages += self.dataframe_to_ocrpages(page_ocr)
51 | return generate_response_json(ocr_pages, self.engine_name)
52 |
53 | @classmethod
54 | def dataframe_to_ocrpages(cls, ocr_df: pd.DataFrame) -> List[OCRPage]:
55 | """Convert a Tesseract DataFrame to a list of OCRPage ready for Textract-like serialization
56 |
57 | Tesseract TSVs / PyTesseract DataFrames group detections by multiple levels: Page, block,
58 | paragraph, line, word. Columns are: level, page_num, block_num, par_num, line_num, word_num,
59 | left, top, width, height, conf, text.
60 |
61 | Each level is introduced by a record, so for example there will be an initial record with
62 | (level=1, page_num=1, block_num=0, par_num=0, line_num=0, word_num=0)... And then several
63 | others before finally getting down to the first WORD record (level=5, page_num=1,
64 | block_num=1, par_num=1, line_num=1, word_num=1). Records are assumed to be sorted in order,
65 | as indeed they are direct from Tesseract.
66 | """
67 | # First construct an indexable list of page geometries, as we'll need these to normalize
68 | # other entity coordinates from absolute pixel values to 0-1 range:
69 | # (Note: In fact this function will often be called with only one page_num at a time)
70 | page_dims = (
71 | ocr_df[ocr_df["level"] == 1]
72 | .groupby("page_num")
73 | .agg(
74 | {
75 | "left": "min",
76 | "top": "min",
77 | "width": "max",
78 | "height": "max",
79 | "page_num": "count",
80 | }
81 | )
82 | )
83 | # There should be exactly one level=1 record per page in the dataframe. After checking
84 | # this, we can dispose the "page_num" count column.
85 | if (page_dims["page_num"] > 1).sum() > 0:
86 | raise ValueError(
87 | "Tesseract DataFrame had duplicate entries for these page_nums at level 1: %s"
88 | % page_dims.index[page_dims["page_num"] > 0].values[:20]
89 | )
90 | page_dims.drop(columns="page_num", inplace=True)
91 |
92 | # We need to collapse the {block, paragraph} levels of Tesseract hierarchy to preserve only
93 | # PAGE, LINE and WORD for consistency with Textract. Here we'll assume the DataFrame is in
94 | # its original Tesseract sort order, allowing iteration through the records to correctly
95 | # roll the entities up. Although iterating through large DataFrames isn't generally a
96 | # performant practice, this could always be balanced with specific parallelism if wanted:
97 | # E.g. processing multiple pages at once.
98 | pages = {
99 | num: OCRPage([]) # Initialise all pages first with no text
100 | for num in sorted(ocr_df[ocr_df["level"] == 1]["page_num"].unique())
101 | }
102 | cur_page_num = None
103 | page_lines = []
104 | cur_line_id = None
105 | line_words = []
106 |
107 | # Tesseract LINE records (level 4) don't have a confidence (equals -1), so we'll use the
108 | # average over the included WORDs as a heuristic. They *do* have T/L/H/W geometry info, but
109 | # we'll ignore that for the sake of code simplicity and let OCRLine infer it from the union
110 | # of all WORD bounding boxes.
111 | add_line = lambda words: (
112 | page_lines.append(OCRLine(mean(w.confidence for w in words), words))
113 | )
114 |
115 | # Loop through all WORD records, ignoring whitespace-only ones that Tesseract likes to yield
116 | words_df = ocr_df[ocr_df["level"] == 5].copy()
117 | words_df["text"] = words_df["text"].str.strip()
118 | words_df = words_df[words_df["text"].str.len() > 0]
119 | for _, rec in words_df.iterrows():
120 | line_id = (rec.block_num, rec.par_num, rec.line_num)
121 | page_num = rec.page_num
122 | if cur_line_id != line_id:
123 | # Start of new LINE - add previous one to result:
124 | if cur_line_id is not None:
125 | add_line(line_words)
126 | cur_line_id = line_id
127 | line_words = []
128 | if cur_page_num != page_num:
129 | # Start of new PAGE - add previous one to result:
130 | if cur_page_num is not None:
131 | pages[cur_page_num].add_lines(page_lines)
132 | cur_page_num = page_num
133 | page_lines = []
134 | # Parse this record into a WORD:
135 | page_dim_rec = page_dims.loc[page_num]
136 | line_words.append(
137 | OCRWord(
138 | rec.text,
139 | rec.conf,
140 | OCRGeometry.from_bbox(
141 | # Word geometries, too, need normalizing by page dimensions.
142 | (rec.top - page_dim_rec.top) / page_dim_rec.height,
143 | (rec.left - page_dim_rec.left) / page_dim_rec.width,
144 | rec.height / page_dim_rec.height,
145 | rec.width / page_dim_rec.width,
146 | ),
147 | )
148 | )
149 | # End of last line and last page: Add any remaining content.
150 | if len(line_words):
151 | add_line(line_words)
152 | if len(page_lines):
153 | pages[cur_page_num].add_lines(page_lines)
154 | return [page for page in pages.values()]
155 |
--------------------------------------------------------------------------------
/notebooks/review/.eslintrc.cjs:
--------------------------------------------------------------------------------
1 | /* eslint-env node */
2 |
3 | module.exports = {
4 | root: true,
5 | extends: [
6 | "plugin:vue/vue3-essential",
7 | "eslint:recommended",
8 | "@vue/eslint-config-typescript/recommended",
9 | "@vue/eslint-config-prettier",
10 | ],
11 | env: {
12 | "vue/setup-compiler-macros": true,
13 | },
14 | rules: {
15 | "vue/multi-word-component-names": [
16 | "error",
17 | {
18 | ignores: ["Viewer.ce"],
19 | },
20 | ],
21 | },
22 | };
23 |
--------------------------------------------------------------------------------
/notebooks/review/.gitignore:
--------------------------------------------------------------------------------
1 | # Local testing files
2 | public/*.pdf
3 | NOTES.md
4 | **.tmp
5 | **.tmp.*
6 |
7 | # Logs
8 | logs
9 | *.log
10 | npm-debug.log*
11 | yarn-debug.log*
12 | yarn-error.log*
13 | pnpm-debug.log*
14 | lerna-debug.log*
15 |
16 | node_modules
17 | .DS_Store
18 | dist
19 | dist-ssr
20 | coverage
21 | *.local
22 |
23 | # Editor directories and files
24 | .vscode/*
25 | !.vscode/extensions.json
26 | .idea
27 | *.suo
28 | *.ntvs*
29 | *.njsproj
30 | *.sln
31 | *.sw?
32 |
--------------------------------------------------------------------------------
/notebooks/review/.prettierrc.yml:
--------------------------------------------------------------------------------
1 | printWidth: 100
2 |
--------------------------------------------------------------------------------
/notebooks/review/env.d.ts:
--------------------------------------------------------------------------------
1 | ///