├── .github ├── ISSUE_TEMPLATE │ ├── bug.md │ ├── doc.md │ └── feature-request.md └── pull_request_template.md ├── .gitignore ├── .readthedocs.yml ├── CONTRIBUTING.md ├── LICENSE ├── MANIFEST.in ├── NOTICE ├── README.rst ├── THIRD-PARTY ├── VERSION ├── doc ├── Makefile ├── choicerules.rst ├── compute.rst ├── conf.py ├── images │ ├── create.png │ └── execute.png ├── index.rst ├── make.bat ├── pipelines.rst ├── placeholders.rst ├── readmelink.rst ├── requirements.txt ├── sagemaker.rst ├── services.rst ├── states.rst ├── steps.rst └── workflow.rst ├── requirements.txt ├── setup.cfg ├── setup.py ├── src └── stepfunctions │ ├── __init__.py │ ├── exceptions.py │ ├── inputs │ ├── __init__.py │ ├── placeholders.py │ └── utils.py │ ├── steps │ ├── __init__.py │ ├── choice_rule.py │ ├── compute.py │ ├── fields.py │ ├── integration_resources.py │ ├── sagemaker.py │ ├── service.py │ ├── states.py │ └── utils.py │ ├── template │ ├── __init__.py │ ├── pipeline │ │ ├── __init__.py │ │ ├── common.py │ │ ├── inference.py │ │ └── train.py │ └── utils.py │ └── workflow │ ├── __init__.py │ ├── cloudformation.py │ ├── stepfunctions.py │ ├── utils.py │ └── widgets │ ├── __init__.py │ ├── events_table.py │ ├── executions_table.py │ ├── graph.py │ ├── utils.py │ └── workflows_table.py ├── tests ├── __init__.py ├── data │ ├── one_p_mnist │ │ ├── mnist.npy.gz │ │ ├── mnist.pkl.gz │ │ ├── sklearn_mnist_estimator.py │ │ ├── sklearn_mnist_preprocessor.py │ │ └── transform_input.csv │ ├── pytorch_mnist │ │ ├── MNIST │ │ │ └── processed │ │ │ │ ├── test.pt │ │ │ │ └── training.pt │ │ └── mnist.py │ ├── sklearn_mnist │ │ ├── mnist.py │ │ ├── test │ │ │ └── test.npz │ │ └── train │ │ │ └── train.npz │ └── sklearn_processing │ │ └── preprocessor.py ├── integ │ ├── __init__.py │ ├── conftest.py │ ├── resources │ │ ├── SageMaker-TrustPolicy.json │ │ ├── StepFunctionsMLWorkflowExecutionFullAccess-Policy.json │ │ └── StepFunctionsMLWorkflowExecutionFullAccess-TrustPolicy.json │ ├── test_inference_pipeline.py │ ├── test_sagemaker_steps.py │ ├── test_state_machine_definition.py │ ├── test_training_pipeline_estimators.py │ ├── test_training_pipeline_framework_estimator.py │ ├── timeout.py │ └── utils.py └── unit │ ├── __init__.py │ ├── test_choice_rule.py │ ├── test_compute_steps.py │ ├── test_graph.py │ ├── test_pipeline.py │ ├── test_placeholders.py │ ├── test_placeholders_with_steps.py │ ├── test_placeholders_with_workflows.py │ ├── test_sagemaker_steps.py │ ├── test_service_steps.py │ ├── test_steps.py │ ├── test_steps_utils.py │ ├── test_template_utils.py │ ├── test_widget_utils.py │ ├── test_widgets.py │ ├── test_workflow.py │ ├── test_workflow_utils.py │ ├── utils.py │ └── widgets │ ├── __init__.py │ ├── test_events_table_widget.py │ ├── test_executions_table_widget.py │ └── test_workflows_table_widget.py └── tox.ini /.github/ISSUE_TEMPLATE/bug.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: "\U0001F41B Bug Report" 3 | about: Report a bug 4 | title: "short issue description" 5 | labels: bug, needs-triage 6 | --- 7 | 8 | 11 | 12 | 13 | 14 | 15 | ### Reproduction Steps 16 | 17 | 23 | 24 | ### What did you expect to happen? 25 | 26 | 29 | 30 | ### What actually happened? 31 | 32 | 35 | 36 | 37 | ### Environment 38 | 39 | - **AWS Step Functions Data Science Python SDK version :** 40 | - **Python Version:** 41 | 42 | ### Other 43 | 44 | 45 | 46 | 47 | 48 | 49 | --- 50 | 51 | This is :bug: Bug Report -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/doc.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: "📕 Documentation Issue" 3 | about: Issue in the reference documentation 4 | title: "short issue description" 5 | labels: feature-request, documentation, needs-triage 6 | --- 7 | 8 | 11 | 12 | 15 | 16 | 17 | 18 | 21 | 22 | 23 | 24 | 25 | 26 | --- 27 | 28 | This is a 📕 documentation issue 29 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/feature-request.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: "\U0001F680 Feature Request" 3 | about: Request a new feature 4 | title: "short issue description" 5 | labels: feature-request, needs-triage 6 | --- 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | ### Use Case 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | ### Proposed Solution 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | ### Other 31 | 32 | 36 | 37 | 38 | 39 | 40 | 41 | * [ ] :wave: I may be able to implement this feature request 42 | * [ ] :warning: This feature might incur a breaking change 43 | 44 | --- 45 | 46 | This is a :rocket: Feature Request 47 | -------------------------------------------------------------------------------- /.github/pull_request_template.md: -------------------------------------------------------------------------------- 1 | ### Description 2 | 3 | Please include a summary of the change being made. 4 | 5 | Fixes #(issue) 6 | 7 | ### Why is the change necessary? 8 | 9 | What capability does it enable? What problem does it solve? 10 | 11 | ### Solution 12 | 13 | Please include an overview of the solution. Discuss trade-offs made, caveats, alternatives, etc. 14 | 15 | ### Testing 16 | 17 | How was this change tested? 18 | 19 | ---- 20 | 21 | ### Pull Request Checklist 22 | 23 | Please check all boxes (including N/A items) 24 | 25 | #### Testing 26 | 27 | - [ ] Unit tests added 28 | - [ ] Integration test added 29 | - [ ] Manual testing - why was it necessary? could it be automated? 30 | 31 | #### Documentation 32 | 33 | - [ ] __docs__: All relevant [docs](https://github.com/aws/aws-step-functions-data-science-sdk-python/tree/main/doc) updated 34 | - [ ] __docstrings__: All public APIs documented 35 | 36 | ### Title and description 37 | 38 | - [ ] __Change type__: Title is prefixed with change type: and follows [conventional commits](https://www.conventionalcommits.org/en/v1.0.0/) 39 | - [ ] __References__: Indicate issues fixed via: `Fixes #xxx` 40 | 41 | ---- 42 | 43 | By submitting this pull request, I confirm that my contribution is made under the terms of the Apache-2.0 license. 44 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # celery beat schedule file 95 | celerybeat-schedule 96 | 97 | # SageMath parsed files 98 | *.sage.py 99 | 100 | # Environments 101 | .env 102 | .venv 103 | env/ 104 | venv/ 105 | ENV/ 106 | env.bak/ 107 | venv.bak/ 108 | 109 | # Spyder project settings 110 | .spyderproject 111 | .spyproject 112 | 113 | # Rope project settings 114 | .ropeproject 115 | 116 | # mkdocs documentation 117 | /site 118 | 119 | # mypy 120 | .mypy_cache/ 121 | .dmypy.json 122 | dmypy.json 123 | 124 | # Pyre type checker 125 | .pyre/ 126 | 127 | .DS_Store 128 | .vscode/ 129 | -------------------------------------------------------------------------------- /.readthedocs.yml: -------------------------------------------------------------------------------- 1 | # ReadTheDocs environment customization to allow us to use conda to install 2 | # libraries which have C dependencies for the doc build. See: 3 | # https://docs.readthedocs.io/en/latest/config-file/v2.html 4 | 5 | version: 2 6 | 7 | python: 8 | version: 3.6 9 | install: 10 | - method: pip 11 | path: . 12 | - requirements: doc/requirements.txt 13 | 14 | sphinx: 15 | configuration: doc/conf.py 16 | fail_on_warning: false # http://www.sphinx-doc.org/en/master/man/sphinx-build.html#id6 17 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | 4 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 5 | 6 | 1. Definitions. 7 | 8 | "License" shall mean the terms and conditions for use, reproduction, 9 | and distribution as defined by Sections 1 through 9 of this document. 10 | 11 | "Licensor" shall mean the copyright owner or entity authorized by 12 | the copyright owner that is granting the License. 13 | 14 | "Legal Entity" shall mean the union of the acting entity and all 15 | other entities that control, are controlled by, or are under common 16 | control with that entity. For the purposes of this definition, 17 | "control" means (i) the power, direct or indirect, to cause the 18 | direction or management of such entity, whether by contract or 19 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 20 | outstanding shares, or (iii) beneficial ownership of such entity. 21 | 22 | "You" (or "Your") shall mean an individual or Legal Entity 23 | exercising permissions granted by this License. 24 | 25 | "Source" form shall mean the preferred form for making modifications, 26 | including but not limited to software source code, documentation 27 | source, and configuration files. 28 | 29 | "Object" form shall mean any form resulting from mechanical 30 | transformation or translation of a Source form, including but 31 | not limited to compiled object code, generated documentation, 32 | and conversions to other media types. 33 | 34 | "Work" shall mean the work of authorship, whether in Source or 35 | Object form, made available under the License, as indicated by a 36 | copyright notice that is included in or attached to the work 37 | (an example is provided in the Appendix below). 38 | 39 | "Derivative Works" shall mean any work, whether in Source or Object 40 | form, that is based on (or derived from) the Work and for which the 41 | editorial revisions, annotations, elaborations, or other modifications 42 | represent, as a whole, an original work of authorship. For the purposes 43 | of this License, Derivative Works shall not include works that remain 44 | separable from, or merely link (or bind by name) to the interfaces of, 45 | the Work and Derivative Works thereof. 46 | 47 | "Contribution" shall mean any work of authorship, including 48 | the original version of the Work and any modifications or additions 49 | to that Work or Derivative Works thereof, that is intentionally 50 | submitted to Licensor for inclusion in the Work by the copyright owner 51 | or by an individual or Legal Entity authorized to submit on behalf of 52 | the copyright owner. For the purposes of this definition, "submitted" 53 | means any form of electronic, verbal, or written communication sent 54 | to the Licensor or its representatives, including but not limited to 55 | communication on electronic mailing lists, source code control systems, 56 | and issue tracking systems that are managed by, or on behalf of, the 57 | Licensor for the purpose of discussing and improving the Work, but 58 | excluding communication that is conspicuously marked or otherwise 59 | designated in writing by the copyright owner as "Not a Contribution." 60 | 61 | "Contributor" shall mean Licensor and any individual or Legal Entity 62 | on behalf of whom a Contribution has been received by Licensor and 63 | subsequently incorporated within the Work. 64 | 65 | 2. Grant of Copyright License. Subject to the terms and conditions of 66 | this License, each Contributor hereby grants to You a perpetual, 67 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 68 | copyright license to reproduce, prepare Derivative Works of, 69 | publicly display, publicly perform, sublicense, and distribute the 70 | Work and such Derivative Works in Source or Object form. 71 | 72 | 3. Grant of Patent License. Subject to the terms and conditions of 73 | this License, each Contributor hereby grants to You a perpetual, 74 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 75 | (except as stated in this section) patent license to make, have made, 76 | use, offer to sell, sell, import, and otherwise transfer the Work, 77 | where such license applies only to those patent claims licensable 78 | by such Contributor that are necessarily infringed by their 79 | Contribution(s) alone or by combination of their Contribution(s) 80 | with the Work to which such Contribution(s) was submitted. If You 81 | institute patent litigation against any entity (including a 82 | cross-claim or counterclaim in a lawsuit) alleging that the Work 83 | or a Contribution incorporated within the Work constitutes direct 84 | or contributory patent infringement, then any patent licenses 85 | granted to You under this License for that Work shall terminate 86 | as of the date such litigation is filed. 87 | 88 | 4. Redistribution. You may reproduce and distribute copies of the 89 | Work or Derivative Works thereof in any medium, with or without 90 | modifications, and in Source or Object form, provided that You 91 | meet the following conditions: 92 | 93 | (a) You must give any other recipients of the Work or 94 | Derivative Works a copy of this License; and 95 | 96 | (b) You must cause any modified files to carry prominent notices 97 | stating that You changed the files; and 98 | 99 | (c) You must retain, in the Source form of any Derivative Works 100 | that You distribute, all copyright, patent, trademark, and 101 | attribution notices from the Source form of the Work, 102 | excluding those notices that do not pertain to any part of 103 | the Derivative Works; and 104 | 105 | (d) If the Work includes a "NOTICE" text file as part of its 106 | distribution, then any Derivative Works that You distribute must 107 | include a readable copy of the attribution notices contained 108 | within such NOTICE file, excluding those notices that do not 109 | pertain to any part of the Derivative Works, in at least one 110 | of the following places: within a NOTICE text file distributed 111 | as part of the Derivative Works; within the Source form or 112 | documentation, if provided along with the Derivative Works; or, 113 | within a display generated by the Derivative Works, if and 114 | wherever such third-party notices normally appear. The contents 115 | of the NOTICE file are for informational purposes only and 116 | do not modify the License. You may add Your own attribution 117 | notices within Derivative Works that You distribute, alongside 118 | or as an addendum to the NOTICE text from the Work, provided 119 | that such additional attribution notices cannot be construed 120 | as modifying the License. 121 | 122 | You may add Your own copyright statement to Your modifications and 123 | may provide additional or different license terms and conditions 124 | for use, reproduction, or distribution of Your modifications, or 125 | for any such Derivative Works as a whole, provided Your use, 126 | reproduction, and distribution of the Work otherwise complies with 127 | the conditions stated in this License. 128 | 129 | 5. Submission of Contributions. Unless You explicitly state otherwise, 130 | any Contribution intentionally submitted for inclusion in the Work 131 | by You to the Licensor shall be under the terms and conditions of 132 | this License, without any additional terms or conditions. 133 | Notwithstanding the above, nothing herein shall supersede or modify 134 | the terms of any separate license agreement you may have executed 135 | with Licensor regarding such Contributions. 136 | 137 | 6. Trademarks. This License does not grant permission to use the trade 138 | names, trademarks, service marks, or product names of the Licensor, 139 | except as required for reasonable and customary use in describing the 140 | origin of the Work and reproducing the content of the NOTICE file. 141 | 142 | 7. Disclaimer of Warranty. Unless required by applicable law or 143 | agreed to in writing, Licensor provides the Work (and each 144 | Contributor provides its Contributions) on an "AS IS" BASIS, 145 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 146 | implied, including, without limitation, any warranties or conditions 147 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 148 | PARTICULAR PURPOSE. You are solely responsible for determining the 149 | appropriateness of using or redistributing the Work and assume any 150 | risks associated with Your exercise of permissions under this License. 151 | 152 | 8. Limitation of Liability. In no event and under no legal theory, 153 | whether in tort (including negligence), contract, or otherwise, 154 | unless required by applicable law (such as deliberate and grossly 155 | negligent acts) or agreed to in writing, shall any Contributor be 156 | liable to You for damages, including any direct, indirect, special, 157 | incidental, or consequential damages of any character arising as a 158 | result of this License or out of the use or inability to use the 159 | Work (including but not limited to damages for loss of goodwill, 160 | work stoppage, computer failure or malfunction, or any and all 161 | other commercial damages or losses), even if such Contributor 162 | has been advised of the possibility of such damages. 163 | 164 | 9. Accepting Warranty or Additional Liability. While redistributing 165 | the Work or Derivative Works thereof, You may choose to offer, 166 | and charge a fee for, acceptance of support, warranty, indemnity, 167 | or other liability obligations and/or rights consistent with this 168 | License. However, in accepting such obligations, You may act only 169 | on Your own behalf and on Your sole responsibility, not on behalf 170 | of any other Contributor, and only if You agree to indemnify, 171 | defend, and hold each Contributor harmless for any liability 172 | incurred by, or claims asserted against, such Contributor by reason 173 | of your accepting any such warranty or additional liability. 174 | 175 | END OF TERMS AND CONDITIONS -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | include VERSION 2 | include LICENSE 3 | -------------------------------------------------------------------------------- /NOTICE: -------------------------------------------------------------------------------- 1 | AWS Step Functions Data Science SDK Python 2 | Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved. 3 | -------------------------------------------------------------------------------- /VERSION: -------------------------------------------------------------------------------- 1 | 2.3.0 2 | -------------------------------------------------------------------------------- /doc/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line. 5 | SPHINXOPTS = 6 | SPHINXBUILD = python -msphinx 7 | SPHINXPROJ = stepfunctions 8 | SOURCEDIR = . 9 | BUILDDIR = _build 10 | 11 | # Put it first so that "make" without argument is like "make help". 12 | help: 13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 14 | 15 | .PHONY: help Makefile 16 | 17 | # Catch-all target: route all unknown targets to Sphinx using the new 18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 19 | %: Makefile 20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) -------------------------------------------------------------------------------- /doc/choicerules.rst: -------------------------------------------------------------------------------- 1 | Choice Rules 2 | ============= 3 | 4 | This module defines the `choice rules `__ for a Choice state. 5 | 6 | Use the :py:meth:`~stepfunctions.steps.Choice.add_choice` method to add a branch to a Choice step. 7 | 8 | .. code-block:: python 9 | 10 | my_choice_state.add_choice( 11 | rule=ChoiceRule.BooleanEquals(variable=previous_state.output()["Success"], value=True), 12 | next_step=happy_path 13 | ) 14 | my_choice_state.add_choice( 15 | ChoiceRule.BooleanEquals(variable=previous_state.output()["Success"], value=False), 16 | next_step=sad_state 17 | ) 18 | 19 | 20 | In this example, choice rules are added to the Choice state 21 | ``my_choice_state`` using :py:meth:`~stepfunctions.steps.Choice.add_choice`. 22 | Logic in a `Choice state `__ 23 | is implemented with the help of Choice Rules. A Choice Rule encapsulates a 24 | comparison, which contains the following: 25 | 26 | - An input **variable** to compare 27 | - The **type** of comparison 28 | - The **value** to compare the variable to 29 | 30 | The type of comparison is abstracted by the classes provided in this module. Multiple choice rules can be 31 | compounded together using the :py:meth:`~stepfunctions.steps.choice_rule.ChoiceRule.And` or 32 | :py:meth:`~stepfunctions.steps.choice_rule.ChoiceRule.Or` classes. A choice rule can be negated using 33 | the :py:meth:`~stepfunctions.steps.choice_rule.ChoiceRule.Not` class. 34 | 35 | .. autoclass:: stepfunctions.steps.choice_rule.BaseRule 36 | 37 | .. autoclass:: stepfunctions.steps.choice_rule.Rule 38 | 39 | .. autoclass:: stepfunctions.steps.choice_rule.CompoundRule 40 | 41 | .. autoclass:: stepfunctions.steps.choice_rule.NotRule 42 | 43 | .. autoclass:: stepfunctions.steps.choice_rule.ChoiceRule 44 | -------------------------------------------------------------------------------- /doc/compute.rst: -------------------------------------------------------------------------------- 1 | Compute 2 | ========= 3 | 4 | This module provides classes to build steps that integrate with AWS Lambda, AWS Batch, AWS Glue, and Amazon ECS. 5 | 6 | .. autoclass:: stepfunctions.steps.compute.LambdaStep 7 | 8 | .. autoclass:: stepfunctions.steps.compute.GlueStartJobRunStep 9 | 10 | .. autoclass:: stepfunctions.steps.compute.BatchSubmitJobStep 11 | 12 | .. autoclass:: stepfunctions.steps.compute.EcsRunTaskStep 13 | -------------------------------------------------------------------------------- /doc/conf.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"). 4 | # You may not use this file except in compliance with the License. 5 | # A copy of the License is located at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # or in the "license" file accompanying this file. This file is distributed 10 | # on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either 11 | # express or implied. See the License for the specific language governing 12 | # permissions and limitations under the License. 13 | from __future__ import absolute_import 14 | 15 | import os 16 | import pkg_resources 17 | import sys 18 | from datetime import datetime 19 | from unittest.mock import MagicMock 20 | 21 | class Mock(MagicMock): 22 | @classmethod 23 | def __getattr__(cls, name): 24 | """ 25 | Args: 26 | name: 27 | """ 28 | if name == "__version__": 29 | return "1.4.0" 30 | else: 31 | return MagicMock() 32 | 33 | 34 | MOCK_MODULES = [ 35 | 'boto3', 36 | 'sagemaker', 37 | 'sagemaker.model', 38 | 'sagemaker.model_monitor', 39 | 'sagemaker.pipeline', 40 | 'sagemaker.sklearn', 41 | 'sagemaker.sklearn.estimator', 42 | 'sagemaker.utils', 43 | 'sagemaker.workflow', 44 | 'sagemaker.workflow.airflow', 45 | 'IPython', 46 | 'IPython.core', 47 | 'IPython.core.display' 48 | ] 49 | 50 | sys.modules.update((mod_name, Mock()) for mod_name in MOCK_MODULES) 51 | 52 | project = u"stepfunctions" 53 | version = pkg_resources.require(project)[0].version 54 | 55 | # Add any Sphinx extension module names here, as strings. They can be extensions 56 | # coming with Sphinx (named 'sphinx.ext.*') or your custom ones. 57 | extensions = [ 58 | "sphinx.ext.autodoc", 59 | "sphinx.ext.autosummary", 60 | "sphinx.ext.intersphinx", 61 | "sphinx.ext.todo", 62 | "sphinx.ext.coverage", 63 | "sphinx.ext.napoleon", 64 | "sphinx.ext.autosectionlabel", 65 | ] 66 | 67 | # Add any paths that contain templates here, relative to this directory. 68 | templates_path = ["_templates"] 69 | 70 | source_suffix = ".rst" # The suffix of source filenames. 71 | master_doc = "index" # The master toctree document. 72 | 73 | copyright = u"%s, Amazon" % datetime.now().year 74 | 75 | # The full version, including alpha/beta/rc tags. 76 | release = version 77 | 78 | # List of directories, relative to source directory, that shouldn't be searched 79 | # for source files. 80 | exclude_trees = ["_build"] 81 | 82 | pygments_style = "default" 83 | 84 | autoclass_content = "both" 85 | autodoc_default_flags = ["show-inheritance", "members", "no-undoc-members"] 86 | autodoc_member_order = "bysource" 87 | 88 | if "READTHEDOCS" in os.environ: 89 | html_theme = "default" 90 | else: 91 | html_theme = "haiku" 92 | 93 | html_static_path = [] 94 | htmlhelp_basename = "%sdoc" % project 95 | 96 | intersphinx_mapping = { 97 | "https://docs.python.org/3.6/": None, 98 | "https://boto3.readthedocs.io/en/latest": None, 99 | "https://sagemaker.readthedocs.io/en/stable": None, 100 | } 101 | 102 | # autosummary 103 | autosummary_generate = True 104 | 105 | # autosectionlabel 106 | autosectionlabel_prefix_document = True 107 | -------------------------------------------------------------------------------- /doc/images/create.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aws/aws-step-functions-data-science-sdk-python/0e970ccde4ef7cc532f80040598e8d4d55f6486d/doc/images/create.png -------------------------------------------------------------------------------- /doc/images/execute.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aws/aws-step-functions-data-science-sdk-python/0e970ccde4ef7cc532f80040598e8d4d55f6486d/doc/images/execute.png -------------------------------------------------------------------------------- /doc/index.rst: -------------------------------------------------------------------------------- 1 | ########################################## 2 | AWS Step Functions Data Science Python SDK 3 | ########################################## 4 | 5 | The AWS Step Functions Data Science SDK is an open source library that allows data 6 | scientists to easily create workflows that process and publish machine learning 7 | models using AWS SageMaker and AWS Step Functions. You can create 8 | multi-step machine learning workflows in Python that orchestrate AWS 9 | infrastructure at scale, without having to provision and integrate the AWS 10 | services separately. 11 | 12 | In addition to creating production-ready workflows directly in Python, the AWS 13 | Step Functions Data Science SDK allows you to copy that workflow, experiment with 14 | new options, and then put the refined workflow in production. 15 | 16 | 17 | ***************** 18 | Overview 19 | ***************** 20 | 21 | The following topics cover how the AWS Step Functions Data Science SDK integrates 22 | with with AWS services to help process and publish machine learning (ML) models 23 | using Python. 24 | 25 | .. toctree:: 26 | :maxdepth: 3 27 | 28 | readmelink 29 | 30 | 31 | ***************** 32 | API Documentation 33 | ***************** 34 | 35 | The AWS Step Functions Data Science SDK consists of a few primary modules. 36 | 37 | .. toctree:: 38 | :maxdepth: 3 39 | 40 | steps 41 | pipelines 42 | workflow 43 | placeholders 44 | -------------------------------------------------------------------------------- /doc/make.bat: -------------------------------------------------------------------------------- 1 | @ECHO OFF 2 | 3 | pushd %~dp0 4 | 5 | REM Command file for Sphinx documentation 6 | 7 | if "%SPHINXBUILD%" == "" ( 8 | set SPHINXBUILD=sphinx-build 9 | ) 10 | set SOURCEDIR=. 11 | set BUILDDIR=_build 12 | 13 | if "%1" == "" goto help 14 | 15 | %SPHINXBUILD% >NUL 2>NUL 16 | if errorlevel 9009 ( 17 | echo. 18 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx 19 | echo.installed, then set the SPHINXBUILD environment variable to point 20 | echo.to the full path of the 'sphinx-build' executable. Alternatively you 21 | echo.may add the Sphinx directory to PATH. 22 | echo. 23 | echo.If you don't have Sphinx installed, grab it from 24 | echo.http://sphinx-doc.org/ 25 | exit /b 1 26 | ) 27 | 28 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 29 | goto end 30 | 31 | :help 32 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 33 | 34 | :end 35 | popd 36 | -------------------------------------------------------------------------------- /doc/pipelines.rst: -------------------------------------------------------------------------------- 1 | Pipelines 2 | ========= 3 | 4 | This module provides pre-built templates that make it easy to build generic data 5 | science workflows. The templates are constructed from steps. 6 | 7 | .. autoclass:: stepfunctions.template.pipeline.train.TrainingPipeline 8 | 9 | .. autoclass:: stepfunctions.template.pipeline.inference.InferencePipeline 10 | -------------------------------------------------------------------------------- /doc/placeholders.rst: -------------------------------------------------------------------------------- 1 | Placeholders 2 | ============= 3 | 4 | Once defined, a workflow is static unless you update it explicitly. But, you can pass 5 | input to workflow executions. You can have dynamic values 6 | that you use in the **parameters** fields of the steps in your workflow. For this, 7 | the AWS Step Functions Data Science SDK provides a way to define placeholders to pass around when you 8 | create your workflow. There are 2 mechanisms for passing dynamic values in a workflow. 9 | 10 | The first mechanism is a global input to the workflow execution. This input is 11 | accessible to all the steps in the workflow. The SDK provides :py:meth:`stepfunctions.inputs.ExecutionInput` 12 | to define the schema for this input, and to access the values in your workflow. 13 | 14 | .. code-block:: python 15 | 16 | # Create an instance of ExecutionInput class, and define a schema. Defining 17 | # a schema is optional, but it is a good practice 18 | 19 | my_execution_input = ExecutionInput(schema={ 20 | 'myDynamicInput': str 21 | }) 22 | 23 | lambda_state = LambdaStep( 24 | state_id="MyLambdaStep", 25 | parameters={ 26 | "FunctionName": "MyLambda", 27 | "Payload": { 28 | "input": my_execution_input["myDynamicInput"] #Use as a 29 | #Python dictionary 30 | } 31 | } 32 | ) 33 | 34 | # Workflow is created with the placeholders 35 | workflow = Workflow( 36 | name='MyLambdaWorkflowWithGlobalInput', 37 | definition=lambda_state, 38 | role=workflow_execution_role, 39 | execution_input=my_execution_input # Provide the execution_input when 40 | # defining your workflow 41 | ) 42 | 43 | # Create the workflow on AWS Step Functions 44 | workflow.create() 45 | 46 | # The placeholder is assigned a value during execution. The SDK will 47 | # verify that all placeholder values are assigned values, and that 48 | # these values are of the expected type based on the defined schema 49 | # before the execution starts. 50 | 51 | workflow.execute(inputs={'myDynamicInput': "WorldHello"}) 52 | 53 | The second mechanism is for passing dynamic values from one step to the next 54 | step. The output of one step becomes the input of the next step. 55 | The SDK provides the :py:meth:`stepfunctions.inputs.StepInput` class for this. 56 | 57 | By default, each step has an output method :py:meth:`stepfunctions.steps.states.State.output` 58 | that returns the placeholder output for that step. 59 | 60 | .. code-block:: python 61 | 62 | lambda_state_first = LambdaStep( 63 | state_id="MyFirstLambdaStep", 64 | parameters={ 65 | "FunctionName": "MakeApiCall", 66 | "Payload": { 67 | "input": "20192312" 68 | } 69 | } 70 | ) 71 | 72 | lambda_state_second = LambdaStep( 73 | state_id="MySecondLambdaStep", 74 | parameters={ 75 | "FunctionName": "ProcessCallResult", 76 | "Payload": { 77 | "input": lambda_state_first.output()["Response"] #Use as a Python dictionary 78 | } 79 | } 80 | ) 81 | 82 | definition = Chain([lambda_state_first, lambda_state_second]) 83 | 84 | 85 | 86 | .. autoclass:: stepfunctions.inputs.Placeholder 87 | 88 | .. autoclass:: stepfunctions.inputs.ExecutionInput 89 | :inherited-members: 90 | 91 | .. autoclass:: stepfunctions.inputs.StepInput 92 | :inherited-members: 93 | -------------------------------------------------------------------------------- /doc/readmelink.rst: -------------------------------------------------------------------------------- 1 | .. include:: ../README.rst 2 | 3 | .. toctree:: 4 | :maxdepth: 3 5 | -------------------------------------------------------------------------------- /doc/requirements.txt: -------------------------------------------------------------------------------- 1 | sphinx==1.7.9 -------------------------------------------------------------------------------- /doc/sagemaker.rst: -------------------------------------------------------------------------------- 1 | SageMaker 2 | ========= 3 | 4 | This module provides classes to build steps that integrate with Amazon SageMaker. 5 | 6 | .. autoclass:: stepfunctions.steps.sagemaker.TrainingStep 7 | 8 | .. autoclass:: stepfunctions.steps.sagemaker.TransformStep 9 | 10 | .. autoclass:: stepfunctions.steps.sagemaker.TuningStep 11 | 12 | .. autoclass:: stepfunctions.steps.sagemaker.ModelStep 13 | 14 | .. autoclass:: stepfunctions.steps.sagemaker.EndpointConfigStep 15 | 16 | .. autoclass:: stepfunctions.steps.sagemaker.EndpointStep 17 | 18 | .. autoclass:: stepfunctions.steps.sagemaker.ProcessingStep -------------------------------------------------------------------------------- /doc/services.rst: -------------------------------------------------------------------------------- 1 | Service Integrations 2 | ===================== 3 | 4 | This module provides classes to build steps that integrate with Amazon DynamoDB, Amazon SNS, Amazon SQS and Amazon EMR. 5 | 6 | 7 | **Table of Contents** 8 | 9 | - `Amazon DynamoDB <#amazon-dynamodb>`__ 10 | 11 | - `Amazon EKS <#amazon-eks>`__ 12 | 13 | - `Amazon EMR <#amazon-emr>`__ 14 | 15 | - `Amazon EventBridge <#amazon-eventbridge>`__ 16 | 17 | - `AWS Glue DataBrew <#aws-glue-databrew>`__ 18 | 19 | - `Amazon SNS <#amazon-sns>`__ 20 | 21 | - `Amazon SQS <#amazon-sqs>`__ 22 | 23 | - `AWS Step Functions <#aws-step-functions>`__ 24 | 25 | 26 | Amazon DynamoDB 27 | ---------------- 28 | .. autoclass:: stepfunctions.steps.service.DynamoDBGetItemStep 29 | 30 | .. autoclass:: stepfunctions.steps.service.DynamoDBPutItemStep 31 | 32 | .. autoclass:: stepfunctions.steps.service.DynamoDBDeleteItemStep 33 | 34 | .. autoclass:: stepfunctions.steps.service.DynamoDBUpdateItemStep 35 | 36 | 37 | Amazon EKS 38 | ---------- 39 | .. autoclass:: stepfunctions.steps.service.EksCallStep 40 | 41 | .. autoclass:: stepfunctions.steps.service.EksCreateClusterStep 42 | 43 | .. autoclass:: stepfunctions.steps.service.EksCreateFargateProfileStep 44 | 45 | .. autoclass:: stepfunctions.steps.service.EksCreateNodegroupStep 46 | 47 | .. autoclass:: stepfunctions.steps.service.EksDeleteClusterStep 48 | 49 | .. autoclass:: stepfunctions.steps.service.EksDeleteFargateProfileStep 50 | 51 | .. autoclass:: stepfunctions.steps.service.EksDeleteNodegroupStep 52 | 53 | .. autoclass:: stepfunctions.steps.service.EksRunJobStep 54 | 55 | 56 | Amazon EMR 57 | ----------- 58 | .. autoclass:: stepfunctions.steps.service.EmrCreateClusterStep 59 | 60 | .. autoclass:: stepfunctions.steps.service.EmrTerminateClusterStep 61 | 62 | .. autoclass:: stepfunctions.steps.service.EmrAddStepStep 63 | 64 | .. autoclass:: stepfunctions.steps.service.EmrCancelStepStep 65 | 66 | .. autoclass:: stepfunctions.steps.service.EmrSetClusterTerminationProtectionStep 67 | 68 | .. autoclass:: stepfunctions.steps.service.EmrModifyInstanceFleetByNameStep 69 | 70 | .. autoclass:: stepfunctions.steps.service.EmrModifyInstanceGroupByNameStep 71 | 72 | Amazon EventBridge 73 | ------------------ 74 | .. autoclass:: stepfunctions.steps.service.EventBridgePutEventsStep 75 | 76 | AWS Glue DataBrew 77 | ----------------- 78 | .. autoclass:: stepfunctions.steps.service.GlueDataBrewStartJobRunStep 79 | 80 | Amazon SNS 81 | ----------- 82 | .. autoclass:: stepfunctions.steps.service.SnsPublishStep 83 | 84 | Amazon SQS 85 | ----------- 86 | .. autoclass:: stepfunctions.steps.service.SqsSendMessageStep 87 | 88 | AWS Step Functions 89 | ------------------ 90 | .. autoclass:: stepfunctions.steps.service.StepFunctionsStartExecutionStep 91 | 92 | -------------------------------------------------------------------------------- /doc/states.rst: -------------------------------------------------------------------------------- 1 | States 2 | ========= 3 | 4 | This module implements elements of the `Amazon States Language `__. 5 | 6 | .. autoclass:: stepfunctions.steps.states.Block 7 | 8 | .. autoclass:: stepfunctions.steps.states.State 9 | 10 | .. autoclass:: stepfunctions.steps.states.Chain 11 | 12 | .. autoclass:: stepfunctions.steps.states.Pass 13 | 14 | .. autoclass:: stepfunctions.steps.states.Succeed 15 | 16 | .. autoclass:: stepfunctions.steps.states.Fail 17 | 18 | .. autoclass:: stepfunctions.steps.states.Wait 19 | 20 | .. autoclass:: stepfunctions.steps.states.Task 21 | 22 | .. autoclass:: stepfunctions.steps.states.Choice 23 | 24 | .. autoclass:: stepfunctions.steps.states.Parallel 25 | 26 | .. autoclass:: stepfunctions.steps.states.Map 27 | 28 | .. autoclass:: stepfunctions.steps.states.Retry 29 | 30 | .. autoclass:: stepfunctions.steps.states.Catch 31 | -------------------------------------------------------------------------------- /doc/steps.rst: -------------------------------------------------------------------------------- 1 | 2 | ########################### 3 | Steps 4 | ########################### 5 | 6 | Steps are the basic building block of workflows in the AWS Step Functions Data 7 | Science SDK. Once you create steps, you chain them together to create a workflow, 8 | create that workflow in AWS Step Functions, and execute the workflow in the 9 | AWS cloud. 10 | 11 | Step Functions creates workflows out of steps called `States `__, 12 | and expresses that workflow in the `Amazon States Language `__. 13 | When you create a workflow in the AWS Step Functions Data Science SDK, it 14 | creates a State Machine representing your workflow and steps in AWS Step 15 | Functions. 16 | 17 | .. toctree:: 18 | :maxdepth: 2 19 | 20 | states 21 | choicerules 22 | compute 23 | sagemaker 24 | services 25 | -------------------------------------------------------------------------------- /doc/workflow.rst: -------------------------------------------------------------------------------- 1 | Workflow 2 | ========= 3 | 4 | This module provides classes which abstract workflow and workflow executions for 5 | AWS Step Functions. These classes are used for interacting directly with the 6 | AWS Step Functions service in the cloud. 7 | 8 | .. autoclass:: stepfunctions.workflow.ExecutionStatus 9 | 10 | .. autoclass:: stepfunctions.workflow.Workflow 11 | 12 | .. autoclass:: stepfunctions.workflow.Execution 13 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | sagemaker>=2.1.0 2 | boto3>=1.14.38 3 | pyyaml 4 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | # Test args for pytest; disable stdout capturing by default. 2 | [tool:pytest] 3 | addopts = 4 | -vv 5 | testpaths = tests 6 | 7 | [aliases] 8 | test=pytest 9 | 10 | [metadata] 11 | description-file = README.rst 12 | license_file = LICENSE 13 | 14 | [wheel] 15 | universal = 1 16 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"). 4 | # You may not use this file except in compliance with the License. 5 | # A copy of the License is located at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # or in the "license" file accompanying this file. This file is distributed 10 | # on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either 11 | # express or implied. See the License for the specific language governing 12 | # permissions and limitations under the License. 13 | from __future__ import absolute_import 14 | 15 | import os 16 | from glob import glob 17 | import sys 18 | 19 | from setuptools import setup, find_packages 20 | 21 | 22 | def read(fname): 23 | with open(os.path.join(os.path.dirname(__file__), fname)) as f: 24 | return f.read() 25 | 26 | 27 | def read_version(): 28 | return read("VERSION").strip() 29 | 30 | 31 | # Declare minimal set for installation 32 | required_packages = [ 33 | "sagemaker>=2.1.0", 34 | "boto3>=1.14.38", 35 | "pyyaml" 36 | ] 37 | 38 | # enum is introduced in Python 3.4. Installing enum back port 39 | if sys.version_info < (3, 4): 40 | required_packages.append("enum34>=1.1.6") 41 | 42 | setup( 43 | name="stepfunctions", 44 | version=read_version(), 45 | description="Open source library for developing data science workflows on AWS Step Functions.", 46 | packages=find_packages("src"), 47 | package_dir={"": "src"}, 48 | py_modules=[os.path.splitext(os.path.basename(path))[0] for path in glob("src/*.py")], 49 | long_description=read("README.rst"), 50 | author="Amazon Web Services", 51 | url="https://github.com/aws/aws-step-functions-data-science-sdk-python", 52 | license="Apache License 2.0", 53 | keywords="ML Amazon AWS AI Tensorflow MXNet", 54 | classifiers=[ 55 | "Intended Audience :: Developers", 56 | "Natural Language :: English", 57 | "License :: OSI Approved :: Apache Software License", 58 | "Programming Language :: Python", 59 | "Programming Language :: Python :: 3.6", 60 | ], 61 | install_requires=required_packages, 62 | extras_require={ 63 | "test": [ 64 | "tox>=3.13.1", 65 | "pytest>=4.4.1", 66 | "stopit==1.1.2", 67 | "tensorflow>=1.3.0", 68 | "mock>=2.0.0", 69 | "contextlib2==0.5.5", 70 | "IPython", 71 | ] 72 | } 73 | ) 74 | -------------------------------------------------------------------------------- /src/stepfunctions/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"). 4 | # You may not use this file except in compliance with the License. 5 | # A copy of the License is located at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # or in the "license" file accompanying this file. This file is distributed 10 | # on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either 11 | # express or implied. See the License for the specific language governing 12 | # permissions and limitations under the License. 13 | from __future__ import absolute_import 14 | 15 | import logging 16 | import pkg_resources 17 | import sys 18 | 19 | __version__ = pkg_resources.require("stepfunctions")[0].version 20 | __useragent__ = "aws-step-functions-data-science-sdk-python" 21 | 22 | # disable logging.warning() from import packages 23 | logging.getLogger().setLevel(logging.ERROR) 24 | 25 | from stepfunctions import steps 26 | from stepfunctions import workflow 27 | from stepfunctions import template 28 | from stepfunctions.workflow.utils import CustomColorFormatter 29 | 30 | def set_stream_logger(level=logging.INFO): 31 | logger = logging.getLogger('stepfunctions') 32 | # setup logger config 33 | logger.setLevel(level) 34 | logger.propagate = False 35 | # avoid attaching multiple identical stream handlers 36 | logger.handlers = [] 37 | # add stream handler to logger 38 | handler = logging.StreamHandler(sys.stdout) 39 | handler.setLevel(level) 40 | handler.setFormatter(CustomColorFormatter()) 41 | logger.addHandler(handler) 42 | 43 | 44 | # http://docs.python.org/3.3/howto/logging.html#configuring-logging-for-a-library 45 | class NullHandler(logging.Handler): 46 | def emit(self, record): 47 | pass 48 | 49 | logging.getLogger('stepfunctions').addHandler(NullHandler()) 50 | -------------------------------------------------------------------------------- /src/stepfunctions/exceptions.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"). 4 | # You may not use this file except in compliance with the License. 5 | # A copy of the License is located at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # or in the "license" file accompanying this file. This file is distributed 10 | # on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either 11 | # express or implied. See the License for the specific language governing 12 | # permissions and limitations under the License. 13 | from __future__ import absolute_import 14 | 15 | 16 | class WorkflowNotFound(Exception): 17 | pass 18 | 19 | 20 | class MissingRequiredParameter(Exception): 21 | pass 22 | 23 | 24 | class DuplicateStatesInChain(Exception): 25 | pass -------------------------------------------------------------------------------- /src/stepfunctions/inputs/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"). 4 | # You may not use this file except in compliance with the License. 5 | # A copy of the License is located at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # or in the "license" file accompanying this file. This file is distributed 10 | # on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either 11 | # express or implied. See the License for the specific language governing 12 | # permissions and limitations under the License. 13 | from __future__ import absolute_import 14 | 15 | from stepfunctions.inputs.placeholders import Placeholder, ExecutionInput, StepInput 16 | -------------------------------------------------------------------------------- /src/stepfunctions/inputs/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"). 4 | # You may not use this file except in compliance with the License. 5 | # A copy of the License is located at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # or in the "license" file accompanying this file. This file is distributed 10 | # on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either 11 | # express or implied. See the License for the specific language governing 12 | # permissions and limitations under the License. 13 | from __future__ import absolute_import 14 | 15 | def flatten(input, parent_key='', sep='.'): 16 | items = [] 17 | for k, v in input.items(): 18 | if parent_key: 19 | flattened_key = parent_key + sep + k 20 | else: 21 | flattened_key = k 22 | if isinstance(v, dict): 23 | items.extend(flatten(v, flattened_key, sep=sep).items()) 24 | else: 25 | items.append((flattened_key, v)) 26 | return dict(items) 27 | 28 | def replace_type_with_str(schema): 29 | schema_with_str = {} 30 | for k,v in schema.items(): 31 | if isinstance(v, dict): 32 | schema_with_str[k] = replace_type_with_str(v) 33 | else: 34 | schema_with_str[k] = v.__name__ 35 | return schema_with_str -------------------------------------------------------------------------------- /src/stepfunctions/steps/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"). 4 | # You may not use this file except in compliance with the License. 5 | # A copy of the License is located at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # or in the "license" file accompanying this file. This file is distributed 10 | # on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either 11 | # express or implied. See the License for the specific language governing 12 | # permissions and limitations under the License. 13 | from __future__ import absolute_import 14 | 15 | from stepfunctions.steps.choice_rule import ChoiceRule 16 | 17 | from stepfunctions.steps.states import Pass, Succeed, Fail, Wait, Choice, Parallel, Map, Task, Chain, Retry, Catch 18 | from stepfunctions.steps.states import Graph, FrozenGraph 19 | from stepfunctions.steps.sagemaker import TrainingStep, TransformStep, ModelStep, EndpointConfigStep, EndpointStep, TuningStep, ProcessingStep 20 | from stepfunctions.steps.compute import LambdaStep, BatchSubmitJobStep, GlueStartJobRunStep, EcsRunTaskStep 21 | from stepfunctions.steps.service import DynamoDBGetItemStep, DynamoDBPutItemStep, DynamoDBUpdateItemStep, DynamoDBDeleteItemStep 22 | 23 | from stepfunctions.steps.service import ( 24 | EksCallStep, 25 | EksCreateClusterStep, 26 | EksCreateFargateProfileStep, 27 | EksCreateNodegroupStep, 28 | EksDeleteClusterStep, 29 | EksDeleteFargateProfileStep, 30 | EksDeleteNodegroupStep, 31 | EksRunJobStep, 32 | ) 33 | from stepfunctions.steps.service import EmrCreateClusterStep, EmrTerminateClusterStep, EmrAddStepStep, EmrCancelStepStep, EmrSetClusterTerminationProtectionStep, EmrModifyInstanceFleetByNameStep, EmrModifyInstanceGroupByNameStep 34 | from stepfunctions.steps.service import EventBridgePutEventsStep 35 | from stepfunctions.steps.service import GlueDataBrewStartJobRunStep 36 | from stepfunctions.steps.service import SnsPublishStep, SqsSendMessageStep 37 | from stepfunctions.steps.service import StepFunctionsStartExecutionStep 38 | -------------------------------------------------------------------------------- /src/stepfunctions/steps/fields.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"). 4 | # You may not use this file except in compliance with the License. 5 | # A copy of the License is located at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # or in the "license" file accompanying this file. This file is distributed 10 | # on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either 11 | # express or implied. See the License for the specific language governing 12 | # permissions and limitations under the License. 13 | from __future__ import absolute_import 14 | 15 | from enum import Enum 16 | 17 | 18 | class Field(Enum): 19 | 20 | # Common fields 21 | Comment = 'comment' 22 | InputPath = 'input_path' 23 | OutputPath = 'output_path' 24 | Parameters = 'parameters' 25 | ResultPath = 'result_path' 26 | Next = 'next' 27 | Retry = 'retry' 28 | Catch = 'catch' 29 | Branches = 'branches' 30 | End = 'end' 31 | Version = 'version' 32 | 33 | # Pass state fields 34 | Result = 'result' 35 | 36 | # Fail state fields 37 | Error = 'error' 38 | Cause = 'cause' 39 | 40 | # Wait state fields 41 | Seconds = 'seconds' 42 | Timestamp = 'timestamp' 43 | SecondsPath = 'seconds_path' 44 | TimestampPath = 'timestamp_path' 45 | 46 | # Choice state fields 47 | Choices = 'choices' 48 | Default = 'default' 49 | 50 | # Map state fields 51 | Iterator = 'iterator' 52 | ItemsPath = 'items_path' 53 | MaxConcurrency = 'max_concurrency' 54 | 55 | # Task state fields 56 | Resource = 'resource' 57 | TimeoutSeconds = 'timeout_seconds' 58 | TimeoutSecondsPath = 'timeout_seconds_path' 59 | HeartbeatSeconds = 'heartbeat_seconds' 60 | HeartbeatSecondsPath = 'heartbeat_seconds_path' 61 | 62 | # Retry and catch fields 63 | ErrorEquals = 'error_equals' 64 | IntervalSeconds = 'interval_seconds' 65 | MaxAttempts = 'max_attempts' 66 | BackoffRate = 'backoff_rate' 67 | NextStep = 'next_step' 68 | -------------------------------------------------------------------------------- /src/stepfunctions/steps/integration_resources.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"). 4 | # You may not use this file except in compliance with the License. 5 | # A copy of the License is located at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # or in the "license" file accompanying this file. This file is distributed 10 | # on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either 11 | # express or implied. See the License for the specific language governing 12 | # permissions and limitations under the License. 13 | 14 | from __future__ import absolute_import 15 | 16 | from enum import Enum 17 | from stepfunctions.steps.utils import get_aws_partition 18 | 19 | 20 | class IntegrationPattern(Enum): 21 | """ 22 | Integration pattern enum classes for task integration resource arn builder 23 | """ 24 | 25 | WaitForTaskToken = "waitForTaskToken" 26 | WaitForCompletion = "sync" 27 | CallAndContinue = "" 28 | 29 | 30 | def get_service_integration_arn(service, api, integration_pattern=IntegrationPattern.CallAndContinue, version=None): 31 | 32 | """ 33 | ARN builder for task integration 34 | Args: 35 | service (str): The service name for the service integration 36 | api (str): The api of the service integration 37 | integration_pattern (IntegrationPattern, optional): The integration pattern for the task. (Default: IntegrationPattern.CallAndContinue) 38 | version (int, optional): The version of the resource to use. (Default: None) 39 | """ 40 | arn = "" 41 | if integration_pattern == IntegrationPattern.CallAndContinue: 42 | arn = f"arn:{get_aws_partition()}:states:::{service}:{api.value}" 43 | else: 44 | arn = f"arn:{get_aws_partition()}:states:::{service}:{api.value}.{integration_pattern.value}" 45 | 46 | if version: 47 | arn = f"{arn}:{str(version)}" 48 | 49 | return arn 50 | 51 | 52 | def is_integration_pattern_valid(integration_pattern, supported_integration_patterns): 53 | if not isinstance(integration_pattern, IntegrationPattern): 54 | raise TypeError(f"Integration pattern must be of type {IntegrationPattern}") 55 | elif integration_pattern not in supported_integration_patterns: 56 | raise ValueError(f"Integration Pattern ({integration_pattern.name}) is not supported for this step - " 57 | f"Please use one of the following: " 58 | f"{[integ_type.name for integ_type in supported_integration_patterns]}") 59 | -------------------------------------------------------------------------------- /src/stepfunctions/steps/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"). 4 | # You may not use this file except in compliance with the License. 5 | # A copy of the License is located at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # or in the "license" file accompanying this file. This file is distributed 10 | # on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either 11 | # express or implied. See the License for the specific language governing 12 | # permissions and limitations under the License. 13 | from __future__ import absolute_import 14 | 15 | import boto3 16 | import logging 17 | from stepfunctions.inputs import Placeholder 18 | 19 | logger = logging.getLogger('stepfunctions') 20 | 21 | 22 | def tags_dict_to_kv_list(tags_dict): 23 | kv_list = [{"Key": k, "Value": v} for k, v in tags_dict.items()] 24 | return kv_list 25 | 26 | 27 | def get_aws_partition(): 28 | 29 | """ 30 | Returns the aws partition for the current boto3 session. 31 | Defaults to 'aws' if the region could not be detected. 32 | """ 33 | 34 | partitions = boto3.session.Session().get_available_partitions() 35 | cur_region = boto3.session.Session().region_name 36 | cur_partition = "aws" 37 | 38 | if cur_region is None: 39 | logger.warning("No region detected for the boto3 session. Using default partition: aws") 40 | return cur_partition 41 | 42 | for partition in partitions: 43 | regions = boto3.session.Session().get_available_regions("stepfunctions", partition) 44 | if cur_region in regions: 45 | cur_partition = partition 46 | return cur_partition 47 | 48 | return cur_partition 49 | 50 | 51 | def merge_dicts(target, source): 52 | """ 53 | Merges source dictionary into the target dictionary. 54 | Values in the target dict are updated with the values of the source dict. 55 | Args: 56 | target (dict): Base dictionary into which source is merged 57 | source (dict): Dictionary used to update target. If the same key is present in both dictionaries, source's value 58 | will overwrite target's value for the corresponding key 59 | """ 60 | if isinstance(target, dict) and isinstance(source, dict): 61 | for key, value in source.items(): 62 | if key in target: 63 | if isinstance(target[key], dict) and isinstance(source[key], dict): 64 | merge_dicts(target[key], source[key]) 65 | elif target[key] == value: 66 | pass 67 | else: 68 | logger.info( 69 | f"Property: <{key}> with value: <{target[key]}>" 70 | f" will be overwritten with provided value: <{value}>") 71 | target[key] = source[key] 72 | else: 73 | target[key] = source[key] 74 | -------------------------------------------------------------------------------- /src/stepfunctions/template/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"). 4 | # You may not use this file except in compliance with the License. 5 | # A copy of the License is located at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # or in the "license" file accompanying this file. This file is distributed 10 | # on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either 11 | # express or implied. See the License for the specific language governing 12 | # permissions and limitations under the License. 13 | from __future__ import absolute_import 14 | 15 | from stepfunctions.template.pipeline import TrainingPipeline, InferencePipeline 16 | -------------------------------------------------------------------------------- /src/stepfunctions/template/pipeline/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"). 4 | # You may not use this file except in compliance with the License. 5 | # A copy of the License is located at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # or in the "license" file accompanying this file. This file is distributed 10 | # on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either 11 | # express or implied. See the License for the specific language governing 12 | # permissions and limitations under the License. 13 | from __future__ import absolute_import 14 | 15 | from stepfunctions.template.pipeline.train import TrainingPipeline 16 | from stepfunctions.template.pipeline.inference import InferencePipeline 17 | -------------------------------------------------------------------------------- /src/stepfunctions/template/pipeline/common.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"). 4 | # You may not use this file except in compliance with the License. 5 | # A copy of the License is located at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # or in the "license" file accompanying this file. This file is distributed 10 | # on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either 11 | # express or implied. See the License for the specific language governing 12 | # permissions and limitations under the License. 13 | from __future__ import absolute_import 14 | 15 | from enum import Enum 16 | from datetime import datetime 17 | 18 | from stepfunctions.steps import Task 19 | from stepfunctions.template.utils import replace_parameters_with_context_object 20 | 21 | 22 | class StepId(Enum): 23 | 24 | Train = 'Training' 25 | CreateModel = 'Create Model' 26 | ConfigureEndpoint = 'Configure Endpoint' 27 | Deploy = 'Deploy' 28 | 29 | TrainPreprocessor = 'Train Preprocessor' 30 | CreatePreprocessorModel = 'Create Preprocessor Model' 31 | TransformInput = 'Transform Input' 32 | CreatePipelineModel = 'Create Pipeline Model' 33 | 34 | 35 | class WorkflowTemplate(object): 36 | 37 | def __init__(self, s3_bucket, workflow, role, client, **kwargs): 38 | self.workflow = workflow 39 | self.role = role 40 | self.s3_bucket = s3_bucket 41 | 42 | def render_graph(self, portrait=False): 43 | return self.workflow.render_graph(portrait=portrait) 44 | 45 | def get_workflow(self): 46 | return self.workflow 47 | 48 | def _generate_timestamp(self): 49 | return datetime.now().strftime('%Y-%m-%d-%H-%M-%S') 50 | 51 | def _extract_input_template(self, definition): 52 | input_template = {} 53 | 54 | for step in definition.steps: 55 | if isinstance(step, Task): 56 | input_template[step.state_id] = step.parameters.copy() 57 | step.update_parameters(replace_parameters_with_context_object(step)) 58 | 59 | return input_template 60 | 61 | def build_workflow_definition(self): 62 | raise NotImplementedError() 63 | 64 | def create(self): 65 | return self.workflow.create() 66 | 67 | def execute(self, **kwargs): 68 | raise NotImplementedError() 69 | 70 | def __repr__(self): 71 | return '{}(s3_bucket={!r}, workflow={!r}, role={!r})'.format( 72 | self.__class__.__name__, 73 | self.s3_bucket, self.workflow, self.role 74 | ) 75 | -------------------------------------------------------------------------------- /src/stepfunctions/template/pipeline/train.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"). 4 | # You may not use this file except in compliance with the License. 5 | # A copy of the License is located at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # or in the "license" file accompanying this file. This file is distributed 10 | # on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either 11 | # express or implied. See the License for the specific language governing 12 | # permissions and limitations under the License. 13 | from __future__ import absolute_import 14 | 15 | from sagemaker.utils import base_name_from_image 16 | from sagemaker.sklearn.estimator import SKLearn 17 | from sagemaker.model import Model 18 | from sagemaker.pipeline import PipelineModel 19 | 20 | from stepfunctions.steps import TrainingStep, TransformStep, ModelStep, EndpointConfigStep, EndpointStep, Chain, Fail, Catch 21 | from stepfunctions.workflow import Workflow 22 | from stepfunctions.template.pipeline.common import StepId, WorkflowTemplate 23 | 24 | 25 | class TrainingPipeline(WorkflowTemplate): 26 | 27 | """ 28 | Creates a standard training pipeline with the following steps in order: 29 | 30 | 1. Train estimator 31 | 2. Create estimator model 32 | 3. Endpoint configuration 33 | 4. Deploy model 34 | """ 35 | 36 | __allowed_kwargs = ('pipeline_name',) 37 | 38 | def __init__(self, estimator, role, inputs, s3_bucket, client=None, **kwargs): 39 | """ 40 | Args: 41 | estimator (sagemaker.estimator.EstimatorBase): The estimator to use for training. Can be a BYO estimator, Framework estimator or Amazon algorithm estimator. 42 | role (str): An AWS IAM role (either name or full Amazon Resource Name (ARN)). This role is used to create, manage, and execute the Step Functions workflows. 43 | inputs: Information about the training data. Please refer to the `fit()` method of the associated estimator, as this can take any of the following forms: 44 | 45 | * (str) - The S3 location where training data is saved. 46 | * (dict[str, str] or dict[str, `sagemaker.inputs.TrainingInput`]) - If using multiple channels for training data, you can specify a dict mapping channel names to strings or `sagemaker.inputs.TrainingInput` objects. 47 | * (`sagemaker.inputs.TrainingInput`) - Channel configuration for S3 data sources that can provide additional information about the training dataset. See `sagemaker.inputs.TrainingInput` for full details. 48 | * (`sagemaker.amazon.amazon_estimator.RecordSet`) - A collection of Amazon `Record` objects serialized and stored in S3. For use with an estimator for an Amazon algorithm. 49 | * (list[`sagemaker.amazon.amazon_estimator.RecordSet`]) - A list of `sagemaker.amazon.amazon_estimator.RecordSet` objects, where each instance is a different channel of training data. 50 | s3_bucket (str): S3 bucket under which the output artifacts from the training job will be stored. The parent path used is built using the format: ``s3://{s3_bucket}/{pipeline_name}/models/{job_name}/``. In this format, `pipeline_name` refers to the keyword argument provided for TrainingPipeline. If a `pipeline_name` argument was not provided, one is auto-generated by the pipeline as `training-pipeline-`. Also, in the format, `job_name` refers to the job name provided when calling the :meth:`TrainingPipeline.run()` method. 51 | client (SFN.Client, optional): boto3 client to use for creating and interacting with the training pipeline in Step Functions. (default: None) 52 | 53 | Keyword Args: 54 | pipeline_name (str, optional): Name of the pipeline. This name will be used to name jobs (if not provided when calling execute()), models, endpoints, and S3 objects created by the pipeline. If a `pipeline_name` argument was not provided, one is auto-generated by the pipeline as `training-pipeline-`. (default:None) 55 | """ 56 | self.estimator = estimator 57 | self.inputs = inputs 58 | 59 | for key in self.__class__.__allowed_kwargs: 60 | setattr(self, key, kwargs.pop(key, None)) 61 | 62 | if not self.pipeline_name: 63 | self.__pipeline_name_unique = True 64 | self.pipeline_name = 'training-pipeline-{date}'.format(date=self._generate_timestamp()) 65 | 66 | self.definition = self.build_workflow_definition() 67 | self.input_template = self._extract_input_template(self.definition) 68 | 69 | workflow = Workflow(name=self.pipeline_name, definition=self.definition, role=role, format_json=True, client=client) 70 | 71 | super(TrainingPipeline, self).__init__(s3_bucket=s3_bucket, workflow=workflow, role=role, client=client) 72 | 73 | def build_workflow_definition(self): 74 | """ 75 | Build the workflow definition for the training pipeline with all the states involved. 76 | 77 | Returns: 78 | :class:`~stepfunctions.steps.states.Chain`: Workflow definition as a chain of states involved in the the training pipeline. 79 | """ 80 | default_name = self.pipeline_name 81 | 82 | instance_type = self.estimator.instance_type 83 | instance_count = self.estimator.instance_count 84 | 85 | training_step = TrainingStep( 86 | StepId.Train.value, 87 | estimator=self.estimator, 88 | job_name=default_name + '/estimator-source', 89 | data=self.inputs, 90 | ) 91 | 92 | model = self.estimator.create_model() 93 | model_step = ModelStep( 94 | StepId.CreateModel.value, 95 | instance_type=instance_type, 96 | model=model, 97 | model_name=default_name 98 | ) 99 | 100 | endpoint_config_step = EndpointConfigStep( 101 | StepId.ConfigureEndpoint.value, 102 | endpoint_config_name=default_name, 103 | model_name=default_name, 104 | initial_instance_count=instance_count, 105 | instance_type=instance_type 106 | ) 107 | deploy_step = EndpointStep( 108 | StepId.Deploy.value, 109 | endpoint_name=default_name, 110 | endpoint_config_name=default_name, 111 | ) 112 | 113 | return Chain([training_step, model_step, endpoint_config_step, deploy_step]) 114 | 115 | def execute(self, job_name=None, hyperparameters=None): 116 | """ 117 | Run the training pipeline. 118 | 119 | Args: 120 | job_name (str, optional): Name for the training job. If one is not provided, a job name will be auto-generated. (default: None) 121 | hyperparameters (dict, optional): Hyperparameters for the estimator training. (default: None) 122 | 123 | Returns: 124 | :py:class:`~stepfunctions.workflow.Execution`: Running instance of the training pipeline. 125 | """ 126 | inputs = self.input_template.copy() 127 | 128 | if hyperparameters is not None: 129 | inputs[StepId.Train.value]['HyperParameters'] = { 130 | k: str(v) for k, v in hyperparameters.items() 131 | } 132 | 133 | if job_name is None: 134 | job_name = '{base_name}-{timestamp}'.format(base_name='training-pipeline', timestamp=self._generate_timestamp()) 135 | 136 | # Configure training and model 137 | inputs[StepId.Train.value]['TrainingJobName'] = 'estimator-' + job_name 138 | inputs[StepId.Train.value]['OutputDataConfig']['S3OutputPath'] = 's3://{s3_bucket}/{pipeline_name}/models'.format( 139 | s3_bucket=self.s3_bucket, 140 | pipeline_name=self.workflow.name 141 | ) 142 | inputs[StepId.CreateModel.value]['ModelName'] = job_name 143 | 144 | # Configure endpoint 145 | inputs[StepId.ConfigureEndpoint.value]['EndpointConfigName'] = job_name 146 | for variant in inputs[StepId.ConfigureEndpoint.value]['ProductionVariants']: 147 | variant['ModelName'] = job_name 148 | inputs[StepId.Deploy.value]['EndpointConfigName'] = job_name 149 | inputs[StepId.Deploy.value]['EndpointName'] = job_name 150 | 151 | # Configure the path to model artifact 152 | inputs[StepId.CreateModel.value]['PrimaryContainer']['ModelDataUrl'] = '{s3_uri}/{job}/output/model.tar.gz'.format( 153 | s3_uri=inputs[StepId.Train.value]['OutputDataConfig']['S3OutputPath'], 154 | job=inputs[StepId.Train.value]['TrainingJobName'] 155 | ) 156 | 157 | return self.workflow.execute(inputs=inputs, name=job_name) 158 | -------------------------------------------------------------------------------- /src/stepfunctions/template/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"). 4 | # You may not use this file except in compliance with the License. 5 | # A copy of the License is located at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # or in the "license" file accompanying this file. This file is distributed 10 | # on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either 11 | # express or implied. See the License for the specific language governing 12 | # permissions and limitations under the License. 13 | from __future__ import absolute_import 14 | 15 | 16 | def replace_parameters_with_context_object(step): 17 | """Replace the parameters using $$.Execution.Input.""" 18 | updated_parameters = {} 19 | for k in step.parameters.keys(): 20 | updated_parameters['{key}.$'.format(key=k)] = "$$.Execution.Input['{state_id}'].{key}".format(state_id=step.state_id, key=k) 21 | return updated_parameters 22 | 23 | def replace_parameters_with_jsonpath(step, params): 24 | 25 | def search_and_replace(src_params, dest_params, key): 26 | """Search and replace the dict entry in-place.""" 27 | original_key = key[:-2] # Remove .$ in the end 28 | del src_params[original_key] 29 | src_params[key] = dest_params[key] 30 | 31 | def replace_values(src_params, dest_params): 32 | if isinstance(dest_params, dict): 33 | for key in dest_params.keys(): 34 | if key.endswith('$'): 35 | search_and_replace(src_params, dest_params, key) 36 | else: 37 | replace_values(src_params[key], dest_params[key]) 38 | 39 | task_parameters = step.parameters.copy() 40 | replace_values(task_parameters, params) 41 | 42 | return task_parameters 43 | -------------------------------------------------------------------------------- /src/stepfunctions/workflow/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"). 4 | # You may not use this file except in compliance with the License. 5 | # A copy of the License is located at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # or in the "license" file accompanying this file. This file is distributed 10 | # on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either 11 | # express or implied. See the License for the specific language governing 12 | # permissions and limitations under the License. 13 | from __future__ import absolute_import 14 | 15 | from stepfunctions.workflow.stepfunctions import Workflow, Execution, ExecutionStatus 16 | -------------------------------------------------------------------------------- /src/stepfunctions/workflow/cloudformation.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"). 4 | # You may not use this file except in compliance with the License. 5 | # A copy of the License is located at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # or in the "license" file accompanying this file. This file is distributed 10 | # on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either 11 | # express or implied. See the License for the specific language governing 12 | # permissions and limitations under the License. 13 | from __future__ import absolute_import 14 | 15 | import json 16 | import yaml 17 | import logging 18 | 19 | logger = logging.getLogger('stepfunctions') 20 | 21 | 22 | def repr_str(dumper, data): 23 | if '\n' in data: 24 | return dumper.represent_scalar(u'tag:yaml.org,2002:str', data, style='|') 25 | return dumper.org_represent_str(data) 26 | 27 | 28 | yaml.SafeDumper.org_represent_str = yaml.SafeDumper.represent_str 29 | yaml.add_representer(dict, lambda self, data: yaml.representer.SafeRepresenter.represent_dict(self, data.items()), Dumper=yaml.SafeDumper) 30 | yaml.add_representer(str, repr_str, Dumper=yaml.SafeDumper) 31 | 32 | CLOUDFORMATION_BASE_TEMPLATE = { 33 | "AWSTemplateFormatVersion": '2010-09-09', 34 | "Description": None, 35 | "Resources": { 36 | "StateMachineComponent": { 37 | "Type": "AWS::StepFunctions::StateMachine", 38 | "Properties": { 39 | "StateMachineName": None, 40 | "DefinitionString": None, 41 | "RoleArn": None, 42 | } 43 | } 44 | } 45 | } 46 | 47 | 48 | def build_cloudformation_template(workflow, description=None): 49 | """ 50 | Creates a CloudFormation template from the provided Workflow 51 | Args: 52 | workflow (Workflow): Step Functions workflow instance 53 | description (str, optional): Description of the template. If none provided, the default description will be used: "CloudFormation template for AWS Step Functions - State Machine" 54 | """ 55 | logger.warning('To reuse the CloudFormation template in different regions, please make sure to update the region specific AWS resources in the StateMachine definition.') 56 | 57 | template = CLOUDFORMATION_BASE_TEMPLATE.copy() 58 | 59 | template["Description"] = description if description else "CloudFormation template for AWS Step Functions - State Machine" 60 | template["Resources"]["StateMachineComponent"]["Properties"]["StateMachineName"] = workflow.name 61 | 62 | definition = workflow.definition.to_dict() 63 | 64 | template["Resources"]["StateMachineComponent"]["Properties"]["DefinitionString"] = json.dumps(definition, indent=2) 65 | template["Resources"]["StateMachineComponent"]["Properties"]["RoleArn"] = workflow.role 66 | 67 | return yaml.safe_dump(template, default_flow_style=False) 68 | -------------------------------------------------------------------------------- /src/stepfunctions/workflow/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"). 4 | # You may not use this file except in compliance with the License. 5 | # A copy of the License is located at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # or in the "license" file accompanying this file. This file is distributed 10 | # on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either 11 | # express or implied. See the License for the specific language governing 12 | # permissions and limitations under the License. 13 | from __future__ import absolute_import 14 | 15 | import logging 16 | import stepfunctions 17 | 18 | def append_user_agent_to_client(boto_client): 19 | user_agent_suffix = str.format(" {package_useragent}/{package_version}", package_useragent=stepfunctions.__useragent__, package_version=stepfunctions.__version__) 20 | if user_agent_suffix not in boto_client._client_config.user_agent: 21 | boto_client._client_config.user_agent += user_agent_suffix 22 | 23 | class CustomColorFormatter(logging.Formatter): 24 | GREY = "\x1b[0m" 25 | GREEN = "\x1b[32m" 26 | YELLOW = "\x1b[33m" 27 | RED = "\x1b[31m" 28 | BOLD_RED = "\x1b[31;1m" 29 | RESET = "\x1b[0m" 30 | FORMAT = "[%(levelname)s] %(message)s" 31 | 32 | LEVEL_FORMATS = { 33 | logging.DEBUG: GREY + FORMAT + RESET, 34 | logging.INFO: GREEN + FORMAT + RESET, 35 | logging.WARNING: YELLOW + FORMAT + RESET, 36 | logging.ERROR: RED + FORMAT + RESET, 37 | logging.CRITICAL: BOLD_RED + FORMAT + RESET 38 | } 39 | 40 | def format(self, record): 41 | log_fmt = self.LEVEL_FORMATS.get(record.levelno) 42 | formatter = logging.Formatter(log_fmt) 43 | return formatter.format(record) -------------------------------------------------------------------------------- /src/stepfunctions/workflow/widgets/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"). 4 | # You may not use this file except in compliance with the License. 5 | # A copy of the License is located at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # or in the "license" file accompanying this file. This file is distributed 10 | # on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either 11 | # express or implied. See the License for the specific language governing 12 | # permissions and limitations under the License. 13 | from __future__ import absolute_import 14 | 15 | from stepfunctions.workflow.widgets.graph import WorkflowGraphWidget, ExecutionGraphWidget 16 | from stepfunctions.workflow.widgets.events_table import EventsTableWidget 17 | from stepfunctions.workflow.widgets.executions_table import ExecutionsTableWidget 18 | from stepfunctions.workflow.widgets.workflows_table import WorkflowsTableWidget 19 | -------------------------------------------------------------------------------- /src/stepfunctions/workflow/widgets/executions_table.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"). 4 | # You may not use this file except in compliance with the License. 5 | # A copy of the License is located at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # or in the "license" file accompanying this file. This file is distributed 10 | # on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either 11 | # express or implied. See the License for the specific language governing 12 | # permissions and limitations under the License. 13 | from __future__ import absolute_import 14 | 15 | from datetime import datetime 16 | from string import Template 17 | 18 | from stepfunctions.workflow.widgets.utils import format_time, create_sfn_execution_url, AWS_TABLE_CSS 19 | 20 | TABLE_TEMPLATE = """ 21 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | {table_rows} 36 | 37 |
NameStatusStartedEnd Time
38 | """ 39 | 40 | TABLE_ROW_TEMPLATE = """ 41 | 42 | 43 | $name 44 | 45 | $status 46 | $start_date 47 | $stop_date 48 | 49 | """ 50 | 51 | TABLE_ROW_TEMPLATE = """ 52 | 53 | 54 | $name 55 | 56 | $status 57 | $start_date 58 | $stop_date 59 | 60 | """ 61 | 62 | CSS_TEMPLATE = """ 63 | * { 64 | box-sizing: border-box; 65 | } 66 | 67 | .table-widget { 68 | min-width: 100%; 69 | font-size: 14px; 70 | line-height: 28px; 71 | color: #545b64; 72 | border-spacing: 0; 73 | background-color: #fff; 74 | border-color: grey; 75 | background: #fafafa; 76 | } 77 | 78 | .table-widget thead th { 79 | text-align: left !important; 80 | color: #879596; 81 | padding: 0.3em 2em; 82 | border-bottom: 1px solid #eaeded; 83 | min-height: 4rem; 84 | line-height: 28px; 85 | } 86 | 87 | .table-widget td { 88 | /* padding: 24px 18px; */ 89 | padding: 0.4em 2em; 90 | line-height: 28px; 91 | text-align: left !important; 92 | background: #fff; 93 | border-bottom: 1px solid #eaeded; 94 | border-top: 1px solid transparent; 95 | } 96 | 97 | .table-widget td:before { 98 | content: ""; 99 | height: 3rem; 100 | } 101 | 102 | .table-widget .clickable-cell { 103 | cursor: pointer; 104 | } 105 | 106 | .hide { 107 | display: none; 108 | } 109 | 110 | .triangle-right { 111 | width: 0; 112 | height: 0; 113 | border-top: 5px solid transparent; 114 | border-left: 8px solid #545b64; 115 | border-bottom: 5px solid transparent; 116 | margin-right: 5px; 117 | } 118 | 119 | a.awsui { 120 | text-decoration: none !important; 121 | color: #007dbc; 122 | } 123 | 124 | a.awsui:hover { 125 | text-decoration: underline !important; 126 | } 127 | """ 128 | 129 | class ExecutionsTableWidget(object): 130 | def __init__(self, executions): 131 | table_rows = [Template(TABLE_ROW_TEMPLATE).substitute( 132 | execution_url=create_sfn_execution_url(execution.execution_arn), 133 | name=execution.name, 134 | status=execution.status, 135 | start_date=format_time(execution.start_date), 136 | stop_date=format_time(execution.stop_date) 137 | ) for execution in executions] 138 | 139 | self.template = Template(TABLE_TEMPLATE.format(table_rows='\n'.join(table_rows))) 140 | 141 | def show(self): 142 | return self.template.substitute({ 143 | 'aws_table_css': AWS_TABLE_CSS, 144 | 'custom_css': CSS_TEMPLATE 145 | }) 146 | -------------------------------------------------------------------------------- /src/stepfunctions/workflow/widgets/graph.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"). 4 | # You may not use this file except in compliance with the License. 5 | # A copy of the License is located at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # or in the "license" file accompanying this file. This file is distributed 10 | # on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either 11 | # express or implied. See the License for the specific language governing 12 | # permissions and limitations under the License. 13 | from __future__ import absolute_import 14 | 15 | import random 16 | import json 17 | import logging 18 | 19 | logger = logging.getLogger('stepfunctions') 20 | 21 | try: 22 | from IPython.core.display import HTML 23 | __IPYTHON_IMPORTED__ = True 24 | except ImportError as e: 25 | logger.warning("IPython failed to import. Visualization features will be impaired or broken.") 26 | __IPYTHON_IMPORTED__ = False 27 | 28 | from string import Template 29 | 30 | from stepfunctions.workflow.widgets.utils import create_sfn_execution_url 31 | 32 | JSLIB_URL = 'https://do0of8uwbahzz.cloudfront.net/sfn' 33 | CSS_URL = 'https://do0of8uwbahzz.cloudfront.net/graph.css' 34 | 35 | HTML_TEMPLATE = """ 36 | 37 |
38 | $graph_legend_template 39 | 40 | {console_snippet} 41 |
42 | 43 | 46 | """ 47 | 48 | EXECUTION_URL_TEMPLATE = """ Inspect in AWS Step Functions """ 49 | 50 | WORKFLOW_GRAPH_SCRIPT_TEMPLATE = """ 51 | require.config({ 52 | paths: { 53 | sfn: "$jslib", 54 | } 55 | }); 56 | 57 | require(['sfn'], function(sfn) { 58 | var element = document.getElementById('$element_id') 59 | 60 | var options = { 61 | width: parseFloat(getComputedStyle(element, null).width.replace("px", "")), 62 | height: 600, 63 | layout: '$layout', 64 | resizeHeight: true 65 | }; 66 | 67 | var definition = $definition; 68 | var elementId = '#$element_id'; 69 | 70 | var graph = new sfn.StateMachineGraph(definition, elementId, options); 71 | graph.render(); 72 | }); 73 | """ 74 | 75 | EXECUTION_GRAPH_SCRIPT_TEMPLATE = """ 76 | require.config({ 77 | paths: { 78 | sfn: "$jslib", 79 | } 80 | }); 81 | 82 | require(['sfn'], function(sfn) { 83 | var element = document.getElementById('$element_id') 84 | 85 | var options = { 86 | width: parseFloat(getComputedStyle(element, null).width.replace("px", "")), 87 | height: 1000, 88 | layout: '$layout', 89 | resizeHeight: true 90 | }; 91 | 92 | var definition = $definition; 93 | var elementId = '#$element_id'; 94 | var events = { 'events': $events }; 95 | 96 | var graph = new sfn.StateMachineExecutionGraph(definition, events, elementId, options); 97 | graph.render(); 98 | }); 99 | """ 100 | 101 | EXECUTION_GRAPH_LEGEND_TEMPLATE = """ 102 | 130 |
131 |
    132 |
  • 133 |
    134 | Success 135 |
  • 136 |
  • 137 |
    138 | Failed 139 |
  • 140 |
  • 141 |
    142 | Cancelled 143 |
  • 144 |
  • 145 |
    146 | In Progress 147 |
  • 148 |
  • 149 |
    150 | Caught Error 151 |
  • 152 |
153 |
154 | """ 155 | 156 | class WorkflowGraphWidget(object): 157 | 158 | def __init__(self, json_definition): 159 | self.json_definition = json_definition 160 | self.element_id = 'graph-%d' % random.randint(0, 999) 161 | self.layout = 'TB' 162 | self.template = Template(HTML_TEMPLATE.format( 163 | code_snippet=WORKFLOW_GRAPH_SCRIPT_TEMPLATE, 164 | console_snippet='')) 165 | 166 | def show(self, portrait=True): 167 | if __IPYTHON_IMPORTED__ is False: 168 | logger.error("IPython failed to import. Widgets/graphs cannot be visualized.") 169 | return "" 170 | if portrait is False: 171 | self.layout = 'LR' 172 | else: 173 | self.layout = 'TB' 174 | 175 | return HTML(self.template.substitute({ 176 | 'element_id': self.element_id, 177 | 'definition': self.json_definition, 178 | 'layout': self.layout, 179 | 'css': CSS_URL, 180 | 'jslib': JSLIB_URL, 181 | 'graph_legend_template': "" 182 | })) 183 | 184 | class ExecutionGraphWidget(object): 185 | 186 | def __init__(self, json_definition, json_events, execution_arn): 187 | self.json_definition = json_definition 188 | self.json_events = json_events 189 | self.element_id = 'graph-%d' % random.randint(0, 999) 190 | self.layout = 'TB' 191 | self.template = Template(HTML_TEMPLATE.format( 192 | code_snippet=EXECUTION_GRAPH_SCRIPT_TEMPLATE, 193 | console_snippet=EXECUTION_URL_TEMPLATE)) 194 | self.console_url = create_sfn_execution_url(execution_arn) 195 | 196 | def show(self, portrait=True): 197 | if __IPYTHON_IMPORTED__ is False: 198 | logger.error("IPython failed to import. Widgets/graphs cannot be visualized.") 199 | return "" 200 | if portrait is False: 201 | self.layout = 'LR' 202 | else: 203 | self.layout = 'TB' 204 | 205 | return HTML(self.template.substitute({ 206 | 'element_id': self.element_id, 207 | 'definition': self.json_definition, 208 | 'events': self.json_events, 209 | 'layout': self.layout, 210 | 'css': CSS_URL, 211 | 'jslib': JSLIB_URL, 212 | 'graph_legend_template': EXECUTION_GRAPH_LEGEND_TEMPLATE, 213 | 'console': self.console_url 214 | })) 215 | -------------------------------------------------------------------------------- /src/stepfunctions/workflow/widgets/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"). 4 | # You may not use this file except in compliance with the License. 5 | # A copy of the License is located at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # or in the "license" file accompanying this file. This file is distributed 10 | # on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either 11 | # express or implied. See the License for the specific language governing 12 | # permissions and limitations under the License. 13 | from __future__ import absolute_import 14 | 15 | import boto3 16 | import sys 17 | import time 18 | 19 | from datetime import datetime 20 | 21 | AWS_SAGEMAKER_URL = "https://console.aws.amazon.com/sagemaker/home?region={region}#/{resource_type}/{resource}" 22 | AWS_SFN_EXECUTIONS_DETAIL_URL = "https://console.aws.amazon.com/states/home?region={region}#/executions/details/{execution_arn}" 23 | AWS_SFN_STATE_MACHINE_URL = "https://console.aws.amazon.com/states/home?region={region}#/statemachines/view/{state_machine_arn}" 24 | 25 | AWS_TABLE_CSS = """ 26 | .table-widget { 27 | width: 100%; 28 | font-size: 14px; 29 | line-height: 28px; 30 | color: #545b64; 31 | border-spacing: 0; 32 | background-color: #fff; 33 | border-color: grey; 34 | background: #fafafa; 35 | } 36 | 37 | .table-widget thead th { 38 | text-align: left !important; 39 | color: #879596; 40 | padding: 0.3em 2em; 41 | border-bottom: 1px solid #eaeded; 42 | min-height: 4rem; 43 | line-height: 28px; 44 | } 45 | 46 | .table-widget thead th:first-of-type { 47 | } 48 | 49 | .table-widget td { 50 | overflow-wrap: break-word; 51 | padding: 0.4em 2em; 52 | line-height: 28px; 53 | text-align: left !important; 54 | background: #fff; 55 | border-bottom: 1px solid #eaeded; 56 | border-top: 1px solid transparent; 57 | } 58 | 59 | .table-widget td:before { 60 | content: ""; 61 | height: 3rem; 62 | } 63 | 64 | a { 65 | cursor: pointer; 66 | text-decoration: none !important; 67 | color: #007dbc; 68 | } 69 | 70 | a:hover { 71 | text-decoration: underline !important; 72 | } 73 | 74 | a.disabled { 75 | color: black; 76 | cursor: default; 77 | pointer-events: none; 78 | } 79 | 80 | .hide { 81 | display: none; 82 | } 83 | 84 | pre { 85 | white-space: pre-wrap; 86 | } 87 | """ 88 | 89 | def format_time(timestamp): 90 | if timestamp is None: 91 | return "-" 92 | time = timestamp.strftime("%b %d, %Y %I:%M:%S.%f")[:-3] 93 | return time + timestamp.strftime(" %p") 94 | 95 | def get_timestamp(date): 96 | if sys.version_info[0] < 3 or sys.version_info[1] < 4: 97 | # python version < 3.3 98 | return time.mktime(date.timetuple()) 99 | else: 100 | return date.timestamp() 101 | 102 | def get_elapsed_ms(start_datetime, end_datetime): 103 | elapsed_time_seconds = (end_datetime - start_datetime).microseconds 104 | return elapsed_time_seconds / 1000 105 | 106 | def create_sfn_execution_url(execution_arn): 107 | arn_segments = execution_arn.split(":") 108 | region_name = arn_segments[3] 109 | return AWS_SFN_EXECUTIONS_DETAIL_URL.format(region=region_name, execution_arn=execution_arn) 110 | 111 | def create_sfn_workflow_url(state_machine_arn): 112 | arn_segments = state_machine_arn.split(":") 113 | region_name = arn_segments[3] 114 | return AWS_SFN_STATE_MACHINE_URL.format(region=region_name, state_machine_arn=state_machine_arn) 115 | 116 | def sagemaker_console_link(resource_type, resource): 117 | region_name = boto3.session.Session().region_name 118 | return AWS_SAGEMAKER_URL.format(region=region_name, resource_type=resource_type, resource=resource) 119 | 120 | -------------------------------------------------------------------------------- /src/stepfunctions/workflow/widgets/workflows_table.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"). 4 | # You may not use this file except in compliance with the License. 5 | # A copy of the License is located at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # or in the "license" file accompanying this file. This file is distributed 10 | # on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either 11 | # express or implied. See the License for the specific language governing 12 | # permissions and limitations under the License. 13 | from __future__ import absolute_import 14 | 15 | from datetime import datetime 16 | from string import Template 17 | 18 | from stepfunctions.workflow.widgets.utils import format_time, create_sfn_workflow_url, AWS_TABLE_CSS 19 | 20 | TABLE_TEMPLATE = """ 21 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | {table_rows} 34 | 35 |
NameCreation Date
36 | """ 37 | 38 | TABLE_ROW_TEMPLATE = """ 39 | 40 | 41 | $name 42 | 43 | $creation_date 44 | 45 | """ 46 | 47 | CSS_TEMPLATE = """ 48 | * { 49 | box-sizing: border-box; 50 | } 51 | 52 | .table-widget { 53 | min-width: 100%; 54 | font-size: 14px; 55 | line-height: 28px; 56 | color: #545b64; 57 | border-spacing: 0; 58 | background-color: #fff; 59 | border-color: grey; 60 | background: #fafafa; 61 | } 62 | 63 | .table-widget thead th { 64 | text-align: left !important; 65 | color: #879596; 66 | padding: 0.3em 2em; 67 | border-bottom: 1px solid #eaeded; 68 | min-height: 4rem; 69 | line-height: 28px; 70 | } 71 | 72 | .table-widget td { 73 | /* padding: 24px 18px; */ 74 | padding: 0.4em 2em; 75 | line-height: 28px; 76 | text-align: left !important; 77 | background: #fff; 78 | border-bottom: 1px solid #eaeded; 79 | border-top: 1px solid transparent; 80 | } 81 | 82 | .table-widget td:before { 83 | content: ""; 84 | height: 3rem; 85 | } 86 | 87 | .table-widget .clickable-cell { 88 | cursor: pointer; 89 | } 90 | 91 | .hide { 92 | display: none; 93 | } 94 | 95 | .triangle-right { 96 | width: 0; 97 | height: 0; 98 | border-top: 5px solid transparent; 99 | border-left: 8px solid #545b64; 100 | border-bottom: 5px solid transparent; 101 | margin-right: 5px; 102 | } 103 | 104 | a.awsui { 105 | text-decoration: none !important; 106 | color: #007dbc; 107 | } 108 | 109 | a.awsui:hover { 110 | text-decoration: underline !important; 111 | } 112 | """ 113 | 114 | class WorkflowsTableWidget(object): 115 | 116 | def __init__(self, workflows): 117 | table_rows = [Template(TABLE_ROW_TEMPLATE).substitute( 118 | state_machine_url=create_sfn_workflow_url(workflow['stateMachineArn']), 119 | name=workflow['name'], 120 | creation_date=format_time(workflow['creationDate']), 121 | ) for workflow in workflows] 122 | 123 | self.template = Template(TABLE_TEMPLATE.format(table_rows='\n'.join(table_rows))) 124 | 125 | def show(self): 126 | return self.template.substitute({ 127 | 'aws_table_css': AWS_TABLE_CSS, 128 | 'custom_css': CSS_TEMPLATE 129 | }) 130 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"). 4 | # You may not use this file except in compliance with the License. 5 | # A copy of the License is located at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # or in the "license" file accompanying this file. This file is distributed 10 | # on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either 11 | # express or implied. See the License for the specific language governing 12 | # permissions and limitations under the License. 13 | from __future__ import absolute_import -------------------------------------------------------------------------------- /tests/data/one_p_mnist/mnist.npy.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aws/aws-step-functions-data-science-sdk-python/0e970ccde4ef7cc532f80040598e8d4d55f6486d/tests/data/one_p_mnist/mnist.npy.gz -------------------------------------------------------------------------------- /tests/data/one_p_mnist/mnist.pkl.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aws/aws-step-functions-data-science-sdk-python/0e970ccde4ef7cc532f80040598e8d4d55f6486d/tests/data/one_p_mnist/mnist.pkl.gz -------------------------------------------------------------------------------- /tests/data/one_p_mnist/sklearn_mnist_estimator.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | 3 | # Licensed under the Apache License, Version 2.0 (the "License"). 4 | # You may not use this file except in compliance with the License. 5 | # A copy of the License is located at 6 | 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | # or in the "license" file accompanying this file. This file is distributed 10 | # on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either 11 | # express or implied. See the License for the specific language governing 12 | # permissions and limitations under the License. 13 | from __future__ import print_function, absolute_import 14 | 15 | import argparse 16 | import numpy as np 17 | import os 18 | import json 19 | 20 | from six import BytesIO 21 | 22 | from sklearn import svm 23 | from sklearn.externals import joblib 24 | 25 | 26 | if __name__ == "__main__": 27 | parser = argparse.ArgumentParser() 28 | 29 | # Data and model checkpoints directories 30 | parser.add_argument("--epochs", type=int, default=-1) 31 | parser.add_argument("--output-data-dir", type=str, default=os.environ["SM_OUTPUT_DATA_DIR"]) 32 | parser.add_argument("--model-dir", type=str, default=os.environ["SM_MODEL_DIR"]) 33 | parser.add_argument("--train", type=str, default=os.environ["SM_CHANNEL_TRAIN"]) 34 | args = parser.parse_args() 35 | 36 | # Load the data into memory as numpy arrays 37 | data_path = os.path.join(args.train, "mnist.npy.gz.out") 38 | with open(data_path, 'r') as f: 39 | jsonarray = json.loads(f.read()) 40 | data = np.array(jsonarray) 41 | train_set = data 42 | train_file = {'x': train_set[:, 1:], 'y': train_set[:, 0]} 43 | 44 | # Set up a Support Vector Machine classifier to predict digit from images 45 | clf = svm.SVC(gamma=0.001, C=100.0, max_iter=args.epochs) 46 | 47 | train_images = train_file['x'] 48 | train_labels = train_file['y'] 49 | # Fit the SVM classifier with the images and the corresponding labels 50 | clf.fit(train_images, train_labels) 51 | 52 | # Print the coefficients of the trained classifier, and save the coefficients 53 | joblib.dump(clf, os.path.join(args.model_dir, "model.joblib")) 54 | 55 | 56 | def input_fn(input_data, content_type): 57 | # Load the data into memory as numpy arrays 58 | buf = BytesIO(input_data) 59 | jsonarray = json.loads(buf.read()) 60 | data = np.array(jsonarray) 61 | return data 62 | 63 | 64 | def predict_fn(data, model): 65 | train_set = data[:, 1:] 66 | return model.predict(train_set) 67 | 68 | 69 | def model_fn(model_dir): 70 | clf = joblib.load(os.path.join(model_dir, "model.joblib")) 71 | return clf 72 | -------------------------------------------------------------------------------- /tests/data/one_p_mnist/sklearn_mnist_preprocessor.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"). 4 | # You may not use this file except in compliance with the License. 5 | # A copy of the License is located at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # or in the "license" file accompanying this file. This file is distributed 10 | # on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either 11 | # express or implied. See the License for the specific language governing 12 | # permissions and limitations under the License. 13 | from __future__ import print_function, absolute_import 14 | 15 | import argparse 16 | import numpy as np 17 | import os 18 | import gzip 19 | 20 | from six import BytesIO 21 | 22 | from sklearn.compose import make_column_transformer 23 | from sklearn.externals import joblib 24 | from sklearn.preprocessing import StandardScaler 25 | 26 | 27 | def create_preprocessing_pipeline(num_columns): 28 | preprocessor = make_column_transformer( 29 | (np.arange(num_columns), StandardScaler()), 30 | remainder='passthrough' 31 | ) 32 | return preprocessor 33 | 34 | 35 | if __name__ == "__main__": 36 | parser = argparse.ArgumentParser() 37 | 38 | # Data and model checkpoints directories 39 | parser.add_argument("--epochs", type=int, default=-1) 40 | parser.add_argument("--output-data-dir", type=str, default=os.environ["SM_OUTPUT_DATA_DIR"]) 41 | parser.add_argument("--model-dir", type=str, default=os.environ["SM_MODEL_DIR"]) 42 | parser.add_argument("--train", type=str, default=os.environ["SM_CHANNEL_TRAIN"]) 43 | parser.add_argument("--test", type=str, default=os.environ["SM_CHANNEL_TEST"]) 44 | args = parser.parse_args() 45 | 46 | # Load the data into memory as numpy arrays 47 | data_path = os.path.join(args.train, "mnist.npy.gz") 48 | with gzip.open(data_path, "rb") as f: 49 | data = np.load(f, allow_pickle=True) 50 | train_set = data[0] 51 | test_set = data[1] 52 | train_file = {'x': train_set[:, 1:], 'y': train_set[:, 0]} 53 | 54 | preprocessor = create_preprocessing_pipeline(train_file['x'].shape[1]) 55 | preprocessor.fit(X=train_file['x'], y=train_file['y']) 56 | joblib.dump(preprocessor, os.path.join(args.model_dir, "model.joblib")) 57 | print("saved model!") 58 | 59 | 60 | def input_fn(input_data, content_type): 61 | # Load the data into memory as numpy arrays 62 | buf = BytesIO(input_data) 63 | data = np.load(buf, allow_pickle=True) 64 | train_set = data[0] 65 | return train_set[:50, :] 66 | 67 | 68 | def predict_fn(data, model): 69 | transformed = np.concatenate((data[:, 0].reshape(-1, 1), model.transform(data[:, 1:])), axis=1) 70 | return transformed 71 | 72 | 73 | def model_fn(model_dir): 74 | clf = joblib.load(os.path.join(model_dir, "model.joblib")) 75 | return clf 76 | -------------------------------------------------------------------------------- /tests/data/pytorch_mnist/MNIST/processed/test.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aws/aws-step-functions-data-science-sdk-python/0e970ccde4ef7cc532f80040598e8d4d55f6486d/tests/data/pytorch_mnist/MNIST/processed/test.pt -------------------------------------------------------------------------------- /tests/data/pytorch_mnist/MNIST/processed/training.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aws/aws-step-functions-data-science-sdk-python/0e970ccde4ef7cc532f80040598e8d4d55f6486d/tests/data/pytorch_mnist/MNIST/processed/training.pt -------------------------------------------------------------------------------- /tests/data/pytorch_mnist/mnist.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"). 4 | # You may not use this file except in compliance with the License. 5 | # A copy of the License is located at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # or in the "license" file accompanying this file. This file is distributed 10 | # on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either 11 | # express or implied. See the License for the specific language governing 12 | # permissions and limitations under the License. 13 | 14 | import argparse 15 | import json 16 | import logging 17 | import os 18 | import sagemaker_containers 19 | import sys 20 | import torch 21 | import torch.distributed as dist 22 | import torch.nn as nn 23 | import torch.nn.functional as F 24 | import torch.optim as optim 25 | import torch.utils.data 26 | import torch.utils.data.distributed 27 | from torchvision import datasets, transforms 28 | 29 | logger = logging.getLogger(__name__) 30 | logger.setLevel(logging.DEBUG) 31 | logger.addHandler(logging.StreamHandler(sys.stdout)) 32 | 33 | 34 | # Based on https://github.com/pytorch/examples/blob/master/mnist/main.py 35 | class Net(nn.Module): 36 | def __init__(self): 37 | super(Net, self).__init__() 38 | self.conv1 = nn.Conv2d(1, 10, kernel_size=5) 39 | self.conv2 = nn.Conv2d(10, 20, kernel_size=5) 40 | self.conv2_drop = nn.Dropout2d() 41 | self.fc1 = nn.Linear(320, 50) 42 | self.fc2 = nn.Linear(50, 10) 43 | 44 | def forward(self, x): 45 | x = F.relu(F.max_pool2d(self.conv1(x), 2)) 46 | x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2)) 47 | x = x.view(-1, 320) 48 | x = F.relu(self.fc1(x)) 49 | x = F.dropout(x, training=self.training) 50 | x = self.fc2(x) 51 | return F.log_softmax(x, dim=1) 52 | 53 | 54 | def _get_train_data_loader(batch_size, training_dir, is_distributed, **kwargs): 55 | logger.info("Get train data loader") 56 | dataset = datasets.MNIST(training_dir, train=True, transform=transforms.Compose([ 57 | transforms.ToTensor(), 58 | transforms.Normalize((0.1307,), (0.3081,)) 59 | ])) 60 | train_sampler = torch.utils.data.distributed.DistributedSampler(dataset) if is_distributed else None 61 | return torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=train_sampler is None, 62 | sampler=train_sampler, **kwargs) 63 | 64 | 65 | def _get_test_data_loader(test_batch_size, training_dir, **kwargs): 66 | logger.info("Get test data loader") 67 | return torch.utils.data.DataLoader( 68 | datasets.MNIST(training_dir, train=False, transform=transforms.Compose([ 69 | transforms.ToTensor(), 70 | transforms.Normalize((0.1307,), (0.3081,)) 71 | ])), 72 | batch_size=test_batch_size, shuffle=True, **kwargs) 73 | 74 | 75 | def _average_gradients(model): 76 | # Gradient averaging. 77 | size = float(dist.get_world_size()) 78 | for param in model.parameters(): 79 | dist.all_reduce(param.grad.data, op=dist.reduce_op.SUM) 80 | param.grad.data /= size 81 | 82 | 83 | def train(args): 84 | is_distributed = len(args.hosts) > 1 and args.backend is not None 85 | logger.debug("Distributed training - {}".format(is_distributed)) 86 | use_cuda = args.num_gpus > 0 87 | logger.debug("Number of gpus available - {}".format(args.num_gpus)) 88 | kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {} 89 | device = torch.device("cuda" if use_cuda else "cpu") 90 | 91 | if is_distributed: 92 | # Initialize the distributed environment. 93 | world_size = len(args.hosts) 94 | os.environ['WORLD_SIZE'] = str(world_size) 95 | host_rank = args.hosts.index(args.current_host) 96 | os.environ['RANK'] = str(host_rank) 97 | dist.init_process_group(backend=args.backend, rank=host_rank, world_size=world_size) 98 | logger.info('Initialized the distributed environment: \'{}\' backend on {} nodes. '.format( 99 | args.backend, dist.get_world_size()) + 'Current host rank is {}. Number of gpus: {}'.format( 100 | dist.get_rank(), args.num_gpus)) 101 | 102 | # set the seed for generating random numbers 103 | torch.manual_seed(args.seed) 104 | if use_cuda: 105 | torch.cuda.manual_seed(args.seed) 106 | 107 | train_loader = _get_train_data_loader(args.batch_size, args.data_dir, is_distributed, **kwargs) 108 | test_loader = _get_test_data_loader(args.test_batch_size, args.data_dir, **kwargs) 109 | 110 | logger.debug("Processes {}/{} ({:.0f}%) of train data".format( 111 | len(train_loader.sampler), len(train_loader.dataset), 112 | 100. * len(train_loader.sampler) / len(train_loader.dataset) 113 | )) 114 | 115 | logger.debug("Processes {}/{} ({:.0f}%) of test data".format( 116 | len(test_loader.sampler), len(test_loader.dataset), 117 | 100. * len(test_loader.sampler) / len(test_loader.dataset) 118 | )) 119 | 120 | model = Net().to(device) 121 | if is_distributed and use_cuda: 122 | # multi-machine multi-gpu case 123 | model = torch.nn.parallel.DistributedDataParallel(model) 124 | else: 125 | # single-machine multi-gpu case or single-machine or multi-machine cpu case 126 | model = torch.nn.DataParallel(model) 127 | 128 | optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum) 129 | 130 | for epoch in range(1, args.epochs + 1): 131 | model.train() 132 | for batch_idx, (data, target) in enumerate(train_loader, 1): 133 | data, target = data.to(device), target.to(device) 134 | optimizer.zero_grad() 135 | output = model(data) 136 | loss = F.nll_loss(output, target) 137 | loss.backward() 138 | if is_distributed and not use_cuda: 139 | # average gradients manually for multi-machine cpu case only 140 | _average_gradients(model) 141 | optimizer.step() 142 | if batch_idx % args.log_interval == 0: 143 | logger.info('Train Epoch: {} [{}/{} ({:.0f}%)] Loss: {:.6f}'.format( 144 | epoch, batch_idx * len(data), len(train_loader.sampler), 145 | 100. * batch_idx / len(train_loader), loss.item())) 146 | test(model, test_loader, device) 147 | save_model(model, args.model_dir) 148 | 149 | 150 | def test(model, test_loader, device): 151 | model.eval() 152 | test_loss = 0 153 | correct = 0 154 | with torch.no_grad(): 155 | for data, target in test_loader: 156 | data, target = data.to(device), target.to(device) 157 | output = model(data) 158 | test_loss += F.nll_loss(output, target, size_average=False).item() # sum up batch loss 159 | pred = output.max(1, keepdim=True)[1] # get the index of the max log-probability 160 | correct += pred.eq(target.view_as(pred)).sum().item() 161 | 162 | test_loss /= len(test_loader.dataset) 163 | logger.info('Test set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format( 164 | test_loss, correct, len(test_loader.dataset), 165 | 100. * correct / len(test_loader.dataset))) 166 | 167 | 168 | def model_fn(model_dir): 169 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 170 | model = torch.nn.DataParallel(Net()) 171 | with open(os.path.join(model_dir, 'model.pth'), 'rb') as f: 172 | model.load_state_dict(torch.load(f, weights_only=True)) 173 | return model.to(device) 174 | 175 | 176 | def save_model(model, model_dir): 177 | logger.info("Saving the model.") 178 | path = os.path.join(model_dir, 'model.pth') 179 | # recommended way from http://pytorch.org/docs/master/notes/serialization.html 180 | torch.save(model.cpu().state_dict(), path) 181 | 182 | 183 | if __name__ == '__main__': 184 | parser = argparse.ArgumentParser() 185 | 186 | # Data and model checkpoints directories 187 | parser.add_argument('--batch-size', type=int, default=64, metavar='N', 188 | help='input batch size for training (default: 64)') 189 | parser.add_argument('--test-batch-size', type=int, default=1000, metavar='N', 190 | help='input batch size for testing (default: 1000)') 191 | parser.add_argument('--epochs', type=int, default=10, metavar='N', 192 | help='number of epochs to train (default: 10)') 193 | parser.add_argument('--lr', type=float, default=0.01, metavar='LR', 194 | help='learning rate (default: 0.01)') 195 | parser.add_argument('--momentum', type=float, default=0.5, metavar='M', 196 | help='SGD momentum (default: 0.5)') 197 | parser.add_argument('--seed', type=int, default=1, metavar='S', 198 | help='random seed (default: 1)') 199 | parser.add_argument('--log-interval', type=int, default=100, metavar='N', 200 | help='how many batches to wait before logging training status') 201 | parser.add_argument('--backend', type=str, default=None, 202 | help='backend for distributed training (tcp, gloo on cpu and gloo, nccl on gpu)') 203 | 204 | # Container environment 205 | parser.add_argument('--hosts', type=list, default=json.loads(os.environ['SM_HOSTS'])) 206 | parser.add_argument('--current-host', type=str, default=os.environ['SM_CURRENT_HOST']) 207 | parser.add_argument('--model-dir', type=str, default=os.environ['SM_MODEL_DIR']) 208 | parser.add_argument('--data-dir', type=str, default=os.environ['SM_CHANNEL_TRAINING']) 209 | parser.add_argument('--num-gpus', type=int, default=os.environ['SM_NUM_GPUS']) 210 | 211 | train(parser.parse_args()) 212 | -------------------------------------------------------------------------------- /tests/data/sklearn_mnist/mnist.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"). 4 | # You may not use this file except in compliance with the License. 5 | # A copy of the License is located at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # or in the "license" file accompanying this file. This file is distributed 10 | # on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either 11 | # express or implied. See the License for the specific language governing 12 | # permissions and limitations under the License. 13 | from __future__ import print_function, absolute_import 14 | 15 | import argparse 16 | import numpy as np 17 | import os 18 | 19 | from sklearn import svm 20 | from sklearn.externals import joblib 21 | 22 | 23 | def preprocess_mnist(raw, withlabel, ndim, scale, image_dtype, label_dtype, rgb_format): 24 | images = raw["x"] 25 | if ndim == 2: 26 | images = images.reshape(-1, 28, 28) 27 | elif ndim == 3: 28 | images = images.reshape(-1, 1, 28, 28) 29 | if rgb_format: 30 | images = np.broadcast_to(images, (len(images), 3) + images.shape[2:]) 31 | 32 | elif ndim != 1: 33 | raise ValueError("invalid ndim for MNIST dataset") 34 | images = images.astype(image_dtype) 35 | images *= scale / 255.0 36 | 37 | if withlabel: 38 | labels = raw["y"].astype(label_dtype) 39 | return images, labels 40 | return images 41 | 42 | 43 | if __name__ == "__main__": 44 | parser = argparse.ArgumentParser() 45 | 46 | # Data and model checkpoints directories 47 | parser.add_argument("--epochs", type=int, default=-1) 48 | parser.add_argument("--output-data-dir", type=str, default=os.environ["SM_OUTPUT_DATA_DIR"]) 49 | parser.add_argument("--model-dir", type=str, default=os.environ["SM_MODEL_DIR"]) 50 | parser.add_argument("--train", type=str, default=os.environ["SM_CHANNEL_TRAINING"]) 51 | # parser.add_argument("--test", type=str, default=os.environ["SM_CHANNEL_TEST"]) 52 | 53 | args = parser.parse_args() 54 | 55 | train_file = np.load(os.path.join(args.train, "train.npz")) 56 | # test_file = np.load(os.path.join(args.test, "test.npz")) 57 | 58 | preprocess_mnist_options = { 59 | "withlabel": True, 60 | "ndim": 1, 61 | "scale": 1.0, 62 | "image_dtype": np.float32, 63 | "label_dtype": np.int32, 64 | "rgb_format": False, 65 | } 66 | 67 | # Preprocess MNIST data 68 | train_images, train_labels = preprocess_mnist(train_file, **preprocess_mnist_options) 69 | # test_images, test_labels = preprocess_mnist(test_file, **preprocess_mnist_options) 70 | 71 | # Set up a Support Vector Machine classifier to predict digit from images 72 | clf = svm.SVC(gamma=0.001, C=100.0, max_iter=args.epochs) 73 | 74 | # Fit the SVM classifier with the images and the corresponding labels 75 | clf.fit(train_images, train_labels) 76 | 77 | # Print the coefficients of the trained classifier, and save the coefficients 78 | joblib.dump(clf, os.path.join(args.model_dir, "model.joblib")) 79 | 80 | 81 | def model_fn(model_dir): 82 | clf = joblib.load(os.path.join(model_dir, "model.joblib")) 83 | return clf 84 | -------------------------------------------------------------------------------- /tests/data/sklearn_mnist/test/test.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aws/aws-step-functions-data-science-sdk-python/0e970ccde4ef7cc532f80040598e8d4d55f6486d/tests/data/sklearn_mnist/test/test.npz -------------------------------------------------------------------------------- /tests/data/sklearn_mnist/train/train.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aws/aws-step-functions-data-science-sdk-python/0e970ccde4ef7cc532f80040598e8d4d55f6486d/tests/data/sklearn_mnist/train/train.npz -------------------------------------------------------------------------------- /tests/data/sklearn_processing/preprocessor.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import warnings 4 | 5 | import pandas as pd 6 | import numpy as np 7 | from sklearn.model_selection import train_test_split 8 | from sklearn.preprocessing import StandardScaler, OneHotEncoder, LabelBinarizer, KBinsDiscretizer 9 | from sklearn.preprocessing import PolynomialFeatures 10 | from sklearn.compose import make_column_transformer 11 | 12 | from sklearn.exceptions import DataConversionWarning 13 | warnings.filterwarnings(action='ignore', category=DataConversionWarning) 14 | 15 | 16 | columns = ['age', 'education', 'major industry code', 'class of worker', 'num persons worked for employer', 17 | 'capital gains', 'capital losses', 'dividends from stocks', 'income'] 18 | class_labels = [' - 50000.', ' 50000+.'] 19 | 20 | def print_shape(df): 21 | negative_examples, positive_examples = np.bincount(df['income']) 22 | print('Data shape: {}, {} positive examples, {} negative examples'.format(df.shape, positive_examples, negative_examples)) 23 | 24 | if __name__=='__main__': 25 | parser = argparse.ArgumentParser() 26 | parser.add_argument('--train-test-split-ratio', type=float, default=0.3) 27 | args, _ = parser.parse_known_args() 28 | 29 | print('Received arguments {}'.format(args)) 30 | 31 | input_data_path = os.path.join('/opt/ml/processing/input', 'census-income.csv') 32 | 33 | print('Reading input data from {}'.format(input_data_path)) 34 | df = pd.read_csv(input_data_path) 35 | df = pd.DataFrame(data=df, columns=columns) 36 | df.dropna(inplace=True) 37 | df.drop_duplicates(inplace=True) 38 | df.replace(class_labels, [0, 1], inplace=True) 39 | 40 | negative_examples, positive_examples = np.bincount(df['income']) 41 | print('Data after cleaning: {}, {} positive examples, {} negative examples'.format(df.shape, positive_examples, negative_examples)) 42 | 43 | split_ratio = args.train_test_split_ratio 44 | print('Splitting data into train and test sets with ratio {}'.format(split_ratio)) 45 | X_train, X_test, y_train, y_test = train_test_split(df.drop('income', axis=1), df['income'], test_size=split_ratio, random_state=0) 46 | 47 | preprocess = make_column_transformer( 48 | (['age', 'num persons worked for employer'], KBinsDiscretizer(encode='onehot-dense', n_bins=10)), 49 | (['capital gains', 'capital losses', 'dividends from stocks'], StandardScaler()), 50 | (['education', 'major industry code', 'class of worker'], OneHotEncoder(sparse=False)) 51 | ) 52 | print('Running preprocessing and feature engineering transformations') 53 | train_features = preprocess.fit_transform(X_train) 54 | test_features = preprocess.transform(X_test) 55 | 56 | print('Train data shape after preprocessing: {}'.format(train_features.shape)) 57 | print('Test data shape after preprocessing: {}'.format(test_features.shape)) 58 | 59 | train_features_output_path = os.path.join('/opt/ml/processing/train', 'train_features.csv') 60 | train_labels_output_path = os.path.join('/opt/ml/processing/train', 'train_labels.csv') 61 | 62 | test_features_output_path = os.path.join('/opt/ml/processing/test', 'test_features.csv') 63 | test_labels_output_path = os.path.join('/opt/ml/processing/test', 'test_labels.csv') 64 | 65 | print('Saving training features to {}'.format(train_features_output_path)) 66 | pd.DataFrame(train_features).to_csv(train_features_output_path, header=False, index=False) 67 | 68 | print('Saving test features to {}'.format(test_features_output_path)) 69 | pd.DataFrame(test_features).to_csv(test_features_output_path, header=False, index=False) 70 | 71 | print('Saving training labels to {}'.format(train_labels_output_path)) 72 | y_train.to_csv(train_labels_output_path, header=False, index=False) 73 | 74 | print('Saving test labels to {}'.format(test_labels_output_path)) 75 | y_test.to_csv(test_labels_output_path, header=False, index=False) -------------------------------------------------------------------------------- /tests/integ/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"). 4 | # You may not use this file except in compliance with the License. 5 | # A copy of the License is located at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # or in the "license" file accompanying this file. This file is distributed 10 | # on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either 11 | # express or implied. See the License for the specific language governing 12 | # permissions and limitations under the License. 13 | from __future__ import absolute_import 14 | 15 | import os 16 | 17 | from stepfunctions.steps import Retry 18 | 19 | DATA_DIR = os.path.join(os.path.dirname(__file__), "..", "data") 20 | DEFAULT_TIMEOUT_MINUTES = 25 21 | 22 | # Default retry strategy for SageMaker steps used in integration tests 23 | SAGEMAKER_RETRY_STRATEGY = Retry( 24 | error_equals=["SageMaker.AmazonSageMakerException"], 25 | interval_seconds=5, 26 | max_attempts=5, 27 | backoff_rate=2 28 | ) -------------------------------------------------------------------------------- /tests/integ/conftest.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"). 4 | # You may not use this file except in compliance with the License. 5 | # A copy of the License is located at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # or in the "license" file accompanying this file. This file is distributed 10 | # on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either 11 | # express or implied. See the License for the specific language governing 12 | # permissions and limitations under the License. 13 | from __future__ import absolute_import 14 | 15 | import pytest 16 | import boto3 17 | import gzip 18 | import sys 19 | import os 20 | import pickle 21 | from sagemaker import Session 22 | from sagemaker.amazon import pca 23 | from sagemaker.sklearn.processing import SKLearnProcessor 24 | from stepfunctions.steps.utils import get_aws_partition 25 | from tests.integ import DATA_DIR 26 | 27 | @pytest.fixture(scope="session") 28 | def sfn_client(): 29 | return boto3.client('stepfunctions') 30 | 31 | @pytest.fixture(scope="session") 32 | def sagemaker_client(): 33 | return boto3.client('sagemaker') 34 | 35 | @pytest.fixture(scope="session") 36 | def sagemaker_session(): 37 | sess = Session() 38 | return sess 39 | 40 | @pytest.fixture(scope="session") 41 | def aws_account_id(): 42 | account_id = boto3.client("sts").get_caller_identity().get("Account") 43 | return account_id 44 | 45 | @pytest.fixture(scope="session") 46 | def sfn_role_arn(aws_account_id): 47 | return f"arn:{get_aws_partition()}:iam::{aws_account_id}:role/StepFunctionsMLWorkflowExecutionFullAccess" 48 | 49 | @pytest.fixture(scope="session") 50 | def sagemaker_role_arn(aws_account_id): 51 | return f"arn:{get_aws_partition()}:iam::{aws_account_id}:role/SageMakerRole" 52 | 53 | @pytest.fixture(scope="session") 54 | def pca_estimator_fixture(sagemaker_role_arn): 55 | estimator = pca.PCA( 56 | role=sagemaker_role_arn, 57 | instance_count=1, 58 | instance_type="ml.m5.large", 59 | num_components=48 60 | ) 61 | return estimator 62 | 63 | @pytest.fixture(scope="session") 64 | def sklearn_processor_fixture(sagemaker_role_arn): 65 | processor = SKLearnProcessor( 66 | framework_version="0.20.0", 67 | role=sagemaker_role_arn, 68 | instance_type="ml.m5.xlarge", 69 | instance_count=1, 70 | max_runtime_in_seconds=300 71 | ) 72 | return processor 73 | 74 | @pytest.fixture(scope="session") 75 | def train_set(): 76 | data_path = os.path.join(DATA_DIR, "one_p_mnist", "mnist.pkl.gz") 77 | pickle_args = {} if sys.version_info.major == 2 else {"encoding": "latin1"} 78 | 79 | # Load the data into memory as numpy arrays 80 | with gzip.open(data_path, "rb") as f: 81 | train_set, _, _ = pickle.load(f, **pickle_args) 82 | 83 | return train_set 84 | 85 | @pytest.fixture(scope="session") 86 | def record_set_fixture(pca_estimator_fixture, train_set): 87 | record_set = pca_estimator_fixture.record_set(train=train_set[0][:100]) 88 | return record_set -------------------------------------------------------------------------------- /tests/integ/resources/SageMaker-TrustPolicy.json: -------------------------------------------------------------------------------- 1 | { 2 | "Version": "2012-10-17", 3 | "Statement": [ 4 | { 5 | "Sid": "", 6 | "Effect": "Allow", 7 | "Principal": { 8 | "Service": "sagemaker.amazonaws.com" 9 | }, 10 | "Action": "sts:AssumeRole" 11 | } 12 | ] 13 | } -------------------------------------------------------------------------------- /tests/integ/resources/StepFunctionsMLWorkflowExecutionFullAccess-Policy.json: -------------------------------------------------------------------------------- 1 | { 2 | "Version": "2012-10-17", 3 | "Statement": [ 4 | { 5 | "Effect": "Allow", 6 | "Action": [ 7 | "sagemaker:CreateTransformJob", 8 | "sagemaker:DescribeTransformJob", 9 | "sagemaker:StopTransformJob", 10 | "sagemaker:CreateTrainingJob", 11 | "sagemaker:DescribeTrainingJob", 12 | "sagemaker:StopTrainingJob", 13 | "sagemaker:CreateHyperParameterTuningJob", 14 | "sagemaker:DescribeHyperParameterTuningJob", 15 | "sagemaker:StopHyperParameterTuningJob", 16 | "sagemaker:CreateModel", 17 | "sagemaker:CreateEndpointConfig", 18 | "sagemaker:CreateEndpoint", 19 | "sagemaker:DeleteEndpointConfig", 20 | "sagemaker:DeleteEndpoint", 21 | "sagemaker:UpdateEndpoint", 22 | "sagemaker:ListTags", 23 | "sagemaker:CreateProcessingJob", 24 | "sagemaker:DescribeProcessingJob", 25 | "sagemaker:StopProcessingJob", 26 | "lambda:InvokeFunction", 27 | "sqs:SendMessage", 28 | "sns:Publish", 29 | "ecs:RunTask", 30 | "ecs:StopTask", 31 | "ecs:DescribeTasks", 32 | "dynamodb:GetItem", 33 | "dynamodb:PutItem", 34 | "dynamodb:UpdateItem", 35 | "dynamodb:DeleteItem", 36 | "batch:SubmitJob", 37 | "batch:DescribeJobs", 38 | "batch:TerminateJob", 39 | "glue:StartJobRun", 40 | "glue:GetJobRun", 41 | "glue:GetJobRuns", 42 | "glue:BatchStopJobRun" 43 | ], 44 | "Resource": "*" 45 | }, 46 | { 47 | "Effect": "Allow", 48 | "Action": [ 49 | "iam:PassRole" 50 | ], 51 | "Resource": "*", 52 | "Condition": { 53 | "StringEquals": { 54 | "iam:PassedToService": "sagemaker.amazonaws.com" 55 | } 56 | } 57 | }, 58 | { 59 | "Effect": "Allow", 60 | "Action": [ 61 | "events:PutTargets", 62 | "events:PutRule", 63 | "events:DescribeRule" 64 | ], 65 | "Resource": [ 66 | "arn:aws:events:*:*:rule/StepFunctionsGetEventsForSageMakerTrainingJobsRule", 67 | "arn:aws:events:*:*:rule/StepFunctionsGetEventsForSageMakerTransformJobsRule", 68 | "arn:aws:events:*:*:rule/StepFunctionsGetEventsForSageMakerTuningJobsRule", 69 | "arn:aws:events:*:*:rule/StepFunctionsGetEventsForSageMakerProcessingJobsRule", 70 | "arn:aws:events:*:*:rule/StepFunctionsGetEventsForECSTaskRule", 71 | "arn:aws:events:*:*:rule/StepFunctionsGetEventsForBatchJobsRule" 72 | ] 73 | } 74 | ] 75 | } 76 | -------------------------------------------------------------------------------- /tests/integ/resources/StepFunctionsMLWorkflowExecutionFullAccess-TrustPolicy.json: -------------------------------------------------------------------------------- 1 | { 2 | "Version": "2012-10-17", 3 | "Statement": [ 4 | { 5 | "Sid": "", 6 | "Effect": "Allow", 7 | "Principal": { 8 | "Service": "states.amazonaws.com" 9 | }, 10 | "Action": "sts:AssumeRole" 11 | } 12 | ] 13 | } -------------------------------------------------------------------------------- /tests/integ/test_inference_pipeline.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"). 4 | # You may not use this file except in compliance with the License. 5 | # A copy of the License is located at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # or in the "license" file accompanying this file. This file is distributed 10 | # on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either 11 | # express or implied. See the License for the specific language governing 12 | # permissions and limitations under the License. 13 | 14 | from __future__ import absolute_import 15 | 16 | import os 17 | import json 18 | import pytest 19 | from datetime import datetime 20 | 21 | from sagemaker.sklearn.estimator import SKLearn 22 | 23 | from stepfunctions.template.pipeline import InferencePipeline 24 | 25 | from tests.integ import DATA_DIR, DEFAULT_TIMEOUT_MINUTES 26 | from tests.integ.timeout import timeout 27 | from tests.integ.utils import ( 28 | state_machine_delete_wait, 29 | delete_sagemaker_model, 30 | delete_sagemaker_endpoint_config, 31 | delete_sagemaker_endpoint, 32 | ) 33 | 34 | 35 | # Constants 36 | BASE_NAME = 'inference-pipeline-integtest' 37 | COMPRESSED_NPY_DATA = 'mnist.npy.gz' 38 | 39 | # Fixtures 40 | @pytest.fixture(scope="module") 41 | def sklearn_preprocessor(sagemaker_role_arn, sagemaker_session): 42 | script_path = os.path.join(DATA_DIR, 43 | 'one_p_mnist', 44 | 'sklearn_mnist_preprocessor.py') 45 | sklearn_preprocessor = SKLearn( 46 | framework_version='0.20.0', 47 | py_version='py3', 48 | entry_point=script_path, 49 | role=sagemaker_role_arn, 50 | instance_type="ml.m5.large", 51 | sagemaker_session=sagemaker_session, 52 | hyperparameters={"epochs": 1}, 53 | ) 54 | return sklearn_preprocessor 55 | 56 | 57 | @pytest.fixture(scope="module") 58 | def sklearn_estimator(sagemaker_role_arn, sagemaker_session): 59 | script_path = os.path.join(DATA_DIR, 60 | 'one_p_mnist', 61 | 'sklearn_mnist_estimator.py') 62 | sklearn_estimator = SKLearn( 63 | framework_version='0.20.0', 64 | py_version='py3', 65 | entry_point=script_path, 66 | role=sagemaker_role_arn, 67 | instance_type="ml.m5.large", 68 | sagemaker_session=sagemaker_session, 69 | hyperparameters={"epochs": 1}, 70 | input_mode='File' 71 | ) 72 | return sklearn_estimator 73 | 74 | 75 | @pytest.fixture(scope="module") 76 | def inputs(sagemaker_session): 77 | data_path = os.path.join(DATA_DIR, "one_p_mnist", COMPRESSED_NPY_DATA) 78 | inputs = sagemaker_session.upload_data( 79 | path=data_path, key_prefix='dataset/one_p_mnist' 80 | ) 81 | return inputs 82 | 83 | 84 | def test_inference_pipeline_framework( 85 | sfn_client, 86 | sagemaker_session, 87 | sfn_role_arn, 88 | sagemaker_role_arn, 89 | sklearn_preprocessor, 90 | sklearn_estimator, 91 | inputs): 92 | bucket_name = sagemaker_session.default_bucket() 93 | unique_name = '{}-{}'.format(BASE_NAME, datetime.now().strftime('%Y%m%d%H%M%S')) 94 | with timeout(minutes=DEFAULT_TIMEOUT_MINUTES): 95 | pipeline = InferencePipeline( 96 | preprocessor=sklearn_preprocessor, 97 | estimator=sklearn_estimator, 98 | inputs={'train': inputs, 'test': inputs}, 99 | s3_bucket=bucket_name, 100 | role=sfn_role_arn, 101 | compression_type='Gzip', 102 | content_type='application/x-npy', 103 | pipeline_name=unique_name 104 | ) 105 | 106 | _ = pipeline.create() 107 | execution = pipeline.execute(job_name=unique_name) 108 | out = execution.get_output(wait=True) 109 | assert out # If fails, out is None. 110 | 111 | execution_info = execution.describe() 112 | 113 | execution_arn = execution.execution_arn 114 | state_machine_definition = sfn_client.describe_state_machine_for_execution(executionArn=execution_arn) 115 | state_machine_definition['definition'] = json.loads(state_machine_definition['definition']) 116 | assert state_machine_definition['definition'] == pipeline.workflow.definition.to_dict() 117 | 118 | state_machine_arn = state_machine_definition['stateMachineArn'] 119 | job_name = execution_info['name'] 120 | 121 | client_info = sfn_client.describe_execution(executionArn=execution_arn) 122 | client_info['input'] = json.loads(client_info['input']) 123 | _ = client_info.pop('ResponseMetadata') 124 | _ = client_info.pop('output') 125 | 126 | assert client_info['input'] == json.loads(execution_info['input']) 127 | 128 | state_machine_delete_wait(sfn_client, state_machine_arn) 129 | delete_sagemaker_endpoint(job_name, sagemaker_session) 130 | delete_sagemaker_endpoint_config(job_name, sagemaker_session) 131 | delete_sagemaker_model(job_name, sagemaker_session) 132 | -------------------------------------------------------------------------------- /tests/integ/test_training_pipeline_estimators.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"). 4 | # You may not use this file except in compliance with the License. 5 | # A copy of the License is located at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # or in the "license" file accompanying this file. This file is distributed 10 | # on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either 11 | # express or implied. See the License for the specific language governing 12 | # permissions and limitations under the License. 13 | from __future__ import absolute_import 14 | 15 | import os 16 | import sys 17 | import gzip 18 | import pickle 19 | import pytest 20 | import numpy as np 21 | import json 22 | from datetime import datetime 23 | 24 | import boto3 25 | 26 | # import Sagemaker 27 | from sagemaker.amazon.pca import PCA 28 | from sagemaker.image_uris import retrieve 29 | 30 | # import StepFunctions 31 | from stepfunctions.template.pipeline import TrainingPipeline 32 | 33 | from tests.integ import DATA_DIR, DEFAULT_TIMEOUT_MINUTES 34 | from tests.integ.timeout import timeout 35 | from tests.integ.utils import ( 36 | state_machine_delete_wait, 37 | delete_sagemaker_model, 38 | delete_sagemaker_endpoint_config, 39 | delete_sagemaker_endpoint, 40 | ) 41 | 42 | 43 | # Constants 44 | BASE_NAME = 'training-pipeline-integtest' 45 | 46 | 47 | # Fixtures 48 | @pytest.fixture(scope="module") 49 | def pca_estimator(sagemaker_role_arn): 50 | pca_estimator = PCA( 51 | role=sagemaker_role_arn, 52 | num_components=1, 53 | instance_count=1, 54 | instance_type='ml.m5.large', 55 | ) 56 | 57 | pca_estimator.feature_dim=500 58 | pca_estimator.subtract_mean=True, 59 | pca_estimator.algorithm_mode='randomized' 60 | pca_estimator.mini_batch_size=128 61 | 62 | return pca_estimator 63 | 64 | @pytest.fixture(scope="module") 65 | def inputs(pca_estimator): 66 | data_path = os.path.join(DATA_DIR, "one_p_mnist", "mnist.pkl.gz") 67 | pickle_args = {} if sys.version_info.major == 2 else {"encoding": "latin1"} 68 | 69 | # Load the data into memory as numpy arrays 70 | with gzip.open(data_path, "rb") as f: 71 | train_set, _, _ = pickle.load(f, **pickle_args) 72 | 73 | inputs = pca_estimator.record_set(train=train_set[0][:100]) 74 | return inputs 75 | 76 | 77 | def test_pca_estimator(sfn_client, sagemaker_session, sagemaker_role_arn, sfn_role_arn, pca_estimator, inputs): 78 | bucket_name = sagemaker_session.default_bucket() 79 | unique_name = '{}-{}'.format(BASE_NAME, datetime.now().strftime('%Y%m%d%H%M%S')) 80 | hyperparams = pca_estimator.hyperparameters() 81 | 82 | with timeout(minutes=DEFAULT_TIMEOUT_MINUTES): 83 | tp = TrainingPipeline( 84 | estimator=pca_estimator, 85 | role=sfn_role_arn, 86 | inputs=inputs, 87 | s3_bucket=bucket_name, 88 | pipeline_name = unique_name 89 | ) 90 | tp.create() 91 | 92 | execution = tp.execute(job_name=unique_name, hyperparameters=hyperparams) 93 | out = execution.get_output(wait=True) 94 | assert out # If fails, out is None. 95 | endpoint_arn = out['EndpointArn'] 96 | 97 | workflow_execution_info = execution.describe() 98 | 99 | execution_arn = execution.execution_arn 100 | state_machine_definition = sfn_client.describe_state_machine_for_execution(executionArn=execution_arn) 101 | state_machine_definition['definition'] = json.loads(state_machine_definition['definition']) 102 | assert state_machine_definition['definition'] == tp.workflow.definition.to_dict() 103 | 104 | state_machine_arn = state_machine_definition['stateMachineArn'] 105 | job_name = workflow_execution_info['name'] 106 | s3_manifest_uri = inputs.s3_data 107 | status = 'SUCCEEDED' 108 | estimator_image_uri = retrieve(region=sagemaker_session.boto_region_name, framework='pca') 109 | 110 | execution_info = sfn_client.describe_execution(executionArn=execution_arn) 111 | execution_info['input'] = json.loads(execution_info['input']) 112 | _=execution_info.pop('ResponseMetadata') 113 | _=execution_info.pop('output') 114 | 115 | s3_output_path = 's3://{bucket_name}/{workflow_name}/models'.format(bucket_name=bucket_name, workflow_name=unique_name) 116 | expected_execution_info = {'executionArn': execution_arn, 117 | 'stateMachineArn': state_machine_arn, 118 | 'inputDetails': {'included': True}, 119 | 'name': job_name, 120 | 'outputDetails': {'included': True}, 121 | 'status': status, 122 | 'startDate': execution_info['startDate'], 123 | 'stopDate': execution_info['stopDate'], 124 | 'inputDetails': {'included': True}, 125 | 'outputDetails': {'included': True}, 126 | 'input': {'Training': {'AlgorithmSpecification': {'TrainingImage': estimator_image_uri, 127 | 'TrainingInputMode': 'File'}, 128 | 'OutputDataConfig': {'S3OutputPath': s3_output_path}, 129 | 'StoppingCondition': {'MaxRuntimeInSeconds': 86400}, 130 | 'ResourceConfig': {'InstanceCount': 1, 131 | 'InstanceType': 'ml.m5.large', 132 | 'VolumeSizeInGB': 30}, 133 | 'RoleArn': sagemaker_role_arn, 134 | 'InputDataConfig': [{'DataSource': {'S3DataSource': {'S3DataDistributionType': 'ShardedByS3Key', 135 | 'S3DataType': 'ManifestFile', 136 | 'S3Uri': s3_manifest_uri}}, 137 | 'ChannelName': 'train'}], 138 | 'HyperParameters': hyperparams, 139 | 'TrainingJobName': 'estimator-' + job_name}, 140 | 'Create Model': {'ModelName': job_name, 141 | 'PrimaryContainer': {'Image': estimator_image_uri, 142 | 'Environment': {}, 143 | 'ModelDataUrl': 's3://' + bucket_name +'/' + unique_name + '/models/' + 'estimator-'+job_name + '/output/model.tar.gz'}, 144 | 'ExecutionRoleArn': sagemaker_role_arn}, 145 | 'Configure Endpoint': {'EndpointConfigName': job_name, 146 | 'ProductionVariants': [{'ModelName': job_name, 147 | 'InstanceType': 'ml.m5.large', 148 | 'InitialInstanceCount': 1, 149 | 'VariantName': 'AllTraffic'}]}, 150 | 'Deploy': {'EndpointName': job_name, 151 | 'EndpointConfigName': job_name}} 152 | } 153 | assert execution_info == expected_execution_info 154 | 155 | # Cleanup 156 | state_machine_delete_wait(sfn_client, state_machine_arn) 157 | delete_sagemaker_endpoint(job_name, sagemaker_session) 158 | delete_sagemaker_endpoint_config(job_name, sagemaker_session) 159 | delete_sagemaker_model(job_name, sagemaker_session) 160 | -------------------------------------------------------------------------------- /tests/integ/test_training_pipeline_framework_estimator.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"). 4 | # You may not use this file except in compliance with the License. 5 | # A copy of the License is located at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # or in the "license" file accompanying this file. This file is distributed 10 | # on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either 11 | # express or implied. See the License for the specific language governing 12 | # permissions and limitations under the License. 13 | from __future__ import absolute_import 14 | 15 | import pytest 16 | import sagemaker 17 | import os 18 | 19 | from tests.integ import DATA_DIR, DEFAULT_TIMEOUT_MINUTES 20 | from tests.integ.timeout import timeout 21 | from stepfunctions.template import TrainingPipeline 22 | from sagemaker.pytorch import PyTorch 23 | from sagemaker.sklearn import SKLearn 24 | from tests.integ.utils import ( 25 | state_machine_delete_wait, 26 | delete_sagemaker_model, 27 | delete_sagemaker_endpoint_config, 28 | delete_sagemaker_endpoint, 29 | get_resource_name_from_arn 30 | ) 31 | 32 | @pytest.fixture(scope="module") 33 | def torch_estimator(sagemaker_role_arn): 34 | script_path = os.path.join(DATA_DIR, "pytorch_mnist", "mnist.py") 35 | return PyTorch( 36 | py_version='py3', 37 | entry_point=script_path, 38 | role=sagemaker_role_arn, 39 | framework_version='1.2.0', 40 | instance_count=1, 41 | instance_type='ml.m5.large', 42 | hyperparameters={ 43 | 'epochs': 6, 44 | 'backend': 'gloo' 45 | } 46 | ) 47 | 48 | @pytest.fixture(scope="module") 49 | def sklearn_estimator(sagemaker_role_arn): 50 | script_path = os.path.join(DATA_DIR, "sklearn_mnist", "mnist.py") 51 | return SKLearn( 52 | framework_version='0.20.0', 53 | py_version='py3', 54 | entry_point=script_path, 55 | role=sagemaker_role_arn, 56 | instance_count=1, 57 | instance_type='ml.m5.large', 58 | hyperparameters={ 59 | "epochs": 1 60 | } 61 | ) 62 | 63 | 64 | def _get_endpoint_name(execution_output): 65 | endpoint_arn = execution_output.get('EndpointArn', None) 66 | endpoint_name = None 67 | 68 | if endpoint_arn is not None: 69 | resource_name = get_resource_name_from_arn(endpoint_arn) 70 | endpoint_name = resource_name.split("/")[-1] 71 | 72 | return endpoint_name 73 | 74 | 75 | def _pipeline_test_suite(sagemaker_client, training_job_name, model_name, endpoint_name): 76 | assert sagemaker_client.describe_training_job(TrainingJobName=training_job_name).get('TrainingJobName') == training_job_name 77 | assert sagemaker_client.describe_model(ModelName=model_name).get('ModelName') == endpoint_name 78 | assert sagemaker_client.describe_endpoint(EndpointName=endpoint_name).get('EndpointName') == endpoint_name 79 | 80 | 81 | def _pipeline_teardown(sfn_client, sagemaker_session, endpoint_name, pipeline): 82 | if endpoint_name is not None: 83 | delete_sagemaker_endpoint(endpoint_name, sagemaker_session) 84 | delete_sagemaker_endpoint_config(endpoint_name, sagemaker_session) 85 | delete_sagemaker_model(endpoint_name, sagemaker_session) 86 | 87 | state_machine_delete_wait(sfn_client, pipeline.workflow.state_machine_arn) 88 | 89 | 90 | def test_torch_training_pipeline(sfn_client, sagemaker_client, torch_estimator, sagemaker_session, sfn_role_arn): 91 | with timeout(minutes=DEFAULT_TIMEOUT_MINUTES): 92 | # upload input data 93 | data_path = os.path.join(DATA_DIR, "pytorch_mnist") 94 | inputs = sagemaker_session.upload_data( 95 | path=data_path, 96 | bucket=sagemaker_session.default_bucket(), 97 | key_prefix='integ-test-data/torch_mnist/train' 98 | ) 99 | 100 | # create training pipeline 101 | pipeline = TrainingPipeline( 102 | torch_estimator, 103 | sfn_role_arn, 104 | inputs, 105 | sagemaker_session.default_bucket(), 106 | sfn_client 107 | ) 108 | pipeline.create() 109 | # execute pipeline 110 | execution = pipeline.execute() 111 | 112 | # get pipeline output and extract endpoint name 113 | execution_output = execution.get_output(wait=True) 114 | assert execution_output # If fails, execution_output is None. 115 | 116 | endpoint_name = _get_endpoint_name(execution_output) 117 | 118 | # assertions 119 | _pipeline_test_suite(sagemaker_client, training_job_name='estimator-'+endpoint_name, model_name=endpoint_name, endpoint_name=endpoint_name) 120 | 121 | # teardown 122 | _pipeline_teardown(sfn_client, sagemaker_session, endpoint_name, pipeline) 123 | 124 | 125 | def test_sklearn_training_pipeline(sfn_client, sagemaker_client, sklearn_estimator, sagemaker_session, sfn_role_arn): 126 | with timeout(minutes=DEFAULT_TIMEOUT_MINUTES): 127 | # upload input data 128 | data_path = os.path.join(DATA_DIR, "sklearn_mnist") 129 | inputs = sagemaker_session.upload_data( 130 | path=os.path.join(data_path, "train"), 131 | bucket=sagemaker_session.default_bucket(), 132 | key_prefix="integ-test-data/sklearn_mnist/train" 133 | ) 134 | 135 | # create training pipeline 136 | pipeline = TrainingPipeline( 137 | sklearn_estimator, 138 | sfn_role_arn, 139 | inputs, 140 | sagemaker_session.default_bucket(), 141 | sfn_client 142 | ) 143 | pipeline.create() 144 | # run pipeline 145 | execution = pipeline.execute() 146 | 147 | # get pipeline output and extract endpoint name 148 | execution_output = execution.get_output(wait=True) 149 | assert execution_output # If fails, execution_output is None. 150 | 151 | endpoint_name = _get_endpoint_name(execution_output) 152 | 153 | # assertions 154 | _pipeline_test_suite(sagemaker_client, training_job_name='estimator-'+endpoint_name, model_name=endpoint_name, endpoint_name=endpoint_name) 155 | 156 | # teardown 157 | _pipeline_teardown(sfn_client, sagemaker_session, endpoint_name, pipeline) -------------------------------------------------------------------------------- /tests/integ/timeout.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"). 4 | # You may not use this file except in compliance with the License. 5 | # A copy of the License is located at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # or in the "license" file accompanying this file. This file is distributed 10 | # on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either 11 | # express or implied. See the License for the specific language governing 12 | # permissions and limitations under the License. 13 | from __future__ import absolute_import 14 | 15 | from contextlib import contextmanager 16 | import stopit 17 | 18 | @contextmanager 19 | def timeout(seconds=0, minutes=0, hours=0): 20 | """ 21 | Add a signal-based timeout to any block of code. 22 | If multiple time units are specified, they will be added together to determine time limit. 23 | Usage: 24 | with timeout(seconds=5): 25 | my_slow_function(...) 26 | Args: 27 | - seconds: The time limit, in seconds. 28 | - minutes: The time limit, in minutes. 29 | - hours: The time limit, in hours. 30 | """ 31 | 32 | limit = seconds + 60 * minutes + 3600 * hours 33 | 34 | with stopit.ThreadingTimeout(limit, swallow_exc=False) as t: 35 | yield [t] 36 | -------------------------------------------------------------------------------- /tests/integ/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"). 4 | # You may not use this file except in compliance with the License. 5 | # A copy of the License is located at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # or in the "license" file accompanying this file. This file is distributed 10 | # on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either 11 | # express or implied. See the License for the specific language governing 12 | # permissions and limitations under the License. 13 | from __future__ import absolute_import 14 | 15 | import time 16 | 17 | def state_machine_delete_wait(client, state_machine_arn, sleep_interval=10): 18 | response = client.delete_state_machine(stateMachineArn=state_machine_arn) 19 | state_machine_status = "DELETING" 20 | 21 | while state_machine_status is not None: 22 | try: 23 | state_machine_describe = client.describe_state_machine(stateMachineArn=state_machine_arn) 24 | state_machine_status = state_machine_describe.get("status") 25 | time.sleep(sleep_interval) 26 | except: 27 | state_machine_status = None 28 | 29 | def delete_sagemaker_model(model_name, sagemaker_session): 30 | sagemaker_session.delete_model(model_name=model_name) 31 | 32 | def delete_sagemaker_endpoint_config(endpoint_config_name, sagemaker_session): 33 | sagemaker_session.delete_endpoint_config(endpoint_config_name=endpoint_config_name) 34 | 35 | def delete_sagemaker_endpoint(endpoint_name, sagemaker_session, sleep_interval=10): 36 | sagemaker_session.wait_for_endpoint(endpoint=endpoint_name, poll=sleep_interval) 37 | sagemaker_session.delete_endpoint(endpoint_name=endpoint_name) 38 | 39 | def get_resource_name_from_arn(arn): 40 | return arn.split(":")[-1] -------------------------------------------------------------------------------- /tests/unit/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"). 4 | # You may not use this file except in compliance with the License. 5 | # A copy of the License is located at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # or in the "license" file accompanying this file. This file is distributed 10 | # on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either 11 | # express or implied. See the License for the specific language governing 12 | # permissions and limitations under the License. 13 | from __future__ import absolute_import -------------------------------------------------------------------------------- /tests/unit/test_choice_rule.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"). 4 | # You may not use this file except in compliance with the License. 5 | # A copy of the License is located at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # or in the "license" file accompanying this file. This file is distributed 10 | # on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either 11 | # express or implied. See the License for the specific language governing 12 | # permissions and limitations under the License. 13 | from __future__ import absolute_import 14 | 15 | import pytest 16 | 17 | from stepfunctions.steps import Pass, Succeed, Fail, Wait, ChoiceRule 18 | 19 | 20 | def test_variable_must_start_with_prefix(): 21 | with pytest.raises(ValueError): 22 | ChoiceRule.StringEquals('Variable', '42') 23 | 24 | def test_variable_value_must_be_consistent(): 25 | string_functions = ( 26 | 'StringEquals', 27 | 'StringLessThan', 28 | 'StringGreaterThan', 29 | 'StringLessThanEquals', 30 | 'StringGreaterThanEquals', 31 | ) 32 | for string_function in string_functions: 33 | func = getattr(ChoiceRule, string_function) 34 | with pytest.raises(ValueError): 35 | func('$.Variable', 42) 36 | 37 | numeric_functions = ( 38 | 'NumericEquals', 39 | 'NumericLessThan', 40 | 'NumericGreaterThan', 41 | 'NumericLessThanEquals', 42 | 'NumericGreaterThanEquals', 43 | ) 44 | for numeric_function in numeric_functions: 45 | func = getattr(ChoiceRule, numeric_function) 46 | with pytest.raises(ValueError): 47 | func('$.Variable', 'ABC') 48 | 49 | with pytest.raises(ValueError): 50 | ChoiceRule.BooleanEquals('$.Variable', 42) 51 | 52 | timestamp_functions = ( 53 | 'TimestampEquals', 54 | 'TimestampLessThan', 55 | 'TimestampGreaterThan', 56 | 'TimestampLessThanEquals', 57 | 'TimestampGreaterThanEquals', 58 | ) 59 | for timestamp_function in timestamp_functions: 60 | func = getattr(ChoiceRule, timestamp_function) 61 | with pytest.raises(ValueError): 62 | func('$.Variable', True) 63 | 64 | def test_path_comparator_raises_error_when_value_is_not_a_path(): 65 | path_comparators = { 66 | 'StringEqualsPath', 67 | 'NumericEqualsPath', 68 | 'TimestampEqualsPath', 69 | 'BooleanEqualsPath' 70 | } 71 | for path_comparator in path_comparators: 72 | func = getattr(ChoiceRule, path_comparator) 73 | with pytest.raises(ValueError): 74 | func('$.Variable', 'string') 75 | 76 | def test_is_comparator_raises_error_when_value_is_not_a_bool(): 77 | type_comparators = { 78 | 'IsPresent', 79 | 'IsNull', 80 | 'IsString', 81 | 'IsNumeric', 82 | 'IsBoolean', 83 | 'IsTimestamp' 84 | } 85 | 86 | for type_comparator in type_comparators: 87 | func = getattr(ChoiceRule, type_comparator) 88 | with pytest.raises(ValueError): 89 | func('$.Variable', 'string') 90 | with pytest.raises(ValueError): 91 | func('$.Variable', 101) 92 | 93 | def test_static_comparator_serialization(): 94 | string_timestamp_static_comparators = { 95 | 'StringEquals', 96 | 'StringLessThan', 97 | 'StringLessThanEquals', 98 | 'StringGreaterThan', 99 | 'StringGreaterThanEquals', 100 | 'TimestampEquals', 101 | 'TimestampLessThan', 102 | 'TimestampGreaterThan', 103 | 'TimestampLessThanEquals' 104 | } 105 | 106 | for string_timestamp_static_comparator in string_timestamp_static_comparators: 107 | type_rule = getattr(ChoiceRule, string_timestamp_static_comparator)('$.input', 'hello') 108 | expected_dict = {} 109 | expected_dict['Variable'] = '$.input' 110 | expected_dict[string_timestamp_static_comparator] = 'hello' 111 | assert type_rule.to_dict() == expected_dict 112 | 113 | number_static_comparators = { 114 | 'NumericEquals', 115 | 'NumericLessThan', 116 | 'NumericGreaterThan', 117 | 'NumericLessThanEquals', 118 | 'NumericGreaterThanEquals' 119 | } 120 | 121 | for number_static_comparator in number_static_comparators: 122 | type_rule = getattr(ChoiceRule, number_static_comparator)('$.input', 123) 123 | expected_dict = {} 124 | expected_dict['Variable'] = '$.input' 125 | expected_dict[number_static_comparator] = 123 126 | assert type_rule.to_dict() == expected_dict 127 | 128 | boolean_static_comparators = { 129 | 'BooleanEquals' 130 | } 131 | 132 | for boolean_static_comparator in boolean_static_comparators: 133 | type_rule = getattr(ChoiceRule, boolean_static_comparator)('$.input', False) 134 | expected_dict = {} 135 | expected_dict['Variable'] = '$.input' 136 | expected_dict[boolean_static_comparator] = False 137 | assert type_rule.to_dict() == expected_dict 138 | 139 | def test_dynamic_comparator_serialization(): 140 | dynamic_comparators = { 141 | 'StringEqualsPath', 142 | 'StringLessThanPath', 143 | 'StringLessThanEqualsPath', 144 | 'StringGreaterThanPath', 145 | 'StringGreaterThanEqualsPath', 146 | 'TimestampEqualsPath', 147 | 'TimestampLessThanPath', 148 | 'TimestampGreaterThanPath', 149 | 'TimestampLessThanEqualsPath', 150 | 'NumericEqualsPath', 151 | 'NumericLessThanPath', 152 | 'NumericGreaterThanPath', 153 | 'NumericLessThanEqualsPath', 154 | 'NumericGreaterThanEqualsPath', 155 | 'BooleanEqualsPath' 156 | } 157 | 158 | for dynamic_comparator in dynamic_comparators: 159 | type_rule = getattr(ChoiceRule, dynamic_comparator)('$.input', '$.input2') 160 | expected_dict = {} 161 | expected_dict['Variable'] = '$.input' 162 | expected_dict[dynamic_comparator] = '$.input2' 163 | assert type_rule.to_dict() == expected_dict 164 | 165 | def test_type_check_comparators_serialization(): 166 | type_comparators = { 167 | 'IsPresent', 168 | 'IsNull', 169 | 'IsString', 170 | 'IsNumeric', 171 | 'IsBoolean', 172 | 'IsTimestamp' 173 | } 174 | 175 | for type_comparator in type_comparators: 176 | type_rule = getattr(ChoiceRule, type_comparator)('$.input', True) 177 | expected_dict = {} 178 | expected_dict['Variable'] = '$.input' 179 | expected_dict[type_comparator] = True 180 | assert type_rule.to_dict() == expected_dict 181 | 182 | def test_string_matches_serialization(): 183 | string_matches_rule = ChoiceRule.StringMatches('$.input', 'hello*world\\*') 184 | assert string_matches_rule.to_dict() == { 185 | 'Variable': '$.input', 186 | 'StringMatches': 'hello*world\\*' 187 | } 188 | 189 | def test_rule_serialization(): 190 | bool_rule = ChoiceRule.BooleanEquals('$.BooleanVariable', True) 191 | assert bool_rule.to_dict() == { 192 | 'Variable': '$.BooleanVariable', 193 | 'BooleanEquals': True 194 | } 195 | 196 | string_rule = ChoiceRule.StringEquals('$.StringVariable', 'ABC') 197 | assert string_rule.to_dict() == { 198 | 'Variable': '$.StringVariable', 199 | 'StringEquals': 'ABC' 200 | } 201 | 202 | and_rule = ChoiceRule.And([bool_rule, string_rule]) 203 | assert and_rule.to_dict() == { 204 | 'And': [ 205 | { 206 | 'Variable': '$.BooleanVariable', 207 | 'BooleanEquals': True 208 | }, 209 | { 210 | 'Variable': '$.StringVariable', 211 | 'StringEquals': 'ABC' 212 | } 213 | ] 214 | } 215 | 216 | not_rule = ChoiceRule.Not(string_rule) 217 | assert not_rule.to_dict() == { 218 | 'Not': { 219 | 'Variable': '$.StringVariable', 220 | 'StringEquals': 'ABC' 221 | } 222 | } 223 | 224 | compound_rule = ChoiceRule.Or([and_rule, not_rule]) 225 | assert compound_rule.to_dict() == { 226 | 'Or': [ 227 | { 228 | 'And': [{ 229 | 'Variable': '$.BooleanVariable', 230 | 'BooleanEquals': True 231 | }, 232 | { 233 | 'Variable': '$.StringVariable', 234 | 'StringEquals': 'ABC' 235 | }], 236 | }, 237 | { 238 | 'Not': { 239 | 'Variable': '$.StringVariable', 240 | 'StringEquals': 'ABC' 241 | } 242 | } 243 | ] 244 | } 245 | -------------------------------------------------------------------------------- /tests/unit/test_compute_steps.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"). 4 | # You may not use this file except in compliance with the License. 5 | # A copy of the License is located at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # or in the "license" file accompanying this file. This file is distributed 10 | # on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either 11 | # express or implied. See the License for the specific language governing 12 | # permissions and limitations under the License. 13 | from __future__ import absolute_import 14 | 15 | import pytest 16 | import boto3 17 | 18 | from unittest.mock import patch 19 | from stepfunctions.steps.compute import LambdaStep, GlueStartJobRunStep, BatchSubmitJobStep, EcsRunTaskStep 20 | 21 | 22 | @patch.object(boto3.session.Session, 'region_name', 'us-east-1') 23 | def test_lambda_step_creation(): 24 | step = LambdaStep('Echo') 25 | 26 | assert step.to_dict() == { 27 | 'Type': 'Task', 28 | 'Resource': 'arn:aws:states:::lambda:invoke', 29 | 'End': True 30 | } 31 | 32 | step = LambdaStep('lambda', wait_for_callback=True, parameters={ 33 | 'Payload': { 34 | 'model.$': '$.new_model', 35 | 'token.$': '$$.Task.Token' 36 | } 37 | }) 38 | 39 | assert step.to_dict() == { 40 | 'Type': 'Task', 41 | 'Resource': 'arn:aws:states:::lambda:invoke.waitForTaskToken', 42 | 'Parameters': { 43 | 'Payload': { 44 | 'model.$': '$.new_model', 45 | 'token.$': '$$.Task.Token' 46 | }, 47 | }, 48 | 'End': True 49 | } 50 | 51 | 52 | @patch.object(boto3.session.Session, 'region_name', 'us-east-1') 53 | def test_glue_start_job_run_step_creation(): 54 | step = GlueStartJobRunStep('Glue Job', wait_for_completion=False) 55 | 56 | assert step.to_dict() == { 57 | 'Type': 'Task', 58 | 'Resource': 'arn:aws:states:::glue:startJobRun', 59 | 'End': True 60 | } 61 | 62 | step = GlueStartJobRunStep('Glue Job', parameters={ 63 | 'JobName': 'Job' 64 | }) 65 | 66 | assert step.to_dict() == { 67 | 'Type': 'Task', 68 | 'Resource': 'arn:aws:states:::glue:startJobRun.sync', 69 | 'Parameters': { 70 | 'JobName': 'Job', 71 | }, 72 | 'End': True 73 | } 74 | 75 | 76 | @patch.object(boto3.session.Session, 'region_name', 'us-east-1') 77 | def test_batch_submit_job_step_creation(): 78 | step = BatchSubmitJobStep('Batch Job', wait_for_completion=False) 79 | 80 | assert step.to_dict() == { 81 | 'Type': 'Task', 82 | 'Resource': 'arn:aws:states:::batch:submitJob', 83 | 'End': True 84 | } 85 | 86 | step = BatchSubmitJobStep('Batch Job', parameters={ 87 | 'JobName': 'Job', 88 | 'JobQueue': 'JobQueue' 89 | }) 90 | 91 | assert step.to_dict() == { 92 | 'Type': 'Task', 93 | 'Resource': 'arn:aws:states:::batch:submitJob.sync', 94 | 'Parameters': { 95 | 'JobName': 'Job', 96 | 'JobQueue': 'JobQueue' 97 | }, 98 | 'End': True 99 | } 100 | 101 | 102 | @patch.object(boto3.session.Session, 'region_name', 'us-east-1') 103 | def test_ecs_run_task_step_creation(): 104 | step = EcsRunTaskStep('Ecs Job', wait_for_completion=False) 105 | 106 | assert step.to_dict() == { 107 | 'Type': 'Task', 108 | 'Resource': 'arn:aws:states:::ecs:runTask', 109 | 'End': True 110 | } 111 | 112 | step = EcsRunTaskStep('Ecs Job', parameters={ 113 | 'TaskDefinition': 'Task' 114 | }) 115 | 116 | assert step.to_dict() == { 117 | 'Type': 'Task', 118 | 'Resource': 'arn:aws:states:::ecs:runTask.sync', 119 | 'Parameters': { 120 | 'TaskDefinition': 'Task' 121 | }, 122 | 'End': True 123 | } 124 | -------------------------------------------------------------------------------- /tests/unit/test_graph.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"). 4 | # You may not use this file except in compliance with the License. 5 | # A copy of the License is located at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # or in the "license" file accompanying this file. This file is distributed 10 | # on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either 11 | # express or implied. See the License for the specific language governing 12 | # permissions and limitations under the License. 13 | from __future__ import absolute_import 14 | 15 | import pytest 16 | 17 | from stepfunctions.steps import Pass, Succeed, Fail, Wait, Choice, ChoiceRule, Parallel, Map, Task, Chain, Graph, FrozenGraph 18 | from stepfunctions.steps.states import State 19 | 20 | 21 | def test_nested_parallel_example(): 22 | nested_level_1 = Parallel('NestedStateLevel1') 23 | nested_level_1.add_branch(Succeed('NestedStateLevel2')) 24 | 25 | first_state = Parallel('FirstState') 26 | first_state.add_branch(nested_level_1) 27 | 28 | result = Graph(first_state, comment='This is a test.', version='1.0', timeout_seconds=3600).to_dict() 29 | assert result == { 30 | 'StartAt': 'FirstState', 31 | 'Comment': 'This is a test.', 32 | 'Version': '1.0', 33 | 'TimeoutSeconds': 3600, 34 | 'States': { 35 | 'FirstState': { 36 | 'Type': 'Parallel', 37 | 'Branches': [ 38 | { 39 | 'StartAt': 'NestedStateLevel1', 40 | 'States': { 41 | 'NestedStateLevel1': { 42 | 'Type': 'Parallel', 43 | 'Branches': [ 44 | { 45 | 'StartAt': 'NestedStateLevel2', 46 | 'States': { 47 | 'NestedStateLevel2': { 48 | 'Type': 'Succeed' 49 | } 50 | } 51 | } 52 | ], 53 | 'End': True 54 | } 55 | } 56 | } 57 | ], 58 | 'End': True 59 | } 60 | } 61 | } 62 | 63 | def test_wait_loop(): 64 | first_state = Task('FirstState', resource='arn:aws:lambda:us-east-1:1234567890:function:FirstState') 65 | retry = Chain([Pass('Retry'), Pass('Cleanup'), first_state]) 66 | 67 | choice_state = Choice('Is Completed?') 68 | choice_state.add_choice(ChoiceRule.BooleanEquals('$.Completed', True), Succeed('Complete')) 69 | choice_state.add_choice(ChoiceRule.BooleanEquals('$.Completed', False), retry) 70 | first_state.next(choice_state) 71 | 72 | result = Graph(first_state).to_dict() 73 | assert result == { 74 | 'StartAt': 'FirstState', 75 | 'States': { 76 | 'FirstState': { 77 | 'Type': 'Task', 78 | 'Resource': 'arn:aws:lambda:us-east-1:1234567890:function:FirstState', 79 | 'Next': 'Is Completed?' 80 | }, 81 | 'Is Completed?': { 82 | 'Type': 'Choice', 83 | 'Choices': [ 84 | { 85 | 'Variable': '$.Completed', 86 | 'BooleanEquals': True, 87 | 'Next': 'Complete' 88 | }, 89 | { 90 | 'Variable': '$.Completed', 91 | 'BooleanEquals': False, 92 | 'Next': 'Retry' 93 | } 94 | ] 95 | }, 96 | 'Complete': { 97 | 'Type': 'Succeed' 98 | }, 99 | 'Retry': { 100 | 'Type': 'Pass', 101 | 'Next': 'Cleanup', 102 | }, 103 | 'Cleanup': { 104 | 'Type': 'Pass', 105 | 'Next': 'FirstState' 106 | } 107 | } 108 | } 109 | 110 | def test_wait_example(): 111 | chain = Chain() 112 | chain.append(Task('FirstState', resource='arn:aws:lambda:us-east-1:1234567890:function:StartState')) 113 | chain.append(Wait('wait_using_seconds', seconds=10)) 114 | chain.append(Wait('wait_using_timestamp', timestamp='2015-09-04T01:59:00Z')) 115 | chain.append(Wait('wait_using_timestamp_path', timestamp_path='$.expirydate')) 116 | chain.append(Wait('wait_using_seconds_path', seconds_path='$.expiryseconds')) 117 | chain.append(Task('FinalState', resource='arn:aws:lambda:us-east-1:1234567890:function:EndLambda')) 118 | 119 | result = Graph(chain).to_dict() 120 | assert result == { 121 | 'StartAt': 'FirstState', 122 | 'States': { 123 | 'FirstState': { 124 | 'Type': 'Task', 125 | 'Resource': 'arn:aws:lambda:us-east-1:1234567890:function:StartState', 126 | 'Next': 'wait_using_seconds' 127 | }, 128 | 'wait_using_seconds': { 129 | 'Type': 'Wait', 130 | 'Seconds': 10, 131 | 'Next': 'wait_using_timestamp' 132 | }, 133 | 'wait_using_timestamp': { 134 | 'Type': 'Wait', 135 | 'Timestamp': '2015-09-04T01:59:00Z', 136 | 'Next': 'wait_using_timestamp_path' 137 | }, 138 | 'wait_using_timestamp_path': { 139 | 'Type': 'Wait', 140 | 'TimestampPath': '$.expirydate', 141 | 'Next': 'wait_using_seconds_path' 142 | }, 143 | 'wait_using_seconds_path': { 144 | 'Type': 'Wait', 145 | 'SecondsPath': '$.expiryseconds', 146 | 'Next': 'FinalState', 147 | }, 148 | 'FinalState': { 149 | 'Type': 'Task', 150 | 'Resource': 'arn:aws:lambda:us-east-1:1234567890:function:EndLambda', 151 | 'End': True 152 | } 153 | } 154 | } 155 | 156 | def test_choice_example(): 157 | next_state = Task('NextState', resource='arn:aws:lambda:us-east-1:1234567890:function:NextState') 158 | 159 | choice_state = Choice('ChoiceState') 160 | choice_state.default_choice(Fail('DefaultState', error='DefaultStateError', cause='No Matches!')) 161 | choice_state.add_choice(ChoiceRule.NumericEquals(variable='$.foo', value=1), Chain([ 162 | Task('FirstMatchState', resource='arn:aws:lambda:us-east-1:1234567890:function:FirstMatchState'), 163 | next_state 164 | ])) 165 | 166 | choice_state.add_choice(ChoiceRule.NumericEquals(variable='$.foo', value=2), Chain([ 167 | Task('SecondMatchState', resource='arn:aws:lambda:us-east-1:1234567890:function:SecondMatchState'), 168 | next_state 169 | ])) 170 | 171 | chain = Chain() 172 | chain.append(Task('FirstState', resource='arn:aws:lambda:us-east-1:1234567890:function:StartLambda')) 173 | chain.append(choice_state) 174 | 175 | result = Graph(chain).to_dict() 176 | assert result == { 177 | 'StartAt': 'FirstState', 178 | 'States': { 179 | 'FirstState': { 180 | 'Type': 'Task', 181 | 'Resource': 'arn:aws:lambda:us-east-1:1234567890:function:StartLambda', 182 | 'Next': 'ChoiceState' 183 | }, 184 | 'ChoiceState': { 185 | 'Type': 'Choice', 186 | 'Choices': [ 187 | { 188 | 'Variable': '$.foo', 189 | 'NumericEquals': 1, 190 | 'Next': 'FirstMatchState' 191 | }, 192 | { 193 | 'Variable': '$.foo', 194 | 'NumericEquals': 2, 195 | 'Next': 'SecondMatchState' 196 | } 197 | ], 198 | 'Default': 'DefaultState' 199 | }, 200 | 'FirstMatchState': { 201 | 'Type': 'Task', 202 | 'Resource': 'arn:aws:lambda:us-east-1:1234567890:function:FirstMatchState', 203 | 'Next': 'NextState' 204 | }, 205 | 'SecondMatchState': { 206 | 'Type': 'Task', 207 | 'Resource': 'arn:aws:lambda:us-east-1:1234567890:function:SecondMatchState', 208 | 'Next': 'NextState' 209 | }, 210 | 'DefaultState': { 211 | 'Type': 'Fail', 212 | 'Error': 'DefaultStateError', 213 | 'Cause': 'No Matches!' 214 | }, 215 | 'NextState': { 216 | 'Type': 'Task', 217 | 'Resource': 'arn:aws:lambda:us-east-1:1234567890:function:NextState', 218 | 'End': True 219 | } 220 | } 221 | } 222 | 223 | def test_graph_from_string(): 224 | g = Graph(Chain([Pass('HelloWorld')])) 225 | g1 = FrozenGraph.from_json(g.to_json()) 226 | assert isinstance(g1, Graph) 227 | assert g.to_dict() == g1.to_dict() 228 | -------------------------------------------------------------------------------- /tests/unit/test_placeholders.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"). 4 | # You may not use this file except in compliance with the License. 5 | # A copy of the License is located at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # or in the "license" file accompanying this file. This file is distributed 10 | # on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either 11 | # express or implied. See the License for the specific language governing 12 | # permissions and limitations under the License. 13 | from __future__ import absolute_import 14 | 15 | import pytest 16 | import json 17 | 18 | from stepfunctions.inputs import ExecutionInput, StepInput 19 | 20 | def test_placeholder_creation_with_subscript_operator(): 21 | step_input = StepInput() 22 | placeholder_variable = step_input["A"] 23 | assert placeholder_variable.name == "A" 24 | assert placeholder_variable.type is None 25 | 26 | def test_placeholder_creation_with_type(): 27 | workflow_input = ExecutionInput() 28 | placeholder_variable = workflow_input["A"]["b"].get("C", float) 29 | assert placeholder_variable.name == "C" 30 | assert placeholder_variable.type == float 31 | 32 | def test_placeholder_creation_with_int_key(): 33 | workflow_input = ExecutionInput() 34 | placeholder_variable = workflow_input["A"][0] 35 | assert placeholder_variable.name == 0 36 | assert placeholder_variable.type == None 37 | 38 | def test_placeholder_creation_with_invalid_key(): 39 | step_input = StepInput() 40 | with pytest.raises(ValueError): 41 | step_input["A"][1.3] 42 | with pytest.raises(ValueError): 43 | step_input["A"].get(1.2, str) 44 | 45 | def test_placeholder_creation_failure_with_type(): 46 | workflow_input = ExecutionInput() 47 | placeholder_variable = workflow_input["A"]["b"].get("C", float) 48 | with pytest.raises(ValueError): 49 | workflow_input["A"]["b"].get("C", int) 50 | 51 | def test_placeholder_path(): 52 | workflow_input = ExecutionInput() 53 | placeholder_variable = workflow_input["A"]["b"]["C"] 54 | expected_path = ["A", "b", "C"] 55 | assert placeholder_variable._get_path() == expected_path 56 | 57 | def test_placeholder_contains(): 58 | step_input = StepInput() 59 | var_one = step_input["Key01"] 60 | var_two = step_input["Key02"]["Key03"] 61 | var_three = step_input["Key01"]["Key04"] 62 | var_four = step_input["Key05"] 63 | 64 | step_input_two = StepInput() 65 | var_five = step_input_two["Key07"] 66 | 67 | assert step_input.contains(var_three) == True 68 | assert step_input.contains(var_five) == False 69 | assert step_input_two.contains(var_three) == False 70 | 71 | def test_placeholder_schema_as_dict(): 72 | workflow_input = ExecutionInput() 73 | workflow_input["A"]["b"].get("C", float) 74 | workflow_input["Message"] 75 | workflow_input["Key01"]["Key02"] 76 | workflow_input["Key03"] 77 | workflow_input["Key03"]["Key04"] 78 | 79 | expected_schema = { 80 | "A": { 81 | "b": { 82 | "C": float 83 | } 84 | }, 85 | "Message": str, 86 | "Key01": { 87 | "Key02": str 88 | }, 89 | "Key03": { 90 | "Key04": str 91 | } 92 | } 93 | 94 | assert workflow_input.get_schema_as_dict() == expected_schema 95 | 96 | def test_placeholder_schema_as_json(): 97 | step_input = StepInput() 98 | step_input["Response"].get("StatusCode", int) 99 | step_input["Hello"]["World"] 100 | step_input["A"] 101 | step_input["Hello"]["World"].get("Test", str) 102 | 103 | expected_schema = { 104 | "Response": { 105 | "StatusCode": "int" 106 | }, 107 | "Hello": { 108 | "World": { 109 | "Test": "str" 110 | } 111 | }, 112 | "A": "str" 113 | } 114 | 115 | assert step_input.get_schema_as_json() == json.dumps(expected_schema) 116 | 117 | def test_placeholder_is_empty(): 118 | workflow_input = ExecutionInput() 119 | placeholder_variable = workflow_input["A"]["B"]["C"] 120 | assert placeholder_variable._is_empty() == True 121 | workflow_input["A"]["B"]["C"]["D"] 122 | assert placeholder_variable._is_empty() == False 123 | 124 | 125 | def test_placeholder_make_immutable(): 126 | workflow_input = ExecutionInput() 127 | workflow_input["A"]["b"].get("C", float) 128 | workflow_input["Message"] 129 | workflow_input["Key01"]["Key02"] 130 | workflow_input["Key03"] 131 | workflow_input["Key03"]["Key04"] 132 | 133 | assert check_immutable(workflow_input) == False 134 | 135 | workflow_input._make_immutable() 136 | assert check_immutable(workflow_input) == True 137 | 138 | 139 | def test_placeholder_with_schema(): 140 | test_schema = { 141 | "A": { 142 | "B":{ 143 | "C": int 144 | } 145 | }, 146 | "Request": { 147 | "Status": str 148 | }, 149 | "Hello": float 150 | } 151 | workflow_input = ExecutionInput(schema=test_schema) 152 | assert workflow_input.get_schema_as_dict() == test_schema 153 | assert workflow_input.immutable == True 154 | 155 | with pytest.raises(ValueError): 156 | workflow_input["A"]["B"]["D"] 157 | 158 | with pytest.raises(ValueError): 159 | workflow_input["A"]["B"].get("C", float) 160 | 161 | def test_workflow_input_jsonpath(): 162 | workflow_input = ExecutionInput() 163 | placeholder_variable = workflow_input["A"]["b"].get("C", float) 164 | assert placeholder_variable.to_jsonpath() == "$$.Execution.Input['A']['b']['C']" 165 | 166 | def test_step_input_jsonpath(): 167 | step_input = StepInput() 168 | placeholder_variable = step_input["A"]["b"].get(0, float) 169 | assert placeholder_variable.to_jsonpath() == "$['A']['b'][0]" 170 | 171 | # UTILS 172 | 173 | def check_immutable(placeholder): 174 | if placeholder.immutable is True: 175 | if placeholder._is_empty(): 176 | return True 177 | else: 178 | for k, v in placeholder.store.items(): 179 | return check_immutable(v) 180 | else: 181 | return False -------------------------------------------------------------------------------- /tests/unit/test_placeholders_with_workflows.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | 3 | # Licensed under the Apache License, Version 2.0 (the "License"). 4 | # You may not use this file except in compliance with the License. 5 | # A copy of the License is located at 6 | 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | # or in the "license" file accompanying this file. This file is distributed 10 | # on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either 11 | # express or implied. See the License for the specific language governing 12 | # permissions and limitations under the License. 13 | from __future__ import absolute_import 14 | 15 | import pytest 16 | from mock import MagicMock 17 | 18 | from datetime import datetime 19 | 20 | from stepfunctions.inputs import ExecutionInput, StepInput 21 | from stepfunctions.steps import Pass, Chain 22 | from stepfunctions.workflow import Workflow 23 | 24 | @pytest.fixture(scope="module") 25 | def client(): 26 | mock = MagicMock() 27 | return mock 28 | 29 | @pytest.fixture(scope="module") 30 | def workflow(client): 31 | execution_input = ExecutionInput() 32 | 33 | test_step_01 = Pass( 34 | state_id='StateOne', 35 | parameters={ 36 | 'ParamA': execution_input['Key02']['Key03'], 37 | 'ParamD': execution_input['Key01']['Key03'], 38 | } 39 | ) 40 | 41 | test_step_02 = Pass( 42 | state_id='StateTwo', 43 | parameters={ 44 | 'ParamC': execution_input["Key05"], 45 | "ParamB": "SampleValueB", 46 | "ParamE": test_step_01.output()["Response"]["Key04"] 47 | } 48 | ) 49 | 50 | test_step_03 = Pass( 51 | state_id='StateThree', 52 | parameters={ 53 | 'ParamG': "SampleValueG", 54 | "ParamF": execution_input["Key06"], 55 | "ParamH": "SampleValueH", 56 | "ParamI": test_step_02.output() 57 | } 58 | ) 59 | 60 | workflow_definition = Chain([test_step_01, test_step_02, test_step_03]) 61 | workflow = Workflow( 62 | name='TestWorkflow', 63 | definition=workflow_definition, 64 | role='testRoleArn', 65 | execution_input=execution_input, 66 | client=client 67 | ) 68 | return workflow 69 | 70 | 71 | def test_workflow_execute_with_invalid_input(workflow): 72 | 73 | with pytest.raises(ValueError): 74 | workflow.execute(inputs={}) 75 | 76 | with pytest.raises(ValueError): 77 | workflow.execute(inputs={ 78 | "Key02": { 79 | "Key03": "Hello" 80 | }, 81 | "Key01": { 82 | "Key03": "World" 83 | }, 84 | "Key05": "Test", 85 | "Key06": 123 86 | }) 87 | 88 | def test_workflow_execute_with_valid_input(client, workflow): 89 | 90 | workflow.create() 91 | execution = workflow.execute(inputs={ 92 | "Key02": { 93 | "Key03": "Hello" 94 | }, 95 | "Key01": { 96 | "Key03": "World" 97 | }, 98 | "Key05": "Test01", 99 | "Key06": "Test02" 100 | }) 101 | 102 | client.create_state_machine.assert_called() 103 | client.start_execution.assert_called() 104 | -------------------------------------------------------------------------------- /tests/unit/test_steps_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"). 4 | # You may not use this file except in compliance with the License. 5 | # A copy of the License is located at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # or in the "license" file accompanying this file. This file is distributed 10 | # on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either 11 | # express or implied. See the License for the specific language governing 12 | # permissions and limitations under the License. 13 | 14 | # Test if boto3 session can fetch correct aws partition info from test environment 15 | 16 | import boto3 17 | import logging 18 | import pytest 19 | 20 | from enum import Enum 21 | from unittest.mock import patch 22 | 23 | from stepfunctions.steps.utils import get_aws_partition, merge_dicts 24 | from stepfunctions.steps.integration_resources import IntegrationPattern,\ 25 | get_service_integration_arn, is_integration_pattern_valid 26 | 27 | 28 | testService = "sagemaker" 29 | 30 | 31 | class TestApi(Enum): 32 | CreateTrainingJob = "createTrainingJob" 33 | 34 | 35 | @patch.object(boto3.session.Session, 'region_name', 'us-east-1') 36 | def test_util_get_aws_partition_aws(): 37 | cur_partition = get_aws_partition() 38 | assert cur_partition == "aws" 39 | 40 | 41 | @patch.object(boto3.session.Session, 'region_name', 'cn-north-1') 42 | def test_util_get_aws_partition_aws_cn(): 43 | cur_partition = get_aws_partition() 44 | assert cur_partition == "aws-cn" 45 | 46 | 47 | @patch.object(boto3.session.Session, 'region_name', 'us-east-1') 48 | def test_arn_builder_sagemaker_no_wait(): 49 | arn = get_service_integration_arn(testService, TestApi.CreateTrainingJob) 50 | assert arn == "arn:aws:states:::sagemaker:createTrainingJob" 51 | 52 | 53 | @patch.object(boto3.session.Session, 'region_name', 'us-east-1') 54 | def test_arn_builder_sagemaker_wait_completion(): 55 | arn = get_service_integration_arn(testService, TestApi.CreateTrainingJob, 56 | IntegrationPattern.WaitForCompletion) 57 | assert arn == "arn:aws:states:::sagemaker:createTrainingJob.sync" 58 | 59 | 60 | def test_merge_dicts(): 61 | d1 = { 62 | 'a': { 63 | 'aa': 1, 64 | 'bb': 2, 65 | 'cc': 3 66 | }, 67 | 'b': 1 68 | } 69 | 70 | d2 = { 71 | 'a': { 72 | 'bb': { 73 | 'aaa': 1, 74 | 'bbb': 2 75 | } 76 | }, 77 | 'b': 2, 78 | 'c': 3 79 | } 80 | 81 | merge_dicts(d1, d2) 82 | assert d1 == { 83 | 'a': { 84 | 'aa': 1, 85 | 'bb': { 86 | 'aaa': 1, 87 | 'bbb': 2 88 | }, 89 | 'cc': 3 90 | }, 91 | 'b': 2, 92 | 'c': 3 93 | } 94 | 95 | 96 | @pytest.mark.parametrize("service_integration_type", [ 97 | None, 98 | "IntegrationPatternStr", 99 | 0 100 | ]) 101 | def test_is_integration_pattern_valid_with_invalid_type_raises_type_error(service_integration_type): 102 | with pytest.raises(TypeError): 103 | is_integration_pattern_valid(service_integration_type, [IntegrationPattern.WaitForTaskToken]) 104 | 105 | 106 | def test_is_integration_pattern_valid_with_non_supported_type_raises_value_error(): 107 | with pytest.raises(ValueError): 108 | is_integration_pattern_valid(IntegrationPattern.WaitForTaskToken, [IntegrationPattern.WaitForCompletion]) 109 | -------------------------------------------------------------------------------- /tests/unit/test_template_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"). 4 | # You may not use this file except in compliance with the License. 5 | # A copy of the License is located at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # or in the "license" file accompanying this file. This file is distributed 10 | # on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either 11 | # express or implied. See the License for the specific language governing 12 | # permissions and limitations under the License. 13 | from __future__ import absolute_import 14 | 15 | import pytest 16 | 17 | from stepfunctions.steps.states import Task 18 | from stepfunctions.template.utils import ( 19 | replace_parameters_with_context_object, 20 | replace_parameters_with_jsonpath 21 | ) 22 | 23 | 24 | def test_context_object_replacement(): 25 | task = Task('LambdaTask', parameters={ 26 | 'InputObject': 'Result', 27 | 'OutputObject': { 28 | 'Location': 'Unknown' 29 | } 30 | }) 31 | 32 | task.update_parameters(replace_parameters_with_context_object(task)) 33 | 34 | assert task.parameters == { 35 | 'InputObject.$': "$$.Execution.Input['LambdaTask'].InputObject", 36 | 'OutputObject.$': "$$.Execution.Input['LambdaTask'].OutputObject", 37 | } 38 | 39 | def test_jsonpath_replacement(): 40 | task = Task('LambdaTask', parameters={ 41 | 'InputObject': 'Result', 42 | 'OutputObject': { 43 | 'Location': { 44 | 'Country': 'Unknown', 45 | 'City': 'Unknown' 46 | }, 47 | 'Coordinates': { 48 | 'Latitude': 0, 49 | 'Longitude': 0 50 | } 51 | } 52 | }) 53 | 54 | params = replace_parameters_with_jsonpath(task, { 55 | 'InputObject.$': '$.InputObject', 56 | 'OutputObject': { 57 | 'Location.$': '$.OutputLocation', 58 | 'Coordinates': { 59 | 'Latitude.$': '$.Latitude', 60 | 'Longitude': 0 61 | } 62 | } 63 | }) 64 | 65 | print(params) 66 | 67 | assert params == { 68 | 'InputObject.$': '$.InputObject', 69 | 'OutputObject': { 70 | 'Location.$': '$.OutputLocation', 71 | 'Coordinates': { 72 | 'Latitude.$': '$.Latitude', 73 | 'Longitude': 0 74 | } 75 | } 76 | } 77 | -------------------------------------------------------------------------------- /tests/unit/test_widget_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"). 4 | # You may not use this file except in compliance with the License. 5 | # A copy of the License is located at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # or in the "license" file accompanying this file. This file is distributed 10 | # on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either 11 | # express or implied. See the License for the specific language governing 12 | # permissions and limitations under the License. 13 | from __future__ import absolute_import 14 | 15 | import pytest 16 | 17 | from datetime import datetime 18 | 19 | from stepfunctions.workflow import Execution 20 | from stepfunctions.workflow.widgets.utils import ( 21 | AWS_SFN_EXECUTIONS_DETAIL_URL, 22 | format_time, 23 | get_elapsed_ms, 24 | create_sfn_execution_url 25 | ) 26 | 27 | REGION = 'us-east-1' 28 | WORKFLOW_NAME = 'HelloWorld' 29 | STATUS = 'RUNNING' 30 | 31 | EXECUTION_ARN = 'arn:aws:states:{}:1234567890:execution:{}:execution-1'.format(REGION, WORKFLOW_NAME) 32 | expected_aws_sfn_executions_detail_url = "https://console.aws.amazon.com/states/home?region={region}#/executions/details/{execution_arn}" 33 | 34 | 35 | def test_sfn_console_url(): 36 | assert AWS_SFN_EXECUTIONS_DETAIL_URL == expected_aws_sfn_executions_detail_url 37 | 38 | def test_format_time(): 39 | none_time = format_time(None) 40 | formatted_time = format_time(datetime(2020, 1, 2, 13, 30, 45, 123000)) 41 | 42 | assert none_time == "-" 43 | assert formatted_time == 'Jan 02, 2020 01:30:45.123 PM' 44 | 45 | def test_get_elapsed_ms(): 46 | elapsed_time_microseconds = 10000 47 | start_time = datetime(2020, 1, 2, 13, 30, 45, 123000) 48 | end_time = datetime(2020, 1, 2, 13, 30, 45, 123000 + 49 | elapsed_time_microseconds) 50 | calculated_elapsed_milliseconds = get_elapsed_ms(start_time, end_time) 51 | 52 | assert calculated_elapsed_milliseconds == elapsed_time_microseconds / 1000 53 | 54 | def test_create_sfn_execution_url(): 55 | sfn_execution_url = create_sfn_execution_url(EXECUTION_ARN) 56 | assert sfn_execution_url == expected_aws_sfn_executions_detail_url.format( 57 | region=REGION, execution_arn=EXECUTION_ARN) 58 | -------------------------------------------------------------------------------- /tests/unit/test_widgets.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"). 4 | # You may not use this file except in compliance with the License. 5 | # A copy of the License is located at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # or in the "license" file accompanying this file. This file is distributed 10 | # on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either 11 | # express or implied. See the License for the specific language governing 12 | # permissions and limitations under the License. 13 | from __future__ import absolute_import 14 | 15 | import pytest 16 | import json 17 | 18 | from stepfunctions.steps import * 19 | from stepfunctions.workflow.widgets import WorkflowGraphWidget, ExecutionGraphWidget 20 | 21 | 22 | REGION = 'test-region' 23 | WORKFLOW = 'HelloWorld' 24 | EXECUTION_ARN = 'arn:aws:states:{region}:1234567890:execution:{workflow}:execution-1'.format(region=REGION, 25 | workflow=WORKFLOW) 26 | 27 | 28 | def test_workflow_graph(): 29 | graph = Graph(Chain([Pass('Prepare Data'), Pass('Start Training'), Pass('Batch Transform'), Pass('Deploy')])) 30 | 31 | widget = WorkflowGraphWidget(graph.to_json()) 32 | html_snippet = widget.show(portrait=False) 33 | 34 | assert "layout: 'LR'" in html_snippet.data 35 | assert 'var graph = new sfn.StateMachineGraph(definition, elementId, options);' in html_snippet.data 36 | 37 | html_snippet = widget.show(portrait=True) 38 | assert "layout: 'TB'" in html_snippet.data 39 | 40 | 41 | def test_execution_graph(): 42 | graph = Graph(Chain([Pass('Prepare Data'), Pass('Start Training'), Pass('Batch Transform'), Pass('Deploy')])) 43 | events = [{}, {}, {}] 44 | 45 | widget = ExecutionGraphWidget(graph.to_json(), json.dumps(events), EXECUTION_ARN) 46 | html_snippet = widget.show() 47 | assert 'var graph = new sfn.StateMachineExecutionGraph(definition, events, elementId, options);' in html_snippet.data 48 | 49 | expected_aws_sfn_executions_detail_url=' Inspect in AWS Step Functions ' 50 | assert expected_aws_sfn_executions_detail_url.format(region=REGION, 51 | execution_arn=EXECUTION_ARN) in html_snippet.data 52 | -------------------------------------------------------------------------------- /tests/unit/test_workflow.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"). 4 | # You may not use this file except in compliance with the License. 5 | # A copy of the License is located at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # or in the "license" file accompanying this file. This file is distributed 10 | # on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either 11 | # express or implied. See the License for the specific language governing 12 | # permissions and limitations under the License. 13 | from __future__ import absolute_import 14 | 15 | import pytest 16 | import uuid 17 | import boto3 18 | import yaml 19 | import json 20 | 21 | from datetime import datetime 22 | from unittest.mock import MagicMock, Mock 23 | from stepfunctions import steps 24 | from stepfunctions.exceptions import WorkflowNotFound, MissingRequiredParameter 25 | from stepfunctions.workflow import Workflow, Execution, ExecutionStatus 26 | 27 | 28 | state_machine_name = 'HelloWorld' 29 | state_machine_arn = 'arn:aws:states:us-east-1:1234567890:stateMachine:HelloWorld' 30 | role_arn = 'arn:aws:iam::1234567890:role/service-role/StepFunctionsRole' 31 | execution_arn = 'arn:aws:states:us-east-1:1234567890:execution:HelloWorld:execution-1' 32 | definition = steps.Chain([steps.Pass('HelloWorld'), steps.Succeed('Complete')]) 33 | 34 | 35 | @pytest.fixture 36 | def client(): 37 | sfn = boto3.client('stepfunctions') 38 | sfn.describe_state_machine = MagicMock(return_value={ 39 | 'creationDate': datetime(2019, 9, 9, 9, 59, 59, 276000), 40 | 'definition': steps.Graph(definition).to_json(), 41 | 'name': state_machine_name, 42 | 'roleArn': role_arn, 43 | 'stateMachineArn': state_machine_arn, 44 | 'status': 'ACTIVE' 45 | }) 46 | sfn.create_state_machine = MagicMock(return_value={ 47 | 'creationDate': datetime.now(), 48 | 'stateMachineArn': state_machine_arn 49 | }) 50 | sfn.delete_state_machine = MagicMock(return_value=None) 51 | sfn.start_execution = MagicMock(return_value={ 52 | 'executionArn': execution_arn, 53 | 'startDate': datetime.now(), 54 | 'stateMachineArn': state_machine_arn, 55 | 'status': 'RUNNING' 56 | }) 57 | return sfn 58 | 59 | 60 | @pytest.fixture 61 | def workflow(client): 62 | workflow = Workflow( 63 | name=state_machine_name, 64 | definition=definition, 65 | role=role_arn, 66 | client=client 67 | ) 68 | workflow.create() 69 | return workflow 70 | 71 | 72 | def test_workflow_creation(client, workflow): 73 | assert workflow.state_machine_arn == state_machine_arn 74 | 75 | 76 | def test_workflow_creation_failure_duplicate_state_ids(client): 77 | improper_definition = steps.Chain([steps.Pass('HelloWorld'), steps.Succeed('HelloWorld')]) 78 | with pytest.raises(ValueError): 79 | workflow = Workflow( 80 | name=state_machine_name, 81 | definition=improper_definition, 82 | role=role_arn, 83 | client=client 84 | ) 85 | 86 | 87 | # calling update() before create() 88 | def test_workflow_update_when_statemachinearn_is_none(client): 89 | workflow = Workflow( 90 | name=state_machine_name, 91 | definition=definition, 92 | role=role_arn, 93 | client=client 94 | ) 95 | new_definition = steps.Pass('HelloWorld') 96 | with pytest.raises(WorkflowNotFound): 97 | workflow.update(definition=new_definition) 98 | 99 | 100 | # calling update() after create() without arguments 101 | def test_workflow_update_when_arguments_are_missing(client, workflow): 102 | with pytest.raises(MissingRequiredParameter): 103 | workflow.update() 104 | 105 | 106 | # calling update() after create() 107 | def test_workflow_update(client, workflow): 108 | client.update_state_machine = MagicMock(return_value={ 109 | 'updateDate': datetime.now() 110 | }) 111 | new_definition = steps.Pass('HelloWorld') 112 | new_role = 'arn:aws:iam::1234567890:role/service-role/StepFunctionsRoleNew' 113 | assert workflow.update(definition=new_definition, role=new_role) == state_machine_arn 114 | 115 | 116 | def test_attach_existing_workflow(client): 117 | workflow = Workflow.attach(state_machine_arn, client) 118 | assert workflow.name == state_machine_name 119 | assert workflow.role == role_arn 120 | assert workflow.state_machine_arn == state_machine_arn 121 | 122 | 123 | def test_workflow_list_executions(client, workflow): 124 | paginator = client.get_paginator('list_executions') 125 | paginator.paginate = MagicMock(return_value=[ 126 | { 127 | 'executions': [ 128 | { 129 | 'stateMachineArn': state_machine_arn, 130 | 'executionArn': execution_arn, 131 | 'startDate': datetime.now(), 132 | 'status': 'RUNNING', 133 | 'name': 'HelloWorld' 134 | } 135 | ] 136 | } 137 | ]) 138 | client.get_paginator = MagicMock(return_value=paginator) 139 | 140 | execution = workflow.execute() 141 | assert execution.workflow.state_machine_arn == workflow.state_machine_arn 142 | assert execution.execution_arn == execution_arn 143 | 144 | executions = workflow.list_executions() 145 | assert len(executions) == 1 146 | assert isinstance(executions[0], Execution) 147 | 148 | workflow.state_machine_arn = None 149 | assert workflow.list_executions() == [] 150 | 151 | 152 | def test_workflow_makes_deletion_call(client, workflow): 153 | client.delete_state_machine = MagicMock(return_value=None) 154 | workflow.delete() 155 | 156 | client.delete_state_machine.assert_called_once_with(stateMachineArn=state_machine_arn) 157 | 158 | 159 | def test_workflow_execute_creation(client, workflow): 160 | execution = workflow.execute() 161 | assert execution.workflow.state_machine_arn == state_machine_arn 162 | assert execution.execution_arn == execution_arn 163 | assert execution.status == ExecutionStatus.Running 164 | 165 | client.start_execution = MagicMock(return_value={ 166 | 'executionArn': 'arn:aws:states:us-east-1:1234567890:execution:HelloWorld:TestRun', 167 | 'startDate': datetime.now() 168 | }) 169 | 170 | execution = workflow.execute(name='TestRun', inputs={}) 171 | client.start_execution.assert_called_once_with( 172 | stateMachineArn=state_machine_arn, 173 | name='TestRun', 174 | input='{}' 175 | ) 176 | 177 | 178 | def test_workflow_execute_when_statemachinearn_is_none(client, workflow): 179 | workflow.state_machine_arn = None 180 | with pytest.raises(WorkflowNotFound): 181 | workflow.execute() 182 | 183 | 184 | def test_execution_makes_describe_call(client, workflow): 185 | execution = workflow.execute() 186 | 187 | client.describe_execution = MagicMock(return_value={}) 188 | execution.describe() 189 | 190 | client.describe_execution.assert_called_once() 191 | 192 | 193 | def test_execution_makes_stop_call(client, workflow): 194 | execution = workflow.execute() 195 | 196 | client.stop_execution = MagicMock(return_value={}) 197 | 198 | execution.stop() 199 | client.stop_execution.assert_called_with( 200 | executionArn=execution_arn 201 | ) 202 | 203 | execution.stop(cause='Test', error='Error') 204 | client.stop_execution.assert_called_with( 205 | executionArn=execution_arn, 206 | cause='Test', 207 | error='Error' 208 | ) 209 | 210 | 211 | def test_execution_list_events(client, workflow): 212 | paginator = client.get_paginator('get_execution_history') 213 | paginator.paginate = MagicMock(return_value=[ 214 | { 215 | 'events': [ 216 | { 217 | 'timestamp': datetime(2019, 1, 1), 218 | 'type': 'TaskFailed', 219 | 'id': 123, 220 | 'previousEventId': 456, 221 | 'taskFailedEventDetails': { 222 | 'resourceType': 'type', 223 | 'resource': 'resource', 224 | 'error': 'error', 225 | 'cause': 'test' 226 | } 227 | } 228 | ], 229 | 'NextToken': 'Token' 230 | } 231 | ]) 232 | client.get_paginator = MagicMock(return_value=paginator) 233 | 234 | execution = workflow.execute() 235 | execution.list_events(max_items=999, reverse_order=True) 236 | 237 | paginator.paginate.assert_called_with( 238 | executionArn=execution_arn, 239 | reverseOrder=True, 240 | PaginationConfig={ 241 | 'MaxItems': 999, 242 | 'PageSize': 1000 243 | } 244 | ) 245 | 246 | 247 | def test_list_workflows(client): 248 | paginator = client.get_paginator('list_state_machines') 249 | paginator.paginate = MagicMock(return_value=[ 250 | { 251 | 'stateMachines': [ 252 | { 253 | 'stateMachineArn': state_machine_arn, 254 | 'name': state_machine_name, 255 | 'creationDate': datetime(2019, 1, 1) 256 | } 257 | ], 258 | 'NextToken': 'Token' 259 | } 260 | ]) 261 | 262 | client.get_paginator = MagicMock(return_value=paginator) 263 | workflows = Workflow.list_workflows(max_items=999, client=client) 264 | 265 | paginator.paginate.assert_called_with( 266 | PaginationConfig={ 267 | 'MaxItems': 999, 268 | 'PageSize': 1000 269 | } 270 | ) 271 | 272 | 273 | def test_cloudformation_export_with_simple_definition(workflow): 274 | cfn_template = workflow.get_cloudformation_template() 275 | cfn_template = yaml.safe_load(cfn_template) 276 | assert 'StateMachineComponent' in cfn_template['Resources'] 277 | assert workflow.role == cfn_template['Resources']['StateMachineComponent']['Properties']['RoleArn'] 278 | assert cfn_template['Description'] == "CloudFormation template for AWS Step Functions - State Machine" 279 | 280 | 281 | def test_cloudformation_export_with_sagemaker_execution_role(workflow): 282 | workflow.definition.to_dict = MagicMock(return_value={ 283 | 'StartAt': 'Training', 284 | 'States': { 285 | 'Training': { 286 | 'Type': 'Task', 287 | 'Parameters': { 288 | 'AlgorithmSpecification': { 289 | 'TrainingImage': '382416733822.dkr.ecr.us-east-1.amazonaws.com/pca:1', 290 | 'TrainingInputMode': 'File' 291 | }, 292 | 'OutputDataConfig': { 293 | 'S3OutputPath': 's3://sagemaker/models' 294 | }, 295 | 'RoleArn': 'arn:aws:iam::1234567890:role/service-role/AmazonSageMaker-ExecutionRole', 296 | }, 297 | 'Resource': 'arn:aws:states:::sagemaker:createTrainingJob.sync', 298 | 'End': True 299 | } 300 | } 301 | }) 302 | cfn_template = workflow.get_cloudformation_template(description="CloudFormation template with Sagemaker role") 303 | cfn_template = yaml.safe_load(cfn_template) 304 | assert json.dumps(workflow.definition.to_dict(), indent=2) == cfn_template['Resources']['StateMachineComponent']['Properties']['DefinitionString'] 305 | assert workflow.role == cfn_template['Resources']['StateMachineComponent']['Properties']['RoleArn'] 306 | assert cfn_template['Description'] == "CloudFormation template with Sagemaker role" 307 | -------------------------------------------------------------------------------- /tests/unit/test_workflow_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"). 4 | # You may not use this file except in compliance with the License. 5 | # A copy of the License is located at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # or in the "license" file accompanying this file. This file is distributed 10 | # on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either 11 | # express or implied. See the License for the specific language governing 12 | # permissions and limitations under the License. 13 | from __future__ import absolute_import 14 | 15 | import pytest 16 | import boto3 17 | 18 | from unittest.mock import Mock, patch 19 | from stepfunctions.workflow.utils import append_user_agent_to_client 20 | 21 | @pytest.fixture 22 | def client(): 23 | sfn = boto3.client('stepfunctions') 24 | sfn._client_config = Mock() 25 | sfn._client_config.user_agent = "abc/1.2.3 def/4.5.6" 26 | return sfn 27 | 28 | @patch('stepfunctions.__useragent__', 'helloworld') 29 | @patch('stepfunctions.__version__', '9.8.7') 30 | def test_append_user_agent_to_client(client): 31 | append_user_agent_to_client(client) 32 | user_agent_suffix = client._client_config.user_agent.split()[-1] 33 | assert user_agent_suffix == "helloworld/9.8.7" 34 | -------------------------------------------------------------------------------- /tests/unit/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"). 4 | # You may not use this file except in compliance with the License. 5 | # A copy of the License is located at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # or in the "license" file accompanying this file. This file is distributed 10 | # on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either 11 | # express or implied. See the License for the specific language governing 12 | # permissions and limitations under the License. 13 | from __future__ import absolute_import 14 | 15 | import botocore 16 | 17 | boto_true_api_call = botocore.client.BaseClient._make_api_call 18 | 19 | def mock_boto_api_call(self, operation_name, kwarg): 20 | if operation_name == "GetCallerIdentity": 21 | return { 22 | "Account": "ABCXYZ" 23 | } 24 | elif operation_name == "CreateBucket": 25 | return { 26 | "Location": "s3://abcxyz-path/" 27 | } 28 | return boto_true_api_call(self, operation_name, kwarg) -------------------------------------------------------------------------------- /tests/unit/widgets/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"). 4 | # You may not use this file except in compliance with the License. 5 | # A copy of the License is located at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # or in the "license" file accompanying this file. This file is distributed 10 | # on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either 11 | # express or implied. See the License for the specific language governing 12 | # permissions and limitations under the License. 13 | from __future__ import absolute_import -------------------------------------------------------------------------------- /tests/unit/widgets/test_events_table_widget.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"). 4 | # You may not use this file except in compliance with the License. 5 | # A copy of the License is located at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # or in the "license" file accompanying this file. This file is distributed 10 | # on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either 11 | # express or implied. See the License for the specific language governing 12 | # permissions and limitations under the License. 13 | from __future__ import absolute_import 14 | 15 | import pytest 16 | 17 | from datetime import datetime 18 | from stepfunctions.workflow.widgets import EventsTableWidget 19 | PASS_WORKFLOW_EVENTS = [ 20 | { 21 | 'timestamp': datetime(2019, 9, 1, 13, 45, 47, 940000), 22 | 'type': 'ExecutionStarted', 23 | 'id': 1, 24 | 'previousEventId': 0, 25 | 'executionStartedEventDetails': { 26 | 'input': '{}', 27 | 'roleArn': 'arn:aws:iam::099764291644:role/stepfunctions' 28 | } 29 | }, 30 | { 31 | 'timestamp': datetime(2019, 9, 1, 13, 46, 50, 215000), 32 | 'type': 'PassStateEntered', 33 | 'id': 2, 34 | 'previousEventId': 0, 35 | 'stateEnteredEventDetails': { 36 | 'name': 'HelloWorld1234', 37 | 'input': '{}' 38 | } 39 | }, 40 | { 41 | 'timestamp': datetime(2019, 9, 1, 13, 47, 9, 396000), 42 | 'type': 'PassStateExited', 43 | 'id': 3, 44 | 'previousEventId': 2, 45 | 'stateExitedEventDetails': { 46 | 'name': 'HelloWorld1234', 47 | 'output': '"Hello World!"' 48 | } 49 | }, 50 | { 51 | 'timestamp': datetime(2019, 9, 1, 13, 47, 24, 724000), 52 | 'type': 'ExecutionSucceeded', 53 | 'id': 4, 54 | 'previousEventId': 3, 55 | 'executionSucceededEventDetails': { 56 | 'output': '"Hello World!"' 57 | } 58 | } 59 | ] 60 | LAMBDA_STATE_NAME = "Lambda Name" 61 | LAMBDA_RESOURCE_NAME = "Lambda" 62 | LAMBDA_WORKFLOW_EVENTS = [ 63 | { 64 | 'timestamp': datetime(2019, 9, 1, 13, 54, 21, 366000), 65 | 'type': 'ExecutionStarted', 66 | 'id': 1, 67 | 'previousEventId': 0, 68 | 'executionStartedEventDetails': { 69 | 'input': '{}', 70 | 'roleArn': 'arn:aws:iam::012345678901:role/stepfunctions' 71 | } 72 | }, 73 | { 74 | 'timestamp': datetime(2019, 9, 1, 13, 55, 30, 169000), 75 | 'type': 'TaskStateEntered', 76 | 'id': 2, 77 | 'previousEventId': 1, 78 | 'stateEnteredEventDetails': { 79 | 'name': LAMBDA_STATE_NAME, 80 | 'input': '"Hello World!"' 81 | } 82 | }, 83 | { 84 | 'timestamp': datetime(2019, 9, 1, 13, 55, 30, 169000), 85 | 'type': 'LambdaFunctionScheduled', 86 | 'id': 3, 87 | 'previousEventId': 2, 88 | 'lambdaFunctionScheduledEventDetails': { 89 | 'resource': 'arn:aws:lambda:us-east-1:012345678901:function:serverlessrepo-sample-lambda-helloworld-V3LUDGZAC5KB', 90 | 'input': '"Hello World!"' 91 | } 92 | }, 93 | { 94 | 'timestamp': datetime(2019, 9, 1, 13, 55, 30, 192000), 95 | 'type': 'LambdaFunctionStarted', 96 | 'id': 4, 97 | 'previousEventId': 3 98 | }, 99 | { 100 | 'timestamp': datetime(2019, 9, 1, 13, 55, 30, 241000), 101 | 'type': 'LambdaFunctionSucceeded', 102 | 'id': 5, 103 | 'previousEventId': 4, 104 | 'lambdaFunctionSucceededEventDetails': { 105 | 'output': 'null' 106 | } 107 | }, 108 | { 109 | 'timestamp': datetime(2019, 9, 1, 13, 55, 30, 241000), 110 | 'type': 'TaskStateExited', 111 | 'id': 6, 112 | 'previousEventId': 5, 113 | 'stateExitedEventDetails': { 114 | 'name': LAMBDA_STATE_NAME, 115 | 'output': 'null' 116 | } 117 | }, 118 | { 119 | 'timestamp': datetime(2019, 9, 1, 13, 55, 30, 241000), 120 | 'type': 'ExecutionSucceeded', 121 | 'id': 7, 122 | 'previousEventId': 6, 123 | 'executionSucceededEventDetails': { 124 | 'output': 'null' 125 | } 126 | } 127 | ] 128 | 129 | 130 | def test_empty_events_table(): 131 | widget = EventsTableWidget([]) 132 | html_snippet = widget.show() 133 | assert '' not in html_snippet 134 | assert '' in html_snippet 135 | 136 | 137 | def test_events_table(): 138 | widget = EventsTableWidget(PASS_WORKFLOW_EVENTS) 139 | html_snippet = widget.show() 140 | assert html_snippet.count( 141 | '') == len(PASS_WORKFLOW_EVENTS) 142 | 143 | 144 | def test_lambda_workflow_events_table(): 145 | widget = EventsTableWidget(LAMBDA_WORKFLOW_EVENTS) 146 | html_snippet = widget.show() 147 | assert html_snippet.count( 148 | '') == len(LAMBDA_WORKFLOW_EVENTS) 149 | assert html_snippet.count(''.format(LAMBDA_STATE_NAME)) == 5 150 | assert html_snippet.count( 151 | """""".format(LAMBDA_RESOURCE_NAME)) == 3 -------------------------------------------------------------------------------- /tests/unit/widgets/test_executions_table_widget.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"). 4 | # You may not use this file except in compliance with the License. 5 | # A copy of the License is located at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # or in the "license" file accompanying this file. This file is distributed 10 | # on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either 11 | # express or implied. See the License for the specific language governing 12 | # permissions and limitations under the License. 13 | from __future__ import absolute_import 14 | 15 | import pytest 16 | 17 | from unittest.mock import MagicMock 18 | from datetime import datetime 19 | 20 | from stepfunctions.workflow.widgets import ExecutionsTableWidget 21 | from stepfunctions.workflow import Execution 22 | 23 | REGION = 'us-east-1' 24 | WORKFLOW_NAME = 'HelloWorld' 25 | STATUS = 'RUNNING' 26 | 27 | @pytest.fixture 28 | def executions(): 29 | workflow = MagicMock() 30 | workflow.state_machine_arn = 'arn:aws:states:{}:1234567890:stateMachine:{}'.format(REGION, WORKFLOW_NAME) 31 | execution_arn = 'arn:aws:states:{}:1234567890:execution:{}:execution-1'.format(REGION, WORKFLOW_NAME) 32 | 33 | executions = [ 34 | Execution( 35 | name=WORKFLOW_NAME, 36 | workflow=workflow, 37 | execution_arn=execution_arn, 38 | start_date=datetime.now(), 39 | stop_date=None, 40 | status=STATUS, 41 | client=None 42 | ) 43 | ] 44 | return executions 45 | 46 | def test_empty_executions_table(): 47 | widget = ExecutionsTableWidget([]) 48 | html_snippet = widget.show() 49 | assert '' not in html_snippet 50 | assert '
{}{}
' in html_snippet 51 | 52 | def test_executions_table(executions): 53 | widget = ExecutionsTableWidget(executions) 54 | html_snippet = widget.show() 55 | assert html_snippet.count('') == len(executions) 56 | assert WORKFLOW_NAME in html_snippet 57 | assert STATUS in html_snippet 58 | -------------------------------------------------------------------------------- /tests/unit/widgets/test_workflows_table_widget.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"). 4 | # You may not use this file except in compliance with the License. 5 | # A copy of the License is located at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # or in the "license" file accompanying this file. This file is distributed 10 | # on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either 11 | # express or implied. See the License for the specific language governing 12 | # permissions and limitations under the License. 13 | from __future__ import absolute_import 14 | 15 | import pytest 16 | 17 | from unittest.mock import MagicMock 18 | from datetime import datetime 19 | 20 | from stepfunctions.workflow.widgets import WorkflowsTableWidget 21 | 22 | 23 | WORKFLOW_NAME = 'HelloWorld' 24 | STATE_MACHINE_ARN = 'arn:aws:states:us-east-1:1234567890:stateMachine:HelloWorld' 25 | 26 | 27 | @pytest.fixture 28 | def workflows(): 29 | return [ 30 | { 31 | 'stateMachineArn': STATE_MACHINE_ARN, 32 | 'name': WORKFLOW_NAME, 33 | 'creationDate': datetime(2019, 1, 1) 34 | } 35 | ] 36 | 37 | def test_empty_workflows_table(): 38 | widget = WorkflowsTableWidget([]) 39 | html_snippet = widget.show() 40 | assert '' not in html_snippet 41 | assert '
' in html_snippet 42 | 43 | def test_workflows_table(workflows): 44 | widget = WorkflowsTableWidget(workflows) 45 | html_snippet = widget.show() 46 | assert html_snippet.count('') == len(workflows) 47 | assert WORKFLOW_NAME in html_snippet 48 | assert STATE_MACHINE_ARN in html_snippet 49 | -------------------------------------------------------------------------------- /tox.ini: -------------------------------------------------------------------------------- 1 | # tox (https://tox.readthedocs.io/) is a tool for running tests 2 | # in multiple virtualenvs. This configuration file will run the 3 | # test suite on all supported python versions. To use it, "pip install tox" 4 | # and then run "tox" from this directory. 5 | 6 | [tox] 7 | envlist = python3.6 8 | 9 | skip_missing_interpreters = False 10 | 11 | [testenv] 12 | deps = .[test] 13 | passenv = 14 | AWS_ACCESS_KEY_ID 15 | AWS_SECRET_ACCESS_KEY 16 | AWS_DEFAULT_REGION 17 | 18 | # {posargs} can be passed in by additional arguments specified when invoking tox. 19 | # Can be used to specify which tests to run, e.g.: tox -- -s 20 | commands = 21 | pytest {posargs} --------------------------------------------------------------------------------