├── .gitignore ├── CONTRIBUTING.md ├── LICENSE ├── README.md ├── google_cloud_mlflow ├── __init__.py ├── _mlflow_model_gcp_deployment_utils.py ├── _mlflow_models_docker_utils_patch.py ├── mlflow_model_deployment_plugin_for_google_cloud_vertex_ai.py └── mlflow_model_registry_plugin_for_google_cloud_storage.py └── setup.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # How to Contribute 2 | 3 | We'd love to accept your patches and contributions to this project. There are 4 | just a few small guidelines you need to follow. 5 | 6 | ## Contributor License Agreement 7 | 8 | Contributions to this project must be accompanied by a Contributor License 9 | Agreement. You (or your employer) retain the copyright to your contribution; 10 | this simply gives us permission to use and redistribute your contributions as 11 | part of the project. Head over to to see 12 | your current agreements on file or to sign a new one. 13 | 14 | You generally only need to submit a CLA once, so if you've already submitted one 15 | (even if it was for a different project), you probably don't need to do it 16 | again. 17 | 18 | ## Code Reviews 19 | 20 | All submissions, including submissions by project members, require review. We 21 | use GitHub pull requests for this purpose. Consult 22 | [GitHub Help](https://help.github.com/articles/about-pull-requests/) for more 23 | information on using pull requests. 24 | 25 | ## Community Guidelines 26 | 27 | This project follows [Google's Open Source Community 28 | Guidelines](https://opensource.google/conduct/). 29 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # MLflow plugin for Google Cloud Vertex AI 2 | 3 | Note: The plugin is **experimental** and may be changed or removed in the future. 4 | 5 | ## Installation 6 | 7 | ```shell 8 | python3 -m pip install google_cloud_mlflow 9 | ``` 10 | 11 | ## Deployment plugin usage 12 | 13 | ### Command-line 14 | 15 | Create deployment 16 | 17 | ```shell 18 | mlflow deployments create --target google_cloud --name "deployment name" --model-uri "models:/mymodel/mymodelversion" --config destination_image_uri="gcr.io//" 19 | ``` 20 | 21 | List deployments 22 | 23 | ```shell 24 | mlflow deployments list --target google_cloud 25 | ``` 26 | 27 | Get deployment 28 | 29 | ```shell 30 | mlflow deployments get --target google_cloud --name "deployment name" 31 | ``` 32 | 33 | Delete deployment 34 | 35 | ```shell 36 | mlflow deployments delete --target google_cloud --name "deployment name" 37 | ``` 38 | 39 | Update deployment 40 | 41 | ```shell 42 | mlflow deployments update --target google_cloud --name "deployment name" --model-uri "models:/mymodel/mymodelversion" --config destination_image_uri="gcr.io//" 43 | ``` 44 | 45 | Predict 46 | 47 | ```shell 48 | mlflow deployments predict --target google_cloud --name "deployment name" --input-path "inputs.json" --output-path "outputs.json 49 | ``` 50 | 51 | Get help 52 | 53 | ```shell 54 | mlflow deployments help --target google_cloud 55 | ``` 56 | 57 | ### Python 58 | 59 | ```python 60 | from mlflow import deployments 61 | client = deployments.get_deploy_client("google_cloud") 62 | 63 | # Create deployment 64 | model_uri = "models:/mymodel/mymodelversion" 65 | deployment = client.create_deployment( 66 | name="deployment name", 67 | model_uri=model_uri, 68 | # Config is optional 69 | config=dict( 70 | # Deployed model config 71 | machine_type="n1-standard-2", 72 | min_replica_count=None, 73 | max_replica_count=None, 74 | accelerator_type=None, 75 | accelerator_count=None, 76 | service_account=None, 77 | explanation_metadata=None, # JSON string 78 | explanation_parameters=None, # JSON string 79 | 80 | # Model container image building config 81 | destination_image_uri=None, 82 | 83 | # Endpoint config 84 | endpoint_description=None, 85 | endpoint_deploy_timeout=None, 86 | 87 | # Vertex AI config 88 | project=None, 89 | location=None, 90 | encryption_spec_key_name=None, 91 | staging_bucket=None, 92 | ) 93 | ) 94 | 95 | # List deployments 96 | deployments = client.list_deployments() 97 | 98 | # Get deployment 99 | deployments = client.get_deployment(name="deployment name") 100 | 101 | # Delete deployment 102 | deployment = client.delete_deployment(name="deployment name") 103 | 104 | # Update deployment 105 | deployment = client.create_deployment( 106 | name="deployment name", 107 | model_uri=model_uri, 108 | # Config is optional 109 | config=dict(...), 110 | ) 111 | 112 | # Predict 113 | import pandas 114 | df = pandas.DataFrame([ 115 | {"a": 1,"b": 2,"c": 3}, 116 | {"a": 4,"b": 5,"c": 6} 117 | ]) 118 | predictions = client.predict("deployment name", df) 119 | ``` 120 | 121 | ## Model Registry plugin usage 122 | 123 | Set the MLflow Model Registry URI to a directory in some Google Cloud Storage bucket, then log models using `mlflow.log_model` as usual. 124 | 125 | ```python 126 | mlflow.set_registry_uri("gs:///models/") 127 | ``` 128 | -------------------------------------------------------------------------------- /google_cloud_mlflow/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ark-kun/google_cloud_mlflow/64620e1fd837734832d9c6d0a2a31f7801319ea3/google_cloud_mlflow/__init__.py -------------------------------------------------------------------------------- /google_cloud_mlflow/_mlflow_model_gcp_deployment_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 The google-cloud-mlflow Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | """This module provides an API for deploying MLflow models to Google Cloud Vertex AI. 16 | 17 | The `upload_mlflow_model_to_vertex_ai_models` function builds a Docker container 18 | for a given MLflow model and pushes the image to Google Container Registry. 19 | Then it registers the model as a Google Cloud Vertex AI Model. 20 | Once the model is registered, the user can deploy it for serving on Google Cloud 21 | Vertex AI Endpoint using the `deploy_gcp_model_to_new_endpoint` function. 22 | See 23 | [docs](https://cloud.google.com/ai-platform-unified/docs/predictions/deploy-model-api) 24 | for more information. 25 | 26 | Examples:: 27 | 28 | # Use MLflow to register the model on Cloud AI Platform 29 | model_uri = "models:/mymodel/mymodelversion" # Replace with your model URI 30 | display_name = "my_mlflow_model" # Replace with the desired model name 31 | 32 | model_name = upload_mlflow_model_to_vertex_ai_models( 33 | model_uri=model_uri, 34 | display_name=display_name, 35 | ) 36 | 37 | deploy_model_operation = deploy_vertex_ai_model_to_endpoint( 38 | model_name=model_name, 39 | ) 40 | deployed_model = deploy_model_operation.result().deployed_model 41 | """ 42 | 43 | import logging 44 | import os 45 | import re 46 | import tempfile 47 | from typing import Any, Dict, Optional 48 | import urllib 49 | import zipfile 50 | 51 | import google 52 | from google.cloud import aiplatform 53 | from mlflow.models import cli, Model 54 | from mlflow.models.model import MLMODEL_FILE_NAME 55 | from mlflow.tracking.artifact_utils import _download_artifact_from_uri 56 | from unittest import mock 57 | 58 | from . import _mlflow_models_docker_utils_patch as docker_utils_patch 59 | 60 | 61 | _logger = logging.getLogger(__name__) 62 | 63 | 64 | def get_fixed_mlflow_source_dir(): 65 | """Downloads the fixed MLflow source code.""" 66 | fixed_mlflow_archive_url = "https://github.com/Ark-kun/mlflow/archive/refs/heads/MLFlow-fixes.zip" 67 | fixed_mlflow_archive_path, _ = urllib.request.urlretrieve(url=fixed_mlflow_archive_url) 68 | fixed_mlflow_parent_dir = tempfile.mkdtemp(prefix="mlflow.fixed") 69 | with zipfile.ZipFile(fixed_mlflow_archive_path, 'r') as zip_ref: 70 | zip_ref.extractall(fixed_mlflow_parent_dir) 71 | # The archive contains a subdirectory: "Ark-kun-mlflow-0ec4c64" 72 | # So we need to go one level deeper 73 | subdir = os.listdir(fixed_mlflow_parent_dir)[0] 74 | fixed_mlflow_dir = os.path.join(fixed_mlflow_parent_dir, subdir) 75 | return fixed_mlflow_dir 76 | 77 | 78 | def get_pickle_protocol(file_path: str) -> int: 79 | import pickletools 80 | 81 | max_proto = -1 82 | with open(file_path, "rb") as file: 83 | try: 84 | for opcode, arg, _ in pickletools.genops(file): 85 | if opcode.name == "PROTO": 86 | return arg 87 | # Looking at the opcode.proto is not reliable since unsupported opcodes cannot be parsed by old python versions. 88 | max_proto = max(max_proto, opcode.proto) 89 | except: 90 | pass 91 | return max_proto 92 | 93 | 94 | def upload_mlflow_model_to_vertex_ai_models( 95 | model_uri: str, 96 | display_name: str, 97 | destination_image_uri: Optional[str] = None, 98 | model_options: Optional[Dict[str, Any]] = None, 99 | project: Optional[str] = None, 100 | location: Optional[str] = None, 101 | encryption_spec_key_name: Optional[str] = None, 102 | staging_bucket: Optional[str] = None, 103 | ) -> str: 104 | """Builds a container for an MLflow model and registers the model with Google Cloud Vertex AI. 105 | 106 | The resulting container image will contain the MLflow webserver that processes 107 | prediction requests. The container image can be deployed as a web service to 108 | Vertex AI Endpoints. 109 | 110 | Args: 111 | model_uri: The URI of the MLflow model. 112 | Format examples: 113 | * `/Users/me/path/to/local/model` 114 | * `relative/path/to/local/model` 115 | * `gs://my_bucket/path/to/model` 116 | * `runs://run-relative/path/to/model` 117 | * `models://` 118 | * `models://` 119 | For more information about supported URI schemes, see [Referencing 120 | Artifacts](https://www.mlflow.org/docs/latest/concepts.html#artifact-locations). 121 | display_name: The display name for the Google Cloud Vertex AI Model. 122 | The name can be up to 128 characters long and can be consist of any UTF-8 123 | characters. 124 | destination_image_uri: The full name of the container image that will be 125 | built with the provided model inside it. 126 | The format should be `gcr.io//:`. 127 | Defaults to `gcr.io//mlflow/:` 128 | model_options: A dict of other attributes of the Google Cloud Vertex AI 129 | Model object, like labels and schema. See 130 | [Model](https://cloud.google.com/vertex-ai/docs/reference/rpc/google.cloud.aiplatform.v1#google.cloud.aiplatform.v1.Model). 131 | project: The Google Cloud project where to push the container image 132 | and register the model. Defaults to the location used by the gcloud CLI. 133 | location: The Google Cloud location where to push the container image 134 | and register the model. Defaults to "us-central1". 135 | encryption_spec_key_name: 136 | Optional. The Cloud KMS resource identifier of the customer 137 | managed encryption key used to protect the model. Has the 138 | form: 139 | ``projects/my-project/locations/my-region/keyRings/my-kr/cryptoKeys/my-key``. 140 | The key needs to be in the same region as where the compute 141 | resource is created. 142 | 143 | If set, this Model and all sub-resources of this Model will be secured 144 | by this key. 145 | 146 | Overrides encryption_spec_key_name set in aiplatform.init. 147 | staging_bucket: Optional. Bucket to stage local model artifacts. 148 | Overrides staging_bucket set in aiplatform.init. 149 | 150 | Returns: 151 | The full resource name of the Google Cloud Vertex AI Model. 152 | 153 | Examples:: 154 | 155 | # Use MLflow to register the model on Google Cloud Vertex AI 156 | model_uri = "models:/mymodel/mymodelversion" # Replace with your model URI 157 | display_name = "my_mlflow_model" # Replace with the desired model name 158 | 159 | model_name = upload_mlflow_model_to_vertex_ai_models( 160 | model_uri=model_uri, 161 | display_name=display_name, 162 | ) 163 | 164 | deployed_model_id = deploy_vertex_ai_model_to_endpoint( 165 | model_name=model_name, 166 | ) 167 | """ 168 | project = project or aiplatform.initializer.global_config.project 169 | 170 | temp_dir = tempfile.mkdtemp() 171 | model_dir = _download_artifact_from_uri( 172 | artifact_uri=model_uri, output_path=temp_dir, 173 | ) 174 | model_metadata = Model.load(os.path.join(model_dir, MLMODEL_FILE_NAME)) 175 | 176 | for flavor_name, flavor in model_metadata.flavors.items(): 177 | if flavor_name == "python_function": 178 | continue 179 | if flavor_name == "xgboost": 180 | model_file_name = flavor["data"] 181 | full_xgboost_version = flavor["xgb_version"] 182 | model_file_path = os.path.join(model_dir, model_file_name) 183 | # TODO: Remove after https://b.corp.google.com/issues/216705259 is fixed 184 | pickle_protocol = get_pickle_protocol(model_file_path) 185 | # Vertex Prediction uses Python 3.7 which does not support pickle protocol 5 186 | if pickle_protocol == 5: 187 | _logger.warning("Detected model with pickle protocol version 5 > 4. Prebuilt containers do not support such models.") 188 | continue 189 | # TODO: Handle case when the version is not supported by Vertex AI 190 | vertex_xgboost_version = ".".join(full_xgboost_version.split(".")[0:2]) 191 | vertex_model = aiplatform.Model.upload_xgboost_model_file( 192 | model_file_path=model_file_path, 193 | xgboost_version=vertex_xgboost_version, 194 | display_name=display_name, 195 | project=project, 196 | location=location, 197 | encryption_spec_key_name=encryption_spec_key_name, 198 | staging_bucket=staging_bucket, 199 | ) 200 | return vertex_model.resource_name 201 | if flavor_name == "sklearn": 202 | model_file_name = flavor["pickled_model"] 203 | model_file_path = os.path.join(model_dir, model_file_name) 204 | # TODO: Remove after https://b.corp.google.com/issues/216705259 is fixed 205 | pickle_protocol = get_pickle_protocol(model_file_path) 206 | # Vertex Prediction uses Python 3.7 which does not support pickle protocol 5 207 | if pickle_protocol == 5: 208 | _logger.warning("Detected model with pickle protocol version 5 > 4. Prebuilt containers do not support such models.") 209 | continue 210 | vertex_model = aiplatform.Model.upload_scikit_learn_model_file( 211 | model_file_path=model_file_path, 212 | # TODO: Deduce version from requirements.txt 213 | # sklearn_version= 214 | display_name=display_name, 215 | project=project, 216 | location=location, 217 | encryption_spec_key_name=encryption_spec_key_name, 218 | staging_bucket=staging_bucket, 219 | ) 220 | return vertex_model.resource_name 221 | if flavor_name == "tensorflow": 222 | model_dir_name = flavor["saved_model_dir"] 223 | model_dir_path = os.path.join(model_dir, model_dir_name) 224 | vertex_model = aiplatform.Model.upload_tensorflow_saved_model( 225 | saved_model_dir=model_dir_path, 226 | # TODO: Deduce version from requirements.txt 227 | # tensorflow_version= 228 | display_name=display_name, 229 | project=project, 230 | location=location, 231 | encryption_spec_key_name=encryption_spec_key_name, 232 | staging_bucket=staging_bucket, 233 | ) 234 | return vertex_model.resource_name 235 | 236 | _logger.info( 237 | "Model flavor is not directly supported by Vertex AI. Importing model as a custom-built container" 238 | ) 239 | 240 | if not destination_image_uri: 241 | image_name = re.sub("[^-A-Za-z0-9_.]", "_", display_name).lower() 242 | destination_image_uri = f"gcr.io/{project}/mlflow/{image_name}" 243 | _logger.info( 244 | "Destination image URI not set. Building and uploading image to %s", 245 | destination_image_uri, 246 | ) 247 | 248 | pushed_image_uri_with_digest = _build_serving_image( 249 | model_uri=model_uri, 250 | destination_image_uri=destination_image_uri, 251 | mlflow_source_dir=None, 252 | ) 253 | 254 | uploaded_model = aiplatform.Model.upload( 255 | # artifact_uri= 256 | display_name=display_name, 257 | # description= 258 | serving_container_image_uri=pushed_image_uri_with_digest, 259 | # serving_container_command= 260 | # serving_container_args= 261 | serving_container_predict_route="/invocations", 262 | serving_container_health_route="/ping", 263 | # Setting environment variables to tell the scoring server to properly wrap 264 | # the responses. 265 | # See https://github.com/mlflow/mlflow/pull/4611 266 | # https://cloud.google.com/vertex-ai/docs/predictions/custom-container-requirements#response_requirements 267 | serving_container_environment_variables={ 268 | "PREDICTIONS_WRAPPER_ATTR_NAME": "predictions", 269 | }, 270 | serving_container_ports=[8080], 271 | project=project, 272 | location=location, 273 | labels={ 274 | "mlflow_model_vertex_ai_deployer": "mlflow_model_vertex_ai_deployer", 275 | }, 276 | encryption_spec_key_name=encryption_spec_key_name, 277 | staging_bucket=staging_bucket, 278 | ) 279 | return uploaded_model.resource_name 280 | 281 | 282 | def _build_serving_image( 283 | model_uri: str, 284 | destination_image_uri: str, 285 | mlflow_source_dir: Optional[str] = None, 286 | ) -> str: 287 | """Builds and pushes an MLflow serving image for the MLflow model. 288 | 289 | Args: 290 | model_uri: The URI of the MLflow model. 291 | destination_image_uri: The full name of the container image that will 292 | be built with the provided model inside it. 293 | The format should be `gcr.io//:`. 294 | mlflow_source_dir: If set, installs MLflow from this directory instead of 295 | PyPI. 296 | Returns: 297 | Fully-qualified URI of the pushed container image including the hash digest. 298 | """ 299 | _logger.info("Building image. This can take up to 20 minutes") 300 | flavor_backend = cli._get_flavor_backend( 301 | model_uri 302 | ) # pylint:disable=protected-access 303 | 304 | with mock.patch( 305 | "mlflow.models.docker_utils._build_image_from_context", 306 | new=docker_utils_patch._build_image_from_context 307 | ): 308 | flavor_backend.build_image( 309 | model_uri, 310 | destination_image_uri, 311 | install_mlflow=mlflow_source_dir is not None, 312 | mlflow_home=mlflow_source_dir, 313 | ) 314 | return destination_image_uri 315 | 316 | 317 | def deploy_vertex_ai_model_to_endpoint( 318 | model_name: str, 319 | endpoint_name: Optional[str] = None, 320 | machine_type: str = "n1-standard-2", 321 | min_replica_count: int = 1, 322 | max_replica_count: int = 1, 323 | endpoint_display_name: Optional[str] = None, 324 | deployed_model_display_name: Optional[str] = None, 325 | project: Optional[str] = None, 326 | location: Optional[str] = None, 327 | ) -> google.api_core.operation.Operation: 328 | # pylint: disable=line-too-long 329 | """Deploys Google Cloud Vertex AI Model to a Google Cloud Vertex AI Endpoint. 330 | 331 | Args: 332 | model_name: Full resource name of a Google Cloud Vertex AI Model 333 | endpoint_name: Full name of Google Cloud Vertex Endpoint. A new 334 | enpoint is created if the name is not passed. 335 | machine_type: The type of the machine. See the [list of machine types 336 | supported for 337 | prediction](https://cloud.google.com/vertex-ai/docs/predictions/configure-compute#machine-types). 338 | Defaults to "n1-standard-2" 339 | min_replica_count: The minimum number of replicas the DeployedModel 340 | will be always deployed on. If traffic against it increases, it may 341 | dynamically be deployed onto more replicas up to max_replica_count, and as 342 | traffic decreases, some of these extra replicas may be freed. If the 343 | requested value is too large, the deployment will error. Defaults to 1. 344 | max_replica_count: The maximum number of replicas this DeployedModel 345 | may be deployed on when the traffic against it increases. If the requested 346 | value is too large, the deployment will error, but if deployment succeeds 347 | then the ability to scale the model to that many replicas is guaranteed 348 | (barring service outages). If traffic against the DeployedModel increases 349 | beyond what its replicas at maximum may handle, a portion of the traffic 350 | will be dropped. If this value is not provided, a no upper bound for 351 | scaling under heavy traffic will be assume, though Vertex AI may be unable 352 | to scale beyond certain replica number. Defaults to `min_replica_count` 353 | endpoint_display_name: The display name of the Endpoint. The name can 354 | be up to 128 characters long and can be consist of any UTF-8 characters. 355 | Defaults to the lowercased model ID. 356 | deployed_model_display_name: The display name of the DeployedModel. If 357 | not provided upon creation, the Model's display_name is used. 358 | project: The Google Cloud project ID. Defaults to the default project. 359 | location: The Google Cloud region. Defaults to "us-central1" 360 | timeout: Model deployment timeout 361 | 362 | Returns: 363 | google.api_core.operation.Operation: 364 | An object representing a long-running operation. 365 | 366 | The result type for the operation will be 367 | :class:`google.cloud.aiplatform_v1.types.DeployModelResponse` 368 | Response message for 369 | [EndpointService.DeployModel][google.cloud.aiplatform.v1.EndpointService.DeployModel] 370 | See 371 | https://cloud.google.com/vertex-ai/docs/reference/rpc/google.cloud.aiplatform.v1#google.cloud.aiplatform.v1.DeployedModel 372 | 373 | Examples:: 374 | 375 | # Use MLflow to register the model on Cloud AI Platform 376 | model_uri = "models:/mymodel/mymodelversion" # Replace with your model URI 377 | display_name = "my_mlflow_model" # Replace with the desired model name 378 | 379 | model_name = upload_mlflow_model_to_vertex_ai_models( 380 | model_uri=model_uri, 381 | display_name=display_name, 382 | ) 383 | 384 | deployed_model_id = deploy_vertex_ai_model_to_endpoint( 385 | model_name=model_name, 386 | ) 387 | """ 388 | model = aiplatform.Model(model_name) 389 | if endpoint_name: 390 | endpoint = aiplatform.Endpoint(endpoint_name=endpoint_name) 391 | else: 392 | # Model.deploy can create the Endpoint automatically, but I want to add label to the endpoint. 393 | if not endpoint_display_name: 394 | endpoint_display_name = model.display_name[:127] 395 | endpoint = aiplatform.Endpoint.create( 396 | display_name=endpoint_display_name, 397 | labels={ 398 | "mlflow_model_vertex_ai_deployer": "mlflow_model_vertex_ai_deployer", 399 | }, 400 | project=project, 401 | location=location, 402 | ) 403 | 404 | return model.deploy( 405 | endpoint=endpoint, 406 | deployed_model_display_name=deployed_model_display_name, 407 | traffic_percentage=100, 408 | machine_type=machine_type, 409 | min_replica_count=min_replica_count, 410 | max_replica_count=max_replica_count, 411 | ) 412 | -------------------------------------------------------------------------------- /google_cloud_mlflow/_mlflow_models_docker_utils_patch.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 The google-cloud-mlflow Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | import logging 16 | import subprocess 17 | import shutil 18 | import tempfile 19 | import uuid 20 | 21 | 22 | _logger = logging.getLogger(__name__) 23 | 24 | 25 | def _build_image_from_context(context_dir: str, image_name: str): 26 | return _build_image_from_context_using_cloudbuild_client( 27 | context_dir=context_dir, 28 | image_name=image_name, 29 | ) 30 | 31 | 32 | def _build_image_from_context_using_cloudbuild_gcloud( 33 | context_dir: str, 34 | image_name: str, 35 | ): 36 | build_process = subprocess.Popen( 37 | args=["gcloud", "builds", "submit", "--tag", image_name, "--timeout", "1800", context_dir], 38 | stdout=subprocess.PIPE, 39 | stderr=subprocess.STDOUT, 40 | universal_newlines=True, 41 | ) 42 | info_line_prefixes = [ 43 | "Logs are available at ", 44 | "Step ", 45 | "Pushing ", 46 | "latest: digest: ", 47 | ] 48 | for line in iter(build_process.stdout.readline, ""): 49 | line = line.rstrip() 50 | if any(line.startswith(prefix) for prefix in info_line_prefixes): 51 | _logger.info(line) 52 | else: 53 | _logger.debug(line) 54 | 55 | if build_process.wait() != 0: 56 | raise RuntimeError("Container image build has failed.") 57 | return 58 | 59 | 60 | def _build_image_from_context_using_cloudbuild_client( 61 | context_dir: str, image_name: str 62 | ): 63 | import google 64 | import google.auth 65 | from google.cloud import storage 66 | from google.cloud.devtools import cloudbuild 67 | 68 | archive_base_name = tempfile.mktemp() 69 | context_archive_path = shutil.make_archive( 70 | base_name=archive_base_name, 71 | format="gztar", 72 | root_dir=context_dir, 73 | ) 74 | 75 | _, project_id = google.auth.default() 76 | 77 | storage_client = storage.Client(project=project_id) 78 | build_client = cloudbuild.CloudBuildClient() 79 | 80 | # Staging the data in GCS 81 | bucket_name = project_id + "_cloudbuild" 82 | 83 | bucket = storage_client.lookup_bucket(bucket_name) 84 | # TODO: Throw error if bucket is in different project 85 | if bucket is None: 86 | bucket = storage_client.create_bucket(bucket_name) 87 | blob_name = f"source/{uuid.uuid4().hex}.tgz" 88 | 89 | bucket.blob(blob_name).upload_from_filename(context_archive_path) 90 | 91 | build_config = cloudbuild.Build( 92 | source=cloudbuild.Source( 93 | storage_source=cloudbuild.StorageSource( 94 | bucket=bucket_name, object_=blob_name 95 | ), 96 | ), 97 | images=[image_name], 98 | steps=[ 99 | cloudbuild.BuildStep( 100 | name="gcr.io/cloud-builders/docker", 101 | args=[ 102 | "build", 103 | "--network", 104 | "cloudbuild", 105 | "--no-cache", 106 | "-t", 107 | image_name, 108 | ".", 109 | ], 110 | ), 111 | ], 112 | timeout=google.protobuf.duration_pb2.Duration( 113 | seconds=1800, 114 | ), 115 | ) 116 | build_operation = build_client.create_build( 117 | project_id=project_id, build=build_config 118 | ) 119 | _logger.info("Submitted Cloud Build job") 120 | _logger.debug("build_operation.metadata:") 121 | _logger.debug(build_operation.metadata) 122 | _logger.info(f"Logs are available at [{build_operation.metadata.build.log_url}].") 123 | 124 | try: 125 | result = build_operation.result() 126 | except Exception as ex: 127 | _logger.error( 128 | "MLFlow container image build has failed." 129 | f" See Google Cloud Build logs here: {build_operation.metadata.build.log_url}" 130 | ) 131 | raise Exception("MLFlow container image build has failed.") from ex 132 | _logger.debug("operation.result") 133 | _logger.debug(result) 134 | 135 | built_image = result.results.images[0] 136 | image_base_name = built_image.name.split(":")[0] 137 | image_digest = built_image.digest 138 | image_name_with_digest = image_base_name + "@" + image_digest 139 | return image_name_with_digest 140 | -------------------------------------------------------------------------------- /google_cloud_mlflow/mlflow_model_deployment_plugin_for_google_cloud_vertex_ai.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 The google-cloud-mlflow Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | """MLflow deployment plugin to deploy MLflow models to Google Cloud Vertex AI.""" 16 | 17 | __all__ = [ 18 | "GoogleCloudVertexAiDeploymentClient", 19 | "target_help", 20 | "run_local", 21 | ] 22 | 23 | import json 24 | import logging 25 | from typing import Any, Dict, List, Optional, Union 26 | 27 | from . import _mlflow_model_gcp_deployment_utils as vertex_utils 28 | from google.protobuf import json_format 29 | import mlflow 30 | from mlflow import deployments 31 | import numpy 32 | import pandas 33 | 34 | # TODO(b/195784726) Remove this workaround once google-cloud-aiplatform conforms 35 | # to the third_party python rules 36 | try: 37 | from google.cloud import aiplatform # pylint:disable=g-import-not-at-top 38 | except ImportError: 39 | from google.cloud.aiplatform import aiplatform # pylint:disable=g-import-not-at-top 40 | 41 | 42 | _logger = logging.getLogger(__name__) 43 | 44 | DEFAULT_MACHINE_TYPE="n1-standard-2" 45 | 46 | 47 | def _resource_to_mlflow_dict( 48 | resource: aiplatform.base.VertexAiResourceNoun, 49 | ) -> Dict[str, Any]: 50 | """Converts Vertex AI resource instance to a MLflow dict.""" 51 | # TODO(avolkov): Switch to .to_dict() method when my PR is merged: 52 | # https://github.com/googleapis/python-aiplatform/pull/588 53 | resource_dict = json_format.MessageToDict(resource._gca_resource._pb) # pylint: disable=protected-access 54 | # The MLflow documentation seems to imply that the returned dicts 55 | # need to have "name" attribute set to MLflow "deployment name", not the 56 | # internal resource name. 57 | # We put MLflow deployment name into Endpoint's display_name 58 | resource_dict["resource_name"] = resource.resource_name 59 | resource_dict["name"] = resource.display_name 60 | return resource_dict 61 | 62 | 63 | def _data_to_list_of_instances( 64 | data: Union[list, numpy.ndarray, pandas.DataFrame] 65 | ) -> List: 66 | if isinstance(data, pandas.DataFrame): 67 | return data.values.tolist() 68 | if isinstance(data, numpy.ndarray): 69 | return data.tolist() 70 | if isinstance(data, list): 71 | return data 72 | raise TypeError(f"Unsupported data type {type(data)}") 73 | 74 | 75 | class GoogleCloudVertexAiDeploymentClient(deployments.BaseDeploymentClient): 76 | """The Google Cloud Vertex AI implementation of the BaseDeploymentClient.""" 77 | 78 | def create_deployment( 79 | self, 80 | name: str, 81 | model_uri: str, 82 | flavor: Optional[str] = None, 83 | config: Optional[Dict[str, str]] = None, 84 | ) -> Dict[str, Any]: 85 | """Deploys the model. 86 | 87 | Deploys a model to the specified target. By default, this method should 88 | block until deployment completes (i.e. until it's possible to perform 89 | inference with the deployment). In the case of conflicts (e.g. if it's 90 | not possible to create the specified deployment without due to conflict 91 | with an existing deployment), raises a 92 | :py:class:`mlflow.exceptions.MlflowException`. See target-specific 93 | plugin documentation for additional detail on support for asynchronous 94 | deployment and other configuration. 95 | 96 | Example:: 97 | 98 | from mlflow import deployments 99 | client = deployments.get_deploy_client("google_cloud") 100 | deployment = client.create_deployment( 101 | name="deployment name", 102 | model_uri=..., 103 | # Config is optional 104 | config=dict( 105 | # Deployed model config 106 | machine_type="n1-standard-2", 107 | min_replica_count=None, 108 | max_replica_count=None, 109 | accelerator_type=None, 110 | accelerator_count=None, 111 | service_account=None, 112 | explanation_metadata=None, # JSON string 113 | explanation_parameters=None, # JSON string 114 | 115 | # Model container image building config 116 | destination_image_uri=None, 117 | 118 | # Endpoint config 119 | endpoint_description=None, 120 | endpoint_deploy_timeout=None, 121 | 122 | # Vertex AI config 123 | project=None, 124 | location=None, 125 | encryption_spec_key_name=None, 126 | staging_bucket=None, 127 | ) 128 | ) 129 | 130 | Args: 131 | name: Unique name to use for deployment. If another deployment 132 | exists with the same name, raises a 133 | :py:class:`mlflow.exceptions.MlflowException` 134 | model_uri: URI of model to deploy 135 | flavor: (optional) The MLflow model flavor to deploy. 136 | If unspecified, the default flavor will be chosen. 137 | config: (optional) Dict containing updated target-specific configuration for the 138 | deployment 139 | 140 | Returns: 141 | A dict corresponding to the created deployment, which must contain the 'name' key. 142 | """ 143 | config = config or {} 144 | 145 | project = config.get("project") 146 | location = config.get("location") 147 | encryption_spec_key_name = config.get("encryption_spec_key_name") 148 | staging_bucket = config.get("staging_bucket") 149 | 150 | existing_endpoints = aiplatform.Endpoint.list( 151 | filter=f'display_name="{name}"', 152 | project=project, 153 | location=location, 154 | ) 155 | if existing_endpoints: 156 | raise mlflow.exceptions.MlflowException( 157 | f"Found existing deployment with name {name}: " + 158 | ",".join(list(endpoint.resource_name for endpoint in existing_endpoints))) 159 | 160 | model_name = vertex_utils.upload_mlflow_model_to_vertex_ai_models( 161 | model_uri=model_uri, 162 | display_name=name, 163 | destination_image_uri=config.get("destination_image_uri"), 164 | project=project, 165 | location=location, 166 | encryption_spec_key_name=encryption_spec_key_name, 167 | staging_bucket=staging_bucket, 168 | ) 169 | endpoint = aiplatform.Endpoint.create( 170 | display_name=name, 171 | description=config.get("endpoint_description"), 172 | labels={ 173 | "google_cloud_mlflow_plugin_version": "0-0-1", 174 | }, 175 | project=project, 176 | location=location, 177 | encryption_spec_key_name=encryption_spec_key_name, 178 | ) 179 | endpoint.deploy( 180 | model=aiplatform.Model(model_name=model_name), 181 | deployed_model_display_name=name, 182 | traffic_percentage=100, 183 | # Need to always specify the machine type to prevent this error: 184 | # InvalidArgument: 400 'automatic_resources' is not supported for Model 185 | # The choice between "dedicated_resources" and "automatic_resources" 186 | # (only supported with AutoML models) is based on the presence of 187 | # machine_type. 188 | machine_type=config.get("machine_type", DEFAULT_MACHINE_TYPE), 189 | min_replica_count=int(config.get("min_replica_count", 1)), 190 | max_replica_count=int(config.get("max_replica_count", 1)), 191 | accelerator_type=config.get("accelerator_type"), 192 | accelerator_count=int(config.get("accelerator_count", 0)) or None, 193 | service_account=config.get("service_account"), 194 | explanation_metadata=(json.loads(config.get("explanation_metadata")) if "explanation_metadata" in config else None), 195 | explanation_parameters=(json.loads(config.get("explanation_parameters")) if "explanation_parameters" in config else None), 196 | sync=json.loads(config.get("sync", "true")), 197 | deploy_request_timeout=int(config.get("endpoint_deploy_timeout", 1800)), 198 | ) 199 | 200 | deployment_dict = _resource_to_mlflow_dict(endpoint) 201 | deployment_dict["flavor"] = flavor 202 | return deployment_dict 203 | 204 | def get_deployment(self, name: str) -> Dict[str, Any]: 205 | """Gets deployment by name. 206 | 207 | Args: 208 | name: ID of deployment to fetch 209 | 210 | Returns: 211 | A dictionary describing the specified deployment, throwing a 212 | py:class:`mlflow.exception.MlflowException` if no deployment exists with the provided ID. 213 | The dict is guaranteed to contain an 'name' key containing the deployment name. 214 | The other fields of the returned dictionary and their types may vary across 215 | deployment targets. 216 | """ 217 | return _resource_to_mlflow_dict( 218 | self._get_deployment(deployment_name=name) 219 | ) 220 | 221 | def _get_deployment(self, deployment_name: str) -> aiplatform.Endpoint: 222 | endpoints = aiplatform.Endpoint.list(filter=f'display_name="{deployment_name}"') 223 | if len(endpoints) > 1: 224 | raise mlflow.exceptions.MlflowException( 225 | f"Found multiple deployments with name {deployment_name}: " + 226 | ",".join(list(endpoint.resource_name for endpoint in endpoints))) 227 | if endpoints: 228 | return endpoints[0] 229 | raise mlflow.exceptions.MlflowException( 230 | f"Could not find deployment with name {deployment_name}." 231 | ) 232 | 233 | def list_deployments(self) -> List[Dict[str, Any]]: 234 | """Lists all deployments. 235 | 236 | Returns: 237 | A list of dicts corresponding to deployments. Each dict is guaranteed to 238 | contain a 'name' key containing the deployment name. The other fields of 239 | the returned dictionary and their types may vary across deployment targets. 240 | """ 241 | endpoints = aiplatform.Endpoint.list(filter='labels.google_cloud_mlflow_plugin_version:*') 242 | endpoint_dicts = list(map(_resource_to_mlflow_dict, endpoints)) 243 | return endpoint_dicts 244 | 245 | def delete_deployment(self, name: str) -> None: 246 | """Deletes the deployment. 247 | 248 | Deletes the deployment with name ``name`` from the specified target. 249 | Deletion is idempotent (i.e. deletion does not fail if retried on a 250 | non-existent deployment). 251 | 252 | Args: 253 | name: The name of deployment to delete 254 | """ 255 | deployment_name = name 256 | # Locate deployment endpoint from MLflow deployment list 257 | # Using Endpoint.delete with force=True will undeploy all models on Endpoint 258 | # before deleting the Endpoint 259 | endpoints = aiplatform.Endpoint.list(filter=f'display_name="{deployment_name}"') 260 | if len(endpoints) > 1: 261 | raise mlflow.exceptions.MlflowException( 262 | f"Found multiple deployments with name {deployment_name}: " + 263 | ",".join(list(endpoint.resource_name for endpoint in endpoints))) 264 | if endpoints: 265 | endpoint = endpoints[0] 266 | endpoint.delete(force=True) 267 | 268 | def update_deployment( 269 | self, 270 | name: str, 271 | model_uri: Optional[str] = None, 272 | flavor: Optional[str] = None, 273 | config: Optional[Dict[str, Any]] = None, 274 | ) -> Dict[str, Any]: 275 | """Updates the deployment with the specified name. 276 | 277 | You can update the URI of the model, the flavor of the deployed model 278 | (in which case the model URI must also be specified), and/or any 279 | target-specific attributes of the deployment (via `config`). By default, 280 | this method should block until deployment completes (i.e. until it's 281 | possible to perform inference with the updated deployment). 282 | See target-specific plugin documentation for additional detail on support 283 | for asynchronous deployment and other configuration. 284 | 285 | Args: 286 | name: Unique name of deployment to update 287 | model_uri: URI of a new model to deploy. 288 | flavor: (optional) new model flavor to use for deployment. If provided, 289 | ``model_uri`` must also be specified. If ``flavor`` is unspecified but 290 | ``model_uri`` is specified, a default flavor will be chosen and the 291 | deployment will be updated using that flavor. 292 | config: (optional) dict containing updated target-specific configuration for the 293 | deployment 294 | 295 | Returns: 296 | A dict corresponding to the created deployment, which must contain the 'name' key. 297 | """ 298 | self.delete_deployment(name=name) 299 | return self.create_deployment( 300 | name=name, 301 | model_uri=model_uri, 302 | flavor=flavor, 303 | config=config, 304 | ) 305 | 306 | def predict( 307 | self, 308 | deployment_name: str, 309 | df: Union[List, numpy.ndarray, pandas.DataFrame], 310 | ) -> Union[pandas.DataFrame, pandas.Series, numpy.ndarray, Dict[str, Any]]: 311 | """Computes model predictions. 312 | 313 | Compute predictions on the pandas DataFrame ``df`` using the specified 314 | deployment. Note that the input/output types of this method matches that 315 | of `mlflow pyfunc predict` (we accept a pandas DataFrame as input and 316 | return either a pandas DataFrame, pandas Series, or numpy array as output). 317 | 318 | Args: 319 | deployment_name: Name of deployment to predict against 320 | df: Pandas DataFrame to use for inference 321 | 322 | Returns: 323 | A pandas DataFrame, pandas Series, or numpy array 324 | 325 | Example:: 326 | 327 | from mlflow import deployments 328 | import pandas 329 | df = pandas.DataFrame( 330 | [ 331 | {"a": 1,"b": 2,"c": 3}, 332 | {"a": 4,"b": 5,"c": 6} 333 | ] 334 | ) 335 | client = deployments.get_deploy_client("google_cloud") 336 | client.create_deployment("deployment name", model_uri=...) 337 | predictions = client.predict("deployment name", df) 338 | """ 339 | endpoint = self._get_deployment(deployment_name=deployment_name) 340 | predictions = endpoint.predict( 341 | instances=_data_to_list_of_instances(df), 342 | ) 343 | return predictions 344 | 345 | def explain( 346 | self, 347 | deployment_name: str, 348 | df: Union[List, numpy.ndarray, pandas.DataFrame], 349 | ) -> Union[pandas.DataFrame, pandas.Series, numpy.ndarray, Dict[str, Any]]: 350 | """Generates explanations of model predictions. 351 | 352 | Generate explanations of model predictions on the specified input pandas Dataframe 353 | ``df`` for the deployed model. Explanation output formats vary by deployment target, 354 | and can include details like feature importance for understanding/debugging predictions. 355 | 356 | Args: 357 | deployment_name: Name of deployment to predict against 358 | df: Pandas DataFrame to use for inference 359 | 360 | Returns: 361 | A JSON-able object (pandas dataframe, numpy array, dictionary), or 362 | an exception if the implementation is not available in deployment target's class 363 | 364 | Example:: 365 | 366 | from mlflow import deployments 367 | import pandas 368 | df = pandas.DataFrame( 369 | [ 370 | {"a": 1,"b": 2,"c": 3}, 371 | {"a": 4,"b": 5,"c": 6} 372 | ] 373 | ) 374 | client = deployments.get_deploy_client("google_cloud") 375 | client.create_deployment("deployment name", model_uri=...) 376 | predictions = client.explain("deployment name", df) 377 | """ 378 | endpoint = self._get_deployment(deployment_name=deployment_name) 379 | predictions = endpoint.explain( 380 | instances=_data_to_list_of_instances(df), 381 | ) 382 | return predictions 383 | 384 | 385 | def run_local( 386 | name: str, 387 | model_uri: str, 388 | flavor: Optional[str] = None, 389 | config: Optional[Dict[str, str]] = None, 390 | ): 391 | """Deploys the specified model locally, for testing. 392 | 393 | Args: 394 | name: Unique name to use for deployment. If another deployment exists with 395 | the same name, create_deployment will raise a 396 | :py:class:`mlflow.exceptions.MlflowException` 397 | model_uri: URI of model to deploy 398 | flavor: Model flavor to deploy. If unspecified, default flavor is chosen. 399 | config: Dict containing updated target-specific config for the deployment 400 | """ 401 | raise NotImplementedError() 402 | 403 | 404 | def target_help(): 405 | """Returns help string. 406 | 407 | Returns a string containing detailed documentation on the current deployment 408 | target, to be displayed when users invoke the 409 | ``mlflow deployments help -t `` CLI command. 410 | """ 411 | return """ 412 | MLflow deployment plugin to deploy MLflow models to Google Cloud Vertex AI. 413 | 414 | Example:: 415 | 416 | from mlflow import deployments 417 | client = deployments.get_deploy_client("google_cloud") 418 | deployment = client.create_deployment( 419 | name="deployment name", 420 | model_uri=..., 421 | # Config is optional 422 | config=dict( 423 | # Deployed model config 424 | machine_type="n1-standard-2", 425 | min_replica_count=None, 426 | max_replica_count=None, 427 | accelerator_type=None, 428 | accelerator_count=None, 429 | service_account=None, 430 | explanation_metadata=None, # JSON string 431 | explanation_parameters=None, # JSON string 432 | 433 | # Model container image building config 434 | destination_image_uri=None, 435 | 436 | # Model deployment config 437 | sync="true", 438 | 439 | # Endpoint config 440 | description=None, 441 | 442 | # Vertex AI config 443 | project=None, 444 | location=None, 445 | experiment=None, 446 | experiment_description=None, 447 | staging_bucket=None, 448 | ) 449 | ) 450 | """ 451 | -------------------------------------------------------------------------------- /google_cloud_mlflow/mlflow_model_registry_plugin_for_google_cloud_storage.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 The google-cloud-mlflow Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | """MLflow Model Registry plugin to allow logging models to Google Cloud Storage.""" 16 | 17 | import datetime 18 | from typing import Iterator, List, Optional 19 | 20 | import google 21 | from google.cloud import storage 22 | from google.protobuf import json_format 23 | import mlflow 24 | from mlflow.entities import model_registry 25 | from mlflow.entities.model_registry import model_version_stages 26 | from mlflow.protos import databricks_pb2 27 | from mlflow.protos import model_registry_pb2 28 | from mlflow.store.entities import paged_list 29 | from mlflow.utils.search_utils import SearchUtils 30 | from mlflow.utils.validation import ( 31 | _validate_registered_model_tag, 32 | _validate_model_version_tag, 33 | _validate_model_name, 34 | _validate_model_version, 35 | _validate_tag_name, 36 | ) 37 | 38 | 39 | class GoogleCloudStorageModelRegistry( 40 | mlflow.store.model_registry.abstract_store.AbstractStore 41 | ): 42 | """Class for storing Model Registry metadata.""" 43 | 44 | _MODEL_INFO_FILE_NAME: str = "model_info.json" 45 | _MODEL_VERSION_INFO_FILE_NAME: str = "model_version_info.json" 46 | _LAST_MODEL_VERSION_FILE_NAME: str = "last_model_version" 47 | 48 | DELETE_MODEL_VERSIONS_INSTEAD_OF_MARKING_AS_DELETED = False 49 | 50 | def __init__(self, store_uri: str): 51 | if not _validate_store_uri(store_uri): 52 | raise mlflow.exceptions.MlflowException( 53 | f"Bad store_uri format: {store_uri}" 54 | ) 55 | store_uri = store_uri.rstrip("/") + "/" 56 | self._base_uri = store_uri 57 | 58 | # CRUD API for RegisteredModel objects 59 | 60 | def _get_model_dir(self, name: str) -> str: 61 | return self._base_uri + name + "/" 62 | 63 | def _get_model_info_file_path(self, name: str) -> str: 64 | return self._get_model_dir(name=name) + self._MODEL_INFO_FILE_NAME 65 | 66 | def _get_model_version_dir(self, name: str, version: str) -> str: 67 | return self._get_model_dir(name=name) + version + "/" 68 | 69 | def _get_model_version_info_file_path(self, name: str, version: str) -> str: 70 | return ( 71 | self._get_model_version_dir(name=name, version=version) 72 | + self._MODEL_VERSION_INFO_FILE_NAME 73 | ) 74 | 75 | def create_registered_model( 76 | self, 77 | name: str, 78 | tags: List[model_registry.RegisteredModelTag] = None, 79 | description: str = None, 80 | ) -> model_registry.RegisteredModel: 81 | """Creates a new registered model in backend store. 82 | 83 | Args: 84 | name: Name of the new model. This is expected to be unique in the backend store. 85 | tags: A list of :py:class:`mlflow.entities.model_registry.RegisteredModelTag` 86 | instances associated with this registered model. 87 | description: Description of the model. 88 | 89 | Returns: 90 | A single object of :py:class:`mlflow.entities.model_registry.RegisteredModel` 91 | created in the backend. 92 | """ 93 | # TODO(avolkov): Validate that the model does not exist. 94 | _validate_model_name(name) 95 | for tag in tags or []: 96 | _validate_registered_model_tag(tag.key, tag.value) 97 | current_time = datetime.datetime.utcnow() 98 | model = model_registry.RegisteredModel( 99 | name=name, 100 | creation_timestamp=int(current_time.timestamp()), 101 | last_updated_timestamp=int(current_time.timestamp()), 102 | description=description, 103 | tags=tags, 104 | ) 105 | self._set_registered_model_proto( 106 | name, model.to_proto(), update_modification_time=False 107 | ) 108 | return model 109 | 110 | def update_registered_model( 111 | self, 112 | name: str, 113 | description: str, 114 | ) -> model_registry.RegisteredModel: 115 | """Updates description of the registered model. 116 | 117 | Args: 118 | name: Registered model name. 119 | description: New description. 120 | 121 | Returns: 122 | A single updated :py:class:`mlflow.entities.model_registry.RegisteredModel` object. 123 | """ 124 | model = self.get_registered_model(name=name) 125 | model_proto = model.to_proto() 126 | model_proto.description = description 127 | self._set_registered_model_proto(name, model_proto) 128 | return model 129 | 130 | def rename_registered_model( 131 | self, 132 | name: str, 133 | new_name: str, 134 | ) -> model_registry.RegisteredModel: 135 | """Renames the registered model. 136 | 137 | Args: 138 | name: Registered model name. 139 | new_name: New proposed name. 140 | 141 | Returns: 142 | A single updated :py:class:`mlflow.entities.model_registry.RegisteredModel` object. 143 | """ 144 | _validate_model_name(name) 145 | _validate_model_name(new_name) 146 | model = self.get_registered_model(name) 147 | if new_name == name: 148 | return model 149 | model_proto = model.to_proto() 150 | model_proto.name = new_name 151 | src_dir = self._get_model_dir(name=name) 152 | dst_dir = self._get_model_dir(name=new_name) 153 | src_dir_blob = storage.Blob.from_string(uri=src_dir) 154 | dst_dir_blob = storage.Blob.from_string(uri=dst_dir) 155 | bucket = src_dir_blob.bucket 156 | src_path = src_dir_blob.name 157 | dst_path = dst_dir_blob.name 158 | blobs: List[storage.Blob] = list( 159 | storage.Client().list_blobs( 160 | # Using bucket name as a workaround for 161 | # https://github.com/googleapis/python-storage/issues/540 162 | bucket_or_name=bucket.name, 163 | prefix=src_path, 164 | ) 165 | ) 166 | blobs_to_delete = [] 167 | for blob in blobs: 168 | assert blob.name.startswith(src_path) 169 | new_blob_name = blob.name.replace(src_path, dst_path, 1) 170 | if blob.name.endswith("/" + self._MODEL_VERSION_INFO_FILE_NAME): 171 | blobs_to_delete.append(blob) 172 | model_version_proto = _json_to_registered_model_version( 173 | blob.download_as_text() 174 | ).to_proto() 175 | model_version_proto.name = new_name 176 | self._set_model_version_proto( 177 | model_version_proto.name, 178 | model_version_proto.version, 179 | model_version_proto, 180 | update_modification_time=False, 181 | ) 182 | elif blob.name.endswith("/" + self._LAST_MODEL_VERSION_FILE_NAME): 183 | blobs_to_delete.append(blob) 184 | blob.bucket.copy_blob( 185 | blob=blob, destination_bucket=blob.bucket, new_name=new_blob_name 186 | ) 187 | elif blob.name.endswith("/" + self._MODEL_INFO_FILE_NAME): 188 | blobs_to_delete.append(blob) 189 | 190 | # _set_registered_model_proto also updates 191 | # model_proto.last_updated_timestamp 192 | self._set_registered_model_proto(new_name, model_proto) 193 | # After we copy and update all model files, we delete the original blobs. 194 | for blob in blobs_to_delete: 195 | blob.delete() 196 | # The blob is not available right after writing. 197 | # So we cannot use get_registered_model(new_name) 198 | return model_registry.RegisteredModel.from_proto(model_proto) 199 | 200 | def delete_registered_model(self, name: str) -> None: 201 | """Deletes the registered model. 202 | 203 | Backend raises exception if a registered model with given name does not exist. 204 | 205 | Args: 206 | name: Registered model name. 207 | 208 | Returns: 209 | None 210 | """ 211 | _validate_model_name(name) 212 | src_dir = self._get_model_dir(name=name) 213 | src_dir_blob = storage.Blob.from_string(uri=src_dir) 214 | bucket = src_dir_blob.bucket 215 | src_path = src_dir_blob.name 216 | blobs: Iterator[storage.Blob] = storage.Client().list_blobs( 217 | # Using bucket name as a workaround for 218 | # https://github.com/googleapis/python-storage/issues/540 219 | bucket_or_name=bucket.name, 220 | prefix=src_path, 221 | ) 222 | for blob in blobs: 223 | blob.delete() 224 | 225 | def list_registered_models( 226 | self, 227 | max_results: Optional[int], 228 | page_token: Optional[str] = None, 229 | ) -> paged_list.PagedList[model_registry.RegisteredModel]: 230 | """Lists all registered models. 231 | 232 | Args: 233 | max_results: Maximum number of registered models desired. 234 | page_token: Token specifying the next page of results. It should be obtained from 235 | a ``list_registered_models`` call. 236 | 237 | Returns: 238 | A PagedList of :py:class:`mlflow.entities.model_registry.RegisteredModel` objects. 239 | The pagination token for the next page can be obtained via the ``token`` attribute 240 | of the object. 241 | """ 242 | root_dir = storage.Blob.from_string(uri=self._base_uri) 243 | blob_iterator = storage.Client().list_blobs( 244 | # Using bucket name as a workaround for 245 | # https://github.com/googleapis/python-storage/issues/540 246 | bucket_or_name=root_dir.bucket.name, 247 | prefix=root_dir.name, 248 | max_results=max_results, 249 | page_token=page_token, 250 | ) 251 | models = [ 252 | _json_to_registered_model(blob.download_as_text()) 253 | for blob in blob_iterator 254 | if blob.name.endswith(self._MODEL_INFO_FILE_NAME) 255 | ] 256 | return paged_list.PagedList(items=models, token=blob_iterator.next_page_token) 257 | 258 | def search_registered_models( 259 | self, 260 | filter_string: str, 261 | max_results: Optional[int] = None, 262 | order_by: Optional[str] = None, 263 | page_token: Optional[str] = None, 264 | ) -> paged_list.PagedList[model_registry.RegisteredModel]: 265 | """Searches for registered models that satisfy the filter criteria. 266 | 267 | Args: 268 | filter_string: Filter query string, defaults to searching all registered models. 269 | max_results: Maximum number of registered models desired. (Default value = None) 270 | order_by: List of column names with ASC|DESC annotation, to be used for ordering 271 | matching search results. (Default value = None) 272 | page_token: Token specifying the next page of results. It should be obtained from 273 | a ``search_registered_models`` call. (Default value = None) 274 | 275 | Returns: 276 | A PagedList of :py:class:`mlflow.entities.model_registry.RegisteredModel` objects 277 | that satisfy the search expressions. The pagination token for the next page can be 278 | obtained via the ``token`` attribute of the object. 279 | """ 280 | del page_token 281 | parsed_filters = SearchUtils.parse_filter_for_registered_models(filter_string) 282 | if order_by: 283 | ( 284 | ordering_key, 285 | ordering_is_ascending, 286 | ) = SearchUtils.parse_order_by_for_search_registered_models(order_by) 287 | models = self._list_models() 288 | for parsed_filter in parsed_filters: 289 | if parsed_filter["comparator"] != "=": 290 | raise mlflow.exceptions.MlflowException( 291 | "Model Registry search filter only supports equality(=) " 292 | f"comparator. Input filter string: {filter_string}", 293 | error_code=databricks_pb2.INVALID_PARAMETER_VALUE, 294 | ) 295 | # Key validated by `parse_filter_for_models` 296 | key = parsed_filter["key"] 297 | value = parsed_filter["value"] 298 | models = [model for model in models if getattr(model, key, None) == value] 299 | if order_by: 300 | models.sort( 301 | key=lambda x: getattr(x, ordering_key, None), 302 | reversed=not ordering_is_ascending, 303 | ) 304 | if max_results: 305 | models = models[0:max_results] 306 | return models 307 | 308 | def get_registered_model(self, name: str) -> model_registry.RegisteredModel: 309 | """Gets registered model instance by name. 310 | 311 | Args: 312 | name: Registered model name. 313 | 314 | Returns: 315 | A single :py:class:`mlflow.entities.model_registry.RegisteredModel` object. 316 | """ 317 | _validate_model_name(name) 318 | model_uri = self._get_model_info_file_path(name=name) 319 | try: 320 | model_json = storage.Blob.from_string( 321 | uri=model_uri, 322 | # Workaround for https://github.com/googleapis/python-storage/issues/540 323 | client=storage.Client(), 324 | ).download_as_text() 325 | model = _json_to_registered_model(model_json) 326 | return model 327 | except google.api_core.exceptions.NotFound: 328 | raise mlflow.exceptions.MlflowException( 329 | message=f'Model "{name}" does not exist', 330 | error_code=databricks_pb2.RESOURCE_DOES_NOT_EXIST, 331 | ) 332 | 333 | def _set_registered_model_proto( 334 | self, 335 | name: str, 336 | model_proto: model_registry_pb2.RegisteredModel, 337 | update_modification_time: bool = True, 338 | ) -> None: 339 | """Sets registered model instance. 340 | 341 | Args: 342 | name: Registered model name. 343 | model_proto: A single 344 | :py:class:`mlflow.entities.model_registry_pb2.RegisteredModel` object. 345 | update_modification_time: Whether to update the modification time 346 | """ 347 | _validate_model_name(name) 348 | if update_modification_time: 349 | current_time = datetime.datetime.utcnow() 350 | model_proto.last_updated_timestamp = int(current_time.timestamp()) 351 | model_json = json_format.MessageToJson(model_proto) 352 | model_uri = self._get_model_info_file_path(name=name) 353 | storage.Blob.from_string( 354 | uri=model_uri, 355 | # Workaround for https://github.com/googleapis/python-storage/issues/540 356 | client=storage.Client(), 357 | ).upload_from_string(data=model_json) 358 | 359 | def _list_models(self) -> List[model_registry.RegisteredModel]: 360 | """Lists all registered models. 361 | 362 | Returns: 363 | A list of :py:class:`mlflow.entities.model_registry.RegisteredModel` objects. 364 | """ 365 | base_dir_blob = storage.Blob.from_string(uri=self._base_uri) 366 | blob_iterator = storage.Client().list_blobs( 367 | # Using bucket name as a workaround for 368 | # https://github.com/googleapis/python-storage/issues/540 369 | bucket_or_name=base_dir_blob.bucket.name, 370 | prefix=base_dir_blob.name, 371 | ) 372 | models = [ 373 | _json_to_registered_model(blob.download_as_text()) 374 | for blob in blob_iterator 375 | if blob.name.endswith(self._MODEL_INFO_FILE_NAME) 376 | ] 377 | return models 378 | 379 | def get_latest_versions( 380 | self, name: str, stages: List[str] = None 381 | ) -> List[model_registry.ModelVersion]: 382 | """Gets the latest model version for each requested stage. 383 | 384 | If no ``stages`` argument is provided, returns the latest version for each stage. 385 | 386 | Args: 387 | name: Registered model name. 388 | stages: List of desired stages. If input list is None, return latest versions for 389 | for 'Staging' and 'Production' stages. (Default value = None) 390 | 391 | Returns: 392 | List of :py:class:`mlflow.entities.model_registry.ModelVersion` objects. 393 | """ 394 | _validate_model_name(name) 395 | stages = stages or model_version_stages.DEFAULT_STAGES_FOR_GET_LATEST_VERSIONS 396 | stages = [model_version_stages.get_canonical_stage(stage) for stage in stages] 397 | latest_versions = {} 398 | model_versions = self._list_model_versions(name=name) 399 | for model_version in model_versions: 400 | stage = model_version.current_stage 401 | if stage not in stages: 402 | continue 403 | if stage not in latest_versions or int( 404 | latest_versions[stage].version 405 | ) < int(model_version.version): 406 | latest_versions[stage] = model_version 407 | latest_versions_list = [latest_versions.get(stage) or None for stage in stages] 408 | return latest_versions_list 409 | 410 | def set_registered_model_tag( 411 | self, 412 | name: str, 413 | tag: model_registry.RegisteredModelTag, 414 | ) -> None: 415 | """Sets a tag for the registered model. 416 | 417 | Args: 418 | name: Registered model name. 419 | tag: py:class:`mlflow.entities.model_registry.RegisteredModelTag` instance to log. 420 | 421 | Returns: 422 | None 423 | """ 424 | _validate_registered_model_tag(key=tag.key, value=tag.value) 425 | model = self.get_registered_model(name=name) 426 | model_proto = model.to_proto() 427 | # model.tags is not a proper map. 428 | # Deleting possible previous tag values first. 429 | for existing_tag in model_proto.tags: 430 | if existing_tag.key == tag.key: 431 | model_proto.tags.remove(existing_tag) 432 | model_proto.tags.append(tag) 433 | self._set_registered_model_proto(name, model_proto) 434 | 435 | def delete_registered_model_tag(self, name: str, key: str) -> None: 436 | """Delete a tag associated with the registered model. 437 | 438 | Args: 439 | name: Registered model name. 440 | key: Registered model tag key. 441 | 442 | Returns: 443 | None 444 | """ 445 | _validate_tag_name(name=key) 446 | model = self.get_registered_model(name=name) 447 | model_proto = model.to_proto() 448 | for tag in model_proto.tags: 449 | if tag.key == key: 450 | model_proto.tags.remove(tag) 451 | self._set_registered_model_proto(name, model_proto) 452 | 453 | # CRUD API for ModelVersion objects 454 | 455 | def _increment_last_model_version(self, name: str) -> int: 456 | """Increments and returns the last model version for a given model. 457 | 458 | Args: 459 | name: Registered model name. 460 | 461 | Returns: 462 | The version number for the next model version. 463 | """ 464 | _validate_model_name(name) 465 | model_dir_uri = self._get_model_dir(name=name) 466 | last_model_version_file_uri = model_dir_uri + self._LAST_MODEL_VERSION_FILE_NAME 467 | last_model_version_file_blob = storage.Blob.from_string( 468 | uri=last_model_version_file_uri, 469 | # Workaround for https://github.com/googleapis/python-storage/issues/540 470 | client=storage.Client(), 471 | ) 472 | last_model_version = 0 473 | try: 474 | last_model_version = int(last_model_version_file_blob.download_as_text()) 475 | except google.api_core.exceptions.NotFound: 476 | pass 477 | last_model_version += 1 478 | last_model_version_file_blob.upload_from_string( 479 | data=str(last_model_version), 480 | # Avoiding race condition 481 | if_generation_match=last_model_version_file_blob.generation, 482 | ) 483 | return str(last_model_version) 484 | 485 | def create_model_version( 486 | self, 487 | name: str, 488 | source: str, 489 | run_id: str = None, 490 | tags: List[model_registry.ModelVersionTag] = None, 491 | run_link: str = None, 492 | description: str = None, 493 | ) -> model_registry.ModelVersion: 494 | """Creates a new model version from given source and run ID. 495 | 496 | Args: 497 | name: Registered model name. 498 | source: Source path where the MLflow model is stored. 499 | run_id: Run ID from MLflow tracking server that generated the model. (Default value = None) 500 | tags: A list of :py:class:`mlflow.entities.model_registry.ModelVersionTag` 501 | instances associated with this model version. (Default value = None) 502 | run_link: Link to the run from an MLflow tracking server that generated this model. (Default value = None) 503 | description: Description of the version. (Default value = None) 504 | 505 | Returns: 506 | A single object of :py:class:`mlflow.entities.model_registry.ModelVersion` 507 | created in the backend. 508 | """ 509 | _validate_model_name(name) 510 | version = self._increment_last_model_version(name=name) 511 | current_time = datetime.datetime.utcnow() 512 | model_version = model_registry.ModelVersion( 513 | name=name, 514 | version=version, 515 | creation_timestamp=int(current_time.timestamp()), 516 | last_updated_timestamp=int(current_time.timestamp()), 517 | description=description, 518 | source=source, 519 | run_id=run_id, 520 | tags=tags, 521 | run_link=run_link, 522 | ) 523 | self._set_model_version_proto( 524 | name=name, version=version, model_version_proto=model_version.to_proto() 525 | ) 526 | return model_version 527 | 528 | def update_model_version( 529 | self, 530 | name: str, 531 | version: str, 532 | description: str, 533 | ) -> model_registry.ModelVersion: 534 | """Updates metadata associated with a model version in backend. 535 | 536 | Args: 537 | name: Registered model name. 538 | version: Registered model version. 539 | description: New model description. 540 | 541 | Returns: 542 | A single :py:class:`mlflow.entities.model_registry.ModelVersion` object. 543 | """ 544 | model_version = self.get_model_version(name=name, version=version) 545 | model_version_proto = model_version.to_proto() 546 | model_version_proto.description = description 547 | self._set_model_version_proto( 548 | name=name, version=version, model_version_proto=model_version_proto 549 | ) 550 | return model_version 551 | 552 | def transition_model_version_stage( 553 | self, 554 | name: str, 555 | version: str, 556 | stage: str, 557 | archive_existing_versions: bool = False, 558 | ) -> model_registry.ModelVersion: 559 | """Updates model version stage. 560 | 561 | Args: 562 | name: Registered model name. 563 | version: Registered model version. 564 | stage: New desired stage for this model version. 565 | archive_existing_versions: If this flag is set to ``True``, all existing model 566 | versions in the stage will be automically moved to the "archived" stage. Only valid 567 | when ``stage`` is ``"staging"`` or ``"production"`` otherwise an error will be raised. 568 | 569 | Returns: 570 | A single :py:class:`mlflow.entities.model_registry.ModelVersion` object. 571 | """ 572 | model_version = self.get_model_version(name=name, version=version) 573 | stage = model_version_stages.get_canonical_stage(stage) 574 | model_version_proto = model_version.to_proto() 575 | model_version_proto.current_stage = stage 576 | self._set_model_version_proto(name, version, model_version_proto) 577 | if archive_existing_versions: 578 | all_model_versions = self._list_model_versions(name=name) 579 | all_other_model_versions = list( 580 | filter( 581 | lambda model_version: model_version.name == name, all_model_versions 582 | ) 583 | ) 584 | all_other_model_versions_in_same_stage = list( 585 | filter( 586 | lambda model_version: model_version.current_stage == stage, 587 | all_other_model_versions, 588 | ) 589 | ) 590 | for other_model_version in all_other_model_versions_in_same_stage: 591 | other_model_version_proto = other_model_version.to_proto() 592 | other_model_version_proto.current_stage = ( 593 | model_version_stages.STAGE_ARCHIVED 594 | ) 595 | self._set_model_version_proto( 596 | name=other_model_version.name, 597 | version=other_model_version.version, 598 | model_version_proto=other_model_version_proto, 599 | ) 600 | return model_version 601 | 602 | def delete_model_version(self, name: str, version: str) -> None: 603 | """Deletes model version in backend. 604 | 605 | Args: 606 | name: Registered model name. 607 | version: Registered model version. 608 | 609 | Returns: 610 | None 611 | """ 612 | # Validating that the version is a proper model version 613 | model_version = self.get_model_version(name=name, version=version) 614 | if self.DELETE_MODEL_VERSIONS_INSTEAD_OF_MARKING_AS_DELETED: 615 | model_version_uri = self._get_model_version_info_file_path( 616 | name=name, version=version 617 | ) 618 | storage.Blob.from_string( 619 | uri=model_version_uri, 620 | # Workaround for https://github.com/googleapis/python-storage/issues/540 621 | client=storage.Client(), 622 | ).delete() 623 | else: 624 | model_version_proto = model_version.to_proto() 625 | model_version_proto.current_stage = ( 626 | model_version_stages.STAGE_DELETED_INTERNAL 627 | ) 628 | self._set_model_version_proto( 629 | name=name, 630 | version=version, 631 | model_version_proto=model_version_proto, 632 | ) 633 | 634 | def get_model_version(self, name: str, version: str) -> model_registry.ModelVersion: 635 | """Gets the model version instance by name and version. 636 | 637 | Args: 638 | name: Registered model name. 639 | version: Registered model version. 640 | 641 | Returns: 642 | A single :py:class:`mlflow.entities.model_registry.ModelVersion` object. 643 | """ 644 | _validate_model_name(name) 645 | _validate_model_version(version) 646 | model_version_uri = self._get_model_version_info_file_path( 647 | name=name, version=version 648 | ) 649 | try: 650 | model_version_json = storage.Blob.from_string( 651 | uri=model_version_uri, 652 | # Workaround for https://github.com/googleapis/python-storage/issues/540 653 | client=storage.Client(), 654 | ).download_as_text() 655 | model_version = _json_to_registered_model_version(model_version_json) 656 | return model_version 657 | except google.api_core.exceptions.NotFound: 658 | raise mlflow.exceptions.MlflowException( 659 | message=f'Model "{name}" version "{version}" does not exist', 660 | error_code=databricks_pb2.RESOURCE_DOES_NOT_EXIST, 661 | ) 662 | 663 | def _set_model_version_proto( 664 | self, 665 | name: str, 666 | version: str, 667 | model_version_proto: model_registry_pb2.ModelVersion, 668 | update_modification_time: bool = True, 669 | ) -> None: 670 | """Gets the model version instance by name and version. 671 | 672 | Args: 673 | name: Registered model name. 674 | version: Registered model version. 675 | model_version_proto: A 676 | :py:class:`mlflow.entities.model_registry_pb2.ModelVersion` object. 677 | update_modification_time: Whether to update the modification time 678 | 679 | Returns: 680 | None 681 | """ 682 | _validate_model_name(name) 683 | _validate_model_version(version) 684 | if update_modification_time: 685 | current_time = datetime.datetime.utcnow() 686 | model_version_proto.last_updated_timestamp = int(current_time.timestamp()) 687 | model_version_json = json_format.MessageToJson(model_version_proto) 688 | model_version_uri = self._get_model_version_info_file_path( 689 | name=name, version=version 690 | ) 691 | storage.Blob.from_string( 692 | uri=model_version_uri, 693 | # Workaround for https://github.com/googleapis/python-storage/issues/540 694 | client=storage.Client(), 695 | ).upload_from_string(model_version_json) 696 | 697 | def _list_model_versions( 698 | self, 699 | name: Optional[str] = None, 700 | ) -> List[model_registry.ModelVersion]: 701 | """Lists all versions of a registered model. 702 | 703 | Args: 704 | name: Registered model name. 705 | 706 | Returns: 707 | A list of :py:class:`mlflow.entities.model_registry.ModelVersion` objects. 708 | """ 709 | if name: 710 | _validate_model_name(name) 711 | model_dir_uri = self._get_model_dir(name=name) 712 | else: 713 | model_dir_uri = self._base_uri 714 | model_dir_blob = storage.Blob.from_string(uri=model_dir_uri) 715 | blob_iterator = storage.Client().list_blobs( 716 | # Using bucket name as a workaround for 717 | # https://github.com/googleapis/python-storage/issues/540 718 | bucket_or_name=model_dir_blob.bucket.name, 719 | prefix=model_dir_blob.name, 720 | ) 721 | models = [ 722 | _json_to_registered_model_version(blob.download_as_text()) 723 | for blob in blob_iterator 724 | if blob.name.endswith(self._MODEL_VERSION_INFO_FILE_NAME) 725 | ] 726 | return models 727 | 728 | def get_model_version_download_uri(self, name: str, version: str) -> str: 729 | """Gets the download location in Model Registry for this model version. 730 | 731 | NOTE: For first version of Model Registry, since the models are not copied over to another 732 | location, download URI points to input source path. 733 | 734 | Args: 735 | name: Registered model name. 736 | version: Registered model version. 737 | 738 | Returns: 739 | A single URI location that allows reads for downloading. 740 | """ 741 | model_version = self.get_model_version(name, version) 742 | return model_version.source 743 | 744 | def search_model_versions( 745 | self, filter_string: str 746 | ) -> paged_list.PagedList[model_registry.ModelVersion]: 747 | """Searches for model versions in backend that satisfy the filter criteria. 748 | 749 | Args: 750 | filter_string: A filter string expression. Currently supports a single filter 751 | condition either name of model like ``name = 'model_name'`` or ``run_id = '...'``. 752 | 753 | Returns: 754 | PagedList of :py:class:`mlflow.entities.model_registry.ModelVersion` 755 | objects. 756 | """ 757 | parsed_filters = SearchUtils.parse_filter_for_model_versions(filter_string) 758 | model_versions = self._list_model_versions() 759 | model_versions = [ 760 | model_version 761 | for model_version in model_versions 762 | if model_version.current_stage is None 763 | or model_version.current_stage in model_version_stages.ALL_STAGES 764 | ] 765 | for parsed_filter in parsed_filters: 766 | if parsed_filter["comparator"] != "=": 767 | raise mlflow.exceptions.MlflowException( 768 | "Model Registry search filter only supports equality(=) " 769 | f"comparator. Input filter string: {filter_string}", 770 | error_code=databricks_pb2.INVALID_PARAMETER_VALUE, 771 | ) 772 | # Key validated by `parse_filter_for_model_versions` 773 | key = parsed_filter["key"] 774 | value = parsed_filter["value"] 775 | model_versions = [ 776 | model_version 777 | for model_version in model_versions 778 | if getattr(model_version, key, None) == value 779 | ] 780 | return paged_list.PagedList(items=model_versions, token=None) 781 | 782 | def set_model_version_tag( 783 | self, name: str, version: str, tag: model_registry.ModelVersionTag 784 | ) -> None: 785 | """Sets a tag for the model version. 786 | 787 | Args: 788 | name: Registered model name. 789 | version: Registered model version. 790 | tag: py:class:`mlflow.entities.model_registry.ModelVersionTag` instance 791 | to log. 792 | 793 | Returns: 794 | None 795 | """ 796 | _validate_model_version_tag(tag.key, tag.value) 797 | model_version = self.get_model_version(name, version) 798 | model_version_proto = model_version.to_proto() 799 | # model.tags is not a proper map. 800 | # Deleting possible previous tag values first. 801 | for existing_tag in model_version_proto.tags: 802 | if existing_tag.key == tag.key: 803 | model_version_proto.tags.remove(existing_tag) 804 | model_version_proto.tags.append(tag) 805 | self._set_model_version_proto( 806 | name=name, version=version, model_version_proto=model_version_proto 807 | ) 808 | 809 | def delete_model_version_tag(self, name: str, version: str, key: str) -> None: 810 | """Deletes a tag associated with the model version. 811 | 812 | Args: 813 | name: Registered model name. 814 | version: Registered model version. 815 | key: Tag key. 816 | 817 | Returns: 818 | None 819 | """ 820 | _validate_tag_name(key) 821 | model_version = self.get_model_version(name, version) 822 | model_version_proto = model_version.to_proto() 823 | for tag in model_version_proto.tags: 824 | if tag.key == key: 825 | model_version_proto.tags.remove(tag) 826 | self._set_model_version_proto( 827 | name=name, version=version, model_version_proto=model_version_proto 828 | ) 829 | 830 | 831 | def _validate_store_uri(store_uri: str) -> bool: 832 | return store_uri.startswith("gs://") 833 | 834 | 835 | def _json_to_registered_model(model_json: str) -> model_registry.RegisteredModel: 836 | """Converts JSON string to RegisteredModel.""" 837 | model = model_registry_pb2.RegisteredModel() 838 | json_format.Parse(model_json, model) 839 | return model_registry.RegisteredModel.from_proto(model) 840 | 841 | 842 | def _json_to_registered_model_version( 843 | model_json: str, 844 | ) -> model_registry.ModelVersion: 845 | """Converts JSON string to ModelVersion.""" 846 | model_version = model_registry_pb2.ModelVersion() 847 | json_format.Parse(model_json, model_version) 848 | return model_registry.ModelVersion.from_proto(model_version) 849 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 The google-cloud-mlflow Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | """MLflow Google Cloud Vertex AI integration package.""" 16 | 17 | import setuptools 18 | 19 | with open("README.md", "r") as fh: 20 | long_description = fh.read() 21 | 22 | setuptools.setup( 23 | name="google_cloud_mlflow", 24 | version="0.0.6", 25 | description="MLflow Google Cloud Vertex AI integration package", 26 | url="https://github.com/Ark-kun/google_cloud_mlflow", 27 | project_urls={ 28 | 'Source': 'https://github.com/pypa/sampleproject/', 29 | 'Issues': 'https://github.com/Ark-kun/google_cloud_mlflow/issues', 30 | }, 31 | author="Alexey Volkov", 32 | author_email="alexey.volkov@ark-kun.com", 33 | classifiers=[ 34 | "License :: OSI Approved :: Apache Software License", 35 | "Natural Language :: English", 36 | "Operating System :: OS Independent", 37 | "Programming Language :: Python :: 3", 38 | ], 39 | license="Apache License 2.0", 40 | long_description=long_description, 41 | long_description_content_type="text/markdown", 42 | packages=setuptools.find_packages(), 43 | keywords='mlflow, Google Cloud, Vertex AI', 44 | python_requires=">=3.6", 45 | install_requires=[ 46 | "google-cloud-aiplatform~=1.7", 47 | "mlflow~=1.22", 48 | "google-cloud-build==3.*", 49 | "google-cloud-storage==1.*", 50 | ], 51 | entry_points={ 52 | "mlflow.deployments": [ 53 | "google_cloud=google_cloud_mlflow.mlflow_model_deployment_plugin_for_google_cloud_vertex_ai", 54 | ], 55 | "mlflow.model_registry_store": [ 56 | "gs=google_cloud_mlflow.mlflow_model_registry_plugin_for_google_cloud_storage:GoogleCloudStorageModelRegistry", 57 | ], 58 | }, 59 | ) 60 | --------------------------------------------------------------------------------