├── custom_components ├── __init__.py ├── vertex_uploader.py ├── vertex_deployer.py ├── flower_mobilenet_trainer.py ├── firebase_publisher.py └── flower_densenet_trainer.py ├── figures ├── cloud_build_log.png ├── overall_workflow.png ├── sample_architecture.png └── README.md ├── notebooks ├── README.md ├── Dataset_Prep.ipynb ├── Dual_Deployments_With_AutoML.ipynb └── Custom_Model_TFX.ipynb ├── .gitignore ├── README.md └── LICENSE /custom_components/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /figures/cloud_build_log.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sayakpaul/Dual-Deployments-on-Vertex-AI/main/figures/cloud_build_log.png -------------------------------------------------------------------------------- /figures/overall_workflow.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sayakpaul/Dual-Deployments-on-Vertex-AI/main/figures/overall_workflow.png -------------------------------------------------------------------------------- /figures/sample_architecture.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sayakpaul/Dual-Deployments-on-Vertex-AI/main/figures/sample_architecture.png -------------------------------------------------------------------------------- /figures/README.md: -------------------------------------------------------------------------------- 1 | The beautiful figures contained in this directory have been prepared by [Chansung Park](https://github.com/deep-diver). 2 | -------------------------------------------------------------------------------- /custom_components/vertex_uploader.py: -------------------------------------------------------------------------------- 1 | """ 2 | Custom TFX component for importing a model into Vertex AI. 3 | Author: Sayak Paul 4 | Reference: https://github.com/GoogleCloudPlatform/mlops-with-vertex-ai/blob/main/src/tfx_pipelines/components.py#L74 5 | """ 6 | 7 | import os 8 | import tensorflow as tf 9 | 10 | from tfx.dsl.component.experimental.decorators import component 11 | from tfx.dsl.component.experimental.annotations import Parameter 12 | from tfx.types.standard_artifacts import String 13 | from google.cloud import aiplatform as vertex_ai 14 | from tfx import v1 as tfx 15 | from absl import logging 16 | 17 | 18 | @component 19 | def VertexUploader( 20 | project: Parameter[str], 21 | region: Parameter[str], 22 | model_display_name: Parameter[str], 23 | pushed_model_location: Parameter[str], 24 | serving_image_uri: Parameter[str], 25 | uploaded_model: tfx.dsl.components.OutputArtifact[String], 26 | ): 27 | 28 | vertex_ai.init(project=project, location=region) 29 | 30 | pushed_model_dir = os.path.join( 31 | pushed_model_location, tf.io.gfile.listdir(pushed_model_location)[-1] 32 | ) 33 | 34 | logging.info(f"Model registry location: {pushed_model_dir}") 35 | 36 | vertex_model = vertex_ai.Model.upload( 37 | display_name=model_display_name, 38 | artifact_uri=pushed_model_dir, 39 | serving_container_image_uri=serving_image_uri, 40 | parameters_schema_uri=None, 41 | instance_schema_uri=None, 42 | explanation_metadata=None, 43 | explanation_parameters=None, 44 | ) 45 | 46 | uploaded_model.set_string_custom_property( 47 | "model_resource_name", str(vertex_model.resource_name) 48 | ) 49 | logging.info(f"Model resource: {str(vertex_model.resource_name)}") 50 | -------------------------------------------------------------------------------- /custom_components/vertex_deployer.py: -------------------------------------------------------------------------------- 1 | """ 2 | Custom TFX component for deploying a model to a Vertex AI Endpoint. 3 | Author: Sayak Paul 4 | Reference: https://github.com/GoogleCloudPlatform/mlops-with-vertex-ai/blob/main/build/utils.py#L97 5 | """ 6 | 7 | from tfx.dsl.component.experimental.decorators import component 8 | from tfx.dsl.component.experimental.annotations import Parameter 9 | from tfx.types.standard_artifacts import String 10 | from google.cloud import aiplatform as vertex_ai 11 | from tfx import v1 as tfx 12 | from absl import logging 13 | 14 | 15 | @component 16 | def VertexDeployer( 17 | project: Parameter[str], 18 | region: Parameter[str], 19 | model_display_name: Parameter[str], 20 | deployed_model_display_name: Parameter[str], 21 | ): 22 | 23 | logging.info(f"Endpoint display: {deployed_model_display_name}") 24 | vertex_ai.init(project=project, location=region) 25 | 26 | endpoints = vertex_ai.Endpoint.list( 27 | filter=f"display_name={deployed_model_display_name}", order_by="update_time" 28 | ) 29 | 30 | if len(endpoints) > 0: 31 | logging.info(f"Endpoint {deployed_model_display_name} already exists.") 32 | endpoint = endpoints[-1] 33 | else: 34 | endpoint = vertex_ai.Endpoint.create(deployed_model_display_name) 35 | 36 | model = vertex_ai.Model.list( 37 | filter=f"display_name={model_display_name}", order_by="update_time" 38 | )[-1] 39 | 40 | endpoint = vertex_ai.Endpoint.list( 41 | filter=f"display_name={deployed_model_display_name}", order_by="update_time" 42 | )[-1] 43 | 44 | deployed_model = endpoint.deploy( 45 | model=model, 46 | # Syntax from here: https://git.io/JBQDP 47 | traffic_split={"0": 100}, 48 | machine_type="n1-standard-4", 49 | min_replica_count=1, 50 | max_replica_count=1, 51 | ) 52 | 53 | logging.info(f"Model deployed to: {deployed_model}") 54 | -------------------------------------------------------------------------------- /notebooks/README.md: -------------------------------------------------------------------------------- 1 | We expose our notebooks that are fully runnable on [Google Colab](https://colab.research.google.com/) to facilitate easier learning. Below 2 | you can find short descriptions of the individual notebooks present here. 3 | 4 | ## Notebooks 5 | 6 | |

Notebook

|

Colab

|

Description

| 7 | | :--------: | :-------: | :--------------------------------: | 8 | | `Dataset_Prep.ipynb` | Open In Colab | Prepares dataset for AutoML. | 9 | | `Dual_Deployments_With_AutoML.ipynb` | Open In Colab | Shows how to build a Kubeflow Pipeline
to train and deploy two different models
using AutoML and Vertex AI. | 10 | | `Model_Tests.ipynb` | Open In Colab | Shows how to test the models trained
using the notebook above in a
standalone manner. | 11 | | `Custom_Model_TFX.ipynb` | Open In Colab | Shows how to build a TFX pipeline using
custom components to train and deploy
two different models and run them using Vertex AI. | 12 | 13 | ## References 14 | 15 | * [Kubeflow Pipelines](https://www.kubeflow.org/docs/components/pipelines/overview/pipelines-overview/) 16 | * [AutoML SDKs from Kubeflow](https://google-cloud-pipeline-components.readthedocs.io/en/latest/google_cloud_pipeline_components.aiplatform.html#module-google_cloud_pipeline_components.aiplatform) 17 | * [Vertex AI Pipelines](https://cloud.google.com/vertex-ai/docs/pipelines) 18 | * [TensorFlow Extended (TFX)](https://www.tensorflow.org/tfx) 19 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | -------------------------------------------------------------------------------- /custom_components/flower_mobilenet_trainer.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | from absl import logging 3 | from tensorflow import keras 4 | from tfx import v1 as tfx 5 | import tensorflow as tf 6 | 7 | 8 | _IMAGE_FEATURES = { 9 | "image": tf.io.FixedLenFeature([], tf.string), 10 | "class": tf.io.FixedLenFeature([], tf.int64), 11 | "one_hot_class": tf.io.VarLenFeature(tf.float32), 12 | } 13 | 14 | _INPUT_SHAPE = (224, 224, 3) 15 | _TRAIN_BATCH_SIZE = 64 16 | _EVAL_BATCH_SIZE = 64 17 | _EPOCHS = 2 18 | 19 | 20 | def _parse_fn(example): 21 | example = tf.io.parse_single_example(example, _IMAGE_FEATURES) 22 | image = tf.image.decode_jpeg(example["image"], channels=3) 23 | class_label = tf.cast(example["class"], tf.int32) 24 | return image, class_label 25 | 26 | 27 | def _input_fn(file_pattern: List[str], batch_size: int) -> tf.data.Dataset: 28 | """Generates features and label for training. 29 | 30 | Args: 31 | file_pattern: List of paths or patterns of input tfrecord files. 32 | batch_size: representing the number of consecutive elements of returned 33 | dataset to combine in a single batch. 34 | 35 | Returns: 36 | A dataset that contains (features, indices) tuple where features is a 37 | dictionary of Tensors, and indices is a single Tensor of label indices. 38 | """ 39 | logging.info(f"Reading data from: {file_pattern}") 40 | tfrecord_filenames = tf.io.gfile.glob(file_pattern[0] + ".gz") 41 | dataset = tf.data.TFRecordDataset(tfrecord_filenames, compression_type="GZIP") 42 | dataset = dataset.map(_parse_fn).batch(batch_size) 43 | return dataset.repeat() 44 | 45 | 46 | def _make_keras_model() -> tf.keras.Model: 47 | """Creates a MobileNetV3-based model for classifying flowers data. 48 | 49 | Returns: 50 | A Keras Model. 51 | """ 52 | inputs = keras.Input(shape=_INPUT_SHAPE) 53 | base_model = keras.applications.MobileNetV3Small( 54 | include_top=False, input_shape=_INPUT_SHAPE, pooling="avg" 55 | ) 56 | base_model.trainable = False 57 | x = keras.applications.mobilenet_v3.preprocess_input(inputs) 58 | x = base_model( 59 | x, training=False 60 | ) # Ensures BatchNorm runs in inference model in this model 61 | outputs = keras.layers.Dense(5, activation="softmax")(x) 62 | model = keras.Model(inputs, outputs) 63 | 64 | model.compile( 65 | optimizer=keras.optimizers.Adam(), 66 | loss=tf.keras.losses.SparseCategoricalCrossentropy(), 67 | metrics=[keras.metrics.SparseCategoricalAccuracy()], 68 | ) 69 | 70 | model.summary(print_fn=logging.info) 71 | return model 72 | 73 | 74 | # TFX Trainer will call this function. 75 | def run_fn(fn_args: tfx.components.FnArgs): 76 | """Train the model based on given args. 77 | 78 | Args: 79 | fn_args: Holds args used to train the model as name/value pairs. 80 | """ 81 | train_dataset = _input_fn(fn_args.train_files, batch_size=_TRAIN_BATCH_SIZE) 82 | eval_dataset = _input_fn(fn_args.eval_files, batch_size=_EVAL_BATCH_SIZE) 83 | 84 | model = _make_keras_model() 85 | model.fit( 86 | train_dataset, 87 | steps_per_epoch=fn_args.train_steps, 88 | validation_data=eval_dataset, 89 | validation_steps=fn_args.eval_steps, 90 | epochs=_EPOCHS, 91 | ) 92 | _, acc = model.evaluate(eval_dataset, steps=fn_args.eval_steps) 93 | logging.info(f"Validation accuracy: {round(acc * 100, 2)}%") 94 | 95 | # Convert the model. 96 | converter = tf.lite.TFLiteConverter.from_keras_model(model) 97 | converter.optimizations = [tf.lite.Optimize.DEFAULT] 98 | tflite_model = converter.convert() 99 | 100 | # Save the model. 101 | # The result of the training should be saved in `fn_args.serving_model_dir` directory. 102 | with tf.io.gfile.GFile(fn_args.serving_model_dir + "/model.tflite", "wb") as f: 103 | f.write(tflite_model) 104 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Dual-Deployments-on-Vertex-AI 2 | 3 | _By [Chansung Park](https://github.com/deep-diver) and Sayak Paul_ 4 | 5 | This project demonstrates a workflow to cover dual model deployment scenarios using [Kubeflow](https://www.kubeflow.org/), 6 | [TensorFlow Extended (TFX)](https://www.tensorflow.org/tfx), and [Vertex AI](https://cloud.google.com/vertex-ai). We suggest 7 | reading the accompanying [blog post](https://cloud.google.com/blog/topics/developers-practitioners/dual-deployments-vertex-ai) first 8 | to get an idea and then following along with the code. This project also received the [#TFCommunitySpotlight Award](https://twitter.com/TensorFlow/status/1446611368078086144?s=20). 9 | 10 | ## Motivation 💻 11 | 12 | Let's say you want to allow your users to run an application both in online and offline mode. Your mobile 13 | application would use a TFLite model depending on the network bandwidth/battery etc., and if sufficient 14 | network coverage/internet bandwidth is available your application would instead use the online cloud one. This way 15 | your application stays resilient and can ensure high availability. 16 | 17 | Sometimes we also do layered predictions where we first divide a problem into smaller tasks: 18 | 1) predict if it's a yes/no, 19 | 2) depending on the output of 1) we run the final model. 20 | 21 | In these cases, 1) takes place on-device and 2) takes place on the cloud to ensure a smooth UX. Furthermore, it's 22 | a good practice to use a mobile-friendly network architecture (such as MobileNet) when considering 23 | mobile deployments. This leads us to the following question: 24 | 25 | _**Can we train two different models within the same deployment pipeline and manage them seamlessly?**_ 26 | 27 | This project is motivated by this question. 28 | 29 | ## AutoML, TFX, etc. 🛠 30 | 31 | Different organizations have people with varied technical backgrounds. We wanted to provide the easiest solution first 32 | and then move on to something that is more customizable. To this end, we leverage [Kubeflow's AutoML SDKs](https://github.com/kubeflow/pipelines/tree/master/components/google-cloud) to build, train, and deploy models with 33 | different production use-cases. With AutoML, the developers can delegate a large part of their workflows to the SDKs 34 | and the codebase also stays comparatively smaller. The figure below depicts a sample system architecture for 35 | this scenario: 36 | 37 | ![](figures/sample_architecture.png) 38 | 39 | **Figure developed by Chansung Park.** 40 | 41 | But the story does not end here. What if we wanted to have better control over the models to be built, trained, 42 | and deployed? Enter TFX! TFX provides the flexibility of writing custom components and including them inside a 43 | pipeline. This way Machine Learning Engineers can focus on building and training their favorite models and delegate 44 | a part of the heavy lifting to TFX and Vertex AI. On Vertex AI (acting as an orchestrator) this pipeline will look like 45 | so: 46 | 47 | ![](https://i.ibb.co/98Ry74n/Screen-Shot-2021-08-06-at-1-43-35-AM.png) 48 | 49 | ```txt 50 | 🔥 In this project we cover both these situations. 51 | ``` 52 | 53 | ## Code 🆘 54 | 55 | Our code is distributed as Colab Notebooks. But one needs to have a billing-enabled GCP account 56 | (with a few APIs enabled) to successfully run these notebooks. Alternatively one can also use the 57 | notebooks on [Vertex AI Notebooks](https://cloud.google.com/vertex-ai/docs/general/notebooks). Find 58 | all the notebooks and their descriptions here: 59 | [`notebooks`](https://github.com/sayakpaul/Dual-Deployments-on-Vertex-AI/tree/main/notebooks). 60 | 61 | Additionally, you can find the custom TFX components separately here - [`custom_components`](https://github.com/sayakpaul/Dual-Deployments-on-Vertex-AI/tree/main/custom_components). 62 | 63 | ## Acknowledgements 64 | 65 | [ML-GDE program](https://developers.google.com/programs/experts/) for providing GCP credits. Thanks to [Karl Weinmeister](https://twitter.com/kweinmeister?lang=hr) and [Robert Crowe](https://twitter.com/robert_crowe?lang=en) for providing review feedback on this project. 66 | -------------------------------------------------------------------------------- /custom_components/firebase_publisher.py: -------------------------------------------------------------------------------- 1 | """ 2 | Custom TFX component for Firebase upload. 3 | Author: Chansung Park 4 | """ 5 | 6 | from tfx import types 7 | from tfx.dsl.component.experimental.decorators import component 8 | from tfx.dsl.component.experimental.annotations import Parameter 9 | from tfx import v1 as tfx 10 | from absl import logging 11 | 12 | import firebase_admin 13 | from firebase_admin import ml 14 | from firebase_admin import storage 15 | from firebase_admin import credentials 16 | from google.cloud import storage as gcs_storage 17 | 18 | 19 | @component 20 | def FirebasePublisher( 21 | pushed_model: tfx.dsl.components.InputArtifact[ 22 | tfx.types.standard_artifacts.PushedModel 23 | ], 24 | credential_uri: Parameter[str], 25 | firebase_dest_gcs_bucket: Parameter[str], 26 | model_display_name: Parameter[str], 27 | model_tag: Parameter[str], 28 | ) -> tfx.dsl.components.OutputDict(result=str): 29 | """ 30 | publish trained tflite model to Firebase ML, this component assumes that 31 | trained model and Firebase credential files are stored in GCS locations. 32 | 33 | Args: 34 | pushed_model: The URI of pushed model obtained from previous component (i.e. Pusher) 35 | credential_uri: The URI of Firebase credential. In order to get one, go to Firebase dashboard 36 | and on the Settings page, create a service account and download the service account key file. 37 | Keep this file safe, since it grants administrator access to your project. 38 | firebase_dest_gcs_bucket: GCS bucket where the model is going to be temporarily stored. 39 | In order to create one, go to Firebase dashboard and on the Storage page, enable Cloud Storage. 40 | Take note of your bucket name. 41 | model_display_name: The name to be appeared on Firebase ML dashboard 42 | model_tag: The tage name to be appeared on Firebase ML dashboard 43 | """ 44 | 45 | model_uri = f"{pushed_model.uri}/model.tflite" 46 | 47 | assert model_uri.split("://")[0] == "gs" 48 | assert credential_uri.split("://")[0] == "gs" 49 | 50 | # create gcs client instance 51 | gcs_client = gcs_storage.Client() 52 | 53 | # get credential for firebase 54 | credential_gcs_bucket = credential_uri.split("//")[1].split("/")[0] 55 | credential_blob_path = "/".join(credential_uri.split("//")[1].split("/")[1:]) 56 | 57 | bucket = gcs_client.bucket(credential_gcs_bucket) 58 | blob = bucket.blob(credential_blob_path) 59 | blob.download_to_filename("credential.json") 60 | logging.info(f"download credential.json from {credential_uri} is completed") 61 | 62 | # get tflite model file 63 | tflite_gcs_bucket = model_uri.split("//")[1].split("/")[0] 64 | tflite_blob_path = "/".join(model_uri.split("//")[1].split("/")[1:]) 65 | 66 | bucket = gcs_client.bucket(tflite_gcs_bucket) 67 | blob = bucket.blob(tflite_blob_path) 68 | blob.download_to_filename("model.tflite") 69 | logging.info(f"download model.tflite from {model_uri} is completed") 70 | 71 | firebase_admin.initialize_app( 72 | credentials.Certificate("credential.json"), 73 | options={"storageBucket": firebase_dest_gcs_bucket}, 74 | ) 75 | logging.info("firebase_admin initialize app is completed") 76 | 77 | model_list = ml.list_models(list_filter=f"display_name={model_display_name}") 78 | # update 79 | if len(model_list.models) > 0: 80 | # get the first match model 81 | model = model_list.models[0] 82 | source = ml.TFLiteGCSModelSource.from_tflite_model_file("model.tflite") 83 | model.model_format = ml.TFLiteFormat(model_source=source) 84 | 85 | updated_model = ml.update_model(model) 86 | ml.publish_model(updated_model.model_id) 87 | 88 | logging.info("model exists, so update it in FireBase ML") 89 | return {"result": "model updated"} 90 | # create 91 | else: 92 | # load a tflite file and upload it to Cloud Storage 93 | source = ml.TFLiteGCSModelSource.from_tflite_model_file("model.tflite") 94 | 95 | # create the model object 96 | tflite_format = ml.TFLiteFormat(model_source=source) 97 | model = ml.Model( 98 | display_name=model_display_name, 99 | tags=[model_tag], 100 | model_format=tflite_format, 101 | ) 102 | 103 | # Add the model to your Firebase project and publish it 104 | new_model = ml.create_model(model) 105 | ml.publish_model(new_model.model_id) 106 | 107 | logging.info("model doesn exists, so create one in FireBase ML") 108 | return {"result": "model created"} 109 | -------------------------------------------------------------------------------- /custom_components/flower_densenet_trainer.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | from absl import logging 3 | from tensorflow import keras 4 | from tfx import v1 as tfx 5 | import tensorflow as tf 6 | 7 | 8 | _IMAGE_FEATURES = { 9 | "image": tf.io.FixedLenFeature([], tf.string), 10 | "class": tf.io.FixedLenFeature([], tf.int64), 11 | "one_hot_class": tf.io.VarLenFeature(tf.float32), 12 | } 13 | 14 | _CONCRETE_INPUT = "numpy_inputs" 15 | _INPUT_SHAPE = (224, 224, 3) 16 | _TRAIN_BATCH_SIZE = 64 17 | _EVAL_BATCH_SIZE = 64 18 | _EPOCHS = 2 19 | 20 | 21 | def _parse_fn(example): 22 | example = tf.io.parse_single_example(example, _IMAGE_FEATURES) 23 | image = tf.image.decode_jpeg(example["image"], channels=3) 24 | class_label = tf.cast(example["class"], tf.int32) 25 | return image, class_label 26 | 27 | 28 | def _input_fn(file_pattern: List[str], batch_size: int) -> tf.data.Dataset: 29 | """Generates features and label for training. 30 | 31 | Args: 32 | file_pattern: List of paths or patterns of input tfrecord files. 33 | batch_size: representing the number of consecutive elements of returned 34 | dataset to combine in a single batch. 35 | 36 | Returns: 37 | A dataset that contains (features, indices) tuple where features is a 38 | dictionary of Tensors, and indices is a single Tensor of label indices. 39 | """ 40 | logging.info(f"Reading data from: {file_pattern}") 41 | tfrecord_filenames = tf.io.gfile.glob(file_pattern[0] + ".gz") 42 | dataset = tf.data.TFRecordDataset(tfrecord_filenames, compression_type="GZIP") 43 | dataset = dataset.map(_parse_fn).batch(batch_size) 44 | return dataset.repeat() 45 | 46 | 47 | def _make_keras_model() -> tf.keras.Model: 48 | """Creates a DenseNet121-based model for classifying flowers data. 49 | 50 | Returns: 51 | A Keras Model. 52 | """ 53 | inputs = keras.Input(shape=_INPUT_SHAPE) 54 | base_model = keras.applications.DenseNet121( 55 | include_top=False, input_shape=_INPUT_SHAPE, pooling="avg" 56 | ) 57 | base_model.trainable = False 58 | x = keras.applications.densenet.preprocess_input(inputs) 59 | x = base_model( 60 | x, training=False 61 | ) # Ensures BatchNorm runs in inference model in this model 62 | outputs = keras.layers.Dense(5, activation="softmax")(x) 63 | model = keras.Model(inputs, outputs) 64 | 65 | model.compile( 66 | optimizer=keras.optimizers.Adam(), 67 | loss=tf.keras.losses.SparseCategoricalCrossentropy(), 68 | metrics=[keras.metrics.SparseCategoricalAccuracy()], 69 | ) 70 | 71 | model.summary(print_fn=logging.info) 72 | return model 73 | 74 | 75 | def _preprocess(bytes_input): 76 | decoded = tf.io.decode_jpeg(bytes_input, channels=3) 77 | resized = tf.image.resize(decoded, size=(224, 224)) 78 | return resized 79 | 80 | 81 | @tf.function(input_signature=[tf.TensorSpec([None], tf.string)]) 82 | def preprocess_fn(bytes_inputs): 83 | decoded_images = tf.map_fn( 84 | _preprocess, bytes_inputs, dtype=tf.float32, back_prop=False 85 | ) 86 | return {_CONCRETE_INPUT: decoded_images} 87 | 88 | 89 | def _model_exporter(model: tf.keras.Model): 90 | m_call = tf.function(model.call).get_concrete_function( 91 | [ 92 | tf.TensorSpec( 93 | shape=[None, 224, 224, 3], dtype=tf.float32, name=_CONCRETE_INPUT 94 | ) 95 | ] 96 | ) 97 | 98 | @tf.function(input_signature=[tf.TensorSpec([None], tf.string)]) 99 | def serving_fn(bytes_inputs): 100 | # This function comes from the Computer Vision book from O'Reilly. 101 | labels = tf.constant( 102 | ["daisy", "dandelion", "roses", "sunflowers", "tulips"], dtype=tf.string 103 | ) 104 | images = preprocess_fn(bytes_inputs) 105 | 106 | probs = m_call(**images) 107 | indices = tf.argmax(probs, axis=1) 108 | pred_source = tf.gather(params=labels, indices=indices) 109 | pred_confidence = tf.reduce_max(probs, axis=1) 110 | return {"label": pred_source, "confidence": pred_confidence} 111 | 112 | return serving_fn 113 | 114 | 115 | # TFX Trainer will call this function. 116 | def run_fn(fn_args: tfx.components.FnArgs): 117 | """Train the model based on given args. 118 | 119 | Args: 120 | fn_args: Holds args used to train the model as name/value pairs. 121 | """ 122 | train_dataset = _input_fn(fn_args.train_files, batch_size=_TRAIN_BATCH_SIZE) 123 | eval_dataset = _input_fn(fn_args.eval_files, batch_size=_EVAL_BATCH_SIZE) 124 | 125 | model = _make_keras_model() 126 | model.fit( 127 | train_dataset, 128 | steps_per_epoch=fn_args.train_steps, 129 | validation_data=eval_dataset, 130 | validation_steps=fn_args.eval_steps, 131 | epochs=_EPOCHS, 132 | ) 133 | _, acc = model.evaluate(eval_dataset, steps=fn_args.eval_steps) 134 | logging.info(f"Validation accuracy: {round(acc * 100, 2)}%") 135 | # The result of the training should be saved in `fn_args.serving_model_dir` 136 | # directory. 137 | tf.saved_model.save( 138 | model, 139 | fn_args.serving_model_dir, 140 | signatures={"serving_default": _model_exporter(model)}, 141 | ) 142 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /notebooks/Dataset_Prep.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "nbformat": 4, 3 | "nbformat_minor": 0, 4 | "metadata": { 5 | "colab": { 6 | "name": "Dataset_Prep", 7 | "provenance": [], 8 | "include_colab_link": true 9 | }, 10 | "kernelspec": { 11 | "display_name": "Python 3", 12 | "name": "python3" 13 | } 14 | }, 15 | "cells": [ 16 | { 17 | "cell_type": "markdown", 18 | "metadata": { 19 | "id": "view-in-github", 20 | "colab_type": "text" 21 | }, 22 | "source": [ 23 | "\"Open" 24 | ] 25 | }, 26 | { 27 | "cell_type": "markdown", 28 | "metadata": { 29 | "id": "zyOcSJS29gkH" 30 | }, 31 | "source": [ 32 | "In this notebook, we will use the [Flowers dataset](https://www.tensorflow.org/datasets/catalog/tf_flowers) and create a `.csv` file out of it so that it can be imported into Vertex AI as a [managed dataset](https://cloud.google.com/vertex-ai/docs/training/using-managed-datasets). \n", 33 | "\n", 34 | "To proceed with the rest of the notebook you'd need a billing-enabled GCP account. " 35 | ] 36 | }, 37 | { 38 | "cell_type": "markdown", 39 | "metadata": { 40 | "id": "TRIz8jbQ-MUb" 41 | }, 42 | "source": [ 43 | "## Setup" 44 | ] 45 | }, 46 | { 47 | "cell_type": "code", 48 | "metadata": { 49 | "id": "lIYdn1woOS1n" 50 | }, 51 | "source": [ 52 | "!gcloud init" 53 | ], 54 | "execution_count": null, 55 | "outputs": [] 56 | }, 57 | { 58 | "cell_type": "markdown", 59 | "metadata": { 60 | "id": "bG-NwjjB-ioI" 61 | }, 62 | "source": [ 63 | "## Download the original dataset and copy over to a GCS Bucket" 64 | ] 65 | }, 66 | { 67 | "cell_type": "code", 68 | "metadata": { 69 | "id": "dWUoW87xtuQ8" 70 | }, 71 | "source": [ 72 | "!wget -q https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz\n", 73 | "!tar -xf flower_photos.tgz" 74 | ], 75 | "execution_count": null, 76 | "outputs": [] 77 | }, 78 | { 79 | "cell_type": "markdown", 80 | "metadata": { 81 | "id": "U7nDmdxY-rFr" 82 | }, 83 | "source": [ 84 | "***If you have a spare bucket, then you can use that for this purpose.*** " 85 | ] 86 | }, 87 | { 88 | "cell_type": "code", 89 | "metadata": { 90 | "id": "eqldz8Jkz8se" 91 | }, 92 | "source": [ 93 | "#@title GCS\n", 94 | "#@markdown You should change these values as per your preferences. The copy operation can take ~5 minutes. \n", 95 | "BUCKET_PATH = \"gs://flowers-experimental\" #@param {type:\"string\"}\n", 96 | "REGION = \"us-central1\" #@param {type:\"string\"}\n", 97 | "\n", 98 | "!gsutil mb -l {REGION} {BUCKET_PATH}\n", 99 | "!gsutil -m cp -r flower_photos {BUCKET_PATH}" 100 | ], 101 | "execution_count": null, 102 | "outputs": [] 103 | }, 104 | { 105 | "cell_type": "markdown", 106 | "metadata": { 107 | "id": "UX7Gw2_h-4Pk" 108 | }, 109 | "source": [ 110 | "Verify if the files were copied over." 111 | ] 112 | }, 113 | { 114 | "cell_type": "code", 115 | "metadata": { 116 | "colab": { 117 | "base_uri": "https://localhost:8080/" 118 | }, 119 | "id": "KNfP-pbowVkU", 120 | "outputId": "2ae0abcb-2ee8-4d22-af6b-16b6bc29c824" 121 | }, 122 | "source": [ 123 | "!gsutil ls {BUCKET_PATH}/flower_photos/" 124 | ], 125 | "execution_count": null, 126 | "outputs": [ 127 | { 128 | "output_type": "stream", 129 | "text": [ 130 | "gs://flowers-experimental/flower_photos/LICENSE.txt\n", 131 | "gs://flowers-experimental/flower_photos/daisy/\n", 132 | "gs://flowers-experimental/flower_photos/dandelion/\n", 133 | "gs://flowers-experimental/flower_photos/roses/\n", 134 | "gs://flowers-experimental/flower_photos/sunflowers/\n", 135 | "gs://flowers-experimental/flower_photos/tulips/\n" 136 | ], 137 | "name": "stdout" 138 | } 139 | ] 140 | }, 141 | { 142 | "cell_type": "markdown", 143 | "metadata": { 144 | "id": "ZPOBNLDS--Vu" 145 | }, 146 | "source": [ 147 | "## Imports" 148 | ] 149 | }, 150 | { 151 | "cell_type": "code", 152 | "metadata": { 153 | "id": "_u3FHpaAw7He" 154 | }, 155 | "source": [ 156 | "import random\n", 157 | "random.seed(666)\n", 158 | "\n", 159 | "from google.cloud import storage\n", 160 | "from pprint import pprint\n", 161 | "import pandas as pd\n", 162 | "import os" 163 | ], 164 | "execution_count": null, 165 | "outputs": [] 166 | }, 167 | { 168 | "cell_type": "code", 169 | "metadata": { 170 | "id": "XMN7RXwt_M20" 171 | }, 172 | "source": [ 173 | "from google.colab import auth\n", 174 | "auth.authenticate_user()" 175 | ], 176 | "execution_count": null, 177 | "outputs": [] 178 | }, 179 | { 180 | "cell_type": "markdown", 181 | "metadata": { 182 | "id": "utZIHk6p_a5X" 183 | }, 184 | "source": [ 185 | "## Preparing a single `.csv` file\n", 186 | "\n", 187 | "Vertex AI datasets can operate with `.jsonl` and `.csv` formats in order to import datasets. In this notebook, we will be using `.csv`. Here's the structure which Vertex AI expects ([reference](https://cloud.google.com/vertex-ai/docs/datasets/prepare-image#csv)):\n", 188 | "\n", 189 | "```\n", 190 | "[ML_USE],GCS_FILE_PATH,[LABEL]\n", 191 | "```\n", 192 | "\n", 193 | "`ML_USE` stands for the data split - `training`, `valid`, and `test`. " 194 | ] 195 | }, 196 | { 197 | "cell_type": "markdown", 198 | "metadata": { 199 | "id": "d-1fItckAh0u" 200 | }, 201 | "source": [ 202 | "### Derive GCS URIs of the images" 203 | ] 204 | }, 205 | { 206 | "cell_type": "code", 207 | "metadata": { 208 | "id": "0aQEzBF-zE9J", 209 | "colab": { 210 | "base_uri": "https://localhost:8080/" 211 | }, 212 | "outputId": "b9f23b24-2a74-49db-fe41-7f67392c32b1" 213 | }, 214 | "source": [ 215 | "gs_uris = []\n", 216 | "\n", 217 | "storage_client = storage.Client(project=\"fast-ai-exploration\") # Change it accordingly.\n", 218 | "blobs = storage_client.list_blobs(BUCKET_PATH.split(\"/\")[-1])\n", 219 | "\n", 220 | "for blob in blobs:\n", 221 | " if \".txt\" in blob.name.split(\"/\")[-1]:\n", 222 | " continue\n", 223 | " gs_uri = os.path.join(BUCKET_PATH, blob.name)\n", 224 | " gs_uris.append(gs_uri)\n", 225 | "\n", 226 | "pprint(gs_uris[:5])" 227 | ], 228 | "execution_count": null, 229 | "outputs": [ 230 | { 231 | "output_type": "stream", 232 | "text": [ 233 | "['gs://flowers-experimental/flower_photos/daisy/100080576_f52e8ee070_n.jpg',\n", 234 | " 'gs://flowers-experimental/flower_photos/daisy/10140303196_b88d3d6cec.jpg',\n", 235 | " 'gs://flowers-experimental/flower_photos/daisy/10172379554_b296050f82_n.jpg',\n", 236 | " 'gs://flowers-experimental/flower_photos/daisy/10172567486_2748826a8b.jpg',\n", 237 | " 'gs://flowers-experimental/flower_photos/daisy/10172636503_21bededa75_n.jpg']\n" 238 | ], 239 | "name": "stdout" 240 | } 241 | ] 242 | }, 243 | { 244 | "cell_type": "markdown", 245 | "metadata": { 246 | "id": "E116FS6sAleK" 247 | }, 248 | "source": [ 249 | "### Dataset splitting" 250 | ] 251 | }, 252 | { 253 | "cell_type": "code", 254 | "metadata": { 255 | "colab": { 256 | "base_uri": "https://localhost:8080/" 257 | }, 258 | "id": "1gyb4tbaxMFz", 259 | "outputId": "743ea2ef-683d-48c1-f981-161c3d1eaa9c" 260 | }, 261 | "source": [ 262 | "# Create splits.\n", 263 | "random.shuffle(gs_uris)\n", 264 | "\n", 265 | "i = int(len(gs_uris) * 0.9)\n", 266 | "train_paths = gs_uris[:i]\n", 267 | "test_paths = gs_uris[i:]\n", 268 | "\n", 269 | "i = int(len(train_paths) * 0.05)\n", 270 | "valid_paths = train_paths[:i]\n", 271 | "train_paths = train_paths[i:]\n", 272 | "\n", 273 | "print(len(train_paths), len(valid_paths), len(test_paths))" 274 | ], 275 | "execution_count": null, 276 | "outputs": [ 277 | { 278 | "output_type": "stream", 279 | "text": [ 280 | "3138 165 367\n" 281 | ], 282 | "name": "stdout" 283 | } 284 | ] 285 | }, 286 | { 287 | "cell_type": "markdown", 288 | "metadata": { 289 | "id": "Xe9C4yYwA1Ih" 290 | }, 291 | "source": [ 292 | "### Utility for deriving the labels and `ML_USE`" 293 | ] 294 | }, 295 | { 296 | "cell_type": "code", 297 | "metadata": { 298 | "id": "O4cg8hZS1Q1_" 299 | }, 300 | "source": [ 301 | "def derive_labels(gcs_paths, split=\"training\"):\n", 302 | " labels = []\n", 303 | " for gcs_path in gcs_paths:\n", 304 | " label = gcs_path.split(\"/\")[4]\n", 305 | " labels.append(label)\n", 306 | " return labels, [split] * len(gcs_paths)" 307 | ], 308 | "execution_count": null, 309 | "outputs": [] 310 | }, 311 | { 312 | "cell_type": "markdown", 313 | "metadata": { 314 | "id": "eXdAM-fvA6fv" 315 | }, 316 | "source": [ 317 | "### Prepare the lists" 318 | ] 319 | }, 320 | { 321 | "cell_type": "code", 322 | "metadata": { 323 | "id": "sQDtn-Ra16ZS" 324 | }, 325 | "source": [ 326 | "# File format is referred from: https://cloud.google.com/vertex-ai/docs/datasets/prepare-image#csv\n", 327 | "train_labels, train_use = derive_labels(train_paths)\n", 328 | "val_labels, val_use = derive_labels(valid_paths, split=\"validation\")\n", 329 | "test_labels, test_use= derive_labels(test_paths, split=\"test\")" 330 | ], 331 | "execution_count": null, 332 | "outputs": [] 333 | }, 334 | { 335 | "cell_type": "markdown", 336 | "metadata": { 337 | "id": "7CYTVk0XBAKm" 338 | }, 339 | "source": [ 340 | "### Create `.csv` file" 341 | ] 342 | }, 343 | { 344 | "cell_type": "code", 345 | "metadata": { 346 | "id": "lzBXCTox3jnJ" 347 | }, 348 | "source": [ 349 | "gcs_uris = []\n", 350 | "labels = []\n", 351 | "use = []\n", 352 | "\n", 353 | "gcs_uris.extend(train_paths)\n", 354 | "gcs_uris.extend(valid_paths)\n", 355 | "gcs_uris.extend(test_paths)\n", 356 | "\n", 357 | "labels.extend(train_labels)\n", 358 | "labels.extend(val_labels)\n", 359 | "labels.extend(test_labels)\n", 360 | "\n", 361 | "use.extend(train_use)\n", 362 | "use.extend(val_use)\n", 363 | "use.extend(test_use)" 364 | ], 365 | "execution_count": null, 366 | "outputs": [] 367 | }, 368 | { 369 | "cell_type": "code", 370 | "metadata": { 371 | "id": "iKzg-1Rh4EQ8" 372 | }, 373 | "source": [ 374 | "import csv\n", 375 | "\n", 376 | "with open(\"flowers_vertex.csv\", \"w\") as csvfile: \n", 377 | " csvwriter = csv.writer(csvfile)\n", 378 | " \n", 379 | " for ml_use, gcs_uri, label in zip(use, gcs_uris, labels):\n", 380 | " row = [ml_use, gcs_uri, label]\n", 381 | " csvwriter.writerow(row) " 382 | ], 383 | "execution_count": null, 384 | "outputs": [] 385 | }, 386 | { 387 | "cell_type": "code", 388 | "metadata": { 389 | "colab": { 390 | "base_uri": "https://localhost:8080/" 391 | }, 392 | "id": "7vuj-j-A4y_C", 393 | "outputId": "921eac5c-9abe-49ef-a017-efeb03c5e4ae" 394 | }, 395 | "source": [ 396 | "!head -5 flowers_vertex.csv" 397 | ], 398 | "execution_count": null, 399 | "outputs": [ 400 | { 401 | "output_type": "stream", 402 | "text": [ 403 | "training,gs://flowers-experimental/flower_photos/sunflowers/4895721788_f10208ab77_n.jpg,sunflowers\r\n", 404 | "training,gs://flowers-experimental/flower_photos/sunflowers/8202034834_ee0ee91e04_n.jpg,sunflowers\r\n", 405 | "training,gs://flowers-experimental/flower_photos/daisy/19019544592_b64469bf84_n.jpg,daisy\r\n", 406 | "training,gs://flowers-experimental/flower_photos/dandelion/4634716478_1cbcbee7ca.jpg,dandelion\r\n", 407 | "training,gs://flowers-experimental/flower_photos/tulips/12163418275_bd6a1edd61.jpg,tulips\r\n" 408 | ], 409 | "name": "stdout" 410 | } 411 | ] 412 | }, 413 | { 414 | "cell_type": "code", 415 | "metadata": { 416 | "colab": { 417 | "base_uri": "https://localhost:8080/" 418 | }, 419 | "id": "LSxVlje9BJ_Q", 420 | "outputId": "e045fcb1-8063-41d8-e207-5c1f6d8e8e54" 421 | }, 422 | "source": [ 423 | "!tail -5 flowers_vertex.csv" 424 | ], 425 | "execution_count": null, 426 | "outputs": [ 427 | { 428 | "output_type": "stream", 429 | "text": [ 430 | "test,gs://flowers-experimental/flower_photos/roses/6363951285_a802238d4e.jpg,roses\r\n", 431 | "test,gs://flowers-experimental/flower_photos/dandelion/4571923094_b9cefa9438_n.jpg,dandelion\r\n", 432 | "test,gs://flowers-experimental/flower_photos/roses/2471103806_87ba53d997_n.jpg,roses\r\n", 433 | "test,gs://flowers-experimental/flower_photos/roses/12238827553_cf427bfd51_n.jpg,roses\r\n", 434 | "test,gs://flowers-experimental/flower_photos/roses/3663244576_97f595cf4a.jpg,roses\r\n" 435 | ], 436 | "name": "stdout" 437 | } 438 | ] 439 | }, 440 | { 441 | "cell_type": "markdown", 442 | "metadata": { 443 | "id": "VgOMd-b-BQqM" 444 | }, 445 | "source": [ 446 | "## Copy over to a GCS Bucket" 447 | ] 448 | }, 449 | { 450 | "cell_type": "code", 451 | "metadata": { 452 | "colab": { 453 | "base_uri": "https://localhost:8080/" 454 | }, 455 | "id": "JSTu3ETJBTxP", 456 | "outputId": "7a89ba6b-5568-4441-ceed-db477c02993b" 457 | }, 458 | "source": [ 459 | "!gsutil cp flowers_vertex.csv {BUCKET_PATH}" 460 | ], 461 | "execution_count": null, 462 | "outputs": [ 463 | { 464 | "output_type": "stream", 465 | "text": [ 466 | "Copying file://flowers_vertex.csv [Content-Type=text/csv]...\n", 467 | "-\n", 468 | "Operation completed over 1 objects/334.7 KiB. \n" 469 | ], 470 | "name": "stdout" 471 | } 472 | ] 473 | } 474 | ] 475 | } -------------------------------------------------------------------------------- /notebooks/Dual_Deployments_With_AutoML.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "nbformat": 4, 3 | "nbformat_minor": 5, 4 | "metadata": { 5 | "environment": { 6 | "name": "common-cpu.m75", 7 | "type": "gcloud", 8 | "uri": "gcr.io/deeplearning-platform-release/base-cpu:m75" 9 | }, 10 | "kernelspec": { 11 | "display_name": "Python 3", 12 | "language": "python", 13 | "name": "python3" 14 | }, 15 | "language_info": { 16 | "codemirror_mode": { 17 | "name": "ipython", 18 | "version": 3 19 | }, 20 | "file_extension": ".py", 21 | "mimetype": "text/x-python", 22 | "name": "python", 23 | "nbconvert_exporter": "python", 24 | "pygments_lexer": "ipython3", 25 | "version": "3.7.10" 26 | }, 27 | "colab": { 28 | "name": "dual_deployments_with_AutoML.ipynb", 29 | "provenance": [], 30 | "include_colab_link": true 31 | } 32 | }, 33 | "cells": [ 34 | { 35 | "cell_type": "markdown", 36 | "metadata": { 37 | "id": "view-in-github", 38 | "colab_type": "text" 39 | }, 40 | "source": [ 41 | "\"Open" 42 | ] 43 | }, 44 | { 45 | "cell_type": "markdown", 46 | "metadata": { 47 | "id": "XlRzusV8MDT7" 48 | }, 49 | "source": [ 50 | "This notebook shows how to build a dual deployment pipeline in Vertex AI (one for tflite model export to firebase, normal model to be deployed in Vertex AI Endpoint). \n", 51 | "\n", 52 | "Before proceeding, please make sure the pre-requisites.\n", 53 | "- billing-enabled GCP account.\n", 54 | "- enable Vertex AI, Google Storage Service, Firebase APIs\n", 55 | "- add roles of Vertex AI Custom Code Service Agent, Vertex AI Service Agent, Vertex AI User, AI Platform Notebooks Service Agent, Storage Object Admin, Storage Object Creator \n", 56 | "- follow the steps described from `Before you begin` section in the official [Deploy and manage custom models with Firebase Admin SDK](https://firebase.google.com/docs/ml/manage-hosted-models) documentation to enable Firebase" 57 | ], 58 | "id": "XlRzusV8MDT7" 59 | }, 60 | { 61 | "cell_type": "markdown", 62 | "metadata": { 63 | "id": "61cd40e1" 64 | }, 65 | "source": [ 66 | "# Setup" 67 | ], 68 | "id": "61cd40e1" 69 | }, 70 | { 71 | "cell_type": "markdown", 72 | "metadata": { 73 | "id": "0de28066" 74 | }, 75 | "source": [ 76 | "## Install packages\n", 77 | "- `firebase-admin` will be used to publish TFLite model to firebase" 78 | ], 79 | "id": "0de28066" 80 | }, 81 | { 82 | "cell_type": "code", 83 | "metadata": { 84 | "id": "196d7830" 85 | }, 86 | "source": [ 87 | "import os\n", 88 | "\n", 89 | "# The Google Cloud Notebook product has specific requirements\n", 90 | "IS_GOOGLE_CLOUD_NOTEBOOK = os.path.exists(\"/opt/deeplearning/metadata/env_version\")\n", 91 | "\n", 92 | "# Google Cloud Notebook requires dependencies to be installed with '--user'\n", 93 | "USER_FLAG = \"\"\n", 94 | "if IS_GOOGLE_CLOUD_NOTEBOOK:\n", 95 | " USER_FLAG = \"--user\"" 96 | ], 97 | "id": "196d7830", 98 | "execution_count": null, 99 | "outputs": [] 100 | }, 101 | { 102 | "cell_type": "code", 103 | "metadata": { 104 | "id": "0c708f00" 105 | }, 106 | "source": [ 107 | "!pip3 install {USER_FLAG} google-cloud-aiplatform --upgrade\n", 108 | "!pip3 install {USER_FLAG} kfp google-cloud-pipeline-components --upgrade\n", 109 | "!pip3 install {USER_FLAG} firebase-admin" 110 | ], 111 | "id": "0c708f00", 112 | "execution_count": null, 113 | "outputs": [] 114 | }, 115 | { 116 | "cell_type": "markdown", 117 | "metadata": { 118 | "id": "812c875e" 119 | }, 120 | "source": [ 121 | "## Restart Jupyter Notebook programatically\n", 122 | "- in order to reflect the package installations" 123 | ], 124 | "id": "812c875e" 125 | }, 126 | { 127 | "cell_type": "code", 128 | "metadata": { 129 | "id": "92fb4dd6" 130 | }, 131 | "source": [ 132 | "# Automatically restart kernel after installs\n", 133 | "import os\n", 134 | "\n", 135 | "if not os.getenv(\"IS_TESTING\"):\n", 136 | " # Automatically restart kernel after installs\n", 137 | " import IPython\n", 138 | "\n", 139 | " app = IPython.Application.instance()\n", 140 | " app.kernel.do_shutdown(True)" 141 | ], 142 | "id": "92fb4dd6", 143 | "execution_count": null, 144 | "outputs": [] 145 | }, 146 | { 147 | "cell_type": "markdown", 148 | "metadata": { 149 | "id": "5f1f65e9" 150 | }, 151 | "source": [ 152 | "## Check KFP(KubeFlow Pipeline) version" 153 | ], 154 | "id": "5f1f65e9" 155 | }, 156 | { 157 | "cell_type": "code", 158 | "metadata": { 159 | "id": "55234629", 160 | "outputId": "e5ebdc8d-9ed9-4057-b31c-1d460d4fce4f" 161 | }, 162 | "source": [ 163 | "!python3 -c \"import kfp; print('KFP SDK version: {}'.format(kfp.__version__))\"" 164 | ], 165 | "id": "55234629", 166 | "execution_count": null, 167 | "outputs": [ 168 | { 169 | "output_type": "stream", 170 | "text": [ 171 | "KFP SDK version: 1.6.6\n" 172 | ], 173 | "name": "stdout" 174 | } 175 | ] 176 | }, 177 | { 178 | "cell_type": "markdown", 179 | "metadata": { 180 | "id": "c437adf1" 181 | }, 182 | "source": [ 183 | "## Setup GCP Project ID" 184 | ], 185 | "id": "c437adf1" 186 | }, 187 | { 188 | "cell_type": "code", 189 | "metadata": { 190 | "id": "3a4f1da1", 191 | "outputId": "0d06cedc-bf9a-4b49-87bc-9ecc52a069c9" 192 | }, 193 | "source": [ 194 | "import os\n", 195 | "\n", 196 | "PROJECT_ID = \"grounded-atrium-320207\"\n", 197 | "\n", 198 | "# Get your Google Cloud project ID from gcloud\n", 199 | "if not os.getenv(\"IS_TESTING\"):\n", 200 | " shell_output=!gcloud config list --format 'value(core.project)' 2>/dev/null\n", 201 | " PROJECT_ID = shell_output[0]\n", 202 | " print(\"Project ID: \", PROJECT_ID)" 203 | ], 204 | "id": "3a4f1da1", 205 | "execution_count": null, 206 | "outputs": [ 207 | { 208 | "output_type": "stream", 209 | "text": [ 210 | "Project ID: grounded-atrium-320207\n" 211 | ], 212 | "name": "stdout" 213 | } 214 | ] 215 | }, 216 | { 217 | "cell_type": "markdown", 218 | "metadata": { 219 | "id": "950c18ba" 220 | }, 221 | "source": [ 222 | "## GCP Authentication" 223 | ], 224 | "id": "950c18ba" 225 | }, 226 | { 227 | "cell_type": "code", 228 | "metadata": { 229 | "id": "2e58f1fd" 230 | }, 231 | "source": [ 232 | "import os\n", 233 | "import sys\n", 234 | "\n", 235 | "# If you are running this notebook in Colab, run this cell and follow the\n", 236 | "# instructions to authenticate your GCP account. This provides access to your\n", 237 | "# Cloud Storage bucket and lets you submit training jobs and prediction\n", 238 | "# requests.\n", 239 | "\n", 240 | "# The Google Cloud Notebook product has specific requirements\n", 241 | "IS_GOOGLE_CLOUD_NOTEBOOK = os.path.exists(\"/opt/deeplearning/metadata/env_version\")\n", 242 | "\n", 243 | "# If on Google Cloud Notebooks, then don't execute this code\n", 244 | "if not IS_GOOGLE_CLOUD_NOTEBOOK:\n", 245 | " if \"google.colab\" in sys.modules:\n", 246 | " from google.colab import auth as google_auth\n", 247 | "\n", 248 | " google_auth.authenticate_user()\n", 249 | "\n", 250 | " # If you are running this notebook locally, replace the string below with the\n", 251 | " # path to your service account key and run this cell to authenticate your GCP\n", 252 | " # account.\n", 253 | " elif not os.getenv(\"IS_TESTING\"):\n", 254 | " %env GOOGLE_APPLICATION_CREDENTIALS ''" 255 | ], 256 | "id": "2e58f1fd", 257 | "execution_count": null, 258 | "outputs": [] 259 | }, 260 | { 261 | "cell_type": "markdown", 262 | "metadata": { 263 | "id": "7221dad0" 264 | }, 265 | "source": [ 266 | "## Setup GCS Bucket name\n", 267 | "- this bucket is where everything is going to stored" 268 | ], 269 | "id": "7221dad0" 270 | }, 271 | { 272 | "cell_type": "code", 273 | "metadata": { 274 | "id": "6d49e084" 275 | }, 276 | "source": [ 277 | "BUCKET_NAME = \"gs://vertexai_dual_example\"\n", 278 | "REGION = \"us-central1\" " 279 | ], 280 | "id": "6d49e084", 281 | "execution_count": null, 282 | "outputs": [] 283 | }, 284 | { 285 | "cell_type": "markdown", 286 | "metadata": { 287 | "id": "b1690ebb" 288 | }, 289 | "source": [ 290 | "### Create GCS Bucket" 291 | ], 292 | "id": "b1690ebb" 293 | }, 294 | { 295 | "cell_type": "code", 296 | "metadata": { 297 | "id": "f3e85ef1", 298 | "outputId": "2b2854e3-e372-473a-ae0f-97efe8d9521e" 299 | }, 300 | "source": [ 301 | "!gsutil mb -l $REGION $BUCKET_NAME" 302 | ], 303 | "id": "f3e85ef1", 304 | "execution_count": null, 305 | "outputs": [ 306 | { 307 | "output_type": "stream", 308 | "text": [ 309 | "Creating gs://vertexai_dual_example/...\n" 310 | ], 311 | "name": "stdout" 312 | } 313 | ] 314 | }, 315 | { 316 | "cell_type": "markdown", 317 | "metadata": { 318 | "id": "5f1d553a" 319 | }, 320 | "source": [ 321 | "## Setup GCS Path for Pipeline\n", 322 | "- the pipeline runs are going to be stored (i.e. Metadata)" 323 | ], 324 | "id": "5f1d553a" 325 | }, 326 | { 327 | "cell_type": "code", 328 | "metadata": { 329 | "id": "f0c61275", 330 | "outputId": "12d1b068-f9b6-4105-e151-ec5c935d4c4d" 331 | }, 332 | "source": [ 333 | "PATH=%env PATH\n", 334 | "%env PATH={PATH}:/home/jupyter/.local/bin\n", 335 | "\n", 336 | "USER = \"chansung\"\n", 337 | "PIPELINE_ROOT = \"{}/pipeline_root/{}\".format(BUCKET_NAME, USER)\n", 338 | "PIPELINE_ROOT" 339 | ], 340 | "id": "f0c61275", 341 | "execution_count": null, 342 | "outputs": [ 343 | { 344 | "output_type": "stream", 345 | "text": [ 346 | "env: PATH=/opt/conda/bin:/opt/conda/condabin:/usr/local/bin:/usr/bin:/bin:/usr/local/games:/usr/games:/home/jupyter/.local/bin\n" 347 | ], 348 | "name": "stdout" 349 | }, 350 | { 351 | "output_type": "execute_result", 352 | "data": { 353 | "text/plain": [ 354 | "'gs://vertexai_dual_example/pipeline_root/chansung'" 355 | ] 356 | }, 357 | "metadata": { 358 | "tags": [] 359 | }, 360 | "execution_count": 8 361 | } 362 | ] 363 | }, 364 | { 365 | "cell_type": "markdown", 366 | "metadata": { 367 | "id": "6986b40b" 368 | }, 369 | "source": [ 370 | "## Build Pipeline" 371 | ], 372 | "id": "6986b40b" 373 | }, 374 | { 375 | "cell_type": "markdown", 376 | "metadata": { 377 | "id": "f6ff09f9" 378 | }, 379 | "source": [ 380 | "### Install packages\n", 381 | "- KFP related\n", 382 | "- Client API for AI Platform" 383 | ], 384 | "id": "f6ff09f9" 385 | }, 386 | { 387 | "cell_type": "code", 388 | "metadata": { 389 | "id": "9a75acd2" 390 | }, 391 | "source": [ 392 | "import kfp\n", 393 | "from google.cloud import aiplatform\n", 394 | "from google_cloud_pipeline_components import aiplatform as gcc_aip\n", 395 | "from kfp.v2 import compiler\n", 396 | "from kfp.v2.google.client import AIPlatformClient\n", 397 | "from kfp.v2 import dsl\n", 398 | "from kfp.v2.dsl import component" 399 | ], 400 | "id": "9a75acd2", 401 | "execution_count": null, 402 | "outputs": [] 403 | }, 404 | { 405 | "cell_type": "markdown", 406 | "metadata": { 407 | "id": "e5e21652" 408 | }, 409 | "source": [ 410 | "### Define pipeline component to publish TFLite model to Firebase\n", 411 | "0. please follow the steps described from `Before you begin` section in the official [Deploy and manage custom models with Firebase Admin SDK](https://firebase.google.com/docs/ml/manage-hosted-models) documentation. \n", 412 | "1. download credentials for the Firebase project\n", 413 | "2. download TFLite model file\n", 414 | "3. initialize firebase admin\n", 415 | "4. upload and publish TFLite model from local file to Firebase" 416 | ], 417 | "id": "e5e21652" 418 | }, 419 | { 420 | "cell_type": "code", 421 | "metadata": { 422 | "id": "57076987" 423 | }, 424 | "source": [ 425 | "@component(\n", 426 | " packages_to_install=[\"google-cloud-storage\", \"firebase-admin\", \"tensorflow\"]\n", 427 | ")\n", 428 | "def push_to_firebase(\n", 429 | " credential_uri: str,\n", 430 | " model_bucket: str,\n", 431 | " firebase_dest_gcs_bucket: str,\n", 432 | " model_display_name: str,\n", 433 | " model_tag: str\n", 434 | "):\n", 435 | " import firebase_admin\n", 436 | " from firebase_admin import ml\n", 437 | " from firebase_admin import storage\n", 438 | " from firebase_admin import credentials \n", 439 | " from google.cloud import storage as gcs_storage\n", 440 | " \n", 441 | " gcs_client = gcs_storage.Client()\n", 442 | " \n", 443 | " # get credential for firebase \n", 444 | " credential_gcs_bucket = credential_uri.split('//')[1].split('/')[0]\n", 445 | " credential_blob_path = '/'.join(credential_uri.split('//')[1].split('/')[1:])\n", 446 | " \n", 447 | " bucket = gcs_client.bucket(credential_gcs_bucket)\n", 448 | " blob = bucket.blob(credential_blob_path)\n", 449 | " blob.download_to_filename('credential.json')\n", 450 | " \n", 451 | " # get the latest model \n", 452 | " tflite_blobs = gcs_client.get_bucket(model_bucket).list_blobs()\n", 453 | " tflite_blob = sorted(tflite_blobs, reverse=True, key=lambda blob: blob.name.split('/')[-2])[0]\n", 454 | " tflite_blob.download_to_filename('model.tflite') \n", 455 | " \n", 456 | " firebase_admin.initialize_app(\n", 457 | " credentials.Certificate('credential.json'),\n", 458 | " options={\n", 459 | " 'storageBucket': firebase_dest_gcs_bucket\n", 460 | " }\n", 461 | " )\n", 462 | "\n", 463 | " model_list = ml.list_models(list_filter=f'display_name={model_display_name}')\n", 464 | " # update\n", 465 | " if len(model_list.models) > 0:\n", 466 | " # get the first match model\n", 467 | " model = model_list.models[0]\n", 468 | " \n", 469 | " source = ml.TFLiteGCSModelSource.from_tflite_model_file('model.tflite')\n", 470 | " model.model_format = ml.TFLiteFormat(model_source=source)\n", 471 | " \n", 472 | " updated_model = ml.update_model(model)\n", 473 | " ml.publish_model(updated_model.model_id)\n", 474 | " # create\n", 475 | " else: \n", 476 | " # Load a tflite file and upload it to Cloud Storage\n", 477 | " source = ml.TFLiteGCSModelSource.from_tflite_model_file('model.tflite')\n", 478 | "\n", 479 | " # Create the model object\n", 480 | " tflite_format = ml.TFLiteFormat(model_source=source)\n", 481 | " model = ml.Model(\n", 482 | " display_name=model_display_name, # This is the name you use from your app to load the model.\n", 483 | " tags=[model_tag], # Optional tags for easier management.\n", 484 | " model_format=tflite_format)\n", 485 | "\n", 486 | " # Add the model to your Firebase project and publish it\n", 487 | " new_model = ml.create_model(model)\n", 488 | " ml.publish_model(new_model.model_id)" 489 | ], 490 | "id": "57076987", 491 | "execution_count": null, 492 | "outputs": [] 493 | }, 494 | { 495 | "cell_type": "markdown", 496 | "metadata": { 497 | "id": "c9d14770" 498 | }, 499 | "source": [ 500 | "### Define pipeline itself\n", 501 | "1. Create Vertex AI's managed dataset from CSV\n", 502 | "2. Define configs, one for cloud model, and the other one for mobile model\n", 503 | "3. Run parallel processing for two different workflow(each workflow is configured appropriate for each target environment)\n", 504 | " - AutoML training can be configured differently for each target environment depending on `model_type`\n", 505 | "4. Deploying the trained model as well as creating an endpoint is done with `ModelDeployOp` for cloud model\n", 506 | "5. Export the trained mobile model to a GCS bucket\n", 507 | " - publish the exported mobile model to Firebase through push_to_firebase component\n", 508 | " " 509 | ], 510 | "id": "c9d14770" 511 | }, 512 | { 513 | "cell_type": "code", 514 | "metadata": { 515 | "id": "22e3d30c" 516 | }, 517 | "source": [ 518 | "@kfp.dsl.pipeline(name=\"cloud-mobile-dual-deployment\")\n", 519 | "def pipeline(project: str = PROJECT_ID):\n", 520 | " ds_op = gcc_aip.ImageDatasetCreateOp(\n", 521 | " project=project,\n", 522 | " display_name=\"flowers-dataset\",\n", 523 | " gcs_source=\"gs://dataset-meta-gde-csp/flowers_vertex.csv\",\n", 524 | " import_schema_uri=aiplatform.schema.dataset.ioformat.image.multi_label_classification,\n", 525 | " )\n", 526 | "\n", 527 | " configs = [\n", 528 | " {\n", 529 | " \"type\": \"CLOUD\",\n", 530 | " \"model_type\": \"CLOUD\",\n", 531 | " \"display_name\": \"train-cloud-model\",\n", 532 | " \"model_display_name\": \"cloud-model\",\n", 533 | " \"budget_milli_node_hours\": 8000,\n", 534 | " },\n", 535 | " {\n", 536 | " \"type\": \"MOBILE\",\n", 537 | " \"model_type\": \"MOBILE_TF_VERSATILE_1\",\n", 538 | " \"display_name\": \"train-mobile-model\",\n", 539 | " \"model_display_name\": \"mobile-model\",\n", 540 | " \"budget_milli_node_hours\": 1000,\n", 541 | " }\n", 542 | " ]\n", 543 | "\n", 544 | " with kfp.dsl.ParallelFor(configs) as config:\n", 545 | " training_job_run_op = gcc_aip.AutoMLImageTrainingJobRunOp(\n", 546 | " project=project,\n", 547 | " display_name=config.display_name,\n", 548 | " prediction_type=\"classification\",\n", 549 | " multi_label=True,\n", 550 | " model_type=config.model_type,\n", 551 | " base_model=None,\n", 552 | " dataset=ds_op.outputs[\"dataset\"],\n", 553 | " model_display_name=config.model_display_name,\n", 554 | " budget_milli_node_hours=config.budget_milli_node_hours,\n", 555 | " )\n", 556 | " training_job_run_op.after(ds_op)\n", 557 | "\n", 558 | " with kfp.dsl.Condition(config.type=='CLOUD'):\n", 559 | " endpoint_op = gcc_aip.ModelDeployOp(\n", 560 | " project=project,\n", 561 | " model=training_job_run_op.outputs[\"model\"]\n", 562 | " )\n", 563 | " endpoint_op.after(training_job_run_op)\n", 564 | " \n", 565 | " with kfp.dsl.Condition(config.type=='MOBILE'):\n", 566 | " export_op = gcc_aip.ModelExportOp( \n", 567 | " project=project,\n", 568 | " model=training_job_run_op.outputs[\"model\"],\n", 569 | " # tflite, edgetpu-tflite, tf-saved-model, tf-js, core-ml, custom-trained\n", 570 | " export_format_id=\"tflite\",\n", 571 | " artifact_destination=\"gs://output-model-gde-csp/flower-models/\"\n", 572 | " )\n", 573 | " export_op.after(training_job_run_op)\n", 574 | "\n", 575 | " credential_uri=\"gs://firebase-ml-bucket-gde-csp/grounded-atrium-320207-firebase-adminsdk-5n9sn-20dbda9947.json\"\n", 576 | " model_bucket=\"output-model-gde-csp\"\n", 577 | " firebase_bucket=\"grounded-atrium-320207.appspot.com\"\n", 578 | " \n", 579 | " firebase_op = push_to_firebase(\n", 580 | " ins=export_op.outputs['exported_dataset'],\n", 581 | " credential_uri=credential_uri,\n", 582 | " model_bucket=model_bucket,\n", 583 | " firebase_dest_gcs_bucket=firebase_bucket,\n", 584 | " model_display_name=\"custom_model\",\n", 585 | " model_tag=\"from_dual_deployment\"\n", 586 | " )\n", 587 | " firebase_op.after(export_op)\n", 588 | " " 589 | ], 590 | "id": "22e3d30c", 591 | "execution_count": null, 592 | "outputs": [] 593 | }, 594 | { 595 | "cell_type": "markdown", 596 | "metadata": { 597 | "id": "59a52203" 598 | }, 599 | "source": [ 600 | "### Compile the pipeline\n", 601 | "- you will get a json file for the pipeline spec after compiling.\n", 602 | " - you will only need this json file to run the pipeline" 603 | ], 604 | "id": "59a52203" 605 | }, 606 | { 607 | "cell_type": "code", 608 | "metadata": { 609 | "id": "6b0cdaef" 610 | }, 611 | "source": [ 612 | "from kfp.v2 import compiler\n", 613 | "\n", 614 | "compiler.Compiler().compile(\n", 615 | " pipeline_func=pipeline, package_path=\"cloud-mobile-dual-deployment.json\"\n", 616 | ")" 617 | ], 618 | "id": "6b0cdaef", 619 | "execution_count": null, 620 | "outputs": [] 621 | }, 622 | { 623 | "cell_type": "markdown", 624 | "metadata": { 625 | "id": "7ec1f64f" 626 | }, 627 | "source": [ 628 | "## Run the pipeline on Vertex AI Pipeline" 629 | ], 630 | "id": "7ec1f64f" 631 | }, 632 | { 633 | "cell_type": "markdown", 634 | "metadata": { 635 | "id": "fcbb6c14" 636 | }, 637 | "source": [ 638 | "### Create client instance to AI Platform (which is Vertex AI)" 639 | ], 640 | "id": "fcbb6c14" 641 | }, 642 | { 643 | "cell_type": "code", 644 | "metadata": { 645 | "id": "cc9a0a61" 646 | }, 647 | "source": [ 648 | "from kfp.v2.google.client import AIPlatformClient\n", 649 | "\n", 650 | "api_client = AIPlatformClient(project_id=PROJECT_ID, region=REGION)" 651 | ], 652 | "id": "cc9a0a61", 653 | "execution_count": null, 654 | "outputs": [] 655 | }, 656 | { 657 | "cell_type": "markdown", 658 | "metadata": { 659 | "id": "07274930" 660 | }, 661 | "source": [ 662 | "### Run the pipeline with the pipeline spec (json file)" 663 | ], 664 | "id": "07274930" 665 | }, 666 | { 667 | "cell_type": "code", 668 | "metadata": { 669 | "id": "f485821f", 670 | "outputId": "a76c399e-70ff-4fe2-e476-772cf47e932c" 671 | }, 672 | "source": [ 673 | "response = api_client.create_run_from_job_spec(\n", 674 | " \"cloud-mobile-dual-deployment.json\",\n", 675 | " pipeline_root=PIPELINE_ROOT,\n", 676 | " parameter_values={\"project\": PROJECT_ID},\n", 677 | ")" 678 | ], 679 | "id": "f485821f", 680 | "execution_count": null, 681 | "outputs": [ 682 | { 683 | "output_type": "display_data", 684 | "data": { 685 | "text/html": [ 686 | "See the Pipeline job here." 687 | ], 688 | "text/plain": [ 689 | "" 690 | ] 691 | }, 692 | "metadata": { 693 | "tags": [] 694 | } 695 | } 696 | ] 697 | } 698 | ] 699 | } -------------------------------------------------------------------------------- /notebooks/Custom_Model_TFX.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "nbformat": 4, 3 | "nbformat_minor": 0, 4 | "metadata": { 5 | "colab": { 6 | "name": "Custom_Model_TFX", 7 | "provenance": [], 8 | "collapsed_sections": [], 9 | "machine_shape": "hm", 10 | "include_colab_link": true 11 | }, 12 | "environment": { 13 | "name": "tf2-gpu.2-4.mnightly-2021-02-02-debian-10-test", 14 | "type": "gcloud", 15 | "uri": "gcr.io/deeplearning-platform-release/tf2-gpu.2-4:mnightly-2021-02-02-debian-10-test" 16 | }, 17 | "kernelspec": { 18 | "display_name": "Python 3", 19 | "language": "python", 20 | "name": "python3" 21 | }, 22 | "language_info": { 23 | "codemirror_mode": { 24 | "name": "ipython", 25 | "version": 3 26 | }, 27 | "file_extension": ".py", 28 | "mimetype": "text/x-python", 29 | "name": "python", 30 | "nbconvert_exporter": "python", 31 | "pygments_lexer": "ipython3", 32 | "version": "3.7.4" 33 | } 34 | }, 35 | "cells": [ 36 | { 37 | "cell_type": "markdown", 38 | "metadata": { 39 | "id": "view-in-github", 40 | "colab_type": "text" 41 | }, 42 | "source": [ 43 | "\"Open" 44 | ] 45 | }, 46 | { 47 | "cell_type": "markdown", 48 | "metadata": { 49 | "id": "mTVp-9PGYFIO" 50 | }, 51 | "source": [ 52 | "In this notebook, we will build two custom models - one for Endpoint deployment and the other one for mobile deployment. The mobile-friendly model will be first converted to TFLite and will then be pushed to Firebase because it provides enhanced support to operate with TFLite models in production. We will write a TFX pipeline to run their training and export. The entire pipeline will be orchestrated using [Vertex AI Pipelines](https://cloud.google.com/vertex-ai/docs/pipelines/introduction). \n", 53 | "\n", 54 | "This notebook assumes you are familiar with the basics of Vertex AI, TFX (especially custom components), and TensorFlow. " 55 | ] 56 | }, 57 | { 58 | "cell_type": "markdown", 59 | "metadata": { 60 | "id": "W7gJqmqrsfqh" 61 | }, 62 | "source": [ 63 | "## References\n", 64 | "\n", 65 | "This notebook refers to the following resources and also reuses parts of the code from there: \n", 66 | "* [Simple TFX Pipeline for Vertex Pipelines](https://colab.research.google.com/github/tensorflow/tfx/blob/master/docs/tutorials/tfx/gcp/vertex_pipelines_simple.ipynb)\n", 67 | "* [Vertex AI Training with TFX and Vertex Pipelines](https://www.tensorflow.org/tfx/tutorials/tfx/gcp/vertex_pipelines_vertex_training)\n", 68 | "* [Importing models to Vertex AI](https://cloud.google.com/vertex-ai/docs/general/import-model)\n", 69 | "* [Deploying a model using the Vertex AI API](https://cloud.google.com/vertex-ai/docs/predictions/deploy-model-api)\n", 70 | "* [MLOPs with Vertex AI](https://github.com/GoogleCloudPlatform/mlops-with-vertex-ai)\n", 71 | "* [Custom components TFX](https://www.tensorflow.org/tfx/tutorials/tfx/python_function_component)" 72 | ] 73 | }, 74 | { 75 | "cell_type": "markdown", 76 | "metadata": { 77 | "id": "D04aKMGWXjOu" 78 | }, 79 | "source": [ 80 | "## Setup" 81 | ] 82 | }, 83 | { 84 | "cell_type": "code", 85 | "metadata": { 86 | "id": "I_niUhp_TY1G" 87 | }, 88 | "source": [ 89 | "# Use the latest version of pip.\n", 90 | "%%capture\n", 91 | "!pip install --upgrade pip\n", 92 | "!pip install --upgrade tfx==1.0.0 kfp==1.6.1\n", 93 | "!pip install -q --upgrade google-cloud-aiplatform" 94 | ], 95 | "execution_count": null, 96 | "outputs": [] 97 | }, 98 | { 99 | "cell_type": "markdown", 100 | "metadata": { 101 | "id": "ZVmgQ6w1oT_Z" 102 | }, 103 | "source": [ 104 | "### ***Please restart runtime before continuing.*** " 105 | ] 106 | }, 107 | { 108 | "cell_type": "code", 109 | "metadata": { 110 | "id": "mstgsNHWoiXk" 111 | }, 112 | "source": [ 113 | "!gcloud init" 114 | ], 115 | "execution_count": null, 116 | "outputs": [] 117 | }, 118 | { 119 | "cell_type": "code", 120 | "metadata": { 121 | "id": "Pl8ewjX3oXRx" 122 | }, 123 | "source": [ 124 | "from google.colab import auth\n", 125 | "auth.authenticate_user()" 126 | ], 127 | "execution_count": null, 128 | "outputs": [] 129 | }, 130 | { 131 | "cell_type": "markdown", 132 | "metadata": { 133 | "id": "zqVWpmywXngD" 134 | }, 135 | "source": [ 136 | "## Imports" 137 | ] 138 | }, 139 | { 140 | "cell_type": "code", 141 | "metadata": { 142 | "colab": { 143 | "base_uri": "https://localhost:8080/" 144 | }, 145 | "id": "wptXF0e-UXsT", 146 | "outputId": "1bb97d2e-f3f5-419f-e987-b2847a6adb8a" 147 | }, 148 | "source": [ 149 | "import tensorflow as tf\n", 150 | "print('TensorFlow version: {}'.format(tf.__version__))\n", 151 | "from tfx import v1 as tfx\n", 152 | "print('TFX version: {}'.format(tfx.__version__))\n", 153 | "import kfp\n", 154 | "print('KFP version: {}'.format(kfp.__version__))\n", 155 | "\n", 156 | "from google.cloud import aiplatform as vertex_ai\n", 157 | "import os" 158 | ], 159 | "execution_count": null, 160 | "outputs": [ 161 | { 162 | "output_type": "stream", 163 | "text": [ 164 | "TensorFlow version: 2.5.0\n", 165 | "TFX version: 1.0.0\n", 166 | "KFP version: 1.6.1\n" 167 | ], 168 | "name": "stdout" 169 | } 170 | ] 171 | }, 172 | { 173 | "cell_type": "markdown", 174 | "metadata": { 175 | "id": "hFYHeepnXxpZ" 176 | }, 177 | "source": [ 178 | "## Environment setup" 179 | ] 180 | }, 181 | { 182 | "cell_type": "code", 183 | "metadata": { 184 | "id": "zPVyBrXrW-vu" 185 | }, 186 | "source": [ 187 | "GOOGLE_CLOUD_PROJECT = 'fast-ai-exploration' #@param {type:\"string\"}\n", 188 | "GOOGLE_CLOUD_REGION = 'us-central1' #@param {type:\"string\"}\n", 189 | "GCS_BUCKET_NAME = 'vertex-tfx-mlops' #@param {type:\"string\"}\n", 190 | "\n", 191 | "if not (GOOGLE_CLOUD_PROJECT and GOOGLE_CLOUD_REGION and GCS_BUCKET_NAME):\n", 192 | " from absl import logging\n", 193 | " logging.error('Please set all required parameters.')" 194 | ], 195 | "execution_count": null, 196 | "outputs": [] 197 | }, 198 | { 199 | "cell_type": "markdown", 200 | "metadata": { 201 | "id": "CV-BZSvQq7YY" 202 | }, 203 | "source": [ 204 | "The location of the bucket must be a single region. Also, the bucket needs to be created in a region when [Vertex AI services are available](https://cloud.google.com/vertex-ai/docs/general/locations#available_regions). " 205 | ] 206 | }, 207 | { 208 | "cell_type": "code", 209 | "metadata": { 210 | "colab": { 211 | "base_uri": "https://localhost:8080/" 212 | }, 213 | "id": "J65KHrt4X-Fu", 214 | "outputId": "085fe0e1-02fd-47e3-bc6e-ee24ad82b1b2" 215 | }, 216 | "source": [ 217 | "PIPELINE_NAME = 'two-way-vertex-pipelines5'\n", 218 | "\n", 219 | "# Path to various pipeline artifact.\n", 220 | "PIPELINE_ROOT = 'gs://{}/pipeline_root/{}'.format(\n", 221 | " GCS_BUCKET_NAME, PIPELINE_NAME)\n", 222 | "\n", 223 | "# Paths for users' Python module.\n", 224 | "MODULE_ROOT = 'gs://{}/pipeline_module/{}'.format(\n", 225 | " GCS_BUCKET_NAME, PIPELINE_NAME)\n", 226 | "\n", 227 | "# Paths for input data.\n", 228 | "DATA_ROOT = 'gs://flowers-public/tfrecords-jpeg-224x224'\n", 229 | "\n", 230 | "# This is the path where your model will be pushed for serving.\n", 231 | "SERVING_MODEL_DIR = 'gs://{}/serving_model/{}'.format(\n", 232 | " GCS_BUCKET_NAME, PIPELINE_NAME)\n", 233 | "\n", 234 | "print('PIPELINE_ROOT: {}'.format(PIPELINE_ROOT))" 235 | ], 236 | "execution_count": null, 237 | "outputs": [ 238 | { 239 | "output_type": "stream", 240 | "text": [ 241 | "PIPELINE_ROOT: gs://demo-experiments-gde-csp/pipeline_root/two-way-vertex-pipelines5\n" 242 | ], 243 | "name": "stdout" 244 | } 245 | ] 246 | }, 247 | { 248 | "cell_type": "markdown", 249 | "metadata": { 250 | "id": "dIfd3_x0bfTu" 251 | }, 252 | "source": [ 253 | "There are three things you need for Firebase setup, and you can find a sufficient description for those in [Deploy and manage custom models with Firebase Admin SDK](https://firebase.google.com/docs/ml/manage-hosted-models) official document. \n", 254 | "\n", 255 | "1. You need to obtain credential file for your [Firebase project](https://console.firebase.google.com/).\n", 256 | " - On the [Settings](https://console.firebase.google.com/project/_/settings/serviceaccounts/adminsdk) page, create a service account and download the service account key file. Keep this file safe, since it grants administrator access to your project.\n", 257 | " - Save the credential JSON file in a GCS bucket, and replace the `FIREBASE_CREDENTIAL_PATH` value with it.\n", 258 | "\n", 259 | "\n", 260 | "2. You need to create a GCS bucket where the model is going to be temporarily stored.\n", 261 | " - On the Storage page, enable Cloud Storage. Take note of your bucket name.\n", 262 | " - Replace `FIREBASE_GCS_BUCKET` with the obtained GCS bucket name. It usually has this form `YOUR_GCP_PROJECT_ID.appspot.com`.\n", 263 | "\n", 264 | "\n", 265 | "3. On the Firebase ML page, click Get started if you haven't yet enabled Firebase ML." 266 | ] 267 | }, 268 | { 269 | "cell_type": "code", 270 | "metadata": { 271 | "id": "fWkhLCv_bfTu" 272 | }, 273 | "source": [ 274 | "FIREBASE_CREDENTIAL_PATH = 'gs://credential-csp/gcp-ml-172005-firebase-adminsdk-5gdtb-38c6644f1e.json'\n", 275 | "FIREBASE_GCS_BUCKET = 'gcp-ml-172005.appspot.com'" 276 | ], 277 | "execution_count": null, 278 | "outputs": [] 279 | }, 280 | { 281 | "cell_type": "markdown", 282 | "metadata": { 283 | "id": "kQVpzyftX0y0" 284 | }, 285 | "source": [ 286 | "## Create training modules" 287 | ] 288 | }, 289 | { 290 | "cell_type": "code", 291 | "metadata": { 292 | "id": "AR5pU65m6nAE" 293 | }, 294 | "source": [ 295 | "_trainer_densenet_module_file = 'flower_densenet_trainer.py'\n", 296 | "_trainer_mobilenet_module_file = 'flower_mobilenet_trainer.py'" 297 | ], 298 | "execution_count": null, 299 | "outputs": [] 300 | }, 301 | { 302 | "cell_type": "markdown", 303 | "source": [ 304 | "`_model_exporter()` utility comes from [Practical Machine Learning for Computer Vision](https://www.oreilly.com/library/view/practical-machine-learning/9781098102357/)." 305 | ], 306 | "metadata": { 307 | "id": "3vUUTikER33Y" 308 | } 309 | }, 310 | { 311 | "cell_type": "code", 312 | "metadata": { 313 | "colab": { 314 | "base_uri": "https://localhost:8080/" 315 | }, 316 | "id": "XqwYrR1GYLvm", 317 | "outputId": "4f8c480a-a4cd-42cc-9d5c-8d5ae822ae22" 318 | }, 319 | "source": [ 320 | "%%writefile {_trainer_densenet_module_file}\n", 321 | "\n", 322 | "from typing import List\n", 323 | "from absl import logging\n", 324 | "from tensorflow import keras\n", 325 | "from tfx import v1 as tfx\n", 326 | "import tensorflow as tf\n", 327 | "\n", 328 | "\n", 329 | "_IMAGE_FEATURES = {\n", 330 | " \"image\": tf.io.FixedLenFeature([], tf.string),\n", 331 | " \"class\": tf.io.FixedLenFeature([], tf.int64),\n", 332 | " \"one_hot_class\": tf.io.VarLenFeature(tf.float32),\n", 333 | "}\n", 334 | "\n", 335 | "_CONCRETE_INPUT = \"numpy_inputs\"\n", 336 | "_INPUT_SHAPE = (224, 224, 3)\n", 337 | "_TRAIN_BATCH_SIZE = 64\n", 338 | "_EVAL_BATCH_SIZE = 64\n", 339 | "_EPOCHS = 2\n", 340 | "\n", 341 | "\n", 342 | "def _parse_fn(example):\n", 343 | " example = tf.io.parse_single_example(example, _IMAGE_FEATURES)\n", 344 | " image = tf.image.decode_jpeg(example[\"image\"], channels=3)\n", 345 | " class_label = tf.cast(example[\"class\"], tf.int32)\n", 346 | " return image, class_label\n", 347 | "\n", 348 | "\n", 349 | "def _input_fn(file_pattern: List[str], batch_size: int) -> tf.data.Dataset:\n", 350 | " \"\"\"Generates features and label for training.\n", 351 | "\n", 352 | " Args:\n", 353 | " file_pattern: List of paths or patterns of input tfrecord files.\n", 354 | " batch_size: representing the number of consecutive elements of returned\n", 355 | " dataset to combine in a single batch.\n", 356 | "\n", 357 | " Returns:\n", 358 | " A dataset that contains (features, indices) tuple where features is a\n", 359 | " dictionary of Tensors, and indices is a single Tensor of label indices.\n", 360 | " \"\"\"\n", 361 | " logging.info(f\"Reading data from: {file_pattern}\")\n", 362 | " tfrecord_filenames = tf.io.gfile.glob(file_pattern[0] + \".gz\")\n", 363 | " dataset = tf.data.TFRecordDataset(tfrecord_filenames, compression_type=\"GZIP\")\n", 364 | " dataset = dataset.map(_parse_fn).batch(batch_size)\n", 365 | " return dataset.repeat()\n", 366 | "\n", 367 | "\n", 368 | "def _make_keras_model() -> tf.keras.Model:\n", 369 | " \"\"\"Creates a DenseNet121-based model for classifying flowers data.\n", 370 | "\n", 371 | " Returns:\n", 372 | " A Keras Model.\n", 373 | " \"\"\"\n", 374 | " inputs = keras.Input(shape=_INPUT_SHAPE)\n", 375 | " base_model = keras.applications.DenseNet121(\n", 376 | " include_top=False, input_shape=_INPUT_SHAPE, pooling=\"avg\"\n", 377 | " )\n", 378 | " base_model.trainable = False\n", 379 | " x = keras.applications.densenet.preprocess_input(inputs)\n", 380 | " x = base_model(\n", 381 | " x, training=False\n", 382 | " ) # Ensures BatchNorm runs in inference model in this model\n", 383 | " outputs = keras.layers.Dense(5, activation=\"softmax\")(x)\n", 384 | " model = keras.Model(inputs, outputs)\n", 385 | "\n", 386 | " model.compile(\n", 387 | " optimizer=keras.optimizers.Adam(),\n", 388 | " loss=tf.keras.losses.SparseCategoricalCrossentropy(),\n", 389 | " metrics=[keras.metrics.SparseCategoricalAccuracy()],\n", 390 | " )\n", 391 | "\n", 392 | " model.summary(print_fn=logging.info)\n", 393 | " return model\n", 394 | "\n", 395 | "\n", 396 | "def _preprocess(bytes_input):\n", 397 | " decoded = tf.io.decode_jpeg(bytes_input, channels=3)\n", 398 | " resized = tf.image.resize(decoded, size=(224, 224))\n", 399 | " return resized\n", 400 | "\n", 401 | "\n", 402 | "@tf.function(input_signature=[tf.TensorSpec([None], tf.string)])\n", 403 | "def preprocess_fn(bytes_inputs):\n", 404 | " decoded_images = tf.map_fn(\n", 405 | " _preprocess, bytes_inputs, dtype=tf.float32, back_prop=False\n", 406 | " )\n", 407 | " return {_CONCRETE_INPUT: decoded_images}\n", 408 | "\n", 409 | "\n", 410 | "def _model_exporter(model: tf.keras.Model):\n", 411 | " m_call = tf.function(model.call).get_concrete_function(\n", 412 | " [\n", 413 | " tf.TensorSpec(\n", 414 | " shape=[None, 224, 224, 3], dtype=tf.float32, name=_CONCRETE_INPUT\n", 415 | " )\n", 416 | " ]\n", 417 | " )\n", 418 | "\n", 419 | " @tf.function(input_signature=[tf.TensorSpec([None], tf.string)])\n", 420 | " def serving_fn(bytes_inputs):\n", 421 | " labels = tf.constant(\n", 422 | " [\"daisy\", \"dandelion\", \"roses\", \"sunflowers\", \"tulips\"], dtype=tf.string\n", 423 | " )\n", 424 | " images = preprocess_fn(bytes_inputs)\n", 425 | "\n", 426 | " probs = m_call(**images)\n", 427 | " indices = tf.argmax(probs, axis=1)\n", 428 | " pred_source = tf.gather(params=labels, indices=indices)\n", 429 | " pred_confidence = tf.reduce_max(probs, axis=1)\n", 430 | " return {\"label\": pred_source, \"confidence\": pred_confidence}\n", 431 | "\n", 432 | " return serving_fn\n", 433 | "\n", 434 | "\n", 435 | "# TFX Trainer will call this function.\n", 436 | "def run_fn(fn_args: tfx.components.FnArgs):\n", 437 | " \"\"\"Train the model based on given args.\n", 438 | "\n", 439 | " Args:\n", 440 | " fn_args: Holds args used to train the model as name/value pairs.\n", 441 | " \"\"\"\n", 442 | " train_dataset = _input_fn(fn_args.train_files, batch_size=_TRAIN_BATCH_SIZE)\n", 443 | " eval_dataset = _input_fn(fn_args.eval_files, batch_size=_EVAL_BATCH_SIZE)\n", 444 | "\n", 445 | " model = _make_keras_model()\n", 446 | " model.fit(\n", 447 | " train_dataset,\n", 448 | " steps_per_epoch=fn_args.train_steps,\n", 449 | " validation_data=eval_dataset,\n", 450 | " validation_steps=fn_args.eval_steps,\n", 451 | " epochs=_EPOCHS,\n", 452 | " )\n", 453 | " _, acc = model.evaluate(eval_dataset, steps=fn_args.eval_steps)\n", 454 | " logging.info(f\"Validation accuracy: {round(acc * 100, 2)}%\")\n", 455 | " # The result of the training should be saved in `fn_args.serving_model_dir`\n", 456 | " # directory.\n", 457 | " tf.saved_model.save(\n", 458 | " model,\n", 459 | " fn_args.serving_model_dir,\n", 460 | " signatures={\"serving_default\": _model_exporter(model)},\n", 461 | " )\n" 462 | ], 463 | "execution_count": null, 464 | "outputs": [ 465 | { 466 | "output_type": "stream", 467 | "text": [ 468 | "Overwriting flower_densenet_trainer.py\n" 469 | ], 470 | "name": "stdout" 471 | } 472 | ] 473 | }, 474 | { 475 | "cell_type": "code", 476 | "metadata": { 477 | "colab": { 478 | "base_uri": "https://localhost:8080/" 479 | }, 480 | "id": "tIFDOw5cmqd1", 481 | "outputId": "c26a9017-b14f-4f1a-ce5a-4ef98ba96270" 482 | }, 483 | "source": [ 484 | "%%writefile {_trainer_mobilenet_module_file}\n", 485 | "\n", 486 | "from typing import List\n", 487 | "from absl import logging\n", 488 | "from tensorflow import keras\n", 489 | "from tfx import v1 as tfx\n", 490 | "import tensorflow as tf\n", 491 | "\n", 492 | "\n", 493 | "_IMAGE_FEATURES = {\n", 494 | " \"image\": tf.io.FixedLenFeature([], tf.string),\n", 495 | " \"class\": tf.io.FixedLenFeature([], tf.int64),\n", 496 | " \"one_hot_class\": tf.io.VarLenFeature(tf.float32),\n", 497 | "}\n", 498 | "\n", 499 | "_INPUT_SHAPE = (224, 224, 3)\n", 500 | "_TRAIN_BATCH_SIZE = 64\n", 501 | "_EVAL_BATCH_SIZE = 64\n", 502 | "_EPOCHS = 2\n", 503 | "\n", 504 | "\n", 505 | "def _parse_fn(example):\n", 506 | " example = tf.io.parse_single_example(example, _IMAGE_FEATURES)\n", 507 | " image = tf.image.decode_jpeg(example[\"image\"], channels=3)\n", 508 | " class_label = tf.cast(example[\"class\"], tf.int32)\n", 509 | " return image, class_label\n", 510 | "\n", 511 | "\n", 512 | "def _input_fn(file_pattern: List[str], batch_size: int) -> tf.data.Dataset:\n", 513 | " \"\"\"Generates features and label for training.\n", 514 | "\n", 515 | " Args:\n", 516 | " file_pattern: List of paths or patterns of input tfrecord files.\n", 517 | " batch_size: representing the number of consecutive elements of returned\n", 518 | " dataset to combine in a single batch.\n", 519 | "\n", 520 | " Returns:\n", 521 | " A dataset that contains (features, indices) tuple where features is a\n", 522 | " dictionary of Tensors, and indices is a single Tensor of label indices.\n", 523 | " \"\"\"\n", 524 | " logging.info(f\"Reading data from: {file_pattern}\")\n", 525 | " tfrecord_filenames = tf.io.gfile.glob(file_pattern[0] + \".gz\")\n", 526 | " dataset = tf.data.TFRecordDataset(tfrecord_filenames, compression_type=\"GZIP\")\n", 527 | " dataset = dataset.map(_parse_fn).batch(batch_size)\n", 528 | " return dataset.repeat()\n", 529 | "\n", 530 | "\n", 531 | "def _make_keras_model() -> tf.keras.Model:\n", 532 | " \"\"\"Creates a MobileNetV3-based model for classifying flowers data.\n", 533 | "\n", 534 | " Returns:\n", 535 | " A Keras Model.\n", 536 | " \"\"\"\n", 537 | " inputs = keras.Input(shape=_INPUT_SHAPE)\n", 538 | " base_model = keras.applications.MobileNetV3Small(\n", 539 | " include_top=False, input_shape=_INPUT_SHAPE, pooling=\"avg\"\n", 540 | " )\n", 541 | " base_model.trainable = False\n", 542 | " x = keras.applications.mobilenet_v3.preprocess_input(inputs)\n", 543 | " x = base_model(\n", 544 | " x, training=False\n", 545 | " ) # Ensures BatchNorm runs in inference model in this model\n", 546 | " outputs = keras.layers.Dense(5, activation=\"softmax\")(x)\n", 547 | " model = keras.Model(inputs, outputs)\n", 548 | "\n", 549 | " model.compile(\n", 550 | " optimizer=keras.optimizers.Adam(),\n", 551 | " loss=tf.keras.losses.SparseCategoricalCrossentropy(),\n", 552 | " metrics=[keras.metrics.SparseCategoricalAccuracy()],\n", 553 | " )\n", 554 | "\n", 555 | " model.summary(print_fn=logging.info)\n", 556 | " return model\n", 557 | "\n", 558 | "\n", 559 | "# TFX Trainer will call this function.\n", 560 | "def run_fn(fn_args: tfx.components.FnArgs):\n", 561 | " \"\"\"Train the model based on given args.\n", 562 | "\n", 563 | " Args:\n", 564 | " fn_args: Holds args used to train the model as name/value pairs.\n", 565 | " \"\"\"\n", 566 | " train_dataset = _input_fn(fn_args.train_files, batch_size=_TRAIN_BATCH_SIZE)\n", 567 | " eval_dataset = _input_fn(fn_args.eval_files, batch_size=_EVAL_BATCH_SIZE)\n", 568 | "\n", 569 | " model = _make_keras_model()\n", 570 | " model.fit(\n", 571 | " train_dataset,\n", 572 | " steps_per_epoch=fn_args.train_steps,\n", 573 | " validation_data=eval_dataset,\n", 574 | " validation_steps=fn_args.eval_steps,\n", 575 | " epochs=_EPOCHS,\n", 576 | " )\n", 577 | " _, acc = model.evaluate(eval_dataset, steps=fn_args.eval_steps)\n", 578 | " logging.info(f\"Validation accuracy: {round(acc * 100, 2)}%\")\n", 579 | "\n", 580 | " # Convert the model.\n", 581 | " converter = tf.lite.TFLiteConverter.from_keras_model(model)\n", 582 | " converter.optimizations = [tf.lite.Optimize.DEFAULT]\n", 583 | " tflite_model = converter.convert()\n", 584 | "\n", 585 | " # Save the model.\n", 586 | " # The result of the training should be saved in `fn_args.serving_model_dir` directory.\n", 587 | " with tf.io.gfile.GFile(fn_args.serving_model_dir + \"/model.tflite\", \"wb\") as f:\n", 588 | " f.write(tflite_model)\n" 589 | ], 590 | "execution_count": null, 591 | "outputs": [ 592 | { 593 | "output_type": "stream", 594 | "text": [ 595 | "Overwriting flower_mobilenet_trainer.py\n" 596 | ], 597 | "name": "stdout" 598 | } 599 | ] 600 | }, 601 | { 602 | "cell_type": "code", 603 | "metadata": { 604 | "colab": { 605 | "base_uri": "https://localhost:8080/" 606 | }, 607 | "id": "Gp7RTxpMoLe9", 608 | "outputId": "1233ba34-5b4a-4a85-8386-122c6e9b84f9" 609 | }, 610 | "source": [ 611 | "!gsutil cp -r *.py {MODULE_ROOT}/\n", 612 | "!gsutil ls -lh {MODULE_ROOT}/" 613 | ], 614 | "execution_count": null, 615 | "outputs": [ 616 | { 617 | "output_type": "stream", 618 | "text": [ 619 | "Copying file://firebase_publisher.py [Content-Type=text/x-python]...\n", 620 | "Copying file://flower_densenet_trainer.py [Content-Type=text/x-python]... \n", 621 | "Copying file://flower_mobilenet_trainer.py [Content-Type=text/x-python]... \n", 622 | "Copying file://vertex_deployer.py [Content-Type=text/x-python]... \n", 623 | "/ [4 files][ 12.5 KiB/ 12.5 KiB] \n", 624 | "==> NOTE: You are performing a sequence of gsutil operations that may\n", 625 | "run significantly faster if you instead use gsutil -m cp ... Please\n", 626 | "see the -m section under \"gsutil help options\" for further information\n", 627 | "about when gsutil -m can be advantageous.\n", 628 | "\n", 629 | "Copying file://vertex_uploader.py [Content-Type=text/x-python]...\n", 630 | "/ [5 files][ 13.8 KiB/ 13.8 KiB] \n", 631 | "Operation completed over 5 objects/13.8 KiB. \n", 632 | " 3.16 KiB 2021-08-05T16:02:35Z gs://demo-experiments-gde-csp/pipeline_module/two-way-vertex-pipelines5/firebase_publisher.py\n", 633 | " 4.53 KiB 2021-08-05T16:02:35Z gs://demo-experiments-gde-csp/pipeline_module/two-way-vertex-pipelines5/flower_densenet_trainer.py\n", 634 | " 3.37 KiB 2021-08-05T16:02:35Z gs://demo-experiments-gde-csp/pipeline_module/two-way-vertex-pipelines5/flower_mobilenet_trainer.py\n", 635 | " 1.49 KiB 2021-08-05T16:02:35Z gs://demo-experiments-gde-csp/pipeline_module/two-way-vertex-pipelines5/vertex_deployer.py\n", 636 | " 1.31 KiB 2021-08-05T16:02:35Z gs://demo-experiments-gde-csp/pipeline_module/two-way-vertex-pipelines5/vertex_uploader.py\n", 637 | "TOTAL: 5 objects, 14181 bytes (13.85 KiB)\n" 638 | ], 639 | "name": "stdout" 640 | } 641 | ] 642 | }, 643 | { 644 | "cell_type": "markdown", 645 | "metadata": { 646 | "id": "-qX4M-C6X4LF" 647 | }, 648 | "source": [ 649 | "## Create the pipeline\n", 650 | "\n", 651 | "To create the end-to-end pipeline, we will need to write two custom TFX components:\n", 652 | "\n", 653 | "* One will take the pushed model from `Pusher` and will upload it to Vertex AI. \n", 654 | "* One will deploy the uploaded model to an Endpoint.\n", 655 | "\n", 656 | "We will then need to build a Docker image using these custom components and serve the pipeline using this image. We will use [Cloud Build](https://cloud.google.com/build) in order to build the Docker image. " 657 | ] 658 | }, 659 | { 660 | "cell_type": "code", 661 | "metadata": { 662 | "id": "SYaypU11e-cG" 663 | }, 664 | "source": [ 665 | "_vertex_uploader_module_file = 'vertex_uploader.py'\n", 666 | "_vertex_deployer_module_file = 'vertex_deployer.py'\n", 667 | "_firebase_publisher_module_file = 'firebase_publisher.py'" 668 | ], 669 | "execution_count": null, 670 | "outputs": [] 671 | }, 672 | { 673 | "cell_type": "code", 674 | "metadata": { 675 | "colab": { 676 | "base_uri": "https://localhost:8080/" 677 | }, 678 | "id": "pBj_WnLw6oRR", 679 | "outputId": "83adae19-5209-4032-bf94-42f1d668e947" 680 | }, 681 | "source": [ 682 | "%%writefile {_vertex_uploader_module_file}\n", 683 | "\n", 684 | "import os\n", 685 | "import tensorflow as tf\n", 686 | "\n", 687 | "from tfx.dsl.component.experimental.decorators import component\n", 688 | "from tfx.dsl.component.experimental.annotations import Parameter\n", 689 | "from tfx.types.standard_artifacts import String\n", 690 | "from google.cloud import aiplatform as vertex_ai\n", 691 | "from tfx import v1 as tfx\n", 692 | "from absl import logging\n", 693 | "\n", 694 | "\n", 695 | "@component\n", 696 | "def VertexUploader(\n", 697 | " project: Parameter[str],\n", 698 | " region: Parameter[str],\n", 699 | " model_display_name: Parameter[str],\n", 700 | " pushed_model_location: Parameter[str],\n", 701 | " serving_image_uri: Parameter[str],\n", 702 | " uploaded_model: tfx.dsl.components.OutputArtifact[String],\n", 703 | "):\n", 704 | "\n", 705 | " vertex_ai.init(project=project, location=region)\n", 706 | "\n", 707 | " pushed_model_dir = os.path.join(\n", 708 | " pushed_model_location, tf.io.gfile.listdir(pushed_model_location)[-1]\n", 709 | " )\n", 710 | "\n", 711 | " logging.info(f\"Model registry location: {pushed_model_dir}\")\n", 712 | "\n", 713 | " vertex_model = vertex_ai.Model.upload(\n", 714 | " display_name=model_display_name,\n", 715 | " artifact_uri=pushed_model_dir,\n", 716 | " serving_container_image_uri=serving_image_uri,\n", 717 | " parameters_schema_uri=None,\n", 718 | " instance_schema_uri=None,\n", 719 | " explanation_metadata=None,\n", 720 | " explanation_parameters=None,\n", 721 | " )\n", 722 | "\n", 723 | " uploaded_model.set_string_custom_property(\n", 724 | " \"model_resource_name\", str(vertex_model.resource_name)\n", 725 | " )\n", 726 | " logging.info(f\"Model resource: {str(vertex_model.resource_name)}\")\n" 727 | ], 728 | "execution_count": null, 729 | "outputs": [ 730 | { 731 | "output_type": "stream", 732 | "text": [ 733 | "Overwriting vertex_uploader.py\n" 734 | ], 735 | "name": "stdout" 736 | } 737 | ] 738 | }, 739 | { 740 | "cell_type": "code", 741 | "metadata": { 742 | "colab": { 743 | "base_uri": "https://localhost:8080/" 744 | }, 745 | "id": "awNiJiZ1fZsp", 746 | "outputId": "dd732111-6210-47ef-ef63-32bd907749dd" 747 | }, 748 | "source": [ 749 | "%%writefile {_vertex_deployer_module_file}\n", 750 | "\n", 751 | "from tfx.dsl.component.experimental.decorators import component\n", 752 | "from tfx.dsl.component.experimental.annotations import Parameter\n", 753 | "from tfx.types.standard_artifacts import String\n", 754 | "from google.cloud import aiplatform as vertex_ai\n", 755 | "from tfx import v1 as tfx\n", 756 | "from absl import logging\n", 757 | "\n", 758 | "\n", 759 | "@component\n", 760 | "def VertexDeployer(\n", 761 | " project: Parameter[str],\n", 762 | " region: Parameter[str],\n", 763 | " model_display_name: Parameter[str],\n", 764 | " deployed_model_display_name: Parameter[str],\n", 765 | "):\n", 766 | "\n", 767 | " logging.info(f\"Endpoint display: {deployed_model_display_name}\")\n", 768 | " vertex_ai.init(project=project, location=region)\n", 769 | "\n", 770 | " endpoints = vertex_ai.Endpoint.list(\n", 771 | " filter=f\"display_name={deployed_model_display_name}\", order_by=\"update_time\"\n", 772 | " )\n", 773 | "\n", 774 | " if len(endpoints) > 0:\n", 775 | " logging.info(f\"Endpoint {deployed_model_display_name} already exists.\")\n", 776 | " endpoint = endpoints[-1]\n", 777 | " else:\n", 778 | " endpoint = vertex_ai.Endpoint.create(deployed_model_display_name)\n", 779 | "\n", 780 | " model = vertex_ai.Model.list(\n", 781 | " filter=f\"display_name={model_display_name}\", order_by=\"update_time\"\n", 782 | " )[-1]\n", 783 | "\n", 784 | " endpoint = vertex_ai.Endpoint.list(\n", 785 | " filter=f\"display_name={deployed_model_display_name}\", order_by=\"update_time\"\n", 786 | " )[-1]\n", 787 | "\n", 788 | " deployed_model = endpoint.deploy(\n", 789 | " model=model,\n", 790 | " # Syntax from here: https://git.io/JBQDP\n", 791 | " traffic_split={\"0\": 100},\n", 792 | " machine_type=\"n1-standard-4\",\n", 793 | " min_replica_count=1,\n", 794 | " max_replica_count=1,\n", 795 | " )\n", 796 | "\n", 797 | " logging.info(f\"Model deployed to: {deployed_model}\")\n" 798 | ], 799 | "execution_count": null, 800 | "outputs": [ 801 | { 802 | "output_type": "stream", 803 | "text": [ 804 | "Overwriting vertex_deployer.py\n" 805 | ], 806 | "name": "stdout" 807 | } 808 | ] 809 | }, 810 | { 811 | "cell_type": "code", 812 | "metadata": { 813 | "id": "Bp7GMaUlbfTz", 814 | "outputId": "4f1458c5-04e6-4681-86ab-5c1d7a8047f3" 815 | }, 816 | "source": [ 817 | "%%writefile {_firebase_publisher_module_file}\n", 818 | "\n", 819 | "from tfx import types\n", 820 | "from tfx.dsl.component.experimental.decorators import component\n", 821 | "from tfx.dsl.component.experimental.annotations import Parameter\n", 822 | "from tfx import v1 as tfx\n", 823 | "from absl import logging\n", 824 | "\n", 825 | "import firebase_admin\n", 826 | "from firebase_admin import ml\n", 827 | "from firebase_admin import storage\n", 828 | "from firebase_admin import credentials\n", 829 | "from google.cloud import storage as gcs_storage\n", 830 | "\n", 831 | "\n", 832 | "@component\n", 833 | "def FirebasePublisher(\n", 834 | " pushed_model: tfx.dsl.components.InputArtifact[\n", 835 | " tfx.types.standard_artifacts.PushedModel\n", 836 | " ],\n", 837 | " credential_uri: Parameter[str],\n", 838 | " firebase_dest_gcs_bucket: Parameter[str],\n", 839 | " model_display_name: Parameter[str],\n", 840 | " model_tag: Parameter[str],\n", 841 | ") -> tfx.dsl.components.OutputDict(result=str):\n", 842 | " model_uri = f\"{pushed_model.uri}/model.tflite\"\n", 843 | "\n", 844 | " assert model_uri.split(\"://\")[0] == \"gs\"\n", 845 | " assert credential_uri.split(\"://\")[0] == \"gs\"\n", 846 | "\n", 847 | " # create gcs client instance\n", 848 | " gcs_client = gcs_storage.Client()\n", 849 | "\n", 850 | " # get credential for firebase\n", 851 | " credential_gcs_bucket = credential_uri.split(\"//\")[1].split(\"/\")[0]\n", 852 | " credential_blob_path = \"/\".join(credential_uri.split(\"//\")[1].split(\"/\")[1:])\n", 853 | "\n", 854 | " bucket = gcs_client.bucket(credential_gcs_bucket)\n", 855 | " blob = bucket.blob(credential_blob_path)\n", 856 | " blob.download_to_filename(\"credential.json\")\n", 857 | " logging.info(f\"download credential.json from {credential_uri} is completed\")\n", 858 | "\n", 859 | " # get tflite model file\n", 860 | " tflite_gcs_bucket = model_uri.split(\"//\")[1].split(\"/\")[0]\n", 861 | " tflite_blob_path = \"/\".join(model_uri.split(\"//\")[1].split(\"/\")[1:])\n", 862 | "\n", 863 | " bucket = gcs_client.bucket(tflite_gcs_bucket)\n", 864 | " blob = bucket.blob(tflite_blob_path)\n", 865 | " blob.download_to_filename(\"model.tflite\")\n", 866 | " logging.info(f\"download model.tflite from {model_uri} is completed\")\n", 867 | "\n", 868 | " firebase_admin.initialize_app(\n", 869 | " credentials.Certificate(\"credential.json\"),\n", 870 | " options={\"storageBucket\": firebase_dest_gcs_bucket},\n", 871 | " )\n", 872 | " logging.info(\"firebase_admin initialize app is completed\")\n", 873 | "\n", 874 | " model_list = ml.list_models(list_filter=f\"display_name={model_display_name}\")\n", 875 | " # update\n", 876 | " if len(model_list.models) > 0:\n", 877 | " # get the first match model\n", 878 | " model = model_list.models[0]\n", 879 | " source = ml.TFLiteGCSModelSource.from_tflite_model_file(\"model.tflite\")\n", 880 | " model.model_format = ml.TFLiteFormat(model_source=source)\n", 881 | "\n", 882 | " updated_model = ml.update_model(model)\n", 883 | " ml.publish_model(updated_model.model_id)\n", 884 | "\n", 885 | " logging.info(\"model exists, so update it in FireBase ML\")\n", 886 | " return {\"result\": \"model updated\"}\n", 887 | " # create\n", 888 | " else:\n", 889 | " # load a tflite file and upload it to Cloud Storage\n", 890 | " source = ml.TFLiteGCSModelSource.from_tflite_model_file(\"model.tflite\")\n", 891 | "\n", 892 | " # create the model object\n", 893 | " tflite_format = ml.TFLiteFormat(model_source=source)\n", 894 | " model = ml.Model(\n", 895 | " display_name=model_display_name,\n", 896 | " tags=[model_tag],\n", 897 | " model_format=tflite_format,\n", 898 | " )\n", 899 | "\n", 900 | " # Add the model to your Firebase project and publish it\n", 901 | " new_model = ml.create_model(model)\n", 902 | " ml.publish_model(new_model.model_id)\n", 903 | "\n", 904 | " logging.info(\"model doesn exists, so create one in FireBase ML\")\n", 905 | " return {\"result\": \"model created\"}\n" 906 | ], 907 | "execution_count": null, 908 | "outputs": [ 909 | { 910 | "output_type": "stream", 911 | "text": [ 912 | "Overwriting firebase_publisher.py\n" 913 | ], 914 | "name": "stdout" 915 | } 916 | ] 917 | }, 918 | { 919 | "cell_type": "markdown", 920 | "metadata": { 921 | "id": "joRJYojpP8-H" 922 | }, 923 | "source": [ 924 | "Create a package called `custom_components` and copy the modules we just wrote. " 925 | ] 926 | }, 927 | { 928 | "cell_type": "code", 929 | "metadata": { 930 | "id": "GYiXX9pOOOAX" 931 | }, 932 | "source": [ 933 | "!mkdir -p ./custom_components\n", 934 | "!touch ./custom_components/__init__.py\n", 935 | "!cp -r {_vertex_uploader_module_file} {_vertex_deployer_module_file} {_firebase_publisher_module_file} custom_components" 936 | ], 937 | "execution_count": null, 938 | "outputs": [] 939 | }, 940 | { 941 | "cell_type": "code", 942 | "metadata": { 943 | "colab": { 944 | "base_uri": "https://localhost:8080/" 945 | }, 946 | "id": "UiLiGjR3OOAX", 947 | "outputId": "ec5a7f09-df8c-4151-d3c6-51c20528d0d6" 948 | }, 949 | "source": [ 950 | "!ls -lh custom_components" 951 | ], 952 | "execution_count": null, 953 | "outputs": [ 954 | { 955 | "output_type": "stream", 956 | "text": [ 957 | "total 16K\n", 958 | "-rw-r--r-- 1 jupyter jupyter 0 Aug 5 14:20 __init__.py\n", 959 | "drwxr-xr-x 2 jupyter jupyter 4.0K Aug 5 06:50 __pycache__\n", 960 | "-rw-r--r-- 1 jupyter jupyter 3.2K Aug 5 14:20 firebase_publisher.py\n", 961 | "-rw-r--r-- 1 jupyter jupyter 1.5K Aug 5 14:20 vertex_deployer.py\n", 962 | "-rw-r--r-- 1 jupyter jupyter 1.4K Aug 5 14:20 vertex_uploader.py\n" 963 | ], 964 | "name": "stdout" 965 | } 966 | ] 967 | }, 968 | { 969 | "cell_type": "markdown", 970 | "metadata": { 971 | "id": "DcFpCCSbQDol" 972 | }, 973 | "source": [ 974 | "### `Dockerfile` configuration " 975 | ] 976 | }, 977 | { 978 | "cell_type": "code", 979 | "metadata": { 980 | "colab": { 981 | "base_uri": "https://localhost:8080/" 982 | }, 983 | "id": "h9j867lPfyM8", 984 | "outputId": "c4858955-56ef-444f-d683-e03630c288f7" 985 | }, 986 | "source": [ 987 | "DATASET_DISPLAY_NAME = \"flowers\"\n", 988 | "VERSION = \"tfx-1-0-0\"\n", 989 | "TFX_IMAGE_URI = f\"gcr.io/{GOOGLE_CLOUD_PROJECT}/{DATASET_DISPLAY_NAME}:{VERSION}\"\n", 990 | "print(f\"URI of the custom image: {TFX_IMAGE_URI}\")" 991 | ], 992 | "execution_count": null, 993 | "outputs": [ 994 | { 995 | "output_type": "stream", 996 | "text": [ 997 | "URI of the custom image: gcr.io/gcp-ml-172005/flowers:tfx-1-0-0\n" 998 | ], 999 | "name": "stdout" 1000 | } 1001 | ] 1002 | }, 1003 | { 1004 | "cell_type": "code", 1005 | "metadata": { 1006 | "colab": { 1007 | "base_uri": "https://localhost:8080/" 1008 | }, 1009 | "id": "F3YcE0xUgOh8", 1010 | "outputId": "abfa72c3-dfd9-4a8f-a1b6-cf8e94b36f85" 1011 | }, 1012 | "source": [ 1013 | "%%writefile Dockerfile\n", 1014 | "\n", 1015 | "FROM gcr.io/tfx-oss-public/tfx:1.0.0\n", 1016 | "RUN mkdir -p custom_components\n", 1017 | "COPY custom_components/* ./custom_components/\n", 1018 | "RUN pip install --upgrade google-cloud-aiplatform google-cloud-storage firebase-admin" 1019 | ], 1020 | "execution_count": null, 1021 | "outputs": [ 1022 | { 1023 | "output_type": "stream", 1024 | "text": [ 1025 | "Overwriting Dockerfile\n" 1026 | ], 1027 | "name": "stdout" 1028 | } 1029 | ] 1030 | }, 1031 | { 1032 | "cell_type": "code", 1033 | "metadata": { 1034 | "id": "hEydNZrHg6_J" 1035 | }, 1036 | "source": [ 1037 | "!gcloud builds submit --tag $TFX_IMAGE_URI . --timeout=15m --machine-type=e2-highcpu-8" 1038 | ], 1039 | "execution_count": null, 1040 | "outputs": [] 1041 | }, 1042 | { 1043 | "cell_type": "code", 1044 | "metadata": { 1045 | "id": "sEbNM9CeERX2" 1046 | }, 1047 | "source": [ 1048 | "# Specify training worker configurations. To minimize costs we can even specify two\n", 1049 | "# different configurations: a beefier machine for the Endpoint model and slightly less\n", 1050 | "# powerful machine for the mobile model.\n", 1051 | "TRAINING_JOB_SPEC = {\n", 1052 | " 'project': GOOGLE_CLOUD_PROJECT,\n", 1053 | " 'worker_pool_specs': [{\n", 1054 | " 'machine_spec': {\n", 1055 | " 'machine_type': 'n1-standard-4',\n", 1056 | " 'accelerator_type': 'NVIDIA_TESLA_K80',\n", 1057 | " 'accelerator_count': 1\n", 1058 | " },\n", 1059 | " 'replica_count': 1,\n", 1060 | " 'container_spec': {\n", 1061 | " 'image_uri': 'gcr.io/tfx-oss-public/tfx:{}'.format(tfx.__version__),\n", 1062 | " },\n", 1063 | " }],\n", 1064 | "}" 1065 | ], 1066 | "execution_count": null, 1067 | "outputs": [] 1068 | }, 1069 | { 1070 | "cell_type": "code", 1071 | "metadata": { 1072 | "id": "MzZdpUjypcsT" 1073 | }, 1074 | "source": [ 1075 | "from datetime import datetime\n", 1076 | "\n", 1077 | "TIMESTAMP = datetime.now().strftime(\"%Y%m%d%H%M%S\")" 1078 | ], 1079 | "execution_count": null, 1080 | "outputs": [] 1081 | }, 1082 | { 1083 | "cell_type": "code", 1084 | "metadata": { 1085 | "id": "ln1cvbcfphA9" 1086 | }, 1087 | "source": [ 1088 | "from custom_components.vertex_uploader import VertexUploader\n", 1089 | "from custom_components.vertex_deployer import VertexDeployer\n", 1090 | "from custom_components.firebase_publisher import FirebasePublisher\n", 1091 | "\n", 1092 | "\n", 1093 | "def _create_pipeline(\n", 1094 | " pipeline_name: str,\n", 1095 | " pipeline_root: str,\n", 1096 | " data_root: str,\n", 1097 | " densenet_module_file: str,\n", 1098 | " mobilenet_module_file: str,\n", 1099 | " serving_model_dir: str,\n", 1100 | " firebase_crediential_path: str,\n", 1101 | " firebase_gcs_bucket: str,\n", 1102 | " project_id: str,\n", 1103 | " region: str,\n", 1104 | ") -> tfx.dsl.Pipeline:\n", 1105 | " \"\"\"Creates a three component flowers pipeline with TFX.\"\"\"\n", 1106 | " # Brings data into the pipeline.\n", 1107 | " # input_base: gs://flowers-public/tfrecords-jpeg-224x224\n", 1108 | " example_gen = tfx.components.ImportExampleGen(input_base=data_root)\n", 1109 | "\n", 1110 | " # Uses user-provided Python function that trains a model.\n", 1111 | " densenet_trainer = tfx.extensions.google_cloud_ai_platform.Trainer(\n", 1112 | " module_file=densenet_module_file,\n", 1113 | " examples=example_gen.outputs[\"examples\"],\n", 1114 | " train_args=tfx.proto.TrainArgs(num_steps=52),\n", 1115 | " eval_args=tfx.proto.EvalArgs(num_steps=5),\n", 1116 | " custom_config={\n", 1117 | " tfx.extensions.google_cloud_ai_platform.ENABLE_UCAIP_KEY: True,\n", 1118 | " tfx.extensions.google_cloud_ai_platform.UCAIP_REGION_KEY: region,\n", 1119 | " tfx.extensions.google_cloud_ai_platform.TRAINING_ARGS_KEY: TRAINING_JOB_SPEC,\n", 1120 | " \"use_gpu\": True,\n", 1121 | " },\n", 1122 | " ).with_id(\"densenet_trainer\")\n", 1123 | "\n", 1124 | " # Pushes the model to a filesystem destination.\n", 1125 | " pushed_model_location = os.path.join(serving_model_dir, \"densenet\")\n", 1126 | " densnet_pusher = tfx.components.Pusher(\n", 1127 | " model=densenet_trainer.outputs[\"model\"],\n", 1128 | " push_destination=tfx.proto.PushDestination(\n", 1129 | " filesystem=tfx.proto.PushDestination.Filesystem(\n", 1130 | " base_directory=pushed_model_location\n", 1131 | " )\n", 1132 | " ),\n", 1133 | " ).with_id(\"densnet_pusher\")\n", 1134 | "\n", 1135 | " # Vertex AI upload.\n", 1136 | " model_display_name = \"densenet_flowers_latest\"\n", 1137 | " uploader = VertexUploader(\n", 1138 | " project=project_id,\n", 1139 | " region=region,\n", 1140 | " model_display_name=model_display_name,\n", 1141 | " pushed_model_location=pushed_model_location,\n", 1142 | " serving_image_uri=\"us-docker.pkg.dev/vertex-ai/prediction/tf2-cpu.2-5:latest\",\n", 1143 | " ).with_id(\"vertex_uploader\")\n", 1144 | " uploader.add_upstream_node(densnet_pusher)\n", 1145 | "\n", 1146 | " # Create an endpoint.\n", 1147 | " deployer = VertexDeployer(\n", 1148 | " project=project_id,\n", 1149 | " region=region,\n", 1150 | " model_display_name=model_display_name,\n", 1151 | " deployed_model_display_name=model_display_name + \"_\" + TIMESTAMP,\n", 1152 | " ).with_id(\"vertex_deployer\")\n", 1153 | " deployer.add_upstream_node(uploader)\n", 1154 | "\n", 1155 | " # We repeat the steps for the MobileNet model too but this time we won't\n", 1156 | " # be creating an Endpoint. We will first convert the Keras model to TFLite\n", 1157 | " # and then push it to Firebase for better operability. \n", 1158 | " mobilenet_trainer = tfx.extensions.google_cloud_ai_platform.Trainer(\n", 1159 | " module_file=mobilenet_module_file,\n", 1160 | " examples=example_gen.outputs[\"examples\"],\n", 1161 | " train_args=tfx.proto.TrainArgs(num_steps=52),\n", 1162 | " eval_args=tfx.proto.EvalArgs(num_steps=5),\n", 1163 | " custom_config={\n", 1164 | " tfx.extensions.google_cloud_ai_platform.ENABLE_UCAIP_KEY: True,\n", 1165 | " tfx.extensions.google_cloud_ai_platform.UCAIP_REGION_KEY: region,\n", 1166 | " tfx.extensions.google_cloud_ai_platform.TRAINING_ARGS_KEY: TRAINING_JOB_SPEC,\n", 1167 | " \"use_gpu\": True,\n", 1168 | " },\n", 1169 | " ).with_id(\"mobilenet_trainer\")\n", 1170 | "\n", 1171 | " pushed_location_mobilenet = os.path.join(serving_model_dir, \"mobilenet\")\n", 1172 | " mobilenet_pusher = tfx.components.Pusher(\n", 1173 | " model=mobilenet_trainer.outputs[\"model\"],\n", 1174 | " push_destination=tfx.proto.PushDestination(\n", 1175 | " filesystem=tfx.proto.PushDestination.Filesystem(\n", 1176 | " base_directory=pushed_location_mobilenet\n", 1177 | " )\n", 1178 | " ),\n", 1179 | " ).with_id(\"mobilenet_pusher\")\n", 1180 | "\n", 1181 | " firebase_publisher = FirebasePublisher(\n", 1182 | " pushed_model=mobilenet_pusher.outputs[\"pushed_model\"],\n", 1183 | " credential_uri=firebase_crediential_path,\n", 1184 | " firebase_dest_gcs_bucket=firebase_gcs_bucket,\n", 1185 | " model_display_name=model_display_name,\n", 1186 | " model_tag=\"mobilenet\",\n", 1187 | " ).with_id(\"firebase_publisher\")\n", 1188 | "\n", 1189 | " # Following components will be included in the pipeline.\n", 1190 | " components = [\n", 1191 | " example_gen,\n", 1192 | " densenet_trainer,\n", 1193 | " densnet_pusher,\n", 1194 | " uploader,\n", 1195 | " deployer,\n", 1196 | " mobilenet_trainer,\n", 1197 | " mobilenet_pusher,\n", 1198 | " firebase_publisher,\n", 1199 | " ]\n", 1200 | "\n", 1201 | " return tfx.dsl.Pipeline(\n", 1202 | " pipeline_name=pipeline_name, pipeline_root=pipeline_root, components=components\n", 1203 | " )\n" 1204 | ], 1205 | "execution_count": null, 1206 | "outputs": [] 1207 | }, 1208 | { 1209 | "cell_type": "markdown", 1210 | "metadata": { 1211 | "id": "IFdlslfOX54z" 1212 | }, 1213 | "source": [ 1214 | "## Compile the pipeline" 1215 | ] 1216 | }, 1217 | { 1218 | "cell_type": "code", 1219 | "metadata": { 1220 | "id": "-AY5Z2tbsbwE" 1221 | }, 1222 | "source": [ 1223 | "PIPELINE_DEFINITION_FILE = PIPELINE_NAME + '_pipeline.json'\n", 1224 | "\n", 1225 | "# Important: We need to pass the custom Docker image URI to the\n", 1226 | "# `KubeflowV2DagRunnerConfig` to take effect.\n", 1227 | "runner = tfx.orchestration.experimental.KubeflowV2DagRunner(\n", 1228 | " config=tfx.orchestration.experimental.KubeflowV2DagRunnerConfig(default_image=TFX_IMAGE_URI),\n", 1229 | " output_filename=PIPELINE_DEFINITION_FILE)\n", 1230 | "\n", 1231 | "_ = runner.run(\n", 1232 | " _create_pipeline(\n", 1233 | " pipeline_name=PIPELINE_NAME,\n", 1234 | " pipeline_root=PIPELINE_ROOT,\n", 1235 | " data_root=DATA_ROOT,\n", 1236 | " densenet_module_file=os.path.join(MODULE_ROOT, _trainer_densenet_module_file),\n", 1237 | " mobilenet_module_file=os.path.join(MODULE_ROOT, _trainer_mobilenet_module_file),\n", 1238 | " serving_model_dir=SERVING_MODEL_DIR,\n", 1239 | " firebase_crediential_path=FIREBASE_CREDENTIAL_PATH,\n", 1240 | " firebase_gcs_bucket=FIREBASE_GCS_BUCKET, \n", 1241 | " project_id=GOOGLE_CLOUD_PROJECT,\n", 1242 | " region=GOOGLE_CLOUD_REGION\n", 1243 | " )\n", 1244 | ")" 1245 | ], 1246 | "execution_count": null, 1247 | "outputs": [] 1248 | }, 1249 | { 1250 | "cell_type": "markdown", 1251 | "metadata": { 1252 | "id": "ocHBJaR_X7x2" 1253 | }, 1254 | "source": [ 1255 | "## Submit the pipeline for execution to Vertex AI\n", 1256 | "\n", 1257 | "Generally, it's a good idea to first do a local run of the end-to-end pipeline before submitting it an online orchestrator. We can use `tfx.orchestration.LocalDagRunner()` for that but for the purposes of this notebook we won't be doing that. " 1258 | ] 1259 | }, 1260 | { 1261 | "cell_type": "code", 1262 | "metadata": { 1263 | "colab": { 1264 | "base_uri": "https://localhost:8080/", 1265 | "height": 51 1266 | }, 1267 | "id": "3elrtDOus83z", 1268 | "outputId": "7eb13f73-602c-44f0-8385-ac8c097905ea" 1269 | }, 1270 | "source": [ 1271 | "from kfp.v2.google import client\n", 1272 | "\n", 1273 | "pipelines_client = client.AIPlatformClient(\n", 1274 | " project_id=GOOGLE_CLOUD_PROJECT,\n", 1275 | " region=GOOGLE_CLOUD_REGION,\n", 1276 | ")\n", 1277 | "\n", 1278 | "_ = pipelines_client.create_run_from_job_spec(PIPELINE_DEFINITION_FILE, enable_caching=True)" 1279 | ], 1280 | "execution_count": null, 1281 | "outputs": [ 1282 | { 1283 | "output_type": "display_data", 1284 | "data": { 1285 | "text/html": [ 1286 | "See the Pipeline job here." 1287 | ], 1288 | "text/plain": [ 1289 | "" 1290 | ] 1291 | }, 1292 | "metadata": { 1293 | "tags": [] 1294 | } 1295 | } 1296 | ] 1297 | }, 1298 | { 1299 | "cell_type": "markdown", 1300 | "metadata": { 1301 | "id": "YX3pFWlhN_jR" 1302 | }, 1303 | "source": [ 1304 | "The pipeline should come out as the following:\n", 1305 | "\n", 1306 | "![](https://i.ibb.co/98Ry74n/Screen-Shot-2021-08-06-at-1-43-35-AM.png)" 1307 | ] 1308 | }, 1309 | { 1310 | "cell_type": "markdown", 1311 | "metadata": { 1312 | "id": "e-757SqrYhf3" 1313 | }, 1314 | "source": [ 1315 | "## Making predictions with the Endpoint\n", 1316 | "\n", 1317 | "Some code is used from [here](https://github.com/GoogleCloudPlatform/ai-platform-samples/blob/master/ai-platform-unified/notebooks/unofficial/gapic/custom/showcase_custom_image_classification_online.ipynb). " 1318 | ] 1319 | }, 1320 | { 1321 | "cell_type": "markdown", 1322 | "metadata": { 1323 | "id": "bclxNBlJ-LNu" 1324 | }, 1325 | "source": [ 1326 | "### Imports and initialization" 1327 | ] 1328 | }, 1329 | { 1330 | "cell_type": "code", 1331 | "metadata": { 1332 | "id": "6XWcRgv98gGr" 1333 | }, 1334 | "source": [ 1335 | "from google.cloud.aiplatform import gapic as aip\n", 1336 | "from google.protobuf import json_format\n", 1337 | "from google.protobuf.json_format import MessageToJson, ParseDict\n", 1338 | "from google.protobuf.struct_pb2 import Struct, Value\n", 1339 | "\n", 1340 | "import base64" 1341 | ], 1342 | "execution_count": null, 1343 | "outputs": [] 1344 | }, 1345 | { 1346 | "cell_type": "code", 1347 | "metadata": { 1348 | "id": "zSAVSxSC9EOC" 1349 | }, 1350 | "source": [ 1351 | "vertex_ai.init(project=GOOGLE_CLOUD_PROJECT, location=GOOGLE_CLOUD_REGION)" 1352 | ], 1353 | "execution_count": null, 1354 | "outputs": [] 1355 | }, 1356 | { 1357 | "cell_type": "markdown", 1358 | "metadata": { 1359 | "id": "zpZWBSs6-Nfz" 1360 | }, 1361 | "source": [ 1362 | "### Programatically retrieve the latest Endpoint macthing a name" 1363 | ] 1364 | }, 1365 | { 1366 | "cell_type": "code", 1367 | "metadata": { 1368 | "colab": { 1369 | "base_uri": "https://localhost:8080/", 1370 | "height": 35 1371 | }, 1372 | "id": "ZXKmceKG8qYW", 1373 | "outputId": "2c65421c-0211-43a2-fa12-cd69dc31cea4" 1374 | }, 1375 | "source": [ 1376 | "model_display_name = \"densenet_flowers_latest\"\n", 1377 | "deployed_model_display_name = model_display_name + \"_\" + TIMESTAMP\n", 1378 | "\n", 1379 | "endpoint = vertex_ai.Endpoint.list(\n", 1380 | " filter=f'display_name={deployed_model_display_name}',\n", 1381 | " order_by=\"update_time\"\n", 1382 | ")[-1]\n", 1383 | "\n", 1384 | "endpoint_id = endpoint.name\n", 1385 | "endpoint_id" 1386 | ], 1387 | "execution_count": null, 1388 | "outputs": [ 1389 | { 1390 | "output_type": "execute_result", 1391 | "data": { 1392 | "text/plain": [ 1393 | "'3904532915999997952'" 1394 | ] 1395 | }, 1396 | "metadata": { 1397 | "tags": [] 1398 | }, 1399 | "execution_count": 19 1400 | } 1401 | ] 1402 | }, 1403 | { 1404 | "cell_type": "markdown", 1405 | "metadata": { 1406 | "id": "uS9x3SNu-R5l" 1407 | }, 1408 | "source": [ 1409 | "### Sample data" 1410 | ] 1411 | }, 1412 | { 1413 | "cell_type": "code", 1414 | "metadata": { 1415 | "colab": { 1416 | "base_uri": "https://localhost:8080/" 1417 | }, 1418 | "id": "wJeyhzeL8-2C", 1419 | "outputId": "ace8e9b8-976b-4de4-e55b-728ff7012730" 1420 | }, 1421 | "source": [ 1422 | "image_path = tf.keras.utils.get_file(\"image.jpg\", \n", 1423 | " \"https://m.economictimes.com/thumb/msid-71307470,width-1201,height-900,resizemode-4,imgsize-1040796/roses.jpg\")\n", 1424 | "bytes = tf.io.read_file(image_path)\n", 1425 | "b64str = base64.b64encode(bytes.numpy()).decode(\"utf-8\")" 1426 | ], 1427 | "execution_count": null, 1428 | "outputs": [ 1429 | { 1430 | "output_type": "stream", 1431 | "text": [ 1432 | "Downloading data from https://m.economictimes.com/thumb/msid-71307470,width-1201,height-900,resizemode-4,imgsize-1040796/roses.jpg\n", 1433 | "229376/221658 [===============================]- ETA: - ETA: - 0s 0us/step\n" 1434 | ], 1435 | "name": "stdout" 1436 | } 1437 | ] 1438 | }, 1439 | { 1440 | "cell_type": "markdown", 1441 | "metadata": { 1442 | "id": "Jmhj5-FU-TmQ" 1443 | }, 1444 | "source": [ 1445 | "### Investigating the input key" 1446 | ] 1447 | }, 1448 | { 1449 | "cell_type": "code", 1450 | "metadata": { 1451 | "colab": { 1452 | "base_uri": "https://localhost:8080/" 1453 | }, 1454 | "id": "2H6M-Usq9cUr", 1455 | "outputId": "e5412f05-51c9-4bb6-c8b7-07b550784a74" 1456 | }, 1457 | "source": [ 1458 | "pushed_model_location = os.path.join(SERVING_MODEL_DIR, \"densenet\")\n", 1459 | "model_path_to_deploy = os.path.join(\n", 1460 | " pushed_model_location, tf.io.gfile.listdir(pushed_model_location)[-1]\n", 1461 | ")\n", 1462 | "\n", 1463 | "loaded = tf.saved_model.load(model_path_to_deploy)\n", 1464 | "serving_input = list(\n", 1465 | " loaded.signatures[\"serving_default\"].structured_input_signature[1].keys()\n", 1466 | ")[0]\n", 1467 | "print(\"Serving function input:\", serving_input)" 1468 | ], 1469 | "execution_count": null, 1470 | "outputs": [ 1471 | { 1472 | "output_type": "stream", 1473 | "text": [ 1474 | "Serving function input: bytes_inputs\n" 1475 | ], 1476 | "name": "stdout" 1477 | } 1478 | ] 1479 | }, 1480 | { 1481 | "cell_type": "markdown", 1482 | "metadata": { 1483 | "id": "C8gnuJf1-YkC" 1484 | }, 1485 | "source": [ 1486 | "### Make predictions" 1487 | ] 1488 | }, 1489 | { 1490 | "cell_type": "code", 1491 | "metadata": { 1492 | "colab": { 1493 | "base_uri": "https://localhost:8080/" 1494 | }, 1495 | "id": "UcB7HZat9QGI", 1496 | "outputId": "ceafc8e4-e93f-4607-c39c-160f8356520d" 1497 | }, 1498 | "source": [ 1499 | "def predict_image(image, endpoint, parameters_dict):\n", 1500 | " # The format of each instance should conform to the deployed model's prediction input schema.\n", 1501 | " instances_list = [{serving_input: {\"b64\": image}}]\n", 1502 | " instances = [json_format.ParseDict(s, Value()) for s in instances_list]\n", 1503 | "\n", 1504 | " endpoint = vertex_ai.Endpoint(endpoint)\n", 1505 | " print(endpoint.predict(instances=instances))\n", 1506 | "\n", 1507 | "predict_image(b64str, endpoint_id, None)" 1508 | ], 1509 | "execution_count": null, 1510 | "outputs": [ 1511 | { 1512 | "output_type": "stream", 1513 | "text": [ 1514 | "Prediction(predictions=[{'confidence': 0.669201732, 'label': 'roses'}], deployed_model_id='7558456345704267776', explanations=None)\n" 1515 | ], 1516 | "name": "stdout" 1517 | } 1518 | ] 1519 | } 1520 | ] 1521 | } --------------------------------------------------------------------------------