├── .coveragerc ├── .flake8 ├── .gitignore ├── ATTRIBUTION.md ├── CODE_OF_CONDUCT.md ├── CONTRIBUTING.md ├── LICENSE ├── MANIFEST.in ├── NOTICE ├── README.md ├── VERSION ├── mypy.ini ├── pyproject.toml ├── requirements ├── black_requirements.txt ├── flake8_requirements.txt ├── integration_test_requirements.txt ├── mypy_requirements.txt ├── prerelease_test_requirements.txt └── twine_requirements.txt ├── sagemaker_mlflow ├── __init__.py ├── auth.py ├── auth_provider.py ├── exceptions.py ├── mlflow_sagemaker_helpers.py ├── mlflow_sagemaker_registry_store.py ├── mlflow_sagemaker_request_header_provider.py ├── mlflow_sagemaker_store.py └── presigned_url.py ├── setup.cfg ├── setup.py ├── test ├── integration │ ├── README.md │ ├── conftest.py │ ├── tests │ │ ├── test_artifact_logging.py │ │ ├── test_metadata_logging.py │ │ ├── test_model_registry.py │ │ └── test_presigned_url_from_server.py │ └── utils │ │ ├── boto_utils.py │ │ ├── random_utils.py │ │ └── sklearn_utils.py ├── prerelease │ └── test_release_version.py └── unit │ ├── test_auth_boto.py │ ├── test_auth_provider.py │ ├── test_mlflow_sagemaker_helpers.py │ ├── test_mlflow_sagemaker_request_header_provider.py │ ├── test_mlflow_sagemaker_store.py │ ├── test_presigned_url.py │ └── test_version.py └── tox.ini /.coveragerc: -------------------------------------------------------------------------------- 1 | [run] 2 | concurrency = threading 3 | omit = test/* 4 | timid = True 5 | disable_warnings = module-not-measured 6 | -------------------------------------------------------------------------------- /.flake8: -------------------------------------------------------------------------------- 1 | [flake8] 2 | application_import_names = sagemaker_mlflow, test 3 | import-order-style = google 4 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | build 2 | sagemaker_mlflow.egg-info 3 | .vscode 4 | .idea 5 | .tox 6 | dist 7 | sagemaker_mlflow/__pycache__ 8 | *.pyc 9 | -------------------------------------------------------------------------------- /ATTRIBUTION.md: -------------------------------------------------------------------------------- 1 | # Open Source Software Attribution 2 | 3 | sagemaker-mlflow plugin 4 | Copyright 2024 Amazon.com, Inc. or its affiliates. All Rights Reserved. 5 | 6 | ## Package dependencies 7 | 8 | boto3 9 | mlflow 10 | Copyright 2018 Databricks, Inc. 11 | Licensed Under Apache 2.0 12 | -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | ## Code of Conduct 2 | This project has adopted the [Amazon Open Source Code of Conduct](https://aws.github.io/code-of-conduct). 3 | For more information see the [Code of Conduct FAQ](https://aws.github.io/code-of-conduct-faq) or contact 4 | opensource-codeofconduct@amazon.com with any additional questions or comments. 5 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing Guidelines 2 | 3 | Thank you for your interest in contributing to our project. Whether it's a bug report, new feature, correction, or additional 4 | documentation, we greatly value feedback and contributions from our community. 5 | 6 | Please read through this document before submitting any issues or pull requests to ensure we have all the necessary 7 | information to effectively respond to your bug report or contribution. 8 | 9 | 10 | ## Reporting Bugs/Feature Requests 11 | 12 | We welcome you to use the GitHub issue tracker to report bugs or suggest features. 13 | 14 | When filing an issue, please check existing open, or recently closed, issues to make sure somebody else hasn't already 15 | reported the issue. Please try to include as much information as you can. Details like these are incredibly useful: 16 | 17 | * A reproducible test case or series of steps 18 | * The version of our code being used 19 | * Any modifications you've made relevant to the bug 20 | * Anything unusual about your environment or deployment 21 | 22 | 23 | ## Contributing via Pull Requests 24 | Contributions via pull requests are much appreciated. Before sending us a pull request, please ensure that: 25 | 26 | 1. You are working against the latest source on the *main* branch. 27 | 2. You check existing open, and recently merged, pull requests to make sure someone else hasn't addressed the problem already. 28 | 3. You open an issue to discuss any significant work - we would hate for your time to be wasted. 29 | 30 | To send us a pull request, please: 31 | 32 | 1. Fork the repository. 33 | 2. Modify the source; please focus on the specific change you are contributing. If you also reformat all the code, it will be hard for us to focus on your change. 34 | 3. Ensure local tests pass. 35 | 4. Commit to your fork using clear commit messages. 36 | 5. Send us a pull request, answering any default questions in the pull request interface. 37 | 6. Pay attention to any automated CI failures reported in the pull request, and stay involved in the conversation. 38 | 39 | GitHub provides additional document on [forking a repository](https://help.github.com/articles/fork-a-repo/) and 40 | [creating a pull request](https://help.github.com/articles/creating-a-pull-request/). 41 | 42 | 43 | ## Finding contributions to work on 44 | Looking at the existing issues is a great way to find something to contribute on. As our projects, by default, use the default GitHub issue labels (enhancement/bug/duplicate/help wanted/invalid/question/wontfix), looking at any 'help wanted' issues is a great place to start. 45 | 46 | 47 | ## Code of Conduct 48 | This project has adopted the [Amazon Open Source Code of Conduct](https://aws.github.io/code-of-conduct). 49 | For more information see the [Code of Conduct FAQ](https://aws.github.io/code-of-conduct-faq) or contact 50 | opensource-codeofconduct@amazon.com with any additional questions or comments. 51 | 52 | 53 | ## Security issue notifications 54 | If you discover a potential security issue in this project we ask that you notify AWS/Amazon Security via our [vulnerability reporting page](http://aws.amazon.com/security/vulnerability-reporting/). Please do **not** create a public github issue. 55 | 56 | 57 | ## Licensing 58 | 59 | See the [LICENSE](LICENSE) file for our project's licensing. We will ask you to confirm the licensing of your contribution. 60 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | 2 | Apache License 3 | Version 2.0, January 2004 4 | http://www.apache.org/licenses/ 5 | 6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 7 | 8 | 1. Definitions. 9 | 10 | "License" shall mean the terms and conditions for use, reproduction, 11 | and distribution as defined by Sections 1 through 9 of this document. 12 | 13 | "Licensor" shall mean the copyright owner or entity authorized by 14 | the copyright owner that is granting the License. 15 | 16 | "Legal Entity" shall mean the union of the acting entity and all 17 | other entities that control, are controlled by, or are under common 18 | control with that entity. For the purposes of this definition, 19 | "control" means (i) the power, direct or indirect, to cause the 20 | direction or management of such entity, whether by contract or 21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 22 | outstanding shares, or (iii) beneficial ownership of such entity. 23 | 24 | "You" (or "Your") shall mean an individual or Legal Entity 25 | exercising permissions granted by this License. 26 | 27 | "Source" form shall mean the preferred form for making modifications, 28 | including but not limited to software source code, documentation 29 | source, and configuration files. 30 | 31 | "Object" form shall mean any form resulting from mechanical 32 | transformation or translation of a Source form, including but 33 | not limited to compiled object code, generated documentation, 34 | and conversions to other media types. 35 | 36 | "Work" shall mean the work of authorship, whether in Source or 37 | Object form, made available under the License, as indicated by a 38 | copyright notice that is included in or attached to the work 39 | (an example is provided in the Appendix below). 40 | 41 | "Derivative Works" shall mean any work, whether in Source or Object 42 | form, that is based on (or derived from) the Work and for which the 43 | editorial revisions, annotations, elaborations, or other modifications 44 | represent, as a whole, an original work of authorship. For the purposes 45 | of this License, Derivative Works shall not include works that remain 46 | separable from, or merely link (or bind by name) to the interfaces of, 47 | the Work and Derivative Works thereof. 48 | 49 | "Contribution" shall mean any work of authorship, including 50 | the original version of the Work and any modifications or additions 51 | to that Work or Derivative Works thereof, that is intentionally 52 | submitted to Licensor for inclusion in the Work by the copyright owner 53 | or by an individual or Legal Entity authorized to submit on behalf of 54 | the copyright owner. For the purposes of this definition, "submitted" 55 | means any form of electronic, verbal, or written communication sent 56 | to the Licensor or its representatives, including but not limited to 57 | communication on electronic mailing lists, source code control systems, 58 | and issue tracking systems that are managed by, or on behalf of, the 59 | Licensor for the purpose of discussing and improving the Work, but 60 | excluding communication that is conspicuously marked or otherwise 61 | designated in writing by the copyright owner as "Not a Contribution." 62 | 63 | "Contributor" shall mean Licensor and any individual or Legal Entity 64 | on behalf of whom a Contribution has been received by Licensor and 65 | subsequently incorporated within the Work. 66 | 67 | 2. Grant of Copyright License. Subject to the terms and conditions of 68 | this License, each Contributor hereby grants to You a perpetual, 69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 70 | copyright license to reproduce, prepare Derivative Works of, 71 | publicly display, publicly perform, sublicense, and distribute the 72 | Work and such Derivative Works in Source or Object form. 73 | 74 | 3. Grant of Patent License. Subject to the terms and conditions of 75 | this License, each Contributor hereby grants to You a perpetual, 76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 77 | (except as stated in this section) patent license to make, have made, 78 | use, offer to sell, sell, import, and otherwise transfer the Work, 79 | where such license applies only to those patent claims licensable 80 | by such Contributor that are necessarily infringed by their 81 | Contribution(s) alone or by combination of their Contribution(s) 82 | with the Work to which such Contribution(s) was submitted. If You 83 | institute patent litigation against any entity (including a 84 | cross-claim or counterclaim in a lawsuit) alleging that the Work 85 | or a Contribution incorporated within the Work constitutes direct 86 | or contributory patent infringement, then any patent licenses 87 | granted to You under this License for that Work shall terminate 88 | as of the date such litigation is filed. 89 | 90 | 4. Redistribution. You may reproduce and distribute copies of the 91 | Work or Derivative Works thereof in any medium, with or without 92 | modifications, and in Source or Object form, provided that You 93 | meet the following conditions: 94 | 95 | (a) You must give any other recipients of the Work or 96 | Derivative Works a copy of this License; and 97 | 98 | (b) You must cause any modified files to carry prominent notices 99 | stating that You changed the files; and 100 | 101 | (c) You must retain, in the Source form of any Derivative Works 102 | that You distribute, all copyright, patent, trademark, and 103 | attribution notices from the Source form of the Work, 104 | excluding those notices that do not pertain to any part of 105 | the Derivative Works; and 106 | 107 | (d) If the Work includes a "NOTICE" text file as part of its 108 | distribution, then any Derivative Works that You distribute must 109 | include a readable copy of the attribution notices contained 110 | within such NOTICE file, excluding those notices that do not 111 | pertain to any part of the Derivative Works, in at least one 112 | of the following places: within a NOTICE text file distributed 113 | as part of the Derivative Works; within the Source form or 114 | documentation, if provided along with the Derivative Works; or, 115 | within a display generated by the Derivative Works, if and 116 | wherever such third-party notices normally appear. The contents 117 | of the NOTICE file are for informational purposes only and 118 | do not modify the License. You may add Your own attribution 119 | notices within Derivative Works that You distribute, alongside 120 | or as an addendum to the NOTICE text from the Work, provided 121 | that such additional attribution notices cannot be construed 122 | as modifying the License. 123 | 124 | You may add Your own copyright statement to Your modifications and 125 | may provide additional or different license terms and conditions 126 | for use, reproduction, or distribution of Your modifications, or 127 | for any such Derivative Works as a whole, provided Your use, 128 | reproduction, and distribution of the Work otherwise complies with 129 | the conditions stated in this License. 130 | 131 | 5. Submission of Contributions. Unless You explicitly state otherwise, 132 | any Contribution intentionally submitted for inclusion in the Work 133 | by You to the Licensor shall be under the terms and conditions of 134 | this License, without any additional terms or conditions. 135 | Notwithstanding the above, nothing herein shall supersede or modify 136 | the terms of any separate license agreement you may have executed 137 | with Licensor regarding such Contributions. 138 | 139 | 6. Trademarks. This License does not grant permission to use the trade 140 | names, trademarks, service marks, or product names of the Licensor, 141 | except as required for reasonable and customary use in describing the 142 | origin of the Work and reproducing the content of the NOTICE file. 143 | 144 | 7. Disclaimer of Warranty. Unless required by applicable law or 145 | agreed to in writing, Licensor provides the Work (and each 146 | Contributor provides its Contributions) on an "AS IS" BASIS, 147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 148 | implied, including, without limitation, any warranties or conditions 149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 150 | PARTICULAR PURPOSE. You are solely responsible for determining the 151 | appropriateness of using or redistributing the Work and assume any 152 | risks associated with Your exercise of permissions under this License. 153 | 154 | 8. Limitation of Liability. In no event and under no legal theory, 155 | whether in tort (including negligence), contract, or otherwise, 156 | unless required by applicable law (such as deliberate and grossly 157 | negligent acts) or agreed to in writing, shall any Contributor be 158 | liable to You for damages, including any direct, indirect, special, 159 | incidental, or consequential damages of any character arising as a 160 | result of this License or out of the use or inability to use the 161 | Work (including but not limited to damages for loss of goodwill, 162 | work stoppage, computer failure or malfunction, or any and all 163 | other commercial damages or losses), even if such Contributor 164 | has been advised of the possibility of such damages. 165 | 166 | 9. Accepting Warranty or Additional Liability. While redistributing 167 | the Work or Derivative Works thereof, You may choose to offer, 168 | and charge a fee for, acceptance of support, warranty, indemnity, 169 | or other liability obligations and/or rights consistent with this 170 | License. However, in accepting such obligations, You may act only 171 | on Your own behalf and on Your sole responsibility, not on behalf 172 | of any other Contributor, and only if You agree to indemnify, 173 | defend, and hold each Contributor harmless for any liability 174 | incurred by, or claims asserted against, such Contributor by reason 175 | of your accepting any such warranty or additional liability. 176 | 177 | END OF TERMS AND CONDITIONS 178 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | recursive-include sagemaker_mlflow *.py 2 | 3 | recursive-include requirements * 4 | 5 | include VERSION 6 | include LICENSE 7 | include README.md 8 | 9 | prune test 10 | 11 | recursive-exclude * __pycache__ 12 | recursive-exclude * *.py[co] 13 | -------------------------------------------------------------------------------- /NOTICE: -------------------------------------------------------------------------------- 1 | SageMaker MLflow 2 | Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 3 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # SageMaker MLflow Plugin 2 | 3 | ## What does this Plugin do? 4 | 5 | This plugin generates Signature V4 headers in each outgoing request to the Amazon SageMaker with MLflow capability, 6 | determines the URL of capability to connect to tracking servers, and registers models to the SageMaker Model Registry. 7 | It generates a token with the SigV4 Algorithm that the service will use to conduct Authentication and Authorization 8 | using AWS IAM. 9 | 10 | ## Installation 11 | 12 | To install this plugin, run the following command inside the directory: 13 | ``` 14 | pip install . 15 | ``` 16 | 17 | Eventually when the plugin gets distributed, it will be installed with: 18 | ``` 19 | pip install sagemaker-mlflow 20 | ``` 21 | 22 | Running this will install the Auth Plugin and mlflow. 23 | 24 | To install a specific mlflow version 25 | 26 | ``` 27 | pip install . 28 | pip install mlflow==2.13 29 | ``` 30 | 31 | ## Development details 32 | 33 | ### setup.py 34 | 35 | `setup.py` Contains the primary entry points for the sdk. 36 | `install_requires` Installs mlflow. 37 | `entry_points` Contains the entry points for the sdk. See https://mlflow.org/docs/latest/plugins.html#defining-a-plugin 38 | for more details. 39 | 40 | ### Running tests 41 | 42 | #### Setup 43 | To run tests using tox, run: 44 | ``` 45 | pip install tox 46 | ``` 47 | Installing tox will enable users to run multi-environment tests. On the other hand, if 48 | running individual tests in a single environment, feel free to continue to use pytest instead. 49 | 50 | #### Running format checks 51 | ``` 52 | tox -e flake8,black-check,typing,twine 53 | ``` 54 | 55 | #### Formatting code to comply with format checks 56 | ``` 57 | tox -e black-format 58 | ``` 59 | 60 | #### Running unit tests 61 | ``` 62 | tox --skip-env "black.*|flake8|typing|twine" -- test/unit 63 | ``` 64 | 65 | #### Running integration tests 66 | ``` 67 | tox --skip-env "black.*|flake8|typing|twine" -- test/integration 68 | ``` 69 | 70 | #### Available test environments by default 71 | tox.ini contains support for py39, py310, py311, with mlflow 2.11.* and 2.12.*. 72 | To add test environments on tox for additional versions of python or mlflow, modify the 73 | environment configs in `envlist`, as well as `deps` and `depends` in `[testenv]`. 74 | -------------------------------------------------------------------------------- /VERSION: -------------------------------------------------------------------------------- 1 | 0.1.0.dev0 2 | -------------------------------------------------------------------------------- /mypy.ini: -------------------------------------------------------------------------------- 1 | [mypy] 2 | ignore_missing_imports = True 3 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.black] 2 | line-length = 120 3 | -------------------------------------------------------------------------------- /requirements/black_requirements.txt: -------------------------------------------------------------------------------- 1 | black 2 | -------------------------------------------------------------------------------- /requirements/flake8_requirements.txt: -------------------------------------------------------------------------------- 1 | flake8 2 | -------------------------------------------------------------------------------- /requirements/integration_test_requirements.txt: -------------------------------------------------------------------------------- 1 | boto3 2 | coverage>=5.2,<6.2 3 | mlflow 4 | pytest 5 | pytest-cov 6 | pytest-rerunfailures 7 | pytest-timeout 8 | pytest-xdist 9 | scikit-learn 10 | -------------------------------------------------------------------------------- /requirements/mypy_requirements.txt: -------------------------------------------------------------------------------- 1 | mypy 2 | types-requests 3 | -------------------------------------------------------------------------------- /requirements/prerelease_test_requirements.txt: -------------------------------------------------------------------------------- 1 | # Put mlflow in requirements because mlflow has a version dependency on packaging 2 | mlflow 3 | packaging 4 | pytest 5 | -------------------------------------------------------------------------------- /requirements/twine_requirements.txt: -------------------------------------------------------------------------------- 1 | twine 2 | setuptools 3 | -------------------------------------------------------------------------------- /sagemaker_mlflow/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"). You 4 | # may not use this file except in compliance with the License. A copy of 5 | # the License is located at 6 | # 7 | # http://aws.amazon.com/apache2.0/ 8 | # 9 | # or in the "license" file accompanying this file. This file is 10 | # distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF 11 | # ANY KIND, either express or implied. See the License for the specific 12 | # language governing permissions and limitations under the License. 13 | 14 | import importlib_metadata 15 | 16 | __version__ = importlib_metadata.version("sagemaker-mlflow") 17 | -------------------------------------------------------------------------------- /sagemaker_mlflow/auth.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"). You 4 | # may not use this file except in compliance with the License. A copy of 5 | # the License is located at 6 | # 7 | # http://aws.amazon.com/apache2.0/ 8 | # 9 | # or in the "license" file accompanying this file. This file is 10 | # distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF 11 | # ANY KIND, either express or implied. See the License for the specific 12 | # language governing permissions and limitations under the License. 13 | 14 | import boto3 15 | from requests.auth import AuthBase 16 | from requests.models import PreparedRequest 17 | 18 | from botocore.auth import SigV4Auth 19 | from botocore.awsrequest import AWSRequest 20 | from hashlib import sha256 21 | import functools 22 | 23 | SERVICE_NAME = "sagemaker-mlflow" 24 | PAYLOAD_BUFFER = 1024 * 1024 25 | # Hardcode SHA256 hash for empty string to reduce latency for requests without a body 26 | EMPTY_SHA256_HASH = "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855" 27 | 28 | 29 | class AuthBoto(AuthBase): 30 | 31 | def __init__(self, region: str): 32 | """Constructor for Authorization Mechanism 33 | :param region: AWS region eg us-west-2 34 | """ 35 | 36 | session = boto3.Session() 37 | self.creds = session.get_credentials() 38 | self.region = region 39 | self.sigv4 = SigV4Auth(self.creds, SERVICE_NAME, self.region) 40 | 41 | def __call__(self, r: PreparedRequest) -> PreparedRequest: 42 | """Method to return the prepared request 43 | :param r: PreparedRequest Base mlflow request 44 | :return: PreparedRequest Request with SigV4 signed headers 45 | """ 46 | 47 | url = r.url 48 | method = r.method 49 | headers = r.headers 50 | request_body = r.body 51 | connection_header = headers["Connection"] 52 | 53 | headers["X-Amz-Content-SHA256"] = self.get_request_body_header(request_body) 54 | 55 | # SageMaker Mlflow strips out this header before auth. 56 | # But boto signs every header even its its uppercase or lower cased. 57 | if "Connection" in headers: 58 | connection_header = headers["Connection"] 59 | del headers["Connection"] 60 | 61 | # Mlflow encodes spaces as +, Auth prefers %20 62 | if method == "GET" or method == "DELETE": 63 | url = url.replace("+", "%20") 64 | 65 | # Creating a new request with the SigV4 signed headers. 66 | aws_request = AWSRequest(method=method, url=url, data=r.body, headers=headers) 67 | self.sigv4.add_auth(aws_request) 68 | 69 | # Adding back in the connection header. 70 | final_headers = aws_request.headers 71 | final_headers["Connection"] = connection_header 72 | final_request = AWSRequest( 73 | method=method, url=url, data=r.body, headers=final_headers 74 | ) 75 | 76 | return final_request.prepare() 77 | 78 | def get_request_body_header(self, request_body: bytes): 79 | """Stripped down version of the botocore method. 80 | :param request_body: request body 81 | :return: hex_checksum 82 | """ 83 | if request_body and hasattr(request_body, "seek"): 84 | position = request_body.tell() 85 | read_chunksize = functools.partial(request_body.read, PAYLOAD_BUFFER) 86 | checksum = sha256() 87 | for chunk in iter(read_chunksize, b""): 88 | checksum.update(chunk) 89 | hex_checksum = checksum.hexdigest() 90 | request_body.seek(position) 91 | return hex_checksum 92 | elif request_body: 93 | # The request serialization has ensured that 94 | # request.body is a bytes() type. 95 | return sha256(request_body).hexdigest() 96 | else: 97 | return EMPTY_SHA256_HASH 98 | -------------------------------------------------------------------------------- /sagemaker_mlflow/auth_provider.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"). You 4 | # may not use this file except in compliance with the License. A copy of 5 | # the License is located at 6 | # 7 | # http://aws.amazon.com/apache2.0/ 8 | # 9 | # or in the "license" file accompanying this file. This file is 10 | # distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF 11 | # ANY KIND, either express or implied. See the License for the specific 12 | # language governing permissions and limitations under the License. 13 | 14 | from sagemaker_mlflow.auth import AuthBoto 15 | from mlflow import get_tracking_uri 16 | from sagemaker_mlflow.mlflow_sagemaker_helpers import validate_and_parse_arn 17 | 18 | AWS_SIGV4_PLUGIN_NAME = "arn" 19 | 20 | class AuthProvider: 21 | """Entry Point class to using the plugin. mlflow will call get_name 22 | to determine the name of the plugin. get_auth will be called 23 | when creating the request to put a callback class that will 24 | generate the Sig v4 token. 25 | """ 26 | 27 | def get_name(self) -> str: 28 | """Returns the name of the plugin""" 29 | return AWS_SIGV4_PLUGIN_NAME 30 | 31 | def get_auth(self) -> AuthBoto: 32 | """Returns the callback class(AuthBoto) used for generating the SigV4 header. 33 | 34 | Returns: 35 | AuthBoto: Callback Object which will calculate the header just before request submission. 36 | """ 37 | 38 | arn = validate_and_parse_arn(get_tracking_uri()) 39 | return AuthBoto(arn.region) 40 | -------------------------------------------------------------------------------- /sagemaker_mlflow/exceptions.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"). You 4 | # may not use this file except in compliance with the License. A copy of 5 | # the License is located at 6 | # 7 | # http://aws.amazon.com/apache2.0/ 8 | # 9 | # or in the "license" file accompanying this file. This file is 10 | # distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF 11 | # ANY KIND, either express or implied. See the License for the specific 12 | # language governing permissions and limitations under the License. 13 | 14 | 15 | class MlflowSageMakerException(Exception): 16 | pass 17 | -------------------------------------------------------------------------------- /sagemaker_mlflow/mlflow_sagemaker_helpers.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"). You 4 | # may not use this file except in compliance with the License. A copy of 5 | # the License is located at 6 | # 7 | # http://aws.amazon.com/apache2.0/ 8 | # 9 | # or in the "license" file accompanying this file. This file is 10 | # distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF 11 | # ANY KIND, either express or implied. See the License for the specific 12 | # language governing permissions and limitations under the License. 13 | 14 | from sagemaker_mlflow.exceptions import MlflowSageMakerException 15 | import os 16 | import logging 17 | 18 | class Arn: 19 | 20 | """ Constructor for Arn Object 21 | Args: 22 | tracking_server_arn (str): Tracking Server Arn 23 | """ 24 | def __init__(self, arn: str): 25 | splitted_arn = arn.split(":") 26 | self.partition = splitted_arn[1] 27 | self.service = splitted_arn[2] 28 | self.region = splitted_arn[3] 29 | self.account = splitted_arn[4] 30 | self.resource_type = splitted_arn[5].split("/")[0] 31 | self.resource_id = splitted_arn[5].split("/")[1] 32 | 33 | 34 | def validate_and_parse_arn(tracking_server_arn: str) -> Arn: 35 | """Validates and returns an arn from a string. 36 | 37 | Args: 38 | tracking_server_arn (str): Tracking Server Arn 39 | Returns: 40 | Arn: Arn Object 41 | """ 42 | arn = Arn(tracking_server_arn) 43 | if ( 44 | arn.service != "sagemaker" 45 | or not arn.resource_type 46 | or not arn.resource_id 47 | ): 48 | raise MlflowSageMakerException(f"{tracking_server_arn} is not a valid arn") 49 | return arn 50 | 51 | def get_tracking_server_url(tracking_server_arn: str) -> str: 52 | """Returns the url used by SageMaker MLflow 53 | 54 | Args: 55 | tracking_server_arn (str): Tracking Server Arn 56 | Returns: 57 | str: Tracking Server URL. 58 | """ 59 | custom_endpoint = os.environ.get("SAGEMAKER_MLFLOW_CUSTOM_ENDPOINT", "") 60 | if custom_endpoint: 61 | logging.info(f"Using custom endpoint {custom_endpoint}") 62 | return custom_endpoint 63 | arn = validate_and_parse_arn(tracking_server_arn) 64 | dns_suffix = get_dns_suffix(arn.partition) 65 | endpoint = f"https://{arn.region}.experiments.sagemaker.{dns_suffix}" 66 | return endpoint 67 | 68 | 69 | def get_dns_suffix(partition: str) -> str: 70 | """Returns a DNS suffix for a partition 71 | 72 | Args: 73 | partition (str): Partition that the tracking server resides in. 74 | Returns: 75 | str: DNS suffix of the partition 76 | """ 77 | if partition == "aws": 78 | return "aws" 79 | else: 80 | raise MlflowSageMakerException(f"Partition {partition} Not supported.") 81 | -------------------------------------------------------------------------------- /sagemaker_mlflow/mlflow_sagemaker_registry_store.py: -------------------------------------------------------------------------------- 1 | from mlflow.store.model_registry.rest_store import RestStore 2 | from sagemaker_mlflow.mlflow_sagemaker_helpers import get_tracking_server_url 3 | from mlflow.utils import rest_utils 4 | from functools import partial 5 | import os 6 | 7 | class MlflowSageMakerRegistryStore(RestStore): 8 | 9 | store_uri = "" 10 | 11 | def __init__(self, store_uri): 12 | self.store_uri = store_uri 13 | super().__init__(partial(get_host_creds, store_uri)) 14 | 15 | 16 | # Extra environment variables which take precedence for setting the basic/bearer 17 | # auth on http requests. 18 | _TRACKING_USERNAME_ENV_VAR = "MLFLOW_TRACKING_USERNAME" 19 | _TRACKING_PASSWORD_ENV_VAR = "MLFLOW_TRACKING_PASSWORD" 20 | _TRACKING_TOKEN_ENV_VAR = "MLFLOW_TRACKING_TOKEN" 21 | 22 | # sets verify param of 'requests.request' function 23 | # see https://requests.readthedocs.io/en/master/api/ 24 | _TRACKING_INSECURE_TLS_ENV_VAR = "MLFLOW_TRACKING_INSECURE_TLS" 25 | _TRACKING_SERVER_CERT_PATH_ENV_VAR = "MLFLOW_TRACKING_SERVER_CERT_PATH" 26 | 27 | # sets cert param of 'requests.request' function 28 | # see https://requests.readthedocs.io/en/master/api/ 29 | _TRACKING_CLIENT_CERT_PATH_ENV_VAR = "MLFLOW_TRACKING_CLIENT_CERT_PATH" 30 | 31 | 32 | def get_host_creds(store_uri): 33 | """Configuring mlflow's client""" 34 | tracking_server_url = get_tracking_server_url(store_uri) 35 | return rest_utils.MlflowHostCreds( 36 | host=tracking_server_url, 37 | username=os.environ.get(_TRACKING_USERNAME_ENV_VAR), 38 | password=os.environ.get(_TRACKING_PASSWORD_ENV_VAR), 39 | token=os.environ.get(_TRACKING_TOKEN_ENV_VAR), 40 | auth="arn", 41 | # aws_sigv4="False", # Auth provider is used instead for now 42 | ignore_tls_verification=os.environ.get(_TRACKING_INSECURE_TLS_ENV_VAR) 43 | == "true", 44 | client_cert_path=os.environ.get(_TRACKING_CLIENT_CERT_PATH_ENV_VAR), 45 | server_cert_path=os.environ.get(_TRACKING_SERVER_CERT_PATH_ENV_VAR), 46 | ) 47 | -------------------------------------------------------------------------------- /sagemaker_mlflow/mlflow_sagemaker_request_header_provider.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"). You 4 | # may not use this file except in compliance with the License. A copy of 5 | # the License is located at 6 | # 7 | # http://aws.amazon.com/apache2.0/ 8 | # 9 | # or in the "license" file accompanying this file. This file is 10 | # distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF 11 | # ANY KIND, either express or implied. See the License for the specific 12 | # language governing permissions and limitations under the License. 13 | 14 | 15 | from mlflow.tracking.request_header.abstract_request_header_provider import RequestHeaderProvider 16 | from mlflow import get_tracking_uri 17 | 18 | class MlflowSageMakerRequestHeaderProvider(RequestHeaderProvider): 19 | """RequestHeaderProvider provided through plugin system""" 20 | 21 | def in_context(self): 22 | """Activates the plugin""" 23 | return True 24 | 25 | def request_headers(self): 26 | """Returns plugin headers used by SageMaker MLflow 27 | 28 | Returns: 29 | dict: Dictionary containing the headers that are needed for routing. 30 | """ 31 | return { "x-mlflow-sm-tracking-server-arn": get_tracking_uri() } 32 | -------------------------------------------------------------------------------- /sagemaker_mlflow/mlflow_sagemaker_store.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"). You 4 | # may not use this file except in compliance with the License. A copy of 5 | # the License is located at 6 | # 7 | # http://aws.amazon.com/apache2.0/ 8 | # 9 | # or in the "license" file accompanying this file. This file is 10 | # distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF 11 | # ANY KIND, either express or implied. See the License for the specific 12 | # language governing permissions and limitations under the License. 13 | 14 | from mlflow.store.tracking.rest_store import RestStore 15 | from mlflow.utils import rest_utils 16 | import os 17 | from functools import partial 18 | from sagemaker_mlflow.mlflow_sagemaker_helpers import get_tracking_server_url 19 | 20 | # Extra environment variables which take precedence for setting the basic/bearer 21 | # auth on http requests. 22 | _TRACKING_USERNAME_ENV_VAR = "MLFLOW_TRACKING_USERNAME" 23 | _TRACKING_PASSWORD_ENV_VAR = "MLFLOW_TRACKING_PASSWORD" 24 | _TRACKING_TOKEN_ENV_VAR = "MLFLOW_TRACKING_TOKEN" 25 | 26 | # sets verify param of 'requests.request' function 27 | # see https://requests.readthedocs.io/en/master/api/ 28 | _TRACKING_INSECURE_TLS_ENV_VAR = "MLFLOW_TRACKING_INSECURE_TLS" 29 | _TRACKING_SERVER_CERT_PATH_ENV_VAR = "MLFLOW_TRACKING_SERVER_CERT_PATH" 30 | 31 | # sets cert param of 'requests.request' function 32 | # see https://requests.readthedocs.io/en/master/api/ 33 | _TRACKING_CLIENT_CERT_PATH_ENV_VAR = "MLFLOW_TRACKING_CLIENT_CERT_PATH" 34 | 35 | 36 | class MlflowSageMakerStore(RestStore): 37 | store_uri = "" 38 | 39 | def __init__(self, store_uri, artifact_uri): 40 | self.store_uri = store_uri 41 | super().__init__(partial(get_host_creds, store_uri)) 42 | 43 | 44 | def get_host_creds(store_uri) -> rest_utils.MlflowHostCreds: 45 | """Configuring mlflow's client""" 46 | tracking_server_url = get_tracking_server_url(store_uri) 47 | return rest_utils.MlflowHostCreds( 48 | host=tracking_server_url, 49 | username=os.environ.get(_TRACKING_USERNAME_ENV_VAR), 50 | password=os.environ.get(_TRACKING_PASSWORD_ENV_VAR), 51 | token=os.environ.get(_TRACKING_TOKEN_ENV_VAR), 52 | auth="arn", 53 | # aws_sigv4="False", # Auth provider is used instead for now 54 | ignore_tls_verification=os.environ.get(_TRACKING_INSECURE_TLS_ENV_VAR) 55 | == "true", 56 | client_cert_path=os.environ.get(_TRACKING_CLIENT_CERT_PATH_ENV_VAR), 57 | server_cert_path=os.environ.get(_TRACKING_SERVER_CERT_PATH_ENV_VAR), 58 | ) 59 | -------------------------------------------------------------------------------- /sagemaker_mlflow/presigned_url.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"). You 4 | # may not use this file except in compliance with the License. A copy of 5 | # the License is located at 6 | # 7 | # http://aws.amazon.com/apache2.0/ 8 | # 9 | # or in the "license" file accompanying this file. This file is 10 | # distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF 11 | # ANY KIND, either express or implied. See the License for the specific 12 | # language governing permissions and limitations under the License. 13 | 14 | import os 15 | 16 | import boto3 17 | import mlflow 18 | from sagemaker_mlflow.mlflow_sagemaker_helpers import validate_and_parse_arn 19 | 20 | 21 | def get_presigned_url(url_expiration_duration=300, session_duration=5000) -> str: 22 | """ Creates a presigned url 23 | 24 | :param url_expiration_duration: First use expiration time of the presigned url 25 | :param session_duration: Session duration of the presigned url 26 | 27 | :returns: Authorized Url 28 | 29 | """ 30 | arn = validate_and_parse_arn(mlflow.get_tracking_uri()) 31 | custom_endpoint = os.environ.get("SAGEMAKER_ENDPOINT_URL", "") 32 | if not custom_endpoint: 33 | sagemaker_client = boto3.client("sagemaker", region_name=arn.region) 34 | else: 35 | sagemaker_client = boto3.client("sagemaker", endpoint_url=custom_endpoint, region_name=arn.region) 36 | 37 | config = { 38 | "TrackingServerName": arn.resource_id, 39 | "ExpiresInSeconds": url_expiration_duration, 40 | "SessionExpirationDurationInSeconds": session_duration 41 | } 42 | response = sagemaker_client.create_presigned_mlflow_tracking_server_url(**config) 43 | return response["AuthorizedUrl"] 44 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [metadata] 2 | description-file = README.md 3 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"). You 4 | # may not use this file except in compliance with the License. A copy of 5 | # the License is located at 6 | # 7 | # http://aws.amazon.com/apache2.0/ 8 | # 9 | # or in the "license" file accompanying this file. This file is 10 | # distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF 11 | # ANY KIND, either express or implied. See the License for the specific 12 | # language governing permissions and limitations under the License. 13 | 14 | """ 15 | sagemaker-mlflow plugin installation. 16 | """ 17 | import os 18 | 19 | from setuptools import setup, find_packages 20 | 21 | 22 | def read(fname): 23 | """ 24 | Args: 25 | fname: 26 | """ 27 | with open(os.path.join(os.path.dirname(__file__), fname), "r") as f: 28 | contents = f.read() 29 | return contents 30 | 31 | 32 | def read_version(): 33 | return read("VERSION").strip() 34 | 35 | 36 | def read_requirements(filename): 37 | """Reads requirements file which lists package dependencies. 38 | 39 | Args: 40 | filename: type(str) Relative file path of requirements.txt file 41 | 42 | Returns: 43 | list of dependencies extracted from file 44 | """ 45 | with open(os.path.abspath(filename)) as fp: 46 | deps = [line.strip() for line in fp.readlines()] 47 | return deps 48 | 49 | 50 | test_requirements = read_requirements("requirements/integration_test_requirements.txt") 51 | test_prerelease_requirements = read_requirements("requirements/prerelease_test_requirements.txt") 52 | 53 | setup( 54 | name="sagemaker-mlflow", 55 | packages=find_packages(), 56 | author="Amazon Web Services", 57 | license="Apache License 2.0", 58 | url = 'https://github.com/aws/sagemaker-mlflow', 59 | classifiers=[ 60 | "Development Status :: 5 - Production/Stable", 61 | "Intended Audience :: Developers", 62 | "Natural Language :: English", 63 | "License :: OSI Approved :: Apache Software License", 64 | "Programming Language :: Python", 65 | "Programming Language :: Python :: 3.8", 66 | "Programming Language :: Python :: 3.9", 67 | "Programming Language :: Python :: 3.10", 68 | "Programming Language :: Python :: 3.11", 69 | ], 70 | # Require MLflow as a dependency of the plugin, so that plugin users can 71 | # simply install the plugin and then immediately use it with MLflow 72 | install_requires=["boto3>=1.34", "mlflow>=2.8"], 73 | extras_require={"test": test_requirements, "test_prerelease": test_prerelease_requirements}, 74 | python_requires=">= 3.8", 75 | entry_points={ 76 | "mlflow.tracking_store": "arn=sagemaker_mlflow.mlflow_sagemaker_store:MlflowSageMakerStore", 77 | "mlflow.request_auth_provider": "arn=sagemaker_mlflow.auth_provider:AuthProvider", 78 | "mlflow.request_header_provider": "arn=sagemaker_mlflow.mlflow_sagemaker_request_header_provider:MlflowSageMakerRequestHeaderProvider", 79 | "mlflow.model_registry_store": "arn=sagemaker_mlflow.mlflow_sagemaker_registry_store:MlflowSageMakerRegistryStore" 80 | }, 81 | version=read_version(), 82 | description="AWS Plugin for MLflow with SageMaker", 83 | long_description=read("README.md"), 84 | long_description_content_type="text/markdown", 85 | ) 86 | -------------------------------------------------------------------------------- /test/integration/README.md: -------------------------------------------------------------------------------- 1 | # SageMaker MLflow integration tests 2 | 3 | ## Usage 4 | 5 | ### Using a pre-created Tracking Server 6 | 7 | - Create an mlflow tracking server and note down its arn. 8 | - Set `MLFLOW_TRACKING_SERVER_URI` to the arn. 9 | - In the `test/integration` directory, run `pytest` (You may need to run python`(python version)` -m pytest) 10 | -------------------------------------------------------------------------------- /test/integration/conftest.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pytest 3 | from mlflow import MlflowClient 4 | 5 | from utils.boto_utils import get_default_region, get_account_id 6 | 7 | 8 | """ Default tracking server that a user can create 9 | """ 10 | @pytest.fixture(scope="module") 11 | def tracking_server(): 12 | server_arn = os.environ.get("MLFLOW_TRACKING_SERVER_URI", "") 13 | if not server_arn: 14 | server_name = os.environ.get("MLFLOW_TRACKING_SERVER_NAME", "") 15 | if server_name: 16 | region = get_default_region() 17 | account_id = get_account_id() 18 | # Reconstruct server arn from env variables 19 | server_arn = f"arn:aws:sagemaker:{region}:{account_id}:mlflow-tracking-server/{server_name}" 20 | else: 21 | server_arn = create_tracking_server() 22 | os.environ["MLFLOW_TRACKING_SERVER_URI"] = server_arn 23 | return server_arn 24 | 25 | 26 | """ Mlflow Client Fixture 27 | """ 28 | @pytest.fixture 29 | def mlflow_client() -> MlflowClient: 30 | return MlflowClient() 31 | 32 | 33 | def create_tracking_server() -> str: 34 | # TODO: Implement 35 | return "not implemented" 36 | -------------------------------------------------------------------------------- /test/integration/tests/test_artifact_logging.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | 4 | import mlflow 5 | import pytest 6 | 7 | from utils.boto_utils import get_file_data_from_s3 8 | from utils.random_utils import generate_uuid, generate_random_float 9 | 10 | """ Test that artifacts are being persisted correctly with an 11 | Amazon S3 based artifact store. This test is only for artifact 12 | stores in which the client goes through S3. 13 | """ 14 | class TestArtifactLogging: 15 | 16 | @pytest.fixture(scope="class") 17 | def setup(self, tracking_server): 18 | # TODO: Verify that tracking server is created 19 | mlflow.set_tracking_uri(tracking_server) 20 | 21 | def test_log_artifact(self, setup): 22 | # Create a random file 23 | file_name = f"{generate_uuid(20)}.txt" 24 | file_contents = "".join([generate_uuid(40) for i in range(5)]) 25 | logging.info(f"Writing {file_contents} to {file_name}") 26 | with open(file_name, "wb") as f: 27 | f.write(bytes(file_contents.encode("utf-8"))) 28 | 29 | current_run = None 30 | with mlflow.start_run(): 31 | mlflow.log_artifact(file_name) 32 | current_run = mlflow.active_run() 33 | 34 | assert current_run is not None 35 | 36 | run_artifact_location = current_run.info.artifact_uri 37 | split_location = run_artifact_location.replace("s3://", "").split("/") 38 | split_location.append(file_name) 39 | bucket = split_location[0] 40 | prefix = "/".join(split_location[1:]) 41 | data = get_file_data_from_s3(bucket, prefix) 42 | data = data.read().decode("ascii") 43 | assert data == file_contents 44 | 45 | os.remove(file_name) 46 | -------------------------------------------------------------------------------- /test/integration/tests/test_metadata_logging.py: -------------------------------------------------------------------------------- 1 | import mlflow 2 | import pytest 3 | 4 | from utils.random_utils import generate_uuid, generate_random_float 5 | 6 | TEST_METRIC_NAME = "test_metadata_metric" 7 | 8 | """ Test Metadata modification, ensure that requests get properly routed 9 | and that SigV4 header calculation works with SageMaker Mlflow. 10 | """ 11 | class TestMetadataLogging: 12 | 13 | @pytest.fixture(scope="class") 14 | def setup(self, tracking_server): 15 | # TODO: Verify that tracking server is created 16 | mlflow.set_tracking_uri(tracking_server) 17 | 18 | def test_log_metric(self, setup, mlflow_client): 19 | random_tag = generate_uuid(32) 20 | tags = {"purpose": random_tag} 21 | metric_value = generate_random_float() 22 | 23 | run = mlflow_client.create_run("0", tags=tags) 24 | mlflow_client.log_metric( 25 | run.info.run_id, TEST_METRIC_NAME, metric_value, step=0 26 | ) 27 | metric = list(mlflow_client.get_metric_history(run.info.run_id, TEST_METRIC_NAME))[0] 28 | 29 | assert "s3" in run.info.artifact_uri 30 | assert run.data.tags["purpose"] == tags["purpose"] 31 | assert metric.key == TEST_METRIC_NAME 32 | assert metric.value == metric_value 33 | assert run.info.experiment_id == "0" 34 | -------------------------------------------------------------------------------- /test/integration/tests/test_model_registry.py: -------------------------------------------------------------------------------- 1 | import mlflow 2 | import pytest 3 | 4 | from mlflow.models import infer_signature 5 | 6 | from utils.random_utils import generate_uuid 7 | from utils.sklearn_utils import train_iris_logistic_regression_model 8 | 9 | """ This test makes sure that registering models works. 10 | """ 11 | class TestModelRegistry: 12 | 13 | @pytest.fixture(scope="class") 14 | def setup(self, tracking_server): 15 | # TODO: Verify that tracking server is created 16 | mlflow.set_tracking_uri(tracking_server) 17 | 18 | def test_model_registry(self, setup, mlflow_client): 19 | # Start an MLflow run 20 | registered_model_name = generate_uuid(20) 21 | X_train, lr = train_iris_logistic_regression_model() 22 | with mlflow.start_run(): 23 | signature = infer_signature(X_train, lr.predict(X_train)) 24 | 25 | # Log/Register the model 26 | mlflow.sklearn.log_model( 27 | sk_model=lr, 28 | artifact_path="iris_model", 29 | signature=signature, 30 | input_example=X_train, 31 | registered_model_name=registered_model_name, 32 | ) 33 | registered_model = mlflow_client.get_registered_model(registered_model_name) 34 | assert registered_model.name == registered_model_name 35 | mlflow_client.delete_registered_model(registered_model_name) 36 | -------------------------------------------------------------------------------- /test/integration/tests/test_presigned_url_from_server.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import sagemaker_mlflow 4 | import mlflow 5 | import pytest 6 | import requests 7 | 8 | import sagemaker_mlflow.presigned_url 9 | 10 | 11 | """ This test makes sure that getting the presigned url works. 12 | """ 13 | class TestPresignedUrl: 14 | 15 | @pytest.fixture(scope="class") 16 | def setup(self, tracking_server): 17 | # TODO: Verify that tracking server is created 18 | mlflow.set_tracking_uri(tracking_server) 19 | 20 | # Restrict to local environments for now, remove after GA. 21 | @pytest.mark.skipif(os.environ.get("CODEBUILD_BUILD_ARN", "") != "", reason="Codebuild might not have the right API shape") 22 | def test_presigned_url(self, setup): 23 | presigned_url = sagemaker_mlflow.presigned_url.get_presigned_url() 24 | response = requests.get(url=presigned_url) 25 | assert response.status_code == 200 26 | -------------------------------------------------------------------------------- /test/integration/utils/boto_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import boto3 4 | 5 | 6 | def get_default_region(): 7 | return os.environ.get("REGION", "us-east-2") 8 | 9 | 10 | def get_account_id(): 11 | region = get_default_region() 12 | sts = boto3.client("sts", region_name=region) 13 | return sts.get_caller_identity()["Account"] 14 | 15 | 16 | def get_s3_client(): 17 | return boto3.client("s3") 18 | 19 | 20 | def get_file_data_from_s3(bucket: str, key: str): 21 | s3_client = get_s3_client() 22 | response = s3_client.get_object(Bucket=bucket, Key=key) 23 | return response["Body"] 24 | -------------------------------------------------------------------------------- /test/integration/utils/random_utils.py: -------------------------------------------------------------------------------- 1 | import uuid 2 | import random 3 | 4 | """ Generates a random uuid that can be 64 characters or shorter. 5 | """ 6 | 7 | 8 | def generate_uuid(num_chars: int) -> str: 9 | if num_chars > 64: 10 | raise Exception("Must be 64 characters or under") 11 | return str(uuid.uuid4())[:num_chars] 12 | 13 | 14 | """ Generates a random floating point integer 15 | """ 16 | 17 | 18 | def generate_random_float() -> float: 19 | return random.uniform(1.0, 20.0) 20 | -------------------------------------------------------------------------------- /test/integration/utils/sklearn_utils.py: -------------------------------------------------------------------------------- 1 | from sklearn import datasets 2 | from sklearn.model_selection import train_test_split 3 | from sklearn.linear_model import LogisticRegression 4 | 5 | 6 | def train_iris_logistic_regression_model(): 7 | # Load the Iris dataset 8 | X, y = datasets.load_iris(return_X_y=True) 9 | 10 | # Split the data into training and test sets 11 | X_train, _, y_train, _ = train_test_split(X, y, test_size=0.2, random_state=42) 12 | 13 | # Define the model hyperparameters 14 | params = { 15 | "solver": "lbfgs", 16 | "max_iter": 1000, 17 | "multi_class": "auto", 18 | "random_state": 8888, 19 | } 20 | 21 | # Train the model 22 | lr = LogisticRegression(**params) 23 | lr.fit(X_train, y_train) 24 | return X_train, lr 25 | -------------------------------------------------------------------------------- /test/prerelease/test_release_version.py: -------------------------------------------------------------------------------- 1 | import sagemaker_mlflow 2 | 3 | from packaging.version import Version 4 | 5 | 6 | def test_release_version(): 7 | plugin_version = Version(sagemaker_mlflow.__version__) 8 | assert not plugin_version.is_devrelease, f"sagemaker_mlflow version is dev - {plugin_version}" 9 | assert ( 10 | not plugin_version.is_prerelease 11 | ), f"sagemaker_mlflow version is prerelease - {plugin_version}" 12 | assert ( 13 | not plugin_version.is_postrelease 14 | ), f"sagemaker_mlflow version is postrelease - {plugin_version}" 15 | -------------------------------------------------------------------------------- /test/unit/test_auth_boto.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | from unittest.mock import patch, Mock 3 | from botocore.awsrequest import AWSRequest 4 | from requests import PreparedRequest 5 | 6 | from sagemaker_mlflow.auth import AuthBoto, EMPTY_SHA256_HASH 7 | 8 | 9 | class TestAuthBoto(unittest.TestCase): 10 | 11 | @patch("boto3.Session") 12 | def test_init(self, mock_session): 13 | # Arrange 14 | mock_session_instance = mock_session.return_value 15 | mock_get_credentials = mock_session_instance.get_credentials 16 | mock_credentials = Mock() 17 | mock_get_credentials.return_value = mock_credentials 18 | region = "us-west-2" 19 | 20 | # Act 21 | auth_boto = AuthBoto(region) 22 | 23 | # Assert 24 | self.assertEqual(auth_boto.region, region) 25 | self.assertEqual(auth_boto.creds, mock_credentials) 26 | mock_session.assert_called_once() 27 | mock_get_credentials.assert_called_once() 28 | 29 | def test_call(self): 30 | # Arrange 31 | region = "us-west-2" 32 | auth_boto = AuthBoto(region) 33 | 34 | mock_sigv4 = Mock() 35 | auth_boto.sigv4 = mock_sigv4 36 | auth_boto.creds = Mock() 37 | 38 | url = "https://example.com/path" 39 | method = "GET" 40 | header_value = "test-value" 41 | headers = {"Connection": "keep-alive", "x-sagemaker": header_value} 42 | body = None 43 | prepared_request = PreparedRequest() 44 | prepared_request.prepare(url=url, method=method, headers=headers, data=body) 45 | 46 | expected_headers = { 47 | "X-Amz-Content-SHA256": EMPTY_SHA256_HASH, 48 | "Connection": "keep-alive", 49 | "x-sagemaker": header_value, 50 | } 51 | expected_aws_request = AWSRequest( 52 | method=method, 53 | url=url.replace("+", "%20"), 54 | headers=expected_headers, 55 | data=body, 56 | ) 57 | 58 | # Act 59 | result = auth_boto(prepared_request) 60 | 61 | # Assert 62 | for header in result.headers: 63 | self.assertTrue(header in expected_headers) 64 | 65 | self.assertEqual(result.body, expected_aws_request.data) 66 | self.assertEqual(result.method, method) 67 | self.assertEqual(result.url, url.replace("+", "%20")) 68 | 69 | def test_get_request_body_header(self): 70 | # Arrange 71 | region = "us-west-2" 72 | auth_boto = AuthBoto(region) 73 | request_body = b"test_body" 74 | expected_hash = ( 75 | "4443c6a8412e6c11f324c870a8366d6ede75e7f9ed12f00c36b88d479df371d6" 76 | ) 77 | 78 | # Act 79 | result = auth_boto.get_request_body_header(request_body) 80 | 81 | # Assert 82 | self.assertEqual(result, expected_hash) 83 | 84 | def test_get_request_body_header_empty(self): 85 | # Arrange 86 | region = "us-west-2" 87 | auth_boto = AuthBoto(region) 88 | request_body = b"" 89 | 90 | # Act 91 | result = auth_boto.get_request_body_header(request_body) 92 | 93 | # Assert 94 | self.assertEqual(result, EMPTY_SHA256_HASH) 95 | 96 | 97 | if __name__ == "__main__": 98 | unittest.main() 99 | -------------------------------------------------------------------------------- /test/unit/test_auth_provider.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | from unittest import mock, TestCase 3 | 4 | from sagemaker_mlflow.auth_provider import AuthProvider 5 | import os 6 | 7 | 8 | class AuthProviderTest(TestCase): 9 | def test_auth_provider_returns_correct_name(self): 10 | auth_provider = AuthProvider() 11 | auth_provider_name = auth_provider.get_name() 12 | self.assertEqual(auth_provider_name, "arn") 13 | 14 | @mock.patch.dict( 15 | os.environ, 16 | { 17 | "AWS_ACCESS_KEY_ID": "default_ak", 18 | "AWS_SECRET_ACCESS_KEY": "default_sk", 19 | "AWS_DEFAULT_REGION": "us-east-2", 20 | "AWS_SESSION_TOKEN": "", 21 | "MLFLOW_TRACKING_URI": "arn:aws:sagemaker:us-east-2:000000000000:mlflow-tracking-server/mw" 22 | }, 23 | ) 24 | def test_auth_provider_returns_correct_sigv4(self): 25 | auth_provider = AuthProvider() 26 | result = auth_provider.get_auth() 27 | 28 | self.assertEqual(result.region, "us-east-2") 29 | 30 | @mock.patch.dict( 31 | os.environ, 32 | { 33 | "AWS_ACCESS_KEY_ID": "default_ak", 34 | "AWS_SECRET_ACCESS_KEY": "default_sk", 35 | "AWS_DEFAULT_REGION": "us-east-1", 36 | "AWS_SESSION_TOKEN": "dcs", 37 | "MLFLOW_TRACKING_URI": "arn:aws:sagemaker:us-east-2:000000000001:mlflow-tracking-server/mw" 38 | }, 39 | ) 40 | def test_auth_provider_returns_correct_sigv4_session_different_region(self): 41 | auth_provider = AuthProvider() 42 | result = auth_provider.get_auth() 43 | 44 | self.assertEqual(result.region, "us-east-2") 45 | 46 | 47 | if __name__ == "__main__": 48 | unittest.main() 49 | -------------------------------------------------------------------------------- /test/unit/test_mlflow_sagemaker_helpers.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | from sagemaker_mlflow.mlflow_sagemaker_helpers import validate_and_parse_arn, get_tracking_server_url, get_dns_suffix, Arn 4 | from sagemaker_mlflow.exceptions import MlflowSageMakerException 5 | from unittest import TestCase 6 | from unittest import mock 7 | import os 8 | 9 | 10 | TEST_VALID_ARN = "arn:aws:sagemaker:us-west-2:000000000000:mlflow-tracking-server/xw" 11 | 12 | 13 | class MlflowSageMakerHelpersTest(TestCase): 14 | 15 | def test_validate_and_parse_arn_happy(self): 16 | arn = TEST_VALID_ARN 17 | result = validate_and_parse_arn(arn) 18 | assert type(result) is Arn 19 | assert result.partition == "aws" 20 | assert result.service == "sagemaker" 21 | assert result.region == "us-west-2" 22 | assert result.account == "000000000000" 23 | assert result.resource_type == "mlflow-tracking-server" 24 | assert result.resource_id == "xw" 25 | 26 | def test_validate_and_parse_arn_invalid_service(self): 27 | arn = "arn:aws:wagemaker:us-west-2:000000000000:mlflow-tracking-server/xw" 28 | with self.assertRaises(MlflowSageMakerException): 29 | validate_and_parse_arn(arn) 30 | 31 | def test_validate_and_parse_arn_invalid_arn(self): 32 | arn = "arn:aws:sagemaker:us-west-2mlflow-tracking-server/xw" 33 | with self.assertRaises(Exception): 34 | validate_and_parse_arn(arn) 35 | 36 | def test_get_tracking_server_url_normal(self): 37 | url = get_tracking_server_url(TEST_VALID_ARN) 38 | assert url == "https://us-west-2.experiments.sagemaker.aws" 39 | 40 | @mock.patch.dict( 41 | os.environ, 42 | { 43 | "SAGEMAKER_MLFLOW_CUSTOM_ENDPOINT": "https://a.com" 44 | }, 45 | ) 46 | def test_get_tracking_server_url_custom(self): 47 | url = get_tracking_server_url(TEST_VALID_ARN) 48 | assert url == "https://a.com" 49 | 50 | def test_dns_suffix_happy(self): 51 | suffix = get_dns_suffix("aws") 52 | assert suffix == "aws" 53 | 54 | def test_dns_suffix_invalid(self): 55 | with self.assertRaises(MlflowSageMakerException): 56 | get_dns_suffix("aws-ocean") 57 | 58 | 59 | if __name__ == "__main__": 60 | unittest.main() 61 | -------------------------------------------------------------------------------- /test/unit/test_mlflow_sagemaker_request_header_provider.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | from unittest import mock, TestCase 3 | from sagemaker_mlflow.mlflow_sagemaker_request_header_provider import MlflowSageMakerRequestHeaderProvider 4 | import os 5 | 6 | class MlflowSageMakerRequestHeaderProviderTest(TestCase): 7 | 8 | def test_in_context(self): 9 | provider = MlflowSageMakerRequestHeaderProvider() 10 | in_context = provider.in_context() 11 | assert in_context 12 | 13 | @mock.patch.dict( 14 | os.environ, 15 | { 16 | "MLFLOW_TRACKING_URI": "mw" 17 | }, 18 | ) 19 | def test_request_header(self): 20 | provider = MlflowSageMakerRequestHeaderProvider() 21 | header = provider.request_headers() 22 | assert header.get("x-mlflow-sm-tracking-server-arn") == "mw" 23 | 24 | 25 | if __name__ == "__main__": 26 | unittest.main() 27 | -------------------------------------------------------------------------------- /test/unit/test_mlflow_sagemaker_store.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | from unittest import mock, TestCase 3 | 4 | from sagemaker_mlflow.mlflow_sagemaker_store import MlflowSageMakerStore, get_host_creds 5 | 6 | 7 | TEST_VALID_ARN = "arn:aws:sagemaker:us-west-2:000000000000:mlflow-tracking-server/xw" 8 | TEST_VALID_URL = "https://test-site.com" 9 | 10 | class MlflowSageMakerStoreTest(TestCase): 11 | 12 | def test_get_host_creds_happy(self): 13 | arn = TEST_VALID_ARN 14 | mock_func = mock.Mock(return_value='result') 15 | with mock.patch("sagemaker_mlflow.mlflow_sagemaker_store.get_tracking_server_url", mock_func): 16 | result = get_host_creds(arn) 17 | assert result.host == "result" 18 | assert result.auth == "arn" 19 | 20 | def test_MlflowSageMakerStore_Store(self): 21 | test_instance = MlflowSageMakerStore(TEST_VALID_ARN, "") 22 | assert test_instance is not None 23 | 24 | if __name__ == "__main__": 25 | unittest.main() 26 | -------------------------------------------------------------------------------- /test/unit/test_presigned_url.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | from unittest import mock, TestCase 3 | import os 4 | 5 | from sagemaker_mlflow.presigned_url import get_presigned_url 6 | 7 | TEST_VALID_ARN = "arn:aws:sagemaker:us-west-2:000000000000:mlflow-tracking-server/xw" 8 | TEST_VALID_URL = "https://test-site.com" 9 | 10 | 11 | @mock.patch.dict( 12 | os.environ, 13 | {"MLFLOW_TRACKING_URI": TEST_VALID_ARN}, 14 | ) 15 | class PresignedUrlTests(TestCase): 16 | 17 | @mock.patch("boto3.client") 18 | def test_presigned_url(self, mock_boto3_client): 19 | mock_client = mock_boto3_client.return_value 20 | mock_response = {"AuthorizedUrl": TEST_VALID_URL} 21 | mock_client.create_presigned_mlflow_tracking_server_url.return_value = ( 22 | mock_response 23 | ) 24 | function_response = get_presigned_url() 25 | assert function_response == TEST_VALID_URL 26 | 27 | @mock.patch("boto3.client") 28 | def test_presigned_url_with_fields(self, mock_boto3_client): 29 | mock_client = mock_boto3_client.return_value 30 | mock_response = {"AuthorizedUrl": TEST_VALID_URL} 31 | 32 | create_presigned_api_request = { 33 | "TrackingServerName": "xw", 34 | "ExpiresInSeconds": 100, 35 | "SessionExpirationDurationInSeconds": 200, 36 | } 37 | 38 | mock_client.create_presigned_mlflow_tracking_server_url.return_value = ( 39 | mock_response 40 | ) 41 | function_response = get_presigned_url(100, 200) 42 | 43 | mock_client.create_presigned_mlflow_tracking_server_url.assert_called_with(**create_presigned_api_request) 44 | 45 | assert function_response == TEST_VALID_URL 46 | 47 | 48 | if __name__ == "__main__": 49 | unittest.main() 50 | -------------------------------------------------------------------------------- /test/unit/test_version.py: -------------------------------------------------------------------------------- 1 | import sagemaker_mlflow 2 | 3 | 4 | def test_version(): 5 | assert sagemaker_mlflow.__version__ 6 | -------------------------------------------------------------------------------- /tox.ini: -------------------------------------------------------------------------------- 1 | # Tox (http://tox.testrun.org/) is a tool for running tests 2 | # in multiple virtualenvs. This configuration file will run the 3 | # test suite on all supported python versions. To use it, "pip install tox" 4 | # and then run "tox" from this directory. 5 | 6 | [tox] 7 | envlist = black-format,flake8,twine,py{39,310,311}-mlflow{28,29,210,211,212} 8 | 9 | [flake8] 10 | max-line-length = 120 11 | exclude = 12 | build/ 13 | .git 14 | __pycache__ 15 | .tox 16 | venv/ 17 | env/ 18 | 19 | max-complexity = 10 20 | 21 | ignore = 22 | C901, 23 | # whitespace before ':': Black disagrees with and explicitly violates this. 24 | E203 25 | 26 | require-code = True 27 | 28 | [doc8] 29 | ignore-path=.tox,sagemaker_mlflow.egg-info 30 | 31 | [testenv] 32 | passenv = 33 | AWS_ACCESS_KEY_ID 34 | AWS_SECRET_ACCESS_KEY 35 | AWS_SESSION_TOKEN 36 | AWS_CONTAINER_CREDENTIALS_RELATIVE_URI 37 | AWS_DEFAULT_REGION 38 | CODEBUILD_BUILD_ARN 39 | MLFLOW_TRACKING_SERVER_URI 40 | MLFLOW_TRACKING_SERVER_NAME 41 | REGION 42 | # {posargs} can be passed in by additional arguments specified when invoking tox. 43 | # Can be used to specify which tests to run, e.g.: tox -- -s 44 | commands = 45 | pytest {posargs} 46 | deps = 47 | mlflow28: mlflow>=2.8,<2.9 48 | mlflow29: mlflow>=2.9,<2.10 49 | mlflow210: mlflow>=2.10,<2.11 50 | mlflow211: mlflow>=2.11,<2.12 51 | mlflow212: mlflow>=2.12,<2.13 52 | .[test] 53 | depends = 54 | py{39,310,311}-mlflow{28,29,210,211,212}: clean 55 | 56 | [testenv:runcoverage] 57 | description = run unit tests with coverage 58 | commands = 59 | pytest --cov=sagemaker_mlflow --cov-append {posargs} 60 | {env:IGNORE_COVERAGE:} coverage report -i --fail-under=86 61 | 62 | [testenv:flake8] 63 | skipdist = true 64 | skip_install = true 65 | deps = 66 | -r requirements/flake8_requirements.txt 67 | commands = 68 | flake8 69 | 70 | [testenv:twine] 71 | # https://packaging.python.org/guides/making-a-pypi-friendly-readme/#validating-restructuredtext-markup 72 | skip_install = true 73 | deps = 74 | -r requirements/twine_requirements.txt 75 | commands = 76 | python setup.py sdist 77 | twine check dist/*.tar.gz 78 | 79 | [testenv:black-format] 80 | # Used during development (before committing) to format .py files. 81 | skip_install = true 82 | setenv = 83 | LC_ALL=C.UTF-8 84 | LANG=C.UTF-8 85 | deps = 86 | -r requirements/black_requirements.txt 87 | commands = 88 | black -l 120 ./ 89 | 90 | [testenv:black-check] 91 | # Used by automated build steps to check that all files are properly formatted. 92 | skip_install = true 93 | setenv = 94 | LC_ALL=C.UTF-8 95 | LANG=C.UTF-8 96 | deps = 97 | -r requirements/black_requirements.txt 98 | commands = 99 | black --color --check -l 120 ./ 100 | 101 | [testenv:clean] 102 | skip_install = true 103 | commands = 104 | coverage erase 105 | 106 | [testenv:typing] 107 | # Do not skip installation here, the extras are needed for mypy to get type info 108 | skip_install = false 109 | extras = 110 | all 111 | deps = 112 | -r requirements/mypy_requirements.txt 113 | commands = 114 | mypy sagemaker_mlflow 115 | 116 | [testenv:collect-tests] 117 | # this needs to succeed for tests to display in some IDEs 118 | deps = .[test] 119 | commands = 120 | pytest --collect-only 121 | --------------------------------------------------------------------------------