├── .coveragerc
├── .github
└── workflows
│ └── build_publish.yml
├── .gitignore
├── .prospector.yaml
├── CONTRIBUTING.md
├── LICENSE
├── Pipfile
├── README.md
├── aws_lambda_decorators
├── __init__.py
├── classes.py
├── decoders.py
├── decorators.py
├── utils.py
└── validators.py
├── buildspec.yml
├── examples
├── __init__.py
├── examples.py
└── test_examples.py
├── pylintrc
├── requirements.txt
├── setup.py
├── tests
├── __init__.py
├── test_classes.py
├── test_decoders.py
├── test_decorators.py
├── test_param.py
└── test_utils.py
└── tools
└── dev
└── coverage.sh
/.coveragerc:
--------------------------------------------------------------------------------
1 | [run]
2 |
3 | branch = True
4 |
5 | omit =
6 | *__init__.py
7 | examples/*
8 | setup.py
9 | tests/*
10 |
11 | [report]
12 |
13 | show_missing = True
14 |
15 | omit =
16 | *__init__.py
17 | examples/*
18 | setup.py
19 | tests/*
20 |
21 | fail_under = 100
22 |
--------------------------------------------------------------------------------
/.github/workflows/build_publish.yml:
--------------------------------------------------------------------------------
1 | name: Publish 🐍 distributions 📦 to PyPI
2 |
3 | on:
4 | push:
5 | branches:
6 | - master
7 | jobs:
8 | build-and-publish:
9 | name: Build and publish 🐍 distributions 📦 to PyPI
10 | runs-on: ubuntu-latest
11 | steps:
12 | - name: Check out repository code
13 | uses: actions/checkout@v2
14 |
15 | - name: Fetch version
16 | run: |
17 | output="$(python setup.py --version)"
18 | echo "::set-output name=version::$output"
19 | id: setup_version
20 | - name: Print build version
21 | run: echo "${{ steps.setup_version.outputs.version }}"
22 |
23 | # Setup Python (faster than using Python container)
24 | - name: Setup Python
25 | uses: actions/setup-python@v2
26 | with:
27 | python-version: "3.7"
28 |
29 | - name: Install pipenv
30 | run: |
31 | python -m pip install --upgrade pipenv wheel
32 | - id: cache-pipenv
33 | uses: actions/cache@v1
34 | with:
35 | path: ~/.local/share/virtualenvs
36 | key: ${{ runner.os }}-pipenv-${{ hashFiles('**/Pipfile.lock') }}
37 |
38 | - name: Install dependencies
39 | if: steps.cache-pipenv.outputs.cache-hit != 'true'
40 | run: |
41 | pipenv install --deploy --dev
42 |
43 | - name: Build a binary wheel and a source tarball
44 | run: |
45 | pipenv run build
46 | # git tag created using version specified in setup.py
47 | - name: Bump git tag version
48 | id: tag_version
49 | uses: mathieudutour/github-tag-action@v5.6
50 | with:
51 | github_token: ${{ secrets.GITHUB_TOKEN }}
52 | custom_tag: ${{ steps.setup_version.outputs.version }}
53 | - name: Create a GitHub release
54 | uses: ncipollo/release-action@v1
55 | with:
56 | tag: ${{ steps.tag_version.outputs.new_tag }}
57 | name: Release ${{ steps.tag_version.outputs.new_tag }}
58 | body: ${{ github.event.head_commit.message }}
59 |
60 | - name: Publish distribution 📦 to Test PyPI
61 | uses: pypa/gh-action-pypi-publish@master
62 | with:
63 | password: ${{ secrets.TEST_PYPI_API_TOKEN }}
64 | repository_url: https://test.pypi.org/legacy/
65 | verbose: true
66 | skip_existing: true
67 |
68 | - name: Publish distribution 📦 to PyPI
69 | uses: pypa/gh-action-pypi-publish@master
70 | with:
71 | password: ${{ secrets.PYPI_API_TOKEN }}
72 | verbose: true
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | *.egg-info/
24 | .installed.cfg
25 | *.egg
26 | MANIFEST
27 |
28 | # PyInstaller
29 | # Usually these files are written by a python script from a template
30 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
31 | *.manifest
32 | *.spec
33 |
34 | # Installer logs
35 | pip-log.txt
36 | pip-delete-this-directory.txt
37 |
38 | # Unit test / coverage reports
39 | htmlcov/
40 | .tox/
41 | .coverage
42 | .coverage.*
43 | .cache
44 | nosetests.xml
45 | coverage.xml
46 | *.cover
47 | .hypothesis/
48 | .pytest_cache/
49 | coverage.sh
50 | .project
51 |
52 | # Translations
53 | *.mo
54 | *.pot
55 |
56 | # Django stuff:
57 | *.log
58 | local_settings.py
59 | db.sqlite3
60 |
61 | # Flask stuff:
62 | instance/
63 | .webassets-cache
64 |
65 | # Scrapy stuff:
66 | .scrapy
67 |
68 | # Sphinx documentation
69 | docs/_build/
70 |
71 | # PyBuilder
72 | target/
73 |
74 | # Jupyter Notebook
75 | .ipynb_checkpoints
76 |
77 | # pyenv
78 | .python-version
79 |
80 | # celery beat schedule file
81 | celerybeat-schedule
82 |
83 | # SageMath parsed files
84 | *.sage.py
85 |
86 | # Environments
87 | .env
88 | .venv
89 | env/
90 | venv/
91 | ENV/
92 | env.bak/
93 | venv.bak/
94 | .idea/
95 |
96 | # Spyder project settings
97 | .spyderproject
98 | .spyproject
99 |
100 | # Rope project settings
101 | .ropeproject
102 |
103 | # mkdocs documentation
104 | /site
105 |
106 | # mypy
107 | .mypy_cache/
108 |
109 | # Scripts
110 | *.sh
111 |
112 | # Mac
113 | .DS_Store
114 | */.DS_Store
115 |
116 | # Visual studio code
117 | .vscode
118 |
119 | # vim (sorry)
120 | *.swp
121 |
--------------------------------------------------------------------------------
/.prospector.yaml:
--------------------------------------------------------------------------------
1 | output-format: text
2 |
3 | strictness: veryhigh
4 | test-warnings: true
5 | doc-warnings: false
6 |
7 | pep8:
8 | full: true
9 | options:
10 | max-line-length: 120
11 |
12 | pep257:
13 | disable:
14 | - D203
15 | - D212
16 |
17 | pylint:
18 | options:
19 | max-line-length: 120
--------------------------------------------------------------------------------
/CONTRIBUTING.md:
--------------------------------------------------------------------------------
1 | # Contributing
2 |
3 | If you want to contribute to this project, please submit an issue describing your proposed change. We will respond to you as soon as we can.
4 |
5 | If you want to work on that change, fork this Github repo and clone the fork locally. Install the requirements in a python3 virtual environment, using the appropriate version of pip:
6 |
7 | `pip install -r requirements.txt`
8 |
9 | Before submitting a PR, please ensure that:
10 |
11 | - you run [__Bandit__](https://pypi.org/project/bandit/) for security checking and all checks are passing:
12 |
13 | `bandit -r .`
14 |
15 | - you run [__Prospector__](https://pypi.org/project/prospector/) for code analysis and all checks are passing:
16 |
17 | `prospector`
18 |
19 | - you run [__Coverage__](https://pypi.org/project/coverage/) and all unit tests are passing:
20 |
21 | `coverage run --source='.' -m unittest`
22 |
23 | `coverage report`
24 |
25 | - you can run the test examples like this:
26 |
27 | `python -m unittest examples.test_examples`
28 |
29 | Thanks for contributing!
30 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2019 Grid Smarter Cities
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/Pipfile:
--------------------------------------------------------------------------------
1 | [[source]]
2 | url = "https://pypi.org/simple"
3 | verify_ssl = true
4 | name = "pypi"
5 |
6 | [packages]
7 | pyjwt = "==1.7.1"
8 |
9 | [dev-packages]
10 | pytest = "*"
11 | setuptools = "*"
12 | wheel = "*"
13 | twine = "*"
14 | tqdm = "*"
15 | pre-commit = "*"
16 | commitizen = "*"
17 | toml = "*"
18 | coverage = "*"
19 | bandit = "*"
20 | pylint_quotes = "*"
21 | schema ="*"
22 | PyJWT ="*"
23 | boto3="*"
24 |
25 | [requires]
26 | python_version = "3.7"
27 |
28 | [scripts]
29 | test = "pytest"
30 | build = "python3 setup.py sdist bdist_wheel"
31 | deploy = "twine upload dist/*"
32 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | [
](https://www.gridsmartercities.com/)
2 |
3 | 
4 | [](https://opensource.org/licenses/MIT)
5 | 
6 | \
7 | \
8 | 
9 | 
10 | 
11 | 
12 |
13 | # aws-lambda-decorators
14 |
15 | A set of Python decorators to ease the development of AWS lambda functions.
16 |
17 | ## Installation
18 |
19 | The easiest way to use these AWS Lambda Decorators is to install them through Pip:
20 |
21 | `pip install aws-lambda-decorators`
22 |
23 | ## Logging
24 |
25 | The Logging level of the decorators can be controlled by setting a LOG_LEVEL environment variable. In python:
26 |
27 | `os.environ["LOG_LEVEL"] = "INFO"`
28 |
29 | The default value is "INFO"
30 |
31 | ## Package Contents
32 |
33 | ### [Decorators](https://github.com/gridsmartercities/aws-lambda-decorators/blob/master/aws_lambda_decorators/decorators.py)
34 |
35 | The current list of AWS Lambda Python Decorators includes:
36 |
37 | * [__extract__](#extract): a decorator to extract and validate specific keys of a dictionary parameter passed to a AWS Lambda function.
38 | * [__extract_from_event__](#extract_from_event): a facade of [__extract__](#extract) to extract and validate keys from an AWS API Gateway lambda function _event_ parameter.
39 | * [__extract_from_context__](#extract_from_context): a facade of [__extract__](#extract) to extract and validate keys from an AWS API Gateway lambda function _context_ parameter.
40 | * [__extract_from_ssm__](#extract_from_ssm): a decorator to extract from AWS SSM the values of a set of parameter keys.
41 | * [__validate__](#validate): a decorator to validate a list of function parameters.
42 | * [__log__](#log): a decorator to log the parameters passed to the lambda function and/or the response of the lambda function.
43 | * [__handle_exceptions__](#handle_exceptions): a decorator to handle any type of declared exception generated by the lambda function.
44 | * [__response_body_as_json__](#response_body_as_json): a decorator to transform a response dictionary body to a json string.
45 | * [__handle_all_exceptions__](#handle_all_exceptions): a decorator to handle all exceptions thrown by the lambda function.
46 | * [__cors__](#cors): a decorator to add cors headers to a lambda function.
47 | * [__push_ws_errors__](#push_ws_errors): a decorator to push unsuccessful responses back to the calling user via websockets with api gateway.
48 | * [__push_ws_responses__](#push_ws_response): a decorator to push all responses back to the calling user via websockets with api gateway.
49 |
50 |
51 | ### [Validators](https://github.com/gridsmartercities/aws-lambda-decorators/blob/master/aws_lambda_decorators/validators.py)
52 |
53 | Currently, the package offers 12 validators:
54 |
55 | * __Mandatory__: Checks if a parameter has a value.
56 | * __RegexValidator__: Checks a parameter against a regular expression.
57 | * __SchemaValidator__: Checks if an object adheres to the schema. Uses [schema](https://github.com/keleshev/schema) library.
58 | * __Minimum__: Checks if an optional numerical value is greater than a minimum value.
59 | * __Maximum__: Checks if an optional numerical value is less than a maximum value.
60 | * __MinLength__: Checks if an optional string value is longer than a minimum length.
61 | * __MaxLength__: Checks if an optional string value is shorter than a maximum length.
62 | * __Type__: Checks if an optional object value is of a given python type.
63 | * __EnumValidator__: Checks if an optional object value is in a list of valid values.
64 | * __NonEmpty__: Checks if an optional object value is not an empty value.
65 | * __DateValidator__: Checks if a given string is a valid date according to a passed in date format.
66 | * __CurrencyCodeValidator__: Checks if a given string is a valid currency code (ISO 4217).
67 |
68 | ### [Decoders](https://github.com/gridsmartercities/aws-lambda-decorators/blob/master/aws_lambda_decorators/decoders.py)
69 |
70 | The package offers functions to decode from JSON and JWT.
71 |
72 | * __decode_json__: decodes/converts a json string to a python dictionary
73 | * __decode_jwt__: decodes/converts a JWT string to a python dictionary
74 |
75 | ## Examples
76 |
77 | You can see some basic examples in the [examples](https://github.com/gridsmartercities/aws-lambda-decorators/blob/master/examples/examples.py) folder.
78 |
79 | ### extract
80 |
81 | This decorator extracts and validates values from dictionary parameters passed to a Lambda Function.
82 |
83 | * The decorator takes a list of __Parameter__ objects.
84 | * Each __Parameter__ object requires a non-empty path to the parameter in the dictionary, and the name of the dictionary (func_param_name)
85 | * The parameter value is extracted and added as a kwarg to the lambda handler (or any other decorated function/method).
86 | * You can add the parameter to the handler signature, or access it in the handler through kwargs.
87 | * The name of the extracted parameter is defaulted to the last element of the path name, but can be changed by passing a (valid pythonic variable name) var_name
88 | * You can define a default value for the parameter in the __Parameter__ or in the lambda handler itself.
89 | * A 400 exception is raised when the parameter cannot be extracted or when it does not validate.
90 | * A variable path (e.g. '/headers/Authorization[jwt]/sub') can be annotated to specify a decoding. In the example, Authorization might contain a JWT, which needs to be decoded before accessing the "sub" element.
91 |
92 | Example:
93 | ```python
94 | @extract(parameters=[
95 | Parameter(path='/parent/my_param', func_param_name='a_dictionary'), # extracts a non mandatory my_param from a_dictionary
96 | Parameter(path='/parent/missing_non_mandatory', func_param_name='a_dictionary', default='I am missing'), # extracts a non mandatory missing_non_mandatory from a_dictionary
97 | Parameter(path='/parent/missing_mandatory', func_param_name='a_dictionary'), # does not fail as the parameter is not validated as mandatory
98 | Parameter(path='/parent/child/id', validators=[Mandatory], var_name='user_id', func_param_name='another_dictionary') # extracts a mandatory id as "user_id" from another_dictionary
99 | ])
100 | def extract_example(a_dictionary, another_dictionary, my_param='aDefaultValue', missing_non_mandatory='I am missing', missing_mandatory=None, user_id=None):
101 | """
102 | Given these two dictionaries:
103 |
104 | a_dictionary = {
105 | 'parent': {
106 | 'my_param': 'Hello!'
107 | },
108 | 'other': 'other value'
109 | }
110 |
111 | another_dictionary = {
112 | 'parent': {
113 | 'child': {
114 | 'id': '123'
115 | }
116 | }
117 | }
118 |
119 | you can now access the extracted parameters directly:
120 | """
121 | return my_param, missing_non_mandatory, missing_mandatory, user_id
122 | ```
123 |
124 | Or you can use kwargs instead of specific parameter names:
125 |
126 | Example:
127 | ```python
128 | @extract(parameters=[
129 | Parameter(path='/parent/my_param', func_param_name='a_dictionary') # extracts a non mandatory my_param from a_dictionary
130 | ])
131 | def extract_to_kwargs_example(a_dictionary, **kwargs):
132 | """
133 | a_dictionary = {
134 | 'parent': {
135 | 'my_param': 'Hello!'
136 | },
137 | 'other': 'other value'
138 | }
139 | """
140 | return kwargs['my_param'] # returns 'Hello!'
141 | ```
142 |
143 | A missing mandatory parameter, or a parameter that fails validation, will raise an exception:
144 |
145 | Example:
146 | ```python
147 | @extract(parameters=[
148 | Parameter(path='/parent/mandatory_param', func_param_name='a_dictionary', validators=[Mandatory]) # extracts a mandatory mandatory_param from a_dictionary
149 | ])
150 | def extract_mandatory_param_example(a_dictionary, mandatory_param=None):
151 | return 'Here!' # this part will never be reached, if the mandatory_param is missing
152 |
153 | response = extract_mandatory_param_example({'parent': {'my_param': 'Hello!'}, 'other': 'other value'} )
154 |
155 | print(response) # prints { 'statusCode': 400, 'body': '{"message": [{"mandatory_param": ["Missing mandatory value"]}]}' } and logs a more detailed error
156 |
157 | ```
158 |
159 | You can add custom error messages to all validators, and incorporate to those error messages the validated value and the validation condition:
160 |
161 | Example:
162 | ```python
163 | @extract(parameters=[
164 | Parameter(path='/parent/an_int', func_param_name='a_dictionary', validators=[Minimum(100, 'Bad value {value}: should be at least {condition}')]) # extracts a mandatory mandatory_param from a_dictionary
165 | ])
166 | def extract_minimum_param_with_custom_error_example(a_dictionary, mandatory_param=None):
167 | return 'Here!' # this part will never be reached, if the an_int param is less than 100
168 |
169 | response = extract_minimum_param_with_custom_error_example({'parent': {'an_int': 10}})
170 |
171 | print(response) # prints { 'statusCode': 400, 'body': '{"message": [{"an_int": ["Bad value 10: should be at least 100"]}]}' } and logs a more detailed error
172 |
173 | ```
174 |
175 | You can group the validation errors together (instead of exiting on first error).
176 |
177 | Example:
178 | ```python
179 | @extract(parameters=[
180 | Parameter(path='/parent/mandatory_param', func_param_name='a_dictionary', validators=[Mandatory]), # extracts two mandatory parameters from a_dictionary
181 | Parameter(path='/parent/another_mandatory_param', func_param_name='a_dictionary', validators=[Mandatory]),
182 | Parameter(path='/parent/an_int', func_param_name='a_dictionary', validators=[Maximum(10), Minimum(5)])
183 | ], group_errors=True) # groups both errors together
184 | def extract_multiple_param_example(a_dictionary, mandatory_param=None, another_mandatory_param=None, an_int=0):
185 | return 'Here!' # this part will never be reached, if the mandatory_param is missing
186 |
187 | response = extract_multiple_param_example({'parent': {'my_param': 'Hello!', 'an_int': 20}, 'other': 'other value'})
188 |
189 | print(response) # prints {'statusCode': 400, 'body': '{"message": [{"mandatory_param": ["Missing mandatory value"]}, {"another_mandatory_param": ["Missing mandatory value"]}, {"an_int": ["\'20\' is greater than maximum value \'10\'"]}]}'}
190 |
191 | ```
192 |
193 | You can decode any part of the parameter path from json or any other existing annotation.
194 |
195 | Example:
196 | ```python
197 | @extract(parameters=[
198 | Parameter(path='/parent[json]/my_param', func_param_name='a_dictionary') # extracts a non mandatory my_param from a_dictionary
199 | ])
200 | def extract_from_json_example(a_dictionary, my_param=None):
201 | """
202 | a_dictionary = {
203 | 'parent': '{"my_param": "Hello!" }',
204 | 'other': 'other value'
205 | }
206 | """
207 | return my_param # returns 'Hello!'
208 |
209 | ```
210 |
211 | You can also use an integer annotation to access an specific list element by index.
212 |
213 | Example:
214 | ```python
215 | @extract(parameters=[
216 | Parameter(path='/parent[1]/my_param', func_param_name='a_dictionary') # extracts a non mandatory my_param from a_dictionary
217 | ])
218 | def extract_from_list_example(a_dictionary, my_param=None):
219 | """
220 | a_dictionary = {
221 | 'parent': [
222 | {'my_param': 'Hello!'},
223 | {'my_param': 'Bye!'}
224 | ]
225 | }
226 | """
227 | return my_param # returns 'Bye!'
228 |
229 | ```
230 |
231 | You can extract all parameters into a dictionary
232 |
233 | Example:
234 | ```python
235 | @extract(parameters=[
236 | Parameter(path='/params/my_param_1', func_param_name='a_dictionary'), # extracts a non mandatory my_param_1 from a_dictionary
237 | Parameter(path='/params/my_param_2', func_param_name='a_dictionary') # extracts a non mandatory my_param_2 from a_dictionary
238 | ])
239 | def extract_dictionary_example(a_dictionary, **kwargs):
240 | """
241 | a_dictionary = {
242 | 'params': {
243 | 'my_param_1': 'Hello!',
244 | 'my_param_2': 'Bye!'
245 | }
246 | }
247 | """
248 | return kwargs # returns {'my_param_1': 'Hello!', 'my_param_2': 'Bye!'}
249 |
250 | ```
251 |
252 | You can apply a transformation to an extracted value. The transformation will happen before validation.
253 |
254 | Example:
255 | ```python
256 | @extract(parameters=[
257 | Parameter(path='/params/my_param', func_param_name='a_dictionary', transform=int) # extracts a non mandatory my_param from a_dictionary
258 | ])
259 | def extract_with_transform_example(a_dictionary, my_param=None):
260 | """
261 | a_dictionary = {
262 | 'params': {
263 | 'my_param': '2' # the original value is the string '2'
264 | }
265 | }
266 | """
267 | return my_param # returns the int value 2
268 |
269 | ```
270 |
271 | The transform function can be any function, with its own error handling.
272 |
273 | Example:
274 | ```python
275 |
276 | def to_int(arg):
277 | try:
278 | return int(arg)
279 | except Exception:
280 | raise Exception("My custom error message")
281 |
282 | @extract(parameters=[
283 | Parameter(path='/params/my_param', func_param_name='a_dictionary', transform=to_int) # extracts a non mandatory my_param from a_dictionary
284 | ])
285 | def extract_with_custom_transform_example(a_dictionary, my_param=None):
286 | return {}
287 |
288 | response = extract_with_custom_transform_example({'params': {'my_param': 'abc'}})
289 |
290 | print(response) # prints {'statusCode': 400, 'body': '{"message": "Error extracting parameters"}'}, and the logs will contain the "My custom error message" message.
291 |
292 |
293 | ```
294 |
295 | ### extract_from_event
296 |
297 | This decorator is just a facade to the [extract](#extract) method to be used in AWS Api Gateway Lambdas. It automatically extracts from the event lambda parameter.
298 |
299 | Example:
300 | ```python
301 | @extract_from_event(parameters=[
302 | Parameter(path='/body[json]/my_param', validators=[Mandatory]), # extracts a mandatory my_param from the json body of the event
303 | Parameter(path='/headers/Authorization[jwt]/sub', validators=[Mandatory], var_name='user_id') # extract the mandatory sub value as user_id from the authorization JWT
304 | ])
305 | def extract_from_event_example(event, context, my_param=None, user_id=None):
306 | """
307 | event = {
308 | 'body': '{"my_param": "Hello!"}',
309 | 'headers': {
310 | 'Authorization': 'eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiaWF0IjoxNTE2MjM5MDIyfQ.SflKxwRJSMeKKF2QT4fwpMeJf36POk6yJV_adQssw5c'
311 | }
312 | }
313 | """
314 | return my_param, user_id # returns ('Hello!', '1234567890')
315 | ```
316 |
317 | ### extract_from_context
318 |
319 | This decorator is just a facade to the [extract](#extract) method to be used in AWS Api Gateway Lambdas. It automatically extracts from the context lambda parameter.
320 |
321 | Example:
322 | ```python
323 | @extract_from_context(parameters=[
324 | Parameter(path='/parent/my_param', validators=[Mandatory]) # extracts a mandatory my_param from the parent element in context
325 | ])
326 | def extract_from_context_example(event, context, my_param=None):
327 | """
328 | context = {
329 | 'parent': {
330 | 'my_param': 'Hello!'
331 | }
332 | }
333 | """
334 | return my_param # returns 'Hello!'
335 | ```
336 |
337 | ### extract_from_ssm
338 |
339 | This decorator extracts a parameter from AWS SSM and passes the parameter down to your function as a kwarg.
340 |
341 | * The decorator takes a list of __SSMParameter__ objects.
342 | * Each __SSMParameter__ object requires the name of the SSM parameter (ssm_name)
343 | * If no var_name is passed in, the extracted value is passed to the function with the ssm_name name
344 |
345 | Example:
346 | ```python
347 | @extract_from_ssm(ssm_parameters=[
348 | SSMParameter(ssm_name='one_key'), # extracts the value of one_key from SSM as a kwarg named "one_key"
349 | SSMParameter(ssm_name='another_key', var_name="another") # extracts another_key as a kwarg named "another"
350 | ])
351 | def extract_from_ssm_example(your_func_params, one_key=None, another=None):
352 | return your_func_params, one_key, another
353 | ```
354 |
355 | ### validate
356 |
357 | This decorator validates a list of non dictionary parameters from your lambda function.
358 |
359 | * The decorator takes a list of __ValidatedParameter__ objects.
360 | * Each parameter object needs the name of the lambda function parameter that it is going to be validated, and the list of rules to validate.
361 | * A 400 exception is raised when the parameter does not validate.
362 |
363 | Example:
364 | ```python
365 | @validate(parameters=[
366 | ValidatedParameter(func_param_name='a_param', validators=[Mandatory]), # validates a_param as mandatory
367 | ValidatedParameter(func_param_name='another_param', validators=[Mandatory, RegexValidator(r'\d+')]) # validates another_param as mandatory and containing only digits
368 | ValidatedParameter(func_param_name='param_with_schema', validators=[SchemaValidator(Schema({'a': Or(str, dict)}))]) # validates param_with_schema as an object with specified schema
369 | ])
370 | def validate_example(a_param, another_param, param_with_schema):
371 | return a_param, another_param, param_with_schema # returns 'Hello!', '123456', {'a': {'b': 'c'}}
372 |
373 | validate_example('Hello!', '123456', {'a': {'b': 'c'}})
374 | ```
375 |
376 | Given the same function `validate_example`, a 400 exception is returned if at least one parameter does not validate (as per the [extract](#extract) decorator, you can group errors with the group_errors flag):
377 |
378 | ```python
379 | validate_example('Hello!', 'ABCD') # returns a 400 status code and an error message
380 | ```
381 |
382 | ### log
383 |
384 | This decorator allows for logging the function arguments and/or the response.
385 |
386 | Example:
387 | ```python
388 | @log(parameters=True, response=True)
389 | def log_example(parameters):
390 | return 'Done!'
391 |
392 | log_example('Hello!') # logs 'Hello!' and 'Done!'
393 | ```
394 |
395 | ### handle_exceptions
396 |
397 | This decorator handles a list of exceptions, returning a 400 response containing the specified friendly message to the caller.
398 |
399 | * The decorator takes a list of __ExceptionHandler__ objects.
400 | * Each __ExceptionHandler__ requires the type of exception to check, and an optional friendly message to return to the caller.
401 |
402 | Example:
403 | ```python
404 | @handle_exceptions(handlers=[
405 | ExceptionHandler(ClientError, "Your message when a client error happens.")
406 | ])
407 | def handle_exceptions_example():
408 | dynamodb = boto3.resource('dynamodb')
409 | table = dynamodb.Table('non_existing_table')
410 | table.query(KeyConditionExpression=Key('user_id').eq(user_id))
411 | # ...
412 |
413 | handle_exceptions_example() # returns {'body': '{"message": "Your message when a client error happens."}', 'statusCode': 400}
414 | ```
415 |
416 | ### handle_all_exceptions
417 |
418 | This decorator handles all exceptions thrown by a lambda, returning a 400 response and the exception's message.
419 |
420 | Example:
421 | ```python
422 | @handle_all_exceptions()
423 | def handle_exceptions_example():
424 | test_list = [1, 2, 3]
425 | invalid_value = test_list[5]
426 | # ...
427 |
428 | handle_all_exceptions_example() # returns {'body': '{"message": "list index out of range"}, 'statusCode': 400}
429 | ```
430 |
431 | ### response_body_as_json
432 |
433 | This decorator ensures that, if the response contains a body, the body is dumped as json.
434 |
435 | * Returns a 500 error if the response body cannot be dumped as json.
436 |
437 | Example:
438 | ```python
439 | @response_body_as_json
440 | def response_body_as_json_example():
441 | return {'statusCode': 400, 'body': {'param': 'hello!'}}
442 |
443 | response_body_as_json_example() # returns { 'statusCode': 400, 'body': "{'param': 'hello!'}" }
444 | ```
445 |
446 | ### cors
447 |
448 | This decorator adds your defined CORS headers to the decorated function response.
449 |
450 | * Returns a 500 error if one or more of the CORS headers have an invalid type
451 |
452 | Example:
453 | ```python
454 | @cors(allow_origin='*', allow_methods='POST', allow_headers='Content-Type', max_age=86400)
455 | def cors_example():
456 | return {'statusCode': 200}
457 |
458 | cors_example() # returns {'statusCode': 200, 'headers': {'access-control-allow-origin': '*', 'access-control-allow-methods': 'POST', 'access-control-allow-headers': 'Content-Type', 'access-control-max-age': 86400}}
459 | ```
460 |
461 | ### hsts
462 |
463 | This decorator adds HSTS header to the decorated function response. Uses 2 years max-age (recommended default from https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Strict-Transport-Security) unless custom value provided as parameter.
464 |
465 | Example:
466 | ```python
467 | @hsts()
468 | def hsts_example():
469 | return {'statusCode': 200}
470 |
471 | hsts_example() # returns {'statusCode': 200, 'headers': {'Strict-Transport-Security': 'max-age=63072000'}}
472 | ```
473 |
474 | ### push_ws_errors
475 |
476 | This decorator pushes unsuccessful responses back to the calling client over websockets built on api gateway
477 |
478 | This decorator requires the client is connected to the websocket api gateway instance, and will therefore have a connection id
479 |
480 | Example:
481 | ```py
482 | @push_ws_errors('https://api_id.execute_id.region.amazonaws.com/Prod')
483 | @handle_all_exceptions()
484 | def handler(event, context):
485 | return {
486 | 'statusCode': 400,
487 | 'body': {
488 | 'message': 'Bad request'
489 | }
490 | }
491 |
492 | # will push {'type': 'error', 'statusCode': 400, 'message': 'Bad request'} back to the client via websockets
493 | ```
494 |
495 | ### push_ws_response
496 |
497 | This decorator pushes all responses back to the calling client over websockets built on api gateway
498 |
499 | This decorator requires the client is connected to the websocket api gateway instance, and will therefore have a connection id
500 |
501 | Example:
502 | ```py
503 | @push_ws_response('https://api_id.execute_id.region.amazonaws.com/Prod')
504 | def handler(event, context):
505 | return {
506 | 'statusCode': 200,
507 | 'body': 'Hello, world!'
508 | }
509 |
510 | # will push {'statusCode': 200, 'body': 'Hello, world!'} back to the client via websockets
511 | ```
512 |
513 | ## Writing your own validators
514 |
515 | You can create your own validators by inheriting from the Validator class.
516 |
517 | Fix length validator example:
518 |
519 | ```python
520 | class FixLength(Validator):
521 | ERROR_MESSAGE = "'{value}' length should be '{condition}'"
522 |
523 | def __init__(self, fix_length: int, error_message=None):
524 | super().__init__(error_message=error_message, condition=fix_length)
525 |
526 | def validate(self, value=None):
527 | if value is None:
528 | return True
529 |
530 | return len(str(value)) == self._condition
531 | ```
532 |
533 | ## Documentation
534 |
535 | You can get the docstring help by running:
536 |
537 | ```bash
538 | >>> from aws_lambda_decorators.decorators import extract
539 | >>> help(extract)
540 | ```
541 |
542 | ## Links
543 |
544 | * [PyPi](https://pypi.org/project/aws-lambda-decorators/)
545 | * [Test PyPi](https://test.pypi.org/project/aws-lambda-decorators/)
546 | * [Github](https://github.com/gridsmartercities/aws-lambda-decorators)
547 |
--------------------------------------------------------------------------------
/aws_lambda_decorators/__init__.py:
--------------------------------------------------------------------------------
1 | from aws_lambda_decorators.classes import * # noqa
2 | from aws_lambda_decorators.decoders import * # noqa
3 | from aws_lambda_decorators.decorators import * # noqa
4 | from aws_lambda_decorators.utils import * # noqa
5 | from aws_lambda_decorators.validators import * # noqa
6 |
--------------------------------------------------------------------------------
/aws_lambda_decorators/classes.py:
--------------------------------------------------------------------------------
1 | """All the classes used as parameters for the decorators."""
2 | from aws_lambda_decorators.decoders import decode
3 | from aws_lambda_decorators.utils import is_valid_variable_name
4 |
5 | PATH_DIVIDER = "/"
6 | ANNOTATIONS_START = "["
7 | ANNOTATIONS_END = "]"
8 |
9 |
10 | class ExceptionHandler:
11 | """Class mapping a friendly error message to a given Exception."""
12 |
13 | def __init__(self, exception, friendly_message=None, status_code=400):
14 | """
15 | Sets the private variables of the ExceptionHandler object.
16 |
17 | Args:
18 | exception (object|Exception): An exception to be handled.
19 | friendly_message (str): Friendly Message to be returned if the exception is caught.
20 | """
21 | self._exception = exception
22 | self._friendly_message = friendly_message
23 | self._status_code = status_code
24 |
25 | @property
26 | def friendly_message(self):
27 | """Getter for the friendly message parameter."""
28 | return self._friendly_message
29 |
30 | @property
31 | def exception(self):
32 | """Getter for the exception parameter."""
33 | return self._exception
34 |
35 | @property
36 | def status_code(self):
37 | """Getter for the status code parameter."""
38 | return self._status_code
39 |
40 |
41 | class BaseParameter: # noqa: pylint - too-few-public-methods
42 | """Parent class of all parameter classes."""
43 | def __init__(self, var_name):
44 | """
45 | Set the private variables of the BaseParameter object.
46 |
47 | Args:
48 | var_name (str): The name of the variable where to store the extracted parameter.
49 | """
50 | self._name = var_name
51 |
52 | def get_var_name(self):
53 | """Gets the name of the variable that represents the parameter."""
54 | if self._name and not is_valid_variable_name(self._name):
55 | raise SyntaxError(self._name)
56 | return self._name
57 |
58 |
59 | class SSMParameter(BaseParameter):
60 | """Class used for defining the key and, optionally, the variable name for ssm parameter extraction."""
61 |
62 | def __init__(self, ssm_name, var_name=None):
63 | """
64 | Set the private variables of the SSMParameter object.
65 |
66 | Args:
67 | ssm_name (str): Key of the variable in the AWS parameter store
68 | var_name (str): Optional, the name of variable to store the extracted value to. Defaults to ssm_name.
69 | """
70 | self._ssm_name = ssm_name
71 | BaseParameter.__init__(self, var_name if var_name else ssm_name)
72 |
73 | def get_ssm_name(self):
74 | """Getter for the ssm_name parameter."""
75 | return self._ssm_name
76 |
77 |
78 | class ValidatedParameter:
79 | """Class used to encapsulate the validation methods parameter data."""
80 |
81 | def __init__(self, func_param_name=None, validators=None):
82 | """
83 | Sets the private variables of the ValidatedParameter object.
84 | Args:
85 | func_param_name (str): the name for the dictionary in the function signature
86 | def fun(event, context). To extract from context func_param_name has to be "context"
87 | validators (list): A list of validators the value must conform to (e.g. Mandatory,
88 | RegexValidator(my_regex), ...)
89 | """
90 | self._func_param_name = func_param_name
91 | self._validators = validators
92 |
93 | @property
94 | def func_param_name(self):
95 | """Getter for the func_param_name parameter."""
96 | return self._func_param_name
97 |
98 | @func_param_name.setter
99 | def func_param_name(self, value):
100 | """Setter for the func_param_name parameter."""
101 | self._func_param_name = value
102 |
103 | def validate(self, value, group_errors):
104 | """
105 | Validates a value against the passed in validators
106 |
107 | Args:
108 | value (any): value to be validated
109 | group_errors (bool): flag that indicates if error messages are to be grouped together
110 | (if set to False, validation will end on first error)
111 |
112 | Returns:
113 | A list of errors
114 | """
115 | errors = []
116 |
117 | if self._validators:
118 | for validator in self._validators:
119 | if not validator.validate(value):
120 | if hasattr(validator, "_error_message"):
121 | errors.append(validator.message(value))
122 | else: # calling the validator statically
123 | errors.append(validator.ERROR_MESSAGE.format(value=value))
124 | if not group_errors:
125 | return errors
126 |
127 | return errors
128 |
129 |
130 | class Parameter(ValidatedParameter, BaseParameter):
131 | """Class used to encapsulate the extract methods parameter data."""
132 |
133 | def __init__(self, path="", func_param_name=None, validators=None, var_name=None, default=None, transform=None): # noqa: pylint - too-many-arguments
134 | """
135 | Sets the private variables of the Parameter object.
136 |
137 | Args:
138 | path (str): The path to the variable we want to extract. Can use any annotation that has an existing
139 | equivalent decode function in decoders.py (like [jwt] or [json]).
140 | As an example, given the dictionary
141 |
142 | {
143 | "a": {
144 | "b": "{'c': 'hello'}",
145 | }
146 | }
147 |
148 | the path to c is "a/b[json]/c"
149 | func_param_name (str): the name for the dictionary in the function signature
150 | def fun(event, context). To extract from context func_param_name has to be "context"
151 | validators (list): A list of validators the value must conform to (e.g. Mandatory,
152 | RegexValidator(my_regex), ...)
153 | var_name (str): Optional, the name of the variable we want to assign the extracted value to. The default
154 | value is the last element of the path (e.g. "c" in the case above)
155 | default (any): Optional, a default value if the value is missing and not mandatory.
156 | The default value is None
157 | transform (function): Optional, a function to apply to the extracted value before checking validation rules.
158 | """
159 | self._path = path
160 | self._default = default
161 | self._transform = transform
162 | ValidatedParameter.__init__(self, func_param_name, validators)
163 | BaseParameter.__init__(self, var_name)
164 |
165 | @property
166 | def path(self):
167 | """Getter for the path parameter."""
168 | return self._path
169 |
170 | def extract_value(self, dict_value):
171 | """
172 | Calculate and decode the value of the variable in the given path.
173 |
174 | Used by the extract_validated_value.
175 |
176 | Args:
177 | dict_value (dict): dictionary to be parsed.
178 |
179 | Returns:
180 | The extracted value
181 | """
182 | for path_key in filter(lambda item: item != "", self._path.split(PATH_DIVIDER)):
183 | real_key, annotation = Parameter.get_annotations_from_key(path_key)
184 | if dict_value and real_key in dict_value:
185 | dict_value = decode(annotation, dict_value[real_key])
186 | else:
187 | dict_value = self._default
188 |
189 | if not self._name:
190 | self._name = real_key
191 |
192 | if dict_value and self._transform:
193 | dict_value = self._transform(dict_value)
194 |
195 | return dict_value
196 |
197 | def validate_path(self, value, group_errors=False):
198 | """
199 | Validates a value against the passed in validators
200 |
201 | Args:
202 | value: value to be validated
203 | group_errors (bool): flag that indicates if error messages are to be grouped together
204 | (if set to False, validation will end on first error)
205 |
206 | Returns:
207 | A list of validation key/pair errors
208 | """
209 | key = self._path.split(PATH_DIVIDER)[-1]
210 |
211 | errors = self.validate(value, group_errors)
212 |
213 | return {key: errors} if errors else {}
214 |
215 | @staticmethod
216 | def get_annotations_from_key(key):
217 | """
218 | Extract the key and the encoding type (annotation) from the string.
219 |
220 | Args:
221 | key (str): a combined string to extract key and annotation from. e.g. ("key[jwt]" -> "key", "jwt",
222 | "key" -> "key", None)
223 | """
224 | if ANNOTATIONS_START in key and ANNOTATIONS_END in key:
225 | annotation = key[key.find(ANNOTATIONS_START) + 1:key.find(ANNOTATIONS_END)]
226 | return key.replace(ANNOTATIONS_START + annotation + ANNOTATIONS_END, ""), annotation
227 | return key, None
228 |
--------------------------------------------------------------------------------
/aws_lambda_decorators/decoders.py:
--------------------------------------------------------------------------------
1 | """Decoder abstractions and functions for decoding/converting a string with a given annotation to a dictionary."""
2 | import functools
3 | import json
4 | import logging
5 | import sys
6 | import jwt
7 |
8 | LOGGER = logging.getLogger()
9 | LOGGER.setLevel(logging.INFO)
10 |
11 | DECODE_FUNC_NAME = "decode_%s"
12 | DECODE_FUNC_MISSING_ERROR = "Missing decode function for annotation: %s"
13 |
14 |
15 | def decode(annotation, value):
16 | """
17 | Converts an annotated string to a python dictionary.
18 |
19 | If :annotation: is not empty, use decode_:annotation:(:value:) to convert to dictionary.
20 |
21 | Existing decoders:
22 | annotation decoder
23 | [json] decode_json
24 | [jwt] decode_jwt
25 | [n] where n is a number. Decodes the value as an array, and picks item n from the array
26 |
27 | Args:
28 | annotation (str): the type of encoding of the value (e.g. 'json', 'jwt').
29 | value (str): the value to be converted from given annotation to a dictionary.
30 |
31 | Returns:
32 | decoded dictionary.
33 | """
34 | if annotation:
35 | if annotation.isdigit():
36 | return value[int(annotation)]
37 |
38 | module_name = sys.modules[__name__]
39 | func_name = DECODE_FUNC_NAME % annotation
40 | if hasattr(module_name, func_name):
41 | func = getattr(module_name, func_name)
42 | return func(value)
43 |
44 | LOGGER.error(DECODE_FUNC_MISSING_ERROR, annotation)
45 |
46 | return value
47 |
48 |
49 | @functools.lru_cache()
50 | def decode_json(value):
51 | """Convert a json to a dictionary."""
52 | return json.loads(value)
53 |
54 |
55 | @functools.lru_cache()
56 | def decode_jwt(value):
57 | """Convert a jwt to a dictionary."""
58 | return jwt.decode(value, verify=False)
59 |
--------------------------------------------------------------------------------
/aws_lambda_decorators/decorators.py:
--------------------------------------------------------------------------------
1 | """
2 | AWS lambda decorators.
3 |
4 | A set of Python decorators to ease the development of AWS lambda functions.
5 |
6 | """
7 | import json
8 | from http import HTTPStatus
9 | import boto3
10 | from aws_lambda_decorators.utils import (full_name, all_func_args, find_key_case_insensitive, failure, get_logger,
11 | find_websocket_connection_id, get_websocket_endpoint)
12 |
13 |
14 | LOGGER = get_logger(__name__)
15 |
16 |
17 | PARAM_LOG_MESSAGE = "Function: %s, Parameters: %s"
18 | RESPONSE_LOG_MESSAGE = "Function: %s, Response: %s"
19 | EXCEPTION_LOG_MESSAGE = "%s: %s in argument %s for path %s"
20 | EXCEPTION_LOG_MESSAGE_PATHLESS = "%s: %s in argument %s"
21 | ERROR_MESSAGE = "Error extracting parameters"
22 | VALIDATE_ERROR_MESSAGE = "Error validating parameters. Errors: %s"
23 | NON_SERIALIZABLE_ERROR_MESSAGE = "Response body is not JSON serializable"
24 | CORS_INVALID_TYPE_ERROR = "Invalid value type in CORS header"
25 | CORS_NON_DICT_ERROR = "Invalid response type for CORS headers"
26 | CORS_INVALID_TYPE_LOG_MESSAGE = "Cannot set %s header to a non %s value"
27 | NON_DICT_LOG_MESSAGE = "Cannot add headers to a non dictionary response"
28 | HSTS_NON_DICT_ERROR = "Invalid response type for HSTS header"
29 |
30 | UNKNOWN = "Unknown"
31 |
32 |
33 | def extract_from_event(parameters, group_errors=False, allow_none_defaults=False):
34 | """
35 | Extracts a set of parameters from the event dictionary in a lambda handler.
36 |
37 | The extracted parameters are added as kwargs to the handler function.
38 |
39 | Usage:
40 | @extract_from_event([Parameter(path="/body[json]/my_param")])
41 | def lambda_handler(event, context, my_param=None)
42 | pass
43 |
44 | Args:
45 | parameters (list): A collection of Parameter type items.
46 | group_errors (bool): flag that indicates if error messages are to be grouped together
47 | (if set to False, validation will end on first error)
48 | """
49 | for param in parameters:
50 | param.func_param_name = "event"
51 | return extract(parameters, group_errors, allow_none_defaults)
52 |
53 |
54 | def extract_from_context(parameters, group_errors=False, allow_none_defaults=False):
55 | """
56 | Extracts a set of parameters from the context dictionary in a lambda handler.
57 |
58 | The extracted parameters are added as kwargs to the handler function.
59 |
60 | Usage:
61 | @extract_from_context([Parameter(path="/parent/my_param")])
62 | def lambda_handler(event, context, my_param=None)
63 | pass
64 |
65 | Args:
66 | parameters (list): A collection of Parameter type items.
67 | group_errors (bool): flag that indicates if error messages are to be grouped together
68 | (if set to False, validation will end on first error)
69 | """
70 | for param in parameters:
71 | param.func_param_name = "context"
72 | return extract(parameters, group_errors, allow_none_defaults)
73 |
74 |
75 | def extract(parameters, group_errors=False, allow_none_defaults=False):
76 | """
77 | Extracts a set of parameters from any function parameter passed to an AWS lambda handler.
78 |
79 | The extracted parameters are added as kwargs to the handler function.
80 |
81 | Usage:
82 | @extract([Parameter(path="headers/Authorization[jwt]/sub", var_name="user_id", func_param_name="event")])
83 | def lambda_handler(event, context, user_id=None)
84 | pass
85 |
86 | Args:
87 | parameters (list): A collection of Parameter type items.
88 | group_errors (bool): flag that indicates if error messages are to be grouped together
89 | (if set to False, validation will end on first error)
90 | allow_none_defaults: A flag to allow None defaults. If True, None defaults will be passed into the kwargs.
91 | If the flag is set to False, the None defaults will not be added to kwargs, and the default will be
92 | picked up (if exists) from the method signature.
93 | """
94 | def decorator(func):
95 | def wrapper(*args, **kwargs):
96 | try:
97 | param = None
98 | errors = []
99 | arg_dictionary = all_func_args(func, args, kwargs)
100 | for param in parameters:
101 | param_val = arg_dictionary[param.func_param_name]
102 | return_val = param.extract_value(param_val)
103 | param_errors = param.validate_path(return_val, group_errors)
104 | if param_errors:
105 | errors.append(param_errors)
106 | if not group_errors:
107 | LOGGER.error(VALIDATE_ERROR_MESSAGE, errors)
108 | return failure(errors)
109 | elif allow_none_defaults or return_val is not None:
110 | kwargs[param.get_var_name()] = return_val
111 | except Exception as ex: # noqa: pylint - broad-except
112 | LOGGER.error(EXCEPTION_LOG_MESSAGE, full_name(ex), str(ex),
113 | param.func_param_name if param else UNKNOWN,
114 | param.path if param else UNKNOWN)
115 | return failure(ERROR_MESSAGE)
116 | else:
117 | if group_errors and errors:
118 | LOGGER.error(VALIDATE_ERROR_MESSAGE, errors)
119 | return failure(errors)
120 |
121 | return func(*args, **kwargs)
122 | return wrapper
123 | return decorator
124 |
125 |
126 | def handle_exceptions(handlers):
127 | """
128 | Handles exceptions thrown by the wrapped/decorated function.
129 |
130 | Usage:
131 | @handle_exceptions([ExceptionHandler(exception=KeyError, friendly_message="Your message on KeyError except")]).
132 | def lambda_handler(params)
133 | pass
134 |
135 | Args:
136 | handlers (list): A collection of ExceptionHandler type items.
137 | """
138 | def decorator(func):
139 | def wrapper(*args, **kwargs):
140 | try:
141 | return func(*args, **kwargs)
142 | except tuple(handler.exception for handler in handlers) as ex: # noqa: pylint - catching-non-exception
143 | failed_handler = [handler for handler in handlers if isinstance(ex, handler.exception)][0]
144 | message = failed_handler.friendly_message
145 |
146 | if message and str(ex):
147 | LOGGER.error("%s: %s", message, str(ex))
148 | else:
149 | LOGGER.error(message if message else str(ex))
150 |
151 | return failure(message if message else str(ex), failed_handler.status_code)
152 | return wrapper
153 | return decorator
154 |
155 |
156 | def log(parameters=False, response=False):
157 | """
158 | Log parameters and/or response of the wrapped/decorated function using logging package
159 |
160 | Args:
161 | parameters: a flag indicating if the input parameters are to be logged
162 | response: a flag indicating if the returned response is to be logged
163 | """
164 | def decorator(func):
165 | def wrapper(*args, **kwargs):
166 | if parameters:
167 | LOGGER.info(PARAM_LOG_MESSAGE, func.__name__, args)
168 | func_response = func(*args, **kwargs)
169 | if response:
170 | LOGGER.info(RESPONSE_LOG_MESSAGE, func.__name__, func_response)
171 | return func_response
172 | return wrapper
173 | return decorator
174 |
175 |
176 | def extract_from_ssm(ssm_parameters):
177 | """
178 | Load given ssm parameters from AWS parameter store to the handler variables.
179 |
180 | Usage:
181 | @extract_from_ssm([SSMParameter(ssm_name="key", var_name="var")])
182 | def lambda_handler(var=None)
183 | pass
184 |
185 | Args:
186 | ssm_parameters (list): A collection of SSMParameter type items.
187 | """
188 | def decorator(func):
189 | def wrapper(*args, **kwargs):
190 | ssm = boto3.client("ssm")
191 | server_key_containers = ssm.get_parameters(
192 | Names=[ssm_parameter.get_ssm_name() for ssm_parameter in ssm_parameters],
193 | WithDecryption=True)
194 | for key_container in server_key_containers["Parameters"]:
195 | for ssm_parameter in ssm_parameters: # pragma: no cover
196 | if ssm_parameter.get_ssm_name() == key_container["Name"]:
197 | kwargs[ssm_parameter.get_var_name()] = key_container["Value"]
198 | break
199 | return func(*args, **kwargs)
200 | return wrapper
201 | return decorator
202 |
203 |
204 | def response_body_as_json(func):
205 | """
206 | Convert the dictionary response of the wrapped/decorated function to a json string literal.
207 |
208 | Usage:
209 | @response_body_as_json
210 | def lambda_handler():
211 | return {"statusCode": 200, "body": {"key": "value"}}
212 |
213 | will return {"statusCode": 200, "body": "{"key":"value"}"}
214 | """
215 | def wrapper(*args, **kwargs):
216 | response = func(*args, **kwargs)
217 | if "body" in response:
218 | try:
219 | response["body"] = json.dumps(response["body"])
220 | except TypeError:
221 | return failure(NON_SERIALIZABLE_ERROR_MESSAGE, 500)
222 | return response
223 | return wrapper
224 |
225 |
226 | def validate(parameters, group_errors=False):
227 | """
228 | Validates a set of function parameters.
229 |
230 | Usage:
231 | @validate([ValidatedParameter(func_param_name="my_param", validators=[...])])
232 | def func(my_param)
233 | pass
234 |
235 | Args:
236 | parameters (list): A collection of ValidatedParameter type items.
237 | group_errors (bool): flag that indicates if error messages are to be grouped together
238 | (if set to False, validation will end on first error)
239 | """
240 | def decorator(func):
241 | def wrapper(*args, **kwargs):
242 | try:
243 | errors = []
244 | arg_dictionary = all_func_args(func, args, kwargs)
245 | for param in parameters:
246 | param_val = arg_dictionary[param.func_param_name]
247 | param_errors = param.validate(param_val, group_errors)
248 | if param_errors:
249 | errors.append({param.func_param_name: param_errors})
250 | if not group_errors:
251 | LOGGER.error(VALIDATE_ERROR_MESSAGE, errors)
252 | return failure(errors)
253 | except Exception as ex: # noqa: pylint - broad-except
254 | LOGGER.error(EXCEPTION_LOG_MESSAGE_PATHLESS, full_name(ex), str(ex), param.func_param_name)
255 | return failure(ERROR_MESSAGE)
256 |
257 | if group_errors and errors:
258 | LOGGER.error(VALIDATE_ERROR_MESSAGE, errors)
259 | return failure(errors)
260 |
261 | return func(*args, **kwargs)
262 | return wrapper
263 | return decorator
264 |
265 |
266 | def handle_all_exceptions():
267 | """
268 | Handles all exceptions thrown by the wrapped/decorated function.
269 |
270 | Usage:
271 | @handle_all_exceptions()
272 | def lambda_handler(params)
273 | pass
274 | """
275 | def decorator(func):
276 | def wrapper(*args, **kwargs):
277 | try:
278 | return func(*args, **kwargs)
279 | except Exception as ex: # noqa: pylint - catching-non-exception
280 | LOGGER.error(str(ex))
281 | return failure(str(ex))
282 | return wrapper
283 | return decorator
284 |
285 |
286 | def cors(allow_origin=None, allow_methods=None, allow_headers=None, max_age=None):
287 | """
288 | Adds CORS headers to the response of the decorated function
289 |
290 | Usage:
291 | @cors(allow_origin="http://example.com", allow_methods="POST,GET", allow_headers="Content-Type", max_age=86400)
292 | def func(my_param)
293 | pass
294 |
295 | Args:
296 | allow_origin: A string containing the comma-separated list of allowed origins
297 | allow_methods: A string containing the comma-separated list of allowed methods
298 | allow_headers: A string containing the comma-separated list of allowed headers
299 | max_age: An integer to indicate the caching time (in seconds) for the CORS pre-flight request
300 |
301 | Returns:
302 | The original decorated function response with the additional cors headers
303 | """
304 | def decorator(func):
305 | def wrapper(*args, **kwargs):
306 |
307 | def update_header(headers, header_name, value, value_type):
308 | if value:
309 | if isinstance(value, value_type):
310 | header_key = find_key_case_insensitive(header_name, headers)
311 | headers[header_key] = f"{headers[header_key]},{value}" if header_key in headers else value
312 | else:
313 | LOGGER.error(CORS_INVALID_TYPE_LOG_MESSAGE, header_name, value_type)
314 | raise TypeError
315 |
316 | return headers
317 |
318 | response = func(*args, **kwargs)
319 |
320 | if isinstance(response, dict):
321 | headers_key = find_key_case_insensitive("headers", response)
322 |
323 | resp_headers = response[headers_key] if headers_key in response else {}
324 |
325 | try:
326 | resp_headers = update_header(resp_headers, "access-control-allow-origin", allow_origin, str)
327 | resp_headers = update_header(resp_headers, "access-control-allow-methods", allow_methods, str)
328 | resp_headers = update_header(resp_headers, "access-control-allow-headers", allow_headers, str)
329 | resp_headers = update_header(resp_headers, "access-control-max-age", max_age, int)
330 |
331 | response[headers_key] = resp_headers
332 | return response
333 | except TypeError:
334 | return failure(CORS_INVALID_TYPE_ERROR, HTTPStatus.INTERNAL_SERVER_ERROR)
335 | else:
336 | LOGGER.error(NON_DICT_LOG_MESSAGE)
337 | return failure(CORS_NON_DICT_ERROR, HTTPStatus.INTERNAL_SERVER_ERROR)
338 | return wrapper
339 | return decorator
340 |
341 |
342 | def push_ws_errors(websocket_endpoint_url: str):
343 | """
344 | Handles and pushes any unsuccessful responses as errors to the calling client via websockets
345 |
346 | Usage:
347 | @push_ws_errors('https://api_id.execute_id.region.amazonaws.com/Prod')
348 | @handle_all_exceptions()
349 | def handler(event, context):
350 | return {
351 | 'statusCode': 400,
352 | 'body': {
353 | 'message': 'Bad request'
354 | }
355 | }
356 |
357 | Args:
358 | websocket_endpoint_url (str): The api gateway connection URL
359 |
360 | Returns:
361 | the original response from the lambda handler
362 | """
363 | def decorator(func):
364 | def wrapper(*args, **kwargs):
365 | connection_id = find_websocket_connection_id(args)
366 |
367 | response = func(*args, **kwargs)
368 | success = response.get("statusCode", HTTPStatus.INTERNAL_SERVER_ERROR).value < 300
369 |
370 | if connection_id and not success:
371 | websocket_endpoint = get_websocket_endpoint(websocket_endpoint_url)
372 |
373 | ws_response = {
374 | "type": "error",
375 | "statusCode": response.get("statusCode", HTTPStatus.INTERNAL_SERVER_ERROR),
376 | "message": json.loads(response.get("body", "{}")).get("message")
377 | }
378 |
379 | websocket_endpoint.post_to_connection(
380 | ConnectionId=connection_id,
381 | Data=json.dumps(ws_response)
382 | )
383 |
384 | return response
385 | return wrapper
386 | return decorator
387 |
388 |
389 | def push_ws_response(websocket_endpoint_url: str):
390 | """
391 | Handles and pushes all responses to the calling client via websockets
392 |
393 | Usage:
394 | @push_ws_response('https://api_id.execute_id.region.amazonaws.com/Prod')
395 | def handler(event, context):
396 | return {
397 | 'statusCode': 200,
398 | 'body': 'Hello, world!'
399 | }
400 |
401 | Args:
402 | websocket_endpoint_url (str): The api gateway connection URL
403 |
404 | Returns:
405 | the original response from the lambda handler
406 | """
407 | def decorator(func):
408 | def wrapper(*args, **kwargs):
409 | connection_id = find_websocket_connection_id(args)
410 |
411 | response = func(*args, **kwargs)
412 |
413 | if connection_id:
414 | websocket_endpoint = get_websocket_endpoint(websocket_endpoint_url)
415 |
416 | websocket_endpoint.post_to_connection(
417 | ConnectionId=connection_id,
418 | Data=json.dumps(response)
419 | )
420 |
421 | return response
422 | return wrapper
423 | return decorator
424 |
425 |
426 | # pylint:disable=no-else-return
427 | def hsts(max_age: int = None):
428 | """
429 | Adds HSTS header to the response of the decorated function
430 |
431 | Usage:
432 | @hsts(max_age=86400)
433 | def func(my_param)
434 | pass
435 |
436 | Args:
437 | max_age: An integer to indicate the time browser should remember your domain as HTTPS only communication.
438 | If not specified default value of 2 years is used.
439 |
440 | Returns:
441 | The original decorated function response with the additional hsts header
442 | """
443 | def decorator(func):
444 | def wrapper(*args, **kwargs):
445 | response = func(*args, **kwargs)
446 |
447 | if isinstance(response, dict):
448 | headers_key = find_key_case_insensitive("headers", response)
449 |
450 | resp_headers = response[headers_key] if headers_key in response else {}
451 |
452 | header_key = find_key_case_insensitive("Strict-Transport-Security", resp_headers)
453 | header_value = f"max-age={max_age}" if max_age else "max-age=63072000"
454 | resp_headers[header_key] = header_value
455 | response[headers_key] = resp_headers
456 | return response
457 | else:
458 | LOGGER.error(NON_DICT_LOG_MESSAGE)
459 | return failure(HSTS_NON_DICT_ERROR, HTTPStatus.INTERNAL_SERVER_ERROR)
460 | return wrapper
461 | return decorator
462 |
--------------------------------------------------------------------------------
/aws_lambda_decorators/utils.py:
--------------------------------------------------------------------------------
1 | """Utility functions."""
2 | from functools import lru_cache
3 | from http import HTTPStatus
4 | import inspect
5 | import json
6 | import keyword
7 | import logging
8 | import os
9 |
10 | import boto3
11 |
12 |
13 | LOG_LEVEL = getattr(logging, os.getenv("LOG_LEVEL", "INFO"))
14 |
15 |
16 | def get_logger(name):
17 | logger = logging.getLogger(name)
18 | logger.setLevel(LOG_LEVEL)
19 | return logger
20 |
21 |
22 | def full_name(class_type):
23 | """
24 | Gets the fully qualified name of a class type.
25 |
26 | From https://stackoverflow.com/questions/2020014/get-fully-qualified-class-name-of-an-object-in-python
27 |
28 | Args:
29 | class_type (type): the type of the class.
30 |
31 | Returns:
32 | the fully qualified name of the class type.
33 | """
34 | module = class_type.__class__.__module__
35 | if module is None or module == str.__class__.__module__:
36 | return class_type.__class__.__name__ # Avoid reporting __builtin__
37 | return f"{module}.{class_type.__class__.__name__}"
38 |
39 |
40 | def is_type_in_list(item_type, items):
41 | """
42 | Checks if there is an item of a given type in the list of items.
43 |
44 | Args:
45 | item_type (type): the type of the item.
46 | items (list): a list of items.
47 |
48 | Returns:
49 | true if an item of the given type exists in the list, otherwise false.
50 | """
51 | return any(isinstance(item, item_type) for item in items)
52 |
53 |
54 | def is_valid_variable_name(name):
55 | """
56 | Check if the given name is python allowed variable name.
57 |
58 | Args:
59 | name (str): the name to check.
60 |
61 | Returns:
62 | true if the name can be used as a python variable name.
63 | """
64 | return name.isidentifier() and not keyword.iskeyword(name)
65 |
66 |
67 | def all_func_args(func, args, kwargs):
68 | """
69 | Combine arguments and key word arguments to a dictionary.
70 |
71 | Args:
72 | func (function): function whose arguments should be extracted.
73 | args (list): list of function args (*args).
74 | kwargs (dict): dictionary of function kwargs (**kwargs).
75 |
76 | Returns:
77 | dictionary argument name -> argument value.
78 | """
79 | arg_spec = inspect.getfullargspec(func)[0]
80 | arg_dictionary = {arg_spec[idx]: value for idx, value in enumerate(args)}
81 | arg_dictionary.update(kwargs)
82 | return arg_dictionary
83 |
84 |
85 | def find_key_case_insensitive(key_name, the_dict):
86 | """
87 | Finds if a dictionary (the_dict) has a string key (key_name) in any string case
88 |
89 | Args:
90 | key_name: the key to search in the dictionary
91 | the_dict: the dictionary to search
92 |
93 | Returns:
94 | The found key name in its original case, if found. Otherwise, returns the searching key name
95 |
96 | """
97 | for key in the_dict:
98 | if key.lower() == key_name:
99 | return key
100 | return key_name
101 |
102 |
103 | def failure(errors, status_code=HTTPStatus.BAD_REQUEST):
104 | """
105 | Returns an error to the caller
106 |
107 | Args:
108 | errors (list): a list of errors to be returned
109 | status_code (int): the status code of the error
110 |
111 | Returns:
112 | An object that contains the status code and the list of errors
113 | """
114 | return {
115 | "statusCode": status_code,
116 | "body": json.dumps({"message": errors})
117 | }
118 |
119 |
120 | def find_websocket_connection_id(args: list) -> str:
121 | """
122 | Finds an API Gateway connection id from the event dictionary in the
123 | arguments of a lambda
124 |
125 | Args:
126 | args (list): a list of arguments from a lambda (*args)
127 |
128 | Returns:
129 | The connection id of a user as a string if found
130 | None if not
131 | """
132 | for arg in args:
133 | if isinstance(arg, dict) and "requestContext" in arg:
134 | return arg["requestContext"].get("connectionId")
135 | return None
136 |
137 |
138 | @lru_cache()
139 | def get_websocket_endpoint(endpoint_url: str) -> "botocore.client.ApiGatewayManagementApi": # noqa: pyflakes - F821
140 | """
141 | Gets an instance of ApiGatewayManagementApi for sending messages
142 | through websockets
143 |
144 | Args:
145 | endpoint_url (str): an api gateway connection url (ish)
146 |
147 | Returns:
148 | The api gateway management client (cached)
149 | """
150 | return boto3.client(
151 | "apigatewaymanagementapi",
152 | endpoint_url=endpoint_url
153 | )
154 |
--------------------------------------------------------------------------------
/aws_lambda_decorators/validators.py:
--------------------------------------------------------------------------------
1 | """Validation rules."""
2 | import datetime
3 | import re
4 | from schema import SchemaError
5 |
6 | CURRENCIES = {"LKR", "ETB", "RWF", "NZD", "SBD", "MKD", "NPR", "LAK", "KWD", "INR", "HUF", "AFN", "BTN", "ISK", "MVR",
7 | "WST", "MNT", "AZN", "SAR", "JMD", "BIF", "BMD", "CAD", "GEL", "MXN", "BHD", "HKD", "RSD", "PKR", "SLL",
8 | "NGN", "TOP", "SCR", "SVC", "CHW", "UYW", "IDR", "IQD", "THB", "GBP", "MYR", "SDG", "CNY", "GNF", "LRD",
9 | "KHR", "TJS", "BYN", "SHP", "AED", "BOB", "CUC", "PHP", "SSP", "USN", "MZN", "COP", "SEK", "EUR", "CDF",
10 | "CRC", "KMF", "JPY", "ZWL", "ALL", "GHS", "GIP", "QAR", "GYD", "HTG", "VUV", "CZK", "ANG", "AWG", "AMD",
11 | "DOP", "TRY", "ZMW", "MGA", "KZT", "XUA", "ARS", "XPF", "BRL", "MXV", "LSL", "CLP", "KES", "PYG", "TND",
12 | "MAD", "DZD", "MWK", "BSD", "BBD", "FKP", "KGS", "BWP", "CVE", "HRK", "DKK", "COU", "SYP", "LYD", "PLN",
13 | "TZS", "KPW", "UGX", "BOV", "UAH", "NAD", "AOA", "VES", "SOS", "CUP", "SGD", "PAB", "UZS", "STN", "SRD",
14 | "CHE", "XOF", "DJF", "PGK", "UYI", "XCD", "BZD", "EGP", "ERN", "RON", "TWD", "USD", "FJD", "VND", "SZL",
15 | "BND", "HNL", "KRW", "XAF", "MDL", "BDT", "MUR", "PEN", "OMR", "NIO", "TMT", "YER", "TTD", "GMD", "XDR",
16 | "CHF", "NOK", "GTQ", "JOD", "KYD", "UYU", "RUB", "ZAR", "AUD", "BGN", "MOP", "LBP", "MRU", "CLF", "XSU",
17 | "BAM", "MMK", "IRR", "ILS"}
18 |
19 |
20 | class Validator: # noqa: pylint - too-few-public-methods
21 | """Validation rule to check if the given mandatory value exists."""
22 | ERROR_MESSAGE = "Unknown error"
23 |
24 | def __init__(self, error_message, condition=None):
25 | """
26 | Validates a parameter
27 |
28 | Args:
29 | error_message (str): A custom error message to output if validation fails
30 | condition (any): A condition to validate
31 | """
32 | self._error_message = error_message or self.ERROR_MESSAGE
33 | self._condition = condition
34 |
35 | def message(self, value=None): # noqa: pylint - unused-argument
36 | """
37 | Gets the formatted error message for a failed mandatory check
38 |
39 | Args:
40 | value (any): The validated value
41 |
42 | Returns:
43 | The error message
44 | """
45 | return self._error_message.format(value=value, condition=self._condition)
46 |
47 |
48 | class Mandatory(Validator): # noqa: pylint - too-few-public-methods
49 | """Validation rule to check if the given mandatory value exists."""
50 | ERROR_MESSAGE = "Missing mandatory value"
51 |
52 | def __init__(self, error_message=None):
53 | """
54 | Checks if a parameter has a value
55 |
56 | Args:
57 | error_message (str): A custom error message to output if validation fails
58 | """
59 | super().__init__(error_message)
60 |
61 | @staticmethod
62 | def validate(value=None):
63 | """
64 | Check if the given mandatory value exists.
65 |
66 | Args:
67 | value (any): Value to be validated.
68 | """
69 | return value is not None and str(value)
70 |
71 |
72 | class RegexValidator(Validator): # noqa: pylint - too-few-public-methods
73 | """Validation rule to check if a value matches a regular expression."""
74 | ERROR_MESSAGE = "'{value}' does not conform to regular expression '{condition}'"
75 |
76 | def __init__(self, regex="", error_message=None):
77 | """
78 | Compile a regular expression to a regular expression pattern.
79 |
80 | Args:
81 | regex (str): Regular expression for parameter validation.
82 | error_message (str): A custom error message to output if validation fails
83 | """
84 | super().__init__(error_message, regex)
85 | self._regexp = re.compile(regex)
86 |
87 | def validate(self, value=None):
88 | """
89 | Check if a value adheres to the defined regular expression.
90 |
91 | Args:
92 | value (str): Value to be validated.
93 | """
94 | if value is None:
95 | return True
96 |
97 | return self._regexp.fullmatch(value) is not None
98 |
99 |
100 | class SchemaValidator(Validator): # noqa: pylint - too-few-public-methods
101 | """Validation rule to check if a value matches a regular expression."""
102 | ERROR_MESSAGE = "'{value}' does not validate against schema '{condition}'"
103 |
104 | def __init__(self, schema, error_message=None):
105 | """
106 | Set the schema field.
107 |
108 | Args:
109 | schema (Schema): The expected schema.
110 | error_message (str): A custom error message to output if validation fails
111 | """
112 | super().__init__(error_message, schema)
113 |
114 | def validate(self, value=None):
115 | """
116 | Check if the object adheres to the defined schema.
117 |
118 | Args:
119 | value (object): Value to be validated.
120 | """
121 | try:
122 | if value is None:
123 | return True
124 |
125 | return self._condition.validate(value) == value
126 | except SchemaError:
127 | return False
128 |
129 |
130 | class Minimum(Validator): # noqa: pylint - too-few-public-methods
131 | """Validation rule to check if a value is greater than a minimum value."""
132 | ERROR_MESSAGE = "'{value}' is less than minimum value '{condition}'"
133 |
134 | def __init__(self, minimum: (float, int), error_message=None):
135 | """
136 | Set the minimum value.
137 |
138 | Args:
139 | minimum (float, int): The minimum value.
140 | error_message (str): A custom error message to output if validation fails
141 | """
142 | super().__init__(error_message, minimum)
143 |
144 | def validate(self, value=None):
145 | """
146 | Check if the value is greater than the minimum.
147 |
148 | Args:
149 | value (float, int): Value to be validated.
150 | """
151 | if value is None:
152 | return True
153 |
154 | if isinstance(value, (float, int)):
155 | return self._condition <= value
156 |
157 | return False
158 |
159 |
160 | class Maximum(Validator): # noqa: pylint - too-few-public-methods
161 | """Validation rule to check if a value is less than a maximum value."""
162 | ERROR_MESSAGE = "'{value}' is greater than maximum value '{condition}'"
163 |
164 | def __init__(self, maximum: (float, int), error_message=None):
165 | """
166 | Set the maximum value.
167 |
168 | Args:
169 | maximum (float, int): The maximum value.
170 | error_message (str): A custom error message to output if validation fails
171 | """
172 | super().__init__(error_message, maximum)
173 |
174 | def validate(self, value=None):
175 | """
176 | Check if the value is less than the maximum.
177 |
178 | Args:
179 | value (float, int): Value to be validated.
180 | """
181 | if value is None:
182 | return True
183 |
184 | if isinstance(value, (float, int)):
185 | return self._condition >= value
186 |
187 | return False
188 |
189 |
190 | class MinLength(Validator): # noqa: pylint - too-few-public-methods
191 | """Validation rule to check if a string is shorter than a minimum length."""
192 | ERROR_MESSAGE = "'{value}' is shorter than minimum length '{condition}'"
193 |
194 | def __init__(self, min_length: int, error_message=None):
195 | """
196 | Set the minimum length.
197 |
198 | Args:
199 | min_length (int): The minimum length.
200 | error_message (str): A custom error message to output if validation fails
201 | """
202 | super().__init__(error_message, min_length)
203 |
204 | def validate(self, value=None):
205 | """
206 | Check if a string is shorter than the minimum length.
207 |
208 | Args:
209 | value (str): String to be validated.
210 | """
211 | if value is None:
212 | return True
213 |
214 | return len(str(value)) >= self._condition
215 |
216 |
217 | class MaxLength(Validator): # noqa: pylint - too-few-public-methods
218 | """Validation rule to check if a string is longer than a maximum length."""
219 | ERROR_MESSAGE = "'{value}' is longer than maximum length '{condition}'"
220 |
221 | def __init__(self, max_length: int, error_message=None):
222 | """
223 | Set the maximum length.
224 |
225 | Args:
226 | max_length (int): The maximum length.
227 | error_message (str): A custom error message to output if validation fails
228 | """
229 | super().__init__(error_message, max_length)
230 |
231 | def validate(self, value=None):
232 | """
233 | Check if a string is longer than the maximum length.
234 |
235 | Args:
236 | value (str): String to be validated.
237 | """
238 | if value is None:
239 | return True
240 |
241 | return len(str(value)) <= self._condition
242 |
243 |
244 | class Type(Validator):
245 | ERROR_MESSAGE = "'{value}' is not of type '{condition.__name__}'"
246 |
247 | def __init__(self, valid_type: type, error_message=None):
248 | """
249 | Set the valid type.
250 |
251 | Args:
252 | valid_type (type): The value type to check.
253 | error_message (str): A custom error message to output if validation fails
254 | """
255 | super().__init__(error_message, valid_type)
256 |
257 | def validate(self, value=None):
258 | """
259 | Check if a value is of the right type
260 |
261 | Args:
262 | value (object): object to be validated.
263 | """
264 | if value is None:
265 | return True
266 |
267 | return isinstance(value, self._condition)
268 |
269 |
270 | class EnumValidator(Validator):
271 | ERROR_MESSAGE = "'{value}' is not in list '{condition}'"
272 |
273 | def __init__(self, *args: list, error_message=None):
274 | """
275 | Set the list of valid values.
276 |
277 | Args:
278 | error_message (str): A custom error message to output if validation fails
279 | args (list): The list of valid values
280 | """
281 | super().__init__(error_message, args)
282 |
283 | def validate(self, value=None):
284 | """
285 | Check if a value is in a list of valid values
286 |
287 | Args:
288 | value (object): object to be validated.
289 | """
290 | if value is None:
291 | return True
292 |
293 | return value in self._condition
294 |
295 |
296 | class NonEmpty(Validator): # noqa: pylint - too-few-public-methods
297 | """Validation rule to check if the given value is empty."""
298 | ERROR_MESSAGE = "Value is empty"
299 |
300 | def __init__(self, error_message=None):
301 | """
302 | Checks if a parameter has a non empty value
303 |
304 | Args:
305 | error_message (str): A custom error message to output if validation fails
306 | """
307 | super().__init__(error_message)
308 |
309 | @staticmethod
310 | def validate(value=None):
311 | """
312 | Check if the given value is non empty.
313 |
314 | Args:
315 | value (any): Value to be validated.
316 | """
317 | if value is None or value in (0, 0.0, 0j):
318 | return True
319 |
320 | return bool(value)
321 |
322 |
323 | class DateValidator(Validator):
324 | """Validation rule to check if a string is a valid date according to some format."""
325 | ERROR_MESSAGE = "'{value}' is not a '{condition}' date"
326 |
327 | def __init__(self, date_format: str, error_message=None):
328 | """
329 | Checks if a string is a date with a given format
330 |
331 | Args:
332 | date_format (str): The date format to check against
333 | error_message (str): A custom error message to output if validation fails
334 | """
335 | super().__init__(error_message, date_format)
336 |
337 | def validate(self, value=None):
338 | """
339 | Check if a string is a date with a given format
340 |
341 | Args:
342 | value (str): string date to validate against a format
343 | """
344 | if value is None:
345 | return True
346 |
347 | try:
348 | datetime.datetime.strptime(value, self._condition)
349 | except ValueError:
350 | return False
351 | else:
352 | return True
353 |
354 |
355 | class CurrencyValidator(Validator):
356 | """Validation rule to check if a string is a valid currency according to ISO 4217 Currency Code."""
357 | ERROR_MESSAGE = "'{value}' is not a valid currency code."
358 |
359 | def __init__(self, error_message=None):
360 | """
361 | Checks if a string is a valid currency based on ISO 4217
362 |
363 | Args:
364 | error_message (str): A custom error message to output if validation fails
365 | """
366 | super().__init__(error_message)
367 |
368 | @staticmethod
369 | def validate(value=None):
370 | """
371 | Check if a string is a valid currency based on ISO 4217
372 |
373 | Args:
374 | value (str): value to validate against a ISO 4217
375 | """
376 |
377 | if value is None:
378 | return True
379 |
380 | return value.upper() in CURRENCIES
381 |
--------------------------------------------------------------------------------
/buildspec.yml:
--------------------------------------------------------------------------------
1 | version: 0.2
2 |
3 | phases:
4 | install:
5 | commands:
6 | - pip install --upgrade pip
7 | - pip install -q boto3 bandit coverage==4.5.4 schema pylint_quotes prospector==1.3.1 PyJWT==1.7.1
8 | pre_build:
9 | commands:
10 | - export LOG_LEVEL=CRITICAL
11 | - export OUR_COMMIT_SHA=`git rev-parse HEAD`
12 | - bandit -r -q .
13 | - prospector
14 | - coverage run --source='.' -m unittest
15 | - coverage report -m --fail-under=100 --omit=*/__init__.py,tests/*,setup.py,examples/*
16 |
--------------------------------------------------------------------------------
/examples/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/gridsmartercities/aws-lambda-decorators/16dbe6ae1b9982f312d593336682c4ebbcd4f52d/examples/__init__.py
--------------------------------------------------------------------------------
/examples/examples.py:
--------------------------------------------------------------------------------
1 | # flake8: noqa
2 | import boto3
3 | from boto3.dynamodb.conditions import Key
4 | from botocore.exceptions import ClientError
5 | from schema import Schema, Or
6 | from aws_lambda_decorators import (extract, extract_from_event, extract_from_context, extract_from_ssm, validate, log,
7 | handle_exceptions, response_body_as_json, Parameter, SSMParameter,
8 | ValidatedParameter, ExceptionHandler, Mandatory, RegexValidator,
9 | handle_all_exceptions, cors, SchemaValidator, Maximum, Minimum, Type, EnumValidator,
10 | NonEmpty, DateValidator, CurrencyValidator, hsts)
11 |
12 |
13 | @extract(parameters=[
14 | # extracts a non mandatory my_param from a_dictionary
15 | Parameter(path="/parent/my_param", func_param_name="a_dictionary"),
16 | # extracts a non mandatory missing_non_mandatory from a_dictionary
17 | Parameter(path="/parent/missing_non_mandatory", func_param_name="a_dictionary"),
18 | # does not fail as the parameter is not validated as mandatory
19 | Parameter(path="/parent/missing_mandatory", func_param_name="a_dictionary"),
20 | # extracts a mandatory id as "user_id" from another_dictionary
21 | Parameter(path="/parent/child/id", validators=[Mandatory], var_name="user_id",
22 | func_param_name="another_dictionary")
23 | ])
24 | def extract_example(a_dictionary, another_dictionary, my_param="aDefaultValue",
25 | missing_non_mandatory="I am missing", missing_mandatory=None, user_id=None):
26 | # you can now access the extracted parameters directly:
27 | return my_param, missing_non_mandatory, missing_mandatory, user_id
28 |
29 |
30 | @extract(parameters=[
31 | # extracts a non mandatory my_param from a_dictionary
32 | Parameter(path="/parent/my_param", func_param_name="a_dictionary")
33 | ])
34 | def extract_to_kwargs_example(a_dictionary, **kwargs):
35 | return kwargs["my_param"]
36 |
37 |
38 | @extract(parameters=[
39 | # extracts a mandatory my_param from a_dictionary
40 | Parameter(path="/parent/mandatory_param", func_param_name="a_dictionary", validators=[Mandatory])
41 | ])
42 | def extract_mandatory_param_example(a_dictionary, mandatory_param=None):
43 | return "Here!" # this part will never be reached, if the mandatory_param is missing
44 |
45 |
46 | @extract(parameters=[
47 | # extracts two mandatory parameters from a_dictionary
48 | Parameter(path="/parent/mandatory_param", func_param_name="a_dictionary", validators=[Mandatory]),
49 | Parameter(path="/parent/another_mandatory_param", func_param_name="a_dictionary", validators=[Mandatory]),
50 | Parameter(path="/parent/an_int", func_param_name="a_dictionary", validators=[Maximum(10), Minimum(5)])
51 | ], group_errors=True) # groups both errors together
52 | def extract_multiple_param_example(a_dictionary, mandatory_param=None, another_mandatory_param=None, an_int=0):
53 | return "Here!" # this part will never be reached, if the mandatory_param is missing
54 |
55 |
56 | @extract(parameters=[
57 | # extracts a non mandatory my_param from a_dictionary
58 | Parameter(path="/parent[json]/my_param", func_param_name="a_dictionary")
59 | ])
60 | def extract_from_json_example(a_dictionary, my_param=None):
61 | return my_param
62 |
63 |
64 | @extract_from_event(parameters=[
65 | # extracts a mandatory my_param from the json body of the event
66 | Parameter(path="/body[json]/my_param", validators=[Mandatory]),
67 | # extract the mandatory sub value as user_id from the authorization JWT
68 | Parameter(path="/headers/Authorization[jwt]/sub", validators=[Mandatory], var_name="user_id")
69 | ])
70 | def extract_from_event_example(event, context, my_param=None, user_id=None):
71 | return my_param, user_id # returns ("Hello!", "1234567890")
72 |
73 |
74 | @extract_from_context(parameters=[
75 | # extracts a mandatory my_param from the parent element in context
76 | Parameter(path="/parent/my_param", validators=[Mandatory])
77 | ])
78 | def extract_from_context_example(event, context, my_param=None):
79 | return my_param # returns "Hello!"
80 |
81 |
82 | @extract_from_ssm(ssm_parameters=[
83 | # extracts the value of one_key from SSM as a kwarg named "one_key"
84 | SSMParameter(ssm_name="one_key"),
85 | # extracts another_key as a kwarg named "another"
86 | SSMParameter(ssm_name="another_key", var_name="another")
87 | ])
88 | def extract_from_ssm_example(your_func_params, one_key=None, another=None):
89 | return your_func_params, one_key, another
90 |
91 |
92 | @validate(parameters=[
93 | # validates a_param as mandatory
94 | ValidatedParameter(func_param_name="a_param", validators=[Mandatory]),
95 | # validates another_param as mandatory and containing only digits
96 | ValidatedParameter(func_param_name="another_param", validators=[Mandatory, RegexValidator(r"\d+")]),
97 | # validates param_with_schema as an object with specified schema
98 | ValidatedParameter(func_param_name="param_with_schema", validators=[SchemaValidator(Schema({"a": Or(str, dict)}))])
99 | ])
100 | def validate_example(a_param, another_param, param_with_schema):
101 | return a_param, another_param, param_with_schema # returns a_param, another_param, param_with_schema
102 |
103 |
104 | @log(parameters=True, response=True)
105 | def log_example(parameters):
106 | return "Done!"
107 |
108 |
109 | @handle_exceptions(handlers=[
110 | ExceptionHandler(ClientError, "Your message when a client error happens.")
111 | ])
112 | def handle_exceptions_example():
113 | dynamodb = boto3.resource("dynamodb")
114 | table = dynamodb.Table("your_table_name")
115 | table.query(KeyConditionExpression=Key("user_id").eq("1234"))
116 | # ...
117 |
118 |
119 | @response_body_as_json
120 | def response_body_as_json_example():
121 | return {"statusCode": 400, "body": {"param": "hello!"}}
122 |
123 |
124 | @extract(parameters=[
125 | # extracts a non mandatory my_param from a_dictionary
126 | Parameter(path="/parent[1]/my_param", func_param_name="a_dictionary")
127 | ])
128 | def extract_from_list_example(a_dictionary, my_param=None):
129 | return my_param
130 |
131 |
132 | @handle_all_exceptions()
133 | def handle_all_exceptions_example():
134 | test_list = [1, 2, 3]
135 | test_list[5]
136 | # ...
137 |
138 |
139 | @cors(allow_origin="*", allow_methods="POST", allow_headers="Content-Type", max_age=86400)
140 | def cors_example():
141 | return {"statusCode": 200}
142 |
143 |
144 | @hsts()
145 | def hsts_example():
146 | return {'statusCode': 200}
147 |
148 |
149 | @extract(parameters=[
150 | Parameter(path="/parent/an_int", func_param_name="a_dictionary",
151 | validators=[Minimum(100, "Bad value {value}: should be at least {condition}")])
152 | ])
153 | def extract_minimum_param_with_custom_error_example(a_dictionary, an_int=None):
154 | return {}
155 |
156 |
157 | @extract(parameters=[
158 | Parameter(path="/params/my_param_1", func_param_name="a_dictionary"),
159 | Parameter(path="/params/my_param_2", func_param_name="a_dictionary")
160 | ])
161 | def extract_dictionary_example(a_dictionary, **kwargs):
162 | return kwargs
163 |
164 |
165 | @extract(parameters=[
166 | Parameter(path="/params/a_bool", func_param_name="a_dictionary", validators=[Type(bool)])
167 | ])
168 | def extract_type_param(a_dictionary, a_bool=False):
169 | return a_bool
170 |
171 |
172 | @extract(parameters=[
173 | Parameter(path="/params/an_enum", func_param_name="a_dictionary", validators=[EnumValidator("Hello", "Bye")])
174 | ])
175 | def extract_enum_param(a_dictionary, an_enum=None):
176 | return an_enum
177 |
178 |
179 | @extract(parameters=[
180 | Parameter(path="/params/non_empty", func_param_name="a_dictionary", validators=[NonEmpty])
181 | ])
182 | def extract_non_empty_param(a_dictionary, non_empty=None):
183 | return non_empty
184 |
185 |
186 | @extract(parameters=[
187 | Parameter(path="/params/date_example", func_param_name="a_dictionary",
188 | validators=[DateValidator("%Y-%m-%d %H:%M:%S")])
189 | ])
190 | def extract_date_param(a_dictionary, date_example=None):
191 | return date_example
192 |
193 |
194 | @extract(parameters=[
195 | Parameter(path="/params/my_param", func_param_name="a_dictionary", transform=int)
196 | ])
197 | def extract_with_transform_example(a_dictionary, my_param=None):
198 | return my_param
199 |
200 |
201 | def to_int(arg):
202 | try:
203 | return int(arg)
204 | except Exception:
205 | raise Exception("My custom error message")
206 |
207 |
208 | @extract(parameters=[
209 | Parameter(path="/params/my_param", func_param_name="a_dictionary", transform=to_int)
210 | ])
211 | def extract_with_custom_transform_example(a_dictionary, my_param=None):
212 | return {}
213 |
214 |
215 | @extract(parameters=[
216 | Parameter(path="/params/currency_example", func_param_name="a_dictionary",
217 | validators=[CurrencyValidator])
218 | ])
219 | def extract_currency_param(a_dictionary, currency_example=None):
220 | return currency_example
221 |
--------------------------------------------------------------------------------
/examples/test_examples.py:
--------------------------------------------------------------------------------
1 | import unittest
2 | from unittest.mock import patch, MagicMock, call
3 | from botocore.exceptions import ClientError
4 | from examples.examples import (extract_example, extract_to_kwargs_example, extract_mandatory_param_example,
5 | extract_from_json_example, extract_from_event_example, extract_from_context_example,
6 | extract_from_ssm_example, validate_example, log_example, handle_exceptions_example,
7 | response_body_as_json_example, extract_from_list_example, handle_all_exceptions_example,
8 | cors_example, extract_multiple_param_example,
9 | extract_minimum_param_with_custom_error_example, extract_dictionary_example,
10 | extract_type_param, extract_enum_param, extract_non_empty_param, extract_date_param,
11 | extract_with_transform_example, extract_with_custom_transform_example,
12 | extract_currency_param, hsts_example)
13 |
14 |
15 | # pylint:disable=too-many-public-methods
16 | class ExamplesTests(unittest.TestCase):
17 |
18 | def test_extract_example(self):
19 | # Given these two dictionaries:
20 | a_dict = {
21 | "parent": {
22 | "my_param": "Hello!"
23 | },
24 | "other": "other value"
25 | }
26 | b_dict = {
27 | "parent": {
28 | "child": {
29 | "id": "123"
30 | }
31 | }
32 | }
33 |
34 | # calling the decorated extract_example:
35 | response = extract_example(a_dict, b_dict)
36 |
37 | # will return the values of the extracted parameters.
38 | self.assertEqual(("Hello!", "I am missing", None, "123"), response)
39 |
40 | def test_extract_to_kwargs_example(self):
41 | # Given this dictionary:
42 | dictionary = {
43 | "parent": {
44 | "my_param": "Hello!"
45 | },
46 | "other": "other value"
47 | }
48 | # we can extract "my_param".
49 | response = extract_to_kwargs_example(dictionary)
50 |
51 | # and get the value from kwargs.
52 | self.assertEqual("Hello!", response)
53 |
54 | def test_extract_missing_mandatory_example(self):
55 | # Given this dictionary:
56 | dictionary = {
57 | "parent": {
58 | "my_param": "Hello!"
59 | },
60 | "other": "other value"
61 | }
62 | # we can try to extract a missing mandatory parameter.
63 | response = extract_mandatory_param_example(dictionary)
64 |
65 | # but we will get an error response as it is missing and it was mandatory.
66 | self.assertEqual(
67 | {"statusCode": 400, "body": "{\"message\": [{\"mandatory_param\": [\"Missing mandatory value\"]}]}"},
68 | response)
69 |
70 | def test_extract_multiple_param_example(self):
71 | # Given this dictionary:
72 | dictionary = {
73 | "parent": {
74 | "my_param": "Hello!",
75 | "an_int": 20
76 | },
77 | "other": "other value"
78 | }
79 | # we can try to extract a missing mandatory parameter.
80 | response = extract_multiple_param_example(dictionary)
81 |
82 | # but we will get an error response as it is missing and it was mandatory.
83 | self.assertEqual({"statusCode": 400,
84 | "body": "{\"message\": [{\"mandatory_param\": [\"Missing mandatory value\"]}, "
85 | "{\"another_mandatory_param\": [\"Missing mandatory value\"]}, "
86 | "{\"an_int\": [\"'20' is greater than maximum value '10'\"]}]}"}, response)
87 |
88 | def test_extract_not_missing_mandatory_example(self):
89 | # Given this dictionary:
90 | dictionary = {
91 | "parent": {
92 | "mandatory_param": "Hello!"
93 | }
94 | }
95 |
96 | # we can try to extract a mandatory parameter.
97 | response = extract_mandatory_param_example(dictionary)
98 |
99 | # we will get the coded lambda response as the parameter is not missing.
100 | self.assertEqual("Here!", response)
101 |
102 | def test_extract_from_json_example(self):
103 | # Given this dictionary:
104 | dictionary = {
105 | "parent": "{\"my_param\": \"Hello!\"}",
106 | "other": "other value"
107 | }
108 |
109 | # we can extract from a json string by adding the [json] annotation to parent.
110 | response = extract_from_json_example(dictionary)
111 |
112 | # and we will get the value of the "my_param" parameter inside the "parent" json.
113 | self.assertEqual("Hello!", response)
114 |
115 | def test_extract_from_event_example(self):
116 | # Given this API Gateway event:
117 | event = {
118 | "body": "{\"my_param\": \"Hello!\"}",
119 | "headers": {
120 | "Authorization": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9l"
121 | "IiwiaWF0IjoxNTE2MjM5MDIyfQ.SflKxwRJSMeKKF2QT4fwpMeJf36POk6yJV_adQssw5c"
122 | }
123 | }
124 |
125 | # we can extract "my_param" from event, and "sub" from Authorization JWT
126 | # using the "extract_from_event" decorator.
127 | response = extract_from_event_example(event, None)
128 |
129 | # and return both values in the lambda response.
130 | self.assertEqual(("Hello!", "1234567890"), response)
131 |
132 | def test_extract_from_context_example(self):
133 | # Given this context:
134 | context = {
135 | "parent": {
136 | "my_param": "Hello!"
137 | }
138 | }
139 |
140 | # we can extract "my_param" from context using the "extract_from_context" decorator.
141 | response = extract_from_context_example(None, context)
142 |
143 | # and return the value in the lambda response.
144 | self.assertEqual("Hello!", response)
145 |
146 | @patch("boto3.client")
147 | def test_extract_from_ssm_example(self, mock_boto_client):
148 | # Mocking the extraction of an SSM parameter from AWS SSM:
149 | mock_ssm = MagicMock()
150 | mock_ssm.get_parameters.return_value = {
151 | "Parameters": [
152 | {
153 | "Value": "test1",
154 | "Name": "one_key"
155 | },
156 | {
157 | "Value": "test2",
158 | "Name": "another_key"
159 | }
160 | ]
161 | }
162 | mock_boto_client.return_value = mock_ssm
163 |
164 | # we can extract the value of that parameter.
165 | response = extract_from_ssm_example(None)
166 |
167 | # and return the mocked SSM parameters from the lambda.
168 | self.assertEqual((None, "test1", "test2"), response)
169 |
170 | def test_validate_example(self):
171 | # We can validate non-dictionary parameters too, using the "validate" decorator.
172 | response = validate_example("Hello!", "123456", {"a": {"b": "c"}})
173 |
174 | # in this case the parameters are valid and are returned by the function.
175 | self.assertEqual(("Hello!", "123456", {"a": {"b": "c"}}), response)
176 |
177 | def test_validate_raises_exception_example(self):
178 | # We can validate non-dictionary parameters too, using the "validate" decorator.
179 | response = validate_example("Hello!", "123456", {"a": 123456})
180 |
181 | # in this case at least one parameter is not valid and a 400 error is returned to the caller.
182 | self.assertEqual({
183 | "statusCode": 400,
184 | "body": "{\"message\": [{\"param_with_schema\": [\"'{'a': 123456}' does not validate against schema"
185 | " 'Schema({'a': Or(, )})'\"]}]}"}, response)
186 |
187 | @staticmethod
188 | @patch("aws_lambda_decorators.decorators.LOGGER")
189 | def test_log_example(mock_logger):
190 | # We can use the "log" decorator to log the parameters passed to a lambda and/or the response from the lambda.
191 | log_example("Hello!") # logs "Hello!" and "Done!"
192 |
193 | # and check the log messages were produced.
194 | mock_logger.info.assert_has_calls([
195 | call("Function: %s, Parameters: %s", "log_example", ("Hello!",)),
196 | call("Function: %s, Response: %s", "log_example", "Done!")
197 | ])
198 |
199 | @patch("boto3.resource")
200 | def test_handle_exceptions_example(self, mock_dynamo):
201 | # Mocking the dynamo query to return a ClientError.
202 | mock_table = MagicMock()
203 | client_error = ClientError({}, "")
204 | mock_table.query.side_effect = client_error
205 |
206 | mock_dynamo.return_value.Table.return_value = mock_table
207 |
208 | # we can automatically handle the ClientError, using the "exception_handler" decorator.
209 | response = handle_exceptions_example() # noqa: pylint - assignment-from-no-return
210 |
211 | # and return the error supplied to the caller.
212 | self.assertEqual(response["statusCode"], 400)
213 | self.assertEqual(response["body"], "{\"message\": \"Your message when a client error happens.\"}")
214 |
215 | def test_response_as_json_example(self):
216 | # We can automatically json dump a body dictionary:
217 | response = response_body_as_json_example()
218 |
219 | # the response body is a string.
220 | self.assertEqual("{\"param\": \"hello!\"}", response["body"])
221 |
222 | def test_extract_from_list_example(self):
223 | # Given this dictionary:
224 | dictionary = {
225 | "parent": [
226 | {"my_param": "Hello!"},
227 | {"my_param": "Bye!"}
228 | ],
229 | "other": "other value"
230 | }
231 |
232 | # we can extract from a json string by adding the [json] annotation to parent.
233 | response = extract_from_list_example(dictionary)
234 |
235 | # and we will get the value of the "my_param" parameter inside the "parent" correct item.
236 | self.assertEqual("Bye!", response)
237 |
238 | def test_handle_all_exceptions_example(self):
239 | # we can automatically handle any exceptions, using the "handle_all_exceptions" decorator.
240 | response = handle_all_exceptions_example() # noqa: pylint - assignment-from-no-return
241 |
242 | # and return the error to the caller.
243 | self.assertEqual(response["statusCode"], 400)
244 | self.assertEqual(response["body"], "{\"message\": \"list index out of range\"}")
245 |
246 | def test_cors(self):
247 | # you can automatically add CORS headers to any function, using the "cors" decorator.
248 | response = cors_example()
249 |
250 | # the response has been decorated with the access-control cors headers.
251 | self.assertEqual(response["statusCode"], 200)
252 | self.assertEqual(response["headers"]["access-control-allow-origin"], "*")
253 | self.assertEqual(response["headers"]["access-control-allow-methods"], "POST")
254 | self.assertEqual(response["headers"]["access-control-allow-headers"], "Content-Type")
255 | self.assertEqual(response["headers"]["access-control-max-age"], 86400)
256 |
257 | def test_hsts(self):
258 | # You can automatically add HSTS header to any function, using the "hsts" decorator.
259 | response = hsts_example()
260 |
261 | # The response has been decorated with HSTS header
262 | self.assertEqual(response["statusCode"], 200)
263 | self.assertEqual(response["headers"]["Strict-Transport-Security"], "max-age=63072000")
264 |
265 | def test_extract_minimum_param_with_custom_error_example(self):
266 | # You can add custom error messages to all validators, and incorporate to those error messages
267 | # the validated value and the validation condition.
268 | response = extract_minimum_param_with_custom_error_example({"parent": {"an_int": 10}})
269 |
270 | # The error response contains the custom error message, with the correct value to validate
271 | # and condition to check.
272 | self.assertEqual(response["statusCode"], 400)
273 | self.assertEqual(response["body"], "{\"message\": [{\"an_int\": [\"Bad value 10: should be at least 100\"]}]}")
274 |
275 | def test_extract_dictionary_example(self):
276 | # Given this dictionary:
277 | a_dictionary = {
278 | "params": {
279 | "my_param_1": "Hello!",
280 | "my_param_2": "Bye!"
281 | }
282 | }
283 |
284 | # calling the decorated extract_dictionary_example:
285 | response = extract_dictionary_example(a_dictionary)
286 |
287 | # will return the extracted values in a dictionary.
288 | self.assertEqual({"my_param_1": "Hello!", "my_param_2": "Bye!"}, response)
289 |
290 | def test_extract_type_param(self):
291 | # Given this dictionary:
292 | a_dictionary = {
293 | "params": {
294 | "a_bool": True
295 | }
296 | }
297 |
298 | # calling the decorated extract_dictionary_example:
299 | response = extract_type_param(a_dictionary)
300 |
301 | # will return the extracted values in a dictionary.
302 | self.assertEqual(True, response)
303 |
304 | def test_extract_enum_param(self):
305 | # Given this dictionary:
306 | a_dictionary = {
307 | "params": {
308 | "an_enum": "Bye"
309 | }
310 | }
311 |
312 | # calling the decorated extract_dictionary_example:
313 | response = extract_enum_param(a_dictionary)
314 |
315 | # will return the extracted values in a dictionary.
316 | self.assertEqual("Bye", response)
317 |
318 | def test_extract_non_empty_param(self):
319 | # Given this dictionary:
320 | a_dictionary = {
321 | "params": {
322 | "non_empty": ["first value"]
323 | }
324 | }
325 |
326 | # calling the decorated extract_dictionary_example:
327 | response = extract_non_empty_param(a_dictionary)
328 |
329 | # will return the extracted values in a dictionary.
330 | self.assertEqual(["first value"], response)
331 |
332 | def test_extract_date_param(self):
333 | # Given this dictionary:
334 | a_dictionary = {
335 | "params": {
336 | "date_example": "2001-01-01 00:00:00"
337 | }
338 | }
339 |
340 | # calling the decorated extract_dictionary_example:
341 | response = extract_date_param(a_dictionary)
342 |
343 | # will return the extracted values in a dictionary.
344 | self.assertEqual("2001-01-01 00:00:00", response)
345 |
346 | def test_extract_currency_param(self):
347 | # Given this dictionary:
348 | a_dictionary = {
349 | "params": {
350 | "currency_example": "GBP"
351 | }
352 | }
353 |
354 | # calling the decorated extract_dictionary_example:
355 | response = extract_currency_param(a_dictionary)
356 |
357 | # will return the extracted values in a dictionary.
358 | self.assertEqual("GBP", response)
359 |
360 | def test_extract_with_transform(self):
361 | # Given this dictionary:
362 | a_dictionary = {
363 | "params": {
364 | "my_param": "2"
365 | }
366 | }
367 |
368 | # calling the decorated extract_with_transform_example:
369 | response = extract_with_transform_example(a_dictionary)
370 |
371 | # will return the integer value 2 ("2" transform to an int)
372 | self.assertEqual(2, response)
373 |
374 | @patch("aws_lambda_decorators.decorators.LOGGER")
375 | def test_extract_with_custom_transform(self, mock_logger):
376 | # Given this dictionary:
377 | a_dictionary = {
378 | "params": {
379 | "my_param": "abc"
380 | }
381 | }
382 |
383 | # calling the decorated extract_with_custom_transform_example:
384 | response = extract_with_custom_transform_example(a_dictionary)
385 |
386 | # The error response contains a generic error
387 | self.assertEqual(response["statusCode"], 400)
388 | self.assertEqual(response["body"], "{\"message\": \"Error extracting parameters\"}")
389 |
390 | # and the logs will contain the "My custom error message" message
391 | mock_logger.error.assert_called_once_with("%s: %s in argument %s for path %s",
392 | "Exception",
393 | "My custom error message",
394 | "a_dictionary",
395 | "/params/my_param")
396 |
--------------------------------------------------------------------------------
/pylintrc:
--------------------------------------------------------------------------------
1 | [MASTER]
2 |
3 | # A comma-separated list of package or module names from where C extensions may
4 | # be loaded. Extensions are loading into the active Python interpreter and may
5 | # run arbitrary code.
6 | extension-pkg-whitelist=
7 |
8 | # Add files or directories to the blacklist. They should be base names, not
9 | # paths.
10 | ignore=CVS
11 |
12 | # Add files or directories matching the regex patterns to the blacklist. The
13 | # regex matches against base names, not paths.
14 | ignore-patterns=
15 |
16 | # Python code to execute, usually for sys.path manipulation such as
17 | # pygtk.require().
18 | #init-hook=
19 |
20 | # Use multiple processes to speed up Pylint. Specifying 0 will auto-detect the
21 | # number of processors available to use.
22 | jobs=1
23 |
24 | # Control the amount of potential inferred values when inferring a single
25 | # object. This can help the performance when dealing with large functions or
26 | # complex, nested conditions.
27 | limit-inference-results=100
28 |
29 | # List of plugins (as comma separated values of python modules names) to load,
30 | # usually to register additional checkers.
31 | load-plugins=pylint_quotes
32 |
33 | # Pickle collected data for later comparisons.
34 | persistent=yes
35 |
36 | # Specify a configuration file.
37 | #rcfile=
38 |
39 | # When enabled, pylint would attempt to guess common misconfiguration and emit
40 | # user-friendly hints instead of false-positive error messages.
41 | suggestion-mode=yes
42 |
43 | # Allow loading of arbitrary C extensions. Extensions are imported into the
44 | # active Python interpreter and may run arbitrary code.
45 | unsafe-load-any-extension=no
46 |
47 |
48 | [MESSAGES CONTROL]
49 |
50 | # Only show warnings with the listed confidence levels. Leave empty to show
51 | # all. Valid levels: HIGH, INFERENCE, INFERENCE_FAILURE, UNDEFINED.
52 | confidence=
53 |
54 | # Disable the message, report, category or checker with the given id(s). You
55 | # can either give multiple identifiers separated by comma (,) or put this
56 | # option multiple times (only on the command line, not in the configuration
57 | # file where it should appear only once). You can also use "--disable=all" to
58 | # disable everything first and then reenable specific checks. For example, if
59 | # you want to run only the similarities checker, you can use "--disable=all
60 | # --enable=similarities". If you want to run only the classes checker, but have
61 | # no Warning level messages displayed, use "--disable=all --enable=classes
62 | # --disable=W".
63 | disable=missing-docstring,
64 | print-statement,
65 | parameter-unpacking,
66 | unpacking-in-except,
67 | old-raise-syntax,
68 | backtick,
69 | long-suffix,
70 | old-ne-operator,
71 | old-octal-literal,
72 | import-star-module-level,
73 | non-ascii-bytes-literal,
74 | raw-checker-failed,
75 | bad-inline-option,
76 | locally-disabled,
77 | locally-enabled,
78 | file-ignored,
79 | suppressed-message,
80 | useless-suppression,
81 | deprecated-pragma,
82 | use-symbolic-message-instead,
83 | apply-builtin,
84 | basestring-builtin,
85 | buffer-builtin,
86 | cmp-builtin,
87 | coerce-builtin,
88 | execfile-builtin,
89 | file-builtin,
90 | long-builtin,
91 | raw_input-builtin,
92 | reduce-builtin,
93 | standarderror-builtin,
94 | unicode-builtin,
95 | xrange-builtin,
96 | coerce-method,
97 | delslice-method,
98 | getslice-method,
99 | setslice-method,
100 | no-absolute-import,
101 | old-division,
102 | dict-iter-method,
103 | dict-view-method,
104 | next-method-called,
105 | metaclass-assignment,
106 | indexing-exception,
107 | raising-string,
108 | reload-builtin,
109 | oct-method,
110 | hex-method,
111 | nonzero-method,
112 | cmp-method,
113 | input-builtin,
114 | round-builtin,
115 | intern-builtin,
116 | unichr-builtin,
117 | map-builtin-not-iterating,
118 | zip-builtin-not-iterating,
119 | range-builtin-not-iterating,
120 | filter-builtin-not-iterating,
121 | using-cmp-argument,
122 | eq-without-hash,
123 | div-method,
124 | idiv-method,
125 | rdiv-method,
126 | exception-message-attribute,
127 | invalid-str-codec,
128 | sys-max-int,
129 | bad-python3-import,
130 | deprecated-string-function,
131 | deprecated-str-translate-call,
132 | deprecated-itertools-function,
133 | deprecated-types-field,
134 | next-method-defined,
135 | dict-items-not-iterating,
136 | dict-keys-not-iterating,
137 | dict-values-not-iterating,
138 | deprecated-operator-function,
139 | deprecated-urllib-function,
140 | xreadlines-attribute,
141 | deprecated-sys-function,
142 | exception-escape,
143 | comprehension-escape
144 |
145 | # Enable the message, report, category or checker with the given id(s). You can
146 | # either give multiple identifier separated by comma (,) or put this option
147 | # multiple time (only on the command line, not in the configuration file where
148 | # it should appear only once). See also the "--disable" option for examples.
149 | enable=c-extension-no-member
150 |
151 |
152 | [REPORTS]
153 |
154 | # Python expression which should return a note less than 10 (10 is the highest
155 | # note). You have access to the variables errors warning, statement which
156 | # respectively contain the number of errors / warnings messages and the total
157 | # number of statements analyzed. This is used by the global evaluation report
158 | # (RP0004).
159 | evaluation=10.0 - ((float(5 * error + warning + refactor + convention) / statement) * 10)
160 |
161 | # Template used to display messages. This is a python new-style format string
162 | # used to format the message information. See doc for all details.
163 | #msg-template=
164 |
165 | # Set the output format. Available formats are text, parseable, colorized, json
166 | # and msvs (visual studio). You can also give a reporter class, e.g.
167 | # mypackage.mymodule.MyReporterClass.
168 | output-format=text
169 |
170 | # Tells whether to display a full report or only the messages.
171 | reports=no
172 |
173 | # Activate the evaluation score.
174 | score=yes
175 |
176 |
177 | [REFACTORING]
178 |
179 | # Maximum number of nested blocks for function / method body
180 | max-nested-blocks=5
181 |
182 | # Complete name of functions that never returns. When checking for
183 | # inconsistent-return-statements if a never returning function is called then
184 | # it will be considered as an explicit return statement and no message will be
185 | # printed.
186 | never-returning-functions=sys.exit
187 |
188 |
189 | [LOGGING]
190 |
191 | # Logging modules to check that the string format arguments are in logging
192 | # function parameter format.
193 | logging-modules=logging
194 |
195 |
196 | [SPELLING]
197 |
198 | # Limits count of emitted suggestions for spelling mistakes.
199 | max-spelling-suggestions=4
200 |
201 | # Spelling dictionary name. Available dictionaries: none. To make it working
202 | # install python-enchant package..
203 | spelling-dict=
204 |
205 | # List of comma separated words that should not be checked.
206 | spelling-ignore-words=
207 |
208 | # A path to a file that contains private dictionary; one word per line.
209 | spelling-private-dict-file=
210 |
211 | # Tells whether to store unknown words to indicated private dictionary in
212 | # --spelling-private-dict-file option instead of raising a message.
213 | spelling-store-unknown-words=no
214 |
215 |
216 | [MISCELLANEOUS]
217 |
218 | # List of note tags to take in consideration, separated by a comma.
219 | notes=FIXME,
220 | XXX,
221 | TODO
222 |
223 |
224 | [TYPECHECK]
225 |
226 | # List of decorators that produce context managers, such as
227 | # contextlib.contextmanager. Add to this list to register other decorators that
228 | # produce valid context managers.
229 | contextmanager-decorators=contextlib.contextmanager
230 |
231 | # List of members which are set dynamically and missed by pylint inference
232 | # system, and so shouldn't trigger E1101 when accessed. Python regular
233 | # expressions are accepted.
234 | generated-members=
235 |
236 | # Tells whether missing members accessed in mixin class should be ignored. A
237 | # mixin class is detected if its name ends with "mixin" (case insensitive).
238 | ignore-mixin-members=yes
239 |
240 | # Tells whether to warn about missing members when the owner of the attribute
241 | # is inferred to be None.
242 | ignore-none=yes
243 |
244 | # This flag controls whether pylint should warn about no-member and similar
245 | # checks whenever an opaque object is returned when inferring. The inference
246 | # can return multiple potential results while evaluating a Python object, but
247 | # some branches might not be evaluated, which results in partial inference. In
248 | # that case, it might be useful to still emit no-member and other checks for
249 | # the rest of the inferred objects.
250 | ignore-on-opaque-inference=yes
251 |
252 | # List of class names for which member attributes should not be checked (useful
253 | # for classes with dynamically set attributes). This supports the use of
254 | # qualified names.
255 | ignored-classes=optparse.Values,thread._local,_thread._local
256 |
257 | # List of module names for which member attributes should not be checked
258 | # (useful for modules/projects where namespaces are manipulated during runtime
259 | # and thus existing member attributes cannot be deduced by static analysis. It
260 | # supports qualified module names, as well as Unix pattern matching.
261 | ignored-modules=
262 |
263 | # Show a hint with possible names when a member name was not found. The aspect
264 | # of finding the hint is based on edit distance.
265 | missing-member-hint=yes
266 |
267 | # The minimum edit distance a name should have in order to be considered a
268 | # similar match for a missing member name.
269 | missing-member-hint-distance=1
270 |
271 | # The total number of similar names that should be taken in consideration when
272 | # showing a hint for a missing member.
273 | missing-member-max-choices=1
274 |
275 |
276 | [VARIABLES]
277 |
278 | # List of additional names supposed to be defined in builtins. Remember that
279 | # you should avoid to define new builtins when possible.
280 | additional-builtins=
281 |
282 | # Tells whether unused global variables should be treated as a violation.
283 | allow-global-unused-variables=yes
284 |
285 | # List of strings which can identify a callback function by name. A callback
286 | # name must start or end with one of those strings.
287 | callbacks=cb_,
288 | _cb
289 |
290 | # A regular expression matching the name of dummy variables (i.e. expected to
291 | # not be used).
292 | dummy-variables-rgx=_+$|(_[a-zA-Z0-9_]*[a-zA-Z0-9]+?$)|dummy|^ignored_|^unused_
293 |
294 | # Argument names that match this expression will be ignored. Default to name
295 | # with leading underscore.
296 | ignored-argument-names=_.*|^ignored_|^unused_
297 |
298 | # Tells whether we should check for unused import in __init__ files.
299 | init-import=no
300 |
301 | # List of qualified module names which can have objects that can redefine
302 | # builtins.
303 | redefining-builtins-modules=six.moves,past.builtins,future.builtins,builtins,io
304 |
305 |
306 | [FORMAT]
307 |
308 | # Expected format of line ending, e.g. empty (any line ending), LF or CRLF.
309 | expected-line-ending-format=
310 |
311 | # Regexp for a line that is allowed to be longer than the limit.
312 | ignore-long-lines=^\s*(# )??$
313 |
314 | # Number of spaces of indent required inside a hanging or continued line.
315 | indent-after-paren=4
316 |
317 | # String used as indentation unit. This is usually " " (4 spaces) or "\t" (1
318 | # tab).
319 | indent-string=' '
320 |
321 | # Maximum number of characters on a single line.
322 | max-line-length=120
323 |
324 | # Maximum number of lines in a module.
325 | max-module-lines=1000
326 |
327 | # List of optional constructs for which whitespace checking is disabled. `dict-
328 | # separator` is used to allow tabulation in dicts, etc.: {1 : 1,\n222: 2}.
329 | # `trailing-comma` allows a space between comma and closing bracket: (a, ).
330 | # `empty-line` allows space-only lines.
331 | no-space-check=trailing-comma,
332 | dict-separator
333 |
334 | # Allow the body of a class to be on the same line as the declaration if body
335 | # contains single statement.
336 | single-line-class-stmt=no
337 |
338 | # Allow the body of an if to be on the same line as the test if there is no
339 | # else.
340 | single-line-if-stmt=no
341 |
342 |
343 | [SIMILARITIES]
344 |
345 | # Ignore comments when computing similarities.
346 | ignore-comments=yes
347 |
348 | # Ignore docstrings when computing similarities.
349 | ignore-docstrings=yes
350 |
351 | # Ignore imports when computing similarities.
352 | ignore-imports=no
353 |
354 | # Minimum lines number of a similarity.
355 | min-similarity-lines=4
356 |
357 |
358 | [BASIC]
359 |
360 | # Naming style matching correct argument names.
361 | argument-naming-style=snake_case
362 |
363 | # Regular expression matching correct argument names. Overrides argument-
364 | # naming-style.
365 | #argument-rgx=
366 |
367 | # Naming style matching correct attribute names.
368 | attr-naming-style=snake_case
369 |
370 | # Regular expression matching correct attribute names. Overrides attr-naming-
371 | # style.
372 | #attr-rgx=
373 |
374 | # Bad variable names which should always be refused, separated by a comma.
375 | bad-names=foo,
376 | bar,
377 | baz,
378 | toto,
379 | tutu,
380 | tata
381 |
382 | # Naming style matching correct class attribute names.
383 | class-attribute-naming-style=any
384 |
385 | # Regular expression matching correct class attribute names. Overrides class-
386 | # attribute-naming-style.
387 | #class-attribute-rgx=
388 |
389 | # Naming style matching correct class names.
390 | class-naming-style=PascalCase
391 |
392 | # Regular expression matching correct class names. Overrides class-naming-
393 | # style.
394 | #class-rgx=
395 |
396 | # Naming style matching correct constant names.
397 | const-naming-style=UPPER_CASE
398 |
399 | # Regular expression matching correct constant names. Overrides const-naming-
400 | # style.
401 | #const-rgx=
402 |
403 | # Minimum line length for functions/classes that require docstrings, shorter
404 | # ones are exempt.
405 | docstring-min-length=-1
406 |
407 | # Naming style matching correct function names.
408 | function-naming-style=snake_case
409 |
410 | # Regular expression matching correct function names. Overrides function-
411 | # naming-style.
412 | #function-rgx=
413 |
414 | # Good variable names which should always be accepted, separated by a comma.
415 | good-names=i,
416 | j,
417 | k,
418 | ex,
419 | Run,
420 | _
421 |
422 | # Include a hint for the correct naming format with invalid-name.
423 | include-naming-hint=no
424 |
425 | # Naming style matching correct inline iteration names.
426 | inlinevar-naming-style=any
427 |
428 | # Regular expression matching correct inline iteration names. Overrides
429 | # inlinevar-naming-style.
430 | #inlinevar-rgx=
431 |
432 | # Naming style matching correct method names.
433 | method-naming-style=snake_case
434 |
435 | # Regular expression matching correct method names. Overrides method-naming-
436 | # style.
437 | #method-rgx=
438 |
439 | # Naming style matching correct module names.
440 | module-naming-style=snake_case
441 |
442 | # Regular expression matching correct module names. Overrides module-naming-
443 | # style.
444 | #module-rgx=
445 |
446 | # Colon-delimited sets of names that determine each other's naming style when
447 | # the name regexes allow several styles.
448 | name-group=
449 |
450 | # Regular expression which should only match function or class names that do
451 | # not require a docstring.
452 | no-docstring-rgx=^_
453 |
454 | # List of decorators that produce properties, such as abc.abstractproperty. Add
455 | # to this list to register other decorators that produce valid properties.
456 | # These decorators are taken in consideration only for invalid-name.
457 | property-classes=abc.abstractproperty
458 |
459 | # Naming style matching correct variable names.
460 | variable-naming-style=snake_case
461 |
462 | # Regular expression matching correct variable names. Overrides variable-
463 | # naming-style.
464 | #variable-rgx=
465 |
466 |
467 | [STRING_QUOTES]
468 |
469 | # The quote character for triple-quoted docstrings.
470 | docstring-quote=double
471 |
472 | # The quote character for string literals.
473 | string-quote=double
474 |
475 | # The quote character for triple-quoted strings (non-docstring).
476 | triple-quote=double
477 |
478 |
479 | [IMPORTS]
480 |
481 | # Allow wildcard imports from modules that define __all__.
482 | allow-wildcard-with-all=no
483 |
484 | # Analyse import fallback blocks. This can be used to support both Python 2 and
485 | # 3 compatible code, which means that the block might have code that exists
486 | # only in one or another interpreter, leading to false positives when analysed.
487 | analyse-fallback-blocks=no
488 |
489 | # Deprecated modules which should not be used, separated by a comma.
490 | deprecated-modules=optparse,tkinter.tix
491 |
492 | # Create a graph of external dependencies in the given file (report RP0402 must
493 | # not be disabled).
494 | ext-import-graph=
495 |
496 | # Create a graph of every (i.e. internal and external) dependencies in the
497 | # given file (report RP0402 must not be disabled).
498 | import-graph=
499 |
500 | # Create a graph of internal dependencies in the given file (report RP0402 must
501 | # not be disabled).
502 | int-import-graph=
503 |
504 | # Force import order to recognize a module as part of the standard
505 | # compatibility libraries.
506 | known-standard-library=
507 |
508 | # Force import order to recognize a module as part of a third party library.
509 | known-third-party=enchant
510 |
511 |
512 | [CLASSES]
513 |
514 | # List of method names used to declare (i.e. assign) instance attributes.
515 | defining-attr-methods=__init__,
516 | __new__,
517 | setUp
518 |
519 | # List of member names, which should be excluded from the protected access
520 | # warning.
521 | exclude-protected=_asdict,
522 | _fields,
523 | _replace,
524 | _source,
525 | _make
526 |
527 | # List of valid names for the first argument in a class method.
528 | valid-classmethod-first-arg=cls
529 |
530 | # List of valid names for the first argument in a metaclass class method.
531 | valid-metaclass-classmethod-first-arg=cls
532 |
533 |
534 | [DESIGN]
535 |
536 | # Maximum number of arguments for function / method.
537 | max-args=5
538 |
539 | # Maximum number of attributes for a class (see R0902).
540 | max-attributes=7
541 |
542 | # Maximum number of boolean expressions in an if statement.
543 | max-bool-expr=5
544 |
545 | # Maximum number of branch for function / method body.
546 | max-branches=12
547 |
548 | # Maximum number of locals for function / method body.
549 | max-locals=15
550 |
551 | # Maximum number of parents for a class (see R0901).
552 | max-parents=7
553 |
554 | # Maximum number of public methods for a class (see R0904).
555 | max-public-methods=20
556 |
557 | # Maximum number of return / yield for function / method body.
558 | max-returns=6
559 |
560 | # Maximum number of statements in function / method body.
561 | max-statements=50
562 |
563 | # Minimum number of public methods for a class (see R0903).
564 | min-public-methods=2
565 |
566 |
567 | [EXCEPTIONS]
568 |
569 | # Exceptions that will emit a warning when being caught. Defaults to
570 | # "Exception".
571 | overgeneral-exceptions=Exception
572 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | bandit
2 | boto3
3 | prospector
4 | PyJWT==1.7.1
5 | pylint_quotes
6 | schema
7 | coverage==4.5.4
8 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | from setuptools import setup, find_packages
2 |
3 | LONG_DESCRIPTION = open("README.md").read()
4 |
5 | setup(name="aws-lambda-decorators",
6 | version="0.53",
7 | description="A set of python decorators to simplify aws python lambda development",
8 | long_description=LONG_DESCRIPTION,
9 | long_description_content_type="text/markdown",
10 | url="https://github.com/gridsmartercities/aws-lambda-decorators",
11 | author="Grid Smarter Cities",
12 | author_email="open-source@gridsmartercities.com",
13 | license="MIT",
14 | classifiers=["Intended Audience :: Developers",
15 | "Development Status :: 4 - Beta",
16 | "Programming Language :: Python :: 3",
17 | "License :: OSI Approved :: MIT License",
18 | "Operating System :: OS Independent",
19 | "Natural Language :: English"
20 | ],
21 | keywords="aws lambda decorator",
22 | packages=find_packages(exclude=("tests", "examples",)),
23 | install_requires=[
24 | "boto3",
25 | "PyJWT",
26 | "schema"
27 | ],
28 | zip_safe=False
29 | )
30 |
--------------------------------------------------------------------------------
/tests/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/gridsmartercities/aws-lambda-decorators/16dbe6ae1b9982f312d593336682c4ebbcd4f52d/tests/__init__.py
--------------------------------------------------------------------------------
/tests/test_classes.py:
--------------------------------------------------------------------------------
1 | import unittest
2 | from aws_lambda_decorators.classes import Parameter, SSMParameter, BaseParameter
3 |
4 |
5 | class ParamTests(unittest.TestCase):
6 |
7 | def test_can_create_base_parameter(self):
8 | base_param = BaseParameter("var_name")
9 | self.assertEqual("var_name", base_param.get_var_name())
10 |
11 | def test_annotations_from_key_returns_annotation(self):
12 | key = "simple[annotation]"
13 | response = Parameter.get_annotations_from_key(key)
14 | self.assertTrue(response[0] == "simple")
15 | self.assertTrue(response[1] == "annotation")
16 |
17 | def test_can_not_add_non_pythonic_var_name_to_ssm_parameter(self):
18 | param = SSMParameter("tests", "with space")
19 |
20 | with self.assertRaises(SyntaxError):
21 | param.get_var_name()
22 |
--------------------------------------------------------------------------------
/tests/test_decoders.py:
--------------------------------------------------------------------------------
1 | # pylint:disable=no-self-use
2 | import json
3 | import unittest
4 | from unittest.mock import patch
5 | from aws_lambda_decorators.decoders import decode, decode_json, decode_jwt
6 | from aws_lambda_decorators.decorators import extract
7 | from aws_lambda_decorators.classes import Parameter
8 |
9 |
10 | TEST_JWT = "eyJraWQiOiJEQlwvK0lGMVptekNWOGNmRE1XVUxBRlBwQnVObW5CU2NcL2RoZ3pnTVhcL2NzPSIsImFsZyI6IlJTMjU2In0." \
11 | "eyJzdWIiOiJhYWRkMWUwZS01ODA3LTQ3NjMtYjFlOC01ODIzYmY2MzFiYjYiLCJhdWQiOiIycjdtMW1mdWFiODg3ZmZvdG9iNWFjcX" \
12 | "Q2aCIsImNvZ25pdG86Z3JvdXBzIjpbIkRBU0gtQ3VzdG9tZXIiXSwiZW1haWxfdmVyaWZpZWQiOnRydWUsImV2ZW50X2lkIjoiZDU4" \
13 | "NzU0ZjUtMTdlMC0xMWU5LTg2NzAtMjVkOTNhNWNiMjAwIiwidG9rZW5fdXNlIjoiaWQiLCJhdXRoX3RpbWUiOjE1NDc0NTkwMDMsIm" \
14 | "lzcyI6Imh0dHBzOlwvXC9jb2duaXRvLWlkcC5ldS13ZXN0LTIuYW1hem9uYXdzLmNvbVwvZXUtd2VzdC0yX1B4bEdzMU11SiIsImNv" \
15 | "Z25pdG86dXNlcm5hbWUiOiJhYWRkMWUwZS01ODA3LTQ3NjMtYjFlOC01ODIzYmY2MzFiYjYiLCJleHAiOjE1NDc0NjI2MDMsImlhdC" \
16 | "I6MTU0NzQ1OTAwMywiZW1haWwiOiJjdXN0b21lckBleGFtcGxlLmNvbSJ9.CNSDu4a9azT40maHAF9tnQTWbfEeiTZ9PfkR9_RU_VG" \
17 | "4QTA1y4R0F2zWVpsa3CkVMq4Uv2NWOwG6zXf-7XaWTEjoGOQR07sq54IEWU3WIxgkgtRAI-aR7nIvllMXXR0RE3e5jzn5SmefG1j-O" \
18 | "NYiD1yYExrKOEMPJVgkdYG6x2cBiucHihVliJQUf9u-ebpu2Cpm_ACvUTUilB6sBL06D3sRobvNLbNNnSjsA66ULNpPTPOVYJxhFbu" \
19 | "ceQ1EICp0oICw2ncJch78RAFY5TeqiVa-uBybxwd36zJmZkXeJPWAKd32IOIJXNUyDOJtmXtSQW51pZGYTsihjZHz3kNlfg"
20 |
21 |
22 | class DecodersTests(unittest.TestCase):
23 |
24 | @patch("aws_lambda_decorators.decoders.LOGGER")
25 | def test_decode_function_missing_logs_error(self, mock_logger):
26 | decode("[random]", None)
27 | mock_logger.error.assert_called_once_with("Missing decode function for annotation: %s", "[random]")
28 |
29 | @patch("aws_lambda_decorators.decorators.LOGGER")
30 | def test_extract_returns_400_on_json_decode_error(self, mock_logger):
31 | path = "/a/b[json]/c"
32 | dictionary = {
33 | "a": {
34 | "b": "{'c'}"
35 | }
36 | }
37 |
38 | @extract([Parameter(path, "event")])
39 | def handler(event, context, c=None): # noqa
40 | return {}
41 |
42 | response = handler(dictionary, None)
43 |
44 | self.assertEqual(400, response["statusCode"])
45 | self.assertEqual("{\"message\": \"Error extracting parameters\"}", response["body"])
46 |
47 | mock_logger.error.assert_called_once_with(
48 | "%s: %s in argument %s for path %s",
49 | "json.decoder.JSONDecodeError",
50 | "Expecting property name enclosed in double quotes: line 1 column 2 (char 1)",
51 | "event",
52 | "/a/b[json]/c")
53 |
54 | @patch("aws_lambda_decorators.decorators.LOGGER")
55 | def test_extract_returns_400_on_jwt_decode_error(self, mock_logger):
56 | path = "/a/b[jwt]/c"
57 | dictionary = {
58 | "a": {
59 | "b": "wrong.jwt"
60 | }
61 | }
62 |
63 | @extract([Parameter(path, "event")])
64 | def handler(event, context, c=None): # noqa
65 | return {}
66 |
67 | response = handler(dictionary, None)
68 |
69 | self.assertEqual(400, response["statusCode"])
70 | self.assertTrue("{\"message\": \"Error extracting parameters\"}" in response["body"])
71 |
72 | mock_logger.error.assert_called_once_with(
73 | "%s: %s in argument %s for path %s",
74 | "jwt.exceptions.DecodeError",
75 | "Not enough segments",
76 | "event",
77 | "/a/b[jwt]/c")
78 |
79 | def test_extracts_from_list_by_index_annotation_successfully(self):
80 | path = "/a/b[1]/c"
81 | dictionary = {
82 | "a": {
83 | "b": [
84 | {
85 | "c": 2
86 | },
87 | {
88 | "c": 3
89 | }
90 | ]
91 | }
92 | }
93 |
94 | @extract([Parameter(path, "event")])
95 | def handler(event, context, c=None): # noqa
96 | return c
97 |
98 | response = handler(dictionary, None)
99 |
100 | self.assertEqual(3, response)
101 |
102 | @patch("aws_lambda_decorators.decorators.LOGGER")
103 | def test_extracts_from_list_by_index_out_of_range_fails_with_400(self, mock_logger):
104 | path = "/a/b[4]/c"
105 | dictionary = {
106 | "a": {
107 | "b": [
108 | {
109 | "c": 2
110 | },
111 | {
112 | "c": 3
113 | }
114 | ]
115 | }
116 | }
117 |
118 | @extract([Parameter(path, "event")])
119 | def handler(event, context, c=None): # noqa
120 | return c
121 |
122 | response = handler(dictionary, None)
123 |
124 | self.assertEqual(400, response["statusCode"]) # noqa
125 | self.assertTrue("{\"message\": \"Error extracting parameters\"}" in response["body"]) # noqa
126 |
127 | mock_logger.error.assert_called_once_with(
128 | "%s: %s in argument %s for path %s",
129 | "IndexError",
130 | "list index out of range",
131 | "event",
132 | "/a/b[4]/c")
133 |
134 | def test_extract_multiple_parameters_from_json_hits_cache(self):
135 | dictionary = {
136 | "a": json.dumps({
137 | "b": 123,
138 | "c": 456
139 | })
140 | }
141 |
142 | initial_cache_info = decode_json.cache_info()
143 |
144 | @extract([
145 | Parameter("a[json]/b", "event", var_name="b"),
146 | Parameter("a[json]/c", "event", var_name="c")
147 | ]) # noqa: pylint - invalid-name
148 | def handler(event, b=None, c=None): # noqa: pylint - unused-argument
149 | return {}
150 |
151 | handler(dictionary)
152 |
153 | self.assertEqual(decode_json.cache_info().hits, initial_cache_info.hits + 1)
154 | self.assertEqual(decode_json.cache_info().misses, initial_cache_info.misses + 1)
155 | self.assertEqual(decode_json.cache_info().currsize, initial_cache_info.currsize + 1)
156 |
157 | def test_extract_multiple_parameters_from_jwt_hits_cache(self):
158 | dictionary = {
159 | "a": TEST_JWT
160 | }
161 |
162 | initial_cache_info = decode_jwt.cache_info()
163 |
164 | @extract([
165 | Parameter("a[jwt]/sub", "event", var_name="sub"),
166 | Parameter("a[jwt]/aud", "event", var_name="aud")
167 | ])
168 | def handler(event, sub=None, aud=None): # noqa: pylint - unused-argument
169 | return {}
170 |
171 | handler(dictionary)
172 |
173 | self.assertEqual(decode_jwt.cache_info().hits, initial_cache_info.hits + 1)
174 | self.assertEqual(decode_jwt.cache_info().misses, initial_cache_info.misses + 1)
175 | self.assertEqual(decode_jwt.cache_info().currsize, initial_cache_info.currsize + 1)
176 |
--------------------------------------------------------------------------------
/tests/test_decorators.py:
--------------------------------------------------------------------------------
1 | # pylint:disable=too-many-lines
2 | from http import HTTPStatus
3 | import json
4 | from json import JSONDecodeError
5 | import unittest
6 | from unittest.mock import patch, MagicMock
7 | from uuid import uuid4
8 |
9 | from botocore.exceptions import ClientError
10 | from schema import Schema, And, Optional
11 |
12 | from aws_lambda_decorators.classes import ExceptionHandler, Parameter, SSMParameter, ValidatedParameter
13 | from aws_lambda_decorators.decorators import extract, extract_from_event, extract_from_context, handle_exceptions, \
14 | log, response_body_as_json, extract_from_ssm, validate, handle_all_exceptions, cors, push_ws_errors, \
15 | push_ws_response, hsts
16 | from aws_lambda_decorators.utils import get_websocket_endpoint
17 | from aws_lambda_decorators.validators import Mandatory, RegexValidator, SchemaValidator, Minimum, Maximum, MaxLength, \
18 | MinLength, Type, EnumValidator, NonEmpty, DateValidator, CurrencyValidator
19 |
20 | TEST_JWT = "eyJraWQiOiJEQlwvK0lGMVptekNWOGNmRE1XVUxBRlBwQnVObW5CU2NcL2RoZ3pnTVhcL2NzPSIsImFsZyI6IlJTMjU2In0." \
21 | "eyJzdWIiOiJhYWRkMWUwZS01ODA3LTQ3NjMtYjFlOC01ODIzYmY2MzFiYjYiLCJhdWQiOiIycjdtMW1mdWFiODg3ZmZvdG9iNWFjcX" \
22 | "Q2aCIsImNvZ25pdG86Z3JvdXBzIjpbIkRBU0gtQ3VzdG9tZXIiXSwiZW1haWxfdmVyaWZpZWQiOnRydWUsImV2ZW50X2lkIjoiZDU4" \
23 | "NzU0ZjUtMTdlMC0xMWU5LTg2NzAtMjVkOTNhNWNiMjAwIiwidG9rZW5fdXNlIjoiaWQiLCJhdXRoX3RpbWUiOjE1NDc0NTkwMDMsIm" \
24 | "lzcyI6Imh0dHBzOlwvXC9jb2duaXRvLWlkcC5ldS13ZXN0LTIuYW1hem9uYXdzLmNvbVwvZXUtd2VzdC0yX1B4bEdzMU11SiIsImNv" \
25 | "Z25pdG86dXNlcm5hbWUiOiJhYWRkMWUwZS01ODA3LTQ3NjMtYjFlOC01ODIzYmY2MzFiYjYiLCJleHAiOjE1NDc0NjI2MDMsImlhdC" \
26 | "I6MTU0NzQ1OTAwMywiZW1haWwiOiJjdXN0b21lckBleGFtcGxlLmNvbSJ9.CNSDu4a9azT40maHAF9tnQTWbfEeiTZ9PfkR9_RU_VG" \
27 | "4QTA1y4R0F2zWVpsa3CkVMq4Uv2NWOwG6zXf-7XaWTEjoGOQR07sq54IEWU3WIxgkgtRAI-aR7nIvllMXXR0RE3e5jzn5SmefG1j-O" \
28 | "NYiD1yYExrKOEMPJVgkdYG6x2cBiucHihVliJQUf9u-ebpu2Cpm_ACvUTUilB6sBL06D3sRobvNLbNNnSjsA66ULNpPTPOVYJxhFbu" \
29 | "ceQ1EICp0oICw2ncJch78RAFY5TeqiVa-uBybxwd36zJmZkXeJPWAKd32IOIJXNUyDOJtmXtSQW51pZGYTsihjZHz3kNlfg"
30 |
31 |
32 | class DecoratorsTests(unittest.TestCase): # noqa: pylint - too-many-public-methods
33 |
34 | def test_can_get_value_from_dict_by_path(self):
35 | path = "/a/b/c"
36 | dictionary = {
37 | "a": {
38 | "b": {
39 | "c": "hello"
40 | }
41 | }
42 | }
43 | param = Parameter(path)
44 | response = param.extract_value(dictionary)
45 | self.assertEqual("hello", response)
46 |
47 | def test_can_get_dict_value_from_dict_by_path(self):
48 | path = "/a/b"
49 | dictionary = {
50 | "a": {
51 | "b": {
52 | "c": "hello"
53 | }
54 | }
55 | }
56 | param = Parameter(path)
57 | response = param.extract_value(dictionary)
58 | self.assertEqual({"c": "hello"}, response)
59 |
60 | def test_raises_decode_error_convert_json_string_to_dict(self):
61 | path = "/a/b[json]/c"
62 | dictionary = {
63 | "a": {
64 | "b": "{ 'c': 'hello' }",
65 | "c": "bye"
66 | }
67 | }
68 | param = Parameter(path)
69 | with self.assertRaises(JSONDecodeError) as context:
70 | param.extract_value(dictionary)
71 |
72 | self.assertTrue("Expecting property name enclosed in double quotes" in context.exception.msg)
73 |
74 | def test_can_get_value_from_dict_with_json_by_path(self):
75 | path = "/a/b[json]/c"
76 | dictionary = {
77 | "a": {
78 | "b": "{\"c\": \"hello\"}",
79 | "c": "bye"
80 | }
81 | }
82 | param = Parameter(path, "event")
83 | response = param.extract_value(dictionary)
84 | self.assertEqual("hello", response)
85 |
86 | def test_can_get_value_from_dict_with_jwt_by_path(self):
87 | path = "/a/b[jwt]/sub"
88 | dictionary = {
89 | "a": {
90 | "b": TEST_JWT
91 | }
92 | }
93 | param = Parameter(path, "event")
94 | response = param.extract_value(dictionary)
95 | self.assertEqual("aadd1e0e-5807-4763-b1e8-5823bf631bb6", response)
96 |
97 | def test_extract_from_event_calls_function_with_extra_kwargs(self):
98 | path = "/a/b/c"
99 | dictionary = {
100 | "a": {
101 | "b": {
102 | "c": "hello"
103 | }
104 | }
105 | }
106 |
107 | @extract_from_event([Parameter(path)])
108 | def handler(event, context, c=None): # noqa
109 | return c
110 |
111 | self.assertEqual(handler(dictionary, None), "hello")
112 |
113 | def test_extract_from_event_calls_function_with_extra_kwargs_bool_true(self):
114 | path = "/a/b/c"
115 | dictionary = {
116 | "a": {
117 | "b": {
118 | "c": True
119 | }
120 | }
121 | }
122 |
123 | @extract_from_event([Parameter(path)])
124 | def handler(event, context, c=None): # noqa
125 | return c
126 |
127 | self.assertFalse(handler(dictionary, None) is None)
128 | self.assertEqual(True, handler(dictionary, None))
129 |
130 | def test_extract_from_event_calls_function_with_extra_kwargs_bool_false(self):
131 | path = "/a/b/c"
132 | dictionary = {
133 | "a": {
134 | "b": {
135 | "c": False
136 | }
137 | }
138 | }
139 |
140 | @extract_from_event([Parameter(path)])
141 | def handler(event, context, c=None): # noqa
142 | return c
143 |
144 | self.assertFalse(handler(dictionary, None) is None)
145 | self.assertEqual(False, handler(dictionary, None))
146 |
147 | def test_extract_from_context_calls_function_with_extra_kwargs(self):
148 | path = "/a/b/c"
149 | dictionary = {
150 | "a": {
151 | "b": {
152 | "c": "hello"
153 | }
154 | }
155 | }
156 |
157 | @extract_from_context([Parameter(path)])
158 | def handler(event, context, c=None): # noqa
159 | return c
160 |
161 | self.assertEqual(handler(None, dictionary), "hello")
162 |
163 | def test_extract_returns_400_on_empty_path(self):
164 | path = None
165 | dictionary = {
166 | "a": {
167 | "b": {
168 | }
169 | }
170 | }
171 |
172 | @extract([Parameter(path, "event")])
173 | def handler(event, context, c=None): # noqa
174 | return {}
175 |
176 | response = handler(dictionary, None)
177 |
178 | self.assertEqual(400, response["statusCode"])
179 | self.assertEqual("{\"message\": \"Error extracting parameters\"}", response["body"])
180 |
181 | def test_extract_returns_400_on_missing_mandatory_key(self):
182 | path = "/a/b/c"
183 | dictionary = {
184 | "a": {
185 | "b": {
186 | }
187 | }
188 | }
189 |
190 | @extract([Parameter(path, "event", validators=[Mandatory])])
191 | def handler(event, context, c=None): # noqa
192 | return {}
193 |
194 | response = handler(dictionary, None)
195 |
196 | self.assertEqual(400, response["statusCode"])
197 | self.assertEqual("{\"message\": [{\"c\": [\"Missing mandatory value\"]}]}", response["body"])
198 |
199 | def test_can_add_name_to_parameter(self):
200 | path = "/a/b"
201 | dictionary = {
202 | "a": {
203 | "b": "hello"
204 | }
205 | }
206 |
207 | @extract([Parameter(path, "event", validators=[Mandatory], var_name="custom")])
208 | def handler(event, context, custom=None): # noqa
209 | return custom
210 |
211 | response = handler(dictionary, None)
212 |
213 | self.assertEqual("hello", response)
214 |
215 | @patch("aws_lambda_decorators.decorators.LOGGER")
216 | def test_can_not_add_non_pythonic_var_name_to_parameter(self, mock_logger):
217 | path = "/a/b"
218 | dictionary = {
219 | "a": {
220 | "b": "hello"
221 | }
222 | }
223 |
224 | @extract_from_event([Parameter(path, validators=[Mandatory], var_name="with space")])
225 | def handler(event, context): # noqa
226 | return {}
227 |
228 | response = handler(dictionary, None)
229 |
230 | self.assertEqual(400, response["statusCode"])
231 | self.assertEqual("{\"message\": \"Error extracting parameters\"}", response["body"])
232 |
233 | mock_logger.error.assert_called_once_with(
234 | "%s: %s in argument %s for path %s",
235 | "SyntaxError",
236 | "with space",
237 | "event",
238 | "/a/b")
239 |
240 | @patch("aws_lambda_decorators.decorators.LOGGER")
241 | def test_can_not_add_pythonic_keyword_as_name_to_parameter(self, mock_logger):
242 | path = "/a/b"
243 | dictionary = {
244 | "a": {
245 | "b": "hello"
246 | }
247 | }
248 |
249 | @extract_from_event([Parameter(path, validators=[Mandatory], var_name="class")])
250 | def handler(event, context): # noqa
251 | return {}
252 |
253 | response = handler(dictionary, None)
254 |
255 | self.assertEqual(400, response["statusCode"])
256 | self.assertEqual("{\"message\": \"Error extracting parameters\"}", response["body"])
257 |
258 | mock_logger.error.assert_called_once_with(
259 | "%s: %s in argument %s for path %s",
260 | "SyntaxError",
261 | "class",
262 | "event",
263 | "/a/b")
264 |
265 | def test_extract_does_not_raise_an_error_on_missing_optional_key(self):
266 | path = "/a/b/c"
267 | dictionary = {
268 | "a": {
269 | "b": {
270 | }
271 | }
272 | }
273 |
274 | @extract([Parameter(path, "event")])
275 | def handler(event, context, c=None): # noqa
276 | return {}
277 |
278 | response = handler(dictionary, None)
279 |
280 | self.assertEqual({}, response)
281 |
282 | @patch("aws_lambda_decorators.decorators.LOGGER")
283 | def test_extract_returns_400_on_invalid_regex_key(self, mock_logger):
284 | path = "/a/b/c"
285 | dictionary = {
286 | "a": {
287 | "b": {
288 | "c": "hello"
289 | }
290 | }
291 | }
292 |
293 | # Expect a number
294 | @extract([Parameter(path, "event", [RegexValidator(r"\d+")])])
295 | def handler(event, context, c=None): # noqa
296 | return {}
297 |
298 | response = handler(dictionary, None)
299 | self.assertEqual(400, response["statusCode"])
300 | self.assertEqual("{\"message\": [{\"c\": [\"\'hello\' does not conform to regular expression \'\\\\d+\'\"]}]}",
301 | response["body"])
302 |
303 | mock_logger.error.assert_called_once_with(
304 | "Error validating parameters. Errors: %s",
305 | [{"c": ["'hello' does not conform to regular expression '\\d+'"]}]
306 | )
307 |
308 | def test_extract_does_not_raise_an_error_on_valid_regex_key(self):
309 | path = "/a/b/c"
310 | dictionary = {
311 | "a": {
312 | "b": {
313 | "c": "2019"
314 | }
315 | }
316 | }
317 |
318 | # Expect a number
319 | @extract([Parameter(path, "event", [RegexValidator(r"\d+")])])
320 | def handler(event, context, c=None): # noqa
321 | return {}
322 |
323 | response = handler(dictionary, None)
324 |
325 | self.assertEqual({}, response)
326 |
327 | @patch("aws_lambda_decorators.decorators.LOGGER")
328 | def test_validate_raises_an_error_on_invalid_variables(self, mock_logger):
329 | @validate([
330 | ValidatedParameter(func_param_name="var1", validators=[RegexValidator(r"\d+")]),
331 | ValidatedParameter(func_param_name="var2", validators=[RegexValidator(r"\d+")])
332 | ])
333 | def handler(var1=None, var2=None): # noqa: pylint - unused-argument
334 | return {}
335 |
336 | response = handler("2019", "abcd")
337 |
338 | self.assertEqual(400, response["statusCode"])
339 | self.assertEqual(
340 | "{\"message\": [{\"var2\": [\"\'abcd\' does not conform to regular expression \'\\\\d+\'\"]}]}",
341 | response["body"]
342 | )
343 |
344 | mock_logger.error.assert_called_once_with(
345 | "Error validating parameters. Errors: %s",
346 | [{"var2": ["\'abcd\' does not conform to regular expression \'\\d+\'"]}]
347 | )
348 |
349 | @patch("aws_lambda_decorators.decorators.LOGGER")
350 | def test_validate_raises_multiple_errors_on_exit_on_error_false(self, mock_logger):
351 | @validate([
352 | ValidatedParameter(func_param_name="var1", validators=[RegexValidator(r"\d+")]),
353 | ValidatedParameter(func_param_name="var2", validators=[RegexValidator(r"\d+")])
354 | ], True)
355 | def handler(var1=None, var2=None): # noqa: pylint - unused-argument
356 | return {}
357 |
358 | response = handler("20wq19", "abcd")
359 |
360 | self.assertEqual(400, response["statusCode"])
361 | self.assertEqual(
362 | "{\"message\": [{\"var1\": [\"\'20wq19\' does not conform to regular expression \'\\\\d+\'\"]}, "
363 | "{\"var2\": [\"\'abcd\' does not conform to regular expression \'\\\\d+\'\"]}]}",
364 | response["body"])
365 |
366 | mock_logger.error.assert_called_once_with(
367 | "Error validating parameters. Errors: %s",
368 | [
369 | {"var1": ["'20wq19' does not conform to regular expression '\\d+'"]},
370 | {"var2": ["'abcd' does not conform to regular expression '\\d+'"]}
371 | ]
372 | )
373 |
374 | @patch("aws_lambda_decorators.decorators.LOGGER")
375 | def test_can_not_validate_non_pythonic_var_name(self, mock_logger):
376 | @validate([
377 | ValidatedParameter(func_param_name="var 1", validators=[RegexValidator(r"\d+")]),
378 | ValidatedParameter(func_param_name="var2", validators=[RegexValidator(r"\d+")])
379 | ], True)
380 | def handler(var1=None, var2=None): # noqa: pylint - unused-argument
381 | return {}
382 |
383 | response = handler("20wq19", "abcd")
384 |
385 | self.assertEqual(400, response["statusCode"])
386 | self.assertEqual(
387 | "{\"message\": \"Error extracting parameters\"}",
388 | response["body"])
389 |
390 | mock_logger.error.assert_called_once_with("%s: %s in argument %s", "KeyError", "'var 1'", "var 1")
391 |
392 | def test_validate_does_not_raise_an_error_on_valid_variables(self):
393 | @validate([
394 | ValidatedParameter(func_param_name="var1", validators=[RegexValidator(r"\d+")]),
395 | ValidatedParameter(func_param_name="var2", validators=[RegexValidator(r"[ab]+")])
396 | ])
397 | def handler(var1, var2=None): # noqa: pylint - unused-argument
398 | return {}
399 |
400 | response = handler("2019", var2="abba")
401 | self.assertEqual({}, response)
402 |
403 | def test_extract_returns_400_on_type_error(self):
404 | path = "/a/b[json]/c"
405 | dictionary = {
406 | "a": {
407 | "b": {
408 | "c": "hello"
409 | }
410 | }
411 | }
412 |
413 | @extract([Parameter(path)])
414 | def handler(event, context, c=None): # noqa
415 | return {}
416 |
417 | response = handler(dictionary, None)
418 |
419 | self.assertEqual(400, response["statusCode"])
420 | self.assertEqual("{\"message\": \"Error extracting parameters\"}", response["body"])
421 |
422 | @patch("aws_lambda_decorators.decorators.LOGGER")
423 | def test_exception_handler_raises_exception(self, mock_logger):
424 |
425 | @handle_exceptions(handlers=[ExceptionHandler(KeyError, "msg")])
426 | def handler():
427 | raise KeyError("blank")
428 |
429 | response = handler() # noqa
430 |
431 | self.assertEqual(400, response["statusCode"])
432 | self.assertTrue("msg" in response["body"])
433 |
434 | mock_logger.error.assert_called_once_with("%s: %s", "msg", "'blank'")
435 |
436 | @patch("aws_lambda_decorators.decorators.LOGGER")
437 | def test_exception_handler_raises_exception_without_friendly_message(self, mock_logger):
438 |
439 | @handle_exceptions(handlers=[ExceptionHandler(KeyError)])
440 | def handler():
441 | raise KeyError("blank")
442 |
443 | response = handler() # noqa
444 |
445 | self.assertEqual(400, response["statusCode"])
446 | self.assertTrue("blank" in response["body"])
447 |
448 | mock_logger.error.assert_called_once_with("'blank'")
449 |
450 | @patch("aws_lambda_decorators.decorators.LOGGER")
451 | def test_exception_handler_raises_exception_with_status_code(self, mock_logger):
452 |
453 | @handle_exceptions(handlers=[ExceptionHandler(KeyError, "error", 500)])
454 | def handler():
455 | raise KeyError("blank")
456 |
457 | response = handler() # noqa
458 |
459 | self.assertEqual(500, response["statusCode"])
460 | self.assertEqual("""{"message": "error"}""", response["body"])
461 |
462 | mock_logger.error.assert_called_once_with("%s: %s", "error", "'blank'")
463 |
464 | @patch("aws_lambda_decorators.decorators.LOGGER")
465 | def test_exception_handler_raises_exception_with_inherited_exception(self, mock_logger):
466 |
467 | @handle_exceptions(handlers=[ExceptionHandler(Exception)])
468 | def handler():
469 | raise KeyError("blank")
470 |
471 | response = handler()
472 |
473 | self.assertEqual(400, response["statusCode"])
474 | self.assertTrue("blank" in response["body"])
475 |
476 | mock_logger.error.assert_called_once_with("'blank'")
477 |
478 | @patch("aws_lambda_decorators.decorators.LOGGER")
479 | def test_log_decorator_can_log_params(self, mock_logger): # noqa: pylint - no-self-use
480 |
481 | @log(True, False)
482 | def handler(event, context, an_other): # noqa
483 | return {}
484 |
485 | handler("first", "{\"tests\": \"a\"}", "another")
486 |
487 | mock_logger.info.assert_called_once_with(
488 | "Function: %s, Parameters: %s", "handler", ("first", "{\"tests\": \"a\"}", "another"))
489 |
490 | @patch("aws_lambda_decorators.decorators.LOGGER")
491 | def test_log_decorator_can_log_response(self, mock_logger): # noqa: pylint - no-self-use
492 |
493 | @log(False, True)
494 | def handler():
495 | return {"statusCode": 201}
496 |
497 | handler()
498 |
499 | mock_logger.info.assert_called_once_with("Function: %s, Response: %s", "handler", {"statusCode": 201})
500 |
501 | @patch("boto3.client")
502 | def test_get_valid_ssm_parameter(self, mock_boto_client):
503 | mock_ssm = MagicMock()
504 | mock_ssm.get_parameters.return_value = {
505 | "Parameters": [
506 | {
507 | "Name": "key",
508 | "Value": "tests"
509 | }
510 | ]
511 | }
512 | mock_boto_client.return_value = mock_ssm
513 |
514 | @extract_from_ssm([SSMParameter("key")])
515 | def handler(key=None):
516 | return key
517 |
518 | self.assertEqual(handler(), "tests")
519 |
520 | @patch("boto3.client")
521 | def test_get_valid_ssm_parameter_custom_name(self, mock_boto_client):
522 | mock_ssm = MagicMock()
523 | mock_ssm.get_parameters.return_value = {
524 | "Parameters": [
525 | {
526 | "Name": "key",
527 | "Value": "tests"
528 | }
529 | ]
530 | }
531 | mock_boto_client.return_value = mock_ssm
532 |
533 | @extract_from_ssm([SSMParameter("key", "custom")])
534 | def handler(custom=None):
535 | return custom
536 |
537 | self.assertEqual(handler(), "tests")
538 |
539 | @patch("boto3.client")
540 | def test_get_valid_ssm_parameters(self, mock_boto_client):
541 | mock_ssm = MagicMock()
542 | mock_ssm.get_parameters.return_value = {
543 | "Parameters": [
544 | {
545 | "Name": "key2",
546 | "Value": "test2"
547 | },
548 | {
549 | "Name": "key1",
550 | "Value": "test1"
551 | }
552 | ]
553 | }
554 | mock_boto_client.return_value = mock_ssm
555 |
556 | @extract_from_ssm([SSMParameter("key1", "key1"), SSMParameter("key2", "key2")])
557 | def handler(key1=None, key2=None):
558 | return [key1, key2]
559 |
560 | self.assertEqual(handler(), ["test1", "test2"])
561 |
562 | @patch("boto3.client")
563 | def test_get_ssm_parameter_missing_parameter_raises_client_error(self, mock_boto_client):
564 | mock_ssm = MagicMock()
565 | mock_ssm.get_parameters.side_effect = ClientError({}, "")
566 | mock_boto_client.return_value = mock_ssm
567 |
568 | @extract_from_ssm([SSMParameter("")])
569 | def handler(key=None):
570 | return key
571 |
572 | with self.assertRaises(ClientError):
573 | handler()
574 |
575 | @patch("boto3.client")
576 | def test_get_ssm_parameter_empty_key_container_raises_key_error(self, mock_boto_client):
577 | mock_ssm = MagicMock()
578 | mock_ssm.get_parameters.return_value = {
579 | }
580 | mock_boto_client.return_value = mock_ssm
581 |
582 | @extract_from_ssm([SSMParameter("")])
583 | def handler(key=None):
584 | return key
585 |
586 | with self.assertRaises(KeyError):
587 | handler()
588 |
589 | def test_body_gets_dumped_as_json(self):
590 |
591 | @response_body_as_json
592 | def handler():
593 | return {"statusCode": 200, "body": {"a": "b"}}
594 |
595 | response = handler()
596 |
597 | self.assertEqual(response, {"statusCode": 200, "body": "{\"a\": \"b\"}"})
598 |
599 | def test_body_dump_raises_exception_on_invalid_json(self):
600 |
601 | @response_body_as_json
602 | def handler():
603 | return {"statusCode": 200, "body": {"a"}}
604 |
605 | response = handler()
606 |
607 | self.assertEqual(
608 | response,
609 | {"statusCode": 500, "body": "{\"message\": \"Response body is not JSON serializable\"}"})
610 |
611 | def test_response_as_json_invalid_application_does_nothing(self):
612 |
613 | @response_body_as_json
614 | def handler():
615 | return {"statusCode": 200}
616 |
617 | response = handler()
618 |
619 | self.assertEqual(response, {"statusCode": 200})
620 |
621 | @patch("aws_lambda_decorators.decorators.LOGGER")
622 | def test_handle_all_exceptions(self, mock_logger):
623 |
624 | @handle_all_exceptions()
625 | def handler():
626 | raise KeyError("blank")
627 |
628 | response = handler() # noqa
629 |
630 | self.assertEqual(400, response["statusCode"])
631 | self.assertTrue("blank" in response["body"])
632 |
633 | mock_logger.error.assert_called_once_with("'blank'")
634 |
635 | def test_cors_no_headers_in_response(self):
636 |
637 | @cors(allow_origin="*", allow_methods="POST", allow_headers="Content-Type", max_age=12)
638 | def handler():
639 | return {}
640 |
641 | response = handler()
642 |
643 | self.assertEqual(response["headers"]["access-control-allow-headers"], "Content-Type")
644 | self.assertEqual(response["headers"]["access-control-allow-methods"], "POST")
645 | self.assertEqual(response["headers"]["access-control-allow-origin"], "*")
646 | self.assertEqual(response["headers"]["access-control-max-age"], 12)
647 |
648 | def test_cors_adds_correct_headers_only(self):
649 |
650 | @cors(allow_origin="*")
651 | def handler():
652 | return {}
653 |
654 | response = handler()
655 |
656 | self.assertEqual(response["headers"]["access-control-allow-origin"], "*")
657 | self.assertTrue("access-control-allow-methods" not in response["headers"])
658 | self.assertTrue("access-control-allow-headers" not in response["headers"])
659 | self.assertTrue("access-control-max-age" not in response["headers"])
660 |
661 | def test_cors_with_headers_in_response(self):
662 |
663 | @cors(allow_origin="*", allow_methods="POST", allow_headers="Content-Type", max_age=12)
664 | def handler():
665 | return {
666 | "headers": {
667 | "content-type": "application/json",
668 | "access-control-allow-origin": "http://example.com"
669 | }
670 | }
671 |
672 | response = handler()
673 |
674 | self.assertEqual(response["headers"]["access-control-allow-headers"], "Content-Type")
675 | self.assertEqual(response["headers"]["access-control-allow-methods"], "POST")
676 | self.assertEqual(response["headers"]["access-control-allow-origin"], "http://example.com,*")
677 | self.assertEqual(response["headers"]["access-control-max-age"], 12)
678 |
679 | def test_cors_with_headers_a_none_value_does_not_remove_headers(self):
680 |
681 | @cors(allow_origin=None)
682 | def handler():
683 | return {
684 | "headers": {
685 | "access-control-allow-origin": "http://example.com"
686 | }
687 | }
688 |
689 | response = handler()
690 |
691 | self.assertEqual(response["headers"]["access-control-allow-origin"], "http://example.com")
692 | self.assertTrue("access-control-allow-methods" not in response["headers"])
693 | self.assertTrue("access-control-allow-headers" not in response["headers"])
694 | self.assertTrue("access-control-max-age" not in response["headers"])
695 |
696 | def test_cors_with_headers_an_empty_value_does_not_remove_headers(self):
697 |
698 | @cors(allow_origin="")
699 | def handler():
700 | return {
701 | "headers": {
702 | "access-control-allow-origin": "http://example.com"
703 | }
704 | }
705 |
706 | response = handler()
707 |
708 | self.assertEqual(response["headers"]["access-control-allow-origin"], "http://example.com")
709 |
710 | def test_cors_with_uppercase_headers_in_response(self):
711 |
712 | @cors(allow_origin="*", allow_methods="POST", allow_headers="Content-Type", max_age=12)
713 | def handler():
714 | return {
715 | "Headers": {
716 | "content-type": "application/json",
717 | "Access-Control-Allow-Origin": "http://example.com"
718 | }
719 | }
720 |
721 | response = handler()
722 |
723 | self.assertEqual(response["Headers"]["access-control-allow-headers"], "Content-Type")
724 | self.assertEqual(response["Headers"]["access-control-allow-methods"], "POST")
725 | self.assertEqual(response["Headers"]["Access-Control-Allow-Origin"], "http://example.com,*")
726 | self.assertEqual(response["Headers"]["access-control-max-age"], 12)
727 |
728 | @patch("aws_lambda_decorators.decorators.LOGGER")
729 | def test_cors_invalid_max_age_logs_error(self, mock_logger):
730 |
731 | @cors(max_age="12")
732 | def handler():
733 | return {}
734 |
735 | response = handler()
736 |
737 | self.assertEqual(response["statusCode"], 500)
738 | self.assertEqual(response["body"], "{\"message\": \"Invalid value type in CORS header\"}")
739 |
740 | mock_logger.error.assert_called_once_with("Cannot set %s header to a non %s value",
741 | "access-control-max-age",
742 | int)
743 |
744 | @patch("aws_lambda_decorators.decorators.LOGGER")
745 | def test_cors_cannot_decorate_non_dict(self, mock_logger):
746 |
747 | @cors(allow_origin="*")
748 | def handler():
749 | return "I am a string"
750 |
751 | response = handler()
752 |
753 | self.assertEqual(response["statusCode"], 500) # noqa: pylint-invalid-sequence-index
754 | self.assertEqual(response["body"], "{\"message\": \"Invalid response type for CORS headers\"}") # noqa: pylint-invalid-sequence-index
755 |
756 | mock_logger.error.assert_called_once_with("Cannot add headers to a non dictionary response")
757 |
758 | def test_extract_returns_400_on_invalid_dictionary_schema(self):
759 | path = "/a"
760 | dictionary = {
761 | "a": {
762 | "b": {
763 | "c": 3
764 | }
765 | }
766 | }
767 |
768 | schema = Schema(
769 | {
770 | "b": And(dict, {
771 | "c": str
772 | })
773 | }
774 | )
775 |
776 | @extract([Parameter(path, "event", validators=[SchemaValidator(schema)])])
777 | def handler(event, context, c=None): # noqa
778 | return {}
779 |
780 | response = handler(dictionary, None)
781 |
782 | self.assertEqual(400, response["statusCode"])
783 | self.assertEqual(
784 | "{\"message\": [{\"a\": [\"\'{\'b\': {\'c\': 3}}\' "
785 | "does not validate against schema "
786 | "\'Schema({\'b\': And(, {\'c\': })})\'\"]}]}",
787 | response["body"])
788 |
789 | def test_extract_valid_dictionary_schema(self):
790 | path = "/a"
791 | dictionary = {
792 | "a": {
793 | "b": {
794 | "c": "d"
795 | }
796 | }
797 | }
798 |
799 | schema = Schema(
800 | {
801 | "b": And(dict, {
802 | "c": str
803 | }),
804 | Optional("j"): str
805 | }
806 | )
807 |
808 | @extract([Parameter(path, "event", validators=[SchemaValidator(schema)])])
809 | def handler(event, context, a=None): # noqa
810 | return a
811 |
812 | response = handler(dictionary, None)
813 |
814 | expected = {
815 | "b": {
816 | "c": "d"
817 | }
818 | }
819 | self.assertEqual(expected, response)
820 |
821 | def test_extract_schema_when_property_is_none(self):
822 | path = "/a/b"
823 | dictionary = {
824 | "a": {}
825 | }
826 |
827 | schema = Schema(
828 | {
829 | "b": And(dict, {
830 | "c": str
831 | }),
832 | Optional("j"): str
833 | }
834 | )
835 |
836 | @extract([Parameter(path, "event", validators=[SchemaValidator(schema)])])
837 | def handler(event, context, b=None): # noqa
838 | return b
839 |
840 | response = handler(dictionary, None)
841 |
842 | self.assertEqual(None, response)
843 |
844 | def test_extract_parameter_with_minimum(self):
845 | event = {
846 | "value": 20
847 | }
848 |
849 | @extract([Parameter("/value", "event", validators=[Minimum(10.0)])])
850 | def handler(event, value=None): # noqa: pylint - unused-argument
851 | return {}
852 |
853 | response = handler(event)
854 | self.assertEqual({}, response)
855 |
856 | def test_error_extracting_parameter_with_minimum(self):
857 | event = {
858 | "value": 5
859 | }
860 |
861 | @extract([Parameter("/value", "event", validators=[Minimum(10.0)])])
862 | def handler(event, value=None): # noqa: pylint - unused-argument
863 | return {}
864 |
865 | response = handler(event)
866 | self.assertEqual(response["statusCode"], 400)
867 | self.assertEqual(
868 | "{\"message\": [{\"value\": [\"\'5\' is less than minimum value \'10.0\'\"]}]}",
869 | response["body"])
870 |
871 | def test_error_extracting_non_numeric_parameter_with_minimum(self):
872 | event = {
873 | "value": "20"
874 | }
875 |
876 | @extract([Parameter("/value", "event", validators=[Minimum(10.0)])])
877 | def handler(event, value=None): # noqa: pylint - unused-argument
878 | return {}
879 |
880 | response = handler(event)
881 | self.assertEqual(response["statusCode"], 400)
882 | self.assertEqual(
883 | "{\"message\": [{\"value\": [\"\'20\' is less than minimum value \'10.0\'\"]}]}", response["body"])
884 |
885 | def test_extract_optional_null_parameter_with_minimum(self):
886 | event = {
887 | }
888 |
889 | @extract([Parameter("/value", "event", validators=[Minimum(10.0)])])
890 | def handler(event, value=None): # noqa: pylint - unused-argument
891 | return {}
892 |
893 | response = handler(event)
894 | self.assertEqual({}, response)
895 |
896 | def test_extract_mandatory_parameter_with_minimum(self):
897 | event = {
898 | "value": 20
899 | }
900 |
901 | @extract([Parameter("/value", "event", validators=[Minimum(10.0), Mandatory])])
902 | def handler(event, value=None): # noqa: pylint - unused-argument
903 | return {}
904 |
905 | response = handler(event)
906 | self.assertEqual({}, response)
907 |
908 | def test_extract_parameter_with_maximum(self):
909 | event = {
910 | "value": 20
911 | }
912 |
913 | @extract([Parameter("/value", "event", validators=[Maximum(100.0)])])
914 | def handler(event, value=None): # noqa: pylint - unused-argument
915 | return {}
916 |
917 | response = handler(event)
918 | self.assertEqual({}, response)
919 |
920 | def test_error_extracting_parameter_with_maximum(self):
921 | event = {
922 | "value": 105
923 | }
924 |
925 | @extract([Parameter("/value", "event", validators=[Maximum(100.0)])])
926 | def handler(event, value=None): # noqa: pylint - unused-argument
927 | return {}
928 |
929 | response = handler(event)
930 | self.assertEqual(response["statusCode"], 400)
931 | self.assertEqual(
932 | "{\"message\": [{\"value\": [\"\'105\' is greater than maximum value \'100.0\'\"]}]}", response["body"])
933 |
934 | def test_error_extracting_non_numeric_parameter_with_maximum(self):
935 | event = {
936 | "value": "20"
937 | }
938 |
939 | @extract([Parameter("/value", "event", validators=[Maximum(100.0)])])
940 | def handler(event, value=None): # noqa: pylint - unused-argument
941 | return {}
942 |
943 | response = handler(event)
944 | self.assertEqual(response["statusCode"], 400)
945 | self.assertEqual(
946 | "{\"message\": [{\"value\": [\"\'20\' is greater than maximum value \'100.0\'\"]}]}", response["body"])
947 |
948 | def test_extract_optional_null_parameter_with_maximum(self):
949 | event = {
950 | }
951 |
952 | @extract([Parameter("/value", "event", validators=[Maximum(10.0)])])
953 | def handler(event, value=None): # noqa: pylint - unused-argument
954 | return {}
955 |
956 | response = handler(event)
957 | self.assertEqual({}, response)
958 |
959 | def test_extract_mandatory_parameter_with_maximum(self):
960 | event = {
961 | "value": 20
962 | }
963 |
964 | @extract([Parameter("/value", "event", validators=[Maximum(100.0), Mandatory])])
965 | def handler(event, value=None): # noqa: pylint - unused-argument
966 | return {}
967 |
968 | response = handler(event)
969 | self.assertEqual({}, response)
970 |
971 | def test_extract_mandatory_parameter_with_range(self):
972 | event = {
973 | "value": 20
974 | }
975 |
976 | @extract([Parameter("/value", "event", validators=[Minimum(10.0), Maximum(100.0), Mandatory])])
977 | def handler(event, value=None): # noqa: pylint - unused-argument
978 | return {}
979 |
980 | response = handler(event)
981 | self.assertEqual({}, response)
982 |
983 | def test_extract_parameter_with_maximum_length(self):
984 | event = {
985 | "value": "correct"
986 | }
987 |
988 | @extract([Parameter("/value", "event", validators=[MaxLength(20)])])
989 | def handler(event, value=None): # noqa: pylint - unused-argument
990 | return {}
991 |
992 | response = handler(event)
993 | self.assertEqual({}, response)
994 |
995 | def test_error_extracting_parameter_with_max_length(self):
996 | event = {
997 | "value": "too long"
998 | }
999 |
1000 | @extract([Parameter("/value", "event", validators=[MaxLength(5)])])
1001 | def handler(event, value=None): # noqa: pylint - unused-argument
1002 | return {}
1003 |
1004 | response = handler(event)
1005 | self.assertEqual(response["statusCode"], 400)
1006 | self.assertEqual(
1007 | "{\"message\": [{\"value\": [\"\'too long\' is longer than maximum length \'5\'\"]}]}",
1008 | response["body"])
1009 |
1010 | def test_values_are_stringified_in_max_length_validator(self):
1011 | event = {
1012 | "value": 20
1013 | }
1014 |
1015 | @extract([Parameter("/value", "event", validators=[MaxLength(5)])])
1016 | def handler(event, value=None): # noqa: pylint - unused-argument
1017 | return {}
1018 |
1019 | response = handler(event)
1020 | self.assertEqual({}, response)
1021 |
1022 | def test_extract_optional_null_parameter_with_max_length(self):
1023 | event = {
1024 | }
1025 |
1026 | @extract([Parameter("/value", "event", validators=[MaxLength(5)])])
1027 | def handler(event, value=None): # noqa: pylint - unused-argument
1028 | return {}
1029 |
1030 | response = handler(event)
1031 | self.assertEqual({}, response)
1032 |
1033 | def test_extract_mandatory_parameter_with_max_length(self):
1034 | event = {
1035 | "value": "aa"
1036 | }
1037 |
1038 | @extract([Parameter("/value", "event", validators=[MaxLength(5), Mandatory])])
1039 | def handler(event, value=None): # noqa: pylint - unused-argument
1040 | return {}
1041 |
1042 | response = handler(event)
1043 | self.assertEqual({}, response)
1044 |
1045 | def test_extract_parameter_with_mainimum_length(self):
1046 | event = {
1047 | "value": "correct"
1048 | }
1049 |
1050 | @extract([Parameter("/value", "event", validators=[MinLength(4)])])
1051 | def handler(event, value=None): # noqa: pylint - unused-argument
1052 | return {}
1053 |
1054 | response = handler(event)
1055 | self.assertEqual({}, response)
1056 |
1057 | def test_error_extracting_parameter_with_min_length(self):
1058 | event = {
1059 | "value": "too short"
1060 | }
1061 |
1062 | @extract([Parameter("/value", "event", validators=[MinLength(15)])])
1063 | def handler(event, value=None): # noqa: pylint - unused-argument
1064 | return {}
1065 |
1066 | response = handler(event)
1067 | self.assertEqual(response["statusCode"], 400)
1068 | self.assertEqual(
1069 | "{\"message\": [{\"value\": [\"\'too short\' is shorter than minimum length \'15\'\"]}]}",
1070 | response["body"])
1071 |
1072 | def test_values_are_stringified_in_min_length_validator(self):
1073 | event = {
1074 | "value": 20
1075 | }
1076 |
1077 | @extract([Parameter("/value", "event", validators=[MinLength(1)])])
1078 | def handler(event, value=None): # noqa: pylint - unused-argument
1079 | return {}
1080 |
1081 | response = handler(event)
1082 | self.assertEqual({}, response)
1083 |
1084 | def test_extract_optional_null_parameter_with_min_length(self):
1085 | event = {
1086 | }
1087 |
1088 | @extract([Parameter("/value", "event", validators=[MinLength(5)])])
1089 | def handler(event, value=None): # noqa: pylint - unused-argument
1090 | return {}
1091 |
1092 | response = handler(event)
1093 | self.assertEqual({}, response)
1094 |
1095 | def test_extract_mandatory_parameter_with_min_length(self):
1096 | event = {
1097 | "value": "aa"
1098 | }
1099 |
1100 | @extract([Parameter("/value", "event", validators=[MaxLength(2), Mandatory])])
1101 | def handler(event, value=None): # noqa: pylint - unused-argument
1102 | return {}
1103 |
1104 | response = handler(event)
1105 | self.assertEqual({}, response)
1106 |
1107 | def test_extract_mandatory_parameter_with_length_range(self):
1108 | event = {
1109 | "value": "right in the middle"
1110 | }
1111 |
1112 | @extract([Parameter("/value", "event", validators=[MinLength(10), MaxLength(100), Mandatory])])
1113 | def handler(event, value=None): # noqa: pylint - unused-argument
1114 | return {}
1115 |
1116 | response = handler(event)
1117 | self.assertEqual({}, response)
1118 |
1119 | @patch("aws_lambda_decorators.decorators.LOGGER")
1120 | def test_exit_on_error_false_bundles_all_errors(self, mock_logger):
1121 | path_1 = "/a/b/c"
1122 | path_2 = "/a/b/d"
1123 | path_3 = "/a/b/e"
1124 | path_4 = "/a/b/f"
1125 | path_5 = "/a/b/g"
1126 | dictionary = {
1127 | "a": {
1128 | "b": {
1129 | "e": 23,
1130 | "f": 15,
1131 | "g": "a"
1132 | }
1133 | }
1134 | }
1135 |
1136 | schema = Schema(
1137 | {
1138 | "g": int
1139 | }
1140 | )
1141 |
1142 | @extract([
1143 | Parameter(path_1, "event", validators=[Mandatory], var_name="c"),
1144 | Parameter(path_2, "event", validators=[Mandatory]),
1145 | Parameter(path_3, "event", validators=[Minimum(30)]),
1146 | Parameter(path_4, "event", validators=[Maximum(10)]),
1147 | Parameter(path_5, "event", validators=[
1148 | RegexValidator(r"[0-9]+"),
1149 | RegexValidator(r"[1][0-9]+"),
1150 | SchemaValidator(schema),
1151 | MinLength(2),
1152 | MaxLength(0)
1153 | ])
1154 | ], True)
1155 | def handler(event, context, c=None, d=None): # noqa: pylint - unused-argument
1156 | return {}
1157 |
1158 | response = handler(dictionary, None)
1159 | self.assertEqual(400, response["statusCode"])
1160 | self.assertEqual(
1161 | "{\"message\": [{\"c\": [\"Missing mandatory value\"]}, "
1162 | "{\"d\": [\"Missing mandatory value\"]}, "
1163 | "{\"e\": [\"\'23\' is less than minimum value \'30\'\"]}, "
1164 | "{\"f\": [\"\'15\' is greater than maximum value \'10\'\"]}, "
1165 | "{\"g\": [\"\'a\' does not conform to regular expression \'[0-9]+\'\", "
1166 | "\"\'a\' does not conform to regular expression \'[1][0-9]+\'\", "
1167 | "\"\'a\' does not validate against schema \'Schema({\'g\': })\'\", "
1168 | "\"\'a\' is shorter than minimum length \'2\'\", "
1169 | "\"\'a\' is longer than maximum length \'0\'\""
1170 | "]}]}",
1171 | response["body"])
1172 |
1173 | mock_logger.error.assert_called_once_with(
1174 | "Error validating parameters. Errors: %s",
1175 | [
1176 | {"c": ["Missing mandatory value"]},
1177 | {"d": ["Missing mandatory value"]},
1178 | {"e": ["'23' is less than minimum value '30'"]},
1179 | {"f": ["'15' is greater than maximum value '10'"]},
1180 | {"g": [
1181 | "'a' does not conform to regular expression '[0-9]+'",
1182 | "'a' does not conform to regular expression '[1][0-9]+'",
1183 | "'a' does not validate against schema 'Schema({'g': })'",
1184 | "'a' is shorter than minimum length '2'",
1185 | "'a' is longer than maximum length '0'"
1186 | ]}
1187 | ]
1188 | )
1189 |
1190 | def test_group_errors_true_returns_ok(self):
1191 | path = "/a/b"
1192 | dictionary = {
1193 | "a": {
1194 | "b": "hello"
1195 | }
1196 | }
1197 |
1198 | @extract([Parameter(path, "event", validators=[Mandatory])], True)
1199 | def handler(event, context, b=None): # noqa
1200 | return b
1201 |
1202 | response = handler(dictionary, None)
1203 |
1204 | self.assertEqual("hello", response)
1205 |
1206 | def test_mandatory_parameter_with_default_returns_error_on_empty(self):
1207 | event = {
1208 | "var": ""
1209 | }
1210 |
1211 | @extract([
1212 | Parameter("/var", "event", validators=[Mandatory], default="hello")
1213 | ])
1214 | def handler(event, context, var=None): # noqa: pylint - unused-argument
1215 | return {}
1216 |
1217 | response = handler(event, None)
1218 |
1219 | self.assertEqual(response["statusCode"], 400)
1220 | self.assertEqual("{\"message\": [{\"var\": [\"Missing mandatory value\"]}]}", response["body"])
1221 |
1222 | def test_group_errors_true_on_extract_from_event_returns_ok(self):
1223 | path = "/a/b"
1224 | dictionary = {
1225 | "a": {
1226 | "b": "hello"
1227 | }
1228 | }
1229 |
1230 | @extract_from_event([Parameter(path, validators=[Mandatory])], True)
1231 | def handler(event, context, b=None): # noqa
1232 | return b
1233 |
1234 | response = handler(dictionary, None)
1235 |
1236 | self.assertEqual("hello", response)
1237 |
1238 | def test_group_errors_true_on_extract_from_context_returns_ok(self):
1239 | path = "/a/b"
1240 | dictionary = {
1241 | "a": {
1242 | "b": "hello"
1243 | }
1244 | }
1245 |
1246 | @extract_from_context([Parameter(path, validators=[Mandatory])], True)
1247 | def handler(event, context, b=None): # noqa
1248 | return b
1249 |
1250 | response = handler(None, dictionary)
1251 |
1252 | self.assertEqual("hello", response)
1253 |
1254 | @patch("aws_lambda_decorators.decorators.LOGGER")
1255 | def test_can_output_custom_error_message_on_validation_failure(self, mock_logger):
1256 | path_1 = "/a/b/c"
1257 | path_2 = "/a/b/d"
1258 | path_3 = "/a/b/e"
1259 | path_4 = "/a/b/f"
1260 | path_5 = "/a/b/g"
1261 | dictionary = {
1262 | "a": {
1263 | "b": {
1264 | "e": 23,
1265 | "f": 15,
1266 | "g": "a"
1267 | }
1268 | }
1269 | }
1270 |
1271 | schema = Schema(
1272 | {
1273 | "g": int
1274 | }
1275 | )
1276 |
1277 | @extract([
1278 | Parameter(path_1, "event", validators=[Mandatory("Missing c")], var_name="c"),
1279 | Parameter(path_2, "event", validators=[Mandatory("Missing d")]),
1280 | Parameter(path_3, "event", validators=[Minimum(30, "Bad e value {value}, should be at least {condition}")]),
1281 | Parameter(path_4, "event", validators=[Maximum(10, "Bad f")]),
1282 | Parameter(path_5, "event", validators=[
1283 | RegexValidator(r"[0-9]+", "Bad g regex 1"),
1284 | RegexValidator(r"[1][0-9]+", "Bad g regex 2"),
1285 | SchemaValidator(schema, "Bad g schema"),
1286 | MinLength(2, "Bad g min length"),
1287 | MaxLength(0, "Bad g max length")
1288 | ])
1289 | ], True)
1290 | def handler(event, context, c=None, d=None): # noqa: pylint - unused-argument
1291 | return {}
1292 |
1293 | response = handler(dictionary, None)
1294 |
1295 | self.assertEqual(400, response["statusCode"])
1296 | self.assertEqual(
1297 | "{\"message\": [{\"c\": [\"Missing c\"]}, "
1298 | "{\"d\": [\"Missing d\"]}, "
1299 | "{\"e\": [\"Bad e value 23, should be at least 30\"]}, "
1300 | "{\"f\": [\"Bad f\"]}, "
1301 | "{\"g\": [\"Bad g regex 1\", "
1302 | "\"Bad g regex 2\", "
1303 | "\"Bad g schema\", "
1304 | "\"Bad g min length\", "
1305 | "\"Bad g max length\""
1306 | "]}]}",
1307 | response["body"])
1308 |
1309 | mock_logger.error.assert_called_once_with(
1310 | "Error validating parameters. Errors: %s",
1311 | [
1312 | {"c": ["Missing c"]},
1313 | {"d": ["Missing d"]},
1314 | {"e": ["Bad e value 23, should be at least 30"]},
1315 | {"f": ["Bad f"]},
1316 | {"g": [
1317 | "Bad g regex 1",
1318 | "Bad g regex 2",
1319 | "Bad g schema",
1320 | "Bad g min length",
1321 | "Bad g max length"
1322 | ]}
1323 | ]
1324 | )
1325 |
1326 | def test_extract_returns_400_on_missing_mandatory_key_with_regex(self):
1327 | path = "/a/b/c"
1328 | dictionary = {
1329 | "a": {
1330 | "b": {
1331 | }
1332 | }
1333 | }
1334 |
1335 | @extract([Parameter(path, "event", validators=[Mandatory, RegexValidator("[0-9]+")])], group_errors=True)
1336 | def handler(event, context, c=None): # noqa
1337 | return {}
1338 |
1339 | response = handler(dictionary, None)
1340 |
1341 | self.assertEqual(400, response["statusCode"])
1342 | self.assertEqual("{\"message\": [{\"c\": [\"Missing mandatory value\"]}]}", response["body"])
1343 |
1344 | def test_extract_nulls_are_returned(self):
1345 | path = "/a/b"
1346 | dictionary = {
1347 | "a": {
1348 | }
1349 | }
1350 |
1351 | @extract([Parameter(path, "event", default=None)], allow_none_defaults=True)
1352 | def handler(event, context, **kwargs): # noqa
1353 | return kwargs["b"]
1354 |
1355 | response = handler(dictionary, None)
1356 |
1357 | self.assertEqual(None, response)
1358 |
1359 | def test_extract_nulls_raises_exception_when_extracted_from_kwargs_if_allow_none_defaults_is_false(self):
1360 | path = "/a/b"
1361 | dictionary = {
1362 | "a": {
1363 | }
1364 | }
1365 |
1366 | @extract([Parameter(path, "event", default=None)], allow_none_defaults=False)
1367 | def handler(event, context, **kwargs): # noqa
1368 | return kwargs["b"]
1369 |
1370 | with self.assertRaises(KeyError):
1371 | handler(dictionary, None)
1372 |
1373 | def test_extract_nulls_preserve_signature_defaults(self):
1374 | path = "/a/b"
1375 | dictionary = {
1376 | "a": {
1377 | }
1378 | }
1379 |
1380 | @extract([Parameter(path, "event")])
1381 | def handler(event, context, b="Hello"): # noqa
1382 | return b
1383 |
1384 | response = handler(dictionary, None)
1385 |
1386 | self.assertEqual("Hello", response)
1387 |
1388 | def test_extract_nulls_default_on_decorator_takes_precedence(self):
1389 | path = "/a/b"
1390 | dictionary = {
1391 | "a": {
1392 | }
1393 | }
1394 |
1395 | @extract([Parameter(path, "event", default="bye")])
1396 | def handler(event, context, b="Hello"): # noqa
1397 | return b
1398 |
1399 | response = handler(dictionary, None)
1400 |
1401 | self.assertEqual("bye", response)
1402 |
1403 | @patch("aws_lambda_decorators.decorators.LOGGER")
1404 | def test_extract_returns_400_on_invalid_bool_type(self, mock_logger):
1405 | path = "/a/b/c"
1406 | dictionary = {
1407 | "a": {
1408 | "b": {
1409 | "c": 1
1410 | }
1411 | }
1412 | }
1413 |
1414 | @extract([Parameter(path, "event", [Type(bool)])])
1415 | def handler(event, context, c=None): # noqa
1416 | return {}
1417 |
1418 | response = handler(dictionary, None)
1419 | self.assertEqual(400, response["statusCode"])
1420 | self.assertEqual("{\"message\": [{\"c\": [\"\'1\' is not of type \'bool'\"]}]}", response["body"])
1421 |
1422 | mock_logger.error.assert_called_once_with(
1423 | "Error validating parameters. Errors: %s",
1424 | [{"c": ["'1' is not of type 'bool'"]}]
1425 | )
1426 |
1427 | @patch("aws_lambda_decorators.decorators.LOGGER")
1428 | def test_extract_returns_400_on_invalid_float_type(self, mock_logger):
1429 | path = "/a/b/c"
1430 | dictionary = {
1431 | "a": {
1432 | "b": {
1433 | "c": 1
1434 | }
1435 | }
1436 | }
1437 |
1438 | @extract([Parameter(path, "event", [Type(float)])])
1439 | def handler(event, context, c=None): # noqa
1440 | return {}
1441 |
1442 | response = handler(dictionary, None)
1443 | self.assertEqual(400, response["statusCode"])
1444 | self.assertEqual("{\"message\": [{\"c\": [\"\'1\' is not of type \'float'\"]}]}", response["body"])
1445 |
1446 | mock_logger.error.assert_called_once_with(
1447 | "Error validating parameters. Errors: %s",
1448 | [{"c": ["'1' is not of type 'float'"]}]
1449 | )
1450 |
1451 | def test_type_validator_returns_true_when_none_is_passed_in(self):
1452 | path = "/a/b/c"
1453 | dictionary = {
1454 | "a": {
1455 | "b": {
1456 | "c": None
1457 | }
1458 | }
1459 | }
1460 |
1461 | @extract([Parameter(path, "event", [Type(float)])])
1462 | def handler(event, context, c=None): # noqa
1463 | return c
1464 |
1465 | response = handler(dictionary, None)
1466 | self.assertEqual(None, response)
1467 |
1468 | def test_extract_succeeds_with_valid_type_validation(self):
1469 | path = "/a/b/c"
1470 | dictionary = {
1471 | "a": {
1472 | "b": {
1473 | "c": 1
1474 | }
1475 | }
1476 | }
1477 |
1478 | @extract([Parameter(path, "event", [Type(int)])])
1479 | def handler(event, context, c=None): # noqa
1480 | return c
1481 |
1482 | response = handler(dictionary, None)
1483 | self.assertEqual(1, response)
1484 |
1485 | @patch("aws_lambda_decorators.decorators.LOGGER")
1486 | def test_extract_returns_400_on_value_not_in_list(self, mock_logger):
1487 | path = "/a/b/c"
1488 | dictionary = {
1489 | "a": {
1490 | "b": {
1491 | "c": "Hello"
1492 | }
1493 | }
1494 | }
1495 |
1496 | @extract([Parameter(path, "event", [EnumValidator("bye", "test", "another")])])
1497 | def handler(event, context, c=None): # noqa
1498 | return {}
1499 |
1500 | response = handler(dictionary, None)
1501 | self.assertEqual(400, response["statusCode"])
1502 | self.assertEqual(
1503 | "{\"message\": [{\"c\": [\"\'Hello\' is not in list \'(\'bye\', \'test\', \'another\')'\"]}]}",
1504 | response["body"])
1505 |
1506 | mock_logger.error.assert_called_once_with(
1507 | "Error validating parameters. Errors: %s",
1508 | [{"c": ["'Hello' is not in list '('bye', 'test', 'another')'"]}]
1509 | )
1510 |
1511 | def test_extract_suceeds_with_valid_enum_validation(self):
1512 | path = "/a/b/c"
1513 | dictionary = {
1514 | "a": {
1515 | "b": {
1516 | "c": 123
1517 | }
1518 | }
1519 | }
1520 |
1521 | @extract([Parameter(path, "event", [EnumValidator("Hello", 123)])])
1522 | def handler(event, context, c=None): # noqa
1523 | return c
1524 |
1525 | response = handler(dictionary, None)
1526 | self.assertEqual(123, response)
1527 |
1528 | def test_enum_validator_returns_true_when_none_is_passed_in(self):
1529 | path = "/a/b/c"
1530 | dictionary = {
1531 | "a": {
1532 | "b": {
1533 | "c": None
1534 | }
1535 | }
1536 | }
1537 |
1538 | @extract([Parameter(path, "event", [EnumValidator("Test", "another")])])
1539 | def handler(event, context, c=None): # noqa
1540 | return c
1541 |
1542 | response = handler(dictionary, None)
1543 | self.assertEqual(None, response)
1544 |
1545 | def test_extract_from_event_missing_parameter_path(self):
1546 | event = {
1547 | "body": "{}"
1548 | }
1549 |
1550 | @extract_from_event(parameters=[Parameter(path="body[json]/optional/value", default="Hello")])
1551 | def handler(event, context, **kwargs): # noqa
1552 | return {
1553 | "statusCode": HTTPStatus.OK,
1554 | "body": json.dumps(kwargs)
1555 | }
1556 |
1557 | expected_body = json.dumps({
1558 | "value": "Hello"
1559 | })
1560 |
1561 | response = handler(event, None)
1562 |
1563 | self.assertEqual(HTTPStatus.OK, response["statusCode"])
1564 | self.assertEqual(expected_body, response["body"])
1565 |
1566 | def test_extract_non_empty_parameter(self):
1567 | event = {
1568 | "value": 20
1569 | }
1570 |
1571 | @extract([Parameter("/value", "event", validators=[NonEmpty])])
1572 | def handler(event, value=None): # noqa: pylint - unused-argument
1573 | return value
1574 |
1575 | response = handler(event)
1576 | self.assertEqual(20, response)
1577 |
1578 | def test_extract_missing_non_empty_parameter(self):
1579 | event = {
1580 | "a": 20
1581 | }
1582 |
1583 | @extract([Parameter("/b", "event", validators=[NonEmpty])])
1584 | def handler(event, b=None): # noqa: pylint - unused-argument
1585 | return b
1586 |
1587 | response = handler(event)
1588 | self.assertEqual(None, response)
1589 |
1590 | @patch("aws_lambda_decorators.decorators.LOGGER")
1591 | def test_extract_non_empty_parameter_that_is_empty(self, mock_logger):
1592 | event = {
1593 | "a": {}
1594 | }
1595 |
1596 | @extract([Parameter("/a", "event", validators=[NonEmpty])])
1597 | def handler(event, a=None): # noqa: pylint - unused-argument
1598 | return {}
1599 |
1600 | response = handler(event, None)
1601 |
1602 | self.assertEqual(400, response["statusCode"])
1603 | self.assertEqual(
1604 | "{\"message\": [{\"a\": [\"Value is empty\"]}]}",
1605 | response["body"])
1606 |
1607 | mock_logger.error.assert_called_once_with(
1608 | "Error validating parameters. Errors: %s",
1609 | [{"a": ["Value is empty"]}]
1610 | )
1611 |
1612 | @patch("aws_lambda_decorators.decorators.LOGGER")
1613 | def test_extract_non_empty_parameter_that_is_empty_with_custom_message(self, mock_logger):
1614 | event = {
1615 | "a": {}
1616 | }
1617 |
1618 | @extract([Parameter("/a", "event", validators=[NonEmpty("The value was empty")])])
1619 | def handler(event, a=None): # noqa: pylint - unused-argument
1620 | return {}
1621 |
1622 | response = handler(event, None)
1623 |
1624 | self.assertEqual(400, response["statusCode"])
1625 | self.assertEqual(
1626 | "{\"message\": [{\"a\": [\"The value was empty\"]}]}",
1627 | response["body"])
1628 |
1629 | mock_logger.error.assert_called_once_with(
1630 | "Error validating parameters. Errors: %s",
1631 | [{"a": ["The value was empty"]}]
1632 | )
1633 |
1634 | def test_extract_date_parameter(self):
1635 | event = {
1636 | "a": "2001-01-01 00:00:00"
1637 | }
1638 |
1639 | @extract([Parameter("/a", "event", validators=[DateValidator("%Y-%m-%d %H:%M:%S")])])
1640 | def handler(event, a=None): # noqa: pylint - unused-argument
1641 | return a
1642 |
1643 | response = handler(event)
1644 | self.assertEqual("2001-01-01 00:00:00", response)
1645 |
1646 | @patch("aws_lambda_decorators.decorators.LOGGER")
1647 | def test_extract_date_parameter_fails_on_invalid_date(self, mock_logger):
1648 | event = {
1649 | "a": "2001-01-01 35:00:00"
1650 | }
1651 |
1652 | @extract([Parameter("/a", "event", validators=[DateValidator("%Y-%m-%d %H:%M:%S")])])
1653 | def handler(event, a=None): # noqa: pylint - unused-argument
1654 | return {}
1655 |
1656 | response = handler(event, None)
1657 |
1658 | self.assertEqual(400, response["statusCode"])
1659 | self.assertEqual("{\"message\": [{\"a\": [\"'2001-01-01 35:00:00' is not a '%Y-%m-%d %H:%M:%S' date\"]}]}",
1660 | response["body"])
1661 |
1662 | mock_logger.error.assert_called_once_with(
1663 | "Error validating parameters. Errors: %s",
1664 | [{"a": ["'2001-01-01 35:00:00' is not a '%Y-%m-%d %H:%M:%S' date"]}]
1665 | )
1666 |
1667 | @patch("aws_lambda_decorators.decorators.LOGGER")
1668 | def test_extract_date_parameter_fails_with_custom_error(self, mock_logger):
1669 | event = {
1670 | "a": "2001-01-01 35:00:00"
1671 | }
1672 |
1673 | @extract([Parameter("/a", "event", validators=[DateValidator("%Y-%m-%d %H:%M:%S", "Not a valid date!")])])
1674 | def handler(event, a=None): # noqa: pylint - unused-argument
1675 | return {}
1676 |
1677 | response = handler(event, None)
1678 |
1679 | self.assertEqual(400, response["statusCode"])
1680 | self.assertEqual("{\"message\": [{\"a\": [\"Not a valid date!\"]}]}", response["body"])
1681 |
1682 | mock_logger.error.assert_called_once_with(
1683 | "Error validating parameters. Errors: %s",
1684 | [{"a": ["Not a valid date!"]}]
1685 | )
1686 |
1687 | def test_extract_date_parameter_valid_on_empty(self):
1688 | event = {
1689 | "a": None
1690 | }
1691 |
1692 | @extract([Parameter("/a", "event", validators=[DateValidator("%Y-%m-%d %H:%M:%S")])])
1693 | def handler(event, a=None): # noqa: pylint - unused-argument
1694 | return a
1695 |
1696 | response = handler(event)
1697 | self.assertEqual(None, response)
1698 |
1699 | def test_extract_currency_parameter(self):
1700 | event = {
1701 | "a": "GBP"
1702 | }
1703 |
1704 | @extract([Parameter("/a", "event", [CurrencyValidator])])
1705 | def handler(event, a=None): # noqa: pylint - unused-argument
1706 | return a
1707 |
1708 | response = handler(event)
1709 | self.assertEqual("GBP", response)
1710 |
1711 | def test_currency_validator_returns_true_when_none_is_passed_in(self):
1712 | path = "/a/b/c"
1713 | dictionary = {
1714 | "a": {
1715 | "b": {
1716 | "c": None
1717 | }
1718 | }
1719 | }
1720 |
1721 | @extract([Parameter(path, "event", [CurrencyValidator])])
1722 | def handler(event, c=None): # noqa
1723 | return c
1724 |
1725 | response = handler(dictionary, None)
1726 | self.assertEqual(None, response)
1727 |
1728 | def test_currency_validator_returns_false_when_invalid_code_passed_in(self):
1729 | event = {
1730 | "a": "GBT"
1731 | }
1732 |
1733 | @extract([Parameter("/a", "event", [CurrencyValidator])])
1734 | def handler(event, a=None): # noqa: pylint - unused-argument
1735 | return {}
1736 |
1737 | response = handler(event)
1738 | self.assertEqual(400, response["statusCode"])
1739 | self.assertEqual("{\"message\": [{\"a\": [\"\'GBT\' is not a valid currency code.\"]}]}",
1740 | response["body"])
1741 |
1742 | def test_currency_validator_can_be_called_non_statically(self):
1743 | event = {
1744 | "a": "GBP"
1745 | }
1746 |
1747 | @extract([Parameter("/a", "event", [CurrencyValidator()])])
1748 | def handler(event, a=None): # noqa: pylint - unused-argument
1749 | return a
1750 |
1751 | response = handler(event)
1752 | self.assertEqual("GBP", response)
1753 |
1754 | def test_can_apply_transformation(self):
1755 | event = {
1756 | "a": "2"
1757 | }
1758 |
1759 | @extract([Parameter("/a", "event", transform=float)])
1760 | def handler(event, a=None): # noqa: pylint - unused-argument
1761 | return a
1762 |
1763 | response = handler(event)
1764 | self.assertEqual(2, response)
1765 |
1766 | def test_apply_transformation_on_none_value(self):
1767 | event = {
1768 | "a": None
1769 | }
1770 |
1771 | @extract([Parameter("/a", "event", transform=float)])
1772 | def handler(event, a=None): # noqa: pylint - unused-argument
1773 | return a
1774 |
1775 | response = handler(event)
1776 | self.assertEqual(None, response)
1777 |
1778 | def test_apply_custom_transformation(self):
1779 | event = {
1780 | "a": "2"
1781 | }
1782 |
1783 | def to_float(arg):
1784 | return float(arg)
1785 |
1786 | @extract([Parameter("/a", "event", transform=to_float)])
1787 | def handler(event, a=None): # noqa: pylint - unused-argument
1788 | return a
1789 |
1790 | response = handler(event)
1791 | self.assertEqual(2, response)
1792 |
1793 | @patch("aws_lambda_decorators.decorators.LOGGER")
1794 | def test_apply_custom_transformation_with_error_handling(self, mock_logger):
1795 | event = {
1796 | "a": "abc"
1797 | }
1798 |
1799 | def to_float(arg):
1800 | try:
1801 | return float(arg)
1802 | except Exception:
1803 | raise Exception(f"Custom error message: value '{arg}' cannot be converted to float")
1804 |
1805 | @extract([Parameter("/a", "event", transform=to_float)])
1806 | def handler(event, a=None): # noqa: pylint - unused-argument
1807 | return {}
1808 |
1809 | response = handler(event)
1810 | self.assertEqual(400, response["statusCode"])
1811 | self.assertEqual("{\"message\": \"Error extracting parameters\"}", response["body"])
1812 |
1813 | mock_logger.error.assert_called_once_with("%s: %s in argument %s for path %s",
1814 | "Exception",
1815 | "Custom error message: value 'abc' cannot be converted to float",
1816 | "event",
1817 | "/a")
1818 |
1819 | @patch("aws_lambda_decorators.decorators.LOGGER")
1820 | def test_apply_invalid_transformation_raises_error(self, mock_logger):
1821 | event = {
1822 | "a": "abc"
1823 | }
1824 |
1825 | @extract([Parameter("/a", "event", transform=float)])
1826 | def handler(event, a=None): # noqa: pylint - unused-argument
1827 | return {}
1828 |
1829 | response = handler(event)
1830 | self.assertEqual(400, response["statusCode"])
1831 | self.assertEqual("{\"message\": \"Error extracting parameters\"}", response["body"])
1832 |
1833 | mock_logger.error.assert_called_once_with("%s: %s in argument %s for path %s",
1834 | "ValueError",
1835 | "could not convert string to float: 'abc'",
1836 | "event",
1837 | "/a")
1838 |
1839 | @patch("boto3.client")
1840 | def test_push_ws_errors_missing_parameter(self, mock_boto3_client):
1841 | get_websocket_endpoint.cache_clear()
1842 |
1843 | event = {
1844 | "requestContext": {
1845 | "connectionId": "test_connection_id"
1846 | },
1847 | "body": json.dumps({
1848 | "invalid_property": "invalid_value"
1849 | })
1850 | }
1851 |
1852 | @push_ws_errors(websocket_endpoint_url="https://api_id.execute_id.region.amazonaws.com/Prod")
1853 | @extract_from_event(parameters=[
1854 | Parameter(path="body[json]/valid_property", validators=[Mandatory])
1855 | ])
1856 | def lambda_func(event, context, valid_property=None): # noqa: pylint - unused-argument
1857 | return {"statusCode": HTTPStatus.OK}
1858 |
1859 | response = lambda_func(event, None)
1860 |
1861 | expected_data = {
1862 | "type": "error",
1863 | "statusCode": 400,
1864 | "message": [{
1865 | "valid_property": ["Missing mandatory value"]
1866 | }]
1867 | }
1868 |
1869 | self.assertEqual(response["statusCode"], HTTPStatus.BAD_REQUEST)
1870 |
1871 | mock_boto3_client.assert_called_once_with(
1872 | "apigatewaymanagementapi",
1873 | endpoint_url="https://api_id.execute_id.region.amazonaws.com/Prod"
1874 | )
1875 |
1876 | mock_boto3_client.return_value.post_to_connection.assert_called_once_with(
1877 | ConnectionId="test_connection_id",
1878 | Data=json.dumps(expected_data)
1879 | )
1880 |
1881 | @patch("boto3.client")
1882 | def test_push_ws_errors_no_action_on_success(self, mock_boto3_client):
1883 | get_websocket_endpoint.cache_clear()
1884 |
1885 | event = {
1886 | "requestContext": {
1887 | "connectionId": "test_connection_id"
1888 | }
1889 | }
1890 |
1891 | @push_ws_errors(websocket_endpoint_url="https://api_id.execute_id.region.amazonaws.com/Prod")
1892 | def lambda_func(event, context): # noqa: pylint - unused-argument
1893 | return {"statusCode": HTTPStatus.OK}
1894 |
1895 | response = lambda_func(event, None)
1896 |
1897 | self.assertEqual(response["statusCode"], 200)
1898 |
1899 | mock_boto3_client.return_value.post_to_connection.assert_not_called()
1900 |
1901 | @patch("boto3.client")
1902 | def test_push_ws_errors_no_connection_id(self, mock_boto3_client):
1903 | get_websocket_endpoint.cache_clear()
1904 |
1905 | event = {
1906 | "body": {
1907 | "property": "value"
1908 | }
1909 | }
1910 |
1911 | @push_ws_errors(websocket_endpoint_url="https://api_id.execute_id.region.amazonaws.com/Prod")
1912 | def lambda_func(event, context): # noqa: pylint - unused-argument
1913 | return {"statusCode": HTTPStatus.BAD_REQUEST}
1914 |
1915 | response = lambda_func(event, None)
1916 |
1917 | self.assertEqual(response["statusCode"], 400)
1918 |
1919 | mock_boto3_client.return_value.post_to_connection.assert_not_called()
1920 |
1921 | @patch("boto3.client")
1922 | def test_push_ws_response(self, mock_boto3_client):
1923 | get_websocket_endpoint.cache_clear()
1924 |
1925 | event = {
1926 | "requestContext": {
1927 | "connectionId": "test_connection_id"
1928 | }
1929 | }
1930 |
1931 | @push_ws_response(websocket_endpoint_url="https://api_id.execute_id.region.amazonaws.com/Prod")
1932 | def lambda_func(event, context): # noqa: pylint - unused-argument
1933 | return {
1934 | "statusCode": HTTPStatus.OK,
1935 | "body": "Hello, world!"
1936 | }
1937 |
1938 | response = lambda_func(event, None)
1939 |
1940 | self.assertEqual(response["statusCode"], 200)
1941 | self.assertEqual(response["body"], "Hello, world!")
1942 |
1943 | mock_boto3_client.return_value.post_to_connection.assert_called_once_with(
1944 | ConnectionId="test_connection_id",
1945 | Data="{\"statusCode\": 200, \"body\": \"Hello, world!\"}"
1946 | )
1947 |
1948 | @patch("boto3.client")
1949 | def test_push_ws_response_no_connection_id(self, mock_boto3_client):
1950 | get_websocket_endpoint.cache_clear()
1951 |
1952 | event = {
1953 | "body": "Hello, world!"
1954 | }
1955 |
1956 | @push_ws_response(websocket_endpoint_url="https://api_id.execute_id.region.amazonaws.com/Prod")
1957 | @extract_from_event(parameters=[
1958 | Parameter(path="body")
1959 | ])
1960 | def lambda_func(event, context, body=None): # noqa: pylint - unused-argument
1961 | return {
1962 | "statusCode": HTTPStatus.OK,
1963 | "body": body
1964 | }
1965 |
1966 | response = lambda_func(event, None)
1967 |
1968 | self.assertEqual(response["statusCode"], 200)
1969 | self.assertEqual(response["body"], "Hello, world!")
1970 |
1971 | mock_boto3_client.return_value.post_to_connection.assert_not_called()
1972 |
1973 | def test_hsts_returns_headers_in_response(self):
1974 |
1975 | @hsts()
1976 | def handler():
1977 | return {}
1978 |
1979 | response = handler()
1980 |
1981 | self.assertEqual(response["headers"]["Strict-Transport-Security"], "max-age=63072000")
1982 |
1983 | def test_hsts_returns_headers_in_response_with_custom_age(self):
1984 |
1985 | @hsts(max_age=121212)
1986 | def handler():
1987 | return {}
1988 |
1989 | response = handler()
1990 |
1991 | self.assertEqual(response["headers"]["Strict-Transport-Security"], "max-age=121212")
1992 |
1993 | def test_hsts_function_returns_non_dictionary(self):
1994 |
1995 | @hsts()
1996 | def handler():
1997 | return "I am a string"
1998 |
1999 | response = handler()
2000 |
2001 | self.assertEqual(response["statusCode"], HTTPStatus.INTERNAL_SERVER_ERROR) # noqa: pylint-invalid-sequence-index
2002 | self.assertEqual(response["body"], "{\"message\": \"Invalid response type for HSTS header\"}") # noqa: pylint-invalid-sequence-index
2003 |
2004 |
2005 | class IsolatedDecoderTests(unittest.TestCase):
2006 | # Tests have been named so they run in a specific order
2007 |
2008 | ID_PATTERN = "^[0-9a-f]{8}-[0-9a-f]{4}-[1-5][0-9a-f]{3}-[89ab][0-9a-f]{3}-[0-9a-f]{12}$"
2009 |
2010 | PARAMETERS = [
2011 | Parameter(path="pathParameters/test_id", validators=[Mandatory, RegexValidator(ID_PATTERN)]),
2012 | Parameter(path="body[json]/name", validators=[MaxLength(255)])
2013 | ]
2014 |
2015 | def test_01_extract_from_event_400(self):
2016 | event = {
2017 | "pathParameters": {}
2018 | }
2019 |
2020 | @extract_from_event(parameters=self.PARAMETERS, group_errors=True, allow_none_defaults=False)
2021 | def handler(event, context, **kwargs): # noqa
2022 | return kwargs
2023 |
2024 | response = handler(event, None)
2025 | self.assertEqual(HTTPStatus.BAD_REQUEST, response["statusCode"])
2026 |
2027 | def test_02_extract_from_event_200(self):
2028 | test_id = str(uuid4())
2029 |
2030 | event = {
2031 | "pathParameters": {
2032 | "test_id": test_id
2033 | },
2034 | "body": json.dumps({
2035 | "name": "Gird"
2036 | })
2037 | }
2038 |
2039 | @extract_from_event(parameters=self.PARAMETERS, group_errors=True, allow_none_defaults=False)
2040 | def handler(event, context, **kwargs): # noqa
2041 | return {
2042 | "statusCode": HTTPStatus.OK,
2043 | "body": json.dumps(kwargs)
2044 | }
2045 |
2046 | expected_body = json.dumps({
2047 | "test_id": test_id,
2048 | "name": "Gird"
2049 | })
2050 |
2051 | response = handler(event, None)
2052 |
2053 | self.assertEqual(HTTPStatus.OK, response["statusCode"])
2054 | self.assertEqual(expected_body, response["body"])
2055 |
--------------------------------------------------------------------------------
/tests/test_param.py:
--------------------------------------------------------------------------------
1 | import unittest
2 | from aws_lambda_decorators.classes import Parameter
3 |
4 |
5 | class ParamTests(unittest.TestCase):
6 |
7 | def test_annotations_from_key_returns_none_when_no_annotations(self):
8 | key = "simple"
9 | response = Parameter.get_annotations_from_key(key)
10 | self.assertTrue(response[0] == "simple")
11 | self.assertTrue(response[1] is None)
12 |
13 | def test_annotations_from_key_returns_annotation(self):
14 | key = "simple[annotation]"
15 | response = Parameter.get_annotations_from_key(key)
16 | self.assertTrue(response[0] == "simple")
17 | self.assertTrue(response[1] == "annotation")
18 |
--------------------------------------------------------------------------------
/tests/test_utils.py:
--------------------------------------------------------------------------------
1 | import unittest
2 | from aws_lambda_decorators.validators import Mandatory, RegexValidator
3 | from aws_lambda_decorators.utils import is_type_in_list
4 |
5 |
6 | class UtilsTests(unittest.TestCase):
7 |
8 | def test_is_type_in_list_returns_false_if_item_of_type_missing(self):
9 | items = [Mandatory, Mandatory, Mandatory]
10 | self.assertFalse(is_type_in_list(RegexValidator, items))
11 |
12 | def test_is_type_in_list_returns_true_if_item_of_type_exists(self):
13 | items = [Mandatory, RegexValidator(), Mandatory]
14 | self.assertTrue(is_type_in_list(RegexValidator, items))
15 |
--------------------------------------------------------------------------------
/tools/dev/coverage.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | coverage run --branch --include='aws_lambda_decorators/*.py' -m unittest tests/test_*.py
4 | coverage report -m --fail-under=100 --omit=tests/*,it/*,env*
--------------------------------------------------------------------------------