├── test
├── integ
│ ├── __init__.py
│ ├── test_access-control.py
│ ├── test_redaction.py
│ └── integ_base.py
├── load
│ ├── __init__.py
│ ├── access-control_load_test.py
│ ├── redaction_load_test.py
│ └── load_test_base.py
├── data
│ ├── integ
│ │ ├── clean.txt
│ │ ├── RandomImage.png
│ │ ├── pii_input.txt
│ │ ├── pii_output.txt
│ │ └── pii_bank_routing_redacted.txt
│ └── sample_event.json
└── unit
│ ├── conftest.py
│ ├── test_data_object.py
│ ├── test_exceptions.py
│ ├── test_util.py
│ ├── test_cloudwatch_client.py
│ ├── test_validators.py
│ ├── test_exception_handlers.py
│ ├── test_comprehend_client.py
│ ├── test_s3_client.py
│ └── test_processors.py
├── .pydocstyle
├── .flake8
├── images
├── architecture_redaction.gif
└── architecture_access_control.gif
├── src
├── lambdainit.py
├── lambdalogging.py
├── util.py
├── config.py
├── exceptions.py
├── data_object.py
├── validators.py
├── clients
│ ├── cloudwatch_client.py
│ ├── comprehend_client.py
│ └── s3_client.py
├── exception_handlers.py
├── processors.py
├── handler.py
└── constants.py
├── Pipfile
├── README.md
├── LICENSE
├── Makefile
├── .gitignore
├── access-control-template.yml
├── redaction-template.yml
├── ACCESS_CONTROL_README.md
└── REDACTION_README.md
/test/integ/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/test/load/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/test/data/integ/clean.txt:
--------------------------------------------------------------------------------
1 | This document contains no PII.
--------------------------------------------------------------------------------
/.pydocstyle:
--------------------------------------------------------------------------------
1 | [pydocstyle]
2 | ignore= D107,D203,D212,D205
3 |
--------------------------------------------------------------------------------
/.flake8:
--------------------------------------------------------------------------------
1 | [flake8]
2 | max-line-length = 140
3 | ignore = E126,W504
4 |
5 |
--------------------------------------------------------------------------------
/images/architecture_redaction.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aws-samples/amazon-comprehend-s3-object-lambda-functions/HEAD/images/architecture_redaction.gif
--------------------------------------------------------------------------------
/test/data/integ/RandomImage.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aws-samples/amazon-comprehend-s3-object-lambda-functions/HEAD/test/data/integ/RandomImage.png
--------------------------------------------------------------------------------
/images/architecture_access_control.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aws-samples/amazon-comprehend-s3-object-lambda-functions/HEAD/images/architecture_access_control.gif
--------------------------------------------------------------------------------
/test/unit/conftest.py:
--------------------------------------------------------------------------------
1 | """Setup unit test environment."""
2 |
3 | import sys
4 | import os
5 |
6 | # make sure tests can import the app code
7 | my_path = os.path.dirname(os.path.abspath(__file__))
8 | sys.path.insert(0, my_path + '/../../src/')
9 |
10 |
--------------------------------------------------------------------------------
/src/lambdainit.py:
--------------------------------------------------------------------------------
1 | """
2 | Special initializations for Lambda functions.
3 |
4 | This file must be imported as the first import in any file containing a Lambda function handler method.
5 | """
6 |
7 | import sys
8 |
9 | # add packaged dependencies to search path
10 | sys.path.append('lib')
11 |
--------------------------------------------------------------------------------
/Pipfile:
--------------------------------------------------------------------------------
1 | [[source]]
2 | url = "https://pypi.python.org/simple"
3 | verify_ssl = true
4 | name = "pypi"
5 |
6 | [packages]
7 | boto3 = "*"
8 | requests = "*"
9 |
10 | [dev-packages]
11 | flake8 = "*"
12 | autopep8 = "*"
13 | pydocstyle = "*"
14 | cfn-lint = "*"
15 | awscli = "*"
16 | pytest = "*"
17 | pytest-mock = "*"
18 | pytest-cov = "*"
19 |
20 | [requires]
21 | python_version = "3.8"
22 |
--------------------------------------------------------------------------------
/test/unit/test_data_object.py:
--------------------------------------------------------------------------------
1 | from unittest import TestCase
2 |
3 | from data_object import PiiConfig
4 | from exceptions import InvalidConfigurationException
5 |
6 |
7 | class DataObjectTest(TestCase):
8 |
9 | def test_Pii_config_valid_confidence_threshold(self):
10 | with self.assertRaises(InvalidConfigurationException) as e:
11 | PiiConfig(confidence_threshold=0.1)
12 | assert e.exception.message == 'CONFIDENCE_THRESHOLD is not within allowed range [0.5,1]'
--------------------------------------------------------------------------------
/test/data/integ/pii_input.txt:
--------------------------------------------------------------------------------
1 | Hello Zhang Wei. Your AnyCompany Financial Services, LLC credit card account 1111-0000-1111-0000 has a minimum payment of $24.53 that is due by July 31st. Based on your autopay settings, we will withdraw your payment on the due date from your bank account XXXXXX1111 with the routing number XXXXX0000.
2 |
3 | Your latest statement was mailed to 100 Main Street, Anytown, WA 98121.
4 |
5 | After your payment is received, you will receive a confirmation text message at 206-555-0100.
6 |
7 | If you have questions about your bill, AnyCompany Customer Service is available by phone at 206-555-0199 or email at support@anycompany.com.
--------------------------------------------------------------------------------
/test/data/integ/pii_output.txt:
--------------------------------------------------------------------------------
1 | Hello *********. Your AnyCompany Financial Services, LLC credit card account ******************* has a minimum payment of $24.53 that is due by *********. Based on your autopay settings, we will withdraw your payment on the due date from your bank account ********** with the routing number *********.
2 |
3 | Your latest statement was mailed to **********************************.
4 |
5 | After your payment is received, you will receive a confirmation text message at ************.
6 |
7 | If you have questions about your bill, AnyCompany Customer Service is available by phone at ************ or email at **********************.
--------------------------------------------------------------------------------
/test/data/integ/pii_bank_routing_redacted.txt:
--------------------------------------------------------------------------------
1 | Hello Zhang Wei. Your AnyCompany Financial Services, LLC credit card account ******************* has a minimum payment of $24.53 that is due by July 31st. Based on your autopay settings, we will withdraw your payment on the due date from your bank account ********** with the routing number *********.
2 |
3 | Your latest statement was mailed to 100 Main Street, Anytown, WA 98121.
4 |
5 | After your payment is received, you will receive a confirmation text message at 206-555-0100.
6 |
7 | If you have questions about your bill, AnyCompany Customer Service is available by phone at 206-555-0199 or email at **********************.
--------------------------------------------------------------------------------
/test/unit/test_exceptions.py:
--------------------------------------------------------------------------------
1 | from unittest import TestCase
2 |
3 | from constants import CONTENT_LENGTH, RANGE
4 | from exceptions import UnsupportedFileException
5 |
6 |
7 | class ExceptionsTest(TestCase):
8 | def test_unsupported_file_format_exception(self):
9 | try:
10 | http_headers = {CONTENT_LENGTH: 1234, RANGE: "0-123"}
11 | raise UnsupportedFileException("Some Random blob".encode('utf-8'), http_headers, "Unsupported file")
12 | except UnsupportedFileException as exception:
13 | assert exception.file_content.decode('utf-8') == "Some Random blob"
14 | assert str(exception) == "Unsupported file"
15 | assert exception.http_headers == http_headers
16 |
--------------------------------------------------------------------------------
/src/lambdalogging.py:
--------------------------------------------------------------------------------
1 | """
2 | Lambda logging helper.
3 |
4 | Returns a Logger with log level set based on env variables.
5 | """
6 |
7 | import logging
8 |
9 | from config import LOG_LEVEL
10 |
11 | # translate log level from string to numeric value
12 | LOG_LEVEL = getattr(logging, LOG_LEVEL) if hasattr(logging, LOG_LEVEL) else logging.DEBUG
13 |
14 | # setup logging levels for botocore
15 | logging.getLogger('botocore.endpoint').setLevel(LOG_LEVEL)
16 | logging.getLogger('botocore.retryhandler').setLevel(LOG_LEVEL)
17 | logging.getLogger('botocore.parsers').setLevel(LOG_LEVEL)
18 |
19 |
20 | def getLogger(name):
21 | """Return a logger configured based on env variables."""
22 | logger = logging.getLogger(name)
23 | # in lambda environment, logging config has already been setup so can't use logging.basicConfig to change log level
24 | logger.setLevel(LOG_LEVEL)
25 | return logger
26 |
--------------------------------------------------------------------------------
/src/util.py:
--------------------------------------------------------------------------------
1 | """Utility Class."""
2 | import concurrent
3 | from concurrent.futures._base import TimeoutError
4 |
5 | import lambdalogging
6 | from exceptions import TimeoutException
7 |
8 | LOG = lambdalogging.getLogger(__name__)
9 |
10 |
11 | def execute_task_with_timeout(timeout_in_millis, task):
12 | """
13 | Execute a given task within a given time limit.
14 | :param timeout_in_millis: milliseconds to timeout
15 | :param task: task to execute
16 | :raise: TimeoutException
17 | """
18 | timeout_in_sec = int(timeout_in_millis / 1000)
19 | try:
20 | executor = concurrent.futures.ThreadPoolExecutor(max_workers=1)
21 | future_result = executor.submit(task)
22 | return future_result.result(timeout=timeout_in_sec)
23 | except TimeoutError:
24 | # Free up the resources
25 | executor.shutdown(wait=False)
26 | raise TimeoutException()
27 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Comprehend S3 Object Lambda functions
2 |
3 | This package contains the source code for Lambda functions published by AWS Comprehend that can be used with [S3 Object Lambda Access Points](https://docs.aws.amazon.com/AmazonS3/latest/userguide/transforming-objects.html).
4 | These Lambda functions can be deployed from Serveless App Repositories.
5 |
6 | Refer README file for specific lambdas for more details on the functionality of each lambda.
7 | * [PII Access Control Lambda](ACCESS_CONTROL_README.md)-This lambda function helps you to control access to file present in s3 containing PII (Personally Identifiable Information).
8 | * [PII Redaction Lambda](REDACTION_README.md)-This lambda function helps you to redact PII (Personally Identifiable Information) from text files present in S3.
9 |
10 |
11 | ## License Summary
12 |
13 | This code is made available under the MIT-0 license. See the LICENSE file.
14 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.
2 |
3 | Permission is hereby granted, free of charge, to any person obtaining a copy of this
4 | software and associated documentation files (the "Software"), to deal in the Software
5 | without restriction, including without limitation the rights to use, copy, modify,
6 | merge, publish, distribute, sublicense, and/or sell copies of the Software, and to
7 | permit persons to whom the Software is furnished to do so.
8 |
9 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED,
10 | INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A
11 | PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT
12 | HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
13 | OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
14 | SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
--------------------------------------------------------------------------------
/test/unit/test_util.py:
--------------------------------------------------------------------------------
1 | import time
2 | from time import sleep
3 | from unittest import TestCase
4 |
5 | from exceptions import TimeoutException, FileSizeLimitExceededException
6 | from util import execute_task_with_timeout
7 |
8 |
9 | class UtilTest(TestCase):
10 |
11 | def test_execute_task_with_timeout_with_time_limit_exceeded(self):
12 | with self.assertRaises(TimeoutException) as e:
13 | start_time = time.time()
14 |
15 | def sleep_2():
16 | sleep(5)
17 |
18 | execute_task_with_timeout(1000, sleep_2)
19 | elapsed_time = time.time() - start_time
20 | assert 1000 <= elapsed_time * 1000 <= 1100
21 |
22 | def test_execute_task_with_timeout_time_limit_not_exceeded(self):
23 | start_time = time.time()
24 |
25 | def sleep_1():
26 | sleep(1)
27 |
28 | execute_task_with_timeout(2000, sleep_1)
29 | elapsed_time = time.time() - start_time
30 | assert 1000 <= elapsed_time * 1000 <= 1100
31 |
32 | def test_execute_task_with_timeout_when_task_fails(self):
33 | def task():
34 | raise FileSizeLimitExceededException()
35 |
36 | with self.assertRaises(FileSizeLimitExceededException) as e:
37 | execute_task_with_timeout(2000, task)
38 |
--------------------------------------------------------------------------------
/src/config.py:
--------------------------------------------------------------------------------
1 | """Contain the configurations used in the package."""
2 | import os
3 |
4 | from constants import UNSUPPORTED_FILE_HANDLING_VALID_VALUES, MASK_MODE_VALID_VALUES
5 |
6 | DOCUMENT_MAX_SIZE_CONTAINS_PII_ENTITIES = int(os.getenv('DOCUMENT_MAX_SIZE_CONTAINS_PII_ENTITIES', 50 * 1000)) # 50 KB
7 | DOCUMENT_MAX_SIZE_DETECT_PII_ENTITIES = int(os.getenv('DOCUMENT_MAX_SIZE_DETECT_PII_ENTITIES', 5 * 1000)) # 5KB
8 | PII_ENTITY_TYPES = str(os.getenv('PII_ENTITY_TYPES', 'ALL'))
9 | IS_PARTIAL_OBJECT_SUPPORTED = str(os.getenv('IS_PARTIAL_OBJECT_SUPPORTED', 'false')).lower() == 'true'
10 | MASK_CHARACTER = str(os.getenv('MASK_CHARACTER', '*'))
11 | MASK_MODE = MASK_MODE_VALID_VALUES[os.getenv('MASK_MODE', MASK_MODE_VALID_VALUES.MASK.name)]
12 | SUBSEGMENT_OVERLAPPING_TOKENS = int(os.getenv('SUBSEGMENT_OVERLAPPING_TOKENS', 20))
13 | DOCUMENT_MAX_SIZE = int(os.getenv('DOCUMENT_MAX_SIZE', 100 * 1024))
14 | CONFIDENCE_THRESHOLD = float(os.getenv('CONFIDENCE_THRESHOLD', 0.5))
15 | assert 0.5 <= CONFIDENCE_THRESHOLD <= 1.0, "CONFIDENCE_THRESHOLD is not within allowed range [0.5,1]"
16 | MAX_CHARS_OVERLAP = int(os.getenv('MAX_CHARS_OVERLAP', 200))
17 | DEFAULT_LANGUAGE_CODE = str(os.getenv('DEFAULT_LANGUAGE_CODE', 'en'))
18 | REDACTION_API_ONLY = os.getenv('REDACTION_API_ONLY', 'false').lower() == 'true'
19 |
20 | UNSUPPORTED_FILE_HANDLING = UNSUPPORTED_FILE_HANDLING_VALID_VALUES[
21 | os.getenv('UNSUPPORTED_FILE_HANDLING', UNSUPPORTED_FILE_HANDLING_VALID_VALUES.FAIL.name)]
22 |
23 | DETECT_PII_ENTITIES_THREAD_COUNT = int(os.getenv('DETECT_PII_ENTITIES_THREAD_COUNT', 8))
24 | CONTAINS_PII_ENTITIES_THREAD_COUNT = int(os.getenv('CONTAINS_PII_ENTITIES_THREAD_COUNT', 20))
25 | PUBLISH_CLOUD_WATCH_METRICS = os.getenv('PUBLISH_CLOUD_WATCH_METRICS', 'true').lower() == 'true'
26 | COMPREHEND_ENDPOINT_URL = None if os.getenv('COMPREHEND_ENDPOINT_URL', '') == '' else os.getenv('COMPREHEND_ENDPOINT_URL')
27 |
28 | LOG_LEVEL = os.getenv('LOG_LEVEL', 'INFO')
29 |
--------------------------------------------------------------------------------
/test/load/access-control_load_test.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import time
3 | import uuid
4 |
5 | from botocore.exceptions import ClientError
6 |
7 | from load.load_test_base import BaseLoadTest
8 |
9 |
10 | class PiiAccessControlLoadTest(BaseLoadTest):
11 | BUILD_DIR = '.aws-sam/build/PiiAccessControlFunction/'
12 | PII_ENTITY_TYPES_IN_TEST_DOC = ['EMAIL', 'ADDRESS', 'NAME', 'PHONE', 'DATE_TIME']
13 |
14 | @classmethod
15 | def setUpClass(cls) -> None:
16 | super().setUpClass()
17 | test_run_id = str(uuid.uuid4())[0:8]
18 | cls.lambda_function_arn = cls._create_function("pii_access_control_load_test", 'handler.pii_access_control_handler',
19 | test_run_id, {"LOG_LEVEL": "DEBUG"}, cls.lambda_role_arn, cls.BUILD_DIR)
20 | cls.s3ol_access_point_arn = cls._create_s3ol_access_point(test_run_id, cls.lambda_function_arn)
21 | cls._update_lambda_env_variables(cls.lambda_function_arn, {"LOG_LEVEL": "DEBUG",
22 | "DOCUMENT_MAX_SIZE": "1500000"})
23 | logging.info(f"Created access point: {cls.s3ol_access_point_arn} for testing")
24 |
25 | @classmethod
26 | def tearDownClass(cls) -> None:
27 | cls.s3_ctrl.delete_access_point_for_object_lambda(
28 | AccountId=cls.account_id,
29 | Name=cls.s3ol_access_point_arn.split('/')[-1]
30 | )
31 |
32 | cls.lambda_client.delete_function(
33 | FunctionName=cls.lambda_function_arn.split(':')[-1]
34 | )
35 | super().tearDownClass()
36 |
37 | def tearDown(self) -> None:
38 | self._update_lambda_env_variables(self.lambda_function_arn, {"LOG_LEVEL": "DEBUG"})
39 |
40 | def test_access_control_lambda_with_varying_load(self):
41 | document_sizes = [1, 5, 50, 1000, 1500]
42 | for size in document_sizes:
43 | self.find_max_tpm(self.s3ol_access_point_arn,size, True, ClientError)
44 | time.sleep(90)
45 |
--------------------------------------------------------------------------------
/src/exceptions.py:
--------------------------------------------------------------------------------
1 | """Custom defined exceptions."""
2 |
3 |
4 | class CustomException(Exception):
5 | """Exceptions which are generated because of non-compliance of business constraints of this lambda."""
6 |
7 |
8 | class UnsupportedFileException(CustomException):
9 | """Exception generated when we encounter an unsupported file format. For e.g. an image file."""
10 |
11 | def __init__(self, file_content, http_headers, *args, **kwargs):
12 | super().__init__(*args)
13 | self.file_content = file_content
14 | self.http_headers = http_headers
15 |
16 |
17 | class FileSizeLimitExceededException(CustomException):
18 | """
19 | Exception representing file size beyond the supported limits.
20 | Files beyond this size are prone to take too long to process causing timeouts.
21 | """
22 |
23 | pass
24 |
25 |
26 | class TimeoutException(CustomException):
27 | """Exception raised when some task is not able to complete within a certain time limit."""
28 |
29 | pass
30 |
31 |
32 | class InvalidConfigurationException(CustomException):
33 | """Exception representing an incorrect configuration of the access point such as incorrect function payload structure."""
34 |
35 | def __init__(self, message, *args, **kwargs):
36 | super().__init__(*args)
37 | self.message = message
38 |
39 |
40 | class InvalidRequestException(CustomException):
41 | """Exception representing an invalid request."""
42 |
43 | def __init__(self, message, *args, **kwargs):
44 | super().__init__(*args)
45 | self.message = message
46 |
47 |
48 | class S3DownloadException(CustomException):
49 | """Exception representing an error occurring during downloading from the presigned url."""
50 |
51 | def __init__(self, s3_error_code, s3_message, *args, **kwargs):
52 | super().__init__(*args)
53 | self.s3_error_code = s3_error_code
54 | self.s3_message = s3_message
55 |
56 |
57 | class RestrictedDocumentException(CustomException):
58 | """Exception representing a restricted document throw when it contains pii."""
59 |
--------------------------------------------------------------------------------
/test/data/sample_event.json:
--------------------------------------------------------------------------------
1 | {
2 | "xAmzRequestId": "FEDCBA0987654321",
3 | "getObjectContext": {
4 | "inputS3Url": "https://pii-document-for-banner.s3.amazonaws.com/SomeText?AWSAccessKeyId=ASIAZI7HWWFYYJMPESP7&Signature=x%2B7jhKr7N2e%2FdNAitb%2F3RrtDJ2o%3D&x-amz-security-token=IQoJb3JpZ2luX2VjEFYaCXVzLWVhc3QtMSJGMEQCIACRXcrjeRvpIhWgaOOmBAv8FDCjdLjIQtlmG4ZvrvLAiA0h%2B%2Bkxk0tHuHNWl0cXMJ3oIpK3PNRWNk9IcstgK8ppyqnAgiv%2F%2F%2F%2F%2F%2F%2F%2F%2F%2F8BEAEaDDYzNzc1MTcwMTg3MyIMoxQ1S3IApjAJO6UTKvsBESuzIXwqKlu64hEGDz%2Bmw5oSoPxAvjR1cmbLCf2SYIdIdEBQpuMqHaXFGjHFsW4%2BcppiYXZlxcqOG8jZn79zd%2FvhHzTzB2vGjD2ooK7bPk9D3hx4voJzsPpNdsIPIba4DwYZlNbBBpCIhHcrGdKuhZ6UlI4wbGMA%2FE646mLQXK5H6pjlrZmKbaMlS9%2BdPQagSkqYioLN%2FWmya4rMag9h06IOyhP%2BLrQqldLLblod1otzfjf1Klc4KWJPm2StzTR85MI2zN2DHQ0WezZn6CeT80rM4te2Vyc0nW%2BJaA1msH2A0Yg57JEJx0ZzTifvfCXDVjds4bXYKVNUCBMwy%2FqB%2FQU6ngF0xA7fDEkIsHRFluXzDsESrlomBiSOD%2FayvNvgWHVXA5YW4g0hODrLE%2BZtXr5RhdMu7BSe6HQvrfm4pLfEDZczu9A9%2Fwc9V27vGkaJQHeNFO%2FBALDFHn4mkGuS2OVlh9uOF78XqrSKkmRB4B4RIE7dp54ODaoGaTp9RVKnTnr8K3KF5i8s%2Fl6T5ukl%2FU29ja6EZ6607ea00YIdln81gA%3D%3D&Expires=1604364707",
5 | "outputRoute": "io-a1c1d6c7",
6 | "outputToken": "ReallyLongEncryptedToken"
7 | },
8 | "configuration": {
9 | "accessPointArn": "arn:aws:s3-object-lambda:us-east-1:111222333444:accesspoint/my-banner-ap",
10 | "supportingAccessPointArn": "arn:aws:s3:us-east-1:123456789012:accesspoint/existing-ap",
11 | "payload": "{\"pii_entity_types\" : [\"ALL\",\"CREDIT_DEBIT_NUMBER\"],\"mask_mode\":\"MASK\", \"mask_character\" : \"*\",\"confidence_threshold\":0.6,\"language_code\":\"en\"}"
12 | },
13 | "userRequest": {
14 | "url": "https://my-banner-ap-111222333444.s3-banner.us-east-1.amazonaws.com/foo?QueryParam=2",
15 | "headers": {
16 | "Content-type": "application/txt",
17 | "CustomHeader": "ClientSpecifiedValue"
18 | }
19 | },
20 | "userIdentity": {
21 | "accountId": "111222333444",
22 | "principalId": "AIDAJ45Q7YFFAREXAMPLE",
23 | "arn": "arn:aws:iam::111222333444:user/Alice",
24 | "groups": "Finance,Users"
25 | },
26 | "protocolVersion": "1.00"
27 | }
28 |
--------------------------------------------------------------------------------
/src/data_object.py:
--------------------------------------------------------------------------------
1 | """Module containing some custom data structures ."""
2 |
3 | import os
4 | from typing import List
5 |
6 | from exceptions import InvalidConfigurationException
7 |
8 |
9 | class PiiConfig:
10 | """PiiConfig class represents the base config for classification and redaction."""
11 |
12 | def __init__(self, pii_entity_types: List = None,
13 | confidence_threshold: float = os.getenv('CONFIDENCE_THRESHOLD', 0.5),
14 | **kwargs):
15 | self.pii_entity_types = pii_entity_types
16 | self.confidence_threshold = float(confidence_threshold)
17 | if not 0.5 <= self.confidence_threshold <= 1.0:
18 | raise InvalidConfigurationException('CONFIDENCE_THRESHOLD is not within allowed range [0.5,1]')
19 | if self.pii_entity_types is None:
20 | self.pii_entity_types = os.getenv('PII_ENTITY_TYPES', 'ALL').split(',')
21 |
22 |
23 | class ClassificationConfig(PiiConfig):
24 | """ClassificationConfig class represents the config to be used for classification."""
25 |
26 | def __init__(self, pii_entity_types: List = None,
27 | confidence_threshold: float = os.getenv('CONFIDENCE_THRESHOLD', 0.5),
28 | **kwargs):
29 | super().__init__(pii_entity_types, confidence_threshold, **kwargs)
30 |
31 |
32 | class RedactionConfig(ClassificationConfig):
33 | """RedactionConfig class represents the config to be used for redaction."""
34 |
35 | def __init__(self, pii_entity_types: List = None, mask_mode: str = os.getenv('MASK_MODE', 'MASK'),
36 | mask_character: str = os.getenv('MASK_CHARACTER', '*'),
37 | confidence_threshold: float = os.getenv('CONFIDENCE_THRESHOLD', 0.5),
38 | **kwargs):
39 | super().__init__(pii_entity_types, confidence_threshold, **kwargs)
40 | self.mask_character = mask_character
41 | self.mask_mode = mask_mode
42 |
43 |
44 | class Document:
45 | """A chunk of text."""
46 |
47 | def __init__(self, text: str, char_offset: int = 0, pii_classification: map = {},
48 | pii_entities: List = [], redacted_text: str = ''):
49 | self.text = text
50 | self.char_offset = char_offset
51 | self.pii_classification = pii_classification
52 | self.pii_entities = pii_entities
53 | self.redacted_text = redacted_text
54 |
--------------------------------------------------------------------------------
/Makefile:
--------------------------------------------------------------------------------
1 | SHELL := /bin/sh
2 | PY_VERSION := 3.8
3 |
4 | export PYTHONUNBUFFERED := 1
5 |
6 | SRC_DIR := src
7 | SAM_DIR := .aws-sam
8 |
9 | # Required environment variables (user must override)
10 |
11 | # S3 bucket used for packaging SAM templates
12 | PACKAGE_BUCKET ?= comprehend-s3-object-lambdas
13 |
14 | # user can optionally override the following by setting environment variables with the same names before running make
15 |
16 | # Path to system pip
17 | PIP ?= pip
18 | # Default AWS CLI region
19 | AWS_DEFAULT_REGION ?= us-west-2
20 |
21 | PYTHON := $(shell /usr/bin/which python$(PY_VERSION))
22 |
23 | .DEFAULT_GOAL := build
24 |
25 | clean:
26 | rm -f $(SRC_DIR)/requirements.txt
27 | rm -rf $(SAM_DIR)
28 |
29 | # used once just after project creation to lock and install dependencies
30 | bootstrap:
31 | $(PYTHON) -m $(PIP) install pipenv
32 | pipenv lock
33 | pipenv sync --dev
34 |
35 | # used by CI build to install dependencies
36 | init:
37 | $(PYTHON) -m $(PIP) install aws-sam-cli
38 | $(PYTHON) -m $(PIP) install pipenv
39 | pipenv sync --dev
40 | pipenv lock --requirements > $(SRC_DIR)/requirements.txt
41 |
42 | build:
43 | pipenv run flake8 $(SRC_DIR)
44 | pipenv run pydocstyle $(SRC_DIR)
45 | pipenv run cfn-lint $(LAMBDA_NAME)-template.yml
46 | sam build --profile sar-account --template $(LAMBDA_NAME)-template.yml
47 | mv $(SAM_DIR)/build/template.yaml $(SAM_DIR)/build/$(LAMBDA_NAME)-template.yml
48 |
49 | unit-testing: build
50 | pipenv run py.test --cov=$(SRC_DIR) --cov-fail-under=97 -vv test/unit -s --cov-report html
51 |
52 | # can be triggered as `make integ-testing LAMBDA_NAME=access-control`
53 | integ-testing: unit-testing
54 | pipenv run py.test -s -vv test/integ/test_$(LAMBDA_NAME).py
55 |
56 | load-testing:
57 | pipenv run py.test -s -vv test/load/$(LAMBDA_NAME)_load_test.py --log-cli-level=INFO
58 |
59 | package:
60 | sam package --region us-east-1 --profile sar-account --template $(SAM_DIR)/build/$(LAMBDA_NAME)-template.yml --s3-bucket $(PACKAGE_BUCKET) --output-template-file $(SAM_DIR)/packaged-$(LAMBDA_NAME)-template.yml
61 |
62 | deploy: package
63 | sam deploy --profile sar-account --region us-east-1 --template-file $(SAM_DIR)/packaged-$(LAMBDA_NAME)-template.yml --capabilities CAPABILITY_IAM --stack-name $(LAMBDA_NAME)-lambda
64 |
65 | publish: package
66 | sam publish --region us-east-1 --template $(SAM_DIR)/packaged-$(LAMBDA_NAME)-template.yml --profile sar-account
67 |
--------------------------------------------------------------------------------
/test/unit/test_cloudwatch_client.py:
--------------------------------------------------------------------------------
1 | from unittest import TestCase
2 | from unittest.mock import patch, MagicMock
3 |
4 | from clients.cloudwatch_client import CloudWatchClient
5 |
6 | S3OL_ACCESS_POINT_TEST = "arn:aws:s3-object-lambda:us-east-1:000000000000:accesspoint/myPiiAp"
7 |
8 |
9 | class CloudWatchClientTest(TestCase):
10 | @patch('clients.cloudwatch_client.boto3')
11 | def test_cloudwatch_client_put_pii_document_processed_metric(self, mocked_boto3):
12 | mocked_client = MagicMock()
13 | mocked_boto3.client.return_value = mocked_client
14 |
15 | cloudwatch = CloudWatchClient()
16 | cloudwatch.put_pii_document_processed_metric('en', S3OL_ACCESS_POINT_TEST)
17 |
18 | mocked_client.put_metric_data.assert_called_once()
19 |
20 | @patch('clients.cloudwatch_client.boto3')
21 | def test_cloudwatch_client_put_document_processed_metric(self, mocked_boto3):
22 | mocked_client = MagicMock()
23 | mocked_boto3.client.return_value = mocked_client
24 |
25 | cloudwatch = CloudWatchClient()
26 | cloudwatch.put_document_processed_metric('en', S3OL_ACCESS_POINT_TEST)
27 |
28 | mocked_client.put_metric_data.assert_called_once()
29 |
30 | @patch('clients.cloudwatch_client.boto3')
31 | def test_cloudwatch_client_put_pii_document_types_metric(self, mocked_boto3):
32 | mocked_client = MagicMock()
33 | mocked_boto3.client.return_value = mocked_client
34 |
35 | cloudwatch = CloudWatchClient()
36 | cloudwatch.put_pii_document_types_metric(['SSN', 'PHONE'], 'en', S3OL_ACCESS_POINT_TEST)
37 |
38 | mocked_client.put_metric_data.assert_called_once()
39 |
40 | def test_segment_metrics_segmentation_required(self):
41 | cloudwatch = CloudWatchClient()
42 | metric_list = [i for i in range(0, 100)]
43 | chunks = cloudwatch.segment_metric_data(metric_list)
44 | total_metrics = 0
45 | for chunk in chunks:
46 | assert len(chunk) <= cloudwatch.MAX_METRIC_DATA
47 | total_metrics += len(chunk)
48 | assert total_metrics == 100
49 |
50 | def test_segment_metrics_segmentation_not_required(self):
51 | cloudwatch = CloudWatchClient()
52 | metric_list = [i for i in range(0, 10)]
53 | chunks = cloudwatch.segment_metric_data(metric_list)
54 | total_metrics = 0
55 | for chunk in chunks:
56 | assert len(chunk) <= cloudwatch.MAX_METRIC_DATA
57 | total_metrics += len(chunk)
58 | assert total_metrics == 10
59 |
--------------------------------------------------------------------------------
/src/validators.py:
--------------------------------------------------------------------------------
1 | """Represents the different validations we would use."""
2 | import json
3 |
4 | from config import IS_PARTIAL_OBJECT_SUPPORTED
5 | from constants import REQUEST_ID, REQUEST_TOKEN, REQUEST_ROUTE, GET_OBJECT_CONTEXT, S3OL_CONFIGURATION, INPUT_S3_URL, PAYLOAD, \
6 | PART_NUMBER, RANGE, USER_REQUEST, HEADERS
7 | from exceptions import InvalidConfigurationException, InvalidRequestException
8 |
9 |
10 | class Validator:
11 | """Generic validator class representing one container which does a particular type of validation."""
12 |
13 | @staticmethod
14 | def validate(object):
15 | """Execute the validation."""
16 | raise NotImplementedError
17 |
18 |
19 | class JsonValidator(Validator):
20 | """Simply validates that given string can be converted to Json object or not."""
21 |
22 | @staticmethod
23 | def validate(json_string: str):
24 | """Simply validates that given string can be converted to Json object or not."""
25 | try:
26 |
27 | json.loads(json_string)
28 | except ValueError:
29 | raise Exception("Invalid Json %s", json_string)
30 |
31 |
32 | class PartialObjectRequestValidator(Validator):
33 | """Validates that the GetObject request is not for a partial object."""
34 |
35 | @staticmethod
36 | def validate(input_event: str):
37 | """Perform the validation."""
38 | RESTRICTED_HEADERS = [RANGE, PART_NUMBER]
39 | if not IS_PARTIAL_OBJECT_SUPPORTED:
40 | if HEADERS in input_event[USER_REQUEST]:
41 | for header in input_event[USER_REQUEST][HEADERS]:
42 | if header in RESTRICTED_HEADERS:
43 | raise InvalidRequestException(f"HTTP Header {header} is not supported")
44 |
45 |
46 | class InputEventValidator(Validator):
47 | """Validate the main lambda input."""
48 |
49 | @staticmethod
50 | def validate(event):
51 | """Validate the main lambda input."""
52 | # validations on parts of the event S3 control
53 | assert S3OL_CONFIGURATION in event
54 | assert GET_OBJECT_CONTEXT in event
55 | assert REQUEST_TOKEN in event[GET_OBJECT_CONTEXT]
56 | assert REQUEST_ROUTE in event[GET_OBJECT_CONTEXT]
57 | assert REQUEST_ID in event
58 | assert INPUT_S3_URL in event[GET_OBJECT_CONTEXT]
59 | assert PAYLOAD in event[S3OL_CONFIGURATION]
60 |
61 | # parts of the event derived from access point configuration
62 | try:
63 | if event[S3OL_CONFIGURATION][PAYLOAD]:
64 | JsonValidator.validate(event[S3OL_CONFIGURATION][PAYLOAD])
65 | except Exception:
66 | raise InvalidConfigurationException(f"Invalid function payload: {event[S3OL_CONFIGURATION][PAYLOAD]}")
67 |
--------------------------------------------------------------------------------
/test/load/redaction_load_test.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import time
3 | import uuid
4 |
5 | from load.load_test_base import BaseLoadTest
6 |
7 |
8 | class PiiRedactionLoadTest(BaseLoadTest):
9 | BUILD_DIR = '.aws-sam/build/PiiRedactionFunction/'
10 | PII_ENTITY_TYPES_IN_TEST_DOC = ['EMAIL', 'ADDRESS', 'NAME', 'PHONE', 'DATE_TIME']
11 |
12 | @classmethod
13 | def setUpClass(cls) -> None:
14 | super().setUpClass()
15 | test_run_id = str(uuid.uuid4())[0:8]
16 | cls.lambda_function_arn = cls._create_function("pii_redaction_load_test", 'handler.redact_pii_documents_handler',
17 | test_run_id, {"LOG_LEVEL": "DEBUG"}, cls.lambda_role_arn, cls.BUILD_DIR)
18 | cls.s3ol_access_point_arn = cls._create_s3ol_access_point(test_run_id, cls.lambda_function_arn)
19 | cls._update_lambda_env_variables(cls.lambda_function_arn, {"LOG_LEVEL": "DEBUG",
20 | "DOCUMENT_MAX_SIZE": "1600000"})
21 | logging.info(f"Created access point: {cls.s3ol_access_point_arn} for testing")
22 |
23 | @classmethod
24 | def tearDownClass(cls) -> None:
25 | cls.s3_ctrl.delete_access_point_for_object_lambda(
26 | AccountId=cls.account_id,
27 | Name=cls.s3ol_access_point_arn.split('/')[-1]
28 | )
29 |
30 | cls.lambda_client.delete_function(
31 | FunctionName=cls.lambda_function_arn.split(':')[-1]
32 | )
33 | super().tearDownClass()
34 |
35 | def tearDown(self) -> None:
36 | self._update_lambda_env_variables(self.lambda_function_arn, {"LOG_LEVEL": "DEBUG"})
37 |
38 | def test_redaction_lambda_with_varying_load(self):
39 | variations = [(1, True),
40 | (5, False),
41 | (5, True),
42 | (50, False),
43 | (50, True),
44 | (1000, False),
45 | (1000, True),
46 | (1500, False),
47 | (1500, True)
48 | ]
49 | for text_size, is_pii in variations:
50 | self.execute_load_on_get_object(self.s3ol_access_point_arn, text_size, is_pii, None)
51 | time.sleep(60)
52 |
53 | def test_redaction_lambda_find_max_conn(self):
54 | variations = [
55 | (1, True),
56 | (5, False),
57 | (5, True),
58 | (50, False),
59 | (50, True),
60 | (1000, False),
61 | (1000, True),
62 | (1500, False),
63 | (1500, True)
64 | ]
65 | for text_size, is_pii in variations:
66 | self.find_max_tpm(self.s3ol_access_point_arn, text_size, is_pii, None)
67 |
--------------------------------------------------------------------------------
/src/clients/cloudwatch_client.py:
--------------------------------------------------------------------------------
1 | """Client wrapper over aws services."""
2 |
3 | from typing import List
4 |
5 | import boto3
6 |
7 | import lambdalogging
8 | from constants import CLOUD_WATCH_NAMESPACE, LANGUAGE, COUNT, PII_DOCUMENTS_PROCESSED, DOCUMENTS_PROCESSED, NAME, \
9 | VALUE, S3OL_ACCESS_POINT, METRIC_NAME, UNIT, DIMENSIONS, PII_DOCUMENT_TYPES_PROCESSED, PII_ENTITY_TYPE, \
10 | LATENCY, API, SERVICE, ERROR_COUNT, MILLISECONDS
11 |
12 | LOG = lambdalogging.getLogger(__name__)
13 |
14 |
15 | class Metrics:
16 | """Metrics class for latency and fault counts."""
17 |
18 | def __init__(self, service_name, api, s3ol_access_point, cloudwatch_namespace=CLOUD_WATCH_NAMESPACE):
19 | self.cloudwatch_namespace = cloudwatch_namespace
20 | self.service_name = service_name
21 | self.s3ol_access_point_arn = s3ol_access_point
22 | self.api = api
23 | self.metrics = []
24 |
25 | def add_latency(self, start_time: float, end_time: float):
26 | """Add a latency metric."""
27 | self.metrics.append({METRIC_NAME: LATENCY, DIMENSIONS: [
28 | {NAME: API, VALUE: self.api},
29 | {NAME: S3OL_ACCESS_POINT, VALUE: self.s3ol_access_point_arn},
30 | {NAME: SERVICE, VALUE: self.service_name}
31 | ], UNIT: MILLISECONDS, VALUE: (end_time - start_time) * 1000})
32 |
33 | def add_fault_count(self, count: int = 1):
34 | """Add a fault count metric."""
35 | self.metrics.append({METRIC_NAME: ERROR_COUNT, DIMENSIONS: [
36 | {NAME: API, VALUE: self.api},
37 | {NAME: S3OL_ACCESS_POINT, VALUE: self.s3ol_access_point_arn},
38 | {NAME: SERVICE, VALUE: self.service_name}
39 | ], UNIT: COUNT, VALUE: count})
40 |
41 |
42 | class CloudWatchClient:
43 | """Wrapper over cloudwatch client."""
44 |
45 | MAX_METRIC_DATA = 15
46 |
47 | def __init__(self):
48 | self.cloudwatch = boto3.client('cloudwatch')
49 |
50 | def segment_metric_data(self, metric_list: List):
51 | """Segments a list of arbitrary length into a list of lists each of size MAX_METRIC_DATA."""
52 | list_len = len(metric_list)
53 | if list_len <= self.MAX_METRIC_DATA:
54 | return [metric_list]
55 | remaining_list_len = list_len % self.MAX_METRIC_DATA
56 | chunks = [metric_list[x:x + self.MAX_METRIC_DATA] for x in range(0, list_len - remaining_list_len, self.MAX_METRIC_DATA)]
57 | chunks.append(metric_list[-remaining_list_len:])
58 | return chunks
59 |
60 | def publish_metrics(self, metric_list: List):
61 | """Publish the metrics to CloudWatch."""
62 | for metrics in self.segment_metric_data(metric_list):
63 | self.cloudwatch.put_metric_data(MetricData=metrics, Namespace=CLOUD_WATCH_NAMESPACE)
64 |
65 | def put_pii_document_processed_metric(self, language: str, s3ol_access_point: str):
66 | """Put PiiDocumentsProcessed metric."""
67 | self.publish_metrics([{METRIC_NAME: PII_DOCUMENTS_PROCESSED, DIMENSIONS: [
68 | {NAME: LANGUAGE, VALUE: language},
69 | {NAME: S3OL_ACCESS_POINT, VALUE: s3ol_access_point}
70 | ], UNIT: COUNT, VALUE: 1.0}])
71 |
72 | def put_document_processed_metric(self, language: str, s3ol_access_point: str):
73 | """Put DocumentsProcessed metric."""
74 | self.publish_metrics([{METRIC_NAME: DOCUMENTS_PROCESSED, DIMENSIONS: [
75 | {NAME: LANGUAGE, VALUE: language},
76 | {NAME: S3OL_ACCESS_POINT, VALUE: s3ol_access_point}
77 | ], UNIT: COUNT, VALUE: 1.0}])
78 |
79 | def put_pii_document_types_metric(self, pii_entity_types: List[str], language: str, s3ol_access_point: str):
80 | """Put PiiDocumentTypesProcessed metric."""
81 | self.publish_metrics([{METRIC_NAME: PII_DOCUMENT_TYPES_PROCESSED, DIMENSIONS: [
82 | {NAME: PII_ENTITY_TYPE, VALUE: pii_entity_type},
83 | {NAME: S3OL_ACCESS_POINT, VALUE: s3ol_access_point},
84 | {NAME: LANGUAGE, VALUE: language}
85 | ], UNIT: COUNT, VALUE: 1.0} for pii_entity_type in pii_entity_types])
86 |
--------------------------------------------------------------------------------
/test/unit/test_validators.py:
--------------------------------------------------------------------------------
1 | import json
2 | import os
3 | from copy import deepcopy
4 | from unittest import TestCase
5 | from unittest.mock import patch
6 |
7 | from config import IS_PARTIAL_OBJECT_SUPPORTED
8 | from constants import S3OL_CONFIGURATION, GET_OBJECT_CONTEXT, REQUEST_TOKEN, REQUEST_ROUTE, REQUEST_ID, INPUT_S3_URL, PAYLOAD, \
9 | USER_REQUEST, HEADERS, RANGE
10 | from exceptions import InvalidConfigurationException, InvalidRequestException
11 | from validators import JsonValidator, Validator, InputEventValidator, PartialObjectRequestValidator
12 |
13 | this_module_path = os.path.dirname(__file__)
14 |
15 |
16 | class TestValidator(TestCase):
17 | def setUp(self) -> None:
18 | with open(os.path.join(this_module_path, "..", 'data', 'sample_event.json'), 'r') as file_pointer:
19 | self.sample_event = json.load(file_pointer)
20 | self.is_partial_object_supported = IS_PARTIAL_OBJECT_SUPPORTED
21 |
22 | def test_json_validator(self):
23 | with self.assertRaises(Exception) as context:
24 | JsonValidator.validate("{\"invalid\":json_string")
25 |
26 | def test_validator_interface(self):
27 | with self.assertRaises(NotImplementedError) as context:
28 | Validator.validate(None)
29 |
30 | def test_partial_object_validator_partial_object_requested(self):
31 | with self.assertRaises(InvalidRequestException) as context:
32 | self.sample_event[USER_REQUEST][HEADERS][RANGE] = "0-100"
33 | PartialObjectRequestValidator.validate(self.sample_event)
34 | assert context.exception.message == "HTTP Header Range is not supported"
35 |
36 |
37 | def test_partial_object_validator_complete_object_requested(self):
38 | temp = deepcopy(self.sample_event)
39 | PartialObjectRequestValidator.validate(temp)
40 |
41 | @patch('validators.IS_PARTIAL_OBJECT_SUPPORTED', True)
42 | def test_partial_object_validator_when_partial_object_is_supported(self):
43 | self.sample_event[USER_REQUEST][HEADERS][RANGE] = "0-100"
44 | PartialObjectRequestValidator.validate(self.sample_event)
45 |
46 | def test_input_event_validation_empty_input(self):
47 | with self.assertRaises(AssertionError) as context:
48 | InputEventValidator.validate({})
49 |
50 | def test_input_event_validation_missing_s3ol_config(self):
51 | with self.assertRaises(AssertionError) as context:
52 | invalid_event = deepcopy(self.sample_event)
53 | del invalid_event[S3OL_CONFIGURATION]
54 | InputEventValidator.validate(invalid_event)
55 |
56 | def test_input_event_validation_missing_object_context(self):
57 | with self.assertRaises(AssertionError) as context:
58 | invalid_event = deepcopy(self.sample_event)
59 | del invalid_event[GET_OBJECT_CONTEXT]
60 | InputEventValidator.validate(invalid_event)
61 |
62 | def test_input_event_validation_missing_request_token(self):
63 | with self.assertRaises(AssertionError) as context:
64 | invalid_event = deepcopy(self.sample_event)
65 | del invalid_event[GET_OBJECT_CONTEXT][REQUEST_TOKEN]
66 | InputEventValidator.validate(invalid_event)
67 |
68 | def test_input_event_validation_missing_request_route(self):
69 | with self.assertRaises(AssertionError) as context:
70 | invalid_event = deepcopy(self.sample_event)
71 | del invalid_event[GET_OBJECT_CONTEXT][REQUEST_ROUTE]
72 | InputEventValidator.validate(invalid_event)
73 |
74 | def test_input_event_validation_missing_request_id(self):
75 | with self.assertRaises(AssertionError) as context:
76 | invalid_event = deepcopy(self.sample_event)
77 | del invalid_event[REQUEST_ID]
78 | InputEventValidator.validate(invalid_event)
79 |
80 | def test_input_event_validation_missing_input_s3_url(self):
81 | with self.assertRaises(AssertionError) as context:
82 | invalid_event = deepcopy(self.sample_event)
83 | del invalid_event[GET_OBJECT_CONTEXT][INPUT_S3_URL]
84 | InputEventValidator.validate(invalid_event)
85 |
86 | def test_input_event_validation_missing_payload(self):
87 | with self.assertRaises(AssertionError) as context:
88 | invalid_event = deepcopy(self.sample_event)
89 | del invalid_event[S3OL_CONFIGURATION][PAYLOAD]
90 | InputEventValidator.validate(invalid_event)
91 |
92 | def test_input_event_validation_invalid_payload(self):
93 | with self.assertRaises(InvalidConfigurationException) as context:
94 | invalid_event = deepcopy(self.sample_event)
95 | invalid_event[S3OL_CONFIGURATION][PAYLOAD] = "Invalid json"
96 | InputEventValidator.validate(invalid_event)
97 |
98 | def test_input_event_validation_empty_payload(self):
99 | invalid_event = deepcopy(self.sample_event)
100 | invalid_event[S3OL_CONFIGURATION][PAYLOAD] = ""
101 | InputEventValidator.validate(invalid_event)
102 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 |
2 | # Created by https://www.gitignore.io/api/osx,linux,python,windows,pycharm,visualstudiocode
3 |
4 | ### Linux ###
5 | *~
6 |
7 | # temporary files which can be created if a process still has a handle open of a deleted file
8 | .fuse_hidden*
9 |
10 | # KDE directory preferences
11 | .directory
12 |
13 | # Linux trash folder which might appear on any partition or disk
14 | .Trash-*
15 |
16 | # .nfs files are created when an open file is removed but is still being accessed
17 | .nfs*
18 |
19 | ### OSX ###
20 | *.DS_Store
21 | .AppleDouble
22 | .LSOverride
23 |
24 | # Icon must end with two \r
25 | Icon
26 |
27 | # Thumbnails
28 | ._*
29 |
30 | # Files that might appear in the root of a volume
31 | .DocumentRevisions-V100
32 | .fseventsd
33 | .Spotlight-V100
34 | .TemporaryItems
35 | .Trashes
36 | .VolumeIcon.icns
37 | .com.apple.timemachine.donotpresent
38 |
39 | # Directories potentially created on remote AFP share
40 | .AppleDB
41 | .AppleDesktop
42 | Network Trash Folder
43 | Temporary Items
44 | .apdisk
45 |
46 | ### PyCharm ###
47 | # Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio and Webstorm
48 | # Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839
49 |
50 | # User-specific stuff:
51 | .idea/**/workspace.xml
52 | .idea/**/tasks.xml
53 | .idea/dictionaries
54 |
55 | # Sensitive or high-churn files:
56 | .idea/**/dataSources/
57 | .idea/**/dataSources.ids
58 | .idea/**/dataSources.xml
59 | .idea/**/dataSources.local.xml
60 | .idea/**/sqlDataSources.xml
61 | .idea/**/dynamic.xml
62 | .idea/**/uiDesigner.xml
63 |
64 | # Gradle:
65 | .idea/**/gradle.xml
66 | .idea/**/libraries
67 |
68 | # CMake
69 | cmake-build-debug/
70 |
71 | # Mongo Explorer plugin:
72 | .idea/**/mongoSettings.xml
73 |
74 | ## File-based project format:
75 | *.iws
76 |
77 | ## Plugin-specific files:
78 |
79 | # IntelliJ
80 | /out/
81 |
82 | # mpeltonen/sbt-idea plugin
83 | .idea_modules/
84 |
85 | # JIRA plugin
86 | atlassian-ide-plugin.xml
87 |
88 | # Cursive Clojure plugin
89 | .idea/replstate.xml
90 |
91 | # Ruby plugin and RubyMine
92 | /.rakeTasks
93 |
94 | # Crashlytics plugin (for Android Studio and IntelliJ)
95 | com_crashlytics_export_strings.xml
96 | crashlytics.properties
97 | crashlytics-build.properties
98 | fabric.properties
99 |
100 | ### PyCharm Patch ###
101 | # Comment Reason: https://github.com/joeblau/gitignore.io/issues/186#issuecomment-215987721
102 |
103 | # *.iml
104 | # modules.xml
105 | # .idea/misc.xml
106 | # *.ipr
107 |
108 | # Sonarlint plugin
109 | .idea/sonarlint
110 |
111 | ### Python ###
112 | # Byte-compiled / optimized / DLL files
113 | __pycache__/
114 | *.py[cod]
115 | *$py.class
116 |
117 | # C extensions
118 | *.so
119 |
120 | # Distribution / packaging
121 | .Python
122 | build/
123 | develop-eggs/
124 | dist/
125 | downloads/
126 | eggs/
127 | .eggs/
128 | lib/
129 | lib64/
130 | parts/
131 | sdist/
132 | var/
133 | wheels/
134 | *.egg-info/
135 | .installed.cfg
136 | *.egg
137 | /*.zip
138 |
139 | # PyInstaller
140 | # Usually these files are written by a python script from a template
141 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
142 | *.manifest
143 | *.spec
144 |
145 | # Installer logs
146 | pip-log.txt
147 | pip-delete-this-directory.txt
148 |
149 | # Unit test / coverage reports
150 | htmlcov/
151 | .tox/
152 | .coverage
153 | .coverage.*
154 | .cache
155 | .pytest_cache/
156 | nosetests.xml
157 | coverage.xml
158 | *.cover
159 | .hypothesis/
160 |
161 | # Translations
162 | *.mo
163 | *.pot
164 |
165 | # Flask stuff:
166 | instance/
167 | .webassets-cache
168 |
169 | # Scrapy stuff:
170 | .scrapy
171 |
172 | # Sphinx documentation
173 | docs/_build/
174 |
175 | # PyBuilder
176 | target/
177 |
178 | # Jupyter Notebook
179 | .ipynb_checkpoints
180 |
181 | # pyenv
182 | .python-version
183 |
184 | # celery beat schedule file
185 | celerybeat-schedule.*
186 |
187 | # SageMath parsed files
188 | *.sage.py
189 |
190 | # Environments
191 | .env
192 | .venv
193 | env/
194 | venv/
195 | ENV/
196 | env.bak/
197 | venv.bak/
198 |
199 | # Spyder project settings
200 | .spyderproject
201 | .spyproject
202 |
203 | # Rope project settings
204 | .ropeproject
205 |
206 | # mkdocs documentation
207 | /site
208 |
209 | # mypy
210 | .mypy_cache/
211 |
212 | ### VisualStudioCode ###
213 | .vscode/*
214 | !.vscode/settings.json
215 | !.vscode/tasks.json
216 | !.vscode/launch.json
217 | !.vscode/extensions.json
218 | .history
219 |
220 | ### Windows ###
221 | # Windows thumbnail cache files
222 | Thumbs.db
223 | ehthumbs.db
224 | ehthumbs_vista.db
225 |
226 | # Folder config file
227 | Desktop.ini
228 |
229 | # Recycle Bin used on file shares
230 | $RECYCLE.BIN/
231 |
232 | # Windows Installer files
233 | *.cab
234 | *.msi
235 | *.msm
236 | *.msp
237 |
238 | # Windows shortcuts
239 | *.lnk
240 |
241 | # Build folder
242 |
243 | */build/*
244 |
245 | # End of https://www.gitignore.io/api/osx,linux,python,windows,pycharm,visualstudiocode
246 |
247 | # vim, mac stuff, powerpoint temp files (I use powerpoint for architecture diagrams)
248 | .*.sw[op]
249 | .DS_Store
250 | ~$*.ppt*
251 |
252 | # SAM CLI build dir
253 | .aws-sam
254 |
255 | # We're using pipenv so any requirements.txt files are auto-generated and should be ignored by git
256 | requirements.txt
257 |
--------------------------------------------------------------------------------
/test/load/load_test_base.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import os
3 | import time
4 | from concurrent.futures._base import as_completed
5 | from concurrent.futures.thread import ThreadPoolExecutor
6 |
7 | from integ.integ_base import BasicIntegTest
8 |
9 |
10 | class BaseLoadTest(BasicIntegTest):
11 | DEFAULT_PERIOD = 5
12 | DEFAULT_CONNECTIONS = 1
13 |
14 | def find_max_tpm(self, s3_olap_arn, file_size, is_pii, expected_error, error_percent_threshold=0.05, load_test_period=3):
15 | file_name = self.create_and_upload_text_file(file_size, is_pii)
16 | total_counts, failed_counts = self.run_desired_simulataneous_get_object_calls(s3_olap_arn, file_name, expected_error,
17 | period_in_minutes=load_test_period)
18 | current_tpm = total_counts / load_test_period
19 | previous_tpm = 0
20 | previous_best_connection_count = self.DEFAULT_CONNECTIONS
21 | connection_count = self.DEFAULT_CONNECTIONS
22 | while failed_counts / total_counts < error_percent_threshold and previous_tpm <= current_tpm:
23 | previous_tpm = current_tpm
24 | previous_best_connection_count = connection_count
25 | connection_count = previous_best_connection_count + 2
26 | total_counts, failed_counts = self.run_desired_simulataneous_get_object_calls(s3_olap_arn, file_name, expected_error,
27 | connection_counts=connection_count,
28 | period_in_minutes=load_test_period)
29 | current_tpm = total_counts / load_test_period
30 |
31 | logging.info(
32 | f"Best results tpm: {previous_tpm} with error rate :{failed_counts / total_counts * 100} % for file size {file_size} "
33 | f"are obtained with Connection count as {previous_best_connection_count}")
34 |
35 | def execute_load_on_get_object(self, s3_olap_arn, file_size, is_pii, expected_error):
36 | logging.info(f"Running Load Test for {file_size} KB file where pii is {'' if is_pii else 'not'} present")
37 | file_name = self.create_and_upload_text_file(file_size, is_pii)
38 | return self.run_desired_simulataneous_get_object_calls(s3_olap_arn, file_name, expected_error)
39 |
40 | def create_and_upload_text_file(self, desired_size_in_KB: int, is_pii: bool):
41 | file_name = str("pii_" if is_pii else "non_pii") + str(desired_size_in_KB) + "_KB"
42 |
43 | repeat_text = " Some Random Text ttt"
44 | with open(self.DATA_PATH + "/pii_input.txt") as pii_file:
45 | pii_text = pii_file.read()
46 | full_text = pii_text if is_pii else ""
47 | with open(file_name, 'w') as temp:
48 | while len(full_text) <= desired_size_in_KB * 1000:
49 | full_text += repeat_text
50 | temp.write(full_text)
51 | self.s3_client.upload_file(file_name, self.bucket_name, file_name)
52 | os.remove(file_name)
53 | return file_name
54 |
55 | def run_desired_simulataneous_get_object_calls(self, s3_olap_arn, file_name, expected_error, connection_counts=DEFAULT_CONNECTIONS,
56 | period_in_minutes=DEFAULT_PERIOD, ):
57 | logging.info(
58 | f"Running Load Test for file : {file_name} for period {period_in_minutes} with connection counts: {connection_counts}")
59 | s = ThreadPoolExecutor(max_workers=connection_counts)
60 | futures = [s.submit(self.fail_safe_fetch_s3_object, s3_olap_arn, file_name, period=period_in_minutes * 60,
61 | expected_error=expected_error) for i in
62 | range(0, connection_counts)]
63 |
64 | total_counts = 0
65 | successful_counts = 0
66 | average_latency = 0
67 | for f in as_completed(futures):
68 | successful_counts += f.result()[1]
69 | total_counts += f.result()[0] + f.result()[1]
70 | average_latency = f.result()[2] / total_counts
71 |
72 | logging.info(f" Total calls made: {total_counts}, out of which {successful_counts} calls were successful."
73 | f" ({successful_counts / total_counts * 100}%) ,Average Latency {average_latency}")
74 | return total_counts, total_counts - successful_counts, average_latency
75 |
76 | def fail_safe_fetch_s3_object(self, s3ol_ap_arn, file, period, expected_error=None):
77 | start_time = time.time()
78 | failed_calls = 0
79 | successful_calls = 0
80 | total_time = 0
81 | while (time.time() - start_time) <= period:
82 | try:
83 | api_call_st_time = time.time()
84 | response = self.s3_client.get_object(Bucket=s3ol_ap_arn, Key=file)
85 | total_time += time.time() - api_call_st_time
86 | successful_calls += 1
87 | except Exception as e:
88 | total_time += time.time() - api_call_st_time
89 | if expected_error:
90 | if isinstance(e, expected_error):
91 | successful_calls += 1
92 | else:
93 | failed_calls += 1
94 | else:
95 | # logging.error(e)
96 | failed_calls += 1
97 | return failed_calls, successful_calls, total_time
98 |
--------------------------------------------------------------------------------
/access-control-template.yml:
--------------------------------------------------------------------------------
1 | AWSTemplateFormatVersion: '2010-09-09'
2 | Transform: AWS::Serverless-2016-10-31
3 |
4 | Metadata:
5 | AWS::ServerlessRepo::Application:
6 | Name: ComprehendPiiAccessControlS3ObjectLambda
7 | Description: Deploys a Lambda which will provide capability to control access to text files with PII (Personally Identifiable Information). This Lambda can be used as a s3 object lambda which will be triggered on get-object call when configured with access point
8 | Author: AWS Comprehend
9 | # SPDX License Id, e.g., MIT, MIT-0, Apache-2.0. See https://spdx.org/licenses for more details
10 | SpdxLicenseId: MIT-0
11 | LicenseUrl: LICENSE
12 | ReadmeUrl: ACCESS_CONTROL_README.md
13 | Labels: [serverless,comprehend,nlp,pii]
14 | HomePageUrl: https://aws.amazon.com/comprehend/
15 | SemanticVersion: 1.0.2
16 | SourceCodeUrl: https://github.com/aws-samples/amazon-comprehend-s3-object-lambda-functions
17 |
18 | Parameters:
19 | LogLevel:
20 | Type: String
21 | Description: Log level for Lambda function logging, e.g., ERROR, INFO, DEBUG, etc.
22 | Default: INFO
23 | UnsupportedFileHandling:
24 | Type: String
25 | Description: Handling logic for Unsupported files. Valid values are PASS and FAIL.
26 | Default: FAIL
27 | IsPartialObjectSupported:
28 | Type: String
29 | Description: Whether to support partial objects or not. Accessing partial object through http headers such byte-range can corrupt the object and/or affect PII detection accuracy.
30 | Default: FALSE
31 | DocumentMaxSizeContainsPiiEntities:
32 | Type: Number
33 | Description: Maximum document size (in bytes) to be used for making calls to Comprehend's ContainsPiiEntities API.
34 | Default: 50000
35 | PiiEntityTypes:
36 | Type: String
37 | Description: List of comma separated PII entity types to be considered for access control. Refer Comprehend's documentation page for list of supported PII entity types.
38 | Default: ALL
39 | SubsegmentOverlappingTokens:
40 | Type: Number
41 | Description: Number of tokens/words to overlap among segments of a document in case chunking is needed because of maximum document size limit.
42 | Default: 20
43 | DocumentMaxSize:
44 | Type: Number
45 | Description: Default maximum document size (in bytes) that this function can process otherwise will throw exception for too large document size.
46 | Default: 102400
47 | ConfidenceThreshold:
48 | Type: Number
49 | Description: The minimum prediction confidence score above which PII classification and detection would be considered as final answer. Valid range (0.5 to 1.0).
50 | Default: 0.5
51 | MaxCharsOverlap:
52 | Type: Number
53 | Description: Maximum characters to overlap among segments of a document in case chunking is needed because of maximum document size limit.
54 | Default: 200
55 | DefaultLanguageCode:
56 | Type: String
57 | Description: Default language of the text to be processed. This code will be used for interacting with Comprehend.
58 | Default: en
59 | ContainsPiiEntitiesThreadCount:
60 | Type: Number
61 | Description: Number of threads to use for calling Comprehend's ContainsPiiEntities API. This controls the number of simultaneous calls that will be made from this Lambda.
62 | Default: 20
63 | PublishCloudWatchMetrics:
64 | Type: String
65 | Description: True if publish metrics to Cloudwatch, false otherwise. See README.md for details on CloudWatch metrics.
66 | Default: True
67 |
68 | Resources:
69 | PiiAccessControlFunction:
70 | Type: AWS::Serverless::Function
71 | Properties:
72 | CodeUri: src/
73 | Handler: handler.pii_access_control_handler
74 | Runtime: python3.8
75 | Tracing: Active
76 | Timeout: 60
77 | Policies:
78 | - Statement:
79 | - Sid: ComprehendPiiDetectionPolicy
80 | Effect: Allow
81 | Action:
82 | - comprehend:ContainsPiiEntities
83 | Resource: '*'
84 | - Sid: S3AccessPointCallbackPolicy
85 | Effect: Allow
86 | Action:
87 | - s3-object-lambda:WriteGetObjectResponse
88 | Resource: '*'
89 | - Sid: CloudWatchMetricsPolicy
90 | Effect: Allow
91 | Action:
92 | - cloudwatch:PutMetricData
93 | Resource: '*'
94 | Environment:
95 | Variables:
96 | LOG_LEVEL: !Ref LogLevel
97 | UNSUPPORTED_FILE_HANDLING: !Ref UnsupportedFileHandling
98 | IS_PARTIAL_OBJECT_SUPPORTED: !Ref IsPartialObjectSupported
99 | DOCUMENT_MAX_SIZE_CONTAINS_PII_ENTITIES: !Ref DocumentMaxSizeContainsPiiEntities
100 | PII_ENTITY_TYPES: !Ref PiiEntityTypes
101 | SUBSEGMENT_OVERLAPPING_TOKENS: !Ref SubsegmentOverlappingTokens
102 | DOCUMENT_MAX_SIZE: !Ref DocumentMaxSize
103 | CONFIDENCE_THRESHOLD: !Ref ConfidenceThreshold
104 | MAX_CHARS_OVERLAP: !Ref MaxCharsOverlap
105 | DEFAULT_LANGUAGE_CODE: !Ref DefaultLanguageCode
106 | CONTAINS_PII_ENTITIES_THREAD_COUNT: !Ref ContainsPiiEntitiesThreadCount
107 | PUBLISH_CLOUD_WATCH_METRICS: !Ref PublishCloudWatchMetrics
108 |
109 | Outputs:
110 | PiiAccessControlFunctionName:
111 | Description: "PII Access Control Function Name"
112 | Value: !Ref PiiAccessControlFunction
113 |
--------------------------------------------------------------------------------
/src/exception_handlers.py:
--------------------------------------------------------------------------------
1 | """Classes for handling exceptions."""
2 | import lambdalogging
3 | from clients.s3_client import S3Client
4 | from config import UNSUPPORTED_FILE_HANDLING
5 | from constants import UNSUPPORTED_FILE_HANDLING_VALID_VALUES, S3_STATUS_CODES, S3_ERROR_CODES, error_code_to_enums
6 | from exceptions import UnsupportedFileException, FileSizeLimitExceededException, S3DownloadException, InvalidConfigurationException, \
7 | InvalidRequestException, RestrictedDocumentException, TimeoutException
8 |
9 | LOG = lambdalogging.getLogger(__name__)
10 |
11 |
12 | class ExceptionHandler:
13 | """Handler enclosing an action to be taken in case of an error occurred while processing files."""
14 |
15 | def __init__(self, s3_client: S3Client):
16 | self.s3_client = s3_client
17 |
18 | def handle_exception(self, exception: BaseException, request_route: str, request_token: str):
19 | """Handle exception and take appropriate actions."""
20 | try:
21 | raise exception
22 | except UnsupportedFileException as e:
23 | self._handle_unsupported_file_exception(e, request_route, request_token)
24 | except InvalidConfigurationException as e:
25 | LOG.error(f"Encountered an invalid configuration setup. {e}", exc_info=True)
26 | self.s3_client.respond_back_with_error(S3_STATUS_CODES.BAD_REQUEST_400,
27 | S3_ERROR_CODES.InvalidRequest,
28 | "Lambda function has been incorrectly setup", request_route,
29 | request_token)
30 | except FileSizeLimitExceededException:
31 | LOG.info(
32 | f"File size of the requested object exceeds maximum file size supported. Responding back with"
33 | f"error: {S3_STATUS_CODES.BAD_REQUEST_400.name} ")
34 | self.s3_client.respond_back_with_error(S3_STATUS_CODES.BAD_REQUEST_400, S3_ERROR_CODES.EntityTooLarge,
35 | "Size of the requested object exceeds maximum file size supported", request_route,
36 | request_token)
37 | except InvalidRequestException as e:
38 | LOG.info(f"Encountered an invalid request {e}", exc_info=True)
39 | self.s3_client.respond_back_with_error(S3_STATUS_CODES.BAD_REQUEST_400, S3_ERROR_CODES.InvalidRequest,
40 | e.message, request_route, request_token)
41 | except S3DownloadException as e:
42 | LOG.error(f"Error downloading from presigned url. {e}", exc_info=True)
43 | status_code, error_code = error_code_to_enums(e.s3_error_code)
44 | self.s3_client.respond_back_with_error(status_code, error_code, e.s3_message,
45 | request_route, request_token)
46 | except RestrictedDocumentException as e:
47 | LOG.error(f"Document contains pii. {e}", exc_info=True)
48 | self.s3_client.respond_back_with_error(S3_STATUS_CODES.FORBIDDEN_403,
49 | S3_ERROR_CODES.AccessDenied,
50 | "Document Contains PII",
51 | request_route, request_token)
52 | except TimeoutException as e:
53 | LOG.error(f"Couldn't complete processing within the time limit. {e}", exc_info=True)
54 | self.s3_client.respond_back_with_error(S3_STATUS_CODES.BAD_REQUEST_400,
55 | S3_ERROR_CODES.RequestTimeout,
56 | "Failed to complete document processing within time limit",
57 | request_route, request_token)
58 | except Exception as e:
59 | LOG.error(f"Internal error {e} occurred while processing the file", exc_info=True)
60 | self.s3_client.respond_back_with_error(S3_STATUS_CODES.INTERNAL_SERVER_ERROR_500, S3_ERROR_CODES.InternalError,
61 | "An internal error occurred while processing the file", request_route,
62 | request_token)
63 |
64 | def _handle_unsupported_file_exception(self, exception: UnsupportedFileException, request_route: str, request_token: str):
65 | """Handle the action to be taken in case we encounter a file which is not supported by Lambda's core functionality."""
66 | LOG.debug("File is not supported for determining and redacting pii data.")
67 |
68 | if UNSUPPORTED_FILE_HANDLING == UNSUPPORTED_FILE_HANDLING_VALID_VALUES.PASS:
69 | LOG.debug("Unsupported file handling strategy is set to PASS. Responding back with the file content to the caller")
70 | self.s3_client.respond_back_with_data(exception.file_content, exception.http_headers, request_route, request_token)
71 |
72 | elif UNSUPPORTED_FILE_HANDLING == UNSUPPORTED_FILE_HANDLING_VALID_VALUES.FAIL:
73 | LOG.debug(
74 | f"Unsupported file handling strategy is set to FAIL. Responding back with error: "
75 | f"{S3_ERROR_CODES.UnexpectedContent.name} to the caller")
76 | self.s3_client.respond_back_with_error(S3_STATUS_CODES.BAD_REQUEST_400,
77 | S3_ERROR_CODES.UnexpectedContent,
78 | "Unsupported file encountered for determining Pii", request_route, request_token)
79 | else:
80 | raise Exception("Unknown exception handling strategy found for UnsupportedFileException.")
81 |
--------------------------------------------------------------------------------
/src/clients/comprehend_client.py:
--------------------------------------------------------------------------------
1 | """Client wrapper over aws services."""
2 |
3 | import string
4 | from concurrent.futures._base import as_completed
5 | from concurrent.futures.thread import ThreadPoolExecutor
6 | from copy import deepcopy
7 | from random import choices
8 | from typing import List
9 |
10 | import boto3
11 | import botocore
12 | import time
13 |
14 | import lambdalogging
15 | from clients.cloudwatch_client import Metrics
16 | from config import CONTAINS_PII_ENTITIES_THREAD_COUNT, DETECT_PII_ENTITIES_THREAD_COUNT, DEFAULT_LANGUAGE_CODE
17 | from constants import DEFAULT_USER_AGENT, CONTAINS_PII_ENTITIES, DETECT_PII_ENTITIES, COMPREHEND, COMPREHEND_MAX_RETRIES
18 | from data_object import Document
19 |
20 | LOG = lambdalogging.getLogger(__name__)
21 |
22 |
23 | class ComprehendClient:
24 | """Wrapper over comprehend client."""
25 |
26 | def __init__(self, s3ol_access_point: str, pii_classification_thread_count: int = CONTAINS_PII_ENTITIES_THREAD_COUNT,
27 | pii_redaction_thread_count: int = DETECT_PII_ENTITIES_THREAD_COUNT,
28 | session_id: str = ''.join(choices(string.ascii_uppercase + string.digits, k=10)),
29 | user_agent=DEFAULT_USER_AGENT, endpoint_url=None):
30 | self.session_id = session_id
31 | session_config = botocore.config.Config(
32 | user_agent_extra=user_agent,
33 | retries={
34 | 'max_attempts': COMPREHEND_MAX_RETRIES,
35 | 'mode': 'standard'
36 | })
37 | if endpoint_url is None:
38 | self.comprehend = boto3.client('comprehend', config=session_config)
39 | else:
40 | self.comprehend = boto3.client('comprehend', config=session_config, endpoint_url=endpoint_url, verify=False)
41 | self.comprehend.meta.events.register('before-sign.comprehend.*', self._add_session_header)
42 | self.classification_executor_service = ThreadPoolExecutor(max_workers=pii_classification_thread_count)
43 | self.redaction_executor_service = ThreadPoolExecutor(max_workers=pii_redaction_thread_count)
44 | self.classify_metrics = Metrics(service_name=COMPREHEND, api=CONTAINS_PII_ENTITIES, s3ol_access_point=s3ol_access_point)
45 | self.detection_metrics = Metrics(service_name=COMPREHEND, api=DETECT_PII_ENTITIES, s3ol_access_point=s3ol_access_point)
46 |
47 | def _add_session_header(self, request, **kwargs):
48 | request.headers.add_header('x-amzn-session-id', self.session_id)
49 |
50 | def contains_pii_entities(self, documents: List[Document], language=DEFAULT_LANGUAGE_CODE) -> List[Document]:
51 | """Call comprehend to get pii classification of given documents."""
52 | documents_copy = deepcopy(documents)
53 | result = []
54 | with self.classification_executor_service:
55 | futures = []
56 | for doc in documents_copy:
57 | futures.append(self.classification_executor_service.submit(self._update_doc_with_pii_classification, doc, language))
58 |
59 | for future_result in as_completed(futures):
60 | try:
61 | result.append(future_result.result())
62 | except Exception as error:
63 | LOG.error("Error occurred while calling comprehend for classifying text as pii", exc_info=True)
64 | self.classify_metrics.add_fault_count()
65 | raise error
66 | return result
67 |
68 | def _update_doc_with_pii_classification(self, document: Document, language) -> Document:
69 | start_time = time.time()
70 | response = None
71 | try:
72 | response = self.comprehend.contains_pii_entities(Text=document.text, LanguageCode=language)
73 | finally:
74 | if response is not None:
75 | self.classify_metrics.add_fault_count(response['ResponseMetadata']['RetryAttempts'])
76 | self.classify_metrics.add_latency(start_time, time.time())
77 | # updating the document itself instead of creating a new copy to save space
78 | document.pii_classification = {label['Name']: label['Score'] for label in response['Labels']}
79 | return document
80 |
81 | def detect_pii_documents(self, documents: List[Document], language=DEFAULT_LANGUAGE_CODE) -> List[Document]:
82 | """Call comprehend to get pii entities present in given documents."""
83 | documents_copy = deepcopy(documents)
84 | result = []
85 | with self.redaction_executor_service:
86 | futures = []
87 | for doc in documents_copy:
88 | futures.append(self.redaction_executor_service.submit(self._update_doc_with_pii_entities, doc, language))
89 |
90 | for future_result in as_completed(futures):
91 | try:
92 | result.append(future_result.result())
93 | except Exception as error:
94 | LOG.error("Error occurred while calling comprehend for detecting pii entities", exc_info=True)
95 | self.detection_metrics.add_fault_count()
96 | raise error
97 | return result
98 |
99 | def _update_doc_with_pii_entities(self, document: Document, language) -> Document:
100 | start_time = time.time()
101 | response = None
102 | try:
103 | response = self.comprehend.detect_pii_entities(Text=document.text, LanguageCode=language)
104 | finally:
105 | if response is not None:
106 | self.detection_metrics.add_fault_count(response['ResponseMetadata']['RetryAttempts'])
107 | self.detection_metrics.add_latency(start_time, time.time())
108 | # updating the document itself instead of creating a new copy to save space
109 | document.pii_entities = response['Entities']
110 | document.pii_classification = {entity['Type']: max(entity['Score'], document.pii_classification[entity['Type']])
111 | if entity['Type'] in document.pii_classification else entity['Score']
112 | for entity in response['Entities']}
113 | return document
114 |
--------------------------------------------------------------------------------
/test/unit/test_exception_handlers.py:
--------------------------------------------------------------------------------
1 | from unittest import TestCase
2 | from unittest.mock import patch, MagicMock
3 |
4 | from constants import S3_STATUS_CODES, S3_ERROR_CODES, UNSUPPORTED_FILE_HANDLING_VALID_VALUES
5 | from exception_handlers import ExceptionHandler
6 | from exceptions import UnsupportedFileException, FileSizeLimitExceededException, InvalidConfigurationException, \
7 | S3DownloadException, RestrictedDocumentException, TimeoutException
8 |
9 |
10 | class ExceptionHandlerTest(TestCase):
11 |
12 | def setUp(self) -> None:
13 | super().setUp()
14 |
15 | def tearDown(self) -> None:
16 | super().tearDown()
17 |
18 | @patch("exception_handlers.UNSUPPORTED_FILE_HANDLING", UNSUPPORTED_FILE_HANDLING_VALID_VALUES.PASS)
19 | def test_unsupported_file_exception_handling_do_not_fail(self):
20 | s3_client = MagicMock()
21 | ExceptionHandler(s3_client). \
22 | handle_exception(UnsupportedFileException(file_content="SomeContent", http_headers={'h1': 'v1'}), "SomeRoute", "SomeToken")
23 | s3_client.respond_back_with_data.assert_called_once_with("SomeContent", {'h1': 'v1'}, "SomeRoute", "SomeToken")
24 |
25 | @patch("exception_handlers.UNSUPPORTED_FILE_HANDLING", UNSUPPORTED_FILE_HANDLING_VALID_VALUES.FAIL)
26 | def test_unsupported_file_exception_handling_return_error(self):
27 | s3_client = MagicMock()
28 | ExceptionHandler(s3_client). \
29 | handle_exception(UnsupportedFileException(file_content="SomeContent", http_headers={'h1': 'v1'}), "SomeRoute", "SomeToken")
30 | s3_client.respond_back_with_error.assert_called_once_with(S3_STATUS_CODES.BAD_REQUEST_400,
31 | S3_ERROR_CODES.UnexpectedContent,
32 | "Unsupported file encountered for determining Pii",
33 | "SomeRoute", "SomeToken")
34 |
35 | @patch("exception_handlers.UNSUPPORTED_FILE_HANDLING", 'Unknown')
36 | def test_unsupported_file_exception_handling_return_unknown_error(self):
37 | s3_client = MagicMock()
38 | self.assertRaises(Exception, ExceptionHandler(s3_client).handle_exception,
39 | UnsupportedFileException(file_content="SomeContent", http_headers={'h1': 'v1'}), "SomeRoute", "SomeToken")
40 |
41 | def test_file_size_limit_exceeded_handler(self):
42 | s3_client = MagicMock()
43 | ExceptionHandler(s3_client).handle_exception(FileSizeLimitExceededException(), "SomeRoute", "SomeToken")
44 | s3_client.respond_back_with_error.assert_called_once_with(S3_STATUS_CODES.BAD_REQUEST_400,
45 | S3_ERROR_CODES.EntityTooLarge,
46 | "Size of the requested object exceeds maximum file size supported",
47 | "SomeRoute", "SomeToken")
48 |
49 | def test_default_exception_handler(self):
50 | s3_client = MagicMock()
51 | ExceptionHandler(s3_client).handle_exception(Exception(), "SomeRoute", "SomeToken")
52 | s3_client.respond_back_with_error.assert_called_once_with(S3_STATUS_CODES.INTERNAL_SERVER_ERROR_500,
53 | S3_ERROR_CODES.InternalError,
54 | "An internal error occurred while processing the file",
55 | "SomeRoute", "SomeToken")
56 |
57 | def test_invalid_configuration_exception_handler(self):
58 | s3_client = MagicMock()
59 | ExceptionHandler(s3_client).handle_exception(InvalidConfigurationException("Missconfigured knob"),
60 | "SomeRoute", "SomeToken")
61 | s3_client.respond_back_with_error.assert_called_once_with(S3_STATUS_CODES.BAD_REQUEST_400,
62 | S3_ERROR_CODES.InvalidRequest,
63 | "Lambda function has been incorrectly setup",
64 | "SomeRoute", "SomeToken")
65 |
66 | def test_s3_download_exception_handler(self):
67 | s3_client = MagicMock()
68 | ExceptionHandler(s3_client).handle_exception(S3DownloadException("InternalError", "Internal Server Error"),
69 | "SomeRoute", "SomeToken")
70 | s3_client.respond_back_with_error.assert_called_once_with(S3_STATUS_CODES.INTERNAL_SERVER_ERROR_500,
71 | S3_ERROR_CODES.InternalError,
72 | "Internal Server Error",
73 | "SomeRoute", "SomeToken")
74 |
75 | def test_restricted_document_exception_handler(self):
76 | s3_client = MagicMock()
77 | ExceptionHandler(s3_client).handle_exception(RestrictedDocumentException(),
78 | "SomeRoute", "SomeToken")
79 | s3_client.respond_back_with_error.assert_called_once_with(S3_STATUS_CODES.FORBIDDEN_403,
80 | S3_ERROR_CODES.AccessDenied,
81 | "Document Contains PII",
82 | "SomeRoute", "SomeToken")
83 |
84 | def test_timeout_exception_handler(self):
85 | s3_client = MagicMock()
86 | ExceptionHandler(s3_client).handle_exception(TimeoutException(),
87 | "SomeRoute", "SomeToken")
88 | s3_client.respond_back_with_error.assert_called_once_with(S3_STATUS_CODES.BAD_REQUEST_400,
89 | S3_ERROR_CODES.RequestTimeout,
90 | "Failed to complete document processing within time limit",
91 | "SomeRoute", "SomeToken")
92 |
--------------------------------------------------------------------------------
/redaction-template.yml:
--------------------------------------------------------------------------------
1 | AWSTemplateFormatVersion: '2010-09-09'
2 | Transform: AWS::Serverless-2016-10-31
3 |
4 | Metadata:
5 | AWS::ServerlessRepo::Application:
6 | Name: ComprehendPiiRedactionS3ObjectLambda
7 | Description: Deploys a Lambda which will provide capability to redact PII (Personally Identifiable Information) from a text file present in s3. This Lambda can be used as a s3 object lambda which will be triggered on get-object call when configured with access point
8 | Author: AWS Comprehend
9 | # SPDX License Id, e.g., MIT, MIT-0, Apache-2.0. See https://spdx.org/licenses for more details
10 | SpdxLicenseId: MIT-0
11 | LicenseUrl: LICENSE
12 | ReadmeUrl: REDACTION_README.md
13 | Labels: [serverless,comprehend,pii,nlp]
14 | HomePageUrl: https://aws.amazon.com/comprehend/
15 | SemanticVersion: 1.0.2
16 | SourceCodeUrl: https://github.com/aws-samples/amazon-comprehend-s3-object-lambda-functions
17 |
18 | Parameters:
19 | LogLevel:
20 | Type: String
21 | Description: Log level for Lambda function logging, e.g., ERROR, INFO, DEBUG, etc.
22 | Default: INFO
23 | UnsupportedFileHandling:
24 | Type: String
25 | Description: Handling logic for Unsupported files. Valid values are PASS and FAIL.
26 | Default: FAIL
27 | IsPartialObjectSupported:
28 | Type: String
29 | Description: Whether to support partial objects or not. Accessing partial object through http headers such byte-range can corrupt the object and/or affect PII detection accuracy.
30 | Default: FALSE
31 | DocumentMaxSizeContainsPiiEntities:
32 | Type: Number
33 | Description: Maximum document size (in bytes) to be used for making calls to Comprehend's ContainsPiiEntities API.
34 | Default: 50000
35 | DocumentMaxSizeDetectPiiEntities:
36 | Type: Number
37 | Description: Maximum document size (in bytes) to be used for making calls to Comprehend's DetectPiiEntities API.
38 | Default: 5000
39 | PiiEntityTypes:
40 | Type: String
41 | Description: List of comma separated PII entity types to be considered for redaction. Refer Comprehend's documentation page for list of supported PII entity types.
42 | Default: ALL
43 | MaskCharacter:
44 | Type: String
45 | Description: A character that replaces each character in the redacted PII entity.
46 | Default: '*'
47 | MaskMode:
48 | Type: String
49 | Description: Specifies whether the PII entity is redacted with the mask character or the entity type. Valid values - REPLACE_WITH_PII_ENTITY_TYPE and MASK.
50 | Default: MASK
51 | SubsegmentOverlappingTokens:
52 | Type: Number
53 | Description: Number of tokens/words to overlap among segments of a document in case chunking is needed because of maximum document size limit.
54 | Default: 20
55 | DocumentMaxSize:
56 | Type: Number
57 | Description: Default maximum document size (in bytes) that this function can process otherwise will throw exception for too large document size.
58 | Default: 102400
59 | ConfidenceThreshold:
60 | Type: Number
61 | Description: The minimum prediction confidence score above which PII classification and detection would be considered as final answer. Valid range (0.5 to 1.0).
62 | Default: 0.5
63 | MaxCharsOverlap:
64 | Type: Number
65 | Description: Maximum characters to overlap among segments of a document in case chunking is needed because of maximum document size limit.
66 | Default: 200
67 | DefaultLanguageCode:
68 | Type: String
69 | Description: Default language of the text to be processed. This code will be used for interacting with Comprehend.
70 | Default: en
71 | DetectPiiEntitiesThreadCount:
72 | Type: Number
73 | Description: Number of threads to use for calling Comprehend's DetectPiiEntities API. This controls the number of simultaneous calls that will be made from this Lambda.
74 | Default: 8
75 | ContainsPiiEntitiesThreadCount:
76 | Type: Number
77 | Description: Number of threads to use for calling Comprehend's ContainsPiiEntities API. This controls the number of simultaneous calls that will be made from this Lambda.
78 | Default: 20
79 | PublishCloudWatchMetrics:
80 | Type: String
81 | Description: True if publish metrics to Cloudwatch, false otherwise. See README.md for details on CloudWatch metrics.
82 | Default: True
83 |
84 | Resources:
85 | PiiRedactionFunction:
86 | Type: AWS::Serverless::Function
87 | Properties:
88 | CodeUri: src/
89 | Handler: handler.redact_pii_documents_handler
90 | Runtime: python3.8
91 | Tracing: Active
92 | Timeout: 60
93 | Policies:
94 | - Statement:
95 | - Sid: ComprehendPiiDetectionPolicy
96 | Effect: Allow
97 | Action:
98 | - comprehend:DetectPiiEntities
99 | - comprehend:ContainsPiiEntities
100 | Resource: '*'
101 | - Sid: S3AccessPointCallbackPolicy
102 | Effect: Allow
103 | Action:
104 | - s3-object-lambda:WriteGetObjectResponse
105 | Resource: '*'
106 | - Sid: CloudWatchMetricsPolicy
107 | Effect: Allow
108 | Action:
109 | - cloudwatch:PutMetricData
110 | Resource: '*'
111 | Environment:
112 | Variables:
113 | LOG_LEVEL: !Ref LogLevel
114 | UNSUPPORTED_FILE_HANDLING: !Ref UnsupportedFileHandling
115 | PARTIAL_OBJECT_SUPPORT: !Ref IsPartialObjectSupported
116 | DOCUMENT_MAX_SIZE_CONTAINS_PII_ENTITIES: !Ref DocumentMaxSizeContainsPiiEntities
117 | DOCUMENT_MAX_SIZE_DETECT_PII_ENTITIES: !Ref DocumentMaxSizeDetectPiiEntities
118 | PII_ENTITY_TYPES: !Ref PiiEntityTypes
119 | MASK_CHARACTER: !Ref MaskCharacter
120 | MASK_MODE: !Ref MaskMode
121 | SUBSEGMENT_OVERLAPPING_TOKENS: !Ref SubsegmentOverlappingTokens
122 | DOCUMENT_MAX_SIZE: !Ref DocumentMaxSize
123 | CONFIDENCE_THRESHOLD: !Ref ConfidenceThreshold
124 | MAX_CHARS_OVERLAP: !Ref MaxCharsOverlap
125 | DEFAULT_LANGUAGE_CODE: !Ref DefaultLanguageCode
126 | DETECT_PII_ENTITIES_THREAD_COUNT: !Ref DetectPiiEntitiesThreadCount
127 | CONTAINS_PII_ENTITIES_THREAD_COUNT: !Ref ContainsPiiEntitiesThreadCount
128 | PUBLISH_CLOUD_WATCH_METRICS: !Ref PublishCloudWatchMetrics
129 |
130 | Outputs:
131 | PiiRedactionFunctionName:
132 | Description: "Redaction Function Name"
133 | Value: !Ref PiiRedactionFunction
134 |
--------------------------------------------------------------------------------
/test/integ/test_access-control.py:
--------------------------------------------------------------------------------
1 | import uuid
2 | from datetime import datetime, timedelta
3 |
4 | import boto3
5 | import botocore
6 | from botocore.exceptions import ClientError
7 | from dateutil.tz import tzutc
8 |
9 | from integ.integ_base import BasicIntegTest
10 |
11 |
12 | class PiiAccessControlIntegTest(BasicIntegTest):
13 | BUILD_DIR = '.aws-sam/build/PiiAccessControlFunction/'
14 | PII_ENTITY_TYPES_IN_TEST_DOC = ['EMAIL', 'ADDRESS', 'NAME', 'PHONE', 'DATE_TIME']
15 |
16 | @classmethod
17 | def setUpClass(cls) -> None:
18 | super().setUpClass()
19 | test_run_id = str(uuid.uuid4())[0:8]
20 | cls.lambda_function_arn = cls._create_function("pii_access_control", 'handler.pii_access_control_handler',
21 | test_run_id, {"LOG_LEVEL": "DEBUG"}, cls.lambda_role_arn, cls.BUILD_DIR)
22 | cls.s3ol_access_point_arn = cls._create_s3ol_access_point(test_run_id, cls.lambda_function_arn)
23 |
24 | @classmethod
25 | def tearDownClass(cls) -> None:
26 | cls.s3_ctrl.delete_access_point_for_object_lambda(
27 | AccountId=cls.account_id,
28 | Name=cls.s3ol_access_point_arn.split('/')[-1]
29 | )
30 |
31 | cls.lambda_client.delete_function(
32 | FunctionName=cls.lambda_function_arn.split(':')[-1]
33 | )
34 | super().tearDownClass()
35 |
36 | def tearDown(self) -> None:
37 | self._update_lambda_env_variables(self.lambda_function_arn, {"LOG_LEVEL": "DEBUG"})
38 |
39 | def test_classification_lambda_with_default_entity_types(self):
40 | start_time = datetime.now(tz=tzutc())
41 | test_pii_obj = self.s3.Object(self.s3ol_access_point_arn, 'pii_input.txt')
42 | with self.assertRaises(ClientError) as e:
43 | test_pii_obj.get()
44 | assert e.exception.response['Error']['Message'] == "Document Contains PII"
45 | assert e.exception.response['Error']['Code'] == "AccessDenied"
46 | self._validate_pii_count_metric_published(self.s3ol_access_point_arn, start_time, self.PII_ENTITY_TYPES_IN_TEST_DOC)
47 | self._validate_api_call_latency_published(self.s3ol_access_point_arn, start_time)
48 |
49 | def test_classification_lambda_with_pii_overridden_entity_types(self):
50 | self._update_lambda_env_variables(self.lambda_function_arn,
51 | {"LOG_LEVEL": "DEBUG", "PII_ENTITY_TYPES": "USERNAME,PASSWORD,AWS_ACCESS_KEY"})
52 | response = self.s3_client.get_object(Bucket=self.s3ol_access_point_arn, Key='pii_input.txt')
53 | with open(f"{self.DATA_PATH}/pii_input.txt") as expected_output_file:
54 | assert response['Body'].read().decode('utf-8') == expected_output_file.read()
55 |
56 | def test_classification_lambda_throws_invalid_request_error_when_file_size_exceeds_limit(self):
57 | self._update_lambda_env_variables(self.lambda_function_arn, {"LOG_LEVEL": "DEBUG", "DOCUMENT_MAX_SIZE": "500"})
58 | test_pii_obj = self.s3.Object(self.s3ol_access_point_arn, 'pii_input.txt')
59 | with self.assertRaises(ClientError) as e:
60 | test_pii_obj.get()
61 | assert e.exception.response['Error']['Message'] == "Size of the requested object exceeds maximum file size supported"
62 | assert e.exception.response['Error']['Code'] == "EntityTooLarge"
63 |
64 | def test_classification_lambda_throws_access_denied_with_an_overridden_max_file_size_limit(self):
65 | self._update_lambda_env_variables(self.lambda_function_arn, {"LOG_LEVEL": "DEBUG", "DOCUMENT_MAX_SIZE": "1500000",
66 | "DOCUMENT_MAX_SIZE_CONTAINS_PII_ENTITIES": "5000"})
67 | with self.assertRaises(ClientError) as e:
68 | self.s3_client.get_object(Bucket=self.s3ol_access_point_arn, Key='1mb_pii_text')
69 | assert e.exception.response['Error']['Message'] == "Document Contains PII"
70 | assert e.exception.response['Error']['Code'] == "AccessDenied"
71 |
72 | def test_classification_lambda_with_unsupported_file(self):
73 | test_obj = self.s3.Object(self.s3ol_access_point_arn, 'RandomImage.png')
74 | with self.assertRaises(ClientError) as e:
75 | test_obj.get()
76 | assert e.exception.response['Error']['Message'] == "Unsupported file encountered for determining Pii"
77 | assert e.exception.response['Error']['Code'] == "UnexpectedContent"
78 |
79 | def test_classification_lambda_with_unsupported_file_handling_set_to_pass(self):
80 | self._update_lambda_env_variables(self.lambda_function_arn, {"LOG_LEVEL": "DEBUG", "UNSUPPORTED_FILE_HANDLING": 'PASS'})
81 | response = self.s3_client.get_object(Bucket=self.s3ol_access_point_arn, Key='RandomImage.png')
82 | assert 'Body' in response
83 |
84 | def test_classification_lambda_with_partial_object(self):
85 | with self.assertRaises(ClientError) as e:
86 | self.s3_client.get_object(Bucket=self.s3ol_access_point_arn, Key="pii_input.txt", Range="bytes=0-100")
87 | assert e.exception.response['Error']['Message'] == "HTTP Header Range is not supported"
88 |
89 | def test_classification_lambda_with_partial_object_allowed_with_versioned(self):
90 | file_name = 'pii_output.txt'
91 | self.s3_client.upload_file(f"{self.DATA_PATH}/{file_name}", self.bucket_name, file_name)
92 | versions = set()
93 | self._update_lambda_env_variables(self.lambda_function_arn, {"LOG_LEVEL": "DEBUG", "IS_PARTIAL_OBJECT_SUPPORTED": "TRUE"})
94 |
95 | for version in self.s3_client.list_object_versions(Bucket=self.bucket_name)['Versions']:
96 | if version['Key'] == 'pii_output.txt':
97 | versions.add(version['VersionId'])
98 | assert len(versions) >= 2, f"Expected at least 2 different versions of {file_name}"
99 | for versionId in versions:
100 | response = self.s3_client.get_object(Bucket=self.s3ol_access_point_arn, Key=file_name, Range="bytes=0-100",
101 | VersionId=versionId)
102 | assert response['ContentRange'] == "bytes 0-100/611"
103 | assert response['ContentLength'] == 101
104 | assert response['ContentType'] == 'binary/octet-stream'
105 | assert response['VersionId'] == versionId
106 |
107 | def test_request_timeout(self):
108 | self._update_lambda_env_variables(self.lambda_function_arn,
109 | {"LOG_LEVEL": "DEBUG", "DOCUMENT_MAX_SIZE": "10000000",
110 | "CONTAINS_PII_ENTITIES_THREAD_COUNT": "1",
111 | "DOCUMENT_MAX_SIZE_CONTAINS_PII_ENTITIES": "5000"})
112 | start_time = datetime.now(tz=tzutc())
113 | with self.assertRaises(ClientError) as e:
114 | session_config = botocore.config.Config(
115 | retries={
116 | 'max_attempts': 0
117 | })
118 | s3_client = boto3.client('s3', region_name=self.REGION_NAME, config=session_config)
119 | response = s3_client.get_object(Bucket=self.s3ol_access_point_arn, Key='5mb_pii_text')
120 | end_time = datetime.now(tz=tzutc()) + timedelta(minutes=1)
121 |
122 | self._validate_api_call_latency_published(self.s3ol_access_point_arn, start_time, end_time)
123 | assert e.exception.response['Error']['Message'] == "Failed to complete document processing within time limit"
124 | assert e.exception.response['Error']['Code'] == "RequestTimeout"
125 |
126 |
127 | def test_classification_handler_no_pii(self):
128 | non_pii_file_name = 'clean.txt'
129 | start_time = datetime.now(tz=tzutc())
130 | test_clean_obj = self.s3.Object(self.s3ol_access_point_arn, non_pii_file_name)
131 | get_obj_response = test_clean_obj.get()
132 | get_obj_data = get_obj_response['Body'].read().decode('utf-8')
133 |
134 | with open(f"{self.DATA_PATH}/{non_pii_file_name}") as expected_output_file:
135 | expected_output = expected_output_file.read()
136 |
137 | assert expected_output == get_obj_data
138 | self._validate_api_call_latency_published(self.s3ol_access_point_arn, start_time)
139 |
--------------------------------------------------------------------------------
/test/integ/test_redaction.py:
--------------------------------------------------------------------------------
1 | import uuid
2 | from datetime import datetime, timedelta
3 |
4 | import boto3
5 | import botocore
6 | from botocore.exceptions import ClientError
7 | from dateutil.tz import tzutc
8 |
9 | from integ.integ_base import BasicIntegTest
10 |
11 |
12 | class PiiRedactionIntegTest(BasicIntegTest):
13 | BUILD_DIR = '.aws-sam/build/PiiRedactionFunction/'
14 |
15 | @classmethod
16 | def setUpClass(cls) -> None:
17 | super().setUpClass()
18 | test_run_id = str(uuid.uuid4())[0:8]
19 | cls.lambda_function_arn = cls._create_function("pii_redaction", 'handler.redact_pii_documents_handler',
20 | test_run_id, {"LOG_LEVEL": "DEBUG"}, cls.lambda_role_arn, cls.BUILD_DIR)
21 | cls.s3ol_access_point_arn = cls._create_s3ol_access_point(test_run_id, cls.lambda_function_arn)
22 |
23 | @classmethod
24 | def tearDownClass(cls) -> None:
25 | cls.s3_ctrl.delete_access_point_for_object_lambda(
26 | AccountId=cls.account_id,
27 | Name=cls.s3ol_access_point_arn.split('/')[-1]
28 | )
29 |
30 | cls.lambda_client.delete_function(
31 | FunctionName=cls.lambda_function_arn.split(':')[-1]
32 | )
33 | super().tearDownClass()
34 |
35 | def tearDown(self) -> None:
36 | self._update_lambda_env_variables(self.lambda_function_arn, {"LOG_LEVEL": "DEBUG"})
37 |
38 | def test_request_timeout(self):
39 | self._update_lambda_env_variables(self.lambda_function_arn,
40 | {"LOG_LEVEL": "DEBUG", "DOCUMENT_MAX_SIZE": "10000000",
41 | "CONTAINS_PII_ENTITIES_THREAD_COUNT": "1",
42 | "DOCUMENT_MAX_SIZE_CONTAINS_PII_ENTITIES": "5000"})
43 | start_time = datetime.now(tz=tzutc())
44 | # To keep total integration test time low we avoid making multiple retry attempts
45 | with self.assertRaises(ClientError) as e:
46 | session_config = botocore.config.Config(
47 | retries={
48 | 'max_attempts': 0
49 | })
50 | s3_client = boto3.client('s3', region_name=self.REGION_NAME, config=session_config)
51 | response = s3_client.get_object(Bucket=self.s3ol_access_point_arn, Key='5mb_pii_text')
52 | end_time = datetime.now(tz=tzutc()) + timedelta(minutes=1)
53 | self._validate_api_call_latency_published(self.s3ol_access_point_arn, start_time, end_time)
54 | assert e.exception.response['Error']['Message'] == "Failed to complete document processing within time limit"
55 | assert e.exception.response['Error']['Code'] == "RequestTimeout"
56 |
57 | def test_redaction_lambda_with_default_entity_types(self):
58 | start_time = datetime.now(tz=tzutc())
59 | test_pii_obj = self.s3.Object(self.s3ol_access_point_arn, 'pii_input.txt')
60 | test_obj_response = test_pii_obj.get()
61 | test_obj_data = test_obj_response['Body'].read().decode('utf-8')
62 |
63 | with open(f"{self.DATA_PATH}/pii_output.txt") as expected_output_file:
64 | expected_output = expected_output_file.read()
65 |
66 | assert expected_output == test_obj_data
67 | self._validate_pii_count_metric_published(self.s3ol_access_point_arn, start_time, self.PII_ENTITY_TYPES_IN_TEST_DOC)
68 | self._validate_api_call_latency_published(self.s3ol_access_point_arn, start_time, is_pii_detection_performed=True)
69 |
70 | def test_redaction_lambda_with_pii_overridden_entity_types(self):
71 | self._update_lambda_env_variables(self.lambda_function_arn,
72 | {"LOG_LEVEL": "DEBUG",
73 | "PII_ENTITY_TYPES": "CREDIT_DEBIT_NUMBER,BANK_ROUTING,BANK_ACCOUNT_NUMBER,EMAIL,PASSWORD"})
74 | start_time = datetime.now(tz=tzutc())
75 | response = self.s3_client.get_object(Bucket=self.s3ol_access_point_arn, Key='pii_input.txt')
76 | with open(f"{self.DATA_PATH}/pii_bank_routing_redacted.txt") as expected_output_file:
77 | assert response['Body'].read().decode('utf-8') == expected_output_file.read()
78 | self._validate_pii_count_metric_published(self.s3ol_access_point_arn, start_time,
79 | ['BANK_ROUTING', 'CREDIT_DEBIT_NUMBER', 'BANK_ACCOUNT_NUMBER'])
80 |
81 | def test_redaction_lambda_fails_when_file_size_exceeds_limit(self):
82 | self._update_lambda_env_variables(self.lambda_function_arn, {"LOG_LEVEL": "DEBUG", "DOCUMENT_MAX_SIZE": "500"})
83 | test_pii_obj = self.s3.Object(self.s3ol_access_point_arn, 'pii_input.txt')
84 | with self.assertRaises(ClientError) as e:
85 | test_pii_obj.get()
86 | assert e.exception.response['Error']['Message'] == "Size of the requested object exceeds maximum file size supported"
87 | assert e.exception.response['Error']['Code'] == "EntityTooLarge"
88 |
89 | def test_redaction_lambda_succeeds_with_an_overridden_max_file_size_limit(self):
90 | self._update_lambda_env_variables(self.lambda_function_arn, {"LOG_LEVEL": "DEBUG", "DOCUMENT_MAX_SIZE": "1500000",
91 | "DOCUMENT_MAX_SIZE_CONTAINS_PII_ENTITIES": "5000"})
92 | response = self.s3_client.get_object(Bucket=self.s3ol_access_point_arn, Key='1mb_pii_text')
93 | with open(f"{self.DATA_PATH}/1mb_pii_redacted_text") as expected_output_file:
94 | assert response['Body'].read().decode('utf-8') == expected_output_file.read()
95 |
96 | def test_redaction_lambda_with_unsupported_file(self):
97 | test_obj = self.s3.Object(self.s3ol_access_point_arn, 'RandomImage.png')
98 | with self.assertRaises(ClientError) as e:
99 | test_obj.get()
100 | assert e.exception.response['Error']['Message'] == "Unsupported file encountered for determining Pii"
101 | assert e.exception.response['Error']['Code'] == "UnexpectedContent"
102 |
103 | def test_classification_lambda_with_unsupported_file_handling_set_to_pass(self):
104 | self._update_lambda_env_variables(self.lambda_function_arn, {"LOG_LEVEL": "DEBUG", "UNSUPPORTED_FILE_HANDLING": 'PASS'})
105 | response = self.s3_client.get_object(Bucket=self.s3ol_access_point_arn, Key='RandomImage.png')
106 | assert 'Body' in response
107 |
108 | def test_classification_lambda_with_partial_object(self):
109 | with self.assertRaises(ClientError) as e:
110 | self.s3_client.get_object(Bucket=self.s3ol_access_point_arn, Key="pii_input.txt", Range="bytes=0-100")
111 | assert e.exception.response['Error']['Message'] == "HTTP Header Range is not supported"
112 |
113 | def test_classification_lambda_with_partial_object_allowed_with_versioned(self):
114 | file_name = 'pii_input.txt'
115 | self.s3_client.upload_file(f"{self.DATA_PATH}/{file_name}", self.bucket_name, file_name)
116 | versions = set()
117 | self._update_lambda_env_variables(self.lambda_function_arn, {"LOG_LEVEL": "DEBUG", "IS_PARTIAL_OBJECT_SUPPORTED": "TRUE"})
118 |
119 | for version in self.s3_client.list_object_versions(Bucket=self.bucket_name)['Versions']:
120 | if version['Key'] == 'pii_input.txt':
121 | versions.add(version['VersionId'])
122 | assert len(versions) >= 2, f"Expected at least 2 different versions of {file_name}"
123 | for versionId in versions:
124 | response = self.s3_client.get_object(Bucket=self.s3ol_access_point_arn, Key=file_name, Range="bytes=0-100",
125 | VersionId=versionId)
126 | assert response['ContentRange'] == "bytes 0-100/611"
127 | assert response['ContentLength'] == 101
128 | assert response['ContentType'] == 'binary/octet-stream'
129 | assert response['VersionId'] == versionId
130 |
131 |
132 | def test_redaction_handler_no_pii(self):
133 | non_pii_file_name = 'clean.txt'
134 | test_clean_obj = self.s3.Object(self.s3ol_access_point_arn, non_pii_file_name)
135 | get_obj_response = test_clean_obj.get()
136 | get_obj_data = get_obj_response['Body'].read().decode('utf-8')
137 |
138 | with open(f"{self.DATA_PATH}/{non_pii_file_name}") as expected_output_file:
139 | expected_output = expected_output_file.read()
140 |
141 | assert expected_output == get_obj_data
142 |
--------------------------------------------------------------------------------
/test/unit/test_comprehend_client.py:
--------------------------------------------------------------------------------
1 | from time import sleep, time
2 | from unittest import TestCase
3 | from unittest.mock import patch, MagicMock, call
4 |
5 | from botocore.awsrequest import AWSRequest
6 |
7 | from clients.comprehend_client import ComprehendClient
8 | from constants import BEGIN_OFFSET, END_OFFSET, ENTITY_TYPE, SCORE
9 | from data_object import Document
10 |
11 |
12 | class ComprehendClientTest(TestCase):
13 | @patch('clients.comprehend_client.boto3')
14 | def test_comprehend_client_constuctor(self, mocked_boto3):
15 | mocked_client = MagicMock()
16 | mocked_boto3.client.return_value = mocked_client
17 | comprehend_client = ComprehendClient(s3ol_access_point="some_access_point_arn")
18 | mocked_client.meta.events.register.assert_called_with('before-sign.comprehend.*', comprehend_client._add_session_header)
19 | request = AWSRequest()
20 | comprehend_client._add_session_header(request)
21 | assert len(request.headers.get('x-amzn-session-id')) >= 10
22 | assert comprehend_client.classification_executor_service._max_workers == 20
23 | assert comprehend_client.redaction_executor_service._max_workers == 8
24 |
25 | @patch('clients.comprehend_client.boto3')
26 | def test_comprehend_detect_pii_entities(self, mocked_boto3):
27 | DUMMY_PII_ENTITY = {BEGIN_OFFSET: 12, END_OFFSET: 14, ENTITY_TYPE: 'SSN', SCORE: 0.345}
28 |
29 | def mocked_api_call(**kwargs):
30 | sleep(0.1)
31 | return {'Entities': [DUMMY_PII_ENTITY], 'ResponseMetadata': {'RetryAttempts': 0}}
32 |
33 | mocked_client = MagicMock()
34 | mocked_boto3.client.return_value = mocked_client
35 | comprehend_client = ComprehendClient(s3ol_access_point="Some_random_access_point",pii_redaction_thread_count=5)
36 | mocked_client.detect_pii_entities.side_effect = mocked_api_call
37 | start_time = time()
38 | docs_with_pii_entity = comprehend_client.detect_pii_documents(
39 | documents=[Document(text="Some Random 1mb_pii_text", ) for i in range(1, 20)],
40 | language='en')
41 | end_time = time()
42 | mocked_client.detect_pii_entities.assert_has_calls([call(Text="Some Random 1mb_pii_text", LanguageCode='en') for i in range(1, 20)])
43 |
44 | assert len(comprehend_client.detection_metrics.metrics) == 38
45 | for i in range(0, 19, 2):
46 | assert comprehend_client.detection_metrics.metrics[i]['MetricName'] == 'ErrorCount'
47 | assert comprehend_client.detection_metrics.metrics[i]['Value'] == 0
48 | assert comprehend_client.detection_metrics.metrics[i + 1]['MetricName'] == 'Latency'
49 | assert len(comprehend_client.classify_metrics.metrics) == 0
50 |
51 | # should be around 0.4 : 20 calls with 5 thread counts , where each call taking 0.1 seconds to complete
52 | assert 0.4 <= end_time - start_time < 0.5
53 | for doc in docs_with_pii_entity:
54 | assert len(doc.pii_entities) == 1
55 | assert doc.pii_entities[0] == DUMMY_PII_ENTITY
56 |
57 | @patch('clients.comprehend_client.boto3')
58 | def test_comprehend_contains_pii_entities(self, mocked_boto3):
59 | classification_result = {'Labels': [{'Name': 'SSN', 'Score': 0.1234}], 'ResponseMetadata': {'RetryAttempts': 0}}
60 |
61 | def mocked_api_call(**kwargs):
62 | sleep(0.1)
63 | return classification_result
64 |
65 | mocked_client = MagicMock()
66 | mocked_boto3.client.return_value = mocked_client
67 | comprehend_client = ComprehendClient("some_access_point_arn", pii_classification_thread_count=2, pii_redaction_thread_count=5)
68 | mocked_client.contains_pii_entities.side_effect = mocked_api_call
69 | start_time = time()
70 | docs_with_pii_classification = comprehend_client.contains_pii_entities(
71 | documents=[Document(text="Some Random 1mb_pii_text", ) for i in range(1, 4)],
72 | language='en')
73 | end_time = time()
74 |
75 | mocked_client.contains_pii_entities.assert_has_calls(
76 | [call(Text="Some Random 1mb_pii_text", LanguageCode='en') for i in range(1, 4)])
77 | # should be around 0.2 : 4 calls with 2 thread counts , where each call taking 0.1 seconds to complete
78 | assert 0.2 <= end_time - start_time < 0.3
79 | assert len(comprehend_client.classify_metrics.metrics) == 6
80 | for i in range(0, 6, 2):
81 | assert comprehend_client.classify_metrics.metrics[i]['MetricName'] == 'ErrorCount'
82 | assert comprehend_client.classify_metrics.metrics[i + 1]['MetricName'] == 'Latency'
83 |
84 | assert len(comprehend_client.detection_metrics.metrics) == 0
85 | for doc in docs_with_pii_classification:
86 | assert doc.pii_classification == {'SSN': 0.1234}
87 |
88 | @patch('clients.comprehend_client.boto3')
89 | def test_comprehend_contains_pii_entities_failure(self, mocked_boto3):
90 | classification_result = {'Labels': [{'Name': 'SSN', 'Score': 0.1234}], 'ResponseMetadata': {'RetryAttempts': 0}}
91 |
92 | mocked_client = MagicMock()
93 | mocked_boto3.client.return_value = mocked_client
94 | comprehend_client = ComprehendClient(s3ol_access_point="Some_access_point_arn", pii_classification_thread_count=4)
95 | api_invocation_exception = Exception("Some unrecoverable error")
96 | mocked_client.contains_pii_entities.side_effect = [classification_result, classification_result, api_invocation_exception,
97 | classification_result]
98 | try:
99 | comprehend_client.contains_pii_entities(documents=[Document(text="Some Random 1mb_pii_text", ) for i in range(0, 4)],
100 | language='en')
101 |
102 | assert False, "Expected an exception "
103 | except Exception as e:
104 | assert e == api_invocation_exception
105 | mocked_client.contains_pii_entities.assert_has_calls(
106 | [call(Text="Some Random 1mb_pii_text", LanguageCode='en') for i in range(0, 4)])
107 | assert len(comprehend_client.classify_metrics.metrics) == 8 # 4 latency metrics and 1 fault
108 | assert len(comprehend_client.detection_metrics.metrics) == 0
109 | assert comprehend_client.classify_metrics.service_name == "Comprehend"
110 | assert comprehend_client.classify_metrics.s3ol_access_point_arn == "Some_access_point_arn"
111 | assert comprehend_client.classify_metrics.api == "ContainsPiiEntities"
112 | metric_count = {"ErrorCount": 0, "Latency": 0}
113 | for i in range(0, 8):
114 | metric_name = comprehend_client.classify_metrics.metrics[i]['MetricName']
115 | metric_count[metric_name] += 1
116 | assert metric_count['ErrorCount'] == 4
117 | assert metric_count['Latency'] == 4
118 |
119 | @patch('clients.comprehend_client.boto3')
120 | def test_comprehend_detect_pii_entities_failure(self, mocked_boto3):
121 | DUMMY_PII_ENTITY = {'Entities': [{BEGIN_OFFSET: 12, END_OFFSET: 14, ENTITY_TYPE: 'SSN', SCORE: 0.345}],
122 | 'ResponseMetadata': {'RetryAttempts': 0}}
123 | mocked_client = MagicMock()
124 | mocked_boto3.client.return_value = mocked_client
125 | comprehend_client = ComprehendClient(s3ol_access_point="Some_access_point_arn")
126 | api_invocation_exception = Exception("Some unrecoverable error")
127 | mocked_client.detect_pii_entities.side_effect = [DUMMY_PII_ENTITY, DUMMY_PII_ENTITY, api_invocation_exception,
128 | DUMMY_PII_ENTITY]
129 | try:
130 | comprehend_client.detect_pii_documents(documents=[Document(text="Some Random 1mb_pii_text", ) for i in range(0, 4)],
131 | language='en')
132 |
133 | assert False, "Expected an exception "
134 | except Exception as e:
135 | assert e == api_invocation_exception
136 | mocked_client.detect_pii_entities.assert_has_calls([call(Text="Some Random 1mb_pii_text", LanguageCode='en') for i in range(0, 4)])
137 | assert len(comprehend_client.detection_metrics.metrics) == 8 # 4 latency metrics and 1 fault
138 | assert len(comprehend_client.classify_metrics.metrics) == 0
139 | assert comprehend_client.detection_metrics.service_name == "Comprehend"
140 | assert comprehend_client.detection_metrics.s3ol_access_point_arn == "Some_access_point_arn"
141 | assert comprehend_client.detection_metrics.api == "DetectPiiEntities"
142 | metric_count = {"ErrorCount": 0, "Latency": 0}
143 | for i in range(0, 8):
144 | metric_name = comprehend_client.detection_metrics.metrics[i]['MetricName']
145 | metric_count[metric_name] += 1
146 | assert metric_count['ErrorCount'] == 4
147 | assert metric_count['Latency'] == 4
148 |
--------------------------------------------------------------------------------
/ACCESS_CONTROL_README.md:
--------------------------------------------------------------------------------
1 | # PII Access Control S3 Object Lambda function
2 |
3 | This serverless app helps you to control access to PII (Personally Identifiable Information) from valid text files present in S3.
4 | This app deploys a Lambda function which can be attached to S3 object lambda access point.
5 | The lambda function internally uses AWS Comprehend to detect PII entities from the text .
6 |
7 | ## App Architecture
8 | 
9 |
10 | Lambda function is optimized to leverage Comprehend's ContainsPiiEntities.
11 |
12 | 1. Lambda function is invoked with a request containing information about the S3 object to get and transform.
13 | 2. The request contains a S3 presigned url to fetch the requested object.
14 | 3. The data is split into chunks that are accepted by Comprehend’s ContainsPiiEntities API and call the API with each chunk.
15 | 4. The responses are aggregated from all chunks.
16 | 5. Lambda function callsback S3 with the response i.e either the text data or throws exception if the file contains PII.
17 | 6. If any failure happens while processing, Lambda function returns an appropriate error response to S3 which will be returned to the original caller.
18 | 7. Lambda function returns with 0 exit code .i.e. with out any error if no error occurred else would fail.
19 |
20 | ## Installation Instructions
21 |
22 | 1. [Create an AWS account](https://portal.aws.amazon.com/gp/aws/developer/registration/index.html) if you do not already have one and login
23 | 1. Go to the app's page on the [Serverless Application Repository](https://console.aws.amazon.com/lambda/home#/create/app?applicationId=arn:aws:serverlessrepo:us-east-1:839782855223:applications/ComprehendPiiAccessControlS3ObjectLambda)
24 | 1. Provide the required app parameters (see parameter details below) and click "Deploy".
25 |
26 | ## Parameters
27 | Following are the parameters that you can tune to get desired behavior
28 | #### Environment variables
29 | Following environment variables for Lambda function can be set to get desired behaviour
30 | 1. `LOG_LEVEL` - Log level for Lambda function function logging, e.g., ERROR, INFO, DEBUG, etc. Default: `INFO`.
31 | 1. `UNSUPPORTED_FILE_HANDLING` Handling logic for Unsupported files. Valid values are `PASS` and `FAIL` (Default: `FAIL`). If set to `FAIL` it will throw UnsupportedFileException when the requested object is of unsupported type.
32 | 1. `IS_PARTIAL_OBJECT_SUPPORTED` Whether to support partial objects or not. Accessing partial object through http headers such byte-range can corrupt the object and/or affect PII detection accuracy. Valid values are `TRUE` and `FALSE`. Default: `FALSE`.
33 | 1. `DOCUMENT_MAX_SIZE_CONTAINS_PII_ENTITIES` Maximum document size (in bytes) to be used for making calls to Comprehend's ContainsPiiDocument API for classifying PII entity types present in the doc Default: 50000.
34 | 1. `PII_ENTITY_TYPES` : List of comma separated PII entity types to be considered for access control. Refer [Comprehend's documentation page](https://docs.aws.amazon.com/comprehend/latest/dg/how-pii.html#how-pii-types) for list of supported PII entity types. Default: `ALL` which signifies all entity types that comprehend supports.
35 | 1. `SUBSEGMENT_OVERLAPPING_TOKENS` : Number of tokens/words to overlap among segments of a document in case chunking is needed because of maximum document size limit. Default: 20.
36 | 1. `DOCUMENT_MAX_SIZE` : Default maximum document size (in bytes) that this function can process otherwise will throw exception for too large document size.
37 | 1. `CONFIDENCE_THRESHOLD` : The minimum prediction confidence score above which PII classification and detection would be considered as final answer. Valid range (0.5 to 1.0). Default: 0.5.
38 | 1. `MAX_CHARS_OVERLAP` : Maximum characters to overlap among segments of a document in case chunking is needed because of maximum document size limit. Default: 2.
39 | 1. `DEFAULT_LANGUAGE_CODE` : Default language of the text to be processed. This code will be used for interacting with Comprehend . Default: en.
40 | 1. `CONTAINS_PII_ENTITIES_THREAD_COUNT` : Number of threads to use for calling Comprehend's ContainsPiiEntities API. This controls the number of simultaneous calls that will be made from this Lambda function. Default: 20.
41 | 1. `PUBLISH_CLOUD_WATCH_METRICS` : This determines whether or not to publish metrics to Cloudwatch. Default: true.
42 |
43 | #### Runtime variables
44 | You can add following arguments in S3 object lambda access point configuration payload to override the default value configured used by the Lambda function . These values would take precedence over environment variables. Provide these variables as a json string like the following example
45 | ```
46 | ...
47 | "payload": "{\"pii_entity_types\" : [\"CREDIT_DEBIT_NUMBER\"],\"mask_mode\":\"MASK\", \"mask_character\" : \"*\",\"confidence_threshold\":0.6,\"language_code\":\"en\"}"
48 | ...
49 | ```
50 | Use these parameters to get desired behaviors from different access point configuration attached to the same lambda function.
51 | 1. `pii_entity_types` : List of PII entity types to be considered for redaction. e.g. `["SSN","CREDIT_DEBIT_NUMBER"]`.
52 | 1. `confidence_threshold` :The minimum prediction confidence score above which PII classification and detection would be considered as final answer.
53 | 1. `language_code`: Language of the text. This will be used to interact with Comprehend.
54 |
55 | ## App Outputs
56 |
57 | #### Successful response
58 | In case the text file contains PII, it would be redacted and returned in response to GetObject API output.
59 | #### Error responses
60 | Lambda function would forward the standard [S3 error responses](https://docs.aws.amazon.com/AmazonS3/latest/API/ErrorResponses.html) it will receive while downloading the file from S3.
61 |
62 | Further following error responses will be thrown by Lambda function:
63 |
64 | |Status Code|Error Code|Error Message|Description|
65 | |---|---|---|---|
66 | | BAD_REQUEST_400 | InvalidRequest | Lambda function has been incorrectly setup | An incorrect configuration which restricts lambda function to even start handling the incoming events|
67 | | BAD_REQUEST_400 | UnexpectedContent | Unsupported file encountered for determining PII | This error would be thrown in case caller tries to get an invalid utf8 file (e.g image) and UNSUPPORTED_FILE_HANDLING variable is set to FAIL|
68 | | BAD_REQUEST_400 | EntityTooLarge | Size of the requested object exceeds maximum file size supported | This error would be thrown in case caller tries to get an object which is beyond the max file size supported|
69 | | BAD_REQUEST_400 | RequestTimeout | Failed to complete document processing within time limit | This error would be thrown in case lambda is not able to complete the processing of the document within the time limit. This could be because your file size is too big or you are getting throttled by either S3 or Comprehend.|
70 | | INTERNAL_SERVER_ERROR_500 | InternalError | An internal error occurred while processing the file | Any other error occurred while processing the object |
71 | | FORBIDDEN_403 | AccessDenied | Document Contains PII | The the requested document has been inferred to contain PII|
72 |
73 | ## Metrics
74 | Metrics are published after each invocation of the lambda function and are a best effort attempt (Failures in CloudWatch metric publishing are ignored)
75 |
76 | All metrics will be under the Namespace: ComprehendS3ObjectLambda
77 |
78 | ### Metrics for processed document
79 | |MetricName|Description|Unit|Dimensions|
80 | |---|---|---|---|
81 | |PiiDocumentsProcessed|Emitted after processing a document that contains pii|Count|S3ObjectLambdaAccessPoint, Language|
82 | |DocumentsProcessed|Emitted after processing any document|Count|S3ObjectLambdaAccessPoint, Language|
83 | |PiiDocumentTypesProcessed|Emitted after processing a document that contains PII for each type of PII of interest|Count|S3ObjectLambdaAccessPoint, Language, PiiEntityType|
84 |
85 | ### Metrics for Comprehend operations
86 | |MetricName|Description|Unit|Dimensions|
87 | |---|---|---|---|
88 | |Latency|The latency of Comprehend DetectPiiEntities API|Milliseconds|Comprehend, DetectPiiEntities|
89 | |Latency|The latency of Comprehend ContainsPiiEntities API|Milliseconds|Comprehend, ContainsPiiEntities|
90 | |ErrorCount|The error count of Comprehend DetectPiiEntities API|Count|Comprehend, DetectPiiEntities|
91 | |ErrorCount|The error count of Comprehend ContainsPiiEntities API|Count|Comprehend, ContainsPiiEntities|
92 |
93 | ### Metrics for S3 operations
94 | |MetricName|Description|Unit|Dimensions|
95 | |---|---|---|---|
96 | |Latency|The latency of S3 WriteGetObjectResponse API|Milliseconds|S3, WriteGetObjectResponse|
97 | |Latency|The latency of downloading a file from a presigned S3 url|Milliseconds|S3, DownloadPresignedUrl|
98 | |ErrorCount|The error count of S3 WriteGetObjectResponse API|Count|S3, WriteGetObjectResponse|
99 | |ErrorCount|The error count of downloading a file from a presigned S3 url|Count|S3, DownloadPresignedUrl|
100 |
101 | ## License Summary
102 |
103 | This code is made available under the MIT-0 license. See the LICENSE file.
--------------------------------------------------------------------------------
/test/unit/test_s3_client.py:
--------------------------------------------------------------------------------
1 | from unittest import TestCase
2 | from unittest.mock import patch, MagicMock
3 |
4 | from clients.s3_client import S3Client
5 | from constants import BEGIN_OFFSET, END_OFFSET, ENTITY_TYPE, SCORE, S3_STATUS_CODES, S3_ERROR_CODES, RANGE
6 | from exceptions import S3DownloadException, FileSizeLimitExceededException, UnsupportedFileException
7 |
8 | PRESIGNED_URL_TEST = "https://s3ol-classifier.s3.amazonaws.com/test.txt"
9 |
10 |
11 | class MockResponse:
12 | def __init__(self, content, status_code, headers):
13 | self.status_code = status_code
14 | self.content = content
15 | self.headers = headers
16 |
17 |
18 | def get_s3_xml_response(code: str, message: str = '') -> str:
19 | return f"\n{code}{message}"
20 |
21 |
22 | class S3ClientTest(TestCase):
23 | @patch('clients.s3_client.boto3')
24 | def test_s3_client_respond_back_with_error(self, mocked_boto3):
25 | mocked_client = MagicMock()
26 | mocked_boto3.client.return_value = mocked_client
27 | s3_client = S3Client(s3ol_access_point="Random_access_point")
28 | s3_client.respond_back_with_error(status_code=S3_STATUS_CODES.PRECONDITION_FAILED_412,
29 | error_code=S3_ERROR_CODES.PreconditionFailed, error_message="Some Error",
30 | request_route="Route", request_token="q2334")
31 |
32 | mocked_client.write_get_object_response.assert_called_once_with(StatusCode=412,
33 | ErrorCode='PreconditionFailed',
34 | ErrorMessage="Some Error",
35 | RequestRoute='Route', RequestToken="q2334")
36 |
37 | @patch('clients.s3_client.boto3')
38 | def test_s3_client_respond_back_with_data_default_status_code(self, mocked_boto3):
39 | mocked_client = MagicMock()
40 | mocked_boto3.client.return_value = mocked_client
41 | s3_client = S3Client(s3ol_access_point="Random_access_point")
42 | s3_client.respond_back_with_data(data='SomeData',
43 | headers={"ContentRange": "0-100", "SomeRandomHeader": '0123', "Content-Length": "101"},
44 | request_route="Route", request_token="q2334")
45 |
46 | mocked_client.write_get_object_response.assert_called_once_with(Body='SomeData', ContentLength=101,
47 | RequestRoute='Route', RequestToken="q2334",
48 | StatusCode=200)
49 |
50 | @patch('clients.s3_client.boto3')
51 | def test_s3_client_respond_back_with_data_partial_data(self, mocked_boto3):
52 | mocked_client = MagicMock()
53 | mocked_boto3.client.return_value = mocked_client
54 | s3_client = S3Client(s3ol_access_point="Random_access_point")
55 | s3_client.respond_back_with_data(data='SomeData', headers={"Content-Range": "0-1200", "SomeRandomHeader": '0123'},
56 | request_route="Route", request_token="q2334", status_code=S3_STATUS_CODES.PARTIAL_CONTENT_206)
57 |
58 | mocked_client.write_get_object_response.assert_called_once_with(Body='SomeData', ContentRange="0-1200",
59 | RequestRoute='Route', RequestToken="q2334",
60 | StatusCode=206)
61 |
62 | @patch('clients.s3_client.requests.Session.get',
63 | side_effect=lambda *args, **kwargs: MockResponse(b'Test', 200, {'Content-Length': '4'}))
64 | def test_s3_client_download_file_from_presigned_url_200_ok(self, mocked_get):
65 | s3_client = S3Client(s3ol_access_point="Random_access_point")
66 | http_header = {'some-header': 'header-value'}
67 | text, response_http_headers, status_code = s3_client.download_file_from_presigned_url(PRESIGNED_URL_TEST, http_header)
68 | assert text == 'Test'
69 | assert response_http_headers == {'Content-Length': '4'}
70 | assert status_code == S3_STATUS_CODES.OK_200
71 | mocked_get.assert_called_with(PRESIGNED_URL_TEST, timeout=10, headers=http_header)
72 |
73 | @patch('clients.s3_client.requests.Session.get',
74 | side_effect=lambda *args, **kwargs: MockResponse(b'Test', 206, {'Content-Length': '100'}))
75 | def test_s3_client_download_partial_file_from_presigned_url(self, mocked_get):
76 | s3_client = S3Client(s3ol_access_point="Random_access_point")
77 | http_header = {'some-header': 'header-value'}
78 | text, response_http_headers, status_code = s3_client.download_file_from_presigned_url(PRESIGNED_URL_TEST, http_header)
79 | assert text == 'Test'
80 | assert response_http_headers == {'Content-Length': '100'}
81 | assert status_code == S3_STATUS_CODES.PARTIAL_CONTENT_206
82 | mocked_get.assert_called_with(PRESIGNED_URL_TEST, timeout=10, headers=http_header)
83 |
84 | @patch('clients.s3_client.requests.Session.get',
85 | side_effect=lambda *args, **kwargs: MockResponse(b'Test', 400, {'Content-Length': '4'}))
86 | def test_s3_client_download_file_from_presigned_url_400_from_get(self, mocked_get):
87 | s3_client = S3Client(s3ol_access_point="Random_access_point")
88 | self.assertRaises(S3DownloadException, s3_client.download_file_from_presigned_url, PRESIGNED_URL_TEST, {})
89 |
90 | assert mocked_get.call_count == 5
91 |
92 | @patch('clients.s3_client.requests.Session.get',
93 | side_effect=lambda *args, **kwargs: MockResponse(b'A' * (11 * 1024 * 1024), 200, {'Content-Length': str(11 * 1024 * 1024)}))
94 | def test_s3_client_download_file_from_presigned_url_file_size_limit_exceeded(self, mocked_get):
95 | s3_client = S3Client(s3ol_access_point="Random_access_point")
96 | self.assertRaises(FileSizeLimitExceededException, s3_client.download_file_from_presigned_url, PRESIGNED_URL_TEST, {})
97 |
98 | mocked_get.assert_called_once()
99 |
100 | @patch('clients.s3_client.requests.Session.get',
101 | side_effect=lambda *args, **kwargs: MockResponse(get_s3_xml_response('AccessDenied').encode('utf-8'), 200,
102 | {'Content-Length': '4'}))
103 | def test_s3_client_download_file_from_presigned_url_access_denied(self, mocked_get):
104 | s3_client = S3Client(s3ol_access_point="Random_access_point")
105 | self.assertRaises(S3DownloadException, s3_client.download_file_from_presigned_url, PRESIGNED_URL_TEST, {})
106 |
107 | mocked_get.assert_called_once()
108 |
109 | @patch('clients.s3_client.requests.Session.get',
110 | side_effect=lambda *args, **kwargs: MockResponse(get_s3_xml_response('UnknownError').encode('utf-8'), 200,
111 | {'Content-Length': '4'}))
112 | def test_s3_client_download_file_from_presigned_url_unknown_error(self, mocked_get):
113 | s3_client = S3Client(s3ol_access_point="Random_access_point")
114 | self.assertRaises(S3DownloadException, s3_client.download_file_from_presigned_url, PRESIGNED_URL_TEST, {})
115 |
116 | assert mocked_get.call_count == 5
117 |
118 | @patch('clients.s3_client.requests.Session.get',
119 | side_effect=lambda *args, **kwargs: MockResponse(bytearray.fromhex('ff'), 200, {'Content-Length': '4'}))
120 | def test_s3_client_download_file_from_presigned_unicode_decode_error(self, mocked_get):
121 | s3_client = S3Client(s3ol_access_point="Random_access_point")
122 | self.assertRaises(UnsupportedFileException, s3_client.download_file_from_presigned_url, PRESIGNED_URL_TEST, {})
123 |
124 | mocked_get.assert_called_once()
125 |
126 | @patch('clients.s3_client.requests.Session.get',
127 | side_effect=lambda *args, **kwargs: MockResponse(bytearray.fromhex('ff'), 200, {'Content-Length': '4'}))
128 | def test_s3_client_download_file_from_presigned_unicode_decode_error_error(self, mocked_get):
129 | s3_client = S3Client(s3ol_access_point="Random_access_point")
130 | self.assertRaises(UnsupportedFileException, s3_client.download_file_from_presigned_url, PRESIGNED_URL_TEST, {})
131 |
132 | mocked_get.assert_called_once()
133 |
134 | @patch('clients.s3_client.requests.Session.get',
135 | side_effect=lambda *args, **kwargs: MockResponse(get_s3_xml_response('InternalError').encode('utf-8'), 200,
136 | {'Content-Length': '4'}))
137 | def test_s3_client_download_file_from_presigned_retry(self, mocked_get):
138 | s3_client = S3Client(s3ol_access_point="Random_access_point")
139 | self.assertRaises(S3DownloadException, s3_client.download_file_from_presigned_url, PRESIGNED_URL_TEST, {})
140 |
141 | assert mocked_get.call_count == 5
142 |
--------------------------------------------------------------------------------
/test/unit/test_processors.py:
--------------------------------------------------------------------------------
1 | import os
2 | import timeit
3 | from random import shuffle
4 | from unittest import TestCase
5 |
6 | from constants import REPLACE_WITH_PII_ENTITY_TYPE
7 | from data_object import Document, RedactionConfig
8 | from exceptions import InvalidConfigurationException
9 | from processors import Redactor, Segmenter
10 |
11 | this_module_path = os.path.dirname(__file__)
12 |
13 |
14 | class ProcessorsTest(TestCase):
15 | def test_segmenter_basic_text(self):
16 | segmentor = Segmenter(50, overlap_tokens=3)
17 | original_text = "Barack Hussein Obama II is an American politician and attorney who served as the " \
18 | "44th president of the United States from 2009 to 2017."
19 | segments = segmentor.segment(original_text)
20 | expected_segments = [
21 | "Barack Hussein Obama II is an American politician ",
22 | "an American politician and attorney who served as ",
23 | "who served as the 44th president of the United ",
24 | "of the United States from 2009 to 2017."]
25 | for expected_segment, actual_segment in zip(expected_segments, segments):
26 | assert expected_segment == actual_segment.text
27 | shuffle(segments)
28 | assert segmentor.de_segment(segments).text == original_text
29 |
30 |
31 | def test_segmenter_no_segmentation_needed(self):
32 | segmentor = Segmenter(5000, overlap_tokens=3)
33 | original_text = "Barack Hussein Obama II is an American politician and attorney who served as the " \
34 | "44th president of the United States from 2009 to 2017."
35 | segments = segmentor.segment(original_text)
36 | assert len(segments) == 1
37 | assert segments[0].text == original_text
38 | assert segmentor.de_segment(segments).text == original_text
39 |
40 | def test_segmenter_max_chars_limit(self):
41 | segmentor = Segmenter(50, overlap_tokens=3, max_overlapping_chars=20)
42 | original_text = "BarackHusseinObamaIIisanAmerican politicianandattorneywhoservedasthe " \
43 | "44th president of the United Statesfrom2009to2017."
44 | segments = segmentor.segment(original_text)
45 | expected_segments = [
46 | "BarackHusseinObamaIIisanAmerican ",
47 | "nObamaIIisanAmerican politician",
48 | "nAmerican politicianandattorneywhoservedasthe ",
49 | "torneywhoservedasthe 44th president of the United ",
50 | "of the United Statesfrom2009to2017.", ]
51 | for expected_segment, actual_segment in zip(expected_segments, segments):
52 | assert expected_segment == actual_segment.text
53 | shuffle(segments)
54 | assert segmentor.de_segment(segments).text == original_text
55 |
56 | def test_segmenter_unicode_chars(self):
57 | segmentor = Segmenter(100, overlap_tokens=3)
58 | original_text = "ʕ•́ᴥ•̀ʔっ♡ Emoticons 😜 ʕ•́ᴥ•̀ʔっ♡ Emoticons 😜 ᗷᙓ ò¥¥¥¥¥¥¥ᗢᖇᓮᘐᓰﬡᗩᒪ ℬ℮ ¢◎øł Bᴇ ʏᴏᴜʀsᴇʟғ विकिपीडिया सभी विषयों पर प्रामाणिक और उपयोग, " \
59 | "परिवर्तन व पुनर्वितरण के लिए स्वतन्त्र ज्ञानकोश बनाने hànbǎobāo, hànbǎo 汉堡包/漢堡包, 汉堡/漢堡 – hamburger"
60 | segments = segmentor.segment(original_text)
61 | expected_segments = [
62 | "ʕ•́ᴥ•̀ʔっ♡ Emoticons 😜 ʕ•́ᴥ•̀ʔっ♡ Emoticons 😜 ᗷᙓ ",
63 | "Emoticons 😜 ᗷᙓ ò¥¥¥¥¥¥¥ᗢᖇᓮᘐᓰﬡᗩᒪ ℬ℮ ¢◎øł Bᴇ ",
64 | "ℬ℮ ¢◎øł Bᴇ ʏᴏᴜʀsᴇʟғ विकिपीडिया सभी ",
65 | "ʏᴏᴜʀsᴇʟғ विकिपीडिया सभी विषयों पर ",
66 | "सभी विषयों पर प्रामाणिक और उपयोग, ",
67 | "प्रामाणिक और उपयोग, परिवर्तन व ",
68 | "उपयोग, परिवर्तन व पुनर्वितरण के लिए ",
69 | "पुनर्वितरण के लिए स्वतन्त्र ",
70 | "के लिए स्वतन्त्र ज्ञानकोश बनाने hànbǎobāo, ",
71 | "ज्ञानकोश बनाने hànbǎobāo, hànbǎo 汉堡包/漢堡包, 汉堡/漢堡 ",
72 | "hànbǎo 汉堡包/漢堡包, 汉堡/漢堡 – hamburger"]
73 | assert len(expected_segments) == len(segments)
74 | for expected_segment, actual_segment in zip(expected_segments, segments):
75 | assert expected_segment == actual_segment.text
76 | assert segmentor.de_segment(segments).text == original_text
77 |
78 | def test_desegment_overlapping_results(self):
79 | segments = [
80 | Document(text="Some Random SSN Some Random email-id Some Random name and address and some credit card number", char_offset=0,
81 | pii_classification={'SSN': 0.234, 'EMAIL': 0.765, 'NAME': 0.124, 'ADDRESS': 0.976},
82 | pii_entities=[{'Score': 0.234, 'Type': 'SSN', 'BeginOffset': 12, 'EndOffset': 36},
83 | {'Score': 0.765, 'Type': 'EMAIL', 'BeginOffset': 28, 'EndOffset': 36},
84 | {'Score': 0.534, 'Type': 'NAME', 'BeginOffset': 49, 'EndOffset': 53},
85 | {'Score': 0.234, 'Type': 'ADDRESS', 'BeginOffset': 58, 'EndOffset': 65}]),
86 | Document(text="Some Random name and address and some credit card number", char_offset=37,
87 | pii_classification={'SSN': 0.234, 'EMAIL': 0.765, 'USERNAME': 0.424, 'ADDRESS': 0.976},
88 | pii_entities=[{'Score': 0.234, 'Type': 'USERNAME', 'BeginOffset': 12, 'EndOffset': 16},
89 | {'Score': 0.634, 'Type': 'ADDRESS', 'BeginOffset': 17, 'EndOffset': 28},
90 | {'Score': 0.234, 'Type': 'CREDIT_DEBIT_NUMBER', 'BeginOffset': 38, 'EndOffset': 56}])]
91 | segmentor = Segmenter(5000)
92 | expected_merged_document = Document(
93 | text="Some Random SSN Some Random email-id Some Random name and address and some credit card number", char_offset=37,
94 | pii_classification={'SSN': 0.234, 'EMAIL': 0.765, 'NAME': 0.124, 'USERNAME': 0.424, 'ADDRESS': 0.976},
95 | pii_entities=[{'Score': 0.234, 'Type': 'SSN', 'BeginOffset': 12, 'EndOffset': 36},
96 | {'Score': 0.765, 'Type': 'EMAIL', 'BeginOffset': 28, 'EndOffset': 36},
97 | {'Score': 0.534, 'Type': 'NAME', 'BeginOffset': 49, 'EndOffset': 53},
98 | {'Score': 0.634, 'Type': 'ADDRESS', 'BeginOffset': 54, 'EndOffset': 65},
99 | {'Score': 0.234, 'Type': 'CREDIT_DEBIT_NUMBER', 'BeginOffset': 75, 'EndOffset': 93}])
100 | actual_merged_doc = segmentor.de_segment(segments)
101 | assert expected_merged_document.text == actual_merged_doc.text
102 | assert expected_merged_document.pii_classification == actual_merged_doc.pii_classification
103 | assert expected_merged_document.pii_entities == actual_merged_doc.pii_entities
104 |
105 |
106 | def test_is_overlapping_annotations(self):
107 | segmentor = Segmenter(5000)
108 | assert segmentor._is_overlapping_annotations({'Score': 0.634, 'Type': 'ADDRESS', 'BeginOffset': 54, 'EndOffset': 65},
109 | {'Score': 0.234, 'Type': 'ADDRESS', 'BeginOffset': 58, 'EndOffset': 65}) == 0
110 |
111 | def test_segmenter_scalablity_test(self):
112 | # 1MB of text should be segmented with around 30 ms latency
113 | setup = """
114 | import os
115 | from processors import Segmenter
116 |
117 | text=" Hello Zhang Wei. Your AnyCompany Financial Services, LLC credit card account 1111-0000-1111-0000 has a minimum payment of $24.53"
118 | one_mb_text=" "
119 | for i in range(7751):
120 | one_mb_text += text
121 | segmenter = Segmenter(overlap_tokens=20, max_doc_size=5000)
122 | """.format(this_module_path)
123 | segmentation_time = timeit.timeit("segmenter.segment(one_mb_text)", setup=setup, number=100)
124 | assert segmentation_time < 15
125 |
126 | def test_redaction_with_no_entities(self):
127 | text = "Hello Zhang Wei. Your AnyCompany Financial Services, LLC credit card account 1111-0000-1111-0000 has a minimum payment of $24.53"
128 | redactor = Redactor(RedactionConfig())
129 | redacted_text = redactor.redact(text, [])
130 | assert text == redacted_text
131 |
132 | def test_redaction_default_redaction_config(self):
133 | text = "Hello Zhang Wei. Your AnyCompany Financial Services, LLC credit card account 1111-0000-1111-0000 has a minimum payment of $24.53"
134 | redactor = Redactor(RedactionConfig())
135 | redacted_text = redactor.redact(text, [{'Score': 0.234, 'Type': 'NAME', 'BeginOffset': 6, 'EndOffset': 16},
136 | {'Score': 0.765, 'Type': 'CREDIT_DEBIT_NUMBER', 'BeginOffset': 77, 'EndOffset': 96}])
137 | expected_redaction = "Hello Zhang Wei. Your AnyCompany Financial Services, LLC credit card account ******************* has a minimum payment of $24.53"
138 | assert expected_redaction == redacted_text
139 |
140 | def test_redaction_with_replace_entity_type(self):
141 | text = "Hello Zhang Wei. Your AnyCompany Financial Services, LLC credit card account 1111-0000-1111-0000 has a minimum payment of $24.53"
142 | redactor = Redactor(RedactionConfig(pii_entity_types=['NAME'], mask_mode=REPLACE_WITH_PII_ENTITY_TYPE, confidence_threshold=0.6))
143 | redacted_text = redactor.redact(text, [{'Score': 0.634, 'Type': 'NAME', 'BeginOffset': 6, 'EndOffset': 15},
144 | {'Score': 0.765, 'Type': 'CREDIT_DEBIT_NUMBER', 'BeginOffset': 77, 'EndOffset': 96}])
145 | expected_redaction = "Hello [NAME]. Your AnyCompany Financial Services, LLC credit card account 1111-0000-1111-0000 has a minimum payment of $24.53"
146 | assert expected_redaction == redacted_text
147 |
148 | def test_segmenter_constructor_invalid_args(self):
149 | try:
150 | Segmenter(3)
151 | assert False, "Expected an InvalidConfigurationException"
152 | except InvalidConfigurationException:
153 | return
154 |
--------------------------------------------------------------------------------
/REDACTION_README.md:
--------------------------------------------------------------------------------
1 | # PII Redaction S3 Object Lambda function
2 |
3 | This serverless app helps you redact PII (Personally Identifiable Information) from valid text files present in S3.
4 | This app deploys a Lambda function which can be attached to S3 object lambda access point.
5 | The lambda function internally uses AWS Comprehend to detect PII entities from the text
6 |
7 | ## App Architecture
8 | 
9 |
10 | Lambda function is optimized to leverage Comprehend's ContainsPiiEntities and DetectPIIEntities API efficiently to save time and cost for large documents
11 | ContainsPiiEntities API is a high-throughput API which acts as a filter so only those documents that contain PII are sent to the DetectPIIEntities API for the compute heavy PII redaction task
12 |
13 | 1. Lambda function is invoked with a request containing information about the S3 object to get and transform.
14 | 2. The request contains a S3 presigned url to fetch the requested object.
15 | 3. The data is split into chunks that are accepted by Comprehend’s ContainsPiiEntities API and call the API with each chunk.
16 | 4. For each chunk which contains PII data, the chunks are split further into smaller chunks with max size supported by Comprehend’s DetectPiiEntities API.
17 | 5. For each of these smaller PII chunks, Comprehend’s DetectPIIEntities API is invoked to detect the text spans containing interested PII entities.
18 | 6. The responses are aggregated from all chunks.
19 | 7. Lambda function callsback S3 with the response i.e the redacted document.
20 | 8. If any failure happens while processing, Lambda function returns an appropriate error response to S3 which will be returned to the original caller.
21 | 9. Lambda function returns with 0 exit code .i.e. with out any error if no error occurred else would fail.
22 |
23 | ## Installation Instructions
24 |
25 | 1. [Create an AWS account](https://portal.aws.amazon.com/gp/aws/developer/registration/index.html) if you do not already have one and login
26 | 1. Go to the app's page on the [Serverless Application Repository](https://console.aws.amazon.com/lambda/home#/create/app?applicationId=arn:aws:serverlessrepo:us-east-1:839782855223:applications/ComprehendPiiRedactionS3ObjectLambda)
27 | 1. Provide the required app parameters (see parameter details below) and click "Deploy"
28 |
29 | ## Parameters
30 | Following are the parameters that you can tune to get desired behavior
31 | #### Environment variables
32 | Following environment variables for Lambda function can be set to get desired behaviour.
33 | 1. `LOG_LEVEL` - Log level for Lambda function function logging, e.g., ERROR, INFO, DEBUG, etc. Default: `INFO`.
34 | 1. `UNSUPPORTED_FILE_HANDLING` Handling logic for Unsupported files. Valid values are `PASS` and `FAIL` (Default: `FAIL`). If set to `FAIL` it will throw UnsupportedFileException when the requested object is of unsupported type.
35 | 1. `IS_PARTIAL_OBJECT_SUPPORTED` Whether to support partial objects or not. Accessing partial object through http headers such byte-range can corrupt the object and/or affect PII detection accuracy. Valid values are `TRUE` and `FALSE`. Default: `FALSE`.
36 | 1. `DOCUMENT_MAX_SIZE_CONTAINS_PII_ENTITIES` Maximum document size (in bytes) to be used for making calls to Comprehend's ContainsPiiDocument API for classifying PII entity types present in the doc Default: 50000.
37 | 1. `DOCUMENT_MAX_SIZE_DETECT_PII_ENTITIES`: Maximum document size (in bytes) to be used for making calls to Comprehend's DetectPiiEntities API. Default: 5120 i.e. 5KB.
38 | 1. `PII_ENTITY_TYPES` : List of comma separated PII entity types to be considered for redaction. Refer [Comprehend's documentation page](https://docs.aws.amazon.com/comprehend/latest/dg/how-pii.html#how-pii-types) for list of supported PII entity types. Default: `ALL` which signifies all entity types that comprehend supports.
39 | 1. `MASK_CHARACTER` : A character that replaces each character in the redacted PII entity. Default: *.
40 | 1. `MASK_MODE` : Specifies whether the PII entity is redacted with the mask character or the entity type. Valid values: `MASK` and `REPLACE_WITH_PII_ENTITY_TYPE`. Default: `MASK`.
41 | 1. `SUBSEGMENT_OVERLAPPING_TOKENS` : Number of tokens/words to overlap among segments of a document in case chunking is needed because of maximum document size limit. Default: 20.
42 | 1. `DOCUMENT_MAX_SIZE` : Default maximum document size (in bytes) that this function can process otherwise will throw exception for too large document size.
43 | 1. `CONFIDENCE_THRESHOLD` : The minimum prediction confidence score above which PII classification and detection would be considered as final answer. Valid range (0.5 to 1.0). Default: 0.5 .
44 | 1. `MAX_CHARS_OVERLAP` : Maximum characters to overlap among segments of a document in case chunking is needed because of maximum document size limit. Default: 2.
45 | 1. `DEFAULT_LANGUAGE_CODE` : Default language of the text to be processed. This code will be used for interacting with Comprehend . Default: en.
46 | 1. `DETECT_PII_ENTITIES_THREAD_COUNT` : Number of threads to use for calling Comprehend's DetectPiiEntities API. This controls the number of simultaneous calls that will be made from this Lambda function. Default: 8.
47 | 1. `CONTAINS_PII_ENTITIES_THREAD_COUNT` : Number of threads to use for calling Comprehend's ContainsPiiEntities API. This controls the number of simultaneous calls the will be made from this Lambda function. Default: 20.
48 | 1. `PUBLISH_CLOUD_WATCH_METRICS` : This determines whether or not to publish metrics to Cloudwatch. Default: true.
49 |
50 | #### Runtime variables
51 | You can add following arguments in S3 object lambda access point configuration payload to override the default value configured used by the Lambda function . These values would take precedence over environment variables. Provide these variables as a json string like the following example.
52 | ```
53 | ...
54 | "payload": "{\"pii_entity_types\" : [\"CREDIT_DEBIT_NUMBER\"],\"mask_mode\":\"MASK\", \"mask_character\" : \"*\",\"confidence_threshold\":0.6,\"language_code\":\"en\"}"
55 | ...
56 | ```
57 | Use these parameters to get desired behaviors from different access point configuration attached to the same lambda function.
58 | 1. `pii_entity_types` : List of PII entity types to be considered for redaction. e.g. `["SSN","CREDIT_DEBIT_NUMBER"]`
59 | 1. `mask_mode` : Specifies whether the PII entity is redacted with the mask character or the entity type. Valid values: `MASK` and `REPLACE_WITH_PII_ENTITY_TYPE`.
60 | 1. `mask_character` : A character that replaces each character in the redacted PII entity.
61 | 1. `confidence_threshold` :The minimum prediction confidence score above which PII classification and detection would be considered as final answer.
62 | 1. `language_code`: Language of the text. This will be used to interact with Comprehend.
63 |
64 | ## App Outputs
65 |
66 | #### Successful response
67 | In case the text file contains PII, it would be redacted and returned in response to GetObject API output
68 | #### Error responses
69 | Lambda function would forward the standard [S3 error responses](https://docs.aws.amazon.com/AmazonS3/latest/API/ErrorResponses.html) it will receive while downloading the file from S3
70 |
71 | Further following error responses will be thrown by Lambda function:
72 |
73 | |Status Code|Error Code|Error Message|Description|
74 | |---|---|---|---|
75 | | BAD_REQUEST_400 | InvalidRequest | Lambda function has been incorrectly setup | An incorrect configuration which restricts lambda function to even start handling the incoming events|
76 | | BAD_REQUEST_400 | UnexpectedContent | Unsupported file encountered for determining PII | This error would be thrown in case caller tries to get an invalid utf8 file (e.g image) and UNSUPPORTED_FILE_HANDLING variable is set to FAIL|
77 | | BAD_REQUEST_400 | EntityTooLarge | Size of the requested object exceeds maximum file size supported | This error would be thrown in case caller tries to get an object which is beyond the max file size supported|
78 | | BAD_REQUEST_400 | RequestTimeout | Failed to complete document processing within time limit | This error would be thrown in case lambda is not able to complete the processing of the document within the time limit. This could be because your file size is too big or you are getting throttled by either S3 or Comprehend.|
79 | | INTERNAL_SERVER_ERROR_500 | InternalError | An internal error occurred while processing the file | Any other error occurred while processing the object |
80 |
81 | ## Metrics
82 | Metrics are published after each invocation of the lambda function and are a best effort attempt (Failures in CloudWatch metric publishing are ignored)
83 |
84 | All metrics will be under the Namespace: ComprehendS3ObjectLambda
85 |
86 | ### Metrics for processed documents
87 | |MetricName|Description|Unit|Dimensions|
88 | |---|---|---|---|
89 | |PiiDocumentsProcessed|Emitted after processing a document that contains pii|Count|S3ObjectLambdaAccessPoint, Language|
90 | |DocumentsProcessed|Emitted after processing any document|Count|S3ObjectLambdaAccessPoint, Language|
91 | |PiiDocumentTypesProcessed|Emitted after processing a document that contains PII for each type of PII of interest|Count|S3ObjectLambdaAccessPoint, Language, PiiEntityType|
92 |
93 | ### Metrics for Comprehend operations
94 | |MetricName|Description|Unit|Dimensions|
95 | |---|---|---|---|
96 | |Latency|The latency of Comprehend DetectPiiEntities API|Milliseconds|Comprehend, DetectPiiEntities|
97 | |Latency|The latency of Comprehend ContainsPiiEntities API|Milliseconds|Comprehend, ContainsPiiEntities|
98 | |ErrorCount|The error count of Comprehend DetectPiiEntities API|Count|Comprehend, DetectPiiEntities|
99 | |ErrorCount|The error count of Comprehend ContainsPiiEntities API|Count|Comprehend, ContainsPiiEntities|
100 |
101 | ### Metrics for S3 operations
102 | |MetricName|Description|Unit|Dimensions|
103 | |---|---|---|---|
104 | |Latency|The latency of S3 WriteGetObjectResponse API|Milliseconds|S3, WriteGetObjectResponse|
105 | |Latency|The latency of downloading a file from a presigned S3 url|Milliseconds|S3, DownloadPresignedUrl|
106 | |ErrorCount|The fault count of S3 WriteGetObjectResponse API|Count|S3, WriteGetObjectResponse|
107 | |ErrorCount|The fault count of downloading a file from a presigned S3 url|Count|S3, DownloadPresignedUrl|
108 |
109 | ## License Summary
110 |
111 | This code is made available under the MIT-0 license. See the LICENSE file.
--------------------------------------------------------------------------------
/src/clients/s3_client.py:
--------------------------------------------------------------------------------
1 | """Client wrapper over aws services."""
2 | import re
3 | import time
4 | import urllib
5 | from typing import Tuple
6 |
7 | import boto3
8 | import botocore
9 | import requests
10 | from requests.adapters import HTTPAdapter
11 | from urllib3.util.retry import Retry
12 |
13 | import lambdalogging
14 | from clients.cloudwatch_client import Metrics
15 | from config import DOCUMENT_MAX_SIZE
16 | from constants import CONTENT_LENGTH, S3_STATUS_CODES, S3_ERROR_CODES, error_code_to_enums, WRITE_GET_OBJECT_RESPONSE, \
17 | DOWNLOAD_PRESIGNED_URL, S3_MAX_RETRIES, S3, http_status_code_to_s3_status_code
18 | from exceptions import UnsupportedFileException, FileSizeLimitExceededException, S3DownloadException
19 |
20 | LOG = lambdalogging.getLogger(__name__)
21 |
22 |
23 | class S3Client:
24 | """Wrapper over s3 client."""
25 |
26 | ERROR_RE = r'(?<=).*(?=<\/Error>)'
27 | CODE_RE = r'(?<=).*(?=<\/Code>)'
28 | MESSAGE_RE = r'(?<=).*(?=<\/Message>)'
29 | XML_HEADER = ''
30 | S3_DOWNLOAD_MAX_RETRIES = 5
31 | S3_RETRY_STATUS_CODES = [429, 500, 502, 503, 504]
32 | BACKOFF_FACTOR = 1.5
33 | MAX_GET_TIMEOUT = 10
34 | # Translation map from response headers of s3's getObject response to S3OL's WriteGetObjectResponse's request headers
35 | S3GET_TO_WGOR_HEADER_TRANSLATION_MAP = {
36 | "accept-ranges": ("AcceptRanges", str),
37 | "Cache-Control": ("CacheControl", str),
38 | "Content-Disposition": ("ContentDisposition", str),
39 | "Content-Encoding": ("ContentEncoding", str),
40 | "Content-Language": ("ContentLanguage", str),
41 | "Content-Length": ("ContentLength", int),
42 | "Content-Range": ("ContentRange", str),
43 | "Content-Type": ("ContentType", str),
44 | "x-amz-delete-marker": ("DeleteMarker", bool),
45 | "ETag": ("ETag", str),
46 | "Expires": ("Expires", str),
47 | "x-amz-expiration": ("Expiration", str),
48 | "Last-Modified": ("LastModified", str),
49 | "x-amz-missing-meta": ("MissingMeta", str),
50 | "x-amz-meta-": ("Metadata", str),
51 | "x-amz-object-lock-mode": ("ObjectLockMode", str),
52 | "x-amz-object-lock-legal-hold": ("ObjectLockLegalHoldStatus", str),
53 | "x-amz-object-lock-retain-until-date": ("ObjectLockRetainUntilDate", str),
54 | "x-amz-mp-parts-count": ("PartsCount", int),
55 | "x-amz-replication-status": ("ReplicationStatus", str),
56 | "x-amz-request-charged": ("RequestCharged", str),
57 | "x-amz-restore": ("Restore", str),
58 | "x-amz-server-side-encryption": ("ServerSideEncryption", str),
59 | "x-amz-server-side-encryption-customer-algorithm": ("SSECustomerAlgorithm", str),
60 | "x-amz-server-side-encryption-aws-kms-key-id": ("SSEKMSKeyId", str),
61 | "x-amz-server-side-encryption-customer-key-MD5": ("SSECustomerKeyMD5", str),
62 | "x-amz-storage-class": ("StorageClass", str),
63 | "x-amz-tagging-count": ("TagCount", int),
64 | "x-amz-version-id": ("VersionId", str),
65 | }
66 | # Restricted http headers that can't be sent to s3 as part of downloading object using preseigned url
67 | # Adding these headers can causes a mismatch with Sigv4 signature
68 | BLOCKED_REQUEST_HEADERS = ("Host")
69 |
70 | def __init__(self, s3ol_access_point: str, max_file_supported=DOCUMENT_MAX_SIZE):
71 | self.max_file_supported = max_file_supported
72 | session_config = botocore.config.Config(
73 | retries={
74 | 'max_attempts': S3_MAX_RETRIES,
75 | 'mode': 'standard'
76 | })
77 | self.s3 = boto3.client('s3', config=session_config)
78 |
79 | self.session = requests.Session()
80 | self.session.mount("https://", adapter=HTTPAdapter(max_retries=Retry(
81 | total=self.S3_DOWNLOAD_MAX_RETRIES,
82 | status_forcelist=self.S3_RETRY_STATUS_CODES,
83 | method_whitelist=["GET"],
84 | backoff_factor=self.BACKOFF_FACTOR
85 | )))
86 |
87 | self.download_metrics = Metrics(service_name=S3, api=DOWNLOAD_PRESIGNED_URL, s3ol_access_point=s3ol_access_point)
88 | self.write_get_object_metrics = Metrics(service_name=S3, api=WRITE_GET_OBJECT_RESPONSE, s3ol_access_point=s3ol_access_point)
89 |
90 | def _contains_error(self, response) -> Tuple[bool, Tuple[str, str, S3_STATUS_CODES]]:
91 | text = response.content.decode('utf-8')
92 | lines = text.split('\n')
93 | # All 200-299 status codes are succesfull responses . 206 is for partial code .
94 | if response.status_code >= 300 or (len(lines) > 0 and lines[0] == self.XML_HEADER):
95 | xml = ''.join(lines[1:])
96 | LOG.info('Response status code >=300 or text contains xml. ')
97 | error_match = re.search(self.ERROR_RE, xml)
98 | code_match = re.search(self.CODE_RE, xml)
99 | message_match = re.search(self.MESSAGE_RE, xml)
100 | if error_match and code_match and message_match:
101 | error_code = code_match[0]
102 | error_message = message_match[0]
103 | return True, (error_code, error_message, http_status_code_to_s3_status_code(response.status_code))
104 | elif response.status_code >= 300:
105 | return True, (
106 | S3_ERROR_CODES.InternalError.name, "Internal Server Error", http_status_code_to_s3_status_code(response.status_code))
107 | return False, ('', '', http_status_code_to_s3_status_code(response.status_code))
108 |
109 | def _parse_response_headers(self, headers):
110 | """
111 | Convert response headers received from s3 presigned download call to the format similar to arguments of WriteGetObjectResponse API.
112 | :param headers: http headers received as part of response from downloading the object from s3
113 | """
114 | transformed_headers = {}
115 | for header_name in headers:
116 | if header_name in self.S3GET_TO_WGOR_HEADER_TRANSLATION_MAP:
117 | header_value = self.S3GET_TO_WGOR_HEADER_TRANSLATION_MAP[header_name][1](headers[header_name])
118 | transformed_headers[self.S3GET_TO_WGOR_HEADER_TRANSLATION_MAP[header_name][0]] = header_value
119 |
120 | return transformed_headers
121 |
122 | def _filter_request_headers(self, presigned_url, headers={}):
123 | """
124 | Filter some restricted headers that shouldn't be passed along to s3 when downloading the object.
125 | :param headers: http header from the incoming request
126 | :return: a filtered list of headers
127 | """
128 | filtered_headers = {}
129 | parsed_url = urllib.parse.urlparse(presigned_url)
130 | parsed_query_params = urllib.parse.parse_qs(parsed_url.query)
131 | signed_headers = set(parsed_query_params.get('X-Amz-SignedHeaders', []))
132 |
133 | for header in headers:
134 | if header in self.BLOCKED_REQUEST_HEADERS:
135 | continue
136 | if str(header).lower().startswith('x-amz-') and header not in signed_headers:
137 | continue
138 | filtered_headers[header] = headers[header]
139 | return filtered_headers
140 |
141 | def download_file_from_presigned_url(self, presigned_url, headers=None) -> Tuple[str, map, S3_STATUS_CODES]:
142 | """
143 | Download the file from a s3's presigned url.
144 | Python AWS-SDK doesn't provide any method to download from a presigned url directly so we'd have to make a simple GET httpcall.
145 | """
146 | parsed_headers = self._filter_request_headers(presigned_url, headers)
147 | for i in range(self.S3_DOWNLOAD_MAX_RETRIES):
148 | start_time = time.time()
149 | LOG.debug(f"Downloading object with presigned url {presigned_url} and headers: {parsed_headers}")
150 | response = self.session.get(presigned_url, timeout=self.MAX_GET_TIMEOUT, headers=parsed_headers)
151 | end_time = time.time()
152 | try:
153 | # Since presigned urls do not return correct status codes when there is an error,
154 | # the xml must be parsed to find the error code and status
155 | error_detected, (error_code, error_message, response_status_code) = self._contains_error(response)
156 | if error_detected:
157 | status_code_enum, error_code_enum = error_code_to_enums(error_code)
158 | LOG.error(f"Error downloading file from presigned url. ({error_code}: {error_message})")
159 | status_code = int(status_code_enum.name[-3:])
160 | if status_code not in self.S3_RETRY_STATUS_CODES or i == self.S3_DOWNLOAD_MAX_RETRIES - 1:
161 | LOG.error("Client error or max retries reached for downloading file from presigned url.")
162 | self.download_metrics.add_fault_count()
163 | raise S3DownloadException(error_code, error_message)
164 | else:
165 | text_content = response.content.decode('utf-8')
166 | if CONTENT_LENGTH in response.headers and int(response.headers.get(CONTENT_LENGTH)) > self.max_file_supported:
167 | raise FileSizeLimitExceededException("File too large to process")
168 | self.download_metrics.add_latency(start_time, end_time)
169 | return text_content, response.headers, response_status_code,
170 | time.sleep(max(1.0, i ** self.BACKOFF_FACTOR))
171 | except UnicodeDecodeError:
172 | raise UnsupportedFileException(response.content, response.headers, "Not a valid utf-8 file")
173 |
174 | def respond_back_with_data(self, data, headers: map, request_route: str, request_token: str,
175 | status_code: S3_STATUS_CODES = S3_STATUS_CODES.OK_200):
176 | """Call S3's WriteGetObjectResponse API to return the processed object back to the original caller of get_object API."""
177 | start_time = time.time()
178 | try:
179 | parsed_headers = self._parse_response_headers(headers)
180 | LOG.debug(f"Calling s3 WriteGetObjectResponse with RequestRoute:{request_route} , headers: {parsed_headers},"
181 | f" RequestToken: {request_token}")
182 | self.s3.write_get_object_response(StatusCode=status_code.get_http_status_code(), Body=data, RequestRoute=request_route,
183 | RequestToken=request_token, **parsed_headers)
184 | except Exception as error:
185 | LOG.error("Error occurred while calling s3 write get object response with data.", exc_info=True)
186 | self.write_get_object_metrics.add_fault_count()
187 | raise error
188 | finally:
189 | self.write_get_object_metrics.add_latency(start_time, time.time())
190 |
191 | def respond_back_with_error(self, status_code: S3_STATUS_CODES, error_code: S3_ERROR_CODES, error_message: str,
192 | request_route: str, request_token: str):
193 | """Call S3's WriteGetObjectResponse API to return an error to the original caller of get_object API."""
194 | start_time = time.time()
195 | try:
196 | self.s3.write_get_object_response(StatusCode=status_code.get_http_status_code(), ErrorCode=error_code.name,
197 | ErrorMessage=error_message,
198 | RequestRoute=request_route, RequestToken=request_token)
199 | except Exception as error:
200 | LOG.error("Error occurred while calling s3 write get object response with error.", exc_info=True)
201 | self.write_get_object_metrics.add_fault_count()
202 | raise error
203 | finally:
204 | self.write_get_object_metrics.add_latency(start_time, time.time())
205 |
--------------------------------------------------------------------------------
/src/processors.py:
--------------------------------------------------------------------------------
1 | """Text processors."""
2 |
3 | # must be the first import in files with lambda function handlers
4 | from copy import deepcopy
5 | from typing import List
6 |
7 | import lambdalogging
8 | from config import SUBSEGMENT_OVERLAPPING_TOKENS, MAX_CHARS_OVERLAP
9 | from constants import ENTITY_TYPE, BEGIN_OFFSET, END_OFFSET, ALL, REPLACE_WITH_PII_ENTITY_TYPE, SCORE
10 | from data_object import Document
11 | from data_object import RedactionConfig
12 | from exceptions import InvalidConfigurationException
13 |
14 | LOG = lambdalogging.getLogger(__name__)
15 |
16 |
17 | class Segmenter:
18 | """Offer functionality to segment and desegment."""
19 |
20 | def __init__(self, max_doc_size: int, overlap_tokens: int = SUBSEGMENT_OVERLAPPING_TOKENS,
21 | max_overlapping_chars: int = MAX_CHARS_OVERLAP, **kwargs):
22 | self.max_overlapping_chars = int(max_overlapping_chars)
23 | self.overlap_tokens = int(overlap_tokens)
24 | self.max_doc_size = int(max_doc_size)
25 | # A utf8 character can go upto 4 bytes
26 | if max_doc_size < 4:
27 | raise InvalidConfigurationException(
28 | f"Maximum text size limit ({self.max_doc_size} bytes) is too less to perform segmentation")
29 |
30 | def _trim_to_max_bytes(self, s, max_bytes):
31 | """
32 | Ensure that the UTF-8 encoding of a string has not more than max_bytes bytes.
33 |
34 | The table below summarizes the format of these different octet types.
35 | Char. number range | UTF-8 octet sequence
36 | (hexadecimal) | (binary)
37 | --------------------+---------------------------------------------
38 | 0000 0000-0000 007F | 0xxxxxxx
39 | 0000 0080-0000 07FF | 110xxxxx 10xxxxxx
40 | 0000 0800-0000 FFFF | 1110xxxx 10xxxxxx 10xxxxxx
41 | 0001 0000-0010 FFFF | 11110xxx 10xxxxxx 10xxxxxx 10xxxxxx
42 | """
43 |
44 | def safe_b_of_i(b, i):
45 | try:
46 | return b[i]
47 | except IndexError:
48 | return 0
49 |
50 | # Edge cases
51 | if s == '' or max_bytes < 1:
52 | return ''
53 |
54 | # cut it twice to avoid encoding potentially GBs of string just to get e.g. 10 bytes?
55 | bytes_array = s[:max_bytes].encode('utf-8')[:max_bytes]
56 |
57 | # find the first byte from end which contains the starting byte of a utf8 character which is this format 11xxxxxx for
58 | # multi byte character. For single byte character the format is 0xxxxxxx as described above
59 | if bytes_array[-1] & 0b10000000:
60 | last_11xxxxxx_index = [
61 | i
62 | for i in range(-1, -5, -1)
63 | if safe_b_of_i(bytes_array, i) & 0b11000000 == 0b11000000
64 | ][0]
65 | # As described above in the table , we can determine the total size(in bytes) of char from the first byte itself
66 | starting_byte = bytes_array[last_11xxxxxx_index]
67 | if not starting_byte & 0b00100000:
68 | last_char_length = 2
69 | elif not starting_byte & 0b00010000:
70 | last_char_length = 3
71 | elif not starting_byte & 0b00001000:
72 | last_char_length = 4
73 | else:
74 | raise Exception(f"Unexpected utf-8 {starting_byte} byte encountered")
75 |
76 | if last_char_length > -last_11xxxxxx_index:
77 | # remove the incomplete character
78 | bytes_array = bytes_array[:last_11xxxxxx_index]
79 |
80 | return bytes_array.decode('utf-8')
81 |
82 | def _trim_partial_trailing_word(self, text):
83 | # find the first space moving backwards
84 | original_length = len(text)
85 | k = original_length - 1
86 | # ensuring we have a hard limit on how back we need to travel. We don't want to travel the whole sentence back
87 | # if there are no spaces in it. Using max_overlapping_chars as proxy for this
88 | while text[k] != ' ' and k > 0 and original_length - k < self.max_overlapping_chars:
89 | k -= 1
90 | trimmed_text = text[:k + 1]
91 | return trimmed_text
92 |
93 | def _find_trailing_overlapping_tokens_start_index(self, text):
94 | word_count = 0
95 | original_length = len(text)
96 | k = original_length - 1
97 | while word_count < self.overlap_tokens:
98 | k -= 1
99 | # Moving backwards: find the beginning of word (next character is space and current character is not space)
100 | while not (text[k + 1] != ' ' and text[k] == ' ') and k > 0 and original_length - k < self.max_overlapping_chars:
101 | k -= 1
102 | word_count += 1
103 | if k == 0:
104 | LOG.debug("Overlapping tokens for the next sentence starts beyond the current sentence")
105 | break
106 | return k
107 |
108 | def _merge_classifcation_results(self, segment: Document, existing_results: map = {}):
109 | for name, score in segment.pii_classification.items():
110 | if name not in existing_results or (
111 | name in existing_results and score > existing_results[name]):
112 | existing_results[name] = score
113 | return existing_results
114 |
115 | def _is_overlapping_annotations(self, entity_a, entity_b) -> int:
116 | """
117 | Determine if one entity overlaps with another.
118 | It will return :
119 | 1 if entity_b, lies on right side of the entity
120 | 0 if entity_b overlaps with entity_a
121 | -1 if entity_b lies on left side of the entity
122 | """
123 | if entity_a[END_OFFSET] < entity_b[BEGIN_OFFSET]:
124 | return 1
125 | if entity_a[BEGIN_OFFSET] > entity_b[END_OFFSET]:
126 | return -1
127 | else:
128 | return 0
129 |
130 | def _resolve_overlapped_annotation(self, entity_a, entity_b) -> List:
131 | """Merge two overlapping entity annotations."""
132 | if entity_a[SCORE] >= entity_b[SCORE]:
133 | return [entity_a]
134 | else:
135 | return [entity_b]
136 |
137 | def _merge_pii_annotation_results(self, segment: Document, existing_annotations: List = []):
138 |
139 | if not existing_annotations:
140 | existing_annotations.extend(segment.pii_entities)
141 | return
142 |
143 | for pii_entity in segment.pii_entities:
144 | k = len(existing_annotations) - 1
145 | while k > 0:
146 | overlap_result = self._is_overlapping_annotations(existing_annotations[k], pii_entity)
147 | if overlap_result > 0:
148 | existing_annotations.append(pii_entity)
149 | break
150 | elif overlap_result == 0:
151 | LOG.debug("Annotation: " + str(existing_annotations[k]) + " conflicts with: " + str(pii_entity))
152 | resolved_annotation = self._resolve_overlapped_annotation(existing_annotations[k], pii_entity)
153 | LOG.debug("Deleting annotation:" + str(existing_annotations[k]))
154 | del existing_annotations[k]
155 | for i, annotation in enumerate(resolved_annotation):
156 | LOG.debug("Adding annotation:" + str(annotation))
157 | existing_annotations.insert(k + i, annotation)
158 | break
159 | else:
160 | k -= 1
161 |
162 | return existing_annotations
163 |
164 | def _relocate_annotation(self, annotations: List, offset: int):
165 | """Shift the annotated entities by given offset."""
166 | annotations_copy = deepcopy(annotations)
167 | for annotation in annotations_copy:
168 | annotation[END_OFFSET] += offset
169 | annotation[BEGIN_OFFSET] += offset
170 | return annotations_copy
171 |
172 | def segment(self, text: str, char_offset=0) -> List[Document]:
173 | """Segment the text into segments of max_doc_length with overlap_tokens."""
174 | segments = []
175 | starting_index = 0
176 | while len(text[starting_index:].encode()) > self.max_doc_size:
177 | trimmed_text = self._trim_to_max_bytes(text[starting_index:], self.max_doc_size)
178 | trimmed_text = self._trim_partial_trailing_word(trimmed_text)
179 | segments.append(Document(text=trimmed_text, char_offset=char_offset + starting_index))
180 | starting_index = starting_index + self._find_trailing_overlapping_tokens_start_index(trimmed_text) + 1
181 | # Add the remaining segment
182 | if starting_index < len(text) - 1:
183 | segments.append(Document(text=text[starting_index:], char_offset=char_offset + starting_index))
184 | return segments
185 |
186 | def de_segment(self, segments: List[Document]) -> Document:
187 | """
188 | Merge the segments back into one big text. It also merges back the pii classification result.
189 | Handles conflicting result on overlapping text between two text segments in the following ways:
190 | 1. For pii classification, the maximum thresholds for an entity amongst the segments is
191 | updated as the threshold for that entity for the merged document
192 | 2. For pii entity annotations, for a conflicting annotation span a higher priority
193 | is given to the one with a higher confidence threshold
194 | """
195 | merged_text = ""
196 | pii_classification = {}
197 | pii_entities = []
198 | segments.sort(key=lambda x: x.char_offset)
199 | for segment in segments:
200 | offset_adjusted_segment = Document(text=segment.text, char_offset=segment.char_offset,
201 | pii_entities=self._relocate_annotation(segment.pii_entities, segment.char_offset),
202 | pii_classification=segment.pii_classification)
203 | self._merge_classifcation_results(segment, pii_classification)
204 | self._merge_pii_annotation_results(offset_adjusted_segment, pii_entities)
205 | merged_text = merged_text + segment.text[len(merged_text) - segment.char_offset:]
206 | return Document(text=merged_text, char_offset=0, pii_classification=pii_classification, pii_entities=pii_entities)
207 |
208 |
209 | class Redactor:
210 | """Handle the logic of redacting discovered pii entities from the given text."""
211 |
212 | def __init__(self, redaction_config: RedactionConfig):
213 | self.redaction_config = redaction_config
214 |
215 | def redact(self, input_text, entities_list):
216 | """Redact the pii entities from given text."""
217 | doc_parts_list = []
218 | prev_entity = None
219 | for entity in entities_list:
220 | if entity[SCORE] < self.redaction_config.confidence_threshold:
221 | continue
222 | entity_type = entity[ENTITY_TYPE]
223 | begin_offset = entity[BEGIN_OFFSET]
224 | end_offset = entity[END_OFFSET]
225 | if prev_entity is None:
226 | doc_parts_list.append(input_text[:begin_offset])
227 | else:
228 | doc_parts_list.append(input_text[prev_entity[END_OFFSET]:begin_offset])
229 |
230 | if ALL in self.redaction_config.pii_entity_types or entity_type in self.redaction_config.pii_entity_types:
231 | # Redact this entity type
232 | if self.redaction_config.mask_mode == REPLACE_WITH_PII_ENTITY_TYPE:
233 | # Replace with PII Entity Type
234 | doc_parts_list.append(f"[{entity_type}]")
235 | else:
236 | # Replace with MaskCharacter
237 | entity_length = end_offset - begin_offset
238 | doc_parts_list.append(self.redaction_config.mask_character * entity_length)
239 | else:
240 | # Don't redact this entity type
241 | doc_parts_list.append(input_text[begin_offset:end_offset])
242 |
243 | prev_entity = entity
244 | if prev_entity is not None:
245 | doc_parts_list.append(input_text[prev_entity[END_OFFSET]:])
246 | else:
247 | doc_parts_list.append(input_text)
248 | return ''.join([doc_part for doc_part in doc_parts_list])
249 |
--------------------------------------------------------------------------------
/src/handler.py:
--------------------------------------------------------------------------------
1 | """Lambda function handler."""
2 |
3 | # must be the first import in files with lambda function handlers
4 | import time
5 | import traceback
6 |
7 | import lambdainit # noqa: F401
8 | import json
9 | from typing import List
10 | import lambdalogging
11 | from clients.comprehend_client import ComprehendClient
12 | from clients.s3_client import S3Client
13 | from clients.cloudwatch_client import CloudWatchClient
14 | from config import DOCUMENT_MAX_SIZE_CONTAINS_PII_ENTITIES, DOCUMENT_MAX_SIZE_DETECT_PII_ENTITIES, DEFAULT_LANGUAGE_CODE, \
15 | PUBLISH_CLOUD_WATCH_METRICS, REDACTION_API_ONLY, COMPREHEND_ENDPOINT_URL
16 | from constants import ALL, REQUEST_ID, GET_OBJECT_CONTEXT, S3OL_ACCESS_POINT_ARN, \
17 | INPUT_S3_URL, S3OL_CONFIGURATION, REQUEST_ROUTE, REQUEST_TOKEN, PAYLOAD, DEFAULT_USER_AGENT, LANGUAGE_CODE, USER_REQUEST, \
18 | HEADERS, CONTENT_LENGTH, RESERVED_TIME_FOR_CLEANUP
19 | from data_object import Document, PiiConfig, RedactionConfig, ClassificationConfig
20 | from exception_handlers import ExceptionHandler
21 | from exceptions import RestrictedDocumentException
22 | from processors import Segmenter, Redactor
23 | from util import execute_task_with_timeout
24 | from validators import InputEventValidator, PartialObjectRequestValidator
25 |
26 | LOG = lambdalogging.getLogger(__name__)
27 |
28 |
29 | def get_interested_pii(document: Document, classification_config: PiiConfig):
30 | """
31 | Get a list of interested pii from the document.
32 |
33 | Return a list of pii entity types of the given document with only the entities of interest
34 | and above the confidence threshold.
35 | """
36 | pii_entities = []
37 | for name, score in document.pii_classification.items():
38 | if name in classification_config.pii_entity_types or ALL in classification_config.pii_entity_types:
39 | if score >= classification_config.confidence_threshold:
40 | pii_entities.append(name)
41 | return pii_entities
42 |
43 |
44 | def publish_metrics(cloud_watch: CloudWatchClient, s3: S3Client, comprehend: ComprehendClient, processed_document: bool,
45 | processed_pii_document: bool, language_code: str, s3ol_access_point: str, pii_entities: List[str]):
46 | """Publish metrics from the function execution."""
47 | try:
48 | cloud_watch.publish_metrics(s3.download_metrics.metrics + s3.write_get_object_metrics.metrics +
49 | comprehend.classify_metrics.metrics + comprehend.detection_metrics.metrics)
50 | if processed_document:
51 | cloud_watch.put_document_processed_metric(language_code, s3ol_access_point)
52 | if processed_pii_document:
53 | cloud_watch.put_pii_document_processed_metric(language_code, s3ol_access_point)
54 | cloud_watch.put_pii_document_types_metric(pii_entities, language_code, s3ol_access_point)
55 | except Exception as e:
56 | LOG.error(f"Error publishing metrics to cloudwatch. :{e} {traceback.print_exc()}")
57 |
58 |
59 | def redact(text, classification_segmenter: Segmenter, detection_segmenter: Segmenter,
60 | redactor: Redactor, comprehend: ComprehendClient, redaction_config: RedactionConfig, language_code) -> Document:
61 | """
62 | Redact pii data from given text. Logic for redacting:- .
63 |
64 | 1. Segment text into subsegments of reasonable sizes (max doc size supported by comprehend) for doing initial classification
65 | 2. For each subsegment ,
66 | 2.1 call comprehend's classify-pii-document api to determine if it contains any PII data
67 | 2.2 if it contains pii then split it to smaller chunks(e.g. <=5KB), else skip to the next subsegment
68 | 2.3 for each chunk
69 | 2.3.1 call comprehend's detect-pii-entities to extract the pii entities
70 | 2.3.2 redact the pii entities from the chunk
71 | 2.4 merge all chunks
72 | 3. merge all subsegments
73 | """
74 | if REDACTION_API_ONLY:
75 | doc = Document(text)
76 | documents = [doc]
77 | docs_for_entity_detection = detection_segmenter.segment(doc.text, doc.char_offset)
78 | else:
79 | documents = comprehend.contains_pii_entities(classification_segmenter.segment(text), language_code)
80 | pii_docs = [doc for doc in documents if len(get_interested_pii(doc, redaction_config)) > 0]
81 | if not pii_docs:
82 | LOG.debug("Document doesn't have any pii. Nothing to redact.")
83 | text = classification_segmenter.de_segment(documents).text
84 | return Document(text, redacted_text=text)
85 | docs_for_entity_detection = []
86 | for pii_doc in pii_docs:
87 | docs_for_entity_detection.extend(detection_segmenter.segment(pii_doc.text, pii_doc.char_offset))
88 |
89 | docs_with_pii_entities = comprehend.detect_pii_documents(docs_for_entity_detection, language_code)
90 | resultant_doc = classification_segmenter.de_segment(documents + docs_with_pii_entities)
91 | assert len(resultant_doc.text) == len(text), "Not able to recover original document after segmentation and desegmentation."
92 | redacted_text = redactor.redact(text, resultant_doc.pii_entities)
93 | resultant_doc.redacted_text = redacted_text
94 | return resultant_doc
95 |
96 |
97 | def classify(text, classification_segmenter: Segmenter, comprehend: ComprehendClient,
98 | detection_config: ClassificationConfig, language_code) -> List[str]:
99 | """
100 | Detect pii data from given text. Logic for detecting:- .
101 |
102 | 1. Segment text into segments of reasonable sizes (max doc size supported by comprehend) for
103 | doing initial classification
104 | 2. For each segment,
105 | 2.1 call comprehend's classify-pii-document api to determine if it contains any PII data
106 | 2.2 if it contains pii that is in the detection config then return those pii, else move to the next segment
107 | 3. If no pii detected, return empty list, else list of pii types found that is also in the detection config
108 | and above the given threshold
109 | """
110 | pii_classified_documents = comprehend.contains_pii_entities(classification_segmenter.segment(text), language_code)
111 | pii_types = set()
112 | for doc in pii_classified_documents:
113 | doc_pii_types = get_interested_pii(doc, detection_config)
114 | pii_types |= set(doc_pii_types)
115 | return list(pii_types)
116 |
117 |
118 | def redact_pii_documents_handler(event, context):
119 | """Redaction Lambda function handler."""
120 | LOG.info('Received event with requestId: %s', event[REQUEST_ID])
121 | LOG.debug(f'Raw event {event}')
122 |
123 | InputEventValidator.validate(event)
124 | invoke_args = json.loads(event[S3OL_CONFIGURATION][PAYLOAD]) if event[S3OL_CONFIGURATION][PAYLOAD] else {}
125 | language_code = invoke_args.get(LANGUAGE_CODE, DEFAULT_LANGUAGE_CODE)
126 | redaction_config = RedactionConfig(**invoke_args)
127 | object_get_context = event[GET_OBJECT_CONTEXT]
128 | s3ol_access_point = event[S3OL_CONFIGURATION][S3OL_ACCESS_POINT_ARN]
129 | s3 = S3Client(s3ol_access_point)
130 | cloud_watch = CloudWatchClient()
131 | comprehend = ComprehendClient(s3ol_access_point=s3ol_access_point, session_id=event[REQUEST_ID], user_agent=DEFAULT_USER_AGENT,
132 | endpoint_url=COMPREHEND_ENDPOINT_URL)
133 |
134 | exception_handler = ExceptionHandler(s3)
135 |
136 | LOG.debug("Pii Entity Types to be redacted:" + str(redaction_config.pii_entity_types))
137 | processed_document = False
138 | document = Document('')
139 |
140 | try:
141 | def time_bound_task():
142 | nonlocal processed_document
143 | nonlocal document
144 | PartialObjectRequestValidator.validate(event)
145 | pii_classification_segmenter = Segmenter(DOCUMENT_MAX_SIZE_CONTAINS_PII_ENTITIES)
146 | pii_redaction_segmenter = Segmenter(DOCUMENT_MAX_SIZE_DETECT_PII_ENTITIES)
147 | redactor = Redactor(redaction_config)
148 | time1 = time.time()
149 | text, http_headers, status_code = s3.download_file_from_presigned_url(object_get_context[INPUT_S3_URL],
150 | event[USER_REQUEST][HEADERS])
151 | time2 = time.time()
152 | LOG.info(f"Downloaded the file in : {(time2 - time1)} seconds")
153 | document = redact(text, pii_classification_segmenter, pii_redaction_segmenter, redactor,
154 | comprehend, redaction_config, language_code)
155 | processed_document = True
156 | time1 = time.time()
157 | LOG.info(f"Pii redaction completed within {(time1 - time2)} seconds. Returning back the response to S3")
158 | redacted_text_bytes = document.redacted_text.encode('utf-8')
159 | http_headers[CONTENT_LENGTH] = len(redacted_text_bytes)
160 | s3.respond_back_with_data(redacted_text_bytes, http_headers, object_get_context[REQUEST_ROUTE],
161 | object_get_context[REQUEST_TOKEN], status_code)
162 |
163 | execute_task_with_timeout(context.get_remaining_time_in_millis() - RESERVED_TIME_FOR_CLEANUP, time_bound_task)
164 | except Exception as generated_exception:
165 | exception_handler.handle_exception(generated_exception, object_get_context[REQUEST_ROUTE], object_get_context[REQUEST_TOKEN])
166 | finally:
167 | if PUBLISH_CLOUD_WATCH_METRICS:
168 | pii_entities = get_interested_pii(document, redaction_config)
169 | publish_metrics(cloud_watch, s3, comprehend, processed_document, len(pii_entities) > 0, language_code,
170 | s3ol_access_point, pii_entities)
171 |
172 | LOG.info("Responded back to s3 successfully")
173 |
174 |
175 | def pii_access_control_handler(event, context):
176 | """Detect Lambda function handler."""
177 | LOG.info(f'Received event with requestId: {event[REQUEST_ID]}')
178 | LOG.debug(f'Raw event {event}')
179 |
180 | InputEventValidator.validate(event)
181 | invoke_args = json.loads(event[S3OL_CONFIGURATION][PAYLOAD]) if event[S3OL_CONFIGURATION][PAYLOAD] else {}
182 | language_code = invoke_args.get(LANGUAGE_CODE, DEFAULT_LANGUAGE_CODE)
183 | detection_config = ClassificationConfig(**invoke_args)
184 | object_get_context = event[GET_OBJECT_CONTEXT]
185 | s3ol_access_point = event[S3OL_CONFIGURATION][S3OL_ACCESS_POINT_ARN]
186 |
187 | s3 = S3Client(s3ol_access_point)
188 | cloud_watch = CloudWatchClient()
189 | comprehend = ComprehendClient(session_id=event[REQUEST_ID], user_agent=DEFAULT_USER_AGENT, endpoint_url=COMPREHEND_ENDPOINT_URL,
190 | s3ol_access_point=s3ol_access_point)
191 | exception_handler = ExceptionHandler(s3)
192 |
193 | LOG.debug("Pii Entity Types to be detected:" + str(detection_config.pii_entity_types))
194 |
195 | pii_classification_segmenter = Segmenter(DOCUMENT_MAX_SIZE_CONTAINS_PII_ENTITIES)
196 |
197 | processed_document = False
198 | processed_pii_document = False
199 | pii_entities = []
200 |
201 | try:
202 | def time_bound_task():
203 | nonlocal processed_document
204 | nonlocal processed_pii_document
205 | nonlocal pii_entities
206 | PartialObjectRequestValidator.validate(event)
207 | time1 = time.time()
208 | text, http_headers, status_code = s3.download_file_from_presigned_url(object_get_context[INPUT_S3_URL],
209 | event[USER_REQUEST][HEADERS])
210 | time2 = time.time()
211 | LOG.info(f"Downloaded the file in : {(time2 - time1)} seconds")
212 | pii_entities = classify(text, pii_classification_segmenter, comprehend, detection_config, language_code)
213 | time1 = time.time()
214 |
215 | processed_document = True
216 | LOG.info(f"Pii detection completed within {(time1 - time2)} seconds. Returning back the response to S3")
217 | if len(pii_entities) > 0:
218 | processed_pii_document = True
219 | raise RestrictedDocumentException()
220 | else:
221 | text_bytes = text.encode('utf-8')
222 | http_headers[CONTENT_LENGTH] = len(text_bytes)
223 | s3.respond_back_with_data(text_bytes, http_headers, object_get_context[REQUEST_ROUTE],
224 | object_get_context[REQUEST_TOKEN],
225 | status_code)
226 |
227 | execute_task_with_timeout(context.get_remaining_time_in_millis() - RESERVED_TIME_FOR_CLEANUP, time_bound_task)
228 | except Exception as generated_exception:
229 | exception_handler.handle_exception(generated_exception, object_get_context[REQUEST_ROUTE], object_get_context[REQUEST_TOKEN])
230 | finally:
231 | if PUBLISH_CLOUD_WATCH_METRICS:
232 | publish_metrics(cloud_watch, s3, comprehend, processed_document, processed_pii_document, language_code,
233 | s3ol_access_point, pii_entities)
234 |
235 | LOG.info("Responded back to s3 successfully")
236 |
--------------------------------------------------------------------------------
/src/constants.py:
--------------------------------------------------------------------------------
1 | """Collection of constants used in the code."""
2 | from enum import Enum, auto
3 |
4 | from typing import Tuple
5 |
6 | ENTITY_TYPE = "Type"
7 | BEGIN_OFFSET = "BeginOffset"
8 | END_OFFSET = "EndOffset"
9 | NAME = "Name"
10 | SCORE = "Score"
11 | ALL = "ALL"
12 | REPLACE_WITH_PII_ENTITY_TYPE = "REPLACE_WITH_PII_ENTITY_TYPE"
13 | MASK = "MASK"
14 | REQUEST_ID = "xAmzRequestId"
15 | USER_REQUEST = "userRequest"
16 | HEADERS = "headers"
17 | RANGE = "Range"
18 | PART_NUMBER = "PartNumber"
19 | REQUEST_ROUTE = "outputRoute"
20 | REQUEST_TOKEN = "outputToken"
21 | GET_OBJECT_CONTEXT = "getObjectContext"
22 | INPUT_S3_URL = "inputS3Url"
23 | S3OL_CONFIGURATION = "configuration"
24 | S3OL_ACCESS_POINT_ARN = "accessPointArn"
25 | CONTENT_LENGTH = "Content-Length"
26 | OVERLAP_TOKENS = "overlap_tokens"
27 | PAYLOAD = "payload"
28 | ONE_DOC_PER_LINE = "ONE_DOC_PER_LINE"
29 | ONE_DOC_PER_FILE = "ONE_DOC_PER_FILE"
30 | LANGUAGE_CODE = "language_code"
31 |
32 | DEFAULT_USER_AGENT = "S3ObjectLambda/1.0"
33 |
34 | RESERVED_TIME_FOR_CLEANUP = 2000 # We need at least this much time (in millis) to perform cleanup tasks like flushing the metrics
35 | COMPREHEND_MAX_RETRIES = 7
36 | S3_MAX_RETRIES = 10
37 | CLOUD_WATCH_NAMESPACE = "ComprehendS3ObjectLambda"
38 | LATENCY = "Latency"
39 | ERROR_COUNT = "ErrorCount"
40 | API = "API"
41 | CONTAINS_PII_ENTITIES = "ContainsPiiEntities"
42 | DETECT_PII_ENTITIES = "DetectPiiEntities"
43 | PII_DOCUMENTS_PROCESSED = "PiiDocumentsProcessed"
44 | DOCUMENTS_PROCESSED = "DocumentsProcessed"
45 | PII_DOCUMENT_TYPES_PROCESSED = "PiiDocumentTypesProcessed"
46 | PII_ENTITY_TYPE = "PiiEntityType"
47 | SERVICE = "Service"
48 | COMPREHEND = "Comprehend"
49 | S3 = "S3"
50 | WRITE_GET_OBJECT_RESPONSE = "WriteGetObjectResponse"
51 | DOWNLOAD_PRESIGNED_URL = "DownloadPresignedUrl"
52 | LANGUAGE = "Language"
53 | MILLISECONDS = "Milliseconds"
54 | COUNT = "Count"
55 | VALUE = "Value"
56 | S3OL_ACCESS_POINT = "S3ObjectLambdaAccessPoint"
57 | METRIC_NAME = "MetricName"
58 | UNIT = "Unit"
59 | DIMENSIONS = "Dimensions"
60 |
61 |
62 | class UNSUPPORTED_FILE_HANDLING_VALID_VALUES(Enum):
63 | """Valid values for handling logic for Unsupported files."""
64 |
65 | PASS = auto()
66 | FAIL = auto()
67 |
68 |
69 | class MASK_MODE_VALID_VALUES(Enum):
70 | """Valid values for MASK_MODE variable."""
71 |
72 | MASK = auto()
73 | REPLACE_WITH_PII_ENTITY_TYPE = auto()
74 |
75 |
76 | class S3_STATUS_CODES(Enum):
77 | """
78 | Valid http status codes for S3.
79 | Refer https://docs.aws.amazon.com/AmazonS3/latest/API/ErrorResponses.html#ErrorCodeList for more details on the status codes.
80 | """
81 |
82 | OK_200 = auto()
83 | PARTIAL_CONTENT_206 = auto()
84 | NOT_MODIFIED_304 = auto()
85 | BAD_REQUEST_400 = auto()
86 | UNAUTHORIZED_401 = auto()
87 | FORBIDDEN_403 = auto()
88 | NOT_FOUND_404 = auto()
89 | METHOD_NOT_ALLOWED_405 = auto()
90 | CONFLICT_409 = auto()
91 | LENGTH_REQUIRED_411 = auto()
92 | PRECONDITION_FAILED_412 = auto()
93 | RANGE_NOT_SATISFIABLE_416 = auto()
94 | INTERNAL_SERVER_ERROR_500 = auto()
95 | SERVICE_UNAVAILABLE_503 = auto()
96 |
97 | def get_http_status_code(self) -> int:
98 | """Convert s3 status codes to integer http status codes."""
99 | return int(self.name.split('_')[-1])
100 |
101 |
102 | class S3_ERROR_CODES(Enum):
103 | """
104 | Valid error codes for S3.
105 | Refer https://docs.aws.amazon.com/AmazonS3/latest/API/ErrorResponses.html#ErrorCodeList for more details on error code.
106 | """
107 |
108 | AccessDenied = auto()
109 | AccountProblem = auto()
110 | AllAccessDisabled = auto()
111 | AmbiguousGrantByEmailAddress = auto()
112 | AuthorizationHeaderMalformed = auto()
113 | BadDigest = auto()
114 | BucketAlreadyExists = auto()
115 | BucketAlreadyOwnedByYou = auto()
116 | BucketNotEmpty = auto()
117 | CredentialsNotSupported = auto()
118 | CrossLocationLoggingProhibited = auto()
119 | EntityTooSmall = auto()
120 | EntityTooLarge = auto()
121 | ExpiredToken = auto()
122 | IllegalLocationConstraintException = auto()
123 | IllegalVersioningConfigurationException = auto()
124 | IncompleteBody = auto()
125 | IncorrectNumberOfFilesInPostRequest = auto()
126 | InlineDataTooLarge = auto()
127 | InternalError = auto()
128 | InvalidAccessKeyId = auto()
129 | InvalidAccessPoint = auto()
130 | InvalidAddressingHeader = auto()
131 | InvalidArgument = auto()
132 | InvalidBucketName = auto()
133 | InvalidBucketState = auto()
134 | InvalidDigest = auto()
135 | InvalidEncryptionAlgorithmError = auto()
136 | InvalidLocationConstraint = auto()
137 | InvalidObjectState = auto()
138 | InvalidPart = auto()
139 | InvalidPartOrder = auto()
140 | InvalidPayer = auto()
141 | InvalidPolicyDocument = auto()
142 | InvalidRange = auto()
143 | InvalidRequest = auto()
144 | InvalidSecurity = auto()
145 | InvalidSOAPRequest = auto()
146 | InvalidStorageClass = auto()
147 | InvalidTargetBucketForLogging = auto()
148 | InvalidToken = auto()
149 | InvalidURI = auto()
150 | KeyTooLongError = auto()
151 | MalformedACLError = auto()
152 | MalformedPOSTRequest = auto()
153 | MalformedXML = auto()
154 | MaxMessageLengthExceeded = auto()
155 | MaxPostPreDataLengthExceededError = auto()
156 | MetadataTooLarge = auto()
157 | MethodNotAllowed = auto()
158 | MissingAttachment = auto()
159 | MissingContentLength = auto()
160 | MissingRequestBodyError = auto()
161 | MissingSecurityElement = auto()
162 | MissingSecurityHeader = auto()
163 | NoLoggingStatusForKey = auto()
164 | NoSuchBucket = auto()
165 | NoSuchBucketPolicy = auto()
166 | NoSuchKey = auto()
167 | NoSuchLifecycleConfiguration = auto()
168 | NoSuchUpload = auto()
169 | NoSuchVersion = auto()
170 | NotImplemented = auto()
171 | NotSignedUp = auto()
172 | OperationAborted = auto()
173 | PermanentRedirect = auto()
174 | PreconditionFailed = auto()
175 | Redirect = auto()
176 | RestoreAlreadyInProgress = auto()
177 | RequestIsNotMultiPartContent = auto()
178 | RequestTimeout = auto()
179 | RequestTimeTooSkewed = auto()
180 | RequestTorrentOfBucketError = auto()
181 | ServerSideEncryptionConfigurationNotFoundError = auto()
182 | ServiceUnavailable = auto()
183 | SignatureDoesNotMatch = auto()
184 | SlowDown = auto()
185 | TemporaryRedirect = auto()
186 | TokenRefreshRequired = auto()
187 | TooManyAccessPoints = auto()
188 | TooManyBuckets = auto()
189 | UnexpectedContent = auto()
190 | UnresolvableGrantByEmailAddress = auto()
191 | UserKeyMustBeSpecified = auto()
192 | NoSuchAccessPoint = auto()
193 | InvalidTag = auto()
194 | MalformedPolicy = auto()
195 |
196 |
197 | ERROR_CODE_STATUS_MAP = {
198 | S3_ERROR_CODES.AccessDenied: S3_STATUS_CODES.FORBIDDEN_403,
199 | S3_ERROR_CODES.AccountProblem: S3_STATUS_CODES.FORBIDDEN_403,
200 | S3_ERROR_CODES.AllAccessDisabled: S3_STATUS_CODES.FORBIDDEN_403,
201 | S3_ERROR_CODES.AmbiguousGrantByEmailAddress: S3_STATUS_CODES.BAD_REQUEST_400,
202 | S3_ERROR_CODES.AuthorizationHeaderMalformed: S3_STATUS_CODES.BAD_REQUEST_400,
203 | S3_ERROR_CODES.BadDigest: S3_STATUS_CODES.BAD_REQUEST_400,
204 | S3_ERROR_CODES.BucketAlreadyExists: S3_STATUS_CODES.CONFLICT_409,
205 | S3_ERROR_CODES.BucketAlreadyOwnedByYou: S3_STATUS_CODES.CONFLICT_409,
206 | S3_ERROR_CODES.BucketNotEmpty: S3_STATUS_CODES.CONFLICT_409,
207 | S3_ERROR_CODES.CredentialsNotSupported: S3_STATUS_CODES.BAD_REQUEST_400,
208 | S3_ERROR_CODES.CrossLocationLoggingProhibited: S3_STATUS_CODES.FORBIDDEN_403,
209 | S3_ERROR_CODES.EntityTooSmall: S3_STATUS_CODES.BAD_REQUEST_400,
210 | S3_ERROR_CODES.EntityTooLarge: S3_STATUS_CODES.BAD_REQUEST_400,
211 | S3_ERROR_CODES.ExpiredToken: S3_STATUS_CODES.BAD_REQUEST_400,
212 | S3_ERROR_CODES.IllegalLocationConstraintException: S3_STATUS_CODES.BAD_REQUEST_400,
213 | S3_ERROR_CODES.IllegalVersioningConfigurationException: S3_STATUS_CODES.BAD_REQUEST_400,
214 | S3_ERROR_CODES.IncompleteBody: S3_STATUS_CODES.BAD_REQUEST_400,
215 | S3_ERROR_CODES.IncorrectNumberOfFilesInPostRequest: S3_STATUS_CODES.BAD_REQUEST_400,
216 | S3_ERROR_CODES.InlineDataTooLarge: S3_STATUS_CODES.BAD_REQUEST_400,
217 | S3_ERROR_CODES.InternalError: S3_STATUS_CODES.INTERNAL_SERVER_ERROR_500,
218 | S3_ERROR_CODES.InvalidAccessKeyId: S3_STATUS_CODES.FORBIDDEN_403,
219 | S3_ERROR_CODES.InvalidAccessPoint: S3_STATUS_CODES.BAD_REQUEST_400,
220 | S3_ERROR_CODES.InvalidArgument: S3_STATUS_CODES.BAD_REQUEST_400,
221 | S3_ERROR_CODES.InvalidBucketName: S3_STATUS_CODES.BAD_REQUEST_400,
222 | S3_ERROR_CODES.InvalidBucketState: S3_STATUS_CODES.CONFLICT_409,
223 | S3_ERROR_CODES.InvalidDigest: S3_STATUS_CODES.BAD_REQUEST_400,
224 | S3_ERROR_CODES.InvalidEncryptionAlgorithmError: S3_STATUS_CODES.BAD_REQUEST_400,
225 | S3_ERROR_CODES.InvalidLocationConstraint: S3_STATUS_CODES.BAD_REQUEST_400,
226 | S3_ERROR_CODES.InvalidObjectState: S3_STATUS_CODES.FORBIDDEN_403,
227 | S3_ERROR_CODES.InvalidPart: S3_STATUS_CODES.BAD_REQUEST_400,
228 | S3_ERROR_CODES.InvalidPartOrder: S3_STATUS_CODES.BAD_REQUEST_400,
229 | S3_ERROR_CODES.InvalidPayer: S3_STATUS_CODES.FORBIDDEN_403,
230 | S3_ERROR_CODES.InvalidPolicyDocument: S3_STATUS_CODES.BAD_REQUEST_400,
231 | S3_ERROR_CODES.InvalidRange: S3_STATUS_CODES.RANGE_NOT_SATISFIABLE_416,
232 | S3_ERROR_CODES.InvalidRequest: S3_STATUS_CODES.BAD_REQUEST_400,
233 | S3_ERROR_CODES.InvalidSecurity: S3_STATUS_CODES.FORBIDDEN_403,
234 | S3_ERROR_CODES.InvalidSOAPRequest: S3_STATUS_CODES.BAD_REQUEST_400,
235 | S3_ERROR_CODES.InvalidStorageClass: S3_STATUS_CODES.BAD_REQUEST_400,
236 | S3_ERROR_CODES.InvalidTargetBucketForLogging: S3_STATUS_CODES.BAD_REQUEST_400,
237 | S3_ERROR_CODES.InvalidToken: S3_STATUS_CODES.BAD_REQUEST_400,
238 | S3_ERROR_CODES.InvalidURI: S3_STATUS_CODES.BAD_REQUEST_400,
239 | S3_ERROR_CODES.KeyTooLongError: S3_STATUS_CODES.BAD_REQUEST_400,
240 | S3_ERROR_CODES.MalformedACLError: S3_STATUS_CODES.BAD_REQUEST_400,
241 | S3_ERROR_CODES.MalformedPOSTRequest: S3_STATUS_CODES.BAD_REQUEST_400,
242 | S3_ERROR_CODES.MalformedXML: S3_STATUS_CODES.BAD_REQUEST_400,
243 | S3_ERROR_CODES.MaxMessageLengthExceeded: S3_STATUS_CODES.BAD_REQUEST_400,
244 | S3_ERROR_CODES.MaxPostPreDataLengthExceededError: S3_STATUS_CODES.BAD_REQUEST_400,
245 | S3_ERROR_CODES.MetadataTooLarge: S3_STATUS_CODES.BAD_REQUEST_400,
246 | S3_ERROR_CODES.MethodNotAllowed: S3_STATUS_CODES.METHOD_NOT_ALLOWED_405,
247 | S3_ERROR_CODES.MissingContentLength: S3_STATUS_CODES.LENGTH_REQUIRED_411,
248 | S3_ERROR_CODES.MissingRequestBodyError: S3_STATUS_CODES.BAD_REQUEST_400,
249 | S3_ERROR_CODES.MissingSecurityElement: S3_STATUS_CODES.BAD_REQUEST_400,
250 | S3_ERROR_CODES.MissingSecurityHeader: S3_STATUS_CODES.BAD_REQUEST_400,
251 | S3_ERROR_CODES.NoLoggingStatusForKey: S3_STATUS_CODES.BAD_REQUEST_400,
252 | S3_ERROR_CODES.NoSuchBucket: S3_STATUS_CODES.NOT_FOUND_404,
253 | S3_ERROR_CODES.NoSuchBucketPolicy: S3_STATUS_CODES.NOT_FOUND_404,
254 | S3_ERROR_CODES.NoSuchKey: S3_STATUS_CODES.NOT_FOUND_404,
255 | S3_ERROR_CODES.NoSuchLifecycleConfiguration: S3_STATUS_CODES.NOT_FOUND_404,
256 | S3_ERROR_CODES.NoSuchUpload: S3_STATUS_CODES.NOT_FOUND_404,
257 | S3_ERROR_CODES.NoSuchVersion: S3_STATUS_CODES.NOT_FOUND_404,
258 | S3_ERROR_CODES.NotSignedUp: S3_STATUS_CODES.FORBIDDEN_403,
259 | S3_ERROR_CODES.OperationAborted: S3_STATUS_CODES.CONFLICT_409,
260 | S3_ERROR_CODES.PreconditionFailed: S3_STATUS_CODES.PRECONDITION_FAILED_412,
261 | S3_ERROR_CODES.RestoreAlreadyInProgress: S3_STATUS_CODES.CONFLICT_409,
262 | S3_ERROR_CODES.RequestIsNotMultiPartContent: S3_STATUS_CODES.BAD_REQUEST_400,
263 | S3_ERROR_CODES.RequestTimeout: S3_STATUS_CODES.BAD_REQUEST_400,
264 | S3_ERROR_CODES.RequestTimeTooSkewed: S3_STATUS_CODES.FORBIDDEN_403,
265 | S3_ERROR_CODES.RequestTorrentOfBucketError: S3_STATUS_CODES.BAD_REQUEST_400,
266 | S3_ERROR_CODES.ServerSideEncryptionConfigurationNotFoundError: S3_STATUS_CODES.BAD_REQUEST_400,
267 | S3_ERROR_CODES.ServiceUnavailable: S3_STATUS_CODES.SERVICE_UNAVAILABLE_503,
268 | S3_ERROR_CODES.SignatureDoesNotMatch: S3_STATUS_CODES.FORBIDDEN_403,
269 | S3_ERROR_CODES.SlowDown: S3_STATUS_CODES.SERVICE_UNAVAILABLE_503,
270 | S3_ERROR_CODES.TokenRefreshRequired: S3_STATUS_CODES.BAD_REQUEST_400,
271 | S3_ERROR_CODES.TooManyAccessPoints: S3_STATUS_CODES.BAD_REQUEST_400,
272 | S3_ERROR_CODES.TooManyBuckets: S3_STATUS_CODES.BAD_REQUEST_400,
273 | S3_ERROR_CODES.UnexpectedContent: S3_STATUS_CODES.BAD_REQUEST_400,
274 | S3_ERROR_CODES.UnresolvableGrantByEmailAddress: S3_STATUS_CODES.BAD_REQUEST_400,
275 | S3_ERROR_CODES.UserKeyMustBeSpecified: S3_STATUS_CODES.BAD_REQUEST_400,
276 | S3_ERROR_CODES.NoSuchAccessPoint: S3_STATUS_CODES.NOT_FOUND_404,
277 | S3_ERROR_CODES.InvalidTag: S3_STATUS_CODES.BAD_REQUEST_400,
278 | S3_ERROR_CODES.MalformedPolicy: S3_STATUS_CODES.BAD_REQUEST_400
279 | }
280 |
281 |
282 | def error_code_to_enums(error_code: str) -> Tuple[S3_STATUS_CODES, S3_ERROR_CODES]:
283 | """Error code to enums."""
284 | for code, status in ERROR_CODE_STATUS_MAP.items():
285 | if error_code == code.name:
286 | return status, code
287 | return S3_STATUS_CODES.INTERNAL_SERVER_ERROR_500, S3_ERROR_CODES.InternalError
288 |
289 |
290 | def http_status_code_to_s3_status_code(http_status_code: int) -> S3_STATUS_CODES:
291 | """Convert http status codes to s3 status codes."""
292 | for status_codes in S3_STATUS_CODES:
293 | if str(http_status_code) == status_codes.name[-3:]:
294 | return status_codes
295 | return S3_STATUS_CODES.INTERNAL_SERVER_ERROR_500
296 |
--------------------------------------------------------------------------------
/test/integ/integ_base.py:
--------------------------------------------------------------------------------
1 | import logging
2 | from datetime import datetime, timedelta
3 | from time import sleep
4 | from unittest import TestCase
5 | import os
6 | import boto3
7 | import uuid
8 | import zipfile
9 | import json
10 | import time
11 |
12 | from dateutil.tz import tzutc
13 |
14 | LOG_LEVEL = "LOG_LEVEL"
15 | UNSUPPORTED_FILE_HANDLING = "UNSUPPORTED_FILE_HANDLING"
16 | DOCUMENT_MAX_SIZE_CONTAINS_PII_ENTITIES = "DOCUMENT_MAX_SIZE_CONTAINS_PII_ENTITIES"
17 | DOCUMENT_MAX_SIZE_DETECT_PII_ENTITIES = "DOCUMENT_MAX_SIZE_DETECT_PII_ENTITIES"
18 | PII_ENTITY_TYPES = "PII_ENTITY_TYPES"
19 | MASK_CHARACTER = "MASK_CHARACTER"
20 | MASK_MODE = "MASK_MODE"
21 | SUBSEGMENT_OVERLAPPING_TOKENS = "SUBSEGMENT_OVERLAPPING_TOKENS"
22 | DOCUMENT_MAX_SIZE = "DOCUMENT_MAX_SIZE"
23 | CONFIDENCE_THRESHOLD = "CONFIDENCE_THRESHOLD"
24 | MAX_CHARS_OVERLAP = "MAX_CHARS_OVERLAP"
25 | DEFAULT_LANGUAGE_CODE = "DEFAULT_LANGUAGE_CODE"
26 | DETECT_PII_ENTITIES_THREAD_COUNT = "DETECT_PII_ENTITIES_THREAD_COUNT"
27 | CONTAINS_PII_ENTITIES_THREAD_COUNT = "CONTAINS_PII_ENTITIES_THREAD_COUNT"
28 | REDACTION_API_ONLY = "REDACTION_API_ONLY"
29 | PUBLISH_CLOUD_WATCH_METRICS = "PUBLISH_CLOUD_WATCH_METRICS"
30 | CW_METRIC_PUBLISH_CHECK_ATTEMPT = 5
31 |
32 |
33 | class BasicIntegTest(TestCase):
34 | REGION_NAME = 'us-east-1'
35 | DATA_PATH = "test/data/integ"
36 | PII_ENTITY_TYPES_IN_TEST_DOC = ['EMAIL', 'PHONE', 'BANK_ROUTING', 'BANK_ACCOUNT_NUMBER', 'ADDRESS', 'DATE_TIME']
37 |
38 | @classmethod
39 | def _zip_function_code(cls, build_dir):
40 | source_dir = build_dir
41 | output_filename = "s3ol_function.zip"
42 | relroot = os.path.abspath(source_dir)
43 | with zipfile.ZipFile(output_filename, "w", zipfile.ZIP_DEFLATED) as zipf:
44 | logging.info(f"Build Dir: {source_dir}")
45 | for root, dirs, files in os.walk(source_dir):
46 | # add directory (needed for empty dirs)
47 | zipf.write(root, os.path.relpath(root, relroot))
48 | for file in files:
49 | filename = os.path.join(root, file)
50 | if os.path.isfile(filename): # regular files only
51 | arcname = os.path.join(os.path.relpath(root, relroot), file)
52 | zipf.write(filename, arcname)
53 |
54 | return output_filename
55 |
56 | @classmethod
57 | def _update_lambda_env_variables(cls, function_arn, env_variable_dict):
58 | function_name = function_arn.split(':')[-1]
59 | cls.lambda_client.update_function_configuration(FunctionName=function_name, Environment={'Variables': env_variable_dict})
60 | waiter = cls.lambda_client.get_waiter('function_updated')
61 | waiter.wait(FunctionName=function_name)
62 |
63 | @classmethod
64 | def _create_function(cls, function_name, handler_name, test_id, env_vars_dict, lambda_execution_role, build_dir):
65 | function_unique_name = f"{function_name}-{test_id}"
66 | zip_file_name = cls._zip_function_code(build_dir)
67 | with open(zip_file_name, 'rb') as file_data:
68 | bytes_content = file_data.read()
69 |
70 | create_function_response = cls.lambda_client.create_function(
71 | FunctionName=function_unique_name,
72 | Runtime='python3.8',
73 | Role=lambda_execution_role,
74 | Handler=handler_name,
75 | Code={
76 | 'ZipFile': bytes_content
77 | },
78 | Timeout=60,
79 | MemorySize=128,
80 | Publish=True,
81 | Environment={
82 | 'Variables': env_vars_dict
83 | }
84 | )
85 | return create_function_response['FunctionArn']
86 |
87 | @classmethod
88 | def _update_function_handler(cls, handler):
89 | response = cls.lambda_client.update_function_configuration(
90 | FunctionName=cls.function_name,
91 | Handler=handler
92 | )
93 |
94 | @classmethod
95 | def _create_role(cls, test_id):
96 |
97 | cls.role_name = f"s3ol-integ-test-role-{test_id}"
98 |
99 | assume_role_policy_document = {
100 | "Version": "2012-10-17",
101 | "Statement": [
102 | {
103 | "Effect": "Allow",
104 | "Principal": {
105 | "Service": "lambda.amazonaws.com"
106 | },
107 | "Action": "sts:AssumeRole"
108 | }
109 | ]
110 | }
111 |
112 | create_iam_role_response = cls.iam_client.create_role(
113 | Path='/s3ol-integ-test/',
114 | RoleName=cls.role_name,
115 | AssumeRolePolicyDocument=json.dumps(assume_role_policy_document)
116 | )
117 |
118 | role_arn = create_iam_role_response['Role']['Arn']
119 |
120 | policy_document = {
121 | "Statement": [
122 | {
123 | "Action": [
124 | "comprehend:DetectPiiEntities",
125 | "comprehend:ContainsPiiEntities"
126 | ],
127 | "Resource": "*",
128 | "Effect": "Allow",
129 | "Sid": "ComprehendPiiDetectionPolicy"
130 | },
131 | {
132 | "Action": [
133 | "s3-object-lambda:WriteGetObjectResponse"
134 | ],
135 | "Resource": "*",
136 | "Effect": "Allow",
137 | "Sid": "S3AccessPointCallbackPolicy"
138 | },
139 | {
140 | "Action": [
141 | "cloudwatch:PutMetricData"
142 | ],
143 | "Resource": "*",
144 | "Effect": "Allow",
145 | "Sid": "CloudWatchMetricsPolicy"
146 | }
147 | ]
148 | }
149 |
150 | put_role_policy_response = cls.iam_client.put_role_policy(
151 | RoleName=cls.role_name,
152 | PolicyName='S3OLFunctionPolicy',
153 | PolicyDocument=json.dumps(policy_document)
154 | )
155 |
156 | attach_role_policy_response = cls.iam_client.attach_role_policy(
157 | RoleName=cls.role_name,
158 | PolicyArn='arn:aws:iam::aws:policy/service-role/AWSLambdaBasicExecutionRole'
159 | )
160 | # wait for some time to let iam role propagate
161 | sleep(10)
162 | return role_arn
163 |
164 | @classmethod
165 | def _create_bucket(cls, test_id):
166 | cls.bucket_name = f"s3ol-integ-test-{test_id}"
167 | bucket = cls.s3.Bucket(cls.bucket_name)
168 | create_bucket_response = bucket.create()
169 | cls.s3.BucketVersioning(cls.bucket_name).enable()
170 |
171 | @classmethod
172 | def _create_access_point(cls, test_id):
173 | cls.access_point_name = f"s3ol-integ-test-ac-{test_id}"
174 |
175 | create_access_point_response = cls.s3_ctrl.create_access_point(
176 | AccountId=cls.account_id,
177 | Name=cls.access_point_name,
178 | Bucket=cls.bucket_name,
179 | )
180 | cls.access_point_arn = f"arn:aws:s3:{cls.REGION_NAME}:{cls.account_id}:accesspoint/{cls.access_point_name}"
181 |
182 | @classmethod
183 | def _create_s3ol_access_point(cls, test_id, lambda_function_arn):
184 | cls.s3ol_access_point_name = f"s3ol-integ-test-bac-{test_id}"
185 |
186 | create_s3ol_access_point_response = cls.s3_ctrl.create_access_point_for_object_lambda(
187 | AccountId=cls.account_id,
188 | Name=cls.s3ol_access_point_name,
189 | Configuration={
190 | 'SupportingAccessPoint': cls.access_point_arn,
191 | 'TransformationConfigurations': [
192 | {
193 | 'Actions': ['GetObject'],
194 | 'ContentTransformation': {
195 | 'AwsLambda': {
196 | 'FunctionArn': lambda_function_arn
197 | }
198 | }
199 | }
200 | ],
201 | "AllowedFeatures": ["GetObject-Range"]
202 | }
203 | )
204 | return f"arn:aws:s3-object-lambda:{cls.REGION_NAME}:{cls.account_id}:accesspoint/{cls.s3ol_access_point_name}"
205 |
206 | @classmethod
207 | def _create_temporary_data_files(cls):
208 | non_pii_text = "\nentities identified in the input text. For each entity, the response provides the entity type, where the entity text begins and ends, and the level of confidence entities identified in the input text. For each entity, the response provides the entity type, where the entity text begins and ends, and the level of confidence that Amazon Comprehend has in the detectio"
209 |
210 | def _create_file(input_file_name, output_file_name, repeats=4077):
211 | with open(f"{cls.DATA_PATH}/{input_file_name}") as existing_file:
212 | modified_file_content = existing_file.read()
213 | for i in range(0, repeats):
214 | modified_file_content += non_pii_text
215 | with open(f"{cls.DATA_PATH}/{output_file_name}", 'w') as new_file:
216 | new_file.seek(0)
217 | new_file.write(modified_file_content)
218 |
219 | _create_file('pii_input.txt', '1mb_pii_text')
220 | _create_file('pii_input.txt', '5mb_pii_text', repeats=15000)
221 | _create_file('pii_output.txt', '1mb_pii_redacted_text')
222 |
223 | @classmethod
224 | def _clear_temporary_files(cls):
225 | os.remove(f"{cls.DATA_PATH}/1mb_pii_text")
226 | os.remove(f"{cls.DATA_PATH}/5mb_pii_text")
227 | os.remove(f"{cls.DATA_PATH}/1mb_pii_redacted_text")
228 |
229 | @classmethod
230 | def _upload_data(cls):
231 | for filename in os.listdir(cls.DATA_PATH):
232 | cls.s3_client.upload_file(f"{cls.DATA_PATH}/{filename}", cls.bucket_name, filename)
233 |
234 | @classmethod
235 | def setUpClass(cls):
236 | cls.s3 = boto3.resource('s3', region_name=cls.REGION_NAME)
237 | cls.s3_client = boto3.client('s3', region_name=cls.REGION_NAME)
238 | cls.s3_ctrl = boto3.client('s3control', region_name=cls.REGION_NAME)
239 | cls.lambda_client = boto3.client('lambda', region_name=cls.REGION_NAME)
240 | cls.cloudwatch_client = boto3.client('cloudwatch', region_name=cls.REGION_NAME)
241 | cls.iam = boto3.resource('iam')
242 | cls.account_id = cls.iam.CurrentUser().arn.split(':')[4]
243 | cls.iam_client = boto3.client('iam')
244 | test_run_id = str(uuid.uuid4())[0:8]
245 |
246 | cls.lambda_role_arn = cls._create_role(test_run_id)
247 | cls._create_bucket(test_run_id)
248 | cls._create_access_point(test_run_id)
249 | cls._create_temporary_data_files()
250 | cls._upload_data()
251 |
252 | @classmethod
253 | def tearDownClass(cls):
254 | try:
255 | delete_access_point_response = cls.s3_ctrl.delete_access_point(
256 | AccountId=cls.account_id,
257 | Name=cls.access_point_name
258 | )
259 | except Exception as e:
260 | logging.error(e)
261 | cls._clear_temporary_files()
262 | bucket = cls.s3.Bucket(cls.bucket_name)
263 | delete_object_response = bucket.object_versions.all().delete()
264 | delete_bucket_response = bucket.delete()
265 |
266 | delete_role_policy_response = cls.iam_client.delete_role_policy(
267 | RoleName=cls.role_name,
268 | PolicyName='S3OLFunctionPolicy'
269 | )
270 |
271 | detach_role_policy_response = cls.iam_client.detach_role_policy(
272 | RoleName=cls.role_name,
273 | PolicyArn='arn:aws:iam::aws:policy/service-role/AWSLambdaBasicExecutionRole'
274 | )
275 |
276 | delete_role_response = cls.iam_client.delete_role(
277 | RoleName=cls.role_name
278 | )
279 |
280 | def _validate_pii_count_metric_published(self, s3ol_access_point_arn, start_time: datetime, entity_types):
281 | end_time = start_time + timedelta(minutes=1)
282 | st_time = start_time - timedelta(minutes=1)
283 |
284 | def _is_metric_published() -> bool:
285 | for e_type in entity_types:
286 | pii_doc_processed_count_metric = \
287 | self.cloudwatch_client.get_metric_data(MetricDataQueries=[{'Id': 'a1232454353',
288 | 'MetricStat': {'Metric': {
289 | 'MetricName': 'PiiDocumentTypesProcessed',
290 | 'Namespace': 'ComprehendS3ObjectLambda',
291 | 'Dimensions': [
292 | {'Name': 'PiiEntityType',
293 | 'Value': e_type},
294 | {'Name': 'Language',
295 | 'Value': 'en'},
296 | {'Name': 'S3ObjectLambdaAccessPoint',
297 | 'Value': s3ol_access_point_arn}]},
298 | 'Period': 60,
299 | 'Stat': 'Sum'}}],
300 | StartTime=st_time,
301 | EndTime=end_time)
302 | for result in pii_doc_processed_count_metric['MetricDataResults']:
303 | if result['Id'] == "a1232454353":
304 | if len(result['Values']) == 0:
305 | return False
306 | return True
307 |
308 | attempts = CW_METRIC_PUBLISH_CHECK_ATTEMPT
309 | while attempts > 0:
310 | if _is_metric_published():
311 | return None
312 | sleep(10)
313 | attempts -= 1
314 | assert False, f"No metrics published for s3ol arn {s3ol_access_point_arn} for one of the entity types: {entity_types}," \
315 | f"StartTime: {st_time}, EndTime: {end_time}"
316 |
317 | def _validate_api_call_latency_published(self, s3ol_access_point_arn, start_time: datetime, end_time=None,
318 | is_pii_classification_performed: bool = True,
319 | is_pii_detection_performed: bool = False,
320 | is_s3ol_callback_done: bool = True):
321 | end_time = start_time + timedelta(minutes=1) if not end_time else end_time
322 | st_time = start_time - timedelta(minutes=1)
323 |
324 | def _is_metric_published(api_name, service, ) -> bool:
325 | pii_doc_processed_count_metric = \
326 | self.cloudwatch_client.get_metric_data(MetricDataQueries=[{'Id': 'a1232454353',
327 | 'MetricStat': {'Metric': {
328 | 'MetricName': 'Latency',
329 | 'Namespace': 'ComprehendS3ObjectLambda',
330 | 'Dimensions': [
331 | {'Name': 'API',
332 | 'Value': api_name},
333 | {'Name': 'S3ObjectLambdaAccessPoint',
334 | 'Value': s3ol_access_point_arn},
335 | {'Name': 'Service',
336 | 'Value': service}
337 | ]},
338 | 'Period': 60,
339 | 'Stat': 'Average'}}],
340 | StartTime=st_time,
341 | EndTime=end_time)
342 | for result in pii_doc_processed_count_metric['MetricDataResults']:
343 | if result['Id'] == "a1232454353":
344 | if len(result['Values']) == 0:
345 | return False
346 | return True
347 |
348 | attempts = CW_METRIC_PUBLISH_CHECK_ATTEMPT
349 | metric_checked = {
350 | 'ContainsPiiEntities': False,
351 | 'DetectPiiEntities': False,
352 | 'WriteGetObjectResponse': False
353 | }
354 | while attempts > 0:
355 | if is_pii_classification_performed and not metric_checked['ContainsPiiEntities']:
356 | if _is_metric_published(api_name="ContainsPiiEntities", service="Comprehend"):
357 | metric_checked['ContainsPiiEntities'] = True
358 | if is_pii_detection_performed and not metric_checked['DetectPiiEntities']:
359 | if _is_metric_published(api_name="DetectPiiEntities", service="Comprehend"):
360 | metric_checked['DetectPiiEntities'] = True
361 | if is_s3ol_callback_done and not metric_checked['WriteGetObjectResponse']:
362 | if _is_metric_published(api_name="WriteGetObjectResponse", service="S3"):
363 | metric_checked['WriteGetObjectResponse'] = True
364 | sleep(10)
365 | attempts -= 1
366 | if is_pii_classification_performed:
367 | assert metric_checked['ContainsPiiEntities'], \
368 | f"Could not verify that metrics were published for various API call made between StartTime: {st_time}, EndTime: {end_time}"
369 | if is_pii_detection_performed:
370 | assert metric_checked['DetectPiiEntities'], \
371 | f"Could not verify that metrics were published for various API call made between StartTime: {st_time}, EndTime: {end_time}"
372 | if is_s3ol_callback_done:
373 | assert metric_checked['WriteGetObjectResponse'], \
374 | f"Could not verify that metrics were published for API call made between StartTime: {st_time}, EndTime: {end_time}"
375 |
--------------------------------------------------------------------------------