├── test ├── integ │ ├── __init__.py │ ├── test_access-control.py │ ├── test_redaction.py │ └── integ_base.py ├── load │ ├── __init__.py │ ├── access-control_load_test.py │ ├── redaction_load_test.py │ └── load_test_base.py ├── data │ ├── integ │ │ ├── clean.txt │ │ ├── RandomImage.png │ │ ├── pii_input.txt │ │ ├── pii_output.txt │ │ └── pii_bank_routing_redacted.txt │ └── sample_event.json └── unit │ ├── conftest.py │ ├── test_data_object.py │ ├── test_exceptions.py │ ├── test_util.py │ ├── test_cloudwatch_client.py │ ├── test_validators.py │ ├── test_exception_handlers.py │ ├── test_comprehend_client.py │ ├── test_s3_client.py │ └── test_processors.py ├── .pydocstyle ├── .flake8 ├── images ├── architecture_redaction.gif └── architecture_access_control.gif ├── src ├── lambdainit.py ├── lambdalogging.py ├── util.py ├── config.py ├── exceptions.py ├── data_object.py ├── validators.py ├── clients │ ├── cloudwatch_client.py │ ├── comprehend_client.py │ └── s3_client.py ├── exception_handlers.py ├── processors.py ├── handler.py └── constants.py ├── Pipfile ├── README.md ├── LICENSE ├── Makefile ├── .gitignore ├── access-control-template.yml ├── redaction-template.yml ├── ACCESS_CONTROL_README.md └── REDACTION_README.md /test/integ/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /test/load/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /test/data/integ/clean.txt: -------------------------------------------------------------------------------- 1 | This document contains no PII. -------------------------------------------------------------------------------- /.pydocstyle: -------------------------------------------------------------------------------- 1 | [pydocstyle] 2 | ignore= D107,D203,D212,D205 3 | -------------------------------------------------------------------------------- /.flake8: -------------------------------------------------------------------------------- 1 | [flake8] 2 | max-line-length = 140 3 | ignore = E126,W504 4 | 5 | -------------------------------------------------------------------------------- /images/architecture_redaction.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aws-samples/amazon-comprehend-s3-object-lambda-functions/HEAD/images/architecture_redaction.gif -------------------------------------------------------------------------------- /test/data/integ/RandomImage.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aws-samples/amazon-comprehend-s3-object-lambda-functions/HEAD/test/data/integ/RandomImage.png -------------------------------------------------------------------------------- /images/architecture_access_control.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aws-samples/amazon-comprehend-s3-object-lambda-functions/HEAD/images/architecture_access_control.gif -------------------------------------------------------------------------------- /test/unit/conftest.py: -------------------------------------------------------------------------------- 1 | """Setup unit test environment.""" 2 | 3 | import sys 4 | import os 5 | 6 | # make sure tests can import the app code 7 | my_path = os.path.dirname(os.path.abspath(__file__)) 8 | sys.path.insert(0, my_path + '/../../src/') 9 | 10 | -------------------------------------------------------------------------------- /src/lambdainit.py: -------------------------------------------------------------------------------- 1 | """ 2 | Special initializations for Lambda functions. 3 | 4 | This file must be imported as the first import in any file containing a Lambda function handler method. 5 | """ 6 | 7 | import sys 8 | 9 | # add packaged dependencies to search path 10 | sys.path.append('lib') 11 | -------------------------------------------------------------------------------- /Pipfile: -------------------------------------------------------------------------------- 1 | [[source]] 2 | url = "https://pypi.python.org/simple" 3 | verify_ssl = true 4 | name = "pypi" 5 | 6 | [packages] 7 | boto3 = "*" 8 | requests = "*" 9 | 10 | [dev-packages] 11 | flake8 = "*" 12 | autopep8 = "*" 13 | pydocstyle = "*" 14 | cfn-lint = "*" 15 | awscli = "*" 16 | pytest = "*" 17 | pytest-mock = "*" 18 | pytest-cov = "*" 19 | 20 | [requires] 21 | python_version = "3.8" 22 | -------------------------------------------------------------------------------- /test/unit/test_data_object.py: -------------------------------------------------------------------------------- 1 | from unittest import TestCase 2 | 3 | from data_object import PiiConfig 4 | from exceptions import InvalidConfigurationException 5 | 6 | 7 | class DataObjectTest(TestCase): 8 | 9 | def test_Pii_config_valid_confidence_threshold(self): 10 | with self.assertRaises(InvalidConfigurationException) as e: 11 | PiiConfig(confidence_threshold=0.1) 12 | assert e.exception.message == 'CONFIDENCE_THRESHOLD is not within allowed range [0.5,1]' -------------------------------------------------------------------------------- /test/data/integ/pii_input.txt: -------------------------------------------------------------------------------- 1 | Hello Zhang Wei. Your AnyCompany Financial Services, LLC credit card account 1111-0000-1111-0000 has a minimum payment of $24.53 that is due by July 31st. Based on your autopay settings, we will withdraw your payment on the due date from your bank account XXXXXX1111 with the routing number XXXXX0000. 2 | 3 | Your latest statement was mailed to 100 Main Street, Anytown, WA 98121. 4 | 5 | After your payment is received, you will receive a confirmation text message at 206-555-0100. 6 | 7 | If you have questions about your bill, AnyCompany Customer Service is available by phone at 206-555-0199 or email at support@anycompany.com. -------------------------------------------------------------------------------- /test/data/integ/pii_output.txt: -------------------------------------------------------------------------------- 1 | Hello *********. Your AnyCompany Financial Services, LLC credit card account ******************* has a minimum payment of $24.53 that is due by *********. Based on your autopay settings, we will withdraw your payment on the due date from your bank account ********** with the routing number *********. 2 | 3 | Your latest statement was mailed to **********************************. 4 | 5 | After your payment is received, you will receive a confirmation text message at ************. 6 | 7 | If you have questions about your bill, AnyCompany Customer Service is available by phone at ************ or email at **********************. -------------------------------------------------------------------------------- /test/data/integ/pii_bank_routing_redacted.txt: -------------------------------------------------------------------------------- 1 | Hello Zhang Wei. Your AnyCompany Financial Services, LLC credit card account ******************* has a minimum payment of $24.53 that is due by July 31st. Based on your autopay settings, we will withdraw your payment on the due date from your bank account ********** with the routing number *********. 2 | 3 | Your latest statement was mailed to 100 Main Street, Anytown, WA 98121. 4 | 5 | After your payment is received, you will receive a confirmation text message at 206-555-0100. 6 | 7 | If you have questions about your bill, AnyCompany Customer Service is available by phone at 206-555-0199 or email at **********************. -------------------------------------------------------------------------------- /test/unit/test_exceptions.py: -------------------------------------------------------------------------------- 1 | from unittest import TestCase 2 | 3 | from constants import CONTENT_LENGTH, RANGE 4 | from exceptions import UnsupportedFileException 5 | 6 | 7 | class ExceptionsTest(TestCase): 8 | def test_unsupported_file_format_exception(self): 9 | try: 10 | http_headers = {CONTENT_LENGTH: 1234, RANGE: "0-123"} 11 | raise UnsupportedFileException("Some Random blob".encode('utf-8'), http_headers, "Unsupported file") 12 | except UnsupportedFileException as exception: 13 | assert exception.file_content.decode('utf-8') == "Some Random blob" 14 | assert str(exception) == "Unsupported file" 15 | assert exception.http_headers == http_headers 16 | -------------------------------------------------------------------------------- /src/lambdalogging.py: -------------------------------------------------------------------------------- 1 | """ 2 | Lambda logging helper. 3 | 4 | Returns a Logger with log level set based on env variables. 5 | """ 6 | 7 | import logging 8 | 9 | from config import LOG_LEVEL 10 | 11 | # translate log level from string to numeric value 12 | LOG_LEVEL = getattr(logging, LOG_LEVEL) if hasattr(logging, LOG_LEVEL) else logging.DEBUG 13 | 14 | # setup logging levels for botocore 15 | logging.getLogger('botocore.endpoint').setLevel(LOG_LEVEL) 16 | logging.getLogger('botocore.retryhandler').setLevel(LOG_LEVEL) 17 | logging.getLogger('botocore.parsers').setLevel(LOG_LEVEL) 18 | 19 | 20 | def getLogger(name): 21 | """Return a logger configured based on env variables.""" 22 | logger = logging.getLogger(name) 23 | # in lambda environment, logging config has already been setup so can't use logging.basicConfig to change log level 24 | logger.setLevel(LOG_LEVEL) 25 | return logger 26 | -------------------------------------------------------------------------------- /src/util.py: -------------------------------------------------------------------------------- 1 | """Utility Class.""" 2 | import concurrent 3 | from concurrent.futures._base import TimeoutError 4 | 5 | import lambdalogging 6 | from exceptions import TimeoutException 7 | 8 | LOG = lambdalogging.getLogger(__name__) 9 | 10 | 11 | def execute_task_with_timeout(timeout_in_millis, task): 12 | """ 13 | Execute a given task within a given time limit. 14 | :param timeout_in_millis: milliseconds to timeout 15 | :param task: task to execute 16 | :raise: TimeoutException 17 | """ 18 | timeout_in_sec = int(timeout_in_millis / 1000) 19 | try: 20 | executor = concurrent.futures.ThreadPoolExecutor(max_workers=1) 21 | future_result = executor.submit(task) 22 | return future_result.result(timeout=timeout_in_sec) 23 | except TimeoutError: 24 | # Free up the resources 25 | executor.shutdown(wait=False) 26 | raise TimeoutException() 27 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Comprehend S3 Object Lambda functions 2 | 3 | This package contains the source code for Lambda functions published by AWS Comprehend that can be used with [S3 Object Lambda Access Points](https://docs.aws.amazon.com/AmazonS3/latest/userguide/transforming-objects.html). 4 | These Lambda functions can be deployed from Serveless App Repositories. 5 | 6 | Refer README file for specific lambdas for more details on the functionality of each lambda. 7 | * [PII Access Control Lambda](ACCESS_CONTROL_README.md)-This lambda function helps you to control access to file present in s3 containing PII (Personally Identifiable Information). 8 | * [PII Redaction Lambda](REDACTION_README.md)-This lambda function helps you to redact PII (Personally Identifiable Information) from text files present in S3. 9 | 10 | 11 | ## License Summary 12 | 13 | This code is made available under the MIT-0 license. See the LICENSE file. 14 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | 3 | Permission is hereby granted, free of charge, to any person obtaining a copy of this 4 | software and associated documentation files (the "Software"), to deal in the Software 5 | without restriction, including without limitation the rights to use, copy, modify, 6 | merge, publish, distribute, sublicense, and/or sell copies of the Software, and to 7 | permit persons to whom the Software is furnished to do so. 8 | 9 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, 10 | INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A 11 | PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT 12 | HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION 13 | OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE 14 | SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. -------------------------------------------------------------------------------- /test/unit/test_util.py: -------------------------------------------------------------------------------- 1 | import time 2 | from time import sleep 3 | from unittest import TestCase 4 | 5 | from exceptions import TimeoutException, FileSizeLimitExceededException 6 | from util import execute_task_with_timeout 7 | 8 | 9 | class UtilTest(TestCase): 10 | 11 | def test_execute_task_with_timeout_with_time_limit_exceeded(self): 12 | with self.assertRaises(TimeoutException) as e: 13 | start_time = time.time() 14 | 15 | def sleep_2(): 16 | sleep(5) 17 | 18 | execute_task_with_timeout(1000, sleep_2) 19 | elapsed_time = time.time() - start_time 20 | assert 1000 <= elapsed_time * 1000 <= 1100 21 | 22 | def test_execute_task_with_timeout_time_limit_not_exceeded(self): 23 | start_time = time.time() 24 | 25 | def sleep_1(): 26 | sleep(1) 27 | 28 | execute_task_with_timeout(2000, sleep_1) 29 | elapsed_time = time.time() - start_time 30 | assert 1000 <= elapsed_time * 1000 <= 1100 31 | 32 | def test_execute_task_with_timeout_when_task_fails(self): 33 | def task(): 34 | raise FileSizeLimitExceededException() 35 | 36 | with self.assertRaises(FileSizeLimitExceededException) as e: 37 | execute_task_with_timeout(2000, task) 38 | -------------------------------------------------------------------------------- /src/config.py: -------------------------------------------------------------------------------- 1 | """Contain the configurations used in the package.""" 2 | import os 3 | 4 | from constants import UNSUPPORTED_FILE_HANDLING_VALID_VALUES, MASK_MODE_VALID_VALUES 5 | 6 | DOCUMENT_MAX_SIZE_CONTAINS_PII_ENTITIES = int(os.getenv('DOCUMENT_MAX_SIZE_CONTAINS_PII_ENTITIES', 50 * 1000)) # 50 KB 7 | DOCUMENT_MAX_SIZE_DETECT_PII_ENTITIES = int(os.getenv('DOCUMENT_MAX_SIZE_DETECT_PII_ENTITIES', 5 * 1000)) # 5KB 8 | PII_ENTITY_TYPES = str(os.getenv('PII_ENTITY_TYPES', 'ALL')) 9 | IS_PARTIAL_OBJECT_SUPPORTED = str(os.getenv('IS_PARTIAL_OBJECT_SUPPORTED', 'false')).lower() == 'true' 10 | MASK_CHARACTER = str(os.getenv('MASK_CHARACTER', '*')) 11 | MASK_MODE = MASK_MODE_VALID_VALUES[os.getenv('MASK_MODE', MASK_MODE_VALID_VALUES.MASK.name)] 12 | SUBSEGMENT_OVERLAPPING_TOKENS = int(os.getenv('SUBSEGMENT_OVERLAPPING_TOKENS', 20)) 13 | DOCUMENT_MAX_SIZE = int(os.getenv('DOCUMENT_MAX_SIZE', 100 * 1024)) 14 | CONFIDENCE_THRESHOLD = float(os.getenv('CONFIDENCE_THRESHOLD', 0.5)) 15 | assert 0.5 <= CONFIDENCE_THRESHOLD <= 1.0, "CONFIDENCE_THRESHOLD is not within allowed range [0.5,1]" 16 | MAX_CHARS_OVERLAP = int(os.getenv('MAX_CHARS_OVERLAP', 200)) 17 | DEFAULT_LANGUAGE_CODE = str(os.getenv('DEFAULT_LANGUAGE_CODE', 'en')) 18 | REDACTION_API_ONLY = os.getenv('REDACTION_API_ONLY', 'false').lower() == 'true' 19 | 20 | UNSUPPORTED_FILE_HANDLING = UNSUPPORTED_FILE_HANDLING_VALID_VALUES[ 21 | os.getenv('UNSUPPORTED_FILE_HANDLING', UNSUPPORTED_FILE_HANDLING_VALID_VALUES.FAIL.name)] 22 | 23 | DETECT_PII_ENTITIES_THREAD_COUNT = int(os.getenv('DETECT_PII_ENTITIES_THREAD_COUNT', 8)) 24 | CONTAINS_PII_ENTITIES_THREAD_COUNT = int(os.getenv('CONTAINS_PII_ENTITIES_THREAD_COUNT', 20)) 25 | PUBLISH_CLOUD_WATCH_METRICS = os.getenv('PUBLISH_CLOUD_WATCH_METRICS', 'true').lower() == 'true' 26 | COMPREHEND_ENDPOINT_URL = None if os.getenv('COMPREHEND_ENDPOINT_URL', '') == '' else os.getenv('COMPREHEND_ENDPOINT_URL') 27 | 28 | LOG_LEVEL = os.getenv('LOG_LEVEL', 'INFO') 29 | -------------------------------------------------------------------------------- /test/load/access-control_load_test.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import time 3 | import uuid 4 | 5 | from botocore.exceptions import ClientError 6 | 7 | from load.load_test_base import BaseLoadTest 8 | 9 | 10 | class PiiAccessControlLoadTest(BaseLoadTest): 11 | BUILD_DIR = '.aws-sam/build/PiiAccessControlFunction/' 12 | PII_ENTITY_TYPES_IN_TEST_DOC = ['EMAIL', 'ADDRESS', 'NAME', 'PHONE', 'DATE_TIME'] 13 | 14 | @classmethod 15 | def setUpClass(cls) -> None: 16 | super().setUpClass() 17 | test_run_id = str(uuid.uuid4())[0:8] 18 | cls.lambda_function_arn = cls._create_function("pii_access_control_load_test", 'handler.pii_access_control_handler', 19 | test_run_id, {"LOG_LEVEL": "DEBUG"}, cls.lambda_role_arn, cls.BUILD_DIR) 20 | cls.s3ol_access_point_arn = cls._create_s3ol_access_point(test_run_id, cls.lambda_function_arn) 21 | cls._update_lambda_env_variables(cls.lambda_function_arn, {"LOG_LEVEL": "DEBUG", 22 | "DOCUMENT_MAX_SIZE": "1500000"}) 23 | logging.info(f"Created access point: {cls.s3ol_access_point_arn} for testing") 24 | 25 | @classmethod 26 | def tearDownClass(cls) -> None: 27 | cls.s3_ctrl.delete_access_point_for_object_lambda( 28 | AccountId=cls.account_id, 29 | Name=cls.s3ol_access_point_arn.split('/')[-1] 30 | ) 31 | 32 | cls.lambda_client.delete_function( 33 | FunctionName=cls.lambda_function_arn.split(':')[-1] 34 | ) 35 | super().tearDownClass() 36 | 37 | def tearDown(self) -> None: 38 | self._update_lambda_env_variables(self.lambda_function_arn, {"LOG_LEVEL": "DEBUG"}) 39 | 40 | def test_access_control_lambda_with_varying_load(self): 41 | document_sizes = [1, 5, 50, 1000, 1500] 42 | for size in document_sizes: 43 | self.find_max_tpm(self.s3ol_access_point_arn,size, True, ClientError) 44 | time.sleep(90) 45 | -------------------------------------------------------------------------------- /src/exceptions.py: -------------------------------------------------------------------------------- 1 | """Custom defined exceptions.""" 2 | 3 | 4 | class CustomException(Exception): 5 | """Exceptions which are generated because of non-compliance of business constraints of this lambda.""" 6 | 7 | 8 | class UnsupportedFileException(CustomException): 9 | """Exception generated when we encounter an unsupported file format. For e.g. an image file.""" 10 | 11 | def __init__(self, file_content, http_headers, *args, **kwargs): 12 | super().__init__(*args) 13 | self.file_content = file_content 14 | self.http_headers = http_headers 15 | 16 | 17 | class FileSizeLimitExceededException(CustomException): 18 | """ 19 | Exception representing file size beyond the supported limits. 20 | Files beyond this size are prone to take too long to process causing timeouts. 21 | """ 22 | 23 | pass 24 | 25 | 26 | class TimeoutException(CustomException): 27 | """Exception raised when some task is not able to complete within a certain time limit.""" 28 | 29 | pass 30 | 31 | 32 | class InvalidConfigurationException(CustomException): 33 | """Exception representing an incorrect configuration of the access point such as incorrect function payload structure.""" 34 | 35 | def __init__(self, message, *args, **kwargs): 36 | super().__init__(*args) 37 | self.message = message 38 | 39 | 40 | class InvalidRequestException(CustomException): 41 | """Exception representing an invalid request.""" 42 | 43 | def __init__(self, message, *args, **kwargs): 44 | super().__init__(*args) 45 | self.message = message 46 | 47 | 48 | class S3DownloadException(CustomException): 49 | """Exception representing an error occurring during downloading from the presigned url.""" 50 | 51 | def __init__(self, s3_error_code, s3_message, *args, **kwargs): 52 | super().__init__(*args) 53 | self.s3_error_code = s3_error_code 54 | self.s3_message = s3_message 55 | 56 | 57 | class RestrictedDocumentException(CustomException): 58 | """Exception representing a restricted document throw when it contains pii.""" 59 | -------------------------------------------------------------------------------- /test/data/sample_event.json: -------------------------------------------------------------------------------- 1 | { 2 | "xAmzRequestId": "FEDCBA0987654321", 3 | "getObjectContext": { 4 | "inputS3Url": "https://pii-document-for-banner.s3.amazonaws.com/SomeText?AWSAccessKeyId=ASIAZI7HWWFYYJMPESP7&Signature=x%2B7jhKr7N2e%2FdNAitb%2F3RrtDJ2o%3D&x-amz-security-token=IQoJb3JpZ2luX2VjEFYaCXVzLWVhc3QtMSJGMEQCIACRXcrjeRvpIhWgaOOmBAv8FDCjdLjIQtlmG4ZvrvLAiA0h%2B%2Bkxk0tHuHNWl0cXMJ3oIpK3PNRWNk9IcstgK8ppyqnAgiv%2F%2F%2F%2F%2F%2F%2F%2F%2F%2F8BEAEaDDYzNzc1MTcwMTg3MyIMoxQ1S3IApjAJO6UTKvsBESuzIXwqKlu64hEGDz%2Bmw5oSoPxAvjR1cmbLCf2SYIdIdEBQpuMqHaXFGjHFsW4%2BcppiYXZlxcqOG8jZn79zd%2FvhHzTzB2vGjD2ooK7bPk9D3hx4voJzsPpNdsIPIba4DwYZlNbBBpCIhHcrGdKuhZ6UlI4wbGMA%2FE646mLQXK5H6pjlrZmKbaMlS9%2BdPQagSkqYioLN%2FWmya4rMag9h06IOyhP%2BLrQqldLLblod1otzfjf1Klc4KWJPm2StzTR85MI2zN2DHQ0WezZn6CeT80rM4te2Vyc0nW%2BJaA1msH2A0Yg57JEJx0ZzTifvfCXDVjds4bXYKVNUCBMwy%2FqB%2FQU6ngF0xA7fDEkIsHRFluXzDsESrlomBiSOD%2FayvNvgWHVXA5YW4g0hODrLE%2BZtXr5RhdMu7BSe6HQvrfm4pLfEDZczu9A9%2Fwc9V27vGkaJQHeNFO%2FBALDFHn4mkGuS2OVlh9uOF78XqrSKkmRB4B4RIE7dp54ODaoGaTp9RVKnTnr8K3KF5i8s%2Fl6T5ukl%2FU29ja6EZ6607ea00YIdln81gA%3D%3D&Expires=1604364707", 5 | "outputRoute": "io-a1c1d6c7", 6 | "outputToken": "ReallyLongEncryptedToken" 7 | }, 8 | "configuration": { 9 | "accessPointArn": "arn:aws:s3-object-lambda:us-east-1:111222333444:accesspoint/my-banner-ap", 10 | "supportingAccessPointArn": "arn:aws:s3:us-east-1:123456789012:accesspoint/existing-ap", 11 | "payload": "{\"pii_entity_types\" : [\"ALL\",\"CREDIT_DEBIT_NUMBER\"],\"mask_mode\":\"MASK\", \"mask_character\" : \"*\",\"confidence_threshold\":0.6,\"language_code\":\"en\"}" 12 | }, 13 | "userRequest": { 14 | "url": "https://my-banner-ap-111222333444.s3-banner.us-east-1.amazonaws.com/foo?QueryParam=2", 15 | "headers": { 16 | "Content-type": "application/txt", 17 | "CustomHeader": "ClientSpecifiedValue" 18 | } 19 | }, 20 | "userIdentity": { 21 | "accountId": "111222333444", 22 | "principalId": "AIDAJ45Q7YFFAREXAMPLE", 23 | "arn": "arn:aws:iam::111222333444:user/Alice", 24 | "groups": "Finance,Users" 25 | }, 26 | "protocolVersion": "1.00" 27 | } 28 | -------------------------------------------------------------------------------- /src/data_object.py: -------------------------------------------------------------------------------- 1 | """Module containing some custom data structures .""" 2 | 3 | import os 4 | from typing import List 5 | 6 | from exceptions import InvalidConfigurationException 7 | 8 | 9 | class PiiConfig: 10 | """PiiConfig class represents the base config for classification and redaction.""" 11 | 12 | def __init__(self, pii_entity_types: List = None, 13 | confidence_threshold: float = os.getenv('CONFIDENCE_THRESHOLD', 0.5), 14 | **kwargs): 15 | self.pii_entity_types = pii_entity_types 16 | self.confidence_threshold = float(confidence_threshold) 17 | if not 0.5 <= self.confidence_threshold <= 1.0: 18 | raise InvalidConfigurationException('CONFIDENCE_THRESHOLD is not within allowed range [0.5,1]') 19 | if self.pii_entity_types is None: 20 | self.pii_entity_types = os.getenv('PII_ENTITY_TYPES', 'ALL').split(',') 21 | 22 | 23 | class ClassificationConfig(PiiConfig): 24 | """ClassificationConfig class represents the config to be used for classification.""" 25 | 26 | def __init__(self, pii_entity_types: List = None, 27 | confidence_threshold: float = os.getenv('CONFIDENCE_THRESHOLD', 0.5), 28 | **kwargs): 29 | super().__init__(pii_entity_types, confidence_threshold, **kwargs) 30 | 31 | 32 | class RedactionConfig(ClassificationConfig): 33 | """RedactionConfig class represents the config to be used for redaction.""" 34 | 35 | def __init__(self, pii_entity_types: List = None, mask_mode: str = os.getenv('MASK_MODE', 'MASK'), 36 | mask_character: str = os.getenv('MASK_CHARACTER', '*'), 37 | confidence_threshold: float = os.getenv('CONFIDENCE_THRESHOLD', 0.5), 38 | **kwargs): 39 | super().__init__(pii_entity_types, confidence_threshold, **kwargs) 40 | self.mask_character = mask_character 41 | self.mask_mode = mask_mode 42 | 43 | 44 | class Document: 45 | """A chunk of text.""" 46 | 47 | def __init__(self, text: str, char_offset: int = 0, pii_classification: map = {}, 48 | pii_entities: List = [], redacted_text: str = ''): 49 | self.text = text 50 | self.char_offset = char_offset 51 | self.pii_classification = pii_classification 52 | self.pii_entities = pii_entities 53 | self.redacted_text = redacted_text 54 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | SHELL := /bin/sh 2 | PY_VERSION := 3.8 3 | 4 | export PYTHONUNBUFFERED := 1 5 | 6 | SRC_DIR := src 7 | SAM_DIR := .aws-sam 8 | 9 | # Required environment variables (user must override) 10 | 11 | # S3 bucket used for packaging SAM templates 12 | PACKAGE_BUCKET ?= comprehend-s3-object-lambdas 13 | 14 | # user can optionally override the following by setting environment variables with the same names before running make 15 | 16 | # Path to system pip 17 | PIP ?= pip 18 | # Default AWS CLI region 19 | AWS_DEFAULT_REGION ?= us-west-2 20 | 21 | PYTHON := $(shell /usr/bin/which python$(PY_VERSION)) 22 | 23 | .DEFAULT_GOAL := build 24 | 25 | clean: 26 | rm -f $(SRC_DIR)/requirements.txt 27 | rm -rf $(SAM_DIR) 28 | 29 | # used once just after project creation to lock and install dependencies 30 | bootstrap: 31 | $(PYTHON) -m $(PIP) install pipenv 32 | pipenv lock 33 | pipenv sync --dev 34 | 35 | # used by CI build to install dependencies 36 | init: 37 | $(PYTHON) -m $(PIP) install aws-sam-cli 38 | $(PYTHON) -m $(PIP) install pipenv 39 | pipenv sync --dev 40 | pipenv lock --requirements > $(SRC_DIR)/requirements.txt 41 | 42 | build: 43 | pipenv run flake8 $(SRC_DIR) 44 | pipenv run pydocstyle $(SRC_DIR) 45 | pipenv run cfn-lint $(LAMBDA_NAME)-template.yml 46 | sam build --profile sar-account --template $(LAMBDA_NAME)-template.yml 47 | mv $(SAM_DIR)/build/template.yaml $(SAM_DIR)/build/$(LAMBDA_NAME)-template.yml 48 | 49 | unit-testing: build 50 | pipenv run py.test --cov=$(SRC_DIR) --cov-fail-under=97 -vv test/unit -s --cov-report html 51 | 52 | # can be triggered as `make integ-testing LAMBDA_NAME=access-control` 53 | integ-testing: unit-testing 54 | pipenv run py.test -s -vv test/integ/test_$(LAMBDA_NAME).py 55 | 56 | load-testing: 57 | pipenv run py.test -s -vv test/load/$(LAMBDA_NAME)_load_test.py --log-cli-level=INFO 58 | 59 | package: 60 | sam package --region us-east-1 --profile sar-account --template $(SAM_DIR)/build/$(LAMBDA_NAME)-template.yml --s3-bucket $(PACKAGE_BUCKET) --output-template-file $(SAM_DIR)/packaged-$(LAMBDA_NAME)-template.yml 61 | 62 | deploy: package 63 | sam deploy --profile sar-account --region us-east-1 --template-file $(SAM_DIR)/packaged-$(LAMBDA_NAME)-template.yml --capabilities CAPABILITY_IAM --stack-name $(LAMBDA_NAME)-lambda 64 | 65 | publish: package 66 | sam publish --region us-east-1 --template $(SAM_DIR)/packaged-$(LAMBDA_NAME)-template.yml --profile sar-account 67 | -------------------------------------------------------------------------------- /test/unit/test_cloudwatch_client.py: -------------------------------------------------------------------------------- 1 | from unittest import TestCase 2 | from unittest.mock import patch, MagicMock 3 | 4 | from clients.cloudwatch_client import CloudWatchClient 5 | 6 | S3OL_ACCESS_POINT_TEST = "arn:aws:s3-object-lambda:us-east-1:000000000000:accesspoint/myPiiAp" 7 | 8 | 9 | class CloudWatchClientTest(TestCase): 10 | @patch('clients.cloudwatch_client.boto3') 11 | def test_cloudwatch_client_put_pii_document_processed_metric(self, mocked_boto3): 12 | mocked_client = MagicMock() 13 | mocked_boto3.client.return_value = mocked_client 14 | 15 | cloudwatch = CloudWatchClient() 16 | cloudwatch.put_pii_document_processed_metric('en', S3OL_ACCESS_POINT_TEST) 17 | 18 | mocked_client.put_metric_data.assert_called_once() 19 | 20 | @patch('clients.cloudwatch_client.boto3') 21 | def test_cloudwatch_client_put_document_processed_metric(self, mocked_boto3): 22 | mocked_client = MagicMock() 23 | mocked_boto3.client.return_value = mocked_client 24 | 25 | cloudwatch = CloudWatchClient() 26 | cloudwatch.put_document_processed_metric('en', S3OL_ACCESS_POINT_TEST) 27 | 28 | mocked_client.put_metric_data.assert_called_once() 29 | 30 | @patch('clients.cloudwatch_client.boto3') 31 | def test_cloudwatch_client_put_pii_document_types_metric(self, mocked_boto3): 32 | mocked_client = MagicMock() 33 | mocked_boto3.client.return_value = mocked_client 34 | 35 | cloudwatch = CloudWatchClient() 36 | cloudwatch.put_pii_document_types_metric(['SSN', 'PHONE'], 'en', S3OL_ACCESS_POINT_TEST) 37 | 38 | mocked_client.put_metric_data.assert_called_once() 39 | 40 | def test_segment_metrics_segmentation_required(self): 41 | cloudwatch = CloudWatchClient() 42 | metric_list = [i for i in range(0, 100)] 43 | chunks = cloudwatch.segment_metric_data(metric_list) 44 | total_metrics = 0 45 | for chunk in chunks: 46 | assert len(chunk) <= cloudwatch.MAX_METRIC_DATA 47 | total_metrics += len(chunk) 48 | assert total_metrics == 100 49 | 50 | def test_segment_metrics_segmentation_not_required(self): 51 | cloudwatch = CloudWatchClient() 52 | metric_list = [i for i in range(0, 10)] 53 | chunks = cloudwatch.segment_metric_data(metric_list) 54 | total_metrics = 0 55 | for chunk in chunks: 56 | assert len(chunk) <= cloudwatch.MAX_METRIC_DATA 57 | total_metrics += len(chunk) 58 | assert total_metrics == 10 59 | -------------------------------------------------------------------------------- /src/validators.py: -------------------------------------------------------------------------------- 1 | """Represents the different validations we would use.""" 2 | import json 3 | 4 | from config import IS_PARTIAL_OBJECT_SUPPORTED 5 | from constants import REQUEST_ID, REQUEST_TOKEN, REQUEST_ROUTE, GET_OBJECT_CONTEXT, S3OL_CONFIGURATION, INPUT_S3_URL, PAYLOAD, \ 6 | PART_NUMBER, RANGE, USER_REQUEST, HEADERS 7 | from exceptions import InvalidConfigurationException, InvalidRequestException 8 | 9 | 10 | class Validator: 11 | """Generic validator class representing one container which does a particular type of validation.""" 12 | 13 | @staticmethod 14 | def validate(object): 15 | """Execute the validation.""" 16 | raise NotImplementedError 17 | 18 | 19 | class JsonValidator(Validator): 20 | """Simply validates that given string can be converted to Json object or not.""" 21 | 22 | @staticmethod 23 | def validate(json_string: str): 24 | """Simply validates that given string can be converted to Json object or not.""" 25 | try: 26 | 27 | json.loads(json_string) 28 | except ValueError: 29 | raise Exception("Invalid Json %s", json_string) 30 | 31 | 32 | class PartialObjectRequestValidator(Validator): 33 | """Validates that the GetObject request is not for a partial object.""" 34 | 35 | @staticmethod 36 | def validate(input_event: str): 37 | """Perform the validation.""" 38 | RESTRICTED_HEADERS = [RANGE, PART_NUMBER] 39 | if not IS_PARTIAL_OBJECT_SUPPORTED: 40 | if HEADERS in input_event[USER_REQUEST]: 41 | for header in input_event[USER_REQUEST][HEADERS]: 42 | if header in RESTRICTED_HEADERS: 43 | raise InvalidRequestException(f"HTTP Header {header} is not supported") 44 | 45 | 46 | class InputEventValidator(Validator): 47 | """Validate the main lambda input.""" 48 | 49 | @staticmethod 50 | def validate(event): 51 | """Validate the main lambda input.""" 52 | # validations on parts of the event S3 control 53 | assert S3OL_CONFIGURATION in event 54 | assert GET_OBJECT_CONTEXT in event 55 | assert REQUEST_TOKEN in event[GET_OBJECT_CONTEXT] 56 | assert REQUEST_ROUTE in event[GET_OBJECT_CONTEXT] 57 | assert REQUEST_ID in event 58 | assert INPUT_S3_URL in event[GET_OBJECT_CONTEXT] 59 | assert PAYLOAD in event[S3OL_CONFIGURATION] 60 | 61 | # parts of the event derived from access point configuration 62 | try: 63 | if event[S3OL_CONFIGURATION][PAYLOAD]: 64 | JsonValidator.validate(event[S3OL_CONFIGURATION][PAYLOAD]) 65 | except Exception: 66 | raise InvalidConfigurationException(f"Invalid function payload: {event[S3OL_CONFIGURATION][PAYLOAD]}") 67 | -------------------------------------------------------------------------------- /test/load/redaction_load_test.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import time 3 | import uuid 4 | 5 | from load.load_test_base import BaseLoadTest 6 | 7 | 8 | class PiiRedactionLoadTest(BaseLoadTest): 9 | BUILD_DIR = '.aws-sam/build/PiiRedactionFunction/' 10 | PII_ENTITY_TYPES_IN_TEST_DOC = ['EMAIL', 'ADDRESS', 'NAME', 'PHONE', 'DATE_TIME'] 11 | 12 | @classmethod 13 | def setUpClass(cls) -> None: 14 | super().setUpClass() 15 | test_run_id = str(uuid.uuid4())[0:8] 16 | cls.lambda_function_arn = cls._create_function("pii_redaction_load_test", 'handler.redact_pii_documents_handler', 17 | test_run_id, {"LOG_LEVEL": "DEBUG"}, cls.lambda_role_arn, cls.BUILD_DIR) 18 | cls.s3ol_access_point_arn = cls._create_s3ol_access_point(test_run_id, cls.lambda_function_arn) 19 | cls._update_lambda_env_variables(cls.lambda_function_arn, {"LOG_LEVEL": "DEBUG", 20 | "DOCUMENT_MAX_SIZE": "1600000"}) 21 | logging.info(f"Created access point: {cls.s3ol_access_point_arn} for testing") 22 | 23 | @classmethod 24 | def tearDownClass(cls) -> None: 25 | cls.s3_ctrl.delete_access_point_for_object_lambda( 26 | AccountId=cls.account_id, 27 | Name=cls.s3ol_access_point_arn.split('/')[-1] 28 | ) 29 | 30 | cls.lambda_client.delete_function( 31 | FunctionName=cls.lambda_function_arn.split(':')[-1] 32 | ) 33 | super().tearDownClass() 34 | 35 | def tearDown(self) -> None: 36 | self._update_lambda_env_variables(self.lambda_function_arn, {"LOG_LEVEL": "DEBUG"}) 37 | 38 | def test_redaction_lambda_with_varying_load(self): 39 | variations = [(1, True), 40 | (5, False), 41 | (5, True), 42 | (50, False), 43 | (50, True), 44 | (1000, False), 45 | (1000, True), 46 | (1500, False), 47 | (1500, True) 48 | ] 49 | for text_size, is_pii in variations: 50 | self.execute_load_on_get_object(self.s3ol_access_point_arn, text_size, is_pii, None) 51 | time.sleep(60) 52 | 53 | def test_redaction_lambda_find_max_conn(self): 54 | variations = [ 55 | (1, True), 56 | (5, False), 57 | (5, True), 58 | (50, False), 59 | (50, True), 60 | (1000, False), 61 | (1000, True), 62 | (1500, False), 63 | (1500, True) 64 | ] 65 | for text_size, is_pii in variations: 66 | self.find_max_tpm(self.s3ol_access_point_arn, text_size, is_pii, None) 67 | -------------------------------------------------------------------------------- /src/clients/cloudwatch_client.py: -------------------------------------------------------------------------------- 1 | """Client wrapper over aws services.""" 2 | 3 | from typing import List 4 | 5 | import boto3 6 | 7 | import lambdalogging 8 | from constants import CLOUD_WATCH_NAMESPACE, LANGUAGE, COUNT, PII_DOCUMENTS_PROCESSED, DOCUMENTS_PROCESSED, NAME, \ 9 | VALUE, S3OL_ACCESS_POINT, METRIC_NAME, UNIT, DIMENSIONS, PII_DOCUMENT_TYPES_PROCESSED, PII_ENTITY_TYPE, \ 10 | LATENCY, API, SERVICE, ERROR_COUNT, MILLISECONDS 11 | 12 | LOG = lambdalogging.getLogger(__name__) 13 | 14 | 15 | class Metrics: 16 | """Metrics class for latency and fault counts.""" 17 | 18 | def __init__(self, service_name, api, s3ol_access_point, cloudwatch_namespace=CLOUD_WATCH_NAMESPACE): 19 | self.cloudwatch_namespace = cloudwatch_namespace 20 | self.service_name = service_name 21 | self.s3ol_access_point_arn = s3ol_access_point 22 | self.api = api 23 | self.metrics = [] 24 | 25 | def add_latency(self, start_time: float, end_time: float): 26 | """Add a latency metric.""" 27 | self.metrics.append({METRIC_NAME: LATENCY, DIMENSIONS: [ 28 | {NAME: API, VALUE: self.api}, 29 | {NAME: S3OL_ACCESS_POINT, VALUE: self.s3ol_access_point_arn}, 30 | {NAME: SERVICE, VALUE: self.service_name} 31 | ], UNIT: MILLISECONDS, VALUE: (end_time - start_time) * 1000}) 32 | 33 | def add_fault_count(self, count: int = 1): 34 | """Add a fault count metric.""" 35 | self.metrics.append({METRIC_NAME: ERROR_COUNT, DIMENSIONS: [ 36 | {NAME: API, VALUE: self.api}, 37 | {NAME: S3OL_ACCESS_POINT, VALUE: self.s3ol_access_point_arn}, 38 | {NAME: SERVICE, VALUE: self.service_name} 39 | ], UNIT: COUNT, VALUE: count}) 40 | 41 | 42 | class CloudWatchClient: 43 | """Wrapper over cloudwatch client.""" 44 | 45 | MAX_METRIC_DATA = 15 46 | 47 | def __init__(self): 48 | self.cloudwatch = boto3.client('cloudwatch') 49 | 50 | def segment_metric_data(self, metric_list: List): 51 | """Segments a list of arbitrary length into a list of lists each of size MAX_METRIC_DATA.""" 52 | list_len = len(metric_list) 53 | if list_len <= self.MAX_METRIC_DATA: 54 | return [metric_list] 55 | remaining_list_len = list_len % self.MAX_METRIC_DATA 56 | chunks = [metric_list[x:x + self.MAX_METRIC_DATA] for x in range(0, list_len - remaining_list_len, self.MAX_METRIC_DATA)] 57 | chunks.append(metric_list[-remaining_list_len:]) 58 | return chunks 59 | 60 | def publish_metrics(self, metric_list: List): 61 | """Publish the metrics to CloudWatch.""" 62 | for metrics in self.segment_metric_data(metric_list): 63 | self.cloudwatch.put_metric_data(MetricData=metrics, Namespace=CLOUD_WATCH_NAMESPACE) 64 | 65 | def put_pii_document_processed_metric(self, language: str, s3ol_access_point: str): 66 | """Put PiiDocumentsProcessed metric.""" 67 | self.publish_metrics([{METRIC_NAME: PII_DOCUMENTS_PROCESSED, DIMENSIONS: [ 68 | {NAME: LANGUAGE, VALUE: language}, 69 | {NAME: S3OL_ACCESS_POINT, VALUE: s3ol_access_point} 70 | ], UNIT: COUNT, VALUE: 1.0}]) 71 | 72 | def put_document_processed_metric(self, language: str, s3ol_access_point: str): 73 | """Put DocumentsProcessed metric.""" 74 | self.publish_metrics([{METRIC_NAME: DOCUMENTS_PROCESSED, DIMENSIONS: [ 75 | {NAME: LANGUAGE, VALUE: language}, 76 | {NAME: S3OL_ACCESS_POINT, VALUE: s3ol_access_point} 77 | ], UNIT: COUNT, VALUE: 1.0}]) 78 | 79 | def put_pii_document_types_metric(self, pii_entity_types: List[str], language: str, s3ol_access_point: str): 80 | """Put PiiDocumentTypesProcessed metric.""" 81 | self.publish_metrics([{METRIC_NAME: PII_DOCUMENT_TYPES_PROCESSED, DIMENSIONS: [ 82 | {NAME: PII_ENTITY_TYPE, VALUE: pii_entity_type}, 83 | {NAME: S3OL_ACCESS_POINT, VALUE: s3ol_access_point}, 84 | {NAME: LANGUAGE, VALUE: language} 85 | ], UNIT: COUNT, VALUE: 1.0} for pii_entity_type in pii_entity_types]) 86 | -------------------------------------------------------------------------------- /test/unit/test_validators.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | from copy import deepcopy 4 | from unittest import TestCase 5 | from unittest.mock import patch 6 | 7 | from config import IS_PARTIAL_OBJECT_SUPPORTED 8 | from constants import S3OL_CONFIGURATION, GET_OBJECT_CONTEXT, REQUEST_TOKEN, REQUEST_ROUTE, REQUEST_ID, INPUT_S3_URL, PAYLOAD, \ 9 | USER_REQUEST, HEADERS, RANGE 10 | from exceptions import InvalidConfigurationException, InvalidRequestException 11 | from validators import JsonValidator, Validator, InputEventValidator, PartialObjectRequestValidator 12 | 13 | this_module_path = os.path.dirname(__file__) 14 | 15 | 16 | class TestValidator(TestCase): 17 | def setUp(self) -> None: 18 | with open(os.path.join(this_module_path, "..", 'data', 'sample_event.json'), 'r') as file_pointer: 19 | self.sample_event = json.load(file_pointer) 20 | self.is_partial_object_supported = IS_PARTIAL_OBJECT_SUPPORTED 21 | 22 | def test_json_validator(self): 23 | with self.assertRaises(Exception) as context: 24 | JsonValidator.validate("{\"invalid\":json_string") 25 | 26 | def test_validator_interface(self): 27 | with self.assertRaises(NotImplementedError) as context: 28 | Validator.validate(None) 29 | 30 | def test_partial_object_validator_partial_object_requested(self): 31 | with self.assertRaises(InvalidRequestException) as context: 32 | self.sample_event[USER_REQUEST][HEADERS][RANGE] = "0-100" 33 | PartialObjectRequestValidator.validate(self.sample_event) 34 | assert context.exception.message == "HTTP Header Range is not supported" 35 | 36 | 37 | def test_partial_object_validator_complete_object_requested(self): 38 | temp = deepcopy(self.sample_event) 39 | PartialObjectRequestValidator.validate(temp) 40 | 41 | @patch('validators.IS_PARTIAL_OBJECT_SUPPORTED', True) 42 | def test_partial_object_validator_when_partial_object_is_supported(self): 43 | self.sample_event[USER_REQUEST][HEADERS][RANGE] = "0-100" 44 | PartialObjectRequestValidator.validate(self.sample_event) 45 | 46 | def test_input_event_validation_empty_input(self): 47 | with self.assertRaises(AssertionError) as context: 48 | InputEventValidator.validate({}) 49 | 50 | def test_input_event_validation_missing_s3ol_config(self): 51 | with self.assertRaises(AssertionError) as context: 52 | invalid_event = deepcopy(self.sample_event) 53 | del invalid_event[S3OL_CONFIGURATION] 54 | InputEventValidator.validate(invalid_event) 55 | 56 | def test_input_event_validation_missing_object_context(self): 57 | with self.assertRaises(AssertionError) as context: 58 | invalid_event = deepcopy(self.sample_event) 59 | del invalid_event[GET_OBJECT_CONTEXT] 60 | InputEventValidator.validate(invalid_event) 61 | 62 | def test_input_event_validation_missing_request_token(self): 63 | with self.assertRaises(AssertionError) as context: 64 | invalid_event = deepcopy(self.sample_event) 65 | del invalid_event[GET_OBJECT_CONTEXT][REQUEST_TOKEN] 66 | InputEventValidator.validate(invalid_event) 67 | 68 | def test_input_event_validation_missing_request_route(self): 69 | with self.assertRaises(AssertionError) as context: 70 | invalid_event = deepcopy(self.sample_event) 71 | del invalid_event[GET_OBJECT_CONTEXT][REQUEST_ROUTE] 72 | InputEventValidator.validate(invalid_event) 73 | 74 | def test_input_event_validation_missing_request_id(self): 75 | with self.assertRaises(AssertionError) as context: 76 | invalid_event = deepcopy(self.sample_event) 77 | del invalid_event[REQUEST_ID] 78 | InputEventValidator.validate(invalid_event) 79 | 80 | def test_input_event_validation_missing_input_s3_url(self): 81 | with self.assertRaises(AssertionError) as context: 82 | invalid_event = deepcopy(self.sample_event) 83 | del invalid_event[GET_OBJECT_CONTEXT][INPUT_S3_URL] 84 | InputEventValidator.validate(invalid_event) 85 | 86 | def test_input_event_validation_missing_payload(self): 87 | with self.assertRaises(AssertionError) as context: 88 | invalid_event = deepcopy(self.sample_event) 89 | del invalid_event[S3OL_CONFIGURATION][PAYLOAD] 90 | InputEventValidator.validate(invalid_event) 91 | 92 | def test_input_event_validation_invalid_payload(self): 93 | with self.assertRaises(InvalidConfigurationException) as context: 94 | invalid_event = deepcopy(self.sample_event) 95 | invalid_event[S3OL_CONFIGURATION][PAYLOAD] = "Invalid json" 96 | InputEventValidator.validate(invalid_event) 97 | 98 | def test_input_event_validation_empty_payload(self): 99 | invalid_event = deepcopy(self.sample_event) 100 | invalid_event[S3OL_CONFIGURATION][PAYLOAD] = "" 101 | InputEventValidator.validate(invalid_event) 102 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | 2 | # Created by https://www.gitignore.io/api/osx,linux,python,windows,pycharm,visualstudiocode 3 | 4 | ### Linux ### 5 | *~ 6 | 7 | # temporary files which can be created if a process still has a handle open of a deleted file 8 | .fuse_hidden* 9 | 10 | # KDE directory preferences 11 | .directory 12 | 13 | # Linux trash folder which might appear on any partition or disk 14 | .Trash-* 15 | 16 | # .nfs files are created when an open file is removed but is still being accessed 17 | .nfs* 18 | 19 | ### OSX ### 20 | *.DS_Store 21 | .AppleDouble 22 | .LSOverride 23 | 24 | # Icon must end with two \r 25 | Icon 26 | 27 | # Thumbnails 28 | ._* 29 | 30 | # Files that might appear in the root of a volume 31 | .DocumentRevisions-V100 32 | .fseventsd 33 | .Spotlight-V100 34 | .TemporaryItems 35 | .Trashes 36 | .VolumeIcon.icns 37 | .com.apple.timemachine.donotpresent 38 | 39 | # Directories potentially created on remote AFP share 40 | .AppleDB 41 | .AppleDesktop 42 | Network Trash Folder 43 | Temporary Items 44 | .apdisk 45 | 46 | ### PyCharm ### 47 | # Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio and Webstorm 48 | # Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839 49 | 50 | # User-specific stuff: 51 | .idea/**/workspace.xml 52 | .idea/**/tasks.xml 53 | .idea/dictionaries 54 | 55 | # Sensitive or high-churn files: 56 | .idea/**/dataSources/ 57 | .idea/**/dataSources.ids 58 | .idea/**/dataSources.xml 59 | .idea/**/dataSources.local.xml 60 | .idea/**/sqlDataSources.xml 61 | .idea/**/dynamic.xml 62 | .idea/**/uiDesigner.xml 63 | 64 | # Gradle: 65 | .idea/**/gradle.xml 66 | .idea/**/libraries 67 | 68 | # CMake 69 | cmake-build-debug/ 70 | 71 | # Mongo Explorer plugin: 72 | .idea/**/mongoSettings.xml 73 | 74 | ## File-based project format: 75 | *.iws 76 | 77 | ## Plugin-specific files: 78 | 79 | # IntelliJ 80 | /out/ 81 | 82 | # mpeltonen/sbt-idea plugin 83 | .idea_modules/ 84 | 85 | # JIRA plugin 86 | atlassian-ide-plugin.xml 87 | 88 | # Cursive Clojure plugin 89 | .idea/replstate.xml 90 | 91 | # Ruby plugin and RubyMine 92 | /.rakeTasks 93 | 94 | # Crashlytics plugin (for Android Studio and IntelliJ) 95 | com_crashlytics_export_strings.xml 96 | crashlytics.properties 97 | crashlytics-build.properties 98 | fabric.properties 99 | 100 | ### PyCharm Patch ### 101 | # Comment Reason: https://github.com/joeblau/gitignore.io/issues/186#issuecomment-215987721 102 | 103 | # *.iml 104 | # modules.xml 105 | # .idea/misc.xml 106 | # *.ipr 107 | 108 | # Sonarlint plugin 109 | .idea/sonarlint 110 | 111 | ### Python ### 112 | # Byte-compiled / optimized / DLL files 113 | __pycache__/ 114 | *.py[cod] 115 | *$py.class 116 | 117 | # C extensions 118 | *.so 119 | 120 | # Distribution / packaging 121 | .Python 122 | build/ 123 | develop-eggs/ 124 | dist/ 125 | downloads/ 126 | eggs/ 127 | .eggs/ 128 | lib/ 129 | lib64/ 130 | parts/ 131 | sdist/ 132 | var/ 133 | wheels/ 134 | *.egg-info/ 135 | .installed.cfg 136 | *.egg 137 | /*.zip 138 | 139 | # PyInstaller 140 | # Usually these files are written by a python script from a template 141 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 142 | *.manifest 143 | *.spec 144 | 145 | # Installer logs 146 | pip-log.txt 147 | pip-delete-this-directory.txt 148 | 149 | # Unit test / coverage reports 150 | htmlcov/ 151 | .tox/ 152 | .coverage 153 | .coverage.* 154 | .cache 155 | .pytest_cache/ 156 | nosetests.xml 157 | coverage.xml 158 | *.cover 159 | .hypothesis/ 160 | 161 | # Translations 162 | *.mo 163 | *.pot 164 | 165 | # Flask stuff: 166 | instance/ 167 | .webassets-cache 168 | 169 | # Scrapy stuff: 170 | .scrapy 171 | 172 | # Sphinx documentation 173 | docs/_build/ 174 | 175 | # PyBuilder 176 | target/ 177 | 178 | # Jupyter Notebook 179 | .ipynb_checkpoints 180 | 181 | # pyenv 182 | .python-version 183 | 184 | # celery beat schedule file 185 | celerybeat-schedule.* 186 | 187 | # SageMath parsed files 188 | *.sage.py 189 | 190 | # Environments 191 | .env 192 | .venv 193 | env/ 194 | venv/ 195 | ENV/ 196 | env.bak/ 197 | venv.bak/ 198 | 199 | # Spyder project settings 200 | .spyderproject 201 | .spyproject 202 | 203 | # Rope project settings 204 | .ropeproject 205 | 206 | # mkdocs documentation 207 | /site 208 | 209 | # mypy 210 | .mypy_cache/ 211 | 212 | ### VisualStudioCode ### 213 | .vscode/* 214 | !.vscode/settings.json 215 | !.vscode/tasks.json 216 | !.vscode/launch.json 217 | !.vscode/extensions.json 218 | .history 219 | 220 | ### Windows ### 221 | # Windows thumbnail cache files 222 | Thumbs.db 223 | ehthumbs.db 224 | ehthumbs_vista.db 225 | 226 | # Folder config file 227 | Desktop.ini 228 | 229 | # Recycle Bin used on file shares 230 | $RECYCLE.BIN/ 231 | 232 | # Windows Installer files 233 | *.cab 234 | *.msi 235 | *.msm 236 | *.msp 237 | 238 | # Windows shortcuts 239 | *.lnk 240 | 241 | # Build folder 242 | 243 | */build/* 244 | 245 | # End of https://www.gitignore.io/api/osx,linux,python,windows,pycharm,visualstudiocode 246 | 247 | # vim, mac stuff, powerpoint temp files (I use powerpoint for architecture diagrams) 248 | .*.sw[op] 249 | .DS_Store 250 | ~$*.ppt* 251 | 252 | # SAM CLI build dir 253 | .aws-sam 254 | 255 | # We're using pipenv so any requirements.txt files are auto-generated and should be ignored by git 256 | requirements.txt 257 | -------------------------------------------------------------------------------- /test/load/load_test_base.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import time 4 | from concurrent.futures._base import as_completed 5 | from concurrent.futures.thread import ThreadPoolExecutor 6 | 7 | from integ.integ_base import BasicIntegTest 8 | 9 | 10 | class BaseLoadTest(BasicIntegTest): 11 | DEFAULT_PERIOD = 5 12 | DEFAULT_CONNECTIONS = 1 13 | 14 | def find_max_tpm(self, s3_olap_arn, file_size, is_pii, expected_error, error_percent_threshold=0.05, load_test_period=3): 15 | file_name = self.create_and_upload_text_file(file_size, is_pii) 16 | total_counts, failed_counts = self.run_desired_simulataneous_get_object_calls(s3_olap_arn, file_name, expected_error, 17 | period_in_minutes=load_test_period) 18 | current_tpm = total_counts / load_test_period 19 | previous_tpm = 0 20 | previous_best_connection_count = self.DEFAULT_CONNECTIONS 21 | connection_count = self.DEFAULT_CONNECTIONS 22 | while failed_counts / total_counts < error_percent_threshold and previous_tpm <= current_tpm: 23 | previous_tpm = current_tpm 24 | previous_best_connection_count = connection_count 25 | connection_count = previous_best_connection_count + 2 26 | total_counts, failed_counts = self.run_desired_simulataneous_get_object_calls(s3_olap_arn, file_name, expected_error, 27 | connection_counts=connection_count, 28 | period_in_minutes=load_test_period) 29 | current_tpm = total_counts / load_test_period 30 | 31 | logging.info( 32 | f"Best results tpm: {previous_tpm} with error rate :{failed_counts / total_counts * 100} % for file size {file_size} " 33 | f"are obtained with Connection count as {previous_best_connection_count}") 34 | 35 | def execute_load_on_get_object(self, s3_olap_arn, file_size, is_pii, expected_error): 36 | logging.info(f"Running Load Test for {file_size} KB file where pii is {'' if is_pii else 'not'} present") 37 | file_name = self.create_and_upload_text_file(file_size, is_pii) 38 | return self.run_desired_simulataneous_get_object_calls(s3_olap_arn, file_name, expected_error) 39 | 40 | def create_and_upload_text_file(self, desired_size_in_KB: int, is_pii: bool): 41 | file_name = str("pii_" if is_pii else "non_pii") + str(desired_size_in_KB) + "_KB" 42 | 43 | repeat_text = " Some Random Text ttt" 44 | with open(self.DATA_PATH + "/pii_input.txt") as pii_file: 45 | pii_text = pii_file.read() 46 | full_text = pii_text if is_pii else "" 47 | with open(file_name, 'w') as temp: 48 | while len(full_text) <= desired_size_in_KB * 1000: 49 | full_text += repeat_text 50 | temp.write(full_text) 51 | self.s3_client.upload_file(file_name, self.bucket_name, file_name) 52 | os.remove(file_name) 53 | return file_name 54 | 55 | def run_desired_simulataneous_get_object_calls(self, s3_olap_arn, file_name, expected_error, connection_counts=DEFAULT_CONNECTIONS, 56 | period_in_minutes=DEFAULT_PERIOD, ): 57 | logging.info( 58 | f"Running Load Test for file : {file_name} for period {period_in_minutes} with connection counts: {connection_counts}") 59 | s = ThreadPoolExecutor(max_workers=connection_counts) 60 | futures = [s.submit(self.fail_safe_fetch_s3_object, s3_olap_arn, file_name, period=period_in_minutes * 60, 61 | expected_error=expected_error) for i in 62 | range(0, connection_counts)] 63 | 64 | total_counts = 0 65 | successful_counts = 0 66 | average_latency = 0 67 | for f in as_completed(futures): 68 | successful_counts += f.result()[1] 69 | total_counts += f.result()[0] + f.result()[1] 70 | average_latency = f.result()[2] / total_counts 71 | 72 | logging.info(f" Total calls made: {total_counts}, out of which {successful_counts} calls were successful." 73 | f" ({successful_counts / total_counts * 100}%) ,Average Latency {average_latency}") 74 | return total_counts, total_counts - successful_counts, average_latency 75 | 76 | def fail_safe_fetch_s3_object(self, s3ol_ap_arn, file, period, expected_error=None): 77 | start_time = time.time() 78 | failed_calls = 0 79 | successful_calls = 0 80 | total_time = 0 81 | while (time.time() - start_time) <= period: 82 | try: 83 | api_call_st_time = time.time() 84 | response = self.s3_client.get_object(Bucket=s3ol_ap_arn, Key=file) 85 | total_time += time.time() - api_call_st_time 86 | successful_calls += 1 87 | except Exception as e: 88 | total_time += time.time() - api_call_st_time 89 | if expected_error: 90 | if isinstance(e, expected_error): 91 | successful_calls += 1 92 | else: 93 | failed_calls += 1 94 | else: 95 | # logging.error(e) 96 | failed_calls += 1 97 | return failed_calls, successful_calls, total_time 98 | -------------------------------------------------------------------------------- /access-control-template.yml: -------------------------------------------------------------------------------- 1 | AWSTemplateFormatVersion: '2010-09-09' 2 | Transform: AWS::Serverless-2016-10-31 3 | 4 | Metadata: 5 | AWS::ServerlessRepo::Application: 6 | Name: ComprehendPiiAccessControlS3ObjectLambda 7 | Description: Deploys a Lambda which will provide capability to control access to text files with PII (Personally Identifiable Information). This Lambda can be used as a s3 object lambda which will be triggered on get-object call when configured with access point 8 | Author: AWS Comprehend 9 | # SPDX License Id, e.g., MIT, MIT-0, Apache-2.0. See https://spdx.org/licenses for more details 10 | SpdxLicenseId: MIT-0 11 | LicenseUrl: LICENSE 12 | ReadmeUrl: ACCESS_CONTROL_README.md 13 | Labels: [serverless,comprehend,nlp,pii] 14 | HomePageUrl: https://aws.amazon.com/comprehend/ 15 | SemanticVersion: 1.0.2 16 | SourceCodeUrl: https://github.com/aws-samples/amazon-comprehend-s3-object-lambda-functions 17 | 18 | Parameters: 19 | LogLevel: 20 | Type: String 21 | Description: Log level for Lambda function logging, e.g., ERROR, INFO, DEBUG, etc. 22 | Default: INFO 23 | UnsupportedFileHandling: 24 | Type: String 25 | Description: Handling logic for Unsupported files. Valid values are PASS and FAIL. 26 | Default: FAIL 27 | IsPartialObjectSupported: 28 | Type: String 29 | Description: Whether to support partial objects or not. Accessing partial object through http headers such byte-range can corrupt the object and/or affect PII detection accuracy. 30 | Default: FALSE 31 | DocumentMaxSizeContainsPiiEntities: 32 | Type: Number 33 | Description: Maximum document size (in bytes) to be used for making calls to Comprehend's ContainsPiiEntities API. 34 | Default: 50000 35 | PiiEntityTypes: 36 | Type: String 37 | Description: List of comma separated PII entity types to be considered for access control. Refer Comprehend's documentation page for list of supported PII entity types. 38 | Default: ALL 39 | SubsegmentOverlappingTokens: 40 | Type: Number 41 | Description: Number of tokens/words to overlap among segments of a document in case chunking is needed because of maximum document size limit. 42 | Default: 20 43 | DocumentMaxSize: 44 | Type: Number 45 | Description: Default maximum document size (in bytes) that this function can process otherwise will throw exception for too large document size. 46 | Default: 102400 47 | ConfidenceThreshold: 48 | Type: Number 49 | Description: The minimum prediction confidence score above which PII classification and detection would be considered as final answer. Valid range (0.5 to 1.0). 50 | Default: 0.5 51 | MaxCharsOverlap: 52 | Type: Number 53 | Description: Maximum characters to overlap among segments of a document in case chunking is needed because of maximum document size limit. 54 | Default: 200 55 | DefaultLanguageCode: 56 | Type: String 57 | Description: Default language of the text to be processed. This code will be used for interacting with Comprehend. 58 | Default: en 59 | ContainsPiiEntitiesThreadCount: 60 | Type: Number 61 | Description: Number of threads to use for calling Comprehend's ContainsPiiEntities API. This controls the number of simultaneous calls that will be made from this Lambda. 62 | Default: 20 63 | PublishCloudWatchMetrics: 64 | Type: String 65 | Description: True if publish metrics to Cloudwatch, false otherwise. See README.md for details on CloudWatch metrics. 66 | Default: True 67 | 68 | Resources: 69 | PiiAccessControlFunction: 70 | Type: AWS::Serverless::Function 71 | Properties: 72 | CodeUri: src/ 73 | Handler: handler.pii_access_control_handler 74 | Runtime: python3.8 75 | Tracing: Active 76 | Timeout: 60 77 | Policies: 78 | - Statement: 79 | - Sid: ComprehendPiiDetectionPolicy 80 | Effect: Allow 81 | Action: 82 | - comprehend:ContainsPiiEntities 83 | Resource: '*' 84 | - Sid: S3AccessPointCallbackPolicy 85 | Effect: Allow 86 | Action: 87 | - s3-object-lambda:WriteGetObjectResponse 88 | Resource: '*' 89 | - Sid: CloudWatchMetricsPolicy 90 | Effect: Allow 91 | Action: 92 | - cloudwatch:PutMetricData 93 | Resource: '*' 94 | Environment: 95 | Variables: 96 | LOG_LEVEL: !Ref LogLevel 97 | UNSUPPORTED_FILE_HANDLING: !Ref UnsupportedFileHandling 98 | IS_PARTIAL_OBJECT_SUPPORTED: !Ref IsPartialObjectSupported 99 | DOCUMENT_MAX_SIZE_CONTAINS_PII_ENTITIES: !Ref DocumentMaxSizeContainsPiiEntities 100 | PII_ENTITY_TYPES: !Ref PiiEntityTypes 101 | SUBSEGMENT_OVERLAPPING_TOKENS: !Ref SubsegmentOverlappingTokens 102 | DOCUMENT_MAX_SIZE: !Ref DocumentMaxSize 103 | CONFIDENCE_THRESHOLD: !Ref ConfidenceThreshold 104 | MAX_CHARS_OVERLAP: !Ref MaxCharsOverlap 105 | DEFAULT_LANGUAGE_CODE: !Ref DefaultLanguageCode 106 | CONTAINS_PII_ENTITIES_THREAD_COUNT: !Ref ContainsPiiEntitiesThreadCount 107 | PUBLISH_CLOUD_WATCH_METRICS: !Ref PublishCloudWatchMetrics 108 | 109 | Outputs: 110 | PiiAccessControlFunctionName: 111 | Description: "PII Access Control Function Name" 112 | Value: !Ref PiiAccessControlFunction 113 | -------------------------------------------------------------------------------- /src/exception_handlers.py: -------------------------------------------------------------------------------- 1 | """Classes for handling exceptions.""" 2 | import lambdalogging 3 | from clients.s3_client import S3Client 4 | from config import UNSUPPORTED_FILE_HANDLING 5 | from constants import UNSUPPORTED_FILE_HANDLING_VALID_VALUES, S3_STATUS_CODES, S3_ERROR_CODES, error_code_to_enums 6 | from exceptions import UnsupportedFileException, FileSizeLimitExceededException, S3DownloadException, InvalidConfigurationException, \ 7 | InvalidRequestException, RestrictedDocumentException, TimeoutException 8 | 9 | LOG = lambdalogging.getLogger(__name__) 10 | 11 | 12 | class ExceptionHandler: 13 | """Handler enclosing an action to be taken in case of an error occurred while processing files.""" 14 | 15 | def __init__(self, s3_client: S3Client): 16 | self.s3_client = s3_client 17 | 18 | def handle_exception(self, exception: BaseException, request_route: str, request_token: str): 19 | """Handle exception and take appropriate actions.""" 20 | try: 21 | raise exception 22 | except UnsupportedFileException as e: 23 | self._handle_unsupported_file_exception(e, request_route, request_token) 24 | except InvalidConfigurationException as e: 25 | LOG.error(f"Encountered an invalid configuration setup. {e}", exc_info=True) 26 | self.s3_client.respond_back_with_error(S3_STATUS_CODES.BAD_REQUEST_400, 27 | S3_ERROR_CODES.InvalidRequest, 28 | "Lambda function has been incorrectly setup", request_route, 29 | request_token) 30 | except FileSizeLimitExceededException: 31 | LOG.info( 32 | f"File size of the requested object exceeds maximum file size supported. Responding back with" 33 | f"error: {S3_STATUS_CODES.BAD_REQUEST_400.name} ") 34 | self.s3_client.respond_back_with_error(S3_STATUS_CODES.BAD_REQUEST_400, S3_ERROR_CODES.EntityTooLarge, 35 | "Size of the requested object exceeds maximum file size supported", request_route, 36 | request_token) 37 | except InvalidRequestException as e: 38 | LOG.info(f"Encountered an invalid request {e}", exc_info=True) 39 | self.s3_client.respond_back_with_error(S3_STATUS_CODES.BAD_REQUEST_400, S3_ERROR_CODES.InvalidRequest, 40 | e.message, request_route, request_token) 41 | except S3DownloadException as e: 42 | LOG.error(f"Error downloading from presigned url. {e}", exc_info=True) 43 | status_code, error_code = error_code_to_enums(e.s3_error_code) 44 | self.s3_client.respond_back_with_error(status_code, error_code, e.s3_message, 45 | request_route, request_token) 46 | except RestrictedDocumentException as e: 47 | LOG.error(f"Document contains pii. {e}", exc_info=True) 48 | self.s3_client.respond_back_with_error(S3_STATUS_CODES.FORBIDDEN_403, 49 | S3_ERROR_CODES.AccessDenied, 50 | "Document Contains PII", 51 | request_route, request_token) 52 | except TimeoutException as e: 53 | LOG.error(f"Couldn't complete processing within the time limit. {e}", exc_info=True) 54 | self.s3_client.respond_back_with_error(S3_STATUS_CODES.BAD_REQUEST_400, 55 | S3_ERROR_CODES.RequestTimeout, 56 | "Failed to complete document processing within time limit", 57 | request_route, request_token) 58 | except Exception as e: 59 | LOG.error(f"Internal error {e} occurred while processing the file", exc_info=True) 60 | self.s3_client.respond_back_with_error(S3_STATUS_CODES.INTERNAL_SERVER_ERROR_500, S3_ERROR_CODES.InternalError, 61 | "An internal error occurred while processing the file", request_route, 62 | request_token) 63 | 64 | def _handle_unsupported_file_exception(self, exception: UnsupportedFileException, request_route: str, request_token: str): 65 | """Handle the action to be taken in case we encounter a file which is not supported by Lambda's core functionality.""" 66 | LOG.debug("File is not supported for determining and redacting pii data.") 67 | 68 | if UNSUPPORTED_FILE_HANDLING == UNSUPPORTED_FILE_HANDLING_VALID_VALUES.PASS: 69 | LOG.debug("Unsupported file handling strategy is set to PASS. Responding back with the file content to the caller") 70 | self.s3_client.respond_back_with_data(exception.file_content, exception.http_headers, request_route, request_token) 71 | 72 | elif UNSUPPORTED_FILE_HANDLING == UNSUPPORTED_FILE_HANDLING_VALID_VALUES.FAIL: 73 | LOG.debug( 74 | f"Unsupported file handling strategy is set to FAIL. Responding back with error: " 75 | f"{S3_ERROR_CODES.UnexpectedContent.name} to the caller") 76 | self.s3_client.respond_back_with_error(S3_STATUS_CODES.BAD_REQUEST_400, 77 | S3_ERROR_CODES.UnexpectedContent, 78 | "Unsupported file encountered for determining Pii", request_route, request_token) 79 | else: 80 | raise Exception("Unknown exception handling strategy found for UnsupportedFileException.") 81 | -------------------------------------------------------------------------------- /src/clients/comprehend_client.py: -------------------------------------------------------------------------------- 1 | """Client wrapper over aws services.""" 2 | 3 | import string 4 | from concurrent.futures._base import as_completed 5 | from concurrent.futures.thread import ThreadPoolExecutor 6 | from copy import deepcopy 7 | from random import choices 8 | from typing import List 9 | 10 | import boto3 11 | import botocore 12 | import time 13 | 14 | import lambdalogging 15 | from clients.cloudwatch_client import Metrics 16 | from config import CONTAINS_PII_ENTITIES_THREAD_COUNT, DETECT_PII_ENTITIES_THREAD_COUNT, DEFAULT_LANGUAGE_CODE 17 | from constants import DEFAULT_USER_AGENT, CONTAINS_PII_ENTITIES, DETECT_PII_ENTITIES, COMPREHEND, COMPREHEND_MAX_RETRIES 18 | from data_object import Document 19 | 20 | LOG = lambdalogging.getLogger(__name__) 21 | 22 | 23 | class ComprehendClient: 24 | """Wrapper over comprehend client.""" 25 | 26 | def __init__(self, s3ol_access_point: str, pii_classification_thread_count: int = CONTAINS_PII_ENTITIES_THREAD_COUNT, 27 | pii_redaction_thread_count: int = DETECT_PII_ENTITIES_THREAD_COUNT, 28 | session_id: str = ''.join(choices(string.ascii_uppercase + string.digits, k=10)), 29 | user_agent=DEFAULT_USER_AGENT, endpoint_url=None): 30 | self.session_id = session_id 31 | session_config = botocore.config.Config( 32 | user_agent_extra=user_agent, 33 | retries={ 34 | 'max_attempts': COMPREHEND_MAX_RETRIES, 35 | 'mode': 'standard' 36 | }) 37 | if endpoint_url is None: 38 | self.comprehend = boto3.client('comprehend', config=session_config) 39 | else: 40 | self.comprehend = boto3.client('comprehend', config=session_config, endpoint_url=endpoint_url, verify=False) 41 | self.comprehend.meta.events.register('before-sign.comprehend.*', self._add_session_header) 42 | self.classification_executor_service = ThreadPoolExecutor(max_workers=pii_classification_thread_count) 43 | self.redaction_executor_service = ThreadPoolExecutor(max_workers=pii_redaction_thread_count) 44 | self.classify_metrics = Metrics(service_name=COMPREHEND, api=CONTAINS_PII_ENTITIES, s3ol_access_point=s3ol_access_point) 45 | self.detection_metrics = Metrics(service_name=COMPREHEND, api=DETECT_PII_ENTITIES, s3ol_access_point=s3ol_access_point) 46 | 47 | def _add_session_header(self, request, **kwargs): 48 | request.headers.add_header('x-amzn-session-id', self.session_id) 49 | 50 | def contains_pii_entities(self, documents: List[Document], language=DEFAULT_LANGUAGE_CODE) -> List[Document]: 51 | """Call comprehend to get pii classification of given documents.""" 52 | documents_copy = deepcopy(documents) 53 | result = [] 54 | with self.classification_executor_service: 55 | futures = [] 56 | for doc in documents_copy: 57 | futures.append(self.classification_executor_service.submit(self._update_doc_with_pii_classification, doc, language)) 58 | 59 | for future_result in as_completed(futures): 60 | try: 61 | result.append(future_result.result()) 62 | except Exception as error: 63 | LOG.error("Error occurred while calling comprehend for classifying text as pii", exc_info=True) 64 | self.classify_metrics.add_fault_count() 65 | raise error 66 | return result 67 | 68 | def _update_doc_with_pii_classification(self, document: Document, language) -> Document: 69 | start_time = time.time() 70 | response = None 71 | try: 72 | response = self.comprehend.contains_pii_entities(Text=document.text, LanguageCode=language) 73 | finally: 74 | if response is not None: 75 | self.classify_metrics.add_fault_count(response['ResponseMetadata']['RetryAttempts']) 76 | self.classify_metrics.add_latency(start_time, time.time()) 77 | # updating the document itself instead of creating a new copy to save space 78 | document.pii_classification = {label['Name']: label['Score'] for label in response['Labels']} 79 | return document 80 | 81 | def detect_pii_documents(self, documents: List[Document], language=DEFAULT_LANGUAGE_CODE) -> List[Document]: 82 | """Call comprehend to get pii entities present in given documents.""" 83 | documents_copy = deepcopy(documents) 84 | result = [] 85 | with self.redaction_executor_service: 86 | futures = [] 87 | for doc in documents_copy: 88 | futures.append(self.redaction_executor_service.submit(self._update_doc_with_pii_entities, doc, language)) 89 | 90 | for future_result in as_completed(futures): 91 | try: 92 | result.append(future_result.result()) 93 | except Exception as error: 94 | LOG.error("Error occurred while calling comprehend for detecting pii entities", exc_info=True) 95 | self.detection_metrics.add_fault_count() 96 | raise error 97 | return result 98 | 99 | def _update_doc_with_pii_entities(self, document: Document, language) -> Document: 100 | start_time = time.time() 101 | response = None 102 | try: 103 | response = self.comprehend.detect_pii_entities(Text=document.text, LanguageCode=language) 104 | finally: 105 | if response is not None: 106 | self.detection_metrics.add_fault_count(response['ResponseMetadata']['RetryAttempts']) 107 | self.detection_metrics.add_latency(start_time, time.time()) 108 | # updating the document itself instead of creating a new copy to save space 109 | document.pii_entities = response['Entities'] 110 | document.pii_classification = {entity['Type']: max(entity['Score'], document.pii_classification[entity['Type']]) 111 | if entity['Type'] in document.pii_classification else entity['Score'] 112 | for entity in response['Entities']} 113 | return document 114 | -------------------------------------------------------------------------------- /test/unit/test_exception_handlers.py: -------------------------------------------------------------------------------- 1 | from unittest import TestCase 2 | from unittest.mock import patch, MagicMock 3 | 4 | from constants import S3_STATUS_CODES, S3_ERROR_CODES, UNSUPPORTED_FILE_HANDLING_VALID_VALUES 5 | from exception_handlers import ExceptionHandler 6 | from exceptions import UnsupportedFileException, FileSizeLimitExceededException, InvalidConfigurationException, \ 7 | S3DownloadException, RestrictedDocumentException, TimeoutException 8 | 9 | 10 | class ExceptionHandlerTest(TestCase): 11 | 12 | def setUp(self) -> None: 13 | super().setUp() 14 | 15 | def tearDown(self) -> None: 16 | super().tearDown() 17 | 18 | @patch("exception_handlers.UNSUPPORTED_FILE_HANDLING", UNSUPPORTED_FILE_HANDLING_VALID_VALUES.PASS) 19 | def test_unsupported_file_exception_handling_do_not_fail(self): 20 | s3_client = MagicMock() 21 | ExceptionHandler(s3_client). \ 22 | handle_exception(UnsupportedFileException(file_content="SomeContent", http_headers={'h1': 'v1'}), "SomeRoute", "SomeToken") 23 | s3_client.respond_back_with_data.assert_called_once_with("SomeContent", {'h1': 'v1'}, "SomeRoute", "SomeToken") 24 | 25 | @patch("exception_handlers.UNSUPPORTED_FILE_HANDLING", UNSUPPORTED_FILE_HANDLING_VALID_VALUES.FAIL) 26 | def test_unsupported_file_exception_handling_return_error(self): 27 | s3_client = MagicMock() 28 | ExceptionHandler(s3_client). \ 29 | handle_exception(UnsupportedFileException(file_content="SomeContent", http_headers={'h1': 'v1'}), "SomeRoute", "SomeToken") 30 | s3_client.respond_back_with_error.assert_called_once_with(S3_STATUS_CODES.BAD_REQUEST_400, 31 | S3_ERROR_CODES.UnexpectedContent, 32 | "Unsupported file encountered for determining Pii", 33 | "SomeRoute", "SomeToken") 34 | 35 | @patch("exception_handlers.UNSUPPORTED_FILE_HANDLING", 'Unknown') 36 | def test_unsupported_file_exception_handling_return_unknown_error(self): 37 | s3_client = MagicMock() 38 | self.assertRaises(Exception, ExceptionHandler(s3_client).handle_exception, 39 | UnsupportedFileException(file_content="SomeContent", http_headers={'h1': 'v1'}), "SomeRoute", "SomeToken") 40 | 41 | def test_file_size_limit_exceeded_handler(self): 42 | s3_client = MagicMock() 43 | ExceptionHandler(s3_client).handle_exception(FileSizeLimitExceededException(), "SomeRoute", "SomeToken") 44 | s3_client.respond_back_with_error.assert_called_once_with(S3_STATUS_CODES.BAD_REQUEST_400, 45 | S3_ERROR_CODES.EntityTooLarge, 46 | "Size of the requested object exceeds maximum file size supported", 47 | "SomeRoute", "SomeToken") 48 | 49 | def test_default_exception_handler(self): 50 | s3_client = MagicMock() 51 | ExceptionHandler(s3_client).handle_exception(Exception(), "SomeRoute", "SomeToken") 52 | s3_client.respond_back_with_error.assert_called_once_with(S3_STATUS_CODES.INTERNAL_SERVER_ERROR_500, 53 | S3_ERROR_CODES.InternalError, 54 | "An internal error occurred while processing the file", 55 | "SomeRoute", "SomeToken") 56 | 57 | def test_invalid_configuration_exception_handler(self): 58 | s3_client = MagicMock() 59 | ExceptionHandler(s3_client).handle_exception(InvalidConfigurationException("Missconfigured knob"), 60 | "SomeRoute", "SomeToken") 61 | s3_client.respond_back_with_error.assert_called_once_with(S3_STATUS_CODES.BAD_REQUEST_400, 62 | S3_ERROR_CODES.InvalidRequest, 63 | "Lambda function has been incorrectly setup", 64 | "SomeRoute", "SomeToken") 65 | 66 | def test_s3_download_exception_handler(self): 67 | s3_client = MagicMock() 68 | ExceptionHandler(s3_client).handle_exception(S3DownloadException("InternalError", "Internal Server Error"), 69 | "SomeRoute", "SomeToken") 70 | s3_client.respond_back_with_error.assert_called_once_with(S3_STATUS_CODES.INTERNAL_SERVER_ERROR_500, 71 | S3_ERROR_CODES.InternalError, 72 | "Internal Server Error", 73 | "SomeRoute", "SomeToken") 74 | 75 | def test_restricted_document_exception_handler(self): 76 | s3_client = MagicMock() 77 | ExceptionHandler(s3_client).handle_exception(RestrictedDocumentException(), 78 | "SomeRoute", "SomeToken") 79 | s3_client.respond_back_with_error.assert_called_once_with(S3_STATUS_CODES.FORBIDDEN_403, 80 | S3_ERROR_CODES.AccessDenied, 81 | "Document Contains PII", 82 | "SomeRoute", "SomeToken") 83 | 84 | def test_timeout_exception_handler(self): 85 | s3_client = MagicMock() 86 | ExceptionHandler(s3_client).handle_exception(TimeoutException(), 87 | "SomeRoute", "SomeToken") 88 | s3_client.respond_back_with_error.assert_called_once_with(S3_STATUS_CODES.BAD_REQUEST_400, 89 | S3_ERROR_CODES.RequestTimeout, 90 | "Failed to complete document processing within time limit", 91 | "SomeRoute", "SomeToken") 92 | -------------------------------------------------------------------------------- /redaction-template.yml: -------------------------------------------------------------------------------- 1 | AWSTemplateFormatVersion: '2010-09-09' 2 | Transform: AWS::Serverless-2016-10-31 3 | 4 | Metadata: 5 | AWS::ServerlessRepo::Application: 6 | Name: ComprehendPiiRedactionS3ObjectLambda 7 | Description: Deploys a Lambda which will provide capability to redact PII (Personally Identifiable Information) from a text file present in s3. This Lambda can be used as a s3 object lambda which will be triggered on get-object call when configured with access point 8 | Author: AWS Comprehend 9 | # SPDX License Id, e.g., MIT, MIT-0, Apache-2.0. See https://spdx.org/licenses for more details 10 | SpdxLicenseId: MIT-0 11 | LicenseUrl: LICENSE 12 | ReadmeUrl: REDACTION_README.md 13 | Labels: [serverless,comprehend,pii,nlp] 14 | HomePageUrl: https://aws.amazon.com/comprehend/ 15 | SemanticVersion: 1.0.2 16 | SourceCodeUrl: https://github.com/aws-samples/amazon-comprehend-s3-object-lambda-functions 17 | 18 | Parameters: 19 | LogLevel: 20 | Type: String 21 | Description: Log level for Lambda function logging, e.g., ERROR, INFO, DEBUG, etc. 22 | Default: INFO 23 | UnsupportedFileHandling: 24 | Type: String 25 | Description: Handling logic for Unsupported files. Valid values are PASS and FAIL. 26 | Default: FAIL 27 | IsPartialObjectSupported: 28 | Type: String 29 | Description: Whether to support partial objects or not. Accessing partial object through http headers such byte-range can corrupt the object and/or affect PII detection accuracy. 30 | Default: FALSE 31 | DocumentMaxSizeContainsPiiEntities: 32 | Type: Number 33 | Description: Maximum document size (in bytes) to be used for making calls to Comprehend's ContainsPiiEntities API. 34 | Default: 50000 35 | DocumentMaxSizeDetectPiiEntities: 36 | Type: Number 37 | Description: Maximum document size (in bytes) to be used for making calls to Comprehend's DetectPiiEntities API. 38 | Default: 5000 39 | PiiEntityTypes: 40 | Type: String 41 | Description: List of comma separated PII entity types to be considered for redaction. Refer Comprehend's documentation page for list of supported PII entity types. 42 | Default: ALL 43 | MaskCharacter: 44 | Type: String 45 | Description: A character that replaces each character in the redacted PII entity. 46 | Default: '*' 47 | MaskMode: 48 | Type: String 49 | Description: Specifies whether the PII entity is redacted with the mask character or the entity type. Valid values - REPLACE_WITH_PII_ENTITY_TYPE and MASK. 50 | Default: MASK 51 | SubsegmentOverlappingTokens: 52 | Type: Number 53 | Description: Number of tokens/words to overlap among segments of a document in case chunking is needed because of maximum document size limit. 54 | Default: 20 55 | DocumentMaxSize: 56 | Type: Number 57 | Description: Default maximum document size (in bytes) that this function can process otherwise will throw exception for too large document size. 58 | Default: 102400 59 | ConfidenceThreshold: 60 | Type: Number 61 | Description: The minimum prediction confidence score above which PII classification and detection would be considered as final answer. Valid range (0.5 to 1.0). 62 | Default: 0.5 63 | MaxCharsOverlap: 64 | Type: Number 65 | Description: Maximum characters to overlap among segments of a document in case chunking is needed because of maximum document size limit. 66 | Default: 200 67 | DefaultLanguageCode: 68 | Type: String 69 | Description: Default language of the text to be processed. This code will be used for interacting with Comprehend. 70 | Default: en 71 | DetectPiiEntitiesThreadCount: 72 | Type: Number 73 | Description: Number of threads to use for calling Comprehend's DetectPiiEntities API. This controls the number of simultaneous calls that will be made from this Lambda. 74 | Default: 8 75 | ContainsPiiEntitiesThreadCount: 76 | Type: Number 77 | Description: Number of threads to use for calling Comprehend's ContainsPiiEntities API. This controls the number of simultaneous calls that will be made from this Lambda. 78 | Default: 20 79 | PublishCloudWatchMetrics: 80 | Type: String 81 | Description: True if publish metrics to Cloudwatch, false otherwise. See README.md for details on CloudWatch metrics. 82 | Default: True 83 | 84 | Resources: 85 | PiiRedactionFunction: 86 | Type: AWS::Serverless::Function 87 | Properties: 88 | CodeUri: src/ 89 | Handler: handler.redact_pii_documents_handler 90 | Runtime: python3.8 91 | Tracing: Active 92 | Timeout: 60 93 | Policies: 94 | - Statement: 95 | - Sid: ComprehendPiiDetectionPolicy 96 | Effect: Allow 97 | Action: 98 | - comprehend:DetectPiiEntities 99 | - comprehend:ContainsPiiEntities 100 | Resource: '*' 101 | - Sid: S3AccessPointCallbackPolicy 102 | Effect: Allow 103 | Action: 104 | - s3-object-lambda:WriteGetObjectResponse 105 | Resource: '*' 106 | - Sid: CloudWatchMetricsPolicy 107 | Effect: Allow 108 | Action: 109 | - cloudwatch:PutMetricData 110 | Resource: '*' 111 | Environment: 112 | Variables: 113 | LOG_LEVEL: !Ref LogLevel 114 | UNSUPPORTED_FILE_HANDLING: !Ref UnsupportedFileHandling 115 | PARTIAL_OBJECT_SUPPORT: !Ref IsPartialObjectSupported 116 | DOCUMENT_MAX_SIZE_CONTAINS_PII_ENTITIES: !Ref DocumentMaxSizeContainsPiiEntities 117 | DOCUMENT_MAX_SIZE_DETECT_PII_ENTITIES: !Ref DocumentMaxSizeDetectPiiEntities 118 | PII_ENTITY_TYPES: !Ref PiiEntityTypes 119 | MASK_CHARACTER: !Ref MaskCharacter 120 | MASK_MODE: !Ref MaskMode 121 | SUBSEGMENT_OVERLAPPING_TOKENS: !Ref SubsegmentOverlappingTokens 122 | DOCUMENT_MAX_SIZE: !Ref DocumentMaxSize 123 | CONFIDENCE_THRESHOLD: !Ref ConfidenceThreshold 124 | MAX_CHARS_OVERLAP: !Ref MaxCharsOverlap 125 | DEFAULT_LANGUAGE_CODE: !Ref DefaultLanguageCode 126 | DETECT_PII_ENTITIES_THREAD_COUNT: !Ref DetectPiiEntitiesThreadCount 127 | CONTAINS_PII_ENTITIES_THREAD_COUNT: !Ref ContainsPiiEntitiesThreadCount 128 | PUBLISH_CLOUD_WATCH_METRICS: !Ref PublishCloudWatchMetrics 129 | 130 | Outputs: 131 | PiiRedactionFunctionName: 132 | Description: "Redaction Function Name" 133 | Value: !Ref PiiRedactionFunction 134 | -------------------------------------------------------------------------------- /test/integ/test_access-control.py: -------------------------------------------------------------------------------- 1 | import uuid 2 | from datetime import datetime, timedelta 3 | 4 | import boto3 5 | import botocore 6 | from botocore.exceptions import ClientError 7 | from dateutil.tz import tzutc 8 | 9 | from integ.integ_base import BasicIntegTest 10 | 11 | 12 | class PiiAccessControlIntegTest(BasicIntegTest): 13 | BUILD_DIR = '.aws-sam/build/PiiAccessControlFunction/' 14 | PII_ENTITY_TYPES_IN_TEST_DOC = ['EMAIL', 'ADDRESS', 'NAME', 'PHONE', 'DATE_TIME'] 15 | 16 | @classmethod 17 | def setUpClass(cls) -> None: 18 | super().setUpClass() 19 | test_run_id = str(uuid.uuid4())[0:8] 20 | cls.lambda_function_arn = cls._create_function("pii_access_control", 'handler.pii_access_control_handler', 21 | test_run_id, {"LOG_LEVEL": "DEBUG"}, cls.lambda_role_arn, cls.BUILD_DIR) 22 | cls.s3ol_access_point_arn = cls._create_s3ol_access_point(test_run_id, cls.lambda_function_arn) 23 | 24 | @classmethod 25 | def tearDownClass(cls) -> None: 26 | cls.s3_ctrl.delete_access_point_for_object_lambda( 27 | AccountId=cls.account_id, 28 | Name=cls.s3ol_access_point_arn.split('/')[-1] 29 | ) 30 | 31 | cls.lambda_client.delete_function( 32 | FunctionName=cls.lambda_function_arn.split(':')[-1] 33 | ) 34 | super().tearDownClass() 35 | 36 | def tearDown(self) -> None: 37 | self._update_lambda_env_variables(self.lambda_function_arn, {"LOG_LEVEL": "DEBUG"}) 38 | 39 | def test_classification_lambda_with_default_entity_types(self): 40 | start_time = datetime.now(tz=tzutc()) 41 | test_pii_obj = self.s3.Object(self.s3ol_access_point_arn, 'pii_input.txt') 42 | with self.assertRaises(ClientError) as e: 43 | test_pii_obj.get() 44 | assert e.exception.response['Error']['Message'] == "Document Contains PII" 45 | assert e.exception.response['Error']['Code'] == "AccessDenied" 46 | self._validate_pii_count_metric_published(self.s3ol_access_point_arn, start_time, self.PII_ENTITY_TYPES_IN_TEST_DOC) 47 | self._validate_api_call_latency_published(self.s3ol_access_point_arn, start_time) 48 | 49 | def test_classification_lambda_with_pii_overridden_entity_types(self): 50 | self._update_lambda_env_variables(self.lambda_function_arn, 51 | {"LOG_LEVEL": "DEBUG", "PII_ENTITY_TYPES": "USERNAME,PASSWORD,AWS_ACCESS_KEY"}) 52 | response = self.s3_client.get_object(Bucket=self.s3ol_access_point_arn, Key='pii_input.txt') 53 | with open(f"{self.DATA_PATH}/pii_input.txt") as expected_output_file: 54 | assert response['Body'].read().decode('utf-8') == expected_output_file.read() 55 | 56 | def test_classification_lambda_throws_invalid_request_error_when_file_size_exceeds_limit(self): 57 | self._update_lambda_env_variables(self.lambda_function_arn, {"LOG_LEVEL": "DEBUG", "DOCUMENT_MAX_SIZE": "500"}) 58 | test_pii_obj = self.s3.Object(self.s3ol_access_point_arn, 'pii_input.txt') 59 | with self.assertRaises(ClientError) as e: 60 | test_pii_obj.get() 61 | assert e.exception.response['Error']['Message'] == "Size of the requested object exceeds maximum file size supported" 62 | assert e.exception.response['Error']['Code'] == "EntityTooLarge" 63 | 64 | def test_classification_lambda_throws_access_denied_with_an_overridden_max_file_size_limit(self): 65 | self._update_lambda_env_variables(self.lambda_function_arn, {"LOG_LEVEL": "DEBUG", "DOCUMENT_MAX_SIZE": "1500000", 66 | "DOCUMENT_MAX_SIZE_CONTAINS_PII_ENTITIES": "5000"}) 67 | with self.assertRaises(ClientError) as e: 68 | self.s3_client.get_object(Bucket=self.s3ol_access_point_arn, Key='1mb_pii_text') 69 | assert e.exception.response['Error']['Message'] == "Document Contains PII" 70 | assert e.exception.response['Error']['Code'] == "AccessDenied" 71 | 72 | def test_classification_lambda_with_unsupported_file(self): 73 | test_obj = self.s3.Object(self.s3ol_access_point_arn, 'RandomImage.png') 74 | with self.assertRaises(ClientError) as e: 75 | test_obj.get() 76 | assert e.exception.response['Error']['Message'] == "Unsupported file encountered for determining Pii" 77 | assert e.exception.response['Error']['Code'] == "UnexpectedContent" 78 | 79 | def test_classification_lambda_with_unsupported_file_handling_set_to_pass(self): 80 | self._update_lambda_env_variables(self.lambda_function_arn, {"LOG_LEVEL": "DEBUG", "UNSUPPORTED_FILE_HANDLING": 'PASS'}) 81 | response = self.s3_client.get_object(Bucket=self.s3ol_access_point_arn, Key='RandomImage.png') 82 | assert 'Body' in response 83 | 84 | def test_classification_lambda_with_partial_object(self): 85 | with self.assertRaises(ClientError) as e: 86 | self.s3_client.get_object(Bucket=self.s3ol_access_point_arn, Key="pii_input.txt", Range="bytes=0-100") 87 | assert e.exception.response['Error']['Message'] == "HTTP Header Range is not supported" 88 | 89 | def test_classification_lambda_with_partial_object_allowed_with_versioned(self): 90 | file_name = 'pii_output.txt' 91 | self.s3_client.upload_file(f"{self.DATA_PATH}/{file_name}", self.bucket_name, file_name) 92 | versions = set() 93 | self._update_lambda_env_variables(self.lambda_function_arn, {"LOG_LEVEL": "DEBUG", "IS_PARTIAL_OBJECT_SUPPORTED": "TRUE"}) 94 | 95 | for version in self.s3_client.list_object_versions(Bucket=self.bucket_name)['Versions']: 96 | if version['Key'] == 'pii_output.txt': 97 | versions.add(version['VersionId']) 98 | assert len(versions) >= 2, f"Expected at least 2 different versions of {file_name}" 99 | for versionId in versions: 100 | response = self.s3_client.get_object(Bucket=self.s3ol_access_point_arn, Key=file_name, Range="bytes=0-100", 101 | VersionId=versionId) 102 | assert response['ContentRange'] == "bytes 0-100/611" 103 | assert response['ContentLength'] == 101 104 | assert response['ContentType'] == 'binary/octet-stream' 105 | assert response['VersionId'] == versionId 106 | 107 | def test_request_timeout(self): 108 | self._update_lambda_env_variables(self.lambda_function_arn, 109 | {"LOG_LEVEL": "DEBUG", "DOCUMENT_MAX_SIZE": "10000000", 110 | "CONTAINS_PII_ENTITIES_THREAD_COUNT": "1", 111 | "DOCUMENT_MAX_SIZE_CONTAINS_PII_ENTITIES": "5000"}) 112 | start_time = datetime.now(tz=tzutc()) 113 | with self.assertRaises(ClientError) as e: 114 | session_config = botocore.config.Config( 115 | retries={ 116 | 'max_attempts': 0 117 | }) 118 | s3_client = boto3.client('s3', region_name=self.REGION_NAME, config=session_config) 119 | response = s3_client.get_object(Bucket=self.s3ol_access_point_arn, Key='5mb_pii_text') 120 | end_time = datetime.now(tz=tzutc()) + timedelta(minutes=1) 121 | 122 | self._validate_api_call_latency_published(self.s3ol_access_point_arn, start_time, end_time) 123 | assert e.exception.response['Error']['Message'] == "Failed to complete document processing within time limit" 124 | assert e.exception.response['Error']['Code'] == "RequestTimeout" 125 | 126 | 127 | def test_classification_handler_no_pii(self): 128 | non_pii_file_name = 'clean.txt' 129 | start_time = datetime.now(tz=tzutc()) 130 | test_clean_obj = self.s3.Object(self.s3ol_access_point_arn, non_pii_file_name) 131 | get_obj_response = test_clean_obj.get() 132 | get_obj_data = get_obj_response['Body'].read().decode('utf-8') 133 | 134 | with open(f"{self.DATA_PATH}/{non_pii_file_name}") as expected_output_file: 135 | expected_output = expected_output_file.read() 136 | 137 | assert expected_output == get_obj_data 138 | self._validate_api_call_latency_published(self.s3ol_access_point_arn, start_time) 139 | -------------------------------------------------------------------------------- /test/integ/test_redaction.py: -------------------------------------------------------------------------------- 1 | import uuid 2 | from datetime import datetime, timedelta 3 | 4 | import boto3 5 | import botocore 6 | from botocore.exceptions import ClientError 7 | from dateutil.tz import tzutc 8 | 9 | from integ.integ_base import BasicIntegTest 10 | 11 | 12 | class PiiRedactionIntegTest(BasicIntegTest): 13 | BUILD_DIR = '.aws-sam/build/PiiRedactionFunction/' 14 | 15 | @classmethod 16 | def setUpClass(cls) -> None: 17 | super().setUpClass() 18 | test_run_id = str(uuid.uuid4())[0:8] 19 | cls.lambda_function_arn = cls._create_function("pii_redaction", 'handler.redact_pii_documents_handler', 20 | test_run_id, {"LOG_LEVEL": "DEBUG"}, cls.lambda_role_arn, cls.BUILD_DIR) 21 | cls.s3ol_access_point_arn = cls._create_s3ol_access_point(test_run_id, cls.lambda_function_arn) 22 | 23 | @classmethod 24 | def tearDownClass(cls) -> None: 25 | cls.s3_ctrl.delete_access_point_for_object_lambda( 26 | AccountId=cls.account_id, 27 | Name=cls.s3ol_access_point_arn.split('/')[-1] 28 | ) 29 | 30 | cls.lambda_client.delete_function( 31 | FunctionName=cls.lambda_function_arn.split(':')[-1] 32 | ) 33 | super().tearDownClass() 34 | 35 | def tearDown(self) -> None: 36 | self._update_lambda_env_variables(self.lambda_function_arn, {"LOG_LEVEL": "DEBUG"}) 37 | 38 | def test_request_timeout(self): 39 | self._update_lambda_env_variables(self.lambda_function_arn, 40 | {"LOG_LEVEL": "DEBUG", "DOCUMENT_MAX_SIZE": "10000000", 41 | "CONTAINS_PII_ENTITIES_THREAD_COUNT": "1", 42 | "DOCUMENT_MAX_SIZE_CONTAINS_PII_ENTITIES": "5000"}) 43 | start_time = datetime.now(tz=tzutc()) 44 | # To keep total integration test time low we avoid making multiple retry attempts 45 | with self.assertRaises(ClientError) as e: 46 | session_config = botocore.config.Config( 47 | retries={ 48 | 'max_attempts': 0 49 | }) 50 | s3_client = boto3.client('s3', region_name=self.REGION_NAME, config=session_config) 51 | response = s3_client.get_object(Bucket=self.s3ol_access_point_arn, Key='5mb_pii_text') 52 | end_time = datetime.now(tz=tzutc()) + timedelta(minutes=1) 53 | self._validate_api_call_latency_published(self.s3ol_access_point_arn, start_time, end_time) 54 | assert e.exception.response['Error']['Message'] == "Failed to complete document processing within time limit" 55 | assert e.exception.response['Error']['Code'] == "RequestTimeout" 56 | 57 | def test_redaction_lambda_with_default_entity_types(self): 58 | start_time = datetime.now(tz=tzutc()) 59 | test_pii_obj = self.s3.Object(self.s3ol_access_point_arn, 'pii_input.txt') 60 | test_obj_response = test_pii_obj.get() 61 | test_obj_data = test_obj_response['Body'].read().decode('utf-8') 62 | 63 | with open(f"{self.DATA_PATH}/pii_output.txt") as expected_output_file: 64 | expected_output = expected_output_file.read() 65 | 66 | assert expected_output == test_obj_data 67 | self._validate_pii_count_metric_published(self.s3ol_access_point_arn, start_time, self.PII_ENTITY_TYPES_IN_TEST_DOC) 68 | self._validate_api_call_latency_published(self.s3ol_access_point_arn, start_time, is_pii_detection_performed=True) 69 | 70 | def test_redaction_lambda_with_pii_overridden_entity_types(self): 71 | self._update_lambda_env_variables(self.lambda_function_arn, 72 | {"LOG_LEVEL": "DEBUG", 73 | "PII_ENTITY_TYPES": "CREDIT_DEBIT_NUMBER,BANK_ROUTING,BANK_ACCOUNT_NUMBER,EMAIL,PASSWORD"}) 74 | start_time = datetime.now(tz=tzutc()) 75 | response = self.s3_client.get_object(Bucket=self.s3ol_access_point_arn, Key='pii_input.txt') 76 | with open(f"{self.DATA_PATH}/pii_bank_routing_redacted.txt") as expected_output_file: 77 | assert response['Body'].read().decode('utf-8') == expected_output_file.read() 78 | self._validate_pii_count_metric_published(self.s3ol_access_point_arn, start_time, 79 | ['BANK_ROUTING', 'CREDIT_DEBIT_NUMBER', 'BANK_ACCOUNT_NUMBER']) 80 | 81 | def test_redaction_lambda_fails_when_file_size_exceeds_limit(self): 82 | self._update_lambda_env_variables(self.lambda_function_arn, {"LOG_LEVEL": "DEBUG", "DOCUMENT_MAX_SIZE": "500"}) 83 | test_pii_obj = self.s3.Object(self.s3ol_access_point_arn, 'pii_input.txt') 84 | with self.assertRaises(ClientError) as e: 85 | test_pii_obj.get() 86 | assert e.exception.response['Error']['Message'] == "Size of the requested object exceeds maximum file size supported" 87 | assert e.exception.response['Error']['Code'] == "EntityTooLarge" 88 | 89 | def test_redaction_lambda_succeeds_with_an_overridden_max_file_size_limit(self): 90 | self._update_lambda_env_variables(self.lambda_function_arn, {"LOG_LEVEL": "DEBUG", "DOCUMENT_MAX_SIZE": "1500000", 91 | "DOCUMENT_MAX_SIZE_CONTAINS_PII_ENTITIES": "5000"}) 92 | response = self.s3_client.get_object(Bucket=self.s3ol_access_point_arn, Key='1mb_pii_text') 93 | with open(f"{self.DATA_PATH}/1mb_pii_redacted_text") as expected_output_file: 94 | assert response['Body'].read().decode('utf-8') == expected_output_file.read() 95 | 96 | def test_redaction_lambda_with_unsupported_file(self): 97 | test_obj = self.s3.Object(self.s3ol_access_point_arn, 'RandomImage.png') 98 | with self.assertRaises(ClientError) as e: 99 | test_obj.get() 100 | assert e.exception.response['Error']['Message'] == "Unsupported file encountered for determining Pii" 101 | assert e.exception.response['Error']['Code'] == "UnexpectedContent" 102 | 103 | def test_classification_lambda_with_unsupported_file_handling_set_to_pass(self): 104 | self._update_lambda_env_variables(self.lambda_function_arn, {"LOG_LEVEL": "DEBUG", "UNSUPPORTED_FILE_HANDLING": 'PASS'}) 105 | response = self.s3_client.get_object(Bucket=self.s3ol_access_point_arn, Key='RandomImage.png') 106 | assert 'Body' in response 107 | 108 | def test_classification_lambda_with_partial_object(self): 109 | with self.assertRaises(ClientError) as e: 110 | self.s3_client.get_object(Bucket=self.s3ol_access_point_arn, Key="pii_input.txt", Range="bytes=0-100") 111 | assert e.exception.response['Error']['Message'] == "HTTP Header Range is not supported" 112 | 113 | def test_classification_lambda_with_partial_object_allowed_with_versioned(self): 114 | file_name = 'pii_input.txt' 115 | self.s3_client.upload_file(f"{self.DATA_PATH}/{file_name}", self.bucket_name, file_name) 116 | versions = set() 117 | self._update_lambda_env_variables(self.lambda_function_arn, {"LOG_LEVEL": "DEBUG", "IS_PARTIAL_OBJECT_SUPPORTED": "TRUE"}) 118 | 119 | for version in self.s3_client.list_object_versions(Bucket=self.bucket_name)['Versions']: 120 | if version['Key'] == 'pii_input.txt': 121 | versions.add(version['VersionId']) 122 | assert len(versions) >= 2, f"Expected at least 2 different versions of {file_name}" 123 | for versionId in versions: 124 | response = self.s3_client.get_object(Bucket=self.s3ol_access_point_arn, Key=file_name, Range="bytes=0-100", 125 | VersionId=versionId) 126 | assert response['ContentRange'] == "bytes 0-100/611" 127 | assert response['ContentLength'] == 101 128 | assert response['ContentType'] == 'binary/octet-stream' 129 | assert response['VersionId'] == versionId 130 | 131 | 132 | def test_redaction_handler_no_pii(self): 133 | non_pii_file_name = 'clean.txt' 134 | test_clean_obj = self.s3.Object(self.s3ol_access_point_arn, non_pii_file_name) 135 | get_obj_response = test_clean_obj.get() 136 | get_obj_data = get_obj_response['Body'].read().decode('utf-8') 137 | 138 | with open(f"{self.DATA_PATH}/{non_pii_file_name}") as expected_output_file: 139 | expected_output = expected_output_file.read() 140 | 141 | assert expected_output == get_obj_data 142 | -------------------------------------------------------------------------------- /test/unit/test_comprehend_client.py: -------------------------------------------------------------------------------- 1 | from time import sleep, time 2 | from unittest import TestCase 3 | from unittest.mock import patch, MagicMock, call 4 | 5 | from botocore.awsrequest import AWSRequest 6 | 7 | from clients.comprehend_client import ComprehendClient 8 | from constants import BEGIN_OFFSET, END_OFFSET, ENTITY_TYPE, SCORE 9 | from data_object import Document 10 | 11 | 12 | class ComprehendClientTest(TestCase): 13 | @patch('clients.comprehend_client.boto3') 14 | def test_comprehend_client_constuctor(self, mocked_boto3): 15 | mocked_client = MagicMock() 16 | mocked_boto3.client.return_value = mocked_client 17 | comprehend_client = ComprehendClient(s3ol_access_point="some_access_point_arn") 18 | mocked_client.meta.events.register.assert_called_with('before-sign.comprehend.*', comprehend_client._add_session_header) 19 | request = AWSRequest() 20 | comprehend_client._add_session_header(request) 21 | assert len(request.headers.get('x-amzn-session-id')) >= 10 22 | assert comprehend_client.classification_executor_service._max_workers == 20 23 | assert comprehend_client.redaction_executor_service._max_workers == 8 24 | 25 | @patch('clients.comprehend_client.boto3') 26 | def test_comprehend_detect_pii_entities(self, mocked_boto3): 27 | DUMMY_PII_ENTITY = {BEGIN_OFFSET: 12, END_OFFSET: 14, ENTITY_TYPE: 'SSN', SCORE: 0.345} 28 | 29 | def mocked_api_call(**kwargs): 30 | sleep(0.1) 31 | return {'Entities': [DUMMY_PII_ENTITY], 'ResponseMetadata': {'RetryAttempts': 0}} 32 | 33 | mocked_client = MagicMock() 34 | mocked_boto3.client.return_value = mocked_client 35 | comprehend_client = ComprehendClient(s3ol_access_point="Some_random_access_point",pii_redaction_thread_count=5) 36 | mocked_client.detect_pii_entities.side_effect = mocked_api_call 37 | start_time = time() 38 | docs_with_pii_entity = comprehend_client.detect_pii_documents( 39 | documents=[Document(text="Some Random 1mb_pii_text", ) for i in range(1, 20)], 40 | language='en') 41 | end_time = time() 42 | mocked_client.detect_pii_entities.assert_has_calls([call(Text="Some Random 1mb_pii_text", LanguageCode='en') for i in range(1, 20)]) 43 | 44 | assert len(comprehend_client.detection_metrics.metrics) == 38 45 | for i in range(0, 19, 2): 46 | assert comprehend_client.detection_metrics.metrics[i]['MetricName'] == 'ErrorCount' 47 | assert comprehend_client.detection_metrics.metrics[i]['Value'] == 0 48 | assert comprehend_client.detection_metrics.metrics[i + 1]['MetricName'] == 'Latency' 49 | assert len(comprehend_client.classify_metrics.metrics) == 0 50 | 51 | # should be around 0.4 : 20 calls with 5 thread counts , where each call taking 0.1 seconds to complete 52 | assert 0.4 <= end_time - start_time < 0.5 53 | for doc in docs_with_pii_entity: 54 | assert len(doc.pii_entities) == 1 55 | assert doc.pii_entities[0] == DUMMY_PII_ENTITY 56 | 57 | @patch('clients.comprehend_client.boto3') 58 | def test_comprehend_contains_pii_entities(self, mocked_boto3): 59 | classification_result = {'Labels': [{'Name': 'SSN', 'Score': 0.1234}], 'ResponseMetadata': {'RetryAttempts': 0}} 60 | 61 | def mocked_api_call(**kwargs): 62 | sleep(0.1) 63 | return classification_result 64 | 65 | mocked_client = MagicMock() 66 | mocked_boto3.client.return_value = mocked_client 67 | comprehend_client = ComprehendClient("some_access_point_arn", pii_classification_thread_count=2, pii_redaction_thread_count=5) 68 | mocked_client.contains_pii_entities.side_effect = mocked_api_call 69 | start_time = time() 70 | docs_with_pii_classification = comprehend_client.contains_pii_entities( 71 | documents=[Document(text="Some Random 1mb_pii_text", ) for i in range(1, 4)], 72 | language='en') 73 | end_time = time() 74 | 75 | mocked_client.contains_pii_entities.assert_has_calls( 76 | [call(Text="Some Random 1mb_pii_text", LanguageCode='en') for i in range(1, 4)]) 77 | # should be around 0.2 : 4 calls with 2 thread counts , where each call taking 0.1 seconds to complete 78 | assert 0.2 <= end_time - start_time < 0.3 79 | assert len(comprehend_client.classify_metrics.metrics) == 6 80 | for i in range(0, 6, 2): 81 | assert comprehend_client.classify_metrics.metrics[i]['MetricName'] == 'ErrorCount' 82 | assert comprehend_client.classify_metrics.metrics[i + 1]['MetricName'] == 'Latency' 83 | 84 | assert len(comprehend_client.detection_metrics.metrics) == 0 85 | for doc in docs_with_pii_classification: 86 | assert doc.pii_classification == {'SSN': 0.1234} 87 | 88 | @patch('clients.comprehend_client.boto3') 89 | def test_comprehend_contains_pii_entities_failure(self, mocked_boto3): 90 | classification_result = {'Labels': [{'Name': 'SSN', 'Score': 0.1234}], 'ResponseMetadata': {'RetryAttempts': 0}} 91 | 92 | mocked_client = MagicMock() 93 | mocked_boto3.client.return_value = mocked_client 94 | comprehend_client = ComprehendClient(s3ol_access_point="Some_access_point_arn", pii_classification_thread_count=4) 95 | api_invocation_exception = Exception("Some unrecoverable error") 96 | mocked_client.contains_pii_entities.side_effect = [classification_result, classification_result, api_invocation_exception, 97 | classification_result] 98 | try: 99 | comprehend_client.contains_pii_entities(documents=[Document(text="Some Random 1mb_pii_text", ) for i in range(0, 4)], 100 | language='en') 101 | 102 | assert False, "Expected an exception " 103 | except Exception as e: 104 | assert e == api_invocation_exception 105 | mocked_client.contains_pii_entities.assert_has_calls( 106 | [call(Text="Some Random 1mb_pii_text", LanguageCode='en') for i in range(0, 4)]) 107 | assert len(comprehend_client.classify_metrics.metrics) == 8 # 4 latency metrics and 1 fault 108 | assert len(comprehend_client.detection_metrics.metrics) == 0 109 | assert comprehend_client.classify_metrics.service_name == "Comprehend" 110 | assert comprehend_client.classify_metrics.s3ol_access_point_arn == "Some_access_point_arn" 111 | assert comprehend_client.classify_metrics.api == "ContainsPiiEntities" 112 | metric_count = {"ErrorCount": 0, "Latency": 0} 113 | for i in range(0, 8): 114 | metric_name = comprehend_client.classify_metrics.metrics[i]['MetricName'] 115 | metric_count[metric_name] += 1 116 | assert metric_count['ErrorCount'] == 4 117 | assert metric_count['Latency'] == 4 118 | 119 | @patch('clients.comprehend_client.boto3') 120 | def test_comprehend_detect_pii_entities_failure(self, mocked_boto3): 121 | DUMMY_PII_ENTITY = {'Entities': [{BEGIN_OFFSET: 12, END_OFFSET: 14, ENTITY_TYPE: 'SSN', SCORE: 0.345}], 122 | 'ResponseMetadata': {'RetryAttempts': 0}} 123 | mocked_client = MagicMock() 124 | mocked_boto3.client.return_value = mocked_client 125 | comprehend_client = ComprehendClient(s3ol_access_point="Some_access_point_arn") 126 | api_invocation_exception = Exception("Some unrecoverable error") 127 | mocked_client.detect_pii_entities.side_effect = [DUMMY_PII_ENTITY, DUMMY_PII_ENTITY, api_invocation_exception, 128 | DUMMY_PII_ENTITY] 129 | try: 130 | comprehend_client.detect_pii_documents(documents=[Document(text="Some Random 1mb_pii_text", ) for i in range(0, 4)], 131 | language='en') 132 | 133 | assert False, "Expected an exception " 134 | except Exception as e: 135 | assert e == api_invocation_exception 136 | mocked_client.detect_pii_entities.assert_has_calls([call(Text="Some Random 1mb_pii_text", LanguageCode='en') for i in range(0, 4)]) 137 | assert len(comprehend_client.detection_metrics.metrics) == 8 # 4 latency metrics and 1 fault 138 | assert len(comprehend_client.classify_metrics.metrics) == 0 139 | assert comprehend_client.detection_metrics.service_name == "Comprehend" 140 | assert comprehend_client.detection_metrics.s3ol_access_point_arn == "Some_access_point_arn" 141 | assert comprehend_client.detection_metrics.api == "DetectPiiEntities" 142 | metric_count = {"ErrorCount": 0, "Latency": 0} 143 | for i in range(0, 8): 144 | metric_name = comprehend_client.detection_metrics.metrics[i]['MetricName'] 145 | metric_count[metric_name] += 1 146 | assert metric_count['ErrorCount'] == 4 147 | assert metric_count['Latency'] == 4 148 | -------------------------------------------------------------------------------- /ACCESS_CONTROL_README.md: -------------------------------------------------------------------------------- 1 | # PII Access Control S3 Object Lambda function 2 | 3 | This serverless app helps you to control access to PII (Personally Identifiable Information) from valid text files present in S3. 4 | This app deploys a Lambda function which can be attached to S3 object lambda access point. 5 | The lambda function internally uses AWS Comprehend to detect PII entities from the text . 6 | 7 | ## App Architecture 8 | ![Architecture Diagram](images/architecture_access_control.gif) 9 | 10 | Lambda function is optimized to leverage Comprehend's ContainsPiiEntities. 11 | 12 | 1. Lambda function is invoked with a request containing information about the S3 object to get and transform. 13 | 2. The request contains a S3 presigned url to fetch the requested object. 14 | 3. The data is split into chunks that are accepted by Comprehend’s ContainsPiiEntities API and call the API with each chunk. 15 | 4. The responses are aggregated from all chunks. 16 | 5. Lambda function callsback S3 with the response i.e either the text data or throws exception if the file contains PII. 17 | 6. If any failure happens while processing, Lambda function returns an appropriate error response to S3 which will be returned to the original caller. 18 | 7. Lambda function returns with 0 exit code .i.e. with out any error if no error occurred else would fail. 19 | 20 | ## Installation Instructions 21 | 22 | 1. [Create an AWS account](https://portal.aws.amazon.com/gp/aws/developer/registration/index.html) if you do not already have one and login 23 | 1. Go to the app's page on the [Serverless Application Repository](https://console.aws.amazon.com/lambda/home#/create/app?applicationId=arn:aws:serverlessrepo:us-east-1:839782855223:applications/ComprehendPiiAccessControlS3ObjectLambda) 24 | 1. Provide the required app parameters (see parameter details below) and click "Deploy". 25 | 26 | ## Parameters 27 | Following are the parameters that you can tune to get desired behavior 28 | #### Environment variables 29 | Following environment variables for Lambda function can be set to get desired behaviour 30 | 1. `LOG_LEVEL` - Log level for Lambda function function logging, e.g., ERROR, INFO, DEBUG, etc. Default: `INFO`. 31 | 1. `UNSUPPORTED_FILE_HANDLING` Handling logic for Unsupported files. Valid values are `PASS` and `FAIL` (Default: `FAIL`). If set to `FAIL` it will throw UnsupportedFileException when the requested object is of unsupported type. 32 | 1. `IS_PARTIAL_OBJECT_SUPPORTED` Whether to support partial objects or not. Accessing partial object through http headers such byte-range can corrupt the object and/or affect PII detection accuracy. Valid values are `TRUE` and `FALSE`. Default: `FALSE`. 33 | 1. `DOCUMENT_MAX_SIZE_CONTAINS_PII_ENTITIES` Maximum document size (in bytes) to be used for making calls to Comprehend's ContainsPiiDocument API for classifying PII entity types present in the doc Default: 50000. 34 | 1. `PII_ENTITY_TYPES` : List of comma separated PII entity types to be considered for access control. Refer [Comprehend's documentation page](https://docs.aws.amazon.com/comprehend/latest/dg/how-pii.html#how-pii-types) for list of supported PII entity types. Default: `ALL` which signifies all entity types that comprehend supports. 35 | 1. `SUBSEGMENT_OVERLAPPING_TOKENS` : Number of tokens/words to overlap among segments of a document in case chunking is needed because of maximum document size limit. Default: 20. 36 | 1. `DOCUMENT_MAX_SIZE` : Default maximum document size (in bytes) that this function can process otherwise will throw exception for too large document size. 37 | 1. `CONFIDENCE_THRESHOLD` : The minimum prediction confidence score above which PII classification and detection would be considered as final answer. Valid range (0.5 to 1.0). Default: 0.5. 38 | 1. `MAX_CHARS_OVERLAP` : Maximum characters to overlap among segments of a document in case chunking is needed because of maximum document size limit. Default: 2. 39 | 1. `DEFAULT_LANGUAGE_CODE` : Default language of the text to be processed. This code will be used for interacting with Comprehend . Default: en. 40 | 1. `CONTAINS_PII_ENTITIES_THREAD_COUNT` : Number of threads to use for calling Comprehend's ContainsPiiEntities API. This controls the number of simultaneous calls that will be made from this Lambda function. Default: 20. 41 | 1. `PUBLISH_CLOUD_WATCH_METRICS` : This determines whether or not to publish metrics to Cloudwatch. Default: true. 42 | 43 | #### Runtime variables 44 | You can add following arguments in S3 object lambda access point configuration payload to override the default value configured used by the Lambda function . These values would take precedence over environment variables. Provide these variables as a json string like the following example 45 | ``` 46 | ... 47 | "payload": "{\"pii_entity_types\" : [\"CREDIT_DEBIT_NUMBER\"],\"mask_mode\":\"MASK\", \"mask_character\" : \"*\",\"confidence_threshold\":0.6,\"language_code\":\"en\"}" 48 | ... 49 | ``` 50 | Use these parameters to get desired behaviors from different access point configuration attached to the same lambda function. 51 | 1. `pii_entity_types` : List of PII entity types to be considered for redaction. e.g. `["SSN","CREDIT_DEBIT_NUMBER"]`. 52 | 1. `confidence_threshold` :The minimum prediction confidence score above which PII classification and detection would be considered as final answer. 53 | 1. `language_code`: Language of the text. This will be used to interact with Comprehend. 54 | 55 | ## App Outputs 56 | 57 | #### Successful response 58 | In case the text file contains PII, it would be redacted and returned in response to GetObject API output. 59 | #### Error responses 60 | Lambda function would forward the standard [S3 error responses](https://docs.aws.amazon.com/AmazonS3/latest/API/ErrorResponses.html) it will receive while downloading the file from S3. 61 | 62 | Further following error responses will be thrown by Lambda function: 63 | 64 | |Status Code|Error Code|Error Message|Description| 65 | |---|---|---|---| 66 | | BAD_REQUEST_400 | InvalidRequest | Lambda function has been incorrectly setup | An incorrect configuration which restricts lambda function to even start handling the incoming events| 67 | | BAD_REQUEST_400 | UnexpectedContent | Unsupported file encountered for determining PII | This error would be thrown in case caller tries to get an invalid utf8 file (e.g image) and UNSUPPORTED_FILE_HANDLING variable is set to FAIL| 68 | | BAD_REQUEST_400 | EntityTooLarge | Size of the requested object exceeds maximum file size supported | This error would be thrown in case caller tries to get an object which is beyond the max file size supported| 69 | | BAD_REQUEST_400 | RequestTimeout | Failed to complete document processing within time limit | This error would be thrown in case lambda is not able to complete the processing of the document within the time limit. This could be because your file size is too big or you are getting throttled by either S3 or Comprehend.| 70 | | INTERNAL_SERVER_ERROR_500 | InternalError | An internal error occurred while processing the file | Any other error occurred while processing the object | 71 | | FORBIDDEN_403 | AccessDenied | Document Contains PII | The the requested document has been inferred to contain PII| 72 | 73 | ## Metrics 74 | Metrics are published after each invocation of the lambda function and are a best effort attempt (Failures in CloudWatch metric publishing are ignored) 75 | 76 | All metrics will be under the Namespace: ComprehendS3ObjectLambda 77 | 78 | ### Metrics for processed document 79 | |MetricName|Description|Unit|Dimensions| 80 | |---|---|---|---| 81 | |PiiDocumentsProcessed|Emitted after processing a document that contains pii|Count|S3ObjectLambdaAccessPoint, Language| 82 | |DocumentsProcessed|Emitted after processing any document|Count|S3ObjectLambdaAccessPoint, Language| 83 | |PiiDocumentTypesProcessed|Emitted after processing a document that contains PII for each type of PII of interest|Count|S3ObjectLambdaAccessPoint, Language, PiiEntityType| 84 | 85 | ### Metrics for Comprehend operations 86 | |MetricName|Description|Unit|Dimensions| 87 | |---|---|---|---| 88 | |Latency|The latency of Comprehend DetectPiiEntities API|Milliseconds|Comprehend, DetectPiiEntities| 89 | |Latency|The latency of Comprehend ContainsPiiEntities API|Milliseconds|Comprehend, ContainsPiiEntities| 90 | |ErrorCount|The error count of Comprehend DetectPiiEntities API|Count|Comprehend, DetectPiiEntities| 91 | |ErrorCount|The error count of Comprehend ContainsPiiEntities API|Count|Comprehend, ContainsPiiEntities| 92 | 93 | ### Metrics for S3 operations 94 | |MetricName|Description|Unit|Dimensions| 95 | |---|---|---|---| 96 | |Latency|The latency of S3 WriteGetObjectResponse API|Milliseconds|S3, WriteGetObjectResponse| 97 | |Latency|The latency of downloading a file from a presigned S3 url|Milliseconds|S3, DownloadPresignedUrl| 98 | |ErrorCount|The error count of S3 WriteGetObjectResponse API|Count|S3, WriteGetObjectResponse| 99 | |ErrorCount|The error count of downloading a file from a presigned S3 url|Count|S3, DownloadPresignedUrl| 100 | 101 | ## License Summary 102 | 103 | This code is made available under the MIT-0 license. See the LICENSE file. -------------------------------------------------------------------------------- /test/unit/test_s3_client.py: -------------------------------------------------------------------------------- 1 | from unittest import TestCase 2 | from unittest.mock import patch, MagicMock 3 | 4 | from clients.s3_client import S3Client 5 | from constants import BEGIN_OFFSET, END_OFFSET, ENTITY_TYPE, SCORE, S3_STATUS_CODES, S3_ERROR_CODES, RANGE 6 | from exceptions import S3DownloadException, FileSizeLimitExceededException, UnsupportedFileException 7 | 8 | PRESIGNED_URL_TEST = "https://s3ol-classifier.s3.amazonaws.com/test.txt" 9 | 10 | 11 | class MockResponse: 12 | def __init__(self, content, status_code, headers): 13 | self.status_code = status_code 14 | self.content = content 15 | self.headers = headers 16 | 17 | 18 | def get_s3_xml_response(code: str, message: str = '') -> str: 19 | return f"\n{code}{message}" 20 | 21 | 22 | class S3ClientTest(TestCase): 23 | @patch('clients.s3_client.boto3') 24 | def test_s3_client_respond_back_with_error(self, mocked_boto3): 25 | mocked_client = MagicMock() 26 | mocked_boto3.client.return_value = mocked_client 27 | s3_client = S3Client(s3ol_access_point="Random_access_point") 28 | s3_client.respond_back_with_error(status_code=S3_STATUS_CODES.PRECONDITION_FAILED_412, 29 | error_code=S3_ERROR_CODES.PreconditionFailed, error_message="Some Error", 30 | request_route="Route", request_token="q2334") 31 | 32 | mocked_client.write_get_object_response.assert_called_once_with(StatusCode=412, 33 | ErrorCode='PreconditionFailed', 34 | ErrorMessage="Some Error", 35 | RequestRoute='Route', RequestToken="q2334") 36 | 37 | @patch('clients.s3_client.boto3') 38 | def test_s3_client_respond_back_with_data_default_status_code(self, mocked_boto3): 39 | mocked_client = MagicMock() 40 | mocked_boto3.client.return_value = mocked_client 41 | s3_client = S3Client(s3ol_access_point="Random_access_point") 42 | s3_client.respond_back_with_data(data='SomeData', 43 | headers={"ContentRange": "0-100", "SomeRandomHeader": '0123', "Content-Length": "101"}, 44 | request_route="Route", request_token="q2334") 45 | 46 | mocked_client.write_get_object_response.assert_called_once_with(Body='SomeData', ContentLength=101, 47 | RequestRoute='Route', RequestToken="q2334", 48 | StatusCode=200) 49 | 50 | @patch('clients.s3_client.boto3') 51 | def test_s3_client_respond_back_with_data_partial_data(self, mocked_boto3): 52 | mocked_client = MagicMock() 53 | mocked_boto3.client.return_value = mocked_client 54 | s3_client = S3Client(s3ol_access_point="Random_access_point") 55 | s3_client.respond_back_with_data(data='SomeData', headers={"Content-Range": "0-1200", "SomeRandomHeader": '0123'}, 56 | request_route="Route", request_token="q2334", status_code=S3_STATUS_CODES.PARTIAL_CONTENT_206) 57 | 58 | mocked_client.write_get_object_response.assert_called_once_with(Body='SomeData', ContentRange="0-1200", 59 | RequestRoute='Route', RequestToken="q2334", 60 | StatusCode=206) 61 | 62 | @patch('clients.s3_client.requests.Session.get', 63 | side_effect=lambda *args, **kwargs: MockResponse(b'Test', 200, {'Content-Length': '4'})) 64 | def test_s3_client_download_file_from_presigned_url_200_ok(self, mocked_get): 65 | s3_client = S3Client(s3ol_access_point="Random_access_point") 66 | http_header = {'some-header': 'header-value'} 67 | text, response_http_headers, status_code = s3_client.download_file_from_presigned_url(PRESIGNED_URL_TEST, http_header) 68 | assert text == 'Test' 69 | assert response_http_headers == {'Content-Length': '4'} 70 | assert status_code == S3_STATUS_CODES.OK_200 71 | mocked_get.assert_called_with(PRESIGNED_URL_TEST, timeout=10, headers=http_header) 72 | 73 | @patch('clients.s3_client.requests.Session.get', 74 | side_effect=lambda *args, **kwargs: MockResponse(b'Test', 206, {'Content-Length': '100'})) 75 | def test_s3_client_download_partial_file_from_presigned_url(self, mocked_get): 76 | s3_client = S3Client(s3ol_access_point="Random_access_point") 77 | http_header = {'some-header': 'header-value'} 78 | text, response_http_headers, status_code = s3_client.download_file_from_presigned_url(PRESIGNED_URL_TEST, http_header) 79 | assert text == 'Test' 80 | assert response_http_headers == {'Content-Length': '100'} 81 | assert status_code == S3_STATUS_CODES.PARTIAL_CONTENT_206 82 | mocked_get.assert_called_with(PRESIGNED_URL_TEST, timeout=10, headers=http_header) 83 | 84 | @patch('clients.s3_client.requests.Session.get', 85 | side_effect=lambda *args, **kwargs: MockResponse(b'Test', 400, {'Content-Length': '4'})) 86 | def test_s3_client_download_file_from_presigned_url_400_from_get(self, mocked_get): 87 | s3_client = S3Client(s3ol_access_point="Random_access_point") 88 | self.assertRaises(S3DownloadException, s3_client.download_file_from_presigned_url, PRESIGNED_URL_TEST, {}) 89 | 90 | assert mocked_get.call_count == 5 91 | 92 | @patch('clients.s3_client.requests.Session.get', 93 | side_effect=lambda *args, **kwargs: MockResponse(b'A' * (11 * 1024 * 1024), 200, {'Content-Length': str(11 * 1024 * 1024)})) 94 | def test_s3_client_download_file_from_presigned_url_file_size_limit_exceeded(self, mocked_get): 95 | s3_client = S3Client(s3ol_access_point="Random_access_point") 96 | self.assertRaises(FileSizeLimitExceededException, s3_client.download_file_from_presigned_url, PRESIGNED_URL_TEST, {}) 97 | 98 | mocked_get.assert_called_once() 99 | 100 | @patch('clients.s3_client.requests.Session.get', 101 | side_effect=lambda *args, **kwargs: MockResponse(get_s3_xml_response('AccessDenied').encode('utf-8'), 200, 102 | {'Content-Length': '4'})) 103 | def test_s3_client_download_file_from_presigned_url_access_denied(self, mocked_get): 104 | s3_client = S3Client(s3ol_access_point="Random_access_point") 105 | self.assertRaises(S3DownloadException, s3_client.download_file_from_presigned_url, PRESIGNED_URL_TEST, {}) 106 | 107 | mocked_get.assert_called_once() 108 | 109 | @patch('clients.s3_client.requests.Session.get', 110 | side_effect=lambda *args, **kwargs: MockResponse(get_s3_xml_response('UnknownError').encode('utf-8'), 200, 111 | {'Content-Length': '4'})) 112 | def test_s3_client_download_file_from_presigned_url_unknown_error(self, mocked_get): 113 | s3_client = S3Client(s3ol_access_point="Random_access_point") 114 | self.assertRaises(S3DownloadException, s3_client.download_file_from_presigned_url, PRESIGNED_URL_TEST, {}) 115 | 116 | assert mocked_get.call_count == 5 117 | 118 | @patch('clients.s3_client.requests.Session.get', 119 | side_effect=lambda *args, **kwargs: MockResponse(bytearray.fromhex('ff'), 200, {'Content-Length': '4'})) 120 | def test_s3_client_download_file_from_presigned_unicode_decode_error(self, mocked_get): 121 | s3_client = S3Client(s3ol_access_point="Random_access_point") 122 | self.assertRaises(UnsupportedFileException, s3_client.download_file_from_presigned_url, PRESIGNED_URL_TEST, {}) 123 | 124 | mocked_get.assert_called_once() 125 | 126 | @patch('clients.s3_client.requests.Session.get', 127 | side_effect=lambda *args, **kwargs: MockResponse(bytearray.fromhex('ff'), 200, {'Content-Length': '4'})) 128 | def test_s3_client_download_file_from_presigned_unicode_decode_error_error(self, mocked_get): 129 | s3_client = S3Client(s3ol_access_point="Random_access_point") 130 | self.assertRaises(UnsupportedFileException, s3_client.download_file_from_presigned_url, PRESIGNED_URL_TEST, {}) 131 | 132 | mocked_get.assert_called_once() 133 | 134 | @patch('clients.s3_client.requests.Session.get', 135 | side_effect=lambda *args, **kwargs: MockResponse(get_s3_xml_response('InternalError').encode('utf-8'), 200, 136 | {'Content-Length': '4'})) 137 | def test_s3_client_download_file_from_presigned_retry(self, mocked_get): 138 | s3_client = S3Client(s3ol_access_point="Random_access_point") 139 | self.assertRaises(S3DownloadException, s3_client.download_file_from_presigned_url, PRESIGNED_URL_TEST, {}) 140 | 141 | assert mocked_get.call_count == 5 142 | -------------------------------------------------------------------------------- /test/unit/test_processors.py: -------------------------------------------------------------------------------- 1 | import os 2 | import timeit 3 | from random import shuffle 4 | from unittest import TestCase 5 | 6 | from constants import REPLACE_WITH_PII_ENTITY_TYPE 7 | from data_object import Document, RedactionConfig 8 | from exceptions import InvalidConfigurationException 9 | from processors import Redactor, Segmenter 10 | 11 | this_module_path = os.path.dirname(__file__) 12 | 13 | 14 | class ProcessorsTest(TestCase): 15 | def test_segmenter_basic_text(self): 16 | segmentor = Segmenter(50, overlap_tokens=3) 17 | original_text = "Barack Hussein Obama II is an American politician and attorney who served as the " \ 18 | "44th president of the United States from 2009 to 2017." 19 | segments = segmentor.segment(original_text) 20 | expected_segments = [ 21 | "Barack Hussein Obama II is an American politician ", 22 | "an American politician and attorney who served as ", 23 | "who served as the 44th president of the United ", 24 | "of the United States from 2009 to 2017."] 25 | for expected_segment, actual_segment in zip(expected_segments, segments): 26 | assert expected_segment == actual_segment.text 27 | shuffle(segments) 28 | assert segmentor.de_segment(segments).text == original_text 29 | 30 | 31 | def test_segmenter_no_segmentation_needed(self): 32 | segmentor = Segmenter(5000, overlap_tokens=3) 33 | original_text = "Barack Hussein Obama II is an American politician and attorney who served as the " \ 34 | "44th president of the United States from 2009 to 2017." 35 | segments = segmentor.segment(original_text) 36 | assert len(segments) == 1 37 | assert segments[0].text == original_text 38 | assert segmentor.de_segment(segments).text == original_text 39 | 40 | def test_segmenter_max_chars_limit(self): 41 | segmentor = Segmenter(50, overlap_tokens=3, max_overlapping_chars=20) 42 | original_text = "BarackHusseinObamaIIisanAmerican politicianandattorneywhoservedasthe " \ 43 | "44th president of the United Statesfrom2009to2017." 44 | segments = segmentor.segment(original_text) 45 | expected_segments = [ 46 | "BarackHusseinObamaIIisanAmerican ", 47 | "nObamaIIisanAmerican politician", 48 | "nAmerican politicianandattorneywhoservedasthe ", 49 | "torneywhoservedasthe 44th president of the United ", 50 | "of the United Statesfrom2009to2017.", ] 51 | for expected_segment, actual_segment in zip(expected_segments, segments): 52 | assert expected_segment == actual_segment.text 53 | shuffle(segments) 54 | assert segmentor.de_segment(segments).text == original_text 55 | 56 | def test_segmenter_unicode_chars(self): 57 | segmentor = Segmenter(100, overlap_tokens=3) 58 | original_text = "ʕ•́ᴥ•̀ʔっ♡ Emoticons 😜 ʕ•́ᴥ•̀ʔっ♡ Emoticons 😜 ᗷᙓ ò¥¥¥¥¥¥¥ᗢᖇᓮᘐᓰﬡᗩᒪ ℬ℮ ¢◎øł Bᴇ ʏᴏᴜʀsᴇʟғ विकिपीडिया सभी विषयों पर प्रामाणिक और उपयोग, " \ 59 | "परिवर्तन व पुनर्वितरण के लिए स्वतन्त्र ज्ञानकोश बनाने hànbǎobāo, hànbǎo 汉堡包/漢堡包, 汉堡/漢堡 – hamburger" 60 | segments = segmentor.segment(original_text) 61 | expected_segments = [ 62 | "ʕ•́ᴥ•̀ʔっ♡ Emoticons 😜 ʕ•́ᴥ•̀ʔっ♡ Emoticons 😜 ᗷᙓ ", 63 | "Emoticons 😜 ᗷᙓ ò¥¥¥¥¥¥¥ᗢᖇᓮᘐᓰﬡᗩᒪ ℬ℮ ¢◎øł Bᴇ ", 64 | "ℬ℮ ¢◎øł Bᴇ ʏᴏᴜʀsᴇʟғ विकिपीडिया सभी ", 65 | "ʏᴏᴜʀsᴇʟғ विकिपीडिया सभी विषयों पर ", 66 | "सभी विषयों पर प्रामाणिक और उपयोग, ", 67 | "प्रामाणिक और उपयोग, परिवर्तन व ", 68 | "उपयोग, परिवर्तन व पुनर्वितरण के लिए ", 69 | "पुनर्वितरण के लिए स्वतन्त्र ", 70 | "के लिए स्वतन्त्र ज्ञानकोश बनाने hànbǎobāo, ", 71 | "ज्ञानकोश बनाने hànbǎobāo, hànbǎo 汉堡包/漢堡包, 汉堡/漢堡 ", 72 | "hànbǎo 汉堡包/漢堡包, 汉堡/漢堡 – hamburger"] 73 | assert len(expected_segments) == len(segments) 74 | for expected_segment, actual_segment in zip(expected_segments, segments): 75 | assert expected_segment == actual_segment.text 76 | assert segmentor.de_segment(segments).text == original_text 77 | 78 | def test_desegment_overlapping_results(self): 79 | segments = [ 80 | Document(text="Some Random SSN Some Random email-id Some Random name and address and some credit card number", char_offset=0, 81 | pii_classification={'SSN': 0.234, 'EMAIL': 0.765, 'NAME': 0.124, 'ADDRESS': 0.976}, 82 | pii_entities=[{'Score': 0.234, 'Type': 'SSN', 'BeginOffset': 12, 'EndOffset': 36}, 83 | {'Score': 0.765, 'Type': 'EMAIL', 'BeginOffset': 28, 'EndOffset': 36}, 84 | {'Score': 0.534, 'Type': 'NAME', 'BeginOffset': 49, 'EndOffset': 53}, 85 | {'Score': 0.234, 'Type': 'ADDRESS', 'BeginOffset': 58, 'EndOffset': 65}]), 86 | Document(text="Some Random name and address and some credit card number", char_offset=37, 87 | pii_classification={'SSN': 0.234, 'EMAIL': 0.765, 'USERNAME': 0.424, 'ADDRESS': 0.976}, 88 | pii_entities=[{'Score': 0.234, 'Type': 'USERNAME', 'BeginOffset': 12, 'EndOffset': 16}, 89 | {'Score': 0.634, 'Type': 'ADDRESS', 'BeginOffset': 17, 'EndOffset': 28}, 90 | {'Score': 0.234, 'Type': 'CREDIT_DEBIT_NUMBER', 'BeginOffset': 38, 'EndOffset': 56}])] 91 | segmentor = Segmenter(5000) 92 | expected_merged_document = Document( 93 | text="Some Random SSN Some Random email-id Some Random name and address and some credit card number", char_offset=37, 94 | pii_classification={'SSN': 0.234, 'EMAIL': 0.765, 'NAME': 0.124, 'USERNAME': 0.424, 'ADDRESS': 0.976}, 95 | pii_entities=[{'Score': 0.234, 'Type': 'SSN', 'BeginOffset': 12, 'EndOffset': 36}, 96 | {'Score': 0.765, 'Type': 'EMAIL', 'BeginOffset': 28, 'EndOffset': 36}, 97 | {'Score': 0.534, 'Type': 'NAME', 'BeginOffset': 49, 'EndOffset': 53}, 98 | {'Score': 0.634, 'Type': 'ADDRESS', 'BeginOffset': 54, 'EndOffset': 65}, 99 | {'Score': 0.234, 'Type': 'CREDIT_DEBIT_NUMBER', 'BeginOffset': 75, 'EndOffset': 93}]) 100 | actual_merged_doc = segmentor.de_segment(segments) 101 | assert expected_merged_document.text == actual_merged_doc.text 102 | assert expected_merged_document.pii_classification == actual_merged_doc.pii_classification 103 | assert expected_merged_document.pii_entities == actual_merged_doc.pii_entities 104 | 105 | 106 | def test_is_overlapping_annotations(self): 107 | segmentor = Segmenter(5000) 108 | assert segmentor._is_overlapping_annotations({'Score': 0.634, 'Type': 'ADDRESS', 'BeginOffset': 54, 'EndOffset': 65}, 109 | {'Score': 0.234, 'Type': 'ADDRESS', 'BeginOffset': 58, 'EndOffset': 65}) == 0 110 | 111 | def test_segmenter_scalablity_test(self): 112 | # 1MB of text should be segmented with around 30 ms latency 113 | setup = """ 114 | import os 115 | from processors import Segmenter 116 | 117 | text=" Hello Zhang Wei. Your AnyCompany Financial Services, LLC credit card account 1111-0000-1111-0000 has a minimum payment of $24.53" 118 | one_mb_text=" " 119 | for i in range(7751): 120 | one_mb_text += text 121 | segmenter = Segmenter(overlap_tokens=20, max_doc_size=5000) 122 | """.format(this_module_path) 123 | segmentation_time = timeit.timeit("segmenter.segment(one_mb_text)", setup=setup, number=100) 124 | assert segmentation_time < 15 125 | 126 | def test_redaction_with_no_entities(self): 127 | text = "Hello Zhang Wei. Your AnyCompany Financial Services, LLC credit card account 1111-0000-1111-0000 has a minimum payment of $24.53" 128 | redactor = Redactor(RedactionConfig()) 129 | redacted_text = redactor.redact(text, []) 130 | assert text == redacted_text 131 | 132 | def test_redaction_default_redaction_config(self): 133 | text = "Hello Zhang Wei. Your AnyCompany Financial Services, LLC credit card account 1111-0000-1111-0000 has a minimum payment of $24.53" 134 | redactor = Redactor(RedactionConfig()) 135 | redacted_text = redactor.redact(text, [{'Score': 0.234, 'Type': 'NAME', 'BeginOffset': 6, 'EndOffset': 16}, 136 | {'Score': 0.765, 'Type': 'CREDIT_DEBIT_NUMBER', 'BeginOffset': 77, 'EndOffset': 96}]) 137 | expected_redaction = "Hello Zhang Wei. Your AnyCompany Financial Services, LLC credit card account ******************* has a minimum payment of $24.53" 138 | assert expected_redaction == redacted_text 139 | 140 | def test_redaction_with_replace_entity_type(self): 141 | text = "Hello Zhang Wei. Your AnyCompany Financial Services, LLC credit card account 1111-0000-1111-0000 has a minimum payment of $24.53" 142 | redactor = Redactor(RedactionConfig(pii_entity_types=['NAME'], mask_mode=REPLACE_WITH_PII_ENTITY_TYPE, confidence_threshold=0.6)) 143 | redacted_text = redactor.redact(text, [{'Score': 0.634, 'Type': 'NAME', 'BeginOffset': 6, 'EndOffset': 15}, 144 | {'Score': 0.765, 'Type': 'CREDIT_DEBIT_NUMBER', 'BeginOffset': 77, 'EndOffset': 96}]) 145 | expected_redaction = "Hello [NAME]. Your AnyCompany Financial Services, LLC credit card account 1111-0000-1111-0000 has a minimum payment of $24.53" 146 | assert expected_redaction == redacted_text 147 | 148 | def test_segmenter_constructor_invalid_args(self): 149 | try: 150 | Segmenter(3) 151 | assert False, "Expected an InvalidConfigurationException" 152 | except InvalidConfigurationException: 153 | return 154 | -------------------------------------------------------------------------------- /REDACTION_README.md: -------------------------------------------------------------------------------- 1 | # PII Redaction S3 Object Lambda function 2 | 3 | This serverless app helps you redact PII (Personally Identifiable Information) from valid text files present in S3. 4 | This app deploys a Lambda function which can be attached to S3 object lambda access point. 5 | The lambda function internally uses AWS Comprehend to detect PII entities from the text 6 | 7 | ## App Architecture 8 | ![Architecture Diagram](images/architecture_redaction.gif) 9 | 10 | Lambda function is optimized to leverage Comprehend's ContainsPiiEntities and DetectPIIEntities API efficiently to save time and cost for large documents 11 | ContainsPiiEntities API is a high-throughput API which acts as a filter so only those documents that contain PII are sent to the DetectPIIEntities API for the compute heavy PII redaction task 12 | 13 | 1. Lambda function is invoked with a request containing information about the S3 object to get and transform. 14 | 2. The request contains a S3 presigned url to fetch the requested object. 15 | 3. The data is split into chunks that are accepted by Comprehend’s ContainsPiiEntities API and call the API with each chunk. 16 | 4. For each chunk which contains PII data, the chunks are split further into smaller chunks with max size supported by Comprehend’s DetectPiiEntities API. 17 | 5. For each of these smaller PII chunks, Comprehend’s DetectPIIEntities API is invoked to detect the text spans containing interested PII entities. 18 | 6. The responses are aggregated from all chunks. 19 | 7. Lambda function callsback S3 with the response i.e the redacted document. 20 | 8. If any failure happens while processing, Lambda function returns an appropriate error response to S3 which will be returned to the original caller. 21 | 9. Lambda function returns with 0 exit code .i.e. with out any error if no error occurred else would fail. 22 | 23 | ## Installation Instructions 24 | 25 | 1. [Create an AWS account](https://portal.aws.amazon.com/gp/aws/developer/registration/index.html) if you do not already have one and login 26 | 1. Go to the app's page on the [Serverless Application Repository](https://console.aws.amazon.com/lambda/home#/create/app?applicationId=arn:aws:serverlessrepo:us-east-1:839782855223:applications/ComprehendPiiRedactionS3ObjectLambda) 27 | 1. Provide the required app parameters (see parameter details below) and click "Deploy" 28 | 29 | ## Parameters 30 | Following are the parameters that you can tune to get desired behavior 31 | #### Environment variables 32 | Following environment variables for Lambda function can be set to get desired behaviour. 33 | 1. `LOG_LEVEL` - Log level for Lambda function function logging, e.g., ERROR, INFO, DEBUG, etc. Default: `INFO`. 34 | 1. `UNSUPPORTED_FILE_HANDLING` Handling logic for Unsupported files. Valid values are `PASS` and `FAIL` (Default: `FAIL`). If set to `FAIL` it will throw UnsupportedFileException when the requested object is of unsupported type. 35 | 1. `IS_PARTIAL_OBJECT_SUPPORTED` Whether to support partial objects or not. Accessing partial object through http headers such byte-range can corrupt the object and/or affect PII detection accuracy. Valid values are `TRUE` and `FALSE`. Default: `FALSE`. 36 | 1. `DOCUMENT_MAX_SIZE_CONTAINS_PII_ENTITIES` Maximum document size (in bytes) to be used for making calls to Comprehend's ContainsPiiDocument API for classifying PII entity types present in the doc Default: 50000. 37 | 1. `DOCUMENT_MAX_SIZE_DETECT_PII_ENTITIES`: Maximum document size (in bytes) to be used for making calls to Comprehend's DetectPiiEntities API. Default: 5120 i.e. 5KB. 38 | 1. `PII_ENTITY_TYPES` : List of comma separated PII entity types to be considered for redaction. Refer [Comprehend's documentation page](https://docs.aws.amazon.com/comprehend/latest/dg/how-pii.html#how-pii-types) for list of supported PII entity types. Default: `ALL` which signifies all entity types that comprehend supports. 39 | 1. `MASK_CHARACTER` : A character that replaces each character in the redacted PII entity. Default: *. 40 | 1. `MASK_MODE` : Specifies whether the PII entity is redacted with the mask character or the entity type. Valid values: `MASK` and `REPLACE_WITH_PII_ENTITY_TYPE`. Default: `MASK`. 41 | 1. `SUBSEGMENT_OVERLAPPING_TOKENS` : Number of tokens/words to overlap among segments of a document in case chunking is needed because of maximum document size limit. Default: 20. 42 | 1. `DOCUMENT_MAX_SIZE` : Default maximum document size (in bytes) that this function can process otherwise will throw exception for too large document size. 43 | 1. `CONFIDENCE_THRESHOLD` : The minimum prediction confidence score above which PII classification and detection would be considered as final answer. Valid range (0.5 to 1.0). Default: 0.5 . 44 | 1. `MAX_CHARS_OVERLAP` : Maximum characters to overlap among segments of a document in case chunking is needed because of maximum document size limit. Default: 2. 45 | 1. `DEFAULT_LANGUAGE_CODE` : Default language of the text to be processed. This code will be used for interacting with Comprehend . Default: en. 46 | 1. `DETECT_PII_ENTITIES_THREAD_COUNT` : Number of threads to use for calling Comprehend's DetectPiiEntities API. This controls the number of simultaneous calls that will be made from this Lambda function. Default: 8. 47 | 1. `CONTAINS_PII_ENTITIES_THREAD_COUNT` : Number of threads to use for calling Comprehend's ContainsPiiEntities API. This controls the number of simultaneous calls the will be made from this Lambda function. Default: 20. 48 | 1. `PUBLISH_CLOUD_WATCH_METRICS` : This determines whether or not to publish metrics to Cloudwatch. Default: true. 49 | 50 | #### Runtime variables 51 | You can add following arguments in S3 object lambda access point configuration payload to override the default value configured used by the Lambda function . These values would take precedence over environment variables. Provide these variables as a json string like the following example. 52 | ``` 53 | ... 54 | "payload": "{\"pii_entity_types\" : [\"CREDIT_DEBIT_NUMBER\"],\"mask_mode\":\"MASK\", \"mask_character\" : \"*\",\"confidence_threshold\":0.6,\"language_code\":\"en\"}" 55 | ... 56 | ``` 57 | Use these parameters to get desired behaviors from different access point configuration attached to the same lambda function. 58 | 1. `pii_entity_types` : List of PII entity types to be considered for redaction. e.g. `["SSN","CREDIT_DEBIT_NUMBER"]` 59 | 1. `mask_mode` : Specifies whether the PII entity is redacted with the mask character or the entity type. Valid values: `MASK` and `REPLACE_WITH_PII_ENTITY_TYPE`. 60 | 1. `mask_character` : A character that replaces each character in the redacted PII entity. 61 | 1. `confidence_threshold` :The minimum prediction confidence score above which PII classification and detection would be considered as final answer. 62 | 1. `language_code`: Language of the text. This will be used to interact with Comprehend. 63 | 64 | ## App Outputs 65 | 66 | #### Successful response 67 | In case the text file contains PII, it would be redacted and returned in response to GetObject API output 68 | #### Error responses 69 | Lambda function would forward the standard [S3 error responses](https://docs.aws.amazon.com/AmazonS3/latest/API/ErrorResponses.html) it will receive while downloading the file from S3 70 | 71 | Further following error responses will be thrown by Lambda function: 72 | 73 | |Status Code|Error Code|Error Message|Description| 74 | |---|---|---|---| 75 | | BAD_REQUEST_400 | InvalidRequest | Lambda function has been incorrectly setup | An incorrect configuration which restricts lambda function to even start handling the incoming events| 76 | | BAD_REQUEST_400 | UnexpectedContent | Unsupported file encountered for determining PII | This error would be thrown in case caller tries to get an invalid utf8 file (e.g image) and UNSUPPORTED_FILE_HANDLING variable is set to FAIL| 77 | | BAD_REQUEST_400 | EntityTooLarge | Size of the requested object exceeds maximum file size supported | This error would be thrown in case caller tries to get an object which is beyond the max file size supported| 78 | | BAD_REQUEST_400 | RequestTimeout | Failed to complete document processing within time limit | This error would be thrown in case lambda is not able to complete the processing of the document within the time limit. This could be because your file size is too big or you are getting throttled by either S3 or Comprehend.| 79 | | INTERNAL_SERVER_ERROR_500 | InternalError | An internal error occurred while processing the file | Any other error occurred while processing the object | 80 | 81 | ## Metrics 82 | Metrics are published after each invocation of the lambda function and are a best effort attempt (Failures in CloudWatch metric publishing are ignored) 83 | 84 | All metrics will be under the Namespace: ComprehendS3ObjectLambda 85 | 86 | ### Metrics for processed documents 87 | |MetricName|Description|Unit|Dimensions| 88 | |---|---|---|---| 89 | |PiiDocumentsProcessed|Emitted after processing a document that contains pii|Count|S3ObjectLambdaAccessPoint, Language| 90 | |DocumentsProcessed|Emitted after processing any document|Count|S3ObjectLambdaAccessPoint, Language| 91 | |PiiDocumentTypesProcessed|Emitted after processing a document that contains PII for each type of PII of interest|Count|S3ObjectLambdaAccessPoint, Language, PiiEntityType| 92 | 93 | ### Metrics for Comprehend operations 94 | |MetricName|Description|Unit|Dimensions| 95 | |---|---|---|---| 96 | |Latency|The latency of Comprehend DetectPiiEntities API|Milliseconds|Comprehend, DetectPiiEntities| 97 | |Latency|The latency of Comprehend ContainsPiiEntities API|Milliseconds|Comprehend, ContainsPiiEntities| 98 | |ErrorCount|The error count of Comprehend DetectPiiEntities API|Count|Comprehend, DetectPiiEntities| 99 | |ErrorCount|The error count of Comprehend ContainsPiiEntities API|Count|Comprehend, ContainsPiiEntities| 100 | 101 | ### Metrics for S3 operations 102 | |MetricName|Description|Unit|Dimensions| 103 | |---|---|---|---| 104 | |Latency|The latency of S3 WriteGetObjectResponse API|Milliseconds|S3, WriteGetObjectResponse| 105 | |Latency|The latency of downloading a file from a presigned S3 url|Milliseconds|S3, DownloadPresignedUrl| 106 | |ErrorCount|The fault count of S3 WriteGetObjectResponse API|Count|S3, WriteGetObjectResponse| 107 | |ErrorCount|The fault count of downloading a file from a presigned S3 url|Count|S3, DownloadPresignedUrl| 108 | 109 | ## License Summary 110 | 111 | This code is made available under the MIT-0 license. See the LICENSE file. -------------------------------------------------------------------------------- /src/clients/s3_client.py: -------------------------------------------------------------------------------- 1 | """Client wrapper over aws services.""" 2 | import re 3 | import time 4 | import urllib 5 | from typing import Tuple 6 | 7 | import boto3 8 | import botocore 9 | import requests 10 | from requests.adapters import HTTPAdapter 11 | from urllib3.util.retry import Retry 12 | 13 | import lambdalogging 14 | from clients.cloudwatch_client import Metrics 15 | from config import DOCUMENT_MAX_SIZE 16 | from constants import CONTENT_LENGTH, S3_STATUS_CODES, S3_ERROR_CODES, error_code_to_enums, WRITE_GET_OBJECT_RESPONSE, \ 17 | DOWNLOAD_PRESIGNED_URL, S3_MAX_RETRIES, S3, http_status_code_to_s3_status_code 18 | from exceptions import UnsupportedFileException, FileSizeLimitExceededException, S3DownloadException 19 | 20 | LOG = lambdalogging.getLogger(__name__) 21 | 22 | 23 | class S3Client: 24 | """Wrapper over s3 client.""" 25 | 26 | ERROR_RE = r'(?<=).*(?=<\/Error>)' 27 | CODE_RE = r'(?<=).*(?=<\/Code>)' 28 | MESSAGE_RE = r'(?<=).*(?=<\/Message>)' 29 | XML_HEADER = '' 30 | S3_DOWNLOAD_MAX_RETRIES = 5 31 | S3_RETRY_STATUS_CODES = [429, 500, 502, 503, 504] 32 | BACKOFF_FACTOR = 1.5 33 | MAX_GET_TIMEOUT = 10 34 | # Translation map from response headers of s3's getObject response to S3OL's WriteGetObjectResponse's request headers 35 | S3GET_TO_WGOR_HEADER_TRANSLATION_MAP = { 36 | "accept-ranges": ("AcceptRanges", str), 37 | "Cache-Control": ("CacheControl", str), 38 | "Content-Disposition": ("ContentDisposition", str), 39 | "Content-Encoding": ("ContentEncoding", str), 40 | "Content-Language": ("ContentLanguage", str), 41 | "Content-Length": ("ContentLength", int), 42 | "Content-Range": ("ContentRange", str), 43 | "Content-Type": ("ContentType", str), 44 | "x-amz-delete-marker": ("DeleteMarker", bool), 45 | "ETag": ("ETag", str), 46 | "Expires": ("Expires", str), 47 | "x-amz-expiration": ("Expiration", str), 48 | "Last-Modified": ("LastModified", str), 49 | "x-amz-missing-meta": ("MissingMeta", str), 50 | "x-amz-meta-": ("Metadata", str), 51 | "x-amz-object-lock-mode": ("ObjectLockMode", str), 52 | "x-amz-object-lock-legal-hold": ("ObjectLockLegalHoldStatus", str), 53 | "x-amz-object-lock-retain-until-date": ("ObjectLockRetainUntilDate", str), 54 | "x-amz-mp-parts-count": ("PartsCount", int), 55 | "x-amz-replication-status": ("ReplicationStatus", str), 56 | "x-amz-request-charged": ("RequestCharged", str), 57 | "x-amz-restore": ("Restore", str), 58 | "x-amz-server-side-encryption": ("ServerSideEncryption", str), 59 | "x-amz-server-side-encryption-customer-algorithm": ("SSECustomerAlgorithm", str), 60 | "x-amz-server-side-encryption-aws-kms-key-id": ("SSEKMSKeyId", str), 61 | "x-amz-server-side-encryption-customer-key-MD5": ("SSECustomerKeyMD5", str), 62 | "x-amz-storage-class": ("StorageClass", str), 63 | "x-amz-tagging-count": ("TagCount", int), 64 | "x-amz-version-id": ("VersionId", str), 65 | } 66 | # Restricted http headers that can't be sent to s3 as part of downloading object using preseigned url 67 | # Adding these headers can causes a mismatch with Sigv4 signature 68 | BLOCKED_REQUEST_HEADERS = ("Host") 69 | 70 | def __init__(self, s3ol_access_point: str, max_file_supported=DOCUMENT_MAX_SIZE): 71 | self.max_file_supported = max_file_supported 72 | session_config = botocore.config.Config( 73 | retries={ 74 | 'max_attempts': S3_MAX_RETRIES, 75 | 'mode': 'standard' 76 | }) 77 | self.s3 = boto3.client('s3', config=session_config) 78 | 79 | self.session = requests.Session() 80 | self.session.mount("https://", adapter=HTTPAdapter(max_retries=Retry( 81 | total=self.S3_DOWNLOAD_MAX_RETRIES, 82 | status_forcelist=self.S3_RETRY_STATUS_CODES, 83 | method_whitelist=["GET"], 84 | backoff_factor=self.BACKOFF_FACTOR 85 | ))) 86 | 87 | self.download_metrics = Metrics(service_name=S3, api=DOWNLOAD_PRESIGNED_URL, s3ol_access_point=s3ol_access_point) 88 | self.write_get_object_metrics = Metrics(service_name=S3, api=WRITE_GET_OBJECT_RESPONSE, s3ol_access_point=s3ol_access_point) 89 | 90 | def _contains_error(self, response) -> Tuple[bool, Tuple[str, str, S3_STATUS_CODES]]: 91 | text = response.content.decode('utf-8') 92 | lines = text.split('\n') 93 | # All 200-299 status codes are succesfull responses . 206 is for partial code . 94 | if response.status_code >= 300 or (len(lines) > 0 and lines[0] == self.XML_HEADER): 95 | xml = ''.join(lines[1:]) 96 | LOG.info('Response status code >=300 or text contains xml. ') 97 | error_match = re.search(self.ERROR_RE, xml) 98 | code_match = re.search(self.CODE_RE, xml) 99 | message_match = re.search(self.MESSAGE_RE, xml) 100 | if error_match and code_match and message_match: 101 | error_code = code_match[0] 102 | error_message = message_match[0] 103 | return True, (error_code, error_message, http_status_code_to_s3_status_code(response.status_code)) 104 | elif response.status_code >= 300: 105 | return True, ( 106 | S3_ERROR_CODES.InternalError.name, "Internal Server Error", http_status_code_to_s3_status_code(response.status_code)) 107 | return False, ('', '', http_status_code_to_s3_status_code(response.status_code)) 108 | 109 | def _parse_response_headers(self, headers): 110 | """ 111 | Convert response headers received from s3 presigned download call to the format similar to arguments of WriteGetObjectResponse API. 112 | :param headers: http headers received as part of response from downloading the object from s3 113 | """ 114 | transformed_headers = {} 115 | for header_name in headers: 116 | if header_name in self.S3GET_TO_WGOR_HEADER_TRANSLATION_MAP: 117 | header_value = self.S3GET_TO_WGOR_HEADER_TRANSLATION_MAP[header_name][1](headers[header_name]) 118 | transformed_headers[self.S3GET_TO_WGOR_HEADER_TRANSLATION_MAP[header_name][0]] = header_value 119 | 120 | return transformed_headers 121 | 122 | def _filter_request_headers(self, presigned_url, headers={}): 123 | """ 124 | Filter some restricted headers that shouldn't be passed along to s3 when downloading the object. 125 | :param headers: http header from the incoming request 126 | :return: a filtered list of headers 127 | """ 128 | filtered_headers = {} 129 | parsed_url = urllib.parse.urlparse(presigned_url) 130 | parsed_query_params = urllib.parse.parse_qs(parsed_url.query) 131 | signed_headers = set(parsed_query_params.get('X-Amz-SignedHeaders', [])) 132 | 133 | for header in headers: 134 | if header in self.BLOCKED_REQUEST_HEADERS: 135 | continue 136 | if str(header).lower().startswith('x-amz-') and header not in signed_headers: 137 | continue 138 | filtered_headers[header] = headers[header] 139 | return filtered_headers 140 | 141 | def download_file_from_presigned_url(self, presigned_url, headers=None) -> Tuple[str, map, S3_STATUS_CODES]: 142 | """ 143 | Download the file from a s3's presigned url. 144 | Python AWS-SDK doesn't provide any method to download from a presigned url directly so we'd have to make a simple GET httpcall. 145 | """ 146 | parsed_headers = self._filter_request_headers(presigned_url, headers) 147 | for i in range(self.S3_DOWNLOAD_MAX_RETRIES): 148 | start_time = time.time() 149 | LOG.debug(f"Downloading object with presigned url {presigned_url} and headers: {parsed_headers}") 150 | response = self.session.get(presigned_url, timeout=self.MAX_GET_TIMEOUT, headers=parsed_headers) 151 | end_time = time.time() 152 | try: 153 | # Since presigned urls do not return correct status codes when there is an error, 154 | # the xml must be parsed to find the error code and status 155 | error_detected, (error_code, error_message, response_status_code) = self._contains_error(response) 156 | if error_detected: 157 | status_code_enum, error_code_enum = error_code_to_enums(error_code) 158 | LOG.error(f"Error downloading file from presigned url. ({error_code}: {error_message})") 159 | status_code = int(status_code_enum.name[-3:]) 160 | if status_code not in self.S3_RETRY_STATUS_CODES or i == self.S3_DOWNLOAD_MAX_RETRIES - 1: 161 | LOG.error("Client error or max retries reached for downloading file from presigned url.") 162 | self.download_metrics.add_fault_count() 163 | raise S3DownloadException(error_code, error_message) 164 | else: 165 | text_content = response.content.decode('utf-8') 166 | if CONTENT_LENGTH in response.headers and int(response.headers.get(CONTENT_LENGTH)) > self.max_file_supported: 167 | raise FileSizeLimitExceededException("File too large to process") 168 | self.download_metrics.add_latency(start_time, end_time) 169 | return text_content, response.headers, response_status_code, 170 | time.sleep(max(1.0, i ** self.BACKOFF_FACTOR)) 171 | except UnicodeDecodeError: 172 | raise UnsupportedFileException(response.content, response.headers, "Not a valid utf-8 file") 173 | 174 | def respond_back_with_data(self, data, headers: map, request_route: str, request_token: str, 175 | status_code: S3_STATUS_CODES = S3_STATUS_CODES.OK_200): 176 | """Call S3's WriteGetObjectResponse API to return the processed object back to the original caller of get_object API.""" 177 | start_time = time.time() 178 | try: 179 | parsed_headers = self._parse_response_headers(headers) 180 | LOG.debug(f"Calling s3 WriteGetObjectResponse with RequestRoute:{request_route} , headers: {parsed_headers}," 181 | f" RequestToken: {request_token}") 182 | self.s3.write_get_object_response(StatusCode=status_code.get_http_status_code(), Body=data, RequestRoute=request_route, 183 | RequestToken=request_token, **parsed_headers) 184 | except Exception as error: 185 | LOG.error("Error occurred while calling s3 write get object response with data.", exc_info=True) 186 | self.write_get_object_metrics.add_fault_count() 187 | raise error 188 | finally: 189 | self.write_get_object_metrics.add_latency(start_time, time.time()) 190 | 191 | def respond_back_with_error(self, status_code: S3_STATUS_CODES, error_code: S3_ERROR_CODES, error_message: str, 192 | request_route: str, request_token: str): 193 | """Call S3's WriteGetObjectResponse API to return an error to the original caller of get_object API.""" 194 | start_time = time.time() 195 | try: 196 | self.s3.write_get_object_response(StatusCode=status_code.get_http_status_code(), ErrorCode=error_code.name, 197 | ErrorMessage=error_message, 198 | RequestRoute=request_route, RequestToken=request_token) 199 | except Exception as error: 200 | LOG.error("Error occurred while calling s3 write get object response with error.", exc_info=True) 201 | self.write_get_object_metrics.add_fault_count() 202 | raise error 203 | finally: 204 | self.write_get_object_metrics.add_latency(start_time, time.time()) 205 | -------------------------------------------------------------------------------- /src/processors.py: -------------------------------------------------------------------------------- 1 | """Text processors.""" 2 | 3 | # must be the first import in files with lambda function handlers 4 | from copy import deepcopy 5 | from typing import List 6 | 7 | import lambdalogging 8 | from config import SUBSEGMENT_OVERLAPPING_TOKENS, MAX_CHARS_OVERLAP 9 | from constants import ENTITY_TYPE, BEGIN_OFFSET, END_OFFSET, ALL, REPLACE_WITH_PII_ENTITY_TYPE, SCORE 10 | from data_object import Document 11 | from data_object import RedactionConfig 12 | from exceptions import InvalidConfigurationException 13 | 14 | LOG = lambdalogging.getLogger(__name__) 15 | 16 | 17 | class Segmenter: 18 | """Offer functionality to segment and desegment.""" 19 | 20 | def __init__(self, max_doc_size: int, overlap_tokens: int = SUBSEGMENT_OVERLAPPING_TOKENS, 21 | max_overlapping_chars: int = MAX_CHARS_OVERLAP, **kwargs): 22 | self.max_overlapping_chars = int(max_overlapping_chars) 23 | self.overlap_tokens = int(overlap_tokens) 24 | self.max_doc_size = int(max_doc_size) 25 | # A utf8 character can go upto 4 bytes 26 | if max_doc_size < 4: 27 | raise InvalidConfigurationException( 28 | f"Maximum text size limit ({self.max_doc_size} bytes) is too less to perform segmentation") 29 | 30 | def _trim_to_max_bytes(self, s, max_bytes): 31 | """ 32 | Ensure that the UTF-8 encoding of a string has not more than max_bytes bytes. 33 | 34 | The table below summarizes the format of these different octet types. 35 | Char. number range | UTF-8 octet sequence 36 | (hexadecimal) | (binary) 37 | --------------------+--------------------------------------------- 38 | 0000 0000-0000 007F | 0xxxxxxx 39 | 0000 0080-0000 07FF | 110xxxxx 10xxxxxx 40 | 0000 0800-0000 FFFF | 1110xxxx 10xxxxxx 10xxxxxx 41 | 0001 0000-0010 FFFF | 11110xxx 10xxxxxx 10xxxxxx 10xxxxxx 42 | """ 43 | 44 | def safe_b_of_i(b, i): 45 | try: 46 | return b[i] 47 | except IndexError: 48 | return 0 49 | 50 | # Edge cases 51 | if s == '' or max_bytes < 1: 52 | return '' 53 | 54 | # cut it twice to avoid encoding potentially GBs of string just to get e.g. 10 bytes? 55 | bytes_array = s[:max_bytes].encode('utf-8')[:max_bytes] 56 | 57 | # find the first byte from end which contains the starting byte of a utf8 character which is this format 11xxxxxx for 58 | # multi byte character. For single byte character the format is 0xxxxxxx as described above 59 | if bytes_array[-1] & 0b10000000: 60 | last_11xxxxxx_index = [ 61 | i 62 | for i in range(-1, -5, -1) 63 | if safe_b_of_i(bytes_array, i) & 0b11000000 == 0b11000000 64 | ][0] 65 | # As described above in the table , we can determine the total size(in bytes) of char from the first byte itself 66 | starting_byte = bytes_array[last_11xxxxxx_index] 67 | if not starting_byte & 0b00100000: 68 | last_char_length = 2 69 | elif not starting_byte & 0b00010000: 70 | last_char_length = 3 71 | elif not starting_byte & 0b00001000: 72 | last_char_length = 4 73 | else: 74 | raise Exception(f"Unexpected utf-8 {starting_byte} byte encountered") 75 | 76 | if last_char_length > -last_11xxxxxx_index: 77 | # remove the incomplete character 78 | bytes_array = bytes_array[:last_11xxxxxx_index] 79 | 80 | return bytes_array.decode('utf-8') 81 | 82 | def _trim_partial_trailing_word(self, text): 83 | # find the first space moving backwards 84 | original_length = len(text) 85 | k = original_length - 1 86 | # ensuring we have a hard limit on how back we need to travel. We don't want to travel the whole sentence back 87 | # if there are no spaces in it. Using max_overlapping_chars as proxy for this 88 | while text[k] != ' ' and k > 0 and original_length - k < self.max_overlapping_chars: 89 | k -= 1 90 | trimmed_text = text[:k + 1] 91 | return trimmed_text 92 | 93 | def _find_trailing_overlapping_tokens_start_index(self, text): 94 | word_count = 0 95 | original_length = len(text) 96 | k = original_length - 1 97 | while word_count < self.overlap_tokens: 98 | k -= 1 99 | # Moving backwards: find the beginning of word (next character is space and current character is not space) 100 | while not (text[k + 1] != ' ' and text[k] == ' ') and k > 0 and original_length - k < self.max_overlapping_chars: 101 | k -= 1 102 | word_count += 1 103 | if k == 0: 104 | LOG.debug("Overlapping tokens for the next sentence starts beyond the current sentence") 105 | break 106 | return k 107 | 108 | def _merge_classifcation_results(self, segment: Document, existing_results: map = {}): 109 | for name, score in segment.pii_classification.items(): 110 | if name not in existing_results or ( 111 | name in existing_results and score > existing_results[name]): 112 | existing_results[name] = score 113 | return existing_results 114 | 115 | def _is_overlapping_annotations(self, entity_a, entity_b) -> int: 116 | """ 117 | Determine if one entity overlaps with another. 118 | It will return : 119 | 1 if entity_b, lies on right side of the entity 120 | 0 if entity_b overlaps with entity_a 121 | -1 if entity_b lies on left side of the entity 122 | """ 123 | if entity_a[END_OFFSET] < entity_b[BEGIN_OFFSET]: 124 | return 1 125 | if entity_a[BEGIN_OFFSET] > entity_b[END_OFFSET]: 126 | return -1 127 | else: 128 | return 0 129 | 130 | def _resolve_overlapped_annotation(self, entity_a, entity_b) -> List: 131 | """Merge two overlapping entity annotations.""" 132 | if entity_a[SCORE] >= entity_b[SCORE]: 133 | return [entity_a] 134 | else: 135 | return [entity_b] 136 | 137 | def _merge_pii_annotation_results(self, segment: Document, existing_annotations: List = []): 138 | 139 | if not existing_annotations: 140 | existing_annotations.extend(segment.pii_entities) 141 | return 142 | 143 | for pii_entity in segment.pii_entities: 144 | k = len(existing_annotations) - 1 145 | while k > 0: 146 | overlap_result = self._is_overlapping_annotations(existing_annotations[k], pii_entity) 147 | if overlap_result > 0: 148 | existing_annotations.append(pii_entity) 149 | break 150 | elif overlap_result == 0: 151 | LOG.debug("Annotation: " + str(existing_annotations[k]) + " conflicts with: " + str(pii_entity)) 152 | resolved_annotation = self._resolve_overlapped_annotation(existing_annotations[k], pii_entity) 153 | LOG.debug("Deleting annotation:" + str(existing_annotations[k])) 154 | del existing_annotations[k] 155 | for i, annotation in enumerate(resolved_annotation): 156 | LOG.debug("Adding annotation:" + str(annotation)) 157 | existing_annotations.insert(k + i, annotation) 158 | break 159 | else: 160 | k -= 1 161 | 162 | return existing_annotations 163 | 164 | def _relocate_annotation(self, annotations: List, offset: int): 165 | """Shift the annotated entities by given offset.""" 166 | annotations_copy = deepcopy(annotations) 167 | for annotation in annotations_copy: 168 | annotation[END_OFFSET] += offset 169 | annotation[BEGIN_OFFSET] += offset 170 | return annotations_copy 171 | 172 | def segment(self, text: str, char_offset=0) -> List[Document]: 173 | """Segment the text into segments of max_doc_length with overlap_tokens.""" 174 | segments = [] 175 | starting_index = 0 176 | while len(text[starting_index:].encode()) > self.max_doc_size: 177 | trimmed_text = self._trim_to_max_bytes(text[starting_index:], self.max_doc_size) 178 | trimmed_text = self._trim_partial_trailing_word(trimmed_text) 179 | segments.append(Document(text=trimmed_text, char_offset=char_offset + starting_index)) 180 | starting_index = starting_index + self._find_trailing_overlapping_tokens_start_index(trimmed_text) + 1 181 | # Add the remaining segment 182 | if starting_index < len(text) - 1: 183 | segments.append(Document(text=text[starting_index:], char_offset=char_offset + starting_index)) 184 | return segments 185 | 186 | def de_segment(self, segments: List[Document]) -> Document: 187 | """ 188 | Merge the segments back into one big text. It also merges back the pii classification result. 189 | Handles conflicting result on overlapping text between two text segments in the following ways: 190 | 1. For pii classification, the maximum thresholds for an entity amongst the segments is 191 | updated as the threshold for that entity for the merged document 192 | 2. For pii entity annotations, for a conflicting annotation span a higher priority 193 | is given to the one with a higher confidence threshold 194 | """ 195 | merged_text = "" 196 | pii_classification = {} 197 | pii_entities = [] 198 | segments.sort(key=lambda x: x.char_offset) 199 | for segment in segments: 200 | offset_adjusted_segment = Document(text=segment.text, char_offset=segment.char_offset, 201 | pii_entities=self._relocate_annotation(segment.pii_entities, segment.char_offset), 202 | pii_classification=segment.pii_classification) 203 | self._merge_classifcation_results(segment, pii_classification) 204 | self._merge_pii_annotation_results(offset_adjusted_segment, pii_entities) 205 | merged_text = merged_text + segment.text[len(merged_text) - segment.char_offset:] 206 | return Document(text=merged_text, char_offset=0, pii_classification=pii_classification, pii_entities=pii_entities) 207 | 208 | 209 | class Redactor: 210 | """Handle the logic of redacting discovered pii entities from the given text.""" 211 | 212 | def __init__(self, redaction_config: RedactionConfig): 213 | self.redaction_config = redaction_config 214 | 215 | def redact(self, input_text, entities_list): 216 | """Redact the pii entities from given text.""" 217 | doc_parts_list = [] 218 | prev_entity = None 219 | for entity in entities_list: 220 | if entity[SCORE] < self.redaction_config.confidence_threshold: 221 | continue 222 | entity_type = entity[ENTITY_TYPE] 223 | begin_offset = entity[BEGIN_OFFSET] 224 | end_offset = entity[END_OFFSET] 225 | if prev_entity is None: 226 | doc_parts_list.append(input_text[:begin_offset]) 227 | else: 228 | doc_parts_list.append(input_text[prev_entity[END_OFFSET]:begin_offset]) 229 | 230 | if ALL in self.redaction_config.pii_entity_types or entity_type in self.redaction_config.pii_entity_types: 231 | # Redact this entity type 232 | if self.redaction_config.mask_mode == REPLACE_WITH_PII_ENTITY_TYPE: 233 | # Replace with PII Entity Type 234 | doc_parts_list.append(f"[{entity_type}]") 235 | else: 236 | # Replace with MaskCharacter 237 | entity_length = end_offset - begin_offset 238 | doc_parts_list.append(self.redaction_config.mask_character * entity_length) 239 | else: 240 | # Don't redact this entity type 241 | doc_parts_list.append(input_text[begin_offset:end_offset]) 242 | 243 | prev_entity = entity 244 | if prev_entity is not None: 245 | doc_parts_list.append(input_text[prev_entity[END_OFFSET]:]) 246 | else: 247 | doc_parts_list.append(input_text) 248 | return ''.join([doc_part for doc_part in doc_parts_list]) 249 | -------------------------------------------------------------------------------- /src/handler.py: -------------------------------------------------------------------------------- 1 | """Lambda function handler.""" 2 | 3 | # must be the first import in files with lambda function handlers 4 | import time 5 | import traceback 6 | 7 | import lambdainit # noqa: F401 8 | import json 9 | from typing import List 10 | import lambdalogging 11 | from clients.comprehend_client import ComprehendClient 12 | from clients.s3_client import S3Client 13 | from clients.cloudwatch_client import CloudWatchClient 14 | from config import DOCUMENT_MAX_SIZE_CONTAINS_PII_ENTITIES, DOCUMENT_MAX_SIZE_DETECT_PII_ENTITIES, DEFAULT_LANGUAGE_CODE, \ 15 | PUBLISH_CLOUD_WATCH_METRICS, REDACTION_API_ONLY, COMPREHEND_ENDPOINT_URL 16 | from constants import ALL, REQUEST_ID, GET_OBJECT_CONTEXT, S3OL_ACCESS_POINT_ARN, \ 17 | INPUT_S3_URL, S3OL_CONFIGURATION, REQUEST_ROUTE, REQUEST_TOKEN, PAYLOAD, DEFAULT_USER_AGENT, LANGUAGE_CODE, USER_REQUEST, \ 18 | HEADERS, CONTENT_LENGTH, RESERVED_TIME_FOR_CLEANUP 19 | from data_object import Document, PiiConfig, RedactionConfig, ClassificationConfig 20 | from exception_handlers import ExceptionHandler 21 | from exceptions import RestrictedDocumentException 22 | from processors import Segmenter, Redactor 23 | from util import execute_task_with_timeout 24 | from validators import InputEventValidator, PartialObjectRequestValidator 25 | 26 | LOG = lambdalogging.getLogger(__name__) 27 | 28 | 29 | def get_interested_pii(document: Document, classification_config: PiiConfig): 30 | """ 31 | Get a list of interested pii from the document. 32 | 33 | Return a list of pii entity types of the given document with only the entities of interest 34 | and above the confidence threshold. 35 | """ 36 | pii_entities = [] 37 | for name, score in document.pii_classification.items(): 38 | if name in classification_config.pii_entity_types or ALL in classification_config.pii_entity_types: 39 | if score >= classification_config.confidence_threshold: 40 | pii_entities.append(name) 41 | return pii_entities 42 | 43 | 44 | def publish_metrics(cloud_watch: CloudWatchClient, s3: S3Client, comprehend: ComprehendClient, processed_document: bool, 45 | processed_pii_document: bool, language_code: str, s3ol_access_point: str, pii_entities: List[str]): 46 | """Publish metrics from the function execution.""" 47 | try: 48 | cloud_watch.publish_metrics(s3.download_metrics.metrics + s3.write_get_object_metrics.metrics + 49 | comprehend.classify_metrics.metrics + comprehend.detection_metrics.metrics) 50 | if processed_document: 51 | cloud_watch.put_document_processed_metric(language_code, s3ol_access_point) 52 | if processed_pii_document: 53 | cloud_watch.put_pii_document_processed_metric(language_code, s3ol_access_point) 54 | cloud_watch.put_pii_document_types_metric(pii_entities, language_code, s3ol_access_point) 55 | except Exception as e: 56 | LOG.error(f"Error publishing metrics to cloudwatch. :{e} {traceback.print_exc()}") 57 | 58 | 59 | def redact(text, classification_segmenter: Segmenter, detection_segmenter: Segmenter, 60 | redactor: Redactor, comprehend: ComprehendClient, redaction_config: RedactionConfig, language_code) -> Document: 61 | """ 62 | Redact pii data from given text. Logic for redacting:- . 63 | 64 | 1. Segment text into subsegments of reasonable sizes (max doc size supported by comprehend) for doing initial classification 65 | 2. For each subsegment , 66 | 2.1 call comprehend's classify-pii-document api to determine if it contains any PII data 67 | 2.2 if it contains pii then split it to smaller chunks(e.g. <=5KB), else skip to the next subsegment 68 | 2.3 for each chunk 69 | 2.3.1 call comprehend's detect-pii-entities to extract the pii entities 70 | 2.3.2 redact the pii entities from the chunk 71 | 2.4 merge all chunks 72 | 3. merge all subsegments 73 | """ 74 | if REDACTION_API_ONLY: 75 | doc = Document(text) 76 | documents = [doc] 77 | docs_for_entity_detection = detection_segmenter.segment(doc.text, doc.char_offset) 78 | else: 79 | documents = comprehend.contains_pii_entities(classification_segmenter.segment(text), language_code) 80 | pii_docs = [doc for doc in documents if len(get_interested_pii(doc, redaction_config)) > 0] 81 | if not pii_docs: 82 | LOG.debug("Document doesn't have any pii. Nothing to redact.") 83 | text = classification_segmenter.de_segment(documents).text 84 | return Document(text, redacted_text=text) 85 | docs_for_entity_detection = [] 86 | for pii_doc in pii_docs: 87 | docs_for_entity_detection.extend(detection_segmenter.segment(pii_doc.text, pii_doc.char_offset)) 88 | 89 | docs_with_pii_entities = comprehend.detect_pii_documents(docs_for_entity_detection, language_code) 90 | resultant_doc = classification_segmenter.de_segment(documents + docs_with_pii_entities) 91 | assert len(resultant_doc.text) == len(text), "Not able to recover original document after segmentation and desegmentation." 92 | redacted_text = redactor.redact(text, resultant_doc.pii_entities) 93 | resultant_doc.redacted_text = redacted_text 94 | return resultant_doc 95 | 96 | 97 | def classify(text, classification_segmenter: Segmenter, comprehend: ComprehendClient, 98 | detection_config: ClassificationConfig, language_code) -> List[str]: 99 | """ 100 | Detect pii data from given text. Logic for detecting:- . 101 | 102 | 1. Segment text into segments of reasonable sizes (max doc size supported by comprehend) for 103 | doing initial classification 104 | 2. For each segment, 105 | 2.1 call comprehend's classify-pii-document api to determine if it contains any PII data 106 | 2.2 if it contains pii that is in the detection config then return those pii, else move to the next segment 107 | 3. If no pii detected, return empty list, else list of pii types found that is also in the detection config 108 | and above the given threshold 109 | """ 110 | pii_classified_documents = comprehend.contains_pii_entities(classification_segmenter.segment(text), language_code) 111 | pii_types = set() 112 | for doc in pii_classified_documents: 113 | doc_pii_types = get_interested_pii(doc, detection_config) 114 | pii_types |= set(doc_pii_types) 115 | return list(pii_types) 116 | 117 | 118 | def redact_pii_documents_handler(event, context): 119 | """Redaction Lambda function handler.""" 120 | LOG.info('Received event with requestId: %s', event[REQUEST_ID]) 121 | LOG.debug(f'Raw event {event}') 122 | 123 | InputEventValidator.validate(event) 124 | invoke_args = json.loads(event[S3OL_CONFIGURATION][PAYLOAD]) if event[S3OL_CONFIGURATION][PAYLOAD] else {} 125 | language_code = invoke_args.get(LANGUAGE_CODE, DEFAULT_LANGUAGE_CODE) 126 | redaction_config = RedactionConfig(**invoke_args) 127 | object_get_context = event[GET_OBJECT_CONTEXT] 128 | s3ol_access_point = event[S3OL_CONFIGURATION][S3OL_ACCESS_POINT_ARN] 129 | s3 = S3Client(s3ol_access_point) 130 | cloud_watch = CloudWatchClient() 131 | comprehend = ComprehendClient(s3ol_access_point=s3ol_access_point, session_id=event[REQUEST_ID], user_agent=DEFAULT_USER_AGENT, 132 | endpoint_url=COMPREHEND_ENDPOINT_URL) 133 | 134 | exception_handler = ExceptionHandler(s3) 135 | 136 | LOG.debug("Pii Entity Types to be redacted:" + str(redaction_config.pii_entity_types)) 137 | processed_document = False 138 | document = Document('') 139 | 140 | try: 141 | def time_bound_task(): 142 | nonlocal processed_document 143 | nonlocal document 144 | PartialObjectRequestValidator.validate(event) 145 | pii_classification_segmenter = Segmenter(DOCUMENT_MAX_SIZE_CONTAINS_PII_ENTITIES) 146 | pii_redaction_segmenter = Segmenter(DOCUMENT_MAX_SIZE_DETECT_PII_ENTITIES) 147 | redactor = Redactor(redaction_config) 148 | time1 = time.time() 149 | text, http_headers, status_code = s3.download_file_from_presigned_url(object_get_context[INPUT_S3_URL], 150 | event[USER_REQUEST][HEADERS]) 151 | time2 = time.time() 152 | LOG.info(f"Downloaded the file in : {(time2 - time1)} seconds") 153 | document = redact(text, pii_classification_segmenter, pii_redaction_segmenter, redactor, 154 | comprehend, redaction_config, language_code) 155 | processed_document = True 156 | time1 = time.time() 157 | LOG.info(f"Pii redaction completed within {(time1 - time2)} seconds. Returning back the response to S3") 158 | redacted_text_bytes = document.redacted_text.encode('utf-8') 159 | http_headers[CONTENT_LENGTH] = len(redacted_text_bytes) 160 | s3.respond_back_with_data(redacted_text_bytes, http_headers, object_get_context[REQUEST_ROUTE], 161 | object_get_context[REQUEST_TOKEN], status_code) 162 | 163 | execute_task_with_timeout(context.get_remaining_time_in_millis() - RESERVED_TIME_FOR_CLEANUP, time_bound_task) 164 | except Exception as generated_exception: 165 | exception_handler.handle_exception(generated_exception, object_get_context[REQUEST_ROUTE], object_get_context[REQUEST_TOKEN]) 166 | finally: 167 | if PUBLISH_CLOUD_WATCH_METRICS: 168 | pii_entities = get_interested_pii(document, redaction_config) 169 | publish_metrics(cloud_watch, s3, comprehend, processed_document, len(pii_entities) > 0, language_code, 170 | s3ol_access_point, pii_entities) 171 | 172 | LOG.info("Responded back to s3 successfully") 173 | 174 | 175 | def pii_access_control_handler(event, context): 176 | """Detect Lambda function handler.""" 177 | LOG.info(f'Received event with requestId: {event[REQUEST_ID]}') 178 | LOG.debug(f'Raw event {event}') 179 | 180 | InputEventValidator.validate(event) 181 | invoke_args = json.loads(event[S3OL_CONFIGURATION][PAYLOAD]) if event[S3OL_CONFIGURATION][PAYLOAD] else {} 182 | language_code = invoke_args.get(LANGUAGE_CODE, DEFAULT_LANGUAGE_CODE) 183 | detection_config = ClassificationConfig(**invoke_args) 184 | object_get_context = event[GET_OBJECT_CONTEXT] 185 | s3ol_access_point = event[S3OL_CONFIGURATION][S3OL_ACCESS_POINT_ARN] 186 | 187 | s3 = S3Client(s3ol_access_point) 188 | cloud_watch = CloudWatchClient() 189 | comprehend = ComprehendClient(session_id=event[REQUEST_ID], user_agent=DEFAULT_USER_AGENT, endpoint_url=COMPREHEND_ENDPOINT_URL, 190 | s3ol_access_point=s3ol_access_point) 191 | exception_handler = ExceptionHandler(s3) 192 | 193 | LOG.debug("Pii Entity Types to be detected:" + str(detection_config.pii_entity_types)) 194 | 195 | pii_classification_segmenter = Segmenter(DOCUMENT_MAX_SIZE_CONTAINS_PII_ENTITIES) 196 | 197 | processed_document = False 198 | processed_pii_document = False 199 | pii_entities = [] 200 | 201 | try: 202 | def time_bound_task(): 203 | nonlocal processed_document 204 | nonlocal processed_pii_document 205 | nonlocal pii_entities 206 | PartialObjectRequestValidator.validate(event) 207 | time1 = time.time() 208 | text, http_headers, status_code = s3.download_file_from_presigned_url(object_get_context[INPUT_S3_URL], 209 | event[USER_REQUEST][HEADERS]) 210 | time2 = time.time() 211 | LOG.info(f"Downloaded the file in : {(time2 - time1)} seconds") 212 | pii_entities = classify(text, pii_classification_segmenter, comprehend, detection_config, language_code) 213 | time1 = time.time() 214 | 215 | processed_document = True 216 | LOG.info(f"Pii detection completed within {(time1 - time2)} seconds. Returning back the response to S3") 217 | if len(pii_entities) > 0: 218 | processed_pii_document = True 219 | raise RestrictedDocumentException() 220 | else: 221 | text_bytes = text.encode('utf-8') 222 | http_headers[CONTENT_LENGTH] = len(text_bytes) 223 | s3.respond_back_with_data(text_bytes, http_headers, object_get_context[REQUEST_ROUTE], 224 | object_get_context[REQUEST_TOKEN], 225 | status_code) 226 | 227 | execute_task_with_timeout(context.get_remaining_time_in_millis() - RESERVED_TIME_FOR_CLEANUP, time_bound_task) 228 | except Exception as generated_exception: 229 | exception_handler.handle_exception(generated_exception, object_get_context[REQUEST_ROUTE], object_get_context[REQUEST_TOKEN]) 230 | finally: 231 | if PUBLISH_CLOUD_WATCH_METRICS: 232 | publish_metrics(cloud_watch, s3, comprehend, processed_document, processed_pii_document, language_code, 233 | s3ol_access_point, pii_entities) 234 | 235 | LOG.info("Responded back to s3 successfully") 236 | -------------------------------------------------------------------------------- /src/constants.py: -------------------------------------------------------------------------------- 1 | """Collection of constants used in the code.""" 2 | from enum import Enum, auto 3 | 4 | from typing import Tuple 5 | 6 | ENTITY_TYPE = "Type" 7 | BEGIN_OFFSET = "BeginOffset" 8 | END_OFFSET = "EndOffset" 9 | NAME = "Name" 10 | SCORE = "Score" 11 | ALL = "ALL" 12 | REPLACE_WITH_PII_ENTITY_TYPE = "REPLACE_WITH_PII_ENTITY_TYPE" 13 | MASK = "MASK" 14 | REQUEST_ID = "xAmzRequestId" 15 | USER_REQUEST = "userRequest" 16 | HEADERS = "headers" 17 | RANGE = "Range" 18 | PART_NUMBER = "PartNumber" 19 | REQUEST_ROUTE = "outputRoute" 20 | REQUEST_TOKEN = "outputToken" 21 | GET_OBJECT_CONTEXT = "getObjectContext" 22 | INPUT_S3_URL = "inputS3Url" 23 | S3OL_CONFIGURATION = "configuration" 24 | S3OL_ACCESS_POINT_ARN = "accessPointArn" 25 | CONTENT_LENGTH = "Content-Length" 26 | OVERLAP_TOKENS = "overlap_tokens" 27 | PAYLOAD = "payload" 28 | ONE_DOC_PER_LINE = "ONE_DOC_PER_LINE" 29 | ONE_DOC_PER_FILE = "ONE_DOC_PER_FILE" 30 | LANGUAGE_CODE = "language_code" 31 | 32 | DEFAULT_USER_AGENT = "S3ObjectLambda/1.0" 33 | 34 | RESERVED_TIME_FOR_CLEANUP = 2000 # We need at least this much time (in millis) to perform cleanup tasks like flushing the metrics 35 | COMPREHEND_MAX_RETRIES = 7 36 | S3_MAX_RETRIES = 10 37 | CLOUD_WATCH_NAMESPACE = "ComprehendS3ObjectLambda" 38 | LATENCY = "Latency" 39 | ERROR_COUNT = "ErrorCount" 40 | API = "API" 41 | CONTAINS_PII_ENTITIES = "ContainsPiiEntities" 42 | DETECT_PII_ENTITIES = "DetectPiiEntities" 43 | PII_DOCUMENTS_PROCESSED = "PiiDocumentsProcessed" 44 | DOCUMENTS_PROCESSED = "DocumentsProcessed" 45 | PII_DOCUMENT_TYPES_PROCESSED = "PiiDocumentTypesProcessed" 46 | PII_ENTITY_TYPE = "PiiEntityType" 47 | SERVICE = "Service" 48 | COMPREHEND = "Comprehend" 49 | S3 = "S3" 50 | WRITE_GET_OBJECT_RESPONSE = "WriteGetObjectResponse" 51 | DOWNLOAD_PRESIGNED_URL = "DownloadPresignedUrl" 52 | LANGUAGE = "Language" 53 | MILLISECONDS = "Milliseconds" 54 | COUNT = "Count" 55 | VALUE = "Value" 56 | S3OL_ACCESS_POINT = "S3ObjectLambdaAccessPoint" 57 | METRIC_NAME = "MetricName" 58 | UNIT = "Unit" 59 | DIMENSIONS = "Dimensions" 60 | 61 | 62 | class UNSUPPORTED_FILE_HANDLING_VALID_VALUES(Enum): 63 | """Valid values for handling logic for Unsupported files.""" 64 | 65 | PASS = auto() 66 | FAIL = auto() 67 | 68 | 69 | class MASK_MODE_VALID_VALUES(Enum): 70 | """Valid values for MASK_MODE variable.""" 71 | 72 | MASK = auto() 73 | REPLACE_WITH_PII_ENTITY_TYPE = auto() 74 | 75 | 76 | class S3_STATUS_CODES(Enum): 77 | """ 78 | Valid http status codes for S3. 79 | Refer https://docs.aws.amazon.com/AmazonS3/latest/API/ErrorResponses.html#ErrorCodeList for more details on the status codes. 80 | """ 81 | 82 | OK_200 = auto() 83 | PARTIAL_CONTENT_206 = auto() 84 | NOT_MODIFIED_304 = auto() 85 | BAD_REQUEST_400 = auto() 86 | UNAUTHORIZED_401 = auto() 87 | FORBIDDEN_403 = auto() 88 | NOT_FOUND_404 = auto() 89 | METHOD_NOT_ALLOWED_405 = auto() 90 | CONFLICT_409 = auto() 91 | LENGTH_REQUIRED_411 = auto() 92 | PRECONDITION_FAILED_412 = auto() 93 | RANGE_NOT_SATISFIABLE_416 = auto() 94 | INTERNAL_SERVER_ERROR_500 = auto() 95 | SERVICE_UNAVAILABLE_503 = auto() 96 | 97 | def get_http_status_code(self) -> int: 98 | """Convert s3 status codes to integer http status codes.""" 99 | return int(self.name.split('_')[-1]) 100 | 101 | 102 | class S3_ERROR_CODES(Enum): 103 | """ 104 | Valid error codes for S3. 105 | Refer https://docs.aws.amazon.com/AmazonS3/latest/API/ErrorResponses.html#ErrorCodeList for more details on error code. 106 | """ 107 | 108 | AccessDenied = auto() 109 | AccountProblem = auto() 110 | AllAccessDisabled = auto() 111 | AmbiguousGrantByEmailAddress = auto() 112 | AuthorizationHeaderMalformed = auto() 113 | BadDigest = auto() 114 | BucketAlreadyExists = auto() 115 | BucketAlreadyOwnedByYou = auto() 116 | BucketNotEmpty = auto() 117 | CredentialsNotSupported = auto() 118 | CrossLocationLoggingProhibited = auto() 119 | EntityTooSmall = auto() 120 | EntityTooLarge = auto() 121 | ExpiredToken = auto() 122 | IllegalLocationConstraintException = auto() 123 | IllegalVersioningConfigurationException = auto() 124 | IncompleteBody = auto() 125 | IncorrectNumberOfFilesInPostRequest = auto() 126 | InlineDataTooLarge = auto() 127 | InternalError = auto() 128 | InvalidAccessKeyId = auto() 129 | InvalidAccessPoint = auto() 130 | InvalidAddressingHeader = auto() 131 | InvalidArgument = auto() 132 | InvalidBucketName = auto() 133 | InvalidBucketState = auto() 134 | InvalidDigest = auto() 135 | InvalidEncryptionAlgorithmError = auto() 136 | InvalidLocationConstraint = auto() 137 | InvalidObjectState = auto() 138 | InvalidPart = auto() 139 | InvalidPartOrder = auto() 140 | InvalidPayer = auto() 141 | InvalidPolicyDocument = auto() 142 | InvalidRange = auto() 143 | InvalidRequest = auto() 144 | InvalidSecurity = auto() 145 | InvalidSOAPRequest = auto() 146 | InvalidStorageClass = auto() 147 | InvalidTargetBucketForLogging = auto() 148 | InvalidToken = auto() 149 | InvalidURI = auto() 150 | KeyTooLongError = auto() 151 | MalformedACLError = auto() 152 | MalformedPOSTRequest = auto() 153 | MalformedXML = auto() 154 | MaxMessageLengthExceeded = auto() 155 | MaxPostPreDataLengthExceededError = auto() 156 | MetadataTooLarge = auto() 157 | MethodNotAllowed = auto() 158 | MissingAttachment = auto() 159 | MissingContentLength = auto() 160 | MissingRequestBodyError = auto() 161 | MissingSecurityElement = auto() 162 | MissingSecurityHeader = auto() 163 | NoLoggingStatusForKey = auto() 164 | NoSuchBucket = auto() 165 | NoSuchBucketPolicy = auto() 166 | NoSuchKey = auto() 167 | NoSuchLifecycleConfiguration = auto() 168 | NoSuchUpload = auto() 169 | NoSuchVersion = auto() 170 | NotImplemented = auto() 171 | NotSignedUp = auto() 172 | OperationAborted = auto() 173 | PermanentRedirect = auto() 174 | PreconditionFailed = auto() 175 | Redirect = auto() 176 | RestoreAlreadyInProgress = auto() 177 | RequestIsNotMultiPartContent = auto() 178 | RequestTimeout = auto() 179 | RequestTimeTooSkewed = auto() 180 | RequestTorrentOfBucketError = auto() 181 | ServerSideEncryptionConfigurationNotFoundError = auto() 182 | ServiceUnavailable = auto() 183 | SignatureDoesNotMatch = auto() 184 | SlowDown = auto() 185 | TemporaryRedirect = auto() 186 | TokenRefreshRequired = auto() 187 | TooManyAccessPoints = auto() 188 | TooManyBuckets = auto() 189 | UnexpectedContent = auto() 190 | UnresolvableGrantByEmailAddress = auto() 191 | UserKeyMustBeSpecified = auto() 192 | NoSuchAccessPoint = auto() 193 | InvalidTag = auto() 194 | MalformedPolicy = auto() 195 | 196 | 197 | ERROR_CODE_STATUS_MAP = { 198 | S3_ERROR_CODES.AccessDenied: S3_STATUS_CODES.FORBIDDEN_403, 199 | S3_ERROR_CODES.AccountProblem: S3_STATUS_CODES.FORBIDDEN_403, 200 | S3_ERROR_CODES.AllAccessDisabled: S3_STATUS_CODES.FORBIDDEN_403, 201 | S3_ERROR_CODES.AmbiguousGrantByEmailAddress: S3_STATUS_CODES.BAD_REQUEST_400, 202 | S3_ERROR_CODES.AuthorizationHeaderMalformed: S3_STATUS_CODES.BAD_REQUEST_400, 203 | S3_ERROR_CODES.BadDigest: S3_STATUS_CODES.BAD_REQUEST_400, 204 | S3_ERROR_CODES.BucketAlreadyExists: S3_STATUS_CODES.CONFLICT_409, 205 | S3_ERROR_CODES.BucketAlreadyOwnedByYou: S3_STATUS_CODES.CONFLICT_409, 206 | S3_ERROR_CODES.BucketNotEmpty: S3_STATUS_CODES.CONFLICT_409, 207 | S3_ERROR_CODES.CredentialsNotSupported: S3_STATUS_CODES.BAD_REQUEST_400, 208 | S3_ERROR_CODES.CrossLocationLoggingProhibited: S3_STATUS_CODES.FORBIDDEN_403, 209 | S3_ERROR_CODES.EntityTooSmall: S3_STATUS_CODES.BAD_REQUEST_400, 210 | S3_ERROR_CODES.EntityTooLarge: S3_STATUS_CODES.BAD_REQUEST_400, 211 | S3_ERROR_CODES.ExpiredToken: S3_STATUS_CODES.BAD_REQUEST_400, 212 | S3_ERROR_CODES.IllegalLocationConstraintException: S3_STATUS_CODES.BAD_REQUEST_400, 213 | S3_ERROR_CODES.IllegalVersioningConfigurationException: S3_STATUS_CODES.BAD_REQUEST_400, 214 | S3_ERROR_CODES.IncompleteBody: S3_STATUS_CODES.BAD_REQUEST_400, 215 | S3_ERROR_CODES.IncorrectNumberOfFilesInPostRequest: S3_STATUS_CODES.BAD_REQUEST_400, 216 | S3_ERROR_CODES.InlineDataTooLarge: S3_STATUS_CODES.BAD_REQUEST_400, 217 | S3_ERROR_CODES.InternalError: S3_STATUS_CODES.INTERNAL_SERVER_ERROR_500, 218 | S3_ERROR_CODES.InvalidAccessKeyId: S3_STATUS_CODES.FORBIDDEN_403, 219 | S3_ERROR_CODES.InvalidAccessPoint: S3_STATUS_CODES.BAD_REQUEST_400, 220 | S3_ERROR_CODES.InvalidArgument: S3_STATUS_CODES.BAD_REQUEST_400, 221 | S3_ERROR_CODES.InvalidBucketName: S3_STATUS_CODES.BAD_REQUEST_400, 222 | S3_ERROR_CODES.InvalidBucketState: S3_STATUS_CODES.CONFLICT_409, 223 | S3_ERROR_CODES.InvalidDigest: S3_STATUS_CODES.BAD_REQUEST_400, 224 | S3_ERROR_CODES.InvalidEncryptionAlgorithmError: S3_STATUS_CODES.BAD_REQUEST_400, 225 | S3_ERROR_CODES.InvalidLocationConstraint: S3_STATUS_CODES.BAD_REQUEST_400, 226 | S3_ERROR_CODES.InvalidObjectState: S3_STATUS_CODES.FORBIDDEN_403, 227 | S3_ERROR_CODES.InvalidPart: S3_STATUS_CODES.BAD_REQUEST_400, 228 | S3_ERROR_CODES.InvalidPartOrder: S3_STATUS_CODES.BAD_REQUEST_400, 229 | S3_ERROR_CODES.InvalidPayer: S3_STATUS_CODES.FORBIDDEN_403, 230 | S3_ERROR_CODES.InvalidPolicyDocument: S3_STATUS_CODES.BAD_REQUEST_400, 231 | S3_ERROR_CODES.InvalidRange: S3_STATUS_CODES.RANGE_NOT_SATISFIABLE_416, 232 | S3_ERROR_CODES.InvalidRequest: S3_STATUS_CODES.BAD_REQUEST_400, 233 | S3_ERROR_CODES.InvalidSecurity: S3_STATUS_CODES.FORBIDDEN_403, 234 | S3_ERROR_CODES.InvalidSOAPRequest: S3_STATUS_CODES.BAD_REQUEST_400, 235 | S3_ERROR_CODES.InvalidStorageClass: S3_STATUS_CODES.BAD_REQUEST_400, 236 | S3_ERROR_CODES.InvalidTargetBucketForLogging: S3_STATUS_CODES.BAD_REQUEST_400, 237 | S3_ERROR_CODES.InvalidToken: S3_STATUS_CODES.BAD_REQUEST_400, 238 | S3_ERROR_CODES.InvalidURI: S3_STATUS_CODES.BAD_REQUEST_400, 239 | S3_ERROR_CODES.KeyTooLongError: S3_STATUS_CODES.BAD_REQUEST_400, 240 | S3_ERROR_CODES.MalformedACLError: S3_STATUS_CODES.BAD_REQUEST_400, 241 | S3_ERROR_CODES.MalformedPOSTRequest: S3_STATUS_CODES.BAD_REQUEST_400, 242 | S3_ERROR_CODES.MalformedXML: S3_STATUS_CODES.BAD_REQUEST_400, 243 | S3_ERROR_CODES.MaxMessageLengthExceeded: S3_STATUS_CODES.BAD_REQUEST_400, 244 | S3_ERROR_CODES.MaxPostPreDataLengthExceededError: S3_STATUS_CODES.BAD_REQUEST_400, 245 | S3_ERROR_CODES.MetadataTooLarge: S3_STATUS_CODES.BAD_REQUEST_400, 246 | S3_ERROR_CODES.MethodNotAllowed: S3_STATUS_CODES.METHOD_NOT_ALLOWED_405, 247 | S3_ERROR_CODES.MissingContentLength: S3_STATUS_CODES.LENGTH_REQUIRED_411, 248 | S3_ERROR_CODES.MissingRequestBodyError: S3_STATUS_CODES.BAD_REQUEST_400, 249 | S3_ERROR_CODES.MissingSecurityElement: S3_STATUS_CODES.BAD_REQUEST_400, 250 | S3_ERROR_CODES.MissingSecurityHeader: S3_STATUS_CODES.BAD_REQUEST_400, 251 | S3_ERROR_CODES.NoLoggingStatusForKey: S3_STATUS_CODES.BAD_REQUEST_400, 252 | S3_ERROR_CODES.NoSuchBucket: S3_STATUS_CODES.NOT_FOUND_404, 253 | S3_ERROR_CODES.NoSuchBucketPolicy: S3_STATUS_CODES.NOT_FOUND_404, 254 | S3_ERROR_CODES.NoSuchKey: S3_STATUS_CODES.NOT_FOUND_404, 255 | S3_ERROR_CODES.NoSuchLifecycleConfiguration: S3_STATUS_CODES.NOT_FOUND_404, 256 | S3_ERROR_CODES.NoSuchUpload: S3_STATUS_CODES.NOT_FOUND_404, 257 | S3_ERROR_CODES.NoSuchVersion: S3_STATUS_CODES.NOT_FOUND_404, 258 | S3_ERROR_CODES.NotSignedUp: S3_STATUS_CODES.FORBIDDEN_403, 259 | S3_ERROR_CODES.OperationAborted: S3_STATUS_CODES.CONFLICT_409, 260 | S3_ERROR_CODES.PreconditionFailed: S3_STATUS_CODES.PRECONDITION_FAILED_412, 261 | S3_ERROR_CODES.RestoreAlreadyInProgress: S3_STATUS_CODES.CONFLICT_409, 262 | S3_ERROR_CODES.RequestIsNotMultiPartContent: S3_STATUS_CODES.BAD_REQUEST_400, 263 | S3_ERROR_CODES.RequestTimeout: S3_STATUS_CODES.BAD_REQUEST_400, 264 | S3_ERROR_CODES.RequestTimeTooSkewed: S3_STATUS_CODES.FORBIDDEN_403, 265 | S3_ERROR_CODES.RequestTorrentOfBucketError: S3_STATUS_CODES.BAD_REQUEST_400, 266 | S3_ERROR_CODES.ServerSideEncryptionConfigurationNotFoundError: S3_STATUS_CODES.BAD_REQUEST_400, 267 | S3_ERROR_CODES.ServiceUnavailable: S3_STATUS_CODES.SERVICE_UNAVAILABLE_503, 268 | S3_ERROR_CODES.SignatureDoesNotMatch: S3_STATUS_CODES.FORBIDDEN_403, 269 | S3_ERROR_CODES.SlowDown: S3_STATUS_CODES.SERVICE_UNAVAILABLE_503, 270 | S3_ERROR_CODES.TokenRefreshRequired: S3_STATUS_CODES.BAD_REQUEST_400, 271 | S3_ERROR_CODES.TooManyAccessPoints: S3_STATUS_CODES.BAD_REQUEST_400, 272 | S3_ERROR_CODES.TooManyBuckets: S3_STATUS_CODES.BAD_REQUEST_400, 273 | S3_ERROR_CODES.UnexpectedContent: S3_STATUS_CODES.BAD_REQUEST_400, 274 | S3_ERROR_CODES.UnresolvableGrantByEmailAddress: S3_STATUS_CODES.BAD_REQUEST_400, 275 | S3_ERROR_CODES.UserKeyMustBeSpecified: S3_STATUS_CODES.BAD_REQUEST_400, 276 | S3_ERROR_CODES.NoSuchAccessPoint: S3_STATUS_CODES.NOT_FOUND_404, 277 | S3_ERROR_CODES.InvalidTag: S3_STATUS_CODES.BAD_REQUEST_400, 278 | S3_ERROR_CODES.MalformedPolicy: S3_STATUS_CODES.BAD_REQUEST_400 279 | } 280 | 281 | 282 | def error_code_to_enums(error_code: str) -> Tuple[S3_STATUS_CODES, S3_ERROR_CODES]: 283 | """Error code to enums.""" 284 | for code, status in ERROR_CODE_STATUS_MAP.items(): 285 | if error_code == code.name: 286 | return status, code 287 | return S3_STATUS_CODES.INTERNAL_SERVER_ERROR_500, S3_ERROR_CODES.InternalError 288 | 289 | 290 | def http_status_code_to_s3_status_code(http_status_code: int) -> S3_STATUS_CODES: 291 | """Convert http status codes to s3 status codes.""" 292 | for status_codes in S3_STATUS_CODES: 293 | if str(http_status_code) == status_codes.name[-3:]: 294 | return status_codes 295 | return S3_STATUS_CODES.INTERNAL_SERVER_ERROR_500 296 | -------------------------------------------------------------------------------- /test/integ/integ_base.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from datetime import datetime, timedelta 3 | from time import sleep 4 | from unittest import TestCase 5 | import os 6 | import boto3 7 | import uuid 8 | import zipfile 9 | import json 10 | import time 11 | 12 | from dateutil.tz import tzutc 13 | 14 | LOG_LEVEL = "LOG_LEVEL" 15 | UNSUPPORTED_FILE_HANDLING = "UNSUPPORTED_FILE_HANDLING" 16 | DOCUMENT_MAX_SIZE_CONTAINS_PII_ENTITIES = "DOCUMENT_MAX_SIZE_CONTAINS_PII_ENTITIES" 17 | DOCUMENT_MAX_SIZE_DETECT_PII_ENTITIES = "DOCUMENT_MAX_SIZE_DETECT_PII_ENTITIES" 18 | PII_ENTITY_TYPES = "PII_ENTITY_TYPES" 19 | MASK_CHARACTER = "MASK_CHARACTER" 20 | MASK_MODE = "MASK_MODE" 21 | SUBSEGMENT_OVERLAPPING_TOKENS = "SUBSEGMENT_OVERLAPPING_TOKENS" 22 | DOCUMENT_MAX_SIZE = "DOCUMENT_MAX_SIZE" 23 | CONFIDENCE_THRESHOLD = "CONFIDENCE_THRESHOLD" 24 | MAX_CHARS_OVERLAP = "MAX_CHARS_OVERLAP" 25 | DEFAULT_LANGUAGE_CODE = "DEFAULT_LANGUAGE_CODE" 26 | DETECT_PII_ENTITIES_THREAD_COUNT = "DETECT_PII_ENTITIES_THREAD_COUNT" 27 | CONTAINS_PII_ENTITIES_THREAD_COUNT = "CONTAINS_PII_ENTITIES_THREAD_COUNT" 28 | REDACTION_API_ONLY = "REDACTION_API_ONLY" 29 | PUBLISH_CLOUD_WATCH_METRICS = "PUBLISH_CLOUD_WATCH_METRICS" 30 | CW_METRIC_PUBLISH_CHECK_ATTEMPT = 5 31 | 32 | 33 | class BasicIntegTest(TestCase): 34 | REGION_NAME = 'us-east-1' 35 | DATA_PATH = "test/data/integ" 36 | PII_ENTITY_TYPES_IN_TEST_DOC = ['EMAIL', 'PHONE', 'BANK_ROUTING', 'BANK_ACCOUNT_NUMBER', 'ADDRESS', 'DATE_TIME'] 37 | 38 | @classmethod 39 | def _zip_function_code(cls, build_dir): 40 | source_dir = build_dir 41 | output_filename = "s3ol_function.zip" 42 | relroot = os.path.abspath(source_dir) 43 | with zipfile.ZipFile(output_filename, "w", zipfile.ZIP_DEFLATED) as zipf: 44 | logging.info(f"Build Dir: {source_dir}") 45 | for root, dirs, files in os.walk(source_dir): 46 | # add directory (needed for empty dirs) 47 | zipf.write(root, os.path.relpath(root, relroot)) 48 | for file in files: 49 | filename = os.path.join(root, file) 50 | if os.path.isfile(filename): # regular files only 51 | arcname = os.path.join(os.path.relpath(root, relroot), file) 52 | zipf.write(filename, arcname) 53 | 54 | return output_filename 55 | 56 | @classmethod 57 | def _update_lambda_env_variables(cls, function_arn, env_variable_dict): 58 | function_name = function_arn.split(':')[-1] 59 | cls.lambda_client.update_function_configuration(FunctionName=function_name, Environment={'Variables': env_variable_dict}) 60 | waiter = cls.lambda_client.get_waiter('function_updated') 61 | waiter.wait(FunctionName=function_name) 62 | 63 | @classmethod 64 | def _create_function(cls, function_name, handler_name, test_id, env_vars_dict, lambda_execution_role, build_dir): 65 | function_unique_name = f"{function_name}-{test_id}" 66 | zip_file_name = cls._zip_function_code(build_dir) 67 | with open(zip_file_name, 'rb') as file_data: 68 | bytes_content = file_data.read() 69 | 70 | create_function_response = cls.lambda_client.create_function( 71 | FunctionName=function_unique_name, 72 | Runtime='python3.8', 73 | Role=lambda_execution_role, 74 | Handler=handler_name, 75 | Code={ 76 | 'ZipFile': bytes_content 77 | }, 78 | Timeout=60, 79 | MemorySize=128, 80 | Publish=True, 81 | Environment={ 82 | 'Variables': env_vars_dict 83 | } 84 | ) 85 | return create_function_response['FunctionArn'] 86 | 87 | @classmethod 88 | def _update_function_handler(cls, handler): 89 | response = cls.lambda_client.update_function_configuration( 90 | FunctionName=cls.function_name, 91 | Handler=handler 92 | ) 93 | 94 | @classmethod 95 | def _create_role(cls, test_id): 96 | 97 | cls.role_name = f"s3ol-integ-test-role-{test_id}" 98 | 99 | assume_role_policy_document = { 100 | "Version": "2012-10-17", 101 | "Statement": [ 102 | { 103 | "Effect": "Allow", 104 | "Principal": { 105 | "Service": "lambda.amazonaws.com" 106 | }, 107 | "Action": "sts:AssumeRole" 108 | } 109 | ] 110 | } 111 | 112 | create_iam_role_response = cls.iam_client.create_role( 113 | Path='/s3ol-integ-test/', 114 | RoleName=cls.role_name, 115 | AssumeRolePolicyDocument=json.dumps(assume_role_policy_document) 116 | ) 117 | 118 | role_arn = create_iam_role_response['Role']['Arn'] 119 | 120 | policy_document = { 121 | "Statement": [ 122 | { 123 | "Action": [ 124 | "comprehend:DetectPiiEntities", 125 | "comprehend:ContainsPiiEntities" 126 | ], 127 | "Resource": "*", 128 | "Effect": "Allow", 129 | "Sid": "ComprehendPiiDetectionPolicy" 130 | }, 131 | { 132 | "Action": [ 133 | "s3-object-lambda:WriteGetObjectResponse" 134 | ], 135 | "Resource": "*", 136 | "Effect": "Allow", 137 | "Sid": "S3AccessPointCallbackPolicy" 138 | }, 139 | { 140 | "Action": [ 141 | "cloudwatch:PutMetricData" 142 | ], 143 | "Resource": "*", 144 | "Effect": "Allow", 145 | "Sid": "CloudWatchMetricsPolicy" 146 | } 147 | ] 148 | } 149 | 150 | put_role_policy_response = cls.iam_client.put_role_policy( 151 | RoleName=cls.role_name, 152 | PolicyName='S3OLFunctionPolicy', 153 | PolicyDocument=json.dumps(policy_document) 154 | ) 155 | 156 | attach_role_policy_response = cls.iam_client.attach_role_policy( 157 | RoleName=cls.role_name, 158 | PolicyArn='arn:aws:iam::aws:policy/service-role/AWSLambdaBasicExecutionRole' 159 | ) 160 | # wait for some time to let iam role propagate 161 | sleep(10) 162 | return role_arn 163 | 164 | @classmethod 165 | def _create_bucket(cls, test_id): 166 | cls.bucket_name = f"s3ol-integ-test-{test_id}" 167 | bucket = cls.s3.Bucket(cls.bucket_name) 168 | create_bucket_response = bucket.create() 169 | cls.s3.BucketVersioning(cls.bucket_name).enable() 170 | 171 | @classmethod 172 | def _create_access_point(cls, test_id): 173 | cls.access_point_name = f"s3ol-integ-test-ac-{test_id}" 174 | 175 | create_access_point_response = cls.s3_ctrl.create_access_point( 176 | AccountId=cls.account_id, 177 | Name=cls.access_point_name, 178 | Bucket=cls.bucket_name, 179 | ) 180 | cls.access_point_arn = f"arn:aws:s3:{cls.REGION_NAME}:{cls.account_id}:accesspoint/{cls.access_point_name}" 181 | 182 | @classmethod 183 | def _create_s3ol_access_point(cls, test_id, lambda_function_arn): 184 | cls.s3ol_access_point_name = f"s3ol-integ-test-bac-{test_id}" 185 | 186 | create_s3ol_access_point_response = cls.s3_ctrl.create_access_point_for_object_lambda( 187 | AccountId=cls.account_id, 188 | Name=cls.s3ol_access_point_name, 189 | Configuration={ 190 | 'SupportingAccessPoint': cls.access_point_arn, 191 | 'TransformationConfigurations': [ 192 | { 193 | 'Actions': ['GetObject'], 194 | 'ContentTransformation': { 195 | 'AwsLambda': { 196 | 'FunctionArn': lambda_function_arn 197 | } 198 | } 199 | } 200 | ], 201 | "AllowedFeatures": ["GetObject-Range"] 202 | } 203 | ) 204 | return f"arn:aws:s3-object-lambda:{cls.REGION_NAME}:{cls.account_id}:accesspoint/{cls.s3ol_access_point_name}" 205 | 206 | @classmethod 207 | def _create_temporary_data_files(cls): 208 | non_pii_text = "\nentities identified in the input text. For each entity, the response provides the entity type, where the entity text begins and ends, and the level of confidence entities identified in the input text. For each entity, the response provides the entity type, where the entity text begins and ends, and the level of confidence that Amazon Comprehend has in the detectio" 209 | 210 | def _create_file(input_file_name, output_file_name, repeats=4077): 211 | with open(f"{cls.DATA_PATH}/{input_file_name}") as existing_file: 212 | modified_file_content = existing_file.read() 213 | for i in range(0, repeats): 214 | modified_file_content += non_pii_text 215 | with open(f"{cls.DATA_PATH}/{output_file_name}", 'w') as new_file: 216 | new_file.seek(0) 217 | new_file.write(modified_file_content) 218 | 219 | _create_file('pii_input.txt', '1mb_pii_text') 220 | _create_file('pii_input.txt', '5mb_pii_text', repeats=15000) 221 | _create_file('pii_output.txt', '1mb_pii_redacted_text') 222 | 223 | @classmethod 224 | def _clear_temporary_files(cls): 225 | os.remove(f"{cls.DATA_PATH}/1mb_pii_text") 226 | os.remove(f"{cls.DATA_PATH}/5mb_pii_text") 227 | os.remove(f"{cls.DATA_PATH}/1mb_pii_redacted_text") 228 | 229 | @classmethod 230 | def _upload_data(cls): 231 | for filename in os.listdir(cls.DATA_PATH): 232 | cls.s3_client.upload_file(f"{cls.DATA_PATH}/{filename}", cls.bucket_name, filename) 233 | 234 | @classmethod 235 | def setUpClass(cls): 236 | cls.s3 = boto3.resource('s3', region_name=cls.REGION_NAME) 237 | cls.s3_client = boto3.client('s3', region_name=cls.REGION_NAME) 238 | cls.s3_ctrl = boto3.client('s3control', region_name=cls.REGION_NAME) 239 | cls.lambda_client = boto3.client('lambda', region_name=cls.REGION_NAME) 240 | cls.cloudwatch_client = boto3.client('cloudwatch', region_name=cls.REGION_NAME) 241 | cls.iam = boto3.resource('iam') 242 | cls.account_id = cls.iam.CurrentUser().arn.split(':')[4] 243 | cls.iam_client = boto3.client('iam') 244 | test_run_id = str(uuid.uuid4())[0:8] 245 | 246 | cls.lambda_role_arn = cls._create_role(test_run_id) 247 | cls._create_bucket(test_run_id) 248 | cls._create_access_point(test_run_id) 249 | cls._create_temporary_data_files() 250 | cls._upload_data() 251 | 252 | @classmethod 253 | def tearDownClass(cls): 254 | try: 255 | delete_access_point_response = cls.s3_ctrl.delete_access_point( 256 | AccountId=cls.account_id, 257 | Name=cls.access_point_name 258 | ) 259 | except Exception as e: 260 | logging.error(e) 261 | cls._clear_temporary_files() 262 | bucket = cls.s3.Bucket(cls.bucket_name) 263 | delete_object_response = bucket.object_versions.all().delete() 264 | delete_bucket_response = bucket.delete() 265 | 266 | delete_role_policy_response = cls.iam_client.delete_role_policy( 267 | RoleName=cls.role_name, 268 | PolicyName='S3OLFunctionPolicy' 269 | ) 270 | 271 | detach_role_policy_response = cls.iam_client.detach_role_policy( 272 | RoleName=cls.role_name, 273 | PolicyArn='arn:aws:iam::aws:policy/service-role/AWSLambdaBasicExecutionRole' 274 | ) 275 | 276 | delete_role_response = cls.iam_client.delete_role( 277 | RoleName=cls.role_name 278 | ) 279 | 280 | def _validate_pii_count_metric_published(self, s3ol_access_point_arn, start_time: datetime, entity_types): 281 | end_time = start_time + timedelta(minutes=1) 282 | st_time = start_time - timedelta(minutes=1) 283 | 284 | def _is_metric_published() -> bool: 285 | for e_type in entity_types: 286 | pii_doc_processed_count_metric = \ 287 | self.cloudwatch_client.get_metric_data(MetricDataQueries=[{'Id': 'a1232454353', 288 | 'MetricStat': {'Metric': { 289 | 'MetricName': 'PiiDocumentTypesProcessed', 290 | 'Namespace': 'ComprehendS3ObjectLambda', 291 | 'Dimensions': [ 292 | {'Name': 'PiiEntityType', 293 | 'Value': e_type}, 294 | {'Name': 'Language', 295 | 'Value': 'en'}, 296 | {'Name': 'S3ObjectLambdaAccessPoint', 297 | 'Value': s3ol_access_point_arn}]}, 298 | 'Period': 60, 299 | 'Stat': 'Sum'}}], 300 | StartTime=st_time, 301 | EndTime=end_time) 302 | for result in pii_doc_processed_count_metric['MetricDataResults']: 303 | if result['Id'] == "a1232454353": 304 | if len(result['Values']) == 0: 305 | return False 306 | return True 307 | 308 | attempts = CW_METRIC_PUBLISH_CHECK_ATTEMPT 309 | while attempts > 0: 310 | if _is_metric_published(): 311 | return None 312 | sleep(10) 313 | attempts -= 1 314 | assert False, f"No metrics published for s3ol arn {s3ol_access_point_arn} for one of the entity types: {entity_types}," \ 315 | f"StartTime: {st_time}, EndTime: {end_time}" 316 | 317 | def _validate_api_call_latency_published(self, s3ol_access_point_arn, start_time: datetime, end_time=None, 318 | is_pii_classification_performed: bool = True, 319 | is_pii_detection_performed: bool = False, 320 | is_s3ol_callback_done: bool = True): 321 | end_time = start_time + timedelta(minutes=1) if not end_time else end_time 322 | st_time = start_time - timedelta(minutes=1) 323 | 324 | def _is_metric_published(api_name, service, ) -> bool: 325 | pii_doc_processed_count_metric = \ 326 | self.cloudwatch_client.get_metric_data(MetricDataQueries=[{'Id': 'a1232454353', 327 | 'MetricStat': {'Metric': { 328 | 'MetricName': 'Latency', 329 | 'Namespace': 'ComprehendS3ObjectLambda', 330 | 'Dimensions': [ 331 | {'Name': 'API', 332 | 'Value': api_name}, 333 | {'Name': 'S3ObjectLambdaAccessPoint', 334 | 'Value': s3ol_access_point_arn}, 335 | {'Name': 'Service', 336 | 'Value': service} 337 | ]}, 338 | 'Period': 60, 339 | 'Stat': 'Average'}}], 340 | StartTime=st_time, 341 | EndTime=end_time) 342 | for result in pii_doc_processed_count_metric['MetricDataResults']: 343 | if result['Id'] == "a1232454353": 344 | if len(result['Values']) == 0: 345 | return False 346 | return True 347 | 348 | attempts = CW_METRIC_PUBLISH_CHECK_ATTEMPT 349 | metric_checked = { 350 | 'ContainsPiiEntities': False, 351 | 'DetectPiiEntities': False, 352 | 'WriteGetObjectResponse': False 353 | } 354 | while attempts > 0: 355 | if is_pii_classification_performed and not metric_checked['ContainsPiiEntities']: 356 | if _is_metric_published(api_name="ContainsPiiEntities", service="Comprehend"): 357 | metric_checked['ContainsPiiEntities'] = True 358 | if is_pii_detection_performed and not metric_checked['DetectPiiEntities']: 359 | if _is_metric_published(api_name="DetectPiiEntities", service="Comprehend"): 360 | metric_checked['DetectPiiEntities'] = True 361 | if is_s3ol_callback_done and not metric_checked['WriteGetObjectResponse']: 362 | if _is_metric_published(api_name="WriteGetObjectResponse", service="S3"): 363 | metric_checked['WriteGetObjectResponse'] = True 364 | sleep(10) 365 | attempts -= 1 366 | if is_pii_classification_performed: 367 | assert metric_checked['ContainsPiiEntities'], \ 368 | f"Could not verify that metrics were published for various API call made between StartTime: {st_time}, EndTime: {end_time}" 369 | if is_pii_detection_performed: 370 | assert metric_checked['DetectPiiEntities'], \ 371 | f"Could not verify that metrics were published for various API call made between StartTime: {st_time}, EndTime: {end_time}" 372 | if is_s3ol_callback_done: 373 | assert metric_checked['WriteGetObjectResponse'], \ 374 | f"Could not verify that metrics were published for API call made between StartTime: {st_time}, EndTime: {end_time}" 375 | --------------------------------------------------------------------------------