├── .coveragerc_py310 ├── .coveragerc_py38 ├── .coveragerc_py39 ├── .flake8 ├── .github ├── ISSUE_TEMPLATE │ ├── bug_report.md │ ├── config.yml │ ├── documentation-request.md │ └── feature_request.md └── PULL_REQUEST_TEMPLATE.md ├── .gitignore ├── .pylintrc ├── CHANGELOG.md ├── CODEOWNERS ├── CODE_OF_CONDUCT.md ├── CONTRIBUTING.md ├── ENVIRONMENT_VARIABLES.md ├── LICENSE ├── MANIFEST.in ├── NOTICE ├── README.md ├── VERSION ├── branding └── icon │ └── sagemaker-banner.png ├── buildspec-deploy.yml ├── buildspec-release.yml ├── buildspec.yml ├── setup.cfg ├── setup.py ├── src ├── __init__.py └── sagemaker_training │ ├── __init__.py │ ├── _entry_point_type.py │ ├── c │ ├── gethostname.c │ ├── jsmn.c │ └── jsmn.h │ ├── cli │ ├── __init__.py │ └── train.py │ ├── content_types.py │ ├── encoders.py │ ├── entry_point.py │ ├── environment.py │ ├── errors.py │ ├── files.py │ ├── functions.py │ ├── intermediate_output.py │ ├── logging_config.py │ ├── mapping.py │ ├── modules.py │ ├── mpi.py │ ├── params.py │ ├── process.py │ ├── pytorch_xla.py │ ├── record_pb2.py │ ├── recordio.py │ ├── runner.py │ ├── smdataparallel.py │ ├── timeout.py │ ├── torch_distributed.py │ └── trainer.py ├── test ├── __init__.py ├── conftest.py ├── container │ ├── dummy │ │ ├── Dockerfile │ │ ├── requirements.txt │ │ └── train.py │ └── tensorflow │ │ ├── Dockerfile │ │ └── train.py ├── fake_ml_framework.py ├── functional │ ├── __init__.py │ ├── simple_framework.py │ ├── test_download_and_import.py │ ├── test_intermediate_output.py │ ├── test_mpi.py │ └── test_training_framework.py ├── integration │ └── local │ │ ├── test_dummy.py │ │ └── test_tensorflow.py ├── resources │ └── openmpi │ │ ├── Dockerfile │ │ ├── Dockerfile.base │ │ ├── launcher.sh │ │ └── script.py └── unit │ ├── __init__.py │ ├── c │ └── test_gethostname.py │ ├── cli │ ├── __init__.py │ └── test_train.py │ ├── dummy │ ├── __init__.py │ ├── dummy.py │ └── tensorflow │ │ ├── __init__.py │ │ └── compiler │ │ ├── __init__.py │ │ └── xla │ │ ├── __init__.py │ │ └── dummy_xla.py │ ├── test_encoder.py │ ├── test_entry_point.py │ ├── test_entry_point_type.py │ ├── test_environment.py │ ├── test_errors.py │ ├── test_files.py │ ├── test_functions.py │ ├── test_intermediate_output.py │ ├── test_mapping.py │ ├── test_modules.py │ ├── test_mpi.py │ ├── test_process.py │ ├── test_pytorch_xla.py │ ├── test_runner.py │ ├── test_smdataparallel.py │ ├── test_timeout.py │ ├── test_torch_distributed.py │ └── test_trainer.py └── tox.ini /.coveragerc_py310: -------------------------------------------------------------------------------- 1 | [run] 2 | branch = True 3 | timid = True 4 | 5 | [report] 6 | exclude_lines = 7 | pragma: no cover 8 | pragma: py3 no cover 9 | if six.PY2 10 | elif six.PY2 11 | 12 | partial_branches = 13 | pragma: no cover 14 | pragma: py3 no cover 15 | if six.PY3 16 | elif six.PY3 17 | 18 | show_missing = True 19 | 20 | fail_under = 88 21 | -------------------------------------------------------------------------------- /.coveragerc_py38: -------------------------------------------------------------------------------- 1 | [run] 2 | branch = True 3 | timid = True 4 | 5 | [report] 6 | exclude_lines = 7 | pragma: no cover 8 | pragma: py3 no cover 9 | if six.PY2 10 | elif six.PY2 11 | 12 | partial_branches = 13 | pragma: no cover 14 | pragma: py3 no cover 15 | if six.PY3 16 | elif six.PY3 17 | 18 | show_missing = True 19 | 20 | fail_under = 88 21 | -------------------------------------------------------------------------------- /.coveragerc_py39: -------------------------------------------------------------------------------- 1 | [run] 2 | branch = True 3 | timid = True 4 | 5 | [report] 6 | exclude_lines = 7 | pragma: no cover 8 | pragma: py3 no cover 9 | if six.PY2 10 | elif six.PY2 11 | 12 | partial_branches = 13 | pragma: no cover 14 | pragma: py3 no cover 15 | if six.PY3 16 | elif six.PY3 17 | 18 | show_missing = True 19 | 20 | fail_under = 88 21 | -------------------------------------------------------------------------------- /.flake8: -------------------------------------------------------------------------------- 1 | [flake8] 2 | application_import_names = sagemaker_training, test, gethostname 3 | import-order-style = google 4 | ignore = E501, W503 5 | exclude = 6 | build/ 7 | .git 8 | __pycache__ 9 | .tox 10 | tests/data/ 11 | *venv/ 12 | ./src/sagemaker_training/record_pb2.py -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/bug_report.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Bug report 3 | about: File a report to help us reproduce and fix the problem 4 | title: '' 5 | labels: '' 6 | assignees: '' 7 | 8 | --- 9 | 10 | **Describe the bug** 11 | A clear and concise description of what the bug is. 12 | 13 | **To reproduce** 14 | A clear, step-by-step set of instructions to reproduce the bug. 15 | 16 | **Expected behavior** 17 | A clear and concise description of what you expected to happen. 18 | 19 | **Screenshots or logs** 20 | If applicable, add screenshots or logs to help explain your problem. 21 | 22 | **System information** 23 | A description of your system. 24 | - Include the version of SageMaker Training Toolkit you are using. 25 | - If you are using a [prebuilt Amazon SageMaker Docker image](https://docs.aws.amazon.com/sagemaker/latest/dg/pre-built-containers-frameworks-deep-learning.html), provide the URL. 26 | - If you are using a custom Docker image, provide: 27 | - framework name (eg. PyTorch) 28 | - framework version 29 | - Python version 30 | - processing unit type (ie. CPU or GPU) 31 | 32 | **Additional context** 33 | Add any other context about the problem here. 34 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/config.yml: -------------------------------------------------------------------------------- 1 | blank_issues_enabled: false 2 | contact_links: 3 | - name: Ask a question 4 | url: https://stackoverflow.com/questions/tagged/amazon-sagemaker 5 | about: Use Stack Overflow to ask and answer questions 6 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/documentation-request.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Documentation request 3 | about: Request improved documentation 4 | title: '' 5 | labels: '' 6 | assignees: '' 7 | 8 | --- 9 | 10 | **What did you find confusing? Please describe.** 11 | A clear and concise description of what you found confusing. Ex. I tried to [...] but I didn't understand how to [...] 12 | 13 | **Describe how documentation can be improved** 14 | A clear and concise description of where documentation was lacking and how it can be improved. 15 | 16 | **Additional context** 17 | Add any other context or screenshots about the documentation request here. 18 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/feature_request.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Feature request 3 | about: Suggest new functionality for this toolkit 4 | title: '' 5 | labels: '' 6 | assignees: '' 7 | 8 | --- 9 | 10 | **Describe the feature you'd like** 11 | A clear and concise description of the functionality you want. 12 | 13 | **How would this feature be used? Please describe.** 14 | A clear and concise description of the use case for this feature. Please provide an example, if possible. 15 | 16 | **Describe alternatives you've considered** 17 | A clear and concise description of any alternative solutions or features you've considered. 18 | 19 | **Additional context** 20 | Add any other context or screenshots about the feature request here. 21 | -------------------------------------------------------------------------------- /.github/PULL_REQUEST_TEMPLATE.md: -------------------------------------------------------------------------------- 1 | *Issue #, if available:* 2 | 3 | *Description of changes:* 4 | 5 | *Testing done:* 6 | 7 | ## Merge Checklist 8 | 9 | _Put an `x` in the boxes that apply. You can also fill these out after creating the PR. If you're unsure about any of them, don't hesitate to ask. We're here to help! This is simply a reminder of what we are going to look for before merging your pull request._ 10 | 11 | #### General 12 | 13 | - [ ] I have read the [CONTRIBUTING](https://github.com/aws/sagemaker-training-toolkit/blob/master/CONTRIBUTING.md) doc 14 | - [ ] I used the commit message format described in [CONTRIBUTING](https://github.com/aws/sagemaker-training-toolkit/blob/master/CONTRIBUTING.md#committing-your-change) 15 | - [ ] I have used the regional endpoint when creating S3 and/or STS clients (if appropriate) 16 | - [ ] I have updated any necessary documentation, including [READMEs](https://github.com/aws/sagemaker-training-toolkit/blob/master/README.md) 17 | 18 | #### Tests 19 | 20 | - [ ] I have added tests that prove my fix is effective or that my feature works (if appropriate) 21 | - [ ] I have checked that my tests are not configured for a specific region or account (if appropriate) 22 | 23 | By submitting this pull request, I confirm that my contribution is made under the terms of the Apache 2.0 license. 24 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .idea 2 | .vscode 3 | .cache/ 4 | build/ 5 | dist/ 6 | **/__pycache__/ 7 | .coverage 8 | .pytest_cache/ 9 | .tox/ 10 | htmlcov 11 | **/sagemaker_training.egg-info 12 | **/*.pyc 13 | .mypy_cache/** 14 | .DS_Store 15 | -------------------------------------------------------------------------------- /CODEOWNERS: -------------------------------------------------------------------------------- 1 | * @aws/sagemaker-jobs-platform 2 | -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | ## Code of Conduct 2 | This project has adopted the [Amazon Open Source Code of Conduct](https://aws.github.io/code-of-conduct). 3 | For more information see the [Code of Conduct FAQ](https://aws.github.io/code-of-conduct-faq) or contact 4 | opensource-codeofconduct@amazon.com with any additional questions or comments. 5 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing Guidelines 2 | 3 | Thank you for your interest in contributing to our project. Whether it's a bug report, new feature, correction, or additional 4 | documentation, we greatly value feedback and contributions from our community. 5 | 6 | Please read through this document before submitting any issues or pull requests to ensure we have all the necessary 7 | information to effectively respond to your bug report or contribution. 8 | 9 | 10 | ## Submitting bug reports and feature requests 11 | 12 | We welcome you to use the GitHub issue tracker to report bugs or suggest features. 13 | 14 | When filing an issue, please check [existing open](https://github.com/aws/sagemaker-training-toolkit/issues), or [recently closed](https://github.com/aws/sagemaker-training-toolkit/issues?utf8=%E2%9C%93&q=is%3Aissue%20is%3Aclosed%20), issues to make sure somebody else hasn't already 15 | reported the issue. To create a new issue, select the template that most closely matches what you're writing about (ie. "Bug report", "Documentation request", or "Feature request"). Please fill out all information requested in the issue template. 16 | 17 | ## Contributing via pull requests 18 | Contributions via pull requests are much appreciated. Before sending us a pull request, please ensure that: 19 | 20 | - You are working against the latest source on the *master* branch. 21 | - You check existing open, and recently merged, pull requests to make sure someone else hasn't addressed the problem already. 22 | - You open an issue to discuss any significant work - we would hate for your time to be wasted. 23 | 24 | To send us a pull request, please: 25 | 26 | 1. Fork the repository. 27 | 2. Modify the source; please focus on the specific change you are contributing. If you also reformat all the code, it will be hard for us to focus on your change. 28 | 3. Ensure local tests pass. 29 | 4. Commit to your fork using [clear commit messages](#committing-your-change). 30 | 5. Send us a pull request, answering any default questions in the pull request interface. 31 | 6. Pay attention to any automated CI failures reported in the pull request, and stay involved in the conversation. 32 | The [sagemaker-bot](https://github.com/sagemaker-bot) will comment on the pull request with a link to the build logs. 33 | 34 | GitHub provides additional document on [forking a repository](https://help.github.com/articles/fork-a-repo/) and 35 | [creating a pull request](https://help.github.com/articles/creating-a-pull-request/). 36 | 37 | ### Running the unit tests 38 | 39 | 1. Install tox using `pip install tox` 40 | 1. Install coverage using `pip install .[test]` 41 | 1. cd into the sagemaker-training-toolkit folder: `cd sagemaker-training-toolkit` 42 | 1. Run the following tox command and verify that all code checks and unit tests pass: `tox test/unit` 43 | 44 | You can also run a single test with the following command: `tox -e py37 -- -s -vv test/unit/test_entry_point.py::test_install_module` 45 | * Note that the coverage test will fail if you only run a single test, so make sure to surround the command with `export IGNORE_COVERAGE=-` and `unset IGNORE_COVERAGE` 46 | * Example: `export IGNORE_COVERAGE=- ; tox -e py37 -- -s -vv test/unit/test_entry_point.py::test_install_module ; unset IGNORE_COVERAGE` 47 | 48 | 49 | ### Running the integration tests 50 | 51 | Our CI system runs integration tests (the ones in the `test/integration` directory), in parallel, for every pull request. 52 | You should only worry about manually running any new integration tests that you write, or integration tests that test an area of code that you've modified. 53 | 54 | 1. Follow the instructions at [Set Up the AWS Command Line Interface (AWS CLI)](https://docs.aws.amazon.com/polly/latest/dg/setup-aws-cli.html). 55 | 1. To run a test, specify the test file and method you want to run per the following command: `tox -e py37 -- -s -vv test/integration/local/test_dummy.py::test_install_requirements` 56 | * Note that the coverage test will fail if you only run a single test, so make sure to surround the command with `export IGNORE_COVERAGE=-` and `unset IGNORE_COVERAGE` 57 | * Example: `export IGNORE_COVERAGE=- ; tox -e py37 -- -s -vv test/integration/local/test_dummy.py::test_install_requirements ; unset IGNORE_COVERAGE` 58 | 59 | 60 | ### Making and testing your change 61 | 62 | 1. Create a new git branch: 63 | ```shell 64 | git checkout -b my-fix-branch master 65 | ``` 66 | 1. Make your changes, **including unit tests** and, if appropriate, integration tests. 67 | 1. Include unit tests when you contribute new features or make bug fixes, as they help to: 68 | 1. Prove that your code works correctly. 69 | 1. Guard against future breaking changes to lower the maintenance cost. 70 | 1. Please focus on the specific change you are contributing. If you also reformat all the code, it will be hard for us to focus on your change. 71 | 1. Run all the unit tests as per [Running the unit tests](#running-the-unit-tests), and verify that all checks and tests pass. 72 | 1. Note that this also runs tools that may be necessary for the automated build to pass (ex: code reformatting by 'black'). 73 | 74 | 75 | ### Committing your change 76 | 77 | We use commit messages to update the project version number and generate changelog entries, so it's important for them to follow the right format. Valid commit messages include a prefix, separated from the rest of the message by a colon and a space. Here are a few examples: 78 | 79 | ``` 80 | feature: support VPC config for hyperparameter tuning 81 | fix: fix flake8 errors 82 | documentation: add MXNet documentation 83 | ``` 84 | 85 | Valid prefixes are listed in the table below. 86 | 87 | | Prefix | Use for... | 88 | |----------------:|:-----------------------------------------------------------------------------------------------| 89 | | `breaking` | Incompatible API changes. | 90 | | `deprecation` | Deprecating an existing API or feature, or removing something that was previously deprecated. | 91 | | `feature` | Adding a new feature. | 92 | | `fix` | Bug fixes. | 93 | | `change` | Any other code change. | 94 | | `documentation` | Documentation changes. | 95 | 96 | Some of the prefixes allow abbreviation ; `break`, `feat`, `depr`, and `doc` are all valid. If you omit a prefix, the commit will be treated as a `change`. 97 | 98 | For the rest of the message, use imperative style and keep things concise but informative. See [How to Write a Git Commit Message](https://chris.beams.io/posts/git-commit/) for guidance. 99 | 100 | 101 | ### Sending a pull request 102 | 103 | GitHub provides additional document on [creating a pull request](https://help.github.com/articles/creating-a-pull-request/). 104 | 105 | Please remember to: 106 | * Use commit messages (and PR titles) that follow the guidelines under [Committing your change](#committing-your-change). 107 | * Send us a pull request, answering any default questions in the pull request interface. 108 | * Pay attention to any automated CI failures reported in the pull request, and stay involved in the conversation. 109 | 110 | ## Finding contributions to work on 111 | Looking at the [existing issues](https://github.com/aws/sagemaker-training-toolkit/issues) is a great place to start. 112 | 113 | 114 | ## Code of Conduct 115 | This project has adopted the [Amazon Open Source Code of Conduct](https://aws.github.io/code-of-conduct). 116 | For more information see the [Code of Conduct FAQ](https://aws.github.io/code-of-conduct-faq) or contact 117 | opensource-codeofconduct@amazon.com with any additional questions or comments. 118 | 119 | 120 | ## Security issue notifications 121 | If you discover a potential security issue in this project we ask that you notify AWS/Amazon Security via our [vulnerability reporting page](http://aws.amazon.com/security/vulnerability-reporting/). Please do **not** create a public github issue. 122 | 123 | 124 | ## Licensing 125 | 126 | See the [LICENSE](https://github.com/aws/sagemaker-training-toolkit/blob/master/LICENSE) file for our project's licensing. We will ask you to confirm the licensing of your contribution. 127 | 128 | We may ask you to sign a [Contributor License Agreement (CLA)](http://en.wikipedia.org/wiki/Contributor_License_Agreement) for larger changes. 129 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | include VERSION 2 | include LICENSE.txt 3 | include README.md 4 | 5 | recursive-exclude * __pycache__ 6 | recursive-exclude * *.py[co] 7 | include src/sagemaker_training/c/jsmn.h 8 | -------------------------------------------------------------------------------- /NOTICE: -------------------------------------------------------------------------------- 1 | SageMaker Training Toolkit 2 | Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. 3 | -------------------------------------------------------------------------------- /VERSION: -------------------------------------------------------------------------------- 1 | 5.0.1.dev0 2 | -------------------------------------------------------------------------------- /branding/icon/sagemaker-banner.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aws/sagemaker-training-toolkit/a0073959d4de7e8311b370a1f143e44dcf801d1e/branding/icon/sagemaker-banner.png -------------------------------------------------------------------------------- /buildspec-deploy.yml: -------------------------------------------------------------------------------- 1 | version: 0.2 2 | 3 | phases: 4 | build: 5 | commands: 6 | - PACKAGE_FILE="$CODEBUILD_SRC_DIR_ARTIFACT_1/sagemaker_training-*.tar.gz" 7 | 8 | # publish to pypi 9 | - publish-pypi-package $PACKAGE_FILE 10 | -------------------------------------------------------------------------------- /buildspec-release.yml: -------------------------------------------------------------------------------- 1 | version: 0.2 2 | 3 | phases: 4 | pre_build: 5 | commands: 6 | - start-dockerd 7 | 8 | build: 9 | commands: 10 | # prepare the release (update versions, changelog etc.) 11 | - git-release --prepare 12 | 13 | # run linters 14 | - tox -e flake8,twine,pylint 15 | 16 | # run format verification 17 | - tox -e black-check 18 | 19 | # run unit tests 20 | - AWS_ACCESS_KEY_ID= AWS_SECRET_ACCESS_KEY= AWS_SESSION_TOKEN= 21 | AWS_CONTAINER_CREDENTIALS_RELATIVE_URI= AWS_DEFAULT_REGION= 22 | tox -e py38,py39,py310 -- test/unit 23 | 24 | # run functional tests 25 | - $(aws ecr get-login --no-include-email --region us-west-2) 26 | - IGNORE_COVERAGE=- tox -e py38,py39,py310 -- test/functional 27 | 28 | 29 | # build dummy container 30 | - python setup.py sdist 31 | - cp dist/sagemaker_training-*.tar.gz test/container/dummy/sagemaker_training.tar.gz 32 | - cd test/container 33 | - docker build -t sagemaker-training-toolkit-test:dummy -f dummy/Dockerfile . 34 | - rm dummy/sagemaker_training.tar.gz 35 | 36 | # build tensorflow container 37 | - cd ../.. 38 | - aws ecr get-login-password --region us-west-2 | docker login --username AWS --password-stdin 763104351884.dkr.ecr.us-west-2.amazonaws.com 39 | - cp dist/sagemaker_training-*.tar.gz test/container/dummy/sagemaker_training.tar.gz 40 | - cd test/container 41 | - docker build -t sagemaker-training-toolkit-test:tensorflow -f tensorflow/Dockerfile . 42 | #- docker tag sagemaker-training-toolkit-test:tensorflow sagemaker-training-toolkit-test:tensorflow$(($(date +%s%N)/1000000)) #append currenttime to the tag name. 43 | - rm dummy/sagemaker_training.tar.gz 44 | 45 | - cd ../.. 46 | 47 | # run local integration tests 48 | - IGNORE_COVERAGE=- tox -e py38,py39,py310 -- test/integration/local 49 | 50 | # generate the distribution package 51 | - python3 setup.py sdist 52 | 53 | # publish the release to github 54 | - git-release --publish 55 | 56 | artifacts: 57 | files: 58 | - dist/sagemaker_training*.tar.gz 59 | name: ARTIFACT_1 60 | discard-paths: yes 61 | -------------------------------------------------------------------------------- /buildspec.yml: -------------------------------------------------------------------------------- 1 | version: 0.2 2 | 3 | phases: 4 | pre_build: 5 | commands: 6 | - start-dockerd 7 | 8 | build: 9 | commands: 10 | # run linters 11 | - TOX_PARALLEL_NO_SPINNER=1 12 | - tox -e flake8,black-check,pylint --parallel all 13 | 14 | # run README check 15 | - tox -e twine 16 | 17 | # run unit tests 18 | - AWS_ACCESS_KEY_ID= AWS_SECRET_ACCESS_KEY= AWS_SESSION_TOKEN= 19 | AWS_CONTAINER_CREDENTIALS_RELATIVE_URI= AWS_DEFAULT_REGION= 20 | tox -v -e py38,py39,py310 -- test/unit 21 | 22 | # build toolkit 23 | - python setup.py sdist 24 | 25 | # run functional tests 26 | - $(aws ecr get-login --no-include-email --region us-west-2) 27 | - IGNORE_COVERAGE=- tox -v -e py38,py39,py310 -- test/functional 28 | 29 | # build dummy container 30 | - cp dist/sagemaker_training-*.tar.gz test/container/dummy/sagemaker_training.tar.gz 31 | - cd test/container 32 | - docker build -t sagemaker-training-toolkit-test:dummy -f dummy/Dockerfile . 33 | - rm dummy/sagemaker_training.tar.gz 34 | 35 | # build tensorflow container 36 | - cd ../.. 37 | - aws ecr get-login-password --region $AWS_DEFAULT_REGION | docker login --username AWS --password-stdin 763104351884.dkr.ecr.${AWS_DEFAULT_REGION}.amazonaws.com 38 | - cp dist/sagemaker_training-*.tar.gz test/container/dummy/sagemaker_training.tar.gz 39 | - cd test/container 40 | - docker build -t sagemaker-training-toolkit-test:tensorflow -f tensorflow/Dockerfile . 41 | #- docker tag sagemaker-training-toolkit-test:tensorflow sagemaker-training-toolkit-test:tensorflow$(($(date +%s%N)/1000000)) #append currenttime to the tag name. 42 | - rm dummy/sagemaker_training.tar.gz 43 | 44 | - cd ../.. 45 | 46 | # run local integration tests 47 | - IGNORE_COVERAGE=- tox -v -e py38,py39,py310 -- test/integration/local 48 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [bdist_wheel] 2 | universal=1 3 | 4 | [metadata] 5 | description-file = README.md -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018-2021 Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the 'License'). You 4 | # may not use this file except in compliance with the License. A copy of 5 | # the License is located at 6 | # 7 | # http://aws.amazon.com/apache2.0/ 8 | # 9 | # or in the 'license' file accompanying this file. This file is 10 | # distributed on an 'AS IS' BASIS, WITHOUT WARRANTIES OR CONDITIONS OF 11 | # ANY KIND, either express or implied. See the License for the specific 12 | # language governing permissions and limitations under the License. 13 | from __future__ import absolute_import 14 | 15 | from glob import glob 16 | import os 17 | import sys 18 | 19 | import setuptools 20 | 21 | 22 | def read(file_name): 23 | return open(os.path.join(os.path.dirname(__file__), file_name)).read() 24 | 25 | 26 | def read_version(): 27 | return read("VERSION").strip() 28 | 29 | 30 | packages = setuptools.find_packages(where="src", exclude=("test",)) 31 | 32 | required_packages = [ 33 | "numpy", 34 | "boto3", 35 | "six", 36 | "pip", 37 | "retrying>=1.3.3", 38 | "gevent", 39 | "inotify_simple==1.2.1", 40 | "werkzeug>=0.15.5", 41 | "paramiko>=2.4.2", 42 | "psutil>=5.6.7", 43 | "protobuf>=5.28.1", 44 | "scipy>=1.2.2", 45 | "boto3>=1.28.57", 46 | "botocore>=1.31.57", 47 | ] 48 | 49 | # enum is introduced in Python 3.4. Installing enum back port 50 | if sys.version_info < (3, 4): 51 | required_packages.append("enum34 >= 1.1.6") 52 | 53 | gethostname = setuptools.Extension( 54 | "gethostname", 55 | sources=["src/sagemaker_training/c/gethostname.c", "src/sagemaker_training/c/jsmn.c"], 56 | include_dirs=["src/sagemaker_training/c"], 57 | extra_compile_args=["-Wall", "-shared", "-export-dynamic", "-ldl"], 58 | ) 59 | 60 | setuptools.setup( 61 | name="sagemaker_training", 62 | version=read_version(), 63 | description="Open source library for creating containers to run on Amazon SageMaker.", 64 | packages=packages, 65 | package_dir={"sagemaker_training": "src/sagemaker_training"}, 66 | py_modules=[os.path.splitext(os.path.basename(path))[0] for path in glob("src/*.py")], 67 | ext_modules=[gethostname], 68 | long_description=read("README.md"), 69 | long_description_content_type="text/markdown", 70 | author="Amazon Web Services", 71 | url="https://github.com/aws/sagemaker-training-toolkit/", 72 | license="Apache License 2.0", 73 | classifiers=[ 74 | "Development Status :: 5 - Production/Stable", 75 | "Intended Audience :: Developers", 76 | "Natural Language :: English", 77 | "License :: OSI Approved :: Apache Software License", 78 | "Programming Language :: Python", 79 | "Programming Language :: Python :: 3.8", 80 | "Programming Language :: Python :: 3.9", 81 | "Programming Language :: Python :: 3.10", 82 | ], 83 | install_requires=required_packages, 84 | extras_require={ 85 | "test": [ 86 | "tox==4.6.4", 87 | "pytest==4.4.1", 88 | "pytest-cov", 89 | "mock", 90 | "sagemaker[local]>=2.172.0,<3", 91 | "black==22.3.0 ; python_version >= '3.8'", 92 | ] 93 | }, 94 | entry_points={"console_scripts": ["train=sagemaker_training.cli.train:main"]}, 95 | ) 96 | -------------------------------------------------------------------------------- /src/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018-2020 Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"). You 4 | # may not use this file except in compliance with the License. A copy of 5 | # the License is located at 6 | # 7 | # http://aws.amazon.com/apache2.0/ 8 | # 9 | # or in the "license" file accompanying this file. This file is 10 | # distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF 11 | # ANY KIND, either express or implied. See the License for the specific 12 | # language governing permissions and limitations under the License. 13 | from __future__ import absolute_import 14 | -------------------------------------------------------------------------------- /src/sagemaker_training/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"). You 4 | # may not use this file except in compliance with the License. A copy of 5 | # the License is located at 6 | # 7 | # http://aws.amazon.com/apache2.0/ 8 | # 9 | # or in the "license" file accompanying this file. This file is 10 | # distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF 11 | # ANY KIND, either express or implied. See the License for the specific 12 | # language governing permissions and limitations under the License. 13 | """This file is executed when the sagemaker_training package is imported.""" 14 | from __future__ import absolute_import 15 | 16 | # list of errors: To show user error message on the SM Training job page 17 | # [x for x in dir(__builtins__) if 'Error' in x or 'Exception' in x] 18 | _PYTHON_ERRORS_ = [ 19 | "BaseException", 20 | "Exception", 21 | "ArithmeticError", 22 | "AssertionError", 23 | "AttributeError", 24 | "BlockingIOError", 25 | "BrokenPipeError", 26 | "BufferError", 27 | "ChildProcessError", 28 | "ConnectionAbortedError", 29 | "ConnectionError", 30 | "ConnectionRefusedError", 31 | "ConnectionResetError", 32 | "EOFError", 33 | "EnvironmentError", 34 | "FileExistsError", 35 | "FileNotFoundError", 36 | "FloatingPointError", 37 | "IOError", 38 | "ImportError", 39 | "IndentationError", 40 | "IndexError", 41 | "InterruptedError", 42 | "IsADirectoryError", 43 | "KeyError", 44 | "LookupError", 45 | "MemoryError", 46 | "ModuleNotFoundError", 47 | "NameError", 48 | "NotADirectoryError", 49 | "NotImplementedError", 50 | "OSError", 51 | "OverflowError", 52 | "PermissionError", 53 | "ProcessLookupError", 54 | "RecursionError", 55 | "ReferenceError", 56 | "RuntimeError", 57 | "SyntaxError", 58 | "SystemError", 59 | "TabError", 60 | "TimeoutError", 61 | "TypeError", 62 | "UnboundLocalError", 63 | "UnicodeDecodeError", 64 | "UnicodeEncodeError", 65 | "UnicodeError", 66 | "UnicodeTranslateError", 67 | "ValueError", 68 | "ZeroDivisionError", 69 | "Invalid requirement", 70 | "ResourceExhaustedError", 71 | "OutOfRangeError", 72 | "InvalidArgumentError", 73 | ] 74 | 75 | _MPI_ERRORS_ = ["mpirun.real", "ORTE"] 76 | 77 | SM_EFA_NCCL_INSTANCES = [ 78 | "ml.g4dn.8xlarge", 79 | "ml.g4dn.12xlarge", 80 | "ml.g5.48xlarge", 81 | "ml.p3dn.24xlarge", 82 | "ml.p4d.24xlarge", 83 | "ml.p4de.24xlarge", 84 | "ml.p5.48xlarge", 85 | "ml.trn1.32xlarge", 86 | ] 87 | 88 | SM_EFA_RDMA_INSTANCES = [ 89 | "ml.p4d.24xlarge", 90 | "ml.p4de.24xlarge", 91 | "ml.trn1.32xlarge", 92 | ] 93 | 94 | SM_TRAINING_COMPILER_PATHS = [ 95 | "tensorflow/compiler/xla", 96 | "tensorflow/compiler/tf2xla", 97 | "tensorflow/python/compiler/xla", 98 | "torch_xla/", 99 | ] 100 | -------------------------------------------------------------------------------- /src/sagemaker_training/_entry_point_type.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the 'License'). You 4 | # may not use this file except in compliance with the License. A copy of 5 | # the License is located at 6 | # 7 | # http://aws.amazon.com/apache2.0/ 8 | # 9 | # or in the 'license' file accompanying this file. This file is 10 | # distributed on an 'AS IS' BASIS, WITHOUT WARRANTIES OR CONDITIONS OF 11 | # ANY KIND, either express or implied. See the License for the specific 12 | # language governing permissions and limitations under the License. 13 | """This module contains an enumerated type and helper functions related 14 | to different types of training entry points (Python package, Python 15 | script, bash script, etc.) 16 | """ 17 | import enum 18 | import os 19 | 20 | 21 | class _EntryPointType(enum.Enum): 22 | """Enumerated type consisting of valid types of training entry points.""" 23 | 24 | PYTHON_PACKAGE = "PYTHON_PACKAGE" 25 | PYTHON_PROGRAM = "PYTHON_PROGRAM" 26 | COMMAND = "COMMAND" 27 | 28 | 29 | PYTHON_PACKAGE = _EntryPointType.PYTHON_PACKAGE 30 | PYTHON_PROGRAM = _EntryPointType.PYTHON_PROGRAM 31 | COMMAND = _EntryPointType.COMMAND 32 | 33 | 34 | def get(path, name): # type: (str, str) -> _EntryPointType 35 | """ 36 | Args: 37 | path (string): Directory where the entry point is located. 38 | name (string): Name of the entry point file. 39 | 40 | Returns: 41 | (_EntryPointType): The type of the entry point. 42 | """ 43 | if name.endswith(".sh"): 44 | return _EntryPointType.COMMAND 45 | elif "setup.py" in os.listdir(path): 46 | return _EntryPointType.PYTHON_PACKAGE 47 | elif name.endswith(".py"): 48 | return _EntryPointType.PYTHON_PROGRAM 49 | else: 50 | return _EntryPointType.COMMAND 51 | -------------------------------------------------------------------------------- /src/sagemaker_training/c/gethostname.c: -------------------------------------------------------------------------------- 1 | // Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the 'License'). You 4 | // may not use this file except in compliance with the License. A copy of 5 | // the License is located at 6 | // 7 | // http://aws.amazon.com/apache2.0/ 8 | // 9 | // or in the 'license' file accompanying this file. This file is 10 | // distributed on an 'AS IS' BASIS, WITHOUT WARRANTIES OR CONDITIONS OF 11 | // ANY KIND, either express or implied. See the License for the specific 12 | // language governing permissions and limitations under the License. 13 | 14 | #include 15 | #include 16 | #include 17 | #include "jsmn.h" 18 | 19 | static int jsoneq(const char *json, jsmntok_t *tok, const char *s) 20 | { 21 | if (tok->type == JSMN_STRING && (int) strlen(s) == tok->end - tok->start && 22 | strncmp(json + tok->start, s, tok->end - tok->start) == 0) 23 | { 24 | return 0; 25 | } 26 | return -1; 27 | } 28 | 29 | int gethostname(char *name, size_t len) 30 | { 31 | int r; 32 | FILE *file = fopen("/opt/ml/input/config/resourceconfig.json", "r"); 33 | 34 | fseek(file, 0, SEEK_END); 35 | 36 | long length = ftell(file); 37 | fseek(file, 0, SEEK_SET); 38 | 39 | char *json_string = malloc(length + 1); 40 | fread(json_string, 1, length, file); 41 | fclose(file); 42 | json_string[length] = '\0'; 43 | 44 | jsmn_parser parser; 45 | jsmntok_t token[1024]; 46 | 47 | 48 | jsmn_init(&parser); 49 | r = jsmn_parse(&parser, json_string, strlen(json_string), token, sizeof(token) / sizeof(token[0])); 50 | 51 | 52 | if (r < 0) { 53 | printf("Failed to parse JSON: %d\n", r); 54 | return 1; 55 | } 56 | 57 | /* Assume the top-level element is an object */ 58 | if (r < 1 || token[0].type != JSMN_OBJECT) { 59 | printf("Object expected\n"); 60 | return 1; 61 | } 62 | 63 | /* Loop over all keys of the root object */ 64 | int i; 65 | for (i = 1; i < r; i++) 66 | { 67 | if (jsoneq(json_string, &token[i], "current_host") == 0) 68 | { 69 | // strndup guarantees that val is null terminated. See https://linux.die.net/man/3/strndup 70 | char *val = strndup(json_string + token[i + 1].start, token[i + 1].end - token[i + 1].start); 71 | 72 | // Copy val into name. If strlen(val) > strlen(name) only len characters are copied 73 | strncpy(name, val, len); 74 | 75 | // As per posix (http://man7.org/linux/man-pages/man2/gethostname.2.html), 76 | // len is the size of the buffer, so we null terminate the last 77 | // position in the buffer 78 | name[len - 1] = '\0'; 79 | 80 | free(val); 81 | free(json_string); 82 | return 0; 83 | } 84 | } 85 | 86 | free(json_string); 87 | return 1; 88 | } 89 | 90 | static PyObject *gethostname_call(PyObject *self, PyObject *args) 91 | { 92 | long unsigned command; 93 | char name[40]; 94 | 95 | if (!PyArg_ParseTuple(args, "k", &command)) 96 | { 97 | return NULL; 98 | } 99 | 100 | gethostname(name, command); 101 | 102 | return Py_BuildValue("s", name); 103 | } 104 | 105 | static PyMethodDef GetHostnameMethods[] = { 106 | { 107 | "call", 108 | gethostname_call, 109 | METH_VARARGS, 110 | }, 111 | {NULL, NULL, 0, NULL}, // sentinel 112 | }; 113 | 114 | #if PY_MAJOR_VERSION >= 3 115 | static PyModuleDef gethostnamemodule = { 116 | PyModuleDef_HEAD_INIT, 117 | "gethostname", 118 | "Returns the current host defined in resourceconfig.json", 119 | -1, 120 | GetHostnameMethods, 121 | }; 122 | 123 | PyMODINIT_FUNC PyInit_gethostname() 124 | { 125 | return PyModule_Create(&gethostnamemodule); 126 | } 127 | #else 128 | PyMODINIT_FUNC initgethostname() 129 | { 130 | PyObject *module; 131 | 132 | module = Py_InitModule3( 133 | "gethostname", GetHostnameMethods, "Returns the current host defined in resourceconfig.json"); 134 | } 135 | #endif 136 | -------------------------------------------------------------------------------- /src/sagemaker_training/c/jsmn.h: -------------------------------------------------------------------------------- 1 | // original code from https://github.com/zserge/jsmn under the MIT license 2 | // https://github.com/zserge/jsmn/blob/master/LICENSE 3 | #ifndef __JSMN_H_ 4 | #define __JSMN_H_ 5 | 6 | #include 7 | 8 | #ifdef __cplusplus 9 | extern "C" { 10 | #endif 11 | 12 | /** 13 | * JSON type identifier. Basic types are: 14 | * o Object 15 | * o Array 16 | * o String 17 | * o Other primitive: number, boolean (true/false) or null 18 | */ 19 | typedef enum { 20 | JSMN_UNDEFINED = 0, 21 | JSMN_OBJECT = 1, 22 | JSMN_ARRAY = 2, 23 | JSMN_STRING = 3, 24 | JSMN_PRIMITIVE = 4 25 | } jsmntype_t; 26 | 27 | enum jsmnerr { 28 | /* Not enough tokens were provided */ 29 | JSMN_ERROR_NOMEM = -1, 30 | /* Invalid character inside JSON string */ 31 | JSMN_ERROR_INVAL = -2, 32 | /* The string is not a full JSON packet, more bytes expected */ 33 | JSMN_ERROR_PART = -3 34 | }; 35 | 36 | /** 37 | * JSON token description. 38 | * type type (object, array, string etc.) 39 | * start start position in JSON data string 40 | * end end position in JSON data string 41 | */ 42 | typedef struct { 43 | jsmntype_t type; 44 | int start; 45 | int end; 46 | int size; 47 | #ifdef JSMN_PARENT_LINKS 48 | int parent; 49 | #endif 50 | } jsmntok_t; 51 | 52 | /** 53 | * JSON parser. Contains an array of token blocks available. Also stores 54 | * the string being parsed now and current position in that string 55 | */ 56 | typedef struct { 57 | unsigned int pos; /* offset in the JSON string */ 58 | unsigned int toknext; /* next token to allocate */ 59 | int toksuper; /* superior token node, e.g parent object or array */ 60 | } jsmn_parser; 61 | 62 | /** 63 | * Create JSON parser over an array of tokens 64 | */ 65 | void jsmn_init(jsmn_parser *parser); 66 | 67 | /** 68 | * Run JSON parser. It parses a JSON data string into and array of tokens, each describing 69 | * a single JSON object. 70 | */ 71 | int jsmn_parse(jsmn_parser *parser, const char *js, size_t len, 72 | jsmntok_t *tokens, unsigned int num_tokens); 73 | 74 | #ifdef __cplusplus 75 | } 76 | #endif 77 | 78 | #endif /* __JSMN_H_ */ 79 | -------------------------------------------------------------------------------- /src/sagemaker_training/cli/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aws/sagemaker-training-toolkit/a0073959d4de7e8311b370a1f143e44dcf801d1e/src/sagemaker_training/cli/__init__.py -------------------------------------------------------------------------------- /src/sagemaker_training/cli/train.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the 'License'). You 4 | # may not use this file except in compliance with the License. A copy of 5 | # the License is located at 6 | # 7 | # http://aws.amazon.com/apache2.0/ 8 | # 9 | # or in the 'license' file accompanying this file. This file is 10 | # distributed on an 'AS IS' BASIS, WITHOUT WARRANTIES OR CONDITIONS OF 11 | # ANY KIND, either express or implied. See the License for the specific 12 | # language governing permissions and limitations under the License. 13 | """This module contains the entry point for the image.""" 14 | from sagemaker_training import trainer 15 | 16 | 17 | def main(): 18 | """Calls the function that runs training in the container.""" 19 | trainer.train() 20 | 21 | 22 | if __name__ == "__main__": 23 | main() 24 | -------------------------------------------------------------------------------- /src/sagemaker_training/content_types.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the 'License'). You 4 | # may not use this file except in compliance with the License. A copy of 5 | # the License is located at 6 | # 7 | # http://aws.amazon.com/apache2.0/ 8 | # 9 | # or in the 'license' file accompanying this file. This file is 10 | # distributed on an 'AS IS' BASIS, WITHOUT WARRANTIES OR CONDITIONS OF 11 | # ANY KIND, either express or implied. See the License for the specific 12 | # language governing permissions and limitations under the License. 13 | """This module contains constants representing IANA media types.""" 14 | JSON = "application/json" 15 | CSV = "text/csv" 16 | OCTET_STREAM = "application/octet-stream" 17 | ANY = "*/*" 18 | NPY = "application/x-npy" 19 | UTF8_TYPES = [JSON, CSV] 20 | -------------------------------------------------------------------------------- /src/sagemaker_training/encoders.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the 'License'). You 4 | # may not use this file except in compliance with the License. A copy of 5 | # the License is located at 6 | # 7 | # http://aws.amazon.com/apache2.0/ 8 | # 9 | # or in the 'license' file accompanying this file. This file is 10 | # distributed on an 'AS IS' BASIS, WITHOUT WARRANTIES OR CONDITIONS OF 11 | # ANY KIND, either express or implied. See the License for the specific 12 | # language governing permissions and limitations under the License. 13 | """This module contains utilities to encode and decode different content types.""" 14 | from __future__ import absolute_import 15 | 16 | import csv 17 | import io 18 | import json 19 | 20 | import numpy as np 21 | from scipy.sparse import issparse 22 | from six import BytesIO, StringIO 23 | 24 | from sagemaker_training import content_types, errors 25 | from sagemaker_training.recordio import ( 26 | _write_numpy_to_dense_tensor, 27 | _write_spmatrix_to_sparse_tensor, 28 | ) 29 | 30 | 31 | def array_to_npy(array_like): 32 | """Convert an array-like object to the NPY format. 33 | 34 | To understand what an array-like object is, please see: 35 | https://docs.scipy.org/doc/numpy/user/basics.creation.html#converting-python-array-like-objects-to-numpy-arrays 36 | 37 | Args: 38 | array_like (np.array or Iterable or int or float): Array-like object to be converted to NPY. 39 | 40 | Returns: 41 | (obj): NPY array. 42 | """ 43 | buffer = BytesIO() 44 | np.save(buffer, array_like) 45 | return buffer.getvalue() 46 | 47 | 48 | def npy_to_numpy(npy_array): 49 | """Convert an NPY array into numpy. 50 | 51 | Args: 52 | npy_array (npy array): NPY array to be converted. 53 | Returns: 54 | (np.array): Converted numpy array. 55 | """ 56 | stream = BytesIO(npy_array) 57 | return np.load(stream, allow_pickle=True) 58 | 59 | 60 | def array_to_json(array_like): 61 | """Convert an array-like object to JSON. 62 | 63 | To understand what an array-like object is, please see: 64 | https://docs.scipy.org/doc/numpy/user/basics.creation.html#converting-python-array-like-objects-to-numpy-arrays 65 | 66 | Args: 67 | array_like (np.array or Iterable or int or float): Array-like object to be 68 | converted to JSON. 69 | 70 | Returns: 71 | (str): Object serialized to JSON. 72 | """ 73 | 74 | def default(_array_like): 75 | if hasattr(_array_like, "tolist"): 76 | return _array_like.tolist() 77 | return json.JSONEncoder().default(_array_like) 78 | 79 | return json.dumps(array_like, default=default) 80 | 81 | 82 | def json_to_numpy(string_like, dtype=None): 83 | """Convert a JSON object to a numpy array. 84 | 85 | Args: 86 | string_like (str): JSON string. 87 | dtype (dtype, optional): Data type of the resulting array. If None, 88 | the dtypes will be determined by the 89 | contents of each column, individually. 90 | This argument can only be used to 91 | 'upcast' the array. For downcasting, 92 | use the .astype(t) method. 93 | Returns: 94 | (np.array): Numpy array. 95 | """ 96 | data = json.loads(string_like) 97 | return np.array(data, dtype=dtype) 98 | 99 | 100 | def csv_to_numpy(string_like, dtype=None): 101 | """Convert a CSV object to a numpy array. 102 | 103 | Args: 104 | string_like (str): CSV string. 105 | dtype (dtype, optional): Data type of the resulting array. If None, the 106 | dtypes will be determined by the contents of 107 | each column, individually. This argument can 108 | only be used to 'upcast' the array. For 109 | downcasting, use the .astype(t) method. 110 | Returns: 111 | (np.array): Numpy array. 112 | """ 113 | try: 114 | stream = StringIO(string_like) 115 | reader = csv.reader(stream, delimiter=",", quotechar='"', doublequote=True, strict=True) 116 | array = np.array([row for row in reader]).squeeze() 117 | array = array.astype(dtype) 118 | except ValueError as e: 119 | if dtype is not None: 120 | raise errors.ClientError( 121 | "Error while writing numpy array: {}. dtype is: {}".format(e, dtype) 122 | ) 123 | except Exception as e: 124 | raise errors.ClientError("Error while decoding csv: {}".format(e)) 125 | return array 126 | 127 | 128 | def array_to_csv(array_like): 129 | """Convert an array like object to CSV. 130 | 131 | To understand what an array-like object is, please see: 132 | https://docs.scipy.org/doc/numpy/user/basics.creation.html#converting-python-array-like-objects-to-numpy-arrays 133 | 134 | Args: 135 | array_like (np.array or Iterable or int or float): Array-like object to be converted to CSV. 136 | 137 | Returns: 138 | (str): Object serialized to CSV. 139 | """ 140 | array = np.array(array_like) 141 | if len(array.shape) == 1: 142 | array = np.reshape(array, (array.shape[0], 1)) # pylint: disable=unsubscriptable-object 143 | 144 | try: 145 | stream = StringIO() 146 | writer = csv.writer( 147 | stream, lineterminator="\n", delimiter=",", quotechar='"', doublequote=True, strict=True 148 | ) 149 | writer.writerows(array) 150 | return stream.getvalue() 151 | except csv.Error as e: 152 | raise errors.ClientError("Error while encoding csv: {}".format(e)) 153 | 154 | 155 | def array_to_recordio_protobuf(array_like, labels=None): 156 | """Convert an array like object to recordio-protobuf format. 157 | 158 | To understand what an array-like object is, please see: 159 | https://docs.scipy.org/doc/numpy/user/basics.creation.html#converting-python-array-like-objects-to-numpy-arrays 160 | 161 | Args: 162 | array_like (np.array or scipy.sparse.csr_matrix): Array-like object to be 163 | converted to recordio-protobuf. 164 | labels (np.array or scipy.sparse.csr_matrix): Array-like object representing 165 | the labels to be encoded. 166 | 167 | Returns: 168 | buffer: Bytes buffer recordio-protobuf. 169 | """ 170 | 171 | if len(array_like.shape) == 1: 172 | array_like = array_like.reshape(1, array_like.shape[0]) 173 | assert len(array_like.shape) == 2, "Expecting a 1 or 2 dimensional array" 174 | 175 | buffer = io.BytesIO() 176 | 177 | if issparse(array_like): 178 | _write_spmatrix_to_sparse_tensor(buffer, array_like, labels) 179 | else: 180 | _write_numpy_to_dense_tensor(buffer, array_like, labels) 181 | buffer.seek(0) 182 | return buffer.getvalue() 183 | 184 | 185 | encoders_map = { 186 | content_types.NPY: array_to_npy, 187 | content_types.CSV: array_to_csv, 188 | content_types.JSON: array_to_json, 189 | } 190 | _decoders_map = { 191 | content_types.NPY: npy_to_numpy, 192 | content_types.CSV: csv_to_numpy, 193 | content_types.JSON: json_to_numpy, 194 | } 195 | 196 | 197 | def decode(obj, content_type): 198 | """Decode an object of one of the default content types to a numpy array. 199 | 200 | Args: 201 | obj (object): Object to be decoded. 202 | content_type (str): Content type to be used. 203 | 204 | Returns: 205 | np.array: Decoded object. 206 | """ 207 | try: 208 | decoder = _decoders_map[content_type] 209 | return decoder(obj) 210 | except KeyError: 211 | raise errors.UnsupportedFormatError(content_type) 212 | 213 | 214 | def encode(array_like, content_type): 215 | """Encode an array-like object in a specific content_type to a numpy array. 216 | 217 | To understand what an array-like object is, please see: 218 | https://docs.scipy.org/doc/numpy/user/basics.creation.html#converting-python-array-like-objects-to-numpy-arrays 219 | 220 | Args: 221 | array_like (np.array or Iterable or int or float): Array-like object to be 222 | converted to numpy. 223 | content_type (str): Content type to be used. 224 | 225 | Returns: 226 | (np.array): Object converted as numpy array. 227 | """ 228 | try: 229 | encoder = encoders_map[content_type] 230 | return encoder(array_like) 231 | except KeyError: 232 | raise errors.UnsupportedFormatError(content_type) 233 | -------------------------------------------------------------------------------- /src/sagemaker_training/entry_point.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018-2021 Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the 'License'). You 4 | # may not use this file except in compliance with the License. A copy of 5 | # the License is located at 6 | # 7 | # http://aws.amazon.com/apache2.0/ 8 | # 9 | # or in the 'license' file accompanying this file. This file is 10 | # distributed on an 'AS IS' BASIS, WITHOUT WARRANTIES OR CONDITIONS OF 11 | # ANY KIND, either express or implied. See the License for the specific 12 | # language governing permissions and limitations under the License. 13 | """This module contains functions to install and run the user-provided training 14 | entry point. 15 | """ 16 | from __future__ import absolute_import 17 | 18 | import logging 19 | import os 20 | import socket 21 | import sys 22 | 23 | 24 | from retrying import retry 25 | 26 | from sagemaker_training import _entry_point_type, environment, files, modules, runner 27 | 28 | logger = logging.getLogger(__name__) 29 | 30 | 31 | def run( 32 | uri, 33 | user_entry_point, 34 | args, 35 | env_vars=None, 36 | wait=True, 37 | capture_error=False, 38 | runner_type=runner.ProcessRunnerType, 39 | extra_opts=None, 40 | ): 41 | """Download, prepare and execute a compressed tar file from S3 or provided directory as a user 42 | entry point. Run the user entry point, passing env_vars as environment variables and args 43 | as command arguments. 44 | 45 | If the entry point is: 46 | - A Python package: executes the packages as >>> env_vars python -m module_name + args 47 | - A Python script: executes the script as >>> env_vars python module_name + args 48 | - Any other: executes the command as >>> env_vars /bin/sh -c ./module_name + args 49 | 50 | Example: 51 | >>>from sagemaker_training import entry_point, environment, mapping 52 | 53 | >>>env = environment.Environment() 54 | {'channel-input-dirs': {'training': '/opt/ml/input/training'}, 55 | 'model_dir': '/opt/ml/model', ...} 56 | 57 | 58 | >>>hyperparameters = environment.hyperparameters 59 | {'batch-size': 128, 'model_dir': '/opt/ml/model'} 60 | 61 | >>>args = mapping.to_cmd_args(hyperparameters) 62 | ['--batch-size', '128', '--model_dir', '/opt/ml/model'] 63 | 64 | >>>env_vars = mapping.to_env_vars() 65 | ['SAGEMAKER_CHANNELS':'training', 'SAGEMAKER_CHANNEL_TRAINING':'/opt/ml/input/training', 66 | 'MODEL_DIR':'/opt/ml/model', ...} 67 | 68 | >>>entry_point.run('user_script', args, env_vars) 69 | SAGEMAKER_CHANNELS=training SAGEMAKER_CHANNEL_TRAINING=/opt/ml/input/training \ 70 | SAGEMAKER_MODEL_DIR=/opt/ml/model python -m user_script --batch-size 128 71 | --model_dir /opt/ml/model 72 | 73 | Args: 74 | uri (str): The location of the module or script. This can be an S3 uri, a path to 75 | a local directory, or a path to a local tarball. 76 | user_entry_point (str): Name of the user provided entry point. 77 | args ([str]): A list of program arguments. 78 | env_vars (dict(str,str)): A map containing the environment variables to be written 79 | (default: None). 80 | wait (bool): If the user entry point should be run to completion before this method returns 81 | (default: True). 82 | capture_error (bool): Default false. If True, the running process captures the 83 | stderr, and appends it to the returned Exception message in case of errors. 84 | runner_type (sagemaker_training.runner.RunnerType): The type of runner object to 85 | be created (default: sagemaker_training.runner.ProcessRunnerType). 86 | extra_opts (dict(str,str)): Additional options for running the entry point (default: None). 87 | Currently, this only applies for MPI. 88 | 89 | Returns: 90 | sagemaker_training.process.ProcessRunner: The runner object responsible for 91 | executing the entry point. 92 | """ 93 | env_vars = env_vars or {} 94 | env_vars = env_vars.copy() 95 | 96 | files.download_and_extract(uri=uri, path=environment.code_dir) 97 | install(name=user_entry_point, path=environment.code_dir, capture_error=capture_error) 98 | 99 | environment.write_env_vars(env_vars) 100 | 101 | _sm_studio_local_mode = os.environ.get("SM_STUDIO_LOCAL_MODE", "False").lower() == "true" 102 | 103 | if not _sm_studio_local_mode: 104 | _wait_hostname_resolution() 105 | else: 106 | logger.info("Bypass DNS check in case of Studio Local Mode execution.") 107 | 108 | return runner.get(runner_type, user_entry_point, args, env_vars, extra_opts).run( 109 | wait, capture_error 110 | ) 111 | 112 | 113 | def install(name, path=environment.code_dir, capture_error=False): 114 | """Install the user provided entry point to be executed as follows: 115 | - add the path to sys path 116 | - if the user entry point is a command, gives exec permissions to the script 117 | 118 | Args: 119 | name (str): Name of the script or module. 120 | path (str): Path to directory where the entry point will be installed. 121 | capture_error (bool): Default false. If True, the running process captures the 122 | stderr, and appends it to the returned Exception message in case of errors. 123 | """ 124 | if path not in sys.path: 125 | sys.path.insert(0, path) 126 | 127 | entry_point_type = _entry_point_type.get(path, name) 128 | 129 | if entry_point_type is _entry_point_type.PYTHON_PACKAGE: 130 | modules.install(path, capture_error) 131 | elif entry_point_type is _entry_point_type.PYTHON_PROGRAM and modules.has_requirements(path): 132 | modules.install_requirements(path, capture_error) 133 | 134 | if entry_point_type is _entry_point_type.COMMAND: 135 | os.chmod(os.path.join(path, name), 511) 136 | 137 | 138 | @retry(stop_max_delay=1000 * 60 * 15, wait_exponential_multiplier=100, wait_exponential_max=30000) 139 | def _dns_lookup(host): 140 | """Retrying DNS lookup on host.""" 141 | return socket.gethostbyname(host) 142 | 143 | 144 | def _wait_hostname_resolution(): 145 | """Wait for the hostname resolution of the container. This is known behavior as the cluster 146 | boots up and has been documented here: 147 | https://docs.aws.amazon.com/sagemaker/latest/dg/your-algorithms-training-algo-running-container.html#your-algorithms-training-algo-running-container-dist-training 148 | """ 149 | for host in environment.Environment().hosts: 150 | _dns_lookup(host) 151 | -------------------------------------------------------------------------------- /src/sagemaker_training/errors.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the 'License'). You 4 | # may not use this file except in compliance with the License. A copy of 5 | # the License is located at 6 | # 7 | # http://aws.amazon.com/apache2.0/ 8 | # 9 | # or in the 'license' file accompanying this file. This file is 10 | # distributed on an 'AS IS' BASIS, WITHOUT WARRANTIES OR CONDITIONS OF 11 | # ANY KIND, either express or implied. See the License for the specific 12 | # language governing permissions and limitations under the License. 13 | """This module contains custom exceptions.""" 14 | from __future__ import absolute_import 15 | 16 | import textwrap 17 | 18 | import six 19 | 20 | 21 | class ClientError(Exception): 22 | """Error class used to separate framework and user errors.""" 23 | 24 | 25 | class SMTrainingCompilerConfigurationError(Exception): 26 | """Error class used to separate configuration errors""" 27 | 28 | 29 | class _CalledProcessError(ClientError): 30 | """This exception is raised when a process run by check_call() or 31 | check_output() returns a non-zero exit status. 32 | 33 | Attributes: 34 | cmd, return_code, output 35 | """ 36 | 37 | def __init__(self, cmd, return_code=None, output=None, info=None): 38 | self.return_code = str(return_code) 39 | self.cmd = cmd 40 | self.output = output 41 | self.extra_info = info 42 | super(_CalledProcessError, self).__init__() 43 | 44 | def __str__(self): 45 | if six.PY3 and self.output: 46 | # error_msg = "%s" % self.output.decode("latin1") 47 | if isinstance(self.output, bytes): 48 | error_msg = "%s" % self.output.decode("utf-8") 49 | else: 50 | error_msg = "%s" % self.output 51 | elif self.output: 52 | error_msg = "%s" % self.output 53 | else: 54 | error_msg = "" 55 | if self.extra_info is None: 56 | message = '%s:\nExitCode %s\nErrorMessage "%s"\nCommand "%s"' % ( 57 | type(self).__name__, 58 | self.return_code, 59 | error_msg, 60 | self.cmd, 61 | ) 62 | else: 63 | message = '%s:\nExitCode %s\nErrorMessage "%s"\nExtraInfo "%s"\nCommand "%s"' % ( 64 | type(self).__name__, 65 | self.return_code, 66 | error_msg, 67 | self.extra_info, 68 | self.cmd, 69 | ) 70 | return message.strip() 71 | 72 | 73 | class InstallModuleError(_CalledProcessError): 74 | """Error class indicating a module failed to install.""" 75 | 76 | 77 | class InstallRequirementsError(_CalledProcessError): 78 | """Error class indicating a module failed to install.""" 79 | 80 | 81 | class ImportModuleError(ClientError): 82 | """Error class indicating a module failed to import.""" 83 | 84 | 85 | class ExecuteUserScriptError(_CalledProcessError): 86 | """Error class indicating a user script failed to execute.""" 87 | 88 | 89 | class ChannelDoesNotExistError(Exception): 90 | """Error class indicating a channel does not exist.""" 91 | 92 | def __init__(self, channel_name): 93 | super(ChannelDoesNotExistError, self).__init__( 94 | "Channel %s is not a valid channel" % channel_name 95 | ) 96 | 97 | 98 | class UnsupportedFormatError(Exception): 99 | """Error class indicating a content type is not supported by the current framework.""" 100 | 101 | def __init__(self, content_type, **kwargs): 102 | self.message = textwrap.dedent( 103 | """Content type %s is not supported by this framework. 104 | 105 | Please implement input_fn to to deserialize the request data or an output_fn to 106 | serialize the response. For more information, see the SageMaker Python SDK README.""" 107 | % content_type 108 | ) 109 | super(UnsupportedFormatError, self).__init__(self.message, **kwargs) 110 | -------------------------------------------------------------------------------- /src/sagemaker_training/files.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"). You 4 | # may not use this file except in compliance with the License. A copy of 5 | # the License is located at 6 | # 7 | # http://aws.amazon.com/apache2.0/ 8 | # 9 | # or in the "license" file accompanying this file. This file is 10 | # distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF 11 | # ANY KIND, either express or implied. See the License for the specific 12 | # language governing permissions and limitations under the License. 13 | """This module contains utilities related to reading, writing, and downloading 14 | files and directories. 15 | """ 16 | from __future__ import absolute_import 17 | 18 | import contextlib 19 | import json 20 | import os 21 | import shutil 22 | import tarfile 23 | import tempfile 24 | 25 | import boto3 26 | from six.moves.urllib import parse 27 | 28 | from sagemaker_training import environment, logging_config, params 29 | 30 | logger = logging_config.get_logger() 31 | 32 | 33 | def write_success_file(): # type: () -> None 34 | """Create a file 'success' when training is successful. This file doesn't need to 35 | have any content. 36 | See: https://docs.aws.amazon.com/sagemaker/latest/dg/your-algorithms-training-algo.html 37 | """ 38 | file_path = os.path.join(environment.output_dir, "success") 39 | empty_content = "" 40 | write_file(file_path, empty_content) 41 | 42 | 43 | def write_failure_file(failure_msg): # type: (str) -> None 44 | """Create a file 'failure' if training fails after all algorithm output (for example, 45 | logging) completes, the failure description should be written to this file. In a 46 | DescribeTrainingJob response, Amazon SageMaker returns the first 1024 characters from 47 | this file as FailureReason. 48 | See: https://docs.aws.amazon.com/sagemaker/latest/dg/your-algorithms-training-algo.html 49 | Args: 50 | failure_msg: The description of failure. 51 | """ 52 | file_path = os.path.join(environment.output_dir, "failure") 53 | 54 | # Write failure file only if it does not exist 55 | if not os.path.exists(file_path): 56 | write_file(file_path, failure_msg) 57 | else: 58 | logger.info("Failure file exists. Skipping creation....") 59 | 60 | 61 | @contextlib.contextmanager 62 | def tmpdir(suffix="", prefix="tmp", directory=None): # type: (str, str, str) -> None 63 | """Create a temporary directory with a context manager. The file is deleted when the 64 | context exits. 65 | 66 | The prefix, suffix, and dir arguments are the same as for mkstemp(). 67 | 68 | Args: 69 | suffix (str): If suffix is specified, the file name will end with that suffix, 70 | otherwise there will be no suffix. 71 | prefix (str): If prefix is specified, the file name will begin with that prefix; 72 | otherwise, a default prefix is used. 73 | directory (str): If directory is specified, the file will be created in that directory; 74 | otherwise, a default directory is used. 75 | Returns: 76 | str: Path to the directory. 77 | """ 78 | tmp = tempfile.mkdtemp(suffix=suffix, prefix=prefix, dir=directory) 79 | yield tmp 80 | shutil.rmtree(tmp) 81 | 82 | 83 | def write_file(path, data, mode="w"): # type: (str, str, str) -> None 84 | """Write data to a file. 85 | 86 | Args: 87 | path (str): Path to the file. 88 | data (str): Data to be written to the file. 89 | mode (str): Mode which the file will be open. 90 | """ 91 | with open(path, mode) as f: 92 | f.write(data) 93 | 94 | 95 | def read_file(path, mode="r"): 96 | """Read data from a file. 97 | 98 | Args: 99 | path (str): Path to the file. 100 | mode (str): mode which the file will be open. 101 | 102 | Returns: 103 | """ 104 | with open(path, mode) as f: 105 | return f.read() 106 | 107 | 108 | def read_json(path): # type: (str) -> dict 109 | """Read a JSON file. 110 | 111 | Args: 112 | path (str): Path to the file. 113 | 114 | Returns: 115 | (dict[object, object]): A dictionary representation of the JSON file. 116 | """ 117 | with open(path, "r") as f: 118 | return json.load(f) 119 | 120 | 121 | def download_and_extract(uri, path): # type: (str, str) -> None 122 | """Download, prepare and install a compressed tar file from S3 or local directory as 123 | an entry point. 124 | 125 | SageMaker Python SDK saves the user provided entry points as compressed tar files in S3 126 | 127 | Args: 128 | uri (str): the location of the entry point. 129 | path (bool): The path where the script will be installed. It will not download and 130 | install the if the path already has the user entry point. 131 | """ 132 | if not os.path.exists(path): 133 | os.makedirs(path) 134 | if not os.listdir(path): 135 | with tmpdir() as tmp: 136 | if uri.startswith("s3://"): 137 | dst = os.path.join(tmp, "tar_file") 138 | s3_download(uri, dst) 139 | 140 | with tarfile.open(name=dst, mode="r:gz") as t: 141 | t.extractall(path=path) 142 | 143 | elif os.path.isdir(uri): 144 | if uri == path: 145 | return 146 | if os.path.exists(path): 147 | shutil.rmtree(path) 148 | shutil.copytree(uri, path) 149 | elif tarfile.is_tarfile(uri): 150 | with tarfile.open(name=uri, mode="r:gz") as t: 151 | t.extractall(path=path) 152 | else: 153 | shutil.copy2(uri, path) 154 | 155 | 156 | def s3_download(url, dst): # type: (str, str) -> None 157 | """Download a file from S3. 158 | 159 | Args: 160 | url (str): the s3 url of the file. 161 | dst (str): the destination where the file will be saved. 162 | """ 163 | url = parse.urlparse(url) 164 | 165 | if url.scheme != "s3": 166 | raise ValueError("Expecting 's3' scheme, got: %s in %s" % (url.scheme, url)) 167 | 168 | bucket, key = url.netloc, url.path.lstrip("/") 169 | 170 | region = os.environ.get("AWS_REGION", os.environ.get(params.REGION_NAME_ENV)) 171 | endpoint_url = os.environ.get(params.S3_ENDPOINT_URL, None) 172 | s3 = boto3.resource("s3", region_name=region, endpoint_url=endpoint_url) 173 | 174 | s3.Bucket(bucket).download_file(key, dst) 175 | -------------------------------------------------------------------------------- /src/sagemaker_training/functions.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the 'License'). You 4 | # may not use this file except in compliance with the License. A copy of 5 | # the License is located at 6 | # 7 | # http://aws.amazon.com/apache2.0/ 8 | # 9 | # or in the 'license' file accompanying this file. This file is 10 | # distributed on an 'AS IS' BASIS, WITHOUT WARRANTIES OR CONDITIONS OF 11 | # ANY KIND, either express or implied. See the License for the specific 12 | # language governing permissions and limitations under the License. 13 | """This module contains utilities related to function arguments and 14 | function wrappers. 15 | """ 16 | from __future__ import absolute_import 17 | 18 | import inspect 19 | import sys 20 | 21 | import six 22 | 23 | from sagemaker_training import mapping 24 | 25 | 26 | def matching_args(fn, dictionary): 27 | """Given a function fn and a dict dictionary, returns the function 28 | arguments that match the dict keys. 29 | 30 | Example: 31 | 32 | def train(channel_dirs, model_dir): pass 33 | 34 | dictionary = {'channel_dirs': {}, 'model_dir': '/opt/ml/model', 'other_args': None} 35 | 36 | args = functions.matching_args(train, dictionary) # {'channel_dirs': {}, 37 | 'model_dir': '/opt/ml/model'} 38 | 39 | train(**args) 40 | Args: 41 | fn (function): A function. 42 | dictionary (dict): The dictionary with the keys to compare against the 43 | function arguments. 44 | 45 | Returns: 46 | (dict) A dictionary with only matching arguments. 47 | """ 48 | arg_spec = getargspec(fn) 49 | 50 | if arg_spec.keywords: 51 | return dictionary 52 | 53 | return mapping.split_by_criteria(dictionary, arg_spec.args).included 54 | 55 | 56 | def getargspec(fn): # pylint: disable=inconsistent-return-statements 57 | """Get the names and default values of a function's arguments. 58 | 59 | Args: 60 | fn (function): A function. 61 | 62 | Returns: 63 | `inspect.ArgSpec`: A collections.namedtuple with the following attributes: 64 | 65 | * Args: 66 | args (list): A list of the argument names (it may contain nested lists). 67 | varargs (str): Name of the * argument or None. 68 | keywords (str): Names of the ** argument or None. 69 | defaults (tuple): An n-tuple of the default values of the last n arguments. 70 | """ 71 | full_arg_spec = inspect.getfullargspec(fn) 72 | return inspect.ArgSpec( 73 | full_arg_spec.args, full_arg_spec.varargs, full_arg_spec.varkw, full_arg_spec.defaults 74 | ) 75 | 76 | 77 | def error_wrapper(fn, error_class): 78 | """Wraps function fn in a try catch block that re-raises error_class. 79 | 80 | Args: 81 | fn (function): Function to be wrapped. 82 | error_class (Exception): Error class to be re-raised. 83 | 84 | Returns: 85 | (object): Function wrapped in a try catch. 86 | """ 87 | 88 | def wrapper(*args, **kwargs): 89 | try: 90 | return fn(*args, **kwargs) 91 | except Exception as e: # pylint: disable=broad-except 92 | six.reraise(error_class, error_class(e), sys.exc_info()[2]) 93 | 94 | return wrapper 95 | -------------------------------------------------------------------------------- /src/sagemaker_training/intermediate_output.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"). You 4 | # may not use this file except in compliance with the License. A copy of 5 | # the License is located at 6 | # 7 | # http://aws.amazon.com/apache2.0/ 8 | # 9 | # or in the "license" file accompanying this file. This file is 10 | # distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF 11 | # ANY KIND, either express or implied. See the License for the specific 12 | # language governing permissions and limitations under the License. 13 | """This module contains functionality related to the storage of intermediate 14 | training information in "opt/ml/output/intermediate". 15 | """ 16 | from __future__ import absolute_import 17 | 18 | import concurrent.futures as futures 19 | import multiprocessing 20 | import os 21 | import shutil 22 | import time 23 | 24 | import boto3 25 | import boto3.s3.transfer as s3transfer 26 | import inotify_simple 27 | from six.moves.urllib.parse import urlparse 28 | 29 | from sagemaker_training import environment, logging_config 30 | 31 | logger = logging_config.get_logger() 32 | 33 | intermediate_path = environment.output_intermediate_dir # type: str 34 | failure_file_path = os.path.join(environment.output_dir, "failure") # type: str 35 | success_file_path = os.path.join(environment.output_dir, "success") # type: str 36 | tmp_dir_path = os.path.join(intermediate_path, ".tmp.sagemaker_s3_sync") # type: str 37 | 38 | 39 | def _timestamp(): 40 | """Return a timestamp with microsecond precision.""" 41 | moment = time.time() 42 | moment_us = repr(moment).split(".")[1] 43 | return time.strftime("%Y-%m-%d-%H-%M-%S-{}".format(moment_us), time.gmtime(moment)) 44 | 45 | 46 | def _upload_to_s3(s3_uploader, relative_path, file_path, filename): 47 | """Upload a file to S3.""" 48 | try: 49 | key = os.path.join(s3_uploader["key_prefix"], relative_path, filename) 50 | s3_uploader["transfer"].upload_file(file_path, s3_uploader["bucket"], key) 51 | except FileNotFoundError: # noqa ignore=F821 52 | # Broken link or deleted 53 | pass 54 | except Exception: # pylint: disable=broad-except 55 | logger.exception("Failed to upload file to s3.") 56 | finally: 57 | # delete the original file 58 | if os.path.exists(file_path): 59 | os.remove(file_path) 60 | 61 | 62 | def _copy_file(executor, s3_uploader, relative_path, filename): 63 | """Copy a file to a temporary directory.""" 64 | try: 65 | src = os.path.join(intermediate_path, relative_path, filename) 66 | dst = os.path.join(tmp_dir_path, relative_path, "{}.{}".format(_timestamp(), filename)) 67 | shutil.copy2(src, dst) 68 | executor.submit(_upload_to_s3, s3_uploader, relative_path, dst, filename) 69 | except FileNotFoundError: # noqa ignore=F821 70 | # Broken link or deleted 71 | pass 72 | except Exception: # pylint: disable=broad-except 73 | logger.exception("Failed to copy file to the temporary directory.") 74 | 75 | 76 | def _watch(inotify, watchers, watch_flags, s3_uploader): 77 | """As soon as a user is done with a file under `/opt/ml/output/intermediate` 78 | we will be notified by inotify. We will copy this file under 79 | `/opt/ml/output/intermediate/.tmp.sagemaker_s3_sync` folder preserving 80 | the same folder structure to prevent it from being further modified. 81 | As we copy the file we will add timestamp with microseconds precision 82 | to avoid modification during s3 upload. 83 | After that we copy the file to s3 in a separate Thread. 84 | We keep the queue of the files we need to move as FIFO. 85 | """ 86 | # initialize a thread pool with 1 worker 87 | # to be used for uploading files to s3 in a separate thread 88 | executor = futures.ThreadPoolExecutor(max_workers=1) 89 | 90 | last_pass_done = False 91 | stop_file_exists = False 92 | 93 | # after we see stop file do one additional pass to make sure we didn't miss anything 94 | while not last_pass_done: # pylint: disable=too-many-nested-blocks 95 | # wait for any events in the directory for 1 sec and then re-check exit conditions 96 | for event in inotify.read(timeout=1000): 97 | for flag in inotify_simple.flags.from_mask(event.mask): 98 | # if new directory was created traverse the directory tree to recursively add all 99 | # created folders to the watchers list. 100 | # Upload files to s3 if there any files. 101 | # There is a potential race condition if upload the file and the see a notification 102 | # for it which should cause any problems because when we copy files to temp dir 103 | # we add a unique timestamp up to microseconds. 104 | if flag is inotify_simple.flags.ISDIR and inotify_simple.flags.CREATE & event.mask: 105 | path = os.path.join(intermediate_path, watchers[event.wd], event.name) 106 | for folder, _, files in os.walk(path): 107 | wd = inotify.add_watch(folder, watch_flags) 108 | relative_path = os.path.relpath(folder, intermediate_path) 109 | watchers[wd] = relative_path 110 | tmp_sub_folder = os.path.join(tmp_dir_path, relative_path) 111 | if not os.path.exists(tmp_sub_folder): 112 | os.makedirs(tmp_sub_folder) 113 | for file in files: 114 | _copy_file(executor, s3_uploader, relative_path, file) 115 | elif flag is inotify_simple.flags.CLOSE_WRITE: 116 | _copy_file(executor, s3_uploader, watchers[event.wd], event.name) 117 | 118 | last_pass_done = stop_file_exists 119 | stop_file_exists = os.path.exists(success_file_path) or os.path.exists(failure_file_path) 120 | 121 | # wait for all the s3 upload tasks to finish and shutdown the executor 122 | executor.shutdown(wait=True) 123 | 124 | 125 | def start_sync( 126 | s3_output_location, region, endpoint_url=None 127 | ): # pylint: disable=inconsistent-return-statements 128 | """Start intermediate folder sync, which copies files from 'opt/ml/output/intermediate' 129 | directory to the provided s3 output location as files created or modified. 130 | If files are deleted, it doesn't delete them from s3. 131 | 132 | It starts intermediate folder behavior as a daemonic process only if the directory 133 | doesn't exists yet. If the directory does exist, it indicates that the platform is 134 | taking care of syncing files to S3 and the container should not interfere. 135 | 136 | Args: 137 | s3_output_location (str): Name of the script or module. 138 | region (str): The location of the module. 139 | endpoint_url (str): An alternative endpoint URL to connect to. 140 | 141 | Returns: 142 | (multiprocessing.Process): The intermediate output sync daemonic process. 143 | """ 144 | if not s3_output_location or os.path.exists(intermediate_path): 145 | logger.debug("Could not initialize intermediate folder sync to s3.") 146 | return None 147 | 148 | # create intermediate and intermediate_tmp directories 149 | os.makedirs(intermediate_path) 150 | os.makedirs(tmp_dir_path) 151 | 152 | # configure unique s3 output location similar to how SageMaker platform does it 153 | # or link it to the local output directory 154 | url = urlparse(s3_output_location) 155 | if url.scheme == "file": 156 | logger.debug("Local directory is used for output. No need to sync any intermediate output.") 157 | return None 158 | elif url.scheme != "s3": 159 | raise ValueError("Expecting 's3' scheme, got: %s in %s" % (url.scheme, url)) 160 | 161 | # create s3 transfer client 162 | client = boto3.client("s3", region, endpoint_url=endpoint_url) 163 | s3_transfer = s3transfer.S3Transfer(client) 164 | s3_uploader = { 165 | "transfer": s3_transfer, 166 | "bucket": url.netloc, 167 | "key_prefix": os.path.join( 168 | url.path.lstrip("/"), os.environ.get("TRAINING_JOB_NAME", ""), "output", "intermediate" 169 | ), 170 | } 171 | 172 | # Add intermediate folder to the watch list 173 | inotify = inotify_simple.INotify() 174 | watch_flags = inotify_simple.flags.CLOSE_WRITE | inotify_simple.flags.CREATE 175 | watchers = {} 176 | wd = inotify.add_watch(intermediate_path, watch_flags) 177 | watchers[wd] = "" 178 | # start subprocess to sync any files from intermediate folder to s3 179 | p = multiprocessing.Process(target=_watch, args=[inotify, watchers, watch_flags, s3_uploader]) 180 | # Make the process daemonic as a safety switch to prevent training job from hanging forever 181 | # in case if something goes wrong and main container process exits in an unexpected way 182 | p.daemon = True 183 | p.start() 184 | return p 185 | -------------------------------------------------------------------------------- /src/sagemaker_training/logging_config.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the 'License'). You 4 | # may not use this file except in compliance with the License. A copy of 5 | # the License is located at 6 | # 7 | # http://aws.amazon.com/apache2.0/ 8 | # 9 | # or in the 'license' file accompanying this file. This file is 10 | # distributed on an 'AS IS' BASIS, WITHOUT WARRANTIES OR CONDITIONS OF 11 | # ANY KIND, either express or implied. See the License for the specific 12 | # language governing permissions and limitations under the License. 13 | """This module contains utilities related to logging.""" 14 | from __future__ import absolute_import 15 | 16 | import json 17 | import logging 18 | 19 | import sagemaker_training 20 | 21 | 22 | def get_logger(): 23 | """Return a logger with the name 'sagemaker-training-toolkit', 24 | creating it if necessary. 25 | """ 26 | return logging.getLogger("sagemaker-training-toolkit") 27 | 28 | 29 | def configure_logger(level, log_format="%(asctime)s %(name)-12s %(levelname)-8s %(message)s"): 30 | # type: (int, str) -> None 31 | """Set logger configuration. 32 | 33 | Args: 34 | level (int): Logger level. 35 | log_format (str): Logger format. 36 | """ 37 | logging.basicConfig(format=log_format, level=level) 38 | 39 | if level >= logging.INFO: 40 | logging.getLogger("boto3").setLevel(level) 41 | logging.getLogger("s3transfer").setLevel(level) 42 | logging.getLogger("botocore").setLevel(level) 43 | 44 | 45 | def log_script_invocation(cmd, env_vars, logger=None): 46 | """Log a message with level INFO including information on the user script invoked. 47 | 48 | Args: 49 | cmd (str): Command used to invoke the script. 50 | env_vars (dict): Environment variables. 51 | logger (logging.Logger): Logger used to log the message. 52 | """ 53 | logger = logger or get_logger() 54 | 55 | prefix = "\n".join(["%s=%s" % (key, value) for key, value in env_vars.items()]) 56 | env = sagemaker_training.environment.Environment() 57 | message = """Invoking user script 58 | 59 | Training Env: 60 | 61 | %s 62 | 63 | Environment variables: 64 | 65 | %s 66 | 67 | Invoking script with the following command: 68 | 69 | %s 70 | 71 | """ % ( 72 | json.dumps(dict(env), indent=4), 73 | prefix, 74 | " ".join(cmd), 75 | ) 76 | logger.info(message) 77 | -------------------------------------------------------------------------------- /src/sagemaker_training/mapping.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the 'License'). You 4 | # may not use this file except in compliance with the License. A copy of 5 | # the License is located at 6 | # 7 | # http://aws.amazon.com/apache2.0/ 8 | # 9 | # or in the 'license' file accompanying this file. This file is 10 | # distributed on an 'AS IS' BASIS, WITHOUT WARRANTIES OR CONDITIONS OF 11 | # ANY KIND, either express or implied. See the License for the specific 12 | # language governing permissions and limitations under the License. 13 | """This module contains utilities related to dictionaries. These include 14 | transforming dictionaries into environment variables, transforming into 15 | command arguments, splitting dictionaries by key, and functionality from 16 | the collections.Mapping abstract base class. 17 | """ 18 | from __future__ import absolute_import 19 | 20 | import collections 21 | import collections.abc 22 | import itertools 23 | import json 24 | 25 | import six 26 | 27 | SplitResultSpec = collections.namedtuple("SplitResultSpec", "included excluded") 28 | 29 | 30 | def to_env_vars(mapping): # type: (dict) -> dict 31 | """Transform a dictionary in a dictionary of env vars. 32 | Example: 33 | >>>env_vars = mapping.to_env_vars({'model_dir': '/opt/ml/model', 'batch_size': 25}) 34 | >>> 35 | >>>print(args) 36 | ['MODEL_DIR', '/opt/ml/model', 'BATCH_SIZE', 25] 37 | Args: 38 | mapping (dict[str, object]): A Python mapping. 39 | Returns: 40 | (dict): Dictionary of env vars. 41 | """ 42 | 43 | def format_key(key): 44 | """Decode a key, adds a SM_ prefix to the key and upper case it.""" 45 | if key: 46 | decoded_name = "SM_%s" % str(key).upper() 47 | return decoded_name 48 | else: 49 | return "" 50 | 51 | def format_value(_mapping): 52 | if six.PY3 and isinstance(_mapping, six.binary_type): 53 | # transforms a byte string (b'') in unicode 54 | return _mapping.decode("latin1") 55 | elif _mapping is None: 56 | return "" 57 | elif isinstance(_mapping, six.string_types): 58 | return str(_mapping) 59 | else: 60 | return json.dumps(_mapping, sort_keys=True, separators=(",", ":"), ensure_ascii=True) 61 | 62 | return {format_key(k): format_value(v) for k, v in mapping.items()} 63 | 64 | 65 | def to_cmd_args(mapping): # type: (dict) -> list 66 | """Transform a dictionary in a list of cmd arguments. 67 | Example: 68 | >>>args = mapping.to_cmd_args({'model_dir': '/opt/ml/model', 'batch_size': 25}) 69 | >>> 70 | >>>print(args) 71 | ['--model_dir', '/opt/ml/model', '--batch_size', 25] 72 | Args: 73 | mapping (dict[str, object]): A Python mapping. 74 | Returns: 75 | (list): List of cmd arguments. 76 | """ 77 | 78 | sorted_keys = sorted(mapping.keys()) 79 | 80 | def arg_name(obj): 81 | string = _decode(obj) 82 | if string: 83 | return "--%s" % string if len(string) > 1 else "-%s" % string 84 | else: 85 | return "" 86 | 87 | arg_names = [arg_name(argument) for argument in sorted_keys] 88 | 89 | def arg_value(value): 90 | if hasattr(value, "items"): 91 | map_items = ["%s=%s" % (k, v) for k, v in sorted(value.items())] 92 | return ",".join(map_items) 93 | return _decode(value) 94 | 95 | arg_values = [arg_value(mapping[key]) for key in sorted_keys] 96 | 97 | items = zip(arg_names, arg_values) 98 | 99 | return [item for item in itertools.chain.from_iterable(items)] 100 | 101 | 102 | def _decode(obj): # type: (bytes or str or unicode or object) -> unicode # noqa ignore=F821 103 | """Decode an object to unicode. 104 | Args: 105 | obj (bytes or str or unicode or anything serializable): Object to be decoded. 106 | Returns: 107 | Object decoded in unicode. 108 | """ 109 | if obj is None: 110 | return "" 111 | if six.PY3 and isinstance(obj, six.binary_type): 112 | # transforms a byte string (b'') in unicode 113 | return obj.decode("latin1") 114 | elif six.PY3: 115 | # PY3 strings are unicode. 116 | return str(obj) 117 | elif isinstance(obj, six.text_type): 118 | # returns itself if it is unicode 119 | return obj 120 | else: 121 | # decodes pY2 string to unicode 122 | return str(obj).decode("utf-8") 123 | 124 | 125 | def split_by_criteria( 126 | dictionary, keys=None, prefix=None 127 | ): # type: (dict, set or list or tuple) -> SplitResultSpec 128 | """Split a dictionary in two by the provided keys. 129 | 130 | Args: 131 | dictionary (dict[str, object]): A Python dictionary. 132 | keys (sequence [str]): A sequence of keys which will be added the split criteria. 133 | prefix (str): A prefix which will be added the split criteria. 134 | 135 | Returns: 136 | `SplitResultSpec` : A collections.namedtuple with the following attributes: 137 | 138 | * Args: 139 | included (dict[str, object]: A dictionary with the keys included in the criteria. 140 | excluded (dict[str, object]: A dictionary with the keys not included in the 141 | criteria. 142 | """ 143 | keys = keys or [] 144 | keys = set(keys) 145 | 146 | included_items = { 147 | k: dictionary[k] 148 | for k in dictionary.keys() 149 | if k in keys or (prefix and k.startswith(prefix)) 150 | } 151 | excluded_items = {k: dictionary[k] for k in dictionary.keys() if k not in included_items} 152 | 153 | return SplitResultSpec(included=included_items, excluded=excluded_items) 154 | 155 | 156 | class MappingMixin(collections.abc.Mapping): 157 | """A mixin class that allows for the creation of a dictionary like object, 158 | with any built-in function that works with a dictionary. This is used by the 159 | environment._Env base class. 160 | """ 161 | 162 | def properties(self): # type: () -> list 163 | """ 164 | Returns: 165 | (list[str]) List of public properties. 166 | """ 167 | 168 | _type = type(self) 169 | return [_property for _property in dir(_type) if self._is_property(_property)] 170 | 171 | def _is_property(self, _property): 172 | return isinstance(getattr(type(self), _property), property) 173 | 174 | def __getitem__(self, k): 175 | """Built-in method override.""" 176 | if not self._is_property(k): 177 | raise KeyError("Trying to access non property %s" % k) 178 | return getattr(self, k) 179 | 180 | def __len__(self): 181 | """Built-in method override.""" 182 | return len(self.properties()) 183 | 184 | def __iter__(self): 185 | """Built-in method override.""" 186 | items = {_property: getattr(self, _property) for _property in self.properties()} 187 | return iter(items) 188 | 189 | def __str__(self): 190 | """Built-in method override.""" 191 | return str(dict(self)) 192 | -------------------------------------------------------------------------------- /src/sagemaker_training/modules.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the 'License'). You 4 | # may not use this file except in compliance with the License. A copy of 5 | # the License is located at 6 | # 7 | # http://aws.amazon.com/apache2.0/ 8 | # 9 | # or in the 'license' file accompanying this file. This file is 10 | # distributed on an 'AS IS' BASIS, WITHOUT WARRANTIES OR CONDITIONS OF 11 | # ANY KIND, either express or implied. See the License for the specific 12 | # language governing permissions and limitations under the License. 13 | """This module contains functionality related to preparing, installing, 14 | and importing Python modules. 15 | """ 16 | from __future__ import absolute_import 17 | 18 | import importlib 19 | import os 20 | import re 21 | import shlex 22 | import sys 23 | import textwrap 24 | 25 | import boto3 26 | import six 27 | 28 | from sagemaker_training import environment, errors, files, logging_config, process 29 | 30 | logger = logging_config.get_logger() 31 | 32 | DEFAULT_MODULE_NAME = "default_user_module_name" 33 | CA_REPOSITORY_ARN_ENV = "CA_REPOSITORY_ARN" 34 | 35 | 36 | def exists(name): # type: (str) -> bool 37 | """Return True if the module exists. Return False otherwise. 38 | 39 | Args: 40 | name (str): Module name. 41 | 42 | Returns: 43 | (bool): Boolean indicating if the module exists or not. 44 | """ 45 | try: 46 | importlib.import_module(name) 47 | except ImportError: 48 | return False 49 | else: 50 | return True 51 | 52 | 53 | def has_requirements(path): # type: (str) -> None 54 | """Check whether a directory contains a requirements.txt file. 55 | 56 | Args: 57 | path (str): Path to the directory to check for the requirements.txt file. 58 | 59 | Returns: 60 | (bool): Whether the directory contains a requirements.txt file. 61 | """ 62 | return os.path.exists(os.path.join(path, "requirements.txt")) 63 | 64 | 65 | def prepare(path, name): # type: (str, str) -> None 66 | """Prepare a Python script (or module) to be imported as a module. 67 | If the script does not contain a setup.py file, it creates a minimal setup. 68 | 69 | Args: 70 | path (str): Path to directory with the script or module. 71 | name (str): Name of the script or module. 72 | """ 73 | setup_path = os.path.join(path, "setup.py") 74 | if not os.path.exists(setup_path): 75 | data = textwrap.dedent( 76 | """ 77 | from setuptools import setup 78 | setup(packages=[''], 79 | name="%s", 80 | version='1.0.0', 81 | include_package_data=True) 82 | """ 83 | % name 84 | ) 85 | 86 | logger.info("Module %s does not provide a setup.py. \nGenerating setup.py" % name) 87 | 88 | files.write_file(setup_path, data) 89 | 90 | data = textwrap.dedent( 91 | """ 92 | [wheel] 93 | universal = 1 94 | """ 95 | ) 96 | 97 | logger.info("Generating setup.cfg") 98 | 99 | files.write_file(os.path.join(path, "setup.cfg"), data) 100 | 101 | data = textwrap.dedent( 102 | """ 103 | recursive-include . * 104 | recursive-exclude . __pycache__* 105 | recursive-exclude . *.pyc 106 | recursive-exclude . *.pyo 107 | """ 108 | ) 109 | 110 | logger.info("Generating MANIFEST.in") 111 | 112 | files.write_file(os.path.join(path, "MANIFEST.in"), data) 113 | 114 | 115 | def install(path, capture_error=False): # type: (str, bool) -> None 116 | """Install a Python module in the executing Python environment. 117 | 118 | Args: 119 | path (str): Real path location of the Python module. 120 | capture_error (bool): Default false. If True, the running process captures the 121 | stderr, and appends it to the returned Exception message in case of errors. 122 | """ 123 | cmd = "%s -m pip install . " % process.python_executable() 124 | 125 | if has_requirements(path): 126 | cmd += "-r requirements.txt" 127 | if os.getenv(CA_REPOSITORY_ARN_ENV): 128 | index = _get_codeartifact_index() 129 | cmd += " -i {}".format(index) 130 | 131 | logger.info("Installing module.") 132 | 133 | process.check_error( 134 | shlex.split(cmd), errors.InstallModuleError, 1, cwd=path, capture_error=capture_error 135 | ) 136 | 137 | 138 | def install_requirements(path, capture_error=False): # type: (str, bool) -> None 139 | """Install dependencies from requirements.txt in the executing Python environment. 140 | 141 | Args: 142 | path (str): Real path location of the requirements.txt file. 143 | capture_error (bool): Default false. If True, the running process captures the 144 | stderr, and appends it to the returned Exception message in case of errors. 145 | """ 146 | cmd = "{} -m pip install -r requirements.txt".format(process.python_executable()) 147 | if os.getenv(CA_REPOSITORY_ARN_ENV): 148 | index = _get_codeartifact_index() 149 | cmd += " -i {}".format(index) 150 | 151 | logger.info("Installing dependencies from requirements.txt") 152 | 153 | process.check_error( 154 | shlex.split(cmd), errors.InstallRequirementsError, 1, cwd=path, capture_error=capture_error 155 | ) 156 | 157 | 158 | def import_module(uri, name=DEFAULT_MODULE_NAME): # type: (str, str) -> module 159 | """Download, prepare and install a compressed tar file from S3 or provided directory as a 160 | module. 161 | SageMaker Python SDK saves the user provided scripts as compressed tar files in S3 162 | https://github.com/aws/sagemaker-python-sdk. 163 | This function downloads this compressed file (if provided), transforms it as a module, and 164 | installs it. 165 | 166 | Args: 167 | name (str): Name of the script or module. 168 | uri (str): The location of the module. 169 | 170 | Returns: 171 | (module): The imported module. 172 | """ 173 | files.download_and_extract(uri, environment.code_dir) 174 | 175 | prepare(environment.code_dir, name) 176 | install(environment.code_dir) 177 | try: 178 | module = importlib.import_module(name) 179 | six.moves.reload_module(module) # pylint: disable=too-many-function-args 180 | 181 | return module 182 | except Exception as e: # pylint: disable=broad-except 183 | six.reraise(errors.ImportModuleError, errors.ImportModuleError(e), sys.exc_info()[2]) 184 | 185 | 186 | def _get_codeartifact_index(): 187 | """ 188 | Build the authenticated codeartifact index url based on the arn provided 189 | via CA_REPOSITORY_ARN environment variable following the form 190 | `arn:${Partition}:codeartifact:${Region}:${Account}:repository/${DomainName}/${RepositoryName}` 191 | https://docs.aws.amazon.com/codeartifact/latest/ug/python-configure-pip.html 192 | https://docs.aws.amazon.com/service-authorization/latest/reference/list_awscodeartifact.html#awscodeartifact-resources-for-iam-policies 193 | :return: authenticated codeartifact index url 194 | """ 195 | repository_arn = os.getenv(CA_REPOSITORY_ARN_ENV) 196 | arn_regex = ( 197 | "arn:(?P[^:]+):codeartifact:(?P[^:]+):(?P[^:]+)" 198 | ":repository/(?P[^/]+)/(?P.+)" 199 | ) 200 | m = re.match(arn_regex, repository_arn) 201 | if not m: 202 | raise Exception("invalid CodeArtifact repository arn {}".format(repository_arn)) 203 | domain = m.group("domain") 204 | owner = m.group("account") 205 | repository = m.group("repository") 206 | region = m.group("region") 207 | 208 | logger.info( 209 | "configuring pip to use codeartifact " 210 | "(domain: %s, domain owner: %s, repository: %s, region: %s)", 211 | domain, 212 | owner, 213 | repository, 214 | region, 215 | ) 216 | try: 217 | client = boto3.client("codeartifact", region_name=region) 218 | auth_token_response = client.get_authorization_token(domain=domain, domainOwner=owner) 219 | token = auth_token_response["authorizationToken"] 220 | endpoint_response = client.get_repository_endpoint( 221 | domain=domain, domainOwner=owner, repository=repository, format="pypi" 222 | ) 223 | unauthenticated_index = endpoint_response["repositoryEndpoint"] 224 | return re.sub( 225 | "https://", 226 | "https://aws:{}@".format(token), 227 | re.sub( 228 | "{}/?$".format(repository), 229 | "{}/simple/".format(repository), 230 | unauthenticated_index, 231 | ), 232 | ) 233 | except Exception: 234 | logger.error("failed to configure pip to use codeartifact") 235 | raise Exception("failed to configure pip to use codeartifact") 236 | -------------------------------------------------------------------------------- /src/sagemaker_training/params.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"). You 4 | # may not use this file except in compliance with the License. A copy of 5 | # the License is located at 6 | # 7 | # http://aws.amazon.com/apache2.0/ 8 | # 9 | # or in the "license" file accompanying this file. This file is 10 | # distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF 11 | # ANY KIND, either express or implied. See the License for the specific 12 | # language governing permissions and limitations under the License. 13 | """This module contains string constants representing environment variables 14 | and related parameters. 15 | """ 16 | from __future__ import absolute_import 17 | 18 | SAGEMAKER_PREFIX = "sagemaker_" # type: str 19 | CURRENT_HOST_ENV = "CURRENT_HOST" # type: str 20 | USER_PROGRAM_PARAM = "sagemaker_program" # type: str 21 | USER_PROGRAM_ENV = USER_PROGRAM_PARAM.upper() # type: str 22 | S3_OUTPUT_LOCATION_PARAM = "sagemaker_s3_output" # type: str 23 | S3_OUTPUT_LOCATION_ENV = S3_OUTPUT_LOCATION_PARAM.upper() # type: str 24 | S3_ENDPOINT_URL = "S3_ENDPOINT_URL" # type: str 25 | TRAINING_JOB_ENV = "TRAINING_JOB_NAME" # type: str 26 | SUBMIT_DIR_PARAM = "sagemaker_submit_directory" # type: str 27 | SUBMIT_DIR_ENV = SUBMIT_DIR_PARAM.upper() # type: str 28 | ENABLE_METRICS_PARAM = "sagemaker_enable_cloudwatch_metrics" # type: str 29 | ENABLE_METRICS_ENV = ENABLE_METRICS_PARAM.upper() # type: str 30 | LOG_LEVEL_PARAM = "sagemaker_container_log_level" # type: str 31 | LOG_LEVEL_ENV = LOG_LEVEL_PARAM.upper() # type: str 32 | JOB_NAME_PARAM = "sagemaker_job_name" # type: str 33 | JOB_NAME_ENV = JOB_NAME_PARAM.upper() # type: str 34 | TUNING_METRIC_PARAM = "_tuning_objective_metric" # type: str 35 | DEFAULT_MODULE_NAME_PARAM = "default_user_module_name" # type: str 36 | MPI_ENABLED = "sagemaker_mpi_enabled" # type: str 37 | PARAMETER_SERVER_ENABLED = "sagemaker_parameter_server_enabled" # type: str 38 | MULTI_WORKER_MIRRORED_STRATEGY_ENABLED = ( 39 | "sagemaker_multi_worker_mirrored_strategy_enabled" 40 | ) # type: str 41 | PYTORCH_XLA_MULTI_WORKER_ENABLED = "sagemaker_pytorch_xla_multi_worker_enabled" # type: str 42 | REGION_NAME_PARAM = "sagemaker_region" # type: str 43 | REGION_NAME_ENV = REGION_NAME_PARAM.upper() # type: str 44 | DEFAULT_INVOCATIONS_ACCEPT_ENV = "SAGEMAKER_DEFAULT_INVOCATIONS_ACCEPT" # type: str 45 | SAGEMAKER_BIND_TO_PORT_ENV = "SAGEMAKER_BIND_TO_PORT" # type: str 46 | SAGEMAKER_SAFE_PORT_RANGE_ENV = "SAGEMAKER_SAFE_PORT_RANGE" # type: str 47 | FRAMEWORK_TRAINING_MODULE_ENV = "SAGEMAKER_TRAINING_MODULE" # type: str 48 | SAGEMAKER_HYPERPARAMETERS = ( 49 | USER_PROGRAM_PARAM, 50 | SUBMIT_DIR_PARAM, 51 | ENABLE_METRICS_PARAM, 52 | REGION_NAME_PARAM, 53 | LOG_LEVEL_PARAM, 54 | JOB_NAME_PARAM, 55 | DEFAULT_MODULE_NAME_PARAM, 56 | TUNING_METRIC_PARAM, 57 | S3_OUTPUT_LOCATION_PARAM, 58 | ) # type: tuple 59 | MPI_PROCESSES_PER_HOST = "sagemaker_mpi_num_of_processes_per_host" # type: int 60 | MPI_NUM_PROCESSES = "sagemaker_mpi_num_processes" # type: int 61 | MPI_CUSTOM_OPTIONS = "sagemaker_mpi_custom_mpi_options" # type: str 62 | SAGEMAKER_NETWORK_INTERFACE_NAME = "sagemaker_network_interface_name" # type: str 63 | SMDATAPARALLEL_CUSTOM_MPI_OPTIONS = ( 64 | "sagemaker_distributed_dataparallel_custom_mpi_options" 65 | ) # type: str 66 | SM_HP_MP_PARAMETERS = "SM_HP_MP_PARAMETERS" 67 | DISTRIBUTION_INSTANCE_GROUPS = "sagemaker_distribution_instance_groups" # type: list 68 | -------------------------------------------------------------------------------- /src/sagemaker_training/pytorch_xla.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018-2022 Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the 'License'). You 4 | # may not use this file except in compliance with the License. A copy of 5 | # the License is located at 6 | # 7 | # http://aws.amazon.com/apache2.0/ 8 | # 9 | # or in the 'license' file accompanying this file. This file is 10 | # distributed on an 'AS IS' BASIS, WITHOUT WARRANTIES OR CONDITIONS OF 11 | # ANY KIND, either express or implied. See the License for the specific 12 | # language governing permissions and limitations under the License. 13 | """This module contains functionality related to distributed training using 14 | PT-XLA (PyTorch - Accelerated Linear Algebra).""" 15 | from __future__ import absolute_import 16 | 17 | import os 18 | 19 | from sagemaker_training import ( 20 | _entry_point_type, 21 | environment, 22 | errors, 23 | logging_config, 24 | process, 25 | ) 26 | 27 | 28 | logger = logging_config.get_logger() 29 | 30 | 31 | class PyTorchXLARunner(process.ProcessRunner): 32 | """Responsible for PT-XLA distributed training.""" 33 | 34 | MESH_SERVICE_PORT = 53957 35 | WORKER_PORT = 43857 36 | 37 | def __init__( 38 | self, 39 | user_entry_point, 40 | args, 41 | env_vars, 42 | processes_per_host, 43 | master_hostname, 44 | current_host, 45 | hosts, 46 | num_gpus, 47 | ): 48 | """Initialize a PyTorchXLARunner, which is responsible for distributed 49 | training with PT-XLA. 50 | 51 | Args: 52 | user_entry_point (str): The name of the user entry point. 53 | args ([str]): A list of arguments to include when executing the entry point. 54 | env_vars (dict(str,str)): A dictionary of environment variables. 55 | master_hostname (str): The master hostname. 56 | current_host (str): The current hostname. 57 | hosts ([str]): A list of hosts. 58 | num_gpus (int): The number of GPUs available per host. 59 | """ 60 | 61 | super(PyTorchXLARunner, self).__init__(user_entry_point, args, env_vars, processes_per_host) 62 | 63 | self._master_hostname = master_hostname 64 | self._current_host = current_host 65 | self._hosts = hosts 66 | self._num_gpus = num_gpus 67 | 68 | self._num_hosts = len(self._hosts) 69 | self._rank = self._hosts.index(self._current_host) 70 | 71 | def _setup(self): # type: () -> None 72 | logger.info("Starting distributed training through PT-XLA Runtime.") 73 | self._check_compatibility() 74 | 75 | # Set NCCL logging to info to debug customer issues 76 | os.environ["NCCL_DEBUG"] = "info" 77 | 78 | # Use `simple` protocol to handle the out-of-order data delivery from EFA 79 | os.environ["NCCL_PROTO"] = "simple" 80 | 81 | # Use GPU RDMA when available (available only in p4d.24xlarge) 82 | os.environ["FI_EFA_USE_DEVICE_RDMA"] = "1" 83 | 84 | # Use multiple connections per GPU to better saturate the EFA bandwidth 85 | os.environ["OFI_NCCL_NIC_DUP_CONNS"] = str(self._num_gpus) 86 | 87 | # Set cluster configuration for XLA runtime 88 | os.environ["XRT_HOST_ORDINAL"] = str(self._rank) 89 | os.environ["XRT_SHARD_WORLD_SIZE"] = str(self._num_hosts) 90 | address = "localservice:{};{}:" + str(self.WORKER_PORT) 91 | os.environ["XRT_WORKERS"] = "|".join( 92 | [address.format(i, host) for i, host in enumerate(self._hosts)] 93 | ) 94 | os.environ["GPU_NUM_DEVICES"] = str(self._num_gpus) 95 | if self._num_hosts > 1: 96 | os.environ[ 97 | "XRT_MESH_SERVICE_ADDRESS" 98 | ] = f"{self._master_hostname}:{self.MESH_SERVICE_PORT}" 99 | 100 | logger.info("Completed environment setup for distributed training through PT-XLA Runtime.") 101 | 102 | def _create_command(self): 103 | entrypoint_type = _entry_point_type.get(environment.code_dir, self._user_entry_point) 104 | 105 | if entrypoint_type is _entry_point_type.PYTHON_PACKAGE: 106 | raise errors.SMTrainingCompilerConfigurationError( 107 | "Distributed Training through PT-XLA is not supported for Python packages. " 108 | "Please use a python script as the entry-point" 109 | ) 110 | if entrypoint_type is _entry_point_type.PYTHON_PROGRAM: 111 | return self._pytorch_xla_command() + [self._user_entry_point] + self._args 112 | else: 113 | raise errors.SMTrainingCompilerConfigurationError( 114 | "Distributed Training through PT-XLA is only supported for Python scripts. " 115 | "Please use a python script as the entry-point" 116 | ) 117 | 118 | def _pytorch_xla_command(self): 119 | return self._python_command() + [ 120 | "-m", 121 | "torch_xla.distributed.xla_spawn", 122 | "--num_gpus", 123 | str(self._num_gpus), 124 | ] 125 | 126 | def _check_compatibility(self): 127 | self._check_processor_compatibility() 128 | self._check_for_torch_xla() 129 | self._check_for_sagemaker_integration() 130 | 131 | def _check_for_sagemaker_integration(self): 132 | # pylint: disable=no-self-use 133 | try: 134 | import torch_xla.distributed.xla_spawn # pylint: disable=unused-import # noqa: F401 135 | except ModuleNotFoundError as exception: 136 | raise errors.SMTrainingCompilerConfigurationError( 137 | "Unable to find SageMaker integration code in PT-XLA. " 138 | "AWS SageMaker adds custom code on top of open source " 139 | "PT-XLA to provide platform specific " 140 | "optimizations. These SageMaker specific binaries are" 141 | " shipped as part of our Deep Learning Containers." 142 | " Please refer to " 143 | "https://github.com/aws/deep-learning-containers" 144 | "/blob/master/available_images.md" 145 | ) from exception 146 | 147 | def _check_for_torch_xla(self): 148 | # pylint: disable=no-self-use 149 | try: 150 | import torch_xla # pylint: disable=unused-import # noqa: F401 151 | except ModuleNotFoundError as exception: 152 | raise errors.SMTrainingCompilerConfigurationError( 153 | "Unable to find PT-XLA in the execution environment. " 154 | "This distribution mechanism requires PT-XLA to be available" 155 | " in the execution environment. " 156 | "SageMaker Training Compiler provides ready-to-use containers with PT-XLA. " 157 | "Please refer to https://github.com/aws/deep-learning-containers" 158 | "/blob/master/available_images.md " 159 | ) from exception 160 | 161 | def _check_processor_compatibility(self): 162 | if not self._num_gpus > 0: 163 | raise errors.SMTrainingCompilerConfigurationError( 164 | "Distributed training through PT-XLA is only supported for GPUs." 165 | ) 166 | -------------------------------------------------------------------------------- /src/sagemaker_training/record_pb2.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Generated by the protocol buffer compiler. DO NOT EDIT! 3 | # source: record.proto 4 | """Generated protocol buffer code.""" 5 | from google.protobuf import descriptor as _descriptor 6 | from google.protobuf import descriptor_pool as _descriptor_pool 7 | from google.protobuf import symbol_database as _symbol_database 8 | from google.protobuf.internal import builder as _builder 9 | 10 | # @@protoc_insertion_point(imports) 11 | 12 | _sym_db = _symbol_database.Default() 13 | 14 | 15 | DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile( 16 | b'\n\x0crecord.proto\x12\x0b\x61ialgs.data"H\n\rFloat32Tensor\x12\x12\n\x06values\x18\x01 \x03(\x02\x42\x02\x10\x01\x12\x10\n\x04keys\x18\x02 \x03(\x04\x42\x02\x10\x01\x12\x11\n\x05shape\x18\x03 \x03(\x04\x42\x02\x10\x01"H\n\rFloat64Tensor\x12\x12\n\x06values\x18\x01 \x03(\x01\x42\x02\x10\x01\x12\x10\n\x04keys\x18\x02 \x03(\x04\x42\x02\x10\x01\x12\x11\n\x05shape\x18\x03 \x03(\x04\x42\x02\x10\x01"F\n\x0bInt32Tensor\x12\x12\n\x06values\x18\x01 \x03(\x05\x42\x02\x10\x01\x12\x10\n\x04keys\x18\x02 \x03(\x04\x42\x02\x10\x01\x12\x11\n\x05shape\x18\x03 \x03(\x04\x42\x02\x10\x01",\n\x05\x42ytes\x12\r\n\x05value\x18\x01 \x03(\x0c\x12\x14\n\x0c\x63ontent_type\x18\x02 \x01(\t"\xd3\x01\n\x05Value\x12\x34\n\x0e\x66loat32_tensor\x18\x02 \x01(\x0b\x32\x1a.aialgs.data.Float32TensorH\x00\x12\x34\n\x0e\x66loat64_tensor\x18\x03 \x01(\x0b\x32\x1a.aialgs.data.Float64TensorH\x00\x12\x30\n\x0cint32_tensor\x18\x07 \x01(\x0b\x32\x18.aialgs.data.Int32TensorH\x00\x12#\n\x05\x62ytes\x18\t \x01(\x0b\x32\x12.aialgs.data.BytesH\x00\x42\x07\n\x05value"\xa9\x02\n\x06Record\x12\x33\n\x08\x66\x65\x61tures\x18\x01 \x03(\x0b\x32!.aialgs.data.Record.FeaturesEntry\x12-\n\x05label\x18\x02 \x03(\x0b\x32\x1e.aialgs.data.Record.LabelEntry\x12\x0b\n\x03uid\x18\x03 \x01(\t\x12\x10\n\x08metadata\x18\x04 \x01(\t\x12\x15\n\rconfiguration\x18\x05 \x01(\t\x1a\x43\n\rFeaturesEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12!\n\x05value\x18\x02 \x01(\x0b\x32\x12.aialgs.data.Value:\x02\x38\x01\x1a@\n\nLabelEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12!\n\x05value\x18\x02 \x01(\x0b\x32\x12.aialgs.data.Value:\x02\x38\x01\x42\x30\n com.amazonaws.aialgorithms.protoB\x0cRecordProtos' 17 | ) 18 | 19 | _globals = globals() 20 | _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) 21 | _builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, "record_pb2", _globals) 22 | if _descriptor._USE_C_DESCRIPTORS == False: 23 | 24 | DESCRIPTOR._options = None 25 | DESCRIPTOR._serialized_options = b"\n com.amazonaws.aialgorithms.protoB\014RecordProtos" 26 | _FLOAT32TENSOR.fields_by_name["values"]._options = None 27 | _FLOAT32TENSOR.fields_by_name["values"]._serialized_options = b"\020\001" 28 | _FLOAT32TENSOR.fields_by_name["keys"]._options = None 29 | _FLOAT32TENSOR.fields_by_name["keys"]._serialized_options = b"\020\001" 30 | _FLOAT32TENSOR.fields_by_name["shape"]._options = None 31 | _FLOAT32TENSOR.fields_by_name["shape"]._serialized_options = b"\020\001" 32 | _FLOAT64TENSOR.fields_by_name["values"]._options = None 33 | _FLOAT64TENSOR.fields_by_name["values"]._serialized_options = b"\020\001" 34 | _FLOAT64TENSOR.fields_by_name["keys"]._options = None 35 | _FLOAT64TENSOR.fields_by_name["keys"]._serialized_options = b"\020\001" 36 | _FLOAT64TENSOR.fields_by_name["shape"]._options = None 37 | _FLOAT64TENSOR.fields_by_name["shape"]._serialized_options = b"\020\001" 38 | _INT32TENSOR.fields_by_name["values"]._options = None 39 | _INT32TENSOR.fields_by_name["values"]._serialized_options = b"\020\001" 40 | _INT32TENSOR.fields_by_name["keys"]._options = None 41 | _INT32TENSOR.fields_by_name["keys"]._serialized_options = b"\020\001" 42 | _INT32TENSOR.fields_by_name["shape"]._options = None 43 | _INT32TENSOR.fields_by_name["shape"]._serialized_options = b"\020\001" 44 | _RECORD_FEATURESENTRY._options = None 45 | _RECORD_FEATURESENTRY._serialized_options = b"8\001" 46 | _RECORD_LABELENTRY._options = None 47 | _RECORD_LABELENTRY._serialized_options = b"8\001" 48 | _globals["_FLOAT32TENSOR"]._serialized_start = 29 49 | _globals["_FLOAT32TENSOR"]._serialized_end = 101 50 | _globals["_FLOAT64TENSOR"]._serialized_start = 103 51 | _globals["_FLOAT64TENSOR"]._serialized_end = 175 52 | _globals["_INT32TENSOR"]._serialized_start = 177 53 | _globals["_INT32TENSOR"]._serialized_end = 247 54 | _globals["_BYTES"]._serialized_start = 249 55 | _globals["_BYTES"]._serialized_end = 293 56 | _globals["_VALUE"]._serialized_start = 296 57 | _globals["_VALUE"]._serialized_end = 507 58 | _globals["_RECORD"]._serialized_start = 510 59 | _globals["_RECORD"]._serialized_end = 807 60 | _globals["_RECORD_FEATURESENTRY"]._serialized_start = 674 61 | _globals["_RECORD_FEATURESENTRY"]._serialized_end = 741 62 | _globals["_RECORD_LABELENTRY"]._serialized_start = 743 63 | _globals["_RECORD_LABELENTRY"]._serialized_end = 807 64 | # @@protoc_insertion_point(module_scope) 65 | -------------------------------------------------------------------------------- /src/sagemaker_training/runner.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018-2021 Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the 'License'). You 4 | # may not use this file except in compliance with the License. A copy of 5 | # the License is located at 6 | # 7 | # http://aws.amazon.com/apache2.0/ 8 | # 9 | # or in the 'license' file accompanying this file. This file is 10 | # distributed on an 'AS IS' BASIS, WITHOUT WARRANTIES OR CONDITIONS OF 11 | # ANY KIND, either express or implied. See the License for the specific 12 | # language governing permissions and limitations under the License. 13 | """This module contains functionality to get process runners based on the 14 | runner type. 15 | """ 16 | from __future__ import absolute_import 17 | 18 | import enum 19 | 20 | from sagemaker_training import ( 21 | environment, 22 | mpi, 23 | params, 24 | process, 25 | pytorch_xla, 26 | smdataparallel, 27 | torch_distributed, 28 | ) 29 | 30 | 31 | class RunnerType(enum.Enum): 32 | """Enumerated type consisting of valid types of runners.""" 33 | 34 | MPI = "MPI" 35 | Process = "Process" 36 | SMDataParallel = "SMDataParallel" 37 | PyTorchXLA = "PyTorchXLA" 38 | TorchDistributed = "TorchDistributed" 39 | 40 | 41 | ProcessRunnerType = RunnerType.Process 42 | MPIRunnerType = RunnerType.MPI 43 | SMDataParallelRunnerType = RunnerType.SMDataParallel 44 | PyTorchXLARunnerType = RunnerType.PyTorchXLA 45 | TorchDistributedRunnerType = RunnerType.TorchDistributed 46 | 47 | 48 | def get(identifier, user_entry_point=None, args=None, env_vars=None, extra_opts=None): 49 | """Get the process runner based on the runner type. 50 | 51 | Args: 52 | identifier (RunnerType or process.ProcessRunner): The type of runner to get. 53 | user_entry_point (str): The name of the user entry point. 54 | args ([str]): A list of arguments to include when executing the entry point. 55 | env_vars (dict(str,str)): A dictionary of environment variables. 56 | extra_opts (dict): A dictionary of extra arguments for MPI. 57 | 58 | Returns: 59 | process.Runner: The process. 60 | """ 61 | if isinstance(identifier, process.ProcessRunner): 62 | return identifier 63 | else: 64 | return _get_by_runner_type(identifier, user_entry_point, args, env_vars, extra_opts) 65 | 66 | 67 | def _get_by_runner_type( 68 | identifier, user_entry_point=None, args=None, env_vars=None, extra_opts=None 69 | ): 70 | env = environment.Environment() 71 | user_entry_point = user_entry_point or env.user_entry_point 72 | args = args or env.to_cmd_args() 73 | env_vars = env_vars or env.to_env_vars() 74 | mpi_args = extra_opts or {} 75 | 76 | # Default to single process for CPU 77 | default_processes_per_host = ( 78 | int(env.num_gpus) 79 | if int(env.num_gpus) > 0 80 | else int(env.num_neurons) 81 | if int(env.num_neurons) > 0 82 | else 1 83 | ) 84 | 85 | processes_per_host = _mpi_param_value( 86 | mpi_args, env, params.MPI_PROCESSES_PER_HOST, default_processes_per_host 87 | ) 88 | 89 | if identifier is RunnerType.SMDataParallel and env.is_master: 90 | custom_mpi_options = _mpi_param_value( 91 | mpi_args, env, params.SMDATAPARALLEL_CUSTOM_MPI_OPTIONS, "" 92 | ) 93 | return smdataparallel.SMDataParallelRunner( 94 | user_entry_point, 95 | args, 96 | env_vars, 97 | processes_per_host, 98 | env.master_hostname, 99 | env.distribution_hosts, 100 | custom_mpi_options, 101 | env.network_interface_name, 102 | ) 103 | elif identifier is RunnerType.SMDataParallel: 104 | return mpi.WorkerRunner( 105 | user_entry_point, 106 | args, 107 | env_vars, 108 | processes_per_host, 109 | env.master_hostname, 110 | env.current_host, 111 | ) 112 | elif identifier is RunnerType.TorchDistributed: 113 | return torch_distributed.TorchDistributedRunner( 114 | user_entry_point, 115 | args, 116 | env_vars, 117 | processes_per_host, 118 | env.master_hostname, 119 | env.distribution_hosts, 120 | env.current_host, 121 | env.network_interface_name, 122 | instance_type=env.current_instance_type, 123 | ) 124 | elif identifier is RunnerType.MPI and env.is_master: 125 | num_processes = _mpi_param_value(mpi_args, env, params.MPI_NUM_PROCESSES) 126 | custom_mpi_options = _mpi_param_value(mpi_args, env, params.MPI_CUSTOM_OPTIONS, "") 127 | current_instance_type = env.current_instance_type 128 | return mpi.MasterRunner( 129 | user_entry_point, 130 | args, 131 | env_vars, 132 | processes_per_host, 133 | env.master_hostname, 134 | env.distribution_hosts, 135 | custom_mpi_options, 136 | env.network_interface_name, 137 | num_processes=num_processes, 138 | instance_type=current_instance_type, 139 | ) 140 | elif identifier is RunnerType.MPI: 141 | return mpi.WorkerRunner( 142 | user_entry_point, 143 | args, 144 | env_vars, 145 | processes_per_host, 146 | env.master_hostname, 147 | env.current_host, 148 | ) 149 | elif identifier is RunnerType.PyTorchXLA: 150 | return pytorch_xla.PyTorchXLARunner( 151 | user_entry_point, 152 | args, 153 | env_vars, 154 | processes_per_host, 155 | env.master_hostname, 156 | env.current_host, 157 | env.distribution_hosts, 158 | env.num_gpus, 159 | ) 160 | elif identifier is RunnerType.Process: 161 | return process.ProcessRunner(user_entry_point, args, env_vars, processes_per_host) 162 | else: 163 | raise ValueError("Invalid identifier %s" % identifier) 164 | 165 | 166 | def _mpi_param_value(mpi_args, env, param_name, default=None): 167 | return mpi_args.get(param_name) or env.additional_framework_parameters.get(param_name, default) 168 | -------------------------------------------------------------------------------- /src/sagemaker_training/timeout.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the 'License'). You 4 | # may not use this file except in compliance with the License. A copy of 5 | # the License is located at 6 | # 7 | # http://aws.amazon.com/apache2.0/ 8 | # 9 | # or in the 'license' file accompanying this file. This file is 10 | # distributed on an 'AS IS' BASIS, WITHOUT WARRANTIES OR CONDITIONS OF 11 | # ANY KIND, either express or implied. See the License for the specific 12 | # language governing permissions and limitations under the License. 13 | """This module contains custom timeout functionality.""" 14 | from __future__ import absolute_import 15 | 16 | from contextlib import contextmanager 17 | import signal 18 | 19 | 20 | class TimeoutError(Exception): # pylint: disable=redefined-builtin 21 | """Override the Python 3 TimeoutError built-in exception. 22 | 23 | This builtin is being overridden for the purpose of compatibility with Python 2, 24 | since TimeoutError is not a built-in exception in Python 2. 25 | """ 26 | 27 | 28 | @contextmanager 29 | def timeout(seconds=0, minutes=0, hours=0): 30 | """Add a signal-based timeout to any block of code. 31 | If multiple time units are specified, they will be added together to determine time limit. 32 | 33 | Usage: 34 | with timeout(seconds=5): 35 | my_slow_function(...) 36 | 37 | Args: 38 | seconds (int): The time limit, in seconds. 39 | minutes (int): The time limit, in minutes. 40 | hours (int): The time limit, in hours. 41 | """ 42 | 43 | limit = seconds + 60 * minutes + 3600 * hours 44 | 45 | def handler(signum, frame): # pylint: disable=W0613 46 | raise TimeoutError("timed out after {} seconds".format(limit)) 47 | 48 | try: 49 | signal.signal(signal.SIGALRM, handler) 50 | signal.setitimer(signal.ITIMER_REAL, limit) 51 | yield 52 | finally: 53 | signal.alarm(0) 54 | -------------------------------------------------------------------------------- /src/sagemaker_training/torch_distributed.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018-2021 Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the 'License'). You 4 | # may not use this file except in compliance with the License. A copy of 5 | # the License is located at 6 | # 7 | # http://aws.amazon.com/apache2.0/ 8 | # 9 | # or in the 'license' file accompanying this file. This file is 10 | # distributed on an 'AS IS' BASIS, WITHOUT WARRANTIES OR CONDITIONS OF 11 | # ANY KIND, either express or implied. See the License for the specific 12 | # language governing permissions and limitations under the License. 13 | """This module contains functionality related to Torch Distributed Elastic Runner. 14 | Refer: https://pytorch.org/docs/stable/elastic/run.html 15 | """ 16 | from __future__ import absolute_import 17 | 18 | import os 19 | 20 | from sagemaker_training import ( 21 | _entry_point_type, 22 | environment, 23 | errors, 24 | logging_config, 25 | process, 26 | SM_EFA_NCCL_INSTANCES, 27 | SM_EFA_RDMA_INSTANCES, 28 | ) 29 | 30 | TORCH_DISTRIBUTED_MODULE = "torchrun" 31 | MASTER_PORT = "7777" 32 | 33 | logger = logging_config.get_logger() 34 | 35 | 36 | class TorchDistributedRunner(process.ProcessRunner): 37 | """Runner responsible for preparing Pytorch distributed data parallel training""" 38 | 39 | def __init__( 40 | self, 41 | user_entry_point, 42 | args, 43 | env_vars, 44 | processes_per_host, 45 | master_hostname, 46 | hosts, 47 | current_host, 48 | network_interface_name, 49 | instance_type="ml.trn1.2xlarge", 50 | ): 51 | """Initialize a Native PT Launcher, which is responsible for executing 52 | the user entry point within a process. 53 | 54 | Args: 55 | user_entry_point (str): The name of the user entry point. 56 | args ([str]): A list of arguments to include when executing the entry point. 57 | env_vars (dict(str,str)): A dictionary of environment variables. 58 | """ 59 | super(TorchDistributedRunner, self).__init__( 60 | user_entry_point, args, env_vars, processes_per_host 61 | ) 62 | 63 | self._master_hostname = master_hostname 64 | self._hosts = hosts 65 | self._current_host = current_host 66 | self._network_interface_name = network_interface_name 67 | self._instance_type = instance_type 68 | 69 | def _setup(self): 70 | logger.info("Starting distributed training through torchrun") 71 | # EFA settings 72 | if self._instance_type in SM_EFA_NCCL_INSTANCES: 73 | # Enable EFA use 74 | os.environ["FI_PROVIDER"] = "efa" 75 | if self._instance_type in SM_EFA_RDMA_INSTANCES: 76 | # Use EFA's RDMA functionality for one-sided and two-sided transfer 77 | os.environ["FI_EFA_USE_DEVICE_RDMA"] = "1" 78 | os.environ["RDMAV_FORK_SAFE"] = "1" 79 | os.environ["NCCL_SOCKET_IFNAME"] = str(self._network_interface_name) 80 | os.environ["NCCL_PROTO"] = "simple" 81 | 82 | def _create_command(self): 83 | """ 84 | Based on the number of hosts, torchrun command differs. 85 | Currently the elasticity feture of torchrun is not yet supported. 86 | """ 87 | self._setup() 88 | entrypoint_type = _entry_point_type.get(environment.code_dir, self._user_entry_point) 89 | 90 | if entrypoint_type is _entry_point_type.PYTHON_PACKAGE: 91 | raise errors.ClientError( 92 | "Python packages are not supported for torch_distributed. " 93 | "Please use a python script as the entry-point" 94 | ) 95 | 96 | if entrypoint_type is _entry_point_type.PYTHON_PROGRAM: 97 | num_hosts = len(self._hosts) 98 | torchrun_cmd = [] 99 | 100 | # Adding support for neuron_parallel_compile to precompile XLA graphs, 101 | # if environment variable RUN_NEURON_PARALLEL_COMPILE == "1" 102 | # This is an example of the command line output when this flag is set: 103 | # "neuron_parallel_compile torchrun --nnodes 2 --nproc_per_node 32 104 | # --master_addr algo-1 --master_port 7777 --node_rank 0 trn_train.py 105 | # --max_steps 100" 106 | 107 | if os.environ.get("RUN_NEURON_PARALLEL_COMPILE") == "1": 108 | torchrun_cmd.append("neuron_parallel_compile") 109 | 110 | node_options = [ 111 | TORCH_DISTRIBUTED_MODULE, 112 | "--nnodes", 113 | str(num_hosts), 114 | "--nproc_per_node", 115 | str(self._processes_per_host), 116 | ] 117 | 118 | torchrun_cmd += node_options 119 | 120 | multinode_options = [ 121 | "--master_addr", 122 | str(self._master_hostname), 123 | "--master_port", 124 | MASTER_PORT, 125 | "--node_rank", 126 | str(self._hosts.index(self._current_host)), 127 | ] 128 | 129 | if num_hosts > 1: 130 | torchrun_cmd += multinode_options 131 | 132 | # match = re.match(r"^ml[\._]([a-z\d]+)\.?\w*$", instance_type) 133 | 134 | torchrun_cmd.append(str(self._user_entry_point)) 135 | torchrun_cmd += self._args 136 | return torchrun_cmd 137 | else: 138 | raise errors.ClientError("Unsupported entry point type for torch_distributed") 139 | 140 | def run(self, capture_error=True, wait=True): 141 | """ 142 | Run the process. 143 | 144 | Args: 145 | capture_error (bool): A boolean indicating whether to direct stderr to a stream 146 | that can later be read. Defaults to True. 147 | Returns: 148 | process (subprocess.Popen): The spawned process. 149 | """ 150 | cmd = self._create_command() 151 | logging_config.log_script_invocation(cmd, self._env_vars) 152 | if wait: 153 | process_spawned = process.check_error( 154 | cmd, 155 | errors.ExecuteUserScriptError, 156 | self._processes_per_host, 157 | capture_error=capture_error, 158 | cwd=environment.code_dir, 159 | ) 160 | else: 161 | process_spawned = process.create( 162 | cmd, 163 | errors.ExecuteUserScriptError, 164 | self._processes_per_host, 165 | capture_error=capture_error, 166 | cwd=environment.code_dir, 167 | ) 168 | return process_spawned 169 | -------------------------------------------------------------------------------- /src/sagemaker_training/trainer.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the 'License'). You 4 | # may not use this file except in compliance with the License. A copy of 5 | # the License is located at 6 | # 7 | # http://aws.amazon.com/apache2.0/ 8 | # 9 | # or in the 'license' file accompanying this file. This file is 10 | # distributed on an 'AS IS' BASIS, WITHOUT WARRANTIES OR CONDITIONS OF 11 | # ANY KIND, either express or implied. See the License for the specific 12 | # language governing permissions and limitations under the License. 13 | """This module contains the train function, which is the main function 14 | responsible for running training in the container. 15 | """ 16 | from __future__ import absolute_import 17 | 18 | import importlib 19 | import os 20 | import sys 21 | import traceback 22 | 23 | from sagemaker_training import ( 24 | entry_point, 25 | environment, 26 | errors, 27 | files, 28 | intermediate_output, 29 | logging_config, 30 | params, 31 | runner, 32 | SM_TRAINING_COMPILER_PATHS, 33 | ) 34 | 35 | logger = logging_config.get_logger() 36 | 37 | SUCCESS_CODE = 0 38 | DEFAULT_FAILURE_CODE = 1 39 | 40 | 41 | def _get_valid_failure_exit_code(exit_code): 42 | try: 43 | valid_exit_code = int(exit_code) 44 | except ValueError: 45 | valid_exit_code = DEFAULT_FAILURE_CODE 46 | 47 | return valid_exit_code 48 | 49 | 50 | def _exit_processes(exit_code): # type: (int) -> None 51 | """Exit main thread and child processes. 52 | 53 | For more information: 54 | https://docs.python.org/2/library/os.html#process-management 55 | https://docs.python.org/3/library/os.html#process-management 56 | 57 | Args: 58 | exit_code (int): exit code 59 | """ 60 | if exit_code != 0: 61 | logger.error(f"Encountered exit_code {exit_code}") 62 | sys.exit(exit_code) 63 | 64 | 65 | def train(): 66 | """The main function responsible for running training in the container.""" 67 | intermediate_sync = None 68 | exit_code = SUCCESS_CODE 69 | try: 70 | env = environment.Environment() 71 | 72 | region = os.environ.get("AWS_REGION", os.environ.get(params.REGION_NAME_ENV)) 73 | s3_endpoint_url = os.environ.get(params.S3_ENDPOINT_URL, None) 74 | intermediate_sync = intermediate_output.start_sync( 75 | env.sagemaker_s3_output(), region, endpoint_url=s3_endpoint_url 76 | ) 77 | 78 | if env.framework_module: 79 | framework_name, entry_point_name = env.framework_module.split(":") 80 | 81 | framework = importlib.import_module(framework_name) 82 | 83 | # the logger is configured after importing the framework library, allowing 84 | # the framework to configure logging at import time. 85 | logging_config.configure_logger(env.log_level) 86 | logger.info("Imported framework %s", framework_name) 87 | entrypoint = getattr(framework, entry_point_name) 88 | entrypoint() 89 | else: 90 | logging_config.configure_logger(env.log_level) 91 | 92 | mpi_enabled = env.additional_framework_parameters.get(params.MPI_ENABLED) 93 | runner_type = ( 94 | runner.RunnerType.MPI 95 | if mpi_enabled and (env.current_instance_group in env.distribution_instance_groups) 96 | else runner.RunnerType.Process 97 | ) 98 | 99 | entry_point.run( 100 | env.module_dir, 101 | env.user_entry_point, 102 | env.to_cmd_args(), 103 | env.to_env_vars(), 104 | runner_type=runner_type, 105 | ) 106 | logger.info("Reporting training SUCCESS") 107 | 108 | files.write_success_file() 109 | except errors.ClientError as e: 110 | failure_msg = str(e) 111 | files.write_failure_file(failure_msg) 112 | logger.error("Reporting training FAILURE") 113 | 114 | logger.error(failure_msg) 115 | 116 | if intermediate_sync: 117 | intermediate_sync.join() 118 | 119 | exit_code = DEFAULT_FAILURE_CODE 120 | except Exception as e: # pylint: disable=broad-except 121 | if any(path in traceback.format_exc() for path in SM_TRAINING_COMPILER_PATHS): 122 | failure_msg = "SMTrainingCompiler Error: \n%s\n%s" % (traceback.format_exc(), str(e)) 123 | else: 124 | failure_msg = "Framework Error: \n%s\n%s" % (traceback.format_exc(), str(e)) 125 | files.write_failure_file(failure_msg) 126 | logger.error("Reporting training FAILURE") 127 | 128 | logger.error(failure_msg) 129 | 130 | error_number = getattr(e, "errno", DEFAULT_FAILURE_CODE) 131 | exit_code = _get_valid_failure_exit_code(error_number) 132 | finally: 133 | if intermediate_sync: 134 | intermediate_sync.join() 135 | _exit_processes(exit_code) 136 | -------------------------------------------------------------------------------- /test/conftest.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"). You 4 | # may not use this file except in compliance with the License. A copy of 5 | # the License is located at 6 | # 7 | # http://aws.amazon.com/apache2.0/ 8 | # 9 | # or in the "license" file accompanying this file. This file is 10 | # distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF 11 | # ANY KIND, either express or implied. See the License for the specific 12 | # language governing permissions and limitations under the License. 13 | from __future__ import absolute_import 14 | 15 | import asyncio 16 | import json 17 | import logging 18 | import os 19 | import re 20 | import shutil 21 | import socket 22 | import subprocess 23 | import sys 24 | 25 | from mock import patch 26 | import pytest 27 | 28 | from sagemaker_training import environment 29 | 30 | logging.getLogger("boto3").setLevel(logging.INFO) 31 | logging.getLogger("s3transfer").setLevel(logging.INFO) 32 | logging.getLogger("botocore").setLevel(logging.WARN) 33 | 34 | DEFAULT_REGION = "us-west-2" 35 | 36 | 37 | def _write_json(obj, path): # type: (object, str) -> None 38 | with open(path, "w") as f: 39 | json.dump(obj, f) 40 | 41 | 42 | @pytest.fixture(autouse=True) 43 | def create_base_path(): 44 | yield str(os.environ[environment.BASE_PATH_ENV]) 45 | 46 | shutil.rmtree(os.environ[environment.BASE_PATH_ENV]) 47 | 48 | os.makedirs(environment.model_dir) 49 | os.makedirs(environment.input_config_dir) 50 | os.makedirs(environment.code_dir) 51 | os.makedirs(environment.output_data_dir) 52 | 53 | _write_json({}, environment.hyperparameters_file_dir) 54 | _write_json({}, environment.input_data_config_file_dir) 55 | host_name = socket.gethostname() 56 | 57 | resources_dict = {"current_host": host_name, "hosts": [host_name]} 58 | _write_json(resources_dict, environment.resource_config_file_dir) 59 | 60 | 61 | @pytest.fixture(autouse=True) 62 | def patch_exit_process(): 63 | def _exit(error_code): 64 | if error_code: 65 | raise ValueError(error_code) 66 | 67 | with patch("sagemaker_training.trainer._exit_processes", _exit): 68 | yield _exit 69 | 70 | 71 | @pytest.fixture(autouse=True) 72 | def fix_protobuf_installation_for_python_2(): 73 | # Python 2 requires an __init__.py at every level, 74 | # but protobuf doesn't honor that, so we create the file ourselves. 75 | # https://stackoverflow.com/a/45141001 76 | if sys.version_info.major == 2: 77 | protobuf_info = subprocess.check_output("pip show protobuf".split()) 78 | site_packages = re.match(r"[\S\s]*Location: (.*)\s", protobuf_info).group(1) 79 | with open(os.path.join(site_packages, "google", "__init__.py"), "w"): 80 | pass 81 | 82 | 83 | @pytest.fixture(scope="session") 84 | def event_loop(): 85 | """Create an instance of the default event loop for each test case.""" 86 | loop = asyncio.get_event_loop_policy().new_event_loop() 87 | yield loop 88 | loop.close() 89 | -------------------------------------------------------------------------------- /test/container/dummy/Dockerfile: -------------------------------------------------------------------------------- 1 | FROM public.ecr.aws/lts/ubuntu:20.04 2 | 3 | # Disable prompts during package installation 4 | ENV DEBIAN_FRONTEND="noninteractive" 5 | 6 | ARG PYTHON=python3 7 | ARG PIP=pip3 8 | ARG PYTHON_VERSION=3.8.13 9 | 10 | RUN apt-get update \ 11 | && apt-get install -y --no-install-recommends \ 12 | build-essential \ 13 | ca-certificates \ 14 | wget \ 15 | zlib1g-dev \ 16 | && apt-get clean \ 17 | && rm -rf /var/lib/apt/lists/* 18 | 19 | RUN apt-get update \ 20 | && apt-get install -y --no-install-recommends \ 21 | libbz2-dev \ 22 | libc6-dev \ 23 | libffi-dev \ 24 | libgdbm-dev \ 25 | liblzma-dev \ 26 | libncursesw5-dev \ 27 | libreadline-gplv2-dev \ 28 | libsqlite3-dev \ 29 | libssl-dev \ 30 | tk-dev \ 31 | ffmpeg \ 32 | libsm6 \ 33 | libxext6 \ 34 | && rm -rf /var/lib/apt/lists/* \ 35 | && apt-get clean 36 | 37 | RUN wget https://www.python.org/ftp/python/$PYTHON_VERSION/Python-$PYTHON_VERSION.tgz \ 38 | && tar -xvf Python-$PYTHON_VERSION.tgz \ 39 | && cd Python-$PYTHON_VERSION \ 40 | && ./configure && make && make install \ 41 | && rm -rf ../Python-$PYTHON_VERSION* 42 | 43 | RUN ${PIP} --no-cache-dir install --upgrade pip 44 | 45 | RUN ln -s $(which ${PYTHON}) /usr/local/bin/python \ 46 | && ln -s $(which ${PIP}) /usr/bin/pip 47 | 48 | COPY dummy/sagemaker_training.tar.gz /sagemaker_training.tar.gz 49 | 50 | RUN ${PIP} install --no-cache-dir \ 51 | /sagemaker_training.tar.gz 52 | 53 | RUN rm /sagemaker_training.tar.gz 54 | 55 | COPY dummy/train.py /opt/ml/code/train.py 56 | COPY dummy/requirements.txt /opt/ml/code/requirements.txt 57 | 58 | ENV SAGEMAKER_PROGRAM train.py 59 | -------------------------------------------------------------------------------- /test/container/dummy/requirements.txt: -------------------------------------------------------------------------------- 1 | pyfiglet==0.8.post1 2 | -------------------------------------------------------------------------------- /test/container/dummy/train.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aws/sagemaker-training-toolkit/a0073959d4de7e8311b370a1f143e44dcf801d1e/test/container/dummy/train.py -------------------------------------------------------------------------------- /test/container/tensorflow/Dockerfile: -------------------------------------------------------------------------------- 1 | FROM 763104351884.dkr.ecr.us-west-2.amazonaws.com/tensorflow-training:2.13.0-gpu-py310-cu118-ubuntu20.04-sagemaker 2 | 3 | COPY dummy/sagemaker_training.tar.gz /sagemaker_training.tar.gz 4 | 5 | RUN ${PIP} install --no-cache-dir \ 6 | /sagemaker_training.tar.gz 7 | 8 | RUN rm /sagemaker_training.tar.gz 9 | 10 | COPY tensorflow/train.py /opt/ml/code/train.py 11 | 12 | ENV SAGEMAKER_PROGRAM train.py 13 | 14 | -------------------------------------------------------------------------------- /test/container/tensorflow/train.py: -------------------------------------------------------------------------------- 1 | class XlaRuntimeError(Exception): 2 | """dummy XlaRuntimeError class to throw. Unable to import the actual 3 | XlaRuntimeError class defined in tensorflow/compiler/xla/python/xla_client.py module. 4 | """ 5 | 6 | 7 | raise XlaRuntimeError("Throwing the dummy exception to simulate the error.") 8 | -------------------------------------------------------------------------------- /test/fake_ml_framework.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"). You 4 | # may not use this file except in compliance with the License. A copy of 5 | # the License is located at 6 | # 7 | # http://aws.amazon.com/apache2.0/ 8 | # 9 | # or in the "license" file accompanying this file. This file is 10 | # distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF 11 | # ANY KIND, either express or implied. See the License for the specific 12 | # language governing permissions and limitations under the License. 13 | import numpy as np 14 | 15 | from sagemaker_training import files 16 | import test 17 | 18 | 19 | class Model(object): 20 | def __init__( 21 | self, weights=None, bias=1, loss=None, optimizer=None, epochs=None, batch_size=None 22 | ): 23 | self.batch_size = batch_size 24 | self.epochs = epochs 25 | self.optimizer = optimizer 26 | self.loss = loss 27 | self.weights = weights 28 | self.bias = bias 29 | 30 | def fit(self, x, y, epochs=None, batch_size=None): 31 | self.weights = (y / x + self.bias).tolist() 32 | self.epochs = epochs 33 | self.batch_size = batch_size 34 | 35 | def save(self, model_dir): 36 | test.write_json(self.__dict__, model_dir) 37 | 38 | @classmethod 39 | def load(cls, model_dir): 40 | clazz = cls() 41 | clazz.__dict__ = files.read_json(model_dir) 42 | return clazz 43 | 44 | def predict(self, data): 45 | return np.asarray(self.weights) * np.asarray(data) 46 | -------------------------------------------------------------------------------- /test/functional/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018-2020 Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"). You 4 | # may not use this file except in compliance with the License. A copy of 5 | # the License is located at 6 | # 7 | # http://aws.amazon.com/apache2.0/ 8 | # 9 | # or in the "license" file accompanying this file. This file is 10 | # distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF 11 | # ANY KIND, either express or implied. See the License for the specific 12 | # language governing permissions and limitations under the License. 13 | from __future__ import absolute_import 14 | -------------------------------------------------------------------------------- /test/functional/simple_framework.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the 'License'). You 4 | # may not use this file except in compliance with the License. A copy of 5 | # the License is located at 6 | # 7 | # http://aws.amazon.com/apache2.0/ 8 | # 9 | # or in the 'license' file accompanying this file. This file is 10 | # distributed on an 'AS IS' BASIS, WITHOUT WARRANTIES OR CONDITIONS OF 11 | # ANY KIND, either express or implied. See the License for the specific 12 | # language governing permissions and limitations under the License. 13 | from __future__ import absolute_import 14 | 15 | import os 16 | 17 | from sagemaker_training import environment, functions, modules 18 | 19 | 20 | def train(): 21 | training_env = environment.Environment() 22 | 23 | script = modules.import_module(training_env.module_dir, training_env.module_name) 24 | 25 | model = script.train(**functions.matching_args(script.train, training_env)) 26 | 27 | if model: 28 | if hasattr(script, "save"): 29 | script.save(model, training_env.model_dir) 30 | else: 31 | model_file = os.path.join(training_env.model_dir, "saved_model") 32 | model.save(model_file) 33 | -------------------------------------------------------------------------------- /test/functional/test_download_and_import.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the 'License'). You 4 | # may not use this file except in compliance with the License. A copy of 5 | # the License is located at 6 | # 7 | # http://aws.amazon.com/apache2.0/ 8 | # 9 | # or in the 'license' file accompanying this file. This file is 10 | # distributed on an 'AS IS' BASIS, WITHOUT WARRANTIES OR CONDITIONS OF 11 | # ANY KIND, either express or implied. See the License for the specific 12 | # language governing permissions and limitations under the License. 13 | from __future__ import absolute_import 14 | 15 | import os 16 | import shlex 17 | import subprocess 18 | import textwrap 19 | 20 | import pytest 21 | 22 | from sagemaker_training import errors, modules 23 | import test 24 | 25 | data = [ 26 | "from distutils.core import setup\n", 27 | 'setup(name="my_test_script", py_modules=["my_test_script"])', 28 | ] 29 | 30 | SETUP_FILE = test.File("setup.py", data) 31 | 32 | USER_SCRIPT_FILE = test.File("my_test_script.py", "def validate(): return True") 33 | 34 | REQUIREMENTS_TXT_ASSERT_STR = """ 35 | ____ __ __ _............. 36 | / ___| __ _ __ _ ___| \/ | __ _| | _____ _ __. 37 | \___ \ / _` |/ _` |/ _ \ |\/| |/ _` | |/ / _ \ '__| 38 | ___) | (_| | (_| | __/ | | | (_| | < __/ |... 39 | |____/ \__,_|\__, |\___|_| |_|\__,_|_|\_\___|_|... 40 | |___/................................. 41 | """.replace( # noqa W605 42 | ".", " " 43 | ).strip() 44 | 45 | 46 | @pytest.fixture(name="user_module_name") 47 | def uninstall_user_module(): 48 | user_module = "my_test_script" 49 | yield user_module 50 | 51 | try: 52 | subprocess.check_call(shlex.split("pip uninstall -y --quiet %s" % user_module)) 53 | except subprocess.CalledProcessError: 54 | pass 55 | 56 | 57 | @pytest.fixture(name="requirements_file") 58 | def uninstall_requirements_file(): 59 | requirements_data = "pyfiglet" 60 | yield test.File("requirements.txt", requirements_data) 61 | 62 | try: 63 | subprocess.check_call(shlex.split("pip uninstall -y --quiet %s" % requirements_data)) 64 | except subprocess.CalledProcessError: 65 | pass 66 | 67 | 68 | @pytest.mark.parametrize( 69 | "user_module", 70 | [test.UserModule(USER_SCRIPT_FILE).add_file(SETUP_FILE), test.UserModule(USER_SCRIPT_FILE)], 71 | ) 72 | def test_import_module(user_module, user_module_name): 73 | user_module.upload() 74 | 75 | module = modules.import_module(user_module.url, user_module_name) 76 | 77 | assert module.validate() 78 | 79 | 80 | @pytest.mark.parametrize( 81 | "user_module", 82 | [test.UserModule(USER_SCRIPT_FILE).add_file(SETUP_FILE), test.UserModule(USER_SCRIPT_FILE)], 83 | ) 84 | def test_import_module_with_s3_script(user_module, user_module_name): 85 | user_module.upload() 86 | 87 | module = modules.import_module(user_module.url, user_module_name) 88 | 89 | assert module.validate() 90 | 91 | 92 | @pytest.mark.parametrize( 93 | "user_module", 94 | [test.UserModule(USER_SCRIPT_FILE).add_file(SETUP_FILE), test.UserModule(USER_SCRIPT_FILE)], 95 | ) 96 | def test_import_module_with_local_script(user_module, user_module_name, tmpdir): 97 | tmp_code_dir = str(tmpdir) 98 | 99 | user_module.create_tmp_dir_with_files(tmp_code_dir) 100 | 101 | module = modules.import_module(tmp_code_dir, user_module_name) 102 | 103 | assert module.validate() 104 | 105 | 106 | data = textwrap.dedent( 107 | """ 108 | from pyfiglet import Figlet 109 | 110 | def say(): 111 | return Figlet().renderText('SageMaker').strip() 112 | 113 | """ 114 | ) 115 | 116 | USER_SCRIPT_WITH_REQUIREMENTS = test.File("my_test_script.py", data) 117 | 118 | 119 | @pytest.mark.parametrize( 120 | "user_module", 121 | [ 122 | test.UserModule(USER_SCRIPT_WITH_REQUIREMENTS).add_file(SETUP_FILE), 123 | test.UserModule(USER_SCRIPT_WITH_REQUIREMENTS), 124 | ], 125 | ) 126 | def test_import_module_with_s3_script_with_requirements( 127 | user_module, user_module_name, requirements_file 128 | ): 129 | user_module = user_module.add_file(requirements_file).upload() 130 | 131 | module = modules.import_module(user_module.url, user_module_name) 132 | 133 | assert module.say() == REQUIREMENTS_TXT_ASSERT_STR 134 | 135 | 136 | @pytest.mark.parametrize( 137 | "user_module", 138 | [ 139 | test.UserModule(USER_SCRIPT_WITH_REQUIREMENTS).add_file(SETUP_FILE), 140 | test.UserModule(USER_SCRIPT_WITH_REQUIREMENTS), 141 | ], 142 | ) 143 | def test_import_module_with_requirements(user_module, user_module_name, requirements_file): 144 | user_module = user_module.add_file(requirements_file).upload() 145 | 146 | module = modules.import_module(uri=user_module.url, name=user_module_name) 147 | 148 | assert module.say() == REQUIREMENTS_TXT_ASSERT_STR 149 | 150 | 151 | data = ['raise ValueError("this script does not work")'] 152 | USER_SCRIPT_WITH_ERROR = test.File("my_test_script.py", data) 153 | 154 | 155 | def test_import_module_with_s3_script_with_error(user_module_name): 156 | user_module = test.UserModule(USER_SCRIPT_WITH_ERROR).add_file(SETUP_FILE).upload() 157 | 158 | with pytest.raises(errors.ImportModuleError): 159 | modules.import_module(user_module.url, user_module_name) 160 | 161 | 162 | @pytest.mark.parametrize( 163 | "user_module", 164 | [ 165 | test.UserModule(USER_SCRIPT_WITH_REQUIREMENTS).add_file(SETUP_FILE), 166 | test.UserModule(USER_SCRIPT_WITH_REQUIREMENTS), 167 | ], 168 | ) 169 | def test_import_module_with_local_tar(user_module, user_module_name, requirements_file): 170 | user_module = user_module.add_file(requirements_file) 171 | tar_name = user_module.create_tar() 172 | 173 | module = modules.import_module(tar_name, name=user_module_name) 174 | 175 | assert module.say() == REQUIREMENTS_TXT_ASSERT_STR 176 | 177 | os.remove(tar_name) 178 | -------------------------------------------------------------------------------- /test/functional/test_mpi.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the 'License'). You 4 | # may not use this file except in compliance with the License. A copy of 5 | # the License is located at 6 | # 7 | # http://aws.amazon.com/apache2.0/ 8 | # 9 | # or in the 'license' file accompanying this file. This file is 10 | # distributed on an 'AS IS' BASIS, WITHOUT WARRANTIES OR CONDITIONS OF 11 | # ANY KIND, either express or implied. See the License for the specific 12 | # language governing permissions and limitations under the License. 13 | import logging 14 | import os 15 | import shutil 16 | import subprocess 17 | 18 | import pytest 19 | from sagemaker.estimator import Framework 20 | 21 | logging.basicConfig(level=logging.INFO) 22 | 23 | dir_path = os.path.realpath(__file__) 24 | root_dir = os.path.realpath(os.path.join(dir_path, "..", "..", "..")) 25 | source_dir = os.path.realpath(os.path.join(dir_path, "..", "..", "resources", "openmpi")) 26 | 27 | 28 | class CustomEstimator(Framework): 29 | def create_model(self, **kwargs): 30 | raise NotImplementedError("This methos is not supported.") 31 | 32 | 33 | @pytest.mark.skip( 34 | reason="waiting for local mode fix on " "https://github.com/aws/sagemaker-python-sdk/pull/559" 35 | ) 36 | def test_mpi(tmpdir): 37 | estimator = CustomEstimator( 38 | entry_point="launcher.sh", 39 | image_name=build_mpi_image(tmpdir), 40 | role="SageMakerRole", 41 | train_instance_count=2, 42 | source_dir=source_dir, 43 | train_instance_type="local", 44 | hyperparameters={ 45 | "sagemaker_mpi_enabled": True, 46 | "sagemaker_mpi_custom_mpi_options": "-verbose", 47 | "sagemaker_network_interface_name": "eth0", 48 | }, 49 | ) 50 | 51 | estimator.fit() 52 | 53 | 54 | def build_mpi_image(tmpdir): 55 | tmp = str(tmpdir) 56 | 57 | subprocess.check_call(["python", "setup.py", "sdist"], cwd=root_dir) 58 | 59 | for file in os.listdir(os.path.join(root_dir, "dist")): 60 | shutil.copy2(os.path.join(root_dir, "dist", file), tmp) 61 | 62 | shutil.copy2(os.path.join(source_dir, "Dockerfile"), tmp) 63 | 64 | imagename = "openmpi" 65 | subprocess.check_call(["docker", "build", "-t", imagename, "."], cwd=tmp) 66 | 67 | return imagename 68 | -------------------------------------------------------------------------------- /test/integration/local/test_dummy.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"). You 4 | # may not use this file except in compliance with the License. A copy of 5 | # the License is located at 6 | # 7 | # http://aws.amazon.com/apache2.0/ 8 | # 9 | # or in the "license" file accompanying this file. This file is 10 | # distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF 11 | # ANY KIND, either express or implied. See the License for the specific 12 | # language governing permissions and limitations under the License. 13 | from __future__ import absolute_import 14 | 15 | import subprocess 16 | import sys 17 | 18 | import pytest 19 | from sagemaker.estimator import Estimator 20 | 21 | 22 | @pytest.fixture(scope="module", autouse=True) 23 | def container(): 24 | try: 25 | command = ( 26 | "docker run --name sagemaker-training-toolkit-test " 27 | "sagemaker-training-toolkit-test:dummy train" 28 | ) 29 | 30 | proc = subprocess.Popen(command.split(), stdout=sys.stdout, stderr=subprocess.STDOUT) 31 | 32 | yield proc.pid 33 | 34 | finally: 35 | subprocess.check_call("docker rm -f sagemaker-training-toolkit-test".split()) 36 | 37 | 38 | def test_install_requirements(capsys): 39 | estimator = Estimator( 40 | image_uri="sagemaker-training-toolkit-test:dummy", 41 | role="SageMakerRole", 42 | instance_count=1, 43 | instance_type="local", 44 | ) 45 | 46 | estimator.fit() 47 | 48 | stdout = capsys.readouterr().out 49 | 50 | assert "Installing collected packages: pyfiglet" in stdout 51 | assert "Successfully installed pyfiglet-0.8.post1" in stdout 52 | assert "Reporting training SUCCESS" in stdout 53 | 54 | 55 | # def test_install_requirements_from_codeartifact(capsys): 56 | # # TODO: fill in details for CA 57 | # ca_domain = "..." 58 | # ca_domain_owner = "..." 59 | # ca_repository = "..." 60 | # ca_region = "..." 61 | # ca_repository_arn = "..." 62 | # 63 | # estimator = Estimator( 64 | # image_uri="sagemaker-training-toolkit-test:dummy", 65 | # # TODO: Grant the role permissions to access CodeArtifact repo (repo resource policy + role policy) 66 | # # https://docs.aws.amazon.com/codeartifact/latest/ug/security-iam.html 67 | # # https://docs.aws.amazon.com/codeartifact/latest/ug/repo-policies.html 68 | # role="SageMakerRole", 69 | # instance_count=1, 70 | # instance_type="local", 71 | # environment={ 72 | # "CA_REPOSITORY_ARN": ca_repository_arn, 73 | # } 74 | # ) 75 | # 76 | # estimator.fit() 77 | # 78 | # stdout = capsys.readouterr().out 79 | # 80 | # assert "{}-{}.d.codeartifact.{}.amazonaws.com/pypi/{}/simple/".format(ca_domain, ca_domain_owner, ca_region, ca_repository) in stdout 81 | # assert "Installing collected packages: pyfiglet" in stdout 82 | # assert "Successfully installed pyfiglet-0.8.post1" in stdout 83 | # assert "Reporting training SUCCESS" in stdout 84 | -------------------------------------------------------------------------------- /test/integration/local/test_tensorflow.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"). You 4 | # may not use this file except in compliance with the License. A copy of 5 | # the License is located at 6 | # 7 | # http://aws.amazon.com/apache2.0/ 8 | # 9 | # or in the "license" file accompanying this file. This file is 10 | # distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF 11 | # ANY KIND, either express or implied. See the License for the specific 12 | # language governing permissions and limitations under the License. 13 | from __future__ import absolute_import 14 | 15 | import subprocess 16 | import sys 17 | 18 | import pytest 19 | from sagemaker.estimator import Estimator 20 | 21 | 22 | @pytest.fixture(scope="module", autouse=True) 23 | def container(): 24 | try: 25 | command = ( 26 | "docker run --name sagemaker-training-toolkit-test " 27 | "sagemaker-training-toolkit-test:tensorflow train" 28 | ) 29 | 30 | proc = subprocess.Popen(command.split(), stdout=sys.stdout, stderr=subprocess.STDOUT) 31 | 32 | yield proc.pid 33 | 34 | finally: 35 | subprocess.check_call("docker rm -f sagemaker-training-toolkit-test".split()) 36 | 37 | 38 | def test_tensorflow_exceptions(capsys): 39 | with pytest.raises(Exception): 40 | estimator = Estimator( 41 | image_uri="sagemaker-training-toolkit-test:tensorflow", 42 | role="SageMakerRole", 43 | instance_count=1, 44 | instance_type="local", 45 | ) 46 | 47 | estimator.fit() 48 | stdout = capsys.readouterr().out 49 | assert "XlaRuntimeError" in stdout 50 | -------------------------------------------------------------------------------- /test/resources/openmpi/Dockerfile: -------------------------------------------------------------------------------- 1 | # Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the 'License'). You 4 | # may not use this file except in compliance with the License. A copy of 5 | # the License is located at 6 | # 7 | # http://aws.amazon.com/apache2.0/ 8 | # 9 | # or in the 'license' file accompanying this file. This file is 10 | # distributed on an 'AS IS' BASIS, WITHOUT WARRANTIES OR CONDITIONS OF 11 | # ANY KIND, either express or implied. See the License for the specific 12 | # language governing permissions and limitations under the License. 13 | FROM mvsusp/openmpi 14 | 15 | RUN cat /etc/ssh/ssh_config | grep -v StrictHostKeyChecking > /etc/ssh/ssh_config.new && \ 16 | echo " StrictHostKeyChecking no" >> /etc/ssh/ssh_config.new && \ 17 | mv /etc/ssh/ssh_config.new /etc/ssh/ssh_config 18 | RUN pip install mpi4py==3.0.0 19 | 20 | ENV SAGEMAKER_training=$(ls sagemaker_training-*.tar.gz) 21 | COPY ${SAGEMAKER_training} ${SAGEMAKER_training} 22 | 23 | RUN pip install ${SAGEMAKER_training} 24 | -------------------------------------------------------------------------------- /test/resources/openmpi/Dockerfile.base: -------------------------------------------------------------------------------- 1 | # Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the 'License'). You 4 | # may not use this file except in compliance with the License. A copy of 5 | # the License is located at 6 | # 7 | # http://aws.amazon.com/apache2.0/ 8 | # 9 | # or in the 'license' file accompanying this file. This file is 10 | # distributed on an 'AS IS' BASIS, WITHOUT WARRANTIES OR CONDITIONS OF 11 | # ANY KIND, either express or implied. See the License for the specific 12 | # language governing permissions and limitations under the License. 13 | FROM ubuntu:16.04 14 | 15 | # Install basic dependencies and locales 16 | RUN apt-get update && apt-get install -y \ 17 | build-essential \ 18 | openssh-client \ 19 | openssh-server \ 20 | wget \ 21 | python-dev \ 22 | ca-certificates && \ 23 | rm -rf /var/lib/apt/lists/* /var/cache/apt/archives/* 24 | 25 | # Install Open MPI 26 | RUN mkdir /tmp/openmpi && \ 27 | cd /tmp/openmpi && \ 28 | wget https://download.open-mpi.org/release/open-mpi/v3.1/openmpi-3.1.3.tar.gz && \ 29 | tar zxf openmpi-3.1.3.tar.gz && \ 30 | cd openmpi-3.1.3 && \ 31 | ./configure --enable-orterun-prefix-by-default && \ 32 | make install all && \ 33 | ldconfig && \ 34 | rm -rf /tmp/openmpi 35 | 36 | 37 | # Create a wrapper for OpenMPI to allow running as root by default 38 | RUN mv /usr/local/bin/mpirun /usr/local/bin/mpirun.real && \ 39 | echo '#!/bin/bash' > /usr/local/bin/mpirun && \ 40 | echo 'mpirun.real --allow-run-as-root "$@"' >> /usr/local/bin/mpirun && \ 41 | chmod a+x /usr/local/bin/mpirun 42 | 43 | RUN echo "hwloc_base_binding_policy = none" >> /usr/local/etc/openmpi-mca-params.conf && \ 44 | echo "rmaps_base_mapping_policy = slot" >> /usr/local/etc/openmpi-mca-params.conf 45 | 46 | ENV LD_LIBRARY_PATH=/usr/local/openmpi/lib:$LD_LIBRARY_PATH 47 | 48 | ENV PATH /usr/local/openmpi/bin/:$PATH 49 | 50 | # SSH login fix. Otherwise user is kicked off after login 51 | RUN sed 's@session\s*required\s*pam_loginuid.so@session optional pam_loginuid.so@g' -i /etc/pam.d/sshd 52 | 53 | # Create SSH key. 54 | RUN mkdir -p /root/.ssh/ && \ 55 | mkdir -p /var/run/sshd && \ 56 | ssh-keygen -q -t rsa -N '' -f /root/.ssh/id_rsa && \ 57 | cp /root/.ssh/id_rsa.pub /root/.ssh/authorized_keys && \ 58 | printf "Host *\n StrictHostKeyChecking no\n" >> /root/.ssh/config 59 | 60 | RUN wget https://bootstrap.pypa.io/get-pip.py && \ 61 | python get-pip.py --disable-pip-version-check --no-cache-dir "pip==18.1" 62 | -------------------------------------------------------------------------------- /test/resources/openmpi/launcher.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | # Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the 'License'). You 5 | # may not use this file except in compliance with the License. A copy of 6 | # the License is located at 7 | # 8 | # http://aws.amazon.com/apache2.0/ 9 | # 10 | # or in the 'license' file accompanying this file. This file is 11 | # distributed on an 'AS IS' BASIS, WITHOUT WARRANTIES OR CONDITIONS OF 12 | # ANY KIND, either express or implied. See the License for the specific 13 | # language governing permissions and limitations under the License. 14 | 15 | set -ex 16 | python script.py 17 | -------------------------------------------------------------------------------- /test/resources/openmpi/script.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the 'License'). You 4 | # may not use this file except in compliance with the License. A copy of 5 | # the License is located at 6 | # 7 | # http://aws.amazon.com/apache2.0/ 8 | # 9 | # or in the 'license' file accompanying this file. This file is 10 | # distributed on an 'AS IS' BASIS, WITHOUT WARRANTIES OR CONDITIONS OF 11 | # ANY KIND, either express or implied. See the License for the specific 12 | # language governing permissions and limitations under the License. 13 | import json 14 | import os 15 | 16 | from mpi4py import MPI 17 | 18 | comm = MPI.COMM_WORLD 19 | size = comm.Get_size() 20 | rank = comm.Get_rank() 21 | 22 | data = {"rank": rank, "size": size} 23 | data = comm.gather(data, root=0) 24 | if rank == 0: 25 | assert data == [{"rank": 0, "size": 2}, {"rank": 1, "size": 2}] 26 | 27 | model = os.path.join(os.environ["SM_MODEL_DIR"], "result.json") 28 | with open(model, "w+") as f: 29 | json.dump(data, f) 30 | else: 31 | assert data is None 32 | -------------------------------------------------------------------------------- /test/unit/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018-2020 Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"). You 4 | # may not use this file except in compliance with the License. A copy of 5 | # the License is located at 6 | # 7 | # http://aws.amazon.com/apache2.0/ 8 | # 9 | # or in the "license" file accompanying this file. This file is 10 | # distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF 11 | # ANY KIND, either express or implied. See the License for the specific 12 | # language governing permissions and limitations under the License. 13 | from __future__ import absolute_import 14 | -------------------------------------------------------------------------------- /test/unit/c/test_gethostname.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the 'License'). You 4 | # may not use this file except in compliance with the License. A copy of 5 | # the License is located at 6 | # 7 | # http://aws.amazon.com/apache2.0/ 8 | # 9 | # or in the 'license' file accompanying this file. This file is 10 | # distributed on an 'AS IS' BASIS, WITHOUT WARRANTIES OR CONDITIONS OF 11 | # ANY KIND, either express or implied. See the License for the specific 12 | # language governing permissions and limitations under the License. 13 | import json 14 | import os 15 | import shutil 16 | import sys 17 | 18 | import pytest 19 | 20 | import gethostname 21 | from sagemaker_training import errors, process 22 | 23 | OPT_ML = "/opt/ml" 24 | INPUT_CONFIG = "/opt/ml/input/config/" 25 | 26 | 27 | @pytest.fixture() 28 | def opt_ml_input_config(): 29 | if os.path.exists(OPT_ML): 30 | shutil.rmtree(OPT_ML) 31 | 32 | try: 33 | os.makedirs(INPUT_CONFIG) 34 | 35 | yield INPUT_CONFIG 36 | 37 | finally: 38 | shutil.rmtree(OPT_ML) 39 | 40 | 41 | @pytest.mark.parametrize( 42 | "content,value", 43 | [ 44 | [{"channel": "training", "current_host": "algo-5", "File": "pipe"}, "algo-5"], 45 | [{"current_host": "algo-1-thse"}, "algo-1-thse"], 46 | ], 47 | ) 48 | @pytest.mark.xfail( 49 | os.environ.get("IS_CODEBUILD_IMAGE") != "true", 50 | reason="Needs root permissions to create /opt/ml when run locally.", 51 | ) 52 | def test_gethostname_resource_config_set(content, value, opt_ml_input_config): 53 | with open("/opt/ml/input/config/resourceconfig.json", "w") as f: 54 | json.dump(content, f) 55 | 56 | assert gethostname.call(30) 57 | 58 | 59 | @pytest.mark.xfail( 60 | os.environ.get("IS_CODEBUILD_IMAGE") != "true", 61 | reason="Needs root permissions to create /opt/ml when run locally.", 62 | ) 63 | def test_gethostname_with_env_not_set(opt_ml_input_config): 64 | py_cmd = "import gethostname\nassert gethostname.call(30) == 'algo-9'" 65 | 66 | with pytest.raises(errors.ExecuteUserScriptError): 67 | process.check_error([sys.executable, "-c", py_cmd], errors.ExecuteUserScriptError) 68 | -------------------------------------------------------------------------------- /test/unit/cli/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018-2020 Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"). You 4 | # may not use this file except in compliance with the License. A copy of 5 | # the License is located at 6 | # 7 | # http://aws.amazon.com/apache2.0/ 8 | # 9 | # or in the "license" file accompanying this file. This file is 10 | # distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF 11 | # ANY KIND, either express or implied. See the License for the specific 12 | # language governing permissions and limitations under the License. 13 | from __future__ import absolute_import 14 | -------------------------------------------------------------------------------- /test/unit/cli/test_train.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the 'License'). You 4 | # may not use this file except in compliance with the License. A copy of 5 | # the License is located at 6 | # 7 | # http://aws.amazon.com/apache2.0/ 8 | # 9 | # or in the 'license' file accompanying this file. This file is 10 | # distributed on an 'AS IS' BASIS, WITHOUT WARRANTIES OR CONDITIONS OF 11 | # ANY KIND, either express or implied. See the License for the specific 12 | # language governing permissions and limitations under the License. 13 | from mock import patch 14 | 15 | from sagemaker_training.cli import train as train_cli 16 | 17 | 18 | @patch("sagemaker_training.trainer.train") 19 | def test_entry_point(train): 20 | train_cli.main() 21 | train.assert_called() 22 | -------------------------------------------------------------------------------- /test/unit/dummy/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aws/sagemaker-training-toolkit/a0073959d4de7e8311b370a1f143e44dcf801d1e/test/unit/dummy/__init__.py -------------------------------------------------------------------------------- /test/unit/dummy/dummy.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the 'License'). You 4 | # may not use this file except in compliance with the License. A copy of 5 | # the License is located at 6 | # 7 | # http://aws.amazon.com/apache2.0/ 8 | # 9 | # or in the 'license' file accompanying this file. This file is 10 | # distributed on an 'AS IS' BASIS, WITHOUT WARRANTIES OR CONDITIONS OF 11 | # ANY KIND, either express or implied. See the License for the specific 12 | # language governing permissions and limitations under the License. 13 | """This module contains dummy function to raise an exception""" 14 | 15 | 16 | def dummy_function(): 17 | """dummy function to raise an exception""" 18 | raise Exception("raising dummy exception") 19 | -------------------------------------------------------------------------------- /test/unit/dummy/tensorflow/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aws/sagemaker-training-toolkit/a0073959d4de7e8311b370a1f143e44dcf801d1e/test/unit/dummy/tensorflow/__init__.py -------------------------------------------------------------------------------- /test/unit/dummy/tensorflow/compiler/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aws/sagemaker-training-toolkit/a0073959d4de7e8311b370a1f143e44dcf801d1e/test/unit/dummy/tensorflow/compiler/__init__.py -------------------------------------------------------------------------------- /test/unit/dummy/tensorflow/compiler/xla/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aws/sagemaker-training-toolkit/a0073959d4de7e8311b370a1f143e44dcf801d1e/test/unit/dummy/tensorflow/compiler/xla/__init__.py -------------------------------------------------------------------------------- /test/unit/dummy/tensorflow/compiler/xla/dummy_xla.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the 'License'). You 4 | # may not use this file except in compliance with the License. A copy of 5 | # the License is located at 6 | # 7 | # http://aws.amazon.com/apache2.0/ 8 | # 9 | # or in the 'license' file accompanying this file. This file is 10 | # distributed on an 'AS IS' BASIS, WITHOUT WARRANTIES OR CONDITIONS OF 11 | # ANY KIND, either express or implied. See the License for the specific 12 | # language governing permissions and limitations under the License. 13 | """This module contains dummy function to raise an exception""" 14 | 15 | 16 | def dummy_xla_function(): 17 | """dummy xla function to raise an exception""" 18 | raise Exception("raising xla dummy exception") 19 | -------------------------------------------------------------------------------- /test/unit/test_encoder.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the 'License'). You 4 | # may not use this file except in compliance with the License. A copy of 5 | # the License is located at 6 | # 7 | # http://aws.amazon.com/apache2.0/ 8 | # 9 | # or in the 'license' file accompanying this file. This file is 10 | # distributed on an 'AS IS' BASIS, WITHOUT WARRANTIES OR CONDITIONS OF 11 | # ANY KIND, either express or implied. See the License for the specific 12 | # language governing permissions and limitations under the License. 13 | import io 14 | import itertools 15 | 16 | from mock import Mock, patch 17 | import numpy as np 18 | import pytest 19 | from scipy import sparse 20 | from six import BytesIO 21 | 22 | from sagemaker_training import content_types, encoders, errors 23 | from sagemaker_training.record_pb2 import Record 24 | from sagemaker_training.recordio import _read_recordio 25 | 26 | 27 | @pytest.mark.parametrize( 28 | "target", 29 | ([42, 6, 9], [42.0, 6.0, 9.0], ["42", "6", "9"], ["42", "6", "9"], {42: {"6": 9.0}}), 30 | ) 31 | def test_npy_to_numpy(target): 32 | buffer = BytesIO() 33 | np.save(buffer, target) 34 | input_data = buffer.getvalue() 35 | 36 | actual = encoders.npy_to_numpy(input_data) 37 | 38 | np.testing.assert_equal(actual, np.array(target)) 39 | 40 | 41 | @pytest.mark.parametrize( 42 | "target", 43 | ([42, 6, 9], [42.0, 6.0, 9.0], ["42", "6", "9"], ["42", "6", "9"], {42: {"6": 9.0}}), 44 | ) 45 | def test_array_to_npy(target): 46 | input_data = np.array(target) 47 | 48 | actual = encoders.array_to_npy(input_data) 49 | 50 | np.testing.assert_equal(np.load(BytesIO(actual), allow_pickle=True), np.array(target)) 51 | 52 | actual = encoders.array_to_npy(target) 53 | 54 | np.testing.assert_equal(np.load(BytesIO(actual), allow_pickle=True), np.array(target)) 55 | 56 | 57 | @pytest.mark.parametrize( 58 | "target, expected", 59 | [ 60 | ("[42, 6, 9]", np.array([42, 6, 9])), 61 | ("[42.0, 6.0, 9.0]", np.array([42.0, 6.0, 9.0])), 62 | ('["42", "6", "9"]', np.array(["42", "6", "9"])), 63 | ('["42", "6", "9"]', np.array(["42", "6", "9"])), 64 | ], 65 | ) 66 | def test_json_to_numpy(target, expected): 67 | actual = encoders.json_to_numpy(target) 68 | np.testing.assert_equal(actual, expected) 69 | 70 | np.testing.assert_equal(encoders.json_to_numpy(target, dtype=int), expected.astype(int)) 71 | 72 | np.testing.assert_equal(encoders.json_to_numpy(target, dtype=float), expected.astype(float)) 73 | 74 | 75 | @pytest.mark.parametrize( 76 | "target, expected", 77 | [ 78 | ([42, 6, 9], "[42, 6, 9]"), 79 | ([42.0, 6.0, 9.0], "[42.0, 6.0, 9.0]"), 80 | (["42", "6", "9"], '["42", "6", "9"]'), 81 | ({42: {"6": 9.0}}, '{"42": {"6": 9.0}}'), 82 | ], 83 | ) 84 | def test_array_to_json(target, expected): 85 | actual = encoders.array_to_json(target) 86 | np.testing.assert_equal(actual, expected) 87 | 88 | actual = encoders.array_to_json(np.array(target)) 89 | np.testing.assert_equal(actual, expected) 90 | 91 | 92 | def test_array_to_json_exception(): 93 | with pytest.raises(TypeError): 94 | encoders.array_to_json(lambda x: 3) 95 | 96 | 97 | @pytest.mark.parametrize( 98 | "target, expected", 99 | [ 100 | ("42\n6\n9\n", np.array([42, 6, 9])), 101 | ("42.0\n6.0\n9.0\n", np.array([42.0, 6.0, 9.0])), 102 | ("42\n6\n9\n", np.array([42, 6, 9])), 103 | ('"False,"\n"True."\n"False,"\n', np.array(["False,", "True.", "False,"])), 104 | ('aaa\n"b""bb"\nccc\n', np.array(["aaa", 'b"bb', "ccc"])), 105 | ('"a\nb"\nc\n', np.array(["a\nb", "c"])), 106 | ], 107 | ) 108 | def test_csv_to_numpy(target, expected): 109 | actual = encoders.csv_to_numpy(target) 110 | np.testing.assert_equal(actual, expected) 111 | 112 | 113 | def test_csv_to_numpy_error(): 114 | with pytest.raises(errors.ClientError): 115 | encoders.csv_to_numpy("a\n", dtype="float") 116 | 117 | 118 | @pytest.mark.parametrize( 119 | "target, expected", 120 | [ 121 | ([42, 6, 9], "42\n6\n9\n"), 122 | ([42.0, 6.0, 9.0], "42.0\n6.0\n9.0\n"), 123 | (["42", "6", "9"], "42\n6\n9\n"), 124 | (["False,", "True.", "False,"], '"False,"\nTrue.\n"False,"\n'), 125 | (["aaa", 'b"bb', "ccc"], 'aaa\n"b""bb"\nccc\n'), 126 | (["a\nb", "c"], '"a\nb"\nc\n'), 127 | ], 128 | ) 129 | def test_array_to_csv(target, expected): 130 | actual = encoders.array_to_csv(target) 131 | np.testing.assert_equal(actual, expected) 132 | 133 | actual = encoders.array_to_csv(np.array(target)) 134 | np.testing.assert_equal(actual, expected) 135 | 136 | 137 | @pytest.mark.parametrize("content_type", [content_types.JSON, content_types.CSV, content_types.NPY]) 138 | def test_encode(content_type): 139 | encoder = Mock() 140 | with patch.dict(encoders.encoders_map, {content_type: encoder}, clear=True): 141 | encoders.encode(42, content_type) 142 | 143 | encoder.assert_called_once_with(42) 144 | 145 | 146 | def test_encode_error(): 147 | with pytest.raises(errors.UnsupportedFormatError): 148 | encoders.encode(42, content_types.OCTET_STREAM) 149 | 150 | 151 | def test_decode_error(): 152 | with pytest.raises(errors.UnsupportedFormatError): 153 | encoders.decode(42, content_types.OCTET_STREAM) 154 | 155 | 156 | @pytest.mark.parametrize("content_type", [content_types.JSON, content_types.CSV, content_types.NPY]) 157 | def test_decode(content_type): 158 | decoder = Mock() 159 | with patch.dict(encoders._decoders_map, {content_type: decoder}, clear=True): 160 | encoders.decode(42, content_type) 161 | 162 | decoder.assert_called_once_with(42) 163 | 164 | 165 | def test_array_to_recordio_dense(): 166 | array_data = [[1.0, 2.0, 3.0], [10.0, 20.0, 30.0]] 167 | buf = encoders.array_to_recordio_protobuf(np.array(array_data)) 168 | stream = io.BytesIO(buf) 169 | 170 | for record_data, expected in zip(_read_recordio(stream), array_data): 171 | record = Record() 172 | record.ParseFromString(record_data) 173 | assert record.features["values"].float64_tensor.values == expected 174 | 175 | 176 | def test_sparse_int_write_spmatrix_to_sparse_tensor(): 177 | n = 4 178 | array_data = [[1.0, 2.0], [10.0, 30.0], [100.0, 200.0, 300.0, 400.0], [1000.0, 2000.0, 3000.0]] 179 | keys_data = [[0, 1], [1, 2], [0, 1, 2, 3], [0, 2, 3]] 180 | 181 | flatten_data = list(itertools.chain.from_iterable(array_data)) 182 | y_indices = list(itertools.chain.from_iterable(keys_data)) 183 | x_indices = [[i] * len(keys_data[i]) for i in range(len(keys_data))] 184 | x_indices = list(itertools.chain.from_iterable(x_indices)) 185 | 186 | array = sparse.coo_matrix((flatten_data, (x_indices, y_indices)), dtype="int") 187 | buf = encoders.array_to_recordio_protobuf(array) 188 | stream = io.BytesIO(buf) 189 | 190 | for record_data, expected_data, expected_keys in zip( 191 | _read_recordio(stream), array_data, keys_data 192 | ): 193 | record = Record() 194 | record.ParseFromString(record_data) 195 | assert record.features["values"].int32_tensor.values == expected_data 196 | assert record.features["values"].int32_tensor.keys == expected_keys 197 | assert record.features["values"].int32_tensor.shape == [n] 198 | 199 | 200 | def test_sparse_float32_write_spmatrix_to_sparse_tensor(): 201 | n = 4 202 | array_data = [[1.0, 2.0], [10.0, 30.0], [100.0, 200.0, 300.0, 400.0], [1000.0, 2000.0, 3000.0]] 203 | keys_data = [[0, 1], [1, 2], [0, 1, 2, 3], [0, 2, 3]] 204 | 205 | flatten_data = list(itertools.chain.from_iterable(array_data)) 206 | y_indices = list(itertools.chain.from_iterable(keys_data)) 207 | x_indices = [[i] * len(keys_data[i]) for i in range(len(keys_data))] 208 | x_indices = list(itertools.chain.from_iterable(x_indices)) 209 | 210 | array = sparse.coo_matrix((flatten_data, (x_indices, y_indices)), dtype="float32") 211 | buf = encoders.array_to_recordio_protobuf(array) 212 | stream = io.BytesIO(buf) 213 | 214 | for record_data, expected_data, expected_keys in zip( 215 | _read_recordio(stream), array_data, keys_data 216 | ): 217 | record = Record() 218 | record.ParseFromString(record_data) 219 | assert record.features["values"].float32_tensor.values == expected_data 220 | assert record.features["values"].float32_tensor.keys == expected_keys 221 | assert record.features["values"].float32_tensor.shape == [n] 222 | -------------------------------------------------------------------------------- /test/unit/test_entry_point.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018-2021 Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the 'License'). You 4 | # may not use this file except in compliance with the License. A copy of 5 | # the License is located at 6 | # 7 | # http://aws.amazon.com/apache2.0/ 8 | # 9 | # or in the 'license' file accompanying this file. This file is 10 | # distributed on an 'AS IS' BASIS, WITHOUT WARRANTIES OR CONDITIONS OF 11 | # ANY KIND, either express or implied. See the License for the specific 12 | # language governing permissions and limitations under the License. 13 | from __future__ import absolute_import 14 | 15 | import os 16 | import sys 17 | 18 | from mock import call, MagicMock, patch, PropertyMock 19 | import pytest 20 | 21 | from sagemaker_training import entry_point, environment, errors, process, runner 22 | 23 | builtins_open = "builtins.open" 24 | 25 | 26 | @pytest.fixture 27 | def entry_point_type_module(): 28 | with patch("os.listdir", lambda x: ("setup.py",)): 29 | yield 30 | 31 | 32 | @pytest.fixture(autouse=True) 33 | def entry_point_type_script(): 34 | with patch("os.listdir", lambda x: ()): 35 | yield 36 | 37 | 38 | @pytest.fixture() 39 | def has_requirements(): 40 | with patch("os.path.exists", lambda x: x.endswith("requirements.txt")): 41 | yield 42 | 43 | 44 | @patch("sagemaker_training.modules.prepare") 45 | @patch("sagemaker_training.process.check_error", autospec=True) 46 | def test_install_module(check_error, prepare, entry_point_type_module): 47 | path = "c://sagemaker-pytorch-container" 48 | entry_point.install("python_module.py", path) 49 | 50 | cmd = [sys.executable, "-m", "pip", "install", "."] 51 | check_error.assert_called_with(cmd, errors.InstallModuleError, 1, capture_error=False, cwd=path) 52 | 53 | with patch("os.path.exists", return_value=True): 54 | entry_point.install("python_module.py", path) 55 | 56 | check_error.assert_called_with( 57 | cmd + ["-r", "requirements.txt"], 58 | errors.InstallModuleError, 59 | 1, 60 | cwd=path, 61 | capture_error=False, 62 | ) 63 | 64 | 65 | @patch("sagemaker_training.modules.prepare") 66 | @patch("sagemaker_training.process.check_error", autospec=True) 67 | def test_install_script(check_error, prepare, entry_point_type_module, has_requirements): 68 | path = "c://sagemaker-pytorch-container" 69 | entry_point.install("train.py", path) 70 | 71 | with patch("os.path.exists", return_value=True): 72 | entry_point.install(path, "python_module.py") 73 | 74 | 75 | @patch("sagemaker_training.modules.prepare") 76 | @patch("sagemaker_training.process.check_error", autospec=True) 77 | def test_install_fails(check_error, prepare, entry_point_type_module): 78 | check_error.side_effect = errors.ClientError() 79 | with pytest.raises(errors.ClientError): 80 | entry_point.install("git://aws/container-support", "script") 81 | 82 | 83 | @patch("sagemaker_training.modules.prepare") 84 | @patch("sys.executable", None) 85 | @patch("sagemaker_training.process.check_error", autospec=True) 86 | def test_install_no_python_executable( 87 | check_error, prepare, has_requirements, entry_point_type_module 88 | ): 89 | with pytest.raises(RuntimeError) as e: 90 | entry_point.install("train.py", "git://aws/container-support") 91 | assert str(e.value) == "Failed to retrieve the real path for the Python executable binary" 92 | 93 | 94 | @patch("os.chmod") 95 | @patch("sagemaker_training.process.check_error", autospec=True) 96 | @patch("socket.gethostbyname") 97 | def test_script_entry_point_with_python_package( 98 | gethostbyname, check_error, chmod, entry_point_type_module 99 | ): 100 | runner_mock = MagicMock(spec=process.ProcessRunner) 101 | 102 | entry_point.run( 103 | uri="s3://dummy-uri", 104 | user_entry_point="train.sh", 105 | args=["dummy_arg"], 106 | runner_type=runner_mock, 107 | ) 108 | 109 | chmod.assert_called_with(os.path.join(environment.code_dir, "train.sh"), 511) 110 | 111 | 112 | @patch("sagemaker_training.files.download_and_extract") 113 | @patch("os.chmod") 114 | @patch("sagemaker_training.process.check_error", autospec=True) 115 | @patch("socket.gethostbyname") 116 | def test_run_module_wait(gethostbyname, check_error, chmod, download_and_extract): 117 | runner_mock = MagicMock(spec=process.ProcessRunner) 118 | 119 | entry_point.run( 120 | uri="s3://url", 121 | user_entry_point="launcher.sh", 122 | args=["42"], 123 | capture_error=True, 124 | runner_type=runner_mock, 125 | ) 126 | 127 | download_and_extract.assert_called_with(uri="s3://url", path=environment.code_dir) 128 | runner_mock.run.assert_called_with(True, True) 129 | chmod.assert_called_with(os.path.join(environment.code_dir, "launcher.sh"), 511) 130 | 131 | 132 | @patch("sagemaker_training.files.download_and_extract") 133 | @patch("sagemaker_training.modules.install") 134 | @patch.object( 135 | environment.Environment, "hosts", return_value=["algo-1", "algo-2"], new_callable=PropertyMock 136 | ) 137 | @patch("socket.gethostbyname") 138 | def test_run_calls_hostname_resolution(gethostbyname, install, hosts, download_and_extract): 139 | runner_mock = MagicMock(spec=process.ProcessRunner) 140 | entry_point.run( 141 | uri="s3://url", user_entry_point="launcher.py", args=["42"], runner_type=runner_mock 142 | ) 143 | 144 | gethostbyname.assert_called_with("algo-2") 145 | gethostbyname.assert_any_call("algo-1") 146 | 147 | 148 | @patch("sagemaker_training.files.download_and_extract") 149 | @patch("sagemaker_training.modules.install") 150 | @patch.object( 151 | environment.Environment, "hosts", return_value=["algo-1", "algo-2"], new_callable=PropertyMock 152 | ) 153 | @patch("socket.gethostbyname") 154 | def test_run_skips_hostname_resolution_in_studio_local_mode( 155 | gethostbyname, install, hosts, download_and_extract 156 | ): 157 | os.environ["SM_STUDIO_LOCAL_MODE"] = "True" 158 | runner_mock = MagicMock(spec=process.ProcessRunner) 159 | entry_point.run( 160 | uri="s3://url", user_entry_point="launcher.py", args=["42"], runner_type=runner_mock 161 | ) 162 | gethostbyname.assert_not_called() 163 | del os.environ["SM_STUDIO_LOCAL_MODE"] 164 | 165 | 166 | @patch("sagemaker_training.files.download_and_extract") 167 | @patch("sagemaker_training.modules.install") 168 | @patch.object( 169 | environment.Environment, "hosts", return_value=["algo-1", "algo-2"], new_callable=PropertyMock 170 | ) 171 | @patch("socket.gethostbyname") 172 | def test_run_waits_hostname_resolution(gethostbyname, hosts, install, download_and_extract): 173 | gethostbyname.side_effect = [ValueError(), ValueError(), True, True] 174 | 175 | runner_mock = MagicMock(spec=process.ProcessRunner) 176 | entry_point.run( 177 | uri="s3://url", user_entry_point="launcher.py", args=["42"], runner_type=runner_mock 178 | ) 179 | 180 | gethostbyname.assert_has_calls([call("algo-1"), call("algo-1"), call("algo-1"), call("algo-2")]) 181 | 182 | 183 | @patch("sagemaker_training.files.download_and_extract") 184 | @patch("os.chmod") 185 | @patch("socket.gethostbyname") 186 | def test_run_module_no_wait(gethostbyname, chmod, download_and_extract): 187 | runner_mock = MagicMock(spec=process.ProcessRunner) 188 | 189 | module_name = "default_user_module_name" 190 | entry_point.run( 191 | uri="s3://url", 192 | user_entry_point=module_name, 193 | args=["42"], 194 | wait=False, 195 | runner_type=runner_mock, 196 | ) 197 | 198 | runner_mock.run.assert_called_with(False, False) 199 | 200 | 201 | @patch("sys.path") 202 | @patch("sagemaker_training.runner.get") 203 | @patch("sagemaker_training.files.download_and_extract") 204 | @patch("os.chmod") 205 | @patch("socket.gethostbyname") 206 | def test_run_module_with_env_vars(gethostbyname, chmod, download_and_extract, get_runner, sys_path): 207 | module_name = "default_user_module_name" 208 | args = ["--some-arg", "42"] 209 | entry_point.run( 210 | uri="s3://url", user_entry_point=module_name, args=args, env_vars={"FOO": "BAR"} 211 | ) 212 | 213 | expected_env_vars = {"FOO": "BAR", "PYTHONPATH": ""} 214 | get_runner.assert_called_with( 215 | runner.ProcessRunnerType, module_name, args, expected_env_vars, None 216 | ) 217 | 218 | 219 | @patch("sys.path") 220 | @patch("sagemaker_training.runner.get") 221 | @patch("sagemaker_training.files.download_and_extract") 222 | @patch("os.chmod") 223 | @patch("socket.gethostbyname") 224 | def test_run_module_with_extra_opts( 225 | gethostbyname, chmod, download_and_extract, get_runner, sys_path 226 | ): 227 | module_name = "default_user_module_name" 228 | args = ["--some-arg", "42"] 229 | extra_opts = {"foo": "bar"} 230 | 231 | entry_point.run(uri="s3://url", user_entry_point=module_name, args=args, extra_opts=extra_opts) 232 | get_runner.assert_called_with(runner.ProcessRunnerType, module_name, args, {}, extra_opts) 233 | -------------------------------------------------------------------------------- /test/unit/test_entry_point_type.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the 'License'). You 4 | # may not use this file except in compliance with the License. A copy of 5 | # the License is located at 6 | # 7 | # http://aws.amazon.com/apache2.0/ 8 | # 9 | # or in the 'license' file accompanying this file. This file is 10 | # distributed on an 'AS IS' BASIS, WITHOUT WARRANTIES OR CONDITIONS OF 11 | # ANY KIND, either express or implied. See the License for the specific 12 | # language governing permissions and limitations under the License. 13 | from __future__ import absolute_import 14 | 15 | from mock import patch 16 | import pytest 17 | 18 | from sagemaker_training import _entry_point_type 19 | 20 | 21 | @pytest.fixture 22 | def entry_point_type_module(): 23 | with patch("os.listdir", lambda x: ("setup.py",)): 24 | yield 25 | 26 | 27 | @pytest.fixture(autouse=True) 28 | def entry_point_type_script(): 29 | with patch("os.listdir", lambda x: ()): 30 | yield 31 | 32 | 33 | @pytest.fixture() 34 | def has_requirements(): 35 | with patch("os.path.exists", lambda x: x.endswith("requirements.txt")): 36 | yield 37 | 38 | 39 | def test_get_package(entry_point_type_module): 40 | assert _entry_point_type.get("bla", "program.py") == _entry_point_type.PYTHON_PACKAGE 41 | 42 | 43 | def test_get_command(entry_point_type_script): 44 | assert _entry_point_type.get("bla", "program.sh") == _entry_point_type.COMMAND 45 | 46 | 47 | def test_get_program(): 48 | assert _entry_point_type.get("bla", "program.py") == _entry_point_type.PYTHON_PROGRAM 49 | -------------------------------------------------------------------------------- /test/unit/test_errors.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the 'License'). You 4 | # may not use this file except in compliance with the License. A copy of 5 | # the License is located at 6 | # 7 | # http://aws.amazon.com/apache2.0/ 8 | # 9 | # or in the 'license' file accompanying this file. This file is 10 | # distributed on an 'AS IS' BASIS, WITHOUT WARRANTIES OR CONDITIONS OF 11 | # ANY KIND, either express or implied. See the License for the specific 12 | # language governing permissions and limitations under the License. 13 | from __future__ import absolute_import 14 | 15 | from sagemaker_training import errors 16 | 17 | 18 | def test_install_module_error(): 19 | error = errors.InstallModuleError(["python", "-m", "42"], return_code=42) 20 | assert ( 21 | str(error) == 'InstallModuleError:\nExitCode 42\nErrorMessage ""\nCommand' 22 | " \"['python', '-m', '42']\"" 23 | ) 24 | 25 | 26 | def test_execute_user_script_error(): 27 | error = errors.ExecuteUserScriptError(["python", "-m", "42"], return_code=42) 28 | 29 | assert ( 30 | str(error) == 'ExecuteUserScriptError:\nExitCode 42\nErrorMessage ""\nCommand' 31 | " \"['python', '-m', '42']\"" 32 | ) 33 | 34 | 35 | def test_install_module_error_with_output(): 36 | error = errors.InstallModuleError(["python", "-m", "42"], return_code=42, output="42") 37 | 38 | assert ( 39 | str(error) == 'InstallModuleError:\nExitCode 42\nErrorMessage "42"\nCommand' 40 | " \"['python', '-m', '42']\"" 41 | ) 42 | 43 | 44 | def test_execute_user_script_error_with_output(): 45 | error = errors.ExecuteUserScriptError(["python", "-m", "42"], return_code=137, output=b"42") 46 | 47 | assert ( 48 | str(error) == 'ExecuteUserScriptError:\nExitCode 137\nErrorMessage "42"\nCommand' 49 | " \"['python', '-m', '42']\"" 50 | ) 51 | 52 | 53 | def test_execute_user_script_error_with_output_and_info(): 54 | error = errors.ExecuteUserScriptError( 55 | ["python", "-m", "42"], return_code=137, output="42", info="SIGKILL" 56 | ) 57 | 58 | assert ( 59 | str(error) == "ExecuteUserScriptError:\nExitCode 137\nErrorMessage" 60 | " \"42\"\nExtraInfo \"SIGKILL\"\nCommand \"['python', '-m', '42']\"" 61 | ) 62 | -------------------------------------------------------------------------------- /test/unit/test_files.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the 'License'). You 4 | # may not use this file except in compliance with the License. A copy of 5 | # the License is located at 6 | # 7 | # http://aws.amazon.com/apache2.0/ 8 | # 9 | # or in the 'license' file accompanying this file. This file is 10 | # distributed on an 'AS IS' BASIS, WITHOUT WARRANTIES OR CONDITIONS OF 11 | # ANY KIND, either express or implied. See the License for the specific 12 | # language governing permissions and limitations under the License. 13 | import itertools 14 | import logging 15 | import os 16 | import tarfile 17 | 18 | from mock import mock_open, patch 19 | import pytest 20 | import six 21 | 22 | from sagemaker_training import environment, files 23 | import test 24 | 25 | builtins_open = "__builtin__.open" if six.PY2 else "builtins.open" 26 | 27 | RESOURCE_CONFIG = dict(current_host="algo-1", hosts=["algo-1", "algo-2", "algo-3"]) 28 | 29 | INPUT_DATA_CONFIG = { 30 | "train": { 31 | "ContentType": "trainingContentType", 32 | "TrainingInputMode": "File", 33 | "S3DistributionType": "FullyReplicated", 34 | "RecordWrapperType": "None", 35 | }, 36 | "validation": { 37 | "TrainingInputMode": "File", 38 | "S3DistributionType": "FullyReplicated", 39 | "RecordWrapperType": "None", 40 | }, 41 | } 42 | 43 | USER_HYPERPARAMETERS = dict(batch_size=32, learning_rate=0.001) 44 | SAGEMAKER_HYPERPARAMETERS = { 45 | "sagemaker_region": "us-west-2", 46 | "default_user_module_name": "net", 47 | "sagemaker_job_name": "sagemaker-training-job", 48 | "sagemaker_program": "main.py", 49 | "sagemaker_submit_directory": "imagenet", 50 | "sagemaker_enable_cloudwatch_metrics": True, 51 | "sagemaker_container_log_level": logging.WARNING, 52 | } 53 | 54 | ALL_HYPERPARAMETERS = dict( 55 | itertools.chain(USER_HYPERPARAMETERS.items(), SAGEMAKER_HYPERPARAMETERS.items()) 56 | ) 57 | 58 | 59 | def test_read_json(): 60 | test.write_json(ALL_HYPERPARAMETERS, environment.hyperparameters_file_dir) 61 | 62 | assert files.read_json(environment.hyperparameters_file_dir) == ALL_HYPERPARAMETERS 63 | 64 | 65 | def test_read_json_throws_exception(): 66 | with pytest.raises(IOError): 67 | files.read_json("non-existent.json") 68 | 69 | 70 | def test_read_file(): 71 | test.write_json("test", environment.hyperparameters_file_dir) 72 | 73 | assert files.read_file(environment.hyperparameters_file_dir) == '"test"' 74 | 75 | 76 | @patch("tempfile.mkdtemp") 77 | @patch("shutil.rmtree") 78 | def test_tmpdir(rmtree, mkdtemp): 79 | with files.tmpdir(): 80 | mkdtemp.assert_called() 81 | rmtree.assert_called() 82 | 83 | 84 | @patch("tempfile.mkdtemp") 85 | @patch("shutil.rmtree") 86 | def test_tmpdir_with_args(rmtree, mkdtemp): 87 | with files.tmpdir("suffix", "prefix", "/tmp"): 88 | mkdtemp.assert_called_with(dir="/tmp", prefix="prefix", suffix="suffix") 89 | rmtree.assert_called() 90 | 91 | 92 | @patch(builtins_open, mock_open()) 93 | def test_write_file(): 94 | files.write_file("/tmp/my-file", "42") 95 | open.assert_called_with("/tmp/my-file", "w") 96 | open().write.assert_called_with("42") 97 | 98 | files.write_file("/tmp/my-file", "42", "a") 99 | open.assert_called_with("/tmp/my-file", "a") 100 | open().write.assert_called_with("42") 101 | 102 | 103 | @patch(builtins_open, mock_open()) 104 | def test_write_success_file(): 105 | file_path = os.path.join(environment.output_dir, "success") 106 | empty_msg = "" 107 | files.write_success_file() 108 | open.assert_called_with(file_path, "w") 109 | open().write.assert_called_with(empty_msg) 110 | 111 | 112 | @patch(builtins_open, mock_open()) 113 | def test_write_failure_file(): 114 | file_path = os.path.join(environment.output_dir, "failure") 115 | failure_msg = "This is a failure" 116 | files.write_failure_file(failure_msg) 117 | open.assert_called_with(file_path, "w") 118 | open().write.assert_called_with(failure_msg) 119 | 120 | 121 | @patch("sagemaker_training.files.s3_download") 122 | @patch("os.path.isdir", lambda x: True) 123 | @patch("shutil.rmtree") 124 | @patch("shutil.copytree") 125 | def test_download_and_extract_source_dir(copy, rmtree, s3_download): 126 | uri = environment.channel_path("code") 127 | files.download_and_extract(uri, environment.code_dir) 128 | s3_download.assert_not_called() 129 | 130 | rmtree.assert_any_call(environment.code_dir) 131 | copy.assert_called_with(uri, environment.code_dir) 132 | 133 | 134 | @patch("sagemaker_training.files.s3_download") 135 | @patch("os.path.isdir", lambda x: False) 136 | @patch("shutil.copy2") 137 | def test_download_and_extract_file(copy, s3_download): 138 | uri = __file__ 139 | files.download_and_extract(uri, environment.code_dir) 140 | 141 | s3_download.assert_not_called() 142 | copy.assert_called_with(uri, environment.code_dir) 143 | 144 | 145 | @patch("sagemaker_training.files.s3_download") 146 | @patch("os.path.isdir", lambda x: False) 147 | @patch("tarfile.TarFile.extractall") 148 | def test_download_and_extract_tar(extractall, s3_download): 149 | t = tarfile.open(name="test.tar.gz", mode="w:gz") 150 | t.close() 151 | uri = t.name 152 | files.download_and_extract(uri, environment.code_dir) 153 | 154 | s3_download.assert_not_called() 155 | extractall.assert_called_with(path=environment.code_dir) 156 | 157 | os.remove(uri) 158 | -------------------------------------------------------------------------------- /test/unit/test_functions.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the 'License'). You 4 | # may not use this file except in compliance with the License. A copy of 5 | # the License is located at 6 | # 7 | # http://aws.amazon.com/apache2.0/ 8 | # 9 | # or in the 'license' file accompanying this file. This file is 10 | # distributed on an 'AS IS' BASIS, WITHOUT WARRANTIES OR CONDITIONS OF 11 | # ANY KIND, either express or implied. See the License for the specific 12 | # language governing permissions and limitations under the License. 13 | from __future__ import absolute_import 14 | 15 | import inspect 16 | 17 | import pytest as pytest 18 | 19 | from sagemaker_training import functions 20 | 21 | 22 | @pytest.mark.parametrize( 23 | "fn, expected", 24 | [ 25 | (lambda: None, inspect.ArgSpec([], None, None, None)), 26 | (lambda x, y="y": None, inspect.ArgSpec(["x", "y"], None, None, ("y",))), 27 | (lambda *args: None, inspect.ArgSpec([], "args", None, None)), 28 | (lambda **kwargs: None, inspect.ArgSpec([], None, "kwargs", None)), 29 | (lambda x, y, *args, **kwargs: None, inspect.ArgSpec(["x", "y"], "args", "kwargs", None)), 30 | ], 31 | ) 32 | def test_getargspec(fn, expected): 33 | assert functions.getargspec(fn) == expected 34 | 35 | 36 | @pytest.mark.parametrize( 37 | "fn, env, expected", 38 | [ 39 | (lambda: None, {}, {}), 40 | (lambda x, y="y": None, dict(x="x", y=None, t=3), dict(x="x", y=None)), 41 | (lambda not_in_env_arg: None, dict(x="x", y=None, t=3), {}), 42 | (lambda *args: None, dict(x="x", y=None, t=3), {}), 43 | (lambda *arguments, **keywords: None, dict(x="x", y=None, t=3), dict(x="x", y=None, t=3)), 44 | (lambda **kwargs: None, dict(x="x", y=None, t=3), dict(x="x", y=None, t=3)), 45 | ], 46 | ) 47 | def test_matching_args(fn, env, expected): 48 | assert functions.matching_args(fn, env) == expected 49 | 50 | 51 | def test_error_wrapper(): 52 | assert functions.error_wrapper(lambda x: x * 10, NotImplementedError)(3) == 30 53 | 54 | 55 | def test_error_wrapper_exception(): 56 | with pytest.raises(NotImplementedError) as e: 57 | functions.error_wrapper(lambda x: x, NotImplementedError)(2, 3) 58 | assert isinstance(e.value.args[0], TypeError) 59 | -------------------------------------------------------------------------------- /test/unit/test_intermediate_output.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the 'License'). You 4 | # may not use this file except in compliance with the License. A copy of 5 | # the License is located at 6 | # 7 | # http://aws.amazon.com/apache2.0/ 8 | # 9 | # or in the 'license' file accompanying this file. This file is 10 | # distributed on an 'AS IS' BASIS, WITHOUT WARRANTIES OR CONDITIONS OF 11 | # ANY KIND, either express or implied. See the License for the specific 12 | # language governing permissions and limitations under the License. 13 | from __future__ import absolute_import 14 | 15 | import os 16 | 17 | from inotify_simple import Event, flags 18 | from mock import MagicMock, patch 19 | import pytest 20 | 21 | from sagemaker_training import environment, files, intermediate_output 22 | 23 | REGION = "us-west" 24 | S3_BUCKET = "s3://mybucket/" 25 | 26 | 27 | def test_accept_file_output_no_process(): 28 | intemediate_sync = intermediate_output.start_sync("file://my/favorite/file", REGION) 29 | assert intemediate_sync is None 30 | 31 | 32 | def test_wrong_output(): 33 | with pytest.raises(ValueError) as e: 34 | intermediate_output.start_sync("tcp://my/favorite/url", REGION) 35 | assert "Expecting 's3' scheme" in str(e) 36 | 37 | 38 | @patch("inotify_simple.INotify", MagicMock()) 39 | @patch("multiprocessing.Process.start", MagicMock()) 40 | @patch("multiprocessing.Process.join", MagicMock()) 41 | @patch("boto3.client", MagicMock()) 42 | def test_daemon_process(): 43 | intemediate_sync = intermediate_output.start_sync(S3_BUCKET, REGION) 44 | assert intemediate_sync.daemon is True 45 | 46 | 47 | @patch("boto3.client", MagicMock()) 48 | @patch("shutil.copy2") 49 | @patch("inotify_simple.INotify") 50 | @patch("boto3.s3.transfer.S3Transfer.upload_file") 51 | @patch("multiprocessing.Process") 52 | def test_non_write_ignored(process_mock, upload_file, inotify_mock, copy2): 53 | process = process_mock.return_value 54 | inotify = inotify_mock.return_value 55 | 56 | inotify.add_watch.return_value = 1 57 | mask = flags.CREATE 58 | for flag in flags: 59 | if flag is not flags.CLOSE_WRITE and flag is not flags.ISDIR: 60 | mask = mask | flag 61 | inotify.read.return_value = [Event(1, mask, "cookie", "file_name")] 62 | 63 | def watch(): 64 | call = process_mock.call_args 65 | args, kwargs = call 66 | intermediate_output._watch( 67 | kwargs["args"][0], kwargs["args"][1], kwargs["args"][2], kwargs["args"][3] 68 | ) 69 | 70 | process.start.side_effect = watch 71 | 72 | files.write_success_file() 73 | intermediate_output.start_sync(S3_BUCKET, REGION) 74 | 75 | inotify.add_watch.assert_called() 76 | inotify.read.assert_called() 77 | copy2.assert_not_called() 78 | upload_file.assert_not_called() 79 | 80 | 81 | @patch("boto3.client", MagicMock()) 82 | @patch("shutil.copy2") 83 | @patch("inotify_simple.INotify") 84 | @patch("boto3.s3.transfer.S3Transfer.upload_file") 85 | @patch("multiprocessing.Process") 86 | def test_modification_triggers_upload(process_mock, upload_file, inotify_mock, copy2): 87 | process = process_mock.return_value 88 | inotify = inotify_mock.return_value 89 | 90 | inotify.add_watch.return_value = 1 91 | inotify.read.return_value = [Event(1, flags.CLOSE_WRITE, "cookie", "file_name")] 92 | 93 | def watch(): 94 | call = process_mock.call_args 95 | args, kwargs = call 96 | intermediate_output._watch( 97 | kwargs["args"][0], kwargs["args"][1], kwargs["args"][2], kwargs["args"][3] 98 | ) 99 | 100 | process.start.side_effect = watch 101 | 102 | files.write_success_file() 103 | intermediate_output.start_sync(S3_BUCKET, REGION) 104 | 105 | inotify.add_watch.assert_called() 106 | inotify.read.assert_called() 107 | copy2.assert_called() 108 | upload_file.assert_called() 109 | 110 | 111 | @patch("boto3.client", MagicMock()) 112 | @patch("shutil.copy2") 113 | @patch("inotify_simple.INotify") 114 | @patch("boto3.s3.transfer.S3Transfer.upload_file") 115 | @patch("multiprocessing.Process") 116 | def test_new_folders_are_watched(process_mock, upload_file, inotify_mock, copy2): 117 | process = process_mock.return_value 118 | inotify = inotify_mock.return_value 119 | 120 | new_dir = "new_dir" 121 | new_dir_path = os.path.join(environment.output_intermediate_dir, new_dir) 122 | inotify.add_watch.return_value = 1 123 | inotify.read.return_value = [Event(1, flags.CREATE | flags.ISDIR, "cookie", new_dir)] 124 | 125 | def watch(): 126 | os.makedirs(new_dir_path) 127 | 128 | call = process_mock.call_args 129 | args, kwargs = call 130 | intermediate_output._watch( 131 | kwargs["args"][0], kwargs["args"][1], kwargs["args"][2], kwargs["args"][3] 132 | ) 133 | 134 | process.start.side_effect = watch 135 | 136 | files.write_success_file() 137 | intermediate_output.start_sync(S3_BUCKET, REGION) 138 | 139 | watch_flags = flags.CLOSE_WRITE | flags.CREATE 140 | inotify.add_watch.assert_any_call(environment.output_intermediate_dir, watch_flags) 141 | inotify.add_watch.assert_any_call(new_dir_path, watch_flags) 142 | inotify.read.assert_called() 143 | copy2.assert_not_called() 144 | upload_file.assert_not_called() 145 | -------------------------------------------------------------------------------- /test/unit/test_runner.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018-2021 Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the 'License'). You 4 | # may not use this file except in compliance with the License. A copy of 5 | # the License is located at 6 | # 7 | # http://aws.amazon.com/apache2.0/ 8 | # 9 | # or in the 'license' file accompanying this file. This file is 10 | # distributed on an 'AS IS' BASIS, WITHOUT WARRANTIES OR CONDITIONS OF 11 | # ANY KIND, either express or implied. See the License for the specific 12 | # language governing permissions and limitations under the License. 13 | from __future__ import absolute_import 14 | 15 | from mock import MagicMock, patch 16 | import pytest 17 | 18 | from sagemaker_training import mpi, process, pytorch_xla, runner 19 | 20 | USER_SCRIPT = "script" 21 | CMD_ARGS = ["--some-arg", 42] 22 | ENV_VARS = {"FOO": "BAR"} 23 | DEFAULT_PROC_PER_HOST = 1 24 | 25 | NCCL_DEBUG_MPI_OPT = "-X NCCL_DEBUG=WARN" 26 | MPI_OPTS = { 27 | "sagemaker_mpi_num_of_processes_per_host": 2, 28 | "sagemaker_mpi_num_processes": 4, 29 | "sagemaker_mpi_custom_mpi_options": NCCL_DEBUG_MPI_OPT, 30 | } 31 | 32 | 33 | @pytest.mark.parametrize( 34 | "runner_class", [process.ProcessRunner, mpi.MasterRunner, mpi.WorkerRunner] 35 | ) 36 | def test_get_runner_returns_runnner_itself(runner_class): 37 | runner_mock = MagicMock(spec=runner_class) 38 | 39 | assert runner.get(runner_mock) == runner_mock 40 | 41 | 42 | @patch("sagemaker_training.environment.Environment") 43 | def test_get_runner_by_process_returns_runnner(training_env): 44 | test_runner = runner.get(runner.ProcessRunnerType) 45 | 46 | assert isinstance(test_runner, process.ProcessRunner) 47 | training_env().to_cmd_args.assert_called() 48 | training_env().to_env_vars.assert_called() 49 | 50 | 51 | @patch("sagemaker_training.environment.Environment") 52 | def test_get_runner_by_process_with_extra_args(training_env): 53 | test_runner = runner.get(runner.ProcessRunnerType, USER_SCRIPT, CMD_ARGS, ENV_VARS) 54 | 55 | assert isinstance(test_runner, process.ProcessRunner) 56 | 57 | assert test_runner._user_entry_point == USER_SCRIPT 58 | assert test_runner._args == CMD_ARGS 59 | assert test_runner._env_vars == ENV_VARS 60 | 61 | training_env().to_cmd_args.assert_not_called() 62 | training_env().to_env_vars.assert_not_called() 63 | training_env().user_entry_point.assert_not_called() 64 | 65 | 66 | @patch("sagemaker_training.environment.Environment") 67 | def test_get_runner_by_mpi_returns_runnner(training_env): 68 | training_env().num_gpus = 0 69 | training_env().num_neurons = 0 70 | 71 | test_runner = runner.get(runner.MPIRunnerType) 72 | 73 | assert isinstance(test_runner, mpi.MasterRunner) 74 | training_env().to_cmd_args.assert_called() 75 | training_env().to_env_vars.assert_called() 76 | 77 | training_env().is_master = False 78 | test_runner = runner.get(runner.MPIRunnerType) 79 | 80 | assert isinstance(test_runner, mpi.WorkerRunner) 81 | training_env().to_cmd_args.assert_called() 82 | training_env().to_env_vars.assert_called() 83 | 84 | 85 | @patch("sagemaker_training.environment.Environment") 86 | def test_runnner_with_default_cpu_processes_per_host(training_env): 87 | training_env().additional_framework_parameters = dict() 88 | training_env().num_gpus = 0 89 | training_env().num_neurons = 0 90 | 91 | test_runner = runner.get(runner.MPIRunnerType) 92 | 93 | assert isinstance(test_runner, mpi.MasterRunner) 94 | assert test_runner._processes_per_host == 1 95 | 96 | 97 | @patch("sagemaker_training.environment.Environment") 98 | def test_runnner_with_default_gpu_processes_per_host(training_env): 99 | training_env().additional_framework_parameters = dict() 100 | training_env().num_gpus = 2 101 | training_env().num_neurons = 0 102 | 103 | test_runner = runner.get(runner.MPIRunnerType) 104 | 105 | assert isinstance(test_runner, mpi.MasterRunner) 106 | assert test_runner._processes_per_host == 2 107 | 108 | 109 | @patch("sagemaker_training.environment.Environment") 110 | def test_runnner_with_default_neuron_cores_per_host(training_env): 111 | training_env().additional_framework_parameters = dict() 112 | training_env().num_gpus = 0 113 | training_env().num_neurons = 2 114 | 115 | test_runner = runner.get(runner.MPIRunnerType) 116 | 117 | assert isinstance(test_runner, mpi.MasterRunner) 118 | assert test_runner._processes_per_host == 2 119 | 120 | 121 | @patch("sagemaker_training.environment.Environment") 122 | def test_get_runner_by_mpi_with_extra_args(training_env): 123 | training_env().num_gpus = 0 124 | training_env().num_neurons = 0 125 | 126 | test_runner = runner.get(runner.MPIRunnerType, USER_SCRIPT, CMD_ARGS, ENV_VARS, MPI_OPTS) 127 | 128 | assert isinstance(test_runner, mpi.MasterRunner) 129 | 130 | assert test_runner._user_entry_point == USER_SCRIPT 131 | assert test_runner._args == CMD_ARGS 132 | assert test_runner._env_vars == ENV_VARS 133 | assert test_runner._processes_per_host == 2 134 | assert test_runner._num_processes == 4 135 | assert test_runner._custom_mpi_options == NCCL_DEBUG_MPI_OPT 136 | 137 | training_env().to_cmd_args.assert_not_called() 138 | training_env().to_env_vars.assert_not_called() 139 | training_env().user_entry_point.assert_not_called() 140 | training_env().additional_framework_parameters.assert_not_called() 141 | 142 | training_env().is_master = False 143 | test_runner = runner.get(runner.MPIRunnerType, USER_SCRIPT, CMD_ARGS, ENV_VARS) 144 | 145 | assert isinstance(test_runner, mpi.WorkerRunner) 146 | 147 | assert test_runner._user_entry_point == USER_SCRIPT 148 | assert test_runner._args == CMD_ARGS 149 | assert test_runner._env_vars == ENV_VARS 150 | 151 | training_env().to_cmd_args.assert_not_called() 152 | training_env().to_env_vars.assert_not_called() 153 | training_env().user_entry_point.assert_not_called() 154 | 155 | 156 | def test_get_runner_invalid_identifier(): 157 | with pytest.raises(ValueError): 158 | runner.get(42) 159 | 160 | 161 | @patch("sagemaker_training.environment.Environment") 162 | def test_get_runner_by_pt_xla_returns_runnner(training_env): 163 | training_env().num_gpus = 8 164 | 165 | for is_master in [True, False]: 166 | training_env().is_master = is_master 167 | test_runner = runner.get(runner.PyTorchXLARunnerType) 168 | 169 | assert isinstance(test_runner, pytorch_xla.PyTorchXLARunner) 170 | training_env().to_cmd_args.assert_called() 171 | training_env().to_env_vars.assert_called() 172 | -------------------------------------------------------------------------------- /test/unit/test_timeout.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the 'License'). You 4 | # may not use this file except in compliance with the License. A copy of 5 | # the License is located at 6 | # 7 | # http://aws.amazon.com/apache2.0/ 8 | # 9 | # or in the 'license' file accompanying this file. This file is 10 | # distributed on an 'AS IS' BASIS, WITHOUT WARRANTIES OR CONDITIONS OF 11 | # ANY KIND, either express or implied. See the License for the specific 12 | # language governing permissions and limitations under the License. 13 | from __future__ import absolute_import 14 | 15 | import time 16 | 17 | import pytest 18 | 19 | from sagemaker_training import timeout 20 | 21 | 22 | def test_timeout(): 23 | sec = 2 24 | with pytest.raises(timeout.TimeoutError): 25 | with timeout.timeout(seconds=sec): 26 | print("Waiting and testing timeout, it should happen in {} seconds.".format(sec)) 27 | time.sleep(sec + 1) 28 | -------------------------------------------------------------------------------- /test/unit/test_trainer.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the 'License'). You 4 | # may not use this file except in compliance with the License. A copy of 5 | # the License is located at 6 | # 7 | # http://aws.amazon.com/apache2.0/ 8 | # 9 | # or in the 'license' file accompanying this file. This file is 10 | # distributed on an 'AS IS' BASIS, WITHOUT WARRANTIES OR CONDITIONS OF 11 | # ANY KIND, either express or implied. See the License for the specific 12 | # language governing permissions and limitations under the License. 13 | import errno 14 | import logging 15 | 16 | from mock import MagicMock, Mock, patch 17 | 18 | from sagemaker_training import errors, runner, trainer 19 | 20 | 21 | class Environment(Mock): 22 | framework_module = "my_framework:entry_point" 23 | log_level = 20 24 | 25 | def sagemaker_s3_output(self): 26 | return "s3://bucket" 27 | 28 | 29 | class EnvironmentNoIntermediate(Mock): 30 | framework_module = "my_framework:entry_point" 31 | log_level = 20 32 | 33 | def sagemaker_s3_output(self): 34 | return None 35 | 36 | 37 | class ScriptEnvironment(Environment): 38 | framework_module = None 39 | current_instance_group = "Test1" 40 | distribution_instance_groups = ["Test1"] 41 | 42 | def sagemaker_s3_output(self): 43 | return "s3://bucket" 44 | 45 | 46 | @patch("inotify_simple.INotify", MagicMock()) 47 | @patch("multiprocessing.Process.start", MagicMock()) 48 | @patch("multiprocessing.Process.join", MagicMock()) 49 | @patch("boto3.client", MagicMock()) 50 | @patch("importlib.import_module") 51 | @patch("sagemaker_training.environment.Environment", Environment) 52 | def test_train(import_module): 53 | framework = Mock() 54 | import_module.return_value = framework 55 | trainer.train() 56 | import_module.assert_called_with("my_framework") 57 | framework.entry_point.assert_called() 58 | 59 | 60 | @patch("inotify_simple.INotify", MagicMock()) 61 | @patch("multiprocessing.Process.start", MagicMock()) 62 | @patch("multiprocessing.Process.join", MagicMock()) 63 | @patch("boto3.client", MagicMock()) 64 | @patch("importlib.import_module") 65 | @patch("sagemaker_training.environment.Environment", Environment) 66 | @patch("sagemaker_training.trainer._exit_processes") 67 | def test_train_with_success(_exit, import_module): 68 | def success(): 69 | pass 70 | 71 | framework = Mock(entry_point=success) 72 | import_module.return_value = framework 73 | 74 | trainer.train() 75 | 76 | _exit.assert_called_with(trainer.SUCCESS_CODE) 77 | 78 | 79 | @patch("inotify_simple.INotify", MagicMock()) 80 | @patch("multiprocessing.Process.start", MagicMock()) 81 | @patch("multiprocessing.Process.join", MagicMock()) 82 | @patch("boto3.client", MagicMock()) 83 | @patch("importlib.import_module") 84 | @patch("sagemaker_training.environment.Environment", Environment) 85 | @patch("sagemaker_training.trainer._exit_processes") 86 | def test_train_fails(_exit, import_module): 87 | def fail(): 88 | raise OSError(errno.ENOENT, "No such file or directory") 89 | 90 | framework = Mock(entry_point=fail) 91 | import_module.return_value = framework 92 | 93 | trainer.train() 94 | 95 | _exit.assert_called_with(errno.ENOENT) 96 | 97 | 98 | @patch("inotify_simple.INotify", MagicMock()) 99 | @patch("boto3.client", MagicMock()) 100 | @patch("importlib.import_module") 101 | @patch("sagemaker_training.environment.Environment", Environment) 102 | @patch("sagemaker_training.trainer._exit_processes") 103 | def test_train_fails_with_no_error_number(_exit, import_module): 104 | def fail(): 105 | raise Exception("No errno defined.") 106 | 107 | framework = Mock(entry_point=fail) 108 | import_module.return_value = framework 109 | 110 | trainer.train() 111 | 112 | _exit.assert_called_with(trainer.DEFAULT_FAILURE_CODE) 113 | 114 | 115 | @patch("inotify_simple.INotify", MagicMock()) 116 | @patch("boto3.client", MagicMock()) 117 | @patch("importlib.import_module") 118 | @patch("sagemaker_training.environment.Environment", Environment) 119 | @patch("sagemaker_training.trainer._exit_processes") 120 | def test_train_fails_with_invalid_error_number(_exit, import_module): 121 | class InvalidErrorNumberExceptionError(Exception): 122 | def __init__(self, *args, **kwargs): # real signature unknown 123 | self.errno = "invalid" 124 | 125 | def fail(): 126 | raise InvalidErrorNumberExceptionError("No such file or directory") 127 | 128 | framework = Mock(entry_point=fail) 129 | import_module.return_value = framework 130 | 131 | trainer.train() 132 | 133 | _exit.assert_called_with(trainer.DEFAULT_FAILURE_CODE) 134 | 135 | 136 | @patch("inotify_simple.INotify", MagicMock()) 137 | @patch("boto3.client", MagicMock()) 138 | @patch("importlib.import_module") 139 | @patch("sagemaker_training.environment.Environment", Environment) 140 | @patch("sagemaker_training.trainer._exit_processes") 141 | def test_train_with_client_error(_exit, import_module): 142 | def fail(): 143 | raise errors.ClientError(errno.ENOENT, "No such file or directory") 144 | 145 | framework = Mock(entry_point=fail) 146 | import_module.return_value = framework 147 | 148 | trainer.train() 149 | 150 | _exit.assert_called_with(trainer.DEFAULT_FAILURE_CODE) 151 | 152 | 153 | @patch("inotify_simple.INotify", MagicMock()) 154 | @patch("boto3.client", MagicMock()) 155 | @patch("multiprocessing.Process.start", MagicMock()) 156 | @patch("multiprocessing.Process.join", MagicMock()) 157 | @patch("sagemaker_training.entry_point.run") 158 | @patch("sagemaker_training.environment.Environment", new_callable=ScriptEnvironment) 159 | @patch("sagemaker_training.trainer._exit_processes") 160 | def test_train_script(_exit, training_env, run): 161 | trainer.train() 162 | 163 | env = training_env() 164 | run.assert_called_with( 165 | env.module_dir, 166 | env.user_entry_point, 167 | env.to_cmd_args(), 168 | env.to_env_vars(), 169 | runner_type=runner.RunnerType.MPI, 170 | ) 171 | 172 | _exit.assert_called_with(trainer.SUCCESS_CODE) 173 | 174 | 175 | @patch("importlib.import_module") 176 | @patch("sagemaker_training.intermediate_output.start_sync") 177 | @patch("sagemaker_training.environment.Environment", EnvironmentNoIntermediate) 178 | def test_train_no_intermediate(start_intermediate_folder_sync, import_module): 179 | framework = Mock() 180 | import_module.return_value = framework 181 | trainer.train() 182 | 183 | import_module.assert_called_with("my_framework") 184 | framework.entry_point.assert_called() 185 | start_intermediate_folder_sync.asser_not_called() 186 | 187 | 188 | @patch("inotify_simple.INotify", MagicMock()) 189 | @patch("multiprocessing.Process.start", MagicMock()) 190 | @patch("multiprocessing.Process.join", MagicMock()) 191 | @patch("boto3.client", MagicMock()) 192 | @patch("importlib.import_module") 193 | @patch("sagemaker_training.environment.Environment", Environment) 194 | @patch("sagemaker_training.trainer._exit_processes") 195 | def test_train_with_smtrainingcompiler_error(_exit, import_module, caplog): 196 | def fail(): 197 | from .dummy.tensorflow.compiler.xla import dummy_xla 198 | 199 | dummy_xla.dummy_xla_function() 200 | 201 | framework = Mock(entry_point=fail) 202 | import_module.return_value = framework 203 | with caplog.at_level(logging.INFO): 204 | trainer.train() 205 | expected_errmsg = "SMTrainingCompiler Error:" 206 | unexpected_errmsg = "Framework Error:" 207 | assert expected_errmsg in caplog.text 208 | assert unexpected_errmsg not in caplog.text 209 | 210 | 211 | @patch("inotify_simple.INotify", MagicMock()) 212 | @patch("boto3.client", MagicMock()) 213 | @patch("importlib.import_module") 214 | @patch("sagemaker_training.environment.Environment", Environment) 215 | @patch("sagemaker_training.trainer._exit_processes") 216 | def test_train_with_framework_error(_exit, import_module, caplog): 217 | def fail(): 218 | from .dummy import dummy 219 | 220 | dummy.dummy_function() 221 | 222 | framework = Mock(entry_point=fail) 223 | import_module.return_value = framework 224 | with caplog.at_level(logging.INFO): 225 | trainer.train() 226 | unexpected_errmsg = "SMTrainingCompiler Error:" 227 | expected_errmsg = "Framework Error:" 228 | assert unexpected_errmsg not in caplog.text 229 | assert expected_errmsg in caplog.text 230 | -------------------------------------------------------------------------------- /tox.ini: -------------------------------------------------------------------------------- 1 | # Tox (http://tox.testrun.org/) is a tool for running tests 2 | # in multiple virtualenvs. This configuration file will run the 3 | # test suite on all supported Python versions. To use it, "pip install tox" 4 | # and then run "tox" from this directory. 5 | 6 | [tox] 7 | envlist = black-format,flake8,pylint,twine,py38,py39,py310 8 | 9 | skip_missing_interpreters = False 10 | 11 | [flake8] 12 | max-line-length = 120 13 | exclude = 14 | build/ 15 | .git 16 | __pycache__ 17 | .tox 18 | tests/data/ 19 | *venv/ 20 | ./src/sagemaker_training/record_pb2.py 21 | 22 | max-complexity = 15 23 | 24 | ignore = 25 | E203, # whitespace before ':': Black disagrees with and explicitly violates this. 26 | W503 # Ignore line break before binary operator, since Black violates this. 27 | 28 | builtins = FileNotFoundError 29 | 30 | [testenv] 31 | passenv = 32 | AWS_ACCESS_KEY_ID 33 | AWS_SECRET_ACCESS_KEY 34 | AWS_SESSION_TOKEN 35 | AWS_CONTAINER_CREDENTIALS_RELATIVE_URI 36 | AWS_DEFAULT_REGION 37 | 38 | # {posargs} can be passed in by additional arguments specified when invoking tox. 39 | # Can be used to specify which tests to run, e.g.: tox -- -s 40 | commands = 41 | coverage run --rcfile .coveragerc_{envname} --source sagemaker_training -m py.test {posargs} 42 | {env:IGNORE_COVERAGE:} coverage report --rcfile .coveragerc_{envname} 43 | {env:IGNORE_COVERAGE:} coverage html --rcfile .coveragerc_{envname} 44 | 45 | deps = 46 | pytest==6.2.5 47 | coverage 48 | pytest-cov 49 | pytest-xdist 50 | pytest-asyncio 51 | mock 52 | awslogs 53 | sagemaker[local] 54 | numpy 55 | flask 56 | gunicorn 57 | gevent 58 | paramiko==3.4.1 59 | psutil==6.0.0 60 | nest_asyncio 61 | 62 | [testenv:twine] 63 | basepython = python3.9 64 | # twine check was added starting in 1.12.0 65 | # https://github.com/pypa/twine/blob/master/docs/changelog.rst 66 | deps = 67 | twine>=1.12.0 68 | # https://packaging.python.org/guides/making-a-pypi-friendly-readme/#validating-restructuredtext-markup 69 | commands = 70 | python setup.py sdist 71 | twine check dist/*.tar.gz 72 | 73 | [testenv:flake8] 74 | basepython = python3.9 75 | deps = 76 | flake8 77 | pep8-naming 78 | flake8-import-order 79 | commands = flake8 --config=.flake8 80 | 81 | [testenv:black-format] 82 | # Used during development (before committing) to format .py files. 83 | basepython = python3.9 84 | deps = black==22.3.0 85 | commands = 86 | black -l 100 ./ 87 | 88 | [testenv:black-check] 89 | # Used by automated build steps to check that all files are properly formatted. 90 | basepython = python3.9 91 | deps = black==22.3.0 92 | commands = 93 | black -l 100 --check ./ 94 | 95 | [testenv:pylint] 96 | basepython = python3.9 97 | skipdist = true 98 | skip_install = true 99 | deps = 100 | pylint==2.3.1 101 | commands = 102 | python -m pylint --rcfile=.pylintrc -j 0 src/sagemaker_training 103 | --------------------------------------------------------------------------------