├── custom_components
├── __init__.py
├── vertex_uploader.py
├── vertex_deployer.py
├── flower_mobilenet_trainer.py
├── firebase_publisher.py
└── flower_densenet_trainer.py
├── figures
├── cloud_build_log.png
├── overall_workflow.png
├── sample_architecture.png
└── README.md
├── notebooks
├── README.md
├── Dataset_Prep.ipynb
├── Dual_Deployments_With_AutoML.ipynb
└── Custom_Model_TFX.ipynb
├── .gitignore
├── README.md
└── LICENSE
/custom_components/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/figures/cloud_build_log.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sayakpaul/Dual-Deployments-on-Vertex-AI/main/figures/cloud_build_log.png
--------------------------------------------------------------------------------
/figures/overall_workflow.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sayakpaul/Dual-Deployments-on-Vertex-AI/main/figures/overall_workflow.png
--------------------------------------------------------------------------------
/figures/sample_architecture.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sayakpaul/Dual-Deployments-on-Vertex-AI/main/figures/sample_architecture.png
--------------------------------------------------------------------------------
/figures/README.md:
--------------------------------------------------------------------------------
1 | The beautiful figures contained in this directory have been prepared by [Chansung Park](https://github.com/deep-diver).
2 |
--------------------------------------------------------------------------------
/custom_components/vertex_uploader.py:
--------------------------------------------------------------------------------
1 | """
2 | Custom TFX component for importing a model into Vertex AI.
3 | Author: Sayak Paul
4 | Reference: https://github.com/GoogleCloudPlatform/mlops-with-vertex-ai/blob/main/src/tfx_pipelines/components.py#L74
5 | """
6 |
7 | import os
8 | import tensorflow as tf
9 |
10 | from tfx.dsl.component.experimental.decorators import component
11 | from tfx.dsl.component.experimental.annotations import Parameter
12 | from tfx.types.standard_artifacts import String
13 | from google.cloud import aiplatform as vertex_ai
14 | from tfx import v1 as tfx
15 | from absl import logging
16 |
17 |
18 | @component
19 | def VertexUploader(
20 | project: Parameter[str],
21 | region: Parameter[str],
22 | model_display_name: Parameter[str],
23 | pushed_model_location: Parameter[str],
24 | serving_image_uri: Parameter[str],
25 | uploaded_model: tfx.dsl.components.OutputArtifact[String],
26 | ):
27 |
28 | vertex_ai.init(project=project, location=region)
29 |
30 | pushed_model_dir = os.path.join(
31 | pushed_model_location, tf.io.gfile.listdir(pushed_model_location)[-1]
32 | )
33 |
34 | logging.info(f"Model registry location: {pushed_model_dir}")
35 |
36 | vertex_model = vertex_ai.Model.upload(
37 | display_name=model_display_name,
38 | artifact_uri=pushed_model_dir,
39 | serving_container_image_uri=serving_image_uri,
40 | parameters_schema_uri=None,
41 | instance_schema_uri=None,
42 | explanation_metadata=None,
43 | explanation_parameters=None,
44 | )
45 |
46 | uploaded_model.set_string_custom_property(
47 | "model_resource_name", str(vertex_model.resource_name)
48 | )
49 | logging.info(f"Model resource: {str(vertex_model.resource_name)}")
50 |
--------------------------------------------------------------------------------
/custom_components/vertex_deployer.py:
--------------------------------------------------------------------------------
1 | """
2 | Custom TFX component for deploying a model to a Vertex AI Endpoint.
3 | Author: Sayak Paul
4 | Reference: https://github.com/GoogleCloudPlatform/mlops-with-vertex-ai/blob/main/build/utils.py#L97
5 | """
6 |
7 | from tfx.dsl.component.experimental.decorators import component
8 | from tfx.dsl.component.experimental.annotations import Parameter
9 | from tfx.types.standard_artifacts import String
10 | from google.cloud import aiplatform as vertex_ai
11 | from tfx import v1 as tfx
12 | from absl import logging
13 |
14 |
15 | @component
16 | def VertexDeployer(
17 | project: Parameter[str],
18 | region: Parameter[str],
19 | model_display_name: Parameter[str],
20 | deployed_model_display_name: Parameter[str],
21 | ):
22 |
23 | logging.info(f"Endpoint display: {deployed_model_display_name}")
24 | vertex_ai.init(project=project, location=region)
25 |
26 | endpoints = vertex_ai.Endpoint.list(
27 | filter=f"display_name={deployed_model_display_name}", order_by="update_time"
28 | )
29 |
30 | if len(endpoints) > 0:
31 | logging.info(f"Endpoint {deployed_model_display_name} already exists.")
32 | endpoint = endpoints[-1]
33 | else:
34 | endpoint = vertex_ai.Endpoint.create(deployed_model_display_name)
35 |
36 | model = vertex_ai.Model.list(
37 | filter=f"display_name={model_display_name}", order_by="update_time"
38 | )[-1]
39 |
40 | endpoint = vertex_ai.Endpoint.list(
41 | filter=f"display_name={deployed_model_display_name}", order_by="update_time"
42 | )[-1]
43 |
44 | deployed_model = endpoint.deploy(
45 | model=model,
46 | # Syntax from here: https://git.io/JBQDP
47 | traffic_split={"0": 100},
48 | machine_type="n1-standard-4",
49 | min_replica_count=1,
50 | max_replica_count=1,
51 | )
52 |
53 | logging.info(f"Model deployed to: {deployed_model}")
54 |
--------------------------------------------------------------------------------
/notebooks/README.md:
--------------------------------------------------------------------------------
1 | We expose our notebooks that are fully runnable on [Google Colab](https://colab.research.google.com/) to facilitate easier learning. Below
2 | you can find short descriptions of the individual notebooks present here.
3 |
4 | ## Notebooks
5 |
6 | |
Notebook
| Colab
| Description
|
7 | | :--------: | :-------: | :--------------------------------: |
8 | | `Dataset_Prep.ipynb` |
| Prepares dataset for AutoML. |
9 | | `Dual_Deployments_With_AutoML.ipynb` |
| Shows how to build a Kubeflow Pipeline
to train and deploy two different models
using AutoML and Vertex AI. |
10 | | `Model_Tests.ipynb` |
| Shows how to test the models trained
using the notebook above in a
standalone manner. |
11 | | `Custom_Model_TFX.ipynb` |
| Shows how to build a TFX pipeline using
custom components to train and deploy
two different models and run them using Vertex AI. |
12 |
13 | ## References
14 |
15 | * [Kubeflow Pipelines](https://www.kubeflow.org/docs/components/pipelines/overview/pipelines-overview/)
16 | * [AutoML SDKs from Kubeflow](https://google-cloud-pipeline-components.readthedocs.io/en/latest/google_cloud_pipeline_components.aiplatform.html#module-google_cloud_pipeline_components.aiplatform)
17 | * [Vertex AI Pipelines](https://cloud.google.com/vertex-ai/docs/pipelines)
18 | * [TensorFlow Extended (TFX)](https://www.tensorflow.org/tfx)
19 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | pip-wheel-metadata/
24 | share/python-wheels/
25 | *.egg-info/
26 | .installed.cfg
27 | *.egg
28 | MANIFEST
29 |
30 | # PyInstaller
31 | # Usually these files are written by a python script from a template
32 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
33 | *.manifest
34 | *.spec
35 |
36 | # Installer logs
37 | pip-log.txt
38 | pip-delete-this-directory.txt
39 |
40 | # Unit test / coverage reports
41 | htmlcov/
42 | .tox/
43 | .nox/
44 | .coverage
45 | .coverage.*
46 | .cache
47 | nosetests.xml
48 | coverage.xml
49 | *.cover
50 | *.py,cover
51 | .hypothesis/
52 | .pytest_cache/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | target/
76 |
77 | # Jupyter Notebook
78 | .ipynb_checkpoints
79 |
80 | # IPython
81 | profile_default/
82 | ipython_config.py
83 |
84 | # pyenv
85 | .python-version
86 |
87 | # pipenv
88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
91 | # install all needed dependencies.
92 | #Pipfile.lock
93 |
94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
95 | __pypackages__/
96 |
97 | # Celery stuff
98 | celerybeat-schedule
99 | celerybeat.pid
100 |
101 | # SageMath parsed files
102 | *.sage.py
103 |
104 | # Environments
105 | .env
106 | .venv
107 | env/
108 | venv/
109 | ENV/
110 | env.bak/
111 | venv.bak/
112 |
113 | # Spyder project settings
114 | .spyderproject
115 | .spyproject
116 |
117 | # Rope project settings
118 | .ropeproject
119 |
120 | # mkdocs documentation
121 | /site
122 |
123 | # mypy
124 | .mypy_cache/
125 | .dmypy.json
126 | dmypy.json
127 |
128 | # Pyre type checker
129 | .pyre/
130 |
--------------------------------------------------------------------------------
/custom_components/flower_mobilenet_trainer.py:
--------------------------------------------------------------------------------
1 | from typing import List
2 | from absl import logging
3 | from tensorflow import keras
4 | from tfx import v1 as tfx
5 | import tensorflow as tf
6 |
7 |
8 | _IMAGE_FEATURES = {
9 | "image": tf.io.FixedLenFeature([], tf.string),
10 | "class": tf.io.FixedLenFeature([], tf.int64),
11 | "one_hot_class": tf.io.VarLenFeature(tf.float32),
12 | }
13 |
14 | _INPUT_SHAPE = (224, 224, 3)
15 | _TRAIN_BATCH_SIZE = 64
16 | _EVAL_BATCH_SIZE = 64
17 | _EPOCHS = 2
18 |
19 |
20 | def _parse_fn(example):
21 | example = tf.io.parse_single_example(example, _IMAGE_FEATURES)
22 | image = tf.image.decode_jpeg(example["image"], channels=3)
23 | class_label = tf.cast(example["class"], tf.int32)
24 | return image, class_label
25 |
26 |
27 | def _input_fn(file_pattern: List[str], batch_size: int) -> tf.data.Dataset:
28 | """Generates features and label for training.
29 |
30 | Args:
31 | file_pattern: List of paths or patterns of input tfrecord files.
32 | batch_size: representing the number of consecutive elements of returned
33 | dataset to combine in a single batch.
34 |
35 | Returns:
36 | A dataset that contains (features, indices) tuple where features is a
37 | dictionary of Tensors, and indices is a single Tensor of label indices.
38 | """
39 | logging.info(f"Reading data from: {file_pattern}")
40 | tfrecord_filenames = tf.io.gfile.glob(file_pattern[0] + ".gz")
41 | dataset = tf.data.TFRecordDataset(tfrecord_filenames, compression_type="GZIP")
42 | dataset = dataset.map(_parse_fn).batch(batch_size)
43 | return dataset.repeat()
44 |
45 |
46 | def _make_keras_model() -> tf.keras.Model:
47 | """Creates a MobileNetV3-based model for classifying flowers data.
48 |
49 | Returns:
50 | A Keras Model.
51 | """
52 | inputs = keras.Input(shape=_INPUT_SHAPE)
53 | base_model = keras.applications.MobileNetV3Small(
54 | include_top=False, input_shape=_INPUT_SHAPE, pooling="avg"
55 | )
56 | base_model.trainable = False
57 | x = keras.applications.mobilenet_v3.preprocess_input(inputs)
58 | x = base_model(
59 | x, training=False
60 | ) # Ensures BatchNorm runs in inference model in this model
61 | outputs = keras.layers.Dense(5, activation="softmax")(x)
62 | model = keras.Model(inputs, outputs)
63 |
64 | model.compile(
65 | optimizer=keras.optimizers.Adam(),
66 | loss=tf.keras.losses.SparseCategoricalCrossentropy(),
67 | metrics=[keras.metrics.SparseCategoricalAccuracy()],
68 | )
69 |
70 | model.summary(print_fn=logging.info)
71 | return model
72 |
73 |
74 | # TFX Trainer will call this function.
75 | def run_fn(fn_args: tfx.components.FnArgs):
76 | """Train the model based on given args.
77 |
78 | Args:
79 | fn_args: Holds args used to train the model as name/value pairs.
80 | """
81 | train_dataset = _input_fn(fn_args.train_files, batch_size=_TRAIN_BATCH_SIZE)
82 | eval_dataset = _input_fn(fn_args.eval_files, batch_size=_EVAL_BATCH_SIZE)
83 |
84 | model = _make_keras_model()
85 | model.fit(
86 | train_dataset,
87 | steps_per_epoch=fn_args.train_steps,
88 | validation_data=eval_dataset,
89 | validation_steps=fn_args.eval_steps,
90 | epochs=_EPOCHS,
91 | )
92 | _, acc = model.evaluate(eval_dataset, steps=fn_args.eval_steps)
93 | logging.info(f"Validation accuracy: {round(acc * 100, 2)}%")
94 |
95 | # Convert the model.
96 | converter = tf.lite.TFLiteConverter.from_keras_model(model)
97 | converter.optimizations = [tf.lite.Optimize.DEFAULT]
98 | tflite_model = converter.convert()
99 |
100 | # Save the model.
101 | # The result of the training should be saved in `fn_args.serving_model_dir` directory.
102 | with tf.io.gfile.GFile(fn_args.serving_model_dir + "/model.tflite", "wb") as f:
103 | f.write(tflite_model)
104 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Dual-Deployments-on-Vertex-AI
2 |
3 | _By [Chansung Park](https://github.com/deep-diver) and Sayak Paul_
4 |
5 | This project demonstrates a workflow to cover dual model deployment scenarios using [Kubeflow](https://www.kubeflow.org/),
6 | [TensorFlow Extended (TFX)](https://www.tensorflow.org/tfx), and [Vertex AI](https://cloud.google.com/vertex-ai). We suggest
7 | reading the accompanying [blog post](https://cloud.google.com/blog/topics/developers-practitioners/dual-deployments-vertex-ai) first
8 | to get an idea and then following along with the code. This project also received the [#TFCommunitySpotlight Award](https://twitter.com/TensorFlow/status/1446611368078086144?s=20).
9 |
10 | ## Motivation 💻
11 |
12 | Let's say you want to allow your users to run an application both in online and offline mode. Your mobile
13 | application would use a TFLite model depending on the network bandwidth/battery etc., and if sufficient
14 | network coverage/internet bandwidth is available your application would instead use the online cloud one. This way
15 | your application stays resilient and can ensure high availability.
16 |
17 | Sometimes we also do layered predictions where we first divide a problem into smaller tasks:
18 | 1) predict if it's a yes/no,
19 | 2) depending on the output of 1) we run the final model.
20 |
21 | In these cases, 1) takes place on-device and 2) takes place on the cloud to ensure a smooth UX. Furthermore, it's
22 | a good practice to use a mobile-friendly network architecture (such as MobileNet) when considering
23 | mobile deployments. This leads us to the following question:
24 |
25 | _**Can we train two different models within the same deployment pipeline and manage them seamlessly?**_
26 |
27 | This project is motivated by this question.
28 |
29 | ## AutoML, TFX, etc. 🛠
30 |
31 | Different organizations have people with varied technical backgrounds. We wanted to provide the easiest solution first
32 | and then move on to something that is more customizable. To this end, we leverage [Kubeflow's AutoML SDKs](https://github.com/kubeflow/pipelines/tree/master/components/google-cloud) to build, train, and deploy models with
33 | different production use-cases. With AutoML, the developers can delegate a large part of their workflows to the SDKs
34 | and the codebase also stays comparatively smaller. The figure below depicts a sample system architecture for
35 | this scenario:
36 |
37 | 
38 |
39 | **Figure developed by Chansung Park.**
40 |
41 | But the story does not end here. What if we wanted to have better control over the models to be built, trained,
42 | and deployed? Enter TFX! TFX provides the flexibility of writing custom components and including them inside a
43 | pipeline. This way Machine Learning Engineers can focus on building and training their favorite models and delegate
44 | a part of the heavy lifting to TFX and Vertex AI. On Vertex AI (acting as an orchestrator) this pipeline will look like
45 | so:
46 |
47 | 
48 |
49 | ```txt
50 | 🔥 In this project we cover both these situations.
51 | ```
52 |
53 | ## Code 🆘
54 |
55 | Our code is distributed as Colab Notebooks. But one needs to have a billing-enabled GCP account
56 | (with a few APIs enabled) to successfully run these notebooks. Alternatively one can also use the
57 | notebooks on [Vertex AI Notebooks](https://cloud.google.com/vertex-ai/docs/general/notebooks). Find
58 | all the notebooks and their descriptions here:
59 | [`notebooks`](https://github.com/sayakpaul/Dual-Deployments-on-Vertex-AI/tree/main/notebooks).
60 |
61 | Additionally, you can find the custom TFX components separately here - [`custom_components`](https://github.com/sayakpaul/Dual-Deployments-on-Vertex-AI/tree/main/custom_components).
62 |
63 | ## Acknowledgements
64 |
65 | [ML-GDE program](https://developers.google.com/programs/experts/) for providing GCP credits. Thanks to [Karl Weinmeister](https://twitter.com/kweinmeister?lang=hr) and [Robert Crowe](https://twitter.com/robert_crowe?lang=en) for providing review feedback on this project.
66 |
--------------------------------------------------------------------------------
/custom_components/firebase_publisher.py:
--------------------------------------------------------------------------------
1 | """
2 | Custom TFX component for Firebase upload.
3 | Author: Chansung Park
4 | """
5 |
6 | from tfx import types
7 | from tfx.dsl.component.experimental.decorators import component
8 | from tfx.dsl.component.experimental.annotations import Parameter
9 | from tfx import v1 as tfx
10 | from absl import logging
11 |
12 | import firebase_admin
13 | from firebase_admin import ml
14 | from firebase_admin import storage
15 | from firebase_admin import credentials
16 | from google.cloud import storage as gcs_storage
17 |
18 |
19 | @component
20 | def FirebasePublisher(
21 | pushed_model: tfx.dsl.components.InputArtifact[
22 | tfx.types.standard_artifacts.PushedModel
23 | ],
24 | credential_uri: Parameter[str],
25 | firebase_dest_gcs_bucket: Parameter[str],
26 | model_display_name: Parameter[str],
27 | model_tag: Parameter[str],
28 | ) -> tfx.dsl.components.OutputDict(result=str):
29 | """
30 | publish trained tflite model to Firebase ML, this component assumes that
31 | trained model and Firebase credential files are stored in GCS locations.
32 |
33 | Args:
34 | pushed_model: The URI of pushed model obtained from previous component (i.e. Pusher)
35 | credential_uri: The URI of Firebase credential. In order to get one, go to Firebase dashboard
36 | and on the Settings page, create a service account and download the service account key file.
37 | Keep this file safe, since it grants administrator access to your project.
38 | firebase_dest_gcs_bucket: GCS bucket where the model is going to be temporarily stored.
39 | In order to create one, go to Firebase dashboard and on the Storage page, enable Cloud Storage.
40 | Take note of your bucket name.
41 | model_display_name: The name to be appeared on Firebase ML dashboard
42 | model_tag: The tage name to be appeared on Firebase ML dashboard
43 | """
44 |
45 | model_uri = f"{pushed_model.uri}/model.tflite"
46 |
47 | assert model_uri.split("://")[0] == "gs"
48 | assert credential_uri.split("://")[0] == "gs"
49 |
50 | # create gcs client instance
51 | gcs_client = gcs_storage.Client()
52 |
53 | # get credential for firebase
54 | credential_gcs_bucket = credential_uri.split("//")[1].split("/")[0]
55 | credential_blob_path = "/".join(credential_uri.split("//")[1].split("/")[1:])
56 |
57 | bucket = gcs_client.bucket(credential_gcs_bucket)
58 | blob = bucket.blob(credential_blob_path)
59 | blob.download_to_filename("credential.json")
60 | logging.info(f"download credential.json from {credential_uri} is completed")
61 |
62 | # get tflite model file
63 | tflite_gcs_bucket = model_uri.split("//")[1].split("/")[0]
64 | tflite_blob_path = "/".join(model_uri.split("//")[1].split("/")[1:])
65 |
66 | bucket = gcs_client.bucket(tflite_gcs_bucket)
67 | blob = bucket.blob(tflite_blob_path)
68 | blob.download_to_filename("model.tflite")
69 | logging.info(f"download model.tflite from {model_uri} is completed")
70 |
71 | firebase_admin.initialize_app(
72 | credentials.Certificate("credential.json"),
73 | options={"storageBucket": firebase_dest_gcs_bucket},
74 | )
75 | logging.info("firebase_admin initialize app is completed")
76 |
77 | model_list = ml.list_models(list_filter=f"display_name={model_display_name}")
78 | # update
79 | if len(model_list.models) > 0:
80 | # get the first match model
81 | model = model_list.models[0]
82 | source = ml.TFLiteGCSModelSource.from_tflite_model_file("model.tflite")
83 | model.model_format = ml.TFLiteFormat(model_source=source)
84 |
85 | updated_model = ml.update_model(model)
86 | ml.publish_model(updated_model.model_id)
87 |
88 | logging.info("model exists, so update it in FireBase ML")
89 | return {"result": "model updated"}
90 | # create
91 | else:
92 | # load a tflite file and upload it to Cloud Storage
93 | source = ml.TFLiteGCSModelSource.from_tflite_model_file("model.tflite")
94 |
95 | # create the model object
96 | tflite_format = ml.TFLiteFormat(model_source=source)
97 | model = ml.Model(
98 | display_name=model_display_name,
99 | tags=[model_tag],
100 | model_format=tflite_format,
101 | )
102 |
103 | # Add the model to your Firebase project and publish it
104 | new_model = ml.create_model(model)
105 | ml.publish_model(new_model.model_id)
106 |
107 | logging.info("model doesn exists, so create one in FireBase ML")
108 | return {"result": "model created"}
109 |
--------------------------------------------------------------------------------
/custom_components/flower_densenet_trainer.py:
--------------------------------------------------------------------------------
1 | from typing import List
2 | from absl import logging
3 | from tensorflow import keras
4 | from tfx import v1 as tfx
5 | import tensorflow as tf
6 |
7 |
8 | _IMAGE_FEATURES = {
9 | "image": tf.io.FixedLenFeature([], tf.string),
10 | "class": tf.io.FixedLenFeature([], tf.int64),
11 | "one_hot_class": tf.io.VarLenFeature(tf.float32),
12 | }
13 |
14 | _CONCRETE_INPUT = "numpy_inputs"
15 | _INPUT_SHAPE = (224, 224, 3)
16 | _TRAIN_BATCH_SIZE = 64
17 | _EVAL_BATCH_SIZE = 64
18 | _EPOCHS = 2
19 |
20 |
21 | def _parse_fn(example):
22 | example = tf.io.parse_single_example(example, _IMAGE_FEATURES)
23 | image = tf.image.decode_jpeg(example["image"], channels=3)
24 | class_label = tf.cast(example["class"], tf.int32)
25 | return image, class_label
26 |
27 |
28 | def _input_fn(file_pattern: List[str], batch_size: int) -> tf.data.Dataset:
29 | """Generates features and label for training.
30 |
31 | Args:
32 | file_pattern: List of paths or patterns of input tfrecord files.
33 | batch_size: representing the number of consecutive elements of returned
34 | dataset to combine in a single batch.
35 |
36 | Returns:
37 | A dataset that contains (features, indices) tuple where features is a
38 | dictionary of Tensors, and indices is a single Tensor of label indices.
39 | """
40 | logging.info(f"Reading data from: {file_pattern}")
41 | tfrecord_filenames = tf.io.gfile.glob(file_pattern[0] + ".gz")
42 | dataset = tf.data.TFRecordDataset(tfrecord_filenames, compression_type="GZIP")
43 | dataset = dataset.map(_parse_fn).batch(batch_size)
44 | return dataset.repeat()
45 |
46 |
47 | def _make_keras_model() -> tf.keras.Model:
48 | """Creates a DenseNet121-based model for classifying flowers data.
49 |
50 | Returns:
51 | A Keras Model.
52 | """
53 | inputs = keras.Input(shape=_INPUT_SHAPE)
54 | base_model = keras.applications.DenseNet121(
55 | include_top=False, input_shape=_INPUT_SHAPE, pooling="avg"
56 | )
57 | base_model.trainable = False
58 | x = keras.applications.densenet.preprocess_input(inputs)
59 | x = base_model(
60 | x, training=False
61 | ) # Ensures BatchNorm runs in inference model in this model
62 | outputs = keras.layers.Dense(5, activation="softmax")(x)
63 | model = keras.Model(inputs, outputs)
64 |
65 | model.compile(
66 | optimizer=keras.optimizers.Adam(),
67 | loss=tf.keras.losses.SparseCategoricalCrossentropy(),
68 | metrics=[keras.metrics.SparseCategoricalAccuracy()],
69 | )
70 |
71 | model.summary(print_fn=logging.info)
72 | return model
73 |
74 |
75 | def _preprocess(bytes_input):
76 | decoded = tf.io.decode_jpeg(bytes_input, channels=3)
77 | resized = tf.image.resize(decoded, size=(224, 224))
78 | return resized
79 |
80 |
81 | @tf.function(input_signature=[tf.TensorSpec([None], tf.string)])
82 | def preprocess_fn(bytes_inputs):
83 | decoded_images = tf.map_fn(
84 | _preprocess, bytes_inputs, dtype=tf.float32, back_prop=False
85 | )
86 | return {_CONCRETE_INPUT: decoded_images}
87 |
88 |
89 | def _model_exporter(model: tf.keras.Model):
90 | m_call = tf.function(model.call).get_concrete_function(
91 | [
92 | tf.TensorSpec(
93 | shape=[None, 224, 224, 3], dtype=tf.float32, name=_CONCRETE_INPUT
94 | )
95 | ]
96 | )
97 |
98 | @tf.function(input_signature=[tf.TensorSpec([None], tf.string)])
99 | def serving_fn(bytes_inputs):
100 | # This function comes from the Computer Vision book from O'Reilly.
101 | labels = tf.constant(
102 | ["daisy", "dandelion", "roses", "sunflowers", "tulips"], dtype=tf.string
103 | )
104 | images = preprocess_fn(bytes_inputs)
105 |
106 | probs = m_call(**images)
107 | indices = tf.argmax(probs, axis=1)
108 | pred_source = tf.gather(params=labels, indices=indices)
109 | pred_confidence = tf.reduce_max(probs, axis=1)
110 | return {"label": pred_source, "confidence": pred_confidence}
111 |
112 | return serving_fn
113 |
114 |
115 | # TFX Trainer will call this function.
116 | def run_fn(fn_args: tfx.components.FnArgs):
117 | """Train the model based on given args.
118 |
119 | Args:
120 | fn_args: Holds args used to train the model as name/value pairs.
121 | """
122 | train_dataset = _input_fn(fn_args.train_files, batch_size=_TRAIN_BATCH_SIZE)
123 | eval_dataset = _input_fn(fn_args.eval_files, batch_size=_EVAL_BATCH_SIZE)
124 |
125 | model = _make_keras_model()
126 | model.fit(
127 | train_dataset,
128 | steps_per_epoch=fn_args.train_steps,
129 | validation_data=eval_dataset,
130 | validation_steps=fn_args.eval_steps,
131 | epochs=_EPOCHS,
132 | )
133 | _, acc = model.evaluate(eval_dataset, steps=fn_args.eval_steps)
134 | logging.info(f"Validation accuracy: {round(acc * 100, 2)}%")
135 | # The result of the training should be saved in `fn_args.serving_model_dir`
136 | # directory.
137 | tf.saved_model.save(
138 | model,
139 | fn_args.serving_model_dir,
140 | signatures={"serving_default": _model_exporter(model)},
141 | )
142 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright [yyyy] [name of copyright owner]
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/notebooks/Dataset_Prep.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "nbformat": 4,
3 | "nbformat_minor": 0,
4 | "metadata": {
5 | "colab": {
6 | "name": "Dataset_Prep",
7 | "provenance": [],
8 | "include_colab_link": true
9 | },
10 | "kernelspec": {
11 | "display_name": "Python 3",
12 | "name": "python3"
13 | }
14 | },
15 | "cells": [
16 | {
17 | "cell_type": "markdown",
18 | "metadata": {
19 | "id": "view-in-github",
20 | "colab_type": "text"
21 | },
22 | "source": [
23 | "
"
24 | ]
25 | },
26 | {
27 | "cell_type": "markdown",
28 | "metadata": {
29 | "id": "zyOcSJS29gkH"
30 | },
31 | "source": [
32 | "In this notebook, we will use the [Flowers dataset](https://www.tensorflow.org/datasets/catalog/tf_flowers) and create a `.csv` file out of it so that it can be imported into Vertex AI as a [managed dataset](https://cloud.google.com/vertex-ai/docs/training/using-managed-datasets). \n",
33 | "\n",
34 | "To proceed with the rest of the notebook you'd need a billing-enabled GCP account. "
35 | ]
36 | },
37 | {
38 | "cell_type": "markdown",
39 | "metadata": {
40 | "id": "TRIz8jbQ-MUb"
41 | },
42 | "source": [
43 | "## Setup"
44 | ]
45 | },
46 | {
47 | "cell_type": "code",
48 | "metadata": {
49 | "id": "lIYdn1woOS1n"
50 | },
51 | "source": [
52 | "!gcloud init"
53 | ],
54 | "execution_count": null,
55 | "outputs": []
56 | },
57 | {
58 | "cell_type": "markdown",
59 | "metadata": {
60 | "id": "bG-NwjjB-ioI"
61 | },
62 | "source": [
63 | "## Download the original dataset and copy over to a GCS Bucket"
64 | ]
65 | },
66 | {
67 | "cell_type": "code",
68 | "metadata": {
69 | "id": "dWUoW87xtuQ8"
70 | },
71 | "source": [
72 | "!wget -q https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz\n",
73 | "!tar -xf flower_photos.tgz"
74 | ],
75 | "execution_count": null,
76 | "outputs": []
77 | },
78 | {
79 | "cell_type": "markdown",
80 | "metadata": {
81 | "id": "U7nDmdxY-rFr"
82 | },
83 | "source": [
84 | "***If you have a spare bucket, then you can use that for this purpose.*** "
85 | ]
86 | },
87 | {
88 | "cell_type": "code",
89 | "metadata": {
90 | "id": "eqldz8Jkz8se"
91 | },
92 | "source": [
93 | "#@title GCS\n",
94 | "#@markdown You should change these values as per your preferences. The copy operation can take ~5 minutes. \n",
95 | "BUCKET_PATH = \"gs://flowers-experimental\" #@param {type:\"string\"}\n",
96 | "REGION = \"us-central1\" #@param {type:\"string\"}\n",
97 | "\n",
98 | "!gsutil mb -l {REGION} {BUCKET_PATH}\n",
99 | "!gsutil -m cp -r flower_photos {BUCKET_PATH}"
100 | ],
101 | "execution_count": null,
102 | "outputs": []
103 | },
104 | {
105 | "cell_type": "markdown",
106 | "metadata": {
107 | "id": "UX7Gw2_h-4Pk"
108 | },
109 | "source": [
110 | "Verify if the files were copied over."
111 | ]
112 | },
113 | {
114 | "cell_type": "code",
115 | "metadata": {
116 | "colab": {
117 | "base_uri": "https://localhost:8080/"
118 | },
119 | "id": "KNfP-pbowVkU",
120 | "outputId": "2ae0abcb-2ee8-4d22-af6b-16b6bc29c824"
121 | },
122 | "source": [
123 | "!gsutil ls {BUCKET_PATH}/flower_photos/"
124 | ],
125 | "execution_count": null,
126 | "outputs": [
127 | {
128 | "output_type": "stream",
129 | "text": [
130 | "gs://flowers-experimental/flower_photos/LICENSE.txt\n",
131 | "gs://flowers-experimental/flower_photos/daisy/\n",
132 | "gs://flowers-experimental/flower_photos/dandelion/\n",
133 | "gs://flowers-experimental/flower_photos/roses/\n",
134 | "gs://flowers-experimental/flower_photos/sunflowers/\n",
135 | "gs://flowers-experimental/flower_photos/tulips/\n"
136 | ],
137 | "name": "stdout"
138 | }
139 | ]
140 | },
141 | {
142 | "cell_type": "markdown",
143 | "metadata": {
144 | "id": "ZPOBNLDS--Vu"
145 | },
146 | "source": [
147 | "## Imports"
148 | ]
149 | },
150 | {
151 | "cell_type": "code",
152 | "metadata": {
153 | "id": "_u3FHpaAw7He"
154 | },
155 | "source": [
156 | "import random\n",
157 | "random.seed(666)\n",
158 | "\n",
159 | "from google.cloud import storage\n",
160 | "from pprint import pprint\n",
161 | "import pandas as pd\n",
162 | "import os"
163 | ],
164 | "execution_count": null,
165 | "outputs": []
166 | },
167 | {
168 | "cell_type": "code",
169 | "metadata": {
170 | "id": "XMN7RXwt_M20"
171 | },
172 | "source": [
173 | "from google.colab import auth\n",
174 | "auth.authenticate_user()"
175 | ],
176 | "execution_count": null,
177 | "outputs": []
178 | },
179 | {
180 | "cell_type": "markdown",
181 | "metadata": {
182 | "id": "utZIHk6p_a5X"
183 | },
184 | "source": [
185 | "## Preparing a single `.csv` file\n",
186 | "\n",
187 | "Vertex AI datasets can operate with `.jsonl` and `.csv` formats in order to import datasets. In this notebook, we will be using `.csv`. Here's the structure which Vertex AI expects ([reference](https://cloud.google.com/vertex-ai/docs/datasets/prepare-image#csv)):\n",
188 | "\n",
189 | "```\n",
190 | "[ML_USE],GCS_FILE_PATH,[LABEL]\n",
191 | "```\n",
192 | "\n",
193 | "`ML_USE` stands for the data split - `training`, `valid`, and `test`. "
194 | ]
195 | },
196 | {
197 | "cell_type": "markdown",
198 | "metadata": {
199 | "id": "d-1fItckAh0u"
200 | },
201 | "source": [
202 | "### Derive GCS URIs of the images"
203 | ]
204 | },
205 | {
206 | "cell_type": "code",
207 | "metadata": {
208 | "id": "0aQEzBF-zE9J",
209 | "colab": {
210 | "base_uri": "https://localhost:8080/"
211 | },
212 | "outputId": "b9f23b24-2a74-49db-fe41-7f67392c32b1"
213 | },
214 | "source": [
215 | "gs_uris = []\n",
216 | "\n",
217 | "storage_client = storage.Client(project=\"fast-ai-exploration\") # Change it accordingly.\n",
218 | "blobs = storage_client.list_blobs(BUCKET_PATH.split(\"/\")[-1])\n",
219 | "\n",
220 | "for blob in blobs:\n",
221 | " if \".txt\" in blob.name.split(\"/\")[-1]:\n",
222 | " continue\n",
223 | " gs_uri = os.path.join(BUCKET_PATH, blob.name)\n",
224 | " gs_uris.append(gs_uri)\n",
225 | "\n",
226 | "pprint(gs_uris[:5])"
227 | ],
228 | "execution_count": null,
229 | "outputs": [
230 | {
231 | "output_type": "stream",
232 | "text": [
233 | "['gs://flowers-experimental/flower_photos/daisy/100080576_f52e8ee070_n.jpg',\n",
234 | " 'gs://flowers-experimental/flower_photos/daisy/10140303196_b88d3d6cec.jpg',\n",
235 | " 'gs://flowers-experimental/flower_photos/daisy/10172379554_b296050f82_n.jpg',\n",
236 | " 'gs://flowers-experimental/flower_photos/daisy/10172567486_2748826a8b.jpg',\n",
237 | " 'gs://flowers-experimental/flower_photos/daisy/10172636503_21bededa75_n.jpg']\n"
238 | ],
239 | "name": "stdout"
240 | }
241 | ]
242 | },
243 | {
244 | "cell_type": "markdown",
245 | "metadata": {
246 | "id": "E116FS6sAleK"
247 | },
248 | "source": [
249 | "### Dataset splitting"
250 | ]
251 | },
252 | {
253 | "cell_type": "code",
254 | "metadata": {
255 | "colab": {
256 | "base_uri": "https://localhost:8080/"
257 | },
258 | "id": "1gyb4tbaxMFz",
259 | "outputId": "743ea2ef-683d-48c1-f981-161c3d1eaa9c"
260 | },
261 | "source": [
262 | "# Create splits.\n",
263 | "random.shuffle(gs_uris)\n",
264 | "\n",
265 | "i = int(len(gs_uris) * 0.9)\n",
266 | "train_paths = gs_uris[:i]\n",
267 | "test_paths = gs_uris[i:]\n",
268 | "\n",
269 | "i = int(len(train_paths) * 0.05)\n",
270 | "valid_paths = train_paths[:i]\n",
271 | "train_paths = train_paths[i:]\n",
272 | "\n",
273 | "print(len(train_paths), len(valid_paths), len(test_paths))"
274 | ],
275 | "execution_count": null,
276 | "outputs": [
277 | {
278 | "output_type": "stream",
279 | "text": [
280 | "3138 165 367\n"
281 | ],
282 | "name": "stdout"
283 | }
284 | ]
285 | },
286 | {
287 | "cell_type": "markdown",
288 | "metadata": {
289 | "id": "Xe9C4yYwA1Ih"
290 | },
291 | "source": [
292 | "### Utility for deriving the labels and `ML_USE`"
293 | ]
294 | },
295 | {
296 | "cell_type": "code",
297 | "metadata": {
298 | "id": "O4cg8hZS1Q1_"
299 | },
300 | "source": [
301 | "def derive_labels(gcs_paths, split=\"training\"):\n",
302 | " labels = []\n",
303 | " for gcs_path in gcs_paths:\n",
304 | " label = gcs_path.split(\"/\")[4]\n",
305 | " labels.append(label)\n",
306 | " return labels, [split] * len(gcs_paths)"
307 | ],
308 | "execution_count": null,
309 | "outputs": []
310 | },
311 | {
312 | "cell_type": "markdown",
313 | "metadata": {
314 | "id": "eXdAM-fvA6fv"
315 | },
316 | "source": [
317 | "### Prepare the lists"
318 | ]
319 | },
320 | {
321 | "cell_type": "code",
322 | "metadata": {
323 | "id": "sQDtn-Ra16ZS"
324 | },
325 | "source": [
326 | "# File format is referred from: https://cloud.google.com/vertex-ai/docs/datasets/prepare-image#csv\n",
327 | "train_labels, train_use = derive_labels(train_paths)\n",
328 | "val_labels, val_use = derive_labels(valid_paths, split=\"validation\")\n",
329 | "test_labels, test_use= derive_labels(test_paths, split=\"test\")"
330 | ],
331 | "execution_count": null,
332 | "outputs": []
333 | },
334 | {
335 | "cell_type": "markdown",
336 | "metadata": {
337 | "id": "7CYTVk0XBAKm"
338 | },
339 | "source": [
340 | "### Create `.csv` file"
341 | ]
342 | },
343 | {
344 | "cell_type": "code",
345 | "metadata": {
346 | "id": "lzBXCTox3jnJ"
347 | },
348 | "source": [
349 | "gcs_uris = []\n",
350 | "labels = []\n",
351 | "use = []\n",
352 | "\n",
353 | "gcs_uris.extend(train_paths)\n",
354 | "gcs_uris.extend(valid_paths)\n",
355 | "gcs_uris.extend(test_paths)\n",
356 | "\n",
357 | "labels.extend(train_labels)\n",
358 | "labels.extend(val_labels)\n",
359 | "labels.extend(test_labels)\n",
360 | "\n",
361 | "use.extend(train_use)\n",
362 | "use.extend(val_use)\n",
363 | "use.extend(test_use)"
364 | ],
365 | "execution_count": null,
366 | "outputs": []
367 | },
368 | {
369 | "cell_type": "code",
370 | "metadata": {
371 | "id": "iKzg-1Rh4EQ8"
372 | },
373 | "source": [
374 | "import csv\n",
375 | "\n",
376 | "with open(\"flowers_vertex.csv\", \"w\") as csvfile: \n",
377 | " csvwriter = csv.writer(csvfile)\n",
378 | " \n",
379 | " for ml_use, gcs_uri, label in zip(use, gcs_uris, labels):\n",
380 | " row = [ml_use, gcs_uri, label]\n",
381 | " csvwriter.writerow(row) "
382 | ],
383 | "execution_count": null,
384 | "outputs": []
385 | },
386 | {
387 | "cell_type": "code",
388 | "metadata": {
389 | "colab": {
390 | "base_uri": "https://localhost:8080/"
391 | },
392 | "id": "7vuj-j-A4y_C",
393 | "outputId": "921eac5c-9abe-49ef-a017-efeb03c5e4ae"
394 | },
395 | "source": [
396 | "!head -5 flowers_vertex.csv"
397 | ],
398 | "execution_count": null,
399 | "outputs": [
400 | {
401 | "output_type": "stream",
402 | "text": [
403 | "training,gs://flowers-experimental/flower_photos/sunflowers/4895721788_f10208ab77_n.jpg,sunflowers\r\n",
404 | "training,gs://flowers-experimental/flower_photos/sunflowers/8202034834_ee0ee91e04_n.jpg,sunflowers\r\n",
405 | "training,gs://flowers-experimental/flower_photos/daisy/19019544592_b64469bf84_n.jpg,daisy\r\n",
406 | "training,gs://flowers-experimental/flower_photos/dandelion/4634716478_1cbcbee7ca.jpg,dandelion\r\n",
407 | "training,gs://flowers-experimental/flower_photos/tulips/12163418275_bd6a1edd61.jpg,tulips\r\n"
408 | ],
409 | "name": "stdout"
410 | }
411 | ]
412 | },
413 | {
414 | "cell_type": "code",
415 | "metadata": {
416 | "colab": {
417 | "base_uri": "https://localhost:8080/"
418 | },
419 | "id": "LSxVlje9BJ_Q",
420 | "outputId": "e045fcb1-8063-41d8-e207-5c1f6d8e8e54"
421 | },
422 | "source": [
423 | "!tail -5 flowers_vertex.csv"
424 | ],
425 | "execution_count": null,
426 | "outputs": [
427 | {
428 | "output_type": "stream",
429 | "text": [
430 | "test,gs://flowers-experimental/flower_photos/roses/6363951285_a802238d4e.jpg,roses\r\n",
431 | "test,gs://flowers-experimental/flower_photos/dandelion/4571923094_b9cefa9438_n.jpg,dandelion\r\n",
432 | "test,gs://flowers-experimental/flower_photos/roses/2471103806_87ba53d997_n.jpg,roses\r\n",
433 | "test,gs://flowers-experimental/flower_photos/roses/12238827553_cf427bfd51_n.jpg,roses\r\n",
434 | "test,gs://flowers-experimental/flower_photos/roses/3663244576_97f595cf4a.jpg,roses\r\n"
435 | ],
436 | "name": "stdout"
437 | }
438 | ]
439 | },
440 | {
441 | "cell_type": "markdown",
442 | "metadata": {
443 | "id": "VgOMd-b-BQqM"
444 | },
445 | "source": [
446 | "## Copy over to a GCS Bucket"
447 | ]
448 | },
449 | {
450 | "cell_type": "code",
451 | "metadata": {
452 | "colab": {
453 | "base_uri": "https://localhost:8080/"
454 | },
455 | "id": "JSTu3ETJBTxP",
456 | "outputId": "7a89ba6b-5568-4441-ceed-db477c02993b"
457 | },
458 | "source": [
459 | "!gsutil cp flowers_vertex.csv {BUCKET_PATH}"
460 | ],
461 | "execution_count": null,
462 | "outputs": [
463 | {
464 | "output_type": "stream",
465 | "text": [
466 | "Copying file://flowers_vertex.csv [Content-Type=text/csv]...\n",
467 | "-\n",
468 | "Operation completed over 1 objects/334.7 KiB. \n"
469 | ],
470 | "name": "stdout"
471 | }
472 | ]
473 | }
474 | ]
475 | }
--------------------------------------------------------------------------------
/notebooks/Dual_Deployments_With_AutoML.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "nbformat": 4,
3 | "nbformat_minor": 5,
4 | "metadata": {
5 | "environment": {
6 | "name": "common-cpu.m75",
7 | "type": "gcloud",
8 | "uri": "gcr.io/deeplearning-platform-release/base-cpu:m75"
9 | },
10 | "kernelspec": {
11 | "display_name": "Python 3",
12 | "language": "python",
13 | "name": "python3"
14 | },
15 | "language_info": {
16 | "codemirror_mode": {
17 | "name": "ipython",
18 | "version": 3
19 | },
20 | "file_extension": ".py",
21 | "mimetype": "text/x-python",
22 | "name": "python",
23 | "nbconvert_exporter": "python",
24 | "pygments_lexer": "ipython3",
25 | "version": "3.7.10"
26 | },
27 | "colab": {
28 | "name": "dual_deployments_with_AutoML.ipynb",
29 | "provenance": [],
30 | "include_colab_link": true
31 | }
32 | },
33 | "cells": [
34 | {
35 | "cell_type": "markdown",
36 | "metadata": {
37 | "id": "view-in-github",
38 | "colab_type": "text"
39 | },
40 | "source": [
41 | "
"
42 | ]
43 | },
44 | {
45 | "cell_type": "markdown",
46 | "metadata": {
47 | "id": "XlRzusV8MDT7"
48 | },
49 | "source": [
50 | "This notebook shows how to build a dual deployment pipeline in Vertex AI (one for tflite model export to firebase, normal model to be deployed in Vertex AI Endpoint). \n",
51 | "\n",
52 | "Before proceeding, please make sure the pre-requisites.\n",
53 | "- billing-enabled GCP account.\n",
54 | "- enable Vertex AI, Google Storage Service, Firebase APIs\n",
55 | "- add roles of Vertex AI Custom Code Service Agent, Vertex AI Service Agent, Vertex AI User, AI Platform Notebooks Service Agent, Storage Object Admin, Storage Object Creator \n",
56 | "- follow the steps described from `Before you begin` section in the official [Deploy and manage custom models with Firebase Admin SDK](https://firebase.google.com/docs/ml/manage-hosted-models) documentation to enable Firebase"
57 | ],
58 | "id": "XlRzusV8MDT7"
59 | },
60 | {
61 | "cell_type": "markdown",
62 | "metadata": {
63 | "id": "61cd40e1"
64 | },
65 | "source": [
66 | "# Setup"
67 | ],
68 | "id": "61cd40e1"
69 | },
70 | {
71 | "cell_type": "markdown",
72 | "metadata": {
73 | "id": "0de28066"
74 | },
75 | "source": [
76 | "## Install packages\n",
77 | "- `firebase-admin` will be used to publish TFLite model to firebase"
78 | ],
79 | "id": "0de28066"
80 | },
81 | {
82 | "cell_type": "code",
83 | "metadata": {
84 | "id": "196d7830"
85 | },
86 | "source": [
87 | "import os\n",
88 | "\n",
89 | "# The Google Cloud Notebook product has specific requirements\n",
90 | "IS_GOOGLE_CLOUD_NOTEBOOK = os.path.exists(\"/opt/deeplearning/metadata/env_version\")\n",
91 | "\n",
92 | "# Google Cloud Notebook requires dependencies to be installed with '--user'\n",
93 | "USER_FLAG = \"\"\n",
94 | "if IS_GOOGLE_CLOUD_NOTEBOOK:\n",
95 | " USER_FLAG = \"--user\""
96 | ],
97 | "id": "196d7830",
98 | "execution_count": null,
99 | "outputs": []
100 | },
101 | {
102 | "cell_type": "code",
103 | "metadata": {
104 | "id": "0c708f00"
105 | },
106 | "source": [
107 | "!pip3 install {USER_FLAG} google-cloud-aiplatform --upgrade\n",
108 | "!pip3 install {USER_FLAG} kfp google-cloud-pipeline-components --upgrade\n",
109 | "!pip3 install {USER_FLAG} firebase-admin"
110 | ],
111 | "id": "0c708f00",
112 | "execution_count": null,
113 | "outputs": []
114 | },
115 | {
116 | "cell_type": "markdown",
117 | "metadata": {
118 | "id": "812c875e"
119 | },
120 | "source": [
121 | "## Restart Jupyter Notebook programatically\n",
122 | "- in order to reflect the package installations"
123 | ],
124 | "id": "812c875e"
125 | },
126 | {
127 | "cell_type": "code",
128 | "metadata": {
129 | "id": "92fb4dd6"
130 | },
131 | "source": [
132 | "# Automatically restart kernel after installs\n",
133 | "import os\n",
134 | "\n",
135 | "if not os.getenv(\"IS_TESTING\"):\n",
136 | " # Automatically restart kernel after installs\n",
137 | " import IPython\n",
138 | "\n",
139 | " app = IPython.Application.instance()\n",
140 | " app.kernel.do_shutdown(True)"
141 | ],
142 | "id": "92fb4dd6",
143 | "execution_count": null,
144 | "outputs": []
145 | },
146 | {
147 | "cell_type": "markdown",
148 | "metadata": {
149 | "id": "5f1f65e9"
150 | },
151 | "source": [
152 | "## Check KFP(KubeFlow Pipeline) version"
153 | ],
154 | "id": "5f1f65e9"
155 | },
156 | {
157 | "cell_type": "code",
158 | "metadata": {
159 | "id": "55234629",
160 | "outputId": "e5ebdc8d-9ed9-4057-b31c-1d460d4fce4f"
161 | },
162 | "source": [
163 | "!python3 -c \"import kfp; print('KFP SDK version: {}'.format(kfp.__version__))\""
164 | ],
165 | "id": "55234629",
166 | "execution_count": null,
167 | "outputs": [
168 | {
169 | "output_type": "stream",
170 | "text": [
171 | "KFP SDK version: 1.6.6\n"
172 | ],
173 | "name": "stdout"
174 | }
175 | ]
176 | },
177 | {
178 | "cell_type": "markdown",
179 | "metadata": {
180 | "id": "c437adf1"
181 | },
182 | "source": [
183 | "## Setup GCP Project ID"
184 | ],
185 | "id": "c437adf1"
186 | },
187 | {
188 | "cell_type": "code",
189 | "metadata": {
190 | "id": "3a4f1da1",
191 | "outputId": "0d06cedc-bf9a-4b49-87bc-9ecc52a069c9"
192 | },
193 | "source": [
194 | "import os\n",
195 | "\n",
196 | "PROJECT_ID = \"grounded-atrium-320207\"\n",
197 | "\n",
198 | "# Get your Google Cloud project ID from gcloud\n",
199 | "if not os.getenv(\"IS_TESTING\"):\n",
200 | " shell_output=!gcloud config list --format 'value(core.project)' 2>/dev/null\n",
201 | " PROJECT_ID = shell_output[0]\n",
202 | " print(\"Project ID: \", PROJECT_ID)"
203 | ],
204 | "id": "3a4f1da1",
205 | "execution_count": null,
206 | "outputs": [
207 | {
208 | "output_type": "stream",
209 | "text": [
210 | "Project ID: grounded-atrium-320207\n"
211 | ],
212 | "name": "stdout"
213 | }
214 | ]
215 | },
216 | {
217 | "cell_type": "markdown",
218 | "metadata": {
219 | "id": "950c18ba"
220 | },
221 | "source": [
222 | "## GCP Authentication"
223 | ],
224 | "id": "950c18ba"
225 | },
226 | {
227 | "cell_type": "code",
228 | "metadata": {
229 | "id": "2e58f1fd"
230 | },
231 | "source": [
232 | "import os\n",
233 | "import sys\n",
234 | "\n",
235 | "# If you are running this notebook in Colab, run this cell and follow the\n",
236 | "# instructions to authenticate your GCP account. This provides access to your\n",
237 | "# Cloud Storage bucket and lets you submit training jobs and prediction\n",
238 | "# requests.\n",
239 | "\n",
240 | "# The Google Cloud Notebook product has specific requirements\n",
241 | "IS_GOOGLE_CLOUD_NOTEBOOK = os.path.exists(\"/opt/deeplearning/metadata/env_version\")\n",
242 | "\n",
243 | "# If on Google Cloud Notebooks, then don't execute this code\n",
244 | "if not IS_GOOGLE_CLOUD_NOTEBOOK:\n",
245 | " if \"google.colab\" in sys.modules:\n",
246 | " from google.colab import auth as google_auth\n",
247 | "\n",
248 | " google_auth.authenticate_user()\n",
249 | "\n",
250 | " # If you are running this notebook locally, replace the string below with the\n",
251 | " # path to your service account key and run this cell to authenticate your GCP\n",
252 | " # account.\n",
253 | " elif not os.getenv(\"IS_TESTING\"):\n",
254 | " %env GOOGLE_APPLICATION_CREDENTIALS ''"
255 | ],
256 | "id": "2e58f1fd",
257 | "execution_count": null,
258 | "outputs": []
259 | },
260 | {
261 | "cell_type": "markdown",
262 | "metadata": {
263 | "id": "7221dad0"
264 | },
265 | "source": [
266 | "## Setup GCS Bucket name\n",
267 | "- this bucket is where everything is going to stored"
268 | ],
269 | "id": "7221dad0"
270 | },
271 | {
272 | "cell_type": "code",
273 | "metadata": {
274 | "id": "6d49e084"
275 | },
276 | "source": [
277 | "BUCKET_NAME = \"gs://vertexai_dual_example\"\n",
278 | "REGION = \"us-central1\" "
279 | ],
280 | "id": "6d49e084",
281 | "execution_count": null,
282 | "outputs": []
283 | },
284 | {
285 | "cell_type": "markdown",
286 | "metadata": {
287 | "id": "b1690ebb"
288 | },
289 | "source": [
290 | "### Create GCS Bucket"
291 | ],
292 | "id": "b1690ebb"
293 | },
294 | {
295 | "cell_type": "code",
296 | "metadata": {
297 | "id": "f3e85ef1",
298 | "outputId": "2b2854e3-e372-473a-ae0f-97efe8d9521e"
299 | },
300 | "source": [
301 | "!gsutil mb -l $REGION $BUCKET_NAME"
302 | ],
303 | "id": "f3e85ef1",
304 | "execution_count": null,
305 | "outputs": [
306 | {
307 | "output_type": "stream",
308 | "text": [
309 | "Creating gs://vertexai_dual_example/...\n"
310 | ],
311 | "name": "stdout"
312 | }
313 | ]
314 | },
315 | {
316 | "cell_type": "markdown",
317 | "metadata": {
318 | "id": "5f1d553a"
319 | },
320 | "source": [
321 | "## Setup GCS Path for Pipeline\n",
322 | "- the pipeline runs are going to be stored (i.e. Metadata)"
323 | ],
324 | "id": "5f1d553a"
325 | },
326 | {
327 | "cell_type": "code",
328 | "metadata": {
329 | "id": "f0c61275",
330 | "outputId": "12d1b068-f9b6-4105-e151-ec5c935d4c4d"
331 | },
332 | "source": [
333 | "PATH=%env PATH\n",
334 | "%env PATH={PATH}:/home/jupyter/.local/bin\n",
335 | "\n",
336 | "USER = \"chansung\"\n",
337 | "PIPELINE_ROOT = \"{}/pipeline_root/{}\".format(BUCKET_NAME, USER)\n",
338 | "PIPELINE_ROOT"
339 | ],
340 | "id": "f0c61275",
341 | "execution_count": null,
342 | "outputs": [
343 | {
344 | "output_type": "stream",
345 | "text": [
346 | "env: PATH=/opt/conda/bin:/opt/conda/condabin:/usr/local/bin:/usr/bin:/bin:/usr/local/games:/usr/games:/home/jupyter/.local/bin\n"
347 | ],
348 | "name": "stdout"
349 | },
350 | {
351 | "output_type": "execute_result",
352 | "data": {
353 | "text/plain": [
354 | "'gs://vertexai_dual_example/pipeline_root/chansung'"
355 | ]
356 | },
357 | "metadata": {
358 | "tags": []
359 | },
360 | "execution_count": 8
361 | }
362 | ]
363 | },
364 | {
365 | "cell_type": "markdown",
366 | "metadata": {
367 | "id": "6986b40b"
368 | },
369 | "source": [
370 | "## Build Pipeline"
371 | ],
372 | "id": "6986b40b"
373 | },
374 | {
375 | "cell_type": "markdown",
376 | "metadata": {
377 | "id": "f6ff09f9"
378 | },
379 | "source": [
380 | "### Install packages\n",
381 | "- KFP related\n",
382 | "- Client API for AI Platform"
383 | ],
384 | "id": "f6ff09f9"
385 | },
386 | {
387 | "cell_type": "code",
388 | "metadata": {
389 | "id": "9a75acd2"
390 | },
391 | "source": [
392 | "import kfp\n",
393 | "from google.cloud import aiplatform\n",
394 | "from google_cloud_pipeline_components import aiplatform as gcc_aip\n",
395 | "from kfp.v2 import compiler\n",
396 | "from kfp.v2.google.client import AIPlatformClient\n",
397 | "from kfp.v2 import dsl\n",
398 | "from kfp.v2.dsl import component"
399 | ],
400 | "id": "9a75acd2",
401 | "execution_count": null,
402 | "outputs": []
403 | },
404 | {
405 | "cell_type": "markdown",
406 | "metadata": {
407 | "id": "e5e21652"
408 | },
409 | "source": [
410 | "### Define pipeline component to publish TFLite model to Firebase\n",
411 | "0. please follow the steps described from `Before you begin` section in the official [Deploy and manage custom models with Firebase Admin SDK](https://firebase.google.com/docs/ml/manage-hosted-models) documentation. \n",
412 | "1. download credentials for the Firebase project\n",
413 | "2. download TFLite model file\n",
414 | "3. initialize firebase admin\n",
415 | "4. upload and publish TFLite model from local file to Firebase"
416 | ],
417 | "id": "e5e21652"
418 | },
419 | {
420 | "cell_type": "code",
421 | "metadata": {
422 | "id": "57076987"
423 | },
424 | "source": [
425 | "@component(\n",
426 | " packages_to_install=[\"google-cloud-storage\", \"firebase-admin\", \"tensorflow\"]\n",
427 | ")\n",
428 | "def push_to_firebase(\n",
429 | " credential_uri: str,\n",
430 | " model_bucket: str,\n",
431 | " firebase_dest_gcs_bucket: str,\n",
432 | " model_display_name: str,\n",
433 | " model_tag: str\n",
434 | "):\n",
435 | " import firebase_admin\n",
436 | " from firebase_admin import ml\n",
437 | " from firebase_admin import storage\n",
438 | " from firebase_admin import credentials \n",
439 | " from google.cloud import storage as gcs_storage\n",
440 | " \n",
441 | " gcs_client = gcs_storage.Client()\n",
442 | " \n",
443 | " # get credential for firebase \n",
444 | " credential_gcs_bucket = credential_uri.split('//')[1].split('/')[0]\n",
445 | " credential_blob_path = '/'.join(credential_uri.split('//')[1].split('/')[1:])\n",
446 | " \n",
447 | " bucket = gcs_client.bucket(credential_gcs_bucket)\n",
448 | " blob = bucket.blob(credential_blob_path)\n",
449 | " blob.download_to_filename('credential.json')\n",
450 | " \n",
451 | " # get the latest model \n",
452 | " tflite_blobs = gcs_client.get_bucket(model_bucket).list_blobs()\n",
453 | " tflite_blob = sorted(tflite_blobs, reverse=True, key=lambda blob: blob.name.split('/')[-2])[0]\n",
454 | " tflite_blob.download_to_filename('model.tflite') \n",
455 | " \n",
456 | " firebase_admin.initialize_app(\n",
457 | " credentials.Certificate('credential.json'),\n",
458 | " options={\n",
459 | " 'storageBucket': firebase_dest_gcs_bucket\n",
460 | " }\n",
461 | " )\n",
462 | "\n",
463 | " model_list = ml.list_models(list_filter=f'display_name={model_display_name}')\n",
464 | " # update\n",
465 | " if len(model_list.models) > 0:\n",
466 | " # get the first match model\n",
467 | " model = model_list.models[0]\n",
468 | " \n",
469 | " source = ml.TFLiteGCSModelSource.from_tflite_model_file('model.tflite')\n",
470 | " model.model_format = ml.TFLiteFormat(model_source=source)\n",
471 | " \n",
472 | " updated_model = ml.update_model(model)\n",
473 | " ml.publish_model(updated_model.model_id)\n",
474 | " # create\n",
475 | " else: \n",
476 | " # Load a tflite file and upload it to Cloud Storage\n",
477 | " source = ml.TFLiteGCSModelSource.from_tflite_model_file('model.tflite')\n",
478 | "\n",
479 | " # Create the model object\n",
480 | " tflite_format = ml.TFLiteFormat(model_source=source)\n",
481 | " model = ml.Model(\n",
482 | " display_name=model_display_name, # This is the name you use from your app to load the model.\n",
483 | " tags=[model_tag], # Optional tags for easier management.\n",
484 | " model_format=tflite_format)\n",
485 | "\n",
486 | " # Add the model to your Firebase project and publish it\n",
487 | " new_model = ml.create_model(model)\n",
488 | " ml.publish_model(new_model.model_id)"
489 | ],
490 | "id": "57076987",
491 | "execution_count": null,
492 | "outputs": []
493 | },
494 | {
495 | "cell_type": "markdown",
496 | "metadata": {
497 | "id": "c9d14770"
498 | },
499 | "source": [
500 | "### Define pipeline itself\n",
501 | "1. Create Vertex AI's managed dataset from CSV\n",
502 | "2. Define configs, one for cloud model, and the other one for mobile model\n",
503 | "3. Run parallel processing for two different workflow(each workflow is configured appropriate for each target environment)\n",
504 | " - AutoML training can be configured differently for each target environment depending on `model_type`\n",
505 | "4. Deploying the trained model as well as creating an endpoint is done with `ModelDeployOp` for cloud model\n",
506 | "5. Export the trained mobile model to a GCS bucket\n",
507 | " - publish the exported mobile model to Firebase through push_to_firebase component\n",
508 | " "
509 | ],
510 | "id": "c9d14770"
511 | },
512 | {
513 | "cell_type": "code",
514 | "metadata": {
515 | "id": "22e3d30c"
516 | },
517 | "source": [
518 | "@kfp.dsl.pipeline(name=\"cloud-mobile-dual-deployment\")\n",
519 | "def pipeline(project: str = PROJECT_ID):\n",
520 | " ds_op = gcc_aip.ImageDatasetCreateOp(\n",
521 | " project=project,\n",
522 | " display_name=\"flowers-dataset\",\n",
523 | " gcs_source=\"gs://dataset-meta-gde-csp/flowers_vertex.csv\",\n",
524 | " import_schema_uri=aiplatform.schema.dataset.ioformat.image.multi_label_classification,\n",
525 | " )\n",
526 | "\n",
527 | " configs = [\n",
528 | " {\n",
529 | " \"type\": \"CLOUD\",\n",
530 | " \"model_type\": \"CLOUD\",\n",
531 | " \"display_name\": \"train-cloud-model\",\n",
532 | " \"model_display_name\": \"cloud-model\",\n",
533 | " \"budget_milli_node_hours\": 8000,\n",
534 | " },\n",
535 | " {\n",
536 | " \"type\": \"MOBILE\",\n",
537 | " \"model_type\": \"MOBILE_TF_VERSATILE_1\",\n",
538 | " \"display_name\": \"train-mobile-model\",\n",
539 | " \"model_display_name\": \"mobile-model\",\n",
540 | " \"budget_milli_node_hours\": 1000,\n",
541 | " }\n",
542 | " ]\n",
543 | "\n",
544 | " with kfp.dsl.ParallelFor(configs) as config:\n",
545 | " training_job_run_op = gcc_aip.AutoMLImageTrainingJobRunOp(\n",
546 | " project=project,\n",
547 | " display_name=config.display_name,\n",
548 | " prediction_type=\"classification\",\n",
549 | " multi_label=True,\n",
550 | " model_type=config.model_type,\n",
551 | " base_model=None,\n",
552 | " dataset=ds_op.outputs[\"dataset\"],\n",
553 | " model_display_name=config.model_display_name,\n",
554 | " budget_milli_node_hours=config.budget_milli_node_hours,\n",
555 | " )\n",
556 | " training_job_run_op.after(ds_op)\n",
557 | "\n",
558 | " with kfp.dsl.Condition(config.type=='CLOUD'):\n",
559 | " endpoint_op = gcc_aip.ModelDeployOp(\n",
560 | " project=project,\n",
561 | " model=training_job_run_op.outputs[\"model\"]\n",
562 | " )\n",
563 | " endpoint_op.after(training_job_run_op)\n",
564 | " \n",
565 | " with kfp.dsl.Condition(config.type=='MOBILE'):\n",
566 | " export_op = gcc_aip.ModelExportOp( \n",
567 | " project=project,\n",
568 | " model=training_job_run_op.outputs[\"model\"],\n",
569 | " # tflite, edgetpu-tflite, tf-saved-model, tf-js, core-ml, custom-trained\n",
570 | " export_format_id=\"tflite\",\n",
571 | " artifact_destination=\"gs://output-model-gde-csp/flower-models/\"\n",
572 | " )\n",
573 | " export_op.after(training_job_run_op)\n",
574 | "\n",
575 | " credential_uri=\"gs://firebase-ml-bucket-gde-csp/grounded-atrium-320207-firebase-adminsdk-5n9sn-20dbda9947.json\"\n",
576 | " model_bucket=\"output-model-gde-csp\"\n",
577 | " firebase_bucket=\"grounded-atrium-320207.appspot.com\"\n",
578 | " \n",
579 | " firebase_op = push_to_firebase(\n",
580 | " ins=export_op.outputs['exported_dataset'],\n",
581 | " credential_uri=credential_uri,\n",
582 | " model_bucket=model_bucket,\n",
583 | " firebase_dest_gcs_bucket=firebase_bucket,\n",
584 | " model_display_name=\"custom_model\",\n",
585 | " model_tag=\"from_dual_deployment\"\n",
586 | " )\n",
587 | " firebase_op.after(export_op)\n",
588 | " "
589 | ],
590 | "id": "22e3d30c",
591 | "execution_count": null,
592 | "outputs": []
593 | },
594 | {
595 | "cell_type": "markdown",
596 | "metadata": {
597 | "id": "59a52203"
598 | },
599 | "source": [
600 | "### Compile the pipeline\n",
601 | "- you will get a json file for the pipeline spec after compiling.\n",
602 | " - you will only need this json file to run the pipeline"
603 | ],
604 | "id": "59a52203"
605 | },
606 | {
607 | "cell_type": "code",
608 | "metadata": {
609 | "id": "6b0cdaef"
610 | },
611 | "source": [
612 | "from kfp.v2 import compiler\n",
613 | "\n",
614 | "compiler.Compiler().compile(\n",
615 | " pipeline_func=pipeline, package_path=\"cloud-mobile-dual-deployment.json\"\n",
616 | ")"
617 | ],
618 | "id": "6b0cdaef",
619 | "execution_count": null,
620 | "outputs": []
621 | },
622 | {
623 | "cell_type": "markdown",
624 | "metadata": {
625 | "id": "7ec1f64f"
626 | },
627 | "source": [
628 | "## Run the pipeline on Vertex AI Pipeline"
629 | ],
630 | "id": "7ec1f64f"
631 | },
632 | {
633 | "cell_type": "markdown",
634 | "metadata": {
635 | "id": "fcbb6c14"
636 | },
637 | "source": [
638 | "### Create client instance to AI Platform (which is Vertex AI)"
639 | ],
640 | "id": "fcbb6c14"
641 | },
642 | {
643 | "cell_type": "code",
644 | "metadata": {
645 | "id": "cc9a0a61"
646 | },
647 | "source": [
648 | "from kfp.v2.google.client import AIPlatformClient\n",
649 | "\n",
650 | "api_client = AIPlatformClient(project_id=PROJECT_ID, region=REGION)"
651 | ],
652 | "id": "cc9a0a61",
653 | "execution_count": null,
654 | "outputs": []
655 | },
656 | {
657 | "cell_type": "markdown",
658 | "metadata": {
659 | "id": "07274930"
660 | },
661 | "source": [
662 | "### Run the pipeline with the pipeline spec (json file)"
663 | ],
664 | "id": "07274930"
665 | },
666 | {
667 | "cell_type": "code",
668 | "metadata": {
669 | "id": "f485821f",
670 | "outputId": "a76c399e-70ff-4fe2-e476-772cf47e932c"
671 | },
672 | "source": [
673 | "response = api_client.create_run_from_job_spec(\n",
674 | " \"cloud-mobile-dual-deployment.json\",\n",
675 | " pipeline_root=PIPELINE_ROOT,\n",
676 | " parameter_values={\"project\": PROJECT_ID},\n",
677 | ")"
678 | ],
679 | "id": "f485821f",
680 | "execution_count": null,
681 | "outputs": [
682 | {
683 | "output_type": "display_data",
684 | "data": {
685 | "text/html": [
686 | "See the Pipeline job here."
687 | ],
688 | "text/plain": [
689 | ""
690 | ]
691 | },
692 | "metadata": {
693 | "tags": []
694 | }
695 | }
696 | ]
697 | }
698 | ]
699 | }
--------------------------------------------------------------------------------
/notebooks/Custom_Model_TFX.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "nbformat": 4,
3 | "nbformat_minor": 0,
4 | "metadata": {
5 | "colab": {
6 | "name": "Custom_Model_TFX",
7 | "provenance": [],
8 | "collapsed_sections": [],
9 | "machine_shape": "hm",
10 | "include_colab_link": true
11 | },
12 | "environment": {
13 | "name": "tf2-gpu.2-4.mnightly-2021-02-02-debian-10-test",
14 | "type": "gcloud",
15 | "uri": "gcr.io/deeplearning-platform-release/tf2-gpu.2-4:mnightly-2021-02-02-debian-10-test"
16 | },
17 | "kernelspec": {
18 | "display_name": "Python 3",
19 | "language": "python",
20 | "name": "python3"
21 | },
22 | "language_info": {
23 | "codemirror_mode": {
24 | "name": "ipython",
25 | "version": 3
26 | },
27 | "file_extension": ".py",
28 | "mimetype": "text/x-python",
29 | "name": "python",
30 | "nbconvert_exporter": "python",
31 | "pygments_lexer": "ipython3",
32 | "version": "3.7.4"
33 | }
34 | },
35 | "cells": [
36 | {
37 | "cell_type": "markdown",
38 | "metadata": {
39 | "id": "view-in-github",
40 | "colab_type": "text"
41 | },
42 | "source": [
43 | "
"
44 | ]
45 | },
46 | {
47 | "cell_type": "markdown",
48 | "metadata": {
49 | "id": "mTVp-9PGYFIO"
50 | },
51 | "source": [
52 | "In this notebook, we will build two custom models - one for Endpoint deployment and the other one for mobile deployment. The mobile-friendly model will be first converted to TFLite and will then be pushed to Firebase because it provides enhanced support to operate with TFLite models in production. We will write a TFX pipeline to run their training and export. The entire pipeline will be orchestrated using [Vertex AI Pipelines](https://cloud.google.com/vertex-ai/docs/pipelines/introduction). \n",
53 | "\n",
54 | "This notebook assumes you are familiar with the basics of Vertex AI, TFX (especially custom components), and TensorFlow. "
55 | ]
56 | },
57 | {
58 | "cell_type": "markdown",
59 | "metadata": {
60 | "id": "W7gJqmqrsfqh"
61 | },
62 | "source": [
63 | "## References\n",
64 | "\n",
65 | "This notebook refers to the following resources and also reuses parts of the code from there: \n",
66 | "* [Simple TFX Pipeline for Vertex Pipelines](https://colab.research.google.com/github/tensorflow/tfx/blob/master/docs/tutorials/tfx/gcp/vertex_pipelines_simple.ipynb)\n",
67 | "* [Vertex AI Training with TFX and Vertex Pipelines](https://www.tensorflow.org/tfx/tutorials/tfx/gcp/vertex_pipelines_vertex_training)\n",
68 | "* [Importing models to Vertex AI](https://cloud.google.com/vertex-ai/docs/general/import-model)\n",
69 | "* [Deploying a model using the Vertex AI API](https://cloud.google.com/vertex-ai/docs/predictions/deploy-model-api)\n",
70 | "* [MLOPs with Vertex AI](https://github.com/GoogleCloudPlatform/mlops-with-vertex-ai)\n",
71 | "* [Custom components TFX](https://www.tensorflow.org/tfx/tutorials/tfx/python_function_component)"
72 | ]
73 | },
74 | {
75 | "cell_type": "markdown",
76 | "metadata": {
77 | "id": "D04aKMGWXjOu"
78 | },
79 | "source": [
80 | "## Setup"
81 | ]
82 | },
83 | {
84 | "cell_type": "code",
85 | "metadata": {
86 | "id": "I_niUhp_TY1G"
87 | },
88 | "source": [
89 | "# Use the latest version of pip.\n",
90 | "%%capture\n",
91 | "!pip install --upgrade pip\n",
92 | "!pip install --upgrade tfx==1.0.0 kfp==1.6.1\n",
93 | "!pip install -q --upgrade google-cloud-aiplatform"
94 | ],
95 | "execution_count": null,
96 | "outputs": []
97 | },
98 | {
99 | "cell_type": "markdown",
100 | "metadata": {
101 | "id": "ZVmgQ6w1oT_Z"
102 | },
103 | "source": [
104 | "### ***Please restart runtime before continuing.*** "
105 | ]
106 | },
107 | {
108 | "cell_type": "code",
109 | "metadata": {
110 | "id": "mstgsNHWoiXk"
111 | },
112 | "source": [
113 | "!gcloud init"
114 | ],
115 | "execution_count": null,
116 | "outputs": []
117 | },
118 | {
119 | "cell_type": "code",
120 | "metadata": {
121 | "id": "Pl8ewjX3oXRx"
122 | },
123 | "source": [
124 | "from google.colab import auth\n",
125 | "auth.authenticate_user()"
126 | ],
127 | "execution_count": null,
128 | "outputs": []
129 | },
130 | {
131 | "cell_type": "markdown",
132 | "metadata": {
133 | "id": "zqVWpmywXngD"
134 | },
135 | "source": [
136 | "## Imports"
137 | ]
138 | },
139 | {
140 | "cell_type": "code",
141 | "metadata": {
142 | "colab": {
143 | "base_uri": "https://localhost:8080/"
144 | },
145 | "id": "wptXF0e-UXsT",
146 | "outputId": "1bb97d2e-f3f5-419f-e987-b2847a6adb8a"
147 | },
148 | "source": [
149 | "import tensorflow as tf\n",
150 | "print('TensorFlow version: {}'.format(tf.__version__))\n",
151 | "from tfx import v1 as tfx\n",
152 | "print('TFX version: {}'.format(tfx.__version__))\n",
153 | "import kfp\n",
154 | "print('KFP version: {}'.format(kfp.__version__))\n",
155 | "\n",
156 | "from google.cloud import aiplatform as vertex_ai\n",
157 | "import os"
158 | ],
159 | "execution_count": null,
160 | "outputs": [
161 | {
162 | "output_type": "stream",
163 | "text": [
164 | "TensorFlow version: 2.5.0\n",
165 | "TFX version: 1.0.0\n",
166 | "KFP version: 1.6.1\n"
167 | ],
168 | "name": "stdout"
169 | }
170 | ]
171 | },
172 | {
173 | "cell_type": "markdown",
174 | "metadata": {
175 | "id": "hFYHeepnXxpZ"
176 | },
177 | "source": [
178 | "## Environment setup"
179 | ]
180 | },
181 | {
182 | "cell_type": "code",
183 | "metadata": {
184 | "id": "zPVyBrXrW-vu"
185 | },
186 | "source": [
187 | "GOOGLE_CLOUD_PROJECT = 'fast-ai-exploration' #@param {type:\"string\"}\n",
188 | "GOOGLE_CLOUD_REGION = 'us-central1' #@param {type:\"string\"}\n",
189 | "GCS_BUCKET_NAME = 'vertex-tfx-mlops' #@param {type:\"string\"}\n",
190 | "\n",
191 | "if not (GOOGLE_CLOUD_PROJECT and GOOGLE_CLOUD_REGION and GCS_BUCKET_NAME):\n",
192 | " from absl import logging\n",
193 | " logging.error('Please set all required parameters.')"
194 | ],
195 | "execution_count": null,
196 | "outputs": []
197 | },
198 | {
199 | "cell_type": "markdown",
200 | "metadata": {
201 | "id": "CV-BZSvQq7YY"
202 | },
203 | "source": [
204 | "The location of the bucket must be a single region. Also, the bucket needs to be created in a region when [Vertex AI services are available](https://cloud.google.com/vertex-ai/docs/general/locations#available_regions). "
205 | ]
206 | },
207 | {
208 | "cell_type": "code",
209 | "metadata": {
210 | "colab": {
211 | "base_uri": "https://localhost:8080/"
212 | },
213 | "id": "J65KHrt4X-Fu",
214 | "outputId": "085fe0e1-02fd-47e3-bc6e-ee24ad82b1b2"
215 | },
216 | "source": [
217 | "PIPELINE_NAME = 'two-way-vertex-pipelines5'\n",
218 | "\n",
219 | "# Path to various pipeline artifact.\n",
220 | "PIPELINE_ROOT = 'gs://{}/pipeline_root/{}'.format(\n",
221 | " GCS_BUCKET_NAME, PIPELINE_NAME)\n",
222 | "\n",
223 | "# Paths for users' Python module.\n",
224 | "MODULE_ROOT = 'gs://{}/pipeline_module/{}'.format(\n",
225 | " GCS_BUCKET_NAME, PIPELINE_NAME)\n",
226 | "\n",
227 | "# Paths for input data.\n",
228 | "DATA_ROOT = 'gs://flowers-public/tfrecords-jpeg-224x224'\n",
229 | "\n",
230 | "# This is the path where your model will be pushed for serving.\n",
231 | "SERVING_MODEL_DIR = 'gs://{}/serving_model/{}'.format(\n",
232 | " GCS_BUCKET_NAME, PIPELINE_NAME)\n",
233 | "\n",
234 | "print('PIPELINE_ROOT: {}'.format(PIPELINE_ROOT))"
235 | ],
236 | "execution_count": null,
237 | "outputs": [
238 | {
239 | "output_type": "stream",
240 | "text": [
241 | "PIPELINE_ROOT: gs://demo-experiments-gde-csp/pipeline_root/two-way-vertex-pipelines5\n"
242 | ],
243 | "name": "stdout"
244 | }
245 | ]
246 | },
247 | {
248 | "cell_type": "markdown",
249 | "metadata": {
250 | "id": "dIfd3_x0bfTu"
251 | },
252 | "source": [
253 | "There are three things you need for Firebase setup, and you can find a sufficient description for those in [Deploy and manage custom models with Firebase Admin SDK](https://firebase.google.com/docs/ml/manage-hosted-models) official document. \n",
254 | "\n",
255 | "1. You need to obtain credential file for your [Firebase project](https://console.firebase.google.com/).\n",
256 | " - On the [Settings](https://console.firebase.google.com/project/_/settings/serviceaccounts/adminsdk) page, create a service account and download the service account key file. Keep this file safe, since it grants administrator access to your project.\n",
257 | " - Save the credential JSON file in a GCS bucket, and replace the `FIREBASE_CREDENTIAL_PATH` value with it.\n",
258 | "\n",
259 | "\n",
260 | "2. You need to create a GCS bucket where the model is going to be temporarily stored.\n",
261 | " - On the Storage page, enable Cloud Storage. Take note of your bucket name.\n",
262 | " - Replace `FIREBASE_GCS_BUCKET` with the obtained GCS bucket name. It usually has this form `YOUR_GCP_PROJECT_ID.appspot.com`.\n",
263 | "\n",
264 | "\n",
265 | "3. On the Firebase ML page, click Get started if you haven't yet enabled Firebase ML."
266 | ]
267 | },
268 | {
269 | "cell_type": "code",
270 | "metadata": {
271 | "id": "fWkhLCv_bfTu"
272 | },
273 | "source": [
274 | "FIREBASE_CREDENTIAL_PATH = 'gs://credential-csp/gcp-ml-172005-firebase-adminsdk-5gdtb-38c6644f1e.json'\n",
275 | "FIREBASE_GCS_BUCKET = 'gcp-ml-172005.appspot.com'"
276 | ],
277 | "execution_count": null,
278 | "outputs": []
279 | },
280 | {
281 | "cell_type": "markdown",
282 | "metadata": {
283 | "id": "kQVpzyftX0y0"
284 | },
285 | "source": [
286 | "## Create training modules"
287 | ]
288 | },
289 | {
290 | "cell_type": "code",
291 | "metadata": {
292 | "id": "AR5pU65m6nAE"
293 | },
294 | "source": [
295 | "_trainer_densenet_module_file = 'flower_densenet_trainer.py'\n",
296 | "_trainer_mobilenet_module_file = 'flower_mobilenet_trainer.py'"
297 | ],
298 | "execution_count": null,
299 | "outputs": []
300 | },
301 | {
302 | "cell_type": "markdown",
303 | "source": [
304 | "`_model_exporter()` utility comes from [Practical Machine Learning for Computer Vision](https://www.oreilly.com/library/view/practical-machine-learning/9781098102357/)."
305 | ],
306 | "metadata": {
307 | "id": "3vUUTikER33Y"
308 | }
309 | },
310 | {
311 | "cell_type": "code",
312 | "metadata": {
313 | "colab": {
314 | "base_uri": "https://localhost:8080/"
315 | },
316 | "id": "XqwYrR1GYLvm",
317 | "outputId": "4f8c480a-a4cd-42cc-9d5c-8d5ae822ae22"
318 | },
319 | "source": [
320 | "%%writefile {_trainer_densenet_module_file}\n",
321 | "\n",
322 | "from typing import List\n",
323 | "from absl import logging\n",
324 | "from tensorflow import keras\n",
325 | "from tfx import v1 as tfx\n",
326 | "import tensorflow as tf\n",
327 | "\n",
328 | "\n",
329 | "_IMAGE_FEATURES = {\n",
330 | " \"image\": tf.io.FixedLenFeature([], tf.string),\n",
331 | " \"class\": tf.io.FixedLenFeature([], tf.int64),\n",
332 | " \"one_hot_class\": tf.io.VarLenFeature(tf.float32),\n",
333 | "}\n",
334 | "\n",
335 | "_CONCRETE_INPUT = \"numpy_inputs\"\n",
336 | "_INPUT_SHAPE = (224, 224, 3)\n",
337 | "_TRAIN_BATCH_SIZE = 64\n",
338 | "_EVAL_BATCH_SIZE = 64\n",
339 | "_EPOCHS = 2\n",
340 | "\n",
341 | "\n",
342 | "def _parse_fn(example):\n",
343 | " example = tf.io.parse_single_example(example, _IMAGE_FEATURES)\n",
344 | " image = tf.image.decode_jpeg(example[\"image\"], channels=3)\n",
345 | " class_label = tf.cast(example[\"class\"], tf.int32)\n",
346 | " return image, class_label\n",
347 | "\n",
348 | "\n",
349 | "def _input_fn(file_pattern: List[str], batch_size: int) -> tf.data.Dataset:\n",
350 | " \"\"\"Generates features and label for training.\n",
351 | "\n",
352 | " Args:\n",
353 | " file_pattern: List of paths or patterns of input tfrecord files.\n",
354 | " batch_size: representing the number of consecutive elements of returned\n",
355 | " dataset to combine in a single batch.\n",
356 | "\n",
357 | " Returns:\n",
358 | " A dataset that contains (features, indices) tuple where features is a\n",
359 | " dictionary of Tensors, and indices is a single Tensor of label indices.\n",
360 | " \"\"\"\n",
361 | " logging.info(f\"Reading data from: {file_pattern}\")\n",
362 | " tfrecord_filenames = tf.io.gfile.glob(file_pattern[0] + \".gz\")\n",
363 | " dataset = tf.data.TFRecordDataset(tfrecord_filenames, compression_type=\"GZIP\")\n",
364 | " dataset = dataset.map(_parse_fn).batch(batch_size)\n",
365 | " return dataset.repeat()\n",
366 | "\n",
367 | "\n",
368 | "def _make_keras_model() -> tf.keras.Model:\n",
369 | " \"\"\"Creates a DenseNet121-based model for classifying flowers data.\n",
370 | "\n",
371 | " Returns:\n",
372 | " A Keras Model.\n",
373 | " \"\"\"\n",
374 | " inputs = keras.Input(shape=_INPUT_SHAPE)\n",
375 | " base_model = keras.applications.DenseNet121(\n",
376 | " include_top=False, input_shape=_INPUT_SHAPE, pooling=\"avg\"\n",
377 | " )\n",
378 | " base_model.trainable = False\n",
379 | " x = keras.applications.densenet.preprocess_input(inputs)\n",
380 | " x = base_model(\n",
381 | " x, training=False\n",
382 | " ) # Ensures BatchNorm runs in inference model in this model\n",
383 | " outputs = keras.layers.Dense(5, activation=\"softmax\")(x)\n",
384 | " model = keras.Model(inputs, outputs)\n",
385 | "\n",
386 | " model.compile(\n",
387 | " optimizer=keras.optimizers.Adam(),\n",
388 | " loss=tf.keras.losses.SparseCategoricalCrossentropy(),\n",
389 | " metrics=[keras.metrics.SparseCategoricalAccuracy()],\n",
390 | " )\n",
391 | "\n",
392 | " model.summary(print_fn=logging.info)\n",
393 | " return model\n",
394 | "\n",
395 | "\n",
396 | "def _preprocess(bytes_input):\n",
397 | " decoded = tf.io.decode_jpeg(bytes_input, channels=3)\n",
398 | " resized = tf.image.resize(decoded, size=(224, 224))\n",
399 | " return resized\n",
400 | "\n",
401 | "\n",
402 | "@tf.function(input_signature=[tf.TensorSpec([None], tf.string)])\n",
403 | "def preprocess_fn(bytes_inputs):\n",
404 | " decoded_images = tf.map_fn(\n",
405 | " _preprocess, bytes_inputs, dtype=tf.float32, back_prop=False\n",
406 | " )\n",
407 | " return {_CONCRETE_INPUT: decoded_images}\n",
408 | "\n",
409 | "\n",
410 | "def _model_exporter(model: tf.keras.Model):\n",
411 | " m_call = tf.function(model.call).get_concrete_function(\n",
412 | " [\n",
413 | " tf.TensorSpec(\n",
414 | " shape=[None, 224, 224, 3], dtype=tf.float32, name=_CONCRETE_INPUT\n",
415 | " )\n",
416 | " ]\n",
417 | " )\n",
418 | "\n",
419 | " @tf.function(input_signature=[tf.TensorSpec([None], tf.string)])\n",
420 | " def serving_fn(bytes_inputs):\n",
421 | " labels = tf.constant(\n",
422 | " [\"daisy\", \"dandelion\", \"roses\", \"sunflowers\", \"tulips\"], dtype=tf.string\n",
423 | " )\n",
424 | " images = preprocess_fn(bytes_inputs)\n",
425 | "\n",
426 | " probs = m_call(**images)\n",
427 | " indices = tf.argmax(probs, axis=1)\n",
428 | " pred_source = tf.gather(params=labels, indices=indices)\n",
429 | " pred_confidence = tf.reduce_max(probs, axis=1)\n",
430 | " return {\"label\": pred_source, \"confidence\": pred_confidence}\n",
431 | "\n",
432 | " return serving_fn\n",
433 | "\n",
434 | "\n",
435 | "# TFX Trainer will call this function.\n",
436 | "def run_fn(fn_args: tfx.components.FnArgs):\n",
437 | " \"\"\"Train the model based on given args.\n",
438 | "\n",
439 | " Args:\n",
440 | " fn_args: Holds args used to train the model as name/value pairs.\n",
441 | " \"\"\"\n",
442 | " train_dataset = _input_fn(fn_args.train_files, batch_size=_TRAIN_BATCH_SIZE)\n",
443 | " eval_dataset = _input_fn(fn_args.eval_files, batch_size=_EVAL_BATCH_SIZE)\n",
444 | "\n",
445 | " model = _make_keras_model()\n",
446 | " model.fit(\n",
447 | " train_dataset,\n",
448 | " steps_per_epoch=fn_args.train_steps,\n",
449 | " validation_data=eval_dataset,\n",
450 | " validation_steps=fn_args.eval_steps,\n",
451 | " epochs=_EPOCHS,\n",
452 | " )\n",
453 | " _, acc = model.evaluate(eval_dataset, steps=fn_args.eval_steps)\n",
454 | " logging.info(f\"Validation accuracy: {round(acc * 100, 2)}%\")\n",
455 | " # The result of the training should be saved in `fn_args.serving_model_dir`\n",
456 | " # directory.\n",
457 | " tf.saved_model.save(\n",
458 | " model,\n",
459 | " fn_args.serving_model_dir,\n",
460 | " signatures={\"serving_default\": _model_exporter(model)},\n",
461 | " )\n"
462 | ],
463 | "execution_count": null,
464 | "outputs": [
465 | {
466 | "output_type": "stream",
467 | "text": [
468 | "Overwriting flower_densenet_trainer.py\n"
469 | ],
470 | "name": "stdout"
471 | }
472 | ]
473 | },
474 | {
475 | "cell_type": "code",
476 | "metadata": {
477 | "colab": {
478 | "base_uri": "https://localhost:8080/"
479 | },
480 | "id": "tIFDOw5cmqd1",
481 | "outputId": "c26a9017-b14f-4f1a-ce5a-4ef98ba96270"
482 | },
483 | "source": [
484 | "%%writefile {_trainer_mobilenet_module_file}\n",
485 | "\n",
486 | "from typing import List\n",
487 | "from absl import logging\n",
488 | "from tensorflow import keras\n",
489 | "from tfx import v1 as tfx\n",
490 | "import tensorflow as tf\n",
491 | "\n",
492 | "\n",
493 | "_IMAGE_FEATURES = {\n",
494 | " \"image\": tf.io.FixedLenFeature([], tf.string),\n",
495 | " \"class\": tf.io.FixedLenFeature([], tf.int64),\n",
496 | " \"one_hot_class\": tf.io.VarLenFeature(tf.float32),\n",
497 | "}\n",
498 | "\n",
499 | "_INPUT_SHAPE = (224, 224, 3)\n",
500 | "_TRAIN_BATCH_SIZE = 64\n",
501 | "_EVAL_BATCH_SIZE = 64\n",
502 | "_EPOCHS = 2\n",
503 | "\n",
504 | "\n",
505 | "def _parse_fn(example):\n",
506 | " example = tf.io.parse_single_example(example, _IMAGE_FEATURES)\n",
507 | " image = tf.image.decode_jpeg(example[\"image\"], channels=3)\n",
508 | " class_label = tf.cast(example[\"class\"], tf.int32)\n",
509 | " return image, class_label\n",
510 | "\n",
511 | "\n",
512 | "def _input_fn(file_pattern: List[str], batch_size: int) -> tf.data.Dataset:\n",
513 | " \"\"\"Generates features and label for training.\n",
514 | "\n",
515 | " Args:\n",
516 | " file_pattern: List of paths or patterns of input tfrecord files.\n",
517 | " batch_size: representing the number of consecutive elements of returned\n",
518 | " dataset to combine in a single batch.\n",
519 | "\n",
520 | " Returns:\n",
521 | " A dataset that contains (features, indices) tuple where features is a\n",
522 | " dictionary of Tensors, and indices is a single Tensor of label indices.\n",
523 | " \"\"\"\n",
524 | " logging.info(f\"Reading data from: {file_pattern}\")\n",
525 | " tfrecord_filenames = tf.io.gfile.glob(file_pattern[0] + \".gz\")\n",
526 | " dataset = tf.data.TFRecordDataset(tfrecord_filenames, compression_type=\"GZIP\")\n",
527 | " dataset = dataset.map(_parse_fn).batch(batch_size)\n",
528 | " return dataset.repeat()\n",
529 | "\n",
530 | "\n",
531 | "def _make_keras_model() -> tf.keras.Model:\n",
532 | " \"\"\"Creates a MobileNetV3-based model for classifying flowers data.\n",
533 | "\n",
534 | " Returns:\n",
535 | " A Keras Model.\n",
536 | " \"\"\"\n",
537 | " inputs = keras.Input(shape=_INPUT_SHAPE)\n",
538 | " base_model = keras.applications.MobileNetV3Small(\n",
539 | " include_top=False, input_shape=_INPUT_SHAPE, pooling=\"avg\"\n",
540 | " )\n",
541 | " base_model.trainable = False\n",
542 | " x = keras.applications.mobilenet_v3.preprocess_input(inputs)\n",
543 | " x = base_model(\n",
544 | " x, training=False\n",
545 | " ) # Ensures BatchNorm runs in inference model in this model\n",
546 | " outputs = keras.layers.Dense(5, activation=\"softmax\")(x)\n",
547 | " model = keras.Model(inputs, outputs)\n",
548 | "\n",
549 | " model.compile(\n",
550 | " optimizer=keras.optimizers.Adam(),\n",
551 | " loss=tf.keras.losses.SparseCategoricalCrossentropy(),\n",
552 | " metrics=[keras.metrics.SparseCategoricalAccuracy()],\n",
553 | " )\n",
554 | "\n",
555 | " model.summary(print_fn=logging.info)\n",
556 | " return model\n",
557 | "\n",
558 | "\n",
559 | "# TFX Trainer will call this function.\n",
560 | "def run_fn(fn_args: tfx.components.FnArgs):\n",
561 | " \"\"\"Train the model based on given args.\n",
562 | "\n",
563 | " Args:\n",
564 | " fn_args: Holds args used to train the model as name/value pairs.\n",
565 | " \"\"\"\n",
566 | " train_dataset = _input_fn(fn_args.train_files, batch_size=_TRAIN_BATCH_SIZE)\n",
567 | " eval_dataset = _input_fn(fn_args.eval_files, batch_size=_EVAL_BATCH_SIZE)\n",
568 | "\n",
569 | " model = _make_keras_model()\n",
570 | " model.fit(\n",
571 | " train_dataset,\n",
572 | " steps_per_epoch=fn_args.train_steps,\n",
573 | " validation_data=eval_dataset,\n",
574 | " validation_steps=fn_args.eval_steps,\n",
575 | " epochs=_EPOCHS,\n",
576 | " )\n",
577 | " _, acc = model.evaluate(eval_dataset, steps=fn_args.eval_steps)\n",
578 | " logging.info(f\"Validation accuracy: {round(acc * 100, 2)}%\")\n",
579 | "\n",
580 | " # Convert the model.\n",
581 | " converter = tf.lite.TFLiteConverter.from_keras_model(model)\n",
582 | " converter.optimizations = [tf.lite.Optimize.DEFAULT]\n",
583 | " tflite_model = converter.convert()\n",
584 | "\n",
585 | " # Save the model.\n",
586 | " # The result of the training should be saved in `fn_args.serving_model_dir` directory.\n",
587 | " with tf.io.gfile.GFile(fn_args.serving_model_dir + \"/model.tflite\", \"wb\") as f:\n",
588 | " f.write(tflite_model)\n"
589 | ],
590 | "execution_count": null,
591 | "outputs": [
592 | {
593 | "output_type": "stream",
594 | "text": [
595 | "Overwriting flower_mobilenet_trainer.py\n"
596 | ],
597 | "name": "stdout"
598 | }
599 | ]
600 | },
601 | {
602 | "cell_type": "code",
603 | "metadata": {
604 | "colab": {
605 | "base_uri": "https://localhost:8080/"
606 | },
607 | "id": "Gp7RTxpMoLe9",
608 | "outputId": "1233ba34-5b4a-4a85-8386-122c6e9b84f9"
609 | },
610 | "source": [
611 | "!gsutil cp -r *.py {MODULE_ROOT}/\n",
612 | "!gsutil ls -lh {MODULE_ROOT}/"
613 | ],
614 | "execution_count": null,
615 | "outputs": [
616 | {
617 | "output_type": "stream",
618 | "text": [
619 | "Copying file://firebase_publisher.py [Content-Type=text/x-python]...\n",
620 | "Copying file://flower_densenet_trainer.py [Content-Type=text/x-python]... \n",
621 | "Copying file://flower_mobilenet_trainer.py [Content-Type=text/x-python]... \n",
622 | "Copying file://vertex_deployer.py [Content-Type=text/x-python]... \n",
623 | "/ [4 files][ 12.5 KiB/ 12.5 KiB] \n",
624 | "==> NOTE: You are performing a sequence of gsutil operations that may\n",
625 | "run significantly faster if you instead use gsutil -m cp ... Please\n",
626 | "see the -m section under \"gsutil help options\" for further information\n",
627 | "about when gsutil -m can be advantageous.\n",
628 | "\n",
629 | "Copying file://vertex_uploader.py [Content-Type=text/x-python]...\n",
630 | "/ [5 files][ 13.8 KiB/ 13.8 KiB] \n",
631 | "Operation completed over 5 objects/13.8 KiB. \n",
632 | " 3.16 KiB 2021-08-05T16:02:35Z gs://demo-experiments-gde-csp/pipeline_module/two-way-vertex-pipelines5/firebase_publisher.py\n",
633 | " 4.53 KiB 2021-08-05T16:02:35Z gs://demo-experiments-gde-csp/pipeline_module/two-way-vertex-pipelines5/flower_densenet_trainer.py\n",
634 | " 3.37 KiB 2021-08-05T16:02:35Z gs://demo-experiments-gde-csp/pipeline_module/two-way-vertex-pipelines5/flower_mobilenet_trainer.py\n",
635 | " 1.49 KiB 2021-08-05T16:02:35Z gs://demo-experiments-gde-csp/pipeline_module/two-way-vertex-pipelines5/vertex_deployer.py\n",
636 | " 1.31 KiB 2021-08-05T16:02:35Z gs://demo-experiments-gde-csp/pipeline_module/two-way-vertex-pipelines5/vertex_uploader.py\n",
637 | "TOTAL: 5 objects, 14181 bytes (13.85 KiB)\n"
638 | ],
639 | "name": "stdout"
640 | }
641 | ]
642 | },
643 | {
644 | "cell_type": "markdown",
645 | "metadata": {
646 | "id": "-qX4M-C6X4LF"
647 | },
648 | "source": [
649 | "## Create the pipeline\n",
650 | "\n",
651 | "To create the end-to-end pipeline, we will need to write two custom TFX components:\n",
652 | "\n",
653 | "* One will take the pushed model from `Pusher` and will upload it to Vertex AI. \n",
654 | "* One will deploy the uploaded model to an Endpoint.\n",
655 | "\n",
656 | "We will then need to build a Docker image using these custom components and serve the pipeline using this image. We will use [Cloud Build](https://cloud.google.com/build) in order to build the Docker image. "
657 | ]
658 | },
659 | {
660 | "cell_type": "code",
661 | "metadata": {
662 | "id": "SYaypU11e-cG"
663 | },
664 | "source": [
665 | "_vertex_uploader_module_file = 'vertex_uploader.py'\n",
666 | "_vertex_deployer_module_file = 'vertex_deployer.py'\n",
667 | "_firebase_publisher_module_file = 'firebase_publisher.py'"
668 | ],
669 | "execution_count": null,
670 | "outputs": []
671 | },
672 | {
673 | "cell_type": "code",
674 | "metadata": {
675 | "colab": {
676 | "base_uri": "https://localhost:8080/"
677 | },
678 | "id": "pBj_WnLw6oRR",
679 | "outputId": "83adae19-5209-4032-bf94-42f1d668e947"
680 | },
681 | "source": [
682 | "%%writefile {_vertex_uploader_module_file}\n",
683 | "\n",
684 | "import os\n",
685 | "import tensorflow as tf\n",
686 | "\n",
687 | "from tfx.dsl.component.experimental.decorators import component\n",
688 | "from tfx.dsl.component.experimental.annotations import Parameter\n",
689 | "from tfx.types.standard_artifacts import String\n",
690 | "from google.cloud import aiplatform as vertex_ai\n",
691 | "from tfx import v1 as tfx\n",
692 | "from absl import logging\n",
693 | "\n",
694 | "\n",
695 | "@component\n",
696 | "def VertexUploader(\n",
697 | " project: Parameter[str],\n",
698 | " region: Parameter[str],\n",
699 | " model_display_name: Parameter[str],\n",
700 | " pushed_model_location: Parameter[str],\n",
701 | " serving_image_uri: Parameter[str],\n",
702 | " uploaded_model: tfx.dsl.components.OutputArtifact[String],\n",
703 | "):\n",
704 | "\n",
705 | " vertex_ai.init(project=project, location=region)\n",
706 | "\n",
707 | " pushed_model_dir = os.path.join(\n",
708 | " pushed_model_location, tf.io.gfile.listdir(pushed_model_location)[-1]\n",
709 | " )\n",
710 | "\n",
711 | " logging.info(f\"Model registry location: {pushed_model_dir}\")\n",
712 | "\n",
713 | " vertex_model = vertex_ai.Model.upload(\n",
714 | " display_name=model_display_name,\n",
715 | " artifact_uri=pushed_model_dir,\n",
716 | " serving_container_image_uri=serving_image_uri,\n",
717 | " parameters_schema_uri=None,\n",
718 | " instance_schema_uri=None,\n",
719 | " explanation_metadata=None,\n",
720 | " explanation_parameters=None,\n",
721 | " )\n",
722 | "\n",
723 | " uploaded_model.set_string_custom_property(\n",
724 | " \"model_resource_name\", str(vertex_model.resource_name)\n",
725 | " )\n",
726 | " logging.info(f\"Model resource: {str(vertex_model.resource_name)}\")\n"
727 | ],
728 | "execution_count": null,
729 | "outputs": [
730 | {
731 | "output_type": "stream",
732 | "text": [
733 | "Overwriting vertex_uploader.py\n"
734 | ],
735 | "name": "stdout"
736 | }
737 | ]
738 | },
739 | {
740 | "cell_type": "code",
741 | "metadata": {
742 | "colab": {
743 | "base_uri": "https://localhost:8080/"
744 | },
745 | "id": "awNiJiZ1fZsp",
746 | "outputId": "dd732111-6210-47ef-ef63-32bd907749dd"
747 | },
748 | "source": [
749 | "%%writefile {_vertex_deployer_module_file}\n",
750 | "\n",
751 | "from tfx.dsl.component.experimental.decorators import component\n",
752 | "from tfx.dsl.component.experimental.annotations import Parameter\n",
753 | "from tfx.types.standard_artifacts import String\n",
754 | "from google.cloud import aiplatform as vertex_ai\n",
755 | "from tfx import v1 as tfx\n",
756 | "from absl import logging\n",
757 | "\n",
758 | "\n",
759 | "@component\n",
760 | "def VertexDeployer(\n",
761 | " project: Parameter[str],\n",
762 | " region: Parameter[str],\n",
763 | " model_display_name: Parameter[str],\n",
764 | " deployed_model_display_name: Parameter[str],\n",
765 | "):\n",
766 | "\n",
767 | " logging.info(f\"Endpoint display: {deployed_model_display_name}\")\n",
768 | " vertex_ai.init(project=project, location=region)\n",
769 | "\n",
770 | " endpoints = vertex_ai.Endpoint.list(\n",
771 | " filter=f\"display_name={deployed_model_display_name}\", order_by=\"update_time\"\n",
772 | " )\n",
773 | "\n",
774 | " if len(endpoints) > 0:\n",
775 | " logging.info(f\"Endpoint {deployed_model_display_name} already exists.\")\n",
776 | " endpoint = endpoints[-1]\n",
777 | " else:\n",
778 | " endpoint = vertex_ai.Endpoint.create(deployed_model_display_name)\n",
779 | "\n",
780 | " model = vertex_ai.Model.list(\n",
781 | " filter=f\"display_name={model_display_name}\", order_by=\"update_time\"\n",
782 | " )[-1]\n",
783 | "\n",
784 | " endpoint = vertex_ai.Endpoint.list(\n",
785 | " filter=f\"display_name={deployed_model_display_name}\", order_by=\"update_time\"\n",
786 | " )[-1]\n",
787 | "\n",
788 | " deployed_model = endpoint.deploy(\n",
789 | " model=model,\n",
790 | " # Syntax from here: https://git.io/JBQDP\n",
791 | " traffic_split={\"0\": 100},\n",
792 | " machine_type=\"n1-standard-4\",\n",
793 | " min_replica_count=1,\n",
794 | " max_replica_count=1,\n",
795 | " )\n",
796 | "\n",
797 | " logging.info(f\"Model deployed to: {deployed_model}\")\n"
798 | ],
799 | "execution_count": null,
800 | "outputs": [
801 | {
802 | "output_type": "stream",
803 | "text": [
804 | "Overwriting vertex_deployer.py\n"
805 | ],
806 | "name": "stdout"
807 | }
808 | ]
809 | },
810 | {
811 | "cell_type": "code",
812 | "metadata": {
813 | "id": "Bp7GMaUlbfTz",
814 | "outputId": "4f1458c5-04e6-4681-86ab-5c1d7a8047f3"
815 | },
816 | "source": [
817 | "%%writefile {_firebase_publisher_module_file}\n",
818 | "\n",
819 | "from tfx import types\n",
820 | "from tfx.dsl.component.experimental.decorators import component\n",
821 | "from tfx.dsl.component.experimental.annotations import Parameter\n",
822 | "from tfx import v1 as tfx\n",
823 | "from absl import logging\n",
824 | "\n",
825 | "import firebase_admin\n",
826 | "from firebase_admin import ml\n",
827 | "from firebase_admin import storage\n",
828 | "from firebase_admin import credentials\n",
829 | "from google.cloud import storage as gcs_storage\n",
830 | "\n",
831 | "\n",
832 | "@component\n",
833 | "def FirebasePublisher(\n",
834 | " pushed_model: tfx.dsl.components.InputArtifact[\n",
835 | " tfx.types.standard_artifacts.PushedModel\n",
836 | " ],\n",
837 | " credential_uri: Parameter[str],\n",
838 | " firebase_dest_gcs_bucket: Parameter[str],\n",
839 | " model_display_name: Parameter[str],\n",
840 | " model_tag: Parameter[str],\n",
841 | ") -> tfx.dsl.components.OutputDict(result=str):\n",
842 | " model_uri = f\"{pushed_model.uri}/model.tflite\"\n",
843 | "\n",
844 | " assert model_uri.split(\"://\")[0] == \"gs\"\n",
845 | " assert credential_uri.split(\"://\")[0] == \"gs\"\n",
846 | "\n",
847 | " # create gcs client instance\n",
848 | " gcs_client = gcs_storage.Client()\n",
849 | "\n",
850 | " # get credential for firebase\n",
851 | " credential_gcs_bucket = credential_uri.split(\"//\")[1].split(\"/\")[0]\n",
852 | " credential_blob_path = \"/\".join(credential_uri.split(\"//\")[1].split(\"/\")[1:])\n",
853 | "\n",
854 | " bucket = gcs_client.bucket(credential_gcs_bucket)\n",
855 | " blob = bucket.blob(credential_blob_path)\n",
856 | " blob.download_to_filename(\"credential.json\")\n",
857 | " logging.info(f\"download credential.json from {credential_uri} is completed\")\n",
858 | "\n",
859 | " # get tflite model file\n",
860 | " tflite_gcs_bucket = model_uri.split(\"//\")[1].split(\"/\")[0]\n",
861 | " tflite_blob_path = \"/\".join(model_uri.split(\"//\")[1].split(\"/\")[1:])\n",
862 | "\n",
863 | " bucket = gcs_client.bucket(tflite_gcs_bucket)\n",
864 | " blob = bucket.blob(tflite_blob_path)\n",
865 | " blob.download_to_filename(\"model.tflite\")\n",
866 | " logging.info(f\"download model.tflite from {model_uri} is completed\")\n",
867 | "\n",
868 | " firebase_admin.initialize_app(\n",
869 | " credentials.Certificate(\"credential.json\"),\n",
870 | " options={\"storageBucket\": firebase_dest_gcs_bucket},\n",
871 | " )\n",
872 | " logging.info(\"firebase_admin initialize app is completed\")\n",
873 | "\n",
874 | " model_list = ml.list_models(list_filter=f\"display_name={model_display_name}\")\n",
875 | " # update\n",
876 | " if len(model_list.models) > 0:\n",
877 | " # get the first match model\n",
878 | " model = model_list.models[0]\n",
879 | " source = ml.TFLiteGCSModelSource.from_tflite_model_file(\"model.tflite\")\n",
880 | " model.model_format = ml.TFLiteFormat(model_source=source)\n",
881 | "\n",
882 | " updated_model = ml.update_model(model)\n",
883 | " ml.publish_model(updated_model.model_id)\n",
884 | "\n",
885 | " logging.info(\"model exists, so update it in FireBase ML\")\n",
886 | " return {\"result\": \"model updated\"}\n",
887 | " # create\n",
888 | " else:\n",
889 | " # load a tflite file and upload it to Cloud Storage\n",
890 | " source = ml.TFLiteGCSModelSource.from_tflite_model_file(\"model.tflite\")\n",
891 | "\n",
892 | " # create the model object\n",
893 | " tflite_format = ml.TFLiteFormat(model_source=source)\n",
894 | " model = ml.Model(\n",
895 | " display_name=model_display_name,\n",
896 | " tags=[model_tag],\n",
897 | " model_format=tflite_format,\n",
898 | " )\n",
899 | "\n",
900 | " # Add the model to your Firebase project and publish it\n",
901 | " new_model = ml.create_model(model)\n",
902 | " ml.publish_model(new_model.model_id)\n",
903 | "\n",
904 | " logging.info(\"model doesn exists, so create one in FireBase ML\")\n",
905 | " return {\"result\": \"model created\"}\n"
906 | ],
907 | "execution_count": null,
908 | "outputs": [
909 | {
910 | "output_type": "stream",
911 | "text": [
912 | "Overwriting firebase_publisher.py\n"
913 | ],
914 | "name": "stdout"
915 | }
916 | ]
917 | },
918 | {
919 | "cell_type": "markdown",
920 | "metadata": {
921 | "id": "joRJYojpP8-H"
922 | },
923 | "source": [
924 | "Create a package called `custom_components` and copy the modules we just wrote. "
925 | ]
926 | },
927 | {
928 | "cell_type": "code",
929 | "metadata": {
930 | "id": "GYiXX9pOOOAX"
931 | },
932 | "source": [
933 | "!mkdir -p ./custom_components\n",
934 | "!touch ./custom_components/__init__.py\n",
935 | "!cp -r {_vertex_uploader_module_file} {_vertex_deployer_module_file} {_firebase_publisher_module_file} custom_components"
936 | ],
937 | "execution_count": null,
938 | "outputs": []
939 | },
940 | {
941 | "cell_type": "code",
942 | "metadata": {
943 | "colab": {
944 | "base_uri": "https://localhost:8080/"
945 | },
946 | "id": "UiLiGjR3OOAX",
947 | "outputId": "ec5a7f09-df8c-4151-d3c6-51c20528d0d6"
948 | },
949 | "source": [
950 | "!ls -lh custom_components"
951 | ],
952 | "execution_count": null,
953 | "outputs": [
954 | {
955 | "output_type": "stream",
956 | "text": [
957 | "total 16K\n",
958 | "-rw-r--r-- 1 jupyter jupyter 0 Aug 5 14:20 __init__.py\n",
959 | "drwxr-xr-x 2 jupyter jupyter 4.0K Aug 5 06:50 __pycache__\n",
960 | "-rw-r--r-- 1 jupyter jupyter 3.2K Aug 5 14:20 firebase_publisher.py\n",
961 | "-rw-r--r-- 1 jupyter jupyter 1.5K Aug 5 14:20 vertex_deployer.py\n",
962 | "-rw-r--r-- 1 jupyter jupyter 1.4K Aug 5 14:20 vertex_uploader.py\n"
963 | ],
964 | "name": "stdout"
965 | }
966 | ]
967 | },
968 | {
969 | "cell_type": "markdown",
970 | "metadata": {
971 | "id": "DcFpCCSbQDol"
972 | },
973 | "source": [
974 | "### `Dockerfile` configuration "
975 | ]
976 | },
977 | {
978 | "cell_type": "code",
979 | "metadata": {
980 | "colab": {
981 | "base_uri": "https://localhost:8080/"
982 | },
983 | "id": "h9j867lPfyM8",
984 | "outputId": "c4858955-56ef-444f-d683-e03630c288f7"
985 | },
986 | "source": [
987 | "DATASET_DISPLAY_NAME = \"flowers\"\n",
988 | "VERSION = \"tfx-1-0-0\"\n",
989 | "TFX_IMAGE_URI = f\"gcr.io/{GOOGLE_CLOUD_PROJECT}/{DATASET_DISPLAY_NAME}:{VERSION}\"\n",
990 | "print(f\"URI of the custom image: {TFX_IMAGE_URI}\")"
991 | ],
992 | "execution_count": null,
993 | "outputs": [
994 | {
995 | "output_type": "stream",
996 | "text": [
997 | "URI of the custom image: gcr.io/gcp-ml-172005/flowers:tfx-1-0-0\n"
998 | ],
999 | "name": "stdout"
1000 | }
1001 | ]
1002 | },
1003 | {
1004 | "cell_type": "code",
1005 | "metadata": {
1006 | "colab": {
1007 | "base_uri": "https://localhost:8080/"
1008 | },
1009 | "id": "F3YcE0xUgOh8",
1010 | "outputId": "abfa72c3-dfd9-4a8f-a1b6-cf8e94b36f85"
1011 | },
1012 | "source": [
1013 | "%%writefile Dockerfile\n",
1014 | "\n",
1015 | "FROM gcr.io/tfx-oss-public/tfx:1.0.0\n",
1016 | "RUN mkdir -p custom_components\n",
1017 | "COPY custom_components/* ./custom_components/\n",
1018 | "RUN pip install --upgrade google-cloud-aiplatform google-cloud-storage firebase-admin"
1019 | ],
1020 | "execution_count": null,
1021 | "outputs": [
1022 | {
1023 | "output_type": "stream",
1024 | "text": [
1025 | "Overwriting Dockerfile\n"
1026 | ],
1027 | "name": "stdout"
1028 | }
1029 | ]
1030 | },
1031 | {
1032 | "cell_type": "code",
1033 | "metadata": {
1034 | "id": "hEydNZrHg6_J"
1035 | },
1036 | "source": [
1037 | "!gcloud builds submit --tag $TFX_IMAGE_URI . --timeout=15m --machine-type=e2-highcpu-8"
1038 | ],
1039 | "execution_count": null,
1040 | "outputs": []
1041 | },
1042 | {
1043 | "cell_type": "code",
1044 | "metadata": {
1045 | "id": "sEbNM9CeERX2"
1046 | },
1047 | "source": [
1048 | "# Specify training worker configurations. To minimize costs we can even specify two\n",
1049 | "# different configurations: a beefier machine for the Endpoint model and slightly less\n",
1050 | "# powerful machine for the mobile model.\n",
1051 | "TRAINING_JOB_SPEC = {\n",
1052 | " 'project': GOOGLE_CLOUD_PROJECT,\n",
1053 | " 'worker_pool_specs': [{\n",
1054 | " 'machine_spec': {\n",
1055 | " 'machine_type': 'n1-standard-4',\n",
1056 | " 'accelerator_type': 'NVIDIA_TESLA_K80',\n",
1057 | " 'accelerator_count': 1\n",
1058 | " },\n",
1059 | " 'replica_count': 1,\n",
1060 | " 'container_spec': {\n",
1061 | " 'image_uri': 'gcr.io/tfx-oss-public/tfx:{}'.format(tfx.__version__),\n",
1062 | " },\n",
1063 | " }],\n",
1064 | "}"
1065 | ],
1066 | "execution_count": null,
1067 | "outputs": []
1068 | },
1069 | {
1070 | "cell_type": "code",
1071 | "metadata": {
1072 | "id": "MzZdpUjypcsT"
1073 | },
1074 | "source": [
1075 | "from datetime import datetime\n",
1076 | "\n",
1077 | "TIMESTAMP = datetime.now().strftime(\"%Y%m%d%H%M%S\")"
1078 | ],
1079 | "execution_count": null,
1080 | "outputs": []
1081 | },
1082 | {
1083 | "cell_type": "code",
1084 | "metadata": {
1085 | "id": "ln1cvbcfphA9"
1086 | },
1087 | "source": [
1088 | "from custom_components.vertex_uploader import VertexUploader\n",
1089 | "from custom_components.vertex_deployer import VertexDeployer\n",
1090 | "from custom_components.firebase_publisher import FirebasePublisher\n",
1091 | "\n",
1092 | "\n",
1093 | "def _create_pipeline(\n",
1094 | " pipeline_name: str,\n",
1095 | " pipeline_root: str,\n",
1096 | " data_root: str,\n",
1097 | " densenet_module_file: str,\n",
1098 | " mobilenet_module_file: str,\n",
1099 | " serving_model_dir: str,\n",
1100 | " firebase_crediential_path: str,\n",
1101 | " firebase_gcs_bucket: str,\n",
1102 | " project_id: str,\n",
1103 | " region: str,\n",
1104 | ") -> tfx.dsl.Pipeline:\n",
1105 | " \"\"\"Creates a three component flowers pipeline with TFX.\"\"\"\n",
1106 | " # Brings data into the pipeline.\n",
1107 | " # input_base: gs://flowers-public/tfrecords-jpeg-224x224\n",
1108 | " example_gen = tfx.components.ImportExampleGen(input_base=data_root)\n",
1109 | "\n",
1110 | " # Uses user-provided Python function that trains a model.\n",
1111 | " densenet_trainer = tfx.extensions.google_cloud_ai_platform.Trainer(\n",
1112 | " module_file=densenet_module_file,\n",
1113 | " examples=example_gen.outputs[\"examples\"],\n",
1114 | " train_args=tfx.proto.TrainArgs(num_steps=52),\n",
1115 | " eval_args=tfx.proto.EvalArgs(num_steps=5),\n",
1116 | " custom_config={\n",
1117 | " tfx.extensions.google_cloud_ai_platform.ENABLE_UCAIP_KEY: True,\n",
1118 | " tfx.extensions.google_cloud_ai_platform.UCAIP_REGION_KEY: region,\n",
1119 | " tfx.extensions.google_cloud_ai_platform.TRAINING_ARGS_KEY: TRAINING_JOB_SPEC,\n",
1120 | " \"use_gpu\": True,\n",
1121 | " },\n",
1122 | " ).with_id(\"densenet_trainer\")\n",
1123 | "\n",
1124 | " # Pushes the model to a filesystem destination.\n",
1125 | " pushed_model_location = os.path.join(serving_model_dir, \"densenet\")\n",
1126 | " densnet_pusher = tfx.components.Pusher(\n",
1127 | " model=densenet_trainer.outputs[\"model\"],\n",
1128 | " push_destination=tfx.proto.PushDestination(\n",
1129 | " filesystem=tfx.proto.PushDestination.Filesystem(\n",
1130 | " base_directory=pushed_model_location\n",
1131 | " )\n",
1132 | " ),\n",
1133 | " ).with_id(\"densnet_pusher\")\n",
1134 | "\n",
1135 | " # Vertex AI upload.\n",
1136 | " model_display_name = \"densenet_flowers_latest\"\n",
1137 | " uploader = VertexUploader(\n",
1138 | " project=project_id,\n",
1139 | " region=region,\n",
1140 | " model_display_name=model_display_name,\n",
1141 | " pushed_model_location=pushed_model_location,\n",
1142 | " serving_image_uri=\"us-docker.pkg.dev/vertex-ai/prediction/tf2-cpu.2-5:latest\",\n",
1143 | " ).with_id(\"vertex_uploader\")\n",
1144 | " uploader.add_upstream_node(densnet_pusher)\n",
1145 | "\n",
1146 | " # Create an endpoint.\n",
1147 | " deployer = VertexDeployer(\n",
1148 | " project=project_id,\n",
1149 | " region=region,\n",
1150 | " model_display_name=model_display_name,\n",
1151 | " deployed_model_display_name=model_display_name + \"_\" + TIMESTAMP,\n",
1152 | " ).with_id(\"vertex_deployer\")\n",
1153 | " deployer.add_upstream_node(uploader)\n",
1154 | "\n",
1155 | " # We repeat the steps for the MobileNet model too but this time we won't\n",
1156 | " # be creating an Endpoint. We will first convert the Keras model to TFLite\n",
1157 | " # and then push it to Firebase for better operability. \n",
1158 | " mobilenet_trainer = tfx.extensions.google_cloud_ai_platform.Trainer(\n",
1159 | " module_file=mobilenet_module_file,\n",
1160 | " examples=example_gen.outputs[\"examples\"],\n",
1161 | " train_args=tfx.proto.TrainArgs(num_steps=52),\n",
1162 | " eval_args=tfx.proto.EvalArgs(num_steps=5),\n",
1163 | " custom_config={\n",
1164 | " tfx.extensions.google_cloud_ai_platform.ENABLE_UCAIP_KEY: True,\n",
1165 | " tfx.extensions.google_cloud_ai_platform.UCAIP_REGION_KEY: region,\n",
1166 | " tfx.extensions.google_cloud_ai_platform.TRAINING_ARGS_KEY: TRAINING_JOB_SPEC,\n",
1167 | " \"use_gpu\": True,\n",
1168 | " },\n",
1169 | " ).with_id(\"mobilenet_trainer\")\n",
1170 | "\n",
1171 | " pushed_location_mobilenet = os.path.join(serving_model_dir, \"mobilenet\")\n",
1172 | " mobilenet_pusher = tfx.components.Pusher(\n",
1173 | " model=mobilenet_trainer.outputs[\"model\"],\n",
1174 | " push_destination=tfx.proto.PushDestination(\n",
1175 | " filesystem=tfx.proto.PushDestination.Filesystem(\n",
1176 | " base_directory=pushed_location_mobilenet\n",
1177 | " )\n",
1178 | " ),\n",
1179 | " ).with_id(\"mobilenet_pusher\")\n",
1180 | "\n",
1181 | " firebase_publisher = FirebasePublisher(\n",
1182 | " pushed_model=mobilenet_pusher.outputs[\"pushed_model\"],\n",
1183 | " credential_uri=firebase_crediential_path,\n",
1184 | " firebase_dest_gcs_bucket=firebase_gcs_bucket,\n",
1185 | " model_display_name=model_display_name,\n",
1186 | " model_tag=\"mobilenet\",\n",
1187 | " ).with_id(\"firebase_publisher\")\n",
1188 | "\n",
1189 | " # Following components will be included in the pipeline.\n",
1190 | " components = [\n",
1191 | " example_gen,\n",
1192 | " densenet_trainer,\n",
1193 | " densnet_pusher,\n",
1194 | " uploader,\n",
1195 | " deployer,\n",
1196 | " mobilenet_trainer,\n",
1197 | " mobilenet_pusher,\n",
1198 | " firebase_publisher,\n",
1199 | " ]\n",
1200 | "\n",
1201 | " return tfx.dsl.Pipeline(\n",
1202 | " pipeline_name=pipeline_name, pipeline_root=pipeline_root, components=components\n",
1203 | " )\n"
1204 | ],
1205 | "execution_count": null,
1206 | "outputs": []
1207 | },
1208 | {
1209 | "cell_type": "markdown",
1210 | "metadata": {
1211 | "id": "IFdlslfOX54z"
1212 | },
1213 | "source": [
1214 | "## Compile the pipeline"
1215 | ]
1216 | },
1217 | {
1218 | "cell_type": "code",
1219 | "metadata": {
1220 | "id": "-AY5Z2tbsbwE"
1221 | },
1222 | "source": [
1223 | "PIPELINE_DEFINITION_FILE = PIPELINE_NAME + '_pipeline.json'\n",
1224 | "\n",
1225 | "# Important: We need to pass the custom Docker image URI to the\n",
1226 | "# `KubeflowV2DagRunnerConfig` to take effect.\n",
1227 | "runner = tfx.orchestration.experimental.KubeflowV2DagRunner(\n",
1228 | " config=tfx.orchestration.experimental.KubeflowV2DagRunnerConfig(default_image=TFX_IMAGE_URI),\n",
1229 | " output_filename=PIPELINE_DEFINITION_FILE)\n",
1230 | "\n",
1231 | "_ = runner.run(\n",
1232 | " _create_pipeline(\n",
1233 | " pipeline_name=PIPELINE_NAME,\n",
1234 | " pipeline_root=PIPELINE_ROOT,\n",
1235 | " data_root=DATA_ROOT,\n",
1236 | " densenet_module_file=os.path.join(MODULE_ROOT, _trainer_densenet_module_file),\n",
1237 | " mobilenet_module_file=os.path.join(MODULE_ROOT, _trainer_mobilenet_module_file),\n",
1238 | " serving_model_dir=SERVING_MODEL_DIR,\n",
1239 | " firebase_crediential_path=FIREBASE_CREDENTIAL_PATH,\n",
1240 | " firebase_gcs_bucket=FIREBASE_GCS_BUCKET, \n",
1241 | " project_id=GOOGLE_CLOUD_PROJECT,\n",
1242 | " region=GOOGLE_CLOUD_REGION\n",
1243 | " )\n",
1244 | ")"
1245 | ],
1246 | "execution_count": null,
1247 | "outputs": []
1248 | },
1249 | {
1250 | "cell_type": "markdown",
1251 | "metadata": {
1252 | "id": "ocHBJaR_X7x2"
1253 | },
1254 | "source": [
1255 | "## Submit the pipeline for execution to Vertex AI\n",
1256 | "\n",
1257 | "Generally, it's a good idea to first do a local run of the end-to-end pipeline before submitting it an online orchestrator. We can use `tfx.orchestration.LocalDagRunner()` for that but for the purposes of this notebook we won't be doing that. "
1258 | ]
1259 | },
1260 | {
1261 | "cell_type": "code",
1262 | "metadata": {
1263 | "colab": {
1264 | "base_uri": "https://localhost:8080/",
1265 | "height": 51
1266 | },
1267 | "id": "3elrtDOus83z",
1268 | "outputId": "7eb13f73-602c-44f0-8385-ac8c097905ea"
1269 | },
1270 | "source": [
1271 | "from kfp.v2.google import client\n",
1272 | "\n",
1273 | "pipelines_client = client.AIPlatformClient(\n",
1274 | " project_id=GOOGLE_CLOUD_PROJECT,\n",
1275 | " region=GOOGLE_CLOUD_REGION,\n",
1276 | ")\n",
1277 | "\n",
1278 | "_ = pipelines_client.create_run_from_job_spec(PIPELINE_DEFINITION_FILE, enable_caching=True)"
1279 | ],
1280 | "execution_count": null,
1281 | "outputs": [
1282 | {
1283 | "output_type": "display_data",
1284 | "data": {
1285 | "text/html": [
1286 | "See the Pipeline job here."
1287 | ],
1288 | "text/plain": [
1289 | ""
1290 | ]
1291 | },
1292 | "metadata": {
1293 | "tags": []
1294 | }
1295 | }
1296 | ]
1297 | },
1298 | {
1299 | "cell_type": "markdown",
1300 | "metadata": {
1301 | "id": "YX3pFWlhN_jR"
1302 | },
1303 | "source": [
1304 | "The pipeline should come out as the following:\n",
1305 | "\n",
1306 | ""
1307 | ]
1308 | },
1309 | {
1310 | "cell_type": "markdown",
1311 | "metadata": {
1312 | "id": "e-757SqrYhf3"
1313 | },
1314 | "source": [
1315 | "## Making predictions with the Endpoint\n",
1316 | "\n",
1317 | "Some code is used from [here](https://github.com/GoogleCloudPlatform/ai-platform-samples/blob/master/ai-platform-unified/notebooks/unofficial/gapic/custom/showcase_custom_image_classification_online.ipynb). "
1318 | ]
1319 | },
1320 | {
1321 | "cell_type": "markdown",
1322 | "metadata": {
1323 | "id": "bclxNBlJ-LNu"
1324 | },
1325 | "source": [
1326 | "### Imports and initialization"
1327 | ]
1328 | },
1329 | {
1330 | "cell_type": "code",
1331 | "metadata": {
1332 | "id": "6XWcRgv98gGr"
1333 | },
1334 | "source": [
1335 | "from google.cloud.aiplatform import gapic as aip\n",
1336 | "from google.protobuf import json_format\n",
1337 | "from google.protobuf.json_format import MessageToJson, ParseDict\n",
1338 | "from google.protobuf.struct_pb2 import Struct, Value\n",
1339 | "\n",
1340 | "import base64"
1341 | ],
1342 | "execution_count": null,
1343 | "outputs": []
1344 | },
1345 | {
1346 | "cell_type": "code",
1347 | "metadata": {
1348 | "id": "zSAVSxSC9EOC"
1349 | },
1350 | "source": [
1351 | "vertex_ai.init(project=GOOGLE_CLOUD_PROJECT, location=GOOGLE_CLOUD_REGION)"
1352 | ],
1353 | "execution_count": null,
1354 | "outputs": []
1355 | },
1356 | {
1357 | "cell_type": "markdown",
1358 | "metadata": {
1359 | "id": "zpZWBSs6-Nfz"
1360 | },
1361 | "source": [
1362 | "### Programatically retrieve the latest Endpoint macthing a name"
1363 | ]
1364 | },
1365 | {
1366 | "cell_type": "code",
1367 | "metadata": {
1368 | "colab": {
1369 | "base_uri": "https://localhost:8080/",
1370 | "height": 35
1371 | },
1372 | "id": "ZXKmceKG8qYW",
1373 | "outputId": "2c65421c-0211-43a2-fa12-cd69dc31cea4"
1374 | },
1375 | "source": [
1376 | "model_display_name = \"densenet_flowers_latest\"\n",
1377 | "deployed_model_display_name = model_display_name + \"_\" + TIMESTAMP\n",
1378 | "\n",
1379 | "endpoint = vertex_ai.Endpoint.list(\n",
1380 | " filter=f'display_name={deployed_model_display_name}',\n",
1381 | " order_by=\"update_time\"\n",
1382 | ")[-1]\n",
1383 | "\n",
1384 | "endpoint_id = endpoint.name\n",
1385 | "endpoint_id"
1386 | ],
1387 | "execution_count": null,
1388 | "outputs": [
1389 | {
1390 | "output_type": "execute_result",
1391 | "data": {
1392 | "text/plain": [
1393 | "'3904532915999997952'"
1394 | ]
1395 | },
1396 | "metadata": {
1397 | "tags": []
1398 | },
1399 | "execution_count": 19
1400 | }
1401 | ]
1402 | },
1403 | {
1404 | "cell_type": "markdown",
1405 | "metadata": {
1406 | "id": "uS9x3SNu-R5l"
1407 | },
1408 | "source": [
1409 | "### Sample data"
1410 | ]
1411 | },
1412 | {
1413 | "cell_type": "code",
1414 | "metadata": {
1415 | "colab": {
1416 | "base_uri": "https://localhost:8080/"
1417 | },
1418 | "id": "wJeyhzeL8-2C",
1419 | "outputId": "ace8e9b8-976b-4de4-e55b-728ff7012730"
1420 | },
1421 | "source": [
1422 | "image_path = tf.keras.utils.get_file(\"image.jpg\", \n",
1423 | " \"https://m.economictimes.com/thumb/msid-71307470,width-1201,height-900,resizemode-4,imgsize-1040796/roses.jpg\")\n",
1424 | "bytes = tf.io.read_file(image_path)\n",
1425 | "b64str = base64.b64encode(bytes.numpy()).decode(\"utf-8\")"
1426 | ],
1427 | "execution_count": null,
1428 | "outputs": [
1429 | {
1430 | "output_type": "stream",
1431 | "text": [
1432 | "Downloading data from https://m.economictimes.com/thumb/msid-71307470,width-1201,height-900,resizemode-4,imgsize-1040796/roses.jpg\n",
1433 | "229376/221658 [===============================]- ETA: - ETA: - 0s 0us/step\n"
1434 | ],
1435 | "name": "stdout"
1436 | }
1437 | ]
1438 | },
1439 | {
1440 | "cell_type": "markdown",
1441 | "metadata": {
1442 | "id": "Jmhj5-FU-TmQ"
1443 | },
1444 | "source": [
1445 | "### Investigating the input key"
1446 | ]
1447 | },
1448 | {
1449 | "cell_type": "code",
1450 | "metadata": {
1451 | "colab": {
1452 | "base_uri": "https://localhost:8080/"
1453 | },
1454 | "id": "2H6M-Usq9cUr",
1455 | "outputId": "e5412f05-51c9-4bb6-c8b7-07b550784a74"
1456 | },
1457 | "source": [
1458 | "pushed_model_location = os.path.join(SERVING_MODEL_DIR, \"densenet\")\n",
1459 | "model_path_to_deploy = os.path.join(\n",
1460 | " pushed_model_location, tf.io.gfile.listdir(pushed_model_location)[-1]\n",
1461 | ")\n",
1462 | "\n",
1463 | "loaded = tf.saved_model.load(model_path_to_deploy)\n",
1464 | "serving_input = list(\n",
1465 | " loaded.signatures[\"serving_default\"].structured_input_signature[1].keys()\n",
1466 | ")[0]\n",
1467 | "print(\"Serving function input:\", serving_input)"
1468 | ],
1469 | "execution_count": null,
1470 | "outputs": [
1471 | {
1472 | "output_type": "stream",
1473 | "text": [
1474 | "Serving function input: bytes_inputs\n"
1475 | ],
1476 | "name": "stdout"
1477 | }
1478 | ]
1479 | },
1480 | {
1481 | "cell_type": "markdown",
1482 | "metadata": {
1483 | "id": "C8gnuJf1-YkC"
1484 | },
1485 | "source": [
1486 | "### Make predictions"
1487 | ]
1488 | },
1489 | {
1490 | "cell_type": "code",
1491 | "metadata": {
1492 | "colab": {
1493 | "base_uri": "https://localhost:8080/"
1494 | },
1495 | "id": "UcB7HZat9QGI",
1496 | "outputId": "ceafc8e4-e93f-4607-c39c-160f8356520d"
1497 | },
1498 | "source": [
1499 | "def predict_image(image, endpoint, parameters_dict):\n",
1500 | " # The format of each instance should conform to the deployed model's prediction input schema.\n",
1501 | " instances_list = [{serving_input: {\"b64\": image}}]\n",
1502 | " instances = [json_format.ParseDict(s, Value()) for s in instances_list]\n",
1503 | "\n",
1504 | " endpoint = vertex_ai.Endpoint(endpoint)\n",
1505 | " print(endpoint.predict(instances=instances))\n",
1506 | "\n",
1507 | "predict_image(b64str, endpoint_id, None)"
1508 | ],
1509 | "execution_count": null,
1510 | "outputs": [
1511 | {
1512 | "output_type": "stream",
1513 | "text": [
1514 | "Prediction(predictions=[{'confidence': 0.669201732, 'label': 'roses'}], deployed_model_id='7558456345704267776', explanations=None)\n"
1515 | ],
1516 | "name": "stdout"
1517 | }
1518 | ]
1519 | }
1520 | ]
1521 | }
--------------------------------------------------------------------------------