├── .coveragerc ├── .github └── workflows │ └── build_publish.yml ├── .gitignore ├── .prospector.yaml ├── CONTRIBUTING.md ├── LICENSE ├── Pipfile ├── README.md ├── aws_lambda_decorators ├── __init__.py ├── classes.py ├── decoders.py ├── decorators.py ├── utils.py └── validators.py ├── buildspec.yml ├── examples ├── __init__.py ├── examples.py └── test_examples.py ├── pylintrc ├── requirements.txt ├── setup.py ├── tests ├── __init__.py ├── test_classes.py ├── test_decoders.py ├── test_decorators.py ├── test_param.py └── test_utils.py └── tools └── dev └── coverage.sh /.coveragerc: -------------------------------------------------------------------------------- 1 | [run] 2 | 3 | branch = True 4 | 5 | omit = 6 | *__init__.py 7 | examples/* 8 | setup.py 9 | tests/* 10 | 11 | [report] 12 | 13 | show_missing = True 14 | 15 | omit = 16 | *__init__.py 17 | examples/* 18 | setup.py 19 | tests/* 20 | 21 | fail_under = 100 22 | -------------------------------------------------------------------------------- /.github/workflows/build_publish.yml: -------------------------------------------------------------------------------- 1 | name: Publish 🐍 distributions 📦 to PyPI 2 | 3 | on: 4 | push: 5 | branches: 6 | - master 7 | jobs: 8 | build-and-publish: 9 | name: Build and publish 🐍 distributions 📦 to PyPI 10 | runs-on: ubuntu-latest 11 | steps: 12 | - name: Check out repository code 13 | uses: actions/checkout@v2 14 | 15 | - name: Fetch version 16 | run: | 17 | output="$(python setup.py --version)" 18 | echo "::set-output name=version::$output" 19 | id: setup_version 20 | - name: Print build version 21 | run: echo "${{ steps.setup_version.outputs.version }}" 22 | 23 | # Setup Python (faster than using Python container) 24 | - name: Setup Python 25 | uses: actions/setup-python@v2 26 | with: 27 | python-version: "3.7" 28 | 29 | - name: Install pipenv 30 | run: | 31 | python -m pip install --upgrade pipenv wheel 32 | - id: cache-pipenv 33 | uses: actions/cache@v1 34 | with: 35 | path: ~/.local/share/virtualenvs 36 | key: ${{ runner.os }}-pipenv-${{ hashFiles('**/Pipfile.lock') }} 37 | 38 | - name: Install dependencies 39 | if: steps.cache-pipenv.outputs.cache-hit != 'true' 40 | run: | 41 | pipenv install --deploy --dev 42 | 43 | - name: Build a binary wheel and a source tarball 44 | run: | 45 | pipenv run build 46 | # git tag created using version specified in setup.py 47 | - name: Bump git tag version 48 | id: tag_version 49 | uses: mathieudutour/github-tag-action@v5.6 50 | with: 51 | github_token: ${{ secrets.GITHUB_TOKEN }} 52 | custom_tag: ${{ steps.setup_version.outputs.version }} 53 | - name: Create a GitHub release 54 | uses: ncipollo/release-action@v1 55 | with: 56 | tag: ${{ steps.tag_version.outputs.new_tag }} 57 | name: Release ${{ steps.tag_version.outputs.new_tag }} 58 | body: ${{ github.event.head_commit.message }} 59 | 60 | - name: Publish distribution 📦 to Test PyPI 61 | uses: pypa/gh-action-pypi-publish@master 62 | with: 63 | password: ${{ secrets.TEST_PYPI_API_TOKEN }} 64 | repository_url: https://test.pypi.org/legacy/ 65 | verbose: true 66 | skip_existing: true 67 | 68 | - name: Publish distribution 📦 to PyPI 69 | uses: pypa/gh-action-pypi-publish@master 70 | with: 71 | password: ${{ secrets.PYPI_API_TOKEN }} 72 | verbose: true -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | .pytest_cache/ 49 | coverage.sh 50 | .project 51 | 52 | # Translations 53 | *.mo 54 | *.pot 55 | 56 | # Django stuff: 57 | *.log 58 | local_settings.py 59 | db.sqlite3 60 | 61 | # Flask stuff: 62 | instance/ 63 | .webassets-cache 64 | 65 | # Scrapy stuff: 66 | .scrapy 67 | 68 | # Sphinx documentation 69 | docs/_build/ 70 | 71 | # PyBuilder 72 | target/ 73 | 74 | # Jupyter Notebook 75 | .ipynb_checkpoints 76 | 77 | # pyenv 78 | .python-version 79 | 80 | # celery beat schedule file 81 | celerybeat-schedule 82 | 83 | # SageMath parsed files 84 | *.sage.py 85 | 86 | # Environments 87 | .env 88 | .venv 89 | env/ 90 | venv/ 91 | ENV/ 92 | env.bak/ 93 | venv.bak/ 94 | .idea/ 95 | 96 | # Spyder project settings 97 | .spyderproject 98 | .spyproject 99 | 100 | # Rope project settings 101 | .ropeproject 102 | 103 | # mkdocs documentation 104 | /site 105 | 106 | # mypy 107 | .mypy_cache/ 108 | 109 | # Scripts 110 | *.sh 111 | 112 | # Mac 113 | .DS_Store 114 | */.DS_Store 115 | 116 | # Visual studio code 117 | .vscode 118 | 119 | # vim (sorry) 120 | *.swp 121 | -------------------------------------------------------------------------------- /.prospector.yaml: -------------------------------------------------------------------------------- 1 | output-format: text 2 | 3 | strictness: veryhigh 4 | test-warnings: true 5 | doc-warnings: false 6 | 7 | pep8: 8 | full: true 9 | options: 10 | max-line-length: 120 11 | 12 | pep257: 13 | disable: 14 | - D203 15 | - D212 16 | 17 | pylint: 18 | options: 19 | max-line-length: 120 -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing 2 | 3 | If you want to contribute to this project, please submit an issue describing your proposed change. We will respond to you as soon as we can. 4 | 5 | If you want to work on that change, fork this Github repo and clone the fork locally. Install the requirements in a python3 virtual environment, using the appropriate version of pip: 6 | 7 | `pip install -r requirements.txt` 8 | 9 | Before submitting a PR, please ensure that: 10 | 11 | - you run [__Bandit__](https://pypi.org/project/bandit/) for security checking and all checks are passing: 12 | 13 | `bandit -r .` 14 | 15 | - you run [__Prospector__](https://pypi.org/project/prospector/) for code analysis and all checks are passing: 16 | 17 | `prospector` 18 | 19 | - you run [__Coverage__](https://pypi.org/project/coverage/) and all unit tests are passing: 20 | 21 | `coverage run --source='.' -m unittest` 22 | 23 | `coverage report` 24 | 25 | - you can run the test examples like this: 26 | 27 | `python -m unittest examples.test_examples` 28 | 29 | Thanks for contributing! 30 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 Grid Smarter Cities 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /Pipfile: -------------------------------------------------------------------------------- 1 | [[source]] 2 | url = "https://pypi.org/simple" 3 | verify_ssl = true 4 | name = "pypi" 5 | 6 | [packages] 7 | pyjwt = "==1.7.1" 8 | 9 | [dev-packages] 10 | pytest = "*" 11 | setuptools = "*" 12 | wheel = "*" 13 | twine = "*" 14 | tqdm = "*" 15 | pre-commit = "*" 16 | commitizen = "*" 17 | toml = "*" 18 | coverage = "*" 19 | bandit = "*" 20 | pylint_quotes = "*" 21 | schema ="*" 22 | PyJWT ="*" 23 | boto3="*" 24 | 25 | [requires] 26 | python_version = "3.7" 27 | 28 | [scripts] 29 | test = "pytest" 30 | build = "python3 setup.py sdist bdist_wheel" 31 | deploy = "twine upload dist/*" 32 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | [Grid Smarter Cities](https://www.gridsmartercities.com/) 2 | 3 | ![Build Status](https://codebuild.eu-west-2.amazonaws.com/badges?uuid=eyJlbmNyeXB0ZWREYXRhIjoiSTZNdEsxUHdnWWdRMGwrS3FuaUxSb0g5c2hNdWdSNE94Y1RFRGNrdk96Zm9LWlZWWmpEK1FTWmcraGRnMEdzbmRjakF5SDVQUVBzcVpNL3hLSGw3TnpNPSIsIml2UGFyYW1ldGVyU3BlYyI6ImZsbHEwcUJGOFV2VXNpWHoiLCJtYXRlcmlhbFNldFNlcmlhbCI6MX0%3D&branch=master) 4 | [![License: MIT](https://img.shields.io/badge/License-MIT-yellow.svg)](https://opensource.org/licenses/MIT) 5 | ![Github Release](https://img.shields.io/github/release/gridsmartercities/aws-lambda-decorators.svg?style=flat) 6 | \ 7 | \ 8 | ![Python Versions](https://img.shields.io/pypi/pyversions/aws-lambda-decorators.svg?style=flat) 9 | ![PyPi Version](https://img.shields.io/pypi/v/aws-lambda-decorators.svg?style=flat) 10 | ![PyPi Status](https://img.shields.io/pypi/status/aws-lambda-decorators.svg?style=flat) 11 | ![Pypi Downloads](https://img.shields.io/pypi/dm/aws-lambda-decorators.svg?style=flat&logo=pypi) 12 | 13 | # aws-lambda-decorators 14 | 15 | A set of Python decorators to ease the development of AWS lambda functions. 16 | 17 | ## Installation 18 | 19 | The easiest way to use these AWS Lambda Decorators is to install them through Pip: 20 | 21 | `pip install aws-lambda-decorators` 22 | 23 | ## Logging 24 | 25 | The Logging level of the decorators can be controlled by setting a LOG_LEVEL environment variable. In python: 26 | 27 | `os.environ["LOG_LEVEL"] = "INFO"` 28 | 29 | The default value is "INFO" 30 | 31 | ## Package Contents 32 | 33 | ### [Decorators](https://github.com/gridsmartercities/aws-lambda-decorators/blob/master/aws_lambda_decorators/decorators.py) 34 | 35 | The current list of AWS Lambda Python Decorators includes: 36 | 37 | * [__extract__](#extract): a decorator to extract and validate specific keys of a dictionary parameter passed to a AWS Lambda function. 38 | * [__extract_from_event__](#extract_from_event): a facade of [__extract__](#extract) to extract and validate keys from an AWS API Gateway lambda function _event_ parameter. 39 | * [__extract_from_context__](#extract_from_context): a facade of [__extract__](#extract) to extract and validate keys from an AWS API Gateway lambda function _context_ parameter. 40 | * [__extract_from_ssm__](#extract_from_ssm): a decorator to extract from AWS SSM the values of a set of parameter keys. 41 | * [__validate__](#validate): a decorator to validate a list of function parameters. 42 | * [__log__](#log): a decorator to log the parameters passed to the lambda function and/or the response of the lambda function. 43 | * [__handle_exceptions__](#handle_exceptions): a decorator to handle any type of declared exception generated by the lambda function. 44 | * [__response_body_as_json__](#response_body_as_json): a decorator to transform a response dictionary body to a json string. 45 | * [__handle_all_exceptions__](#handle_all_exceptions): a decorator to handle all exceptions thrown by the lambda function. 46 | * [__cors__](#cors): a decorator to add cors headers to a lambda function. 47 | * [__push_ws_errors__](#push_ws_errors): a decorator to push unsuccessful responses back to the calling user via websockets with api gateway. 48 | * [__push_ws_responses__](#push_ws_response): a decorator to push all responses back to the calling user via websockets with api gateway. 49 | 50 | 51 | ### [Validators](https://github.com/gridsmartercities/aws-lambda-decorators/blob/master/aws_lambda_decorators/validators.py) 52 | 53 | Currently, the package offers 12 validators: 54 | 55 | * __Mandatory__: Checks if a parameter has a value. 56 | * __RegexValidator__: Checks a parameter against a regular expression. 57 | * __SchemaValidator__: Checks if an object adheres to the schema. Uses [schema](https://github.com/keleshev/schema) library. 58 | * __Minimum__: Checks if an optional numerical value is greater than a minimum value. 59 | * __Maximum__: Checks if an optional numerical value is less than a maximum value. 60 | * __MinLength__: Checks if an optional string value is longer than a minimum length. 61 | * __MaxLength__: Checks if an optional string value is shorter than a maximum length. 62 | * __Type__: Checks if an optional object value is of a given python type. 63 | * __EnumValidator__: Checks if an optional object value is in a list of valid values. 64 | * __NonEmpty__: Checks if an optional object value is not an empty value. 65 | * __DateValidator__: Checks if a given string is a valid date according to a passed in date format. 66 | * __CurrencyCodeValidator__: Checks if a given string is a valid currency code (ISO 4217). 67 | 68 | ### [Decoders](https://github.com/gridsmartercities/aws-lambda-decorators/blob/master/aws_lambda_decorators/decoders.py) 69 | 70 | The package offers functions to decode from JSON and JWT. 71 | 72 | * __decode_json__: decodes/converts a json string to a python dictionary 73 | * __decode_jwt__: decodes/converts a JWT string to a python dictionary 74 | 75 | ## Examples 76 | 77 | You can see some basic examples in the [examples](https://github.com/gridsmartercities/aws-lambda-decorators/blob/master/examples/examples.py) folder. 78 | 79 | ### extract 80 | 81 | This decorator extracts and validates values from dictionary parameters passed to a Lambda Function. 82 | 83 | * The decorator takes a list of __Parameter__ objects. 84 | * Each __Parameter__ object requires a non-empty path to the parameter in the dictionary, and the name of the dictionary (func_param_name) 85 | * The parameter value is extracted and added as a kwarg to the lambda handler (or any other decorated function/method). 86 | * You can add the parameter to the handler signature, or access it in the handler through kwargs. 87 | * The name of the extracted parameter is defaulted to the last element of the path name, but can be changed by passing a (valid pythonic variable name) var_name 88 | * You can define a default value for the parameter in the __Parameter__ or in the lambda handler itself. 89 | * A 400 exception is raised when the parameter cannot be extracted or when it does not validate. 90 | * A variable path (e.g. '/headers/Authorization[jwt]/sub') can be annotated to specify a decoding. In the example, Authorization might contain a JWT, which needs to be decoded before accessing the "sub" element. 91 | 92 | Example: 93 | ```python 94 | @extract(parameters=[ 95 | Parameter(path='/parent/my_param', func_param_name='a_dictionary'), # extracts a non mandatory my_param from a_dictionary 96 | Parameter(path='/parent/missing_non_mandatory', func_param_name='a_dictionary', default='I am missing'), # extracts a non mandatory missing_non_mandatory from a_dictionary 97 | Parameter(path='/parent/missing_mandatory', func_param_name='a_dictionary'), # does not fail as the parameter is not validated as mandatory 98 | Parameter(path='/parent/child/id', validators=[Mandatory], var_name='user_id', func_param_name='another_dictionary') # extracts a mandatory id as "user_id" from another_dictionary 99 | ]) 100 | def extract_example(a_dictionary, another_dictionary, my_param='aDefaultValue', missing_non_mandatory='I am missing', missing_mandatory=None, user_id=None): 101 | """ 102 | Given these two dictionaries: 103 | 104 | a_dictionary = { 105 | 'parent': { 106 | 'my_param': 'Hello!' 107 | }, 108 | 'other': 'other value' 109 | } 110 | 111 | another_dictionary = { 112 | 'parent': { 113 | 'child': { 114 | 'id': '123' 115 | } 116 | } 117 | } 118 | 119 | you can now access the extracted parameters directly: 120 | """ 121 | return my_param, missing_non_mandatory, missing_mandatory, user_id 122 | ``` 123 | 124 | Or you can use kwargs instead of specific parameter names: 125 | 126 | Example: 127 | ```python 128 | @extract(parameters=[ 129 | Parameter(path='/parent/my_param', func_param_name='a_dictionary') # extracts a non mandatory my_param from a_dictionary 130 | ]) 131 | def extract_to_kwargs_example(a_dictionary, **kwargs): 132 | """ 133 | a_dictionary = { 134 | 'parent': { 135 | 'my_param': 'Hello!' 136 | }, 137 | 'other': 'other value' 138 | } 139 | """ 140 | return kwargs['my_param'] # returns 'Hello!' 141 | ``` 142 | 143 | A missing mandatory parameter, or a parameter that fails validation, will raise an exception: 144 | 145 | Example: 146 | ```python 147 | @extract(parameters=[ 148 | Parameter(path='/parent/mandatory_param', func_param_name='a_dictionary', validators=[Mandatory]) # extracts a mandatory mandatory_param from a_dictionary 149 | ]) 150 | def extract_mandatory_param_example(a_dictionary, mandatory_param=None): 151 | return 'Here!' # this part will never be reached, if the mandatory_param is missing 152 | 153 | response = extract_mandatory_param_example({'parent': {'my_param': 'Hello!'}, 'other': 'other value'} ) 154 | 155 | print(response) # prints { 'statusCode': 400, 'body': '{"message": [{"mandatory_param": ["Missing mandatory value"]}]}' } and logs a more detailed error 156 | 157 | ``` 158 | 159 | You can add custom error messages to all validators, and incorporate to those error messages the validated value and the validation condition: 160 | 161 | Example: 162 | ```python 163 | @extract(parameters=[ 164 | Parameter(path='/parent/an_int', func_param_name='a_dictionary', validators=[Minimum(100, 'Bad value {value}: should be at least {condition}')]) # extracts a mandatory mandatory_param from a_dictionary 165 | ]) 166 | def extract_minimum_param_with_custom_error_example(a_dictionary, mandatory_param=None): 167 | return 'Here!' # this part will never be reached, if the an_int param is less than 100 168 | 169 | response = extract_minimum_param_with_custom_error_example({'parent': {'an_int': 10}}) 170 | 171 | print(response) # prints { 'statusCode': 400, 'body': '{"message": [{"an_int": ["Bad value 10: should be at least 100"]}]}' } and logs a more detailed error 172 | 173 | ``` 174 | 175 | You can group the validation errors together (instead of exiting on first error). 176 | 177 | Example: 178 | ```python 179 | @extract(parameters=[ 180 | Parameter(path='/parent/mandatory_param', func_param_name='a_dictionary', validators=[Mandatory]), # extracts two mandatory parameters from a_dictionary 181 | Parameter(path='/parent/another_mandatory_param', func_param_name='a_dictionary', validators=[Mandatory]), 182 | Parameter(path='/parent/an_int', func_param_name='a_dictionary', validators=[Maximum(10), Minimum(5)]) 183 | ], group_errors=True) # groups both errors together 184 | def extract_multiple_param_example(a_dictionary, mandatory_param=None, another_mandatory_param=None, an_int=0): 185 | return 'Here!' # this part will never be reached, if the mandatory_param is missing 186 | 187 | response = extract_multiple_param_example({'parent': {'my_param': 'Hello!', 'an_int': 20}, 'other': 'other value'}) 188 | 189 | print(response) # prints {'statusCode': 400, 'body': '{"message": [{"mandatory_param": ["Missing mandatory value"]}, {"another_mandatory_param": ["Missing mandatory value"]}, {"an_int": ["\'20\' is greater than maximum value \'10\'"]}]}'} 190 | 191 | ``` 192 | 193 | You can decode any part of the parameter path from json or any other existing annotation. 194 | 195 | Example: 196 | ```python 197 | @extract(parameters=[ 198 | Parameter(path='/parent[json]/my_param', func_param_name='a_dictionary') # extracts a non mandatory my_param from a_dictionary 199 | ]) 200 | def extract_from_json_example(a_dictionary, my_param=None): 201 | """ 202 | a_dictionary = { 203 | 'parent': '{"my_param": "Hello!" }', 204 | 'other': 'other value' 205 | } 206 | """ 207 | return my_param # returns 'Hello!' 208 | 209 | ``` 210 | 211 | You can also use an integer annotation to access an specific list element by index. 212 | 213 | Example: 214 | ```python 215 | @extract(parameters=[ 216 | Parameter(path='/parent[1]/my_param', func_param_name='a_dictionary') # extracts a non mandatory my_param from a_dictionary 217 | ]) 218 | def extract_from_list_example(a_dictionary, my_param=None): 219 | """ 220 | a_dictionary = { 221 | 'parent': [ 222 | {'my_param': 'Hello!'}, 223 | {'my_param': 'Bye!'} 224 | ] 225 | } 226 | """ 227 | return my_param # returns 'Bye!' 228 | 229 | ``` 230 | 231 | You can extract all parameters into a dictionary 232 | 233 | Example: 234 | ```python 235 | @extract(parameters=[ 236 | Parameter(path='/params/my_param_1', func_param_name='a_dictionary'), # extracts a non mandatory my_param_1 from a_dictionary 237 | Parameter(path='/params/my_param_2', func_param_name='a_dictionary') # extracts a non mandatory my_param_2 from a_dictionary 238 | ]) 239 | def extract_dictionary_example(a_dictionary, **kwargs): 240 | """ 241 | a_dictionary = { 242 | 'params': { 243 | 'my_param_1': 'Hello!', 244 | 'my_param_2': 'Bye!' 245 | } 246 | } 247 | """ 248 | return kwargs # returns {'my_param_1': 'Hello!', 'my_param_2': 'Bye!'} 249 | 250 | ``` 251 | 252 | You can apply a transformation to an extracted value. The transformation will happen before validation. 253 | 254 | Example: 255 | ```python 256 | @extract(parameters=[ 257 | Parameter(path='/params/my_param', func_param_name='a_dictionary', transform=int) # extracts a non mandatory my_param from a_dictionary 258 | ]) 259 | def extract_with_transform_example(a_dictionary, my_param=None): 260 | """ 261 | a_dictionary = { 262 | 'params': { 263 | 'my_param': '2' # the original value is the string '2' 264 | } 265 | } 266 | """ 267 | return my_param # returns the int value 2 268 | 269 | ``` 270 | 271 | The transform function can be any function, with its own error handling. 272 | 273 | Example: 274 | ```python 275 | 276 | def to_int(arg): 277 | try: 278 | return int(arg) 279 | except Exception: 280 | raise Exception("My custom error message") 281 | 282 | @extract(parameters=[ 283 | Parameter(path='/params/my_param', func_param_name='a_dictionary', transform=to_int) # extracts a non mandatory my_param from a_dictionary 284 | ]) 285 | def extract_with_custom_transform_example(a_dictionary, my_param=None): 286 | return {} 287 | 288 | response = extract_with_custom_transform_example({'params': {'my_param': 'abc'}}) 289 | 290 | print(response) # prints {'statusCode': 400, 'body': '{"message": "Error extracting parameters"}'}, and the logs will contain the "My custom error message" message. 291 | 292 | 293 | ``` 294 | 295 | ### extract_from_event 296 | 297 | This decorator is just a facade to the [extract](#extract) method to be used in AWS Api Gateway Lambdas. It automatically extracts from the event lambda parameter. 298 | 299 | Example: 300 | ```python 301 | @extract_from_event(parameters=[ 302 | Parameter(path='/body[json]/my_param', validators=[Mandatory]), # extracts a mandatory my_param from the json body of the event 303 | Parameter(path='/headers/Authorization[jwt]/sub', validators=[Mandatory], var_name='user_id') # extract the mandatory sub value as user_id from the authorization JWT 304 | ]) 305 | def extract_from_event_example(event, context, my_param=None, user_id=None): 306 | """ 307 | event = { 308 | 'body': '{"my_param": "Hello!"}', 309 | 'headers': { 310 | 'Authorization': 'eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiaWF0IjoxNTE2MjM5MDIyfQ.SflKxwRJSMeKKF2QT4fwpMeJf36POk6yJV_adQssw5c' 311 | } 312 | } 313 | """ 314 | return my_param, user_id # returns ('Hello!', '1234567890') 315 | ``` 316 | 317 | ### extract_from_context 318 | 319 | This decorator is just a facade to the [extract](#extract) method to be used in AWS Api Gateway Lambdas. It automatically extracts from the context lambda parameter. 320 | 321 | Example: 322 | ```python 323 | @extract_from_context(parameters=[ 324 | Parameter(path='/parent/my_param', validators=[Mandatory]) # extracts a mandatory my_param from the parent element in context 325 | ]) 326 | def extract_from_context_example(event, context, my_param=None): 327 | """ 328 | context = { 329 | 'parent': { 330 | 'my_param': 'Hello!' 331 | } 332 | } 333 | """ 334 | return my_param # returns 'Hello!' 335 | ``` 336 | 337 | ### extract_from_ssm 338 | 339 | This decorator extracts a parameter from AWS SSM and passes the parameter down to your function as a kwarg. 340 | 341 | * The decorator takes a list of __SSMParameter__ objects. 342 | * Each __SSMParameter__ object requires the name of the SSM parameter (ssm_name) 343 | * If no var_name is passed in, the extracted value is passed to the function with the ssm_name name 344 | 345 | Example: 346 | ```python 347 | @extract_from_ssm(ssm_parameters=[ 348 | SSMParameter(ssm_name='one_key'), # extracts the value of one_key from SSM as a kwarg named "one_key" 349 | SSMParameter(ssm_name='another_key', var_name="another") # extracts another_key as a kwarg named "another" 350 | ]) 351 | def extract_from_ssm_example(your_func_params, one_key=None, another=None): 352 | return your_func_params, one_key, another 353 | ``` 354 | 355 | ### validate 356 | 357 | This decorator validates a list of non dictionary parameters from your lambda function. 358 | 359 | * The decorator takes a list of __ValidatedParameter__ objects. 360 | * Each parameter object needs the name of the lambda function parameter that it is going to be validated, and the list of rules to validate. 361 | * A 400 exception is raised when the parameter does not validate. 362 | 363 | Example: 364 | ```python 365 | @validate(parameters=[ 366 | ValidatedParameter(func_param_name='a_param', validators=[Mandatory]), # validates a_param as mandatory 367 | ValidatedParameter(func_param_name='another_param', validators=[Mandatory, RegexValidator(r'\d+')]) # validates another_param as mandatory and containing only digits 368 | ValidatedParameter(func_param_name='param_with_schema', validators=[SchemaValidator(Schema({'a': Or(str, dict)}))]) # validates param_with_schema as an object with specified schema 369 | ]) 370 | def validate_example(a_param, another_param, param_with_schema): 371 | return a_param, another_param, param_with_schema # returns 'Hello!', '123456', {'a': {'b': 'c'}} 372 | 373 | validate_example('Hello!', '123456', {'a': {'b': 'c'}}) 374 | ``` 375 | 376 | Given the same function `validate_example`, a 400 exception is returned if at least one parameter does not validate (as per the [extract](#extract) decorator, you can group errors with the group_errors flag): 377 | 378 | ```python 379 | validate_example('Hello!', 'ABCD') # returns a 400 status code and an error message 380 | ``` 381 | 382 | ### log 383 | 384 | This decorator allows for logging the function arguments and/or the response. 385 | 386 | Example: 387 | ```python 388 | @log(parameters=True, response=True) 389 | def log_example(parameters): 390 | return 'Done!' 391 | 392 | log_example('Hello!') # logs 'Hello!' and 'Done!' 393 | ``` 394 | 395 | ### handle_exceptions 396 | 397 | This decorator handles a list of exceptions, returning a 400 response containing the specified friendly message to the caller. 398 | 399 | * The decorator takes a list of __ExceptionHandler__ objects. 400 | * Each __ExceptionHandler__ requires the type of exception to check, and an optional friendly message to return to the caller. 401 | 402 | Example: 403 | ```python 404 | @handle_exceptions(handlers=[ 405 | ExceptionHandler(ClientError, "Your message when a client error happens.") 406 | ]) 407 | def handle_exceptions_example(): 408 | dynamodb = boto3.resource('dynamodb') 409 | table = dynamodb.Table('non_existing_table') 410 | table.query(KeyConditionExpression=Key('user_id').eq(user_id)) 411 | # ... 412 | 413 | handle_exceptions_example() # returns {'body': '{"message": "Your message when a client error happens."}', 'statusCode': 400} 414 | ``` 415 | 416 | ### handle_all_exceptions 417 | 418 | This decorator handles all exceptions thrown by a lambda, returning a 400 response and the exception's message. 419 | 420 | Example: 421 | ```python 422 | @handle_all_exceptions() 423 | def handle_exceptions_example(): 424 | test_list = [1, 2, 3] 425 | invalid_value = test_list[5] 426 | # ... 427 | 428 | handle_all_exceptions_example() # returns {'body': '{"message": "list index out of range"}, 'statusCode': 400} 429 | ``` 430 | 431 | ### response_body_as_json 432 | 433 | This decorator ensures that, if the response contains a body, the body is dumped as json. 434 | 435 | * Returns a 500 error if the response body cannot be dumped as json. 436 | 437 | Example: 438 | ```python 439 | @response_body_as_json 440 | def response_body_as_json_example(): 441 | return {'statusCode': 400, 'body': {'param': 'hello!'}} 442 | 443 | response_body_as_json_example() # returns { 'statusCode': 400, 'body': "{'param': 'hello!'}" } 444 | ``` 445 | 446 | ### cors 447 | 448 | This decorator adds your defined CORS headers to the decorated function response. 449 | 450 | * Returns a 500 error if one or more of the CORS headers have an invalid type 451 | 452 | Example: 453 | ```python 454 | @cors(allow_origin='*', allow_methods='POST', allow_headers='Content-Type', max_age=86400) 455 | def cors_example(): 456 | return {'statusCode': 200} 457 | 458 | cors_example() # returns {'statusCode': 200, 'headers': {'access-control-allow-origin': '*', 'access-control-allow-methods': 'POST', 'access-control-allow-headers': 'Content-Type', 'access-control-max-age': 86400}} 459 | ``` 460 | 461 | ### hsts 462 | 463 | This decorator adds HSTS header to the decorated function response. Uses 2 years max-age (recommended default from https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Strict-Transport-Security) unless custom value provided as parameter. 464 | 465 | Example: 466 | ```python 467 | @hsts() 468 | def hsts_example(): 469 | return {'statusCode': 200} 470 | 471 | hsts_example() # returns {'statusCode': 200, 'headers': {'Strict-Transport-Security': 'max-age=63072000'}} 472 | ``` 473 | 474 | ### push_ws_errors 475 | 476 | This decorator pushes unsuccessful responses back to the calling client over websockets built on api gateway 477 | 478 | This decorator requires the client is connected to the websocket api gateway instance, and will therefore have a connection id 479 | 480 | Example: 481 | ```py 482 | @push_ws_errors('https://api_id.execute_id.region.amazonaws.com/Prod') 483 | @handle_all_exceptions() 484 | def handler(event, context): 485 | return { 486 | 'statusCode': 400, 487 | 'body': { 488 | 'message': 'Bad request' 489 | } 490 | } 491 | 492 | # will push {'type': 'error', 'statusCode': 400, 'message': 'Bad request'} back to the client via websockets 493 | ``` 494 | 495 | ### push_ws_response 496 | 497 | This decorator pushes all responses back to the calling client over websockets built on api gateway 498 | 499 | This decorator requires the client is connected to the websocket api gateway instance, and will therefore have a connection id 500 | 501 | Example: 502 | ```py 503 | @push_ws_response('https://api_id.execute_id.region.amazonaws.com/Prod') 504 | def handler(event, context): 505 | return { 506 | 'statusCode': 200, 507 | 'body': 'Hello, world!' 508 | } 509 | 510 | # will push {'statusCode': 200, 'body': 'Hello, world!'} back to the client via websockets 511 | ``` 512 | 513 | ## Writing your own validators 514 | 515 | You can create your own validators by inheriting from the Validator class. 516 | 517 | Fix length validator example: 518 | 519 | ```python 520 | class FixLength(Validator): 521 | ERROR_MESSAGE = "'{value}' length should be '{condition}'" 522 | 523 | def __init__(self, fix_length: int, error_message=None): 524 | super().__init__(error_message=error_message, condition=fix_length) 525 | 526 | def validate(self, value=None): 527 | if value is None: 528 | return True 529 | 530 | return len(str(value)) == self._condition 531 | ``` 532 | 533 | ## Documentation 534 | 535 | You can get the docstring help by running: 536 | 537 | ```bash 538 | >>> from aws_lambda_decorators.decorators import extract 539 | >>> help(extract) 540 | ``` 541 | 542 | ## Links 543 | 544 | * [PyPi](https://pypi.org/project/aws-lambda-decorators/) 545 | * [Test PyPi](https://test.pypi.org/project/aws-lambda-decorators/) 546 | * [Github](https://github.com/gridsmartercities/aws-lambda-decorators) 547 | -------------------------------------------------------------------------------- /aws_lambda_decorators/__init__.py: -------------------------------------------------------------------------------- 1 | from aws_lambda_decorators.classes import * # noqa 2 | from aws_lambda_decorators.decoders import * # noqa 3 | from aws_lambda_decorators.decorators import * # noqa 4 | from aws_lambda_decorators.utils import * # noqa 5 | from aws_lambda_decorators.validators import * # noqa 6 | -------------------------------------------------------------------------------- /aws_lambda_decorators/classes.py: -------------------------------------------------------------------------------- 1 | """All the classes used as parameters for the decorators.""" 2 | from aws_lambda_decorators.decoders import decode 3 | from aws_lambda_decorators.utils import is_valid_variable_name 4 | 5 | PATH_DIVIDER = "/" 6 | ANNOTATIONS_START = "[" 7 | ANNOTATIONS_END = "]" 8 | 9 | 10 | class ExceptionHandler: 11 | """Class mapping a friendly error message to a given Exception.""" 12 | 13 | def __init__(self, exception, friendly_message=None, status_code=400): 14 | """ 15 | Sets the private variables of the ExceptionHandler object. 16 | 17 | Args: 18 | exception (object|Exception): An exception to be handled. 19 | friendly_message (str): Friendly Message to be returned if the exception is caught. 20 | """ 21 | self._exception = exception 22 | self._friendly_message = friendly_message 23 | self._status_code = status_code 24 | 25 | @property 26 | def friendly_message(self): 27 | """Getter for the friendly message parameter.""" 28 | return self._friendly_message 29 | 30 | @property 31 | def exception(self): 32 | """Getter for the exception parameter.""" 33 | return self._exception 34 | 35 | @property 36 | def status_code(self): 37 | """Getter for the status code parameter.""" 38 | return self._status_code 39 | 40 | 41 | class BaseParameter: # noqa: pylint - too-few-public-methods 42 | """Parent class of all parameter classes.""" 43 | def __init__(self, var_name): 44 | """ 45 | Set the private variables of the BaseParameter object. 46 | 47 | Args: 48 | var_name (str): The name of the variable where to store the extracted parameter. 49 | """ 50 | self._name = var_name 51 | 52 | def get_var_name(self): 53 | """Gets the name of the variable that represents the parameter.""" 54 | if self._name and not is_valid_variable_name(self._name): 55 | raise SyntaxError(self._name) 56 | return self._name 57 | 58 | 59 | class SSMParameter(BaseParameter): 60 | """Class used for defining the key and, optionally, the variable name for ssm parameter extraction.""" 61 | 62 | def __init__(self, ssm_name, var_name=None): 63 | """ 64 | Set the private variables of the SSMParameter object. 65 | 66 | Args: 67 | ssm_name (str): Key of the variable in the AWS parameter store 68 | var_name (str): Optional, the name of variable to store the extracted value to. Defaults to ssm_name. 69 | """ 70 | self._ssm_name = ssm_name 71 | BaseParameter.__init__(self, var_name if var_name else ssm_name) 72 | 73 | def get_ssm_name(self): 74 | """Getter for the ssm_name parameter.""" 75 | return self._ssm_name 76 | 77 | 78 | class ValidatedParameter: 79 | """Class used to encapsulate the validation methods parameter data.""" 80 | 81 | def __init__(self, func_param_name=None, validators=None): 82 | """ 83 | Sets the private variables of the ValidatedParameter object. 84 | Args: 85 | func_param_name (str): the name for the dictionary in the function signature 86 | def fun(event, context). To extract from context func_param_name has to be "context" 87 | validators (list): A list of validators the value must conform to (e.g. Mandatory, 88 | RegexValidator(my_regex), ...) 89 | """ 90 | self._func_param_name = func_param_name 91 | self._validators = validators 92 | 93 | @property 94 | def func_param_name(self): 95 | """Getter for the func_param_name parameter.""" 96 | return self._func_param_name 97 | 98 | @func_param_name.setter 99 | def func_param_name(self, value): 100 | """Setter for the func_param_name parameter.""" 101 | self._func_param_name = value 102 | 103 | def validate(self, value, group_errors): 104 | """ 105 | Validates a value against the passed in validators 106 | 107 | Args: 108 | value (any): value to be validated 109 | group_errors (bool): flag that indicates if error messages are to be grouped together 110 | (if set to False, validation will end on first error) 111 | 112 | Returns: 113 | A list of errors 114 | """ 115 | errors = [] 116 | 117 | if self._validators: 118 | for validator in self._validators: 119 | if not validator.validate(value): 120 | if hasattr(validator, "_error_message"): 121 | errors.append(validator.message(value)) 122 | else: # calling the validator statically 123 | errors.append(validator.ERROR_MESSAGE.format(value=value)) 124 | if not group_errors: 125 | return errors 126 | 127 | return errors 128 | 129 | 130 | class Parameter(ValidatedParameter, BaseParameter): 131 | """Class used to encapsulate the extract methods parameter data.""" 132 | 133 | def __init__(self, path="", func_param_name=None, validators=None, var_name=None, default=None, transform=None): # noqa: pylint - too-many-arguments 134 | """ 135 | Sets the private variables of the Parameter object. 136 | 137 | Args: 138 | path (str): The path to the variable we want to extract. Can use any annotation that has an existing 139 | equivalent decode function in decoders.py (like [jwt] or [json]). 140 | As an example, given the dictionary 141 | 142 | { 143 | "a": { 144 | "b": "{'c': 'hello'}", 145 | } 146 | } 147 | 148 | the path to c is "a/b[json]/c" 149 | func_param_name (str): the name for the dictionary in the function signature 150 | def fun(event, context). To extract from context func_param_name has to be "context" 151 | validators (list): A list of validators the value must conform to (e.g. Mandatory, 152 | RegexValidator(my_regex), ...) 153 | var_name (str): Optional, the name of the variable we want to assign the extracted value to. The default 154 | value is the last element of the path (e.g. "c" in the case above) 155 | default (any): Optional, a default value if the value is missing and not mandatory. 156 | The default value is None 157 | transform (function): Optional, a function to apply to the extracted value before checking validation rules. 158 | """ 159 | self._path = path 160 | self._default = default 161 | self._transform = transform 162 | ValidatedParameter.__init__(self, func_param_name, validators) 163 | BaseParameter.__init__(self, var_name) 164 | 165 | @property 166 | def path(self): 167 | """Getter for the path parameter.""" 168 | return self._path 169 | 170 | def extract_value(self, dict_value): 171 | """ 172 | Calculate and decode the value of the variable in the given path. 173 | 174 | Used by the extract_validated_value. 175 | 176 | Args: 177 | dict_value (dict): dictionary to be parsed. 178 | 179 | Returns: 180 | The extracted value 181 | """ 182 | for path_key in filter(lambda item: item != "", self._path.split(PATH_DIVIDER)): 183 | real_key, annotation = Parameter.get_annotations_from_key(path_key) 184 | if dict_value and real_key in dict_value: 185 | dict_value = decode(annotation, dict_value[real_key]) 186 | else: 187 | dict_value = self._default 188 | 189 | if not self._name: 190 | self._name = real_key 191 | 192 | if dict_value and self._transform: 193 | dict_value = self._transform(dict_value) 194 | 195 | return dict_value 196 | 197 | def validate_path(self, value, group_errors=False): 198 | """ 199 | Validates a value against the passed in validators 200 | 201 | Args: 202 | value: value to be validated 203 | group_errors (bool): flag that indicates if error messages are to be grouped together 204 | (if set to False, validation will end on first error) 205 | 206 | Returns: 207 | A list of validation key/pair errors 208 | """ 209 | key = self._path.split(PATH_DIVIDER)[-1] 210 | 211 | errors = self.validate(value, group_errors) 212 | 213 | return {key: errors} if errors else {} 214 | 215 | @staticmethod 216 | def get_annotations_from_key(key): 217 | """ 218 | Extract the key and the encoding type (annotation) from the string. 219 | 220 | Args: 221 | key (str): a combined string to extract key and annotation from. e.g. ("key[jwt]" -> "key", "jwt", 222 | "key" -> "key", None) 223 | """ 224 | if ANNOTATIONS_START in key and ANNOTATIONS_END in key: 225 | annotation = key[key.find(ANNOTATIONS_START) + 1:key.find(ANNOTATIONS_END)] 226 | return key.replace(ANNOTATIONS_START + annotation + ANNOTATIONS_END, ""), annotation 227 | return key, None 228 | -------------------------------------------------------------------------------- /aws_lambda_decorators/decoders.py: -------------------------------------------------------------------------------- 1 | """Decoder abstractions and functions for decoding/converting a string with a given annotation to a dictionary.""" 2 | import functools 3 | import json 4 | import logging 5 | import sys 6 | import jwt 7 | 8 | LOGGER = logging.getLogger() 9 | LOGGER.setLevel(logging.INFO) 10 | 11 | DECODE_FUNC_NAME = "decode_%s" 12 | DECODE_FUNC_MISSING_ERROR = "Missing decode function for annotation: %s" 13 | 14 | 15 | def decode(annotation, value): 16 | """ 17 | Converts an annotated string to a python dictionary. 18 | 19 | If :annotation: is not empty, use decode_:annotation:(:value:) to convert to dictionary. 20 | 21 | Existing decoders: 22 | annotation decoder 23 | [json] decode_json 24 | [jwt] decode_jwt 25 | [n] where n is a number. Decodes the value as an array, and picks item n from the array 26 | 27 | Args: 28 | annotation (str): the type of encoding of the value (e.g. 'json', 'jwt'). 29 | value (str): the value to be converted from given annotation to a dictionary. 30 | 31 | Returns: 32 | decoded dictionary. 33 | """ 34 | if annotation: 35 | if annotation.isdigit(): 36 | return value[int(annotation)] 37 | 38 | module_name = sys.modules[__name__] 39 | func_name = DECODE_FUNC_NAME % annotation 40 | if hasattr(module_name, func_name): 41 | func = getattr(module_name, func_name) 42 | return func(value) 43 | 44 | LOGGER.error(DECODE_FUNC_MISSING_ERROR, annotation) 45 | 46 | return value 47 | 48 | 49 | @functools.lru_cache() 50 | def decode_json(value): 51 | """Convert a json to a dictionary.""" 52 | return json.loads(value) 53 | 54 | 55 | @functools.lru_cache() 56 | def decode_jwt(value): 57 | """Convert a jwt to a dictionary.""" 58 | return jwt.decode(value, verify=False) 59 | -------------------------------------------------------------------------------- /aws_lambda_decorators/decorators.py: -------------------------------------------------------------------------------- 1 | """ 2 | AWS lambda decorators. 3 | 4 | A set of Python decorators to ease the development of AWS lambda functions. 5 | 6 | """ 7 | import json 8 | from http import HTTPStatus 9 | import boto3 10 | from aws_lambda_decorators.utils import (full_name, all_func_args, find_key_case_insensitive, failure, get_logger, 11 | find_websocket_connection_id, get_websocket_endpoint) 12 | 13 | 14 | LOGGER = get_logger(__name__) 15 | 16 | 17 | PARAM_LOG_MESSAGE = "Function: %s, Parameters: %s" 18 | RESPONSE_LOG_MESSAGE = "Function: %s, Response: %s" 19 | EXCEPTION_LOG_MESSAGE = "%s: %s in argument %s for path %s" 20 | EXCEPTION_LOG_MESSAGE_PATHLESS = "%s: %s in argument %s" 21 | ERROR_MESSAGE = "Error extracting parameters" 22 | VALIDATE_ERROR_MESSAGE = "Error validating parameters. Errors: %s" 23 | NON_SERIALIZABLE_ERROR_MESSAGE = "Response body is not JSON serializable" 24 | CORS_INVALID_TYPE_ERROR = "Invalid value type in CORS header" 25 | CORS_NON_DICT_ERROR = "Invalid response type for CORS headers" 26 | CORS_INVALID_TYPE_LOG_MESSAGE = "Cannot set %s header to a non %s value" 27 | NON_DICT_LOG_MESSAGE = "Cannot add headers to a non dictionary response" 28 | HSTS_NON_DICT_ERROR = "Invalid response type for HSTS header" 29 | 30 | UNKNOWN = "Unknown" 31 | 32 | 33 | def extract_from_event(parameters, group_errors=False, allow_none_defaults=False): 34 | """ 35 | Extracts a set of parameters from the event dictionary in a lambda handler. 36 | 37 | The extracted parameters are added as kwargs to the handler function. 38 | 39 | Usage: 40 | @extract_from_event([Parameter(path="/body[json]/my_param")]) 41 | def lambda_handler(event, context, my_param=None) 42 | pass 43 | 44 | Args: 45 | parameters (list): A collection of Parameter type items. 46 | group_errors (bool): flag that indicates if error messages are to be grouped together 47 | (if set to False, validation will end on first error) 48 | """ 49 | for param in parameters: 50 | param.func_param_name = "event" 51 | return extract(parameters, group_errors, allow_none_defaults) 52 | 53 | 54 | def extract_from_context(parameters, group_errors=False, allow_none_defaults=False): 55 | """ 56 | Extracts a set of parameters from the context dictionary in a lambda handler. 57 | 58 | The extracted parameters are added as kwargs to the handler function. 59 | 60 | Usage: 61 | @extract_from_context([Parameter(path="/parent/my_param")]) 62 | def lambda_handler(event, context, my_param=None) 63 | pass 64 | 65 | Args: 66 | parameters (list): A collection of Parameter type items. 67 | group_errors (bool): flag that indicates if error messages are to be grouped together 68 | (if set to False, validation will end on first error) 69 | """ 70 | for param in parameters: 71 | param.func_param_name = "context" 72 | return extract(parameters, group_errors, allow_none_defaults) 73 | 74 | 75 | def extract(parameters, group_errors=False, allow_none_defaults=False): 76 | """ 77 | Extracts a set of parameters from any function parameter passed to an AWS lambda handler. 78 | 79 | The extracted parameters are added as kwargs to the handler function. 80 | 81 | Usage: 82 | @extract([Parameter(path="headers/Authorization[jwt]/sub", var_name="user_id", func_param_name="event")]) 83 | def lambda_handler(event, context, user_id=None) 84 | pass 85 | 86 | Args: 87 | parameters (list): A collection of Parameter type items. 88 | group_errors (bool): flag that indicates if error messages are to be grouped together 89 | (if set to False, validation will end on first error) 90 | allow_none_defaults: A flag to allow None defaults. If True, None defaults will be passed into the kwargs. 91 | If the flag is set to False, the None defaults will not be added to kwargs, and the default will be 92 | picked up (if exists) from the method signature. 93 | """ 94 | def decorator(func): 95 | def wrapper(*args, **kwargs): 96 | try: 97 | param = None 98 | errors = [] 99 | arg_dictionary = all_func_args(func, args, kwargs) 100 | for param in parameters: 101 | param_val = arg_dictionary[param.func_param_name] 102 | return_val = param.extract_value(param_val) 103 | param_errors = param.validate_path(return_val, group_errors) 104 | if param_errors: 105 | errors.append(param_errors) 106 | if not group_errors: 107 | LOGGER.error(VALIDATE_ERROR_MESSAGE, errors) 108 | return failure(errors) 109 | elif allow_none_defaults or return_val is not None: 110 | kwargs[param.get_var_name()] = return_val 111 | except Exception as ex: # noqa: pylint - broad-except 112 | LOGGER.error(EXCEPTION_LOG_MESSAGE, full_name(ex), str(ex), 113 | param.func_param_name if param else UNKNOWN, 114 | param.path if param else UNKNOWN) 115 | return failure(ERROR_MESSAGE) 116 | else: 117 | if group_errors and errors: 118 | LOGGER.error(VALIDATE_ERROR_MESSAGE, errors) 119 | return failure(errors) 120 | 121 | return func(*args, **kwargs) 122 | return wrapper 123 | return decorator 124 | 125 | 126 | def handle_exceptions(handlers): 127 | """ 128 | Handles exceptions thrown by the wrapped/decorated function. 129 | 130 | Usage: 131 | @handle_exceptions([ExceptionHandler(exception=KeyError, friendly_message="Your message on KeyError except")]). 132 | def lambda_handler(params) 133 | pass 134 | 135 | Args: 136 | handlers (list): A collection of ExceptionHandler type items. 137 | """ 138 | def decorator(func): 139 | def wrapper(*args, **kwargs): 140 | try: 141 | return func(*args, **kwargs) 142 | except tuple(handler.exception for handler in handlers) as ex: # noqa: pylint - catching-non-exception 143 | failed_handler = [handler for handler in handlers if isinstance(ex, handler.exception)][0] 144 | message = failed_handler.friendly_message 145 | 146 | if message and str(ex): 147 | LOGGER.error("%s: %s", message, str(ex)) 148 | else: 149 | LOGGER.error(message if message else str(ex)) 150 | 151 | return failure(message if message else str(ex), failed_handler.status_code) 152 | return wrapper 153 | return decorator 154 | 155 | 156 | def log(parameters=False, response=False): 157 | """ 158 | Log parameters and/or response of the wrapped/decorated function using logging package 159 | 160 | Args: 161 | parameters: a flag indicating if the input parameters are to be logged 162 | response: a flag indicating if the returned response is to be logged 163 | """ 164 | def decorator(func): 165 | def wrapper(*args, **kwargs): 166 | if parameters: 167 | LOGGER.info(PARAM_LOG_MESSAGE, func.__name__, args) 168 | func_response = func(*args, **kwargs) 169 | if response: 170 | LOGGER.info(RESPONSE_LOG_MESSAGE, func.__name__, func_response) 171 | return func_response 172 | return wrapper 173 | return decorator 174 | 175 | 176 | def extract_from_ssm(ssm_parameters): 177 | """ 178 | Load given ssm parameters from AWS parameter store to the handler variables. 179 | 180 | Usage: 181 | @extract_from_ssm([SSMParameter(ssm_name="key", var_name="var")]) 182 | def lambda_handler(var=None) 183 | pass 184 | 185 | Args: 186 | ssm_parameters (list): A collection of SSMParameter type items. 187 | """ 188 | def decorator(func): 189 | def wrapper(*args, **kwargs): 190 | ssm = boto3.client("ssm") 191 | server_key_containers = ssm.get_parameters( 192 | Names=[ssm_parameter.get_ssm_name() for ssm_parameter in ssm_parameters], 193 | WithDecryption=True) 194 | for key_container in server_key_containers["Parameters"]: 195 | for ssm_parameter in ssm_parameters: # pragma: no cover 196 | if ssm_parameter.get_ssm_name() == key_container["Name"]: 197 | kwargs[ssm_parameter.get_var_name()] = key_container["Value"] 198 | break 199 | return func(*args, **kwargs) 200 | return wrapper 201 | return decorator 202 | 203 | 204 | def response_body_as_json(func): 205 | """ 206 | Convert the dictionary response of the wrapped/decorated function to a json string literal. 207 | 208 | Usage: 209 | @response_body_as_json 210 | def lambda_handler(): 211 | return {"statusCode": 200, "body": {"key": "value"}} 212 | 213 | will return {"statusCode": 200, "body": "{"key":"value"}"} 214 | """ 215 | def wrapper(*args, **kwargs): 216 | response = func(*args, **kwargs) 217 | if "body" in response: 218 | try: 219 | response["body"] = json.dumps(response["body"]) 220 | except TypeError: 221 | return failure(NON_SERIALIZABLE_ERROR_MESSAGE, 500) 222 | return response 223 | return wrapper 224 | 225 | 226 | def validate(parameters, group_errors=False): 227 | """ 228 | Validates a set of function parameters. 229 | 230 | Usage: 231 | @validate([ValidatedParameter(func_param_name="my_param", validators=[...])]) 232 | def func(my_param) 233 | pass 234 | 235 | Args: 236 | parameters (list): A collection of ValidatedParameter type items. 237 | group_errors (bool): flag that indicates if error messages are to be grouped together 238 | (if set to False, validation will end on first error) 239 | """ 240 | def decorator(func): 241 | def wrapper(*args, **kwargs): 242 | try: 243 | errors = [] 244 | arg_dictionary = all_func_args(func, args, kwargs) 245 | for param in parameters: 246 | param_val = arg_dictionary[param.func_param_name] 247 | param_errors = param.validate(param_val, group_errors) 248 | if param_errors: 249 | errors.append({param.func_param_name: param_errors}) 250 | if not group_errors: 251 | LOGGER.error(VALIDATE_ERROR_MESSAGE, errors) 252 | return failure(errors) 253 | except Exception as ex: # noqa: pylint - broad-except 254 | LOGGER.error(EXCEPTION_LOG_MESSAGE_PATHLESS, full_name(ex), str(ex), param.func_param_name) 255 | return failure(ERROR_MESSAGE) 256 | 257 | if group_errors and errors: 258 | LOGGER.error(VALIDATE_ERROR_MESSAGE, errors) 259 | return failure(errors) 260 | 261 | return func(*args, **kwargs) 262 | return wrapper 263 | return decorator 264 | 265 | 266 | def handle_all_exceptions(): 267 | """ 268 | Handles all exceptions thrown by the wrapped/decorated function. 269 | 270 | Usage: 271 | @handle_all_exceptions() 272 | def lambda_handler(params) 273 | pass 274 | """ 275 | def decorator(func): 276 | def wrapper(*args, **kwargs): 277 | try: 278 | return func(*args, **kwargs) 279 | except Exception as ex: # noqa: pylint - catching-non-exception 280 | LOGGER.error(str(ex)) 281 | return failure(str(ex)) 282 | return wrapper 283 | return decorator 284 | 285 | 286 | def cors(allow_origin=None, allow_methods=None, allow_headers=None, max_age=None): 287 | """ 288 | Adds CORS headers to the response of the decorated function 289 | 290 | Usage: 291 | @cors(allow_origin="http://example.com", allow_methods="POST,GET", allow_headers="Content-Type", max_age=86400) 292 | def func(my_param) 293 | pass 294 | 295 | Args: 296 | allow_origin: A string containing the comma-separated list of allowed origins 297 | allow_methods: A string containing the comma-separated list of allowed methods 298 | allow_headers: A string containing the comma-separated list of allowed headers 299 | max_age: An integer to indicate the caching time (in seconds) for the CORS pre-flight request 300 | 301 | Returns: 302 | The original decorated function response with the additional cors headers 303 | """ 304 | def decorator(func): 305 | def wrapper(*args, **kwargs): 306 | 307 | def update_header(headers, header_name, value, value_type): 308 | if value: 309 | if isinstance(value, value_type): 310 | header_key = find_key_case_insensitive(header_name, headers) 311 | headers[header_key] = f"{headers[header_key]},{value}" if header_key in headers else value 312 | else: 313 | LOGGER.error(CORS_INVALID_TYPE_LOG_MESSAGE, header_name, value_type) 314 | raise TypeError 315 | 316 | return headers 317 | 318 | response = func(*args, **kwargs) 319 | 320 | if isinstance(response, dict): 321 | headers_key = find_key_case_insensitive("headers", response) 322 | 323 | resp_headers = response[headers_key] if headers_key in response else {} 324 | 325 | try: 326 | resp_headers = update_header(resp_headers, "access-control-allow-origin", allow_origin, str) 327 | resp_headers = update_header(resp_headers, "access-control-allow-methods", allow_methods, str) 328 | resp_headers = update_header(resp_headers, "access-control-allow-headers", allow_headers, str) 329 | resp_headers = update_header(resp_headers, "access-control-max-age", max_age, int) 330 | 331 | response[headers_key] = resp_headers 332 | return response 333 | except TypeError: 334 | return failure(CORS_INVALID_TYPE_ERROR, HTTPStatus.INTERNAL_SERVER_ERROR) 335 | else: 336 | LOGGER.error(NON_DICT_LOG_MESSAGE) 337 | return failure(CORS_NON_DICT_ERROR, HTTPStatus.INTERNAL_SERVER_ERROR) 338 | return wrapper 339 | return decorator 340 | 341 | 342 | def push_ws_errors(websocket_endpoint_url: str): 343 | """ 344 | Handles and pushes any unsuccessful responses as errors to the calling client via websockets 345 | 346 | Usage: 347 | @push_ws_errors('https://api_id.execute_id.region.amazonaws.com/Prod') 348 | @handle_all_exceptions() 349 | def handler(event, context): 350 | return { 351 | 'statusCode': 400, 352 | 'body': { 353 | 'message': 'Bad request' 354 | } 355 | } 356 | 357 | Args: 358 | websocket_endpoint_url (str): The api gateway connection URL 359 | 360 | Returns: 361 | the original response from the lambda handler 362 | """ 363 | def decorator(func): 364 | def wrapper(*args, **kwargs): 365 | connection_id = find_websocket_connection_id(args) 366 | 367 | response = func(*args, **kwargs) 368 | success = response.get("statusCode", HTTPStatus.INTERNAL_SERVER_ERROR).value < 300 369 | 370 | if connection_id and not success: 371 | websocket_endpoint = get_websocket_endpoint(websocket_endpoint_url) 372 | 373 | ws_response = { 374 | "type": "error", 375 | "statusCode": response.get("statusCode", HTTPStatus.INTERNAL_SERVER_ERROR), 376 | "message": json.loads(response.get("body", "{}")).get("message") 377 | } 378 | 379 | websocket_endpoint.post_to_connection( 380 | ConnectionId=connection_id, 381 | Data=json.dumps(ws_response) 382 | ) 383 | 384 | return response 385 | return wrapper 386 | return decorator 387 | 388 | 389 | def push_ws_response(websocket_endpoint_url: str): 390 | """ 391 | Handles and pushes all responses to the calling client via websockets 392 | 393 | Usage: 394 | @push_ws_response('https://api_id.execute_id.region.amazonaws.com/Prod') 395 | def handler(event, context): 396 | return { 397 | 'statusCode': 200, 398 | 'body': 'Hello, world!' 399 | } 400 | 401 | Args: 402 | websocket_endpoint_url (str): The api gateway connection URL 403 | 404 | Returns: 405 | the original response from the lambda handler 406 | """ 407 | def decorator(func): 408 | def wrapper(*args, **kwargs): 409 | connection_id = find_websocket_connection_id(args) 410 | 411 | response = func(*args, **kwargs) 412 | 413 | if connection_id: 414 | websocket_endpoint = get_websocket_endpoint(websocket_endpoint_url) 415 | 416 | websocket_endpoint.post_to_connection( 417 | ConnectionId=connection_id, 418 | Data=json.dumps(response) 419 | ) 420 | 421 | return response 422 | return wrapper 423 | return decorator 424 | 425 | 426 | # pylint:disable=no-else-return 427 | def hsts(max_age: int = None): 428 | """ 429 | Adds HSTS header to the response of the decorated function 430 | 431 | Usage: 432 | @hsts(max_age=86400) 433 | def func(my_param) 434 | pass 435 | 436 | Args: 437 | max_age: An integer to indicate the time browser should remember your domain as HTTPS only communication. 438 | If not specified default value of 2 years is used. 439 | 440 | Returns: 441 | The original decorated function response with the additional hsts header 442 | """ 443 | def decorator(func): 444 | def wrapper(*args, **kwargs): 445 | response = func(*args, **kwargs) 446 | 447 | if isinstance(response, dict): 448 | headers_key = find_key_case_insensitive("headers", response) 449 | 450 | resp_headers = response[headers_key] if headers_key in response else {} 451 | 452 | header_key = find_key_case_insensitive("Strict-Transport-Security", resp_headers) 453 | header_value = f"max-age={max_age}" if max_age else "max-age=63072000" 454 | resp_headers[header_key] = header_value 455 | response[headers_key] = resp_headers 456 | return response 457 | else: 458 | LOGGER.error(NON_DICT_LOG_MESSAGE) 459 | return failure(HSTS_NON_DICT_ERROR, HTTPStatus.INTERNAL_SERVER_ERROR) 460 | return wrapper 461 | return decorator 462 | -------------------------------------------------------------------------------- /aws_lambda_decorators/utils.py: -------------------------------------------------------------------------------- 1 | """Utility functions.""" 2 | from functools import lru_cache 3 | from http import HTTPStatus 4 | import inspect 5 | import json 6 | import keyword 7 | import logging 8 | import os 9 | 10 | import boto3 11 | 12 | 13 | LOG_LEVEL = getattr(logging, os.getenv("LOG_LEVEL", "INFO")) 14 | 15 | 16 | def get_logger(name): 17 | logger = logging.getLogger(name) 18 | logger.setLevel(LOG_LEVEL) 19 | return logger 20 | 21 | 22 | def full_name(class_type): 23 | """ 24 | Gets the fully qualified name of a class type. 25 | 26 | From https://stackoverflow.com/questions/2020014/get-fully-qualified-class-name-of-an-object-in-python 27 | 28 | Args: 29 | class_type (type): the type of the class. 30 | 31 | Returns: 32 | the fully qualified name of the class type. 33 | """ 34 | module = class_type.__class__.__module__ 35 | if module is None or module == str.__class__.__module__: 36 | return class_type.__class__.__name__ # Avoid reporting __builtin__ 37 | return f"{module}.{class_type.__class__.__name__}" 38 | 39 | 40 | def is_type_in_list(item_type, items): 41 | """ 42 | Checks if there is an item of a given type in the list of items. 43 | 44 | Args: 45 | item_type (type): the type of the item. 46 | items (list): a list of items. 47 | 48 | Returns: 49 | true if an item of the given type exists in the list, otherwise false. 50 | """ 51 | return any(isinstance(item, item_type) for item in items) 52 | 53 | 54 | def is_valid_variable_name(name): 55 | """ 56 | Check if the given name is python allowed variable name. 57 | 58 | Args: 59 | name (str): the name to check. 60 | 61 | Returns: 62 | true if the name can be used as a python variable name. 63 | """ 64 | return name.isidentifier() and not keyword.iskeyword(name) 65 | 66 | 67 | def all_func_args(func, args, kwargs): 68 | """ 69 | Combine arguments and key word arguments to a dictionary. 70 | 71 | Args: 72 | func (function): function whose arguments should be extracted. 73 | args (list): list of function args (*args). 74 | kwargs (dict): dictionary of function kwargs (**kwargs). 75 | 76 | Returns: 77 | dictionary argument name -> argument value. 78 | """ 79 | arg_spec = inspect.getfullargspec(func)[0] 80 | arg_dictionary = {arg_spec[idx]: value for idx, value in enumerate(args)} 81 | arg_dictionary.update(kwargs) 82 | return arg_dictionary 83 | 84 | 85 | def find_key_case_insensitive(key_name, the_dict): 86 | """ 87 | Finds if a dictionary (the_dict) has a string key (key_name) in any string case 88 | 89 | Args: 90 | key_name: the key to search in the dictionary 91 | the_dict: the dictionary to search 92 | 93 | Returns: 94 | The found key name in its original case, if found. Otherwise, returns the searching key name 95 | 96 | """ 97 | for key in the_dict: 98 | if key.lower() == key_name: 99 | return key 100 | return key_name 101 | 102 | 103 | def failure(errors, status_code=HTTPStatus.BAD_REQUEST): 104 | """ 105 | Returns an error to the caller 106 | 107 | Args: 108 | errors (list): a list of errors to be returned 109 | status_code (int): the status code of the error 110 | 111 | Returns: 112 | An object that contains the status code and the list of errors 113 | """ 114 | return { 115 | "statusCode": status_code, 116 | "body": json.dumps({"message": errors}) 117 | } 118 | 119 | 120 | def find_websocket_connection_id(args: list) -> str: 121 | """ 122 | Finds an API Gateway connection id from the event dictionary in the 123 | arguments of a lambda 124 | 125 | Args: 126 | args (list): a list of arguments from a lambda (*args) 127 | 128 | Returns: 129 | The connection id of a user as a string if found 130 | None if not 131 | """ 132 | for arg in args: 133 | if isinstance(arg, dict) and "requestContext" in arg: 134 | return arg["requestContext"].get("connectionId") 135 | return None 136 | 137 | 138 | @lru_cache() 139 | def get_websocket_endpoint(endpoint_url: str) -> "botocore.client.ApiGatewayManagementApi": # noqa: pyflakes - F821 140 | """ 141 | Gets an instance of ApiGatewayManagementApi for sending messages 142 | through websockets 143 | 144 | Args: 145 | endpoint_url (str): an api gateway connection url (ish) 146 | 147 | Returns: 148 | The api gateway management client (cached) 149 | """ 150 | return boto3.client( 151 | "apigatewaymanagementapi", 152 | endpoint_url=endpoint_url 153 | ) 154 | -------------------------------------------------------------------------------- /aws_lambda_decorators/validators.py: -------------------------------------------------------------------------------- 1 | """Validation rules.""" 2 | import datetime 3 | import re 4 | from schema import SchemaError 5 | 6 | CURRENCIES = {"LKR", "ETB", "RWF", "NZD", "SBD", "MKD", "NPR", "LAK", "KWD", "INR", "HUF", "AFN", "BTN", "ISK", "MVR", 7 | "WST", "MNT", "AZN", "SAR", "JMD", "BIF", "BMD", "CAD", "GEL", "MXN", "BHD", "HKD", "RSD", "PKR", "SLL", 8 | "NGN", "TOP", "SCR", "SVC", "CHW", "UYW", "IDR", "IQD", "THB", "GBP", "MYR", "SDG", "CNY", "GNF", "LRD", 9 | "KHR", "TJS", "BYN", "SHP", "AED", "BOB", "CUC", "PHP", "SSP", "USN", "MZN", "COP", "SEK", "EUR", "CDF", 10 | "CRC", "KMF", "JPY", "ZWL", "ALL", "GHS", "GIP", "QAR", "GYD", "HTG", "VUV", "CZK", "ANG", "AWG", "AMD", 11 | "DOP", "TRY", "ZMW", "MGA", "KZT", "XUA", "ARS", "XPF", "BRL", "MXV", "LSL", "CLP", "KES", "PYG", "TND", 12 | "MAD", "DZD", "MWK", "BSD", "BBD", "FKP", "KGS", "BWP", "CVE", "HRK", "DKK", "COU", "SYP", "LYD", "PLN", 13 | "TZS", "KPW", "UGX", "BOV", "UAH", "NAD", "AOA", "VES", "SOS", "CUP", "SGD", "PAB", "UZS", "STN", "SRD", 14 | "CHE", "XOF", "DJF", "PGK", "UYI", "XCD", "BZD", "EGP", "ERN", "RON", "TWD", "USD", "FJD", "VND", "SZL", 15 | "BND", "HNL", "KRW", "XAF", "MDL", "BDT", "MUR", "PEN", "OMR", "NIO", "TMT", "YER", "TTD", "GMD", "XDR", 16 | "CHF", "NOK", "GTQ", "JOD", "KYD", "UYU", "RUB", "ZAR", "AUD", "BGN", "MOP", "LBP", "MRU", "CLF", "XSU", 17 | "BAM", "MMK", "IRR", "ILS"} 18 | 19 | 20 | class Validator: # noqa: pylint - too-few-public-methods 21 | """Validation rule to check if the given mandatory value exists.""" 22 | ERROR_MESSAGE = "Unknown error" 23 | 24 | def __init__(self, error_message, condition=None): 25 | """ 26 | Validates a parameter 27 | 28 | Args: 29 | error_message (str): A custom error message to output if validation fails 30 | condition (any): A condition to validate 31 | """ 32 | self._error_message = error_message or self.ERROR_MESSAGE 33 | self._condition = condition 34 | 35 | def message(self, value=None): # noqa: pylint - unused-argument 36 | """ 37 | Gets the formatted error message for a failed mandatory check 38 | 39 | Args: 40 | value (any): The validated value 41 | 42 | Returns: 43 | The error message 44 | """ 45 | return self._error_message.format(value=value, condition=self._condition) 46 | 47 | 48 | class Mandatory(Validator): # noqa: pylint - too-few-public-methods 49 | """Validation rule to check if the given mandatory value exists.""" 50 | ERROR_MESSAGE = "Missing mandatory value" 51 | 52 | def __init__(self, error_message=None): 53 | """ 54 | Checks if a parameter has a value 55 | 56 | Args: 57 | error_message (str): A custom error message to output if validation fails 58 | """ 59 | super().__init__(error_message) 60 | 61 | @staticmethod 62 | def validate(value=None): 63 | """ 64 | Check if the given mandatory value exists. 65 | 66 | Args: 67 | value (any): Value to be validated. 68 | """ 69 | return value is not None and str(value) 70 | 71 | 72 | class RegexValidator(Validator): # noqa: pylint - too-few-public-methods 73 | """Validation rule to check if a value matches a regular expression.""" 74 | ERROR_MESSAGE = "'{value}' does not conform to regular expression '{condition}'" 75 | 76 | def __init__(self, regex="", error_message=None): 77 | """ 78 | Compile a regular expression to a regular expression pattern. 79 | 80 | Args: 81 | regex (str): Regular expression for parameter validation. 82 | error_message (str): A custom error message to output if validation fails 83 | """ 84 | super().__init__(error_message, regex) 85 | self._regexp = re.compile(regex) 86 | 87 | def validate(self, value=None): 88 | """ 89 | Check if a value adheres to the defined regular expression. 90 | 91 | Args: 92 | value (str): Value to be validated. 93 | """ 94 | if value is None: 95 | return True 96 | 97 | return self._regexp.fullmatch(value) is not None 98 | 99 | 100 | class SchemaValidator(Validator): # noqa: pylint - too-few-public-methods 101 | """Validation rule to check if a value matches a regular expression.""" 102 | ERROR_MESSAGE = "'{value}' does not validate against schema '{condition}'" 103 | 104 | def __init__(self, schema, error_message=None): 105 | """ 106 | Set the schema field. 107 | 108 | Args: 109 | schema (Schema): The expected schema. 110 | error_message (str): A custom error message to output if validation fails 111 | """ 112 | super().__init__(error_message, schema) 113 | 114 | def validate(self, value=None): 115 | """ 116 | Check if the object adheres to the defined schema. 117 | 118 | Args: 119 | value (object): Value to be validated. 120 | """ 121 | try: 122 | if value is None: 123 | return True 124 | 125 | return self._condition.validate(value) == value 126 | except SchemaError: 127 | return False 128 | 129 | 130 | class Minimum(Validator): # noqa: pylint - too-few-public-methods 131 | """Validation rule to check if a value is greater than a minimum value.""" 132 | ERROR_MESSAGE = "'{value}' is less than minimum value '{condition}'" 133 | 134 | def __init__(self, minimum: (float, int), error_message=None): 135 | """ 136 | Set the minimum value. 137 | 138 | Args: 139 | minimum (float, int): The minimum value. 140 | error_message (str): A custom error message to output if validation fails 141 | """ 142 | super().__init__(error_message, minimum) 143 | 144 | def validate(self, value=None): 145 | """ 146 | Check if the value is greater than the minimum. 147 | 148 | Args: 149 | value (float, int): Value to be validated. 150 | """ 151 | if value is None: 152 | return True 153 | 154 | if isinstance(value, (float, int)): 155 | return self._condition <= value 156 | 157 | return False 158 | 159 | 160 | class Maximum(Validator): # noqa: pylint - too-few-public-methods 161 | """Validation rule to check if a value is less than a maximum value.""" 162 | ERROR_MESSAGE = "'{value}' is greater than maximum value '{condition}'" 163 | 164 | def __init__(self, maximum: (float, int), error_message=None): 165 | """ 166 | Set the maximum value. 167 | 168 | Args: 169 | maximum (float, int): The maximum value. 170 | error_message (str): A custom error message to output if validation fails 171 | """ 172 | super().__init__(error_message, maximum) 173 | 174 | def validate(self, value=None): 175 | """ 176 | Check if the value is less than the maximum. 177 | 178 | Args: 179 | value (float, int): Value to be validated. 180 | """ 181 | if value is None: 182 | return True 183 | 184 | if isinstance(value, (float, int)): 185 | return self._condition >= value 186 | 187 | return False 188 | 189 | 190 | class MinLength(Validator): # noqa: pylint - too-few-public-methods 191 | """Validation rule to check if a string is shorter than a minimum length.""" 192 | ERROR_MESSAGE = "'{value}' is shorter than minimum length '{condition}'" 193 | 194 | def __init__(self, min_length: int, error_message=None): 195 | """ 196 | Set the minimum length. 197 | 198 | Args: 199 | min_length (int): The minimum length. 200 | error_message (str): A custom error message to output if validation fails 201 | """ 202 | super().__init__(error_message, min_length) 203 | 204 | def validate(self, value=None): 205 | """ 206 | Check if a string is shorter than the minimum length. 207 | 208 | Args: 209 | value (str): String to be validated. 210 | """ 211 | if value is None: 212 | return True 213 | 214 | return len(str(value)) >= self._condition 215 | 216 | 217 | class MaxLength(Validator): # noqa: pylint - too-few-public-methods 218 | """Validation rule to check if a string is longer than a maximum length.""" 219 | ERROR_MESSAGE = "'{value}' is longer than maximum length '{condition}'" 220 | 221 | def __init__(self, max_length: int, error_message=None): 222 | """ 223 | Set the maximum length. 224 | 225 | Args: 226 | max_length (int): The maximum length. 227 | error_message (str): A custom error message to output if validation fails 228 | """ 229 | super().__init__(error_message, max_length) 230 | 231 | def validate(self, value=None): 232 | """ 233 | Check if a string is longer than the maximum length. 234 | 235 | Args: 236 | value (str): String to be validated. 237 | """ 238 | if value is None: 239 | return True 240 | 241 | return len(str(value)) <= self._condition 242 | 243 | 244 | class Type(Validator): 245 | ERROR_MESSAGE = "'{value}' is not of type '{condition.__name__}'" 246 | 247 | def __init__(self, valid_type: type, error_message=None): 248 | """ 249 | Set the valid type. 250 | 251 | Args: 252 | valid_type (type): The value type to check. 253 | error_message (str): A custom error message to output if validation fails 254 | """ 255 | super().__init__(error_message, valid_type) 256 | 257 | def validate(self, value=None): 258 | """ 259 | Check if a value is of the right type 260 | 261 | Args: 262 | value (object): object to be validated. 263 | """ 264 | if value is None: 265 | return True 266 | 267 | return isinstance(value, self._condition) 268 | 269 | 270 | class EnumValidator(Validator): 271 | ERROR_MESSAGE = "'{value}' is not in list '{condition}'" 272 | 273 | def __init__(self, *args: list, error_message=None): 274 | """ 275 | Set the list of valid values. 276 | 277 | Args: 278 | error_message (str): A custom error message to output if validation fails 279 | args (list): The list of valid values 280 | """ 281 | super().__init__(error_message, args) 282 | 283 | def validate(self, value=None): 284 | """ 285 | Check if a value is in a list of valid values 286 | 287 | Args: 288 | value (object): object to be validated. 289 | """ 290 | if value is None: 291 | return True 292 | 293 | return value in self._condition 294 | 295 | 296 | class NonEmpty(Validator): # noqa: pylint - too-few-public-methods 297 | """Validation rule to check if the given value is empty.""" 298 | ERROR_MESSAGE = "Value is empty" 299 | 300 | def __init__(self, error_message=None): 301 | """ 302 | Checks if a parameter has a non empty value 303 | 304 | Args: 305 | error_message (str): A custom error message to output if validation fails 306 | """ 307 | super().__init__(error_message) 308 | 309 | @staticmethod 310 | def validate(value=None): 311 | """ 312 | Check if the given value is non empty. 313 | 314 | Args: 315 | value (any): Value to be validated. 316 | """ 317 | if value is None or value in (0, 0.0, 0j): 318 | return True 319 | 320 | return bool(value) 321 | 322 | 323 | class DateValidator(Validator): 324 | """Validation rule to check if a string is a valid date according to some format.""" 325 | ERROR_MESSAGE = "'{value}' is not a '{condition}' date" 326 | 327 | def __init__(self, date_format: str, error_message=None): 328 | """ 329 | Checks if a string is a date with a given format 330 | 331 | Args: 332 | date_format (str): The date format to check against 333 | error_message (str): A custom error message to output if validation fails 334 | """ 335 | super().__init__(error_message, date_format) 336 | 337 | def validate(self, value=None): 338 | """ 339 | Check if a string is a date with a given format 340 | 341 | Args: 342 | value (str): string date to validate against a format 343 | """ 344 | if value is None: 345 | return True 346 | 347 | try: 348 | datetime.datetime.strptime(value, self._condition) 349 | except ValueError: 350 | return False 351 | else: 352 | return True 353 | 354 | 355 | class CurrencyValidator(Validator): 356 | """Validation rule to check if a string is a valid currency according to ISO 4217 Currency Code.""" 357 | ERROR_MESSAGE = "'{value}' is not a valid currency code." 358 | 359 | def __init__(self, error_message=None): 360 | """ 361 | Checks if a string is a valid currency based on ISO 4217 362 | 363 | Args: 364 | error_message (str): A custom error message to output if validation fails 365 | """ 366 | super().__init__(error_message) 367 | 368 | @staticmethod 369 | def validate(value=None): 370 | """ 371 | Check if a string is a valid currency based on ISO 4217 372 | 373 | Args: 374 | value (str): value to validate against a ISO 4217 375 | """ 376 | 377 | if value is None: 378 | return True 379 | 380 | return value.upper() in CURRENCIES 381 | -------------------------------------------------------------------------------- /buildspec.yml: -------------------------------------------------------------------------------- 1 | version: 0.2 2 | 3 | phases: 4 | install: 5 | commands: 6 | - pip install --upgrade pip 7 | - pip install -q boto3 bandit coverage==4.5.4 schema pylint_quotes prospector==1.3.1 PyJWT==1.7.1 8 | pre_build: 9 | commands: 10 | - export LOG_LEVEL=CRITICAL 11 | - export OUR_COMMIT_SHA=`git rev-parse HEAD` 12 | - bandit -r -q . 13 | - prospector 14 | - coverage run --source='.' -m unittest 15 | - coverage report -m --fail-under=100 --omit=*/__init__.py,tests/*,setup.py,examples/* 16 | -------------------------------------------------------------------------------- /examples/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gridsmartercities/aws-lambda-decorators/16dbe6ae1b9982f312d593336682c4ebbcd4f52d/examples/__init__.py -------------------------------------------------------------------------------- /examples/examples.py: -------------------------------------------------------------------------------- 1 | # flake8: noqa 2 | import boto3 3 | from boto3.dynamodb.conditions import Key 4 | from botocore.exceptions import ClientError 5 | from schema import Schema, Or 6 | from aws_lambda_decorators import (extract, extract_from_event, extract_from_context, extract_from_ssm, validate, log, 7 | handle_exceptions, response_body_as_json, Parameter, SSMParameter, 8 | ValidatedParameter, ExceptionHandler, Mandatory, RegexValidator, 9 | handle_all_exceptions, cors, SchemaValidator, Maximum, Minimum, Type, EnumValidator, 10 | NonEmpty, DateValidator, CurrencyValidator, hsts) 11 | 12 | 13 | @extract(parameters=[ 14 | # extracts a non mandatory my_param from a_dictionary 15 | Parameter(path="/parent/my_param", func_param_name="a_dictionary"), 16 | # extracts a non mandatory missing_non_mandatory from a_dictionary 17 | Parameter(path="/parent/missing_non_mandatory", func_param_name="a_dictionary"), 18 | # does not fail as the parameter is not validated as mandatory 19 | Parameter(path="/parent/missing_mandatory", func_param_name="a_dictionary"), 20 | # extracts a mandatory id as "user_id" from another_dictionary 21 | Parameter(path="/parent/child/id", validators=[Mandatory], var_name="user_id", 22 | func_param_name="another_dictionary") 23 | ]) 24 | def extract_example(a_dictionary, another_dictionary, my_param="aDefaultValue", 25 | missing_non_mandatory="I am missing", missing_mandatory=None, user_id=None): 26 | # you can now access the extracted parameters directly: 27 | return my_param, missing_non_mandatory, missing_mandatory, user_id 28 | 29 | 30 | @extract(parameters=[ 31 | # extracts a non mandatory my_param from a_dictionary 32 | Parameter(path="/parent/my_param", func_param_name="a_dictionary") 33 | ]) 34 | def extract_to_kwargs_example(a_dictionary, **kwargs): 35 | return kwargs["my_param"] 36 | 37 | 38 | @extract(parameters=[ 39 | # extracts a mandatory my_param from a_dictionary 40 | Parameter(path="/parent/mandatory_param", func_param_name="a_dictionary", validators=[Mandatory]) 41 | ]) 42 | def extract_mandatory_param_example(a_dictionary, mandatory_param=None): 43 | return "Here!" # this part will never be reached, if the mandatory_param is missing 44 | 45 | 46 | @extract(parameters=[ 47 | # extracts two mandatory parameters from a_dictionary 48 | Parameter(path="/parent/mandatory_param", func_param_name="a_dictionary", validators=[Mandatory]), 49 | Parameter(path="/parent/another_mandatory_param", func_param_name="a_dictionary", validators=[Mandatory]), 50 | Parameter(path="/parent/an_int", func_param_name="a_dictionary", validators=[Maximum(10), Minimum(5)]) 51 | ], group_errors=True) # groups both errors together 52 | def extract_multiple_param_example(a_dictionary, mandatory_param=None, another_mandatory_param=None, an_int=0): 53 | return "Here!" # this part will never be reached, if the mandatory_param is missing 54 | 55 | 56 | @extract(parameters=[ 57 | # extracts a non mandatory my_param from a_dictionary 58 | Parameter(path="/parent[json]/my_param", func_param_name="a_dictionary") 59 | ]) 60 | def extract_from_json_example(a_dictionary, my_param=None): 61 | return my_param 62 | 63 | 64 | @extract_from_event(parameters=[ 65 | # extracts a mandatory my_param from the json body of the event 66 | Parameter(path="/body[json]/my_param", validators=[Mandatory]), 67 | # extract the mandatory sub value as user_id from the authorization JWT 68 | Parameter(path="/headers/Authorization[jwt]/sub", validators=[Mandatory], var_name="user_id") 69 | ]) 70 | def extract_from_event_example(event, context, my_param=None, user_id=None): 71 | return my_param, user_id # returns ("Hello!", "1234567890") 72 | 73 | 74 | @extract_from_context(parameters=[ 75 | # extracts a mandatory my_param from the parent element in context 76 | Parameter(path="/parent/my_param", validators=[Mandatory]) 77 | ]) 78 | def extract_from_context_example(event, context, my_param=None): 79 | return my_param # returns "Hello!" 80 | 81 | 82 | @extract_from_ssm(ssm_parameters=[ 83 | # extracts the value of one_key from SSM as a kwarg named "one_key" 84 | SSMParameter(ssm_name="one_key"), 85 | # extracts another_key as a kwarg named "another" 86 | SSMParameter(ssm_name="another_key", var_name="another") 87 | ]) 88 | def extract_from_ssm_example(your_func_params, one_key=None, another=None): 89 | return your_func_params, one_key, another 90 | 91 | 92 | @validate(parameters=[ 93 | # validates a_param as mandatory 94 | ValidatedParameter(func_param_name="a_param", validators=[Mandatory]), 95 | # validates another_param as mandatory and containing only digits 96 | ValidatedParameter(func_param_name="another_param", validators=[Mandatory, RegexValidator(r"\d+")]), 97 | # validates param_with_schema as an object with specified schema 98 | ValidatedParameter(func_param_name="param_with_schema", validators=[SchemaValidator(Schema({"a": Or(str, dict)}))]) 99 | ]) 100 | def validate_example(a_param, another_param, param_with_schema): 101 | return a_param, another_param, param_with_schema # returns a_param, another_param, param_with_schema 102 | 103 | 104 | @log(parameters=True, response=True) 105 | def log_example(parameters): 106 | return "Done!" 107 | 108 | 109 | @handle_exceptions(handlers=[ 110 | ExceptionHandler(ClientError, "Your message when a client error happens.") 111 | ]) 112 | def handle_exceptions_example(): 113 | dynamodb = boto3.resource("dynamodb") 114 | table = dynamodb.Table("your_table_name") 115 | table.query(KeyConditionExpression=Key("user_id").eq("1234")) 116 | # ... 117 | 118 | 119 | @response_body_as_json 120 | def response_body_as_json_example(): 121 | return {"statusCode": 400, "body": {"param": "hello!"}} 122 | 123 | 124 | @extract(parameters=[ 125 | # extracts a non mandatory my_param from a_dictionary 126 | Parameter(path="/parent[1]/my_param", func_param_name="a_dictionary") 127 | ]) 128 | def extract_from_list_example(a_dictionary, my_param=None): 129 | return my_param 130 | 131 | 132 | @handle_all_exceptions() 133 | def handle_all_exceptions_example(): 134 | test_list = [1, 2, 3] 135 | test_list[5] 136 | # ... 137 | 138 | 139 | @cors(allow_origin="*", allow_methods="POST", allow_headers="Content-Type", max_age=86400) 140 | def cors_example(): 141 | return {"statusCode": 200} 142 | 143 | 144 | @hsts() 145 | def hsts_example(): 146 | return {'statusCode': 200} 147 | 148 | 149 | @extract(parameters=[ 150 | Parameter(path="/parent/an_int", func_param_name="a_dictionary", 151 | validators=[Minimum(100, "Bad value {value}: should be at least {condition}")]) 152 | ]) 153 | def extract_minimum_param_with_custom_error_example(a_dictionary, an_int=None): 154 | return {} 155 | 156 | 157 | @extract(parameters=[ 158 | Parameter(path="/params/my_param_1", func_param_name="a_dictionary"), 159 | Parameter(path="/params/my_param_2", func_param_name="a_dictionary") 160 | ]) 161 | def extract_dictionary_example(a_dictionary, **kwargs): 162 | return kwargs 163 | 164 | 165 | @extract(parameters=[ 166 | Parameter(path="/params/a_bool", func_param_name="a_dictionary", validators=[Type(bool)]) 167 | ]) 168 | def extract_type_param(a_dictionary, a_bool=False): 169 | return a_bool 170 | 171 | 172 | @extract(parameters=[ 173 | Parameter(path="/params/an_enum", func_param_name="a_dictionary", validators=[EnumValidator("Hello", "Bye")]) 174 | ]) 175 | def extract_enum_param(a_dictionary, an_enum=None): 176 | return an_enum 177 | 178 | 179 | @extract(parameters=[ 180 | Parameter(path="/params/non_empty", func_param_name="a_dictionary", validators=[NonEmpty]) 181 | ]) 182 | def extract_non_empty_param(a_dictionary, non_empty=None): 183 | return non_empty 184 | 185 | 186 | @extract(parameters=[ 187 | Parameter(path="/params/date_example", func_param_name="a_dictionary", 188 | validators=[DateValidator("%Y-%m-%d %H:%M:%S")]) 189 | ]) 190 | def extract_date_param(a_dictionary, date_example=None): 191 | return date_example 192 | 193 | 194 | @extract(parameters=[ 195 | Parameter(path="/params/my_param", func_param_name="a_dictionary", transform=int) 196 | ]) 197 | def extract_with_transform_example(a_dictionary, my_param=None): 198 | return my_param 199 | 200 | 201 | def to_int(arg): 202 | try: 203 | return int(arg) 204 | except Exception: 205 | raise Exception("My custom error message") 206 | 207 | 208 | @extract(parameters=[ 209 | Parameter(path="/params/my_param", func_param_name="a_dictionary", transform=to_int) 210 | ]) 211 | def extract_with_custom_transform_example(a_dictionary, my_param=None): 212 | return {} 213 | 214 | 215 | @extract(parameters=[ 216 | Parameter(path="/params/currency_example", func_param_name="a_dictionary", 217 | validators=[CurrencyValidator]) 218 | ]) 219 | def extract_currency_param(a_dictionary, currency_example=None): 220 | return currency_example 221 | -------------------------------------------------------------------------------- /examples/test_examples.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | from unittest.mock import patch, MagicMock, call 3 | from botocore.exceptions import ClientError 4 | from examples.examples import (extract_example, extract_to_kwargs_example, extract_mandatory_param_example, 5 | extract_from_json_example, extract_from_event_example, extract_from_context_example, 6 | extract_from_ssm_example, validate_example, log_example, handle_exceptions_example, 7 | response_body_as_json_example, extract_from_list_example, handle_all_exceptions_example, 8 | cors_example, extract_multiple_param_example, 9 | extract_minimum_param_with_custom_error_example, extract_dictionary_example, 10 | extract_type_param, extract_enum_param, extract_non_empty_param, extract_date_param, 11 | extract_with_transform_example, extract_with_custom_transform_example, 12 | extract_currency_param, hsts_example) 13 | 14 | 15 | # pylint:disable=too-many-public-methods 16 | class ExamplesTests(unittest.TestCase): 17 | 18 | def test_extract_example(self): 19 | # Given these two dictionaries: 20 | a_dict = { 21 | "parent": { 22 | "my_param": "Hello!" 23 | }, 24 | "other": "other value" 25 | } 26 | b_dict = { 27 | "parent": { 28 | "child": { 29 | "id": "123" 30 | } 31 | } 32 | } 33 | 34 | # calling the decorated extract_example: 35 | response = extract_example(a_dict, b_dict) 36 | 37 | # will return the values of the extracted parameters. 38 | self.assertEqual(("Hello!", "I am missing", None, "123"), response) 39 | 40 | def test_extract_to_kwargs_example(self): 41 | # Given this dictionary: 42 | dictionary = { 43 | "parent": { 44 | "my_param": "Hello!" 45 | }, 46 | "other": "other value" 47 | } 48 | # we can extract "my_param". 49 | response = extract_to_kwargs_example(dictionary) 50 | 51 | # and get the value from kwargs. 52 | self.assertEqual("Hello!", response) 53 | 54 | def test_extract_missing_mandatory_example(self): 55 | # Given this dictionary: 56 | dictionary = { 57 | "parent": { 58 | "my_param": "Hello!" 59 | }, 60 | "other": "other value" 61 | } 62 | # we can try to extract a missing mandatory parameter. 63 | response = extract_mandatory_param_example(dictionary) 64 | 65 | # but we will get an error response as it is missing and it was mandatory. 66 | self.assertEqual( 67 | {"statusCode": 400, "body": "{\"message\": [{\"mandatory_param\": [\"Missing mandatory value\"]}]}"}, 68 | response) 69 | 70 | def test_extract_multiple_param_example(self): 71 | # Given this dictionary: 72 | dictionary = { 73 | "parent": { 74 | "my_param": "Hello!", 75 | "an_int": 20 76 | }, 77 | "other": "other value" 78 | } 79 | # we can try to extract a missing mandatory parameter. 80 | response = extract_multiple_param_example(dictionary) 81 | 82 | # but we will get an error response as it is missing and it was mandatory. 83 | self.assertEqual({"statusCode": 400, 84 | "body": "{\"message\": [{\"mandatory_param\": [\"Missing mandatory value\"]}, " 85 | "{\"another_mandatory_param\": [\"Missing mandatory value\"]}, " 86 | "{\"an_int\": [\"'20' is greater than maximum value '10'\"]}]}"}, response) 87 | 88 | def test_extract_not_missing_mandatory_example(self): 89 | # Given this dictionary: 90 | dictionary = { 91 | "parent": { 92 | "mandatory_param": "Hello!" 93 | } 94 | } 95 | 96 | # we can try to extract a mandatory parameter. 97 | response = extract_mandatory_param_example(dictionary) 98 | 99 | # we will get the coded lambda response as the parameter is not missing. 100 | self.assertEqual("Here!", response) 101 | 102 | def test_extract_from_json_example(self): 103 | # Given this dictionary: 104 | dictionary = { 105 | "parent": "{\"my_param\": \"Hello!\"}", 106 | "other": "other value" 107 | } 108 | 109 | # we can extract from a json string by adding the [json] annotation to parent. 110 | response = extract_from_json_example(dictionary) 111 | 112 | # and we will get the value of the "my_param" parameter inside the "parent" json. 113 | self.assertEqual("Hello!", response) 114 | 115 | def test_extract_from_event_example(self): 116 | # Given this API Gateway event: 117 | event = { 118 | "body": "{\"my_param\": \"Hello!\"}", 119 | "headers": { 120 | "Authorization": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9l" 121 | "IiwiaWF0IjoxNTE2MjM5MDIyfQ.SflKxwRJSMeKKF2QT4fwpMeJf36POk6yJV_adQssw5c" 122 | } 123 | } 124 | 125 | # we can extract "my_param" from event, and "sub" from Authorization JWT 126 | # using the "extract_from_event" decorator. 127 | response = extract_from_event_example(event, None) 128 | 129 | # and return both values in the lambda response. 130 | self.assertEqual(("Hello!", "1234567890"), response) 131 | 132 | def test_extract_from_context_example(self): 133 | # Given this context: 134 | context = { 135 | "parent": { 136 | "my_param": "Hello!" 137 | } 138 | } 139 | 140 | # we can extract "my_param" from context using the "extract_from_context" decorator. 141 | response = extract_from_context_example(None, context) 142 | 143 | # and return the value in the lambda response. 144 | self.assertEqual("Hello!", response) 145 | 146 | @patch("boto3.client") 147 | def test_extract_from_ssm_example(self, mock_boto_client): 148 | # Mocking the extraction of an SSM parameter from AWS SSM: 149 | mock_ssm = MagicMock() 150 | mock_ssm.get_parameters.return_value = { 151 | "Parameters": [ 152 | { 153 | "Value": "test1", 154 | "Name": "one_key" 155 | }, 156 | { 157 | "Value": "test2", 158 | "Name": "another_key" 159 | } 160 | ] 161 | } 162 | mock_boto_client.return_value = mock_ssm 163 | 164 | # we can extract the value of that parameter. 165 | response = extract_from_ssm_example(None) 166 | 167 | # and return the mocked SSM parameters from the lambda. 168 | self.assertEqual((None, "test1", "test2"), response) 169 | 170 | def test_validate_example(self): 171 | # We can validate non-dictionary parameters too, using the "validate" decorator. 172 | response = validate_example("Hello!", "123456", {"a": {"b": "c"}}) 173 | 174 | # in this case the parameters are valid and are returned by the function. 175 | self.assertEqual(("Hello!", "123456", {"a": {"b": "c"}}), response) 176 | 177 | def test_validate_raises_exception_example(self): 178 | # We can validate non-dictionary parameters too, using the "validate" decorator. 179 | response = validate_example("Hello!", "123456", {"a": 123456}) 180 | 181 | # in this case at least one parameter is not valid and a 400 error is returned to the caller. 182 | self.assertEqual({ 183 | "statusCode": 400, 184 | "body": "{\"message\": [{\"param_with_schema\": [\"'{'a': 123456}' does not validate against schema" 185 | " 'Schema({'a': Or(, )})'\"]}]}"}, response) 186 | 187 | @staticmethod 188 | @patch("aws_lambda_decorators.decorators.LOGGER") 189 | def test_log_example(mock_logger): 190 | # We can use the "log" decorator to log the parameters passed to a lambda and/or the response from the lambda. 191 | log_example("Hello!") # logs "Hello!" and "Done!" 192 | 193 | # and check the log messages were produced. 194 | mock_logger.info.assert_has_calls([ 195 | call("Function: %s, Parameters: %s", "log_example", ("Hello!",)), 196 | call("Function: %s, Response: %s", "log_example", "Done!") 197 | ]) 198 | 199 | @patch("boto3.resource") 200 | def test_handle_exceptions_example(self, mock_dynamo): 201 | # Mocking the dynamo query to return a ClientError. 202 | mock_table = MagicMock() 203 | client_error = ClientError({}, "") 204 | mock_table.query.side_effect = client_error 205 | 206 | mock_dynamo.return_value.Table.return_value = mock_table 207 | 208 | # we can automatically handle the ClientError, using the "exception_handler" decorator. 209 | response = handle_exceptions_example() # noqa: pylint - assignment-from-no-return 210 | 211 | # and return the error supplied to the caller. 212 | self.assertEqual(response["statusCode"], 400) 213 | self.assertEqual(response["body"], "{\"message\": \"Your message when a client error happens.\"}") 214 | 215 | def test_response_as_json_example(self): 216 | # We can automatically json dump a body dictionary: 217 | response = response_body_as_json_example() 218 | 219 | # the response body is a string. 220 | self.assertEqual("{\"param\": \"hello!\"}", response["body"]) 221 | 222 | def test_extract_from_list_example(self): 223 | # Given this dictionary: 224 | dictionary = { 225 | "parent": [ 226 | {"my_param": "Hello!"}, 227 | {"my_param": "Bye!"} 228 | ], 229 | "other": "other value" 230 | } 231 | 232 | # we can extract from a json string by adding the [json] annotation to parent. 233 | response = extract_from_list_example(dictionary) 234 | 235 | # and we will get the value of the "my_param" parameter inside the "parent" correct item. 236 | self.assertEqual("Bye!", response) 237 | 238 | def test_handle_all_exceptions_example(self): 239 | # we can automatically handle any exceptions, using the "handle_all_exceptions" decorator. 240 | response = handle_all_exceptions_example() # noqa: pylint - assignment-from-no-return 241 | 242 | # and return the error to the caller. 243 | self.assertEqual(response["statusCode"], 400) 244 | self.assertEqual(response["body"], "{\"message\": \"list index out of range\"}") 245 | 246 | def test_cors(self): 247 | # you can automatically add CORS headers to any function, using the "cors" decorator. 248 | response = cors_example() 249 | 250 | # the response has been decorated with the access-control cors headers. 251 | self.assertEqual(response["statusCode"], 200) 252 | self.assertEqual(response["headers"]["access-control-allow-origin"], "*") 253 | self.assertEqual(response["headers"]["access-control-allow-methods"], "POST") 254 | self.assertEqual(response["headers"]["access-control-allow-headers"], "Content-Type") 255 | self.assertEqual(response["headers"]["access-control-max-age"], 86400) 256 | 257 | def test_hsts(self): 258 | # You can automatically add HSTS header to any function, using the "hsts" decorator. 259 | response = hsts_example() 260 | 261 | # The response has been decorated with HSTS header 262 | self.assertEqual(response["statusCode"], 200) 263 | self.assertEqual(response["headers"]["Strict-Transport-Security"], "max-age=63072000") 264 | 265 | def test_extract_minimum_param_with_custom_error_example(self): 266 | # You can add custom error messages to all validators, and incorporate to those error messages 267 | # the validated value and the validation condition. 268 | response = extract_minimum_param_with_custom_error_example({"parent": {"an_int": 10}}) 269 | 270 | # The error response contains the custom error message, with the correct value to validate 271 | # and condition to check. 272 | self.assertEqual(response["statusCode"], 400) 273 | self.assertEqual(response["body"], "{\"message\": [{\"an_int\": [\"Bad value 10: should be at least 100\"]}]}") 274 | 275 | def test_extract_dictionary_example(self): 276 | # Given this dictionary: 277 | a_dictionary = { 278 | "params": { 279 | "my_param_1": "Hello!", 280 | "my_param_2": "Bye!" 281 | } 282 | } 283 | 284 | # calling the decorated extract_dictionary_example: 285 | response = extract_dictionary_example(a_dictionary) 286 | 287 | # will return the extracted values in a dictionary. 288 | self.assertEqual({"my_param_1": "Hello!", "my_param_2": "Bye!"}, response) 289 | 290 | def test_extract_type_param(self): 291 | # Given this dictionary: 292 | a_dictionary = { 293 | "params": { 294 | "a_bool": True 295 | } 296 | } 297 | 298 | # calling the decorated extract_dictionary_example: 299 | response = extract_type_param(a_dictionary) 300 | 301 | # will return the extracted values in a dictionary. 302 | self.assertEqual(True, response) 303 | 304 | def test_extract_enum_param(self): 305 | # Given this dictionary: 306 | a_dictionary = { 307 | "params": { 308 | "an_enum": "Bye" 309 | } 310 | } 311 | 312 | # calling the decorated extract_dictionary_example: 313 | response = extract_enum_param(a_dictionary) 314 | 315 | # will return the extracted values in a dictionary. 316 | self.assertEqual("Bye", response) 317 | 318 | def test_extract_non_empty_param(self): 319 | # Given this dictionary: 320 | a_dictionary = { 321 | "params": { 322 | "non_empty": ["first value"] 323 | } 324 | } 325 | 326 | # calling the decorated extract_dictionary_example: 327 | response = extract_non_empty_param(a_dictionary) 328 | 329 | # will return the extracted values in a dictionary. 330 | self.assertEqual(["first value"], response) 331 | 332 | def test_extract_date_param(self): 333 | # Given this dictionary: 334 | a_dictionary = { 335 | "params": { 336 | "date_example": "2001-01-01 00:00:00" 337 | } 338 | } 339 | 340 | # calling the decorated extract_dictionary_example: 341 | response = extract_date_param(a_dictionary) 342 | 343 | # will return the extracted values in a dictionary. 344 | self.assertEqual("2001-01-01 00:00:00", response) 345 | 346 | def test_extract_currency_param(self): 347 | # Given this dictionary: 348 | a_dictionary = { 349 | "params": { 350 | "currency_example": "GBP" 351 | } 352 | } 353 | 354 | # calling the decorated extract_dictionary_example: 355 | response = extract_currency_param(a_dictionary) 356 | 357 | # will return the extracted values in a dictionary. 358 | self.assertEqual("GBP", response) 359 | 360 | def test_extract_with_transform(self): 361 | # Given this dictionary: 362 | a_dictionary = { 363 | "params": { 364 | "my_param": "2" 365 | } 366 | } 367 | 368 | # calling the decorated extract_with_transform_example: 369 | response = extract_with_transform_example(a_dictionary) 370 | 371 | # will return the integer value 2 ("2" transform to an int) 372 | self.assertEqual(2, response) 373 | 374 | @patch("aws_lambda_decorators.decorators.LOGGER") 375 | def test_extract_with_custom_transform(self, mock_logger): 376 | # Given this dictionary: 377 | a_dictionary = { 378 | "params": { 379 | "my_param": "abc" 380 | } 381 | } 382 | 383 | # calling the decorated extract_with_custom_transform_example: 384 | response = extract_with_custom_transform_example(a_dictionary) 385 | 386 | # The error response contains a generic error 387 | self.assertEqual(response["statusCode"], 400) 388 | self.assertEqual(response["body"], "{\"message\": \"Error extracting parameters\"}") 389 | 390 | # and the logs will contain the "My custom error message" message 391 | mock_logger.error.assert_called_once_with("%s: %s in argument %s for path %s", 392 | "Exception", 393 | "My custom error message", 394 | "a_dictionary", 395 | "/params/my_param") 396 | -------------------------------------------------------------------------------- /pylintrc: -------------------------------------------------------------------------------- 1 | [MASTER] 2 | 3 | # A comma-separated list of package or module names from where C extensions may 4 | # be loaded. Extensions are loading into the active Python interpreter and may 5 | # run arbitrary code. 6 | extension-pkg-whitelist= 7 | 8 | # Add files or directories to the blacklist. They should be base names, not 9 | # paths. 10 | ignore=CVS 11 | 12 | # Add files or directories matching the regex patterns to the blacklist. The 13 | # regex matches against base names, not paths. 14 | ignore-patterns= 15 | 16 | # Python code to execute, usually for sys.path manipulation such as 17 | # pygtk.require(). 18 | #init-hook= 19 | 20 | # Use multiple processes to speed up Pylint. Specifying 0 will auto-detect the 21 | # number of processors available to use. 22 | jobs=1 23 | 24 | # Control the amount of potential inferred values when inferring a single 25 | # object. This can help the performance when dealing with large functions or 26 | # complex, nested conditions. 27 | limit-inference-results=100 28 | 29 | # List of plugins (as comma separated values of python modules names) to load, 30 | # usually to register additional checkers. 31 | load-plugins=pylint_quotes 32 | 33 | # Pickle collected data for later comparisons. 34 | persistent=yes 35 | 36 | # Specify a configuration file. 37 | #rcfile= 38 | 39 | # When enabled, pylint would attempt to guess common misconfiguration and emit 40 | # user-friendly hints instead of false-positive error messages. 41 | suggestion-mode=yes 42 | 43 | # Allow loading of arbitrary C extensions. Extensions are imported into the 44 | # active Python interpreter and may run arbitrary code. 45 | unsafe-load-any-extension=no 46 | 47 | 48 | [MESSAGES CONTROL] 49 | 50 | # Only show warnings with the listed confidence levels. Leave empty to show 51 | # all. Valid levels: HIGH, INFERENCE, INFERENCE_FAILURE, UNDEFINED. 52 | confidence= 53 | 54 | # Disable the message, report, category or checker with the given id(s). You 55 | # can either give multiple identifiers separated by comma (,) or put this 56 | # option multiple times (only on the command line, not in the configuration 57 | # file where it should appear only once). You can also use "--disable=all" to 58 | # disable everything first and then reenable specific checks. For example, if 59 | # you want to run only the similarities checker, you can use "--disable=all 60 | # --enable=similarities". If you want to run only the classes checker, but have 61 | # no Warning level messages displayed, use "--disable=all --enable=classes 62 | # --disable=W". 63 | disable=missing-docstring, 64 | print-statement, 65 | parameter-unpacking, 66 | unpacking-in-except, 67 | old-raise-syntax, 68 | backtick, 69 | long-suffix, 70 | old-ne-operator, 71 | old-octal-literal, 72 | import-star-module-level, 73 | non-ascii-bytes-literal, 74 | raw-checker-failed, 75 | bad-inline-option, 76 | locally-disabled, 77 | locally-enabled, 78 | file-ignored, 79 | suppressed-message, 80 | useless-suppression, 81 | deprecated-pragma, 82 | use-symbolic-message-instead, 83 | apply-builtin, 84 | basestring-builtin, 85 | buffer-builtin, 86 | cmp-builtin, 87 | coerce-builtin, 88 | execfile-builtin, 89 | file-builtin, 90 | long-builtin, 91 | raw_input-builtin, 92 | reduce-builtin, 93 | standarderror-builtin, 94 | unicode-builtin, 95 | xrange-builtin, 96 | coerce-method, 97 | delslice-method, 98 | getslice-method, 99 | setslice-method, 100 | no-absolute-import, 101 | old-division, 102 | dict-iter-method, 103 | dict-view-method, 104 | next-method-called, 105 | metaclass-assignment, 106 | indexing-exception, 107 | raising-string, 108 | reload-builtin, 109 | oct-method, 110 | hex-method, 111 | nonzero-method, 112 | cmp-method, 113 | input-builtin, 114 | round-builtin, 115 | intern-builtin, 116 | unichr-builtin, 117 | map-builtin-not-iterating, 118 | zip-builtin-not-iterating, 119 | range-builtin-not-iterating, 120 | filter-builtin-not-iterating, 121 | using-cmp-argument, 122 | eq-without-hash, 123 | div-method, 124 | idiv-method, 125 | rdiv-method, 126 | exception-message-attribute, 127 | invalid-str-codec, 128 | sys-max-int, 129 | bad-python3-import, 130 | deprecated-string-function, 131 | deprecated-str-translate-call, 132 | deprecated-itertools-function, 133 | deprecated-types-field, 134 | next-method-defined, 135 | dict-items-not-iterating, 136 | dict-keys-not-iterating, 137 | dict-values-not-iterating, 138 | deprecated-operator-function, 139 | deprecated-urllib-function, 140 | xreadlines-attribute, 141 | deprecated-sys-function, 142 | exception-escape, 143 | comprehension-escape 144 | 145 | # Enable the message, report, category or checker with the given id(s). You can 146 | # either give multiple identifier separated by comma (,) or put this option 147 | # multiple time (only on the command line, not in the configuration file where 148 | # it should appear only once). See also the "--disable" option for examples. 149 | enable=c-extension-no-member 150 | 151 | 152 | [REPORTS] 153 | 154 | # Python expression which should return a note less than 10 (10 is the highest 155 | # note). You have access to the variables errors warning, statement which 156 | # respectively contain the number of errors / warnings messages and the total 157 | # number of statements analyzed. This is used by the global evaluation report 158 | # (RP0004). 159 | evaluation=10.0 - ((float(5 * error + warning + refactor + convention) / statement) * 10) 160 | 161 | # Template used to display messages. This is a python new-style format string 162 | # used to format the message information. See doc for all details. 163 | #msg-template= 164 | 165 | # Set the output format. Available formats are text, parseable, colorized, json 166 | # and msvs (visual studio). You can also give a reporter class, e.g. 167 | # mypackage.mymodule.MyReporterClass. 168 | output-format=text 169 | 170 | # Tells whether to display a full report or only the messages. 171 | reports=no 172 | 173 | # Activate the evaluation score. 174 | score=yes 175 | 176 | 177 | [REFACTORING] 178 | 179 | # Maximum number of nested blocks for function / method body 180 | max-nested-blocks=5 181 | 182 | # Complete name of functions that never returns. When checking for 183 | # inconsistent-return-statements if a never returning function is called then 184 | # it will be considered as an explicit return statement and no message will be 185 | # printed. 186 | never-returning-functions=sys.exit 187 | 188 | 189 | [LOGGING] 190 | 191 | # Logging modules to check that the string format arguments are in logging 192 | # function parameter format. 193 | logging-modules=logging 194 | 195 | 196 | [SPELLING] 197 | 198 | # Limits count of emitted suggestions for spelling mistakes. 199 | max-spelling-suggestions=4 200 | 201 | # Spelling dictionary name. Available dictionaries: none. To make it working 202 | # install python-enchant package.. 203 | spelling-dict= 204 | 205 | # List of comma separated words that should not be checked. 206 | spelling-ignore-words= 207 | 208 | # A path to a file that contains private dictionary; one word per line. 209 | spelling-private-dict-file= 210 | 211 | # Tells whether to store unknown words to indicated private dictionary in 212 | # --spelling-private-dict-file option instead of raising a message. 213 | spelling-store-unknown-words=no 214 | 215 | 216 | [MISCELLANEOUS] 217 | 218 | # List of note tags to take in consideration, separated by a comma. 219 | notes=FIXME, 220 | XXX, 221 | TODO 222 | 223 | 224 | [TYPECHECK] 225 | 226 | # List of decorators that produce context managers, such as 227 | # contextlib.contextmanager. Add to this list to register other decorators that 228 | # produce valid context managers. 229 | contextmanager-decorators=contextlib.contextmanager 230 | 231 | # List of members which are set dynamically and missed by pylint inference 232 | # system, and so shouldn't trigger E1101 when accessed. Python regular 233 | # expressions are accepted. 234 | generated-members= 235 | 236 | # Tells whether missing members accessed in mixin class should be ignored. A 237 | # mixin class is detected if its name ends with "mixin" (case insensitive). 238 | ignore-mixin-members=yes 239 | 240 | # Tells whether to warn about missing members when the owner of the attribute 241 | # is inferred to be None. 242 | ignore-none=yes 243 | 244 | # This flag controls whether pylint should warn about no-member and similar 245 | # checks whenever an opaque object is returned when inferring. The inference 246 | # can return multiple potential results while evaluating a Python object, but 247 | # some branches might not be evaluated, which results in partial inference. In 248 | # that case, it might be useful to still emit no-member and other checks for 249 | # the rest of the inferred objects. 250 | ignore-on-opaque-inference=yes 251 | 252 | # List of class names for which member attributes should not be checked (useful 253 | # for classes with dynamically set attributes). This supports the use of 254 | # qualified names. 255 | ignored-classes=optparse.Values,thread._local,_thread._local 256 | 257 | # List of module names for which member attributes should not be checked 258 | # (useful for modules/projects where namespaces are manipulated during runtime 259 | # and thus existing member attributes cannot be deduced by static analysis. It 260 | # supports qualified module names, as well as Unix pattern matching. 261 | ignored-modules= 262 | 263 | # Show a hint with possible names when a member name was not found. The aspect 264 | # of finding the hint is based on edit distance. 265 | missing-member-hint=yes 266 | 267 | # The minimum edit distance a name should have in order to be considered a 268 | # similar match for a missing member name. 269 | missing-member-hint-distance=1 270 | 271 | # The total number of similar names that should be taken in consideration when 272 | # showing a hint for a missing member. 273 | missing-member-max-choices=1 274 | 275 | 276 | [VARIABLES] 277 | 278 | # List of additional names supposed to be defined in builtins. Remember that 279 | # you should avoid to define new builtins when possible. 280 | additional-builtins= 281 | 282 | # Tells whether unused global variables should be treated as a violation. 283 | allow-global-unused-variables=yes 284 | 285 | # List of strings which can identify a callback function by name. A callback 286 | # name must start or end with one of those strings. 287 | callbacks=cb_, 288 | _cb 289 | 290 | # A regular expression matching the name of dummy variables (i.e. expected to 291 | # not be used). 292 | dummy-variables-rgx=_+$|(_[a-zA-Z0-9_]*[a-zA-Z0-9]+?$)|dummy|^ignored_|^unused_ 293 | 294 | # Argument names that match this expression will be ignored. Default to name 295 | # with leading underscore. 296 | ignored-argument-names=_.*|^ignored_|^unused_ 297 | 298 | # Tells whether we should check for unused import in __init__ files. 299 | init-import=no 300 | 301 | # List of qualified module names which can have objects that can redefine 302 | # builtins. 303 | redefining-builtins-modules=six.moves,past.builtins,future.builtins,builtins,io 304 | 305 | 306 | [FORMAT] 307 | 308 | # Expected format of line ending, e.g. empty (any line ending), LF or CRLF. 309 | expected-line-ending-format= 310 | 311 | # Regexp for a line that is allowed to be longer than the limit. 312 | ignore-long-lines=^\s*(# )??$ 313 | 314 | # Number of spaces of indent required inside a hanging or continued line. 315 | indent-after-paren=4 316 | 317 | # String used as indentation unit. This is usually " " (4 spaces) or "\t" (1 318 | # tab). 319 | indent-string=' ' 320 | 321 | # Maximum number of characters on a single line. 322 | max-line-length=120 323 | 324 | # Maximum number of lines in a module. 325 | max-module-lines=1000 326 | 327 | # List of optional constructs for which whitespace checking is disabled. `dict- 328 | # separator` is used to allow tabulation in dicts, etc.: {1 : 1,\n222: 2}. 329 | # `trailing-comma` allows a space between comma and closing bracket: (a, ). 330 | # `empty-line` allows space-only lines. 331 | no-space-check=trailing-comma, 332 | dict-separator 333 | 334 | # Allow the body of a class to be on the same line as the declaration if body 335 | # contains single statement. 336 | single-line-class-stmt=no 337 | 338 | # Allow the body of an if to be on the same line as the test if there is no 339 | # else. 340 | single-line-if-stmt=no 341 | 342 | 343 | [SIMILARITIES] 344 | 345 | # Ignore comments when computing similarities. 346 | ignore-comments=yes 347 | 348 | # Ignore docstrings when computing similarities. 349 | ignore-docstrings=yes 350 | 351 | # Ignore imports when computing similarities. 352 | ignore-imports=no 353 | 354 | # Minimum lines number of a similarity. 355 | min-similarity-lines=4 356 | 357 | 358 | [BASIC] 359 | 360 | # Naming style matching correct argument names. 361 | argument-naming-style=snake_case 362 | 363 | # Regular expression matching correct argument names. Overrides argument- 364 | # naming-style. 365 | #argument-rgx= 366 | 367 | # Naming style matching correct attribute names. 368 | attr-naming-style=snake_case 369 | 370 | # Regular expression matching correct attribute names. Overrides attr-naming- 371 | # style. 372 | #attr-rgx= 373 | 374 | # Bad variable names which should always be refused, separated by a comma. 375 | bad-names=foo, 376 | bar, 377 | baz, 378 | toto, 379 | tutu, 380 | tata 381 | 382 | # Naming style matching correct class attribute names. 383 | class-attribute-naming-style=any 384 | 385 | # Regular expression matching correct class attribute names. Overrides class- 386 | # attribute-naming-style. 387 | #class-attribute-rgx= 388 | 389 | # Naming style matching correct class names. 390 | class-naming-style=PascalCase 391 | 392 | # Regular expression matching correct class names. Overrides class-naming- 393 | # style. 394 | #class-rgx= 395 | 396 | # Naming style matching correct constant names. 397 | const-naming-style=UPPER_CASE 398 | 399 | # Regular expression matching correct constant names. Overrides const-naming- 400 | # style. 401 | #const-rgx= 402 | 403 | # Minimum line length for functions/classes that require docstrings, shorter 404 | # ones are exempt. 405 | docstring-min-length=-1 406 | 407 | # Naming style matching correct function names. 408 | function-naming-style=snake_case 409 | 410 | # Regular expression matching correct function names. Overrides function- 411 | # naming-style. 412 | #function-rgx= 413 | 414 | # Good variable names which should always be accepted, separated by a comma. 415 | good-names=i, 416 | j, 417 | k, 418 | ex, 419 | Run, 420 | _ 421 | 422 | # Include a hint for the correct naming format with invalid-name. 423 | include-naming-hint=no 424 | 425 | # Naming style matching correct inline iteration names. 426 | inlinevar-naming-style=any 427 | 428 | # Regular expression matching correct inline iteration names. Overrides 429 | # inlinevar-naming-style. 430 | #inlinevar-rgx= 431 | 432 | # Naming style matching correct method names. 433 | method-naming-style=snake_case 434 | 435 | # Regular expression matching correct method names. Overrides method-naming- 436 | # style. 437 | #method-rgx= 438 | 439 | # Naming style matching correct module names. 440 | module-naming-style=snake_case 441 | 442 | # Regular expression matching correct module names. Overrides module-naming- 443 | # style. 444 | #module-rgx= 445 | 446 | # Colon-delimited sets of names that determine each other's naming style when 447 | # the name regexes allow several styles. 448 | name-group= 449 | 450 | # Regular expression which should only match function or class names that do 451 | # not require a docstring. 452 | no-docstring-rgx=^_ 453 | 454 | # List of decorators that produce properties, such as abc.abstractproperty. Add 455 | # to this list to register other decorators that produce valid properties. 456 | # These decorators are taken in consideration only for invalid-name. 457 | property-classes=abc.abstractproperty 458 | 459 | # Naming style matching correct variable names. 460 | variable-naming-style=snake_case 461 | 462 | # Regular expression matching correct variable names. Overrides variable- 463 | # naming-style. 464 | #variable-rgx= 465 | 466 | 467 | [STRING_QUOTES] 468 | 469 | # The quote character for triple-quoted docstrings. 470 | docstring-quote=double 471 | 472 | # The quote character for string literals. 473 | string-quote=double 474 | 475 | # The quote character for triple-quoted strings (non-docstring). 476 | triple-quote=double 477 | 478 | 479 | [IMPORTS] 480 | 481 | # Allow wildcard imports from modules that define __all__. 482 | allow-wildcard-with-all=no 483 | 484 | # Analyse import fallback blocks. This can be used to support both Python 2 and 485 | # 3 compatible code, which means that the block might have code that exists 486 | # only in one or another interpreter, leading to false positives when analysed. 487 | analyse-fallback-blocks=no 488 | 489 | # Deprecated modules which should not be used, separated by a comma. 490 | deprecated-modules=optparse,tkinter.tix 491 | 492 | # Create a graph of external dependencies in the given file (report RP0402 must 493 | # not be disabled). 494 | ext-import-graph= 495 | 496 | # Create a graph of every (i.e. internal and external) dependencies in the 497 | # given file (report RP0402 must not be disabled). 498 | import-graph= 499 | 500 | # Create a graph of internal dependencies in the given file (report RP0402 must 501 | # not be disabled). 502 | int-import-graph= 503 | 504 | # Force import order to recognize a module as part of the standard 505 | # compatibility libraries. 506 | known-standard-library= 507 | 508 | # Force import order to recognize a module as part of a third party library. 509 | known-third-party=enchant 510 | 511 | 512 | [CLASSES] 513 | 514 | # List of method names used to declare (i.e. assign) instance attributes. 515 | defining-attr-methods=__init__, 516 | __new__, 517 | setUp 518 | 519 | # List of member names, which should be excluded from the protected access 520 | # warning. 521 | exclude-protected=_asdict, 522 | _fields, 523 | _replace, 524 | _source, 525 | _make 526 | 527 | # List of valid names for the first argument in a class method. 528 | valid-classmethod-first-arg=cls 529 | 530 | # List of valid names for the first argument in a metaclass class method. 531 | valid-metaclass-classmethod-first-arg=cls 532 | 533 | 534 | [DESIGN] 535 | 536 | # Maximum number of arguments for function / method. 537 | max-args=5 538 | 539 | # Maximum number of attributes for a class (see R0902). 540 | max-attributes=7 541 | 542 | # Maximum number of boolean expressions in an if statement. 543 | max-bool-expr=5 544 | 545 | # Maximum number of branch for function / method body. 546 | max-branches=12 547 | 548 | # Maximum number of locals for function / method body. 549 | max-locals=15 550 | 551 | # Maximum number of parents for a class (see R0901). 552 | max-parents=7 553 | 554 | # Maximum number of public methods for a class (see R0904). 555 | max-public-methods=20 556 | 557 | # Maximum number of return / yield for function / method body. 558 | max-returns=6 559 | 560 | # Maximum number of statements in function / method body. 561 | max-statements=50 562 | 563 | # Minimum number of public methods for a class (see R0903). 564 | min-public-methods=2 565 | 566 | 567 | [EXCEPTIONS] 568 | 569 | # Exceptions that will emit a warning when being caught. Defaults to 570 | # "Exception". 571 | overgeneral-exceptions=Exception 572 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | bandit 2 | boto3 3 | prospector 4 | PyJWT==1.7.1 5 | pylint_quotes 6 | schema 7 | coverage==4.5.4 8 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | LONG_DESCRIPTION = open("README.md").read() 4 | 5 | setup(name="aws-lambda-decorators", 6 | version="0.53", 7 | description="A set of python decorators to simplify aws python lambda development", 8 | long_description=LONG_DESCRIPTION, 9 | long_description_content_type="text/markdown", 10 | url="https://github.com/gridsmartercities/aws-lambda-decorators", 11 | author="Grid Smarter Cities", 12 | author_email="open-source@gridsmartercities.com", 13 | license="MIT", 14 | classifiers=["Intended Audience :: Developers", 15 | "Development Status :: 4 - Beta", 16 | "Programming Language :: Python :: 3", 17 | "License :: OSI Approved :: MIT License", 18 | "Operating System :: OS Independent", 19 | "Natural Language :: English" 20 | ], 21 | keywords="aws lambda decorator", 22 | packages=find_packages(exclude=("tests", "examples",)), 23 | install_requires=[ 24 | "boto3", 25 | "PyJWT", 26 | "schema" 27 | ], 28 | zip_safe=False 29 | ) 30 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gridsmartercities/aws-lambda-decorators/16dbe6ae1b9982f312d593336682c4ebbcd4f52d/tests/__init__.py -------------------------------------------------------------------------------- /tests/test_classes.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | from aws_lambda_decorators.classes import Parameter, SSMParameter, BaseParameter 3 | 4 | 5 | class ParamTests(unittest.TestCase): 6 | 7 | def test_can_create_base_parameter(self): 8 | base_param = BaseParameter("var_name") 9 | self.assertEqual("var_name", base_param.get_var_name()) 10 | 11 | def test_annotations_from_key_returns_annotation(self): 12 | key = "simple[annotation]" 13 | response = Parameter.get_annotations_from_key(key) 14 | self.assertTrue(response[0] == "simple") 15 | self.assertTrue(response[1] == "annotation") 16 | 17 | def test_can_not_add_non_pythonic_var_name_to_ssm_parameter(self): 18 | param = SSMParameter("tests", "with space") 19 | 20 | with self.assertRaises(SyntaxError): 21 | param.get_var_name() 22 | -------------------------------------------------------------------------------- /tests/test_decoders.py: -------------------------------------------------------------------------------- 1 | # pylint:disable=no-self-use 2 | import json 3 | import unittest 4 | from unittest.mock import patch 5 | from aws_lambda_decorators.decoders import decode, decode_json, decode_jwt 6 | from aws_lambda_decorators.decorators import extract 7 | from aws_lambda_decorators.classes import Parameter 8 | 9 | 10 | TEST_JWT = "eyJraWQiOiJEQlwvK0lGMVptekNWOGNmRE1XVUxBRlBwQnVObW5CU2NcL2RoZ3pnTVhcL2NzPSIsImFsZyI6IlJTMjU2In0." \ 11 | "eyJzdWIiOiJhYWRkMWUwZS01ODA3LTQ3NjMtYjFlOC01ODIzYmY2MzFiYjYiLCJhdWQiOiIycjdtMW1mdWFiODg3ZmZvdG9iNWFjcX" \ 12 | "Q2aCIsImNvZ25pdG86Z3JvdXBzIjpbIkRBU0gtQ3VzdG9tZXIiXSwiZW1haWxfdmVyaWZpZWQiOnRydWUsImV2ZW50X2lkIjoiZDU4" \ 13 | "NzU0ZjUtMTdlMC0xMWU5LTg2NzAtMjVkOTNhNWNiMjAwIiwidG9rZW5fdXNlIjoiaWQiLCJhdXRoX3RpbWUiOjE1NDc0NTkwMDMsIm" \ 14 | "lzcyI6Imh0dHBzOlwvXC9jb2duaXRvLWlkcC5ldS13ZXN0LTIuYW1hem9uYXdzLmNvbVwvZXUtd2VzdC0yX1B4bEdzMU11SiIsImNv" \ 15 | "Z25pdG86dXNlcm5hbWUiOiJhYWRkMWUwZS01ODA3LTQ3NjMtYjFlOC01ODIzYmY2MzFiYjYiLCJleHAiOjE1NDc0NjI2MDMsImlhdC" \ 16 | "I6MTU0NzQ1OTAwMywiZW1haWwiOiJjdXN0b21lckBleGFtcGxlLmNvbSJ9.CNSDu4a9azT40maHAF9tnQTWbfEeiTZ9PfkR9_RU_VG" \ 17 | "4QTA1y4R0F2zWVpsa3CkVMq4Uv2NWOwG6zXf-7XaWTEjoGOQR07sq54IEWU3WIxgkgtRAI-aR7nIvllMXXR0RE3e5jzn5SmefG1j-O" \ 18 | "NYiD1yYExrKOEMPJVgkdYG6x2cBiucHihVliJQUf9u-ebpu2Cpm_ACvUTUilB6sBL06D3sRobvNLbNNnSjsA66ULNpPTPOVYJxhFbu" \ 19 | "ceQ1EICp0oICw2ncJch78RAFY5TeqiVa-uBybxwd36zJmZkXeJPWAKd32IOIJXNUyDOJtmXtSQW51pZGYTsihjZHz3kNlfg" 20 | 21 | 22 | class DecodersTests(unittest.TestCase): 23 | 24 | @patch("aws_lambda_decorators.decoders.LOGGER") 25 | def test_decode_function_missing_logs_error(self, mock_logger): 26 | decode("[random]", None) 27 | mock_logger.error.assert_called_once_with("Missing decode function for annotation: %s", "[random]") 28 | 29 | @patch("aws_lambda_decorators.decorators.LOGGER") 30 | def test_extract_returns_400_on_json_decode_error(self, mock_logger): 31 | path = "/a/b[json]/c" 32 | dictionary = { 33 | "a": { 34 | "b": "{'c'}" 35 | } 36 | } 37 | 38 | @extract([Parameter(path, "event")]) 39 | def handler(event, context, c=None): # noqa 40 | return {} 41 | 42 | response = handler(dictionary, None) 43 | 44 | self.assertEqual(400, response["statusCode"]) 45 | self.assertEqual("{\"message\": \"Error extracting parameters\"}", response["body"]) 46 | 47 | mock_logger.error.assert_called_once_with( 48 | "%s: %s in argument %s for path %s", 49 | "json.decoder.JSONDecodeError", 50 | "Expecting property name enclosed in double quotes: line 1 column 2 (char 1)", 51 | "event", 52 | "/a/b[json]/c") 53 | 54 | @patch("aws_lambda_decorators.decorators.LOGGER") 55 | def test_extract_returns_400_on_jwt_decode_error(self, mock_logger): 56 | path = "/a/b[jwt]/c" 57 | dictionary = { 58 | "a": { 59 | "b": "wrong.jwt" 60 | } 61 | } 62 | 63 | @extract([Parameter(path, "event")]) 64 | def handler(event, context, c=None): # noqa 65 | return {} 66 | 67 | response = handler(dictionary, None) 68 | 69 | self.assertEqual(400, response["statusCode"]) 70 | self.assertTrue("{\"message\": \"Error extracting parameters\"}" in response["body"]) 71 | 72 | mock_logger.error.assert_called_once_with( 73 | "%s: %s in argument %s for path %s", 74 | "jwt.exceptions.DecodeError", 75 | "Not enough segments", 76 | "event", 77 | "/a/b[jwt]/c") 78 | 79 | def test_extracts_from_list_by_index_annotation_successfully(self): 80 | path = "/a/b[1]/c" 81 | dictionary = { 82 | "a": { 83 | "b": [ 84 | { 85 | "c": 2 86 | }, 87 | { 88 | "c": 3 89 | } 90 | ] 91 | } 92 | } 93 | 94 | @extract([Parameter(path, "event")]) 95 | def handler(event, context, c=None): # noqa 96 | return c 97 | 98 | response = handler(dictionary, None) 99 | 100 | self.assertEqual(3, response) 101 | 102 | @patch("aws_lambda_decorators.decorators.LOGGER") 103 | def test_extracts_from_list_by_index_out_of_range_fails_with_400(self, mock_logger): 104 | path = "/a/b[4]/c" 105 | dictionary = { 106 | "a": { 107 | "b": [ 108 | { 109 | "c": 2 110 | }, 111 | { 112 | "c": 3 113 | } 114 | ] 115 | } 116 | } 117 | 118 | @extract([Parameter(path, "event")]) 119 | def handler(event, context, c=None): # noqa 120 | return c 121 | 122 | response = handler(dictionary, None) 123 | 124 | self.assertEqual(400, response["statusCode"]) # noqa 125 | self.assertTrue("{\"message\": \"Error extracting parameters\"}" in response["body"]) # noqa 126 | 127 | mock_logger.error.assert_called_once_with( 128 | "%s: %s in argument %s for path %s", 129 | "IndexError", 130 | "list index out of range", 131 | "event", 132 | "/a/b[4]/c") 133 | 134 | def test_extract_multiple_parameters_from_json_hits_cache(self): 135 | dictionary = { 136 | "a": json.dumps({ 137 | "b": 123, 138 | "c": 456 139 | }) 140 | } 141 | 142 | initial_cache_info = decode_json.cache_info() 143 | 144 | @extract([ 145 | Parameter("a[json]/b", "event", var_name="b"), 146 | Parameter("a[json]/c", "event", var_name="c") 147 | ]) # noqa: pylint - invalid-name 148 | def handler(event, b=None, c=None): # noqa: pylint - unused-argument 149 | return {} 150 | 151 | handler(dictionary) 152 | 153 | self.assertEqual(decode_json.cache_info().hits, initial_cache_info.hits + 1) 154 | self.assertEqual(decode_json.cache_info().misses, initial_cache_info.misses + 1) 155 | self.assertEqual(decode_json.cache_info().currsize, initial_cache_info.currsize + 1) 156 | 157 | def test_extract_multiple_parameters_from_jwt_hits_cache(self): 158 | dictionary = { 159 | "a": TEST_JWT 160 | } 161 | 162 | initial_cache_info = decode_jwt.cache_info() 163 | 164 | @extract([ 165 | Parameter("a[jwt]/sub", "event", var_name="sub"), 166 | Parameter("a[jwt]/aud", "event", var_name="aud") 167 | ]) 168 | def handler(event, sub=None, aud=None): # noqa: pylint - unused-argument 169 | return {} 170 | 171 | handler(dictionary) 172 | 173 | self.assertEqual(decode_jwt.cache_info().hits, initial_cache_info.hits + 1) 174 | self.assertEqual(decode_jwt.cache_info().misses, initial_cache_info.misses + 1) 175 | self.assertEqual(decode_jwt.cache_info().currsize, initial_cache_info.currsize + 1) 176 | -------------------------------------------------------------------------------- /tests/test_decorators.py: -------------------------------------------------------------------------------- 1 | # pylint:disable=too-many-lines 2 | from http import HTTPStatus 3 | import json 4 | from json import JSONDecodeError 5 | import unittest 6 | from unittest.mock import patch, MagicMock 7 | from uuid import uuid4 8 | 9 | from botocore.exceptions import ClientError 10 | from schema import Schema, And, Optional 11 | 12 | from aws_lambda_decorators.classes import ExceptionHandler, Parameter, SSMParameter, ValidatedParameter 13 | from aws_lambda_decorators.decorators import extract, extract_from_event, extract_from_context, handle_exceptions, \ 14 | log, response_body_as_json, extract_from_ssm, validate, handle_all_exceptions, cors, push_ws_errors, \ 15 | push_ws_response, hsts 16 | from aws_lambda_decorators.utils import get_websocket_endpoint 17 | from aws_lambda_decorators.validators import Mandatory, RegexValidator, SchemaValidator, Minimum, Maximum, MaxLength, \ 18 | MinLength, Type, EnumValidator, NonEmpty, DateValidator, CurrencyValidator 19 | 20 | TEST_JWT = "eyJraWQiOiJEQlwvK0lGMVptekNWOGNmRE1XVUxBRlBwQnVObW5CU2NcL2RoZ3pnTVhcL2NzPSIsImFsZyI6IlJTMjU2In0." \ 21 | "eyJzdWIiOiJhYWRkMWUwZS01ODA3LTQ3NjMtYjFlOC01ODIzYmY2MzFiYjYiLCJhdWQiOiIycjdtMW1mdWFiODg3ZmZvdG9iNWFjcX" \ 22 | "Q2aCIsImNvZ25pdG86Z3JvdXBzIjpbIkRBU0gtQ3VzdG9tZXIiXSwiZW1haWxfdmVyaWZpZWQiOnRydWUsImV2ZW50X2lkIjoiZDU4" \ 23 | "NzU0ZjUtMTdlMC0xMWU5LTg2NzAtMjVkOTNhNWNiMjAwIiwidG9rZW5fdXNlIjoiaWQiLCJhdXRoX3RpbWUiOjE1NDc0NTkwMDMsIm" \ 24 | "lzcyI6Imh0dHBzOlwvXC9jb2duaXRvLWlkcC5ldS13ZXN0LTIuYW1hem9uYXdzLmNvbVwvZXUtd2VzdC0yX1B4bEdzMU11SiIsImNv" \ 25 | "Z25pdG86dXNlcm5hbWUiOiJhYWRkMWUwZS01ODA3LTQ3NjMtYjFlOC01ODIzYmY2MzFiYjYiLCJleHAiOjE1NDc0NjI2MDMsImlhdC" \ 26 | "I6MTU0NzQ1OTAwMywiZW1haWwiOiJjdXN0b21lckBleGFtcGxlLmNvbSJ9.CNSDu4a9azT40maHAF9tnQTWbfEeiTZ9PfkR9_RU_VG" \ 27 | "4QTA1y4R0F2zWVpsa3CkVMq4Uv2NWOwG6zXf-7XaWTEjoGOQR07sq54IEWU3WIxgkgtRAI-aR7nIvllMXXR0RE3e5jzn5SmefG1j-O" \ 28 | "NYiD1yYExrKOEMPJVgkdYG6x2cBiucHihVliJQUf9u-ebpu2Cpm_ACvUTUilB6sBL06D3sRobvNLbNNnSjsA66ULNpPTPOVYJxhFbu" \ 29 | "ceQ1EICp0oICw2ncJch78RAFY5TeqiVa-uBybxwd36zJmZkXeJPWAKd32IOIJXNUyDOJtmXtSQW51pZGYTsihjZHz3kNlfg" 30 | 31 | 32 | class DecoratorsTests(unittest.TestCase): # noqa: pylint - too-many-public-methods 33 | 34 | def test_can_get_value_from_dict_by_path(self): 35 | path = "/a/b/c" 36 | dictionary = { 37 | "a": { 38 | "b": { 39 | "c": "hello" 40 | } 41 | } 42 | } 43 | param = Parameter(path) 44 | response = param.extract_value(dictionary) 45 | self.assertEqual("hello", response) 46 | 47 | def test_can_get_dict_value_from_dict_by_path(self): 48 | path = "/a/b" 49 | dictionary = { 50 | "a": { 51 | "b": { 52 | "c": "hello" 53 | } 54 | } 55 | } 56 | param = Parameter(path) 57 | response = param.extract_value(dictionary) 58 | self.assertEqual({"c": "hello"}, response) 59 | 60 | def test_raises_decode_error_convert_json_string_to_dict(self): 61 | path = "/a/b[json]/c" 62 | dictionary = { 63 | "a": { 64 | "b": "{ 'c': 'hello' }", 65 | "c": "bye" 66 | } 67 | } 68 | param = Parameter(path) 69 | with self.assertRaises(JSONDecodeError) as context: 70 | param.extract_value(dictionary) 71 | 72 | self.assertTrue("Expecting property name enclosed in double quotes" in context.exception.msg) 73 | 74 | def test_can_get_value_from_dict_with_json_by_path(self): 75 | path = "/a/b[json]/c" 76 | dictionary = { 77 | "a": { 78 | "b": "{\"c\": \"hello\"}", 79 | "c": "bye" 80 | } 81 | } 82 | param = Parameter(path, "event") 83 | response = param.extract_value(dictionary) 84 | self.assertEqual("hello", response) 85 | 86 | def test_can_get_value_from_dict_with_jwt_by_path(self): 87 | path = "/a/b[jwt]/sub" 88 | dictionary = { 89 | "a": { 90 | "b": TEST_JWT 91 | } 92 | } 93 | param = Parameter(path, "event") 94 | response = param.extract_value(dictionary) 95 | self.assertEqual("aadd1e0e-5807-4763-b1e8-5823bf631bb6", response) 96 | 97 | def test_extract_from_event_calls_function_with_extra_kwargs(self): 98 | path = "/a/b/c" 99 | dictionary = { 100 | "a": { 101 | "b": { 102 | "c": "hello" 103 | } 104 | } 105 | } 106 | 107 | @extract_from_event([Parameter(path)]) 108 | def handler(event, context, c=None): # noqa 109 | return c 110 | 111 | self.assertEqual(handler(dictionary, None), "hello") 112 | 113 | def test_extract_from_event_calls_function_with_extra_kwargs_bool_true(self): 114 | path = "/a/b/c" 115 | dictionary = { 116 | "a": { 117 | "b": { 118 | "c": True 119 | } 120 | } 121 | } 122 | 123 | @extract_from_event([Parameter(path)]) 124 | def handler(event, context, c=None): # noqa 125 | return c 126 | 127 | self.assertFalse(handler(dictionary, None) is None) 128 | self.assertEqual(True, handler(dictionary, None)) 129 | 130 | def test_extract_from_event_calls_function_with_extra_kwargs_bool_false(self): 131 | path = "/a/b/c" 132 | dictionary = { 133 | "a": { 134 | "b": { 135 | "c": False 136 | } 137 | } 138 | } 139 | 140 | @extract_from_event([Parameter(path)]) 141 | def handler(event, context, c=None): # noqa 142 | return c 143 | 144 | self.assertFalse(handler(dictionary, None) is None) 145 | self.assertEqual(False, handler(dictionary, None)) 146 | 147 | def test_extract_from_context_calls_function_with_extra_kwargs(self): 148 | path = "/a/b/c" 149 | dictionary = { 150 | "a": { 151 | "b": { 152 | "c": "hello" 153 | } 154 | } 155 | } 156 | 157 | @extract_from_context([Parameter(path)]) 158 | def handler(event, context, c=None): # noqa 159 | return c 160 | 161 | self.assertEqual(handler(None, dictionary), "hello") 162 | 163 | def test_extract_returns_400_on_empty_path(self): 164 | path = None 165 | dictionary = { 166 | "a": { 167 | "b": { 168 | } 169 | } 170 | } 171 | 172 | @extract([Parameter(path, "event")]) 173 | def handler(event, context, c=None): # noqa 174 | return {} 175 | 176 | response = handler(dictionary, None) 177 | 178 | self.assertEqual(400, response["statusCode"]) 179 | self.assertEqual("{\"message\": \"Error extracting parameters\"}", response["body"]) 180 | 181 | def test_extract_returns_400_on_missing_mandatory_key(self): 182 | path = "/a/b/c" 183 | dictionary = { 184 | "a": { 185 | "b": { 186 | } 187 | } 188 | } 189 | 190 | @extract([Parameter(path, "event", validators=[Mandatory])]) 191 | def handler(event, context, c=None): # noqa 192 | return {} 193 | 194 | response = handler(dictionary, None) 195 | 196 | self.assertEqual(400, response["statusCode"]) 197 | self.assertEqual("{\"message\": [{\"c\": [\"Missing mandatory value\"]}]}", response["body"]) 198 | 199 | def test_can_add_name_to_parameter(self): 200 | path = "/a/b" 201 | dictionary = { 202 | "a": { 203 | "b": "hello" 204 | } 205 | } 206 | 207 | @extract([Parameter(path, "event", validators=[Mandatory], var_name="custom")]) 208 | def handler(event, context, custom=None): # noqa 209 | return custom 210 | 211 | response = handler(dictionary, None) 212 | 213 | self.assertEqual("hello", response) 214 | 215 | @patch("aws_lambda_decorators.decorators.LOGGER") 216 | def test_can_not_add_non_pythonic_var_name_to_parameter(self, mock_logger): 217 | path = "/a/b" 218 | dictionary = { 219 | "a": { 220 | "b": "hello" 221 | } 222 | } 223 | 224 | @extract_from_event([Parameter(path, validators=[Mandatory], var_name="with space")]) 225 | def handler(event, context): # noqa 226 | return {} 227 | 228 | response = handler(dictionary, None) 229 | 230 | self.assertEqual(400, response["statusCode"]) 231 | self.assertEqual("{\"message\": \"Error extracting parameters\"}", response["body"]) 232 | 233 | mock_logger.error.assert_called_once_with( 234 | "%s: %s in argument %s for path %s", 235 | "SyntaxError", 236 | "with space", 237 | "event", 238 | "/a/b") 239 | 240 | @patch("aws_lambda_decorators.decorators.LOGGER") 241 | def test_can_not_add_pythonic_keyword_as_name_to_parameter(self, mock_logger): 242 | path = "/a/b" 243 | dictionary = { 244 | "a": { 245 | "b": "hello" 246 | } 247 | } 248 | 249 | @extract_from_event([Parameter(path, validators=[Mandatory], var_name="class")]) 250 | def handler(event, context): # noqa 251 | return {} 252 | 253 | response = handler(dictionary, None) 254 | 255 | self.assertEqual(400, response["statusCode"]) 256 | self.assertEqual("{\"message\": \"Error extracting parameters\"}", response["body"]) 257 | 258 | mock_logger.error.assert_called_once_with( 259 | "%s: %s in argument %s for path %s", 260 | "SyntaxError", 261 | "class", 262 | "event", 263 | "/a/b") 264 | 265 | def test_extract_does_not_raise_an_error_on_missing_optional_key(self): 266 | path = "/a/b/c" 267 | dictionary = { 268 | "a": { 269 | "b": { 270 | } 271 | } 272 | } 273 | 274 | @extract([Parameter(path, "event")]) 275 | def handler(event, context, c=None): # noqa 276 | return {} 277 | 278 | response = handler(dictionary, None) 279 | 280 | self.assertEqual({}, response) 281 | 282 | @patch("aws_lambda_decorators.decorators.LOGGER") 283 | def test_extract_returns_400_on_invalid_regex_key(self, mock_logger): 284 | path = "/a/b/c" 285 | dictionary = { 286 | "a": { 287 | "b": { 288 | "c": "hello" 289 | } 290 | } 291 | } 292 | 293 | # Expect a number 294 | @extract([Parameter(path, "event", [RegexValidator(r"\d+")])]) 295 | def handler(event, context, c=None): # noqa 296 | return {} 297 | 298 | response = handler(dictionary, None) 299 | self.assertEqual(400, response["statusCode"]) 300 | self.assertEqual("{\"message\": [{\"c\": [\"\'hello\' does not conform to regular expression \'\\\\d+\'\"]}]}", 301 | response["body"]) 302 | 303 | mock_logger.error.assert_called_once_with( 304 | "Error validating parameters. Errors: %s", 305 | [{"c": ["'hello' does not conform to regular expression '\\d+'"]}] 306 | ) 307 | 308 | def test_extract_does_not_raise_an_error_on_valid_regex_key(self): 309 | path = "/a/b/c" 310 | dictionary = { 311 | "a": { 312 | "b": { 313 | "c": "2019" 314 | } 315 | } 316 | } 317 | 318 | # Expect a number 319 | @extract([Parameter(path, "event", [RegexValidator(r"\d+")])]) 320 | def handler(event, context, c=None): # noqa 321 | return {} 322 | 323 | response = handler(dictionary, None) 324 | 325 | self.assertEqual({}, response) 326 | 327 | @patch("aws_lambda_decorators.decorators.LOGGER") 328 | def test_validate_raises_an_error_on_invalid_variables(self, mock_logger): 329 | @validate([ 330 | ValidatedParameter(func_param_name="var1", validators=[RegexValidator(r"\d+")]), 331 | ValidatedParameter(func_param_name="var2", validators=[RegexValidator(r"\d+")]) 332 | ]) 333 | def handler(var1=None, var2=None): # noqa: pylint - unused-argument 334 | return {} 335 | 336 | response = handler("2019", "abcd") 337 | 338 | self.assertEqual(400, response["statusCode"]) 339 | self.assertEqual( 340 | "{\"message\": [{\"var2\": [\"\'abcd\' does not conform to regular expression \'\\\\d+\'\"]}]}", 341 | response["body"] 342 | ) 343 | 344 | mock_logger.error.assert_called_once_with( 345 | "Error validating parameters. Errors: %s", 346 | [{"var2": ["\'abcd\' does not conform to regular expression \'\\d+\'"]}] 347 | ) 348 | 349 | @patch("aws_lambda_decorators.decorators.LOGGER") 350 | def test_validate_raises_multiple_errors_on_exit_on_error_false(self, mock_logger): 351 | @validate([ 352 | ValidatedParameter(func_param_name="var1", validators=[RegexValidator(r"\d+")]), 353 | ValidatedParameter(func_param_name="var2", validators=[RegexValidator(r"\d+")]) 354 | ], True) 355 | def handler(var1=None, var2=None): # noqa: pylint - unused-argument 356 | return {} 357 | 358 | response = handler("20wq19", "abcd") 359 | 360 | self.assertEqual(400, response["statusCode"]) 361 | self.assertEqual( 362 | "{\"message\": [{\"var1\": [\"\'20wq19\' does not conform to regular expression \'\\\\d+\'\"]}, " 363 | "{\"var2\": [\"\'abcd\' does not conform to regular expression \'\\\\d+\'\"]}]}", 364 | response["body"]) 365 | 366 | mock_logger.error.assert_called_once_with( 367 | "Error validating parameters. Errors: %s", 368 | [ 369 | {"var1": ["'20wq19' does not conform to regular expression '\\d+'"]}, 370 | {"var2": ["'abcd' does not conform to regular expression '\\d+'"]} 371 | ] 372 | ) 373 | 374 | @patch("aws_lambda_decorators.decorators.LOGGER") 375 | def test_can_not_validate_non_pythonic_var_name(self, mock_logger): 376 | @validate([ 377 | ValidatedParameter(func_param_name="var 1", validators=[RegexValidator(r"\d+")]), 378 | ValidatedParameter(func_param_name="var2", validators=[RegexValidator(r"\d+")]) 379 | ], True) 380 | def handler(var1=None, var2=None): # noqa: pylint - unused-argument 381 | return {} 382 | 383 | response = handler("20wq19", "abcd") 384 | 385 | self.assertEqual(400, response["statusCode"]) 386 | self.assertEqual( 387 | "{\"message\": \"Error extracting parameters\"}", 388 | response["body"]) 389 | 390 | mock_logger.error.assert_called_once_with("%s: %s in argument %s", "KeyError", "'var 1'", "var 1") 391 | 392 | def test_validate_does_not_raise_an_error_on_valid_variables(self): 393 | @validate([ 394 | ValidatedParameter(func_param_name="var1", validators=[RegexValidator(r"\d+")]), 395 | ValidatedParameter(func_param_name="var2", validators=[RegexValidator(r"[ab]+")]) 396 | ]) 397 | def handler(var1, var2=None): # noqa: pylint - unused-argument 398 | return {} 399 | 400 | response = handler("2019", var2="abba") 401 | self.assertEqual({}, response) 402 | 403 | def test_extract_returns_400_on_type_error(self): 404 | path = "/a/b[json]/c" 405 | dictionary = { 406 | "a": { 407 | "b": { 408 | "c": "hello" 409 | } 410 | } 411 | } 412 | 413 | @extract([Parameter(path)]) 414 | def handler(event, context, c=None): # noqa 415 | return {} 416 | 417 | response = handler(dictionary, None) 418 | 419 | self.assertEqual(400, response["statusCode"]) 420 | self.assertEqual("{\"message\": \"Error extracting parameters\"}", response["body"]) 421 | 422 | @patch("aws_lambda_decorators.decorators.LOGGER") 423 | def test_exception_handler_raises_exception(self, mock_logger): 424 | 425 | @handle_exceptions(handlers=[ExceptionHandler(KeyError, "msg")]) 426 | def handler(): 427 | raise KeyError("blank") 428 | 429 | response = handler() # noqa 430 | 431 | self.assertEqual(400, response["statusCode"]) 432 | self.assertTrue("msg" in response["body"]) 433 | 434 | mock_logger.error.assert_called_once_with("%s: %s", "msg", "'blank'") 435 | 436 | @patch("aws_lambda_decorators.decorators.LOGGER") 437 | def test_exception_handler_raises_exception_without_friendly_message(self, mock_logger): 438 | 439 | @handle_exceptions(handlers=[ExceptionHandler(KeyError)]) 440 | def handler(): 441 | raise KeyError("blank") 442 | 443 | response = handler() # noqa 444 | 445 | self.assertEqual(400, response["statusCode"]) 446 | self.assertTrue("blank" in response["body"]) 447 | 448 | mock_logger.error.assert_called_once_with("'blank'") 449 | 450 | @patch("aws_lambda_decorators.decorators.LOGGER") 451 | def test_exception_handler_raises_exception_with_status_code(self, mock_logger): 452 | 453 | @handle_exceptions(handlers=[ExceptionHandler(KeyError, "error", 500)]) 454 | def handler(): 455 | raise KeyError("blank") 456 | 457 | response = handler() # noqa 458 | 459 | self.assertEqual(500, response["statusCode"]) 460 | self.assertEqual("""{"message": "error"}""", response["body"]) 461 | 462 | mock_logger.error.assert_called_once_with("%s: %s", "error", "'blank'") 463 | 464 | @patch("aws_lambda_decorators.decorators.LOGGER") 465 | def test_exception_handler_raises_exception_with_inherited_exception(self, mock_logger): 466 | 467 | @handle_exceptions(handlers=[ExceptionHandler(Exception)]) 468 | def handler(): 469 | raise KeyError("blank") 470 | 471 | response = handler() 472 | 473 | self.assertEqual(400, response["statusCode"]) 474 | self.assertTrue("blank" in response["body"]) 475 | 476 | mock_logger.error.assert_called_once_with("'blank'") 477 | 478 | @patch("aws_lambda_decorators.decorators.LOGGER") 479 | def test_log_decorator_can_log_params(self, mock_logger): # noqa: pylint - no-self-use 480 | 481 | @log(True, False) 482 | def handler(event, context, an_other): # noqa 483 | return {} 484 | 485 | handler("first", "{\"tests\": \"a\"}", "another") 486 | 487 | mock_logger.info.assert_called_once_with( 488 | "Function: %s, Parameters: %s", "handler", ("first", "{\"tests\": \"a\"}", "another")) 489 | 490 | @patch("aws_lambda_decorators.decorators.LOGGER") 491 | def test_log_decorator_can_log_response(self, mock_logger): # noqa: pylint - no-self-use 492 | 493 | @log(False, True) 494 | def handler(): 495 | return {"statusCode": 201} 496 | 497 | handler() 498 | 499 | mock_logger.info.assert_called_once_with("Function: %s, Response: %s", "handler", {"statusCode": 201}) 500 | 501 | @patch("boto3.client") 502 | def test_get_valid_ssm_parameter(self, mock_boto_client): 503 | mock_ssm = MagicMock() 504 | mock_ssm.get_parameters.return_value = { 505 | "Parameters": [ 506 | { 507 | "Name": "key", 508 | "Value": "tests" 509 | } 510 | ] 511 | } 512 | mock_boto_client.return_value = mock_ssm 513 | 514 | @extract_from_ssm([SSMParameter("key")]) 515 | def handler(key=None): 516 | return key 517 | 518 | self.assertEqual(handler(), "tests") 519 | 520 | @patch("boto3.client") 521 | def test_get_valid_ssm_parameter_custom_name(self, mock_boto_client): 522 | mock_ssm = MagicMock() 523 | mock_ssm.get_parameters.return_value = { 524 | "Parameters": [ 525 | { 526 | "Name": "key", 527 | "Value": "tests" 528 | } 529 | ] 530 | } 531 | mock_boto_client.return_value = mock_ssm 532 | 533 | @extract_from_ssm([SSMParameter("key", "custom")]) 534 | def handler(custom=None): 535 | return custom 536 | 537 | self.assertEqual(handler(), "tests") 538 | 539 | @patch("boto3.client") 540 | def test_get_valid_ssm_parameters(self, mock_boto_client): 541 | mock_ssm = MagicMock() 542 | mock_ssm.get_parameters.return_value = { 543 | "Parameters": [ 544 | { 545 | "Name": "key2", 546 | "Value": "test2" 547 | }, 548 | { 549 | "Name": "key1", 550 | "Value": "test1" 551 | } 552 | ] 553 | } 554 | mock_boto_client.return_value = mock_ssm 555 | 556 | @extract_from_ssm([SSMParameter("key1", "key1"), SSMParameter("key2", "key2")]) 557 | def handler(key1=None, key2=None): 558 | return [key1, key2] 559 | 560 | self.assertEqual(handler(), ["test1", "test2"]) 561 | 562 | @patch("boto3.client") 563 | def test_get_ssm_parameter_missing_parameter_raises_client_error(self, mock_boto_client): 564 | mock_ssm = MagicMock() 565 | mock_ssm.get_parameters.side_effect = ClientError({}, "") 566 | mock_boto_client.return_value = mock_ssm 567 | 568 | @extract_from_ssm([SSMParameter("")]) 569 | def handler(key=None): 570 | return key 571 | 572 | with self.assertRaises(ClientError): 573 | handler() 574 | 575 | @patch("boto3.client") 576 | def test_get_ssm_parameter_empty_key_container_raises_key_error(self, mock_boto_client): 577 | mock_ssm = MagicMock() 578 | mock_ssm.get_parameters.return_value = { 579 | } 580 | mock_boto_client.return_value = mock_ssm 581 | 582 | @extract_from_ssm([SSMParameter("")]) 583 | def handler(key=None): 584 | return key 585 | 586 | with self.assertRaises(KeyError): 587 | handler() 588 | 589 | def test_body_gets_dumped_as_json(self): 590 | 591 | @response_body_as_json 592 | def handler(): 593 | return {"statusCode": 200, "body": {"a": "b"}} 594 | 595 | response = handler() 596 | 597 | self.assertEqual(response, {"statusCode": 200, "body": "{\"a\": \"b\"}"}) 598 | 599 | def test_body_dump_raises_exception_on_invalid_json(self): 600 | 601 | @response_body_as_json 602 | def handler(): 603 | return {"statusCode": 200, "body": {"a"}} 604 | 605 | response = handler() 606 | 607 | self.assertEqual( 608 | response, 609 | {"statusCode": 500, "body": "{\"message\": \"Response body is not JSON serializable\"}"}) 610 | 611 | def test_response_as_json_invalid_application_does_nothing(self): 612 | 613 | @response_body_as_json 614 | def handler(): 615 | return {"statusCode": 200} 616 | 617 | response = handler() 618 | 619 | self.assertEqual(response, {"statusCode": 200}) 620 | 621 | @patch("aws_lambda_decorators.decorators.LOGGER") 622 | def test_handle_all_exceptions(self, mock_logger): 623 | 624 | @handle_all_exceptions() 625 | def handler(): 626 | raise KeyError("blank") 627 | 628 | response = handler() # noqa 629 | 630 | self.assertEqual(400, response["statusCode"]) 631 | self.assertTrue("blank" in response["body"]) 632 | 633 | mock_logger.error.assert_called_once_with("'blank'") 634 | 635 | def test_cors_no_headers_in_response(self): 636 | 637 | @cors(allow_origin="*", allow_methods="POST", allow_headers="Content-Type", max_age=12) 638 | def handler(): 639 | return {} 640 | 641 | response = handler() 642 | 643 | self.assertEqual(response["headers"]["access-control-allow-headers"], "Content-Type") 644 | self.assertEqual(response["headers"]["access-control-allow-methods"], "POST") 645 | self.assertEqual(response["headers"]["access-control-allow-origin"], "*") 646 | self.assertEqual(response["headers"]["access-control-max-age"], 12) 647 | 648 | def test_cors_adds_correct_headers_only(self): 649 | 650 | @cors(allow_origin="*") 651 | def handler(): 652 | return {} 653 | 654 | response = handler() 655 | 656 | self.assertEqual(response["headers"]["access-control-allow-origin"], "*") 657 | self.assertTrue("access-control-allow-methods" not in response["headers"]) 658 | self.assertTrue("access-control-allow-headers" not in response["headers"]) 659 | self.assertTrue("access-control-max-age" not in response["headers"]) 660 | 661 | def test_cors_with_headers_in_response(self): 662 | 663 | @cors(allow_origin="*", allow_methods="POST", allow_headers="Content-Type", max_age=12) 664 | def handler(): 665 | return { 666 | "headers": { 667 | "content-type": "application/json", 668 | "access-control-allow-origin": "http://example.com" 669 | } 670 | } 671 | 672 | response = handler() 673 | 674 | self.assertEqual(response["headers"]["access-control-allow-headers"], "Content-Type") 675 | self.assertEqual(response["headers"]["access-control-allow-methods"], "POST") 676 | self.assertEqual(response["headers"]["access-control-allow-origin"], "http://example.com,*") 677 | self.assertEqual(response["headers"]["access-control-max-age"], 12) 678 | 679 | def test_cors_with_headers_a_none_value_does_not_remove_headers(self): 680 | 681 | @cors(allow_origin=None) 682 | def handler(): 683 | return { 684 | "headers": { 685 | "access-control-allow-origin": "http://example.com" 686 | } 687 | } 688 | 689 | response = handler() 690 | 691 | self.assertEqual(response["headers"]["access-control-allow-origin"], "http://example.com") 692 | self.assertTrue("access-control-allow-methods" not in response["headers"]) 693 | self.assertTrue("access-control-allow-headers" not in response["headers"]) 694 | self.assertTrue("access-control-max-age" not in response["headers"]) 695 | 696 | def test_cors_with_headers_an_empty_value_does_not_remove_headers(self): 697 | 698 | @cors(allow_origin="") 699 | def handler(): 700 | return { 701 | "headers": { 702 | "access-control-allow-origin": "http://example.com" 703 | } 704 | } 705 | 706 | response = handler() 707 | 708 | self.assertEqual(response["headers"]["access-control-allow-origin"], "http://example.com") 709 | 710 | def test_cors_with_uppercase_headers_in_response(self): 711 | 712 | @cors(allow_origin="*", allow_methods="POST", allow_headers="Content-Type", max_age=12) 713 | def handler(): 714 | return { 715 | "Headers": { 716 | "content-type": "application/json", 717 | "Access-Control-Allow-Origin": "http://example.com" 718 | } 719 | } 720 | 721 | response = handler() 722 | 723 | self.assertEqual(response["Headers"]["access-control-allow-headers"], "Content-Type") 724 | self.assertEqual(response["Headers"]["access-control-allow-methods"], "POST") 725 | self.assertEqual(response["Headers"]["Access-Control-Allow-Origin"], "http://example.com,*") 726 | self.assertEqual(response["Headers"]["access-control-max-age"], 12) 727 | 728 | @patch("aws_lambda_decorators.decorators.LOGGER") 729 | def test_cors_invalid_max_age_logs_error(self, mock_logger): 730 | 731 | @cors(max_age="12") 732 | def handler(): 733 | return {} 734 | 735 | response = handler() 736 | 737 | self.assertEqual(response["statusCode"], 500) 738 | self.assertEqual(response["body"], "{\"message\": \"Invalid value type in CORS header\"}") 739 | 740 | mock_logger.error.assert_called_once_with("Cannot set %s header to a non %s value", 741 | "access-control-max-age", 742 | int) 743 | 744 | @patch("aws_lambda_decorators.decorators.LOGGER") 745 | def test_cors_cannot_decorate_non_dict(self, mock_logger): 746 | 747 | @cors(allow_origin="*") 748 | def handler(): 749 | return "I am a string" 750 | 751 | response = handler() 752 | 753 | self.assertEqual(response["statusCode"], 500) # noqa: pylint-invalid-sequence-index 754 | self.assertEqual(response["body"], "{\"message\": \"Invalid response type for CORS headers\"}") # noqa: pylint-invalid-sequence-index 755 | 756 | mock_logger.error.assert_called_once_with("Cannot add headers to a non dictionary response") 757 | 758 | def test_extract_returns_400_on_invalid_dictionary_schema(self): 759 | path = "/a" 760 | dictionary = { 761 | "a": { 762 | "b": { 763 | "c": 3 764 | } 765 | } 766 | } 767 | 768 | schema = Schema( 769 | { 770 | "b": And(dict, { 771 | "c": str 772 | }) 773 | } 774 | ) 775 | 776 | @extract([Parameter(path, "event", validators=[SchemaValidator(schema)])]) 777 | def handler(event, context, c=None): # noqa 778 | return {} 779 | 780 | response = handler(dictionary, None) 781 | 782 | self.assertEqual(400, response["statusCode"]) 783 | self.assertEqual( 784 | "{\"message\": [{\"a\": [\"\'{\'b\': {\'c\': 3}}\' " 785 | "does not validate against schema " 786 | "\'Schema({\'b\': And(, {\'c\': })})\'\"]}]}", 787 | response["body"]) 788 | 789 | def test_extract_valid_dictionary_schema(self): 790 | path = "/a" 791 | dictionary = { 792 | "a": { 793 | "b": { 794 | "c": "d" 795 | } 796 | } 797 | } 798 | 799 | schema = Schema( 800 | { 801 | "b": And(dict, { 802 | "c": str 803 | }), 804 | Optional("j"): str 805 | } 806 | ) 807 | 808 | @extract([Parameter(path, "event", validators=[SchemaValidator(schema)])]) 809 | def handler(event, context, a=None): # noqa 810 | return a 811 | 812 | response = handler(dictionary, None) 813 | 814 | expected = { 815 | "b": { 816 | "c": "d" 817 | } 818 | } 819 | self.assertEqual(expected, response) 820 | 821 | def test_extract_schema_when_property_is_none(self): 822 | path = "/a/b" 823 | dictionary = { 824 | "a": {} 825 | } 826 | 827 | schema = Schema( 828 | { 829 | "b": And(dict, { 830 | "c": str 831 | }), 832 | Optional("j"): str 833 | } 834 | ) 835 | 836 | @extract([Parameter(path, "event", validators=[SchemaValidator(schema)])]) 837 | def handler(event, context, b=None): # noqa 838 | return b 839 | 840 | response = handler(dictionary, None) 841 | 842 | self.assertEqual(None, response) 843 | 844 | def test_extract_parameter_with_minimum(self): 845 | event = { 846 | "value": 20 847 | } 848 | 849 | @extract([Parameter("/value", "event", validators=[Minimum(10.0)])]) 850 | def handler(event, value=None): # noqa: pylint - unused-argument 851 | return {} 852 | 853 | response = handler(event) 854 | self.assertEqual({}, response) 855 | 856 | def test_error_extracting_parameter_with_minimum(self): 857 | event = { 858 | "value": 5 859 | } 860 | 861 | @extract([Parameter("/value", "event", validators=[Minimum(10.0)])]) 862 | def handler(event, value=None): # noqa: pylint - unused-argument 863 | return {} 864 | 865 | response = handler(event) 866 | self.assertEqual(response["statusCode"], 400) 867 | self.assertEqual( 868 | "{\"message\": [{\"value\": [\"\'5\' is less than minimum value \'10.0\'\"]}]}", 869 | response["body"]) 870 | 871 | def test_error_extracting_non_numeric_parameter_with_minimum(self): 872 | event = { 873 | "value": "20" 874 | } 875 | 876 | @extract([Parameter("/value", "event", validators=[Minimum(10.0)])]) 877 | def handler(event, value=None): # noqa: pylint - unused-argument 878 | return {} 879 | 880 | response = handler(event) 881 | self.assertEqual(response["statusCode"], 400) 882 | self.assertEqual( 883 | "{\"message\": [{\"value\": [\"\'20\' is less than minimum value \'10.0\'\"]}]}", response["body"]) 884 | 885 | def test_extract_optional_null_parameter_with_minimum(self): 886 | event = { 887 | } 888 | 889 | @extract([Parameter("/value", "event", validators=[Minimum(10.0)])]) 890 | def handler(event, value=None): # noqa: pylint - unused-argument 891 | return {} 892 | 893 | response = handler(event) 894 | self.assertEqual({}, response) 895 | 896 | def test_extract_mandatory_parameter_with_minimum(self): 897 | event = { 898 | "value": 20 899 | } 900 | 901 | @extract([Parameter("/value", "event", validators=[Minimum(10.0), Mandatory])]) 902 | def handler(event, value=None): # noqa: pylint - unused-argument 903 | return {} 904 | 905 | response = handler(event) 906 | self.assertEqual({}, response) 907 | 908 | def test_extract_parameter_with_maximum(self): 909 | event = { 910 | "value": 20 911 | } 912 | 913 | @extract([Parameter("/value", "event", validators=[Maximum(100.0)])]) 914 | def handler(event, value=None): # noqa: pylint - unused-argument 915 | return {} 916 | 917 | response = handler(event) 918 | self.assertEqual({}, response) 919 | 920 | def test_error_extracting_parameter_with_maximum(self): 921 | event = { 922 | "value": 105 923 | } 924 | 925 | @extract([Parameter("/value", "event", validators=[Maximum(100.0)])]) 926 | def handler(event, value=None): # noqa: pylint - unused-argument 927 | return {} 928 | 929 | response = handler(event) 930 | self.assertEqual(response["statusCode"], 400) 931 | self.assertEqual( 932 | "{\"message\": [{\"value\": [\"\'105\' is greater than maximum value \'100.0\'\"]}]}", response["body"]) 933 | 934 | def test_error_extracting_non_numeric_parameter_with_maximum(self): 935 | event = { 936 | "value": "20" 937 | } 938 | 939 | @extract([Parameter("/value", "event", validators=[Maximum(100.0)])]) 940 | def handler(event, value=None): # noqa: pylint - unused-argument 941 | return {} 942 | 943 | response = handler(event) 944 | self.assertEqual(response["statusCode"], 400) 945 | self.assertEqual( 946 | "{\"message\": [{\"value\": [\"\'20\' is greater than maximum value \'100.0\'\"]}]}", response["body"]) 947 | 948 | def test_extract_optional_null_parameter_with_maximum(self): 949 | event = { 950 | } 951 | 952 | @extract([Parameter("/value", "event", validators=[Maximum(10.0)])]) 953 | def handler(event, value=None): # noqa: pylint - unused-argument 954 | return {} 955 | 956 | response = handler(event) 957 | self.assertEqual({}, response) 958 | 959 | def test_extract_mandatory_parameter_with_maximum(self): 960 | event = { 961 | "value": 20 962 | } 963 | 964 | @extract([Parameter("/value", "event", validators=[Maximum(100.0), Mandatory])]) 965 | def handler(event, value=None): # noqa: pylint - unused-argument 966 | return {} 967 | 968 | response = handler(event) 969 | self.assertEqual({}, response) 970 | 971 | def test_extract_mandatory_parameter_with_range(self): 972 | event = { 973 | "value": 20 974 | } 975 | 976 | @extract([Parameter("/value", "event", validators=[Minimum(10.0), Maximum(100.0), Mandatory])]) 977 | def handler(event, value=None): # noqa: pylint - unused-argument 978 | return {} 979 | 980 | response = handler(event) 981 | self.assertEqual({}, response) 982 | 983 | def test_extract_parameter_with_maximum_length(self): 984 | event = { 985 | "value": "correct" 986 | } 987 | 988 | @extract([Parameter("/value", "event", validators=[MaxLength(20)])]) 989 | def handler(event, value=None): # noqa: pylint - unused-argument 990 | return {} 991 | 992 | response = handler(event) 993 | self.assertEqual({}, response) 994 | 995 | def test_error_extracting_parameter_with_max_length(self): 996 | event = { 997 | "value": "too long" 998 | } 999 | 1000 | @extract([Parameter("/value", "event", validators=[MaxLength(5)])]) 1001 | def handler(event, value=None): # noqa: pylint - unused-argument 1002 | return {} 1003 | 1004 | response = handler(event) 1005 | self.assertEqual(response["statusCode"], 400) 1006 | self.assertEqual( 1007 | "{\"message\": [{\"value\": [\"\'too long\' is longer than maximum length \'5\'\"]}]}", 1008 | response["body"]) 1009 | 1010 | def test_values_are_stringified_in_max_length_validator(self): 1011 | event = { 1012 | "value": 20 1013 | } 1014 | 1015 | @extract([Parameter("/value", "event", validators=[MaxLength(5)])]) 1016 | def handler(event, value=None): # noqa: pylint - unused-argument 1017 | return {} 1018 | 1019 | response = handler(event) 1020 | self.assertEqual({}, response) 1021 | 1022 | def test_extract_optional_null_parameter_with_max_length(self): 1023 | event = { 1024 | } 1025 | 1026 | @extract([Parameter("/value", "event", validators=[MaxLength(5)])]) 1027 | def handler(event, value=None): # noqa: pylint - unused-argument 1028 | return {} 1029 | 1030 | response = handler(event) 1031 | self.assertEqual({}, response) 1032 | 1033 | def test_extract_mandatory_parameter_with_max_length(self): 1034 | event = { 1035 | "value": "aa" 1036 | } 1037 | 1038 | @extract([Parameter("/value", "event", validators=[MaxLength(5), Mandatory])]) 1039 | def handler(event, value=None): # noqa: pylint - unused-argument 1040 | return {} 1041 | 1042 | response = handler(event) 1043 | self.assertEqual({}, response) 1044 | 1045 | def test_extract_parameter_with_mainimum_length(self): 1046 | event = { 1047 | "value": "correct" 1048 | } 1049 | 1050 | @extract([Parameter("/value", "event", validators=[MinLength(4)])]) 1051 | def handler(event, value=None): # noqa: pylint - unused-argument 1052 | return {} 1053 | 1054 | response = handler(event) 1055 | self.assertEqual({}, response) 1056 | 1057 | def test_error_extracting_parameter_with_min_length(self): 1058 | event = { 1059 | "value": "too short" 1060 | } 1061 | 1062 | @extract([Parameter("/value", "event", validators=[MinLength(15)])]) 1063 | def handler(event, value=None): # noqa: pylint - unused-argument 1064 | return {} 1065 | 1066 | response = handler(event) 1067 | self.assertEqual(response["statusCode"], 400) 1068 | self.assertEqual( 1069 | "{\"message\": [{\"value\": [\"\'too short\' is shorter than minimum length \'15\'\"]}]}", 1070 | response["body"]) 1071 | 1072 | def test_values_are_stringified_in_min_length_validator(self): 1073 | event = { 1074 | "value": 20 1075 | } 1076 | 1077 | @extract([Parameter("/value", "event", validators=[MinLength(1)])]) 1078 | def handler(event, value=None): # noqa: pylint - unused-argument 1079 | return {} 1080 | 1081 | response = handler(event) 1082 | self.assertEqual({}, response) 1083 | 1084 | def test_extract_optional_null_parameter_with_min_length(self): 1085 | event = { 1086 | } 1087 | 1088 | @extract([Parameter("/value", "event", validators=[MinLength(5)])]) 1089 | def handler(event, value=None): # noqa: pylint - unused-argument 1090 | return {} 1091 | 1092 | response = handler(event) 1093 | self.assertEqual({}, response) 1094 | 1095 | def test_extract_mandatory_parameter_with_min_length(self): 1096 | event = { 1097 | "value": "aa" 1098 | } 1099 | 1100 | @extract([Parameter("/value", "event", validators=[MaxLength(2), Mandatory])]) 1101 | def handler(event, value=None): # noqa: pylint - unused-argument 1102 | return {} 1103 | 1104 | response = handler(event) 1105 | self.assertEqual({}, response) 1106 | 1107 | def test_extract_mandatory_parameter_with_length_range(self): 1108 | event = { 1109 | "value": "right in the middle" 1110 | } 1111 | 1112 | @extract([Parameter("/value", "event", validators=[MinLength(10), MaxLength(100), Mandatory])]) 1113 | def handler(event, value=None): # noqa: pylint - unused-argument 1114 | return {} 1115 | 1116 | response = handler(event) 1117 | self.assertEqual({}, response) 1118 | 1119 | @patch("aws_lambda_decorators.decorators.LOGGER") 1120 | def test_exit_on_error_false_bundles_all_errors(self, mock_logger): 1121 | path_1 = "/a/b/c" 1122 | path_2 = "/a/b/d" 1123 | path_3 = "/a/b/e" 1124 | path_4 = "/a/b/f" 1125 | path_5 = "/a/b/g" 1126 | dictionary = { 1127 | "a": { 1128 | "b": { 1129 | "e": 23, 1130 | "f": 15, 1131 | "g": "a" 1132 | } 1133 | } 1134 | } 1135 | 1136 | schema = Schema( 1137 | { 1138 | "g": int 1139 | } 1140 | ) 1141 | 1142 | @extract([ 1143 | Parameter(path_1, "event", validators=[Mandatory], var_name="c"), 1144 | Parameter(path_2, "event", validators=[Mandatory]), 1145 | Parameter(path_3, "event", validators=[Minimum(30)]), 1146 | Parameter(path_4, "event", validators=[Maximum(10)]), 1147 | Parameter(path_5, "event", validators=[ 1148 | RegexValidator(r"[0-9]+"), 1149 | RegexValidator(r"[1][0-9]+"), 1150 | SchemaValidator(schema), 1151 | MinLength(2), 1152 | MaxLength(0) 1153 | ]) 1154 | ], True) 1155 | def handler(event, context, c=None, d=None): # noqa: pylint - unused-argument 1156 | return {} 1157 | 1158 | response = handler(dictionary, None) 1159 | self.assertEqual(400, response["statusCode"]) 1160 | self.assertEqual( 1161 | "{\"message\": [{\"c\": [\"Missing mandatory value\"]}, " 1162 | "{\"d\": [\"Missing mandatory value\"]}, " 1163 | "{\"e\": [\"\'23\' is less than minimum value \'30\'\"]}, " 1164 | "{\"f\": [\"\'15\' is greater than maximum value \'10\'\"]}, " 1165 | "{\"g\": [\"\'a\' does not conform to regular expression \'[0-9]+\'\", " 1166 | "\"\'a\' does not conform to regular expression \'[1][0-9]+\'\", " 1167 | "\"\'a\' does not validate against schema \'Schema({\'g\': })\'\", " 1168 | "\"\'a\' is shorter than minimum length \'2\'\", " 1169 | "\"\'a\' is longer than maximum length \'0\'\"" 1170 | "]}]}", 1171 | response["body"]) 1172 | 1173 | mock_logger.error.assert_called_once_with( 1174 | "Error validating parameters. Errors: %s", 1175 | [ 1176 | {"c": ["Missing mandatory value"]}, 1177 | {"d": ["Missing mandatory value"]}, 1178 | {"e": ["'23' is less than minimum value '30'"]}, 1179 | {"f": ["'15' is greater than maximum value '10'"]}, 1180 | {"g": [ 1181 | "'a' does not conform to regular expression '[0-9]+'", 1182 | "'a' does not conform to regular expression '[1][0-9]+'", 1183 | "'a' does not validate against schema 'Schema({'g': })'", 1184 | "'a' is shorter than minimum length '2'", 1185 | "'a' is longer than maximum length '0'" 1186 | ]} 1187 | ] 1188 | ) 1189 | 1190 | def test_group_errors_true_returns_ok(self): 1191 | path = "/a/b" 1192 | dictionary = { 1193 | "a": { 1194 | "b": "hello" 1195 | } 1196 | } 1197 | 1198 | @extract([Parameter(path, "event", validators=[Mandatory])], True) 1199 | def handler(event, context, b=None): # noqa 1200 | return b 1201 | 1202 | response = handler(dictionary, None) 1203 | 1204 | self.assertEqual("hello", response) 1205 | 1206 | def test_mandatory_parameter_with_default_returns_error_on_empty(self): 1207 | event = { 1208 | "var": "" 1209 | } 1210 | 1211 | @extract([ 1212 | Parameter("/var", "event", validators=[Mandatory], default="hello") 1213 | ]) 1214 | def handler(event, context, var=None): # noqa: pylint - unused-argument 1215 | return {} 1216 | 1217 | response = handler(event, None) 1218 | 1219 | self.assertEqual(response["statusCode"], 400) 1220 | self.assertEqual("{\"message\": [{\"var\": [\"Missing mandatory value\"]}]}", response["body"]) 1221 | 1222 | def test_group_errors_true_on_extract_from_event_returns_ok(self): 1223 | path = "/a/b" 1224 | dictionary = { 1225 | "a": { 1226 | "b": "hello" 1227 | } 1228 | } 1229 | 1230 | @extract_from_event([Parameter(path, validators=[Mandatory])], True) 1231 | def handler(event, context, b=None): # noqa 1232 | return b 1233 | 1234 | response = handler(dictionary, None) 1235 | 1236 | self.assertEqual("hello", response) 1237 | 1238 | def test_group_errors_true_on_extract_from_context_returns_ok(self): 1239 | path = "/a/b" 1240 | dictionary = { 1241 | "a": { 1242 | "b": "hello" 1243 | } 1244 | } 1245 | 1246 | @extract_from_context([Parameter(path, validators=[Mandatory])], True) 1247 | def handler(event, context, b=None): # noqa 1248 | return b 1249 | 1250 | response = handler(None, dictionary) 1251 | 1252 | self.assertEqual("hello", response) 1253 | 1254 | @patch("aws_lambda_decorators.decorators.LOGGER") 1255 | def test_can_output_custom_error_message_on_validation_failure(self, mock_logger): 1256 | path_1 = "/a/b/c" 1257 | path_2 = "/a/b/d" 1258 | path_3 = "/a/b/e" 1259 | path_4 = "/a/b/f" 1260 | path_5 = "/a/b/g" 1261 | dictionary = { 1262 | "a": { 1263 | "b": { 1264 | "e": 23, 1265 | "f": 15, 1266 | "g": "a" 1267 | } 1268 | } 1269 | } 1270 | 1271 | schema = Schema( 1272 | { 1273 | "g": int 1274 | } 1275 | ) 1276 | 1277 | @extract([ 1278 | Parameter(path_1, "event", validators=[Mandatory("Missing c")], var_name="c"), 1279 | Parameter(path_2, "event", validators=[Mandatory("Missing d")]), 1280 | Parameter(path_3, "event", validators=[Minimum(30, "Bad e value {value}, should be at least {condition}")]), 1281 | Parameter(path_4, "event", validators=[Maximum(10, "Bad f")]), 1282 | Parameter(path_5, "event", validators=[ 1283 | RegexValidator(r"[0-9]+", "Bad g regex 1"), 1284 | RegexValidator(r"[1][0-9]+", "Bad g regex 2"), 1285 | SchemaValidator(schema, "Bad g schema"), 1286 | MinLength(2, "Bad g min length"), 1287 | MaxLength(0, "Bad g max length") 1288 | ]) 1289 | ], True) 1290 | def handler(event, context, c=None, d=None): # noqa: pylint - unused-argument 1291 | return {} 1292 | 1293 | response = handler(dictionary, None) 1294 | 1295 | self.assertEqual(400, response["statusCode"]) 1296 | self.assertEqual( 1297 | "{\"message\": [{\"c\": [\"Missing c\"]}, " 1298 | "{\"d\": [\"Missing d\"]}, " 1299 | "{\"e\": [\"Bad e value 23, should be at least 30\"]}, " 1300 | "{\"f\": [\"Bad f\"]}, " 1301 | "{\"g\": [\"Bad g regex 1\", " 1302 | "\"Bad g regex 2\", " 1303 | "\"Bad g schema\", " 1304 | "\"Bad g min length\", " 1305 | "\"Bad g max length\"" 1306 | "]}]}", 1307 | response["body"]) 1308 | 1309 | mock_logger.error.assert_called_once_with( 1310 | "Error validating parameters. Errors: %s", 1311 | [ 1312 | {"c": ["Missing c"]}, 1313 | {"d": ["Missing d"]}, 1314 | {"e": ["Bad e value 23, should be at least 30"]}, 1315 | {"f": ["Bad f"]}, 1316 | {"g": [ 1317 | "Bad g regex 1", 1318 | "Bad g regex 2", 1319 | "Bad g schema", 1320 | "Bad g min length", 1321 | "Bad g max length" 1322 | ]} 1323 | ] 1324 | ) 1325 | 1326 | def test_extract_returns_400_on_missing_mandatory_key_with_regex(self): 1327 | path = "/a/b/c" 1328 | dictionary = { 1329 | "a": { 1330 | "b": { 1331 | } 1332 | } 1333 | } 1334 | 1335 | @extract([Parameter(path, "event", validators=[Mandatory, RegexValidator("[0-9]+")])], group_errors=True) 1336 | def handler(event, context, c=None): # noqa 1337 | return {} 1338 | 1339 | response = handler(dictionary, None) 1340 | 1341 | self.assertEqual(400, response["statusCode"]) 1342 | self.assertEqual("{\"message\": [{\"c\": [\"Missing mandatory value\"]}]}", response["body"]) 1343 | 1344 | def test_extract_nulls_are_returned(self): 1345 | path = "/a/b" 1346 | dictionary = { 1347 | "a": { 1348 | } 1349 | } 1350 | 1351 | @extract([Parameter(path, "event", default=None)], allow_none_defaults=True) 1352 | def handler(event, context, **kwargs): # noqa 1353 | return kwargs["b"] 1354 | 1355 | response = handler(dictionary, None) 1356 | 1357 | self.assertEqual(None, response) 1358 | 1359 | def test_extract_nulls_raises_exception_when_extracted_from_kwargs_if_allow_none_defaults_is_false(self): 1360 | path = "/a/b" 1361 | dictionary = { 1362 | "a": { 1363 | } 1364 | } 1365 | 1366 | @extract([Parameter(path, "event", default=None)], allow_none_defaults=False) 1367 | def handler(event, context, **kwargs): # noqa 1368 | return kwargs["b"] 1369 | 1370 | with self.assertRaises(KeyError): 1371 | handler(dictionary, None) 1372 | 1373 | def test_extract_nulls_preserve_signature_defaults(self): 1374 | path = "/a/b" 1375 | dictionary = { 1376 | "a": { 1377 | } 1378 | } 1379 | 1380 | @extract([Parameter(path, "event")]) 1381 | def handler(event, context, b="Hello"): # noqa 1382 | return b 1383 | 1384 | response = handler(dictionary, None) 1385 | 1386 | self.assertEqual("Hello", response) 1387 | 1388 | def test_extract_nulls_default_on_decorator_takes_precedence(self): 1389 | path = "/a/b" 1390 | dictionary = { 1391 | "a": { 1392 | } 1393 | } 1394 | 1395 | @extract([Parameter(path, "event", default="bye")]) 1396 | def handler(event, context, b="Hello"): # noqa 1397 | return b 1398 | 1399 | response = handler(dictionary, None) 1400 | 1401 | self.assertEqual("bye", response) 1402 | 1403 | @patch("aws_lambda_decorators.decorators.LOGGER") 1404 | def test_extract_returns_400_on_invalid_bool_type(self, mock_logger): 1405 | path = "/a/b/c" 1406 | dictionary = { 1407 | "a": { 1408 | "b": { 1409 | "c": 1 1410 | } 1411 | } 1412 | } 1413 | 1414 | @extract([Parameter(path, "event", [Type(bool)])]) 1415 | def handler(event, context, c=None): # noqa 1416 | return {} 1417 | 1418 | response = handler(dictionary, None) 1419 | self.assertEqual(400, response["statusCode"]) 1420 | self.assertEqual("{\"message\": [{\"c\": [\"\'1\' is not of type \'bool'\"]}]}", response["body"]) 1421 | 1422 | mock_logger.error.assert_called_once_with( 1423 | "Error validating parameters. Errors: %s", 1424 | [{"c": ["'1' is not of type 'bool'"]}] 1425 | ) 1426 | 1427 | @patch("aws_lambda_decorators.decorators.LOGGER") 1428 | def test_extract_returns_400_on_invalid_float_type(self, mock_logger): 1429 | path = "/a/b/c" 1430 | dictionary = { 1431 | "a": { 1432 | "b": { 1433 | "c": 1 1434 | } 1435 | } 1436 | } 1437 | 1438 | @extract([Parameter(path, "event", [Type(float)])]) 1439 | def handler(event, context, c=None): # noqa 1440 | return {} 1441 | 1442 | response = handler(dictionary, None) 1443 | self.assertEqual(400, response["statusCode"]) 1444 | self.assertEqual("{\"message\": [{\"c\": [\"\'1\' is not of type \'float'\"]}]}", response["body"]) 1445 | 1446 | mock_logger.error.assert_called_once_with( 1447 | "Error validating parameters. Errors: %s", 1448 | [{"c": ["'1' is not of type 'float'"]}] 1449 | ) 1450 | 1451 | def test_type_validator_returns_true_when_none_is_passed_in(self): 1452 | path = "/a/b/c" 1453 | dictionary = { 1454 | "a": { 1455 | "b": { 1456 | "c": None 1457 | } 1458 | } 1459 | } 1460 | 1461 | @extract([Parameter(path, "event", [Type(float)])]) 1462 | def handler(event, context, c=None): # noqa 1463 | return c 1464 | 1465 | response = handler(dictionary, None) 1466 | self.assertEqual(None, response) 1467 | 1468 | def test_extract_succeeds_with_valid_type_validation(self): 1469 | path = "/a/b/c" 1470 | dictionary = { 1471 | "a": { 1472 | "b": { 1473 | "c": 1 1474 | } 1475 | } 1476 | } 1477 | 1478 | @extract([Parameter(path, "event", [Type(int)])]) 1479 | def handler(event, context, c=None): # noqa 1480 | return c 1481 | 1482 | response = handler(dictionary, None) 1483 | self.assertEqual(1, response) 1484 | 1485 | @patch("aws_lambda_decorators.decorators.LOGGER") 1486 | def test_extract_returns_400_on_value_not_in_list(self, mock_logger): 1487 | path = "/a/b/c" 1488 | dictionary = { 1489 | "a": { 1490 | "b": { 1491 | "c": "Hello" 1492 | } 1493 | } 1494 | } 1495 | 1496 | @extract([Parameter(path, "event", [EnumValidator("bye", "test", "another")])]) 1497 | def handler(event, context, c=None): # noqa 1498 | return {} 1499 | 1500 | response = handler(dictionary, None) 1501 | self.assertEqual(400, response["statusCode"]) 1502 | self.assertEqual( 1503 | "{\"message\": [{\"c\": [\"\'Hello\' is not in list \'(\'bye\', \'test\', \'another\')'\"]}]}", 1504 | response["body"]) 1505 | 1506 | mock_logger.error.assert_called_once_with( 1507 | "Error validating parameters. Errors: %s", 1508 | [{"c": ["'Hello' is not in list '('bye', 'test', 'another')'"]}] 1509 | ) 1510 | 1511 | def test_extract_suceeds_with_valid_enum_validation(self): 1512 | path = "/a/b/c" 1513 | dictionary = { 1514 | "a": { 1515 | "b": { 1516 | "c": 123 1517 | } 1518 | } 1519 | } 1520 | 1521 | @extract([Parameter(path, "event", [EnumValidator("Hello", 123)])]) 1522 | def handler(event, context, c=None): # noqa 1523 | return c 1524 | 1525 | response = handler(dictionary, None) 1526 | self.assertEqual(123, response) 1527 | 1528 | def test_enum_validator_returns_true_when_none_is_passed_in(self): 1529 | path = "/a/b/c" 1530 | dictionary = { 1531 | "a": { 1532 | "b": { 1533 | "c": None 1534 | } 1535 | } 1536 | } 1537 | 1538 | @extract([Parameter(path, "event", [EnumValidator("Test", "another")])]) 1539 | def handler(event, context, c=None): # noqa 1540 | return c 1541 | 1542 | response = handler(dictionary, None) 1543 | self.assertEqual(None, response) 1544 | 1545 | def test_extract_from_event_missing_parameter_path(self): 1546 | event = { 1547 | "body": "{}" 1548 | } 1549 | 1550 | @extract_from_event(parameters=[Parameter(path="body[json]/optional/value", default="Hello")]) 1551 | def handler(event, context, **kwargs): # noqa 1552 | return { 1553 | "statusCode": HTTPStatus.OK, 1554 | "body": json.dumps(kwargs) 1555 | } 1556 | 1557 | expected_body = json.dumps({ 1558 | "value": "Hello" 1559 | }) 1560 | 1561 | response = handler(event, None) 1562 | 1563 | self.assertEqual(HTTPStatus.OK, response["statusCode"]) 1564 | self.assertEqual(expected_body, response["body"]) 1565 | 1566 | def test_extract_non_empty_parameter(self): 1567 | event = { 1568 | "value": 20 1569 | } 1570 | 1571 | @extract([Parameter("/value", "event", validators=[NonEmpty])]) 1572 | def handler(event, value=None): # noqa: pylint - unused-argument 1573 | return value 1574 | 1575 | response = handler(event) 1576 | self.assertEqual(20, response) 1577 | 1578 | def test_extract_missing_non_empty_parameter(self): 1579 | event = { 1580 | "a": 20 1581 | } 1582 | 1583 | @extract([Parameter("/b", "event", validators=[NonEmpty])]) 1584 | def handler(event, b=None): # noqa: pylint - unused-argument 1585 | return b 1586 | 1587 | response = handler(event) 1588 | self.assertEqual(None, response) 1589 | 1590 | @patch("aws_lambda_decorators.decorators.LOGGER") 1591 | def test_extract_non_empty_parameter_that_is_empty(self, mock_logger): 1592 | event = { 1593 | "a": {} 1594 | } 1595 | 1596 | @extract([Parameter("/a", "event", validators=[NonEmpty])]) 1597 | def handler(event, a=None): # noqa: pylint - unused-argument 1598 | return {} 1599 | 1600 | response = handler(event, None) 1601 | 1602 | self.assertEqual(400, response["statusCode"]) 1603 | self.assertEqual( 1604 | "{\"message\": [{\"a\": [\"Value is empty\"]}]}", 1605 | response["body"]) 1606 | 1607 | mock_logger.error.assert_called_once_with( 1608 | "Error validating parameters. Errors: %s", 1609 | [{"a": ["Value is empty"]}] 1610 | ) 1611 | 1612 | @patch("aws_lambda_decorators.decorators.LOGGER") 1613 | def test_extract_non_empty_parameter_that_is_empty_with_custom_message(self, mock_logger): 1614 | event = { 1615 | "a": {} 1616 | } 1617 | 1618 | @extract([Parameter("/a", "event", validators=[NonEmpty("The value was empty")])]) 1619 | def handler(event, a=None): # noqa: pylint - unused-argument 1620 | return {} 1621 | 1622 | response = handler(event, None) 1623 | 1624 | self.assertEqual(400, response["statusCode"]) 1625 | self.assertEqual( 1626 | "{\"message\": [{\"a\": [\"The value was empty\"]}]}", 1627 | response["body"]) 1628 | 1629 | mock_logger.error.assert_called_once_with( 1630 | "Error validating parameters. Errors: %s", 1631 | [{"a": ["The value was empty"]}] 1632 | ) 1633 | 1634 | def test_extract_date_parameter(self): 1635 | event = { 1636 | "a": "2001-01-01 00:00:00" 1637 | } 1638 | 1639 | @extract([Parameter("/a", "event", validators=[DateValidator("%Y-%m-%d %H:%M:%S")])]) 1640 | def handler(event, a=None): # noqa: pylint - unused-argument 1641 | return a 1642 | 1643 | response = handler(event) 1644 | self.assertEqual("2001-01-01 00:00:00", response) 1645 | 1646 | @patch("aws_lambda_decorators.decorators.LOGGER") 1647 | def test_extract_date_parameter_fails_on_invalid_date(self, mock_logger): 1648 | event = { 1649 | "a": "2001-01-01 35:00:00" 1650 | } 1651 | 1652 | @extract([Parameter("/a", "event", validators=[DateValidator("%Y-%m-%d %H:%M:%S")])]) 1653 | def handler(event, a=None): # noqa: pylint - unused-argument 1654 | return {} 1655 | 1656 | response = handler(event, None) 1657 | 1658 | self.assertEqual(400, response["statusCode"]) 1659 | self.assertEqual("{\"message\": [{\"a\": [\"'2001-01-01 35:00:00' is not a '%Y-%m-%d %H:%M:%S' date\"]}]}", 1660 | response["body"]) 1661 | 1662 | mock_logger.error.assert_called_once_with( 1663 | "Error validating parameters. Errors: %s", 1664 | [{"a": ["'2001-01-01 35:00:00' is not a '%Y-%m-%d %H:%M:%S' date"]}] 1665 | ) 1666 | 1667 | @patch("aws_lambda_decorators.decorators.LOGGER") 1668 | def test_extract_date_parameter_fails_with_custom_error(self, mock_logger): 1669 | event = { 1670 | "a": "2001-01-01 35:00:00" 1671 | } 1672 | 1673 | @extract([Parameter("/a", "event", validators=[DateValidator("%Y-%m-%d %H:%M:%S", "Not a valid date!")])]) 1674 | def handler(event, a=None): # noqa: pylint - unused-argument 1675 | return {} 1676 | 1677 | response = handler(event, None) 1678 | 1679 | self.assertEqual(400, response["statusCode"]) 1680 | self.assertEqual("{\"message\": [{\"a\": [\"Not a valid date!\"]}]}", response["body"]) 1681 | 1682 | mock_logger.error.assert_called_once_with( 1683 | "Error validating parameters. Errors: %s", 1684 | [{"a": ["Not a valid date!"]}] 1685 | ) 1686 | 1687 | def test_extract_date_parameter_valid_on_empty(self): 1688 | event = { 1689 | "a": None 1690 | } 1691 | 1692 | @extract([Parameter("/a", "event", validators=[DateValidator("%Y-%m-%d %H:%M:%S")])]) 1693 | def handler(event, a=None): # noqa: pylint - unused-argument 1694 | return a 1695 | 1696 | response = handler(event) 1697 | self.assertEqual(None, response) 1698 | 1699 | def test_extract_currency_parameter(self): 1700 | event = { 1701 | "a": "GBP" 1702 | } 1703 | 1704 | @extract([Parameter("/a", "event", [CurrencyValidator])]) 1705 | def handler(event, a=None): # noqa: pylint - unused-argument 1706 | return a 1707 | 1708 | response = handler(event) 1709 | self.assertEqual("GBP", response) 1710 | 1711 | def test_currency_validator_returns_true_when_none_is_passed_in(self): 1712 | path = "/a/b/c" 1713 | dictionary = { 1714 | "a": { 1715 | "b": { 1716 | "c": None 1717 | } 1718 | } 1719 | } 1720 | 1721 | @extract([Parameter(path, "event", [CurrencyValidator])]) 1722 | def handler(event, c=None): # noqa 1723 | return c 1724 | 1725 | response = handler(dictionary, None) 1726 | self.assertEqual(None, response) 1727 | 1728 | def test_currency_validator_returns_false_when_invalid_code_passed_in(self): 1729 | event = { 1730 | "a": "GBT" 1731 | } 1732 | 1733 | @extract([Parameter("/a", "event", [CurrencyValidator])]) 1734 | def handler(event, a=None): # noqa: pylint - unused-argument 1735 | return {} 1736 | 1737 | response = handler(event) 1738 | self.assertEqual(400, response["statusCode"]) 1739 | self.assertEqual("{\"message\": [{\"a\": [\"\'GBT\' is not a valid currency code.\"]}]}", 1740 | response["body"]) 1741 | 1742 | def test_currency_validator_can_be_called_non_statically(self): 1743 | event = { 1744 | "a": "GBP" 1745 | } 1746 | 1747 | @extract([Parameter("/a", "event", [CurrencyValidator()])]) 1748 | def handler(event, a=None): # noqa: pylint - unused-argument 1749 | return a 1750 | 1751 | response = handler(event) 1752 | self.assertEqual("GBP", response) 1753 | 1754 | def test_can_apply_transformation(self): 1755 | event = { 1756 | "a": "2" 1757 | } 1758 | 1759 | @extract([Parameter("/a", "event", transform=float)]) 1760 | def handler(event, a=None): # noqa: pylint - unused-argument 1761 | return a 1762 | 1763 | response = handler(event) 1764 | self.assertEqual(2, response) 1765 | 1766 | def test_apply_transformation_on_none_value(self): 1767 | event = { 1768 | "a": None 1769 | } 1770 | 1771 | @extract([Parameter("/a", "event", transform=float)]) 1772 | def handler(event, a=None): # noqa: pylint - unused-argument 1773 | return a 1774 | 1775 | response = handler(event) 1776 | self.assertEqual(None, response) 1777 | 1778 | def test_apply_custom_transformation(self): 1779 | event = { 1780 | "a": "2" 1781 | } 1782 | 1783 | def to_float(arg): 1784 | return float(arg) 1785 | 1786 | @extract([Parameter("/a", "event", transform=to_float)]) 1787 | def handler(event, a=None): # noqa: pylint - unused-argument 1788 | return a 1789 | 1790 | response = handler(event) 1791 | self.assertEqual(2, response) 1792 | 1793 | @patch("aws_lambda_decorators.decorators.LOGGER") 1794 | def test_apply_custom_transformation_with_error_handling(self, mock_logger): 1795 | event = { 1796 | "a": "abc" 1797 | } 1798 | 1799 | def to_float(arg): 1800 | try: 1801 | return float(arg) 1802 | except Exception: 1803 | raise Exception(f"Custom error message: value '{arg}' cannot be converted to float") 1804 | 1805 | @extract([Parameter("/a", "event", transform=to_float)]) 1806 | def handler(event, a=None): # noqa: pylint - unused-argument 1807 | return {} 1808 | 1809 | response = handler(event) 1810 | self.assertEqual(400, response["statusCode"]) 1811 | self.assertEqual("{\"message\": \"Error extracting parameters\"}", response["body"]) 1812 | 1813 | mock_logger.error.assert_called_once_with("%s: %s in argument %s for path %s", 1814 | "Exception", 1815 | "Custom error message: value 'abc' cannot be converted to float", 1816 | "event", 1817 | "/a") 1818 | 1819 | @patch("aws_lambda_decorators.decorators.LOGGER") 1820 | def test_apply_invalid_transformation_raises_error(self, mock_logger): 1821 | event = { 1822 | "a": "abc" 1823 | } 1824 | 1825 | @extract([Parameter("/a", "event", transform=float)]) 1826 | def handler(event, a=None): # noqa: pylint - unused-argument 1827 | return {} 1828 | 1829 | response = handler(event) 1830 | self.assertEqual(400, response["statusCode"]) 1831 | self.assertEqual("{\"message\": \"Error extracting parameters\"}", response["body"]) 1832 | 1833 | mock_logger.error.assert_called_once_with("%s: %s in argument %s for path %s", 1834 | "ValueError", 1835 | "could not convert string to float: 'abc'", 1836 | "event", 1837 | "/a") 1838 | 1839 | @patch("boto3.client") 1840 | def test_push_ws_errors_missing_parameter(self, mock_boto3_client): 1841 | get_websocket_endpoint.cache_clear() 1842 | 1843 | event = { 1844 | "requestContext": { 1845 | "connectionId": "test_connection_id" 1846 | }, 1847 | "body": json.dumps({ 1848 | "invalid_property": "invalid_value" 1849 | }) 1850 | } 1851 | 1852 | @push_ws_errors(websocket_endpoint_url="https://api_id.execute_id.region.amazonaws.com/Prod") 1853 | @extract_from_event(parameters=[ 1854 | Parameter(path="body[json]/valid_property", validators=[Mandatory]) 1855 | ]) 1856 | def lambda_func(event, context, valid_property=None): # noqa: pylint - unused-argument 1857 | return {"statusCode": HTTPStatus.OK} 1858 | 1859 | response = lambda_func(event, None) 1860 | 1861 | expected_data = { 1862 | "type": "error", 1863 | "statusCode": 400, 1864 | "message": [{ 1865 | "valid_property": ["Missing mandatory value"] 1866 | }] 1867 | } 1868 | 1869 | self.assertEqual(response["statusCode"], HTTPStatus.BAD_REQUEST) 1870 | 1871 | mock_boto3_client.assert_called_once_with( 1872 | "apigatewaymanagementapi", 1873 | endpoint_url="https://api_id.execute_id.region.amazonaws.com/Prod" 1874 | ) 1875 | 1876 | mock_boto3_client.return_value.post_to_connection.assert_called_once_with( 1877 | ConnectionId="test_connection_id", 1878 | Data=json.dumps(expected_data) 1879 | ) 1880 | 1881 | @patch("boto3.client") 1882 | def test_push_ws_errors_no_action_on_success(self, mock_boto3_client): 1883 | get_websocket_endpoint.cache_clear() 1884 | 1885 | event = { 1886 | "requestContext": { 1887 | "connectionId": "test_connection_id" 1888 | } 1889 | } 1890 | 1891 | @push_ws_errors(websocket_endpoint_url="https://api_id.execute_id.region.amazonaws.com/Prod") 1892 | def lambda_func(event, context): # noqa: pylint - unused-argument 1893 | return {"statusCode": HTTPStatus.OK} 1894 | 1895 | response = lambda_func(event, None) 1896 | 1897 | self.assertEqual(response["statusCode"], 200) 1898 | 1899 | mock_boto3_client.return_value.post_to_connection.assert_not_called() 1900 | 1901 | @patch("boto3.client") 1902 | def test_push_ws_errors_no_connection_id(self, mock_boto3_client): 1903 | get_websocket_endpoint.cache_clear() 1904 | 1905 | event = { 1906 | "body": { 1907 | "property": "value" 1908 | } 1909 | } 1910 | 1911 | @push_ws_errors(websocket_endpoint_url="https://api_id.execute_id.region.amazonaws.com/Prod") 1912 | def lambda_func(event, context): # noqa: pylint - unused-argument 1913 | return {"statusCode": HTTPStatus.BAD_REQUEST} 1914 | 1915 | response = lambda_func(event, None) 1916 | 1917 | self.assertEqual(response["statusCode"], 400) 1918 | 1919 | mock_boto3_client.return_value.post_to_connection.assert_not_called() 1920 | 1921 | @patch("boto3.client") 1922 | def test_push_ws_response(self, mock_boto3_client): 1923 | get_websocket_endpoint.cache_clear() 1924 | 1925 | event = { 1926 | "requestContext": { 1927 | "connectionId": "test_connection_id" 1928 | } 1929 | } 1930 | 1931 | @push_ws_response(websocket_endpoint_url="https://api_id.execute_id.region.amazonaws.com/Prod") 1932 | def lambda_func(event, context): # noqa: pylint - unused-argument 1933 | return { 1934 | "statusCode": HTTPStatus.OK, 1935 | "body": "Hello, world!" 1936 | } 1937 | 1938 | response = lambda_func(event, None) 1939 | 1940 | self.assertEqual(response["statusCode"], 200) 1941 | self.assertEqual(response["body"], "Hello, world!") 1942 | 1943 | mock_boto3_client.return_value.post_to_connection.assert_called_once_with( 1944 | ConnectionId="test_connection_id", 1945 | Data="{\"statusCode\": 200, \"body\": \"Hello, world!\"}" 1946 | ) 1947 | 1948 | @patch("boto3.client") 1949 | def test_push_ws_response_no_connection_id(self, mock_boto3_client): 1950 | get_websocket_endpoint.cache_clear() 1951 | 1952 | event = { 1953 | "body": "Hello, world!" 1954 | } 1955 | 1956 | @push_ws_response(websocket_endpoint_url="https://api_id.execute_id.region.amazonaws.com/Prod") 1957 | @extract_from_event(parameters=[ 1958 | Parameter(path="body") 1959 | ]) 1960 | def lambda_func(event, context, body=None): # noqa: pylint - unused-argument 1961 | return { 1962 | "statusCode": HTTPStatus.OK, 1963 | "body": body 1964 | } 1965 | 1966 | response = lambda_func(event, None) 1967 | 1968 | self.assertEqual(response["statusCode"], 200) 1969 | self.assertEqual(response["body"], "Hello, world!") 1970 | 1971 | mock_boto3_client.return_value.post_to_connection.assert_not_called() 1972 | 1973 | def test_hsts_returns_headers_in_response(self): 1974 | 1975 | @hsts() 1976 | def handler(): 1977 | return {} 1978 | 1979 | response = handler() 1980 | 1981 | self.assertEqual(response["headers"]["Strict-Transport-Security"], "max-age=63072000") 1982 | 1983 | def test_hsts_returns_headers_in_response_with_custom_age(self): 1984 | 1985 | @hsts(max_age=121212) 1986 | def handler(): 1987 | return {} 1988 | 1989 | response = handler() 1990 | 1991 | self.assertEqual(response["headers"]["Strict-Transport-Security"], "max-age=121212") 1992 | 1993 | def test_hsts_function_returns_non_dictionary(self): 1994 | 1995 | @hsts() 1996 | def handler(): 1997 | return "I am a string" 1998 | 1999 | response = handler() 2000 | 2001 | self.assertEqual(response["statusCode"], HTTPStatus.INTERNAL_SERVER_ERROR) # noqa: pylint-invalid-sequence-index 2002 | self.assertEqual(response["body"], "{\"message\": \"Invalid response type for HSTS header\"}") # noqa: pylint-invalid-sequence-index 2003 | 2004 | 2005 | class IsolatedDecoderTests(unittest.TestCase): 2006 | # Tests have been named so they run in a specific order 2007 | 2008 | ID_PATTERN = "^[0-9a-f]{8}-[0-9a-f]{4}-[1-5][0-9a-f]{3}-[89ab][0-9a-f]{3}-[0-9a-f]{12}$" 2009 | 2010 | PARAMETERS = [ 2011 | Parameter(path="pathParameters/test_id", validators=[Mandatory, RegexValidator(ID_PATTERN)]), 2012 | Parameter(path="body[json]/name", validators=[MaxLength(255)]) 2013 | ] 2014 | 2015 | def test_01_extract_from_event_400(self): 2016 | event = { 2017 | "pathParameters": {} 2018 | } 2019 | 2020 | @extract_from_event(parameters=self.PARAMETERS, group_errors=True, allow_none_defaults=False) 2021 | def handler(event, context, **kwargs): # noqa 2022 | return kwargs 2023 | 2024 | response = handler(event, None) 2025 | self.assertEqual(HTTPStatus.BAD_REQUEST, response["statusCode"]) 2026 | 2027 | def test_02_extract_from_event_200(self): 2028 | test_id = str(uuid4()) 2029 | 2030 | event = { 2031 | "pathParameters": { 2032 | "test_id": test_id 2033 | }, 2034 | "body": json.dumps({ 2035 | "name": "Gird" 2036 | }) 2037 | } 2038 | 2039 | @extract_from_event(parameters=self.PARAMETERS, group_errors=True, allow_none_defaults=False) 2040 | def handler(event, context, **kwargs): # noqa 2041 | return { 2042 | "statusCode": HTTPStatus.OK, 2043 | "body": json.dumps(kwargs) 2044 | } 2045 | 2046 | expected_body = json.dumps({ 2047 | "test_id": test_id, 2048 | "name": "Gird" 2049 | }) 2050 | 2051 | response = handler(event, None) 2052 | 2053 | self.assertEqual(HTTPStatus.OK, response["statusCode"]) 2054 | self.assertEqual(expected_body, response["body"]) 2055 | -------------------------------------------------------------------------------- /tests/test_param.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | from aws_lambda_decorators.classes import Parameter 3 | 4 | 5 | class ParamTests(unittest.TestCase): 6 | 7 | def test_annotations_from_key_returns_none_when_no_annotations(self): 8 | key = "simple" 9 | response = Parameter.get_annotations_from_key(key) 10 | self.assertTrue(response[0] == "simple") 11 | self.assertTrue(response[1] is None) 12 | 13 | def test_annotations_from_key_returns_annotation(self): 14 | key = "simple[annotation]" 15 | response = Parameter.get_annotations_from_key(key) 16 | self.assertTrue(response[0] == "simple") 17 | self.assertTrue(response[1] == "annotation") 18 | -------------------------------------------------------------------------------- /tests/test_utils.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | from aws_lambda_decorators.validators import Mandatory, RegexValidator 3 | from aws_lambda_decorators.utils import is_type_in_list 4 | 5 | 6 | class UtilsTests(unittest.TestCase): 7 | 8 | def test_is_type_in_list_returns_false_if_item_of_type_missing(self): 9 | items = [Mandatory, Mandatory, Mandatory] 10 | self.assertFalse(is_type_in_list(RegexValidator, items)) 11 | 12 | def test_is_type_in_list_returns_true_if_item_of_type_exists(self): 13 | items = [Mandatory, RegexValidator(), Mandatory] 14 | self.assertTrue(is_type_in_list(RegexValidator, items)) 15 | -------------------------------------------------------------------------------- /tools/dev/coverage.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | coverage run --branch --include='aws_lambda_decorators/*.py' -m unittest tests/test_*.py 4 | coverage report -m --fail-under=100 --omit=tests/*,it/*,env* --------------------------------------------------------------------------------