├── .gitignore
├── CONTRIBUTING.md
├── LICENSE
├── README.md
├── google_cloud_mlflow
├── __init__.py
├── _mlflow_model_gcp_deployment_utils.py
├── _mlflow_models_docker_utils_patch.py
├── mlflow_model_deployment_plugin_for_google_cloud_vertex_ai.py
└── mlflow_model_registry_plugin_for_google_cloud_storage.py
└── setup.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | pip-wheel-metadata/
24 | share/python-wheels/
25 | *.egg-info/
26 | .installed.cfg
27 | *.egg
28 | MANIFEST
29 |
30 | # PyInstaller
31 | # Usually these files are written by a python script from a template
32 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
33 | *.manifest
34 | *.spec
35 |
36 | # Installer logs
37 | pip-log.txt
38 | pip-delete-this-directory.txt
39 |
40 | # Unit test / coverage reports
41 | htmlcov/
42 | .tox/
43 | .nox/
44 | .coverage
45 | .coverage.*
46 | .cache
47 | nosetests.xml
48 | coverage.xml
49 | *.cover
50 | *.py,cover
51 | .hypothesis/
52 | .pytest_cache/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | target/
76 |
77 | # Jupyter Notebook
78 | .ipynb_checkpoints
79 |
80 | # IPython
81 | profile_default/
82 | ipython_config.py
83 |
84 | # pyenv
85 | .python-version
86 |
87 | # pipenv
88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
91 | # install all needed dependencies.
92 | #Pipfile.lock
93 |
94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
95 | __pypackages__/
96 |
97 | # Celery stuff
98 | celerybeat-schedule
99 | celerybeat.pid
100 |
101 | # SageMath parsed files
102 | *.sage.py
103 |
104 | # Environments
105 | .env
106 | .venv
107 | env/
108 | venv/
109 | ENV/
110 | env.bak/
111 | venv.bak/
112 |
113 | # Spyder project settings
114 | .spyderproject
115 | .spyproject
116 |
117 | # Rope project settings
118 | .ropeproject
119 |
120 | # mkdocs documentation
121 | /site
122 |
123 | # mypy
124 | .mypy_cache/
125 | .dmypy.json
126 | dmypy.json
127 |
128 | # Pyre type checker
129 | .pyre/
130 |
--------------------------------------------------------------------------------
/CONTRIBUTING.md:
--------------------------------------------------------------------------------
1 | # How to Contribute
2 |
3 | We'd love to accept your patches and contributions to this project. There are
4 | just a few small guidelines you need to follow.
5 |
6 | ## Contributor License Agreement
7 |
8 | Contributions to this project must be accompanied by a Contributor License
9 | Agreement. You (or your employer) retain the copyright to your contribution;
10 | this simply gives us permission to use and redistribute your contributions as
11 | part of the project. Head over to to see
12 | your current agreements on file or to sign a new one.
13 |
14 | You generally only need to submit a CLA once, so if you've already submitted one
15 | (even if it was for a different project), you probably don't need to do it
16 | again.
17 |
18 | ## Code Reviews
19 |
20 | All submissions, including submissions by project members, require review. We
21 | use GitHub pull requests for this purpose. Consult
22 | [GitHub Help](https://help.github.com/articles/about-pull-requests/) for more
23 | information on using pull requests.
24 |
25 | ## Community Guidelines
26 |
27 | This project follows [Google's Open Source Community
28 | Guidelines](https://opensource.google/conduct/).
29 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright [yyyy] [name of copyright owner]
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # MLflow plugin for Google Cloud Vertex AI
2 |
3 | Note: The plugin is **experimental** and may be changed or removed in the future.
4 |
5 | ## Installation
6 |
7 | ```shell
8 | python3 -m pip install google_cloud_mlflow
9 | ```
10 |
11 | ## Deployment plugin usage
12 |
13 | ### Command-line
14 |
15 | Create deployment
16 |
17 | ```shell
18 | mlflow deployments create --target google_cloud --name "deployment name" --model-uri "models:/mymodel/mymodelversion" --config destination_image_uri="gcr.io//"
19 | ```
20 |
21 | List deployments
22 |
23 | ```shell
24 | mlflow deployments list --target google_cloud
25 | ```
26 |
27 | Get deployment
28 |
29 | ```shell
30 | mlflow deployments get --target google_cloud --name "deployment name"
31 | ```
32 |
33 | Delete deployment
34 |
35 | ```shell
36 | mlflow deployments delete --target google_cloud --name "deployment name"
37 | ```
38 |
39 | Update deployment
40 |
41 | ```shell
42 | mlflow deployments update --target google_cloud --name "deployment name" --model-uri "models:/mymodel/mymodelversion" --config destination_image_uri="gcr.io//"
43 | ```
44 |
45 | Predict
46 |
47 | ```shell
48 | mlflow deployments predict --target google_cloud --name "deployment name" --input-path "inputs.json" --output-path "outputs.json
49 | ```
50 |
51 | Get help
52 |
53 | ```shell
54 | mlflow deployments help --target google_cloud
55 | ```
56 |
57 | ### Python
58 |
59 | ```python
60 | from mlflow import deployments
61 | client = deployments.get_deploy_client("google_cloud")
62 |
63 | # Create deployment
64 | model_uri = "models:/mymodel/mymodelversion"
65 | deployment = client.create_deployment(
66 | name="deployment name",
67 | model_uri=model_uri,
68 | # Config is optional
69 | config=dict(
70 | # Deployed model config
71 | machine_type="n1-standard-2",
72 | min_replica_count=None,
73 | max_replica_count=None,
74 | accelerator_type=None,
75 | accelerator_count=None,
76 | service_account=None,
77 | explanation_metadata=None, # JSON string
78 | explanation_parameters=None, # JSON string
79 |
80 | # Model container image building config
81 | destination_image_uri=None,
82 |
83 | # Endpoint config
84 | endpoint_description=None,
85 | endpoint_deploy_timeout=None,
86 |
87 | # Vertex AI config
88 | project=None,
89 | location=None,
90 | encryption_spec_key_name=None,
91 | staging_bucket=None,
92 | )
93 | )
94 |
95 | # List deployments
96 | deployments = client.list_deployments()
97 |
98 | # Get deployment
99 | deployments = client.get_deployment(name="deployment name")
100 |
101 | # Delete deployment
102 | deployment = client.delete_deployment(name="deployment name")
103 |
104 | # Update deployment
105 | deployment = client.create_deployment(
106 | name="deployment name",
107 | model_uri=model_uri,
108 | # Config is optional
109 | config=dict(...),
110 | )
111 |
112 | # Predict
113 | import pandas
114 | df = pandas.DataFrame([
115 | {"a": 1,"b": 2,"c": 3},
116 | {"a": 4,"b": 5,"c": 6}
117 | ])
118 | predictions = client.predict("deployment name", df)
119 | ```
120 |
121 | ## Model Registry plugin usage
122 |
123 | Set the MLflow Model Registry URI to a directory in some Google Cloud Storage bucket, then log models using `mlflow.log_model` as usual.
124 |
125 | ```python
126 | mlflow.set_registry_uri("gs:///models/")
127 | ```
128 |
--------------------------------------------------------------------------------
/google_cloud_mlflow/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Ark-kun/google_cloud_mlflow/64620e1fd837734832d9c6d0a2a31f7801319ea3/google_cloud_mlflow/__init__.py
--------------------------------------------------------------------------------
/google_cloud_mlflow/_mlflow_model_gcp_deployment_utils.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 The google-cloud-mlflow Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | #
15 | """This module provides an API for deploying MLflow models to Google Cloud Vertex AI.
16 |
17 | The `upload_mlflow_model_to_vertex_ai_models` function builds a Docker container
18 | for a given MLflow model and pushes the image to Google Container Registry.
19 | Then it registers the model as a Google Cloud Vertex AI Model.
20 | Once the model is registered, the user can deploy it for serving on Google Cloud
21 | Vertex AI Endpoint using the `deploy_gcp_model_to_new_endpoint` function.
22 | See
23 | [docs](https://cloud.google.com/ai-platform-unified/docs/predictions/deploy-model-api)
24 | for more information.
25 |
26 | Examples::
27 |
28 | # Use MLflow to register the model on Cloud AI Platform
29 | model_uri = "models:/mymodel/mymodelversion" # Replace with your model URI
30 | display_name = "my_mlflow_model" # Replace with the desired model name
31 |
32 | model_name = upload_mlflow_model_to_vertex_ai_models(
33 | model_uri=model_uri,
34 | display_name=display_name,
35 | )
36 |
37 | deploy_model_operation = deploy_vertex_ai_model_to_endpoint(
38 | model_name=model_name,
39 | )
40 | deployed_model = deploy_model_operation.result().deployed_model
41 | """
42 |
43 | import logging
44 | import os
45 | import re
46 | import tempfile
47 | from typing import Any, Dict, Optional
48 | import urllib
49 | import zipfile
50 |
51 | import google
52 | from google.cloud import aiplatform
53 | from mlflow.models import cli, Model
54 | from mlflow.models.model import MLMODEL_FILE_NAME
55 | from mlflow.tracking.artifact_utils import _download_artifact_from_uri
56 | from unittest import mock
57 |
58 | from . import _mlflow_models_docker_utils_patch as docker_utils_patch
59 |
60 |
61 | _logger = logging.getLogger(__name__)
62 |
63 |
64 | def get_fixed_mlflow_source_dir():
65 | """Downloads the fixed MLflow source code."""
66 | fixed_mlflow_archive_url = "https://github.com/Ark-kun/mlflow/archive/refs/heads/MLFlow-fixes.zip"
67 | fixed_mlflow_archive_path, _ = urllib.request.urlretrieve(url=fixed_mlflow_archive_url)
68 | fixed_mlflow_parent_dir = tempfile.mkdtemp(prefix="mlflow.fixed")
69 | with zipfile.ZipFile(fixed_mlflow_archive_path, 'r') as zip_ref:
70 | zip_ref.extractall(fixed_mlflow_parent_dir)
71 | # The archive contains a subdirectory: "Ark-kun-mlflow-0ec4c64"
72 | # So we need to go one level deeper
73 | subdir = os.listdir(fixed_mlflow_parent_dir)[0]
74 | fixed_mlflow_dir = os.path.join(fixed_mlflow_parent_dir, subdir)
75 | return fixed_mlflow_dir
76 |
77 |
78 | def get_pickle_protocol(file_path: str) -> int:
79 | import pickletools
80 |
81 | max_proto = -1
82 | with open(file_path, "rb") as file:
83 | try:
84 | for opcode, arg, _ in pickletools.genops(file):
85 | if opcode.name == "PROTO":
86 | return arg
87 | # Looking at the opcode.proto is not reliable since unsupported opcodes cannot be parsed by old python versions.
88 | max_proto = max(max_proto, opcode.proto)
89 | except:
90 | pass
91 | return max_proto
92 |
93 |
94 | def upload_mlflow_model_to_vertex_ai_models(
95 | model_uri: str,
96 | display_name: str,
97 | destination_image_uri: Optional[str] = None,
98 | model_options: Optional[Dict[str, Any]] = None,
99 | project: Optional[str] = None,
100 | location: Optional[str] = None,
101 | encryption_spec_key_name: Optional[str] = None,
102 | staging_bucket: Optional[str] = None,
103 | ) -> str:
104 | """Builds a container for an MLflow model and registers the model with Google Cloud Vertex AI.
105 |
106 | The resulting container image will contain the MLflow webserver that processes
107 | prediction requests. The container image can be deployed as a web service to
108 | Vertex AI Endpoints.
109 |
110 | Args:
111 | model_uri: The URI of the MLflow model.
112 | Format examples:
113 | * `/Users/me/path/to/local/model`
114 | * `relative/path/to/local/model`
115 | * `gs://my_bucket/path/to/model`
116 | * `runs://run-relative/path/to/model`
117 | * `models://`
118 | * `models://`
119 | For more information about supported URI schemes, see [Referencing
120 | Artifacts](https://www.mlflow.org/docs/latest/concepts.html#artifact-locations).
121 | display_name: The display name for the Google Cloud Vertex AI Model.
122 | The name can be up to 128 characters long and can be consist of any UTF-8
123 | characters.
124 | destination_image_uri: The full name of the container image that will be
125 | built with the provided model inside it.
126 | The format should be `gcr.io//:`.
127 | Defaults to `gcr.io//mlflow/:`
128 | model_options: A dict of other attributes of the Google Cloud Vertex AI
129 | Model object, like labels and schema. See
130 | [Model](https://cloud.google.com/vertex-ai/docs/reference/rpc/google.cloud.aiplatform.v1#google.cloud.aiplatform.v1.Model).
131 | project: The Google Cloud project where to push the container image
132 | and register the model. Defaults to the location used by the gcloud CLI.
133 | location: The Google Cloud location where to push the container image
134 | and register the model. Defaults to "us-central1".
135 | encryption_spec_key_name:
136 | Optional. The Cloud KMS resource identifier of the customer
137 | managed encryption key used to protect the model. Has the
138 | form:
139 | ``projects/my-project/locations/my-region/keyRings/my-kr/cryptoKeys/my-key``.
140 | The key needs to be in the same region as where the compute
141 | resource is created.
142 |
143 | If set, this Model and all sub-resources of this Model will be secured
144 | by this key.
145 |
146 | Overrides encryption_spec_key_name set in aiplatform.init.
147 | staging_bucket: Optional. Bucket to stage local model artifacts.
148 | Overrides staging_bucket set in aiplatform.init.
149 |
150 | Returns:
151 | The full resource name of the Google Cloud Vertex AI Model.
152 |
153 | Examples::
154 |
155 | # Use MLflow to register the model on Google Cloud Vertex AI
156 | model_uri = "models:/mymodel/mymodelversion" # Replace with your model URI
157 | display_name = "my_mlflow_model" # Replace with the desired model name
158 |
159 | model_name = upload_mlflow_model_to_vertex_ai_models(
160 | model_uri=model_uri,
161 | display_name=display_name,
162 | )
163 |
164 | deployed_model_id = deploy_vertex_ai_model_to_endpoint(
165 | model_name=model_name,
166 | )
167 | """
168 | project = project or aiplatform.initializer.global_config.project
169 |
170 | temp_dir = tempfile.mkdtemp()
171 | model_dir = _download_artifact_from_uri(
172 | artifact_uri=model_uri, output_path=temp_dir,
173 | )
174 | model_metadata = Model.load(os.path.join(model_dir, MLMODEL_FILE_NAME))
175 |
176 | for flavor_name, flavor in model_metadata.flavors.items():
177 | if flavor_name == "python_function":
178 | continue
179 | if flavor_name == "xgboost":
180 | model_file_name = flavor["data"]
181 | full_xgboost_version = flavor["xgb_version"]
182 | model_file_path = os.path.join(model_dir, model_file_name)
183 | # TODO: Remove after https://b.corp.google.com/issues/216705259 is fixed
184 | pickle_protocol = get_pickle_protocol(model_file_path)
185 | # Vertex Prediction uses Python 3.7 which does not support pickle protocol 5
186 | if pickle_protocol == 5:
187 | _logger.warning("Detected model with pickle protocol version 5 > 4. Prebuilt containers do not support such models.")
188 | continue
189 | # TODO: Handle case when the version is not supported by Vertex AI
190 | vertex_xgboost_version = ".".join(full_xgboost_version.split(".")[0:2])
191 | vertex_model = aiplatform.Model.upload_xgboost_model_file(
192 | model_file_path=model_file_path,
193 | xgboost_version=vertex_xgboost_version,
194 | display_name=display_name,
195 | project=project,
196 | location=location,
197 | encryption_spec_key_name=encryption_spec_key_name,
198 | staging_bucket=staging_bucket,
199 | )
200 | return vertex_model.resource_name
201 | if flavor_name == "sklearn":
202 | model_file_name = flavor["pickled_model"]
203 | model_file_path = os.path.join(model_dir, model_file_name)
204 | # TODO: Remove after https://b.corp.google.com/issues/216705259 is fixed
205 | pickle_protocol = get_pickle_protocol(model_file_path)
206 | # Vertex Prediction uses Python 3.7 which does not support pickle protocol 5
207 | if pickle_protocol == 5:
208 | _logger.warning("Detected model with pickle protocol version 5 > 4. Prebuilt containers do not support such models.")
209 | continue
210 | vertex_model = aiplatform.Model.upload_scikit_learn_model_file(
211 | model_file_path=model_file_path,
212 | # TODO: Deduce version from requirements.txt
213 | # sklearn_version=
214 | display_name=display_name,
215 | project=project,
216 | location=location,
217 | encryption_spec_key_name=encryption_spec_key_name,
218 | staging_bucket=staging_bucket,
219 | )
220 | return vertex_model.resource_name
221 | if flavor_name == "tensorflow":
222 | model_dir_name = flavor["saved_model_dir"]
223 | model_dir_path = os.path.join(model_dir, model_dir_name)
224 | vertex_model = aiplatform.Model.upload_tensorflow_saved_model(
225 | saved_model_dir=model_dir_path,
226 | # TODO: Deduce version from requirements.txt
227 | # tensorflow_version=
228 | display_name=display_name,
229 | project=project,
230 | location=location,
231 | encryption_spec_key_name=encryption_spec_key_name,
232 | staging_bucket=staging_bucket,
233 | )
234 | return vertex_model.resource_name
235 |
236 | _logger.info(
237 | "Model flavor is not directly supported by Vertex AI. Importing model as a custom-built container"
238 | )
239 |
240 | if not destination_image_uri:
241 | image_name = re.sub("[^-A-Za-z0-9_.]", "_", display_name).lower()
242 | destination_image_uri = f"gcr.io/{project}/mlflow/{image_name}"
243 | _logger.info(
244 | "Destination image URI not set. Building and uploading image to %s",
245 | destination_image_uri,
246 | )
247 |
248 | pushed_image_uri_with_digest = _build_serving_image(
249 | model_uri=model_uri,
250 | destination_image_uri=destination_image_uri,
251 | mlflow_source_dir=None,
252 | )
253 |
254 | uploaded_model = aiplatform.Model.upload(
255 | # artifact_uri=
256 | display_name=display_name,
257 | # description=
258 | serving_container_image_uri=pushed_image_uri_with_digest,
259 | # serving_container_command=
260 | # serving_container_args=
261 | serving_container_predict_route="/invocations",
262 | serving_container_health_route="/ping",
263 | # Setting environment variables to tell the scoring server to properly wrap
264 | # the responses.
265 | # See https://github.com/mlflow/mlflow/pull/4611
266 | # https://cloud.google.com/vertex-ai/docs/predictions/custom-container-requirements#response_requirements
267 | serving_container_environment_variables={
268 | "PREDICTIONS_WRAPPER_ATTR_NAME": "predictions",
269 | },
270 | serving_container_ports=[8080],
271 | project=project,
272 | location=location,
273 | labels={
274 | "mlflow_model_vertex_ai_deployer": "mlflow_model_vertex_ai_deployer",
275 | },
276 | encryption_spec_key_name=encryption_spec_key_name,
277 | staging_bucket=staging_bucket,
278 | )
279 | return uploaded_model.resource_name
280 |
281 |
282 | def _build_serving_image(
283 | model_uri: str,
284 | destination_image_uri: str,
285 | mlflow_source_dir: Optional[str] = None,
286 | ) -> str:
287 | """Builds and pushes an MLflow serving image for the MLflow model.
288 |
289 | Args:
290 | model_uri: The URI of the MLflow model.
291 | destination_image_uri: The full name of the container image that will
292 | be built with the provided model inside it.
293 | The format should be `gcr.io//:`.
294 | mlflow_source_dir: If set, installs MLflow from this directory instead of
295 | PyPI.
296 | Returns:
297 | Fully-qualified URI of the pushed container image including the hash digest.
298 | """
299 | _logger.info("Building image. This can take up to 20 minutes")
300 | flavor_backend = cli._get_flavor_backend(
301 | model_uri
302 | ) # pylint:disable=protected-access
303 |
304 | with mock.patch(
305 | "mlflow.models.docker_utils._build_image_from_context",
306 | new=docker_utils_patch._build_image_from_context
307 | ):
308 | flavor_backend.build_image(
309 | model_uri,
310 | destination_image_uri,
311 | install_mlflow=mlflow_source_dir is not None,
312 | mlflow_home=mlflow_source_dir,
313 | )
314 | return destination_image_uri
315 |
316 |
317 | def deploy_vertex_ai_model_to_endpoint(
318 | model_name: str,
319 | endpoint_name: Optional[str] = None,
320 | machine_type: str = "n1-standard-2",
321 | min_replica_count: int = 1,
322 | max_replica_count: int = 1,
323 | endpoint_display_name: Optional[str] = None,
324 | deployed_model_display_name: Optional[str] = None,
325 | project: Optional[str] = None,
326 | location: Optional[str] = None,
327 | ) -> google.api_core.operation.Operation:
328 | # pylint: disable=line-too-long
329 | """Deploys Google Cloud Vertex AI Model to a Google Cloud Vertex AI Endpoint.
330 |
331 | Args:
332 | model_name: Full resource name of a Google Cloud Vertex AI Model
333 | endpoint_name: Full name of Google Cloud Vertex Endpoint. A new
334 | enpoint is created if the name is not passed.
335 | machine_type: The type of the machine. See the [list of machine types
336 | supported for
337 | prediction](https://cloud.google.com/vertex-ai/docs/predictions/configure-compute#machine-types).
338 | Defaults to "n1-standard-2"
339 | min_replica_count: The minimum number of replicas the DeployedModel
340 | will be always deployed on. If traffic against it increases, it may
341 | dynamically be deployed onto more replicas up to max_replica_count, and as
342 | traffic decreases, some of these extra replicas may be freed. If the
343 | requested value is too large, the deployment will error. Defaults to 1.
344 | max_replica_count: The maximum number of replicas this DeployedModel
345 | may be deployed on when the traffic against it increases. If the requested
346 | value is too large, the deployment will error, but if deployment succeeds
347 | then the ability to scale the model to that many replicas is guaranteed
348 | (barring service outages). If traffic against the DeployedModel increases
349 | beyond what its replicas at maximum may handle, a portion of the traffic
350 | will be dropped. If this value is not provided, a no upper bound for
351 | scaling under heavy traffic will be assume, though Vertex AI may be unable
352 | to scale beyond certain replica number. Defaults to `min_replica_count`
353 | endpoint_display_name: The display name of the Endpoint. The name can
354 | be up to 128 characters long and can be consist of any UTF-8 characters.
355 | Defaults to the lowercased model ID.
356 | deployed_model_display_name: The display name of the DeployedModel. If
357 | not provided upon creation, the Model's display_name is used.
358 | project: The Google Cloud project ID. Defaults to the default project.
359 | location: The Google Cloud region. Defaults to "us-central1"
360 | timeout: Model deployment timeout
361 |
362 | Returns:
363 | google.api_core.operation.Operation:
364 | An object representing a long-running operation.
365 |
366 | The result type for the operation will be
367 | :class:`google.cloud.aiplatform_v1.types.DeployModelResponse`
368 | Response message for
369 | [EndpointService.DeployModel][google.cloud.aiplatform.v1.EndpointService.DeployModel]
370 | See
371 | https://cloud.google.com/vertex-ai/docs/reference/rpc/google.cloud.aiplatform.v1#google.cloud.aiplatform.v1.DeployedModel
372 |
373 | Examples::
374 |
375 | # Use MLflow to register the model on Cloud AI Platform
376 | model_uri = "models:/mymodel/mymodelversion" # Replace with your model URI
377 | display_name = "my_mlflow_model" # Replace with the desired model name
378 |
379 | model_name = upload_mlflow_model_to_vertex_ai_models(
380 | model_uri=model_uri,
381 | display_name=display_name,
382 | )
383 |
384 | deployed_model_id = deploy_vertex_ai_model_to_endpoint(
385 | model_name=model_name,
386 | )
387 | """
388 | model = aiplatform.Model(model_name)
389 | if endpoint_name:
390 | endpoint = aiplatform.Endpoint(endpoint_name=endpoint_name)
391 | else:
392 | # Model.deploy can create the Endpoint automatically, but I want to add label to the endpoint.
393 | if not endpoint_display_name:
394 | endpoint_display_name = model.display_name[:127]
395 | endpoint = aiplatform.Endpoint.create(
396 | display_name=endpoint_display_name,
397 | labels={
398 | "mlflow_model_vertex_ai_deployer": "mlflow_model_vertex_ai_deployer",
399 | },
400 | project=project,
401 | location=location,
402 | )
403 |
404 | return model.deploy(
405 | endpoint=endpoint,
406 | deployed_model_display_name=deployed_model_display_name,
407 | traffic_percentage=100,
408 | machine_type=machine_type,
409 | min_replica_count=min_replica_count,
410 | max_replica_count=max_replica_count,
411 | )
412 |
--------------------------------------------------------------------------------
/google_cloud_mlflow/_mlflow_models_docker_utils_patch.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 The google-cloud-mlflow Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | #
15 | import logging
16 | import subprocess
17 | import shutil
18 | import tempfile
19 | import uuid
20 |
21 |
22 | _logger = logging.getLogger(__name__)
23 |
24 |
25 | def _build_image_from_context(context_dir: str, image_name: str):
26 | return _build_image_from_context_using_cloudbuild_client(
27 | context_dir=context_dir,
28 | image_name=image_name,
29 | )
30 |
31 |
32 | def _build_image_from_context_using_cloudbuild_gcloud(
33 | context_dir: str,
34 | image_name: str,
35 | ):
36 | build_process = subprocess.Popen(
37 | args=["gcloud", "builds", "submit", "--tag", image_name, "--timeout", "1800", context_dir],
38 | stdout=subprocess.PIPE,
39 | stderr=subprocess.STDOUT,
40 | universal_newlines=True,
41 | )
42 | info_line_prefixes = [
43 | "Logs are available at ",
44 | "Step ",
45 | "Pushing ",
46 | "latest: digest: ",
47 | ]
48 | for line in iter(build_process.stdout.readline, ""):
49 | line = line.rstrip()
50 | if any(line.startswith(prefix) for prefix in info_line_prefixes):
51 | _logger.info(line)
52 | else:
53 | _logger.debug(line)
54 |
55 | if build_process.wait() != 0:
56 | raise RuntimeError("Container image build has failed.")
57 | return
58 |
59 |
60 | def _build_image_from_context_using_cloudbuild_client(
61 | context_dir: str, image_name: str
62 | ):
63 | import google
64 | import google.auth
65 | from google.cloud import storage
66 | from google.cloud.devtools import cloudbuild
67 |
68 | archive_base_name = tempfile.mktemp()
69 | context_archive_path = shutil.make_archive(
70 | base_name=archive_base_name,
71 | format="gztar",
72 | root_dir=context_dir,
73 | )
74 |
75 | _, project_id = google.auth.default()
76 |
77 | storage_client = storage.Client(project=project_id)
78 | build_client = cloudbuild.CloudBuildClient()
79 |
80 | # Staging the data in GCS
81 | bucket_name = project_id + "_cloudbuild"
82 |
83 | bucket = storage_client.lookup_bucket(bucket_name)
84 | # TODO: Throw error if bucket is in different project
85 | if bucket is None:
86 | bucket = storage_client.create_bucket(bucket_name)
87 | blob_name = f"source/{uuid.uuid4().hex}.tgz"
88 |
89 | bucket.blob(blob_name).upload_from_filename(context_archive_path)
90 |
91 | build_config = cloudbuild.Build(
92 | source=cloudbuild.Source(
93 | storage_source=cloudbuild.StorageSource(
94 | bucket=bucket_name, object_=blob_name
95 | ),
96 | ),
97 | images=[image_name],
98 | steps=[
99 | cloudbuild.BuildStep(
100 | name="gcr.io/cloud-builders/docker",
101 | args=[
102 | "build",
103 | "--network",
104 | "cloudbuild",
105 | "--no-cache",
106 | "-t",
107 | image_name,
108 | ".",
109 | ],
110 | ),
111 | ],
112 | timeout=google.protobuf.duration_pb2.Duration(
113 | seconds=1800,
114 | ),
115 | )
116 | build_operation = build_client.create_build(
117 | project_id=project_id, build=build_config
118 | )
119 | _logger.info("Submitted Cloud Build job")
120 | _logger.debug("build_operation.metadata:")
121 | _logger.debug(build_operation.metadata)
122 | _logger.info(f"Logs are available at [{build_operation.metadata.build.log_url}].")
123 |
124 | try:
125 | result = build_operation.result()
126 | except Exception as ex:
127 | _logger.error(
128 | "MLFlow container image build has failed."
129 | f" See Google Cloud Build logs here: {build_operation.metadata.build.log_url}"
130 | )
131 | raise Exception("MLFlow container image build has failed.") from ex
132 | _logger.debug("operation.result")
133 | _logger.debug(result)
134 |
135 | built_image = result.results.images[0]
136 | image_base_name = built_image.name.split(":")[0]
137 | image_digest = built_image.digest
138 | image_name_with_digest = image_base_name + "@" + image_digest
139 | return image_name_with_digest
140 |
--------------------------------------------------------------------------------
/google_cloud_mlflow/mlflow_model_deployment_plugin_for_google_cloud_vertex_ai.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 The google-cloud-mlflow Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | #
15 | """MLflow deployment plugin to deploy MLflow models to Google Cloud Vertex AI."""
16 |
17 | __all__ = [
18 | "GoogleCloudVertexAiDeploymentClient",
19 | "target_help",
20 | "run_local",
21 | ]
22 |
23 | import json
24 | import logging
25 | from typing import Any, Dict, List, Optional, Union
26 |
27 | from . import _mlflow_model_gcp_deployment_utils as vertex_utils
28 | from google.protobuf import json_format
29 | import mlflow
30 | from mlflow import deployments
31 | import numpy
32 | import pandas
33 |
34 | # TODO(b/195784726) Remove this workaround once google-cloud-aiplatform conforms
35 | # to the third_party python rules
36 | try:
37 | from google.cloud import aiplatform # pylint:disable=g-import-not-at-top
38 | except ImportError:
39 | from google.cloud.aiplatform import aiplatform # pylint:disable=g-import-not-at-top
40 |
41 |
42 | _logger = logging.getLogger(__name__)
43 |
44 | DEFAULT_MACHINE_TYPE="n1-standard-2"
45 |
46 |
47 | def _resource_to_mlflow_dict(
48 | resource: aiplatform.base.VertexAiResourceNoun,
49 | ) -> Dict[str, Any]:
50 | """Converts Vertex AI resource instance to a MLflow dict."""
51 | # TODO(avolkov): Switch to .to_dict() method when my PR is merged:
52 | # https://github.com/googleapis/python-aiplatform/pull/588
53 | resource_dict = json_format.MessageToDict(resource._gca_resource._pb) # pylint: disable=protected-access
54 | # The MLflow documentation seems to imply that the returned dicts
55 | # need to have "name" attribute set to MLflow "deployment name", not the
56 | # internal resource name.
57 | # We put MLflow deployment name into Endpoint's display_name
58 | resource_dict["resource_name"] = resource.resource_name
59 | resource_dict["name"] = resource.display_name
60 | return resource_dict
61 |
62 |
63 | def _data_to_list_of_instances(
64 | data: Union[list, numpy.ndarray, pandas.DataFrame]
65 | ) -> List:
66 | if isinstance(data, pandas.DataFrame):
67 | return data.values.tolist()
68 | if isinstance(data, numpy.ndarray):
69 | return data.tolist()
70 | if isinstance(data, list):
71 | return data
72 | raise TypeError(f"Unsupported data type {type(data)}")
73 |
74 |
75 | class GoogleCloudVertexAiDeploymentClient(deployments.BaseDeploymentClient):
76 | """The Google Cloud Vertex AI implementation of the BaseDeploymentClient."""
77 |
78 | def create_deployment(
79 | self,
80 | name: str,
81 | model_uri: str,
82 | flavor: Optional[str] = None,
83 | config: Optional[Dict[str, str]] = None,
84 | ) -> Dict[str, Any]:
85 | """Deploys the model.
86 |
87 | Deploys a model to the specified target. By default, this method should
88 | block until deployment completes (i.e. until it's possible to perform
89 | inference with the deployment). In the case of conflicts (e.g. if it's
90 | not possible to create the specified deployment without due to conflict
91 | with an existing deployment), raises a
92 | :py:class:`mlflow.exceptions.MlflowException`. See target-specific
93 | plugin documentation for additional detail on support for asynchronous
94 | deployment and other configuration.
95 |
96 | Example::
97 |
98 | from mlflow import deployments
99 | client = deployments.get_deploy_client("google_cloud")
100 | deployment = client.create_deployment(
101 | name="deployment name",
102 | model_uri=...,
103 | # Config is optional
104 | config=dict(
105 | # Deployed model config
106 | machine_type="n1-standard-2",
107 | min_replica_count=None,
108 | max_replica_count=None,
109 | accelerator_type=None,
110 | accelerator_count=None,
111 | service_account=None,
112 | explanation_metadata=None, # JSON string
113 | explanation_parameters=None, # JSON string
114 |
115 | # Model container image building config
116 | destination_image_uri=None,
117 |
118 | # Endpoint config
119 | endpoint_description=None,
120 | endpoint_deploy_timeout=None,
121 |
122 | # Vertex AI config
123 | project=None,
124 | location=None,
125 | encryption_spec_key_name=None,
126 | staging_bucket=None,
127 | )
128 | )
129 |
130 | Args:
131 | name: Unique name to use for deployment. If another deployment
132 | exists with the same name, raises a
133 | :py:class:`mlflow.exceptions.MlflowException`
134 | model_uri: URI of model to deploy
135 | flavor: (optional) The MLflow model flavor to deploy.
136 | If unspecified, the default flavor will be chosen.
137 | config: (optional) Dict containing updated target-specific configuration for the
138 | deployment
139 |
140 | Returns:
141 | A dict corresponding to the created deployment, which must contain the 'name' key.
142 | """
143 | config = config or {}
144 |
145 | project = config.get("project")
146 | location = config.get("location")
147 | encryption_spec_key_name = config.get("encryption_spec_key_name")
148 | staging_bucket = config.get("staging_bucket")
149 |
150 | existing_endpoints = aiplatform.Endpoint.list(
151 | filter=f'display_name="{name}"',
152 | project=project,
153 | location=location,
154 | )
155 | if existing_endpoints:
156 | raise mlflow.exceptions.MlflowException(
157 | f"Found existing deployment with name {name}: " +
158 | ",".join(list(endpoint.resource_name for endpoint in existing_endpoints)))
159 |
160 | model_name = vertex_utils.upload_mlflow_model_to_vertex_ai_models(
161 | model_uri=model_uri,
162 | display_name=name,
163 | destination_image_uri=config.get("destination_image_uri"),
164 | project=project,
165 | location=location,
166 | encryption_spec_key_name=encryption_spec_key_name,
167 | staging_bucket=staging_bucket,
168 | )
169 | endpoint = aiplatform.Endpoint.create(
170 | display_name=name,
171 | description=config.get("endpoint_description"),
172 | labels={
173 | "google_cloud_mlflow_plugin_version": "0-0-1",
174 | },
175 | project=project,
176 | location=location,
177 | encryption_spec_key_name=encryption_spec_key_name,
178 | )
179 | endpoint.deploy(
180 | model=aiplatform.Model(model_name=model_name),
181 | deployed_model_display_name=name,
182 | traffic_percentage=100,
183 | # Need to always specify the machine type to prevent this error:
184 | # InvalidArgument: 400 'automatic_resources' is not supported for Model
185 | # The choice between "dedicated_resources" and "automatic_resources"
186 | # (only supported with AutoML models) is based on the presence of
187 | # machine_type.
188 | machine_type=config.get("machine_type", DEFAULT_MACHINE_TYPE),
189 | min_replica_count=int(config.get("min_replica_count", 1)),
190 | max_replica_count=int(config.get("max_replica_count", 1)),
191 | accelerator_type=config.get("accelerator_type"),
192 | accelerator_count=int(config.get("accelerator_count", 0)) or None,
193 | service_account=config.get("service_account"),
194 | explanation_metadata=(json.loads(config.get("explanation_metadata")) if "explanation_metadata" in config else None),
195 | explanation_parameters=(json.loads(config.get("explanation_parameters")) if "explanation_parameters" in config else None),
196 | sync=json.loads(config.get("sync", "true")),
197 | deploy_request_timeout=int(config.get("endpoint_deploy_timeout", 1800)),
198 | )
199 |
200 | deployment_dict = _resource_to_mlflow_dict(endpoint)
201 | deployment_dict["flavor"] = flavor
202 | return deployment_dict
203 |
204 | def get_deployment(self, name: str) -> Dict[str, Any]:
205 | """Gets deployment by name.
206 |
207 | Args:
208 | name: ID of deployment to fetch
209 |
210 | Returns:
211 | A dictionary describing the specified deployment, throwing a
212 | py:class:`mlflow.exception.MlflowException` if no deployment exists with the provided ID.
213 | The dict is guaranteed to contain an 'name' key containing the deployment name.
214 | The other fields of the returned dictionary and their types may vary across
215 | deployment targets.
216 | """
217 | return _resource_to_mlflow_dict(
218 | self._get_deployment(deployment_name=name)
219 | )
220 |
221 | def _get_deployment(self, deployment_name: str) -> aiplatform.Endpoint:
222 | endpoints = aiplatform.Endpoint.list(filter=f'display_name="{deployment_name}"')
223 | if len(endpoints) > 1:
224 | raise mlflow.exceptions.MlflowException(
225 | f"Found multiple deployments with name {deployment_name}: " +
226 | ",".join(list(endpoint.resource_name for endpoint in endpoints)))
227 | if endpoints:
228 | return endpoints[0]
229 | raise mlflow.exceptions.MlflowException(
230 | f"Could not find deployment with name {deployment_name}."
231 | )
232 |
233 | def list_deployments(self) -> List[Dict[str, Any]]:
234 | """Lists all deployments.
235 |
236 | Returns:
237 | A list of dicts corresponding to deployments. Each dict is guaranteed to
238 | contain a 'name' key containing the deployment name. The other fields of
239 | the returned dictionary and their types may vary across deployment targets.
240 | """
241 | endpoints = aiplatform.Endpoint.list(filter='labels.google_cloud_mlflow_plugin_version:*')
242 | endpoint_dicts = list(map(_resource_to_mlflow_dict, endpoints))
243 | return endpoint_dicts
244 |
245 | def delete_deployment(self, name: str) -> None:
246 | """Deletes the deployment.
247 |
248 | Deletes the deployment with name ``name`` from the specified target.
249 | Deletion is idempotent (i.e. deletion does not fail if retried on a
250 | non-existent deployment).
251 |
252 | Args:
253 | name: The name of deployment to delete
254 | """
255 | deployment_name = name
256 | # Locate deployment endpoint from MLflow deployment list
257 | # Using Endpoint.delete with force=True will undeploy all models on Endpoint
258 | # before deleting the Endpoint
259 | endpoints = aiplatform.Endpoint.list(filter=f'display_name="{deployment_name}"')
260 | if len(endpoints) > 1:
261 | raise mlflow.exceptions.MlflowException(
262 | f"Found multiple deployments with name {deployment_name}: " +
263 | ",".join(list(endpoint.resource_name for endpoint in endpoints)))
264 | if endpoints:
265 | endpoint = endpoints[0]
266 | endpoint.delete(force=True)
267 |
268 | def update_deployment(
269 | self,
270 | name: str,
271 | model_uri: Optional[str] = None,
272 | flavor: Optional[str] = None,
273 | config: Optional[Dict[str, Any]] = None,
274 | ) -> Dict[str, Any]:
275 | """Updates the deployment with the specified name.
276 |
277 | You can update the URI of the model, the flavor of the deployed model
278 | (in which case the model URI must also be specified), and/or any
279 | target-specific attributes of the deployment (via `config`). By default,
280 | this method should block until deployment completes (i.e. until it's
281 | possible to perform inference with the updated deployment).
282 | See target-specific plugin documentation for additional detail on support
283 | for asynchronous deployment and other configuration.
284 |
285 | Args:
286 | name: Unique name of deployment to update
287 | model_uri: URI of a new model to deploy.
288 | flavor: (optional) new model flavor to use for deployment. If provided,
289 | ``model_uri`` must also be specified. If ``flavor`` is unspecified but
290 | ``model_uri`` is specified, a default flavor will be chosen and the
291 | deployment will be updated using that flavor.
292 | config: (optional) dict containing updated target-specific configuration for the
293 | deployment
294 |
295 | Returns:
296 | A dict corresponding to the created deployment, which must contain the 'name' key.
297 | """
298 | self.delete_deployment(name=name)
299 | return self.create_deployment(
300 | name=name,
301 | model_uri=model_uri,
302 | flavor=flavor,
303 | config=config,
304 | )
305 |
306 | def predict(
307 | self,
308 | deployment_name: str,
309 | df: Union[List, numpy.ndarray, pandas.DataFrame],
310 | ) -> Union[pandas.DataFrame, pandas.Series, numpy.ndarray, Dict[str, Any]]:
311 | """Computes model predictions.
312 |
313 | Compute predictions on the pandas DataFrame ``df`` using the specified
314 | deployment. Note that the input/output types of this method matches that
315 | of `mlflow pyfunc predict` (we accept a pandas DataFrame as input and
316 | return either a pandas DataFrame, pandas Series, or numpy array as output).
317 |
318 | Args:
319 | deployment_name: Name of deployment to predict against
320 | df: Pandas DataFrame to use for inference
321 |
322 | Returns:
323 | A pandas DataFrame, pandas Series, or numpy array
324 |
325 | Example::
326 |
327 | from mlflow import deployments
328 | import pandas
329 | df = pandas.DataFrame(
330 | [
331 | {"a": 1,"b": 2,"c": 3},
332 | {"a": 4,"b": 5,"c": 6}
333 | ]
334 | )
335 | client = deployments.get_deploy_client("google_cloud")
336 | client.create_deployment("deployment name", model_uri=...)
337 | predictions = client.predict("deployment name", df)
338 | """
339 | endpoint = self._get_deployment(deployment_name=deployment_name)
340 | predictions = endpoint.predict(
341 | instances=_data_to_list_of_instances(df),
342 | )
343 | return predictions
344 |
345 | def explain(
346 | self,
347 | deployment_name: str,
348 | df: Union[List, numpy.ndarray, pandas.DataFrame],
349 | ) -> Union[pandas.DataFrame, pandas.Series, numpy.ndarray, Dict[str, Any]]:
350 | """Generates explanations of model predictions.
351 |
352 | Generate explanations of model predictions on the specified input pandas Dataframe
353 | ``df`` for the deployed model. Explanation output formats vary by deployment target,
354 | and can include details like feature importance for understanding/debugging predictions.
355 |
356 | Args:
357 | deployment_name: Name of deployment to predict against
358 | df: Pandas DataFrame to use for inference
359 |
360 | Returns:
361 | A JSON-able object (pandas dataframe, numpy array, dictionary), or
362 | an exception if the implementation is not available in deployment target's class
363 |
364 | Example::
365 |
366 | from mlflow import deployments
367 | import pandas
368 | df = pandas.DataFrame(
369 | [
370 | {"a": 1,"b": 2,"c": 3},
371 | {"a": 4,"b": 5,"c": 6}
372 | ]
373 | )
374 | client = deployments.get_deploy_client("google_cloud")
375 | client.create_deployment("deployment name", model_uri=...)
376 | predictions = client.explain("deployment name", df)
377 | """
378 | endpoint = self._get_deployment(deployment_name=deployment_name)
379 | predictions = endpoint.explain(
380 | instances=_data_to_list_of_instances(df),
381 | )
382 | return predictions
383 |
384 |
385 | def run_local(
386 | name: str,
387 | model_uri: str,
388 | flavor: Optional[str] = None,
389 | config: Optional[Dict[str, str]] = None,
390 | ):
391 | """Deploys the specified model locally, for testing.
392 |
393 | Args:
394 | name: Unique name to use for deployment. If another deployment exists with
395 | the same name, create_deployment will raise a
396 | :py:class:`mlflow.exceptions.MlflowException`
397 | model_uri: URI of model to deploy
398 | flavor: Model flavor to deploy. If unspecified, default flavor is chosen.
399 | config: Dict containing updated target-specific config for the deployment
400 | """
401 | raise NotImplementedError()
402 |
403 |
404 | def target_help():
405 | """Returns help string.
406 |
407 | Returns a string containing detailed documentation on the current deployment
408 | target, to be displayed when users invoke the
409 | ``mlflow deployments help -t `` CLI command.
410 | """
411 | return """
412 | MLflow deployment plugin to deploy MLflow models to Google Cloud Vertex AI.
413 |
414 | Example::
415 |
416 | from mlflow import deployments
417 | client = deployments.get_deploy_client("google_cloud")
418 | deployment = client.create_deployment(
419 | name="deployment name",
420 | model_uri=...,
421 | # Config is optional
422 | config=dict(
423 | # Deployed model config
424 | machine_type="n1-standard-2",
425 | min_replica_count=None,
426 | max_replica_count=None,
427 | accelerator_type=None,
428 | accelerator_count=None,
429 | service_account=None,
430 | explanation_metadata=None, # JSON string
431 | explanation_parameters=None, # JSON string
432 |
433 | # Model container image building config
434 | destination_image_uri=None,
435 |
436 | # Model deployment config
437 | sync="true",
438 |
439 | # Endpoint config
440 | description=None,
441 |
442 | # Vertex AI config
443 | project=None,
444 | location=None,
445 | experiment=None,
446 | experiment_description=None,
447 | staging_bucket=None,
448 | )
449 | )
450 | """
451 |
--------------------------------------------------------------------------------
/google_cloud_mlflow/mlflow_model_registry_plugin_for_google_cloud_storage.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 The google-cloud-mlflow Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | #
15 | """MLflow Model Registry plugin to allow logging models to Google Cloud Storage."""
16 |
17 | import datetime
18 | from typing import Iterator, List, Optional
19 |
20 | import google
21 | from google.cloud import storage
22 | from google.protobuf import json_format
23 | import mlflow
24 | from mlflow.entities import model_registry
25 | from mlflow.entities.model_registry import model_version_stages
26 | from mlflow.protos import databricks_pb2
27 | from mlflow.protos import model_registry_pb2
28 | from mlflow.store.entities import paged_list
29 | from mlflow.utils.search_utils import SearchUtils
30 | from mlflow.utils.validation import (
31 | _validate_registered_model_tag,
32 | _validate_model_version_tag,
33 | _validate_model_name,
34 | _validate_model_version,
35 | _validate_tag_name,
36 | )
37 |
38 |
39 | class GoogleCloudStorageModelRegistry(
40 | mlflow.store.model_registry.abstract_store.AbstractStore
41 | ):
42 | """Class for storing Model Registry metadata."""
43 |
44 | _MODEL_INFO_FILE_NAME: str = "model_info.json"
45 | _MODEL_VERSION_INFO_FILE_NAME: str = "model_version_info.json"
46 | _LAST_MODEL_VERSION_FILE_NAME: str = "last_model_version"
47 |
48 | DELETE_MODEL_VERSIONS_INSTEAD_OF_MARKING_AS_DELETED = False
49 |
50 | def __init__(self, store_uri: str):
51 | if not _validate_store_uri(store_uri):
52 | raise mlflow.exceptions.MlflowException(
53 | f"Bad store_uri format: {store_uri}"
54 | )
55 | store_uri = store_uri.rstrip("/") + "/"
56 | self._base_uri = store_uri
57 |
58 | # CRUD API for RegisteredModel objects
59 |
60 | def _get_model_dir(self, name: str) -> str:
61 | return self._base_uri + name + "/"
62 |
63 | def _get_model_info_file_path(self, name: str) -> str:
64 | return self._get_model_dir(name=name) + self._MODEL_INFO_FILE_NAME
65 |
66 | def _get_model_version_dir(self, name: str, version: str) -> str:
67 | return self._get_model_dir(name=name) + version + "/"
68 |
69 | def _get_model_version_info_file_path(self, name: str, version: str) -> str:
70 | return (
71 | self._get_model_version_dir(name=name, version=version)
72 | + self._MODEL_VERSION_INFO_FILE_NAME
73 | )
74 |
75 | def create_registered_model(
76 | self,
77 | name: str,
78 | tags: List[model_registry.RegisteredModelTag] = None,
79 | description: str = None,
80 | ) -> model_registry.RegisteredModel:
81 | """Creates a new registered model in backend store.
82 |
83 | Args:
84 | name: Name of the new model. This is expected to be unique in the backend store.
85 | tags: A list of :py:class:`mlflow.entities.model_registry.RegisteredModelTag`
86 | instances associated with this registered model.
87 | description: Description of the model.
88 |
89 | Returns:
90 | A single object of :py:class:`mlflow.entities.model_registry.RegisteredModel`
91 | created in the backend.
92 | """
93 | # TODO(avolkov): Validate that the model does not exist.
94 | _validate_model_name(name)
95 | for tag in tags or []:
96 | _validate_registered_model_tag(tag.key, tag.value)
97 | current_time = datetime.datetime.utcnow()
98 | model = model_registry.RegisteredModel(
99 | name=name,
100 | creation_timestamp=int(current_time.timestamp()),
101 | last_updated_timestamp=int(current_time.timestamp()),
102 | description=description,
103 | tags=tags,
104 | )
105 | self._set_registered_model_proto(
106 | name, model.to_proto(), update_modification_time=False
107 | )
108 | return model
109 |
110 | def update_registered_model(
111 | self,
112 | name: str,
113 | description: str,
114 | ) -> model_registry.RegisteredModel:
115 | """Updates description of the registered model.
116 |
117 | Args:
118 | name: Registered model name.
119 | description: New description.
120 |
121 | Returns:
122 | A single updated :py:class:`mlflow.entities.model_registry.RegisteredModel` object.
123 | """
124 | model = self.get_registered_model(name=name)
125 | model_proto = model.to_proto()
126 | model_proto.description = description
127 | self._set_registered_model_proto(name, model_proto)
128 | return model
129 |
130 | def rename_registered_model(
131 | self,
132 | name: str,
133 | new_name: str,
134 | ) -> model_registry.RegisteredModel:
135 | """Renames the registered model.
136 |
137 | Args:
138 | name: Registered model name.
139 | new_name: New proposed name.
140 |
141 | Returns:
142 | A single updated :py:class:`mlflow.entities.model_registry.RegisteredModel` object.
143 | """
144 | _validate_model_name(name)
145 | _validate_model_name(new_name)
146 | model = self.get_registered_model(name)
147 | if new_name == name:
148 | return model
149 | model_proto = model.to_proto()
150 | model_proto.name = new_name
151 | src_dir = self._get_model_dir(name=name)
152 | dst_dir = self._get_model_dir(name=new_name)
153 | src_dir_blob = storage.Blob.from_string(uri=src_dir)
154 | dst_dir_blob = storage.Blob.from_string(uri=dst_dir)
155 | bucket = src_dir_blob.bucket
156 | src_path = src_dir_blob.name
157 | dst_path = dst_dir_blob.name
158 | blobs: List[storage.Blob] = list(
159 | storage.Client().list_blobs(
160 | # Using bucket name as a workaround for
161 | # https://github.com/googleapis/python-storage/issues/540
162 | bucket_or_name=bucket.name,
163 | prefix=src_path,
164 | )
165 | )
166 | blobs_to_delete = []
167 | for blob in blobs:
168 | assert blob.name.startswith(src_path)
169 | new_blob_name = blob.name.replace(src_path, dst_path, 1)
170 | if blob.name.endswith("/" + self._MODEL_VERSION_INFO_FILE_NAME):
171 | blobs_to_delete.append(blob)
172 | model_version_proto = _json_to_registered_model_version(
173 | blob.download_as_text()
174 | ).to_proto()
175 | model_version_proto.name = new_name
176 | self._set_model_version_proto(
177 | model_version_proto.name,
178 | model_version_proto.version,
179 | model_version_proto,
180 | update_modification_time=False,
181 | )
182 | elif blob.name.endswith("/" + self._LAST_MODEL_VERSION_FILE_NAME):
183 | blobs_to_delete.append(blob)
184 | blob.bucket.copy_blob(
185 | blob=blob, destination_bucket=blob.bucket, new_name=new_blob_name
186 | )
187 | elif blob.name.endswith("/" + self._MODEL_INFO_FILE_NAME):
188 | blobs_to_delete.append(blob)
189 |
190 | # _set_registered_model_proto also updates
191 | # model_proto.last_updated_timestamp
192 | self._set_registered_model_proto(new_name, model_proto)
193 | # After we copy and update all model files, we delete the original blobs.
194 | for blob in blobs_to_delete:
195 | blob.delete()
196 | # The blob is not available right after writing.
197 | # So we cannot use get_registered_model(new_name)
198 | return model_registry.RegisteredModel.from_proto(model_proto)
199 |
200 | def delete_registered_model(self, name: str) -> None:
201 | """Deletes the registered model.
202 |
203 | Backend raises exception if a registered model with given name does not exist.
204 |
205 | Args:
206 | name: Registered model name.
207 |
208 | Returns:
209 | None
210 | """
211 | _validate_model_name(name)
212 | src_dir = self._get_model_dir(name=name)
213 | src_dir_blob = storage.Blob.from_string(uri=src_dir)
214 | bucket = src_dir_blob.bucket
215 | src_path = src_dir_blob.name
216 | blobs: Iterator[storage.Blob] = storage.Client().list_blobs(
217 | # Using bucket name as a workaround for
218 | # https://github.com/googleapis/python-storage/issues/540
219 | bucket_or_name=bucket.name,
220 | prefix=src_path,
221 | )
222 | for blob in blobs:
223 | blob.delete()
224 |
225 | def list_registered_models(
226 | self,
227 | max_results: Optional[int],
228 | page_token: Optional[str] = None,
229 | ) -> paged_list.PagedList[model_registry.RegisteredModel]:
230 | """Lists all registered models.
231 |
232 | Args:
233 | max_results: Maximum number of registered models desired.
234 | page_token: Token specifying the next page of results. It should be obtained from
235 | a ``list_registered_models`` call.
236 |
237 | Returns:
238 | A PagedList of :py:class:`mlflow.entities.model_registry.RegisteredModel` objects.
239 | The pagination token for the next page can be obtained via the ``token`` attribute
240 | of the object.
241 | """
242 | root_dir = storage.Blob.from_string(uri=self._base_uri)
243 | blob_iterator = storage.Client().list_blobs(
244 | # Using bucket name as a workaround for
245 | # https://github.com/googleapis/python-storage/issues/540
246 | bucket_or_name=root_dir.bucket.name,
247 | prefix=root_dir.name,
248 | max_results=max_results,
249 | page_token=page_token,
250 | )
251 | models = [
252 | _json_to_registered_model(blob.download_as_text())
253 | for blob in blob_iterator
254 | if blob.name.endswith(self._MODEL_INFO_FILE_NAME)
255 | ]
256 | return paged_list.PagedList(items=models, token=blob_iterator.next_page_token)
257 |
258 | def search_registered_models(
259 | self,
260 | filter_string: str,
261 | max_results: Optional[int] = None,
262 | order_by: Optional[str] = None,
263 | page_token: Optional[str] = None,
264 | ) -> paged_list.PagedList[model_registry.RegisteredModel]:
265 | """Searches for registered models that satisfy the filter criteria.
266 |
267 | Args:
268 | filter_string: Filter query string, defaults to searching all registered models.
269 | max_results: Maximum number of registered models desired. (Default value = None)
270 | order_by: List of column names with ASC|DESC annotation, to be used for ordering
271 | matching search results. (Default value = None)
272 | page_token: Token specifying the next page of results. It should be obtained from
273 | a ``search_registered_models`` call. (Default value = None)
274 |
275 | Returns:
276 | A PagedList of :py:class:`mlflow.entities.model_registry.RegisteredModel` objects
277 | that satisfy the search expressions. The pagination token for the next page can be
278 | obtained via the ``token`` attribute of the object.
279 | """
280 | del page_token
281 | parsed_filters = SearchUtils.parse_filter_for_registered_models(filter_string)
282 | if order_by:
283 | (
284 | ordering_key,
285 | ordering_is_ascending,
286 | ) = SearchUtils.parse_order_by_for_search_registered_models(order_by)
287 | models = self._list_models()
288 | for parsed_filter in parsed_filters:
289 | if parsed_filter["comparator"] != "=":
290 | raise mlflow.exceptions.MlflowException(
291 | "Model Registry search filter only supports equality(=) "
292 | f"comparator. Input filter string: {filter_string}",
293 | error_code=databricks_pb2.INVALID_PARAMETER_VALUE,
294 | )
295 | # Key validated by `parse_filter_for_models`
296 | key = parsed_filter["key"]
297 | value = parsed_filter["value"]
298 | models = [model for model in models if getattr(model, key, None) == value]
299 | if order_by:
300 | models.sort(
301 | key=lambda x: getattr(x, ordering_key, None),
302 | reversed=not ordering_is_ascending,
303 | )
304 | if max_results:
305 | models = models[0:max_results]
306 | return models
307 |
308 | def get_registered_model(self, name: str) -> model_registry.RegisteredModel:
309 | """Gets registered model instance by name.
310 |
311 | Args:
312 | name: Registered model name.
313 |
314 | Returns:
315 | A single :py:class:`mlflow.entities.model_registry.RegisteredModel` object.
316 | """
317 | _validate_model_name(name)
318 | model_uri = self._get_model_info_file_path(name=name)
319 | try:
320 | model_json = storage.Blob.from_string(
321 | uri=model_uri,
322 | # Workaround for https://github.com/googleapis/python-storage/issues/540
323 | client=storage.Client(),
324 | ).download_as_text()
325 | model = _json_to_registered_model(model_json)
326 | return model
327 | except google.api_core.exceptions.NotFound:
328 | raise mlflow.exceptions.MlflowException(
329 | message=f'Model "{name}" does not exist',
330 | error_code=databricks_pb2.RESOURCE_DOES_NOT_EXIST,
331 | )
332 |
333 | def _set_registered_model_proto(
334 | self,
335 | name: str,
336 | model_proto: model_registry_pb2.RegisteredModel,
337 | update_modification_time: bool = True,
338 | ) -> None:
339 | """Sets registered model instance.
340 |
341 | Args:
342 | name: Registered model name.
343 | model_proto: A single
344 | :py:class:`mlflow.entities.model_registry_pb2.RegisteredModel` object.
345 | update_modification_time: Whether to update the modification time
346 | """
347 | _validate_model_name(name)
348 | if update_modification_time:
349 | current_time = datetime.datetime.utcnow()
350 | model_proto.last_updated_timestamp = int(current_time.timestamp())
351 | model_json = json_format.MessageToJson(model_proto)
352 | model_uri = self._get_model_info_file_path(name=name)
353 | storage.Blob.from_string(
354 | uri=model_uri,
355 | # Workaround for https://github.com/googleapis/python-storage/issues/540
356 | client=storage.Client(),
357 | ).upload_from_string(data=model_json)
358 |
359 | def _list_models(self) -> List[model_registry.RegisteredModel]:
360 | """Lists all registered models.
361 |
362 | Returns:
363 | A list of :py:class:`mlflow.entities.model_registry.RegisteredModel` objects.
364 | """
365 | base_dir_blob = storage.Blob.from_string(uri=self._base_uri)
366 | blob_iterator = storage.Client().list_blobs(
367 | # Using bucket name as a workaround for
368 | # https://github.com/googleapis/python-storage/issues/540
369 | bucket_or_name=base_dir_blob.bucket.name,
370 | prefix=base_dir_blob.name,
371 | )
372 | models = [
373 | _json_to_registered_model(blob.download_as_text())
374 | for blob in blob_iterator
375 | if blob.name.endswith(self._MODEL_INFO_FILE_NAME)
376 | ]
377 | return models
378 |
379 | def get_latest_versions(
380 | self, name: str, stages: List[str] = None
381 | ) -> List[model_registry.ModelVersion]:
382 | """Gets the latest model version for each requested stage.
383 |
384 | If no ``stages`` argument is provided, returns the latest version for each stage.
385 |
386 | Args:
387 | name: Registered model name.
388 | stages: List of desired stages. If input list is None, return latest versions for
389 | for 'Staging' and 'Production' stages. (Default value = None)
390 |
391 | Returns:
392 | List of :py:class:`mlflow.entities.model_registry.ModelVersion` objects.
393 | """
394 | _validate_model_name(name)
395 | stages = stages or model_version_stages.DEFAULT_STAGES_FOR_GET_LATEST_VERSIONS
396 | stages = [model_version_stages.get_canonical_stage(stage) for stage in stages]
397 | latest_versions = {}
398 | model_versions = self._list_model_versions(name=name)
399 | for model_version in model_versions:
400 | stage = model_version.current_stage
401 | if stage not in stages:
402 | continue
403 | if stage not in latest_versions or int(
404 | latest_versions[stage].version
405 | ) < int(model_version.version):
406 | latest_versions[stage] = model_version
407 | latest_versions_list = [latest_versions.get(stage) or None for stage in stages]
408 | return latest_versions_list
409 |
410 | def set_registered_model_tag(
411 | self,
412 | name: str,
413 | tag: model_registry.RegisteredModelTag,
414 | ) -> None:
415 | """Sets a tag for the registered model.
416 |
417 | Args:
418 | name: Registered model name.
419 | tag: py:class:`mlflow.entities.model_registry.RegisteredModelTag` instance to log.
420 |
421 | Returns:
422 | None
423 | """
424 | _validate_registered_model_tag(key=tag.key, value=tag.value)
425 | model = self.get_registered_model(name=name)
426 | model_proto = model.to_proto()
427 | # model.tags is not a proper map.
428 | # Deleting possible previous tag values first.
429 | for existing_tag in model_proto.tags:
430 | if existing_tag.key == tag.key:
431 | model_proto.tags.remove(existing_tag)
432 | model_proto.tags.append(tag)
433 | self._set_registered_model_proto(name, model_proto)
434 |
435 | def delete_registered_model_tag(self, name: str, key: str) -> None:
436 | """Delete a tag associated with the registered model.
437 |
438 | Args:
439 | name: Registered model name.
440 | key: Registered model tag key.
441 |
442 | Returns:
443 | None
444 | """
445 | _validate_tag_name(name=key)
446 | model = self.get_registered_model(name=name)
447 | model_proto = model.to_proto()
448 | for tag in model_proto.tags:
449 | if tag.key == key:
450 | model_proto.tags.remove(tag)
451 | self._set_registered_model_proto(name, model_proto)
452 |
453 | # CRUD API for ModelVersion objects
454 |
455 | def _increment_last_model_version(self, name: str) -> int:
456 | """Increments and returns the last model version for a given model.
457 |
458 | Args:
459 | name: Registered model name.
460 |
461 | Returns:
462 | The version number for the next model version.
463 | """
464 | _validate_model_name(name)
465 | model_dir_uri = self._get_model_dir(name=name)
466 | last_model_version_file_uri = model_dir_uri + self._LAST_MODEL_VERSION_FILE_NAME
467 | last_model_version_file_blob = storage.Blob.from_string(
468 | uri=last_model_version_file_uri,
469 | # Workaround for https://github.com/googleapis/python-storage/issues/540
470 | client=storage.Client(),
471 | )
472 | last_model_version = 0
473 | try:
474 | last_model_version = int(last_model_version_file_blob.download_as_text())
475 | except google.api_core.exceptions.NotFound:
476 | pass
477 | last_model_version += 1
478 | last_model_version_file_blob.upload_from_string(
479 | data=str(last_model_version),
480 | # Avoiding race condition
481 | if_generation_match=last_model_version_file_blob.generation,
482 | )
483 | return str(last_model_version)
484 |
485 | def create_model_version(
486 | self,
487 | name: str,
488 | source: str,
489 | run_id: str = None,
490 | tags: List[model_registry.ModelVersionTag] = None,
491 | run_link: str = None,
492 | description: str = None,
493 | ) -> model_registry.ModelVersion:
494 | """Creates a new model version from given source and run ID.
495 |
496 | Args:
497 | name: Registered model name.
498 | source: Source path where the MLflow model is stored.
499 | run_id: Run ID from MLflow tracking server that generated the model. (Default value = None)
500 | tags: A list of :py:class:`mlflow.entities.model_registry.ModelVersionTag`
501 | instances associated with this model version. (Default value = None)
502 | run_link: Link to the run from an MLflow tracking server that generated this model. (Default value = None)
503 | description: Description of the version. (Default value = None)
504 |
505 | Returns:
506 | A single object of :py:class:`mlflow.entities.model_registry.ModelVersion`
507 | created in the backend.
508 | """
509 | _validate_model_name(name)
510 | version = self._increment_last_model_version(name=name)
511 | current_time = datetime.datetime.utcnow()
512 | model_version = model_registry.ModelVersion(
513 | name=name,
514 | version=version,
515 | creation_timestamp=int(current_time.timestamp()),
516 | last_updated_timestamp=int(current_time.timestamp()),
517 | description=description,
518 | source=source,
519 | run_id=run_id,
520 | tags=tags,
521 | run_link=run_link,
522 | )
523 | self._set_model_version_proto(
524 | name=name, version=version, model_version_proto=model_version.to_proto()
525 | )
526 | return model_version
527 |
528 | def update_model_version(
529 | self,
530 | name: str,
531 | version: str,
532 | description: str,
533 | ) -> model_registry.ModelVersion:
534 | """Updates metadata associated with a model version in backend.
535 |
536 | Args:
537 | name: Registered model name.
538 | version: Registered model version.
539 | description: New model description.
540 |
541 | Returns:
542 | A single :py:class:`mlflow.entities.model_registry.ModelVersion` object.
543 | """
544 | model_version = self.get_model_version(name=name, version=version)
545 | model_version_proto = model_version.to_proto()
546 | model_version_proto.description = description
547 | self._set_model_version_proto(
548 | name=name, version=version, model_version_proto=model_version_proto
549 | )
550 | return model_version
551 |
552 | def transition_model_version_stage(
553 | self,
554 | name: str,
555 | version: str,
556 | stage: str,
557 | archive_existing_versions: bool = False,
558 | ) -> model_registry.ModelVersion:
559 | """Updates model version stage.
560 |
561 | Args:
562 | name: Registered model name.
563 | version: Registered model version.
564 | stage: New desired stage for this model version.
565 | archive_existing_versions: If this flag is set to ``True``, all existing model
566 | versions in the stage will be automically moved to the "archived" stage. Only valid
567 | when ``stage`` is ``"staging"`` or ``"production"`` otherwise an error will be raised.
568 |
569 | Returns:
570 | A single :py:class:`mlflow.entities.model_registry.ModelVersion` object.
571 | """
572 | model_version = self.get_model_version(name=name, version=version)
573 | stage = model_version_stages.get_canonical_stage(stage)
574 | model_version_proto = model_version.to_proto()
575 | model_version_proto.current_stage = stage
576 | self._set_model_version_proto(name, version, model_version_proto)
577 | if archive_existing_versions:
578 | all_model_versions = self._list_model_versions(name=name)
579 | all_other_model_versions = list(
580 | filter(
581 | lambda model_version: model_version.name == name, all_model_versions
582 | )
583 | )
584 | all_other_model_versions_in_same_stage = list(
585 | filter(
586 | lambda model_version: model_version.current_stage == stage,
587 | all_other_model_versions,
588 | )
589 | )
590 | for other_model_version in all_other_model_versions_in_same_stage:
591 | other_model_version_proto = other_model_version.to_proto()
592 | other_model_version_proto.current_stage = (
593 | model_version_stages.STAGE_ARCHIVED
594 | )
595 | self._set_model_version_proto(
596 | name=other_model_version.name,
597 | version=other_model_version.version,
598 | model_version_proto=other_model_version_proto,
599 | )
600 | return model_version
601 |
602 | def delete_model_version(self, name: str, version: str) -> None:
603 | """Deletes model version in backend.
604 |
605 | Args:
606 | name: Registered model name.
607 | version: Registered model version.
608 |
609 | Returns:
610 | None
611 | """
612 | # Validating that the version is a proper model version
613 | model_version = self.get_model_version(name=name, version=version)
614 | if self.DELETE_MODEL_VERSIONS_INSTEAD_OF_MARKING_AS_DELETED:
615 | model_version_uri = self._get_model_version_info_file_path(
616 | name=name, version=version
617 | )
618 | storage.Blob.from_string(
619 | uri=model_version_uri,
620 | # Workaround for https://github.com/googleapis/python-storage/issues/540
621 | client=storage.Client(),
622 | ).delete()
623 | else:
624 | model_version_proto = model_version.to_proto()
625 | model_version_proto.current_stage = (
626 | model_version_stages.STAGE_DELETED_INTERNAL
627 | )
628 | self._set_model_version_proto(
629 | name=name,
630 | version=version,
631 | model_version_proto=model_version_proto,
632 | )
633 |
634 | def get_model_version(self, name: str, version: str) -> model_registry.ModelVersion:
635 | """Gets the model version instance by name and version.
636 |
637 | Args:
638 | name: Registered model name.
639 | version: Registered model version.
640 |
641 | Returns:
642 | A single :py:class:`mlflow.entities.model_registry.ModelVersion` object.
643 | """
644 | _validate_model_name(name)
645 | _validate_model_version(version)
646 | model_version_uri = self._get_model_version_info_file_path(
647 | name=name, version=version
648 | )
649 | try:
650 | model_version_json = storage.Blob.from_string(
651 | uri=model_version_uri,
652 | # Workaround for https://github.com/googleapis/python-storage/issues/540
653 | client=storage.Client(),
654 | ).download_as_text()
655 | model_version = _json_to_registered_model_version(model_version_json)
656 | return model_version
657 | except google.api_core.exceptions.NotFound:
658 | raise mlflow.exceptions.MlflowException(
659 | message=f'Model "{name}" version "{version}" does not exist',
660 | error_code=databricks_pb2.RESOURCE_DOES_NOT_EXIST,
661 | )
662 |
663 | def _set_model_version_proto(
664 | self,
665 | name: str,
666 | version: str,
667 | model_version_proto: model_registry_pb2.ModelVersion,
668 | update_modification_time: bool = True,
669 | ) -> None:
670 | """Gets the model version instance by name and version.
671 |
672 | Args:
673 | name: Registered model name.
674 | version: Registered model version.
675 | model_version_proto: A
676 | :py:class:`mlflow.entities.model_registry_pb2.ModelVersion` object.
677 | update_modification_time: Whether to update the modification time
678 |
679 | Returns:
680 | None
681 | """
682 | _validate_model_name(name)
683 | _validate_model_version(version)
684 | if update_modification_time:
685 | current_time = datetime.datetime.utcnow()
686 | model_version_proto.last_updated_timestamp = int(current_time.timestamp())
687 | model_version_json = json_format.MessageToJson(model_version_proto)
688 | model_version_uri = self._get_model_version_info_file_path(
689 | name=name, version=version
690 | )
691 | storage.Blob.from_string(
692 | uri=model_version_uri,
693 | # Workaround for https://github.com/googleapis/python-storage/issues/540
694 | client=storage.Client(),
695 | ).upload_from_string(model_version_json)
696 |
697 | def _list_model_versions(
698 | self,
699 | name: Optional[str] = None,
700 | ) -> List[model_registry.ModelVersion]:
701 | """Lists all versions of a registered model.
702 |
703 | Args:
704 | name: Registered model name.
705 |
706 | Returns:
707 | A list of :py:class:`mlflow.entities.model_registry.ModelVersion` objects.
708 | """
709 | if name:
710 | _validate_model_name(name)
711 | model_dir_uri = self._get_model_dir(name=name)
712 | else:
713 | model_dir_uri = self._base_uri
714 | model_dir_blob = storage.Blob.from_string(uri=model_dir_uri)
715 | blob_iterator = storage.Client().list_blobs(
716 | # Using bucket name as a workaround for
717 | # https://github.com/googleapis/python-storage/issues/540
718 | bucket_or_name=model_dir_blob.bucket.name,
719 | prefix=model_dir_blob.name,
720 | )
721 | models = [
722 | _json_to_registered_model_version(blob.download_as_text())
723 | for blob in blob_iterator
724 | if blob.name.endswith(self._MODEL_VERSION_INFO_FILE_NAME)
725 | ]
726 | return models
727 |
728 | def get_model_version_download_uri(self, name: str, version: str) -> str:
729 | """Gets the download location in Model Registry for this model version.
730 |
731 | NOTE: For first version of Model Registry, since the models are not copied over to another
732 | location, download URI points to input source path.
733 |
734 | Args:
735 | name: Registered model name.
736 | version: Registered model version.
737 |
738 | Returns:
739 | A single URI location that allows reads for downloading.
740 | """
741 | model_version = self.get_model_version(name, version)
742 | return model_version.source
743 |
744 | def search_model_versions(
745 | self, filter_string: str
746 | ) -> paged_list.PagedList[model_registry.ModelVersion]:
747 | """Searches for model versions in backend that satisfy the filter criteria.
748 |
749 | Args:
750 | filter_string: A filter string expression. Currently supports a single filter
751 | condition either name of model like ``name = 'model_name'`` or ``run_id = '...'``.
752 |
753 | Returns:
754 | PagedList of :py:class:`mlflow.entities.model_registry.ModelVersion`
755 | objects.
756 | """
757 | parsed_filters = SearchUtils.parse_filter_for_model_versions(filter_string)
758 | model_versions = self._list_model_versions()
759 | model_versions = [
760 | model_version
761 | for model_version in model_versions
762 | if model_version.current_stage is None
763 | or model_version.current_stage in model_version_stages.ALL_STAGES
764 | ]
765 | for parsed_filter in parsed_filters:
766 | if parsed_filter["comparator"] != "=":
767 | raise mlflow.exceptions.MlflowException(
768 | "Model Registry search filter only supports equality(=) "
769 | f"comparator. Input filter string: {filter_string}",
770 | error_code=databricks_pb2.INVALID_PARAMETER_VALUE,
771 | )
772 | # Key validated by `parse_filter_for_model_versions`
773 | key = parsed_filter["key"]
774 | value = parsed_filter["value"]
775 | model_versions = [
776 | model_version
777 | for model_version in model_versions
778 | if getattr(model_version, key, None) == value
779 | ]
780 | return paged_list.PagedList(items=model_versions, token=None)
781 |
782 | def set_model_version_tag(
783 | self, name: str, version: str, tag: model_registry.ModelVersionTag
784 | ) -> None:
785 | """Sets a tag for the model version.
786 |
787 | Args:
788 | name: Registered model name.
789 | version: Registered model version.
790 | tag: py:class:`mlflow.entities.model_registry.ModelVersionTag` instance
791 | to log.
792 |
793 | Returns:
794 | None
795 | """
796 | _validate_model_version_tag(tag.key, tag.value)
797 | model_version = self.get_model_version(name, version)
798 | model_version_proto = model_version.to_proto()
799 | # model.tags is not a proper map.
800 | # Deleting possible previous tag values first.
801 | for existing_tag in model_version_proto.tags:
802 | if existing_tag.key == tag.key:
803 | model_version_proto.tags.remove(existing_tag)
804 | model_version_proto.tags.append(tag)
805 | self._set_model_version_proto(
806 | name=name, version=version, model_version_proto=model_version_proto
807 | )
808 |
809 | def delete_model_version_tag(self, name: str, version: str, key: str) -> None:
810 | """Deletes a tag associated with the model version.
811 |
812 | Args:
813 | name: Registered model name.
814 | version: Registered model version.
815 | key: Tag key.
816 |
817 | Returns:
818 | None
819 | """
820 | _validate_tag_name(key)
821 | model_version = self.get_model_version(name, version)
822 | model_version_proto = model_version.to_proto()
823 | for tag in model_version_proto.tags:
824 | if tag.key == key:
825 | model_version_proto.tags.remove(tag)
826 | self._set_model_version_proto(
827 | name=name, version=version, model_version_proto=model_version_proto
828 | )
829 |
830 |
831 | def _validate_store_uri(store_uri: str) -> bool:
832 | return store_uri.startswith("gs://")
833 |
834 |
835 | def _json_to_registered_model(model_json: str) -> model_registry.RegisteredModel:
836 | """Converts JSON string to RegisteredModel."""
837 | model = model_registry_pb2.RegisteredModel()
838 | json_format.Parse(model_json, model)
839 | return model_registry.RegisteredModel.from_proto(model)
840 |
841 |
842 | def _json_to_registered_model_version(
843 | model_json: str,
844 | ) -> model_registry.ModelVersion:
845 | """Converts JSON string to ModelVersion."""
846 | model_version = model_registry_pb2.ModelVersion()
847 | json_format.Parse(model_json, model_version)
848 | return model_registry.ModelVersion.from_proto(model_version)
849 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 The google-cloud-mlflow Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | #
15 | """MLflow Google Cloud Vertex AI integration package."""
16 |
17 | import setuptools
18 |
19 | with open("README.md", "r") as fh:
20 | long_description = fh.read()
21 |
22 | setuptools.setup(
23 | name="google_cloud_mlflow",
24 | version="0.0.6",
25 | description="MLflow Google Cloud Vertex AI integration package",
26 | url="https://github.com/Ark-kun/google_cloud_mlflow",
27 | project_urls={
28 | 'Source': 'https://github.com/pypa/sampleproject/',
29 | 'Issues': 'https://github.com/Ark-kun/google_cloud_mlflow/issues',
30 | },
31 | author="Alexey Volkov",
32 | author_email="alexey.volkov@ark-kun.com",
33 | classifiers=[
34 | "License :: OSI Approved :: Apache Software License",
35 | "Natural Language :: English",
36 | "Operating System :: OS Independent",
37 | "Programming Language :: Python :: 3",
38 | ],
39 | license="Apache License 2.0",
40 | long_description=long_description,
41 | long_description_content_type="text/markdown",
42 | packages=setuptools.find_packages(),
43 | keywords='mlflow, Google Cloud, Vertex AI',
44 | python_requires=">=3.6",
45 | install_requires=[
46 | "google-cloud-aiplatform~=1.7",
47 | "mlflow~=1.22",
48 | "google-cloud-build==3.*",
49 | "google-cloud-storage==1.*",
50 | ],
51 | entry_points={
52 | "mlflow.deployments": [
53 | "google_cloud=google_cloud_mlflow.mlflow_model_deployment_plugin_for_google_cloud_vertex_ai",
54 | ],
55 | "mlflow.model_registry_store": [
56 | "gs=google_cloud_mlflow.mlflow_model_registry_plugin_for_google_cloud_storage:GoogleCloudStorageModelRegistry",
57 | ],
58 | },
59 | )
60 |
--------------------------------------------------------------------------------