├── .gitignore
├── pipeline.png
├── README.md
├── converting_exploration_to_pipeline.ipynb
└── preprocess.yaml
/.gitignore:
--------------------------------------------------------------------------------
1 | .ipynb_checkpoints/
2 | venv/
3 | .env
4 |
--------------------------------------------------------------------------------
/pipeline.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jy2k/Kubeflow-v2-end-to-end/HEAD/pipeline.png
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # kubeflow-vertex-ai
2 |
3 | ## This notebook shows how to go from EDA to a working Vertex AI pipeline (kubeflow)
4 |
5 | 
--------------------------------------------------------------------------------
/converting_exploration_to_pipeline.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "id": "2f2ee82c-e800-47e2-8978-bc66d227da73",
6 | "metadata": {},
7 | "source": [
8 | "### //TODO:\n",
9 | "Add component for testing the model after deployment\n",
10 | "\n",
11 | "Add component for model monitoring\n",
12 | "\n",
13 | "Split the preprocess into BQ component and transform component\n",
14 | "\n",
15 | "Add a tensorflow model\n",
16 | "\n",
17 | "Add a pytorch model\n",
18 | "\n",
19 | "Add AutoML model\n",
20 | "\n",
21 | "Add Error handling\n",
22 | "\n",
23 | "Add mondel monitoring"
24 | ]
25 | },
26 | {
27 | "cell_type": "markdown",
28 | "id": "0b04cb1f-d96c-4720-a1d8-b2460de481e2",
29 | "metadata": {},
30 | "source": [
31 | "# Create & Deploy Vertex-AI Pipeline w/ Kubeflow"
32 | ]
33 | },
34 | {
35 | "cell_type": "markdown",
36 | "id": "3140d27f-f516-49a9-9ff4-811fd263feeb",
37 | "metadata": {},
38 | "source": [
39 | "Install the needed libraries in order to run the code locally"
40 | ]
41 | },
42 | {
43 | "cell_type": "code",
44 | "execution_count": 1,
45 | "id": "9560fe6f-8a20-476a-aea3-f4c4abaf37dd",
46 | "metadata": {},
47 | "outputs": [],
48 | "source": [
49 | "%%capture\n",
50 | "!pip3 install google-cloud-aiplatform==1.0.0 --upgrade\n",
51 | "!pip3 install kfp google-cloud-pipeline-components==0.1.1 --upgrade\n",
52 | "!pip3 install scikit-learn\n",
53 | "!pip3 install google-cloud-aiplatform --upgrade\n",
54 | "!pip3 install pandas\n",
55 | "!pip3 install python-dotenv"
56 | ]
57 | },
58 | {
59 | "cell_type": "markdown",
60 | "id": "617816bc-8237-4c94-b00f-6d62530516fe",
61 | "metadata": {},
62 | "source": [
63 | "Might need to restart kernel after initial installation of the cell above"
64 | ]
65 | },
66 | {
67 | "cell_type": "code",
68 | "execution_count": 170,
69 | "id": "97a0f438-f584-457e-b469-90d5ec9e9182",
70 | "metadata": {},
71 | "outputs": [],
72 | "source": [
73 | "import uuid\n",
74 | "from kfp import dsl\n",
75 | "import kfp\n",
76 | "from google.cloud import aiplatform\n",
77 | "from kfp.v2.dsl import component\n",
78 | "from kfp.v2.dsl import (Artifact, Dataset, Input, InputPath, Model, Output, OutputPath, component, ClassificationMetrics, Metrics)\n",
79 | "from google_cloud_pipeline_components import aiplatform as gcc_aip\n",
80 | "from typing import NamedTuple"
81 | ]
82 | },
83 | {
84 | "cell_type": "markdown",
85 | "id": "f9a5a343-6bcb-45e6-8227-eca403590f91",
86 | "metadata": {},
87 | "source": [
88 | "Getting some preset environment variables save to a local file. Create one of your own by following these instructions: https://stackoverflow.com/a/54028874"
89 | ]
90 | },
91 | {
92 | "cell_type": "code",
93 | "execution_count": 21,
94 | "id": "f3affb3e-fcbf-4167-abe7-04b619d69cac",
95 | "metadata": {},
96 | "outputs": [
97 | {
98 | "name": "stdout",
99 | "output_type": "stream",
100 | "text": [
101 | "The dotenv extension is already loaded. To reload it, use:\n",
102 | " %reload_ext dotenv\n",
103 | "kubeflow-demos\n",
104 | "user-group-demo\n",
105 | "gs://user-group-demo/pipeline_root\n"
106 | ]
107 | }
108 | ],
109 | "source": [
110 | "#https://stackoverflow.com/a/54028874\n",
111 | "%load_ext dotenv\n",
112 | "%dotenv\n",
113 | "\n",
114 | "import os\n",
115 | "PROJECT_ID = os.environ['PROJECT_ID']\n",
116 | "BUCKET_NAME = os.environ['BUCKET']\n",
117 | "\n",
118 | "PIPELINE_ROOT = 'gs://{}/pipeline_root'.format(BUCKET_NAME)\n",
119 | "REGION = 'us-central1'\n",
120 | "\n",
121 | "print(PROJECT_ID)\n",
122 | "print(BUCKET_NAME)\n",
123 | "print(PIPELINE_ROOT)"
124 | ]
125 | },
126 | {
127 | "cell_type": "markdown",
128 | "id": "60d89156-3085-47e4-9eec-85c7c53f96d7",
129 | "metadata": {},
130 | "source": [
131 | "## 1. Create a component for reading data from BQ into CSV"
132 | ]
133 | },
134 | {
135 | "cell_type": "code",
136 | "execution_count": 129,
137 | "id": "2298853a-ae42-42f3-8594-4b906552a54b",
138 | "metadata": {},
139 | "outputs": [],
140 | "source": [
141 | "@component(packages_to_install=[\"pandas\", \"google-cloud-aiplatform\", \"google-cloud-bigquery-storage\",\"google-cloud-bigquery\",\"pyarrow\"], output_component_file=\"preprocess.yaml\")\n",
142 | "def preprocess(output_csv_path: OutputPath('CSV')):\n",
143 | " #1\n",
144 | " from google.cloud import bigquery\n",
145 | " import google.auth\n",
146 | " \n",
147 | " creds, project = google.auth.default()\n",
148 | " client = bigquery.Client(project='kubeflow-demos', credentials=creds)\n",
149 | "\n",
150 | " query = \"\"\"\n",
151 | " SELECT * FROM `kubeflow-demos.telco.churn`\n",
152 | " \"\"\"\n",
153 | " print(query)\n",
154 | " \n",
155 | " dataframe = client.query(query).to_dataframe()\n",
156 | " print(dataframe.head())\n",
157 | " \n",
158 | " dataframe.to_csv(output_csv_path)\n",
159 | " print(\"done\")"
160 | ]
161 | },
162 | {
163 | "cell_type": "markdown",
164 | "id": "48fdda2e-1254-4502-9f4f-9fdca5863c7a",
165 | "metadata": {},
166 | "source": [
167 | "## 2. Create a component to train"
168 | ]
169 | },
170 | {
171 | "cell_type": "code",
172 | "execution_count": 183,
173 | "id": "54350599-3b5a-4e4c-8b01-25030ef103e1",
174 | "metadata": {},
175 | "outputs": [],
176 | "source": [
177 | "@component(packages_to_install=[\"pandas\", \"imbalanced-learn\", \"google-cloud-aiplatform\", \"pyarrow\"])\n",
178 | "def train(wmetrics: Output[ClassificationMetrics], input_csv_path: InputPath('CSV'), saved_model: Output[Model], artifact_uri: OutputPath(str), accuracy: Output[Metrics], model_type: str, project_id: str, bucket: str):\n",
179 | " from google.cloud import aiplatform\n",
180 | " from typing import NamedTuple\n",
181 | " #Train\n",
182 | " import pandas as pd\n",
183 | " df = pd.read_csv(input_csv_path)\n",
184 | " print(len(df))\n",
185 | " \n",
186 | " from sklearn.preprocessing import LabelEncoder\n",
187 | " for c in df.columns:\n",
188 | " if df[c].dtype=='object': #Since we are encoding object datatype to integer/float\n",
189 | " lbl = LabelEncoder()\n",
190 | " lbl.fit(list(df[c].values))\n",
191 | " df[c] = lbl.transform(df[c].values)\n",
192 | " print(df.head()) #To check if properly encoded\n",
193 | " \n",
194 | " X = df[['Contract', 'tenure', 'TechSupport', 'OnlineSecurity', 'TotalCharges', 'PaperlessBilling',\n",
195 | " 'DeviceProtection', 'Dependents', 'OnlineBackup', 'SeniorCitizen', 'MonthlyCharges',\n",
196 | " 'PaymentMethod', 'Partner', 'PhoneService']] #taking only relevant columns\n",
197 | " y = df['Churn']\n",
198 | "\n",
199 | "\n",
200 | " # Scaling all the variables to a range of 0 to 1\n",
201 | " from sklearn.preprocessing import MinMaxScaler\n",
202 | " features = X.columns.values\n",
203 | " scaler = MinMaxScaler(feature_range = (0,1))\n",
204 | " scaler.fit(X)\n",
205 | " X = pd.DataFrame(scaler.transform(X))\n",
206 | " X.columns = features\n",
207 | " \n",
208 | " from sklearn.model_selection import train_test_split\n",
209 | " X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=101)\n",
210 | "\n",
211 | " #Choose which model to train\n",
212 | " if model_type == 'linear_regression':\n",
213 | " from sklearn.linear_model import LogisticRegression\n",
214 | " model = LogisticRegression()\n",
215 | " \n",
216 | " elif model_type == 'naive_bayes':\n",
217 | " from sklearn.naive_bayes import GaussianNB\n",
218 | " model = GaussianNB()\n",
219 | " \n",
220 | " elif model_type == 'decision_tree':\n",
221 | " from sklearn.tree import DecisionTreeClassifier\n",
222 | " model = DecisionTreeClassifier()\n",
223 | " \n",
224 | " model.fit(X_train, y_train)\n",
225 | " \n",
226 | " #Save the model to disk and also automatically to GCS\n",
227 | " import joblib\n",
228 | " \n",
229 | " joblib.dump(model, os.path.join(saved_model.path.replace(\"saved_model\",\"\"), 'model.joblib'))\n",
230 | " print(\" saved_model.path: \"+ saved_model.path)\n",
231 | " print(\" saved_model.uri: \"+ saved_model.uri)\n",
232 | " with open(artifact_uri, 'w') as f:\n",
233 | " f.write(saved_model.uri.replace(\"saved_model\",\"\"))\n",
234 | " \n",
235 | " print(saved_model.uri)\n",
236 | " \n",
237 | " accuracy.log_metric('accuracy', 71.0)\n",
238 | " \n",
239 | " #Adding roc curve\n",
240 | " from sklearn.metrics import roc_curve\n",
241 | " from sklearn.model_selection import cross_val_predict, train_test_split\n",
242 | "\n",
243 | " y_scores = cross_val_predict(model, X_train, y_train, cv=3, method=\"predict_proba\")\n",
244 | " fpr, tpr, thresholds = roc_curve(\n",
245 | " y_true=y_train, y_score=y_scores[:, 1], pos_label=True\n",
246 | " )\n",
247 | " wmetrics.log_roc_curve(fpr, tpr, thresholds)\n"
248 | ]
249 | },
250 | {
251 | "cell_type": "markdown",
252 | "id": "fdfccc2d-7633-4579-9c88-f9b6ef368172",
253 | "metadata": {},
254 | "source": [
255 | "## 3. Eval component"
256 | ]
257 | },
258 | {
259 | "cell_type": "code",
260 | "execution_count": 184,
261 | "id": "d1ebe330-78f6-4de7-aaea-c989cde694ea",
262 | "metadata": {},
263 | "outputs": [],
264 | "source": [
265 | "@component()\n",
266 | "def eval(baseline: float, accuracy: Input[Metrics], accuracy2: Input[Metrics], accuracy3: Input[Metrics]) -> bool:\n",
267 | " isBetter = False\n",
268 | " \n",
269 | " print(str(dir(accuracy)))\n",
270 | " new_val = float(accuracy.metadata['accuracy'])\n",
271 | " print(str(new_val))\n",
272 | " \n",
273 | " \n",
274 | " if new_val>baseline:\n",
275 | " isBetter = True\n",
276 | " print(\"isBetter: \"+str(isBetter))\n",
277 | " \n",
278 | " return isBetter"
279 | ]
280 | },
281 | {
282 | "cell_type": "markdown",
283 | "id": "9429e59f-798c-49a6-82e2-7c7854f34c8d",
284 | "metadata": {},
285 | "source": [
286 | "## 4. Predict Endpoint component"
287 | ]
288 | },
289 | {
290 | "cell_type": "code",
291 | "execution_count": 185,
292 | "id": "70ea05e3-9efc-4fe1-9b31-14321326e00c",
293 | "metadata": {},
294 | "outputs": [],
295 | "source": [
296 | "#https://github.com/googleapis/python-aiplatform/blob/master/samples/snippets/predict_custom_trained_model_sample.py\n",
297 | "#https://cloud.google.com/ai-platform/prediction/docs/online-predict\n",
298 | "@component(packages_to_install=[\"google-cloud-aiplatform\"])\n",
299 | "def predict_endpoint_test(endpoint_id: Input[Artifact],\n",
300 | " location: str,\n",
301 | " project: str,\n",
302 | " api_endpoint: str = \"us-central1-aiplatform.googleapis.com\"):\n",
303 | " \n",
304 | " from typing import Dict\n",
305 | " from google.cloud import aiplatform\n",
306 | " from google.protobuf import json_format\n",
307 | " from google.protobuf.struct_pb2 import Value\n",
308 | " \n",
309 | " print(endpoint_id)\n",
310 | " endpoint_id = endpoint_id.uri.split('/')[-1]\n",
311 | " print(endpoint_id)\n",
312 | " \n",
313 | " client_options = {\"api_endpoint\": api_endpoint}\n",
314 | " \n",
315 | " client = aiplatform.gapic.PredictionServiceClient(client_options=client_options)\n",
316 | " #https://machinelearningmastery.com/make-predictions-scikit-learn/\n",
317 | " instance_dict = [ 1.74481176, 0.86540763, -1.07296862 ,-2.3015387, -2.06014071, 1.46210794, 0.3190391 , -0.24937038 ,-0.61175641 ,-0.7612069 , -0.38405435, -0.52817175, -0.3224172, 1.62434536]\n",
318 | " \n",
319 | " instance = json_format.ParseDict(instance_dict, Value())\n",
320 | " instances = [instance]\n",
321 | " \n",
322 | " endpoint = client.endpoint_path(\n",
323 | " project=project, location=location, endpoint=endpoint_id\n",
324 | " )\n",
325 | " response = client.predict(\n",
326 | " endpoint=endpoint, instances=instances\n",
327 | " )\n",
328 | " print(\"response\")\n",
329 | " print(\" deployed_model_id:\", response.deployed_model_id)\n",
330 | " \n",
331 | " predictions = response.predictions\n",
332 | " for prediction in predictions:\n",
333 | " print(\" prediction:\" + str(prediction))"
334 | ]
335 | },
336 | {
337 | "cell_type": "code",
338 | "execution_count": 191,
339 | "id": "847e244b-6007-4e34-8887-877cea41c445",
340 | "metadata": {},
341 | "outputs": [],
342 | "source": [
343 | "from datetime import datetime\n",
344 | "\n",
345 | "TIMESTAMP = datetime.now().strftime(\"%Y%m%d%H%M%S\")"
346 | ]
347 | },
348 | {
349 | "cell_type": "code",
350 | "execution_count": 192,
351 | "id": "81684b00-a7df-428d-a9e1-92faf21dd6ff",
352 | "metadata": {},
353 | "outputs": [],
354 | "source": [
355 | "@kfp.dsl.pipeline(name=\"train-scikit\" + str(uuid.uuid4()))\n",
356 | "def pipeline(\n",
357 | " project: str = PROJECT_ID,\n",
358 | " bucket: str = BUCKET_NAME,\n",
359 | " baseline_accuracy: float = 70.0\n",
360 | "):\n",
361 | " preprocess_task = preprocess()\n",
362 | " \n",
363 | " train_task = train(preprocess_task.output, model_type='decision_tree', project_id=PROJECT_ID, bucket=BUCKET_NAME)\n",
364 | " train_task2 = train(preprocess_task.output, model_type='decision_tree', project_id=PROJECT_ID, bucket=BUCKET_NAME)\n",
365 | " train_task3 = train(preprocess_task.output, model_type='decision_tree', project_id=PROJECT_ID, bucket=BUCKET_NAME)\n",
366 | " \n",
367 | " eval_task = eval(baseline_accuracy, train_task.outputs[\"accuracy\"], train_task2.outputs[\"accuracy\"], train_task3.outputs[\"accuracy\"])\n",
368 | " \n",
369 | " with dsl.Condition(eval_task.output == \"true\", name=\"eval models\"):\n",
370 | " model_upload_op = gcc_aip.ModelUploadOp(\n",
371 | " project=PROJECT_ID,\n",
372 | " display_name=\"model\"+TIMESTAMP, \n",
373 | " # artifact_uri=\"gs://user-group-demo/pipeline_root/141610882258/train-scikitf989f632-b955-4bb1-a72d-0480d1c08627-20210620145355/train_-6780204423378370560/\", # GCS location of model\n",
374 | " artifact_uri=train_task.outputs[\"artifact_uri\"], # GCS location of model\n",
375 | " serving_container_image_uri=\"us-docker.pkg.dev/vertex-ai/prediction/sklearn-cpu.0-24:latest\",\n",
376 | " )\n",
377 | "\n",
378 | " endpoint_create_op = gcc_aip.EndpointCreateOp(\n",
379 | " project=PROJECT_ID,\n",
380 | " display_name=\"pipelines\"+TIMESTAMP,\n",
381 | " )\n",
382 | "\n",
383 | " model_deploy_op = gcc_aip.ModelDeployOp( \n",
384 | " project=PROJECT_ID,\n",
385 | " endpoint=endpoint_create_op.outputs[\"endpoint\"],\n",
386 | " model=model_upload_op.outputs[\"model\"],\n",
387 | " deployed_model_display_name=\"model_display_name\",\n",
388 | " machine_type=\"n1-standard-4\",\n",
389 | " )\n",
390 | " \n",
391 | " predict_task = predict_endpoint_test(project=PROJECT_ID, location=REGION, endpoint_id = model_deploy_op.outputs['endpoint'])"
392 | ]
393 | },
394 | {
395 | "cell_type": "code",
396 | "execution_count": 193,
397 | "id": "e3b37e84-3ed5-4c69-87d1-cf19d27a4212",
398 | "metadata": {},
399 | "outputs": [],
400 | "source": [
401 | "from kfp.v2 import compiler\n",
402 | "\n",
403 | "compiler.Compiler().compile(pipeline_func=pipeline, \n",
404 | " package_path=\"dag-\"+TIMESTAMP+\".json\")"
405 | ]
406 | },
407 | {
408 | "cell_type": "code",
409 | "execution_count": 194,
410 | "id": "37927c38-8ab2-45be-8d27-156de5123c6e",
411 | "metadata": {},
412 | "outputs": [],
413 | "source": [
414 | "from kfp.v2.google.client import AIPlatformClient\n",
415 | "\n",
416 | "api_client = AIPlatformClient(\n",
417 | " project_id=PROJECT_ID,\n",
418 | " region=REGION,\n",
419 | ")"
420 | ]
421 | },
422 | {
423 | "cell_type": "code",
424 | "execution_count": 195,
425 | "id": "4523294d-c363-479a-9198-4995431f67fc",
426 | "metadata": {},
427 | "outputs": [
428 | {
429 | "data": {
430 | "text/html": [
431 | "See the Pipeline job here."
432 | ],
433 | "text/plain": [
434 | ""
435 | ]
436 | },
437 | "metadata": {},
438 | "output_type": "display_data"
439 | }
440 | ],
441 | "source": [
442 | "response = api_client.create_run_from_job_spec(\n",
443 | " \"dag-\"+TIMESTAMP+\".json\",\n",
444 | " pipeline_root=PIPELINE_ROOT,\n",
445 | " parameter_values={},\n",
446 | ")"
447 | ]
448 | },
449 | {
450 | "cell_type": "code",
451 | "execution_count": null,
452 | "id": "b878ae2a-cbcb-400a-b714-f2d32a70a3f3",
453 | "metadata": {},
454 | "outputs": [],
455 | "source": []
456 | }
457 | ],
458 | "metadata": {
459 | "kernelspec": {
460 | "display_name": "Python 3",
461 | "language": "python",
462 | "name": "python3"
463 | },
464 | "language_info": {
465 | "codemirror_mode": {
466 | "name": "ipython",
467 | "version": 3
468 | },
469 | "file_extension": ".py",
470 | "mimetype": "text/x-python",
471 | "name": "python",
472 | "nbconvert_exporter": "python",
473 | "pygments_lexer": "ipython3",
474 | "version": "3.9.4"
475 | }
476 | },
477 | "nbformat": 4,
478 | "nbformat_minor": 5
479 | }
480 |
--------------------------------------------------------------------------------
/preprocess.yaml:
--------------------------------------------------------------------------------
1 | name: Preprocess
2 | outputs:
3 | - {name: output_csv, type: CSV}
4 | implementation:
5 | container:
6 | image: python:3.7
7 | command:
8 | - sh
9 | - -c
10 | - (PIP_DISABLE_PIP_VERSION_CHECK=1 python3 -m pip install --quiet --no-warn-script-location
11 | 'pandas' 'google-cloud-aiplatform' 'google-cloud-bigquery-storage' 'google-cloud-bigquery'
12 | 'pyarrow' || PIP_DISABLE_PIP_VERSION_CHECK=1 python3 -m pip install --quiet
13 | --no-warn-script-location 'pandas' 'google-cloud-aiplatform' 'google-cloud-bigquery-storage'
14 | 'google-cloud-bigquery' 'pyarrow' --user) && "$0" "$@"
15 | - sh
16 | - -ec
17 | - |
18 | program_path=$(mktemp)
19 | printf "%s" "$0" > "$program_path"
20 | python3 -u "$program_path" "$@"
21 | - |2
22 |
23 | import json
24 | import inspect
25 | from typing import *
26 |
27 | # Copyright 2021 The Kubeflow Authors
28 | #
29 | # Licensed under the Apache License, Version 2.0 (the "License");
30 | # you may not use this file except in compliance with the License.
31 | # You may obtain a copy of the License at
32 | #
33 | # http://www.apache.org/licenses/LICENSE-2.0
34 | #
35 | # Unless required by applicable law or agreed to in writing, software
36 | # distributed under the License is distributed on an "AS IS" BASIS,
37 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
38 | # See the License for the specific language governing permissions and
39 | # limitations under the License.
40 | """Classes for input/output types in KFP SDK.
41 |
42 | These are only compatible with v2 Pipelines.
43 | """
44 |
45 | import os
46 | from typing import Dict, Generic, List, Optional, Type, TypeVar, Union
47 |
48 |
49 | _GCS_LOCAL_MOUNT_PREFIX = '/gcs/'
50 | _MINIO_LOCAL_MOUNT_PREFIX = '/minio/'
51 | _S3_LOCAL_MOUNT_PREFIX = '/s3/'
52 |
53 |
54 | class Artifact(object):
55 | """Generic Artifact class.
56 |
57 | This class is meant to represent the metadata around an input or output
58 | machine-learning Artifact. Artifacts have URIs, which can either be a location
59 | on disk (or Cloud storage) or some other resource identifier such as
60 | an API resource name.
61 |
62 | Artifacts carry a `metadata` field, which is a dictionary for storing
63 | metadata related to this artifact.
64 | """
65 | TYPE_NAME = 'system.Artifact'
66 |
67 | def __init__(self,
68 | name: Optional[str] = None,
69 | uri: Optional[str] = None,
70 | metadata: Optional[Dict] = None):
71 | """Initializes the Artifact with the given name, URI and metadata."""
72 | self.uri = uri or ''
73 | self.name = name or ''
74 | self.metadata = metadata or {}
75 |
76 | @property
77 | def path(self):
78 | return self._get_path()
79 |
80 | @path.setter
81 | def path(self, path):
82 | self._set_path(path)
83 |
84 | def _get_path(self) -> Optional[str]:
85 | if self.uri.startswith('gs://'):
86 | return _GCS_LOCAL_MOUNT_PREFIX + self.uri[len('gs://'):]
87 | elif self.uri.startswith('minio://'):
88 | return _MINIO_LOCAL_MOUNT_PREFIX + self.uri[len('minio://'):]
89 | elif self.uri.startswith('s3://'):
90 | return _S3_LOCAL_MOUNT_PREFIX + self.uri[len('s3://'):]
91 | return None
92 |
93 | def _set_path(self, path):
94 | if path.startswith(_GCS_LOCAL_MOUNT_PREFIX):
95 | path = 'gs://' + path[len(_GCS_LOCAL_MOUNT_PREFIX):]
96 | elif path.startswith(_MINIO_LOCAL_MOUNT_PREFIX):
97 | path = 'minio://' + path[len(_MINIO_LOCAL_MOUNT_PREFIX):]
98 | elif path.startswith(_S3_LOCAL_MOUNT_PREFIX):
99 | path = 's3://' + path[len(_S3_LOCAL_MOUNT_PREFIX):]
100 | self.uri = path
101 |
102 |
103 | class Model(Artifact):
104 | """An artifact representing an ML Model."""
105 | TYPE_NAME = 'system.Model'
106 |
107 | def __init__(self,
108 | name: Optional[str] = None,
109 | uri: Optional[str] = None,
110 | metadata: Optional[Dict] = None):
111 | super().__init__(uri=uri, name=name, metadata=metadata)
112 |
113 | @property
114 | def framework(self) -> str:
115 | return self._get_framework()
116 |
117 | def _get_framework(self) -> str:
118 | return self.metadata.get('framework', '')
119 |
120 | @framework.setter
121 | def framework(self, framework: str):
122 | self._set_framework(framework)
123 |
124 | def _set_framework(self, framework: str):
125 | self.metadata['framework'] = framework
126 |
127 |
128 | class Dataset(Artifact):
129 | """An artifact representing an ML Dataset."""
130 | TYPE_NAME = 'system.Dataset'
131 |
132 | def __init__(self,
133 | name: Optional[str] = None,
134 | uri: Optional[str] = None,
135 | metadata: Optional[Dict] = None):
136 | super().__init__(uri=uri, name=name, metadata=metadata)
137 |
138 |
139 | class Metrics(Artifact):
140 | """Represent a simple base Artifact type to store key-value scalar metrics."""
141 | TYPE_NAME = 'system.Metrics'
142 |
143 | def __init__(self,
144 | name: Optional[str] = None,
145 | uri: Optional[str] = None,
146 | metadata: Optional[Dict] = None):
147 | super().__init__(uri=uri, name=name, metadata=metadata)
148 |
149 | def log_metric(self, metric: str, value: float):
150 | """Sets a custom scalar metric.
151 |
152 | Args:
153 | metric: Metric key
154 | value: Value of the metric.
155 | """
156 | self.metadata[metric] = value
157 |
158 |
159 | class ClassificationMetrics(Artifact):
160 | """Represents Artifact class to store Classification Metrics."""
161 | TYPE_NAME = 'system.ClassificationMetrics'
162 |
163 | def __init__(self,
164 | name: Optional[str] = None,
165 | uri: Optional[str] = None,
166 | metadata: Optional[Dict] = None):
167 | super().__init__(uri=uri, name=name, metadata=metadata)
168 |
169 | def log_roc_data_point(self, fpr: float, tpr: float, threshold: float):
170 | """Logs a single data point in the ROC Curve.
171 |
172 | Args:
173 | fpr: False positive rate value of the data point.
174 | tpr: True positive rate value of the data point.
175 | threshold: Threshold value for the data point.
176 | """
177 |
178 | roc_reading = {
179 | 'confidenceThreshold': threshold,
180 | 'recall': tpr,
181 | 'falsePositiveRate': fpr
182 | }
183 | if 'confidenceMetrics' not in self.metadata.keys():
184 | self.metadata['confidenceMetrics'] = []
185 |
186 | self.metadata['confidenceMetrics'].append(roc_reading)
187 |
188 | def log_roc_curve(self, fpr: List[float], tpr: List[float],
189 | threshold: List[float]):
190 | """Logs an ROC curve.
191 |
192 | The list length of fpr, tpr and threshold must be the same.
193 |
194 | Args:
195 | fpr: List of false positive rate values.
196 | tpr: List of true positive rate values.
197 | threshold: List of threshold values.
198 | """
199 | if len(fpr) != len(tpr) or len(fpr) != len(threshold) or len(tpr) != len(
200 | threshold):
201 | raise ValueError('Length of fpr, tpr and threshold must be the same. '
202 | 'Got lengths {}, {} and {} respectively.'.format(
203 | len(fpr), len(tpr), len(threshold)))
204 |
205 | for i in range(len(fpr)):
206 | self.log_roc_data_point(fpr=fpr[i], tpr=tpr[i], threshold=threshold[i])
207 |
208 | def set_confusion_matrix_categories(self, categories: List[str]):
209 | """Stores confusion matrix categories.
210 |
211 | Args:
212 | categories: List of strings specifying the categories.
213 | """
214 |
215 | self._categories = []
216 | annotation_specs = []
217 | for category in categories:
218 | annotation_spec = {'displayName': category}
219 | self._categories.append(category)
220 | annotation_specs.append(annotation_spec)
221 |
222 | self._matrix = []
223 | for row in range(len(self._categories)):
224 | self._matrix.append({'row': [0] * len(self._categories)})
225 |
226 | self._confusion_matrix = {}
227 | self._confusion_matrix['annotationSpecs'] = annotation_specs
228 | self._confusion_matrix['rows'] = self._matrix
229 | self.metadata['confusionMatrix'] = self._confusion_matrix
230 |
231 | def log_confusion_matrix_row(self, row_category: str, row: List[float]):
232 | """Logs a confusion matrix row.
233 |
234 | Args:
235 | row_category: Category to which the row belongs.
236 | row: List of integers specifying the values for the row.
237 |
238 | Raises:
239 | ValueError: If row_category is not in the list of categories
240 | set in set_categories call.
241 | """
242 | if row_category not in self._categories:
243 | raise ValueError('Invalid category: {} passed. Expected one of: {}'.\
244 | format(row_category, self._categories))
245 |
246 | if len(row) != len(self._categories):
247 | raise ValueError('Invalid row. Expected size: {} got: {}'.\
248 | format(len(self._categories), len(row)))
249 |
250 | self._matrix[self._categories.index(row_category)] = {'row': row}
251 | self.metadata['confusionMatrix'] = self._confusion_matrix
252 |
253 | def log_confusion_matrix_cell(self, row_category: str, col_category: str,
254 | value: int):
255 | """Logs a cell in the confusion matrix.
256 |
257 | Args:
258 | row_category: String representing the name of the row category.
259 | col_category: String representing the name of the column category.
260 | value: Int value of the cell.
261 |
262 | Raises:
263 | ValueError: If row_category or col_category is not in the list of
264 | categories set in set_categories.
265 | """
266 | if row_category not in self._categories:
267 | raise ValueError('Invalid category: {} passed. Expected one of: {}'.\
268 | format(row_category, self._categories))
269 |
270 | if col_category not in self._categories:
271 | raise ValueError('Invalid category: {} passed. Expected one of: {}'.\
272 | format(row_category, self._categories))
273 |
274 | self._matrix[self._categories.index(row_category)]['row'][
275 | self._categories.index(col_category)] = value
276 | self.metadata['confusionMatrix'] = self._confusion_matrix
277 |
278 | def log_confusion_matrix(self, categories: List[str],
279 | matrix: List[List[int]]):
280 | """Logs a confusion matrix.
281 |
282 | Args:
283 | categories: List of the category names.
284 | matrix: Complete confusion matrix.
285 |
286 | Raises:
287 | ValueError: Length of categories does not match number of rows or columns.
288 | """
289 | self.set_confusion_matrix_categories(categories)
290 |
291 | if len(matrix) != len(categories):
292 | raise ValueError('Invalid matrix: {} passed for categories: {}'.\
293 | format(matrix, categories))
294 |
295 | for index in range(len(categories)):
296 | if len(matrix[index]) != len(categories):
297 | raise ValueError('Invalid matrix: {} passed for categories: {}'.\
298 | format(matrix, categories))
299 |
300 | self.log_confusion_matrix_row(categories[index], matrix[index])
301 |
302 | self.metadata['confusionMatrix'] = self._confusion_matrix
303 |
304 |
305 | class SlicedClassificationMetrics(Artifact):
306 | """Metrics class representing Sliced Classification Metrics.
307 |
308 | Similar to ClassificationMetrics clients using this class are expected to use
309 | log methods of the class to log metrics with the difference being each log
310 | method takes a slice to associate the ClassificationMetrics.
311 |
312 | """
313 |
314 | TYPE_NAME = 'system.SlicedClassificationMetrics'
315 |
316 | def __init__(self,
317 | name: Optional[str] = None,
318 | uri: Optional[str] = None,
319 | metadata: Optional[Dict] = None):
320 | super().__init__(uri=uri, name=name, metadata=metadata)
321 |
322 | def _upsert_classification_metrics_for_slice(self, slice: str):
323 | """Upserts the classification metrics instance for a slice."""
324 | if slice not in self._sliced_metrics:
325 | self._sliced_metrics[slice] = ClassificationMetrics()
326 |
327 | def _update_metadata(self, slice: str):
328 | """Updates metadata to adhere to the metrics schema."""
329 | self.metadata = {}
330 | self.metadata['evaluationSlices'] = []
331 | for slice in self._sliced_metrics.keys():
332 | slice_metrics = {
333 | 'slice': slice,
334 | 'sliceClassificationMetrics': self._sliced_metrics[slice].metadata
335 | }
336 | self.metadata['evaluationSlices'].append(slice_metrics)
337 |
338 | def log_roc_reading(self, slice: str, threshold: float, tpr: float,
339 | fpr: float):
340 | """Logs a single data point in the ROC Curve of a slice.
341 |
342 | Args:
343 | slice: String representing slice label.
344 | threshold: Thresold value for the data point.
345 | tpr: True positive rate value of the data point.
346 | fpr: False positive rate value of the data point.
347 | """
348 |
349 | self._upsert_classification_metrics_for_slice(slice)
350 | self._sliced_metrics[slice].log_roc_reading(threshold, tpr, fpr)
351 | self._update_metadata(slice)
352 |
353 | def load_roc_readings(self, slice: str, readings: List[List[float]]):
354 | """Supports bulk loading ROC Curve readings for a slice.
355 |
356 | Args:
357 | slice: String representing slice label.
358 | readings: A 2-D list providing ROC Curve data points.
359 | The expected order of the data points is: threshold,
360 | true_positive_rate, false_positive_rate.
361 | """
362 | self._upsert_classification_metrics_for_slice(slice)
363 | self._sliced_metrics[slice].load_roc_readings(readings)
364 | self._update_metadata(slice)
365 |
366 | def set_confusion_matrix_categories(self, slice: str, categories: List[str]):
367 | """Stores confusion matrix categories for a slice..
368 |
369 | Categories are stored in the internal metrics_utils.ConfusionMatrix
370 | instance of the slice.
371 |
372 | Args:
373 | slice: String representing slice label.
374 | categories: List of strings specifying the categories.
375 | """
376 | self._upsert_classification_metrics_for_slice(slice)
377 | self._sliced_metrics[slice].set_confusion_matrix_categories(categories)
378 | self._update_metadata(slice)
379 |
380 | def log_confusion_matrix_row(self, slice: str, row_category: str,
381 | row: List[int]):
382 | """Logs a confusion matrix row for a slice.
383 |
384 | Row is updated on the internal metrics_utils.ConfusionMatrix
385 | instance of the slice.
386 |
387 | Args:
388 | slice: String representing slice label.
389 | row_category: Category to which the row belongs.
390 | row: List of integers specifying the values for the row.
391 | """
392 | self._upsert_classification_metrics_for_slice(slice)
393 | self._sliced_metrics[slice].log_confusion_matrix_row(row_category, row)
394 | self._update_metadata(slice)
395 |
396 | def log_confusion_matrix_cell(self, slice: str, row_category: str,
397 | col_category: str, value: int):
398 | """Logs a confusion matrix cell for a slice..
399 |
400 | Cell is updated on the internal metrics_utils.ConfusionMatrix
401 | instance of the slice.
402 |
403 | Args:
404 | slice: String representing slice label.
405 | row_category: String representing the name of the row category.
406 | col_category: String representing the name of the column category.
407 | value: Int value of the cell.
408 | """
409 | self._upsert_classification_metrics_for_slice(slice)
410 | self._sliced_metrics[slice].log_confusion_matrix_cell(
411 | row_category, col_category, value)
412 | self._update_metadata(slice)
413 |
414 | def load_confusion_matrix(self, slice: str, categories: List[str],
415 | matrix: List[List[int]]):
416 | """Supports bulk loading the whole confusion matrix for a slice.
417 |
418 | Args:
419 | slice: String representing slice label.
420 | categories: List of the category names.
421 | matrix: Complete confusion matrix.
422 | """
423 | self._upsert_classification_metrics_for_slice(slice)
424 | self._sliced_metrics[slice].log_confusion_matrix_cell(categories, matrix)
425 | self._update_metadata(slice)
426 |
427 |
428 | T = TypeVar('T')
429 |
430 |
431 | class InputAnnotation():
432 | """Marker type for input artifacts."""
433 | pass
434 |
435 |
436 |
437 | class OutputAnnotation():
438 | """Marker type for output artifacts."""
439 | pass
440 |
441 |
442 | # TODO: Use typing.Annotated instead of this hack.
443 | # With typing.Annotated (Python 3.9+ or typing_extensions package), the
444 | # following would look like:
445 | # Input = typing.Annotated[T, InputAnnotation]
446 | # Output = typing.Annotated[T, OutputAnnotation]
447 |
448 |
449 | # Input represents an Input artifact of type T.
450 | Input = Union[T, InputAnnotation]
451 |
452 | # Output represents an Output artifact of type T.
453 | Output = Union[T, OutputAnnotation]
454 |
455 |
456 | def is_artifact_annotation(typ) -> bool:
457 | if hasattr(typ, '_subs_tree'): # Python 3.6
458 | subs_tree = typ._subs_tree()
459 | return len(subs_tree) == 3 and subs_tree[0] == Union and subs_tree[2] in [InputAnnotation, OutputAnnotation]
460 |
461 | if not hasattr(typ, '__origin__'):
462 | return False
463 |
464 |
465 | if typ.__origin__ != Union and type(typ.__origin__) != type(Union):
466 | return False
467 |
468 |
469 | if not hasattr(typ, '__args__') or len(typ.__args__) != 2:
470 | return False
471 |
472 | if typ.__args__[1] not in [InputAnnotation, OutputAnnotation]:
473 | return False
474 |
475 | return True
476 |
477 | def is_input_artifact(typ) -> bool:
478 | """Returns True if typ is of type Input[T]."""
479 | if not is_artifact_annotation(typ):
480 | return False
481 |
482 | if hasattr(typ, '_subs_tree'): # Python 3.6
483 | subs_tree = typ._subs_tree()
484 | return len(subs_tree) == 3 and subs_tree[2] == InputAnnotation
485 |
486 | return typ.__args__[1] == InputAnnotation
487 |
488 | def is_output_artifact(typ) -> bool:
489 | """Returns True if typ is of type Output[T]."""
490 | if not is_artifact_annotation(typ):
491 | return False
492 |
493 | if hasattr(typ, '_subs_tree'): # Python 3.6
494 | subs_tree = typ._subs_tree()
495 | return len(subs_tree) == 3 and subs_tree[2] == OutputAnnotation
496 |
497 | return typ.__args__[1] == OutputAnnotation
498 |
499 | def get_io_artifact_class(typ):
500 | if not is_artifact_annotation(typ):
501 | return None
502 | if typ == Input or typ == Output:
503 | return None
504 |
505 | if hasattr(typ, '_subs_tree'): # Python 3.6
506 | subs_tree = typ._subs_tree()
507 | if len(subs_tree) != 3:
508 | return None
509 | return subs_tree[1]
510 |
511 | return typ.__args__[0]
512 |
513 | def get_io_artifact_annotation(typ):
514 | if not is_artifact_annotation(typ):
515 | return None
516 |
517 | if hasattr(typ, '_subs_tree'): # Python 3.6
518 | subs_tree = typ._subs_tree()
519 | if len(subs_tree) != 3:
520 | return None
521 | return subs_tree[2]
522 |
523 | return typ.__args__[1]
524 |
525 |
526 |
527 | _SCHEMA_TITLE_TO_TYPE: Dict[str, Artifact] = {
528 | x.TYPE_NAME: x
529 | for x in [Artifact, Model, Dataset, Metrics, ClassificationMetrics]
530 | }
531 |
532 |
533 | def create_runtime_artifact(runtime_artifact: Dict) -> Artifact:
534 | """Creates an Artifact instance from the specified RuntimeArtifact.
535 |
536 | Args:
537 | runtime_artifact: Dictionary representing JSON-encoded RuntimeArtifact.
538 | """
539 | schema_title = runtime_artifact.get('type', {}).get('schemaTitle', '')
540 |
541 | artifact_type = _SCHEMA_TITLE_TO_TYPE.get(schema_title)
542 | if not artifact_type:
543 | artifact_type = Artifact
544 | return artifact_type(
545 | uri=runtime_artifact.get('uri', ''),
546 | name=runtime_artifact.get('name', ''),
547 | metadata=runtime_artifact.get('metadata', {}),
548 | )
549 |
550 | class InputPath:
551 | '''When creating component from function, :class:`.InputPath` should be used as function parameter annotation to tell the system to pass the *data file path* to the function instead of passing the actual data.'''
552 | def __init__(self, type=None):
553 | self.type = type
554 |
555 | class OutputPath:
556 | '''When creating component from function, :class:`.OutputPath` should be used as function parameter annotation to tell the system that the function wants to output data by writing it into a file with the given path instead of returning the data from the function.'''
557 | def __init__(self, type=None):
558 | self.type = type
559 |
560 | class Executor():
561 | """Executor executes v2-based Python function components."""
562 |
563 | def __init__(self, executor_input: Dict, function_to_execute: Callable):
564 | self._func = function_to_execute
565 | self._input = executor_input
566 | self._input_artifacts: Dict[str, Artifact] = {}
567 | self._output_artifacts: Dict[str, Artifact] = {}
568 |
569 | for name, artifacts in self._input.get('inputs', {}).get('artifacts',
570 | {}).items():
571 | artifacts_list = artifacts.get('artifacts')
572 | if artifacts_list:
573 | self._input_artifacts[name] = self._make_input_artifact(
574 | artifacts_list[0])
575 |
576 | for name, artifacts in self._input.get('outputs', {}).get('artifacts',
577 | {}).items():
578 | artifacts_list = artifacts.get('artifacts')
579 | if artifacts_list:
580 | self._output_artifacts[name] = self._make_output_artifact(
581 | artifacts_list[0])
582 |
583 | self._return_annotation = inspect.signature(self._func).return_annotation
584 | self._executor_output = {}
585 |
586 | @classmethod
587 | def _make_input_artifact(cls, runtime_artifact: Dict):
588 | return create_runtime_artifact(runtime_artifact)
589 |
590 | @classmethod
591 | def _make_output_artifact(cls, runtime_artifact: Dict):
592 | import os
593 | artifact = create_runtime_artifact(runtime_artifact)
594 | os.makedirs(os.path.dirname(artifact.path), exist_ok=True)
595 | return artifact
596 |
597 | def _get_input_artifact(self, name: str):
598 | return self._input_artifacts.get(name)
599 |
600 | def _get_output_artifact(self, name: str):
601 | return self._output_artifacts.get(name)
602 |
603 | def _get_input_parameter_value(self, parameter_name: str, parameter_type: Any):
604 | parameter = self._input.get('inputs', {}).get('parameters',
605 | {}).get(parameter_name, None)
606 | if parameter is None:
607 | return None
608 |
609 | if parameter.get('stringValue'):
610 | if parameter_type == str:
611 | return parameter['stringValue']
612 | elif parameter_type == bool:
613 | # Use `.lower()` so it can also handle 'True' and 'False' (resulted from
614 | # `str(True)` and `str(False)`, respectively.
615 | return json.loads(parameter['stringValue'].lower())
616 | else:
617 | return json.loads(parameter['stringValue'])
618 | elif parameter.get('intValue'):
619 | return int(parameter['intValue'])
620 | elif parameter.get('doubleValue'):
621 | return float(parameter['doubleValue'])
622 |
623 | def _get_output_parameter_path(self, parameter_name: str):
624 | parameter_name = self._maybe_strip_path_suffix(parameter_name)
625 | parameter = self._input.get('outputs',
626 | {}).get('parameters',
627 | {}).get(parameter_name, None)
628 | if parameter is None:
629 | return None
630 |
631 | import os
632 | path = parameter.get('outputFile', None)
633 | if path:
634 | os.makedirs(os.path.dirname(path), exist_ok=True)
635 | return path
636 |
637 | def _get_output_artifact_path(self, artifact_name: str):
638 | artifact_name = self._maybe_strip_path_suffix(artifact_name)
639 | output_artifact = self._output_artifacts.get(artifact_name)
640 | if not output_artifact:
641 | raise ValueError(
642 | 'Failed to get output artifact path for artifact name {}'.format(
643 | artifact_name))
644 | return output_artifact.path
645 |
646 | def _get_input_artifact_path(self, artifact_name: str):
647 | artifact_name = self._maybe_strip_path_suffix(artifact_name)
648 | input_artifact = self._input_artifacts.get(artifact_name)
649 | if not input_artifact:
650 | raise ValueError(
651 | 'Failed to get input artifact path for artifact name {}'.format(
652 | artifact_name))
653 | return input_artifact.path
654 |
655 | def _write_output_parameter_value(self, name: str,
656 | value: Union[str, int, float, bool, dict,
657 | list, Dict, List]):
658 | if type(value) == str:
659 | output = {'stringValue': value}
660 | elif type(value) == int:
661 | output = {'intValue': value}
662 | elif type(value) == float:
663 | output = {'doubleValue': value}
664 | else:
665 | # For bool, list, dict, List, Dict, json serialize the value.
666 | output = {'stringValue': json.dumps(value)}
667 |
668 | if not self._executor_output.get('parameters'):
669 | self._executor_output['parameters'] = {}
670 |
671 | self._executor_output['parameters'][name] = output
672 |
673 | def _write_output_artifact_payload(self, name: str, value: Any):
674 | path = self._get_output_artifact_path(name)
675 | with open(path, 'w') as f:
676 | f.write(str(value))
677 |
678 | # TODO: extract to a util
679 | @classmethod
680 | def _get_short_type_name(cls, type_name: str) -> str:
681 | """Extracts the short form type name.
682 |
683 | This method is used for looking up serializer for a given type.
684 |
685 | For example:
686 | typing.List -> List
687 | typing.List[int] -> List
688 | typing.Dict[str, str] -> Dict
689 | List -> List
690 | str -> str
691 |
692 | Args:
693 | type_name: The original type name.
694 |
695 | Returns:
696 | The short form type name or the original name if pattern doesn't match.
697 | """
698 | import re
699 | match = re.match('(typing\.)?(?P\w+)(?:\[.+\])?', type_name)
700 | if match:
701 | return match.group('type')
702 | else:
703 | return type_name
704 |
705 | # TODO: merge with type_utils.is_parameter_type
706 | @classmethod
707 | def _is_parameter(cls, annotation: Any) -> bool:
708 | if type(annotation) == type:
709 | return annotation in [str, int, float, bool, dict, list]
710 |
711 | # Annotation could be, for instance `typing.Dict[str, str]`, etc.
712 | return cls._get_short_type_name(str(annotation)) in ['Dict', 'List']
713 |
714 | @classmethod
715 | def _is_artifact(cls, annotation: Any) -> bool:
716 | if type(annotation) == type:
717 | return issubclass(annotation, Artifact)
718 | return False
719 |
720 | @classmethod
721 | def _is_named_tuple(cls, annotation: Any) -> bool:
722 | if type(annotation) == type:
723 | return issubclass(annotation, tuple) and hasattr(
724 | annotation, '_fields') and hasattr(annotation, '__annotations__')
725 | return False
726 |
727 | def _handle_single_return_value(self, output_name: str, annotation_type: Any,
728 | return_value: Any):
729 | if self._is_parameter(annotation_type):
730 | if type(return_value) != annotation_type:
731 | raise ValueError(
732 | 'Function `{}` returned value of type {}; want type {}'.format(
733 | self._func.__name__, type(return_value), annotation_type))
734 | self._write_output_parameter_value(output_name, return_value)
735 | elif self._is_artifact(annotation_type):
736 | self._write_output_artifact_payload(output_name, return_value)
737 | else:
738 | raise RuntimeError(
739 | 'Unknown return type: {}. Must be one of `str`, `int`, `float`, or a'
740 | ' subclass of `Artifact`'.format(annotation_type))
741 |
742 | def _write_executor_output(self, func_output: Optional[Any] = None):
743 | if self._output_artifacts:
744 | self._executor_output['artifacts'] = {}
745 |
746 | for name, artifact in self._output_artifacts.items():
747 | runtime_artifact = {
748 | 'name': artifact.name,
749 | 'uri': artifact.uri,
750 | 'metadata': artifact.metadata,
751 | }
752 | artifacts_list = {'artifacts': [runtime_artifact]}
753 |
754 | self._executor_output['artifacts'][name] = artifacts_list
755 |
756 | if func_output is not None:
757 | if self._is_parameter(self._return_annotation) or self._is_artifact(
758 | self._return_annotation):
759 | # Note: single output is named `Output` in component.yaml.
760 | self._handle_single_return_value('Output', self._return_annotation,
761 | func_output)
762 | elif self._is_named_tuple(self._return_annotation):
763 | if len(self._return_annotation._fields) != len(func_output):
764 | raise RuntimeError(
765 | 'Expected {} return values from function `{}`, got {}'.format(
766 | len(self._return_annotation._fields), self._func.__name__,
767 | len(func_output)))
768 | for i in range(len(self._return_annotation._fields)):
769 | field = self._return_annotation._fields[i]
770 | field_type = self._return_annotation.__annotations__[field]
771 | if type(func_output) == tuple:
772 | field_value = func_output[i]
773 | else:
774 | field_value = getattr(func_output, field)
775 | self._handle_single_return_value(field, field_type, field_value)
776 | else:
777 | raise RuntimeError(
778 | 'Unknown return type: {}. Must be one of `str`, `int`, `float`, a'
779 | ' subclass of `Artifact`, or a NamedTuple collection of these types.'
780 | .format(self._return_annotation))
781 |
782 | import os
783 | os.makedirs(
784 | os.path.dirname(self._input['outputs']['outputFile']), exist_ok=True)
785 | with open(self._input['outputs']['outputFile'], 'w') as f:
786 | f.write(json.dumps(self._executor_output))
787 |
788 | def _maybe_strip_path_suffix(self, name) -> str:
789 | if name.endswith('_path'):
790 | name = name[0:-len('_path')]
791 | if name.endswith('_file'):
792 | name = name[0:-len('_file')]
793 | return name
794 |
795 | def execute(self):
796 | annotations = inspect.getfullargspec(self._func).annotations
797 |
798 | # Function arguments.
799 | func_kwargs = {}
800 |
801 | for k, v in annotations.items():
802 | if k == 'return':
803 | continue
804 |
805 | if self._is_parameter(v):
806 | func_kwargs[k] = self._get_input_parameter_value(k, v)
807 |
808 | if is_artifact_annotation(v):
809 | if is_input_artifact(v):
810 | func_kwargs[k] = self._get_input_artifact(k)
811 | if is_output_artifact(v):
812 | func_kwargs[k] = self._get_output_artifact(k)
813 |
814 | elif isinstance(v, OutputPath):
815 | if self._is_parameter(v.type):
816 | func_kwargs[k] = self._get_output_parameter_path(k)
817 | else:
818 | func_kwargs[k] = self._get_output_artifact_path(k)
819 | elif isinstance(v, InputPath):
820 | func_kwargs[k] = self._get_input_artifact_path(k)
821 |
822 | result = self._func(**func_kwargs)
823 | self._write_executor_output(result)
824 |
825 |
826 | def preprocess(output_csv_path: OutputPath('CSV')):
827 | #1
828 | from google.cloud import bigquery
829 | import google.auth
830 |
831 | creds, project = google.auth.default()
832 | client = bigquery.Client(project='kubeflow-demos', credentials=creds)
833 |
834 | query = """
835 | SELECT * FROM `kubeflow-demos.telco.churn`
836 | """
837 | print(query)
838 |
839 | dataframe = client.query(query).to_dataframe()
840 | print(dataframe.head())
841 |
842 | dataframe.to_csv(output_csv_path)
843 | print("done")
844 |
845 |
846 | def executor_main():
847 | import argparse
848 | import json
849 |
850 | parser = argparse.ArgumentParser(description='Process some integers.')
851 | parser.add_argument('--executor_input', type=str)
852 | parser.add_argument('--function_to_execute', type=str)
853 |
854 | args, _ = parser.parse_known_args()
855 | executor_input = json.loads(args.executor_input)
856 | function_to_execute = globals()[args.function_to_execute]
857 |
858 | executor = Executor(executor_input=executor_input,
859 | function_to_execute=function_to_execute)
860 |
861 | executor.execute()
862 |
863 |
864 | if __name__ == '__main__':
865 | executor_main()
866 | args:
867 | - --executor_input
868 | - {executorInput: null}
869 | - --function_to_execute
870 | - preprocess
871 | - --output-csv-output-path
872 | - {outputPath: output_csv}
873 |
--------------------------------------------------------------------------------