├── .dockerignore ├── test ├── resources │ ├── boston │ │ ├── data │ │ │ └── empty │ │ └── single_machine_customer_script.py │ ├── data │ │ ├── libsvm │ │ │ ├── train.libsvm │ │ │ ├── libsvm_files │ │ │ │ └── train.libsvm │ │ │ └── train.libsvm.weights │ │ ├── csv │ │ │ ├── train.csv │ │ │ ├── csv_files │ │ │ │ └── train.csv │ │ │ ├── multiple_files │ │ │ │ ├── train_1.csv │ │ │ │ └── train_2.csv │ │ │ ├── train.csv.weights │ │ │ └── weighted_csv_files │ │ │ │ └── train.csv.weights │ │ ├── parquet │ │ │ ├── train.parquet │ │ │ ├── pq_files │ │ │ │ └── train.parquet │ │ │ └── multiple_files │ │ │ │ ├── train_0.parquet │ │ │ │ └── train_1.parquet │ │ └── recordio_protobuf │ │ │ ├── train.pb │ │ │ ├── sparse │ │ │ └── train.pb │ │ │ ├── pb_files │ │ │ └── train.pb │ │ │ ├── single_feature_label.pb │ │ │ └── sparse_edge_cases │ │ │ ├── diagonal.pbr │ │ │ ├── dense_as_sparse.pbr │ │ │ ├── rectangular_sparse.pbr │ │ │ ├── single_value_center.pbr │ │ │ ├── single_value_bot_left.pbr │ │ │ ├── single_value_bot_right.pbr │ │ │ ├── single_value_top_left.pbr │ │ │ └── single_value_top_right.pbr │ ├── models │ │ ├── pickled_model │ │ │ └── xgboost-model │ │ └── saved_booster │ │ │ └── xgboost-model │ ├── abalone │ │ ├── models │ │ │ └── libsvm_pickled │ │ │ │ └── xgboost-model │ │ └── abalone_distributed.py │ └── versions │ │ └── train.py ├── unit │ ├── distributed_gpu │ │ ├── __init__.py │ │ ├── test_dask_data_utils.py │ │ └── test_distributed_gpu_training.py │ ├── algorithm_toolkit │ │ ├── __init__.py │ │ ├── test_metrics.py │ │ ├── test_exceptions.py │ │ └── test_channel_validation.py │ ├── .DS_Store │ ├── algorithm_mode │ │ ├── __init__.py │ │ ├── test_channel_validation.py │ │ ├── test_serve.py │ │ ├── test_train_utils.py │ │ └── test_hyperparameter_validation.py │ ├── test_training.py │ ├── test_handler_service.py │ ├── test_serving_mms.py │ ├── test_encoder.py │ └── test_distributed.py ├── __init__.py ├── utils │ ├── __init__.py │ └── test_utils.py ├── integration │ ├── local │ │ ├── __init__.py │ │ ├── test_versions.py │ │ ├── test_boston.py │ │ ├── test_kfold.py │ │ └── test_early_stopping.py │ └── test_metadata_calls.py └── conftest.py ├── src ├── sagemaker_xgboost_container │ ├── dmlc_patch │ │ └── __init__.py │ ├── distributed_gpu │ │ ├── __init__.py │ │ ├── dask_data_utils.py │ │ └── dask_cluster_utils.py │ ├── .DS_Store │ ├── __init__.py │ ├── metrics │ │ └── __init__.py │ ├── constants │ │ ├── __init__.py │ │ ├── xgb_content_types.py │ │ ├── sm_env_constants.py │ │ └── xgb_constants.py │ ├── mms_patch │ │ ├── __init__.py │ │ ├── mms_transformer.py │ │ └── model_server.py │ ├── algorithm_mode │ │ ├── __init__.py │ │ ├── metadata.py │ │ ├── metrics.py │ │ ├── integration.py │ │ ├── inference_errors.py │ │ ├── channel_validation.py │ │ ├── train_utils.py │ │ └── handler_service.py │ ├── training.py │ ├── handler_service.py │ ├── callback.py │ ├── prediction_utils.py │ ├── encoder.py │ ├── serving_mms.py │ └── serving.py └── sagemaker_algorithm_toolkit │ ├── __init__.py │ ├── metrics.py │ ├── exceptions.py │ ├── metadata.py │ └── channel_validation.py ├── MANIFEST.in ├── NOTICE ├── docker ├── configs │ └── dask_configs.yaml ├── 3.0-5 │ ├── resources │ │ └── mms │ │ │ ├── endpoints-1.0.jar │ │ │ ├── config.properties.tmp │ │ │ └── ExecutionParameters.java │ └── final │ │ └── Dockerfile.cpu └── 1.7-1-1 │ ├── resources │ └── mms │ │ ├── endpoints-1.0.jar │ │ ├── config.properties.tmp │ │ └── ExecutionParameters.java │ └── final │ └── Dockerfile.cpu ├── setup.cfg ├── pyproject.toml ├── .coveragerc ├── .gitignore ├── .github └── PULL_REQUEST_TEMPLATE.md ├── CODE_OF_CONDUCT.md ├── test-requirements.txt ├── requirements.txt ├── tox.ini ├── setup.py └── CONTRIBUTING.md /.dockerignore: -------------------------------------------------------------------------------- 1 | .venv 2 | .tox 3 | -------------------------------------------------------------------------------- /test/resources/boston/data/empty: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /test/unit/distributed_gpu/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /test/unit/algorithm_toolkit/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/sagemaker_xgboost_container/dmlc_patch/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/sagemaker_xgboost_container/distributed_gpu/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | include requirements.txt test-requirements.txt 2 | 3 | -------------------------------------------------------------------------------- /test/resources/data/libsvm/train.libsvm: -------------------------------------------------------------------------------- 1 | 1 2:1 2 | 1 2:1 3 | 1 2:1 4 | 1 2:1 5 | 0 4:1 -------------------------------------------------------------------------------- /test/resources/data/libsvm/libsvm_files/train.libsvm: -------------------------------------------------------------------------------- 1 | 1 2:1 2 | 1 2:1 3 | 1 2:1 4 | 1 2:1 5 | 0 4:1 -------------------------------------------------------------------------------- /test/unit/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aws/sagemaker-xgboost-container/HEAD/test/unit/.DS_Store -------------------------------------------------------------------------------- /test/resources/data/csv/train.csv: -------------------------------------------------------------------------------- 1 | 0,1,0,0,0,0 2 | 0,1,0,0,0,0 3 | 0,1,0,0,0,0 4 | 0,1,0,0,0,0 5 | 1,0,1,0,0,0 -------------------------------------------------------------------------------- /test/resources/data/libsvm/train.libsvm.weights: -------------------------------------------------------------------------------- 1 | 1:0.2 2:1 2 | 1:0.2 2:1 3 | 1:0.2 2:1 4 | 1:0.2 2:1 5 | 0:0.2 4:1 -------------------------------------------------------------------------------- /NOTICE: -------------------------------------------------------------------------------- 1 | SageMaker XGBoost Container 2 | Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved. 3 | -------------------------------------------------------------------------------- /docker/configs/dask_configs.yaml: -------------------------------------------------------------------------------- 1 | logging: 2 | distributed.comm.tcp: warning 3 | distributed.nanny: critical 4 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [flake8] 2 | max-line-length = 120 3 | exclude = src/sagemaker_xgboost_container/dmlc_patch/tracker.py 4 | -------------------------------------------------------------------------------- /test/resources/data/csv/csv_files/train.csv: -------------------------------------------------------------------------------- 1 | 0,1,0,0,0,0 2 | 0,1,0,0,0,0 3 | 0,1,0,0,0,0 4 | 0,1,0,0,0,0 5 | 1,0,1,0,0,0 -------------------------------------------------------------------------------- /test/resources/data/csv/multiple_files/train_1.csv: -------------------------------------------------------------------------------- 1 | 0,1,0,0,0,0 2 | 0,1,0,0,0,0 3 | 0,1,0,0,0,0 4 | 0,1,0,0,0,0 5 | 1,0,1,0,0,0 -------------------------------------------------------------------------------- /test/resources/data/csv/multiple_files/train_2.csv: -------------------------------------------------------------------------------- 1 | 0,1,0,0,0,0 2 | 0,1,0,0,0,0 3 | 0,1,0,0,0,0 4 | 0,1,0,0,0,0 5 | 1,0,1,0,0,0 -------------------------------------------------------------------------------- /test/resources/data/csv/train.csv.weights: -------------------------------------------------------------------------------- 1 | 0,0.2,1,0,0,0,0 2 | 0,0.2,1,0,0,0,0 3 | 0,0.2,1,0,0,0,0 4 | 0,0.2,1,0,0,0,0 5 | 1,0.2,0,1,0,0,0 6 | -------------------------------------------------------------------------------- /src/sagemaker_xgboost_container/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aws/sagemaker-xgboost-container/HEAD/src/sagemaker_xgboost_container/.DS_Store -------------------------------------------------------------------------------- /test/resources/data/parquet/train.parquet: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aws/sagemaker-xgboost-container/HEAD/test/resources/data/parquet/train.parquet -------------------------------------------------------------------------------- /docker/3.0-5/resources/mms/endpoints-1.0.jar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aws/sagemaker-xgboost-container/HEAD/docker/3.0-5/resources/mms/endpoints-1.0.jar -------------------------------------------------------------------------------- /docker/1.7-1-1/resources/mms/endpoints-1.0.jar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aws/sagemaker-xgboost-container/HEAD/docker/1.7-1-1/resources/mms/endpoints-1.0.jar -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.isort] 2 | profile = "black" 3 | 4 | [build-system] 5 | requires = ["setuptools>=61.0,<81"] 6 | build-backend = "setuptools.build_meta" 7 | -------------------------------------------------------------------------------- /test/resources/data/recordio_protobuf/train.pb: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aws/sagemaker-xgboost-container/HEAD/test/resources/data/recordio_protobuf/train.pb -------------------------------------------------------------------------------- /test/resources/data/csv/weighted_csv_files/train.csv.weights: -------------------------------------------------------------------------------- 1 | 0,0.2,1,0,0,0,0 2 | 0,0.2,1,0,0,0,0 3 | 0,0.2,1,0,0,0,0 4 | 0,0.2,1,0,0,0,0 5 | 1,0.2,0,1,0,0,0 6 | -------------------------------------------------------------------------------- /test/resources/data/parquet/pq_files/train.parquet: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aws/sagemaker-xgboost-container/HEAD/test/resources/data/parquet/pq_files/train.parquet -------------------------------------------------------------------------------- /test/resources/models/pickled_model/xgboost-model: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aws/sagemaker-xgboost-container/HEAD/test/resources/models/pickled_model/xgboost-model -------------------------------------------------------------------------------- /test/resources/models/saved_booster/xgboost-model: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aws/sagemaker-xgboost-container/HEAD/test/resources/models/saved_booster/xgboost-model -------------------------------------------------------------------------------- /.coveragerc: -------------------------------------------------------------------------------- 1 | # .coveragerc to control coverage.py 2 | [run] 3 | omit = 4 | */dmlc_patch/* 5 | 6 | [report] 7 | exclude_lines = 8 | if __name__ == .__main__.: 9 | -------------------------------------------------------------------------------- /test/resources/data/recordio_protobuf/sparse/train.pb: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aws/sagemaker-xgboost-container/HEAD/test/resources/data/recordio_protobuf/sparse/train.pb -------------------------------------------------------------------------------- /test/resources/data/recordio_protobuf/pb_files/train.pb: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aws/sagemaker-xgboost-container/HEAD/test/resources/data/recordio_protobuf/pb_files/train.pb -------------------------------------------------------------------------------- /test/resources/abalone/models/libsvm_pickled/xgboost-model: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aws/sagemaker-xgboost-container/HEAD/test/resources/abalone/models/libsvm_pickled/xgboost-model -------------------------------------------------------------------------------- /test/resources/data/parquet/multiple_files/train_0.parquet: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aws/sagemaker-xgboost-container/HEAD/test/resources/data/parquet/multiple_files/train_0.parquet -------------------------------------------------------------------------------- /test/resources/data/parquet/multiple_files/train_1.parquet: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aws/sagemaker-xgboost-container/HEAD/test/resources/data/parquet/multiple_files/train_1.parquet -------------------------------------------------------------------------------- /test/resources/data/recordio_protobuf/single_feature_label.pb: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aws/sagemaker-xgboost-container/HEAD/test/resources/data/recordio_protobuf/single_feature_label.pb -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | dist/ 2 | build/ 3 | src/sagemaker_xgboost_container.egg-info/ 4 | .venv 5 | *~ 6 | .tox 7 | __pycache__ 8 | .coverage* 9 | .mypy_cache/ 10 | .idea/ 11 | .DS_Store 12 | test.parquet -------------------------------------------------------------------------------- /test/resources/data/recordio_protobuf/sparse_edge_cases/diagonal.pbr: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aws/sagemaker-xgboost-container/HEAD/test/resources/data/recordio_protobuf/sparse_edge_cases/diagonal.pbr -------------------------------------------------------------------------------- /test/resources/data/recordio_protobuf/sparse_edge_cases/dense_as_sparse.pbr: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aws/sagemaker-xgboost-container/HEAD/test/resources/data/recordio_protobuf/sparse_edge_cases/dense_as_sparse.pbr -------------------------------------------------------------------------------- /test/resources/data/recordio_protobuf/sparse_edge_cases/rectangular_sparse.pbr: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aws/sagemaker-xgboost-container/HEAD/test/resources/data/recordio_protobuf/sparse_edge_cases/rectangular_sparse.pbr -------------------------------------------------------------------------------- /test/resources/data/recordio_protobuf/sparse_edge_cases/single_value_center.pbr: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aws/sagemaker-xgboost-container/HEAD/test/resources/data/recordio_protobuf/sparse_edge_cases/single_value_center.pbr -------------------------------------------------------------------------------- /test/resources/data/recordio_protobuf/sparse_edge_cases/single_value_bot_left.pbr: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aws/sagemaker-xgboost-container/HEAD/test/resources/data/recordio_protobuf/sparse_edge_cases/single_value_bot_left.pbr -------------------------------------------------------------------------------- /test/resources/data/recordio_protobuf/sparse_edge_cases/single_value_bot_right.pbr: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aws/sagemaker-xgboost-container/HEAD/test/resources/data/recordio_protobuf/sparse_edge_cases/single_value_bot_right.pbr -------------------------------------------------------------------------------- /test/resources/data/recordio_protobuf/sparse_edge_cases/single_value_top_left.pbr: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aws/sagemaker-xgboost-container/HEAD/test/resources/data/recordio_protobuf/sparse_edge_cases/single_value_top_left.pbr -------------------------------------------------------------------------------- /test/resources/data/recordio_protobuf/sparse_edge_cases/single_value_top_right.pbr: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aws/sagemaker-xgboost-container/HEAD/test/resources/data/recordio_protobuf/sparse_edge_cases/single_value_top_right.pbr -------------------------------------------------------------------------------- /.github/PULL_REQUEST_TEMPLATE.md: -------------------------------------------------------------------------------- 1 | *Issue #, if available:* 2 | 3 | *Description of changes:* 4 | 5 | *Testing:* 6 | 7 | 8 | By submitting this pull request, I confirm that you can use, modify, copy, and redistribute this contribution, under the terms of your choice. 9 | -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | ## Code of Conduct 2 | This project has adopted the [Amazon Open Source Code of Conduct](https://aws.github.io/code-of-conduct). 3 | For more information see the [Code of Conduct FAQ](https://aws.github.io/code-of-conduct-faq) or contact 4 | opensource-codeofconduct@amazon.com with any additional questions or comments. 5 | -------------------------------------------------------------------------------- /test-requirements.txt: -------------------------------------------------------------------------------- 1 | Flask==1.1.1 # sagemaker-containers requires flask 1.1.1 2 | black 3 | coverage 4 | docker==6.1.3 # docker 7.0.0 has a breaking change: https://github.com/docker/docker-py/issues/3194#issuecomment-1848950456 5 | flake8 6 | isort 7 | mock 8 | pytest 9 | pytest-cov 10 | pytest-xdist 11 | sagemaker>=1.3.0,<2.0 12 | protobuf>=3.20.0,<=3.20.3 13 | tox 14 | tox-conda 15 | -------------------------------------------------------------------------------- /docker/1.7-1-1/resources/mms/config.properties.tmp: -------------------------------------------------------------------------------- 1 | model_store=$$SAGEMAKER_MMS_MODEL_STORE$$ 2 | load_models=$$SAGEMAKER_MMS_LOAD_MODELS$$ 3 | plugins_path=/tmp/plugins 4 | inference_address=http://0.0.0.0:$$SAGEMAKER_BIND_TO_PORT$$ 5 | management_address=http://0.0.0.0:$$SAGEMAKER_BIND_TO_PORT$$ 6 | default_workers_per_model=$$SAGEMAKER_NUM_MODEL_WORKERS$$ 7 | max_request_size=$$SAGEMAKER_MAX_REQUEST_SIZE$$ 8 | decode_input_request=false 9 | default_service_handler=$$SAGEMAKER_MMS_DEFAULT_HANDLER$$ 10 | job_queue_size=$$SAGEMAKER_MODEL_JOB_QUEUE_SIZE$$ 11 | preload_model=true 12 | -------------------------------------------------------------------------------- /docker/3.0-5/resources/mms/config.properties.tmp: -------------------------------------------------------------------------------- 1 | model_store=$$SAGEMAKER_MMS_MODEL_STORE$$ 2 | load_models=$$SAGEMAKER_MMS_LOAD_MODELS$$ 3 | plugins_path=/tmp/plugins 4 | inference_address=http://0.0.0.0:$$SAGEMAKER_BIND_TO_PORT$$ 5 | management_address=http://0.0.0.0:$$SAGEMAKER_BIND_TO_PORT$$ 6 | default_workers_per_model=$$SAGEMAKER_NUM_MODEL_WORKERS$$ 7 | max_request_size=$$SAGEMAKER_MAX_REQUEST_SIZE$$ 8 | decode_input_request=false 9 | default_service_handler=$$SAGEMAKER_MMS_DEFAULT_HANDLER$$ 10 | job_queue_size=$$SAGEMAKER_MODEL_JOB_QUEUE_SIZE$$ 11 | preload_model=true 12 | -------------------------------------------------------------------------------- /test/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the 'License'). You 4 | # may not use this file except in compliance with the License. A copy of 5 | # the License is located at 6 | # 7 | # http://aws.amazon.com/apache2.0/ 8 | # 9 | # or in the 'license' file accompanying this file. This file is 10 | # distributed on an 'AS IS' BASIS, WITHOUT WARRANTIES OR CONDITIONS OF 11 | # ANY KIND, either express or implied. See the License for the specific 12 | # language governing permissions and limitations under the License. 13 | -------------------------------------------------------------------------------- /test/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the 'License'). You 4 | # may not use this file except in compliance with the License. A copy of 5 | # the License is located at 6 | # 7 | # http://aws.amazon.com/apache2.0/ 8 | # 9 | # or in the 'license' file accompanying this file. This file is 10 | # distributed on an 'AS IS' BASIS, WITHOUT WARRANTIES OR CONDITIONS OF 11 | # ANY KIND, either express or implied. See the License for the specific 12 | # language governing permissions and limitations under the License. 13 | -------------------------------------------------------------------------------- /test/integration/local/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the 'License'). You 4 | # may not use this file except in compliance with the License. A copy of 5 | # the License is located at 6 | # 7 | # http://aws.amazon.com/apache2.0/ 8 | # 9 | # or in the 'license' file accompanying this file. This file is 10 | # distributed on an 'AS IS' BASIS, WITHOUT WARRANTIES OR CONDITIONS OF 11 | # ANY KIND, either express or implied. See the License for the specific 12 | # language governing permissions and limitations under the License. 13 | -------------------------------------------------------------------------------- /test/unit/algorithm_mode/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the 'License'). You 4 | # may not use this file except in compliance with the License. A copy of 5 | # the License is located at 6 | # 7 | # http://aws.amazon.com/apache2.0/ 8 | # 9 | # or in the 'license' file accompanying this file. This file is 10 | # distributed on an 'AS IS' BASIS, WITHOUT WARRANTIES OR CONDITIONS OF 11 | # ANY KIND, either express or implied. See the License for the specific 12 | # language governing permissions and limitations under the License. 13 | -------------------------------------------------------------------------------- /src/sagemaker_algorithm_toolkit/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the 'License'). You 4 | # may not use this file except in compliance with the License. A copy of 5 | # the License is located at 6 | # 7 | # http://aws.amazon.com/apache2.0/ 8 | # 9 | # or in the 'license' file accompanying this file. This file is 10 | # distributed on an 'AS IS' BASIS, WITHOUT WARRANTIES OR CONDITIONS OF 11 | # ANY KIND, either express or implied. See the License for the specific 12 | # language governing permissions and limitations under the License. 13 | -------------------------------------------------------------------------------- /src/sagemaker_xgboost_container/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the 'License'). You 4 | # may not use this file except in compliance with the License. A copy of 5 | # the License is located at 6 | # 7 | # http://aws.amazon.com/apache2.0/ 8 | # 9 | # or in the 'license' file accompanying this file. This file is 10 | # distributed on an 'AS IS' BASIS, WITHOUT WARRANTIES OR CONDITIONS OF 11 | # ANY KIND, either express or implied. See the License for the specific 12 | # language governing permissions and limitations under the License. 13 | -------------------------------------------------------------------------------- /src/sagemaker_xgboost_container/metrics/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the 'License'). You 4 | # may not use this file except in compliance with the License. A copy of 5 | # the License is located at 6 | # 7 | # http://aws.amazon.com/apache2.0/ 8 | # 9 | # or in the 'license' file accompanying this file. This file is 10 | # distributed on an 'AS IS' BASIS, WITHOUT WARRANTIES OR CONDITIONS OF 11 | # ANY KIND, either express or implied. See the License for the specific 12 | # language governing permissions and limitations under the License. 13 | -------------------------------------------------------------------------------- /src/sagemaker_xgboost_container/constants/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the 'License'). You 4 | # may not use this file except in compliance with the License. A copy of 5 | # the License is located at 6 | # 7 | # http://aws.amazon.com/apache2.0/ 8 | # 9 | # or in the 'license' file accompanying this file. This file is 10 | # distributed on an 'AS IS' BASIS, WITHOUT WARRANTIES OR CONDITIONS OF 11 | # ANY KIND, either express or implied. See the License for the specific 12 | # language governing permissions and limitations under the License. 13 | -------------------------------------------------------------------------------- /src/sagemaker_xgboost_container/mms_patch/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the 'License'). You 4 | # may not use this file except in compliance with the License. A copy of 5 | # the License is located at 6 | # 7 | # http://aws.amazon.com/apache2.0/ 8 | # 9 | # or in the 'license' file accompanying this file. This file is 10 | # distributed on an 'AS IS' BASIS, WITHOUT WARRANTIES OR CONDITIONS OF 11 | # ANY KIND, either express or implied. See the License for the specific 12 | # language governing permissions and limitations under the License. 13 | -------------------------------------------------------------------------------- /src/sagemaker_xgboost_container/constants/xgb_content_types.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the 'License'). You 4 | # may not use this file except in compliance with the License. A copy of 5 | # the License is located at 6 | # 7 | # http://aws.amazon.com/apache2.0/ 8 | # 9 | # or in the 'license' file accompanying this file. This file is 10 | # distributed on an 'AS IS' BASIS, WITHOUT WARRANTIES OR CONDITIONS OF 11 | # ANY KIND, either express or implied. See the License for the specific 12 | # language governing permissions and limitations under the License. 13 | X_LIBSVM = "text/x-libsvm" 14 | LIBSVM = "text/libsvm" 15 | X_PARQUET = "application/x-parquet" 16 | X_RECORDIO_PROTOBUF = "application/x-recordio-protobuf" 17 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | Flask==1.1.1 # sagemaker-containers requires flask 1.1.1 2 | PyYAML==6.0.1 3 | Pillow==9.1.1 4 | boto3==1.17.52 5 | botocore==1.20.52 6 | cryptography==45.0.5 7 | dask[dataframe]==2024.11.2 8 | dask-cuda==24.12.00 9 | cuda-python==12.6.0 10 | gunicorn==23.0.0 11 | itsdangerous==2.0.1 12 | matplotlib==3.9.2 13 | multi-model-server==1.1.2 14 | numpy==2.1.0 15 | pandas==2.2.3 16 | # protobuf==5.27.0 17 | psutil==5.8.0 # sagemaker-containers requires psutil 5.6.7 18 | pynvml==11.4.1 # dask-cuda pynvml>=11.0.0,<12.0.0a0 19 | python-dateutil==2.8.2 20 | retrying==1.3.3 21 | requests==2.32.3 22 | sagemaker-containers==2.8.6.post2 23 | sagemaker-inference==1.5.5 24 | scipy==1.15.0 25 | scikit-learn==1.5.2 26 | urllib3==1.26.5 27 | wheel==0.45.1 28 | jinja2==2.11.3 29 | MarkupSafe==1.1.1 30 | Werkzeug==0.15.6 31 | certifi==2023.7.22 32 | gevent==23.9.1 33 | numba==0.61.0 34 | setuptools<81 35 | -------------------------------------------------------------------------------- /src/sagemaker_xgboost_container/algorithm_mode/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the 'License'). You 4 | # may not use this file except in compliance with the License. A copy of 5 | # the License is located at 6 | # 7 | # http://aws.amazon.com/apache2.0/ 8 | # 9 | # or in the 'license' file accompanying this file. This file is 10 | # distributed on an 'AS IS' BASIS, WITHOUT WARRANTIES OR CONDITIONS OF 11 | # ANY KIND, either express or implied. See the License for the specific 12 | # language governing permissions and limitations under the License. 13 | import os 14 | 15 | from sagemaker_containers.beta.framework import env 16 | 17 | from sagemaker_xgboost_container.algorithm_mode import serve 18 | 19 | # Pre-load the model in the algorithm mode. 20 | # Otherwise, the model will be loaded when serving the first request per worker. 21 | # When the model is large, the request may timeout. 22 | if os.environ.get("SERVER_SOFTWARE") is not None and env.ServingEnv().module_name is None: 23 | serve.load_model() 24 | -------------------------------------------------------------------------------- /src/sagemaker_xgboost_container/algorithm_mode/metadata.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the 'License'). You 4 | # may not use this file except in compliance with the License. A copy of 5 | # the License is located at 6 | # 7 | # http://aws.amazon.com/apache2.0/ 8 | # 9 | # or in the 'license' file accompanying this file. This file is 10 | # distributed on an 'AS IS' BASIS, WITHOUT WARRANTIES OR CONDITIONS OF 11 | # ANY KIND, either express or implied. See the License for the specific 12 | # language governing permissions and limitations under the License. 13 | from sagemaker_algorithm_toolkit import metadata 14 | 15 | 16 | def initialize(image_uri, hyperparameters, channels, metrics): 17 | training_spec = metadata.training_spec( 18 | hyperparameters, channels, metrics, image_uri, metadata.get_cpu_instance_types(metadata.Product.TRAINING), True 19 | ) 20 | inference_spec = metadata.inference_spec( 21 | image_uri, 22 | metadata.get_cpu_instance_types(metadata.Product.HOSTING), 23 | metadata.get_cpu_instance_types(metadata.Product.BATCH_TRANSFORM), 24 | ["text/csv", "text/libsvm"], 25 | ["text/csv", "text/libsvm"], 26 | ) 27 | return metadata.generate_metadata(training_spec, inference_spec) 28 | -------------------------------------------------------------------------------- /test/unit/algorithm_toolkit/test_metrics.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the 'License'). You 4 | # may not use this file except in compliance with the License. A copy of 5 | # the License is located at 6 | # 7 | # http://aws.amazon.com/apache2.0/ 8 | # 9 | # or in the 'license' file accompanying this file. This file is 10 | # distributed on an 'AS IS' BASIS, WITHOUT WARRANTIES OR CONDITIONS OF 11 | # ANY KIND, either express or implied. See the License for the specific 12 | # language governing permissions and limitations under the License. 13 | import unittest 14 | 15 | import mock 16 | 17 | from sagemaker_algorithm_toolkit import metrics as m 18 | 19 | 20 | class TestMetrics(unittest.TestCase): 21 | def test_simple(self): 22 | metrics = m.Metrics( 23 | m.Metric( 24 | name="test mean squared error", 25 | format_string="test:mse {:.3f}", 26 | direction=m.Metric.MINIMIZE, 27 | regex="test:mse ([0-9\\.]+)", 28 | ) 29 | ) 30 | with mock.patch("sagemaker_algorithm_toolkit.metrics.logging") as mock_logging: 31 | metrics["test mean squared error"].log(5.123) 32 | mock_logging.info.assert_called_once_with("test:mse 5.123") 33 | -------------------------------------------------------------------------------- /test/integration/local/test_versions.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the 'License'). You 4 | # may not use this file except in compliance with the License. A copy of 5 | # the License is located at 6 | # 7 | # http://aws.amazon.com/apache2.0/ 8 | # 9 | # or in the 'license' file accompanying this file. This file is 10 | # distributed on an 'AS IS' BASIS, WITHOUT WARRANTIES OR CONDITIONS OF 11 | # ANY KIND, either express or implied. See the License for the specific 12 | # language governing permissions and limitations under the License. 13 | from __future__ import absolute_import 14 | 15 | import os 16 | from test.utils import local_mode 17 | 18 | path = os.path.dirname(os.path.realpath(__file__)) 19 | script_path = os.path.join(path, "..", "..", "resources", "versions") 20 | abalone_path = os.path.join(path, "..", "..", "resources", "abalone") 21 | data_dir = os.path.join(abalone_path, "data") 22 | 23 | 24 | def test_package_version(docker_image, opt_ml): 25 | version_check_script = "train.py" 26 | 27 | local_mode.train( 28 | version_check_script, 29 | data_dir, 30 | docker_image, 31 | opt_ml, 32 | source_dir=script_path, 33 | ) 34 | 35 | assert not local_mode.file_exists(opt_ml, "output/failure"), "Failure happened" 36 | -------------------------------------------------------------------------------- /test/unit/algorithm_mode/test_channel_validation.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the 'License'). You 4 | # may not use this file except in compliance with the License. A copy of 5 | # the License is located at 6 | # 7 | # http://aws.amazon.com/apache2.0/ 8 | # 9 | # or in the 'license' file accompanying this file. This file is 10 | # distributed on an 'AS IS' BASIS, WITHOUT WARRANTIES OR CONDITIONS OF 11 | # ANY KIND, either express or implied. See the License for the specific 12 | # language governing permissions and limitations under the License. 13 | from __future__ import absolute_import 14 | 15 | import unittest 16 | 17 | from sagemaker_algorithm_toolkit import channel_validation as cv 18 | from sagemaker_xgboost_container.algorithm_mode import channel_validation as acv 19 | 20 | REQUIRED_CHANNEL = "required" 21 | NOT_REQUIRED_CHANNEL = "not_required" 22 | 23 | 24 | class TestChannelValidation(unittest.TestCase): 25 | def setUp(self): 26 | self.channels = acv.initialize() 27 | 28 | def test_default_content_type(self): 29 | test_user_channels = {"train": {cv.TRAINING_INPUT_MODE: "File", cv.S3_DIST_TYPE: "FullyReplicated"}} 30 | self.channels.validate(test_user_channels) 31 | self.assertEqual(test_user_channels["train"][cv.CONTENT_TYPE], "text/libsvm") 32 | -------------------------------------------------------------------------------- /tox.ini: -------------------------------------------------------------------------------- 1 | [tox] 2 | envlist = {py38}-xgboost{1.5},flake8,black-{format,check},isort 3 | 4 | [flake8] 5 | max-line-length = 120 6 | 7 | [testenv] 8 | setenv = HADOOP_PREFIX = /not/real/path 9 | deps = 10 | xgboost0.72: xgboost==0.72.1 11 | xgboost0.82: xgboost==0.82 12 | xgboost0.90: xgboost==0.90 13 | xgboost1.0: xgboost==1.0 14 | xgboost1.2: xgboost==1.2 15 | xgboost1.3: xgboost==1.3.3 16 | xgboost1.5: xgboost==1.5.2 17 | xgboost1.7: xgboost==1.7.4 18 | xgboost3.0.5: xgboost==3.0.5 19 | xgboostlatest: xgboost 20 | -r{toxinidir}/requirements.txt 21 | -r{toxinidir}/test-requirements.txt 22 | conda_deps= 23 | pyarrow==17.0.0 24 | tbb==2022.2.0 25 | mlio-py==0.9.0 26 | conda_channels= 27 | conda-forge 28 | mlio 29 | commands = 30 | pytest --cov=sagemaker_xgboost_container --cov-fail-under=60 test/unit # increase minimum bar over time (75%+) 31 | 32 | ; [testenv:flake8] 33 | ; deps = flake8 34 | ; commands = flake8 setup.py src test 35 | 36 | [testenv:black-format] 37 | # Used during development (before committing) to format .py files. 38 | basepython = python3 39 | skipsdist = true 40 | skip_install = true 41 | deps = black 42 | commands = 43 | black -l 120 setup.py src/ test/ 44 | 45 | [testenv:black-check] 46 | # Used by automated build steps to check that all files are properly formatted. 47 | basepython = python3 48 | skipsdist = true 49 | skip_install = true 50 | deps = black 51 | commands = 52 | black -l 120 --check setup.py src/ test/ 53 | 54 | [testenv:isort] 55 | deps = isort 56 | commands = 57 | isort setup.py src/ test/ 58 | -------------------------------------------------------------------------------- /src/sagemaker_xgboost_container/constants/sm_env_constants.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the 'License'). You 4 | # may not use this file except in compliance with the License. A copy of 5 | # the License is located at 6 | # 7 | # http://aws.amazon.com/apache2.0/ 8 | # 9 | # or in the 'license' file accompanying this file. This file is 10 | # distributed on an 'AS IS' BASIS, WITHOUT WARRANTIES OR CONDITIONS OF 11 | # ANY KIND, either express or implied. See the License for the specific 12 | # language governing permissions and limitations under the License. 13 | # TODO: Move these to sagemaker-containers 14 | 15 | # Resource related constants 16 | SM_CURRENT_HOST = "SM_CURRENT_HOST" 17 | SM_HOSTS = "SM_HOSTS" 18 | SM_NUM_GPUS = "SM_NUM_GPUS" 19 | 20 | # Data related constants 21 | SM_CHANNEL_TRAIN = "SM_CHANNEL_TRAIN" 22 | SM_CHANNEL_VALIDATION = "SM_CHANNEL_VALIDATION" 23 | SM_MODEL_DIR = "SM_MODEL_DIR" 24 | 25 | # Training constants 26 | SM_INPUT_TRAINING_CONFIG_FILE = "SM_INPUT_TRAINING_CONFIG_FILE" 27 | SM_INPUT_DATA_CONFIG_FILE = "SM_INPUT_DATA_CONFIG_FILE" 28 | SM_CHECKPOINT_CONFIG_FILE = "SM_CHECKPOINT_CONFIG_FILE" 29 | SM_OUTPUT_DATA_DIR = "SM_OUTPUT_DATA_DIR" 30 | 31 | # Inference constants 32 | SAGEMAKER_INFERENCE_ENSEMBLE = "SAGEMAKER_INFERENCE_ENSEMBLE" 33 | SAGEMAKER_INFERENCE_OUTPUT = "SAGEMAKER_INFERENCE_OUTPUT" 34 | SAGEMAKER_DEFAULT_INVOCATIONS_ACCEPT = "SAGEMAKER_DEFAULT_INVOCATIONS_ACCEPT" 35 | SAGEMAKER_BATCH = "SAGEMAKER_BATCH" 36 | 37 | # Multiprocessing related constants 38 | ONE_THREAD_PER_PROCESS = "1" 39 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | import os 4 | from glob import glob 5 | from os.path import basename, splitext 6 | 7 | from setuptools import find_packages, setup 8 | 9 | 10 | def read(fname): 11 | return open(os.path.join(os.path.dirname(__file__), fname)).read() 12 | 13 | 14 | setup( 15 | name="sagemaker_xgboost_container", 16 | version="2.0", 17 | description="Open source library for creating XGBoost containers to run on Amazon SageMaker.", 18 | packages=find_packages(where="src", exclude=("test",)), 19 | package_dir={"": "src"}, 20 | py_modules=[splitext(basename(path))[0] for path in glob("src/*.py")], 21 | long_description=read("README.rst"), 22 | author="Amazon Web Services", 23 | license="Apache License 2.0", 24 | classifiers=[ 25 | "Development Status :: 5 - Production/Stable", 26 | "Intended Audience :: Developers", 27 | "Natural Language :: English", 28 | "License :: OSI Approved :: Apache Software License", 29 | "Programming Language :: Python", 30 | "Programming Language :: Python :: 3.6", 31 | "Programming Language :: Python :: 3.7", 32 | "Programming Language :: Python :: 3.8", 33 | "Programming Language :: Python :: 3.9", 34 | "Programming Language :: Python :: 3.10", 35 | ], 36 | install_requires=read("requirements.txt"), 37 | extras_require={"test": read("test-requirements.txt")}, 38 | entry_points={ 39 | "console_scripts": [ 40 | "serve=sagemaker_xgboost_container.serving:serving_entrypoint", 41 | ] 42 | }, 43 | python_requires=">=3.8", 44 | ) 45 | -------------------------------------------------------------------------------- /test/utils/test_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the 'License'). You 4 | # may not use this file except in compliance with the License. A copy of 5 | # the License is located at 6 | # 7 | # http://aws.amazon.com/apache2.0/ 8 | # 9 | # or in the 'license' file accompanying this file. This file is 10 | # distributed on an 'AS IS' BASIS, WITHOUT WARRANTIES OR CONDITIONS OF 11 | # ANY KIND, either express or implied. See the License for the specific 12 | # language governing permissions and limitations under the License. 13 | from __future__ import absolute_import 14 | 15 | import socket 16 | import test.utils.local_mode as localmode 17 | from contextlib import closing 18 | 19 | 20 | def files_exist(opt_ml, files): 21 | for f in files: 22 | assert localmode.file_exists(opt_ml, f), "file {} was not created".format(f) 23 | 24 | 25 | def predict_and_assert_response_length(data, content_type): 26 | predict_response = localmode.request(data, content_type=content_type) 27 | assert len(predict_response) == len(data) 28 | 29 | 30 | # From https://stackoverflow.com/a/45690594 31 | def find_two_open_ports(): 32 | with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as s1: 33 | s1.bind(("", 0)) 34 | s1.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) 35 | 36 | with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as s2: 37 | s2.bind(("", 0)) 38 | s2.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) 39 | 40 | return s1.getsockname()[1], s2.getsockname()[1] 41 | -------------------------------------------------------------------------------- /test/resources/versions/train.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | import pkg_resources 4 | 5 | PYTHON_MAJOR_VERSION = 3 6 | PYTHON_MINOR_VERSION = 10 7 | REQUIREMENTS = """\ 8 | Flask==1.1.1 9 | Pillow==9.1.1 10 | PyYAML==6.0.1 11 | boto3==1.17.52 12 | botocore==1.20.52 13 | conda==25.9.1 14 | cryptography==45.0.5 15 | gunicorn==23.0.0 16 | matplotlib==3.9.2 17 | multi-model-server==1.1.2 18 | numpy==2.1.0 19 | pandas==2.2.3 20 | psutil==5.8.0 21 | pyarrow==22.0.0 22 | python-dateutil==2.8.2 23 | retrying==1.3.3 24 | sagemaker-containers==2.8.6.post2 25 | sagemaker-inference==1.5.5 26 | scipy==1.15.0 27 | scikit-learn==1.5.2 28 | urllib3==1.26.5 29 | wheel==0.45.1 30 | jinja2==2.11.3 31 | MarkupSafe==1.1.1 32 | Werkzeug==0.15.6 33 | certifi==2023.7.22 34 | gevent==23.9.1 35 | """.strip() 36 | 37 | 38 | def assert_python_version(major, minor): 39 | assert sys.version_info.major == major and sys.version_info.minor == minor 40 | 41 | 42 | def assert_package_version(package_name, version): 43 | installed_version = pkg_resources.get_distribution(package_name).version 44 | error_message = ( 45 | f"{package_name} requires {version} but {installed_version} is installed." 46 | ) 47 | assert version == installed_version, error_message 48 | 49 | 50 | def parse_requirements(requirements): 51 | for package_equals_version in requirements.split("\n"): 52 | package, version = package_equals_version.split("==") 53 | yield package, version 54 | 55 | 56 | if __name__ == "__main__": 57 | assert_python_version(PYTHON_MAJOR_VERSION, PYTHON_MINOR_VERSION) 58 | for package, version in parse_requirements(REQUIREMENTS): 59 | assert_package_version(package, version) 60 | -------------------------------------------------------------------------------- /src/sagemaker_xgboost_container/algorithm_mode/metrics.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the 'License'). You 4 | # may not use this file except in compliance with the License. A copy of 5 | # the License is located at 6 | # 7 | # http://aws.amazon.com/apache2.0/ 8 | # 9 | # or in the 'license' file accompanying this file. This file is 10 | # distributed on an 'AS IS' BASIS, WITHOUT WARRANTIES OR CONDITIONS OF 11 | # ANY KIND, either express or implied. See the License for the specific 12 | # language governing permissions and limitations under the License. 13 | from sagemaker_algorithm_toolkit import metrics as m 14 | from sagemaker_xgboost_container.constants.xgb_constants import ( 15 | XGB_MAXIMIZE_METRICS, 16 | XGB_MINIMIZE_METRICS, 17 | ) 18 | 19 | 20 | # https://docs.aws.amazon.com/sagemaker/latest/dg/xgboost-tuning.html 21 | def initialize(): 22 | 23 | maximize_metrics = [ 24 | m.Metric( 25 | name="validation:{}".format(metric_name), 26 | direction=m.Metric.MAXIMIZE, 27 | regex=".*\\[[0-9]+\\].*#011validation-{}:(\\S+)".format(metric_name), 28 | ) 29 | for metric_name in XGB_MAXIMIZE_METRICS 30 | ] 31 | 32 | minimize_metrics = [ 33 | m.Metric( 34 | name="validation:{}".format(metric_name), 35 | direction=m.Metric.MINIMIZE, 36 | regex=".*\\[[0-9]+\\].*#011validation-{}:(\\S+)".format(metric_name), 37 | ) 38 | for metric_name in XGB_MINIMIZE_METRICS 39 | ] 40 | 41 | metrics = maximize_metrics + minimize_metrics 42 | return m.Metrics(*metrics) 43 | -------------------------------------------------------------------------------- /test/integration/local/test_boston.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the 'License'). You 4 | # may not use this file except in compliance with the License. A copy of 5 | # the License is located at 6 | # 7 | # http://aws.amazon.com/apache2.0/ 8 | # 9 | # or in the 'license' file accompanying this file. This file is 10 | # distributed on an 'AS IS' BASIS, WITHOUT WARRANTIES OR CONDITIONS OF 11 | # ANY KIND, either express or implied. See the License for the specific 12 | # language governing permissions and limitations under the License. 13 | from __future__ import absolute_import 14 | 15 | import os 16 | from test.utils import local_mode, test_utils 17 | 18 | path = os.path.dirname(os.path.realpath(__file__)) 19 | boston_path = os.path.join(path, "..", "..", "resources", "boston") 20 | data_dir = os.path.join(boston_path, "data") 21 | 22 | 23 | def test_xgboost_boston_single_machine(docker_image, opt_ml): 24 | 25 | customer_script = "single_machine_customer_script.py" 26 | hyperparameters = { 27 | "objective": "reg:squarederror", 28 | "colsample-bytree": 0.3, 29 | "learning-rate": 0.1, 30 | "max-depth": 5, 31 | "reg-alpha": 10, 32 | "n-estimators": 10, 33 | } 34 | 35 | local_mode.train( 36 | customer_script, data_dir, docker_image, opt_ml, hyperparameters=hyperparameters, source_dir=boston_path 37 | ) 38 | 39 | files = ["model/xgb-boston.model", "output/data/cv_results.csv", "output/data/feature-importance-plot.png"] 40 | 41 | assert not local_mode.file_exists(opt_ml, "output/failure"), "Failure happened" 42 | 43 | test_utils.files_exist(opt_ml, files) 44 | -------------------------------------------------------------------------------- /src/sagemaker_xgboost_container/algorithm_mode/integration.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the 'License'). You 4 | # may not use this file except in compliance with the License. A copy of 5 | # the License is located at 6 | # 7 | # http://aws.amazon.com/apache2.0/ 8 | # 9 | # or in the 'license' file accompanying this file. This file is 10 | # distributed on an 'AS IS' BASIS, WITHOUT WARRANTIES OR CONDITIONS OF 11 | # ANY KIND, either express or implied. See the License for the specific 12 | # language governing permissions and limitations under the License. 13 | import logging 14 | import logging.config 15 | 16 | FORMATTERS = { 17 | "verbose": { 18 | "format": "[%(asctime)s:%(levelname)s] %(message)s", 19 | "datefmt": "%Y-%m-%d:%H:%M:%S", 20 | }, 21 | "simple": {"format": "[%(levelname)s:%(name)s] %(message)s"}, 22 | } 23 | 24 | CONSOLE_LOGGING = { 25 | "version": 1, 26 | "disable_existing_loggers": False, 27 | "formatters": FORMATTERS, 28 | "handlers": { 29 | "console": {"level": "INFO", "formatter": "verbose", "class": "logging.StreamHandler", "stream": None}, 30 | }, 31 | "root": { 32 | "handlers": ["console"], 33 | "level": "INFO", 34 | }, 35 | } 36 | 37 | 38 | LOGGING_CONFIGS = { 39 | "console_only": CONSOLE_LOGGING, 40 | } 41 | 42 | 43 | def setup_main_logger(name): 44 | """ 45 | Return a logger that configures logging for the main application. 46 | 47 | :param name: Name of the returned logger. 48 | """ 49 | 50 | log_config = LOGGING_CONFIGS["console_only"] 51 | logging.config.dictConfig(log_config) 52 | return logging.getLogger(name) 53 | -------------------------------------------------------------------------------- /src/sagemaker_xgboost_container/algorithm_mode/inference_errors.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the 'License'). You 4 | # may not use this file except in compliance with the License. A copy of 5 | # the License is located at 6 | # 7 | # http://aws.amazon.com/apache2.0/ 8 | # 9 | # or in the 'license' file accompanying this file. This file is 10 | # distributed on an 'AS IS' BASIS, WITHOUT WARRANTIES OR CONDITIONS OF 11 | # ANY KIND, either express or implied. See the License for the specific 12 | # language governing permissions and limitations under the License. 13 | import http.client 14 | 15 | from sagemaker_inference import errors 16 | 17 | 18 | class NoContentInferenceError(errors.BaseInferenceToolkitError): 19 | def __init__(self): 20 | super(NoContentInferenceError, self).__init__(http.client.NO_CONTENT, "", "") 21 | 22 | 23 | class UnsupportedMediaTypeInferenceError(errors.BaseInferenceToolkitError): 24 | def __init__(self, message): 25 | super(UnsupportedMediaTypeInferenceError, self).__init__(http.client.UNSUPPORTED_MEDIA_TYPE, message, message) 26 | 27 | 28 | class ModelLoadInferenceError(errors.BaseInferenceToolkitError): 29 | def __init__(self, message): 30 | formatted_message = "Unable to load model: {}".format(message) 31 | super(ModelLoadInferenceError, self).__init__( 32 | http.client.INTERNAL_SERVER_ERROR, formatted_message, formatted_message 33 | ) 34 | 35 | 36 | class BadRequestInferenceError(errors.BaseInferenceToolkitError): 37 | def __init__(self, message): 38 | formatted_message = "Unable to evaluate payload provided: {}".format(message) 39 | super(BadRequestInferenceError, self).__init__(http.client.BAD_REQUEST, formatted_message, formatted_message) 40 | -------------------------------------------------------------------------------- /test/unit/test_training.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the 'License'). You 4 | # may not use this file except in compliance with the License. A copy of 5 | # the License is located at 6 | # 7 | # http://aws.amazon.com/apache2.0/ 8 | # 9 | # or in the 'license' file accompanying this file. This file is 10 | # distributed on an 'AS IS' BASIS, WITHOUT WARRANTIES OR CONDITIONS OF 11 | # ANY KIND, either express or implied. See the License for the specific 12 | # language governing permissions and limitations under the License. 13 | from __future__ import absolute_import 14 | 15 | import unittest 16 | 17 | from mock import MagicMock, patch 18 | 19 | from sagemaker_xgboost_container import training 20 | 21 | 22 | def mock_training_env(current_host="algo-1", module_dir="s3://my/script", module_name="svm", **kwargs): 23 | return MagicMock(current_host=current_host, module_dir=module_dir, module_name=module_name, **kwargs) 24 | 25 | 26 | class TestTraining(unittest.TestCase): 27 | """Note: The 'train' method has been mocked since this test only checks the training resource setup""" 28 | 29 | @patch("sagemaker_containers.beta.framework.modules.run_module") 30 | def test_script_mode(self, mock_run_module): 31 | env = mock_training_env() 32 | env.user_entry_point = "dummy_entry_point" 33 | training.train(env) 34 | 35 | mock_run_module.assert_called_with( 36 | "s3://my/script", env.to_cmd_args(), env.to_env_vars(), "svm", capture_error=False 37 | ) 38 | 39 | @patch("sagemaker_xgboost_container.training.run_algorithm_mode") 40 | def test_algorithm_mode(self, mock_algorithm_mode_train): 41 | env = mock_training_env(module_dir="") 42 | env.user_entry_point = None 43 | training.train(env) 44 | 45 | mock_algorithm_mode_train.assert_called_with() 46 | -------------------------------------------------------------------------------- /test/unit/algorithm_toolkit/test_exceptions.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the 'License'). You 4 | # may not use this file except in compliance with the License. A copy of 5 | # the License is located at 6 | # 7 | # http://aws.amazon.com/apache2.0/ 8 | # 9 | # or in the 'license' file accompanying this file. This file is 10 | # distributed on an 'AS IS' BASIS, WITHOUT WARRANTIES OR CONDITIONS OF 11 | # ANY KIND, either express or implied. See the License for the specific 12 | # language governing permissions and limitations under the License. 13 | import unittest 14 | 15 | from sagemaker_algorithm_toolkit import exceptions as exc 16 | 17 | 18 | class TestExceptions(unittest.TestCase): 19 | def test_BaseToolkitError(self): 20 | e = exc.BaseToolkitError() 21 | self.assertEqual(e.message, "unknown error occurred") 22 | 23 | def test_BaseToolkitError_ValueError(self): 24 | e = exc.BaseToolkitError(caused_by=ValueError("abc")) 25 | self.assertEqual(e.message, "abc (caused by ValueError)") 26 | 27 | def test_UserError(self): 28 | e = exc.UserError("Test 123") 29 | self.assertEqual(e.message, "Test 123") 30 | 31 | def test_UserError_ValueError(self): 32 | e = exc.UserError("Test 123", caused_by=ValueError("abc")) 33 | self.assertEqual(e.message, "Test 123 (caused by ValueError)") 34 | 35 | def test_AlgorithmError(self): 36 | e = exc.AlgorithmError("Test 123") 37 | self.assertEqual(e.message, "Test 123") 38 | 39 | def test_AlgorithmError_ValueError(self): 40 | e = exc.AlgorithmError("Test 123", caused_by=ValueError("abc")) 41 | self.assertEqual(e.message, "Test 123 (caused by ValueError)") 42 | 43 | def test_PlatformError(self): 44 | e = exc.PlatformError("Test 123") 45 | self.assertEqual(e.message, "Test 123") 46 | 47 | def test_PlatformError_ValueError(self): 48 | e = exc.PlatformError("Test 123", caused_by=ValueError("abc")) 49 | self.assertEqual(e.message, "Test 123 (caused by ValueError)") 50 | -------------------------------------------------------------------------------- /src/sagemaker_algorithm_toolkit/metrics.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the 'License'). You 4 | # may not use this file except in compliance with the License. A copy of 5 | # the License is located at 6 | # 7 | # http://aws.amazon.com/apache2.0/ 8 | # 9 | # or in the 'license' file accompanying this file. This file is 10 | # distributed on an 'AS IS' BASIS, WITHOUT WARRANTIES OR CONDITIONS OF 11 | # ANY KIND, either express or implied. See the License for the specific 12 | # language governing permissions and limitations under the License. 13 | import logging 14 | 15 | from sagemaker_algorithm_toolkit import exceptions as exc 16 | 17 | 18 | class Metric(object): 19 | MAXIMIZE = "Maximize" 20 | MINIMIZE = "Minimize" 21 | 22 | def __init__(self, name, regex, format_string=None, tunable=True, direction=None): 23 | self.name = name 24 | self.format_string = format_string 25 | self.direction = direction 26 | self.regex = regex 27 | self.tunable = tunable 28 | if self.tunable and direction is None: 29 | raise exc.AlgorithmError("direction must be specified if tunable is True.") 30 | 31 | def log(self, value): 32 | logging.info(self.format_string.format(value)) 33 | 34 | def format_tunable(self): 35 | return {"MetricName": self.name, "Type": self.direction} 36 | 37 | def format_definition(self): 38 | return {"Name": self.name, "Regex": self.regex} 39 | 40 | 41 | class Metrics(object): 42 | def __init__(self, *metrics): 43 | self.metrics = {metric.name: metric for metric in metrics} 44 | 45 | def __getitem__(self, name): 46 | return self.metrics[name] 47 | 48 | @property 49 | def names(self): 50 | return list(self.metrics) 51 | 52 | def format_tunable(self): 53 | metrics = [] 54 | for name, metric in self.metrics.items(): 55 | if metric.tunable: 56 | metrics.append(metric.format_tunable()) 57 | return metrics 58 | 59 | def format_definitions(self): 60 | return [metric.format_definition() for name, metric in self.metrics.items()] 61 | -------------------------------------------------------------------------------- /src/sagemaker_xgboost_container/algorithm_mode/channel_validation.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the 'License'). You 4 | # may not use this file except in compliance with the License. A copy of 5 | # the License is located at 6 | # 7 | # http://aws.amazon.com/apache2.0/ 8 | # 9 | # or in the 'license' file accompanying this file. This file is 10 | # distributed on an 'AS IS' BASIS, WITHOUT WARRANTIES OR CONDITIONS OF 11 | # ANY KIND, either express or implied. See the License for the specific 12 | # language governing permissions and limitations under the License. 13 | from sagemaker_algorithm_toolkit import channel_validation as cv 14 | from sagemaker_xgboost_container.data_utils import ( 15 | VALID_CONTENT_TYPES, 16 | VALID_PIPED_CONTENT_TYPES, 17 | ) 18 | 19 | 20 | def initialize(): 21 | train_channel = cv.Channel(name="train", required=True) 22 | for ct in VALID_CONTENT_TYPES: 23 | train_channel.add(ct, cv.Channel.FILE_MODE, cv.Channel.SHARDED) 24 | train_channel.add(ct, cv.Channel.FILE_MODE, cv.Channel.REPLICATED) 25 | 26 | for ct in VALID_PIPED_CONTENT_TYPES: 27 | train_channel.add(ct, cv.Channel.PIPE_MODE, cv.Channel.SHARDED) 28 | train_channel.add(ct, cv.Channel.PIPE_MODE, cv.Channel.REPLICATED) 29 | 30 | validation_channel = cv.Channel(name="validation", required=False) 31 | for ct in VALID_CONTENT_TYPES: 32 | validation_channel.add(ct, cv.Channel.FILE_MODE, cv.Channel.SHARDED) 33 | validation_channel.add(ct, cv.Channel.FILE_MODE, cv.Channel.REPLICATED) 34 | 35 | for ct in VALID_PIPED_CONTENT_TYPES: 36 | validation_channel.add(ct, cv.Channel.PIPE_MODE, cv.Channel.SHARDED) 37 | validation_channel.add(ct, cv.Channel.PIPE_MODE, cv.Channel.REPLICATED) 38 | 39 | # new for script mode/algorithm mode toggle 40 | code_channel = cv.Channel(name="code", required=False) 41 | code_channel.add("text/python", cv.Channel.FILE_MODE, cv.Channel.REPLICATED) 42 | 43 | data_channels = cv.Channels(train_channel, validation_channel, code_channel) 44 | data_channels.set_default_content_type("text/libsvm") 45 | 46 | return data_channels 47 | -------------------------------------------------------------------------------- /src/sagemaker_xgboost_container/distributed_gpu/dask_data_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the 'License'). You 4 | # may not use this file except in compliance with the License. A copy of 5 | # the License is located at 6 | # 7 | # http://aws.amazon.com/apache2.0/ 8 | # 9 | # or in the 'license' file accompanying this file. This file is 10 | # distributed on an 'AS IS' BASIS, WITHOUT WARRANTIES OR CONDITIONS OF 11 | # ANY KIND, either express or implied. See the License for the specific 12 | # language governing permissions and limitations under the License. 13 | 14 | import os 15 | 16 | 17 | import dask.dataframe as dask_dataframe 18 | from dask.dataframe import DataFrame, Series 19 | from dask.distributed import Client 20 | from xgboost import dask as dxgb 21 | 22 | from sagemaker_algorithm_toolkit.exceptions import AlgorithmError, UserError 23 | from sagemaker_xgboost_container.data_utils import CSV, PARQUET 24 | 25 | 26 | def read_data(local_path: str, content_type: str) -> (DataFrame, Series): 27 | if content_type == CSV: 28 | dataframe = dask_dataframe.read_csv( 29 | os.path.join(local_path, "*.csv"), header=None 30 | ) 31 | elif content_type == PARQUET: 32 | dataframe = dask_dataframe.read_parquet(local_path) 33 | else: 34 | raise UserError( 35 | f"Unexpected content type '{content_type}'. Supported content types are CSV and PARQUET." 36 | ) 37 | 38 | target_column = dataframe.columns[0] 39 | labels = dataframe[target_column] 40 | features = dataframe[dataframe.columns[1:]] 41 | 42 | return features, labels 43 | 44 | 45 | def get_dataframe_dimensions(dataframe: DataFrame) -> (int, int): 46 | df_shape = dataframe.shape 47 | # Note that dataframe.shape[0].compute() is an expensive operation. 48 | rows = df_shape[0].compute() 49 | cols = df_shape[1] 50 | return rows, cols 51 | 52 | 53 | def create_dask_dmatrix( 54 | client: Client, features: DataFrame, labels: Series 55 | ) -> dxgb.DaskDMatrix: 56 | try: 57 | dmatrix = dxgb.DaskDMatrix(client, features, labels) 58 | except Exception as e: 59 | raise AlgorithmError( 60 | f"Failed to create DaskDMatrix with given data. Exception: {e}" 61 | ) 62 | return dmatrix 63 | -------------------------------------------------------------------------------- /test/unit/algorithm_mode/test_serve.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the 'License'). You 4 | # may not use this file except in compliance with the License. A copy of 5 | # the License is located at 6 | # 7 | # http://aws.amazon.com/apache2.0/ 8 | # 9 | # or in the 'license' file accompanying this file. This file is 10 | # distributed on an 'AS IS' BASIS, WITHOUT WARRANTIES OR CONDITIONS OF 11 | # ANY KIND, either express or implied. See the License for the specific 12 | # language governing permissions and limitations under the License. 13 | from __future__ import absolute_import 14 | 15 | import json 16 | 17 | import pytest 18 | from mock import MagicMock, patch 19 | 20 | from sagemaker_xgboost_container.algorithm_mode import serve 21 | 22 | 23 | def test_default_execution_parameters(): 24 | execution_parameters_response = serve.execution_parameters() 25 | 26 | parsed_exec_params_response = json.loads(execution_parameters_response.response[0]) 27 | assert parsed_exec_params_response["MaxPayloadInMB"] == 6 28 | assert parsed_exec_params_response["BatchStrategy"] == "MULTI_RECORD" 29 | 30 | 31 | @patch("sagemaker_xgboost_container.algorithm_mode.serve.PARSED_MAX_CONTENT_LENGTH", 19 * 1024**2) 32 | def test_max_execution_parameters(): 33 | execution_parameters_response = serve.execution_parameters() 34 | 35 | parsed_exec_params_response = json.loads(execution_parameters_response.response[0]) 36 | assert parsed_exec_params_response["MaxPayloadInMB"] == 19 37 | assert parsed_exec_params_response["BatchStrategy"] == "MULTI_RECORD" 38 | 39 | 40 | def test_parse_accept(): 41 | mock_request = MagicMock() 42 | mock_request.headers.get.return_value = "application/json;verbose=True" 43 | assert serve._parse_accept(mock_request) == "application/json" 44 | 45 | 46 | def test_parse_accept_default(monkeypatch): 47 | mock_request = MagicMock() 48 | mock_request.headers = {} 49 | monkeypatch.setenv("SAGEMAKER_DEFAULT_INVOCATIONS_ACCEPT", "text/csv") 50 | assert serve._parse_accept(mock_request) == "text/csv" 51 | 52 | 53 | def test_parse_accept_incompatible(): 54 | mock_request = MagicMock() 55 | mock_request.headers.get.return_value = "text/libsvm" 56 | with pytest.raises(ValueError): 57 | serve._parse_accept(mock_request) 58 | -------------------------------------------------------------------------------- /src/sagemaker_xgboost_container/distributed_gpu/dask_cluster_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the 'License'). You 4 | # may not use this file except in compliance with the License. A copy of 5 | # the License is located at 6 | # 7 | # http://aws.amazon.com/apache2.0/ 8 | # 9 | # or in the 'license' file accompanying this file. This file is 10 | # distributed on an 'AS IS' BASIS, WITHOUT WARRANTIES OR CONDITIONS OF 11 | # ANY KIND, either express or implied. See the License for the specific 12 | # language governing permissions and limitations under the License. 13 | 14 | import socket 15 | from subprocess import Popen 16 | 17 | from dask.distributed import Client 18 | 19 | from sagemaker_algorithm_toolkit.exceptions import AlgorithmError, PlatformError 20 | 21 | SCHEDULER_EXEC_PATH = "/miniconda3/bin/dask-scheduler" 22 | CUDA_WORKER_EXEC_PATH = "/miniconda3/bin/dask-cuda-worker" 23 | 24 | SCHEDULER_CONN_TIMEOUT = "20s" 25 | 26 | 27 | def start_daemons_in_current_instance(scheduler_address: str, is_scheduler_host: bool): 28 | # Dask distributed scheduler API doc: https://docs.dask.org/en/stable/deploying-cli.html 29 | scheduler_cli_command = [SCHEDULER_EXEC_PATH, "--no-dashboard"] 30 | scheduler_conn_string = f"tcp://{scheduler_address}" 31 | # Dask cuda worker API doc: https://docs.rapids.ai/api/dask-cuda/nightly/api.html 32 | worker_cli_command = [ 33 | CUDA_WORKER_EXEC_PATH, 34 | scheduler_conn_string, 35 | "--no-dashboard", 36 | ] 37 | if is_scheduler_host: 38 | Popen(scheduler_cli_command) 39 | try: 40 | # Ensure that the scheduler is up before starting workers. 41 | with Client(scheduler_address, timeout=SCHEDULER_CONN_TIMEOUT): 42 | Popen(worker_cli_command) 43 | except TimeoutError as e: 44 | raise AlgorithmError( 45 | f"Couldn't connect to scheduler after {SCHEDULER_CONN_TIMEOUT}. Please try re-running the training job." 46 | f" Exception: {e}" 47 | ) 48 | 49 | 50 | def get_host_ip(host_name: str) -> str: 51 | try: 52 | host_ip = socket.gethostbyname(host_name) 53 | except socket.gaierror as e: 54 | # This shouldn't have happened, and it's not the user's fault. 55 | raise PlatformError( 56 | f"Failed hostname resolution for host '{host_name}', exception: {e}" 57 | ) 58 | return host_ip 59 | -------------------------------------------------------------------------------- /test/integration/local/test_kfold.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the 'License'). You 4 | # may not use this file except in compliance with the License. A copy of 5 | # the License is located at 6 | # 7 | # http://aws.amazon.com/apache2.0/ 8 | # 9 | # or in the 'license' file accompanying this file. This file is 10 | # distributed on an 'AS IS' BASIS, WITHOUT WARRANTIES OR CONDITIONS OF 11 | # ANY KIND, either express or implied. See the License for the specific 12 | # language governing permissions and limitations under the License. 13 | import os 14 | from test.utils import local_mode, test_utils 15 | 16 | import pytest 17 | 18 | path = os.path.dirname(os.path.realpath(__file__)) 19 | data_root = os.path.join(path, "..", "..", "resources") 20 | 21 | 22 | def get_abalone_default_hyperparameters(num_round=50): 23 | hyperparameters = { 24 | "max_depth": "5", 25 | "eta": "0.2", 26 | "gamma": "4", 27 | "min_child_weight": "6", 28 | "subsample": "0.7", 29 | "num_round": str(num_round), 30 | } 31 | return hyperparameters 32 | 33 | 34 | @pytest.mark.parametrize( 35 | "dataset,extra_hps,model_file_count", 36 | [ 37 | ("abalone", {"objective": "reg:squarederror", "_kfold": "5"}, 5), 38 | ("abalone-binary", {"objective": "binary:logistic", "_kfold": "5"}, 5), 39 | ("abalone-multiclass", {"objective": "multi:softprob", "num_class": "4", "_kfold": "5"}, 5), 40 | ("abalone", {"objective": "reg:squarederror", "_kfold": "5", "_num_cv_round": "2"}, 10), 41 | ( 42 | "abalone-multiclass", 43 | {"objective": "multi:softprob", "num_class": "4", "_kfold": "5", "_num_cv_round": "3"}, 44 | 15, 45 | ), 46 | ], 47 | ) 48 | def test_xgboost_abalone_kfold(dataset, extra_hps, model_file_count, docker_image, opt_ml): 49 | hyperparameters = get_abalone_default_hyperparameters() 50 | data_path = os.path.join(data_root, dataset, "data") 51 | 52 | local_mode.train( 53 | False, 54 | data_path, 55 | docker_image, 56 | opt_ml, 57 | hyperparameters={**hyperparameters, **extra_hps}, 58 | ) 59 | 60 | files = [f"model/xgboost-model-{i}" for i in range(model_file_count)] 61 | assert not local_mode.file_exists(opt_ml, "output/failure"), "Failure happened" 62 | test_utils.files_exist(opt_ml, files) 63 | local_mode.file_exists(opt_ml, "output/data/predictions.csv") 64 | -------------------------------------------------------------------------------- /test/unit/distributed_gpu/test_dask_data_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the 'License'). You 4 | # may not use this file except in compliance with the License. A copy of 5 | # the License is located at 6 | # 7 | # http://aws.amazon.com/apache2.0/ 8 | # 9 | # or in the 'license' file accompanying this file. This file is 10 | # distributed on an 'AS IS' BASIS, WITHOUT WARRANTIES OR CONDITIONS OF 11 | # ANY KIND, either express or implied. See the License for the specific 12 | # language governing permissions and limitations under the License. 13 | 14 | import os 15 | import unittest 16 | from pathlib import Path 17 | 18 | from sagemaker_algorithm_toolkit.exceptions import UserError 19 | from sagemaker_xgboost_container.data_utils import CSV, LIBSVM, PARQUET 20 | from sagemaker_xgboost_container.distributed_gpu.dask_data_utils import read_data 21 | 22 | 23 | class TestDaskDataUtils(unittest.TestCase): 24 | NUM_ROWS_IN_EACH_FILE = 5 25 | NUM_COLS_IN_EACH_FILE = 6 26 | 27 | def setUp(self): 28 | current_path = Path(os.path.abspath(__file__)) 29 | self.data_path_csv = os.path.join( 30 | str(current_path.parent.parent.parent), "resources", "data", "csv", "csv_files" 31 | ) 32 | self.data_path_csv_multiple = os.path.join( 33 | str(current_path.parent.parent.parent), "resources", "data", "csv", "multiple_files" 34 | ) 35 | self.data_path_parquet = os.path.join( 36 | str(current_path.parent.parent.parent), "resources", "data", "parquet", "multiple_files" 37 | ) 38 | 39 | def test_read_data_csv(self): 40 | x, y = read_data(self.data_path_csv, CSV) 41 | assert x.shape[0].compute() == self.NUM_ROWS_IN_EACH_FILE 42 | assert x.shape[1] == self.NUM_COLS_IN_EACH_FILE - 1 43 | assert len(y) == self.NUM_ROWS_IN_EACH_FILE 44 | 45 | def test_read_data_csv_malformed_path(self): 46 | x, y = read_data(self.data_path_csv + "/", CSV) 47 | assert x.shape[0].compute() == self.NUM_ROWS_IN_EACH_FILE 48 | 49 | def test_read_data_csv_multiple_files(self): 50 | x, y = read_data(self.data_path_csv_multiple, CSV) 51 | assert x.shape[0].compute() == self.NUM_ROWS_IN_EACH_FILE * 2 52 | 53 | def test_read_data_parquet(self): 54 | x, y = read_data(self.data_path_parquet, PARQUET) 55 | assert x.shape[0].compute() == self.NUM_ROWS_IN_EACH_FILE * 2 56 | assert x.shape[1] == self.NUM_COLS_IN_EACH_FILE - 1 57 | assert len(y) == self.NUM_ROWS_IN_EACH_FILE * 2 58 | 59 | def test_read_data_unsupported_content(self): 60 | with self.assertRaises(UserError): 61 | read_data(self.data_path_parquet, LIBSVM) 62 | -------------------------------------------------------------------------------- /test/unit/test_handler_service.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the 'License'). You 4 | # may not use this file except in compliance with the License. A copy of 5 | # the License is located at 6 | # 7 | # http://aws.amazon.com/apache2.0/ 8 | # 9 | # or in the 'license' file accompanying this file. This file is 10 | # distributed on an 'AS IS' BASIS, WITHOUT WARRANTIES OR CONDITIONS OF 11 | # ANY KIND, either express or implied. See the License for the specific 12 | # language governing permissions and limitations under the License. 13 | from __future__ import absolute_import 14 | 15 | import numpy as np 16 | import pytest 17 | import xgboost as xgb 18 | from mock import patch 19 | from sagemaker_containers.beta.framework import content_types, encoders, errors 20 | 21 | from sagemaker_xgboost_container.handler_service import HandlerService 22 | 23 | handler = HandlerService().DefaultXGBoostUserModuleInferenceHandler() 24 | 25 | 26 | @pytest.fixture(scope="module", name="np_array") 27 | def fixture_np_array(): 28 | return np.ones((2, 2)) 29 | 30 | 31 | class FakeEstimator: 32 | def __init__(self): 33 | pass 34 | 35 | @staticmethod 36 | def predict(input): 37 | return 38 | 39 | 40 | @pytest.mark.parametrize("csv_array", ("42,6,9", "42.0,6.0,9.0")) 41 | def test_input_fn_dmatrix(csv_array): 42 | deserialized_csv_array = handler.default_input_fn(csv_array, content_types.CSV) 43 | assert type(deserialized_csv_array) is xgb.DMatrix 44 | 45 | 46 | def test_input_fn_bad_content_type(): 47 | with pytest.raises(errors.UnsupportedFormatError): 48 | handler.default_input_fn("", "application/not_supported") 49 | 50 | 51 | def test_default_model_fn(): 52 | with pytest.raises(NotImplementedError): 53 | handler.default_model_fn("model_dir") 54 | 55 | 56 | def test_predict_fn(np_array): 57 | mock_estimator = FakeEstimator() 58 | with patch.object(mock_estimator, "predict") as mock: 59 | handler.default_predict_fn(np_array, mock_estimator) 60 | mock.assert_called_once() 61 | 62 | 63 | def test_output_fn_json(np_array): 64 | response = handler.default_output_fn(np_array, content_types.JSON) 65 | assert response == encoders.array_to_json(np_array.tolist()) 66 | 67 | 68 | def test_output_fn_csv(np_array): 69 | response = handler.default_output_fn(np_array, content_types.CSV) 70 | assert response == b"1.0,1.0\n1.0,1.0\n" 71 | 72 | 73 | def test_output_fn_npz(np_array): 74 | response = handler.default_output_fn(np_array, content_types.NPY) 75 | assert response == encoders.array_to_npy(np_array) 76 | 77 | 78 | def test_input_fn_bad_accept(): 79 | with pytest.raises(errors.UnsupportedFormatError): 80 | handler.default_output_fn("", "application/not_supported") 81 | -------------------------------------------------------------------------------- /test/unit/distributed_gpu/test_distributed_gpu_training.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the 'License'). You 4 | # may not use this file except in compliance with the License. A copy of 5 | # the License is located at 6 | # 7 | # http://aws.amazon.com/apache2.0/ 8 | # 9 | # or in the 'license' file accompanying this file. This file is 10 | # distributed on an 'AS IS' BASIS, WITHOUT WARRANTIES OR CONDITIONS OF 11 | # ANY KIND, either express or implied. See the License for the specific 12 | # language governing permissions and limitations under the License. 13 | 14 | import unittest 15 | 16 | from sagemaker_algorithm_toolkit import channel_validation as cv 17 | from sagemaker_xgboost_container.distributed_gpu.distributed_gpu_training import ( 18 | INPUT_FORMAT_ERROR_MSG, 19 | NON_GPU_ERROR_MSG, 20 | NOT_REPLICATED_ERROR_MSG, 21 | PIPE_MODE_ERROR_MSG, 22 | validate_gpu_train_configuration, 23 | ) 24 | 25 | 26 | class TestDistributedGPUTraining(unittest.TestCase): 27 | def setUp(self): 28 | self.train_channel_replicated = {"train": {cv.S3_DIST_TYPE: "FullyReplicated"}} 29 | self.train_channel_not_replicated = {"train": {cv.S3_DIST_TYPE: "ShardedByS3Key"}} 30 | self.multi_channel_not_replicated = { 31 | "train": {cv.S3_DIST_TYPE: "FullyReplicated"}, 32 | "valid": {cv.S3_DIST_TYPE: "ShardedByS3Key"}, 33 | } 34 | 35 | def test_conditions_fail_channel_not_replicated_multi_host(self): 36 | exc_list = validate_gpu_train_configuration("gpu_hist", 2, 1, "File", "csv", self.multi_channel_not_replicated) 37 | assert NOT_REPLICATED_ERROR_MSG in exc_list 38 | 39 | def test_conditions_pass_channel_replicated_multi_host(self): 40 | exc_list = validate_gpu_train_configuration("gpu_hist", 2, 1, "File", "csv", self.train_channel_replicated) 41 | assert not exc_list 42 | 43 | def test_conditions_pass_channel_not_replicated_singlehost(self): 44 | exc_list = validate_gpu_train_configuration("gpu_hist", 1, 1, "File", "csv", self.train_channel_not_replicated) 45 | assert not exc_list 46 | 47 | def test_conditions_fail_not_gpu_instance(self): 48 | exc_list = validate_gpu_train_configuration("gpu_hist", 1, 0, "File", "csv", self.train_channel_replicated) 49 | assert NON_GPU_ERROR_MSG in exc_list 50 | 51 | def test_conditions_fail_non_gpu_tree_method(self): 52 | exc_list = validate_gpu_train_configuration("approx", 1, 1, "File", "csv", self.train_channel_replicated) 53 | assert NON_GPU_ERROR_MSG in exc_list 54 | 55 | def test_conditions_fail_pipe_mode(self): 56 | exc_list = validate_gpu_train_configuration("gpu_hist", 1, 1, "Pipe", "csv", self.train_channel_replicated) 57 | assert PIPE_MODE_ERROR_MSG in exc_list 58 | 59 | def test_conditions_fail_unsupported_format(self): 60 | exc_list = validate_gpu_train_configuration("gpu_hist", 1, 1, "File", "libsvm", self.train_channel_replicated) 61 | assert INPUT_FORMAT_ERROR_MSG in exc_list 62 | 63 | def test_conditions_fail_multiple_checks(self): 64 | exc_list = validate_gpu_train_configuration("approx", 1, 1, "Pipe", "libsvm", self.train_channel_replicated) 65 | assert len(exc_list) == 3 66 | -------------------------------------------------------------------------------- /test/integration/local/test_early_stopping.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the 'License'). You 4 | # may not use this file except in compliance with the License. A copy of 5 | # the License is located at 6 | # 7 | # http://aws.amazon.com/apache2.0/ 8 | # 9 | # or in the 'license' file accompanying this file. This file is 10 | # distributed on an 'AS IS' BASIS, WITHOUT WARRANTIES OR CONDITIONS OF 11 | # ANY KIND, either express or implied. See the License for the specific 12 | # language governing permissions and limitations under the License. 13 | from __future__ import absolute_import 14 | 15 | import os 16 | from test.utils import local_mode 17 | 18 | path = os.path.dirname(os.path.realpath(__file__)) 19 | early_stopping_path = os.path.join(path, "..", "..", "resources", "early_stopping") 20 | data_dir = os.path.join(early_stopping_path, "data") 21 | 22 | 23 | def get_default_hyperparameters(num_round=100): 24 | hyperparameters = { 25 | "max_depth": "5", 26 | "eta": "0.2", 27 | "gamma": "4", 28 | "min_child_weight": "6", 29 | "subsample": "0.7", 30 | "num_round": str(num_round), 31 | } 32 | return hyperparameters 33 | 34 | 35 | def test_xgboost_training_single_machine_with_early_stopping(docker_image, opt_ml): 36 | hyperparameters = get_default_hyperparameters(100000) 37 | hyperparameters["save_model_on_termination"] = "true" 38 | 39 | local_mode.train( 40 | False, data_dir, docker_image, opt_ml, hyperparameters=hyperparameters, early_stopping=True, train_time=10 41 | ) 42 | 43 | assert local_mode.file_exists(opt_ml, "model/xgboost-model"), "Model not saved" 44 | 45 | 46 | def test_xgboost_training_single_machine_without_early_stopping(docker_image, opt_ml): 47 | hyperparameters = get_default_hyperparameters(100000) 48 | hyperparameters["save_model_on_termination"] = "false" 49 | 50 | local_mode.train( 51 | False, data_dir, docker_image, opt_ml, hyperparameters=hyperparameters, early_stopping=True, train_time=10 52 | ) 53 | 54 | assert not local_mode.file_exists(opt_ml, "model/xgboost-model"), "Model saved" 55 | 56 | 57 | def test_xgboost_training_multiple_machines_with_early_stopping(docker_image, opt_ml): 58 | hyperparameters = get_default_hyperparameters(100000) 59 | hyperparameters["save_model_on_termination"] = "true" 60 | 61 | local_mode.train( 62 | False, data_dir, docker_image, opt_ml, hyperparameters=hyperparameters, cluster_size=2, early_stopping=True 63 | ) 64 | 65 | host1 = local_mode.file_exists(opt_ml, "model/xgboost-model", "algo-1") 66 | host2 = local_mode.file_exists(opt_ml, "model/xgboost-model", "algo-2") 67 | assert host1 or host2, "Model not saved on any host" 68 | assert not (host1 and host2), "Model saved on both hosts" 69 | 70 | 71 | def test_xgboost_training_multiple_machines_without_early_stopping(docker_image, opt_ml): 72 | hyperparameters = get_default_hyperparameters(100000) 73 | hyperparameters["save_model_on_termination"] = "false" 74 | 75 | local_mode.train( 76 | False, data_dir, docker_image, opt_ml, hyperparameters=hyperparameters, cluster_size=2, early_stopping=True 77 | ) 78 | 79 | host1 = local_mode.file_exists(opt_ml, "model/xgboost-model", "algo-1") 80 | host2 = local_mode.file_exists(opt_ml, "model/xgboost-model", "algo-2") 81 | assert not (host1 or host2), "Model saved on some host" 82 | -------------------------------------------------------------------------------- /src/sagemaker_xgboost_container/constants/xgb_constants.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the 'License'). You 4 | # may not use this file except in compliance with the License. A copy of 5 | # the License is located at 6 | # 7 | # http://aws.amazon.com/apache2.0/ 8 | # 9 | # or in the 'license' file accompanying this file. This file is 10 | # distributed on an 'AS IS' BASIS, WITHOUT WARRANTIES OR CONDITIONS OF 11 | # ANY KIND, either express or implied. See the License for the specific 12 | # language governing permissions and limitations under the License. 13 | 14 | XGB_MAXIMIZE_METRICS = [ 15 | "accuracy", 16 | "auc", 17 | "aucpr", 18 | "balanced_accuracy", 19 | "f1", 20 | "f1_binary", 21 | "f1_macro", 22 | "map", 23 | "ndcg", 24 | "precision", 25 | "r2", 26 | "recall", 27 | "precision_macro", 28 | "precision_micro", 29 | "recall_macro", 30 | "recall_micro", 31 | ] 32 | 33 | XGB_MINIMIZE_METRICS = [ 34 | "aft-nloglik", 35 | "cox-nloglik", 36 | "error", 37 | "gamma-deviance", 38 | "gamma-nloglik", 39 | "interval-regression-accuracy", 40 | "logloss", 41 | "mae", 42 | "mape", 43 | "merror", 44 | "mlogloss", 45 | "mphe", 46 | "mse", 47 | "poisson-nloglik", 48 | "rmse", 49 | "rmsle", 50 | "tweedie-nloglik", 51 | ] 52 | 53 | LOGISTIC_REGRESSION_LABEL_RANGE_ERROR = "label must be in [0,1] for logistic regression" 54 | MULTI_CLASS_LABEL_RANGE_ERROR = "label must be in [0, num_class)" 55 | MULTI_CLASS_F1_BINARY_ERROR = "Target is multiclass but average='binary'" 56 | FEATURE_MISMATCH_ERROR = "feature_names mismatch" 57 | LABEL_PREDICTION_SIZE_MISMATCH = "Check failed: preds.size() == info.labels_.size()" 58 | ONLY_POS_OR_NEG_SAMPLES = "Check failed: !auc_error AUC: the dataset only contains pos or neg samples" 59 | BASE_SCORE_RANGE_ERROR = ( 60 | "Check failed: base_score > 0.0f && base_score < 1.0f base_score must be in (0,1) " "for logistic loss" 61 | ) 62 | POISSON_REGRESSION_ERROR = "Check failed: label_correct PoissonRegression: label must be nonnegative" 63 | TWEEDIE_REGRESSION_ERROR = "Check failed: label_correct TweedieRegression: label must be nonnegative" 64 | REG_LAMBDA_ERROR = "Parameter reg_lambda should be greater equal to 0" 65 | 66 | CUSTOMER_ERRORS = [ 67 | LOGISTIC_REGRESSION_LABEL_RANGE_ERROR, 68 | MULTI_CLASS_LABEL_RANGE_ERROR, 69 | MULTI_CLASS_F1_BINARY_ERROR, 70 | FEATURE_MISMATCH_ERROR, 71 | LABEL_PREDICTION_SIZE_MISMATCH, 72 | ONLY_POS_OR_NEG_SAMPLES, 73 | BASE_SCORE_RANGE_ERROR, 74 | POISSON_REGRESSION_ERROR, 75 | TWEEDIE_REGRESSION_ERROR, 76 | REG_LAMBDA_ERROR, 77 | ] 78 | 79 | _SEPARATOR = ":" 80 | TRAIN_CHANNEL = "train" 81 | VAL_CHANNEL = "validation" 82 | 83 | # xgboost objective learning tasks 84 | # https://xgboost.readthedocs.io/en/release_1.0.0/parameter.html#learning-task-parameters 85 | REG_SQUAREDERR = "reg:squarederror" 86 | REG_LOG = "reg:logistic" 87 | REG_GAMMA = "reg:gamma" 88 | REG_ABSOLUTEERR = "reg:absoluteerror" 89 | REG_TWEEDIE = "reg:tweedie" 90 | BINARY_LOG = "binary:logistic" 91 | BINARY_LOGRAW = "binary:logitraw" 92 | BINARY_HINGE = "binary:hinge" 93 | MULTI_SOFTMAX = "multi:softmax" 94 | MULTI_SOFTPROB = "multi:softprob" 95 | 96 | MODEL_NAME = "xgboost-model" 97 | GPU_TREE_METHOD = "gpu_hist" 98 | 99 | FULLY_REPLICATED = "FullyReplicated" 100 | PIPE_MODE = "Pipe" 101 | -------------------------------------------------------------------------------- /test/integration/test_metadata_calls.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the 'License'). You 4 | # may not use this file except in compliance with the License. A copy of 5 | # the License is located at 6 | # 7 | # http://aws.amazon.com/apache2.0/ 8 | # 9 | # or in the 'license' file accompanying this file. This file is 10 | # distributed on an 'AS IS' BASIS, WITHOUT WARRANTIES OR CONDITIONS OF 11 | # ANY KIND, either express or implied. See the License for the specific 12 | # language governing permissions and limitations under the License. 13 | import os 14 | import pprint 15 | import unittest 16 | from datetime import datetime 17 | 18 | import boto3 19 | 20 | from sagemaker_xgboost_container.algorithm_mode import channel_validation as cv 21 | from sagemaker_xgboost_container.algorithm_mode import hyperparameter_validation as hpv 22 | from sagemaker_xgboost_container.algorithm_mode import metadata 23 | from sagemaker_xgboost_container.algorithm_mode import metrics as metrics_mod 24 | 25 | 26 | class TestCreateAlgorithm(unittest.TestCase): 27 | def test_create_algorithm(self): 28 | IMAGE_URI = os.getenv("TEST_IMAGE_URI") 29 | ALGORITHM_NAME = os.getenv("TEST_ALGORITHM_NAME") 30 | ROLE_ARN = os.getenv("TEST_ROLE_ARN") 31 | OUTPUT_PATH = os.getenv("TEST_OUTPUT_PATH") 32 | 33 | if IMAGE_URI is None: 34 | self.fail("Set TEST_IMAGE_URI environment variable.") 35 | if ALGORITHM_NAME is None: 36 | self.fail("Set TEST_ALGORITHM_NAME environment variable.") 37 | if ROLE_ARN is None: 38 | self.fail("Set TEST_ROLE_ARN environment variable.") 39 | if OUTPUT_PATH is None: 40 | self.fail("Set TEST_OUTPUT_PATH environment variable.") 41 | 42 | metrics = metrics_mod.initialize() 43 | hyperparameters = hpv.initialize(metrics) 44 | channels = cv.initialize() 45 | md = metadata.initialize(IMAGE_URI, hyperparameters, channels, metrics) 46 | 47 | client = boto3.client("sagemaker", region_name="us-west-2") 48 | try: 49 | client.delete_algorithm(AlgorithmName=ALGORITHM_NAME) 50 | except Exception as e: 51 | print(e) 52 | 53 | pprint.pprint(md) 54 | client.create_algorithm(AlgorithmName=ALGORITHM_NAME, **md) 55 | 56 | objective = metrics["validation:error"] 57 | now = datetime.now() 58 | dt_string = now.strftime("%Y%m%d-%H%M%S") 59 | 60 | client.create_hyper_parameter_tuning_job( 61 | HyperParameterTuningJobName="test-hpo-" + dt_string, 62 | HyperParameterTuningJobConfig={ 63 | "Strategy": "Random", 64 | "ResourceLimits": {"MaxNumberOfTrainingJobs": 6, "MaxParallelTrainingJobs": 2}, 65 | "HyperParameterTuningJobObjective": objective.format_tunable(), 66 | "ParameterRanges": hyperparameters["alpha"].format_tunable_range(), 67 | }, 68 | TrainingJobDefinition={ 69 | "AlgorithmSpecification": {"AlgorithmName": ALGORITHM_NAME, "TrainingInputMode": "File"}, 70 | "StaticHyperParameters": {"num_round": "3"}, 71 | "RoleArn": ROLE_ARN, 72 | "OutputDataConfig": {"S3OutputPath": OUTPUT_PATH}, 73 | "ResourceConfig": {"InstanceType": "ml.m5.xlarge", "InstanceCount": 1, "VolumeSizeInGB": 5}, 74 | "StoppingCondition": {"MaxRuntimeInSeconds": 300}, 75 | }, 76 | ) 77 | -------------------------------------------------------------------------------- /test/unit/algorithm_mode/test_train_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the 'License'). You 4 | # may not use this file except in compliance with the License. A copy of 5 | # the License is located at 6 | # 7 | # http://aws.amazon.com/apache2.0/ 8 | # 9 | # or in the 'license' file accompanying this file. This file is 10 | # distributed on an 'AS IS' BASIS, WITHOUT WARRANTIES OR CONDITIONS OF 11 | # ANY KIND, either express or implied. See the License for the specific 12 | # language governing permissions and limitations under the License. 13 | from __future__ import absolute_import 14 | 15 | import math 16 | import os 17 | import shutil 18 | import tempfile 19 | 20 | import numpy as np 21 | import xgboost as xgb 22 | 23 | from sagemaker_xgboost_container.algorithm_mode import train_utils 24 | 25 | 26 | def test_get_union_metrics(): 27 | a = ["metric_1", "metric_2"] 28 | b = ["metric_1", "metric_3"] 29 | 30 | union = train_utils.get_union_metrics(a, b) 31 | assert len(union) == 3 32 | for metric in union: 33 | assert metric in ["metric_1", "metric_2", "metric_3"] 34 | 35 | 36 | def test_get_eval_metrics_and_feval(): 37 | test_objective = "validation:logloss" 38 | test_evals = ["accuracy", "rmse"] 39 | 40 | test_eval_metrics, test_configured_eval, tuning_metric = train_utils.get_eval_metrics_and_feval( 41 | test_objective, test_evals 42 | ) 43 | 44 | assert len(test_eval_metrics) == 1 45 | for metric in test_eval_metrics: 46 | assert metric in ["logloss"] 47 | 48 | binary_train_data = np.random.rand(10, 2) 49 | binary_train_label = np.array([1, 1, 1, 1, 1, 0, 0, 0, 0, 0]) 50 | binary_dtrain = xgb.DMatrix(binary_train_data, label=binary_train_label) 51 | binary_preds = np.ones(10) 52 | 53 | custom_metric_results = test_configured_eval(binary_preds, binary_dtrain) 54 | custom_metric_results.sort() 55 | 56 | assert 2 == len(custom_metric_results) 57 | assert ("accuracy", 0.5) == custom_metric_results[0] 58 | assert ("rmse", math.sqrt(0.5)) == custom_metric_results[1] 59 | 60 | 61 | def test_cleanup_dir(): 62 | def setup(file_names): 63 | test_dir = tempfile.mkdtemp() 64 | for file_name in file_names: 65 | test_path = os.path.join(test_dir, file_name) 66 | with open(test_path, "w"): 67 | pass 68 | 69 | return test_dir 70 | 71 | def tearDown(dir): 72 | shutil.rmtree(dir) 73 | 74 | # Test 1: Check if 'xgboost-model' is present after cleanup 75 | model_name = "xgboost-model" 76 | file_names = ["tmp1", "tmp2", "xgboost-model"] 77 | test_dir = setup(file_names) 78 | 79 | train_utils.cleanup_dir(test_dir, model_name) 80 | files = os.listdir(test_dir) 81 | 82 | assert len(files) == 1 83 | assert files[0] == model_name 84 | 85 | tearDown(test_dir) 86 | 87 | # Test 2: Check if directory is empty after cleanup 88 | file_names = ["tmp1", "tmp2"] 89 | test_dir = setup(file_names) 90 | 91 | train_utils.cleanup_dir(test_dir, model_name) 92 | files = os.listdir(test_dir) 93 | 94 | assert len(files) == 0 95 | 96 | tearDown(test_dir) 97 | 98 | # Test 3: Check if directory is empty after cleanup 99 | file_names = [] 100 | test_dir = setup(file_names) 101 | 102 | train_utils.cleanup_dir(test_dir, model_name) 103 | files = os.listdir(test_dir) 104 | 105 | assert len(files) == 0 106 | 107 | tearDown(test_dir) 108 | -------------------------------------------------------------------------------- /src/sagemaker_xgboost_container/mms_patch/mms_transformer.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the 'License'). You 4 | # may not use this file except in compliance with the License. A copy of 5 | # the License is located at 6 | # 7 | # http://aws.amazon.com/apache2.0/ 8 | # 9 | # or in the 'license' file accompanying this file. This file is 10 | # distributed on an 'AS IS' BASIS, WITHOUT WARRANTIES OR CONDITIONS OF 11 | # ANY KIND, either express or implied. See the License for the specific 12 | # language governing permissions and limitations under the License. 13 | from __future__ import absolute_import 14 | 15 | import http 16 | import logging 17 | 18 | from sagemaker_inference import content_types, utils 19 | from sagemaker_inference.errors import BaseInferenceToolkitError 20 | from sagemaker_inference.transformer import Transformer 21 | 22 | 23 | class XGBMMSTransformer(Transformer): 24 | def transform(self, data, context): 25 | """Take a request with input data, deserialize it, make a prediction, and return a 26 | serialized response. 27 | 28 | NOTE: This is almost a copy of the original Transformer method, except it does not decode the utf-8 data. 29 | This is done for backwards compatibility. 30 | 31 | See line removed here: 32 | https://github.com/aws/sagemaker-inference-toolkit/blob/master/src/sagemaker_inference/transformer.py#L123 33 | 34 | Args: 35 | data (obj): the request data. 36 | context (obj): metadata on the incoming request data. 37 | Returns: 38 | list[obj]: the serialized prediction result wrapped in a list. 39 | """ 40 | if not self._initialized: 41 | try: 42 | sys_properties = context._system_properties 43 | model_dir = sys_properties.get("model_dir") 44 | self.validate_and_initialize(model_dir) 45 | except Exception as e: 46 | if isinstance(e, BaseInferenceToolkitError): 47 | logging.error("Error loading model: {}".format(e)) 48 | return self.handle_error(context, e.status_code, e.message) 49 | else: 50 | raise e 51 | self._initialized = True 52 | 53 | try: 54 | input_data = data[0].get("body") 55 | 56 | request_processor = context.request_processor[0] 57 | 58 | request_property = request_processor.get_request_properties() 59 | content_type = utils.retrieve_content_type_header(request_property) 60 | accept = request_property.get("Accept") or request_property.get("accept") 61 | 62 | if not accept or accept == content_types.ANY: 63 | accept = self._environment.default_accept 64 | 65 | result = self._transform_fn(self._model, input_data, content_type, accept) 66 | 67 | response = result 68 | response_content_type = accept 69 | 70 | if isinstance(result, tuple): 71 | # handles tuple for backwards compatibility 72 | response = result[0] 73 | response_content_type = result[1] 74 | 75 | context.set_response_content_type(0, response_content_type) 76 | return [response] 77 | except Exception as e: 78 | if isinstance(e, BaseInferenceToolkitError): 79 | logging.error(e) 80 | return self.handle_error(context, e.status_code, e.message) 81 | else: 82 | return self.handle_error(context, http.HTTPStatus.BAD_REQUEST, e.message) 83 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing Guidelines 2 | 3 | Thank you for your interest in contributing to our project. Whether it's a bug report, new feature, correction, or additional 4 | documentation, we greatly value feedback and contributions from our community. 5 | 6 | Please read through this document before submitting any issues or pull requests to ensure we have all the necessary 7 | information to effectively respond to your bug report or contribution. 8 | 9 | 10 | ## Reporting Bugs/Feature Requests 11 | 12 | We welcome you to use the GitHub issue tracker to report bugs or suggest features. 13 | 14 | When filing an issue, please check [existing open](https://github.com/aws-samples/sagemaker-xgboost-container/issues), or [recently closed](https://github.com/aws-samples/sagemaker-xgboost-container/issues?utf8=%E2%9C%93&q=is%3Aissue%20is%3Aclosed%20), issues to make sure somebody else hasn't already 15 | reported the issue. Please try to include as much information as you can. Details like these are incredibly useful: 16 | 17 | * A reproducible test case or series of steps 18 | * The version of our code being used 19 | * Any modifications you've made relevant to the bug 20 | * Anything unusual about your environment or deployment 21 | 22 | 23 | ## Contributing via Pull Requests 24 | Contributions via pull requests are much appreciated. Before sending us a pull request, please ensure that: 25 | 26 | 1. You are working against the latest source on the *master* branch. 27 | 2. You check existing open, and recently merged, pull requests to make sure someone else hasn't addressed the problem already. 28 | 3. You open an issue to discuss any significant work - we would hate for your time to be wasted. 29 | 30 | To send us a pull request, please: 31 | 32 | 1. Fork the repository. 33 | 2. Modify the source; please focus on the specific change you are contributing. If you also reformat all the code, it will be hard for us to focus on your change. 34 | 3. Ensure local tests pass. 35 | 4. Commit to your fork using clear commit messages. 36 | 5. Send us a pull request, answering any default questions in the pull request interface. 37 | 6. Pay attention to any automated CI failures reported in the pull request, and stay involved in the conversation. 38 | 39 | GitHub provides additional document on [forking a repository](https://help.github.com/articles/fork-a-repo/) and 40 | [creating a pull request](https://help.github.com/articles/creating-a-pull-request/). 41 | 42 | 43 | ## Finding contributions to work on 44 | Looking at the existing issues is a great way to find something to contribute on. As our projects, by default, use the default GitHub issue labels (enhancement/bug/duplicate/help wanted/invalid/question/wontfix), looking at any ['help wanted'](https://github.com/aws-samples/sagemaker-xgboost-container/labels/help%20wanted) issues is a great place to start. 45 | 46 | 47 | ## Code of Conduct 48 | This project has adopted the [Amazon Open Source Code of Conduct](https://aws.github.io/code-of-conduct). 49 | For more information see the [Code of Conduct FAQ](https://aws.github.io/code-of-conduct-faq) or contact 50 | opensource-codeofconduct@amazon.com with any additional questions or comments. 51 | 52 | 53 | ## Security issue notifications 54 | If you discover a potential security issue in this project we ask that you notify AWS/Amazon Security via our [vulnerability reporting page](http://aws.amazon.com/security/vulnerability-reporting/). Please do **not** create a public github issue. 55 | 56 | 57 | ## Licensing 58 | 59 | See the [LICENSE](https://github.com/aws-samples/sagemaker-xgboost-container/blob/master/LICENSE) file for our project's licensing. We will ask you to confirm the licensing of your contribution. 60 | 61 | We may ask you to sign a [Contributor License Agreement (CLA)](http://en.wikipedia.org/wiki/Contributor_License_Agreement) for larger changes. 62 | -------------------------------------------------------------------------------- /docker/3.0-5/resources/mms/ExecutionParameters.java: -------------------------------------------------------------------------------- 1 | package software.amazon.ai.mms.plugins.endpoint; 2 | 3 | import com.google.gson.GsonBuilder; 4 | import com.google.gson.annotations.SerializedName; 5 | import java.io.IOException; 6 | import java.nio.charset.StandardCharsets; 7 | import java.util.Properties; 8 | import software.amazon.ai.mms.servingsdk.Context; 9 | import software.amazon.ai.mms.servingsdk.ModelServerEndpoint; 10 | import software.amazon.ai.mms.servingsdk.annotations.Endpoint; 11 | import software.amazon.ai.mms.servingsdk.annotations.helpers.EndpointTypes; 12 | import software.amazon.ai.mms.servingsdk.http.Request; 13 | import software.amazon.ai.mms.servingsdk.http.Response; 14 | 15 | /** 16 | The modified endpoint source code for the jar used in this container. 17 | You can create this endpoint by moving it by cloning the MMS repo: 18 | > git clone https://github.com/awslabs/mxnet-model-server.git 19 | 20 | Copy this file into plugins/endpoints/src/main/java/software/amazon/ai/mms/plugins/endpoints/ 21 | and then from the plugins directory, run: 22 | 23 | > ./gradlew fJ 24 | 25 | Modify file in plugins/endpoint/resources/META-INF/services/* to specify this file location 26 | 27 | Then build the JAR: 28 | 29 | > ./gradlew build 30 | 31 | The jar should be available in plugins/endpoints/build/libs as endpoints-1.0.jar 32 | **/ 33 | @Endpoint( 34 | urlPattern = "execution-parameters", 35 | endpointType = EndpointTypes.INFERENCE, 36 | description = "Execution parameters endpoint") 37 | public class ExecutionParameters extends ModelServerEndpoint { 38 | 39 | @Override 40 | public void doGet(Request req, Response rsp, Context ctx) throws IOException { 41 | Properties prop = ctx.getConfig(); 42 | // 6 * 1024 * 1024 43 | int maxRequestSize = Integer.parseInt(prop.getProperty("max_request_size", "6291456")); 44 | SagemakerXgboostResponse response = new SagemakerXgboostResponse(); 45 | response.setMaxConcurrentTransforms(Integer.parseInt(prop.getProperty("NUM_WORKERS", "1"))); 46 | response.setBatchStrategy("MULTI_RECORD"); 47 | response.setMaxPayloadInMB(maxRequestSize / (1024 * 1024)); 48 | rsp.getOutputStream() 49 | .write( 50 | new GsonBuilder() 51 | .setPrettyPrinting() 52 | .create() 53 | .toJson(response) 54 | .getBytes(StandardCharsets.UTF_8)); 55 | } 56 | 57 | /** Response for Model server endpoint */ 58 | public static class SagemakerXgboostResponse { 59 | @SerializedName("MaxConcurrentTransforms") 60 | private int maxConcurrentTransforms; 61 | 62 | @SerializedName("BatchStrategy") 63 | private String batchStrategy; 64 | 65 | @SerializedName("MaxPayloadInMB") 66 | private int maxPayloadInMB; 67 | 68 | public SagemakerXgboostResponse() { 69 | maxConcurrentTransforms = 4; 70 | batchStrategy = "MULTI_RECORD"; 71 | maxPayloadInMB = 6; 72 | } 73 | 74 | public int getMaxConcurrentTransforms() { 75 | return maxConcurrentTransforms; 76 | } 77 | 78 | public String getBatchStrategy() { 79 | return batchStrategy; 80 | } 81 | 82 | public int getMaxPayloadInMB() { 83 | return maxPayloadInMB; 84 | } 85 | 86 | public void setMaxConcurrentTransforms(int newMaxConcurrentTransforms) { 87 | maxConcurrentTransforms = newMaxConcurrentTransforms; 88 | } 89 | 90 | public void setBatchStrategy(String newBatchStrategy) { 91 | batchStrategy = newBatchStrategy; 92 | } 93 | 94 | public void setMaxPayloadInMB(int newMaxPayloadInMB) { 95 | maxPayloadInMB = newMaxPayloadInMB; 96 | } 97 | } 98 | } 99 | -------------------------------------------------------------------------------- /docker/1.7-1-1/resources/mms/ExecutionParameters.java: -------------------------------------------------------------------------------- 1 | package software.amazon.ai.mms.plugins.endpoint; 2 | 3 | import com.google.gson.GsonBuilder; 4 | import com.google.gson.annotations.SerializedName; 5 | import java.io.IOException; 6 | import java.nio.charset.StandardCharsets; 7 | import java.util.Properties; 8 | import software.amazon.ai.mms.servingsdk.Context; 9 | import software.amazon.ai.mms.servingsdk.ModelServerEndpoint; 10 | import software.amazon.ai.mms.servingsdk.annotations.Endpoint; 11 | import software.amazon.ai.mms.servingsdk.annotations.helpers.EndpointTypes; 12 | import software.amazon.ai.mms.servingsdk.http.Request; 13 | import software.amazon.ai.mms.servingsdk.http.Response; 14 | 15 | /** 16 | The modified endpoint source code for the jar used in this container. 17 | You can create this endpoint by moving it by cloning the MMS repo: 18 | > git clone https://github.com/awslabs/mxnet-model-server.git 19 | 20 | Copy this file into plugins/endpoints/src/main/java/software/amazon/ai/mms/plugins/endpoints/ 21 | and then from the plugins directory, run: 22 | 23 | > ./gradlew fJ 24 | 25 | Modify file in plugins/endpoint/resources/META-INF/services/* to specify this file location 26 | 27 | Then build the JAR: 28 | 29 | > ./gradlew build 30 | 31 | The jar should be available in plugins/endpoints/build/libs as endpoints-1.0.jar 32 | **/ 33 | @Endpoint( 34 | urlPattern = "execution-parameters", 35 | endpointType = EndpointTypes.INFERENCE, 36 | description = "Execution parameters endpoint") 37 | public class ExecutionParameters extends ModelServerEndpoint { 38 | 39 | @Override 40 | public void doGet(Request req, Response rsp, Context ctx) throws IOException { 41 | Properties prop = ctx.getConfig(); 42 | // 6 * 1024 * 1024 43 | int maxRequestSize = Integer.parseInt(prop.getProperty("max_request_size", "6291456")); 44 | SagemakerXgboostResponse response = new SagemakerXgboostResponse(); 45 | response.setMaxConcurrentTransforms(Integer.parseInt(prop.getProperty("NUM_WORKERS", "1"))); 46 | response.setBatchStrategy("MULTI_RECORD"); 47 | response.setMaxPayloadInMB(maxRequestSize / (1024 * 1024)); 48 | rsp.getOutputStream() 49 | .write( 50 | new GsonBuilder() 51 | .setPrettyPrinting() 52 | .create() 53 | .toJson(response) 54 | .getBytes(StandardCharsets.UTF_8)); 55 | } 56 | 57 | /** Response for Model server endpoint */ 58 | public static class SagemakerXgboostResponse { 59 | @SerializedName("MaxConcurrentTransforms") 60 | private int maxConcurrentTransforms; 61 | 62 | @SerializedName("BatchStrategy") 63 | private String batchStrategy; 64 | 65 | @SerializedName("MaxPayloadInMB") 66 | private int maxPayloadInMB; 67 | 68 | public SagemakerXgboostResponse() { 69 | maxConcurrentTransforms = 4; 70 | batchStrategy = "MULTI_RECORD"; 71 | maxPayloadInMB = 6; 72 | } 73 | 74 | public int getMaxConcurrentTransforms() { 75 | return maxConcurrentTransforms; 76 | } 77 | 78 | public String getBatchStrategy() { 79 | return batchStrategy; 80 | } 81 | 82 | public int getMaxPayloadInMB() { 83 | return maxPayloadInMB; 84 | } 85 | 86 | public void setMaxConcurrentTransforms(int newMaxConcurrentTransforms) { 87 | maxConcurrentTransforms = newMaxConcurrentTransforms; 88 | } 89 | 90 | public void setBatchStrategy(String newBatchStrategy) { 91 | batchStrategy = newBatchStrategy; 92 | } 93 | 94 | public void setMaxPayloadInMB(int newMaxPayloadInMB) { 95 | maxPayloadInMB = newMaxPayloadInMB; 96 | } 97 | } 98 | } 99 | -------------------------------------------------------------------------------- /test/unit/algorithm_mode/test_hyperparameter_validation.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the 'License'). You 4 | # may not use this file except in compliance with the License. A copy of 5 | # the License is located at 6 | # 7 | # http://aws.amazon.com/apache2.0/ 8 | # 9 | # or in the 'license' file accompanying this file. This file is 10 | # distributed on an 'AS IS' BASIS, WITHOUT WARRANTIES OR CONDITIONS OF 11 | # ANY KIND, either express or implied. See the License for the specific 12 | # language governing permissions and limitations under the License. 13 | from __future__ import absolute_import 14 | 15 | import unittest 16 | 17 | from sagemaker_algorithm_toolkit import exceptions as exc 18 | from sagemaker_xgboost_container.algorithm_mode import hyperparameter_validation as hpv 19 | from sagemaker_xgboost_container.algorithm_mode import metrics as metrics_mod 20 | 21 | metrics = metrics_mod.initialize() 22 | hyperparameters = hpv.initialize(metrics) 23 | 24 | 25 | class TestHyperparameterValidation(unittest.TestCase): 26 | def test_auc_invalid_objective(self): 27 | test_hp = {"eval_metric": "auc"} 28 | 29 | auc_invalid_objectives = [ 30 | "count:poisson", 31 | "reg:gamma", 32 | "reg:logistic", 33 | "reg:squarederror", 34 | "reg:tweedie", 35 | "multi:softmax", 36 | "multi:softprob", 37 | "survival:cox", 38 | ] 39 | 40 | for invalid_objective in auc_invalid_objectives: 41 | test_hp["objective"] = invalid_objective 42 | 43 | with self.assertRaises(exc.UserError): 44 | hyperparameters.validate(test_hp) 45 | 46 | def test_verbosity(self): 47 | test_hp = {"num_round": "1", "verbosity": "0"} 48 | 49 | assert hyperparameters.validate(test_hp) 50 | 51 | test_hp2 = {"num_round": "1", "verbosity": "3"} 52 | 53 | assert hyperparameters.validate(test_hp2) 54 | 55 | test_hp3 = {"num_round": "1", "verbosity": "4"} 56 | 57 | with self.assertRaises(exc.UserError): 58 | hyperparameters.validate(test_hp3) 59 | 60 | def test_num_parallel_tree(self): 61 | test_hp = {"num_round": "5", "num_parallel_tree": "10"} 62 | 63 | assert hyperparameters.validate(test_hp) 64 | 65 | test_hp2 = {"num_round": "5", "num_parallel_tree": "-1"} 66 | 67 | with self.assertRaises(exc.UserError): 68 | hyperparameters.validate(test_hp2) 69 | 70 | test_hp3 = {"num_round": "5", "num_parallel_tree": "0"} 71 | 72 | with self.assertRaises(exc.UserError): 73 | hyperparameters.validate(test_hp3) 74 | 75 | def test_save_model_on_termination(self): 76 | test_hp1 = {"num_round": "5", "save_model_on_termination": "true"} 77 | 78 | assert hyperparameters.validate(test_hp1) 79 | 80 | test_hp2 = {"num_round": "5", "save_model_on_termination": "false"} 81 | 82 | assert hyperparameters.validate(test_hp2) 83 | 84 | test_hp3 = {"num_round": "5", "save_model_on_termination": "incorrect"} 85 | 86 | with self.assertRaises(exc.UserError): 87 | hyperparameters.validate(test_hp3) 88 | 89 | def test_survival_analysis(self): 90 | test_hp1 = { 91 | "num_round": "1", 92 | "eval_metric": "aft-nloglik", 93 | "objective": "reg:squarederror", 94 | } 95 | with self.assertRaises(exc.UserError): 96 | hyperparameters.validate(test_hp1) 97 | 98 | test_hp2 = { 99 | "num_round": "1", 100 | "eval_metric": "aft-nloglik", 101 | "objective": "survival:aft", 102 | } 103 | assert hyperparameters.validate(test_hp2) 104 | -------------------------------------------------------------------------------- /src/sagemaker_algorithm_toolkit/exceptions.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the 'License'). You 4 | # may not use this file except in compliance with the License. A copy of 5 | # the License is located at 6 | # 7 | # http://aws.amazon.com/apache2.0/ 8 | # 9 | # or in the 'license' file accompanying this file. This file is 10 | # distributed on an 'AS IS' BASIS, WITHOUT WARRANTIES OR CONDITIONS OF 11 | # ANY KIND, either express or implied. See the License for the specific 12 | # language governing permissions and limitations under the License. 13 | import warnings 14 | 15 | 16 | class BaseToolkitError(Exception): 17 | """Abstract base for all errors that may cause an algorithm to exit/terminate 18 | unsuccessfully. All direct sub-classes should be kept/maintained in this file. 19 | 20 | These errors are grouped into three categories: 21 | 22 | 1. AlgorithmError: an unexpected or unknown failure that cannot be 23 | avoided by the user and is due to a bug in the 24 | algorithm. 25 | 26 | 2. UserError: a failure which can be prevented/avoided by the 27 | user (e.g. change mini_batch_size). 28 | 29 | 3. PlatformError: a failure due to an environmental requirement not 30 | being met (e.g. if the /opt/ml/training directory 31 | is missing). 32 | 33 | Args: see `Attributes` below. 34 | 35 | Attributes: 36 | message (string): Description of why this exception was raised. 37 | caused_by (exception): The underlying exception that caused this 38 | exception to be raised. This should be a 39 | non-BaseToolkitError. 40 | """ 41 | 42 | def __init__(self, message=None, caused_by=None): 43 | formatted_message = BaseToolkitError._format_exception_message(message, caused_by) 44 | super(BaseToolkitError, self).__init__(formatted_message) 45 | self.message = formatted_message 46 | self.caused_by = caused_by 47 | 48 | @staticmethod 49 | def _format_exception_message(message, caused_by): 50 | """Generates the exception message. 51 | 52 | If a message has been explicitly passed then we use that as the exception 53 | message. If we also know the underlying exception type we prepend that 54 | to the name. 55 | 56 | If there is no message but we have an underlying exception then we use 57 | that exceptions message and prepend the type of the exception. 58 | """ 59 | if message: 60 | formatted_message = message 61 | elif caused_by: 62 | with warnings.catch_warnings(): 63 | warnings.simplefilter("ignore") # Suppress deprecation warning 64 | formatted_message = getattr(caused_by, "message", str(caused_by)) 65 | else: 66 | formatted_message = "unknown error occurred" 67 | 68 | if caused_by: 69 | formatted_message += " (caused by {})".format(caused_by.__class__.__name__) 70 | 71 | return formatted_message 72 | 73 | 74 | class AlgorithmError(BaseToolkitError): 75 | """Exception used to indicate a problem that occurred with the algorithm.""" 76 | 77 | def __init__(self, message=None, caused_by=None): 78 | super(AlgorithmError, self).__init__(message, caused_by) 79 | 80 | 81 | class UserError(BaseToolkitError): 82 | """Exception used to indicate a problem caused by mis-configuration or other user input.""" 83 | 84 | def __init__(self, message=None, caused_by=None): 85 | super(UserError, self).__init__(message, caused_by) 86 | 87 | 88 | class PlatformError(BaseToolkitError): 89 | """Exception used to indicate a problem caused by the underlying platform (e.g. network time-outs).""" 90 | 91 | def __init__(self, message=None, caused_by=None): 92 | super(PlatformError, self).__init__(message, caused_by) 93 | -------------------------------------------------------------------------------- /docker/1.7-1-1/final/Dockerfile.cpu: -------------------------------------------------------------------------------- 1 | ARG SAGEMAKER_XGBOOST_VERSION=1.7-1-1 2 | ARG PYTHON_VERSION=3.10 3 | 4 | FROM xgboost-container-base:${SAGEMAKER_XGBOOST_VERSION}-cpu-py3 5 | 6 | ARG SAGEMAKER_XGBOOST_VERSION 7 | 8 | ######################## 9 | # Install dependencies # 10 | ######################## 11 | COPY requirements.txt /requirements.txt 12 | RUN python3 -m pip install -r /requirements.txt && rm /requirements.txt 13 | 14 | # Fix Python 3.10 compatibility for sagemaker-containers 15 | RUN python3 -c "import sys; sys.path.insert(0, '/miniconda3/lib/python3.10/site-packages'); \ 16 | import sagemaker_containers._mapping as m; \ 17 | import collections.abc; \ 18 | setattr(collections, 'Mapping', collections.abc.Mapping); \ 19 | exec(open('/miniconda3/lib/python3.10/site-packages/sagemaker_containers/_mapping.py').read().replace('collections.Mapping', 'collections.abc.Mapping'))" || \ 20 | sed -i 's/collections\.Mapping/collections.abc.Mapping/g' /miniconda3/lib/python3.10/site-packages/sagemaker_containers/_mapping.py 21 | 22 | # Install smdebug from source 23 | RUN python3 -m pip install git+https://github.com/awslabs/sagemaker-debugger.git@1.0.29 24 | 25 | 26 | ########################### 27 | # Copy wheel to container # 28 | ########################### 29 | COPY dist/sagemaker_xgboost_container-2.0-py2.py3-none-any.whl /sagemaker_xgboost_container-1.0-py2.py3-none-any.whl 30 | RUN rm -rf /miniconda3/lib/python${PYTHON_VERSION}/site-packages/numpy-1.21.2.dist-info && \ 31 | python3 -m pip install --force-reinstall PyYAML==6.0.1 && \ 32 | python3 -m pip install --no-cache --no-deps /sagemaker_xgboost_container-1.0-py2.py3-none-any.whl && \ 33 | python3 -m pip uninstall -y typing && \ 34 | rm /sagemaker_xgboost_container-1.0-py2.py3-none-any.whl 35 | 36 | ############## 37 | # DMLC PATCH # 38 | ############## 39 | # TODO: remove after making contributions back to xgboost for tracker.py 40 | COPY src/sagemaker_xgboost_container/dmlc_patch/tracker.py \ 41 | /miniconda3/lib/python${PYTHON_VERSION}/site-packages/xgboost/dmlc-core/tracker/dmlc_tracker/tracker.py 42 | 43 | # Include DMLC python code in PYTHONPATH to use RabitTracker 44 | ENV PYTHONPATH=$PYTHONPATH:/miniconda3/lib/python${PYTHON_VERSION}/site-packages/xgboost/dmlc-core/tracker 45 | 46 | ####### 47 | # MMS # 48 | ####### 49 | # Create MMS user directory 50 | RUN useradd -m model-server 51 | RUN mkdir -p /home/model-server/tmp && chown -R model-server /home/model-server 52 | 53 | # Copy MMS configs 54 | COPY docker/${SAGEMAKER_XGBOOST_VERSION}/resources/mms/config.properties.tmp /home/model-server 55 | ENV XGBOOST_MMS_CONFIG=/home/model-server/config.properties 56 | 57 | # Copy execution parameters endpoint plugin for MMS 58 | RUN mkdir -p /tmp/plugins 59 | COPY docker/${SAGEMAKER_XGBOOST_VERSION}/resources/mms/endpoints-1.0.jar /tmp/plugins 60 | RUN chmod +x /tmp/plugins/endpoints-1.0.jar 61 | 62 | # Create directory for models 63 | RUN mkdir -p /opt/ml/models 64 | RUN chmod +rwx /opt/ml/models 65 | 66 | # Copy Dask configs 67 | RUN mkdir /etc/dask 68 | COPY docker/configs/dask_configs.yaml /etc/dask/ 69 | 70 | # Required label for multi-model loading 71 | LABEL com.amazonaws.sagemaker.capabilities.multi-models=true 72 | 73 | ##################### 74 | # Required ENV vars # 75 | ##################### 76 | # Set SageMaker training environment variables 77 | ENV SM_INPUT /opt/ml/input 78 | ENV SM_INPUT_TRAINING_CONFIG_FILE $SM_INPUT/config/hyperparameters.json 79 | ENV SM_INPUT_DATA_CONFIG_FILE $SM_INPUT/config/inputdataconfig.json 80 | ENV SM_CHECKPOINT_CONFIG_FILE $SM_INPUT/config/checkpointconfig.json 81 | # See: https://github.com/dmlc/xgboost/issues/7982#issuecomment-1379390906 https://github.com/dmlc/xgboost/pull/8257 82 | ENV NCCL_SOCKET_IFNAME eth 83 | 84 | 85 | # Set SageMaker serving environment variables 86 | ENV SM_MODEL_DIR /opt/ml/model 87 | 88 | # Set SageMaker entrypoints 89 | ENV SAGEMAKER_TRAINING_MODULE sagemaker_xgboost_container.training:main 90 | ENV SAGEMAKER_SERVING_MODULE sagemaker_xgboost_container.serving:main 91 | 92 | EXPOSE 8080 93 | ENV TEMP=/home/model-server/tmp 94 | LABEL com.amazonaws.sagemaker.capabilities.accept-bind-to-port=true 95 | -------------------------------------------------------------------------------- /src/sagemaker_xgboost_container/training.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the 'License'). You 4 | # may not use this file except in compliance with the License. A copy of 5 | # the License is located at 6 | # 7 | # http://aws.amazon.com/apache2.0/ 8 | # 9 | # or in the 'license' file accompanying this file. This file is 10 | # distributed on an 'AS IS' BASIS, WITHOUT WARRANTIES OR CONDITIONS OF 11 | # ANY KIND, either express or implied. See the License for the specific 12 | # language governing permissions and limitations under the License. 13 | from __future__ import absolute_import 14 | 15 | import json 16 | import logging 17 | import os 18 | import sys 19 | 20 | import sagemaker_containers.beta.framework as framework 21 | from sagemaker_containers import _env 22 | 23 | from sagemaker_xgboost_container.algorithm_mode.train import sagemaker_train 24 | from sagemaker_xgboost_container.constants import sm_env_constants 25 | 26 | logger = logging.getLogger(__name__) 27 | 28 | 29 | def run_algorithm_mode(): 30 | """Run training in algorithm mode, which does not require a user entry point. 31 | 32 | This parses the following environ elements for training: 33 | 34 | 'SM_INPUT_TRAINING_CONFIG_FILE' 35 | 'SM_INPUT_DATA_CONFIG_FILE' 36 | 'SM_CHANNEL_TRAIN' 37 | 'SM_CHANNEL_VALIDATION' 38 | 'SM_HOSTS' 39 | 'SM_CURRENT_HOST' 40 | 'SM_MODEL_DIR' 41 | 'SM_CHECKPOINT_CONFIG_FILE' 42 | 43 | """ 44 | # TODO: replace with CSDK constants in sagemaker_containers._env 45 | with open(os.getenv(sm_env_constants.SM_INPUT_TRAINING_CONFIG_FILE), "r") as f: 46 | train_config = json.load(f) 47 | with open(os.getenv(sm_env_constants.SM_INPUT_DATA_CONFIG_FILE), "r") as f: 48 | data_config = json.load(f) 49 | 50 | checkpoint_config_file = os.getenv(sm_env_constants.SM_CHECKPOINT_CONFIG_FILE) 51 | if os.path.exists(checkpoint_config_file): 52 | with open(checkpoint_config_file, "r") as f: 53 | checkpoint_config = json.load(f) 54 | else: 55 | checkpoint_config = {} 56 | 57 | train_path = os.environ[sm_env_constants.SM_CHANNEL_TRAIN] 58 | val_path = os.environ.get(sm_env_constants.SM_CHANNEL_VALIDATION) 59 | sm_hosts = json.loads(os.environ[sm_env_constants.SM_HOSTS]) 60 | sm_current_host = os.environ[sm_env_constants.SM_CURRENT_HOST] 61 | 62 | model_dir = os.getenv(sm_env_constants.SM_MODEL_DIR) 63 | 64 | sagemaker_train( 65 | train_config=train_config, 66 | data_config=data_config, 67 | train_path=train_path, 68 | val_path=val_path, 69 | model_dir=model_dir, 70 | sm_hosts=sm_hosts, 71 | sm_current_host=sm_current_host, 72 | checkpoint_config=checkpoint_config, 73 | ) 74 | 75 | 76 | def train(training_environment): 77 | """Run XGBoost training in either algorithm mode or using a user supplied module in local SageMaker environment. 78 | The user supplied module and its dependencies are downloaded from S3. 79 | Training is invoked by calling a "train" function in the user supplied module. 80 | 81 | Args: 82 | training_environment: training environment object containing environment variables, 83 | training arguments and hyperparameters 84 | """ 85 | if training_environment.user_entry_point is not None: 86 | logger.info("Invoking user training script.") 87 | framework.modules.run_module( 88 | training_environment.module_dir, 89 | training_environment.to_cmd_args(), 90 | training_environment.to_env_vars(), 91 | training_environment.module_name, 92 | capture_error=False, 93 | ) 94 | else: 95 | logger.info("Running XGBoost Sagemaker in algorithm mode") 96 | _env.write_env_vars(training_environment.to_env_vars()) 97 | 98 | run_algorithm_mode() 99 | 100 | 101 | def main(): 102 | train(framework.training_env()) 103 | sys.exit(0) 104 | -------------------------------------------------------------------------------- /test/unit/test_serving_mms.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the 'License'). You 4 | # may not use this file except in compliance with the License. A copy of 5 | # the License is located at 6 | # 7 | # http://aws.amazon.com/apache2.0/ 8 | # 9 | # or in the 'license' file accompanying this file. This file is 10 | # distributed on an 'AS IS' BASIS, WITHOUT WARRANTIES OR CONDITIONS OF 11 | # ANY KIND, either express or implied. See the License for the specific 12 | # language governing permissions and limitations under the License. 13 | from __future__ import absolute_import 14 | 15 | import os 16 | 17 | import pytest 18 | from mock import patch 19 | 20 | from sagemaker_xgboost_container import handler_service as user_module_handler_service 21 | from sagemaker_xgboost_container import serving, serving_mms 22 | from sagemaker_xgboost_container.algorithm_mode import ( 23 | handler_service as algo_handler_service, 24 | ) 25 | 26 | TEST_CONFIG_FILE = "test_dir" 27 | ALGO_HANDLER_SERVICE = algo_handler_service.__name__ 28 | USER_HANDLER_SERVICE = user_module_handler_service.__name__ 29 | TEST_MAX_CONTENT_LEN = 1024 30 | TEST_NUM_CPU = 3 31 | 32 | 33 | @pytest.fixture(autouse=True) 34 | def mock_set_mms_config_file(monkeypatch): 35 | monkeypatch.setenv("XGBOOST_MMS_CONFIG", TEST_CONFIG_FILE) 36 | 37 | 38 | @pytest.fixture(autouse=True) 39 | def mock_set_multi_model_env(monkeypatch): 40 | monkeypatch.setenv("SAGEMAKER_MULTI_MODEL", "true") 41 | 42 | 43 | @patch.dict(os.environ, {"SAGEMAKER_MULTI_MODEL": "True", "XGBOOST_MMS_CONFIG": TEST_CONFIG_FILE}) 44 | @patch("sagemaker_xgboost_container.serving_mms.model_server.start_model_server") 45 | def test_multi_model_algorithm_mode_hosting(start_model_server, mock_set_mms_config_file, mock_set_multi_model_env): 46 | serving.serving_entrypoint() 47 | start_model_server.assert_called_with( 48 | is_multi_model=True, 49 | handler_service="sagemaker_xgboost_container.algorithm_mode.handler_service", 50 | config_file=TEST_CONFIG_FILE, 51 | ) 52 | 53 | 54 | @patch.dict(os.environ, {"SAGEMAKER_MULTI_MODEL": "True", "XGBOOST_MMS_CONFIG": TEST_CONFIG_FILE}) 55 | @patch("sagemaker_xgboost_container.serving_mms.model_server.start_model_server") 56 | @patch("sagemaker_xgboost_container.serving.env.ServingEnv.module_dir") 57 | @patch("sagemaker_xgboost_container.serving.env.ServingEnv.module_name") 58 | @patch("sagemaker_containers.beta.framework.modules.import_module") 59 | def test_multi_model_user_mode_hosting_error( 60 | import_module, user_module_name, module_dir, start_model_server, mock_set_mms_config_file, mock_set_multi_model_env 61 | ): 62 | serving.serving_entrypoint() 63 | start_model_server.assert_called_with( 64 | is_multi_model=True, handler_service="sagemaker_xgboost_container.handler_service", config_file=TEST_CONFIG_FILE 65 | ) 66 | 67 | 68 | @patch("sagemaker_xgboost_container.serving_mms.model_server.start_model_server") 69 | @patch("multiprocessing.cpu_count", return_value=TEST_NUM_CPU) 70 | def test_env_var_setting_single_and_multi_model(start_model_server, mock_get_num_cpu): 71 | test_handler_str = "foo" 72 | 73 | with patch.dict("os.environ", {}): 74 | serving_mms._set_mms_configs(True, test_handler_str) 75 | 76 | assert os.environ["SAGEMAKER_NUM_MODEL_WORKERS"] == "1" 77 | assert os.environ["SAGEMAKER_MMS_MODEL_STORE"] == "/" 78 | assert os.environ["SAGEMAKER_MMS_LOAD_MODELS"] == "" 79 | assert os.environ["SAGEMAKER_MAX_REQUEST_SIZE"] == str(serving_mms.DEFAULT_MAX_CONTENT_LEN) 80 | assert os.environ["SAGEMAKER_MMS_DEFAULT_HANDLER"] == test_handler_str 81 | 82 | 83 | @patch("sagemaker_xgboost_container.serving_mms.model_server.start_model_server") 84 | def test_set_max_content_len(start_model_server): 85 | test_handler_str = "foo" 86 | with patch.dict("os.environ", {}): 87 | serving_mms._set_mms_configs(False, test_handler_str) 88 | assert os.environ["SAGEMAKER_MAX_REQUEST_SIZE"] == str(serving_mms.DEFAULT_MAX_CONTENT_LEN) 89 | 90 | with patch.dict("os.environ", {"MAX_CONTENT_LENGTH": str(TEST_MAX_CONTENT_LEN)}): 91 | serving_mms._set_mms_configs(False, test_handler_str) 92 | assert os.environ["SAGEMAKER_MAX_REQUEST_SIZE"] == str(TEST_MAX_CONTENT_LEN) 93 | -------------------------------------------------------------------------------- /src/sagemaker_algorithm_toolkit/metadata.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the 'License'). You 4 | # may not use this file except in compliance with the License. A copy of 5 | # the License is located at 6 | # 7 | # http://aws.amazon.com/apache2.0/ 8 | # 9 | # or in the 'license' file accompanying this file. This file is 10 | # distributed on an 'AS IS' BASIS, WITHOUT WARRANTIES OR CONDITIONS OF 11 | # ANY KIND, either express or implied. See the License for the specific 12 | # language governing permissions and limitations under the License. 13 | import json 14 | 15 | import boto3 16 | 17 | 18 | def _get_instance_types(region_name="us-east-1", location="US East (N. Virginia)"): 19 | s = boto3.client("pricing", region_name=region_name) 20 | 21 | NAME = "AmazonSageMaker" 22 | FILTERS = [ 23 | {"Type": "TERM_MATCH", "Field": "productFamily", "Value": "ML Instance"}, 24 | {"Type": "TERM_MATCH", "Field": "location", "Value": location}, 25 | ] 26 | results = s.get_products(ServiceCode=NAME, Filters=FILTERS) 27 | 28 | total_results = [] 29 | while results.get("NextToken"): 30 | total_results += results["PriceList"] 31 | results = s.get_products(ServiceCode=NAME, Filters=FILTERS, NextToken=results["NextToken"]) 32 | 33 | instance_types = {} 34 | for result in total_results: 35 | result = json.loads(result) 36 | instance_type = result["product"]["attributes"]["instanceType"] 37 | gpu = result["product"]["attributes"]["gpu"] 38 | 39 | instance_types[instance_type] = int(gpu) 40 | return instance_types 41 | 42 | 43 | class Product: 44 | NOTEBOOK = "Notebook" 45 | TRAINING = "Training" 46 | HOSTING = "Hosting" 47 | BATCH_TRANSFORM = "BatchTransform" 48 | 49 | 50 | def _trim(instance_type_product): 51 | SEPARATOR = "-" # e.g. ml.p3.2xlarge-Hosting 52 | return instance_type_product.split(SEPARATOR)[0] 53 | 54 | 55 | def get_cpu_instance_types(product, **kwargs): 56 | results = [] 57 | for instance_type, gpu_count in _get_instance_types(**kwargs).items(): 58 | if gpu_count == 0 and product in instance_type: 59 | results.append(_trim(instance_type)) 60 | return results 61 | 62 | 63 | def get_single_gpu_instance_types(product, **kwargs): 64 | results = [] 65 | for instance_type, gpu_count in _get_instance_types(**kwargs).items(): 66 | if gpu_count == 1 and product in instance_type: 67 | results.append(_trim(instance_type)) 68 | return results 69 | 70 | 71 | def get_multi_gpu_instance_types(product, **kwargs): 72 | results = [] 73 | for instance_type, gpu_count in _get_instance_types(**kwargs).items(): 74 | if gpu_count > 1 and product in instance_type: 75 | results.append(_trim(instance_type)) 76 | return results 77 | 78 | 79 | def training_spec( 80 | hyperparameters, channels, metrics, image_uri, supported_training_instance_types, supports_distributed_training 81 | ): 82 | return { 83 | "TrainingImage": image_uri, 84 | "TrainingChannels": channels.format(), 85 | "SupportedHyperParameters": hyperparameters.format(), 86 | "SupportedTrainingInstanceTypes": supported_training_instance_types, 87 | "SupportsDistributedTraining": supports_distributed_training, 88 | "MetricDefinitions": metrics.format_definitions(), 89 | "SupportedTuningJobObjectiveMetrics": metrics.format_tunable(), 90 | } 91 | 92 | 93 | def inference_spec( 94 | image_uri, 95 | supported_realtime_inference_instance_types, 96 | supported_transform_inference_instance_types, 97 | supported_content_types, 98 | supported_response_mimetypes, 99 | ): 100 | return { 101 | "Containers": [{"Image": image_uri}], 102 | "SupportedTransformInstanceTypes": supported_transform_inference_instance_types, 103 | "SupportedRealtimeInferenceInstanceTypes": supported_realtime_inference_instance_types, 104 | "SupportedContentTypes": supported_content_types, 105 | "SupportedResponseMIMETypes": supported_response_mimetypes, 106 | } 107 | 108 | 109 | def generate_metadata(training_spec, inference_spec): 110 | return {"TrainingSpecification": training_spec, "InferenceSpecification": inference_spec} 111 | -------------------------------------------------------------------------------- /test/unit/algorithm_toolkit/test_channel_validation.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the 'License'). You 4 | # may not use this file except in compliance with the License. A copy of 5 | # the License is located at 6 | # 7 | # http://aws.amazon.com/apache2.0/ 8 | # 9 | # or in the 'license' file accompanying this file. This file is 10 | # distributed on an 'AS IS' BASIS, WITHOUT WARRANTIES OR CONDITIONS OF 11 | # ANY KIND, either express or implied. See the License for the specific 12 | # language governing permissions and limitations under the License. 13 | import unittest 14 | 15 | from sagemaker_algorithm_toolkit import channel_validation as cv 16 | from sagemaker_algorithm_toolkit import exceptions as exc 17 | 18 | 19 | class TestChannelValidation(unittest.TestCase): 20 | def test_simple_supported(self): 21 | channel = cv.Channel(name="train", required=True) 22 | channel.add("text/csv", cv.Channel.FILE_MODE, cv.Channel.REPLICATED) 23 | channels = cv.Channels(channel) 24 | channels.validate( 25 | { 26 | "train": { 27 | cv.CONTENT_TYPE: "text/csv", 28 | cv.TRAINING_INPUT_MODE: "File", 29 | cv.S3_DIST_TYPE: "FullyReplicated", 30 | "RecordWrapperType": "None", 31 | } 32 | } 33 | ) 34 | 35 | def test_default_content_type(self): 36 | channel = cv.Channel(name="train", required=True) 37 | channel.add("text/csv", cv.Channel.FILE_MODE, cv.Channel.REPLICATED) 38 | channels = cv.Channels(channel) 39 | channels.set_default_content_type("text/csv") 40 | test_user_channels = { 41 | "train": {cv.TRAINING_INPUT_MODE: "File", cv.S3_DIST_TYPE: "FullyReplicated", "RecordWrapperType": "None"} 42 | } 43 | channels.validate(test_user_channels) 44 | self.assertEqual(test_user_channels["train"][cv.CONTENT_TYPE], "text/csv") 45 | 46 | def test_simple_not_supported(self): 47 | channel = cv.Channel(name="train", required=True) 48 | channel.add("text/csv", cv.Channel.FILE_MODE, cv.Channel.REPLICATED) 49 | channels = cv.Channels(channel) 50 | with self.assertRaises(exc.UserError): 51 | channels.validate( 52 | { 53 | "train": { 54 | cv.CONTENT_TYPE: "text/csv", 55 | cv.TRAINING_INPUT_MODE: "Pipe", 56 | cv.S3_DIST_TYPE: "FullyReplicated", 57 | "RecordWrapperType": "None", 58 | } 59 | } 60 | ) 61 | 62 | def test_simple_extra(self): 63 | channel = cv.Channel(name="train", required=True) 64 | channel.add("text/csv", cv.Channel.FILE_MODE, cv.Channel.REPLICATED) 65 | channels = cv.Channels(channel) 66 | with self.assertRaises(exc.UserError): 67 | channels.validate( 68 | { 69 | "train": { 70 | cv.CONTENT_TYPE: "text/csv", 71 | cv.TRAINING_INPUT_MODE: "File", 72 | cv.S3_DIST_TYPE: "FullyReplicated", 73 | "RecordWrapperType": "None", 74 | }, 75 | "extra": {}, 76 | } 77 | ) 78 | 79 | def test_simple_required(self): 80 | channel = cv.Channel(name="train", required=True) 81 | channel.add("text/csv", cv.Channel.FILE_MODE, cv.Channel.REPLICATED) 82 | channels = cv.Channels(channel) 83 | with self.assertRaises(exc.UserError): 84 | channels.validate({"sorry": {}}) 85 | 86 | def test_simple_format(self): 87 | channel = cv.Channel(name="train", required=True) 88 | channel.add("text/csv", cv.Channel.FILE_MODE, cv.Channel.REPLICATED) 89 | channels = cv.Channels(channel) 90 | 91 | result = { 92 | "Name": "train", 93 | "Description": "train", 94 | "IsRequired": True, 95 | "SupportedContentTypes": ["text/csv"], 96 | "SupportedInputModes": ["File"], 97 | } 98 | self.assertEqual(channel.format(), result) 99 | self.assertEqual(channels.format(), [result]) 100 | -------------------------------------------------------------------------------- /test/resources/boston/single_machine_customer_script.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017-2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"). You 4 | # may not use this file except in compliance with the License. A copy of 5 | # the License is located at 6 | # 7 | # http://aws.amazon.com/apache2.0/ 8 | # 9 | # or in the "license" file accompanying this file. This file is 10 | # distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF 11 | # ANY KIND, either express or implied. See the License for the specific 12 | # language governing permissions and limitations under the License. 13 | from __future__ import absolute_import, print_function 14 | 15 | import argparse 16 | import os 17 | 18 | import numpy as np 19 | import pandas as pd 20 | import xgboost as xgb 21 | from sklearn.datasets import fetch_california_housing 22 | from sklearn.metrics import mean_squared_error 23 | from sklearn.model_selection import train_test_split 24 | 25 | if __name__ == "__main__": 26 | parser = argparse.ArgumentParser() 27 | 28 | # Data and model checkpoints directories 29 | parser.add_argument("--objective", type=str, default="reg:squarederror") 30 | parser.add_argument("--colsample-bytree", type=float, default=0.3) 31 | parser.add_argument("--learning-rate", type=float, default=0.1) 32 | parser.add_argument("--max-depth", type=int, default=5) 33 | parser.add_argument("--reg-alpha", type=int, default=10) 34 | parser.add_argument("--n-estimators", type=int, default=10) 35 | parser.add_argument("--output-data-dir", type=str, default=os.environ["SM_OUTPUT_DATA_DIR"]) 36 | parser.add_argument("--model-dir", type=str, default=os.environ["SM_MODEL_DIR"]) 37 | 38 | args = parser.parse_args() 39 | 40 | # Load the California housing data into pandas data frame (replacement for deprecated Boston dataset) 41 | california = fetch_california_housing() 42 | data = pd.DataFrame(california.data) 43 | data.columns = california.feature_names 44 | data["PRICE"] = california.target 45 | 46 | # Convert Pandas dataframe to XGBoost DMatrix for better performance (used later). 47 | X, y = data.iloc[:, :-1], data.iloc[:, -1] 48 | data_dmatrix = xgb.DMatrix(data=X, label=y) 49 | 50 | # Create train/test split 51 | X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=100) 52 | 53 | # Create regressor object by using SKLearn API 54 | xg_reg = xgb.XGBRegressor( 55 | objective=args.objective, 56 | colsample_bytree=args.colsample_bytree, 57 | learning_rate=args.learning_rate, 58 | max_depth=args.max_depth, 59 | reg_alpha=args.reg_alpha, 60 | n_estimators=args.n_estimators, 61 | ) 62 | 63 | # Train and save the model 64 | xg_reg.fit(X_train, y_train) 65 | model_path = os.path.join(args.model_dir, "xgb-boston.model") 66 | xg_reg.get_booster().save_model(model_path) 67 | 68 | # Make predictions and calculate RMSE 69 | preds = xg_reg.predict(X_test) 70 | rmse = np.sqrt(mean_squared_error(y_test, preds)) 71 | print("RMSE: %f" % (rmse)) 72 | 73 | # We can look at the feature importance and store the graph as an image. 74 | if not os.path.exists(args.output_data_dir): 75 | os.makedirs(args.output_data_dir) 76 | 77 | try: 78 | ax = xgb.plot_importance(xg_reg) 79 | fig = ax.figure 80 | fig.set_size_inches(5, 5) 81 | fig.savefig(os.path.join(args.output_data_dir, "feature-importance-plot.png")) 82 | except Exception as e: 83 | print(f"Warning: Could not create feature importance plot: {e}") 84 | 85 | # Finally, lets do a bit of cross-validation by using native XGB functionality (keeping some parameters constant, so 86 | # that we don't have a huge input list for this simple example. 87 | params = { 88 | "objective": args.objective, 89 | "colsample_bytree": args.colsample_bytree, 90 | "learning_rate": args.learning_rate, 91 | "max_depth": args.max_depth, 92 | "alpha": args.reg_alpha, 93 | } 94 | cv_results = xgb.cv( 95 | dtrain=data_dmatrix, 96 | params=params, 97 | nfold=5, 98 | num_boost_round=50, 99 | early_stopping_rounds=10, 100 | metrics="rmse", 101 | as_pandas=True, 102 | seed=100, 103 | ) 104 | 105 | cv_results.to_csv(os.path.join(args.output_data_dir, "cv_results.csv")) 106 | -------------------------------------------------------------------------------- /docker/3.0-5/final/Dockerfile.cpu: -------------------------------------------------------------------------------- 1 | ARG SAGEMAKER_XGBOOST_VERSION=3.0-5 2 | ARG PYTHON_VERSION=3.10 3 | 4 | FROM xgboost-container-base:${SAGEMAKER_XGBOOST_VERSION}-cpu-py3 5 | 6 | ARG SAGEMAKER_XGBOOST_VERSION=3.0-5 7 | 8 | ######################## 9 | # Install dependencies # 10 | ######################## 11 | 12 | # Fix Python 3.10 compatibility for sagemaker-containers 13 | # RUN python3 -c "import sys; sys.path.insert(0, '/miniconda3/lib/python3.10/site-packages'); \ 14 | # import sagemaker_containers._mapping as m; \ 15 | # import collections.abc; \ 16 | # setattr(collections, 'Mapping', collections.abc.Mapping); \ 17 | # exec(open('/miniconda3/lib/python3.10/site-packages/sagemaker_containers/_mapping.py').read().replace('collections.Mapping', 'collections.abc.Mapping'))" || \ 18 | # sed -i 's/collections\.Mapping/collections.abc.Mapping/g' /miniconda3/lib/python3.10/site-packages/sagemaker_containers/_mapping.py 19 | 20 | 21 | # Install smdebug from source 22 | RUN python3 -m pip install git+https://github.com/awslabs/sagemaker-debugger.git@v1.0.32 23 | 24 | COPY requirements.txt /requirements.txt 25 | RUN python3 -m pip install -r /requirements.txt && rm /requirements.txt 26 | 27 | # Patches 28 | RUN python3 -m pip install --no-cache-dir "protobuf>=3.20.0,<=3.20.3" "fonttools>=4.60.2" 29 | 30 | # ENV PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python 31 | 32 | RUN sed -i 's/collections\.Mapping/collections.abc.Mapping/g' /miniconda3/lib/python3.10/site-packages/sagemaker_containers/_mapping.py 33 | 34 | ########################### 35 | # Copy wheel to container # 36 | ########################### 37 | COPY dist/sagemaker_xgboost_container-2.0-py2.py3-none-any.whl /sagemaker_xgboost_container-1.0-py2.py3-none-any.whl 38 | RUN rm -rf /miniconda3/lib/python${PYTHON_VERSION}/site-packages/numpy-1.21.2.dist-info && \ 39 | python3 -m pip install --force-reinstall PyYAML==6.0.1 && \ 40 | python3 -m pip install --no-cache --no-deps /sagemaker_xgboost_container-1.0-py2.py3-none-any.whl && \ 41 | python3 -m pip uninstall -y typing && \ 42 | rm /sagemaker_xgboost_container-1.0-py2.py3-none-any.whl 43 | 44 | ############## 45 | # DMLC PATCH # 46 | ############## 47 | # TODO: remove after making contributions back to xgboost for tracker.py 48 | # COPY src/sagemaker_xgboost_container/dmlc_patch/tracker.py \ 49 | # /miniconda3/lib/python${PYTHON_VERSION}/site-packages/xgboost/dmlc-core/tracker/dmlc_tracker/tracker.py 50 | 51 | # # Include DMLC python code in PYTHONPATH to use RabitTracker 52 | # ENV PYTHONPATH=$PYTHONPATH:/miniconda3/lib/python${PYTHON_VERSION}/site-packages/xgboost/dmlc-core/tracker 53 | 54 | ####### 55 | # MMS # 56 | ####### 57 | # Create MMS user directory 58 | RUN useradd -m model-server 59 | RUN mkdir -p /home/model-server/tmp && chown -R model-server /home/model-server 60 | 61 | # Copy MMS configs 62 | COPY docker/${SAGEMAKER_XGBOOST_VERSION}/resources/mms/config.properties.tmp /home/model-server 63 | ENV XGBOOST_MMS_CONFIG=/home/model-server/config.properties 64 | 65 | # Copy execution parameters endpoint plugin for MMS 66 | RUN mkdir -p /tmp/plugins 67 | COPY docker/${SAGEMAKER_XGBOOST_VERSION}/resources/mms/endpoints-1.0.jar /tmp/plugins 68 | RUN chmod +x /tmp/plugins/endpoints-1.0.jar 69 | 70 | # Create directory for models 71 | RUN mkdir -p /opt/ml/models 72 | RUN chmod +rwx /opt/ml/models 73 | 74 | # Copy Dask configs 75 | RUN mkdir /etc/dask 76 | COPY docker/configs/dask_configs.yaml /etc/dask/ 77 | 78 | # Required label for multi-model loading 79 | LABEL com.amazonaws.sagemaker.capabilities.multi-models=true 80 | 81 | ##################### 82 | # Required ENV vars # 83 | ##################### 84 | # Set SageMaker training environment variables 85 | ENV SM_INPUT /opt/ml/input 86 | ENV SM_INPUT_TRAINING_CONFIG_FILE $SM_INPUT/config/hyperparameters.json 87 | ENV SM_INPUT_DATA_CONFIG_FILE $SM_INPUT/config/inputdataconfig.json 88 | ENV SM_CHECKPOINT_CONFIG_FILE $SM_INPUT/config/checkpointconfig.json 89 | # See: https://github.com/dmlc/xgboost/issues/7982#issuecomment-1379390906 https://github.com/dmlc/xgboost/pull/8257 90 | ENV NCCL_SOCKET_IFNAME eth 91 | 92 | 93 | # Set SageMaker serving environment variables 94 | ENV SM_MODEL_DIR /opt/ml/model 95 | 96 | # Set SageMaker entrypoints 97 | ENV SAGEMAKER_TRAINING_MODULE sagemaker_xgboost_container.training:main 98 | ENV SAGEMAKER_SERVING_MODULE sagemaker_xgboost_container.serving:main 99 | 100 | EXPOSE 8080 101 | ENV TEMP=/home/model-server/tmp 102 | LABEL com.amazonaws.sagemaker.capabilities.accept-bind-to-port=true 103 | -------------------------------------------------------------------------------- /src/sagemaker_xgboost_container/algorithm_mode/train_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the 'License'). You 4 | # may not use this file except in compliance with the License. A copy of 5 | # the License is located at 6 | # 7 | # http://aws.amazon.com/apache2.0/ 8 | # 9 | # or in the 'license' file accompanying this file. This file is 10 | # distributed on an 'AS IS' BASIS, WITHOUT WARRANTIES OR CONDITIONS OF 11 | # ANY KIND, either express or implied. See the License for the specific 12 | # language governing permissions and limitations under the License. 13 | import logging 14 | import os 15 | 16 | from sagemaker_xgboost_container.metrics.custom_metrics import ( 17 | configure_feval, 18 | get_custom_metrics, 19 | ) 20 | 21 | HPO_SEPARATOR = ":" 22 | 23 | 24 | # These are helper functions for parsing the list of metrics to be outputted 25 | def get_union_metrics(metric_a, metric_b): 26 | """Union of metric_a and metric_b 27 | 28 | :param metric_a: list 29 | :param metric_b: list 30 | :return: Union metrics list from metric_a and metric_b 31 | """ 32 | if metric_a is None and metric_b is None: 33 | return None 34 | elif metric_a is None: 35 | return metric_b 36 | elif metric_b is None: 37 | return metric_a 38 | else: 39 | # The order of metric_list need to be consistent among all hosts in distributed training 40 | # So we have metric_list sorted here. 41 | metric_list = sorted(list(set(metric_a).union(metric_b))) 42 | return metric_list 43 | 44 | 45 | def get_eval_metrics_and_feval(tuning_objective_metric_param, eval_metric): 46 | """Return list of default xgb evaluation metrics and list of container defined metrics. 47 | 48 | XGB uses the 'eval_metric' parameter for the evaluation metrics supported by default, and 'feval' as an argument 49 | during training to validate using custom evaluation metrics. The argument 'feval' takes a function as value; the 50 | method returned here will be configured to run for only the metrics the user specifies. 51 | 52 | :param tuning_objective_metric_param: HPO metric 53 | :param eval_metric: list of xgb metrics to output 54 | :return: cleaned list of xgb supported evaluation metrics, method configured with container defined metrics, 55 | and tuning objective metric. 56 | """ 57 | tuning_objective_metric = None 58 | configured_eval = None 59 | cleaned_eval_metrics = None 60 | 61 | if tuning_objective_metric_param is not None: 62 | tuning_objective_metric_tuple = MetricNameComponents.decode(tuning_objective_metric_param) 63 | tuning_objective_metric = tuning_objective_metric_tuple.metric_name.split(",") 64 | logging.info("Setting up HPO optimized metric to be : {}".format(tuning_objective_metric_tuple.metric_name)) 65 | 66 | union_metrics = get_union_metrics(tuning_objective_metric, eval_metric) 67 | 68 | if union_metrics is not None: 69 | feval_metrics = get_custom_metrics(union_metrics) 70 | if feval_metrics: 71 | configured_eval = configure_feval(feval_metrics) 72 | cleaned_eval_metrics = list(set(union_metrics) - set(feval_metrics)) 73 | else: 74 | cleaned_eval_metrics = union_metrics 75 | 76 | return cleaned_eval_metrics, configured_eval, tuning_objective_metric 77 | 78 | 79 | def cleanup_dir(dir, file_prefix): 80 | """Clean up directory 81 | 82 | This function is used to remove extra files from a directory other than 'file'. 83 | 84 | :param dir: model directory which needs to be cleaned 85 | :param file_prefix: file name prefix which isn't removed if present 86 | """ 87 | 88 | def _format_path(file_name): 89 | return os.path.join(dir, file_name) 90 | 91 | def _remove(path): 92 | try: 93 | os.remove(path) 94 | except Exception: 95 | pass 96 | 97 | for data_file in os.listdir(dir): 98 | path = _format_path(data_file) 99 | if os.path.isfile(path) and not data_file.startswith(file_prefix): 100 | _remove(path) 101 | 102 | 103 | class MetricNameComponents(object): 104 | def __init__(self, data_segment, metric_name, emission_frequency=None): 105 | self.data_segment = data_segment 106 | self.metric_name = metric_name 107 | self.emission_frequency = emission_frequency 108 | 109 | @classmethod 110 | def decode(cls, tuning_objective_metric): 111 | result = tuning_objective_metric.split(":") 112 | return MetricNameComponents(*result) 113 | 114 | 115 | def _get_bytes_to_mb(num_bytes): 116 | return round(num_bytes / (1024 * 1024), 2) 117 | -------------------------------------------------------------------------------- /src/sagemaker_xgboost_container/handler_service.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"). You 4 | # may not use this file except in compliance with the License. A copy of 5 | # the License is located at 6 | # 7 | # http://aws.amazon.com/apache2.0/ 8 | # 9 | # or in the "license" file accompanying this file. This file is 10 | # distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF 11 | # ANY KIND, either express or implied. See the License for the specific 12 | # language governing permissions and limitations under the License. 13 | from __future__ import absolute_import 14 | 15 | import textwrap 16 | 17 | from sagemaker_containers.beta.framework import encoders 18 | from sagemaker_inference import content_types, default_inference_handler 19 | from sagemaker_inference.default_handler_service import DefaultHandlerService 20 | from sagemaker_inference.transformer import Transformer 21 | 22 | from sagemaker_xgboost_container import encoder as xgb_encoders 23 | 24 | 25 | class HandlerService(DefaultHandlerService): 26 | """Handler service that is executed by the model server. 27 | Determines specific default inference handlers to use based on the type MXNet model being used. 28 | This class extends ``DefaultHandlerService``, which define the following: 29 | - The ``handle`` method is invoked for all incoming inference requests to the model server. 30 | - The ``initialize`` method is invoked at model server start up. 31 | Based on: https://github.com/awslabs/mxnet-model-server/blob/master/docs/custom_service.md 32 | """ 33 | 34 | class DefaultXGBoostUserModuleInferenceHandler(default_inference_handler.DefaultInferenceHandler): 35 | def default_model_fn(self, model_dir): 36 | """Load a model. For XGBoost Framework, a default function to load a model is not provided. 37 | Users should provide customized model_fn() in script. 38 | Args: 39 | model_dir: a directory where model is saved. 40 | Returns: A XGBoost model. 41 | """ 42 | raise NotImplementedError( 43 | textwrap.dedent( 44 | """ 45 | Please provide a model_fn implementation. 46 | See documentation for model_fn at https://github.com/aws/sagemaker-python-sdk 47 | """ 48 | ) 49 | ) 50 | 51 | def default_input_fn(self, input_data, content_type): 52 | """Take request data and de-serializes the data into an object for prediction. 53 | When an InvokeEndpoint operation is made against an Endpoint running SageMaker model server, 54 | the model server receives two pieces of information: 55 | - The request Content-Type, for example "application/json" 56 | - The request data, which is at most 5 MB (5 * 1024 * 1024 bytes) in size. 57 | The input_fn is responsible to take the request data and pre-process it before prediction. 58 | Args: 59 | input_data (obj): the request data. 60 | content_type (str): the request Content-Type. 61 | Returns: 62 | (obj): data ready for prediction. For XGBoost, this defaults to DMatrix. 63 | """ 64 | return xgb_encoders.decode(input_data, content_type) 65 | 66 | def default_predict_fn(self, input_data, model): 67 | """A default predict_fn for XGBooost Framework. Calls a model on data deserialized in input_fn. 68 | Args: 69 | input_data: input data (DMatrix) for prediction deserialized by input_fn 70 | model: XGBoost model loaded in memory by model_fn 71 | Returns: a prediction 72 | """ 73 | output = model.predict(input_data, validate_features=False) 74 | return output 75 | 76 | def default_output_fn(self, prediction, accept): 77 | """Function responsible to serialize the prediction for the response. 78 | Args: 79 | prediction (obj): prediction returned by predict_fn . 80 | accept (str): accept content-type expected by the client. 81 | Returns: 82 | encoded response for MMS to return to client 83 | """ 84 | encoded_prediction = encoders.encode(prediction, accept) 85 | if accept == content_types.CSV: 86 | encoded_prediction = encoded_prediction.encode("utf-8") 87 | 88 | return encoded_prediction 89 | 90 | def __init__(self): 91 | transformer = Transformer(default_inference_handler=self.DefaultXGBoostUserModuleInferenceHandler()) 92 | super(HandlerService, self).__init__(transformer=transformer) 93 | -------------------------------------------------------------------------------- /src/sagemaker_xgboost_container/callback.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import signal 4 | import xgboost as xgb 5 | 6 | from sagemaker_xgboost_container import checkpointing 7 | from sagemaker_xgboost_container.algorithm_mode import train_utils 8 | from sagemaker_xgboost_container.constants.xgb_constants import ( 9 | MODEL_NAME, 10 | XGB_MAXIMIZE_METRICS, 11 | ) 12 | 13 | # from smdebug.xgboost import Hook 14 | 15 | logger = logging.getLogger(__name__) 16 | 17 | 18 | # def add_debugging(callbacks, hyperparameters, train_dmatrix, val_dmatrix=None, json_config_path=None): 19 | # """Add a sagemaker debug hook to a list of callbacks. 20 | 21 | # :param callbacks: List of callback functions. 22 | # :param hyperparameters: Dict of hyperparamters. 23 | # Same as `params` in xgb.train(params, dtrain). 24 | # :param train_dmatrix: Training data set. 25 | # :param val_dmatrix: Validation data set. 26 | # :param json_config_path: If specified, this json config will be used 27 | # instead of default config file. 28 | # """ 29 | # try: 30 | # hook = Hook.hook_from_config(json_config_path) 31 | # hook.hyperparameters = hyperparameters 32 | # hook.train_data = train_dmatrix 33 | # if val_dmatrix is not None: 34 | # hook.validation_data = val_dmatrix 35 | # callbacks.append(hook) 36 | # logging.info("Debug hook created from config") 37 | # except Exception as e: 38 | # logging.debug("Failed to create debug hook", e) 39 | # return 40 | 41 | 42 | def add_sigterm_handler(model_dir, is_master): 43 | """Stop training and cleanup model directory when SIGTERM is received. 44 | 45 | Model directory is only cleaned if is_master is True. Otherwise program terminates. 46 | 47 | :param model_dir: Directory where model is saved 48 | :param is_master: True if single node training, or the current node is the master node in distributed training 49 | """ 50 | 51 | def _terminate(): 52 | os._exit(0) 53 | 54 | def _cleanup_files(signo, frame): 55 | if is_master: 56 | train_utils.cleanup_dir(model_dir, MODEL_NAME) 57 | 58 | _terminate() 59 | 60 | signal.signal(signal.SIGTERM, _cleanup_files) 61 | 62 | 63 | def get_callbacks( 64 | model_dir, 65 | checkpoint_dir, 66 | early_stopping_data_name, 67 | early_stopping_metric, 68 | early_stopping_rounds, 69 | save_model_on_termination, 70 | is_master, 71 | fold=None, 72 | ): 73 | if checkpoint_dir and fold is not None: 74 | checkpoint_dir = os.path.join(checkpoint_dir, f"model-{fold}") 75 | 76 | # Set callbacks 77 | xgb_model, iteration = checkpointing.load_checkpoint(checkpoint_dir) 78 | if xgb_model is not None: 79 | if fold is not None: 80 | xgb_model = f"{xgb_model}-{fold}" 81 | logging.info("Checkpoint loaded from %s", xgb_model) 82 | logging.info("Resuming from iteration %s", iteration) 83 | 84 | callbacks = [] 85 | callbacks.append(xgb.callback.EvaluationMonitor()) 86 | 87 | if checkpoint_dir and is_master: 88 | save_checkpoint = xgb.callback.TrainingCheckPoint( 89 | directory=checkpoint_dir, 90 | interval=iteration, 91 | name=checkpointing.CHECKPOINT_FILENAME, 92 | ) 93 | callbacks.append(save_checkpoint) 94 | 95 | logging.info( 96 | f"CALLBACK_SETUP_DEBUG: save_model_on_termination={save_model_on_termination}, is_master={is_master}" 97 | ) 98 | 99 | if save_model_on_termination == "true" and is_master: 100 | logging.info("CALLBACK_ADDING: Adding SaveIntermediateModelCallBack on master") 101 | model_name = f"{MODEL_NAME}-{fold}" if fold is not None else MODEL_NAME 102 | save_intermediate_model = checkpointing.SaveIntermediateModelCallBack( 103 | model_dir, model_name, is_master 104 | ) 105 | callbacks.append(save_intermediate_model) 106 | add_sigterm_handler(model_dir, is_master) 107 | else: 108 | logging.info( 109 | f"CALLBACK_SKIPPING save_model_on_termination={save_model_on_termination}, is_master={is_master})" 110 | ) 111 | 112 | if early_stopping_data_name and early_stopping_metric and early_stopping_rounds: 113 | maximize = early_stopping_metric in XGB_MAXIMIZE_METRICS 114 | early_stop = xgb.callback.EarlyStopping( 115 | rounds=early_stopping_rounds, 116 | data_name=early_stopping_data_name, 117 | metric_name=early_stopping_metric, 118 | maximize=maximize, 119 | save_best=is_master, 120 | ) 121 | callbacks.append(early_stop) 122 | 123 | return xgb_model, iteration, callbacks 124 | -------------------------------------------------------------------------------- /src/sagemaker_algorithm_toolkit/channel_validation.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the 'License'). You 4 | # may not use this file except in compliance with the License. A copy of 5 | # the License is located at 6 | # 7 | # http://aws.amazon.com/apache2.0/ 8 | # 9 | # or in the 'license' file accompanying this file. This file is 10 | # distributed on an 'AS IS' BASIS, WITHOUT WARRANTIES OR CONDITIONS OF 11 | # ANY KIND, either express or implied. See the License for the specific 12 | # language governing permissions and limitations under the License. 13 | from sagemaker_algorithm_toolkit import exceptions as exc 14 | 15 | CONTENT_TYPE = "ContentType" 16 | TRAINING_INPUT_MODE = "TrainingInputMode" 17 | S3_DIST_TYPE = "S3DistributionType" 18 | 19 | 20 | class Channel(object): 21 | """Represents a single SageMaker training job channel.""" 22 | 23 | FILE_MODE = "File" 24 | PIPE_MODE = "Pipe" 25 | AUGMENTED_MODE = "Augmented" 26 | 27 | SHARDED = "ShardedByS3Key" 28 | REPLICATED = "FullyReplicated" 29 | 30 | def __init__(self, name, required): 31 | self.name = name 32 | self.required = required 33 | self.supported = set() 34 | 35 | def format(self): 36 | """Format channel for SageMaker's CreateAlgorithm API.""" 37 | supported_content_types = list(set(c[0] for c in self.supported)) 38 | supported_input_modes = list(set(c[1] for c in self.supported)) 39 | return { 40 | "Name": self.name, 41 | "Description": self.name, 42 | "IsRequired": self.required, 43 | "SupportedContentTypes": supported_content_types, 44 | "SupportedInputModes": supported_input_modes, 45 | } 46 | 47 | def add(self, content_type, supported_input_mode, supported_s3_data_distribution_type): 48 | """Add relevant configuration as a supported configuration for the channel.""" 49 | self.supported.add((content_type, supported_input_mode, supported_s3_data_distribution_type)) 50 | 51 | def validate(self, value): 52 | """Validate the provided configuration against the channel's supported configuration.""" 53 | if (value[CONTENT_TYPE], value[TRAINING_INPUT_MODE], value[S3_DIST_TYPE]) not in self.supported: 54 | raise exc.UserError("Channel configuration for '{}' channel is not supported: {}".format(self.name, value)) 55 | 56 | 57 | class Channels(object): 58 | """Represents a collection of Channels for a SageMaker training job.""" 59 | 60 | def __init__(self, *channels): 61 | self.channels = channels 62 | self.default_content_type = None 63 | 64 | def set_default_content_type(self, default_content_type): 65 | self.default_content_type = default_content_type 66 | 67 | def format(self): 68 | """Format channels for SageMaker's CreateAlgorithm API.""" 69 | return [channel.format() for channel in self.channels] 70 | 71 | def validate(self, user_channels): 72 | """Validate the provided user-specified channels at runtime against the channels' supported configuration. 73 | 74 | Note that this adds default content type for channels if a default exists. 75 | 76 | :param user_channels: dictionary of channels formatted like so 77 | { 78 | "channel_name": { 79 | "ContentType": . 80 | "TrainingInputMode": , 81 | "S3DistributionType": , 82 | ... 83 | }, 84 | "channel_name": {... 85 | } 86 | } 87 | """ 88 | for channel in self.channels: 89 | if channel.name not in user_channels: 90 | if channel.required: 91 | raise exc.UserError("Missing required channel: {}".format(channel.name)) 92 | 93 | name_to_channel = {channel.name: channel for channel in self.channels} 94 | validated_channels = {} 95 | for channel, value in user_channels.items(): 96 | try: 97 | channel_obj = name_to_channel[channel] 98 | except KeyError: 99 | raise exc.UserError("Extraneous channel found: {}".format(channel)) 100 | 101 | if CONTENT_TYPE not in value: 102 | if self.default_content_type: 103 | value[CONTENT_TYPE] = self.default_content_type 104 | else: 105 | raise exc.UserError("Missing content type for channel: {}".format(channel)) 106 | 107 | channel_obj.validate(value) 108 | validated_channels[channel] = value 109 | 110 | return validated_channels 111 | -------------------------------------------------------------------------------- /test/resources/abalone/abalone_distributed.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"). 4 | # You may not use this file except in compliance with the License. 5 | # A copy of the License is located at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # or in the "license" file accompanying this file. This file is distributed 10 | # on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either 11 | # express or implied. See the License for the specific language governing 12 | # permissions and limitations under the License. 13 | from __future__ import print_function 14 | 15 | import argparse 16 | import json 17 | import logging 18 | import os 19 | import pickle as pkl 20 | 21 | import xgboost as xgb 22 | from sagemaker_containers import entry_point 23 | 24 | from sagemaker_xgboost_container import distributed 25 | from sagemaker_xgboost_container.data_utils import get_dmatrix 26 | 27 | 28 | def _xgb_train(params, dtrain, evals, num_boost_round, model_dir, is_master): 29 | """Run xgb train on arguments given with rabit initialized. 30 | 31 | This is our rabit execution function. 32 | 33 | :param args_dict: Argument dictionary used to run xgb.train(). 34 | :param is_master: True if current node is master host in distributed training, 35 | or is running single node training job. 36 | Note that rabit_run will include this argument. 37 | """ 38 | booster = xgb.train(params=params, dtrain=dtrain, evals=evals, num_boost_round=num_boost_round) 39 | 40 | if is_master: 41 | model_location = model_dir + "/xgboost-model" 42 | pkl.dump(booster, open(model_location, "wb")) 43 | logging.info("Stored trained model at {}".format(model_location)) 44 | 45 | 46 | if __name__ == "__main__": 47 | parser = argparse.ArgumentParser() 48 | 49 | # Hyperparameters are described here. 50 | parser.add_argument("--max_depth", type=int) 51 | parser.add_argument("--eta", type=float) 52 | parser.add_argument("--gamma", type=int) 53 | parser.add_argument("--min_child_weight", type=int) 54 | parser.add_argument("--subsample", type=float) 55 | parser.add_argument("--objective", type=str) 56 | parser.add_argument("--num_round", type=int) 57 | 58 | # Sagemaker specific arguments. Defaults are set in the environment variables. 59 | parser.add_argument("--output_data_dir", type=str, default=os.environ.get("SM_OUTPUT_DATA_DIR")) 60 | parser.add_argument("--model_dir", type=str, default=os.environ.get("SM_MODEL_DIR")) 61 | parser.add_argument("--train", type=str, default=os.environ.get("SM_CHANNEL_TRAIN")) 62 | parser.add_argument("--validation", type=str, default=os.environ.get("SM_CHANNEL_VALIDATION")) 63 | parser.add_argument("--sm_hosts", type=str, default=os.environ.get("SM_HOSTS")) 64 | parser.add_argument("--sm_current_host", type=str, default=os.environ.get("SM_CURRENT_HOST")) 65 | 66 | args, _ = parser.parse_known_args() 67 | 68 | # Get SageMaker host information from runtime environment variables 69 | sm_hosts = json.loads(args.sm_hosts) 70 | sm_current_host = args.sm_current_host 71 | 72 | dtrain = get_dmatrix(args.train, "libsvm") 73 | dval = get_dmatrix(args.validation, "libsvm") 74 | watchlist = [(dtrain, "train"), (dval, "validation")] if dval is not None else [(dtrain, "train")] 75 | 76 | train_hp = { 77 | "max_depth": args.max_depth, 78 | "eta": args.eta, 79 | "gamma": args.gamma, 80 | "min_child_weight": args.min_child_weight, 81 | "subsample": args.subsample, 82 | "objective": args.objective, 83 | } 84 | 85 | xgb_train_args = dict( 86 | params=train_hp, 87 | dtrain=dtrain, 88 | evals=watchlist, 89 | num_boost_round=args.num_round, 90 | model_dir=args.model_dir, 91 | ) 92 | 93 | if len(sm_hosts) > 1: 94 | # Wait until all hosts are able to find each other 95 | entry_point._wait_hostname_resolution() 96 | 97 | # Execute training function after initializing rabit. 98 | distributed.rabit_run( 99 | exec_fun=_xgb_train, 100 | args=xgb_train_args, 101 | include_in_training=(dtrain is not None), 102 | hosts=sm_hosts, 103 | current_host=sm_current_host, 104 | update_rabit_args=True, 105 | ) 106 | else: 107 | # If single node training, call training method directly. 108 | if dtrain: 109 | xgb_train_args["is_master"] = True 110 | _xgb_train(**xgb_train_args) 111 | else: 112 | raise ValueError("Training channel must have data to train model.") 113 | 114 | 115 | def model_fn(model_dir): 116 | """Deserialize and return fitted model. 117 | 118 | Note that this should have the same name as the serialized model in the _xgb_train method 119 | """ 120 | model_file = "xgboost-model" 121 | booster = pkl.load(open(os.path.join(model_dir, model_file), "rb")) 122 | return booster 123 | -------------------------------------------------------------------------------- /test/unit/test_encoder.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the 'License'). You 4 | # may not use this file except in compliance with the License. A copy of 5 | # the License is located at 6 | # 7 | # http://aws.amazon.com/apache2.0/ 8 | # 9 | # or in the 'license' file accompanying this file. This file is 10 | # distributed on an 'AS IS' BASIS, WITHOUT WARRANTIES OR CONDITIONS OF 11 | # ANY KIND, either express or implied. See the License for the specific 12 | # language governing permissions and limitations under the License. 13 | import json 14 | import os 15 | import tempfile 16 | from pathlib import Path 17 | 18 | import mock 19 | import pytest 20 | import xgboost as xgb 21 | from mock import Mock, patch 22 | from sagemaker_containers import _content_types, _errors 23 | 24 | from sagemaker_xgboost_container import encoder 25 | 26 | 27 | @pytest.mark.parametrize("target", ("42,6,9", "42.0,6.0,9.0", "42\n6\n9\n", b"42,6,9", b"42.0,6.0,9.0", b"42\n6\n9\n")) 28 | def test_csv_to_dmatrix(target): 29 | actual = encoder.csv_to_dmatrix(target) 30 | assert type(actual) is xgb.DMatrix 31 | 32 | 33 | @pytest.mark.parametrize( 34 | "target", 35 | ( 36 | "1,2,3,12:12:12", 37 | "1,2,3,2019-1-1", 38 | "1,2,3,2019-1-1 12:12:12", 39 | "1,2,3,2019-1-1 12:12:12+00", 40 | "1,2,3,-14 days", 41 | "1,2,3\n1,2,c", 42 | ), 43 | ) 44 | def test_csv_to_dmatrix_error(target): 45 | try: 46 | encoder.csv_to_dmatrix(target) 47 | assert False 48 | except Exception as e: 49 | assert type(e) is ValueError 50 | 51 | 52 | @pytest.mark.parametrize("target", (b"0 0:1 5:1", b"0:1 5:1")) 53 | def test_libsvm_to_dmatrix(target): 54 | temp_libsvm_file = tempfile.NamedTemporaryFile(delete=False) 55 | temp_libsvm_file_name = temp_libsvm_file.name 56 | assert os.path.exists(temp_libsvm_file_name) 57 | 58 | with mock.patch("sagemaker_xgboost_container.encoder.tempfile") as mock_tempfile: 59 | mock_tempfile.NamedTemporaryFile.return_value = temp_libsvm_file 60 | actual = encoder.libsvm_to_dmatrix(target) 61 | 62 | assert type(actual) is xgb.DMatrix 63 | assert not os.path.exists(temp_libsvm_file_name) 64 | 65 | 66 | @pytest.mark.parametrize( 67 | "target", 68 | ( 69 | b"\n#\xd7\xce\x13\x00\x00\x00\n\x11\n\x06values\x12\x07:\x05\n\x03*\x06\t\x00", # 42,6,9 70 | b"\n#\xd7\xce(\x00\x00\x00\n&\n\x06values\x12\x1c\x1a\x1a\n\x18\x00\x00\x00" # 42.0,6.0,9.0 71 | b'\x00\x00\x00E@\x00\x00\x00\x00\x00\x00\x18@\x00\x00\x00\x00\x00\x00"@', 72 | b"\n#\xd7\xce\x19\x00\x00\x00\n\x17\n\x06values\x12\r:\x0b\n\x02\x01\x01\x12" # 0:1 5:1 73 | b"\x02\x00\x05\x1a\x01\x06\x00\x00\x00", 74 | ), 75 | ) 76 | def test_recordio_protobuf_to_dmatrix(target): 77 | actual = encoder.recordio_protobuf_to_dmatrix(target) 78 | assert type(actual) is xgb.DMatrix 79 | 80 | 81 | def test_sparse_recordio_protobuf_to_dmatrix(): 82 | current_path = Path(os.path.abspath(__file__)) 83 | data_path = os.path.join(str(current_path.parent.parent), "resources", "data") 84 | files_path = os.path.join(data_path, "recordio_protobuf", "sparse_edge_cases") 85 | 86 | for filename in os.listdir(files_path): 87 | file_path = os.path.join(files_path, filename) 88 | with open(file_path, "rb") as f: 89 | target = f.read() 90 | actual = encoder.recordio_protobuf_to_dmatrix(target) 91 | assert type(actual) is xgb.DMatrix 92 | 93 | 94 | def test_decode_error(): 95 | with pytest.raises(_errors.UnsupportedFormatError): 96 | encoder.decode(42, _content_types.OCTET_STREAM) 97 | 98 | 99 | @pytest.mark.parametrize("content_type", [_content_types.JSON, _content_types.CSV]) 100 | def test_decode(content_type): 101 | decoder = Mock() 102 | with patch.dict(encoder._dmatrix_decoders_map, {content_type: decoder}, clear=True): 103 | encoder.decode(42, content_type) 104 | 105 | decoder.assert_called_once_with(42) 106 | 107 | 108 | @pytest.mark.parametrize("content_type", ["text/csv; charset=UTF-8"]) 109 | def test_decode_with_complex_csv_content_type(content_type): 110 | dmatrix_result = encoder.decode("42.0,6.0,9.0\n42.0,6.0,9.0", content_type) 111 | assert type(dmatrix_result) is xgb.DMatrix 112 | 113 | 114 | def test_encoder_jsonlines_from_json(): 115 | json_response = json.dumps( 116 | { 117 | "predictions": [ 118 | {"predicted_label": 1, "probabilities": [0.4, 0.6]}, 119 | {"predicted_label": 0, "probabilities": [0.9, 0.1]}, 120 | ] 121 | } 122 | ) 123 | expected_jsonlines = ( 124 | b'{"predicted_label": 1, "probabilities": [0.4, 0.6]}\n' 125 | b'{"predicted_label": 0, "probabilities": [0.9, 0.1]}\n' 126 | ) 127 | 128 | jsonlines_response = encoder.json_to_jsonlines(json_response) 129 | assert expected_jsonlines == jsonlines_response 130 | 131 | 132 | def test_encoder_jsonlines_from_json_error(): 133 | bad_json_response = json.dumps({"predictions": [], "metadata": []}) 134 | with pytest.raises(ValueError): 135 | encoder.json_to_jsonlines(bad_json_response) 136 | -------------------------------------------------------------------------------- /src/sagemaker_xgboost_container/prediction_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the 'License'). You 4 | # may not use this file except in compliance with the License. A copy of 5 | # the License is located at 6 | # 7 | # http://aws.amazon.com/apache2.0/ 8 | # 9 | # or in the 'license' file accompanying this file. This file is 10 | # distributed on an 'AS IS' BASIS, WITHOUT WARRANTIES OR CONDITIONS OF 11 | # ANY KIND, either express or implied. See the License for the specific 12 | # language governing permissions and limitations under the License. 13 | import logging 14 | import os 15 | 16 | import numpy as np 17 | from scipy import stats 18 | 19 | from sagemaker_algorithm_toolkit import exceptions as exc 20 | 21 | PREDICTIONS_OUTPUT_FILE = "predictions.csv" 22 | EXAMPLE_ROWS_EXCEPTION_COUNT = 100 23 | 24 | 25 | class ValidationPredictionRecorder: 26 | """Helper class to record and store predictions obtained on different train / validation 27 | folds. Predictions are stored in folder specified by SM_OUTPUT_DATA_DIR env variable set by 28 | training platform, and sometimes modified by container code. Additional artefacts at the 29 | end of the training job are stored in output s3 path as output.tar.gz. 30 | 31 | Attributes: 32 | y_true (1d numpy array): Ground truth labels. 33 | num_cv_round (int): number times cross validation procedure will be repeated. 34 | classification (bool): indicates type of learning problem. 35 | """ 36 | 37 | def __init__(self, y_true: np.ndarray, num_cv_round: int, classification: bool, output_data_dir: str) -> None: 38 | self.y_true = y_true.copy() 39 | num_rows = len(y_true) 40 | self.num_cv_round = num_cv_round 41 | self.y_pred = np.zeros((num_rows, num_cv_round)) 42 | self.y_prob = self.y_pred.copy() if classification else None 43 | self.cv_repeat_counter = np.zeros((num_rows,)).astype(int) 44 | self.classification = classification 45 | self.output_data_dir = output_data_dir 46 | self.pred_ndim_ = None 47 | 48 | def record(self, indices: np.ndarray, predictions: np.ndarray) -> None: 49 | """Record predictions on a single validation fold in-memory. 50 | 51 | :param indices: indicates for which rows the predictions were made. 52 | :param predictions: predictions for rows specified in `indices` variable. 53 | """ 54 | if self.pred_ndim_ is None: 55 | self.pred_ndim_ = predictions.ndim 56 | if self.pred_ndim_ != predictions.ndim: 57 | raise exc.AlgorithmError(f"Expected predictions with ndim={self.pred_ndim_}, got ndim={predictions.ndim}.") 58 | 59 | cv_repeat_idx = self.cv_repeat_counter[indices] 60 | if np.any(cv_repeat_idx == self.num_cv_round): 61 | sample_rows = cv_repeat_idx[cv_repeat_idx == self.num_cv_round] 62 | sample_rows = sample_rows[:EXAMPLE_ROWS_EXCEPTION_COUNT] 63 | raise exc.AlgorithmError( 64 | f"More than {self.num_cv_round} repeated predictions for same row were provided. " 65 | f"Example row indices where this is the case: {sample_rows}." 66 | ) 67 | 68 | if self.classification: 69 | if predictions.ndim > 1: 70 | labels = np.argmax(predictions, axis=-1) 71 | proba = predictions[np.arange(len(labels)), labels] 72 | else: 73 | labels = 1 * (predictions > 0.5) 74 | proba = predictions 75 | self.y_pred[indices, cv_repeat_idx] = labels 76 | self.y_prob[indices, cv_repeat_idx] = proba 77 | else: 78 | self.y_pred[indices, cv_repeat_idx] = predictions 79 | self.cv_repeat_counter[indices] += 1 80 | 81 | def _aggregate_predictions(self) -> np.ndarray: 82 | if not np.all(self.cv_repeat_counter == self.num_cv_round): 83 | sample_rows = self.cv_repeat_counter[self.cv_repeat_counter != self.num_cv_round] 84 | sample_rows = sample_rows[:EXAMPLE_ROWS_EXCEPTION_COUNT] 85 | raise exc.AlgorithmError( 86 | f"For some rows number of repeated validation set predictions provided is not {self.num_cv_round}. " 87 | f"Example row indices where this is the case: {sample_rows}" 88 | ) 89 | 90 | columns = [self.y_true] 91 | if self.classification: 92 | columns.append(self.y_prob.mean(axis=-1)) 93 | # mode always returns same number of dimensions of output as for input 94 | model_result = stats.mode(self.y_pred, axis=1, keepdims=True) 95 | model_values = model_result.mode 96 | if model_values.ndim > 1: 97 | model_values = model_values[:, 0] 98 | columns.append(model_values) 99 | else: 100 | columns.append(self.y_pred.mean(axis=-1)) 101 | 102 | return np.vstack(columns).T 103 | 104 | def _check_output_path(self) -> None: 105 | if not os.path.exists(self.output_data_dir): 106 | logging.warn(f"Output directory {self.output_data_dir} not found; Creating the output directory.") 107 | os.makedirs(self.output_data_dir) 108 | 109 | def _get_save_path(self) -> str: 110 | return os.path.join(self.output_data_dir, PREDICTIONS_OUTPUT_FILE) 111 | 112 | def save(self) -> None: 113 | """Serialize predictions as .csv file in output data directory.""" 114 | self._check_output_path() 115 | save_path = self._get_save_path() 116 | 117 | logging.info(f"Storing predictions on validation set(s) in {save_path}") 118 | np.savetxt(save_path, self._aggregate_predictions(), delimiter=",", fmt="%f") 119 | -------------------------------------------------------------------------------- /src/sagemaker_xgboost_container/encoder.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the 'License'). You 4 | # may not use this file except in compliance with the License. A copy of 5 | # the License is located at 6 | # 7 | # http://aws.amazon.com/apache2.0/ 8 | # 9 | # or in the 'license' file accompanying this file. This file is 10 | # distributed on an 'AS IS' BASIS, WITHOUT WARRANTIES OR CONDITIONS OF 11 | # ANY KIND, either express or implied. See the License for the specific 12 | # language governing permissions and limitations under the License. 13 | from __future__ import absolute_import 14 | 15 | import cgi 16 | import csv 17 | import io 18 | import json 19 | import logging 20 | import os 21 | import tempfile 22 | from typing import Union 23 | 24 | import mlio 25 | import numpy as np 26 | import xgboost as xgb 27 | from mlio.integ.numpy import as_numpy 28 | from mlio.integ.scipy import to_coo_matrix 29 | from sagemaker_containers import _content_types, _errors 30 | from scipy.sparse import vstack as scipy_vstack 31 | 32 | from sagemaker_xgboost_container.constants import xgb_content_types 33 | 34 | 35 | def _clean_csv_string(csv_string, delimiter): 36 | return ["nan" if x == "" else x for x in csv_string.split(delimiter)] 37 | 38 | 39 | def csv_to_dmatrix(input: Union[str, bytes], dtype=None) -> xgb.DMatrix: 40 | """Convert a CSV object to a DMatrix object. 41 | Args: 42 | input (str/binary): CSV string or binary object(encoded by UTF-8). 43 | Assumes the string has been stripped of leading or trailing newline chars. 44 | dtype (dtype, optional): Data type of the resulting array. If None, the dtypes will be determined by the 45 | contents of each column, individually. This argument can only be used to 46 | 'upcast' the array. For downcasting, use the .astype(t) method. 47 | Returns: 48 | (xgb.DMatrix): XGBoost DataMatrix 49 | """ 50 | csv_string = input.decode() if isinstance(input, bytes) else input 51 | sniff_delimiter = csv.Sniffer().sniff(csv_string.split("\n")[0][:512]).delimiter 52 | delimiter = "," if sniff_delimiter.isalnum() else sniff_delimiter 53 | logging.info("Determined delimiter of CSV input is '{}'".format(delimiter)) 54 | 55 | np_payload = np.array(list(map(lambda x: _clean_csv_string(x, delimiter), csv_string.split("\n")))).astype(dtype) 56 | return xgb.DMatrix(np_payload) 57 | 58 | 59 | def libsvm_to_dmatrix(string_like): # type: (bytes) -> xgb.DMatrix 60 | """Convert a LIBSVM string representation to a DMatrix object. 61 | Args: 62 | string_like (bytes): LIBSVM string. 63 | Returns: 64 | (xgb.DMatrix): XGBoost DataMatrix 65 | """ 66 | temp_file_location = None 67 | try: 68 | with tempfile.NamedTemporaryFile(delete=False) as libsvm_file: 69 | temp_file_location = libsvm_file.name 70 | libsvm_file.write(string_like) 71 | 72 | dmatrix = xgb.DMatrix(f"{temp_file_location}?format=libsvm") 73 | finally: 74 | if temp_file_location and os.path.exists(temp_file_location): 75 | os.remove(temp_file_location) 76 | 77 | return dmatrix 78 | 79 | 80 | def recordio_protobuf_to_dmatrix(string_like): # type: (bytes) -> xgb.DMatrix 81 | """Convert a RecordIO-Protobuf byte representation to a DMatrix object. 82 | Args: 83 | string_like (bytes): RecordIO-Protobuf bytes. 84 | Returns: 85 | (xgb.DMatrix): XGBoost DataMatrix 86 | """ 87 | buf = bytes(string_like) 88 | dataset = [mlio.InMemoryStore(buf)] 89 | reader_params = mlio.DataReaderParams(dataset=dataset, batch_size=100) 90 | reader = mlio.RecordIOProtobufReader(reader_params) 91 | 92 | is_dense_tensor = type(reader.peek_example()["values"]) is mlio.DenseTensor 93 | 94 | examples = [] 95 | for example in reader: 96 | # Ignore labels if present 97 | values = as_numpy(example["values"]) if is_dense_tensor else to_coo_matrix(example["values"]) 98 | examples.append(values) 99 | 100 | data = np.vstack(examples) if is_dense_tensor else scipy_vstack(examples).tocsr() 101 | dmatrix = xgb.DMatrix(data) 102 | return dmatrix 103 | 104 | 105 | _dmatrix_decoders_map = { 106 | _content_types.CSV: csv_to_dmatrix, 107 | xgb_content_types.LIBSVM: libsvm_to_dmatrix, 108 | xgb_content_types.X_LIBSVM: libsvm_to_dmatrix, 109 | xgb_content_types.X_RECORDIO_PROTOBUF: recordio_protobuf_to_dmatrix, 110 | } 111 | 112 | 113 | def json_to_jsonlines(json_data): 114 | """Convert a json response to jsonlines. 115 | 116 | :param json_data: json data (dict or json string) 117 | :return: jsonlines encoded response (bytes) 118 | """ 119 | resp_dict = json_data if isinstance(json_data, dict) else json.loads(json_data) 120 | 121 | if len(resp_dict.keys()) != 1: 122 | raise ValueError("JSON response is not compatible for conversion to jsonlines.") 123 | 124 | bio = io.BytesIO() 125 | for value in resp_dict.values(): 126 | for entry in value: 127 | bio.write(bytes(json.dumps(entry) + "\n", "UTF-8")) 128 | return bio.getvalue() 129 | 130 | 131 | def decode(obj, content_type): 132 | # type: (np.array or Iterable or int or float, str) -> xgb.DMatrix 133 | """Decode an object ton a one of the default content types to a DMatrix object. 134 | Args: 135 | obj (object): to be decoded. 136 | content_type (str): content type to be used. 137 | Returns: 138 | np.array: decoded object. 139 | """ 140 | try: 141 | media_content_type, _params = cgi.parse_header(content_type) 142 | decoder = _dmatrix_decoders_map[media_content_type] 143 | return decoder(obj) 144 | except KeyError: 145 | raise _errors.UnsupportedFormatError(media_content_type) 146 | -------------------------------------------------------------------------------- /test/unit/test_distributed.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the 'License'). You 4 | # may not use this file except in compliance with the License. A copy of 5 | # the License is located at 6 | # 7 | # http://aws.amazon.com/apache2.0/ 8 | # 9 | # or in the 'license' file accompanying this file. This file is 10 | # distributed on an 'AS IS' BASIS, WITHOUT WARRANTIES OR CONDITIONS OF 11 | # ANY KIND, either express or implied. See the License for the specific 12 | # language governing permissions and limitations under the License. 13 | from __future__ import absolute_import 14 | 15 | import sys 16 | import time 17 | from multiprocessing import Process, Queue 18 | from test.utils.test_utils import find_two_open_ports 19 | 20 | import pytest 21 | 22 | from sagemaker_xgboost_container import distributed 23 | 24 | 25 | def synchronize_fn(host_count, port, master, idx, q): 26 | hosts = ["127.0.0.1"] + ["localhost" for _ in range(host_count - 1)] 27 | current_host = "127.0.0.1" if master else "localhost" 28 | with distributed.Rabit(hosts, current_host=current_host, port=port, master_host="127.0.0.1") as dr: 29 | results = dr.synchronize({"idx": idx}) 30 | q.put(results) 31 | sys.exit(0) 32 | 33 | 34 | def rabit_run_fn( 35 | host_count, is_run, first_port, second_port, master, idx, q, max_connect_attempts=None, connect_retry_timeout=60 36 | ): 37 | hosts = ["127.0.0.1"] + ["localhost" for _ in range(host_count - 1)] 38 | current_host = "127.0.0.1" if master else "localhost" 39 | args_dict = dict(obj=idx) 40 | 41 | distributed.rabit_run( 42 | q.put, 43 | args_dict, 44 | is_run, 45 | hosts, 46 | current_host, 47 | first_port, 48 | second_port, 49 | max_connect_attempts=max_connect_attempts, 50 | connect_retry_timeout=connect_retry_timeout, 51 | update_rabit_args=False, 52 | ) 53 | 54 | sys.exit(0) 55 | 56 | 57 | def rabit_run_delay_master(host_count, is_run, first_port, second_port, master, idx, q, max_connect_attempts): 58 | if master: 59 | time.sleep(10) 60 | 61 | rabit_run_fn(host_count, is_run, first_port, second_port, master, idx, q, max_connect_attempts=max_connect_attempts) 62 | 63 | 64 | def rabit_run_fail(test_fn, host_count, is_run, first_port, second_port, master, idx, q, max_connect_attempts=None): 65 | try: 66 | test_fn(host_count, is_run, first_port, second_port, master, idx, q, max_connect_attempts=max_connect_attempts) 67 | 68 | raise Exception("This rabit run should fail!") 69 | except Exception as e: 70 | q.put("{} {}".format(idx, str(e))) 71 | 72 | 73 | def test_integration_rabit_synchronize(): 74 | q = Queue() 75 | 76 | port, _ = find_two_open_ports() 77 | print(f"test_integration_rabit_synchronize, port={port}") 78 | 79 | host_count = 5 80 | host_list = range(host_count) 81 | expected_results = [{"idx": idx} for idx in host_list] 82 | 83 | for idx in host_list: 84 | p = Process(target=synchronize_fn, args=(host_count, port, idx == 0, idx, q)) 85 | p.start() 86 | 87 | num_responses = 0 88 | while num_responses < host_count: 89 | host_aggregated_result = q.get(timeout=30) 90 | for host_individual_result in host_aggregated_result: 91 | assert host_individual_result in expected_results 92 | num_responses += 1 93 | 94 | 95 | def test_rabit_run_all_hosts_run(): 96 | q = Queue() 97 | 98 | first_port, second_port = find_two_open_ports() 99 | 100 | host_count = 5 101 | host_list = range(host_count) 102 | expected_results = [idx for idx in host_list] 103 | 104 | for idx in host_list: 105 | p = Process(target=rabit_run_fn, args=(host_count, True, first_port, second_port, idx == 0, idx, q)) 106 | p.start() 107 | 108 | num_responses = 0 109 | while num_responses < host_count: 110 | response = q.get(timeout=120) 111 | expected_results.remove(response) 112 | num_responses += 1 113 | 114 | assert len(expected_results) == 0 115 | 116 | 117 | def test_rabit_run_exclude_one_host(): 118 | q = Queue() 119 | 120 | first_port, second_port = find_two_open_ports() 121 | 122 | idx_to_exclude = 3 123 | 124 | host_count = 5 125 | host_list = range(host_count) 126 | expected_results = [idx for idx in host_list if idx != idx_to_exclude] 127 | 128 | for idx in host_list: 129 | p = Process( 130 | target=rabit_run_fn, args=(host_count, idx != idx_to_exclude, first_port, second_port, idx == 0, idx, q) 131 | ) 132 | p.start() 133 | 134 | num_responses = 0 135 | while num_responses < host_count - 1: 136 | response = q.get(timeout=300) 137 | expected_results.remove(response) 138 | num_responses += 1 139 | 140 | assert len(expected_results) == 0 141 | 142 | 143 | def test_rabit_delay_master(): 144 | q = Queue() 145 | 146 | first_port, second_port = find_two_open_ports() 147 | 148 | host_count = 5 149 | host_list = range(host_count) 150 | expected_results = [idx for idx in host_list] 151 | 152 | for idx in host_list: 153 | p = Process( 154 | target=rabit_run_delay_master, args=(host_count, True, first_port, second_port, idx == 0, idx, q, 3) 155 | ) 156 | p.start() 157 | 158 | num_responses = 0 159 | while num_responses < host_count: 160 | response = q.get(timeout=300) 161 | expected_results.remove(response) 162 | num_responses += 1 163 | 164 | assert len(expected_results) == 0 165 | 166 | 167 | @pytest.mark.parametrize("bad_max_retry_attempts", [0, -1]) 168 | def test_rabit_run_fail_bad_max_retry_attempts(bad_max_retry_attempts): 169 | q = Queue() 170 | 171 | first_port, second_port = find_two_open_ports() 172 | 173 | host_count = 5 174 | host_list = range(host_count) 175 | 176 | for idx in host_list: 177 | p = Process( 178 | target=rabit_run_fail, 179 | args=(rabit_run_fn, host_count, True, first_port, second_port, idx == 0, idx, q, bad_max_retry_attempts), 180 | ) 181 | p.start() 182 | 183 | num_responses = 0 184 | while num_responses < host_count: 185 | host_result = q.get(timeout=30) 186 | assert "max_connect_attempts must be None or an integer greater than 0." in host_result 187 | num_responses += 1 188 | -------------------------------------------------------------------------------- /test/conftest.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the 'License'). You 4 | # may not use this file except in compliance with the License. A copy of 5 | # the License is located at 6 | # 7 | # http://aws.amazon.com/apache2.0/ 8 | # 9 | # or in the 'license' file accompanying this file. This file is 10 | # distributed on an 'AS IS' BASIS, WITHOUT WARRANTIES OR CONDITIONS OF 11 | # ANY KIND, either express or implied. See the License for the specific 12 | # language governing permissions and limitations under the License. 13 | from __future__ import absolute_import 14 | 15 | import logging 16 | import os 17 | import platform 18 | import shutil 19 | import sys 20 | import tempfile 21 | 22 | import boto3 23 | import pytest 24 | from sagemaker import Session 25 | 26 | from .utils import local_mode 27 | 28 | logger = logging.getLogger(__name__) 29 | logging.getLogger("boto").setLevel(logging.INFO) 30 | logging.getLogger("boto3").setLevel(logging.INFO) 31 | logging.getLogger("botocore").setLevel(logging.INFO) 32 | logging.getLogger("factory.py").setLevel(logging.INFO) 33 | logging.getLogger("auth.py").setLevel(logging.INFO) 34 | logging.getLogger("connectionpool.py").setLevel(logging.INFO) 35 | 36 | 37 | dir_path = os.path.dirname(os.path.realpath(__file__)) 38 | 39 | 40 | def pytest_addoption(parser): 41 | parser.addoption("--build-image", "-D", action="store_true") 42 | parser.addoption("--build-base-image", "-B", action="store_true") 43 | parser.addoption("--aws-id") 44 | parser.addoption("--instance-type", default="local") 45 | parser.addoption("--install-container-support", "-C", action="store_true") 46 | parser.addoption("--docker-base-name", default="sk-learn") 47 | parser.addoption("--region", default="us-west-2") 48 | parser.addoption("--framework-version", default="1.0-1") 49 | parser.addoption("--py-version", choices=["2", "3"], default=str(sys.version_info.major)) 50 | parser.addoption("--processor", choices=["cpu"], default="cpu") 51 | # If not specified, will default to {framework-version}-{processor}-py{py-version} 52 | parser.addoption("--tag", default=None) 53 | 54 | 55 | @pytest.fixture(scope="session", name="docker_base_name") 56 | def fixture_docker_base_name(request): 57 | return request.config.getoption("--docker-base-name") 58 | 59 | 60 | @pytest.fixture(scope="session", name="region") 61 | def fixture_region(request): 62 | return request.config.getoption("--region") 63 | 64 | 65 | @pytest.fixture(scope="session", name="framework_version") 66 | def fixture_framework_version(request): 67 | return request.config.getoption("--framework-version") 68 | 69 | 70 | @pytest.fixture(scope="session", name="py_version") 71 | def fixture_py_version(request): 72 | return "py{}".format(int(request.config.getoption("--py-version"))) 73 | 74 | 75 | @pytest.fixture(scope="session", name="processor") 76 | def fixture_processor(request): 77 | return request.config.getoption("--processor") 78 | 79 | 80 | @pytest.fixture(scope="session", name="tag") 81 | def fixture_tag(request, framework_version, processor, py_version): 82 | provided_tag = request.config.getoption("--tag") 83 | default_tag = "{}-{}-{}".format(framework_version, processor, py_version) 84 | return provided_tag if provided_tag else default_tag 85 | 86 | 87 | @pytest.fixture(scope="session", name="docker_image") 88 | def fixture_docker_image(docker_base_name, tag): 89 | return "{}:{}".format(docker_base_name, tag) 90 | 91 | 92 | @pytest.fixture 93 | def opt_ml(): 94 | tmp = tempfile.mkdtemp() 95 | os.mkdir(os.path.join(tmp, "output")) 96 | 97 | # Docker cannot mount Mac OS /var folder properly see 98 | # https://forums.docker.com/t/var-folders-isnt-mounted-properly/9600 99 | opt_ml_dir = "/private{}".format(tmp) if platform.system() == "Darwin" else tmp 100 | yield opt_ml_dir 101 | 102 | shutil.rmtree(tmp, True) 103 | 104 | 105 | @pytest.fixture(scope="session", name="install_container_support", autouse=True) 106 | def fixture_install_container_support(request): 107 | install = request.config.getoption("--install-container-support") 108 | if install: 109 | local_mode.install_container_support() 110 | 111 | 112 | @pytest.fixture(scope="session", name="build_base_image", autouse=True) 113 | def fixture_build_base_image(request, framework_version, py_version, processor, tag, docker_base_name): 114 | build_base_image = request.config.getoption("--build-base-image") 115 | if build_base_image: 116 | return local_mode.build_base_image( 117 | framework_name=docker_base_name, 118 | framework_version=framework_version, 119 | py_version=py_version, 120 | base_image_tag=tag, 121 | processor=processor, 122 | cwd=os.path.join(dir_path, ".."), 123 | ) 124 | 125 | return tag 126 | 127 | 128 | @pytest.fixture(scope="session", name="build_image", autouse=True) 129 | def fixture_build_image(request, framework_version, py_version, processor, tag, docker_base_name): 130 | build_image = request.config.getoption("--build-image") 131 | if build_image: 132 | return local_mode.build_image( 133 | framework_name=docker_base_name, 134 | framework_version=framework_version, 135 | py_version=py_version, 136 | processor=processor, 137 | tag=tag, 138 | cwd=os.path.join(dir_path, ".."), 139 | ) 140 | 141 | return tag 142 | 143 | 144 | @pytest.fixture(scope="session", name="sagemaker_session") 145 | def fixture_sagemaker_session(region): 146 | return Session(boto_session=boto3.Session(region_name=region)) 147 | 148 | 149 | @pytest.fixture(name="aws_id", scope="session") 150 | def fixture_aws_id(request): 151 | return request.config.getoption("--aws-id") 152 | 153 | 154 | @pytest.fixture(name="instance_type", scope="session") 155 | def fixture_instance_type(request): 156 | return request.config.getoption("--instance-type") 157 | 158 | 159 | @pytest.fixture(name="docker_registry", scope="session") 160 | def fixture_docker_registry(aws_id, region): 161 | return "{}.dkr.ecr.{}.amazonaws.com".format(aws_id, region) 162 | 163 | 164 | @pytest.fixture(name="ecr_image", scope="session") 165 | def fixture_ecr_image(docker_registry, docker_base_name, tag): 166 | return "{}/{}:{}".format(docker_registry, docker_base_name, tag) 167 | 168 | 169 | @pytest.fixture(scope="session", name="dist_cpu_backend", params=["tcp", "gloo"]) 170 | def fixture_dist_cpu_backend(request): 171 | return request.param 172 | -------------------------------------------------------------------------------- /src/sagemaker_xgboost_container/algorithm_mode/handler_service.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"). You 4 | # may not use this file except in compliance with the License. A copy of 5 | # the License is located at 6 | # 7 | # http://aws.amazon.com/apache2.0/ 8 | # 9 | # or in the "license" file accompanying this file. This file is 10 | # distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF 11 | # ANY KIND, either express or implied. See the License for the specific 12 | # language governing permissions and limitations under the License. 13 | from __future__ import absolute_import 14 | 15 | import os 16 | 17 | from sagemaker_inference import content_types, default_inference_handler, encoder 18 | from sagemaker_inference.default_handler_service import DefaultHandlerService 19 | 20 | from sagemaker_xgboost_container.algorithm_mode import serve_utils 21 | from sagemaker_xgboost_container.algorithm_mode.inference_errors import ( 22 | BadRequestInferenceError, 23 | ModelLoadInferenceError, 24 | NoContentInferenceError, 25 | UnsupportedMediaTypeInferenceError, 26 | ) 27 | from sagemaker_xgboost_container.mms_patch.mms_transformer import XGBMMSTransformer 28 | 29 | SAGEMAKER_BATCH = os.getenv("SAGEMAKER_BATCH") 30 | 31 | 32 | class HandlerService(DefaultHandlerService): 33 | """Handler service that is executed by the model server. 34 | Determines specific default inference handlers to use based on the type MXNet model being used. 35 | This class extends ``DefaultHandlerService``, which define the following: 36 | - The ``handle`` method is invoked for all incoming inference requests to the model server. 37 | - The ``initialize`` method is invoked at model server start up. 38 | Based on: https://github.com/awslabs/mxnet-model-server/blob/v1.0.8/docs/custom_service.md 39 | """ 40 | 41 | class DefaultXGBoostAlgoModeInferenceHandler(default_inference_handler.DefaultInferenceHandler): 42 | def default_model_fn(self, model_dir): 43 | """Load a model. For XGBoost Framework, a default function to load a model is not provided. 44 | Users should provide customized model_fn() in script. 45 | Args: 46 | model_dir: a directory where model is saved. 47 | Returns: 48 | A XGBoost model. 49 | XGBoost model format type. 50 | """ 51 | try: 52 | booster, format = serve_utils.get_loaded_booster(model_dir, serve_utils.is_ensemble_enabled()) 53 | except Exception as e: 54 | raise ModelLoadInferenceError("Unable to load model: {}".format(str(e))) 55 | return booster, format 56 | 57 | def default_input_fn(self, input_data, input_content_type): 58 | """Take request data and de-serializes the data into an object for prediction. 59 | When an InvokeEndpoint operation is made against an Endpoint running SageMaker model server, 60 | the model server receives two pieces of information: 61 | - The request Content-Type, for example "application/json" 62 | - The request data, which is at most 5 MB (5 * 1024 * 1024 bytes) in size. 63 | The input_fn is responsible to take the request data and pre-process it before prediction. 64 | Args: 65 | input_data (obj): the request data. 66 | input_content_type (str): the request Content-Type. XGBoost accepts CSV, LIBSVM, and RECORDIO-PROTOBUF. 67 | Returns: 68 | (obj): data ready for prediction. For XGBoost, this defaults to DMatrix. 69 | """ 70 | if len(input_data) == 0: 71 | raise NoContentInferenceError() 72 | dtest, content_type = serve_utils.parse_content_data(input_data, input_content_type) 73 | return dtest, content_type 74 | 75 | def default_predict_fn(self, data, model): 76 | """A default predict_fn for XGBooost Framework. Calls a model on data deserialized in input_fn. 77 | Args: 78 | data: input data (DMatrix) for prediction deserialized by input_fn and data content type 79 | model: XGBoost model loaded in memory by model_fn, and xgboost model format 80 | Returns: a prediction 81 | """ 82 | booster, model_format = model 83 | dtest, content_type = data 84 | try: 85 | return serve_utils.predict(booster, model_format, dtest, content_type) 86 | except Exception as e: 87 | raise BadRequestInferenceError(str(e)) 88 | 89 | def default_output_fn(self, prediction, accept): 90 | """Return encoded prediction for the response. 91 | Args: 92 | prediction (obj): prediction returned by predict_fn . 93 | accept (str): accept content-type expected by the client. 94 | Returns: 95 | encoded response for MMS to return to client 96 | """ 97 | accept_type = accept.lower() 98 | try: 99 | if accept_type == content_types.CSV or accept_type == "csv": 100 | if SAGEMAKER_BATCH: 101 | return_data = "\n".join(map(str, prediction.tolist())) + "\n" 102 | else: 103 | # FIXME: this is invalid CSV and is only retained for backwards compatibility 104 | return_data = ",".join(map(str, prediction.tolist())) 105 | encoded_prediction = return_data.encode("utf-8") 106 | elif accept_type == content_types.JSON or accept_type == "json": 107 | encoded_prediction = encoder.encode(prediction, accept_type) 108 | else: 109 | raise ValueError( 110 | "{} is not an accepted Accept type. Please choose one of the following:" 111 | " ['{}', '{}'].".format(accept, content_types.CSV, content_types.JSON) 112 | ) 113 | except Exception as e: 114 | raise UnsupportedMediaTypeInferenceError( 115 | "Encoding to accept type {} failed with exception: {}".format(accept, e) 116 | ) 117 | return encoded_prediction 118 | 119 | def __init__(self): 120 | transformer = XGBMMSTransformer(default_inference_handler=self.DefaultXGBoostAlgoModeInferenceHandler()) 121 | super(HandlerService, self).__init__(transformer=transformer) 122 | -------------------------------------------------------------------------------- /src/sagemaker_xgboost_container/serving_mms.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the 'License'). You 4 | # may not use this file except in compliance with the License. A copy of 5 | # the License is located at 6 | # 7 | # http://aws.amazon.com/apache2.0/ 8 | # 9 | # or in the 'license' file accompanying this file. This file is 10 | # distributed on an 'AS IS' BASIS, WITHOUT WARRANTIES OR CONDITIONS OF 11 | # ANY KIND, either express or implied. See the License for the specific 12 | # language governing permissions and limitations under the License. 13 | from __future__ import absolute_import 14 | 15 | import logging 16 | import multiprocessing 17 | import os 18 | import subprocess 19 | from math import ceil 20 | 21 | from retrying import retry 22 | from sagemaker_containers.beta.framework import env, modules 23 | 24 | from sagemaker_xgboost_container import handler_service as user_module_handler_service 25 | from sagemaker_xgboost_container.algorithm_mode import ( 26 | handler_service as algo_handler_service, 27 | ) 28 | from sagemaker_xgboost_container.mms_patch import model_server 29 | 30 | ALGO_HANDLER_SERVICE = algo_handler_service.__name__ 31 | USER_HANDLER_SERVICE = user_module_handler_service.__name__ 32 | 33 | PORT = 8080 34 | DEFAULT_MAX_CONTENT_LEN = 6 * 1024**2 35 | MAX_CONTENT_LEN_LIMIT = 20 * 1024**2 36 | MMS_NUM_MODEL_WORKERS_INIT = 1 37 | MMS_MODEL_JOB_QUEUE_SIZE_DEFAULT = 100 38 | 39 | 40 | def get_mms_config_file_path(): 41 | return os.environ["XGBOOST_MMS_CONFIG"] 42 | 43 | 44 | def _retry_if_error(exception): 45 | return isinstance(exception, subprocess.CalledProcessError) 46 | 47 | 48 | @retry(stop_max_delay=1000 * 30, retry_on_exception=_retry_if_error) 49 | def _start_model_server(is_multi_model, handler): 50 | # there's a race condition that causes the model server command to 51 | # sometimes fail with 'bad address'. more investigation needed 52 | # retry starting mms until it's ready 53 | logging.info("Trying to set up model server handler: {}".format(handler)) 54 | _set_mms_configs(is_multi_model, handler) 55 | model_server.start_model_server( 56 | handler_service=handler, is_multi_model=is_multi_model, config_file=get_mms_config_file_path() 57 | ) 58 | 59 | 60 | def _is_multi_model_endpoint(): 61 | if "SAGEMAKER_MULTI_MODEL" in os.environ and os.environ["SAGEMAKER_MULTI_MODEL"] == "true": 62 | return True 63 | else: 64 | return False 65 | 66 | 67 | def _set_default_if_not_exist(sagemaker_env_var_name, default_value): 68 | if not os.getenv(sagemaker_env_var_name, None): 69 | os.environ[sagemaker_env_var_name] = str(default_value) 70 | 71 | 72 | def _set_mms_configs(is_multi_model, handler): 73 | """Set environment variables for MMS to parse during server initialization. These env vars are used to 74 | propagate the config.properties file used during MxNet Model Server initialization. 75 | 'SAGEMAKER_MMS_MODEL_STORE' has to be set to the model location during single model inference because MMS 76 | is initialized with the model. In multi-model mode, MMS is started with no models loaded. 77 | Note: Ideally, instead of relying on env vars, this should be written directly to a config file. 78 | """ 79 | max_content_length = os.getenv("MAX_CONTENT_LENGTH", DEFAULT_MAX_CONTENT_LEN) 80 | if int(max_content_length) > MAX_CONTENT_LEN_LIMIT: 81 | # Cap at 20mb 82 | max_content_length = MAX_CONTENT_LEN_LIMIT 83 | 84 | max_workers = multiprocessing.cpu_count() 85 | max_job_queue_size = 2 * max_workers 86 | 87 | # Max heap size = (max workers + max job queue size) * max payload size * 1.2 (20% buffer) + 128 (base amount) 88 | max_heap_size = ceil((max_workers + max_job_queue_size) * (int(max_content_length) / 1024**2) * 1.2) + 128 89 | 90 | os.environ["SAGEMAKER_MMS_MODEL_STORE"] = "/" 91 | os.environ["SAGEMAKER_MMS_LOAD_MODELS"] = "" 92 | os.environ["SAGEMAKER_MMS_DEFAULT_HANDLER"] = handler 93 | 94 | # Users can define port 95 | _set_default_if_not_exist("SAGEMAKER_BIND_TO_PORT", str(PORT)) 96 | 97 | # Multi Model Server configs, exposed to users as env vars 98 | _set_default_if_not_exist("SAGEMAKER_NUM_MODEL_WORKERS", MMS_NUM_MODEL_WORKERS_INIT) 99 | _set_default_if_not_exist("SAGEMAKER_MODEL_JOB_QUEUE_SIZE", MMS_MODEL_JOB_QUEUE_SIZE_DEFAULT) 100 | _set_default_if_not_exist("SAGEMAKER_MAX_REQUEST_SIZE", max_content_length) 101 | 102 | # JVM configurations for MMS, exposed to users as env vars 103 | _set_default_if_not_exist("SAGEMAKER_MAX_HEAP_SIZE", str(max_heap_size) + "m") 104 | _set_default_if_not_exist("SAGEMAKER_MAX_DIRECT_MEMORY_SIZE", os.environ["SAGEMAKER_MAX_HEAP_SIZE"]) 105 | 106 | disable_container_support_flag = "" 107 | if ( 108 | "SAGEMAKER_DISABLE_CONTAINER_SUPPORT" in os.environ 109 | and os.environ["SAGEMAKER_DISABLE_CONTAINER_SUPPORT"] == "true" 110 | ): 111 | disable_container_support_flag = " -XX:-UseContainerSupport" 112 | 113 | MMS_CONFIG_FILE_PATH = get_mms_config_file_path() 114 | 115 | # TODO: Revert config.properties.tmp to config.properties and add back in vmargs 116 | # set with environment variables after MMS implements parsing environment variables 117 | # for vmargs, update MMS section of final/Dockerfile.cpu to match, and remove the 118 | # following code. 119 | try: 120 | with open(MMS_CONFIG_FILE_PATH + ".tmp", "r") as f: 121 | with open(MMS_CONFIG_FILE_PATH, "w+") as g: 122 | g.write( 123 | "vmargs=-XX:-UseLargePages" 124 | + " -XX:+UseParNewGC" 125 | + " -XX:MaxMetaspaceSize=32M" 126 | + " -XX:InitiatingHeapOccupancyPercent=25" 127 | + " -Xms" 128 | + os.environ["SAGEMAKER_MAX_HEAP_SIZE"] 129 | + " -Xmx" 130 | + os.environ["SAGEMAKER_MAX_HEAP_SIZE"] 131 | + " -XX:MaxDirectMemorySize=" 132 | + os.environ["SAGEMAKER_MAX_DIRECT_MEMORY_SIZE"] 133 | + disable_container_support_flag 134 | + "\n" 135 | ) 136 | g.write(f.read()) 137 | except Exception: 138 | pass 139 | 140 | 141 | def start_mxnet_model_server(): 142 | serving_env = env.ServingEnv() 143 | is_multi_model = True 144 | 145 | if serving_env.module_name is None: 146 | logging.info("Starting MXNet server in algorithm mode.") 147 | _start_model_server(is_multi_model, ALGO_HANDLER_SERVICE) 148 | else: 149 | logging.info("Staring MXNet Model Server with user module.") 150 | # Install user module from s3 to import 151 | modules.import_module(serving_env.module_dir, serving_env.module_name) 152 | _start_model_server(is_multi_model, USER_HANDLER_SERVICE) 153 | -------------------------------------------------------------------------------- /src/sagemaker_xgboost_container/serving.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the 'License'). You 4 | # may not use this file except in compliance with the License. A copy of 5 | # the License is located at 6 | # 7 | # http://aws.amazon.com/apache2.0/ 8 | # 9 | # or in the 'license' file accompanying this file. This file is 10 | # distributed on an 'AS IS' BASIS, WITHOUT WARRANTIES OR CONDITIONS OF 11 | # ANY KIND, either express or implied. See the License for the specific 12 | # language governing permissions and limitations under the License. 13 | from __future__ import absolute_import 14 | 15 | import logging 16 | import os 17 | import importlib 18 | 19 | from sagemaker_containers.beta.framework import ( 20 | encoders, 21 | env, 22 | server, 23 | transformer, 24 | worker, 25 | ) 26 | 27 | from sagemaker_algorithm_toolkit import exceptions as exc 28 | from sagemaker_xgboost_container import encoder as xgb_encoders 29 | from sagemaker_xgboost_container.algorithm_mode import serve 30 | from sagemaker_xgboost_container.constants import sm_env_constants 31 | from sagemaker_xgboost_container.serving_mms import start_mxnet_model_server 32 | 33 | logging.basicConfig(format="%(asctime)s %(levelname)s - %(name)s - %(message)s", level=logging.INFO) 34 | logging.getLogger("boto3").setLevel(logging.INFO) 35 | logging.getLogger("s3transfer").setLevel(logging.INFO) 36 | logging.getLogger("botocore").setLevel(logging.WARN) 37 | 38 | logger = logging.getLogger(__name__) 39 | logger.setLevel(logging.DEBUG) 40 | 41 | 42 | def is_multi_model(): 43 | return os.environ.get("SAGEMAKER_MULTI_MODEL") 44 | 45 | 46 | def set_default_serving_env_if_unspecified(): 47 | """Set default values for environment variables if they aren't already specified. 48 | 49 | set "OMP_NUM_THREADS" = sm_env_constants.ONE_THREAD_PER_PROCESS 50 | Single-thread processes by default. Multithreading can introduce significant 51 | performance overhead due to task switching. 52 | """ 53 | env_default_dict = {"OMP_NUM_THREADS": sm_env_constants.ONE_THREAD_PER_PROCESS} 54 | for always_specified_key, default_value in env_default_dict.items(): 55 | try: 56 | # If this does not throw, the user has specified a non-default value. 57 | os.environ[always_specified_key] 58 | except KeyError: 59 | # Key that is always specified is not set in the environment. Set default value. 60 | os.environ[always_specified_key] = default_value 61 | 62 | 63 | def default_model_fn(model_dir): 64 | """Load a model. For XGBoost Framework, a default function to load a model is not provided. 65 | Users should provide customized model_fn() in script. 66 | Args: 67 | model_dir: a directory where model is saved. 68 | Returns: A XGBoost model. 69 | """ 70 | return transformer.default_model_fn(model_dir) 71 | 72 | 73 | def default_input_fn(input_data, content_type): 74 | """Take request data and de-serializes the data into an object for prediction. 75 | When an InvokeEndpoint operation is made against an Endpoint running SageMaker model server, 76 | the model server receives two pieces of information: 77 | - The request Content-Type, for example "application/json" 78 | - The request data, which is at most 5 MB (5 * 1024 * 1024 bytes) in size. 79 | The input_fn is responsible to take the request data and pre-process it before prediction. 80 | Note: For CSV data, the decoder will error if there are any leading or trailing newline 81 | chars. 82 | Args: 83 | input_data (obj): the request data. 84 | content_type (str): the request Content-Type. 85 | Returns: 86 | (obj): data ready for prediction. For XGBoost, this defaults to DMatrix. 87 | """ 88 | return xgb_encoders.decode(input_data, content_type) 89 | 90 | 91 | def default_predict_fn(input_data, model): 92 | """A default predict_fn for XGBooost Framework. Calls a model on data deserialized in input_fn. 93 | Args: 94 | input_data: input data (Numpy array) for prediction deserialized by input_fn 95 | model: XGBoost model loaded in memory by model_fn 96 | Returns: a prediction 97 | """ 98 | output = model.predict(input_data, validate_features=False) 99 | return output 100 | 101 | 102 | def default_output_fn(prediction, accept): 103 | """Function responsible to serialize the prediction for the response. 104 | Args: 105 | prediction (obj): prediction returned by predict_fn . 106 | accept (str): accept content-type expected by the client. 107 | Returns: 108 | (worker.Response): a Flask response object with the following args: 109 | * Args: 110 | response: the serialized data to return 111 | accept: the content-type that the data was transformed to. 112 | """ 113 | return worker.Response(encoders.encode(prediction, accept), mimetype=accept) 114 | 115 | 116 | def _user_module_transformer(user_module): 117 | model_fn = getattr(user_module, "model_fn", default_model_fn) 118 | input_fn = getattr(user_module, "input_fn", None) 119 | predict_fn = getattr(user_module, "predict_fn", None) 120 | output_fn = getattr(user_module, "output_fn", None) 121 | transform_fn = getattr(user_module, "transform_fn", None) 122 | 123 | if transform_fn and (input_fn or predict_fn or output_fn): 124 | raise exc.UserError("Cannot use transform_fn implementation with input_fn, predict_fn, and/or output_fn") 125 | 126 | if transform_fn is not None: 127 | return transformer.Transformer(model_fn=model_fn, transform_fn=transform_fn) 128 | else: 129 | return transformer.Transformer( 130 | model_fn=model_fn, 131 | input_fn=input_fn or default_input_fn, 132 | predict_fn=predict_fn or default_predict_fn, 133 | output_fn=output_fn or default_output_fn, 134 | ) 135 | 136 | 137 | app = None 138 | 139 | 140 | def main(environ, start_response): 141 | global app 142 | if app is None: 143 | serving_env = env.ServingEnv() 144 | if serving_env.module_name is None: 145 | app = serve.ScoringService.csdk_start() 146 | else: 147 | user_module = importlib.import_module(serving_env.module_name) 148 | user_module_transformer = _user_module_transformer(user_module) 149 | user_module_transformer.initialize() 150 | app = worker.Worker( 151 | transform_fn=user_module_transformer.transform, 152 | module_name=serving_env.module_name, 153 | ) 154 | 155 | return app(environ, start_response) 156 | 157 | 158 | def serving_entrypoint(): 159 | """Start Inference Server. 160 | 161 | NOTE: If the inference server is multi-model, MxNet Model Server will be used as the base server. Otherwise, 162 | GUnicorn is used as the base server. 163 | """ 164 | set_default_serving_env_if_unspecified() 165 | 166 | if is_multi_model(): 167 | start_mxnet_model_server() 168 | else: 169 | server.start(env.ServingEnv().framework_module) 170 | -------------------------------------------------------------------------------- /src/sagemaker_xgboost_container/mms_patch/model_server.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"). You 4 | # may not use this file except in compliance with the License. A copy of 5 | # the License is located at 6 | # 7 | # http://aws.amazon.com/apache2.0/ 8 | # 9 | # or in the "license" file accompanying this file. This file is 10 | # distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF 11 | # ANY KIND, either express or implied. See the License for the specific 12 | # language governing permissions and limitations under the License. 13 | from __future__ import absolute_import 14 | 15 | import os 16 | import signal 17 | import subprocess 18 | import sys 19 | 20 | import pkg_resources 21 | import psutil 22 | import sagemaker_inference 23 | from retrying import retry 24 | from sagemaker_inference import default_handler_service, environment, logging, utils 25 | from sagemaker_inference.environment import code_dir 26 | 27 | logger = logging.get_logger() 28 | 29 | DEFAULT_HANDLER_SERVICE = default_handler_service.__name__ 30 | MMS_CONFIG_FILE = os.path.join("/etc", "sagemaker-mms.properties") 31 | DEFAULT_MMS_CONFIG_FILE = pkg_resources.resource_filename(sagemaker_inference.__name__, "/etc/default-mms.properties") 32 | DEFAULT_MMS_LOG_FILE = pkg_resources.resource_filename(sagemaker_inference.__name__, "/etc/log4j.properties") 33 | DEFAULT_MMS_MODEL_DIRECTORY = os.path.join(os.getcwd(), ".sagemaker/mms/models") 34 | DEFAULT_MMS_MODEL_NAME = "model" 35 | 36 | PYTHON_PATH_ENV = "PYTHONPATH" 37 | REQUIREMENTS_PATH = os.path.join(code_dir, "requirements.txt") 38 | MMS_NAMESPACE = "com.amazonaws.ml.mms.ModelServer" 39 | 40 | 41 | def start_model_server(is_multi_model=False, handler_service=DEFAULT_HANDLER_SERVICE, config_file=None): 42 | """Configure and start the model server. 43 | Args: 44 | is_multi_model (bool): Whether to start MxNet Model Server as single model or multi model. If 45 | handler_service (str): python path pointing to a module that defines a class with the following: 46 | - A ``handle`` method, which is invoked for all incoming inference requests to the model server. 47 | - A ``initialize`` method, which is invoked at model server start up for loading the model. 48 | Defaults to ``sagemaker_inference.default_handler_service``. 49 | config_file (str): path to user defined MMS properties file 50 | """ 51 | if not config_file: 52 | _create_model_server_config_file() 53 | config_file = MMS_CONFIG_FILE 54 | 55 | if os.path.exists(REQUIREMENTS_PATH): 56 | _install_requirements() 57 | 58 | _set_python_path() 59 | 60 | mxnet_model_server_cmd = [ 61 | "mxnet-model-server", 62 | "--start", 63 | "--mms-config", 64 | config_file, 65 | "--log-config", 66 | DEFAULT_MMS_LOG_FILE, 67 | ] 68 | 69 | if not is_multi_model: 70 | _adapt_to_mms_format(handler_service) 71 | mxnet_model_server_cmd += ["--model-store", DEFAULT_MMS_MODEL_DIRECTORY] 72 | 73 | logger.info(mxnet_model_server_cmd) 74 | subprocess.Popen(mxnet_model_server_cmd) 75 | 76 | mms_process = _retrieve_mms_server_process() 77 | _add_sigterm_handler(mms_process) 78 | _add_sigchild_handler() 79 | mms_process.wait() 80 | 81 | 82 | def _adapt_to_mms_format(handler_service): 83 | """Archive initial model using MMS handler 84 | :param handler_service: 85 | :return: 86 | """ 87 | if not os.path.exists(DEFAULT_MMS_MODEL_DIRECTORY): 88 | os.makedirs(DEFAULT_MMS_MODEL_DIRECTORY) 89 | 90 | model_archiver_cmd = [ 91 | "model-archiver", 92 | "--model-name", 93 | DEFAULT_MMS_MODEL_NAME, 94 | "--handler", 95 | handler_service, 96 | "--model-path", 97 | environment.model_dir, 98 | "--export-path", 99 | DEFAULT_MMS_MODEL_DIRECTORY, 100 | "--archive-format", 101 | "no-archive", 102 | ] 103 | 104 | logger.info(model_archiver_cmd) 105 | subprocess.check_call(model_archiver_cmd) 106 | 107 | 108 | def _set_python_path(): 109 | # MMS handles code execution by appending the export path, provided 110 | # to the model archiver, to the PYTHONPATH env var. 111 | # The code_dir has to be added to the PYTHONPATH otherwise the 112 | # user provided module can not be imported properly. 113 | code_dir_path = "{}:".format(environment.code_dir) 114 | 115 | if PYTHON_PATH_ENV in os.environ: 116 | os.environ[PYTHON_PATH_ENV] = code_dir_path + os.environ[PYTHON_PATH_ENV] 117 | else: 118 | os.environ[PYTHON_PATH_ENV] = code_dir_path 119 | 120 | 121 | def _create_model_server_config_file(): 122 | configuration_properties = _generate_mms_config_properties() 123 | 124 | utils.write_file(MMS_CONFIG_FILE, configuration_properties) 125 | 126 | 127 | def _generate_mms_config_properties(): 128 | env = environment.Environment() 129 | 130 | user_defined_configuration = { 131 | "default_response_timeout": env.model_server_timeout, 132 | "default_workers_per_model": env.model_server_workers, 133 | "inference_address": "http://0.0.0.0:{}".format(env.http_port), 134 | } 135 | 136 | custom_configuration = str() 137 | 138 | for key in user_defined_configuration: 139 | value = user_defined_configuration.get(key) 140 | if value: 141 | custom_configuration += "{}={}\n".format(key, value) 142 | 143 | mms_default_configuration = utils.read_file(DEFAULT_MMS_CONFIG_FILE) 144 | 145 | return mms_default_configuration + custom_configuration 146 | 147 | 148 | def _add_sigterm_handler(mms_process): 149 | def _terminate(signo, frame): 150 | try: 151 | os.kill(mms_process.pid, signal.SIGTERM) 152 | except OSError: 153 | pass 154 | 155 | signal.signal(signal.SIGTERM, _terminate) 156 | 157 | 158 | def _install_requirements(): 159 | logger.info("installing packages from requirements.txt...") 160 | pip_install_cmd = [sys.executable, "-m", "pip", "install", "-r", REQUIREMENTS_PATH] 161 | 162 | try: 163 | subprocess.check_call(pip_install_cmd) 164 | except subprocess.CalledProcessError: 165 | logger.error("failed to install required packages, exiting") 166 | raise ValueError("failed to install required packages") 167 | 168 | 169 | # retry for 10 seconds 170 | @retry(stop_max_delay=10 * 1000) 171 | def _retrieve_mms_server_process(): 172 | mms_server_processes = list() 173 | 174 | for process in psutil.process_iter(): 175 | if MMS_NAMESPACE in process.cmdline(): 176 | mms_server_processes.append(process) 177 | 178 | if not mms_server_processes: 179 | raise Exception("mms model server was unsuccessfully started") 180 | 181 | if len(mms_server_processes) > 1: 182 | raise Exception("multiple mms model servers are not supported") 183 | 184 | return mms_server_processes[0] 185 | 186 | 187 | def _reap_children(signo, frame): 188 | pid = 1 189 | try: 190 | while pid > 0: 191 | pid, status = os.waitpid(-1, os.WNOHANG) 192 | except OSError: 193 | logger.error("Failed to reap children process") 194 | 195 | 196 | def _add_sigchild_handler(): 197 | signal.signal(signal.SIGCHLD, _reap_children) 198 | --------------------------------------------------------------------------------