├── .gitignore ├── pipeline.png ├── README.md ├── converting_exploration_to_pipeline.ipynb └── preprocess.yaml /.gitignore: -------------------------------------------------------------------------------- 1 | .ipynb_checkpoints/ 2 | venv/ 3 | .env 4 | -------------------------------------------------------------------------------- /pipeline.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jy2k/Kubeflow-v2-end-to-end/HEAD/pipeline.png -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # kubeflow-vertex-ai 2 | 3 | ## This notebook shows how to go from EDA to a working Vertex AI pipeline (kubeflow) 4 | 5 | ![Screenshot](pipeline.png) -------------------------------------------------------------------------------- /converting_exploration_to_pipeline.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "2f2ee82c-e800-47e2-8978-bc66d227da73", 6 | "metadata": {}, 7 | "source": [ 8 | "### //TODO:\n", 9 | "Add component for testing the model after deployment\n", 10 | "\n", 11 | "Add component for model monitoring\n", 12 | "\n", 13 | "Split the preprocess into BQ component and transform component\n", 14 | "\n", 15 | "Add a tensorflow model\n", 16 | "\n", 17 | "Add a pytorch model\n", 18 | "\n", 19 | "Add AutoML model\n", 20 | "\n", 21 | "Add Error handling\n", 22 | "\n", 23 | "Add mondel monitoring" 24 | ] 25 | }, 26 | { 27 | "cell_type": "markdown", 28 | "id": "0b04cb1f-d96c-4720-a1d8-b2460de481e2", 29 | "metadata": {}, 30 | "source": [ 31 | "# Create & Deploy Vertex-AI Pipeline w/ Kubeflow" 32 | ] 33 | }, 34 | { 35 | "cell_type": "markdown", 36 | "id": "3140d27f-f516-49a9-9ff4-811fd263feeb", 37 | "metadata": {}, 38 | "source": [ 39 | "Install the needed libraries in order to run the code locally" 40 | ] 41 | }, 42 | { 43 | "cell_type": "code", 44 | "execution_count": 1, 45 | "id": "9560fe6f-8a20-476a-aea3-f4c4abaf37dd", 46 | "metadata": {}, 47 | "outputs": [], 48 | "source": [ 49 | "%%capture\n", 50 | "!pip3 install google-cloud-aiplatform==1.0.0 --upgrade\n", 51 | "!pip3 install kfp google-cloud-pipeline-components==0.1.1 --upgrade\n", 52 | "!pip3 install scikit-learn\n", 53 | "!pip3 install google-cloud-aiplatform --upgrade\n", 54 | "!pip3 install pandas\n", 55 | "!pip3 install python-dotenv" 56 | ] 57 | }, 58 | { 59 | "cell_type": "markdown", 60 | "id": "617816bc-8237-4c94-b00f-6d62530516fe", 61 | "metadata": {}, 62 | "source": [ 63 | "Might need to restart kernel after initial installation of the cell above" 64 | ] 65 | }, 66 | { 67 | "cell_type": "code", 68 | "execution_count": 170, 69 | "id": "97a0f438-f584-457e-b469-90d5ec9e9182", 70 | "metadata": {}, 71 | "outputs": [], 72 | "source": [ 73 | "import uuid\n", 74 | "from kfp import dsl\n", 75 | "import kfp\n", 76 | "from google.cloud import aiplatform\n", 77 | "from kfp.v2.dsl import component\n", 78 | "from kfp.v2.dsl import (Artifact, Dataset, Input, InputPath, Model, Output, OutputPath, component, ClassificationMetrics, Metrics)\n", 79 | "from google_cloud_pipeline_components import aiplatform as gcc_aip\n", 80 | "from typing import NamedTuple" 81 | ] 82 | }, 83 | { 84 | "cell_type": "markdown", 85 | "id": "f9a5a343-6bcb-45e6-8227-eca403590f91", 86 | "metadata": {}, 87 | "source": [ 88 | "Getting some preset environment variables save to a local file. Create one of your own by following these instructions: https://stackoverflow.com/a/54028874" 89 | ] 90 | }, 91 | { 92 | "cell_type": "code", 93 | "execution_count": 21, 94 | "id": "f3affb3e-fcbf-4167-abe7-04b619d69cac", 95 | "metadata": {}, 96 | "outputs": [ 97 | { 98 | "name": "stdout", 99 | "output_type": "stream", 100 | "text": [ 101 | "The dotenv extension is already loaded. To reload it, use:\n", 102 | " %reload_ext dotenv\n", 103 | "kubeflow-demos\n", 104 | "user-group-demo\n", 105 | "gs://user-group-demo/pipeline_root\n" 106 | ] 107 | } 108 | ], 109 | "source": [ 110 | "#https://stackoverflow.com/a/54028874\n", 111 | "%load_ext dotenv\n", 112 | "%dotenv\n", 113 | "\n", 114 | "import os\n", 115 | "PROJECT_ID = os.environ['PROJECT_ID']\n", 116 | "BUCKET_NAME = os.environ['BUCKET']\n", 117 | "\n", 118 | "PIPELINE_ROOT = 'gs://{}/pipeline_root'.format(BUCKET_NAME)\n", 119 | "REGION = 'us-central1'\n", 120 | "\n", 121 | "print(PROJECT_ID)\n", 122 | "print(BUCKET_NAME)\n", 123 | "print(PIPELINE_ROOT)" 124 | ] 125 | }, 126 | { 127 | "cell_type": "markdown", 128 | "id": "60d89156-3085-47e4-9eec-85c7c53f96d7", 129 | "metadata": {}, 130 | "source": [ 131 | "## 1. Create a component for reading data from BQ into CSV" 132 | ] 133 | }, 134 | { 135 | "cell_type": "code", 136 | "execution_count": 129, 137 | "id": "2298853a-ae42-42f3-8594-4b906552a54b", 138 | "metadata": {}, 139 | "outputs": [], 140 | "source": [ 141 | "@component(packages_to_install=[\"pandas\", \"google-cloud-aiplatform\", \"google-cloud-bigquery-storage\",\"google-cloud-bigquery\",\"pyarrow\"], output_component_file=\"preprocess.yaml\")\n", 142 | "def preprocess(output_csv_path: OutputPath('CSV')):\n", 143 | " #1\n", 144 | " from google.cloud import bigquery\n", 145 | " import google.auth\n", 146 | " \n", 147 | " creds, project = google.auth.default()\n", 148 | " client = bigquery.Client(project='kubeflow-demos', credentials=creds)\n", 149 | "\n", 150 | " query = \"\"\"\n", 151 | " SELECT * FROM `kubeflow-demos.telco.churn`\n", 152 | " \"\"\"\n", 153 | " print(query)\n", 154 | " \n", 155 | " dataframe = client.query(query).to_dataframe()\n", 156 | " print(dataframe.head())\n", 157 | " \n", 158 | " dataframe.to_csv(output_csv_path)\n", 159 | " print(\"done\")" 160 | ] 161 | }, 162 | { 163 | "cell_type": "markdown", 164 | "id": "48fdda2e-1254-4502-9f4f-9fdca5863c7a", 165 | "metadata": {}, 166 | "source": [ 167 | "## 2. Create a component to train" 168 | ] 169 | }, 170 | { 171 | "cell_type": "code", 172 | "execution_count": 183, 173 | "id": "54350599-3b5a-4e4c-8b01-25030ef103e1", 174 | "metadata": {}, 175 | "outputs": [], 176 | "source": [ 177 | "@component(packages_to_install=[\"pandas\", \"imbalanced-learn\", \"google-cloud-aiplatform\", \"pyarrow\"])\n", 178 | "def train(wmetrics: Output[ClassificationMetrics], input_csv_path: InputPath('CSV'), saved_model: Output[Model], artifact_uri: OutputPath(str), accuracy: Output[Metrics], model_type: str, project_id: str, bucket: str):\n", 179 | " from google.cloud import aiplatform\n", 180 | " from typing import NamedTuple\n", 181 | " #Train\n", 182 | " import pandas as pd\n", 183 | " df = pd.read_csv(input_csv_path)\n", 184 | " print(len(df))\n", 185 | " \n", 186 | " from sklearn.preprocessing import LabelEncoder\n", 187 | " for c in df.columns:\n", 188 | " if df[c].dtype=='object': #Since we are encoding object datatype to integer/float\n", 189 | " lbl = LabelEncoder()\n", 190 | " lbl.fit(list(df[c].values))\n", 191 | " df[c] = lbl.transform(df[c].values)\n", 192 | " print(df.head()) #To check if properly encoded\n", 193 | " \n", 194 | " X = df[['Contract', 'tenure', 'TechSupport', 'OnlineSecurity', 'TotalCharges', 'PaperlessBilling',\n", 195 | " 'DeviceProtection', 'Dependents', 'OnlineBackup', 'SeniorCitizen', 'MonthlyCharges',\n", 196 | " 'PaymentMethod', 'Partner', 'PhoneService']] #taking only relevant columns\n", 197 | " y = df['Churn']\n", 198 | "\n", 199 | "\n", 200 | " # Scaling all the variables to a range of 0 to 1\n", 201 | " from sklearn.preprocessing import MinMaxScaler\n", 202 | " features = X.columns.values\n", 203 | " scaler = MinMaxScaler(feature_range = (0,1))\n", 204 | " scaler.fit(X)\n", 205 | " X = pd.DataFrame(scaler.transform(X))\n", 206 | " X.columns = features\n", 207 | " \n", 208 | " from sklearn.model_selection import train_test_split\n", 209 | " X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=101)\n", 210 | "\n", 211 | " #Choose which model to train\n", 212 | " if model_type == 'linear_regression':\n", 213 | " from sklearn.linear_model import LogisticRegression\n", 214 | " model = LogisticRegression()\n", 215 | " \n", 216 | " elif model_type == 'naive_bayes':\n", 217 | " from sklearn.naive_bayes import GaussianNB\n", 218 | " model = GaussianNB()\n", 219 | " \n", 220 | " elif model_type == 'decision_tree':\n", 221 | " from sklearn.tree import DecisionTreeClassifier\n", 222 | " model = DecisionTreeClassifier()\n", 223 | " \n", 224 | " model.fit(X_train, y_train)\n", 225 | " \n", 226 | " #Save the model to disk and also automatically to GCS\n", 227 | " import joblib\n", 228 | " \n", 229 | " joblib.dump(model, os.path.join(saved_model.path.replace(\"saved_model\",\"\"), 'model.joblib'))\n", 230 | " print(\" saved_model.path: \"+ saved_model.path)\n", 231 | " print(\" saved_model.uri: \"+ saved_model.uri)\n", 232 | " with open(artifact_uri, 'w') as f:\n", 233 | " f.write(saved_model.uri.replace(\"saved_model\",\"\"))\n", 234 | " \n", 235 | " print(saved_model.uri)\n", 236 | " \n", 237 | " accuracy.log_metric('accuracy', 71.0)\n", 238 | " \n", 239 | " #Adding roc curve\n", 240 | " from sklearn.metrics import roc_curve\n", 241 | " from sklearn.model_selection import cross_val_predict, train_test_split\n", 242 | "\n", 243 | " y_scores = cross_val_predict(model, X_train, y_train, cv=3, method=\"predict_proba\")\n", 244 | " fpr, tpr, thresholds = roc_curve(\n", 245 | " y_true=y_train, y_score=y_scores[:, 1], pos_label=True\n", 246 | " )\n", 247 | " wmetrics.log_roc_curve(fpr, tpr, thresholds)\n" 248 | ] 249 | }, 250 | { 251 | "cell_type": "markdown", 252 | "id": "fdfccc2d-7633-4579-9c88-f9b6ef368172", 253 | "metadata": {}, 254 | "source": [ 255 | "## 3. Eval component" 256 | ] 257 | }, 258 | { 259 | "cell_type": "code", 260 | "execution_count": 184, 261 | "id": "d1ebe330-78f6-4de7-aaea-c989cde694ea", 262 | "metadata": {}, 263 | "outputs": [], 264 | "source": [ 265 | "@component()\n", 266 | "def eval(baseline: float, accuracy: Input[Metrics], accuracy2: Input[Metrics], accuracy3: Input[Metrics]) -> bool:\n", 267 | " isBetter = False\n", 268 | " \n", 269 | " print(str(dir(accuracy)))\n", 270 | " new_val = float(accuracy.metadata['accuracy'])\n", 271 | " print(str(new_val))\n", 272 | " \n", 273 | " \n", 274 | " if new_val>baseline:\n", 275 | " isBetter = True\n", 276 | " print(\"isBetter: \"+str(isBetter))\n", 277 | " \n", 278 | " return isBetter" 279 | ] 280 | }, 281 | { 282 | "cell_type": "markdown", 283 | "id": "9429e59f-798c-49a6-82e2-7c7854f34c8d", 284 | "metadata": {}, 285 | "source": [ 286 | "## 4. Predict Endpoint component" 287 | ] 288 | }, 289 | { 290 | "cell_type": "code", 291 | "execution_count": 185, 292 | "id": "70ea05e3-9efc-4fe1-9b31-14321326e00c", 293 | "metadata": {}, 294 | "outputs": [], 295 | "source": [ 296 | "#https://github.com/googleapis/python-aiplatform/blob/master/samples/snippets/predict_custom_trained_model_sample.py\n", 297 | "#https://cloud.google.com/ai-platform/prediction/docs/online-predict\n", 298 | "@component(packages_to_install=[\"google-cloud-aiplatform\"])\n", 299 | "def predict_endpoint_test(endpoint_id: Input[Artifact],\n", 300 | " location: str,\n", 301 | " project: str,\n", 302 | " api_endpoint: str = \"us-central1-aiplatform.googleapis.com\"):\n", 303 | " \n", 304 | " from typing import Dict\n", 305 | " from google.cloud import aiplatform\n", 306 | " from google.protobuf import json_format\n", 307 | " from google.protobuf.struct_pb2 import Value\n", 308 | " \n", 309 | " print(endpoint_id)\n", 310 | " endpoint_id = endpoint_id.uri.split('/')[-1]\n", 311 | " print(endpoint_id)\n", 312 | " \n", 313 | " client_options = {\"api_endpoint\": api_endpoint}\n", 314 | " \n", 315 | " client = aiplatform.gapic.PredictionServiceClient(client_options=client_options)\n", 316 | " #https://machinelearningmastery.com/make-predictions-scikit-learn/\n", 317 | " instance_dict = [ 1.74481176, 0.86540763, -1.07296862 ,-2.3015387, -2.06014071, 1.46210794, 0.3190391 , -0.24937038 ,-0.61175641 ,-0.7612069 , -0.38405435, -0.52817175, -0.3224172, 1.62434536]\n", 318 | " \n", 319 | " instance = json_format.ParseDict(instance_dict, Value())\n", 320 | " instances = [instance]\n", 321 | " \n", 322 | " endpoint = client.endpoint_path(\n", 323 | " project=project, location=location, endpoint=endpoint_id\n", 324 | " )\n", 325 | " response = client.predict(\n", 326 | " endpoint=endpoint, instances=instances\n", 327 | " )\n", 328 | " print(\"response\")\n", 329 | " print(\" deployed_model_id:\", response.deployed_model_id)\n", 330 | " \n", 331 | " predictions = response.predictions\n", 332 | " for prediction in predictions:\n", 333 | " print(\" prediction:\" + str(prediction))" 334 | ] 335 | }, 336 | { 337 | "cell_type": "code", 338 | "execution_count": 191, 339 | "id": "847e244b-6007-4e34-8887-877cea41c445", 340 | "metadata": {}, 341 | "outputs": [], 342 | "source": [ 343 | "from datetime import datetime\n", 344 | "\n", 345 | "TIMESTAMP = datetime.now().strftime(\"%Y%m%d%H%M%S\")" 346 | ] 347 | }, 348 | { 349 | "cell_type": "code", 350 | "execution_count": 192, 351 | "id": "81684b00-a7df-428d-a9e1-92faf21dd6ff", 352 | "metadata": {}, 353 | "outputs": [], 354 | "source": [ 355 | "@kfp.dsl.pipeline(name=\"train-scikit\" + str(uuid.uuid4()))\n", 356 | "def pipeline(\n", 357 | " project: str = PROJECT_ID,\n", 358 | " bucket: str = BUCKET_NAME,\n", 359 | " baseline_accuracy: float = 70.0\n", 360 | "):\n", 361 | " preprocess_task = preprocess()\n", 362 | " \n", 363 | " train_task = train(preprocess_task.output, model_type='decision_tree', project_id=PROJECT_ID, bucket=BUCKET_NAME)\n", 364 | " train_task2 = train(preprocess_task.output, model_type='decision_tree', project_id=PROJECT_ID, bucket=BUCKET_NAME)\n", 365 | " train_task3 = train(preprocess_task.output, model_type='decision_tree', project_id=PROJECT_ID, bucket=BUCKET_NAME)\n", 366 | " \n", 367 | " eval_task = eval(baseline_accuracy, train_task.outputs[\"accuracy\"], train_task2.outputs[\"accuracy\"], train_task3.outputs[\"accuracy\"])\n", 368 | " \n", 369 | " with dsl.Condition(eval_task.output == \"true\", name=\"eval models\"):\n", 370 | " model_upload_op = gcc_aip.ModelUploadOp(\n", 371 | " project=PROJECT_ID,\n", 372 | " display_name=\"model\"+TIMESTAMP, \n", 373 | " # artifact_uri=\"gs://user-group-demo/pipeline_root/141610882258/train-scikitf989f632-b955-4bb1-a72d-0480d1c08627-20210620145355/train_-6780204423378370560/\", # GCS location of model\n", 374 | " artifact_uri=train_task.outputs[\"artifact_uri\"], # GCS location of model\n", 375 | " serving_container_image_uri=\"us-docker.pkg.dev/vertex-ai/prediction/sklearn-cpu.0-24:latest\",\n", 376 | " )\n", 377 | "\n", 378 | " endpoint_create_op = gcc_aip.EndpointCreateOp(\n", 379 | " project=PROJECT_ID,\n", 380 | " display_name=\"pipelines\"+TIMESTAMP,\n", 381 | " )\n", 382 | "\n", 383 | " model_deploy_op = gcc_aip.ModelDeployOp( \n", 384 | " project=PROJECT_ID,\n", 385 | " endpoint=endpoint_create_op.outputs[\"endpoint\"],\n", 386 | " model=model_upload_op.outputs[\"model\"],\n", 387 | " deployed_model_display_name=\"model_display_name\",\n", 388 | " machine_type=\"n1-standard-4\",\n", 389 | " )\n", 390 | " \n", 391 | " predict_task = predict_endpoint_test(project=PROJECT_ID, location=REGION, endpoint_id = model_deploy_op.outputs['endpoint'])" 392 | ] 393 | }, 394 | { 395 | "cell_type": "code", 396 | "execution_count": 193, 397 | "id": "e3b37e84-3ed5-4c69-87d1-cf19d27a4212", 398 | "metadata": {}, 399 | "outputs": [], 400 | "source": [ 401 | "from kfp.v2 import compiler\n", 402 | "\n", 403 | "compiler.Compiler().compile(pipeline_func=pipeline, \n", 404 | " package_path=\"dag-\"+TIMESTAMP+\".json\")" 405 | ] 406 | }, 407 | { 408 | "cell_type": "code", 409 | "execution_count": 194, 410 | "id": "37927c38-8ab2-45be-8d27-156de5123c6e", 411 | "metadata": {}, 412 | "outputs": [], 413 | "source": [ 414 | "from kfp.v2.google.client import AIPlatformClient\n", 415 | "\n", 416 | "api_client = AIPlatformClient(\n", 417 | " project_id=PROJECT_ID,\n", 418 | " region=REGION,\n", 419 | ")" 420 | ] 421 | }, 422 | { 423 | "cell_type": "code", 424 | "execution_count": 195, 425 | "id": "4523294d-c363-479a-9198-4995431f67fc", 426 | "metadata": {}, 427 | "outputs": [ 428 | { 429 | "data": { 430 | "text/html": [ 431 | "See the Pipeline job here." 432 | ], 433 | "text/plain": [ 434 | "" 435 | ] 436 | }, 437 | "metadata": {}, 438 | "output_type": "display_data" 439 | } 440 | ], 441 | "source": [ 442 | "response = api_client.create_run_from_job_spec(\n", 443 | " \"dag-\"+TIMESTAMP+\".json\",\n", 444 | " pipeline_root=PIPELINE_ROOT,\n", 445 | " parameter_values={},\n", 446 | ")" 447 | ] 448 | }, 449 | { 450 | "cell_type": "code", 451 | "execution_count": null, 452 | "id": "b878ae2a-cbcb-400a-b714-f2d32a70a3f3", 453 | "metadata": {}, 454 | "outputs": [], 455 | "source": [] 456 | } 457 | ], 458 | "metadata": { 459 | "kernelspec": { 460 | "display_name": "Python 3", 461 | "language": "python", 462 | "name": "python3" 463 | }, 464 | "language_info": { 465 | "codemirror_mode": { 466 | "name": "ipython", 467 | "version": 3 468 | }, 469 | "file_extension": ".py", 470 | "mimetype": "text/x-python", 471 | "name": "python", 472 | "nbconvert_exporter": "python", 473 | "pygments_lexer": "ipython3", 474 | "version": "3.9.4" 475 | } 476 | }, 477 | "nbformat": 4, 478 | "nbformat_minor": 5 479 | } 480 | -------------------------------------------------------------------------------- /preprocess.yaml: -------------------------------------------------------------------------------- 1 | name: Preprocess 2 | outputs: 3 | - {name: output_csv, type: CSV} 4 | implementation: 5 | container: 6 | image: python:3.7 7 | command: 8 | - sh 9 | - -c 10 | - (PIP_DISABLE_PIP_VERSION_CHECK=1 python3 -m pip install --quiet --no-warn-script-location 11 | 'pandas' 'google-cloud-aiplatform' 'google-cloud-bigquery-storage' 'google-cloud-bigquery' 12 | 'pyarrow' || PIP_DISABLE_PIP_VERSION_CHECK=1 python3 -m pip install --quiet 13 | --no-warn-script-location 'pandas' 'google-cloud-aiplatform' 'google-cloud-bigquery-storage' 14 | 'google-cloud-bigquery' 'pyarrow' --user) && "$0" "$@" 15 | - sh 16 | - -ec 17 | - | 18 | program_path=$(mktemp) 19 | printf "%s" "$0" > "$program_path" 20 | python3 -u "$program_path" "$@" 21 | - |2 22 | 23 | import json 24 | import inspect 25 | from typing import * 26 | 27 | # Copyright 2021 The Kubeflow Authors 28 | # 29 | # Licensed under the Apache License, Version 2.0 (the "License"); 30 | # you may not use this file except in compliance with the License. 31 | # You may obtain a copy of the License at 32 | # 33 | # http://www.apache.org/licenses/LICENSE-2.0 34 | # 35 | # Unless required by applicable law or agreed to in writing, software 36 | # distributed under the License is distributed on an "AS IS" BASIS, 37 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 38 | # See the License for the specific language governing permissions and 39 | # limitations under the License. 40 | """Classes for input/output types in KFP SDK. 41 | 42 | These are only compatible with v2 Pipelines. 43 | """ 44 | 45 | import os 46 | from typing import Dict, Generic, List, Optional, Type, TypeVar, Union 47 | 48 | 49 | _GCS_LOCAL_MOUNT_PREFIX = '/gcs/' 50 | _MINIO_LOCAL_MOUNT_PREFIX = '/minio/' 51 | _S3_LOCAL_MOUNT_PREFIX = '/s3/' 52 | 53 | 54 | class Artifact(object): 55 | """Generic Artifact class. 56 | 57 | This class is meant to represent the metadata around an input or output 58 | machine-learning Artifact. Artifacts have URIs, which can either be a location 59 | on disk (or Cloud storage) or some other resource identifier such as 60 | an API resource name. 61 | 62 | Artifacts carry a `metadata` field, which is a dictionary for storing 63 | metadata related to this artifact. 64 | """ 65 | TYPE_NAME = 'system.Artifact' 66 | 67 | def __init__(self, 68 | name: Optional[str] = None, 69 | uri: Optional[str] = None, 70 | metadata: Optional[Dict] = None): 71 | """Initializes the Artifact with the given name, URI and metadata.""" 72 | self.uri = uri or '' 73 | self.name = name or '' 74 | self.metadata = metadata or {} 75 | 76 | @property 77 | def path(self): 78 | return self._get_path() 79 | 80 | @path.setter 81 | def path(self, path): 82 | self._set_path(path) 83 | 84 | def _get_path(self) -> Optional[str]: 85 | if self.uri.startswith('gs://'): 86 | return _GCS_LOCAL_MOUNT_PREFIX + self.uri[len('gs://'):] 87 | elif self.uri.startswith('minio://'): 88 | return _MINIO_LOCAL_MOUNT_PREFIX + self.uri[len('minio://'):] 89 | elif self.uri.startswith('s3://'): 90 | return _S3_LOCAL_MOUNT_PREFIX + self.uri[len('s3://'):] 91 | return None 92 | 93 | def _set_path(self, path): 94 | if path.startswith(_GCS_LOCAL_MOUNT_PREFIX): 95 | path = 'gs://' + path[len(_GCS_LOCAL_MOUNT_PREFIX):] 96 | elif path.startswith(_MINIO_LOCAL_MOUNT_PREFIX): 97 | path = 'minio://' + path[len(_MINIO_LOCAL_MOUNT_PREFIX):] 98 | elif path.startswith(_S3_LOCAL_MOUNT_PREFIX): 99 | path = 's3://' + path[len(_S3_LOCAL_MOUNT_PREFIX):] 100 | self.uri = path 101 | 102 | 103 | class Model(Artifact): 104 | """An artifact representing an ML Model.""" 105 | TYPE_NAME = 'system.Model' 106 | 107 | def __init__(self, 108 | name: Optional[str] = None, 109 | uri: Optional[str] = None, 110 | metadata: Optional[Dict] = None): 111 | super().__init__(uri=uri, name=name, metadata=metadata) 112 | 113 | @property 114 | def framework(self) -> str: 115 | return self._get_framework() 116 | 117 | def _get_framework(self) -> str: 118 | return self.metadata.get('framework', '') 119 | 120 | @framework.setter 121 | def framework(self, framework: str): 122 | self._set_framework(framework) 123 | 124 | def _set_framework(self, framework: str): 125 | self.metadata['framework'] = framework 126 | 127 | 128 | class Dataset(Artifact): 129 | """An artifact representing an ML Dataset.""" 130 | TYPE_NAME = 'system.Dataset' 131 | 132 | def __init__(self, 133 | name: Optional[str] = None, 134 | uri: Optional[str] = None, 135 | metadata: Optional[Dict] = None): 136 | super().__init__(uri=uri, name=name, metadata=metadata) 137 | 138 | 139 | class Metrics(Artifact): 140 | """Represent a simple base Artifact type to store key-value scalar metrics.""" 141 | TYPE_NAME = 'system.Metrics' 142 | 143 | def __init__(self, 144 | name: Optional[str] = None, 145 | uri: Optional[str] = None, 146 | metadata: Optional[Dict] = None): 147 | super().__init__(uri=uri, name=name, metadata=metadata) 148 | 149 | def log_metric(self, metric: str, value: float): 150 | """Sets a custom scalar metric. 151 | 152 | Args: 153 | metric: Metric key 154 | value: Value of the metric. 155 | """ 156 | self.metadata[metric] = value 157 | 158 | 159 | class ClassificationMetrics(Artifact): 160 | """Represents Artifact class to store Classification Metrics.""" 161 | TYPE_NAME = 'system.ClassificationMetrics' 162 | 163 | def __init__(self, 164 | name: Optional[str] = None, 165 | uri: Optional[str] = None, 166 | metadata: Optional[Dict] = None): 167 | super().__init__(uri=uri, name=name, metadata=metadata) 168 | 169 | def log_roc_data_point(self, fpr: float, tpr: float, threshold: float): 170 | """Logs a single data point in the ROC Curve. 171 | 172 | Args: 173 | fpr: False positive rate value of the data point. 174 | tpr: True positive rate value of the data point. 175 | threshold: Threshold value for the data point. 176 | """ 177 | 178 | roc_reading = { 179 | 'confidenceThreshold': threshold, 180 | 'recall': tpr, 181 | 'falsePositiveRate': fpr 182 | } 183 | if 'confidenceMetrics' not in self.metadata.keys(): 184 | self.metadata['confidenceMetrics'] = [] 185 | 186 | self.metadata['confidenceMetrics'].append(roc_reading) 187 | 188 | def log_roc_curve(self, fpr: List[float], tpr: List[float], 189 | threshold: List[float]): 190 | """Logs an ROC curve. 191 | 192 | The list length of fpr, tpr and threshold must be the same. 193 | 194 | Args: 195 | fpr: List of false positive rate values. 196 | tpr: List of true positive rate values. 197 | threshold: List of threshold values. 198 | """ 199 | if len(fpr) != len(tpr) or len(fpr) != len(threshold) or len(tpr) != len( 200 | threshold): 201 | raise ValueError('Length of fpr, tpr and threshold must be the same. ' 202 | 'Got lengths {}, {} and {} respectively.'.format( 203 | len(fpr), len(tpr), len(threshold))) 204 | 205 | for i in range(len(fpr)): 206 | self.log_roc_data_point(fpr=fpr[i], tpr=tpr[i], threshold=threshold[i]) 207 | 208 | def set_confusion_matrix_categories(self, categories: List[str]): 209 | """Stores confusion matrix categories. 210 | 211 | Args: 212 | categories: List of strings specifying the categories. 213 | """ 214 | 215 | self._categories = [] 216 | annotation_specs = [] 217 | for category in categories: 218 | annotation_spec = {'displayName': category} 219 | self._categories.append(category) 220 | annotation_specs.append(annotation_spec) 221 | 222 | self._matrix = [] 223 | for row in range(len(self._categories)): 224 | self._matrix.append({'row': [0] * len(self._categories)}) 225 | 226 | self._confusion_matrix = {} 227 | self._confusion_matrix['annotationSpecs'] = annotation_specs 228 | self._confusion_matrix['rows'] = self._matrix 229 | self.metadata['confusionMatrix'] = self._confusion_matrix 230 | 231 | def log_confusion_matrix_row(self, row_category: str, row: List[float]): 232 | """Logs a confusion matrix row. 233 | 234 | Args: 235 | row_category: Category to which the row belongs. 236 | row: List of integers specifying the values for the row. 237 | 238 | Raises: 239 | ValueError: If row_category is not in the list of categories 240 | set in set_categories call. 241 | """ 242 | if row_category not in self._categories: 243 | raise ValueError('Invalid category: {} passed. Expected one of: {}'.\ 244 | format(row_category, self._categories)) 245 | 246 | if len(row) != len(self._categories): 247 | raise ValueError('Invalid row. Expected size: {} got: {}'.\ 248 | format(len(self._categories), len(row))) 249 | 250 | self._matrix[self._categories.index(row_category)] = {'row': row} 251 | self.metadata['confusionMatrix'] = self._confusion_matrix 252 | 253 | def log_confusion_matrix_cell(self, row_category: str, col_category: str, 254 | value: int): 255 | """Logs a cell in the confusion matrix. 256 | 257 | Args: 258 | row_category: String representing the name of the row category. 259 | col_category: String representing the name of the column category. 260 | value: Int value of the cell. 261 | 262 | Raises: 263 | ValueError: If row_category or col_category is not in the list of 264 | categories set in set_categories. 265 | """ 266 | if row_category not in self._categories: 267 | raise ValueError('Invalid category: {} passed. Expected one of: {}'.\ 268 | format(row_category, self._categories)) 269 | 270 | if col_category not in self._categories: 271 | raise ValueError('Invalid category: {} passed. Expected one of: {}'.\ 272 | format(row_category, self._categories)) 273 | 274 | self._matrix[self._categories.index(row_category)]['row'][ 275 | self._categories.index(col_category)] = value 276 | self.metadata['confusionMatrix'] = self._confusion_matrix 277 | 278 | def log_confusion_matrix(self, categories: List[str], 279 | matrix: List[List[int]]): 280 | """Logs a confusion matrix. 281 | 282 | Args: 283 | categories: List of the category names. 284 | matrix: Complete confusion matrix. 285 | 286 | Raises: 287 | ValueError: Length of categories does not match number of rows or columns. 288 | """ 289 | self.set_confusion_matrix_categories(categories) 290 | 291 | if len(matrix) != len(categories): 292 | raise ValueError('Invalid matrix: {} passed for categories: {}'.\ 293 | format(matrix, categories)) 294 | 295 | for index in range(len(categories)): 296 | if len(matrix[index]) != len(categories): 297 | raise ValueError('Invalid matrix: {} passed for categories: {}'.\ 298 | format(matrix, categories)) 299 | 300 | self.log_confusion_matrix_row(categories[index], matrix[index]) 301 | 302 | self.metadata['confusionMatrix'] = self._confusion_matrix 303 | 304 | 305 | class SlicedClassificationMetrics(Artifact): 306 | """Metrics class representing Sliced Classification Metrics. 307 | 308 | Similar to ClassificationMetrics clients using this class are expected to use 309 | log methods of the class to log metrics with the difference being each log 310 | method takes a slice to associate the ClassificationMetrics. 311 | 312 | """ 313 | 314 | TYPE_NAME = 'system.SlicedClassificationMetrics' 315 | 316 | def __init__(self, 317 | name: Optional[str] = None, 318 | uri: Optional[str] = None, 319 | metadata: Optional[Dict] = None): 320 | super().__init__(uri=uri, name=name, metadata=metadata) 321 | 322 | def _upsert_classification_metrics_for_slice(self, slice: str): 323 | """Upserts the classification metrics instance for a slice.""" 324 | if slice not in self._sliced_metrics: 325 | self._sliced_metrics[slice] = ClassificationMetrics() 326 | 327 | def _update_metadata(self, slice: str): 328 | """Updates metadata to adhere to the metrics schema.""" 329 | self.metadata = {} 330 | self.metadata['evaluationSlices'] = [] 331 | for slice in self._sliced_metrics.keys(): 332 | slice_metrics = { 333 | 'slice': slice, 334 | 'sliceClassificationMetrics': self._sliced_metrics[slice].metadata 335 | } 336 | self.metadata['evaluationSlices'].append(slice_metrics) 337 | 338 | def log_roc_reading(self, slice: str, threshold: float, tpr: float, 339 | fpr: float): 340 | """Logs a single data point in the ROC Curve of a slice. 341 | 342 | Args: 343 | slice: String representing slice label. 344 | threshold: Thresold value for the data point. 345 | tpr: True positive rate value of the data point. 346 | fpr: False positive rate value of the data point. 347 | """ 348 | 349 | self._upsert_classification_metrics_for_slice(slice) 350 | self._sliced_metrics[slice].log_roc_reading(threshold, tpr, fpr) 351 | self._update_metadata(slice) 352 | 353 | def load_roc_readings(self, slice: str, readings: List[List[float]]): 354 | """Supports bulk loading ROC Curve readings for a slice. 355 | 356 | Args: 357 | slice: String representing slice label. 358 | readings: A 2-D list providing ROC Curve data points. 359 | The expected order of the data points is: threshold, 360 | true_positive_rate, false_positive_rate. 361 | """ 362 | self._upsert_classification_metrics_for_slice(slice) 363 | self._sliced_metrics[slice].load_roc_readings(readings) 364 | self._update_metadata(slice) 365 | 366 | def set_confusion_matrix_categories(self, slice: str, categories: List[str]): 367 | """Stores confusion matrix categories for a slice.. 368 | 369 | Categories are stored in the internal metrics_utils.ConfusionMatrix 370 | instance of the slice. 371 | 372 | Args: 373 | slice: String representing slice label. 374 | categories: List of strings specifying the categories. 375 | """ 376 | self._upsert_classification_metrics_for_slice(slice) 377 | self._sliced_metrics[slice].set_confusion_matrix_categories(categories) 378 | self._update_metadata(slice) 379 | 380 | def log_confusion_matrix_row(self, slice: str, row_category: str, 381 | row: List[int]): 382 | """Logs a confusion matrix row for a slice. 383 | 384 | Row is updated on the internal metrics_utils.ConfusionMatrix 385 | instance of the slice. 386 | 387 | Args: 388 | slice: String representing slice label. 389 | row_category: Category to which the row belongs. 390 | row: List of integers specifying the values for the row. 391 | """ 392 | self._upsert_classification_metrics_for_slice(slice) 393 | self._sliced_metrics[slice].log_confusion_matrix_row(row_category, row) 394 | self._update_metadata(slice) 395 | 396 | def log_confusion_matrix_cell(self, slice: str, row_category: str, 397 | col_category: str, value: int): 398 | """Logs a confusion matrix cell for a slice.. 399 | 400 | Cell is updated on the internal metrics_utils.ConfusionMatrix 401 | instance of the slice. 402 | 403 | Args: 404 | slice: String representing slice label. 405 | row_category: String representing the name of the row category. 406 | col_category: String representing the name of the column category. 407 | value: Int value of the cell. 408 | """ 409 | self._upsert_classification_metrics_for_slice(slice) 410 | self._sliced_metrics[slice].log_confusion_matrix_cell( 411 | row_category, col_category, value) 412 | self._update_metadata(slice) 413 | 414 | def load_confusion_matrix(self, slice: str, categories: List[str], 415 | matrix: List[List[int]]): 416 | """Supports bulk loading the whole confusion matrix for a slice. 417 | 418 | Args: 419 | slice: String representing slice label. 420 | categories: List of the category names. 421 | matrix: Complete confusion matrix. 422 | """ 423 | self._upsert_classification_metrics_for_slice(slice) 424 | self._sliced_metrics[slice].log_confusion_matrix_cell(categories, matrix) 425 | self._update_metadata(slice) 426 | 427 | 428 | T = TypeVar('T') 429 | 430 | 431 | class InputAnnotation(): 432 | """Marker type for input artifacts.""" 433 | pass 434 | 435 | 436 | 437 | class OutputAnnotation(): 438 | """Marker type for output artifacts.""" 439 | pass 440 | 441 | 442 | # TODO: Use typing.Annotated instead of this hack. 443 | # With typing.Annotated (Python 3.9+ or typing_extensions package), the 444 | # following would look like: 445 | # Input = typing.Annotated[T, InputAnnotation] 446 | # Output = typing.Annotated[T, OutputAnnotation] 447 | 448 | 449 | # Input represents an Input artifact of type T. 450 | Input = Union[T, InputAnnotation] 451 | 452 | # Output represents an Output artifact of type T. 453 | Output = Union[T, OutputAnnotation] 454 | 455 | 456 | def is_artifact_annotation(typ) -> bool: 457 | if hasattr(typ, '_subs_tree'): # Python 3.6 458 | subs_tree = typ._subs_tree() 459 | return len(subs_tree) == 3 and subs_tree[0] == Union and subs_tree[2] in [InputAnnotation, OutputAnnotation] 460 | 461 | if not hasattr(typ, '__origin__'): 462 | return False 463 | 464 | 465 | if typ.__origin__ != Union and type(typ.__origin__) != type(Union): 466 | return False 467 | 468 | 469 | if not hasattr(typ, '__args__') or len(typ.__args__) != 2: 470 | return False 471 | 472 | if typ.__args__[1] not in [InputAnnotation, OutputAnnotation]: 473 | return False 474 | 475 | return True 476 | 477 | def is_input_artifact(typ) -> bool: 478 | """Returns True if typ is of type Input[T].""" 479 | if not is_artifact_annotation(typ): 480 | return False 481 | 482 | if hasattr(typ, '_subs_tree'): # Python 3.6 483 | subs_tree = typ._subs_tree() 484 | return len(subs_tree) == 3 and subs_tree[2] == InputAnnotation 485 | 486 | return typ.__args__[1] == InputAnnotation 487 | 488 | def is_output_artifact(typ) -> bool: 489 | """Returns True if typ is of type Output[T].""" 490 | if not is_artifact_annotation(typ): 491 | return False 492 | 493 | if hasattr(typ, '_subs_tree'): # Python 3.6 494 | subs_tree = typ._subs_tree() 495 | return len(subs_tree) == 3 and subs_tree[2] == OutputAnnotation 496 | 497 | return typ.__args__[1] == OutputAnnotation 498 | 499 | def get_io_artifact_class(typ): 500 | if not is_artifact_annotation(typ): 501 | return None 502 | if typ == Input or typ == Output: 503 | return None 504 | 505 | if hasattr(typ, '_subs_tree'): # Python 3.6 506 | subs_tree = typ._subs_tree() 507 | if len(subs_tree) != 3: 508 | return None 509 | return subs_tree[1] 510 | 511 | return typ.__args__[0] 512 | 513 | def get_io_artifact_annotation(typ): 514 | if not is_artifact_annotation(typ): 515 | return None 516 | 517 | if hasattr(typ, '_subs_tree'): # Python 3.6 518 | subs_tree = typ._subs_tree() 519 | if len(subs_tree) != 3: 520 | return None 521 | return subs_tree[2] 522 | 523 | return typ.__args__[1] 524 | 525 | 526 | 527 | _SCHEMA_TITLE_TO_TYPE: Dict[str, Artifact] = { 528 | x.TYPE_NAME: x 529 | for x in [Artifact, Model, Dataset, Metrics, ClassificationMetrics] 530 | } 531 | 532 | 533 | def create_runtime_artifact(runtime_artifact: Dict) -> Artifact: 534 | """Creates an Artifact instance from the specified RuntimeArtifact. 535 | 536 | Args: 537 | runtime_artifact: Dictionary representing JSON-encoded RuntimeArtifact. 538 | """ 539 | schema_title = runtime_artifact.get('type', {}).get('schemaTitle', '') 540 | 541 | artifact_type = _SCHEMA_TITLE_TO_TYPE.get(schema_title) 542 | if not artifact_type: 543 | artifact_type = Artifact 544 | return artifact_type( 545 | uri=runtime_artifact.get('uri', ''), 546 | name=runtime_artifact.get('name', ''), 547 | metadata=runtime_artifact.get('metadata', {}), 548 | ) 549 | 550 | class InputPath: 551 | '''When creating component from function, :class:`.InputPath` should be used as function parameter annotation to tell the system to pass the *data file path* to the function instead of passing the actual data.''' 552 | def __init__(self, type=None): 553 | self.type = type 554 | 555 | class OutputPath: 556 | '''When creating component from function, :class:`.OutputPath` should be used as function parameter annotation to tell the system that the function wants to output data by writing it into a file with the given path instead of returning the data from the function.''' 557 | def __init__(self, type=None): 558 | self.type = type 559 | 560 | class Executor(): 561 | """Executor executes v2-based Python function components.""" 562 | 563 | def __init__(self, executor_input: Dict, function_to_execute: Callable): 564 | self._func = function_to_execute 565 | self._input = executor_input 566 | self._input_artifacts: Dict[str, Artifact] = {} 567 | self._output_artifacts: Dict[str, Artifact] = {} 568 | 569 | for name, artifacts in self._input.get('inputs', {}).get('artifacts', 570 | {}).items(): 571 | artifacts_list = artifacts.get('artifacts') 572 | if artifacts_list: 573 | self._input_artifacts[name] = self._make_input_artifact( 574 | artifacts_list[0]) 575 | 576 | for name, artifacts in self._input.get('outputs', {}).get('artifacts', 577 | {}).items(): 578 | artifacts_list = artifacts.get('artifacts') 579 | if artifacts_list: 580 | self._output_artifacts[name] = self._make_output_artifact( 581 | artifacts_list[0]) 582 | 583 | self._return_annotation = inspect.signature(self._func).return_annotation 584 | self._executor_output = {} 585 | 586 | @classmethod 587 | def _make_input_artifact(cls, runtime_artifact: Dict): 588 | return create_runtime_artifact(runtime_artifact) 589 | 590 | @classmethod 591 | def _make_output_artifact(cls, runtime_artifact: Dict): 592 | import os 593 | artifact = create_runtime_artifact(runtime_artifact) 594 | os.makedirs(os.path.dirname(artifact.path), exist_ok=True) 595 | return artifact 596 | 597 | def _get_input_artifact(self, name: str): 598 | return self._input_artifacts.get(name) 599 | 600 | def _get_output_artifact(self, name: str): 601 | return self._output_artifacts.get(name) 602 | 603 | def _get_input_parameter_value(self, parameter_name: str, parameter_type: Any): 604 | parameter = self._input.get('inputs', {}).get('parameters', 605 | {}).get(parameter_name, None) 606 | if parameter is None: 607 | return None 608 | 609 | if parameter.get('stringValue'): 610 | if parameter_type == str: 611 | return parameter['stringValue'] 612 | elif parameter_type == bool: 613 | # Use `.lower()` so it can also handle 'True' and 'False' (resulted from 614 | # `str(True)` and `str(False)`, respectively. 615 | return json.loads(parameter['stringValue'].lower()) 616 | else: 617 | return json.loads(parameter['stringValue']) 618 | elif parameter.get('intValue'): 619 | return int(parameter['intValue']) 620 | elif parameter.get('doubleValue'): 621 | return float(parameter['doubleValue']) 622 | 623 | def _get_output_parameter_path(self, parameter_name: str): 624 | parameter_name = self._maybe_strip_path_suffix(parameter_name) 625 | parameter = self._input.get('outputs', 626 | {}).get('parameters', 627 | {}).get(parameter_name, None) 628 | if parameter is None: 629 | return None 630 | 631 | import os 632 | path = parameter.get('outputFile', None) 633 | if path: 634 | os.makedirs(os.path.dirname(path), exist_ok=True) 635 | return path 636 | 637 | def _get_output_artifact_path(self, artifact_name: str): 638 | artifact_name = self._maybe_strip_path_suffix(artifact_name) 639 | output_artifact = self._output_artifacts.get(artifact_name) 640 | if not output_artifact: 641 | raise ValueError( 642 | 'Failed to get output artifact path for artifact name {}'.format( 643 | artifact_name)) 644 | return output_artifact.path 645 | 646 | def _get_input_artifact_path(self, artifact_name: str): 647 | artifact_name = self._maybe_strip_path_suffix(artifact_name) 648 | input_artifact = self._input_artifacts.get(artifact_name) 649 | if not input_artifact: 650 | raise ValueError( 651 | 'Failed to get input artifact path for artifact name {}'.format( 652 | artifact_name)) 653 | return input_artifact.path 654 | 655 | def _write_output_parameter_value(self, name: str, 656 | value: Union[str, int, float, bool, dict, 657 | list, Dict, List]): 658 | if type(value) == str: 659 | output = {'stringValue': value} 660 | elif type(value) == int: 661 | output = {'intValue': value} 662 | elif type(value) == float: 663 | output = {'doubleValue': value} 664 | else: 665 | # For bool, list, dict, List, Dict, json serialize the value. 666 | output = {'stringValue': json.dumps(value)} 667 | 668 | if not self._executor_output.get('parameters'): 669 | self._executor_output['parameters'] = {} 670 | 671 | self._executor_output['parameters'][name] = output 672 | 673 | def _write_output_artifact_payload(self, name: str, value: Any): 674 | path = self._get_output_artifact_path(name) 675 | with open(path, 'w') as f: 676 | f.write(str(value)) 677 | 678 | # TODO: extract to a util 679 | @classmethod 680 | def _get_short_type_name(cls, type_name: str) -> str: 681 | """Extracts the short form type name. 682 | 683 | This method is used for looking up serializer for a given type. 684 | 685 | For example: 686 | typing.List -> List 687 | typing.List[int] -> List 688 | typing.Dict[str, str] -> Dict 689 | List -> List 690 | str -> str 691 | 692 | Args: 693 | type_name: The original type name. 694 | 695 | Returns: 696 | The short form type name or the original name if pattern doesn't match. 697 | """ 698 | import re 699 | match = re.match('(typing\.)?(?P\w+)(?:\[.+\])?', type_name) 700 | if match: 701 | return match.group('type') 702 | else: 703 | return type_name 704 | 705 | # TODO: merge with type_utils.is_parameter_type 706 | @classmethod 707 | def _is_parameter(cls, annotation: Any) -> bool: 708 | if type(annotation) == type: 709 | return annotation in [str, int, float, bool, dict, list] 710 | 711 | # Annotation could be, for instance `typing.Dict[str, str]`, etc. 712 | return cls._get_short_type_name(str(annotation)) in ['Dict', 'List'] 713 | 714 | @classmethod 715 | def _is_artifact(cls, annotation: Any) -> bool: 716 | if type(annotation) == type: 717 | return issubclass(annotation, Artifact) 718 | return False 719 | 720 | @classmethod 721 | def _is_named_tuple(cls, annotation: Any) -> bool: 722 | if type(annotation) == type: 723 | return issubclass(annotation, tuple) and hasattr( 724 | annotation, '_fields') and hasattr(annotation, '__annotations__') 725 | return False 726 | 727 | def _handle_single_return_value(self, output_name: str, annotation_type: Any, 728 | return_value: Any): 729 | if self._is_parameter(annotation_type): 730 | if type(return_value) != annotation_type: 731 | raise ValueError( 732 | 'Function `{}` returned value of type {}; want type {}'.format( 733 | self._func.__name__, type(return_value), annotation_type)) 734 | self._write_output_parameter_value(output_name, return_value) 735 | elif self._is_artifact(annotation_type): 736 | self._write_output_artifact_payload(output_name, return_value) 737 | else: 738 | raise RuntimeError( 739 | 'Unknown return type: {}. Must be one of `str`, `int`, `float`, or a' 740 | ' subclass of `Artifact`'.format(annotation_type)) 741 | 742 | def _write_executor_output(self, func_output: Optional[Any] = None): 743 | if self._output_artifacts: 744 | self._executor_output['artifacts'] = {} 745 | 746 | for name, artifact in self._output_artifacts.items(): 747 | runtime_artifact = { 748 | 'name': artifact.name, 749 | 'uri': artifact.uri, 750 | 'metadata': artifact.metadata, 751 | } 752 | artifacts_list = {'artifacts': [runtime_artifact]} 753 | 754 | self._executor_output['artifacts'][name] = artifacts_list 755 | 756 | if func_output is not None: 757 | if self._is_parameter(self._return_annotation) or self._is_artifact( 758 | self._return_annotation): 759 | # Note: single output is named `Output` in component.yaml. 760 | self._handle_single_return_value('Output', self._return_annotation, 761 | func_output) 762 | elif self._is_named_tuple(self._return_annotation): 763 | if len(self._return_annotation._fields) != len(func_output): 764 | raise RuntimeError( 765 | 'Expected {} return values from function `{}`, got {}'.format( 766 | len(self._return_annotation._fields), self._func.__name__, 767 | len(func_output))) 768 | for i in range(len(self._return_annotation._fields)): 769 | field = self._return_annotation._fields[i] 770 | field_type = self._return_annotation.__annotations__[field] 771 | if type(func_output) == tuple: 772 | field_value = func_output[i] 773 | else: 774 | field_value = getattr(func_output, field) 775 | self._handle_single_return_value(field, field_type, field_value) 776 | else: 777 | raise RuntimeError( 778 | 'Unknown return type: {}. Must be one of `str`, `int`, `float`, a' 779 | ' subclass of `Artifact`, or a NamedTuple collection of these types.' 780 | .format(self._return_annotation)) 781 | 782 | import os 783 | os.makedirs( 784 | os.path.dirname(self._input['outputs']['outputFile']), exist_ok=True) 785 | with open(self._input['outputs']['outputFile'], 'w') as f: 786 | f.write(json.dumps(self._executor_output)) 787 | 788 | def _maybe_strip_path_suffix(self, name) -> str: 789 | if name.endswith('_path'): 790 | name = name[0:-len('_path')] 791 | if name.endswith('_file'): 792 | name = name[0:-len('_file')] 793 | return name 794 | 795 | def execute(self): 796 | annotations = inspect.getfullargspec(self._func).annotations 797 | 798 | # Function arguments. 799 | func_kwargs = {} 800 | 801 | for k, v in annotations.items(): 802 | if k == 'return': 803 | continue 804 | 805 | if self._is_parameter(v): 806 | func_kwargs[k] = self._get_input_parameter_value(k, v) 807 | 808 | if is_artifact_annotation(v): 809 | if is_input_artifact(v): 810 | func_kwargs[k] = self._get_input_artifact(k) 811 | if is_output_artifact(v): 812 | func_kwargs[k] = self._get_output_artifact(k) 813 | 814 | elif isinstance(v, OutputPath): 815 | if self._is_parameter(v.type): 816 | func_kwargs[k] = self._get_output_parameter_path(k) 817 | else: 818 | func_kwargs[k] = self._get_output_artifact_path(k) 819 | elif isinstance(v, InputPath): 820 | func_kwargs[k] = self._get_input_artifact_path(k) 821 | 822 | result = self._func(**func_kwargs) 823 | self._write_executor_output(result) 824 | 825 | 826 | def preprocess(output_csv_path: OutputPath('CSV')): 827 | #1 828 | from google.cloud import bigquery 829 | import google.auth 830 | 831 | creds, project = google.auth.default() 832 | client = bigquery.Client(project='kubeflow-demos', credentials=creds) 833 | 834 | query = """ 835 | SELECT * FROM `kubeflow-demos.telco.churn` 836 | """ 837 | print(query) 838 | 839 | dataframe = client.query(query).to_dataframe() 840 | print(dataframe.head()) 841 | 842 | dataframe.to_csv(output_csv_path) 843 | print("done") 844 | 845 | 846 | def executor_main(): 847 | import argparse 848 | import json 849 | 850 | parser = argparse.ArgumentParser(description='Process some integers.') 851 | parser.add_argument('--executor_input', type=str) 852 | parser.add_argument('--function_to_execute', type=str) 853 | 854 | args, _ = parser.parse_known_args() 855 | executor_input = json.loads(args.executor_input) 856 | function_to_execute = globals()[args.function_to_execute] 857 | 858 | executor = Executor(executor_input=executor_input, 859 | function_to_execute=function_to_execute) 860 | 861 | executor.execute() 862 | 863 | 864 | if __name__ == '__main__': 865 | executor_main() 866 | args: 867 | - --executor_input 868 | - {executorInput: null} 869 | - --function_to_execute 870 | - preprocess 871 | - --output-csv-output-path 872 | - {outputPath: output_csv} 873 | --------------------------------------------------------------------------------