├── python
├── serving
│ ├── logging_lib
│ │ ├── __init__.py
│ │ ├── flags
│ │ │ ├── __init__.py
│ │ │ ├── flag_utils.py
│ │ │ ├── flag_utils_test.py
│ │ │ ├── secret_flag_utils.py
│ │ │ └── secret_flag_utils_test.py
│ │ ├── cloud_logging_client.py
│ │ └── cloud_logging_client_instance.py
│ ├── serving_framework
│ │ ├── triton
│ │ │ ├── requirements.in
│ │ │ ├── __init__.py
│ │ │ ├── server_health_check.py
│ │ │ ├── server_health_check_test.py
│ │ │ ├── triton_server_model_runner.py
│ │ │ ├── triton_streaming_server_model_runner.py
│ │ │ └── triton_server_model_runner_test.py
│ │ ├── pip-install.txt
│ │ ├── requirements.in
│ │ ├── README.md
│ │ ├── __init__.py
│ │ ├── model_transfer.py
│ │ ├── inline_prediction_executor_test.py
│ │ ├── inline_prediction_executor.py
│ │ ├── model_runner.py
│ │ ├── server_gunicorn_test.py
│ │ └── server_gunicorn.py
│ ├── model_requirements.in
│ ├── requirements.in
│ ├── model_repository
│ │ └── default
│ │ │ ├── 1
│ │ │ └── model.py
│ │ │ └── config.pbtxt
│ ├── __init__.py
│ ├── predictor_test.py
│ ├── entrypoint.sh
│ ├── server_gunicorn.py
│ ├── Dockerfile
│ ├── vertex_schemata
│ │ ├── response.yaml
│ │ └── request.yaml
│ └── predictor.py
└── requirements.txt
├── notebooks
├── README.md
└── quick_start_with_model_garden.ipynb
├── CONTRIBUTING.md
├── README.md
└── LICENSE
/python/serving/logging_lib/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/python/serving/logging_lib/flags/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/python/requirements.txt:
--------------------------------------------------------------------------------
1 | -r serving/requirements.txt
2 |
--------------------------------------------------------------------------------
/python/serving/serving_framework/triton/requirements.in:
--------------------------------------------------------------------------------
1 | requests~=2.32.3
2 | # TODO: b/375469331 - Enable testing with most current requests-mock release.
3 | requests-mock==1.9.3
4 | tritonclient~=2.56.0
5 | typing-extensions~=4.12.2
--------------------------------------------------------------------------------
/python/serving/model_requirements.in:
--------------------------------------------------------------------------------
1 | torch~=2.9.1
2 | # TODO(b/468062329): Replace with regular pip version when released.
3 | https://github.com/huggingface/transformers/archive/65dc261512cbdb1ee72b88ae5b222f2605aad8e5.tar.gz#egg=transformers
--------------------------------------------------------------------------------
/python/serving/serving_framework/pip-install.txt:
--------------------------------------------------------------------------------
1 | # Requirements file for updating pip.
2 | pip==25.0 \
3 | --hash=sha256:8e0a97f7b4c47ae4a494560da84775e9e2f671d415d8d828e052efefb206b30b \
4 | --hash=sha256:b6eb97a803356a52b2dd4bb73ba9e65b2ba16caa6bcb25a7497350a4e5859b65
5 |
--------------------------------------------------------------------------------
/python/serving/serving_framework/requirements.in:
--------------------------------------------------------------------------------
1 | absl-py>=2.1.0
2 | flask~=3.0.3
3 | grpcio~=1.68.1
4 | grpcio-status~=1.68.1
5 | gunicorn~=23.0.0
6 | jsonschema~=4.23.0
7 | numpy<=2.0.2 # bypassing faulty version restriction in tritonclient
8 | setuptools~=75.6.0
9 | typing-extensions~=4.12.2
10 |
--------------------------------------------------------------------------------
/python/serving/requirements.in:
--------------------------------------------------------------------------------
1 | google-cloud-core~=2.4.3
2 | google-cloud-logging~=3.12.1
3 | google-cloud-secret-manager==2.20.2
4 | psutil==7.1.3
5 | # TODO(b/468062329): Replace with regular pip version when released.
6 | https://github.com/huggingface/transformers/archive/65dc261512cbdb1ee72b88ae5b222f2605aad8e5.tar.gz#egg=transformers
7 | scipy~=1.16.3
8 | torch~=2.9.1
9 | -r serving_framework/requirements.in
10 | -r serving_framework/triton/requirements.in
11 |
--------------------------------------------------------------------------------
/python/serving/model_repository/default/config.pbtxt:
--------------------------------------------------------------------------------
1 | name: "default"
2 | backend: "pytorch"
3 | runtime: "model.py"
4 | input [
5 | {
6 | name: "input_features__0"
7 | data_type: TYPE_FP32
8 | dims: [-1, -1, 128]
9 | },
10 | {
11 | name: "attention_mask__1"
12 | data_type: TYPE_BOOL
13 | dims: [-1, -1]
14 | }
15 | ]
16 | output [
17 | {
18 | name: "tokens__0"
19 | data_type: TYPE_INT64
20 | dims: [-1, -1]
21 | }
22 | ]
23 |
--------------------------------------------------------------------------------
/python/serving/serving_framework/README.md:
--------------------------------------------------------------------------------
1 | # Serving framework
2 |
3 | This directory contains a Python library that simplifies the creation of custom
4 | prediction containers for Vertex AI. It provides the framework for building HTTP
5 | servers that meet the platform's
6 | [requirements](https://cloud.google.com/vertex-ai/docs/predictions/custom-container-requirements).
7 |
8 | To implement a model-specific HTTP server, frame the custom data handling and
9 | orchestration logic within this framework.
10 |
--------------------------------------------------------------------------------
/python/serving/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
--------------------------------------------------------------------------------
/python/serving/serving_framework/__init__.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright 2024 Google LLC
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
--------------------------------------------------------------------------------
/python/serving/serving_framework/triton/__init__.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright 2025 Google LLC
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
--------------------------------------------------------------------------------
/notebooks/README.md:
--------------------------------------------------------------------------------
1 | # MedASR Notebooks
2 |
3 | * [Quick start with Hugging Face](quick_start_with_hugging_face.ipynb) - This
4 | notebook demonstrates how to quickly get started with MedASR, Google's open
5 | Automatic Speech Recognition (ASR) model fine-tuned for the medical domain.
6 | It shows how to load the MedASR model from Hugging Face and use the
7 | transformers library pipeline to perform speech-to-text transcription on an
8 | example audio file.
9 |
10 | * [Quick start with Model Garden](quick_start_with_model_garden.ipynb) - This
11 | notebook demonstrates how to use MedASR as a service deployed on
12 | [Vertex AI](https://cloud.google.com/vertex-ai/docs/predictions/overview).
13 | It shows how to use the service API to perform speech-to-text transcription
14 | in online or batch workflows.
15 |
16 | * [Fine-tune with Hugging Face](fine_tune_with_hugging_face.ipynb) - This
17 | notebook shows how to fine-tune the model locally using Hugging Face
18 | libraries.
19 |
--------------------------------------------------------------------------------
/CONTRIBUTING.md:
--------------------------------------------------------------------------------
1 | # How to Contribute
2 |
3 | We would love to accept your patches and contributions to this project.
4 |
5 | ## Before you begin
6 |
7 | ### Sign our Contributor License Agreement
8 |
9 | Contributions to this project must be accompanied by a
10 | [Contributor License Agreement](https://cla.developers.google.com/about) (CLA).
11 | You (or your employer) retain the copyright to your contribution; this simply
12 | gives us permission to use and redistribute your contributions as part of the
13 | project.
14 |
15 | If you or your current employer have already signed the Google CLA (even if it
16 | was for a different project), you probably don't need to do it again.
17 |
18 | Visit to see your current agreements or to
19 | sign a new one.
20 |
21 | ### Review our Community Guidelines
22 |
23 | This project follows HAI-DEF's
24 | [Community guidelines](https://developers.google.com/health-ai-developer-foundations/community-guidelines)
25 |
26 | ## Contribution process
27 |
28 | ### Code Reviews
29 |
30 | All submissions, including submissions by project members, require review. We
31 | use [GitHub pull requests](https://docs.github.com/articles/about-pull-requests)
32 | for this purpose.
33 |
34 | For more details, read HAI-DEF's
35 | [Contributing guidelines](https://developers.google.com/health-ai-developer-foundations/community-guidelines#contributing)
36 |
--------------------------------------------------------------------------------
/python/serving/predictor_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Tests for MedASR predictor."""
16 |
17 | import base64
18 |
19 | from absl.testing import absltest
20 | from absl.testing import parameterized
21 |
22 | from serving import predictor
23 |
24 |
25 | class PredictorTest(parameterized.TestCase):
26 |
27 | @parameterized.named_parameters(
28 | {
29 | 'testcase_name': 'not_base64',
30 | 'audio_string': 'not_base64',
31 | },
32 | {
33 | 'testcase_name': 'not_wav',
34 | 'audio_string': base64.b64encode(b'not_wav').decode('utf-8'),
35 | },
36 | )
37 | def test_bad_file_loading_raises_error(self, audio_string: str):
38 | with self.assertRaises(ValueError):
39 | predictor.load_wav(audio_string)
40 |
41 |
42 | if __name__ == '__main__':
43 | absltest.main()
44 |
--------------------------------------------------------------------------------
/python/serving/serving_framework/triton/server_health_check.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """REST-based health check implementation for Tensorflow model servers."""
16 |
17 | import http
18 |
19 | import requests
20 | from typing_extensions import override
21 |
22 | from serving.serving_framework import server_gunicorn
23 |
24 |
25 | class TritonServerHealthCheck(server_gunicorn.ModelServerHealthCheck):
26 | """Checks the health of the local model server via REST request."""
27 |
28 | def __init__(self, health_check_port: int):
29 | self._health_check_url = (
30 | f"http://localhost:{health_check_port}/v2/health/ready"
31 | )
32 |
33 | @override
34 | def check_health(self) -> bool:
35 | try:
36 | r = requests.get(self._health_check_url)
37 | return r.status_code == http.HTTPStatus.OK.value
38 | except requests.exceptions.ConnectionError:
39 | return False
40 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # MedASR
2 |
3 | MedASR is an automated speech recognition (ASR) model that takes spoken audio as an input and represents it as text. MedASR has been specifically trained for the healthcare domain. This means it has higher out-of-the box performance in healthcare contexts than general ASR models. MedASR is open and can be fine-tuned for even higher performance.
4 |
5 | ## Get started
6 |
7 | * Read our
8 | [developer documentation](https://developers.google.com/health-ai-developer-foundations/medgemma/get-started)
9 | to see the full range of next steps available, including learning more about
10 | the model through its
11 | [model card](https://developers.google.com/health-ai-developer-foundations/medasr/model-card).
12 |
13 | * Explore this repository, which contains [notebooks](./notebooks) for using
14 | the model.
15 |
16 | * Visit the model on
17 | [Hugging Face](https://huggingface.co/models?other=medasr) or
18 | [Model Garden](https://console.cloud.google.com/vertex-ai/publishers/google/model-garden/medasr).
19 |
20 | ## Contributing
21 |
22 | We are open to bug reports, pull requests (PR), and other contributions. See
23 | [CONTRIBUTING](CONTRIBUTING.md) and
24 | [community guidelines](https://developers.google.com/health-ai-developer-foundations/community-guidelines)
25 | for details.
26 |
27 | ## License
28 |
29 | While the model is licensed under the
30 | [Health AI Developer Foundations License](https://developers.google.com/health-ai-developer-foundations/terms),
31 | everything in this repository is licensed under the Apache 2.0 license, see
32 | [LICENSE](LICENSE).
33 |
--------------------------------------------------------------------------------
/python/serving/serving_framework/model_transfer.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | r"""A tool to transfer a model from the Vertex gcs bucket to a local directory.
16 |
17 | Copies a target gcs directory into a local directory using default credentials,
18 | intended to be used to set up tfserving-compatible model directories during
19 | serving framework startup.
20 |
21 | usage:
22 | python3 model_transfer.py --gcs_path="gs://bucket/object" \
23 | --local_path="/path/to/local/dir"
24 | """
25 |
26 | from collections.abc import Sequence
27 |
28 | from absl import app
29 | from absl import flags
30 |
31 | from google.cloud.aiplatform.utils import gcs_utils
32 |
33 |
34 | _GCS_PATH = flags.DEFINE_string(
35 | "gcs_path",
36 | None,
37 | "The gcs path to copy from.",
38 | required=True,
39 | )
40 | _LOCAL_PATH = flags.DEFINE_string(
41 | "local_path",
42 | None,
43 | "The local path to copy to.",
44 | required=True,
45 | )
46 |
47 |
48 | def main(argv: Sequence[str]) -> None:
49 | if len(argv) > 1:
50 | raise app.UsageError("Too many command-line arguments.")
51 | gcs_utils.download_from_gcs(_GCS_PATH.value, _LOCAL_PATH.value)
52 |
53 | if __name__ == "__main__":
54 | app.run(main)
55 |
--------------------------------------------------------------------------------
/python/serving/entrypoint.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | # Copyright 2025 Google LLC
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | # This script launches the serving framework, run as the entrypoint.
17 |
18 | # Exit if expanding an undefined variable.
19 | set -u
20 |
21 | export MODEL_REST_PORT=8600
22 | export AIP_PREDICT_ROUTE="/v1/audio/transcriptions"
23 |
24 | if [[ -v "AIP_STORAGE_URI" && -n "$AIP_STORAGE_URI" ]]; then
25 | export MODEL_FILES="/model_files"
26 | mkdir "$MODEL_FILES"
27 | gcloud storage cp "$AIP_STORAGE_URI/*" "$MODEL_FILES" --recursive
28 | fi
29 |
30 | echo "Serving framework start, launching model server"
31 |
32 | (/opt/tritonserver/bin/tritonserver \
33 | --model-repository="/serving/model_repository" \
34 | --allow-grpc=true \
35 | --grpc-address=127.0.0.1 \
36 | --grpc-port=8500 \
37 | --allow-http=true \
38 | --http-address=127.0.0.1 \
39 | --http-port="${MODEL_REST_PORT}" \
40 | --allow-vertex-ai=false \
41 | --strict-readiness=true || exit) &
42 |
43 | echo "Launching front end"
44 |
45 | (/server-env/bin/python3.12 -m serving.server_gunicorn --alsologtostderr \
46 | --verbosity=1 || exit)&
47 |
48 | # Wait for any process to exit
49 | wait -n
50 |
51 | # Exit with status of process that exited first
52 | exit $?
53 |
--------------------------------------------------------------------------------
/python/serving/logging_lib/flags/flag_utils.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | """Utility functions for flags."""
16 | import os
17 |
18 |
19 | def str_to_bool(val: str) -> bool:
20 | """Converts a string representation of truth.
21 |
22 | True values are 'y', 'yes', 't', 'true', 'on', and '1';
23 | False values are 'n', 'no', 'f', 'false', 'off', and '0'.
24 |
25 | Args:
26 | val: String to convert to bool.
27 |
28 | Returns:
29 | Boolean result
30 |
31 | Raises:
32 | ValueError if val is anything else.
33 | """
34 | val = val.strip().lower()
35 | if val in ('y', 'yes', 't', 'true', 'on', '1'):
36 | return True
37 | if val in ('n', 'no', 'f', 'false', 'off', '0'):
38 | return False
39 | raise ValueError(f'invalid truth value {str(val)}')
40 |
41 |
42 | def env_value_to_bool(env_name: str, undefined_value: bool = False) -> bool:
43 | """Converts environmental variable value into boolean value for flag init.
44 |
45 | Args:
46 | env_name: Environmental variable name.
47 | undefined_value: Default value to set undefined values to.
48 |
49 | Returns:
50 | Boolean of environmental variable string value.
51 |
52 | Raises:
53 | ValueError : Environmental variable cannot be parsed to bool.
54 | """
55 | if env_name not in os.environ:
56 | return undefined_value
57 | env_value = os.environ[env_name].strip()
58 | return str_to_bool(env_value)
59 |
--------------------------------------------------------------------------------
/python/serving/model_repository/default/1/model.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright 2025 Google LLC
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Model file compatible with Triton PyTorch 2.0 backend."""
17 |
18 | import os
19 | from typing import Optional
20 |
21 | import torch
22 | from transformers.models import lasr
23 |
24 |
25 | # Override dir to hide imports from the Triton backend's loading strategy.
26 | # Without this, the backend can attempt to load from the imports
27 | # instead of MedASRWrapper.
28 | def __dir__():
29 | return ["MedASRWrapper", "__name__", "__spec__"]
30 |
31 |
32 | class MedASRWrapper(torch.nn.Module):
33 | """Wraps SiglipModel with custom weight loading and return structure."""
34 |
35 | def __init__(self):
36 | super(MedASRWrapper, self).__init__()
37 | token = None
38 | if os.getenv("AIP_STORAGE_URI"):
39 | # Using model files copied from Vertex GCS bucket.
40 | model_origin = os.getenv("MODEL_FILES")
41 | else:
42 | # Using model files from HF repository.
43 | model_origin = os.getenv("MODEL_ID")
44 | if not model_origin:
45 | raise ValueError(
46 | "No model origin found. MODEL_ID or AIP_STORAGE_URI must be set."
47 | )
48 | token = os.getenv("HF_TOKEN") # optional for access to non-public models.
49 | self._model = lasr.LasrForCTC.from_pretrained(
50 | model_origin,
51 | token=token,
52 | )
53 |
54 | def forward(
55 | self,
56 | input_features: Optional[torch.Tensor] = None,
57 | attention_mask: Optional[torch.Tensor] = None,
58 | ):
59 | output = self._model.generate(
60 | input_features=input_features, attention_mask=attention_mask
61 | )
62 | return output
63 |
--------------------------------------------------------------------------------
/python/serving/logging_lib/flags/flag_utils_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | """Tests for flag utils."""
16 | import os
17 | from unittest import mock
18 |
19 | from absl.testing import absltest
20 | from absl.testing import parameterized
21 |
22 | from serving.logging_lib.flags import flag_utils
23 |
24 | # const
25 | _UNDEFINED_ENV_VAR_NAME = 'UNDEFINED'
26 |
27 |
28 | class FlagUtilsTest(parameterized.TestCase):
29 |
30 | @parameterized.parameters(['y', ' YES ', 't', 'tRue', 'on', '1'])
31 | def test_str_to_bool_true(self, val):
32 | self.assertTrue(flag_utils.str_to_bool(val))
33 |
34 | @parameterized.parameters(['n', 'no', 'f', 'FALSE', ' oFf ', '0'])
35 | def test_strtobool_false(self, val):
36 | self.assertFalse(flag_utils.str_to_bool(val))
37 |
38 | def test_str_to_bool_raises(self):
39 | with self.assertRaises(ValueError):
40 | flag_utils.str_to_bool('ABCD')
41 |
42 | def test_undefined_env_default_true(self):
43 | self.assertNotIn(_UNDEFINED_ENV_VAR_NAME, os.environ)
44 | self.assertTrue(flag_utils.env_value_to_bool(_UNDEFINED_ENV_VAR_NAME, True))
45 |
46 | def test_undefined_env_no_default(self):
47 | self.assertNotIn(_UNDEFINED_ENV_VAR_NAME, os.environ)
48 | self.assertFalse(flag_utils.env_value_to_bool(_UNDEFINED_ENV_VAR_NAME))
49 |
50 | @mock.patch.dict(os.environ, {'FOO': ' True '})
51 | def test_initialized_env(self):
52 | self.assertTrue(flag_utils.env_value_to_bool('FOO'))
53 |
54 | @mock.patch.dict(os.environ, {'BAD_VALUE': 'TrueABSCDEF'})
55 | def test_bad_initialized_env(self):
56 | with self.assertRaises(ValueError):
57 | flag_utils.env_value_to_bool('BAD_VALUE')
58 |
59 |
60 | if __name__ == '__main__':
61 | absltest.main()
62 |
--------------------------------------------------------------------------------
/python/serving/serving_framework/triton/server_health_check_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import http
16 | import os
17 | from unittest import mock
18 |
19 | import requests
20 | import requests_mock
21 |
22 | from absl.testing import absltest
23 | from serving.serving_framework import server_gunicorn
24 | from serving.serving_framework.triton import server_health_check
25 |
26 |
27 | class ServerHealthCheckTest(absltest.TestCase):
28 |
29 | def setUp(self):
30 | super().setUp()
31 | os.environ["AIP_PREDICT_ROUTE"] = "/fake-predict-route"
32 | os.environ["AIP_HEALTH_ROUTE"] = "/fake-health-route"
33 |
34 | @requests_mock.Mocker()
35 | def test_health_route_pass_check(self, mock_requests):
36 | mock_requests.register_uri(
37 | "GET",
38 | "http://localhost:12345/v2/health/ready",
39 | text="assorted_metadata",
40 | status_code=http.HTTPStatus.OK,
41 | )
42 |
43 | executor = mock.create_autospec(
44 | server_gunicorn.PredictionExecutor,
45 | instance=True,
46 | )
47 |
48 | app = server_gunicorn.PredictionApplication(
49 | executor,
50 | health_check=server_health_check.TritonServerHealthCheck(
51 | 12345
52 | ),
53 | ).load()
54 | service = app.test_client()
55 |
56 | response = service.get("/fake-health-route")
57 |
58 | self.assertEqual(response.status_code, http.HTTPStatus.OK)
59 | self.assertEqual(response.text, "ok")
60 |
61 | @requests_mock.Mocker()
62 | def test_health_route_fail_check(self, mock_requests):
63 | mock_requests.register_uri(
64 | "GET",
65 | "http://localhost:12345/v2/health/ready",
66 | exc=requests.exceptions.ConnectionError,
67 | )
68 | executor = mock.create_autospec(
69 | server_gunicorn.PredictionExecutor,
70 | instance=True,
71 | )
72 |
73 | app = server_gunicorn.PredictionApplication(
74 | executor,
75 | health_check=server_health_check.TritonServerHealthCheck(
76 | 12345
77 | ),
78 | ).load()
79 | service = app.test_client()
80 |
81 | response = service.get("/fake-health-route")
82 |
83 | self.assertEqual(response.status_code, http.HTTPStatus.SERVICE_UNAVAILABLE)
84 | self.assertEqual(response.text, "not ok")
85 |
86 |
87 | if __name__ == "__main__":
88 | absltest.main()
89 |
--------------------------------------------------------------------------------
/python/serving/server_gunicorn.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Launcher for the prediction_executor based encoder server.
16 |
17 | Uses the serving framework to create a request server which
18 | performs the logic for requests in separate processes and uses a local NVIDIA
19 | Triton model server to handle the model.
20 | """
21 |
22 | from collections.abc import Sequence
23 | import os
24 |
25 | from absl import app
26 | from absl import logging
27 | import jsonschema
28 | import yaml
29 |
30 | from serving.serving_framework import inline_prediction_executor
31 | from serving.serving_framework import server_gunicorn
32 | from serving.serving_framework.triton import server_health_check
33 | from serving.serving_framework.triton import triton_server_model_runner
34 | from serving import predictor
35 |
36 |
37 | def main(argv: Sequence[str]) -> None:
38 | if len(argv) > 1:
39 | raise app.UsageError('Too many command-line arguments.')
40 | if 'AIP_HTTP_PORT' not in os.environ:
41 | raise ValueError(
42 | 'The environment variable AIP_HTTP_PORT needs to be specified.'
43 | )
44 | http_port = int(os.environ.get('AIP_HTTP_PORT'))
45 | options = {
46 | 'bind': f'0.0.0.0:{http_port}',
47 | 'workers': 3,
48 | 'timeout': 240,
49 | }
50 | model_rest_port = int(os.environ.get('MODEL_REST_PORT'))
51 | health_checker = server_health_check.TritonServerHealthCheck(model_rest_port)
52 | # Get schema validators.
53 | local_path = os.path.dirname(__file__)
54 | with open(
55 | os.path.join(local_path, 'vertex_schemata', 'request.yaml'), 'r'
56 | ) as f:
57 | instance_validator = jsonschema.Draft202012Validator(yaml.safe_load(f))
58 | # with open(
59 | # os.path.join(local_path, 'vertex_schemata', 'response.yaml'), 'r'
60 | # ) as f:
61 | # prediction_validator = jsonschema.Draft202012Validator(yaml.safe_load(f))
62 | prediction_validator = None
63 | predictor_instance = predictor.MedASRPredictor(
64 | model_source=os.environ.get('MODEL_FILES')
65 | )
66 | logging.info('Launching gunicorn application.')
67 | server_gunicorn.PredictionApplication(
68 | inline_prediction_executor.InlinePredictionExecutor(
69 | predictor_instance.predict,
70 | triton_server_model_runner.TritonServerModelRunner,
71 | ),
72 | health_check=health_checker,
73 | options=options,
74 | instance_input=False,
75 | input_validator=instance_validator,
76 | prediction_validator=prediction_validator,
77 | ).run()
78 |
79 |
80 | if __name__ == '__main__':
81 | app.run(main)
82 |
--------------------------------------------------------------------------------
/python/serving/serving_framework/inline_prediction_executor_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from collections.abc import Mapping, Set
16 | from unittest import mock
17 | from typing import Any
18 |
19 | import numpy as np
20 |
21 | from absl.testing import absltest
22 | from serving.serving_framework import inline_prediction_executor
23 | from serving.serving_framework import model_runner
24 |
25 |
26 | class DummyModelRunner(model_runner.ModelRunner):
27 | """Dummy model runner for testing."""
28 |
29 | def run_model_multiple_output(
30 | self,
31 | model_input: Mapping[str, np.ndarray] | np.ndarray,
32 | *,
33 | model_name: str = "default",
34 | model_version: int | None = None,
35 | model_output_keys: Set[str],
36 | parameters: Mapping[str, Any] | None = None,
37 | ) -> Mapping[str, np.ndarray]:
38 | del model_name, model_version, model_output_keys, parameters
39 | return {"output_0": np.ones((1, 2), dtype=np.float32)}
40 |
41 |
42 | class InlinePredictionExecutorTest(absltest.TestCase):
43 |
44 | def test_predict_requires_start(self):
45 | predictor = mock.MagicMock()
46 | executor = inline_prediction_executor.InlinePredictionExecutor(
47 | predictor, DummyModelRunner
48 | )
49 | with self.assertRaises(RuntimeError):
50 | executor.predict({"placeholder": "input"})
51 |
52 | def test_execute_catches_predictor_exception(self):
53 | predictor = mock.MagicMock(side_effect=Exception("test error"))
54 | executor = inline_prediction_executor.InlinePredictionExecutor(
55 | predictor, DummyModelRunner
56 | )
57 | executor.start()
58 | with self.assertRaises(RuntimeError):
59 | executor.execute({"placeholder": "input"})
60 |
61 | def test_execute_calls_predictor(self):
62 | predictor = mock.MagicMock(return_value={"placeholder": "output"})
63 | mock_model_runner = mock.create_autospec(
64 | DummyModelRunner, instance=True
65 | )
66 | mock_model_runner_class = mock.create_autospec(
67 | DummyModelRunner, autospec=True
68 | )
69 | mock_model_runner_class.return_value = mock_model_runner
70 | executor = inline_prediction_executor.InlinePredictionExecutor(
71 | predictor, mock_model_runner_class
72 | )
73 |
74 | executor.start()
75 | self.assertEqual(
76 | executor.execute({"placeholder": "input"}),
77 | {"placeholder": "output"},
78 | )
79 | mock_model_runner_class.assert_called_once()
80 | predictor.assert_called_once_with(
81 | {"placeholder": "input"}, mock_model_runner
82 | )
83 |
84 |
85 | if __name__ == "__main__":
86 | absltest.main()
87 |
--------------------------------------------------------------------------------
/python/serving/serving_framework/inline_prediction_executor.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """A thin shell to fit a predictor function into the PredictionExecutor interface.
16 |
17 | This is a convenience wrapper to allow a predictor function to be used directly
18 | as a PredictionExecutor.
19 |
20 | Intended usage in launching a server:
21 | predictor = Predictor()
22 | executor = InlinePredictionExecutor(predictor.predict)
23 | server_gunicorn.PredictionApplication(executor, ...).run()
24 | """
25 |
26 | from collections.abc import Callable
27 | from typing import Any
28 |
29 | from typing_extensions import override
30 |
31 | from serving.serving_framework import model_runner
32 | from serving.serving_framework import server_gunicorn
33 |
34 |
35 | class InlinePredictionExecutor(server_gunicorn.PredictionExecutor):
36 | """Provides prediction request execution using an inline function.
37 |
38 | Provides a little framing to simplify the use of predictor functions in the
39 | server worker process.
40 |
41 | If a function call with no setup is insufficient, overriding the start and
42 | predict methods here can be used to provide more complex behavior. Inheritance
43 | directly from PredictionExecutor is also an option.
44 | """
45 |
46 | def __init__(
47 | self,
48 | predictor: Callable[
49 | [dict[str, Any], model_runner.ModelRunner], dict[str, Any]
50 | ],
51 | model_runner_source: Callable[[], model_runner.ModelRunner]
52 | ):
53 | self._predictor = predictor
54 | self._model_runner = None
55 | self._model_runner_source = model_runner_source
56 |
57 | @override
58 | def start(self) -> None:
59 | """Starts the executor.
60 |
61 | Called after the Gunicorn worker process has started. Performs any setup
62 | which needs to be done post-fork.
63 | """
64 | # Safer to instantiate the RPC stub post-fork.
65 | self._model_runner = self._model_runner_source()
66 |
67 | def predict(self, input_json: dict[str, Any]) -> dict[str, Any]:
68 | """Executes the given request payload."""
69 | if self._model_runner is None:
70 | raise RuntimeError(
71 | "Model runner is not initialized. Please call start() first."
72 | )
73 | return self._predictor(input_json, self._model_runner)
74 |
75 | @override
76 | def execute(self, input_json: dict[str, Any]) -> dict[str, Any]:
77 | """Executes the given request payload.
78 |
79 | Args:
80 | input_json: The full json prediction request payload.
81 |
82 | Returns:
83 | The json response to the prediction request.
84 |
85 | Raises:
86 | RuntimeError: Prediction failed in an unhandled way.
87 | """
88 | try:
89 | return self.predict(input_json)
90 | except Exception as e:
91 | raise RuntimeError("Unhandled exception from predictor.") from e
92 |
--------------------------------------------------------------------------------
/python/serving/Dockerfile:
--------------------------------------------------------------------------------
1 | # Copyright 2025 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | #
15 | # This is used to build a Docker image that includes the necessary dependencies
16 | # for running MedGemma as a microservice.
17 |
18 | FROM python:3.12-slim AS prep
19 |
20 | # Get pip requirements
21 | COPY ./python/serving/requirements.txt requirements.txt
22 | COPY ./python/serving/serving_framework/pip-install.txt pip-install.txt
23 | RUN python3.12 -m venv --copies /payload/server-env && \
24 | /payload/server-env/bin/python3.12 -m pip install --require-hashes \
25 | -r pip-install.txt && \
26 | /payload/server-env/bin/python3.12 -m pip install --require-hashes \
27 | -r requirements.txt
28 |
29 | # Set up code directories
30 | COPY ./python/serving /payload/serving
31 | COPY ./LICENSE /payload/LICENSE
32 | RUN chmod a+x /payload/serving/entrypoint.sh
33 |
34 | # Install git and clone source code for mirroring
35 | RUN apt-get update && \
36 | apt-get install --no-install-recommends -y git && \
37 | rm -rf /var/lib/apt/lists/*
38 | RUN git clone https://github.com/certifi/python-certifi.git /source-mirror/python-certifi && \
39 | git clone https://github.com/tqdm/tqdm.git /source-mirror/tqdm && \
40 | git clone https://git.launchpad.net/launchpadlib /source-mirror/launchpadlib && \
41 | git clone https://git.launchpad.net/lazr.restfulclient /source-mirror/lazr.restfulclient && \
42 | git clone https://git.launchpad.net/lazr.uri /source-mirror/lazr.uri && \
43 | git clone https://git.launchpad.net/wadllib /source-mirror/wadllib && \
44 | git clone https://gitlab.gnome.org/GNOME/pygobject.git /source-mirror/pygobject && \
45 | git clone https://git.launchpad.net/python-apt /source-mirror/python-apt
46 |
47 | # ----------------------------------------------------------------------------
48 |
49 | FROM google/cloud-sdk:stable AS gcloud_builder
50 |
51 | # ----------------------------------------------------------------------------
52 |
53 | FROM nvcr.io/nvidia/tritonserver:25.03-py3
54 |
55 | WORKDIR /
56 | # Install model dependencies for the triton backend.
57 | COPY ./python/serving/model_requirements.txt model_requirements.txt
58 | RUN python3 -m pip install --require-hashes -r /model_requirements.txt
59 | # Copy in minimal gcloud install
60 | COPY --from=gcloud_builder /usr/lib/google-cloud-sdk /usr/lib/google-cloud-sdk
61 | ENV PATH="/usr/lib/google-cloud-sdk/bin:$PATH"
62 |
63 | # Prevent compatibility-breaking NCCL updates.
64 | RUN apt-mark hold libnccl-dev libnccl2 && \
65 | # apt upgrade
66 | apt-get update && \
67 | apt-get -y upgrade && \
68 | # Cleanup cached files.
69 | rm -rf /var/lib/apt/lists/*
70 |
71 | # Copy the code and python environment from the prep image
72 | COPY --from=prep /payload /
73 | # Copy mirrored source code from the prep image
74 | COPY --from=prep /source-mirror /source-mirror/
75 |
76 | ENTRYPOINT ["/serving/entrypoint.sh"]
77 |
--------------------------------------------------------------------------------
/python/serving/serving_framework/triton/triton_server_model_runner.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """ModelRunner implementation using grpc to NVIDIA Triton model server.
16 |
17 | Uses Triton GRPC client and relies on a model server running locally.
18 | """
19 |
20 | from collections.abc import Mapping, Set
21 | from typing import Any
22 |
23 | from absl import logging
24 | import numpy as np
25 | from tritonclient import grpc as triton_grpc
26 | from tritonclient import utils as triton_utils
27 | from typing_extensions import override
28 |
29 | from serving.serving_framework import model_runner
30 |
31 | _HOSTPORT = "localhost:8500"
32 |
33 |
34 | class TritonServerModelRunner(model_runner.ModelRunner):
35 | """ModelRunner implementation using grpc to NVIDIA Triton model server."""
36 |
37 | def __init__(self, client: triton_grpc.InferenceServerClient | None = None):
38 | if client is not None:
39 | self._client = client
40 | else:
41 | self._client = triton_grpc.InferenceServerClient(_HOSTPORT)
42 |
43 | @override
44 | def run_model_multiple_output(
45 | self,
46 | model_input: Mapping[str, np.ndarray] | np.ndarray,
47 | *,
48 | model_name: str = "default",
49 | model_version: int | None = None,
50 | model_output_keys: Set[str],
51 | parameters: Mapping[str, Any] | None = None,
52 | ) -> Mapping[str, np.ndarray]:
53 | """Runs a model on the given input and returns multiple outputs.
54 |
55 | Args:
56 | model_input: An array or map of arrays comprising the input tensors for
57 | the model. A bare array is keyed by "inputs".
58 | model_name: The name of the model to run.
59 | model_version: The version of the model to run. Uses default if None.
60 | model_output_keys: The desired model output keys.
61 | parameters: Additional parameters to pass to the model.
62 |
63 | Returns:
64 | A mapping of model output keys to tensors.
65 |
66 | Raises:
67 | KeyError: If any of the model_output_keys are not found in the model
68 | output.
69 | """
70 | # If a bare np.ndarray was passed, it will be passed using the default
71 | # input key "inputs".
72 | # If a Mapping was passed, use the keys from that mapping.
73 | if isinstance(model_input, np.ndarray):
74 | logging.debug("Handling bare input tensor.")
75 | input_map = {"inputs": model_input}
76 | else:
77 | input_map = model_input
78 |
79 | model_version = str(model_version) if model_version is not None else ""
80 |
81 | inputs = []
82 | for key, data in input_map.items():
83 | input_tensor = triton_grpc.InferInput(
84 | key, data.shape, triton_utils.np_to_triton_dtype(data.dtype)
85 | )
86 | input_tensor.set_data_from_numpy(data)
87 | inputs.append(input_tensor)
88 |
89 | model_parameters = None if parameters is None else dict(parameters)
90 |
91 | result = self._client.infer(
92 | model_name, inputs, model_version, parameters=model_parameters
93 | )
94 | assert result is not None # infer never returns None, despite annotation.
95 |
96 | outputs = {key: result.as_numpy(key) for key in model_output_keys}
97 | missing_keys = {key for key in model_output_keys if outputs[key] is None}
98 | if missing_keys:
99 | raise KeyError(
100 | f"Model output keys {missing_keys} not found in model output."
101 | )
102 | return outputs
103 |
--------------------------------------------------------------------------------
/python/serving/vertex_schemata/response.yaml:
--------------------------------------------------------------------------------
1 | # Copyright 2025 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | title: MedASRResponse
16 | type: object
17 | description: >
18 | The schema for a single Speech-to-Text transcription response. Parameters are aligned
19 | with the core functionality of the OpenAI Transcription API:
20 | https://platform.openai.com/docs/api-reference/audio/json-object
21 | additionalProperties: false
22 | required:
23 | - text
24 | properties:
25 | text:
26 | type: string
27 | description: The transcribed text.
28 | logprobs:
29 | type: array
30 | description: The log probabilities of the tokens in the transcription.
31 | items:
32 | $ref: "#/components/schemas/LogprobDetail"
33 | usage:
34 | description: Token usage statistics for the request.
35 | anyOf:
36 | - $ref: "#/components/schemas/TokenUsage"
37 | - $ref: "#/components/schemas/DurationUsage"
38 |
39 | components:
40 | schemas:
41 | LogprobDetail:
42 | title: LogprobDetail
43 | type: object
44 | description: A single log probability object for a token.
45 | additionalProperties: false
46 | required:
47 | - bytes
48 | - logprob
49 | - token
50 | properties:
51 | bytes:
52 | type: array
53 | description: The bytes of the token (represented as an array of integers).
54 | items:
55 | type: integer
56 | logprob:
57 | type: number
58 | description: The log probability of the token.
59 | token:
60 | type: string
61 | description: The token in the transcription.
62 |
63 | TokenUsage:
64 | title: TokenUsage
65 | type: object
66 | description: Token usage statistics.
67 | additionalProperties: false
68 | required:
69 | - input_tokens
70 | - output_tokens
71 | - total_tokens
72 | - type
73 | - input_token_details
74 | properties:
75 | input_tokens:
76 | type: integer
77 | description: Number of input tokens for this request.
78 | output_tokens:
79 | type: integer
80 | description: Number of output tokens generated.
81 | total_tokens:
82 | type: integer
83 | description: Total number of tokens used (input + output).
84 | type:
85 | type: string
86 | description: The type of the usage object. Always 'tokens' for this variant.
87 | enum:
88 | - tokens
89 | input_token_details:
90 | $ref: "#/components/schemas/InputTokenDetails"
91 |
92 | DurationUsage:
93 | title: DurationUsage
94 | type: object
95 | description: Audio input duration usage statistics.
96 | additionalProperties: false
97 | required:
98 | - seconds
99 | - type
100 | properties:
101 | seconds:
102 | type: number
103 | description: Duration of the input audio in seconds.
104 | type:
105 | type: string
106 | description: The type of the usage object. Always 'duration' for this variant.
107 | enum:
108 | - duration
109 |
110 | InputTokenDetails:
111 | title: InputTokenDetails
112 | type: object
113 | description: Details about the input tokens billed for this request.
114 | additionalProperties: false
115 | required:
116 | - audio_tokens
117 | - text_tokens
118 | properties:
119 | audio_tokens:
120 | type: integer
121 | description: Number of audio tokens billed for this request.
122 | text_tokens:
123 | type: integer
124 | description: Number of text tokens billed for this request.
--------------------------------------------------------------------------------
/python/serving/predictor.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Generates embeddings for text and imaging data."""
16 | import base64
17 | import binascii
18 | import io
19 | from typing import Any, Mapping
20 |
21 | import numpy as np
22 | from scipy.io import wavfile
23 | import transformers
24 |
25 | from serving.serving_framework import model_runner
26 | from serving.logging_lib import cloud_logging_client
27 |
28 | # Input schema keys
29 | AUDIO_INPUT_KEY = 'file'
30 |
31 | # Model input dictionary mapping
32 | MODEL_INPUT_KEYS = {
33 | 'input_features': 'input_features__0',
34 | 'attention_mask': 'attention_mask__1',
35 | }
36 |
37 | # Model output dictionary keys
38 | TEXT_TRANCRIPT_KEY = 'text'
39 |
40 |
41 | def load_wav(audio_string: str) -> np.ndarray:
42 | """Read a b64 encoded wav file for and convert to mono if needed."""
43 | try:
44 | audio_bytes = base64.b64decode(audio_string, validate=True)
45 | except (binascii.Error, ValueError) as exp:
46 | raise ValueError('Cannot decode input bytes.') from exp
47 | try:
48 | sample_rate, waveform = wavfile.read(io.BytesIO(audio_bytes))
49 | except ValueError as exp:
50 | raise ValueError('Invalid wav file.') from exp
51 | cloud_logging_client.debug(
52 | f'WAV file sample rate: {sample_rate}, shape: {waveform.shape},'
53 | f' dtype={waveform.dtype}'
54 | )
55 | type_info = waveform.dtype
56 | if waveform.ndim > 1:
57 | cloud_logging_client.info('Audio is stereo, converting to mono.')
58 | # Convert to mono by averaging the channels
59 | waveform = waveform.mean(axis=1)
60 | if sample_rate != 16000:
61 | raise ValueError(
62 | f'Sample rate {sample_rate} is not 16000, which is the expected'
63 | ' sample rate for audio.'
64 | )
65 | # Normalize the waveform to -1, 1 float range.
66 | match type_info.kind:
67 | case 'i':
68 | waveform = waveform/np.iinfo(type_info).max
69 | case 'u':
70 | raise ValueError('Unsigned wav format is not supported.')
71 | case 'f':
72 | pass # already in -1, 1 float range.
73 | return waveform
74 |
75 |
76 | class MedASRPredictor:
77 | """Callable responsible for generating embeddings."""
78 |
79 | def __init__(self, model_source: str, token: str | None = None):
80 | self._model_source = model_source
81 | self._token = token
82 | self._processor = None
83 |
84 | def predict(
85 | self,
86 | prediction_input: Mapping[str, Any],
87 | model: model_runner.ModelRunner,
88 | ) -> dict[str, Any]:
89 | """Runs inference on provided patches.
90 |
91 | Args:
92 | prediction_input: JSON formatted input for embedding prediction.
93 | model: ModelRunner to handle model step.
94 |
95 | Returns:
96 | JSON formatted output.
97 | """
98 |
99 | if self._processor is None:
100 | self._processor = transformers.AutoProcessor.from_pretrained(
101 | self._model_source,
102 | token=self._token,
103 | )
104 |
105 | # build response for each instance.
106 | try:
107 | audio = load_wav(prediction_input[AUDIO_INPUT_KEY])
108 | cloud_logging_client.debug('Audio loaded.')
109 | processed = self._processor(audio, return_tensors='np')
110 | cloud_logging_client.debug('Model input processed.')
111 | model_input = {
112 | MODEL_INPUT_KEYS[key]: value for key, value in processed.items()
113 | }
114 | tokens = model.run_model(model_input, model_output_key='tokens__0')
115 | cloud_logging_client.debug('Model run completed.')
116 | transcript = self._processor.batch_decode(tokens)[0]
117 | cloud_logging_client.debug('Tokens decoded.')
118 | cloud_logging_client.info('Returning transcripts.')
119 | return{TEXT_TRANCRIPT_KEY: transcript}
120 | except ValueError as exp:
121 | cloud_logging_client.warning('Failed loading wav file.')
122 | return {'error': str(exp)}
123 |
--------------------------------------------------------------------------------
/python/serving/vertex_schemata/request.yaml:
--------------------------------------------------------------------------------
1 | # Copyright 2025 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | title: MedASRRequest
16 | type: object
17 | description: >
18 | The schema for a single Speech-to-Text transcription request. Parameters are aligned
19 | with the core functionality of the OpenAI Transcription API:
20 | https://platform.openai.com/docs/api-reference/audio/createTranscription
21 | additionalProperties: false
22 | required:
23 | - file
24 | properties:
25 | file:
26 | type: string
27 | format: binary
28 | description: >
29 | The audio file object (not file name) to transcribe, in one of these formats: flac,
30 | mp3, mp4, mpeg, mpga, m4a, ogg, wav, or webm.
31 | title: Audio File
32 | model:
33 | anyOf:
34 | - type: string
35 | - type: "null"
36 | description: >
37 | This field is included for compatibility with the OpenAI Transcription API. It is not
38 | currently used, because the model is implicitly selected based on the endpoint used.
39 | title: Model
40 | chunking_strategy:
41 | anyOf:
42 | - type: string
43 | enum:
44 | - auto
45 | - $ref: "#/components/schemas/ServerVAD"
46 | - type: "null"
47 | description: >
48 | Controls how the audio is cut into chunks. When set to "auto", the server
49 | first normalizes loudness and then uses voice activity detection (VAD) to
50 | choose boundaries. A server_vad object can be provided to tweak VAD detection
51 | parameters manually. If unset, the audio is transcribed as a single block.
52 | title: Chunking Strategy
53 | include:
54 | anyOf:
55 | - type: array
56 | items:
57 | type: string
58 | enum:
59 | - logprobs
60 | - type: "null"
61 | description: >
62 | Additional information to include in the transcription response. logprobs will
63 | return the log probabilities of the tokens in the response to understand the
64 | model's confidence in the transcription.
65 | title: Include
66 | language:
67 | anyOf:
68 | - type: string
69 | - type: "null"
70 | description: >
71 | The language of the input audio. Supplying the input language in ISO-639-1
72 | (e.g. en) format will improve accuracy and latency.
73 | title: Language
74 | prompt:
75 | anyOf:
76 | - type: string
77 | - type: "null"
78 | description: >
79 | An optional text to guide the model's style or continue a previous audio
80 | segment. The prompt should match the audio language.
81 | title: Prompt
82 | temperature:
83 | anyOf:
84 | - type: number
85 | minimum: 0
86 | maximum: 1
87 | - type: "null"
88 | default: 0
89 | description: >
90 | The sampling temperature, between 0 and 1. Higher values like 0.8 will make
91 | the output more random, while lower values like 0.2 will make it more
92 | focused and deterministic.
93 | title: Temperature
94 |
95 | components:
96 | schemas:
97 | ServerVAD:
98 | title: ServerVAD
99 | type: object
100 | description: >
101 | Object to tweak Voice Activity Detection (VAD) parameters manually.
102 | additionalProperties: false
103 | required:
104 | - type
105 | properties:
106 | type:
107 | type: string
108 | description: >
109 | Must be set to 'server_vad' to enable manual chunking using server side VAD.
110 | enum:
111 | - server_vad
112 | prefix_padding_ms:
113 | type: integer
114 | description: Amount of audio to include before the VAD detected speech (in milliseconds).
115 | default: 300
116 | silence_duration_ms:
117 | type: integer
118 | description: >
119 | Duration of silence to detect speech stop (in milliseconds). With shorter values the
120 | model will respond more quickly, but may jump in on short pauses from the user.
121 | default: 200
122 | threshold:
123 | type: number
124 | description: >
125 | Sensitivity threshold (0.0 to 1.0) for voice activity detection. A higher
126 | threshold will require louder audio to activate the model, and thus might perform
127 | better in noisy environments.
128 | default: 0.5
129 | minimum: 0.0
130 | maximum: 1.0
--------------------------------------------------------------------------------
/python/serving/serving_framework/model_runner.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Abstract base class for dependency injection of model handling.
16 |
17 | Wraps execution of models on input tensors in an implementation-agnostic
18 | interface. Provides a mixin method for batching model execution.
19 | """
20 |
21 | import abc
22 | from collections.abc import Mapping, Sequence, Set
23 | from typing import Any
24 |
25 | import numpy as np
26 |
27 |
28 | class ModelRunner(abc.ABC):
29 | """Runs a model with tensor inputs and outputs."""
30 |
31 | @abc.abstractmethod
32 | def run_model_multiple_output(
33 | self,
34 | model_input: Mapping[str, np.ndarray] | np.ndarray,
35 | *,
36 | model_name: str = "default",
37 | model_version: int | None = None,
38 | model_output_keys: Set[str],
39 | parameters: Mapping[str, Any] | None = None,
40 | ) -> Mapping[str, np.ndarray]:
41 | """Runs a model on the given input and returns multiple outputs.
42 |
43 | Args:
44 | model_input: An array or map of arrays comprising the input tensors for
45 | the model. A bare array is given a default input key.
46 | model_name: The name of the model to run.
47 | model_version: The version of the model to run. Uses default if None.
48 | model_output_keys: The desired model output keys.
49 | parameters: Additional parameters to pass to the model.
50 |
51 | Returns:
52 | A mapping of model output keys to tensors.
53 | """
54 |
55 | def run_model(
56 | self,
57 | model_input: Mapping[str, np.ndarray] | np.ndarray,
58 | *,
59 | model_name: str = "default",
60 | model_version: int | None = None,
61 | model_output_key: str = "output_0",
62 | parameters: Mapping[str, Any] | None = None,
63 | ) -> np.ndarray:
64 | """Runs a model on the given input.
65 |
66 | Args:
67 | model_input: An array or map of arrays comprising the input tensors for
68 | the model. A bare array is given a default input key.
69 | model_name: The name of the model to run.
70 | model_version: The version of the model to run. Uses default if None.
71 | model_output_key: The key to pull the output from. Defaults to "output_0".
72 | parameters: Additional parameters to pass to the model.
73 |
74 | Returns:
75 | The single output tensor.
76 | """
77 | return self.run_model_multiple_output(
78 | model_input,
79 | model_name=model_name,
80 | model_version=model_version,
81 | model_output_keys={model_output_key},
82 | parameters=parameters,
83 | )[model_output_key]
84 |
85 | def batch_model(
86 | self,
87 | model_inputs: Sequence[Mapping[str, np.ndarray]] | Sequence[np.ndarray],
88 | *,
89 | model_name: str = "default",
90 | model_version: int | None = None,
91 | model_output_key: str = "output_0",
92 | parameters: Mapping[str, Any] | None = None,
93 | ) -> list[np.ndarray]:
94 | """Runs a model on each of the given inputs.
95 |
96 | Args:
97 | model_inputs: A sequence of arrays or maps of arrays comprising the input
98 | tensors for the model. Bare arrays are given a default input key.
99 | model_name: The name of the model to run.
100 | model_version: The version of the model to run. Uses default if None.
101 | model_output_key: The key to pull the output from. Defaults to "output_0".
102 | parameters: Additional parameters to pass to the model.
103 |
104 | Returns:
105 | A list of the single output tensor from each input.
106 | """
107 | return [
108 | self.run_model(
109 | model_input,
110 | model_name=model_name,
111 | model_version=model_version,
112 | model_output_key=model_output_key,
113 | parameters=parameters,
114 | )
115 | for model_input in model_inputs
116 | ]
117 |
118 | def batch_model_multiple_output(
119 | self,
120 | model_inputs: Sequence[Mapping[str, np.ndarray]] | Sequence[np.ndarray],
121 | *,
122 | model_name: str = "default",
123 | model_version: int | None = None,
124 | model_output_keys: Set[str],
125 | parameters: Mapping[str, Any] | None = None,
126 | ) -> list[Mapping[str, np.ndarray]]:
127 | """Runs a model on the given inputs and returns multiple outputs.
128 |
129 | Args:
130 | model_inputs: An array or map of arrays comprising the input tensors for
131 | the model. Bare arrays are given a default input key.
132 | model_name: The name of the model to run.
133 | model_version: The version of the model to run. Uses default if None.
134 | model_output_keys: The desired model output keys.
135 | parameters: Additional parameters to pass to the model.
136 |
137 | Returns:
138 | A list containing the mapping of model output keys to tensors from each
139 | input.
140 | """
141 | return [
142 | self.run_model_multiple_output(
143 | model_input,
144 | model_name=model_name,
145 | model_version=model_version,
146 | model_output_keys=model_output_keys,
147 | parameters=parameters,
148 | )
149 | for model_input in model_inputs
150 | ]
151 |
--------------------------------------------------------------------------------
/python/serving/serving_framework/triton/triton_streaming_server_model_runner.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """ModelRunner implementation using streaming async grpc to Triton model server.
16 |
17 | Uses Triton GRPC aio client and relies on a model server running locally.
18 | """
19 |
20 | import asyncio
21 | from collections.abc import Mapping, Set
22 | from typing import Any
23 |
24 | from absl import logging
25 | import numpy as np
26 | from tritonclient import grpc as triton_grpc
27 | from tritonclient import utils as triton_utils
28 | from tritonclient.grpc import aio as triton_aio
29 | from typing_extensions import override
30 |
31 | from serving.serving_framework import model_runner
32 |
33 | _HOSTPORT = "localhost:8500"
34 |
35 |
36 | async def _queue_request_generator(request_queue: asyncio.Queue):
37 | while True:
38 | request = await request_queue.get()
39 | if request is None:
40 | break
41 | yield request
42 |
43 |
44 | class TritonStreamingServerModelRunner(model_runner.ModelRunner):
45 | """ModelRunner implementation streaming grpc to NVIDIA Triton model server."""
46 |
47 | def __init__(self):
48 | pass # TODO(bramsterling): set up persistent event loop and rpc connection.
49 |
50 | async def _call(self, model_name, inputs, model_version, parameters):
51 | async with triton_aio.InferenceServerClient(_HOSTPORT) as client:
52 | # Create a queue to hold the stream open after the request is sent.
53 | # Experimentation demonstrated that not doing this results in
54 | # asyncio.CancelledError due to server-side cancellation.
55 | request_queue = asyncio.Queue()
56 | request_queue.put_nowait({
57 | "model_name": model_name,
58 | "inputs": inputs,
59 | "model_version": model_version,
60 | "parameters": parameters,
61 | })
62 |
63 | result_iterator = client.stream_infer(
64 | _queue_request_generator(request_queue)
65 | )
66 |
67 | extract = []
68 | try:
69 | async for result, error in result_iterator:
70 | # Enqueue signal for the input stream to close
71 | request_queue.put_nowait(None)
72 | if error:
73 | raise error
74 |
75 | extract.append(result)
76 |
77 | # Since there is only one request, the loop only runs once.
78 |
79 | if not extract:
80 | raise RuntimeError(
81 | "Stream from model server closed without yielding a result."
82 | )
83 | finally:
84 | # Ensure the input stream gets the close signal in any error case.
85 | # If None has already been enqueued, enqueuing it again has no
86 | # consequence.
87 | request_queue.put_nowait(None)
88 |
89 | return extract[0]
90 |
91 | @override
92 | def run_model_multiple_output(
93 | self,
94 | model_input: Mapping[str, np.ndarray] | np.ndarray,
95 | *,
96 | model_name: str = "default",
97 | model_version: int | None = None,
98 | model_output_keys: Set[str],
99 | parameters: Mapping[str, Any] | None = None,
100 | ) -> Mapping[str, np.ndarray]:
101 | """Runs a model on the given input and returns multiple outputs.
102 |
103 | Args:
104 | model_input: An array or map of arrays comprising the input tensors for
105 | the model. A bare array is keyed by "inputs".
106 | model_name: The name of the model to run.
107 | model_version: The version of the model to run. Uses default if None.
108 | model_output_keys: The desired model output keys.
109 | parameters: Additional parameters to pass to the model.
110 |
111 | Returns:
112 | A mapping of model output keys to tensors.
113 |
114 | Raises:
115 | KeyError: If any of the model_output_keys are not found in the model
116 | output.
117 | """
118 | # If a bare np.ndarray was passed, it will be passed using the default
119 | # input key "inputs".
120 | # If a Mapping was passed, use the keys from that mapping.
121 | if isinstance(model_input, np.ndarray):
122 | logging.debug("Handling bare input tensor.")
123 | input_map = {"inputs": model_input}
124 | else:
125 | input_map = model_input
126 |
127 | model_version = str(model_version) if model_version is not None else ""
128 |
129 | inputs = []
130 | for key, data in input_map.items():
131 | input_tensor = triton_grpc.InferInput(
132 | key, data.shape, triton_utils.np_to_triton_dtype(data.dtype)
133 | )
134 | input_tensor.set_data_from_numpy(data)
135 | inputs.append(input_tensor)
136 |
137 | try:
138 | result = asyncio.run(
139 | self._call(model_name, inputs, model_version, parameters)
140 | )
141 | except asyncio.exceptions.CancelledError as e:
142 | raise RuntimeError(
143 | "Model server request was cancelled. This may be due to the server"
144 | " shutting down or an internal asyncio error."
145 | ) from e
146 |
147 | assert result is not None # infer never returns None, despite annotation.
148 |
149 | outputs = {key: result.as_numpy(key) for key in model_output_keys}
150 | missing_keys = {key for key in model_output_keys if outputs[key] is None}
151 | if missing_keys:
152 | raise KeyError(
153 | f"Model output keys {missing_keys} not found in model output."
154 | )
155 | return outputs
156 |
--------------------------------------------------------------------------------
/python/serving/serving_framework/server_gunicorn_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import http
16 | import io
17 | import os
18 | import subprocess
19 | from unittest import mock
20 |
21 | from absl.testing import absltest
22 | from serving.serving_framework import server_gunicorn
23 |
24 |
25 | class DummyHealthCheck:
26 |
27 | def __init__(self, check_result: bool):
28 | self._check_result = check_result
29 |
30 | def check_health(self):
31 | return self._check_result
32 |
33 |
34 | class ServerGunicornTest(absltest.TestCase):
35 |
36 | def setUp(self):
37 | super().setUp()
38 | os.environ["AIP_PREDICT_ROUTE"] = "/fake-predict-route"
39 | os.environ["AIP_HEALTH_ROUTE"] = "/fake-health-route"
40 |
41 | def test_application_option_default(self):
42 | executor = mock.create_autospec(
43 | server_gunicorn.PredictionExecutor,
44 | instance=True,
45 | )
46 |
47 | app = server_gunicorn.PredictionApplication(executor, health_check=None)
48 |
49 | self.assertEqual(app.cfg.workers, 1)
50 |
51 | def test_application_option_setting(self):
52 | options = {
53 | "workers": 3,
54 | }
55 | executor = mock.create_autospec(
56 | server_gunicorn.PredictionExecutor,
57 | instance=True,
58 | )
59 |
60 | app = server_gunicorn.PredictionApplication(
61 | executor, options=options, health_check=None
62 | )
63 |
64 | self.assertEqual(app.cfg.workers, 3)
65 |
66 | def test_health_route_no_check(self):
67 |
68 | executor = mock.create_autospec(
69 | server_gunicorn.PredictionExecutor,
70 | instance=True,
71 | )
72 |
73 | app = server_gunicorn.PredictionApplication(
74 | executor, health_check=None
75 | ).load()
76 | service = app.test_client()
77 |
78 | response = service.get("/fake-health-route")
79 |
80 | self.assertEqual(response.status_code, http.HTTPStatus.OK)
81 | self.assertEqual(response.text, "ok")
82 |
83 | def test_predict_route_no_json(self):
84 | executor = mock.create_autospec(
85 | server_gunicorn.PredictionExecutor,
86 | instance=True,
87 | )
88 | app = server_gunicorn.PredictionApplication(
89 | executor, health_check=None
90 | ).load()
91 | service = app.test_client()
92 |
93 | response = service.post("/fake-predict-route", data="invalid")
94 |
95 | executor.start.assert_called_once()
96 | executor.execute.assert_not_called()
97 | self.assertEqual(response.status_code, http.HTTPStatus.BAD_REQUEST)
98 | self.assertDictEqual({"error": "No JSON body."}, response.get_json())
99 |
100 | def test_predict_route(self):
101 | executor = mock.create_autospec(
102 | server_gunicorn.PredictionExecutor,
103 | instance=True,
104 | )
105 | app = server_gunicorn.PredictionApplication(
106 | executor, health_check=None
107 | ).load()
108 | service = app.test_client()
109 | executor.execute.return_value = {"placeholder": "output"}
110 |
111 | response = service.post(
112 | "/fake-predict-route", json={"instances": [{"filler": "filler"}]}
113 | )
114 |
115 | executor.start.assert_called_once()
116 | executor.execute.assert_called_once_with(
117 | {"instances": [{"filler": "filler"}]}
118 | )
119 | self.assertEqual(response.status_code, http.HTTPStatus.OK)
120 | self.assertDictEqual({"placeholder": "output"}, response.get_json())
121 |
122 | def test_subprocess_executor_execute(self):
123 | mock_process = mock.create_autospec(subprocess.Popen, instance=True)
124 | with mock.patch.object(
125 | subprocess, "Popen", autospec=True, return_value=mock_process
126 | ) as mock_popen:
127 | executor = server_gunicorn.SubprocessPredictionExecutor(["fake_command"])
128 | executor.start()
129 | mock_popen.assert_called_once_with(
130 | args=["fake_command"],
131 | stdout=subprocess.PIPE,
132 | stdin=subprocess.PIPE,
133 | )
134 | mock_process.stdout = io.BytesIO(b'{"placeholder": "output"}\n')
135 | mock_process.stdin = io.BytesIO()
136 |
137 | response = executor.execute({"meaningless": "filler"})
138 |
139 | self.assertEqual(
140 | b'{"meaningless": "filler"}\n', mock_process.stdin.getvalue()
141 | )
142 | self.assertDictEqual({"placeholder": "output"}, response)
143 |
144 | def test_subprocess_executor_execute_error_output_closed(self):
145 | mock_process = mock.create_autospec(subprocess.Popen, instance=True)
146 | with mock.patch.object(
147 | subprocess, "Popen", autospec=True, return_value=mock_process
148 | ) as mock_popen:
149 | executor = server_gunicorn.SubprocessPredictionExecutor(["fake_command"])
150 | executor.start()
151 |
152 | mock_process.stdout = io.BytesIO() # empty output simulates closed pipe.
153 | mock_process.stdin = io.BytesIO()
154 |
155 | with self.assertRaises(RuntimeError) as raised:
156 | executor.execute({"meaningless": "filler"})
157 | self.assertEqual(
158 | raised.exception.args[0], "Executor process output stream closed."
159 | )
160 | self.assertEqual(mock_popen.call_count, 2) # executor restarted.
161 |
162 | def test_subprocess_executor_execute_error_input_broken(self):
163 | mock_process = mock.create_autospec(subprocess.Popen, instance=True)
164 | with mock.patch.object(
165 | subprocess, "Popen", autospec=True, return_value=mock_process
166 | ) as mock_popen:
167 | executor = server_gunicorn.SubprocessPredictionExecutor(["fake_command"])
168 | executor.start()
169 |
170 | mock_process.stdout = io.BytesIO(b'{"placeholder": "output"}\n')
171 | # Simulate broken pipe.
172 | mock_process.stdin = mock.create_autospec(io.BytesIO, instance=True)
173 | mock_process.stdin.write.side_effect = BrokenPipeError
174 |
175 | with self.assertRaises(RuntimeError):
176 | executor.execute({"meaningless": "filler"})
177 | self.assertEqual(mock_popen.call_count, 2) # executor restarted.
178 |
179 |
180 | if __name__ == "__main__":
181 | absltest.main()
182 |
--------------------------------------------------------------------------------
/python/serving/serving_framework/triton/triton_server_model_runner_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from unittest import mock
16 |
17 | import numpy as np
18 | from tritonclient import grpc as triton_grpc
19 |
20 | from absl.testing import absltest
21 | from serving.serving_framework.triton import triton_server_model_runner
22 | from tritonclient.grpc import service_pb2
23 |
24 |
25 | class TritonServerModelRunnerTest(absltest.TestCase):
26 |
27 | def setUp(self):
28 | super().setUp()
29 |
30 | self._client = mock.create_autospec(
31 | triton_grpc.InferenceServerClient, instance=True
32 | )
33 | self._client.infer = mock.MagicMock()
34 | self._runner = triton_server_model_runner.TritonServerModelRunner(
35 | client=self._client
36 | )
37 |
38 | self._output_proto = service_pb2.ModelInferResponse(
39 | model_name="test_model",
40 | model_version="1",
41 | outputs=[
42 | service_pb2.ModelInferResponse.InferOutputTensor(
43 | name="output_a",
44 | shape=[2, 2],
45 | datatype="FP32",
46 | ),
47 | service_pb2.ModelInferResponse.InferOutputTensor(
48 | name="output_b",
49 | shape=[1, 2],
50 | datatype="INT64",
51 | ),
52 | service_pb2.ModelInferResponse.InferOutputTensor(
53 | name="output_c",
54 | shape=[3, 2],
55 | datatype="FP32",
56 | ),
57 | ],
58 | raw_output_contents=[
59 | np.ones((2, 2), dtype=np.float32).tobytes(),
60 | np.array([[7] * 2], dtype=np.int64).tobytes(),
61 | np.zeros((3, 2), dtype=np.float32).tobytes(),
62 | ],
63 | )
64 |
65 | def test_run_map_check_input(self):
66 | """Tests that an input map is passed to the model correctly."""
67 | input_map = {
68 | "a": np.array([[0.5] * 3] * 3, dtype=np.float32),
69 | "b": np.array([[2] * 2] * 3, dtype=np.int64),
70 | }
71 | self._client.infer.return_value = triton_grpc.InferResult(
72 | self._output_proto
73 | )
74 |
75 | _ = self._runner.run_model_multiple_output(
76 | input_map,
77 | model_name="test_model",
78 | model_output_keys={"output_a", "output_b", "output_c"},
79 | )
80 |
81 | self.assertLen(self._client.infer.call_args_list, 1)
82 | self.assertEqual(
83 | self._client.infer.call_args[0][0],
84 | "test_model",
85 | "Model name passed to model does not match expectation.",
86 | )
87 | self.assertEqual(
88 | self._client.infer.call_args[0][2],
89 | "",
90 | "Model version passed to model does not match expectation.",
91 | )
92 | self.assertEqual(
93 | self._client.infer.call_args[0][1][0]._get_tensor(),
94 | service_pb2.ModelInferRequest.InferInputTensor(
95 | name="a",
96 | shape=[3, 3],
97 | datatype="FP32",
98 | ),
99 | msg="Input tensor passed to model does not match expectation.",
100 | )
101 | self.assertEqual(
102 | self._client.infer.call_args[0][1][1]._get_tensor(),
103 | service_pb2.ModelInferRequest.InferInputTensor(
104 | name="b",
105 | shape=[3, 2],
106 | datatype="INT64",
107 | ),
108 | msg="Input tensor passed to model does not match expectation.",
109 | )
110 |
111 | def test_run_bare_check_input(self):
112 | """Tests the handling of a bare input tensor passed to the model."""
113 | input_tensor = np.array([[0.5] * 3] * 3, dtype=np.float32)
114 | self._client.infer.return_value = triton_grpc.InferResult(
115 | self._output_proto
116 | )
117 |
118 | _ = self._runner.run_model_multiple_output(
119 | input_tensor,
120 | model_name="test_model",
121 | model_output_keys={"output_a", "output_b", "output_c"},
122 | )
123 |
124 | self.assertLen(self._client.infer.call_args_list, 1)
125 | self.assertEqual(
126 | self._client.infer.call_args[0][0],
127 | "test_model",
128 | "Model name passed to model does not match expectation.",
129 | )
130 | self.assertEqual(
131 | self._client.infer.call_args[0][2],
132 | "",
133 | "Model version passed to model does not match expectation.",
134 | )
135 | self.assertEqual(
136 | self._client.infer.call_args[0][1][0]._get_tensor(),
137 | service_pb2.ModelInferRequest.InferInputTensor(
138 | name="inputs",
139 | shape=[3, 3],
140 | datatype="FP32",
141 | ),
142 | msg="Input tensor passed to model does not match expectation.",
143 | )
144 |
145 | def test_input_model_version(self):
146 | """Tests that the model version is passed to the model correctly."""
147 | input_tensor = np.array([[0.5] * 3] * 3, dtype=np.float32)
148 | self._client.infer.return_value = triton_grpc.InferResult(
149 | self._output_proto
150 | )
151 |
152 | _ = self._runner.run_model_multiple_output(
153 | input_tensor,
154 | model_name="test_model",
155 | model_version=3,
156 | model_output_keys={"output_a", "output_b", "output_c"},
157 | )
158 |
159 | self.assertEqual(
160 | self._client.infer.call_args[0][2],
161 | "3",
162 | "Model version passed to model does not match expectation.",
163 | )
164 |
165 | def test_run_check_output(self):
166 | """Tests that the output is returned correctly."""
167 | input_map = {
168 | "a": np.array([[0.5] * 3] * 3, dtype=np.float32),
169 | "b": np.array([[2] * 2] * 3, dtype=np.int64),
170 | }
171 | self._client.infer.return_value = triton_grpc.InferResult(
172 | self._output_proto
173 | )
174 |
175 | result = self._runner.run_model_multiple_output(
176 | input_map,
177 | model_name="test_model",
178 | model_output_keys={"output_a", "output_b"},
179 | )
180 |
181 | np.testing.assert_array_equal(
182 | result["output_a"],
183 | np.ones((2, 2), dtype=np.float32),
184 | "Output tensor passed to model does not match expectation.",
185 | )
186 | np.testing.assert_array_equal(
187 | result["output_b"],
188 | np.array([[7] * 2], dtype=np.int64),
189 | "Output tensor passed to model does not match expectation.",
190 | )
191 | self.assertLen(result, 2)
192 |
193 | def test_run_check_missing_key(self):
194 | """Tests that the output is returned correctly."""
195 | input_map = {
196 | "a": np.array([[0.5] * 3] * 3, dtype=np.float32),
197 | "b": np.array([[2] * 2] * 3, dtype=np.int64),
198 | }
199 | self._client.infer.return_value = triton_grpc.InferResult(
200 | self._output_proto
201 | )
202 |
203 | with self.assertRaises(KeyError):
204 | _ = self._runner.run_model_multiple_output(
205 | input_map,
206 | model_name="test_model",
207 | model_output_keys={"output_a", "output_d"},
208 | )
209 |
210 | if __name__ == "__main__":
211 | absltest.main()
212 |
--------------------------------------------------------------------------------
/python/serving/logging_lib/flags/secret_flag_utils.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | """Util to get env value from secret manager."""
16 |
17 | import json
18 | import os
19 | import re
20 | import sys
21 | import threading
22 | from typing import Any, Mapping, TypeVar, Union
23 |
24 | import cachetools
25 | import google.api_core
26 | from google.cloud import secretmanager
27 |
28 | from serving.logging_lib.flags import flag_utils
29 |
30 | _T = TypeVar('_T')
31 |
32 | _SECRET_MANAGER_ENV_CONFIG = 'SECRET_MANAGER_ENV_CONFIG'
33 |
34 | _PARSE_SECRET_CONFIG = re.compile(
35 | r'.*?projects/([^/]+?)/secrets/([^/]+?)($|(/|(/versions/([^/]+))))',
36 | re.IGNORECASE,
37 | )
38 |
39 | # Enable secret manager if not debugging.
40 | _ENABLE_ENV_SECRET_MANAGER = bool(
41 | 'UNITTEST_ON_FORGE' not in os.environ and 'unittest' not in sys.modules
42 | )
43 |
44 |
45 | # Cache secret metadata to avoid repeated reads when initalizing flags.
46 | _cache = cachetools.LRUCache(maxsize=1)
47 | _cache_lock = threading.Lock()
48 |
49 |
50 | class SecretDecodeError(Exception):
51 |
52 | def __init__(self, msg: str, secret_name: str = '', data: str = ''):
53 | super().__init__(msg)
54 | self._secret_name = secret_name
55 | self._data = data
56 |
57 |
58 | def _init_fork_module_state() -> None:
59 | global _cache
60 | global _cache_lock
61 | _cache = cachetools.LRUCache(maxsize=1)
62 | _cache_lock = threading.Lock()
63 |
64 |
65 | def _get_secret_version(
66 | client: secretmanager.SecretManagerServiceClient, parent: str
67 | ) -> str:
68 | """Returns the greatest version number for a secret.
69 |
70 | Args:
71 | client: SecretManagerServiceClient.
72 | parent: String defining project and name of secret to return version.
73 |
74 | Returns:
75 | greatest version number for secret.
76 |
77 | Raises:
78 | SecretDecodeError: Could not identify version for secret.
79 | """
80 | version_list = client.list_secret_versions(request={'parent': parent})
81 | versions_found = []
82 | for name in [ver.name for ver in version_list]:
83 | match = _PARSE_SECRET_CONFIG.fullmatch(name)
84 | if match is None:
85 | continue
86 | try:
87 | versions_found.append(int(match.groups()[-1]))
88 | except ValueError:
89 | continue
90 | if not versions_found:
91 | raise SecretDecodeError(
92 | f'Could not find version for secret {parent}.', secret_name=parent
93 | )
94 | return str(max(versions_found))
95 |
96 |
97 | def _read_secrets(secret_name: str) -> Mapping[str, Any]:
98 | """Returns secret from secret manager.
99 |
100 | Args:
101 | secret_name: Name of secret.
102 |
103 | Returns:
104 | Secret value.
105 |
106 | Raises:
107 | SecretDecodeError: Error retrieving value from secret manager.
108 | """
109 | if not secret_name:
110 | return {}
111 | match = _PARSE_SECRET_CONFIG.fullmatch(secret_name)
112 | if match is None:
113 | raise SecretDecodeError(
114 | 'incorrectly formatted secret; expecting'
115 | f' [projects/.+/secrets/.+/versions/.+; passed {secret_name}.',
116 | secret_name=secret_name,
117 | )
118 | project, secret, *_, version = match.groups()
119 | if not _ENABLE_ENV_SECRET_MANAGER:
120 | return {}
121 | with _cache_lock:
122 | cached_val = _cache.get(secret_name)
123 | if cached_val is not None:
124 | return cached_val
125 | with secretmanager.SecretManagerServiceClient() as client:
126 | parent = client.secret_path(project, secret)
127 | try:
128 | if version is None or not version:
129 | version = _get_secret_version(client, parent)
130 | secret = client.access_secret_version(
131 | request={'name': f'{parent}/versions/{version}'}
132 | )
133 | except google.api_core.exceptions.NotFound as exp:
134 | raise SecretDecodeError(
135 | 'Secret not found.', secret_name=secret_name
136 | ) from exp
137 | except google.api_core.exceptions.PermissionDenied as exp:
138 | raise SecretDecodeError(
139 | 'Permission denied reading secret.', secret_name=secret_name
140 | ) from exp
141 | data = secret.payload.data
142 | if data is None or not data:
143 | return {}
144 | if isinstance(data, bytes):
145 | data = data.decode('utf-8')
146 | try:
147 | value = json.loads(data)
148 | except json.JSONDecodeError as exp:
149 | raise SecretDecodeError(
150 | 'Could not decode secret value.', secret_name=secret_name, data=data
151 | ) from exp
152 | if not isinstance(value, Mapping):
153 | raise SecretDecodeError(
154 | 'Secret value does not define a mapping.',
155 | secret_name=secret_name,
156 | data=data,
157 | )
158 | _cache[secret_name] = value
159 | return value
160 |
161 |
162 | def get_secret_or_env(name: str, default: _T) -> Union[str, _T]:
163 | """Returns value defined in secret manager, env, or if undefined default.
164 |
165 | Searchs first for variable definition in JSON dict stored within the GCP
166 | secret managner. The GCP secret containing the dict is defined by the
167 | _SECRET_MANAGER_ENV_CONFIG. If the variable is not found in the GCP encoded
168 | secret, or if the _SECRET_MANAGER_ENV_CONFIG is undefined then the container
169 | ENV are searched. If neither define the variable the default value is
170 | returned.
171 |
172 | Args:
173 | name: Name of ENV.
174 | default: Default value to return if name is not defined.
175 |
176 | Returns:
177 | ENV value.
178 |
179 | Raises:
180 | SecretDecodeError: Error retrieving value from secret manager.
181 | """
182 | secret_name = os.environ.get(_SECRET_MANAGER_ENV_CONFIG)
183 | if secret_name is not None and secret_name:
184 | secret_env = _read_secrets(secret_name)
185 | result = secret_env.get(name)
186 | if result is not None:
187 | return str(result)
188 | return os.environ.get(name, default)
189 |
190 |
191 | def get_bool_secret_or_env(
192 | env_name: str, undefined_value: bool = False
193 | ) -> bool:
194 | """Returns bool variable value into boolean value.
195 |
196 | Args:
197 | env_name: Environmental variable name.
198 | undefined_value: Default value to set undefined values to.
199 |
200 | Returns:
201 | Boolean of environmental variable string value.
202 |
203 | Raises:
204 | SecretDecodeError: Error retrieving value from secret manager.
205 | ValueError: Environmental variable cannot be parsed to bool.
206 | """
207 | value = get_secret_or_env(env_name, str(undefined_value))
208 | if value is not None:
209 | return flag_utils.str_to_bool(value)
210 |
211 |
212 | # Interfaces may be used from processes which are forked (gunicorn,
213 | # DICOM Proxy, Orchestrator, Refresher). In Python, forked processes do not
214 | # copy threads running within parent processes or re-initalize global/module
215 | # state. This can result in forked modules being executed with invalid global
216 | # state, e.g., acquired locks that will not release or references to invalid
217 | # state.
218 | _init_fork_module_state()
219 | os.register_at_fork(after_in_child=_init_fork_module_state)
220 |
--------------------------------------------------------------------------------
/python/serving/serving_framework/server_gunicorn.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Gunicorn application for passing requests through to the executor command.
16 |
17 | Provides a thin, subject-agnostic request server for Vertex endpoints which
18 | handles requests by piping their JSON bodies to the given executor command
19 | and returning the json output.
20 | """
21 |
22 | import abc
23 | from collections.abc import Mapping, Sequence
24 | import http
25 | import json
26 | import os
27 | import subprocess
28 | from typing import Any, Optional
29 |
30 | from absl import logging
31 | import flask
32 | from gunicorn.app import base as gunicorn_base
33 | import jsonschema
34 | from typing_extensions import override
35 |
36 |
37 | INSTANCES_KEY = "instances"
38 |
39 |
40 | class PredictionExecutor(abc.ABC):
41 | """Wraps arbitrary implementation of executing a prediction request."""
42 |
43 | @abc.abstractmethod
44 | def execute(self, input_json: dict[str, Any]) -> dict[str, Any]:
45 | """Executes the given request payload."""
46 |
47 | def start(self) -> None:
48 | """Starts the executor.
49 |
50 | Called after the Gunicorn worker process has started. Performs any setup
51 | which needs to be done post-fork.
52 | """
53 |
54 |
55 | class SubprocessPredictionExecutor(PredictionExecutor):
56 | """Provides prediction request execution using a persistent worker subprocess."""
57 |
58 | def __init__(self, executor_command: Sequence[str]):
59 | """Initializes the executor with a command to start the subprocess."""
60 | self._executor_command = executor_command
61 | self._executor_process = None
62 |
63 | def _restart(self) -> None:
64 | if self._executor_process is None:
65 | raise RuntimeError("Executor process not started.")
66 |
67 | self._executor_process.terminate()
68 | self.start()
69 |
70 | @override
71 | def start(self):
72 | """Starts the executor process."""
73 | self._executor_process = subprocess.Popen(
74 | args=self._executor_command,
75 | stdout=subprocess.PIPE,
76 | stdin=subprocess.PIPE,
77 | )
78 |
79 | @override
80 | def execute(self, input_json: dict[str, Any]) -> dict[str, Any]:
81 | """Uses the executor process to execute a request.
82 |
83 | Args:
84 | input_json: The full json prediction request payload.
85 |
86 | Returns:
87 | The json response to the prediction request.
88 |
89 | Raises:
90 | RuntimeError: Executor is not started or error communicating with the
91 | subprocess.
92 | """
93 | if self._executor_process is None:
94 | raise RuntimeError("Executor process not started.")
95 |
96 | # Ensure json string is safe to pass through the pipe protocol.
97 | input_str = json.dumps(input_json).replace("\n", "")
98 |
99 | try:
100 | self._executor_process.stdin.write(input_str.encode("utf-8") + b"\n")
101 | self._executor_process.stdin.flush()
102 | except BrokenPipeError as e:
103 | self._restart()
104 | raise RuntimeError("Executor process input stream closed.") from e
105 | exec_result = self._executor_process.stdout.readline()
106 | if not exec_result:
107 | self._restart()
108 | raise RuntimeError("Executor process output stream closed.")
109 | try:
110 | return json.loads(exec_result)
111 | except json.JSONDecodeError as e:
112 | raise RuntimeError("Executor process output not valid json.") from e
113 |
114 |
115 | class ModelServerHealthCheck(abc.ABC):
116 | """Checks the health of the local model server."""
117 |
118 | @abc.abstractmethod
119 | def check_health(self) -> bool:
120 | """Check the health of the local model server immediately."""
121 |
122 |
123 | class _PredictRoute:
124 | """A callable for handling a prediction route."""
125 |
126 | def __init__(
127 | self,
128 | executor: PredictionExecutor,
129 | *,
130 | input_validator: jsonschema.Draft202012Validator | None = None,
131 | instance_input: bool = True,
132 | prediction_validator: jsonschema.Draft202012Validator | None = None,
133 | ):
134 | self._executor = executor
135 | self._input_validator = input_validator
136 | self._instance_input = instance_input
137 | self._prediction_validator = prediction_validator
138 |
139 | def __call__(self) -> tuple[dict[str, Any], int]:
140 | logging.info("predict route hit")
141 | json_body = flask.request.get_json(silent=True)
142 | if json_body is None:
143 | return {"error": "No JSON body."}, http.HTTPStatus.BAD_REQUEST.value
144 | if self._instance_input and INSTANCES_KEY not in json_body:
145 | return {
146 | "error": "No instances field in request."
147 | }, http.HTTPStatus.BAD_REQUEST.value
148 |
149 | if self._input_validator is not None:
150 | try:
151 | if self._instance_input:
152 | for instance in json_body[INSTANCES_KEY]:
153 | self._input_validator.validate(instance)
154 | else:
155 | self._input_validator.validate(json_body)
156 | except jsonschema.exceptions.ValidationError as e:
157 | logging.warning("Input validation failed")
158 | return {"error": str(e)}, http.HTTPStatus.BAD_REQUEST.value
159 |
160 | logging.debug("Dispatching request to executor.")
161 | try:
162 | exec_result = self._executor.execute(json_body)
163 | logging.debug("Executor returned results.")
164 | if self._prediction_validator is not None:
165 | try:
166 | if "predictions" in exec_result:
167 | for result in exec_result["predictions"]:
168 | self._prediction_validator.validate(result)
169 | except jsonschema.exceptions.ValidationError as e:
170 | logging.exception("Response validation failed")
171 | return {
172 | "error": "Internal server error."
173 | }, http.HTTPStatus.INTERNAL_SERVER_ERROR.value
174 |
175 | return (exec_result, http.HTTPStatus.OK.value)
176 | except RuntimeError:
177 | logging.exception("Internal error handling request: Executor failed.")
178 | return {
179 | "error": "Internal server error."
180 | }, http.HTTPStatus.INTERNAL_SERVER_ERROR.value
181 |
182 |
183 | def _create_app(
184 | executor: PredictionExecutor,
185 | health_check: ModelServerHealthCheck | None,
186 | *,
187 | input_validator: jsonschema.Draft202012Validator | None = None,
188 | instance_input: bool = True,
189 | prediction_validator: jsonschema.Draft202012Validator | None = None,
190 | additional_routes: Mapping[str, PredictionExecutor] | None = None,
191 | ) -> flask.Flask:
192 | """Creates a Flask app with the given executor."""
193 | flask_app = flask.Flask(__name__)
194 |
195 | if (
196 | "AIP_HEALTH_ROUTE" not in os.environ
197 | or "AIP_PREDICT_ROUTE" not in os.environ
198 | ):
199 | raise ValueError(
200 | "Both of the environment variables AIP_HEALTH_ROUTE and "
201 | "AIP_PREDICT_ROUTE need to be specified."
202 | )
203 |
204 | def health_route() -> tuple[str, int]:
205 | logging.info("health route hit")
206 | if health_check is not None and not health_check.check_health():
207 | return "not ok", http.HTTPStatus.SERVICE_UNAVAILABLE.value
208 | return "ok", http.HTTPStatus.OK.value
209 |
210 | health_path = os.environ.get("AIP_HEALTH_ROUTE")
211 | logging.info("health path: %s", health_path)
212 | flask_app.add_url_rule(health_path, view_func=health_route)
213 |
214 | predict_route = os.environ.get("AIP_PREDICT_ROUTE")
215 | logging.info("predict route: %s", predict_route)
216 | flask_app.add_url_rule(
217 | rule=predict_route,
218 | endpoint=predict_route,
219 | view_func=_PredictRoute(
220 | executor,
221 | input_validator=input_validator,
222 | instance_input=instance_input,
223 | prediction_validator=prediction_validator,
224 | ),
225 | methods=["POST"],
226 | )
227 |
228 | if additional_routes:
229 | for route, executor in additional_routes.items():
230 | logging.info("additional route: %s", route)
231 | flask_app.add_url_rule(
232 | rule=route,
233 | endpoint=route,
234 | view_func=_PredictRoute(
235 | executor,
236 | instance_input=False,
237 | ),
238 | methods=["POST"],
239 | )
240 |
241 | flask_app.config["TRAP_BAD_REQUEST_ERRORS"] = True
242 |
243 | return flask_app
244 |
245 |
246 | class PredictionApplication(gunicorn_base.BaseApplication):
247 | """Application to serve predictors on Vertex endpoints using gunicorn."""
248 |
249 | def __init__(
250 | self,
251 | executor: PredictionExecutor,
252 | *,
253 | health_check: ModelServerHealthCheck | None,
254 | options: Optional[Mapping[str, Any]] = None,
255 | input_validator: jsonschema.Draft202012Validator | None = None,
256 | instance_input: bool = True,
257 | prediction_validator: jsonschema.Draft202012Validator | None = None,
258 | additional_routes: Mapping[str, PredictionExecutor] | None = None,
259 | ):
260 | """Initializes the application.
261 |
262 | Args:
263 | executor: The executor to use for handling prediction requests.
264 | health_check: The health check to use for the health check route.
265 | options: Gunicorn application options.
266 | input_validator: A jsonschema validator to apply to the input instances.
267 | instance_input: Whether the input will contain a list of instances which
268 | can be validated individually if a validator is provided. If False, the
269 | request object structure will not be assumed and the validator will
270 | be applied to the entire request body.
271 | prediction_validator: A jsonschema validator to apply to the predictions.
272 | additional_routes: Additional routes to serve on the server, keyed by
273 | route path. Validation is not applied to these routes.
274 | """
275 | self.options = options or {}
276 | self.options = dict(self.options)
277 | self.options["preload_app"] = False
278 | self._executor = executor
279 | self.application = _create_app(
280 | self._executor,
281 | health_check,
282 | input_validator=input_validator,
283 | instance_input=instance_input,
284 | prediction_validator=prediction_validator,
285 | additional_routes=additional_routes,
286 | )
287 | super().__init__()
288 |
289 | def load_config(self):
290 | config = {
291 | key: value
292 | for key, value in self.options.items()
293 | if key in self.cfg.settings and value is not None
294 | }
295 | for key, value in config.items():
296 | self.cfg.set(key.lower(), value)
297 |
298 | def load(self) -> flask.Flask:
299 | self._executor.start()
300 | return self.application
301 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 |
2 | Apache License
3 | Version 2.0, January 2004
4 | http://www.apache.org/licenses/
5 |
6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
7 |
8 | 1. Definitions.
9 |
10 | "License" shall mean the terms and conditions for use, reproduction,
11 | and distribution as defined by Sections 1 through 9 of this document.
12 |
13 | "Licensor" shall mean the copyright owner or entity authorized by
14 | the copyright owner that is granting the License.
15 |
16 | "Legal Entity" shall mean the union of the acting entity and all
17 | other entities that control, are controlled by, or are under common
18 | control with that entity. For the purposes of this definition,
19 | "control" means (i) the power, direct or indirect, to cause the
20 | direction or management of such entity, whether by contract or
21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
22 | outstanding shares, or (iii) beneficial ownership of such entity.
23 |
24 | "You" (or "Your") shall mean an individual or Legal Entity
25 | exercising permissions granted by this License.
26 |
27 | "Source" form shall mean the preferred form for making modifications,
28 | including but not limited to software source code, documentation
29 | source, and configuration files.
30 |
31 | "Object" form shall mean any form resulting from mechanical
32 | transformation or translation of a Source form, including but
33 | not limited to compiled object code, generated documentation,
34 | and conversions to other media types.
35 |
36 | "Work" shall mean the work of authorship, whether in Source or
37 | Object form, made available under the License, as indicated by a
38 | copyright notice that is included in or attached to the work
39 | (an example is provided in the Appendix below).
40 |
41 | "Derivative Works" shall mean any work, whether in Source or Object
42 | form, that is based on (or derived from) the Work and for which the
43 | editorial revisions, annotations, elaborations, or other modifications
44 | represent, as a whole, an original work of authorship. For the purposes
45 | of this License, Derivative Works shall not include works that remain
46 | separable from, or merely link (or bind by name) to the interfaces of,
47 | the Work and Derivative Works thereof.
48 |
49 | "Contribution" shall mean any work of authorship, including
50 | the original version of the Work and any modifications or additions
51 | to that Work or Derivative Works thereof, that is intentionally
52 | submitted to Licensor for inclusion in the Work by the copyright owner
53 | or by an individual or Legal Entity authorized to submit on behalf of
54 | the copyright owner. For the purposes of this definition, "submitted"
55 | means any form of electronic, verbal, or written communication sent
56 | to the Licensor or its representatives, including but not limited to
57 | communication on electronic mailing lists, source code control systems,
58 | and issue tracking systems that are managed by, or on behalf of, the
59 | Licensor for the purpose of discussing and improving the Work, but
60 | excluding communication that is conspicuously marked or otherwise
61 | designated in writing by the copyright owner as "Not a Contribution."
62 |
63 | "Contributor" shall mean Licensor and any individual or Legal Entity
64 | on behalf of whom a Contribution has been received by Licensor and
65 | subsequently incorporated within the Work.
66 |
67 | 2. Grant of Copyright License. Subject to the terms and conditions of
68 | this License, each Contributor hereby grants to You a perpetual,
69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
70 | copyright license to reproduce, prepare Derivative Works of,
71 | publicly display, publicly perform, sublicense, and distribute the
72 | Work and such Derivative Works in Source or Object form.
73 |
74 | 3. Grant of Patent License. Subject to the terms and conditions of
75 | this License, each Contributor hereby grants to You a perpetual,
76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
77 | (except as stated in this section) patent license to make, have made,
78 | use, offer to sell, sell, import, and otherwise transfer the Work,
79 | where such license applies only to those patent claims licensable
80 | by such Contributor that are necessarily infringed by their
81 | Contribution(s) alone or by combination of their Contribution(s)
82 | with the Work to which such Contribution(s) was submitted. If You
83 | institute patent litigation against any entity (including a
84 | cross-claim or counterclaim in a lawsuit) alleging that the Work
85 | or a Contribution incorporated within the Work constitutes direct
86 | or contributory patent infringement, then any patent licenses
87 | granted to You under this License for that Work shall terminate
88 | as of the date such litigation is filed.
89 |
90 | 4. Redistribution. You may reproduce and distribute copies of the
91 | Work or Derivative Works thereof in any medium, with or without
92 | modifications, and in Source or Object form, provided that You
93 | meet the following conditions:
94 |
95 | (a) You must give any other recipients of the Work or
96 | Derivative Works a copy of this License; and
97 |
98 | (b) You must cause any modified files to carry prominent notices
99 | stating that You changed the files; and
100 |
101 | (c) You must retain, in the Source form of any Derivative Works
102 | that You distribute, all copyright, patent, trademark, and
103 | attribution notices from the Source form of the Work,
104 | excluding those notices that do not pertain to any part of
105 | the Derivative Works; and
106 |
107 | (d) If the Work includes a "NOTICE" text file as part of its
108 | distribution, then any Derivative Works that You distribute must
109 | include a readable copy of the attribution notices contained
110 | within such NOTICE file, excluding those notices that do not
111 | pertain to any part of the Derivative Works, in at least one
112 | of the following places: within a NOTICE text file distributed
113 | as part of the Derivative Works; within the Source form or
114 | documentation, if provided along with the Derivative Works; or,
115 | within a display generated by the Derivative Works, if and
116 | wherever such third-party notices normally appear. The contents
117 | of the NOTICE file are for informational purposes only and
118 | do not modify the License. You may add Your own attribution
119 | notices within Derivative Works that You distribute, alongside
120 | or as an addendum to the NOTICE text from the Work, provided
121 | that such additional attribution notices cannot be construed
122 | as modifying the License.
123 |
124 | You may add Your own copyright statement to Your modifications and
125 | may provide additional or different license terms and conditions
126 | for use, reproduction, or distribution of Your modifications, or
127 | for any such Derivative Works as a whole, provided Your use,
128 | reproduction, and distribution of the Work otherwise complies with
129 | the conditions stated in this License.
130 |
131 | 5. Submission of Contributions. Unless You explicitly state otherwise,
132 | any Contribution intentionally submitted for inclusion in the Work
133 | by You to the Licensor shall be under the terms and conditions of
134 | this License, without any additional terms or conditions.
135 | Notwithstanding the above, nothing herein shall supersede or modify
136 | the terms of any separate license agreement you may have executed
137 | with Licensor regarding such Contributions.
138 |
139 | 6. Trademarks. This License does not grant permission to use the trade
140 | names, trademarks, service marks, or product names of the Licensor,
141 | except as required for reasonable and customary use in describing the
142 | origin of the Work and reproducing the content of the NOTICE file.
143 |
144 | 7. Disclaimer of Warranty. Unless required by applicable law or
145 | agreed to in writing, Licensor provides the Work (and each
146 | Contributor provides its Contributions) on an "AS IS" BASIS,
147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
148 | implied, including, without limitation, any warranties or conditions
149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
150 | PARTICULAR PURPOSE. You are solely responsible for determining the
151 | appropriateness of using or redistributing the Work and assume any
152 | risks associated with Your exercise of permissions under this License.
153 |
154 | 8. Limitation of Liability. In no event and under no legal theory,
155 | whether in tort (including negligence), contract, or otherwise,
156 | unless required by applicable law (such as deliberate and grossly
157 | negligent acts) or agreed to in writing, shall any Contributor be
158 | liable to You for damages, including any direct, indirect, special,
159 | incidental, or consequential damages of any character arising as a
160 | result of this License or out of the use or inability to use the
161 | Work (including but not limited to damages for loss of goodwill,
162 | work stoppage, computer failure or malfunction, or any and all
163 | other commercial damages or losses), even if such Contributor
164 | has been advised of the possibility of such damages.
165 |
166 | 9. Accepting Warranty or Additional Liability. While redistributing
167 | the Work or Derivative Works thereof, You may choose to offer,
168 | and charge a fee for, acceptance of support, warranty, indemnity,
169 | or other liability obligations and/or rights consistent with this
170 | License. However, in accepting such obligations, You may act only
171 | on Your own behalf and on Your sole responsibility, not on behalf
172 | of any other Contributor, and only if You agree to indemnify,
173 | defend, and hold each Contributor harmless for any liability
174 | incurred by, or claims asserted against, such Contributor by reason
175 | of your accepting any such warranty or additional liability.
176 |
177 | END OF TERMS AND CONDITIONS
178 |
179 | APPENDIX: How to apply the Apache License to your work.
180 |
181 | To apply the Apache License to your work, attach the following
182 | boilerplate notice, with the fields enclosed by brackets "[]"
183 | replaced with your own identifying information. (Don't include
184 | the brackets!) The text should be enclosed in the appropriate
185 | comment syntax for the file format. We also recommend that a
186 | file or class name and description of purpose be included on the
187 | same "printed page" as the copyright notice for easier
188 | identification within third-party archives.
189 |
190 | Copyright [yyyy] [name of copyright owner]
191 |
192 | Licensed under the Apache License, Version 2.0 (the "License");
193 | you may not use this file except in compliance with the License.
194 | You may obtain a copy of the License at
195 |
196 | http://www.apache.org/licenses/LICENSE-2.0
197 |
198 | Unless required by applicable law or agreed to in writing, software
199 | distributed under the License is distributed on an "AS IS" BASIS,
200 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
201 | See the License for the specific language governing permissions and
202 | limitations under the License.
203 |
--------------------------------------------------------------------------------
/python/serving/logging_lib/cloud_logging_client.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | """Wrapper for cloud ops structured logging."""
16 | from __future__ import annotations
17 |
18 | import os
19 | import sys
20 | import threading
21 | from typing import Any, Mapping, Optional, Union
22 |
23 | from absl import flags
24 | from absl import logging
25 | import google.auth
26 | import psutil
27 |
28 | from serving.logging_lib.flags import secret_flag_utils
29 | from serving.logging_lib import cloud_logging_client_instance
30 |
31 | # name of cloud ops log
32 | CLOUD_OPS_LOG_NAME_FLG = flags.DEFINE_string(
33 | 'ops_log_name',
34 | secret_flag_utils.get_secret_or_env('CLOUD_OPS_LOG_NAME', 'python'),
35 | 'Cloud ops log name to write logs to.',
36 | )
37 | CLOUD_OPS_LOG_PROJECT_FLG = flags.DEFINE_string(
38 | 'ops_log_project',
39 | secret_flag_utils.get_secret_or_env('CLOUD_OPS_LOG_PROJECT', None),
40 | 'GCP project name to write log to. Undefined = default',
41 | )
42 | POD_HOSTNAME_FLG = flags.DEFINE_string(
43 | 'pod_hostname',
44 | secret_flag_utils.get_secret_or_env('HOSTNAME', None),
45 | 'Host name of GKE pod. Set by container ENV. '
46 | 'Set to mock value in unit test.',
47 | )
48 | POD_UID_FLG = flags.DEFINE_string(
49 | 'pod_uid',
50 | secret_flag_utils.get_secret_or_env('MY_POD_UID', None),
51 | 'UID of GKE pod. Do not set unless in test.',
52 | )
53 | ENABLE_STRUCTURED_LOGGING_FLG = flags.DEFINE_boolean(
54 | 'enable_structured_logging',
55 | secret_flag_utils.get_bool_secret_or_env(
56 | 'ENABLE_STRUCTURED_LOGGING',
57 | not secret_flag_utils.get_bool_secret_or_env(
58 | 'DISABLE_STRUCTURED_LOGGING'
59 | ),
60 | ),
61 | 'Enable structured logging.',
62 | )
63 | ENABLE_LOGGING_FLG = flags.DEFINE_boolean(
64 | 'enable_logging',
65 | secret_flag_utils.get_bool_secret_or_env('ENABLE_LOGGING_FLG', True),
66 | 'Enable logging.',
67 | )
68 |
69 | _DEBUG_LOGGING_USE_ABSL_LOGGING_FLG = flags.DEFINE_boolean(
70 | 'debug_logging_use_absl_logging',
71 | # Confusing double negative is used to enable external env be a postive
72 | # statement and override the existing default
73 | not secret_flag_utils.get_bool_secret_or_env(
74 | 'ENABLE_CLOUD_LOGGING',
75 | not cloud_logging_client_instance.DEBUG_LOGGING_USE_ABSL_LOGGING,
76 | ),
77 | 'Debug/testing option to logs to absl.logger. Automatically set when '
78 | 'running unit tests.',
79 | )
80 |
81 | LOG_ALL_PYTHON_LOGS_TO_CLOUD_FLG = flags.DEFINE_boolean(
82 | 'log_all_python_logs_to_cloud',
83 | secret_flag_utils.get_bool_secret_or_env('LOG_ALL_PYTHON_LOGS_TO_CLOUD'),
84 | 'Logs every modules log to Cloud Ops.',
85 | )
86 |
87 | PER_THREAD_LOG_SIGNATURES_FLG = flags.DEFINE_boolean(
88 | 'per_thread_log_signatures',
89 | secret_flag_utils.get_bool_secret_or_env('PER_THREAD_LOG_SIGNATURES', True),
90 | 'If True Log signatures are not shared are across threads if false '
91 | 'Process threads share a common log signature',
92 | )
93 |
94 |
95 | def _are_flags_initialized() -> bool:
96 | """Returns True if flags are initialized."""
97 | try:
98 | _ = CLOUD_OPS_LOG_PROJECT_FLG.value
99 | return True
100 | except (flags.UnparsedFlagAccessError, AttributeError):
101 | return False
102 |
103 |
104 | def _get_flags() -> Mapping[str, str]:
105 | load_flags = {}
106 | unparsed_flags = []
107 | for flag_name in flags.FLAGS:
108 | try:
109 | load_flags[flag_name] = flags.FLAGS.__getattr__(flag_name)
110 | except flags.UnparsedFlagAccessError:
111 | unparsed_flags.append(flag_name)
112 | if unparsed_flags:
113 | load_flags['unparsed_flags'] = ', '.join(unparsed_flags)
114 | return load_flags
115 |
116 |
117 | def _default_gcp_project() -> str:
118 | try:
119 | _, project = google.auth.default(
120 | scopes=['https://www.googleapis.com/auth/cloud-platform']
121 | )
122 | return project if project is not None else ''
123 | except google.auth.exceptions.DefaultCredentialsError:
124 | return ''
125 |
126 |
127 | class CloudLoggingClient(
128 | cloud_logging_client_instance.CloudLoggingClientInstance
129 | ):
130 | """Wrapper for cloud ops structured logging.
131 |
132 | Automatically adds signature to structured logs to make traceable.
133 | """
134 |
135 | # lock for log makes access to singleton
136 | # safe across threads. Logging used in main thread and ack_timeout_mon
137 | _singleton_instance: Optional[CloudLoggingClient] = None
138 | _startup_message_logged = False
139 | _singleton_lock = threading.RLock()
140 |
141 | @classmethod
142 | def _init_fork_module_state(cls) -> None:
143 | cls._singleton_instance = None
144 | cls._startup_message_logged = True
145 | cls._singleton_lock = threading.RLock()
146 |
147 | @classmethod
148 | def _fork_shutdown(cls) -> None:
149 | with cls._singleton_lock:
150 | cls._singleton_instance = None
151 |
152 | @classmethod
153 | def _set_absl_skip_frames(cls) -> None:
154 | """Sets absl logging attribution to skip over internal logging frames."""
155 | logging.ABSLLogger.register_frame_to_skip(
156 | __file__,
157 | function_name='debug',
158 | )
159 | logging.ABSLLogger.register_frame_to_skip(
160 | __file__,
161 | function_name='timed_debug',
162 | )
163 | logging.ABSLLogger.register_frame_to_skip(
164 | __file__,
165 | function_name='info',
166 | )
167 | logging.ABSLLogger.register_frame_to_skip(
168 | __file__,
169 | function_name='warning',
170 | )
171 | logging.ABSLLogger.register_frame_to_skip(
172 | __file__,
173 | function_name='error',
174 | )
175 | logging.ABSLLogger.register_frame_to_skip(
176 | __file__,
177 | function_name='critical',
178 | )
179 |
180 | def __init__(self):
181 | with CloudLoggingClient._singleton_lock:
182 | if not _are_flags_initialized():
183 | # if flags are not initialize then init logging flags
184 | flags.FLAGS(sys.argv, known_only=True)
185 | if CloudLoggingClient._singleton_instance is not None:
186 | raise cloud_logging_client_instance.CloudLoggerInstanceExceptionError(
187 | 'Singleton already initialized.'
188 | )
189 | CloudLoggingClient._set_absl_skip_frames()
190 | gcp_project = CLOUD_OPS_LOG_PROJECT_FLG.value
191 | if gcp_project is None or not gcp_project.strip():
192 | gcp_project = _default_gcp_project()
193 | pod_host_name = (
194 | '' if POD_HOSTNAME_FLG.value is None else POD_HOSTNAME_FLG.value
195 | )
196 | pod_uid = '' if POD_UID_FLG.value is None else POD_UID_FLG.value
197 | super().__init__(
198 | log_name=CLOUD_OPS_LOG_NAME_FLG.value,
199 | gcp_project_to_write_logs_to=gcp_project,
200 | gcp_credentials=None,
201 | pod_hostname=pod_host_name,
202 | pod_uid=pod_uid,
203 | enable_structured_logging=ENABLE_STRUCTURED_LOGGING_FLG.value,
204 | use_absl_logging=_DEBUG_LOGGING_USE_ABSL_LOGGING_FLG.value,
205 | log_all_python_logs_to_cloud=LOG_ALL_PYTHON_LOGS_TO_CLOUD_FLG.value,
206 | per_thread_log_signatures=PER_THREAD_LOG_SIGNATURES_FLG.value,
207 | enabled=ENABLE_LOGGING_FLG.value,
208 | )
209 | CloudLoggingClient._singleton_instance = self
210 |
211 | def startup_msg(self) -> None:
212 | """Logs default messages after logger fully initialized."""
213 | if self.use_absl_logging() or CloudLoggingClient._startup_message_logged:
214 | return
215 | CloudLoggingClient._startup_message_logged = True
216 | pid = os.getpid()
217 | process_name = psutil.Process(pid).name()
218 | self.debug(
219 | 'Container process started.',
220 | {'process_name': process_name, 'process_id': pid},
221 | )
222 | self.debug(
223 | 'Container environmental variables.', os.environ
224 | ) # pytype: disable=wrong-arg-types # kwargs-checking
225 | vm = psutil.virtual_memory()
226 | self.debug(
227 | 'Compute instance',
228 | {
229 | 'processors(count)': os.cpu_count(),
230 | 'total_system_mem_(bytes)': vm.total,
231 | 'available_system_mem_(bytes)': vm.available,
232 | },
233 | )
234 | self.debug('Initalized flags', _get_flags())
235 | project_name = self.gcp_project_name if self.gcp_project_name else 'DEFAULT'
236 | self.debug(f'Logging to GCP project: {project_name}')
237 |
238 | @classmethod
239 | def logger(cls, show_startup_msg: bool = True) -> CloudLoggingClient:
240 | if cls._singleton_instance is None:
241 | with cls._singleton_lock: # makes instance creation thread safe.
242 | if cls._singleton_instance is None:
243 | cls._singleton_instance = CloudLoggingClient()
244 | if not show_startup_msg:
245 | cls._startup_message_logged = True
246 | else:
247 | cls._singleton_instance.startup_msg() # pytype: disable=attribute-error
248 | return cls._singleton_instance # pytype: disable=bad-return-type
249 |
250 |
251 | def logger() -> CloudLoggingClient:
252 | return CloudLoggingClient.logger()
253 |
254 |
255 | def do_not_log_startup_msg() -> None:
256 | CloudLoggingClient.logger(show_startup_msg=False)
257 |
258 |
259 | def debug(
260 | msg: str,
261 | *struct: Union[Mapping[str, Any], Exception, None],
262 | stack_frames_back: int = 0,
263 | ) -> None:
264 | """Logs with debug severity.
265 |
266 | Args:
267 | msg: message to log (string).
268 | *struct: zero or more dict or exception to log in structured log.
269 | stack_frames_back: Additional stack frames back to log source_location.
270 | """
271 | logger().debug(msg, *struct, stack_frames_back=stack_frames_back + 1)
272 |
273 |
274 | def timed_debug(
275 | msg: str,
276 | *struct: Union[Mapping[str, Any], Exception, None],
277 | stack_frames_back: int = 0,
278 | ) -> None:
279 | """Logs with debug severity and elapsed time since last timed debug log.
280 |
281 | Args:
282 | msg: message to log (string).
283 | *struct: zero or more dict or exception to log in structured log.
284 | stack_frames_back: Additional stack frames back to log source_location.
285 | """
286 | logger().timed_debug(msg, *struct, stack_frames_back=stack_frames_back + 1)
287 |
288 |
289 | def info(
290 | msg: str,
291 | *struct: Union[Mapping[str, Any], Exception, None],
292 | stack_frames_back: int = 0,
293 | ) -> None:
294 | """Logs with info severity.
295 |
296 | Args:
297 | msg: message to log (string).
298 | *struct: zero or more dict or exception to log in structured log.
299 | stack_frames_back: Additional stack frames back to log source_location.
300 | """
301 | logger().info(msg, *struct, stack_frames_back=stack_frames_back + 1)
302 |
303 |
304 | def warning(
305 | msg: str,
306 | *struct: Union[Mapping[str, Any], Exception, None],
307 | stack_frames_back: int = 0,
308 | ) -> None:
309 | """Logs with warning severity.
310 |
311 | Args:
312 | msg: Message to log (string).
313 | *struct: Zero or more dict or exception to log in structured log.
314 | stack_frames_back: Additional stack frames back to log source_location.
315 | """
316 | logger().warning(msg, *struct, stack_frames_back=stack_frames_back + 1)
317 |
318 |
319 | def error(
320 | msg: str,
321 | *struct: Union[Mapping[str, Any], Exception, None],
322 | stack_frames_back: int = 0,
323 | ) -> None:
324 | """Logs with error severity.
325 |
326 | Args:
327 | msg: Message to log (string).
328 | *struct: Zero or more dict or exception to log in structured log.
329 | stack_frames_back: Additional stack frames back to log source_location.
330 | """
331 | logger().error(msg, *struct, stack_frames_back=stack_frames_back + 1)
332 |
333 |
334 | def critical(
335 | msg: str,
336 | *struct: Union[Mapping[str, Any], Exception, None],
337 | stack_frames_back: int = 0,
338 | ) -> None:
339 | """Logs with critical severity.
340 |
341 | Args:
342 | msg: Message to log (string).
343 | *struct: Zero or more dict or exception to log in structured log.
344 | stack_frames_back: Additional stack frames back to log source_location.
345 | """
346 | logger().critical(msg, *struct, stack_frames_back=stack_frames_back + 1)
347 |
348 |
349 | def clear_log_signature() -> None:
350 | logger().clear_log_signature()
351 |
352 |
353 | def get_log_signature() -> Mapping[str, Any]:
354 | return logger().log_signature
355 |
356 |
357 | def set_log_signature(sig: Mapping[str, Any]) -> None:
358 | logger().log_signature = sig
359 |
360 |
361 | def set_per_thread_log_signatures(val: bool) -> None:
362 | logger().per_thread_log_signatures = val
363 |
364 |
365 | def get_build_version(clip_length: Optional[int] = None) -> str:
366 | if clip_length is not None and clip_length >= 0:
367 | return logger().build_version[:clip_length]
368 | return logger().build_version
369 |
370 |
371 | def set_build_version(build_version: str) -> None:
372 | logger().build_version = build_version
373 |
374 |
375 | def set_log_trace_key(key: str) -> None:
376 | logger().trace_key = key
377 |
378 |
379 | # Logging interfaces are used from processes which are forked (gunicorn,
380 | # DICOM Proxy, Orchestrator, Refresher). In Python, forked processes do not
381 | # copy threads running within parent processes or re-initalize global/module
382 | # state. This can result in forked modules being executed with invalid global
383 | # state, e.g., acquired locks that will not release or references to invalid
384 | # state. The cloud logging library utilizes a background thread transporting
385 | # logs to cloud. The background threading is not compatiable with forking and
386 | # will seg-fault (python queue wait). This can be avoided, by stoping and
387 | # the background transport prior to forking and then restarting the transport
388 | # following the fork.
389 | os.register_at_fork(
390 | before=CloudLoggingClient._fork_shutdown, # pylint: disable=protected-access
391 | after_in_child=CloudLoggingClient._init_fork_module_state, # pylint: disable=protected-access
392 | )
393 |
--------------------------------------------------------------------------------
/python/serving/logging_lib/flags/secret_flag_utils_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | """Tests for secret_flag_utils."""
16 | import os
17 | from typing import Any, List
18 | from unittest import mock
19 |
20 | from absl.testing import absltest
21 | from absl.testing import parameterized
22 | import google.api_core
23 | from google.cloud import secretmanager
24 |
25 | from serving.logging_lib.flags import secret_flag_utils
26 |
27 | _MOCK_SECRET = 'projects/test-project/secrets/test-secret'
28 |
29 |
30 | def _create_mock_secret_value_response(
31 | data: Any,
32 | ) -> secretmanager.AccessSecretVersionResponse:
33 | payload = mock.create_autospec(secretmanager.SecretPayload, instance=True)
34 | payload.data = data
35 | sec_val = mock.create_autospec(
36 | secretmanager.AccessSecretVersionResponse, instance=True
37 | )
38 | sec_val.payload = payload
39 | return sec_val
40 |
41 |
42 | def _create_mock_secret_version_response(
43 | secret_version_prefix: str,
44 | version_numbers: List[str],
45 | ) -> list[secretmanager.ListSecretVersionsResponse]:
46 | version_list = []
47 | for version_number in version_numbers:
48 | version = mock.create_autospec(
49 | secretmanager.ListSecretVersionsResponse, instance=True
50 | )
51 | version.name = f'{secret_version_prefix}{version_number}'
52 | version_list.append(version)
53 | return version_list
54 |
55 |
56 | class SecretFlagUtilsTest(parameterized.TestCase):
57 |
58 | def setUp(self):
59 | super().setUp()
60 | secret_flag_utils._cache.clear()
61 | self.enter_context(
62 | mock.patch(
63 | 'google.auth.default', autospec=True, return_value=('mock', 'mock')
64 | )
65 | )
66 |
67 | def tearDown(self):
68 | super().tearDown()
69 | secret_flag_utils._ENABLE_ENV_SECRET_MANAGER = False
70 |
71 | def test_init_fork_module_state(self):
72 | secret_flag_utils._cache = None
73 | secret_flag_utils._cache_lock = None
74 | secret_flag_utils._init_fork_module_state()
75 | self.assertIsNotNone(secret_flag_utils._cache)
76 | self.assertIsNotNone(secret_flag_utils._cache_lock)
77 |
78 | @parameterized.named_parameters(
79 | dict(
80 | testcase_name='expected',
81 | version_numbers=['1', '2', '3'],
82 | expected_version='3',
83 | ),
84 | dict(
85 | testcase_name='ignore_unexpected_version',
86 | version_numbers=['A', '1'],
87 | expected_version='1',
88 | ),
89 | )
90 | def test_get_secret_version(self, version_numbers, expected_version):
91 | client = mock.create_autospec(
92 | secretmanager.SecretManagerServiceClient, instance=True
93 | )
94 | client.list_secret_versions.return_value = (
95 | _create_mock_secret_version_response(
96 | f'{_MOCK_SECRET}/versions/', version_numbers
97 | )
98 | )
99 | version = secret_flag_utils._get_secret_version(client, _MOCK_SECRET)
100 | self.assertEqual(version, expected_version)
101 |
102 | def test_get_secret_version_cannot_find_version(self):
103 | client = mock.create_autospec(
104 | secretmanager.SecretManagerServiceClient, instance=True
105 | )
106 | client.list_secret_versions.return_value = (
107 | _create_mock_secret_version_response('', ['NO_MATCH', 'NO_MATCH'])
108 | )
109 | with self.assertRaises(secret_flag_utils.SecretDecodeError):
110 | secret_flag_utils._get_secret_version(client, _MOCK_SECRET)
111 |
112 | def test_get_secret_version_ignore_bad_response(self):
113 | client = mock.create_autospec(
114 | secretmanager.SecretManagerServiceClient, instance=True
115 | )
116 | client.list_secret_versions.return_value = []
117 | with self.assertRaises(secret_flag_utils.SecretDecodeError):
118 | secret_flag_utils._get_secret_version(client, _MOCK_SECRET)
119 |
120 | def test_read_secrets_path_to_version_empty(self):
121 | self.assertEqual(secret_flag_utils._read_secrets(''), {})
122 |
123 | def test_read_secrets_path_invalid_raises(self):
124 | with self.assertRaises(secret_flag_utils.SecretDecodeError):
125 | secret_flag_utils._read_secrets('no_match')
126 |
127 | def test_read_secrets_path_disabled_returns_empty_result(self):
128 | self.assertEqual(
129 | secret_flag_utils._read_secrets('projects/prj/secrets/sec'), {}
130 | )
131 |
132 | @parameterized.named_parameters([
133 | dict(
134 | testcase_name='none',
135 | data=None,
136 | expected_value={},
137 | ),
138 | dict(
139 | testcase_name='empty',
140 | data='',
141 | expected_value={},
142 | ),
143 | dict(
144 | testcase_name='empty_dict',
145 | data='{}',
146 | expected_value={},
147 | ),
148 | dict(
149 | testcase_name='defined_dict_str',
150 | data='{"abc": "123"}',
151 | expected_value={'abc': '123'},
152 | ),
153 | dict(
154 | testcase_name='defined_dict_bytes',
155 | data='{"abc": "123"}'.encode('utf-8'),
156 | expected_value={'abc': '123'},
157 | ),
158 | ])
159 | @mock.patch.object(
160 | secret_flag_utils, '_get_secret_version', autospec=True, return_value='1'
161 | )
162 | @mock.patch.object(
163 | secretmanager.SecretManagerServiceClient,
164 | 'access_secret_version',
165 | autospec=True,
166 | )
167 | def test_read_secrets_value(
168 | self, mk_access_secret, mk_get_secret_version, data, expected_value
169 | ):
170 | secret_flag_utils._ENABLE_ENV_SECRET_MANAGER = True
171 | project_secret = 'projects/prj/secrets/sec'
172 | secret_version = '1'
173 | mk_get_secret_version.return_value = secret_version
174 | mk_access_secret.return_value = _create_mock_secret_value_response(data)
175 | self.assertEqual(
176 | secret_flag_utils._read_secrets(project_secret),
177 | expected_value,
178 | )
179 | mk_get_secret_version.assert_called_once_with(mock.ANY, project_secret)
180 | mk_access_secret.assert_called_once_with(
181 | mock.ANY,
182 | request={'name': f'{project_secret}/versions/{secret_version}'},
183 | )
184 |
185 | @mock.patch.object(secret_flag_utils, '_get_secret_version', autospec=True)
186 | @mock.patch.object(
187 | secretmanager.SecretManagerServiceClient,
188 | 'access_secret_version',
189 | autospec=True,
190 | )
191 | def test_read_secrets_value_does_not_call_version_if_in_path(
192 | self, mk_access_secret, mk_get_secret_version
193 | ):
194 | secret_flag_utils._ENABLE_ENV_SECRET_MANAGER = True
195 | project_secret = 'projects/prj/secrets/sec/versions/2'
196 | mk_access_secret.return_value = _create_mock_secret_value_response('')
197 | self.assertEqual(
198 | secret_flag_utils._read_secrets(project_secret),
199 | {},
200 | )
201 | mk_get_secret_version.assert_not_called()
202 | mk_access_secret.assert_called_once_with(
203 | mock.ANY, request={'name': project_secret}
204 | )
205 |
206 | def test_read_secrets_returns_cache_value(self):
207 | secret_flag_utils._ENABLE_ENV_SECRET_MANAGER = True
208 | project_secret = 'projects/prj/secrets/sec/versions/3'
209 | expected = 'EXPECTED_VALUE'
210 | secret_flag_utils._cache[project_secret] = expected
211 | self.assertEqual(secret_flag_utils._read_secrets(project_secret), expected)
212 |
213 | @parameterized.named_parameters([
214 | dict(testcase_name='invalid_json', response='{'),
215 | dict(testcase_name='not_dict', response='[]'),
216 | ])
217 | @mock.patch.object(
218 | secretmanager.SecretManagerServiceClient,
219 | 'access_secret_version',
220 | autospec=True,
221 | )
222 | def test_read_secrets_returns_invalid_value_raises(
223 | self, mk_access_secret, response
224 | ):
225 | secret_flag_utils._ENABLE_ENV_SECRET_MANAGER = True
226 | project_secret = 'projects/prj/secrets/sec/versions/2'
227 | mk_access_secret.return_value = _create_mock_secret_value_response(response)
228 | with self.assertRaises(secret_flag_utils.SecretDecodeError):
229 | secret_flag_utils._read_secrets(project_secret)
230 |
231 | @parameterized.parameters([
232 | google.api_core.exceptions.NotFound('mock'),
233 | google.api_core.exceptions.PermissionDenied('mock'),
234 | ])
235 | @mock.patch.object(
236 | secretmanager.SecretManagerServiceClient,
237 | 'access_secret_version',
238 | autospec=True,
239 | )
240 | def test_access_secret_version_raises(self, exp, mk_access_secret):
241 | secret_flag_utils._ENABLE_ENV_SECRET_MANAGER = True
242 | project_secret = 'projects/prj/secrets/sec/versions/2'
243 | mk_access_secret.side_effect = exp
244 | with self.assertRaises(secret_flag_utils.SecretDecodeError):
245 | secret_flag_utils._read_secrets(project_secret)
246 |
247 | @parameterized.parameters([
248 | google.api_core.exceptions.NotFound('mock'),
249 | google.api_core.exceptions.PermissionDenied('mock'),
250 | ])
251 | @mock.patch.object(secret_flag_utils, '_get_secret_version', autospec=True)
252 | def test_get_secret_version_raises(self, exp, mk_get_secret_version):
253 | secret_flag_utils._ENABLE_ENV_SECRET_MANAGER = True
254 | project_secret = 'projects/prj/secrets/sec'
255 | mk_get_secret_version.side_effect = exp
256 | with self.assertRaises(secret_flag_utils.SecretDecodeError):
257 | secret_flag_utils._read_secrets(project_secret)
258 |
259 | @parameterized.named_parameters([
260 | dict(
261 | testcase_name='defined_env',
262 | env_name='MOCK_ENV',
263 | expected_value='DEFINED_VALUE',
264 | ),
265 | dict(
266 | testcase_name='not_defined_env',
267 | env_name='UNDEFINED_ENV',
268 | expected_value='MOCK_DEFAULT',
269 | ),
270 | ])
271 | def test_get_secret_or_env_no_config_defined_returns_env_value_or_default(
272 | self, env_name, expected_value
273 | ):
274 | with mock.patch.dict(os.environ, {'MOCK_ENV': 'DEFINED_VALUE'}):
275 | self.assertEqual(
276 | secret_flag_utils.get_secret_or_env(env_name, 'MOCK_DEFAULT'),
277 | expected_value,
278 | )
279 |
280 | @parameterized.named_parameters([
281 | dict(
282 | testcase_name='secret_env',
283 | search_env='SECRET_ENV',
284 | expected='SECRET_VALUE',
285 | ),
286 | dict(
287 | testcase_name='not_in_sec',
288 | search_env='NOT_IN_SEC',
289 | expected='ENV_VALUE',
290 | ),
291 | dict(
292 | testcase_name='not_in_env',
293 | search_env='NOT_IN_ENV',
294 | expected='MOCK_DEFAULT',
295 | ),
296 | ])
297 | @mock.patch.object(
298 | secretmanager.SecretManagerServiceClient,
299 | 'access_secret_version',
300 | autospec=True,
301 | )
302 | def test_get_secret_or_env_with_config(
303 | self, mk_access_secret, search_env, expected
304 | ):
305 | secret_flag_utils._ENABLE_ENV_SECRET_MANAGER = True
306 | project_secret = 'projects/prj/secrets/sec/versions/1'
307 | mk_access_secret.return_value = _create_mock_secret_value_response(
308 | '{"SECRET_ENV": "SECRET_VALUE"}'
309 | )
310 | with mock.patch.dict(
311 | os.environ,
312 | {
313 | secret_flag_utils._SECRET_MANAGER_ENV_CONFIG: project_secret,
314 | 'NOT_IN_SEC': 'ENV_VALUE',
315 | },
316 | ):
317 | self.assertEqual(
318 | secret_flag_utils.get_secret_or_env(search_env, 'MOCK_DEFAULT'),
319 | expected,
320 | )
321 | mk_access_secret.assert_called_once_with(
322 | mock.ANY,
323 | request={'name': project_secret},
324 | )
325 |
326 | @mock.patch.object(
327 | secretmanager.SecretManagerServiceClient,
328 | 'access_secret_version',
329 | autospec=True,
330 | )
331 | def test_get_secret_or_env_default_return_value(self, mk_access_secret):
332 | secret_flag_utils._ENABLE_ENV_SECRET_MANAGER = True
333 | project_secret = 'projects/prj/secrets/sec/versions/1'
334 | mk_access_secret.return_value = _create_mock_secret_value_response(
335 | '{"MOCK_ENV": "SECRET_VALUE"}'
336 | )
337 | with mock.patch.dict(
338 | os.environ,
339 | {
340 | secret_flag_utils._SECRET_MANAGER_ENV_CONFIG: project_secret,
341 | },
342 | ):
343 | self.assertIsNone(
344 | secret_flag_utils.get_secret_or_env('NOT_IN_SECRET_OR_ENV', None)
345 | )
346 | mk_access_secret.assert_called_once_with(
347 | mock.ANY,
348 | request={'name': project_secret},
349 | )
350 |
351 | @parameterized.named_parameters([
352 | dict(testcase_name='FOUND', env_name='MOCK_ENV', expected_value=True),
353 | dict(
354 | testcase_name='DEFAULT_MISSING',
355 | env_name='NOT_FOUND_ENV',
356 | expected_value=False,
357 | ),
358 | ])
359 | @mock.patch.object(
360 | secretmanager.SecretManagerServiceClient,
361 | 'access_secret_version',
362 | autospec=True,
363 | )
364 | def test_get_bool_secret_or_env(
365 | self, mk_access_secret, env_name, expected_value
366 | ):
367 | secret_flag_utils._ENABLE_ENV_SECRET_MANAGER = True
368 | project_secret = 'projects/prj/secrets/sec/versions/1'
369 | mk_access_secret.return_value = _create_mock_secret_value_response(
370 | '{"MOCK_ENV": "TRUE"}'
371 | )
372 | with mock.patch.dict(
373 | os.environ,
374 | {
375 | secret_flag_utils._SECRET_MANAGER_ENV_CONFIG: project_secret,
376 | },
377 | ):
378 | self.assertEqual(
379 | secret_flag_utils.get_bool_secret_or_env(env_name), expected_value
380 | )
381 | mk_access_secret.assert_called_once_with(
382 | mock.ANY,
383 | request={'name': project_secret},
384 | )
385 |
386 | @parameterized.named_parameters([
387 | dict(testcase_name='FOUND', env_name='ENV_FOUND', expected_value=False),
388 | dict(
389 | testcase_name='DEFAULT_MISSING',
390 | env_name='NOT_FOUND_ENV',
391 | expected_value=True,
392 | ),
393 | ])
394 | @mock.patch.object(
395 | secretmanager.SecretManagerServiceClient,
396 | 'access_secret_version',
397 | autospec=True,
398 | )
399 | def test_get_bool_secret_or_env_trys_env(
400 | self, mk_access_secret, env_name, expected_value
401 | ):
402 | secret_flag_utils._ENABLE_ENV_SECRET_MANAGER = True
403 | project_secret = 'projects/prj/secrets/sec/versions/1'
404 | mk_access_secret.return_value = _create_mock_secret_value_response('{}')
405 | with mock.patch.dict(
406 | os.environ,
407 | {
408 | secret_flag_utils._SECRET_MANAGER_ENV_CONFIG: project_secret,
409 | 'ENV_FOUND': 'False',
410 | },
411 | ):
412 | self.assertEqual(
413 | secret_flag_utils.get_bool_secret_or_env(env_name, True),
414 | expected_value,
415 | )
416 | mk_access_secret.assert_called_once_with(
417 | mock.ANY,
418 | request={'name': project_secret},
419 | )
420 |
421 |
422 | if __name__ == '__main__':
423 | absltest.main()
424 |
--------------------------------------------------------------------------------
/notebooks/quick_start_with_model_garden.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": null,
6 | "metadata": {
7 | "id": "tYba2hfAs0AS"
8 | },
9 | "outputs": [],
10 | "source": [
11 | "# Copyright 2025 Google LLC\n",
12 | "#\n",
13 | "# Licensed under the Apache License, Version 2.0 (the \"License\");\n",
14 | "# you may not use this file except in compliance with the License.\n",
15 | "# You may obtain a copy of the License at\n",
16 | "#\n",
17 | "# https://www.apache.org/licenses/LICENSE-2.0\n",
18 | "#\n",
19 | "# Unless required by applicable law or agreed to in writing, software\n",
20 | "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
21 | "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
22 | "# See the License for the specific language governing permissions and\n",
23 | "# limitations under the License."
24 | ]
25 | },
26 | {
27 | "cell_type": "markdown",
28 | "metadata": {
29 | "id": "kILpWz4Gs5Kh"
30 | },
31 | "source": [
32 | "# Quick start with Model Garden - MedASR\n",
33 | "\n",
34 | "
\n",
35 | " \n",
36 | " \n",
37 | "  Run in Colab Enterprise\n",
38 | " \n",
39 | " | \n",
40 | " \n",
41 | " \n",
42 | "  View on GitHub\n",
43 | " \n",
44 | " | \n",
45 | "
"
46 | ]
47 | },
48 | {
49 | "cell_type": "markdown",
50 | "metadata": {
51 | "id": "JRGljGF2lUWd"
52 | },
53 | "source": [
54 | "## Overview\n",
55 | "\n",
56 | "This notebook demonstrates how to use MedASR in Vertex AI to transcribe medical audio to text using online inference.\n",
57 | "\n",
58 | "**Online inferences** are synchronous requests that are made to the endpoint deployed from Model Garden and are served with low latency. Online inferences are useful if the model outputs are being used in production. The cost for online inference is based on the time a virtual machine spends waiting in an active state (an endpoint with a deployed model) to handle inference requests.\n",
59 | "\n",
60 | "Vertex AI makes it easy to serve your model and make it accessible to the world. Learn more about [Vertex AI](https://cloud.google.com/vertex-ai/docs/start/introduction-unified-platform).\n",
61 | "\n",
62 | "### Objectives\n",
63 | "\n",
64 | "- Deploy MedASR to a Vertex AI Endpoint and get online inferences.\n",
65 | "\n",
66 | "### Costs\n",
67 | "\n",
68 | "This tutorial uses billable components of Google Cloud:\n",
69 | "\n",
70 | "* Vertex AI\n",
71 | "* Cloud Storage\n",
72 | "\n",
73 | "Learn about [Vertex AI pricing](https://cloud.google.com/vertex-ai/pricing), [Cloud Storage pricing](https://cloud.google.com/storage/pricing), and use the [Pricing Calculator](https://cloud.google.com/products/calculator/) to generate a cost estimate based on your projected usage."
74 | ]
75 | },
76 | {
77 | "cell_type": "markdown",
78 | "metadata": {
79 | "id": "emPr6M1fs5Kj"
80 | },
81 | "source": [
82 | "## Before you begin"
83 | ]
84 | },
85 | {
86 | "cell_type": "code",
87 | "execution_count": null,
88 | "metadata": {
89 | "id": "ag_zmUlgmJD8",
90 | "cellView": "form"
91 | },
92 | "outputs": [],
93 | "source": [
94 | "# @title Install dependencies and import packages\n",
95 | "\n",
96 | "! pip install -qU --upgrade pip\n",
97 | "! pip install -qU 'google-cloud-aiplatform>=1.101.0' jiwer levenshtein\n",
98 | "\n",
99 | "import base64\n",
100 | "import json\n",
101 | "import os\n",
102 | "\n",
103 | "from google.cloud import aiplatform\n",
104 | "from IPython.display import Audio, display\n",
105 | "\n",
106 | "models, endpoints = {}, {}"
107 | ]
108 | },
109 | {
110 | "cell_type": "code",
111 | "execution_count": null,
112 | "metadata": {
113 | "cellView": "form",
114 | "id": "zMsN9Ep1mJD8"
115 | },
116 | "outputs": [],
117 | "source": [
118 | "# @title Set up Google Cloud environment\n",
119 | "\n",
120 | "# @markdown #### Prerequisites\n",
121 | "\n",
122 | "# @markdown 1. Make sure that [billing is enabled](https://cloud.google.com/billing/docs/how-to/modify-project) for your project.\n",
123 | "\n",
124 | "# @markdown 2. Make sure that either the Compute Engine API is enabled or that you have the [Service Usage Admin](https://cloud.google.com/iam/docs/understanding-roles#serviceusage.serviceUsageAdmin) (`roles/serviceusage.serviceUsageAdmin`) role to enable the API.\n",
125 | "\n",
126 | "# @markdown This section sets the default Google Cloud project and region, enables the Compute Engine API (if not already enabled), and initializes the Vertex AI API.\n",
127 | "\n",
128 | "# Get the default project ID.\n",
129 | "PROJECT_ID = os.environ[\"GOOGLE_CLOUD_PROJECT\"]\n",
130 | "\n",
131 | "# Get the default region for launching jobs.\n",
132 | "REGION = os.environ[\"GOOGLE_CLOUD_REGION\"]\n",
133 | "\n",
134 | "# Enable the Compute Engine API, if not already.\n",
135 | "print(\"Enabling Compute Engine API.\")\n",
136 | "! gcloud services enable compute.googleapis.com\n",
137 | "\n",
138 | "# Initialize Vertex AI API.\n",
139 | "print(\"Initializing Vertex AI API.\")\n",
140 | "aiplatform.init(project=PROJECT_ID, location=REGION)"
141 | ]
142 | },
143 | {
144 | "cell_type": "code",
145 | "source": [
146 | "# @title Retrieve sample data\n",
147 | "\n",
148 | "# @markdown This notebook uses a sample medical audio file and transcript.\n",
149 | "\n",
150 | "! gcloud storage cp gs://healthai-us/medasr/test_audio.wav test_audio.wav\n",
151 | "with open(\"test_audio.wav\", \"rb\") as f:\n",
152 | " audio_bytes = f.read()\n",
153 | "sample_transcript = \"Exam type CT chest PE protocol period. Indication 54 year old female, shortness of breath, evaluate for PE period. Technique standard protocol period. Findings colon. Pulmonary vasculature colon. The main PA is patent period. There are filling defects in the segmental branches of the right lower lobe comma compatible with acute PE period. No saddle embolus period. Lungs colon. No pneumothorax period. Small bilateral effusions comma right greater than left period. New paragraph. Impression colon. Acute segmental PE right lower lobe period.\"\n",
154 | "display(Audio(audio_bytes, autoplay=False))"
155 | ],
156 | "metadata": {
157 | "id": "4GJMn-QvZ4up",
158 | "cellView": "form"
159 | },
160 | "execution_count": null,
161 | "outputs": []
162 | },
163 | {
164 | "cell_type": "code",
165 | "source": [
166 | "# @title Define utility functions\n",
167 | "\n",
168 | "# @markdown These functions will be used to evaluate the word error rate (WER) of the generated transcripts.\n",
169 | "\n",
170 | "import re\n",
171 | "import jiwer\n",
172 | "import Levenshtein\n",
173 | "\n",
174 | "def normalize(s: str) -> str:\n",
175 | " s = s.lower()\n",
176 | " s = re.sub(r\"[^ a-z0-9']\", ' ', s)\n",
177 | " s = ' '.join(s.split())\n",
178 | " return s\n",
179 | "\n",
180 | "def _colored(text, color):\n",
181 | " if color == 'red':\n",
182 | " return f\"\\033[91m{text}\\033[0m\"\n",
183 | " elif color == 'green':\n",
184 | " return f\"\\033[92m{text}\\033[0m\"\n",
185 | " return text\n",
186 | "\n",
187 | "def evaluate(\n",
188 | " ref_text: str,\n",
189 | " hyp_text: str,\n",
190 | " delete_color: str = 'red',\n",
191 | " insert_color: str = 'green',\n",
192 | ") -> None:\n",
193 | " print('HYP:', hyp_text)\n",
194 | " normalized_ref = normalize(ref_text)\n",
195 | " normalized_hyp = normalize(hyp_text)\n",
196 | "\n",
197 | " # Calculate word lists early so we can use them for both jiwer and diffs\n",
198 | " ref_words = normalized_ref.split()\n",
199 | " hyp_words = normalized_hyp.split()\n",
200 | "\n",
201 | " # jiwer.process_words expects a list of strings (sentences) or list of list of words\n",
202 | " measures = jiwer.process_words([normalized_ref], [normalized_hyp])\n",
203 | "\n",
204 | " # Calculate edit operations using Levenshtein for the colored diff\n",
205 | " edits = Levenshtein.editops(ref_words, hyp_words)\n",
206 | "\n",
207 | " r = 0 # Index for the reference words for diff building\n",
208 | " diff = ''\n",
209 | "\n",
210 | " for op, i, j in edits:\n",
211 | " # Add matched words before the current edit\n",
212 | " if r < i:\n",
213 | " diff += ' ' + ' '.join(ref_words[r:i])\n",
214 | " r = i # Update reference index for next iteration\n",
215 | "\n",
216 | " if op == 'replace':\n",
217 | " diff += (\n",
218 | " f' {_colored(f\"{{-{ref_words[i]}-}}\", delete_color)}'\n",
219 | " f' {_colored(f\"{{+{hyp_words[j]}+}}\", insert_color)}'\n",
220 | " )\n",
221 | " r += 1 # Advance reference index after replacement\n",
222 | " elif op == 'insert':\n",
223 | " diff += f' {_colored(f\"{{+{hyp_words[j]}+}}\", insert_color)}'\n",
224 | " # Reference index `r` does not advance for an insertion\n",
225 | " elif op == 'delete':\n",
226 | " diff += f' {_colored(f\"{{-{ref_words[i]}-}}\", delete_color)}'\n",
227 | " r += 1 # Advance reference index after deletion\n",
228 | "\n",
229 | " # Add any remaining matched words from the reference\n",
230 | " if r < len(ref_words):\n",
231 | " diff += ' ' + ' '.join(ref_words[r:])\n",
232 | "\n",
233 | " print(\n",
234 | " f'WER: {measures.wer * 100:.2f}%: '\n",
235 | " f'insertions {measures.insertions}, deletions {measures.deletions}, substitutions {measures.substitutions}, '\n",
236 | " f'ref tokens {len(ref_words)}'\n",
237 | " )\n",
238 | " print(diff)"
239 | ],
240 | "metadata": {
241 | "cellView": "form",
242 | "id": "1bz6RCO2c7h4"
243 | },
244 | "execution_count": null,
245 | "outputs": []
246 | },
247 | {
248 | "cell_type": "markdown",
249 | "metadata": {
250 | "id": "Bak-klNTmJD8"
251 | },
252 | "source": [
253 | "## Get online inferences"
254 | ]
255 | },
256 | {
257 | "cell_type": "code",
258 | "execution_count": null,
259 | "metadata": {
260 | "id": "tjOWCeGJu-94",
261 | "cellView": "form"
262 | },
263 | "outputs": [],
264 | "source": [
265 | "# @title Import deployed model\n",
266 | "\n",
267 | "# @markdown To get [online inferences](https://cloud.google.com/vertex-ai/docs/predictions/get-online-predictions), you will need a MedASR [Vertex AI Endpoint](https://cloud.google.com/vertex-ai/docs/general/deployment) that has been deployed from Model Garden. If you have not already done so, go to the [MedASR model card](https://console.cloud.google.com/vertex-ai/publishers/google/model-garden/medasr) and click \"Deploy model\" to deploy the model.\n",
268 | "\n",
269 | "# @markdown Note: Endpoints deployed from Model Garden must be [dedicated endpoints](https://cloud.google.com/vertex-ai/docs/predictions/choose-endpoint-type).\n",
270 | "\n",
271 | "# @markdown This section gets the Vertex AI Endpoint resource that you deployed from Model Garden to use for online inferences.\n",
272 | "\n",
273 | "# @markdown Fill in the endpoint ID and region below. You can find your deployed endpoint on the [Vertex AI Endpoints page](https://console.cloud.google.com/vertex-ai/online-prediction/endpoints).\n",
274 | "\n",
275 | "ENDPOINT_ID = \"\" # @param {type: \"string\", placeholder:\"e.g. 123456789\"}\n",
276 | "ENDPOINT_REGION = \"\" # @param {type: \"string\", placeholder:\"e.g. us-central1\"}\n",
277 | "\n",
278 | "endpoints[\"endpoint\"] = aiplatform.Endpoint(\n",
279 | " endpoint_name=ENDPOINT_ID,\n",
280 | " project=PROJECT_ID,\n",
281 | " location=ENDPOINT_REGION,\n",
282 | ")"
283 | ]
284 | },
285 | {
286 | "cell_type": "code",
287 | "execution_count": null,
288 | "metadata": {
289 | "id": "XhijNE6PnWn7",
290 | "cellView": "form"
291 | },
292 | "outputs": [],
293 | "source": [
294 | "# @title Run inference using the Vertex AI SDK\n",
295 | "\n",
296 | "# @markdown This section shows how to send [online prediction](https://cloud.google.com/vertex-ai/docs/predictions/get-online-predictions) requests to your Vertex AI endpoint.\n",
297 | "\n",
298 | "# @markdown Click \"Show code\" to see more details.\n",
299 | "\n",
300 | "request = {\n",
301 | " \"file\": base64.b64encode(audio_bytes).decode(\"utf-8\"),\n",
302 | "}\n",
303 | "\n",
304 | "response = endpoints[\"endpoint\"].raw_predict(\n",
305 | " body=json.dumps(request).encode(\"utf-8\"),\n",
306 | " headers={\n",
307 | " \"Content-Type\": \"application/json\",\n",
308 | " },\n",
309 | ")\n",
310 | "generated_transcript = json.loads(response.content)[\"text\"]\n",
311 | "\n",
312 | "print(generated_transcript)\n",
313 | "evaluate(sample_transcript, generated_transcript)"
314 | ]
315 | },
316 | {
317 | "cell_type": "markdown",
318 | "metadata": {
319 | "id": "4yfFlkncxzDF"
320 | },
321 | "source": [
322 | "\n",
323 | "## Next steps\n",
324 | "\n",
325 | "Explore the other [notebooks](https://github.com/google-health/medasr/blob/main/notebooks) to learn what else you can do with the model.\n"
326 | ]
327 | },
328 | {
329 | "cell_type": "markdown",
330 | "metadata": {
331 | "id": "paQNyrzT_mX_"
332 | },
333 | "source": [
334 | "## Clean up resources"
335 | ]
336 | },
337 | {
338 | "cell_type": "code",
339 | "execution_count": null,
340 | "metadata": {
341 | "id": "edUIpvZZ_mYA",
342 | "cellView": "form"
343 | },
344 | "outputs": [],
345 | "source": [
346 | "# @markdown Delete the experiment models and endpoints to recycle the resources\n",
347 | "# @markdown and avoid unnecessary continuous charges that may incur.\n",
348 | "\n",
349 | "# Undeploy model and delete endpoint.\n",
350 | "for endpoint in endpoints.values():\n",
351 | " endpoint.delete(force=True)\n",
352 | "\n",
353 | "# Delete models.\n",
354 | "for model in models.values():\n",
355 | " model.delete()"
356 | ]
357 | }
358 | ],
359 | "metadata": {
360 | "colab": {
361 | "provenance": [],
362 | "toc_visible": true
363 | },
364 | "kernelspec": {
365 | "display_name": "Python 3",
366 | "name": "python3"
367 | },
368 | "language_info": {
369 | "name": "python"
370 | }
371 | },
372 | "nbformat": 4,
373 | "nbformat_minor": 0
374 | }
375 |
--------------------------------------------------------------------------------
/python/serving/logging_lib/cloud_logging_client_instance.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | """Wrapper for cloud ops structured logging.
16 |
17 | Get instance of logger:
18 | logger() -> CloudLoggingClient
19 |
20 | logger().log_signature = dict to append to logs.
21 | """
22 | import collections
23 | import copy
24 | import enum
25 | import inspect
26 | import logging
27 | import math
28 | import os
29 | import sys
30 | import threading
31 | import time
32 | import traceback
33 | from typing import Any, Mapping, MutableMapping, Optional, Tuple, Union
34 |
35 | from absl import logging as absl_logging
36 | import google.auth
37 | from google.cloud import logging as cloud_logging
38 |
39 | # Debug/testing option to logs to absl.logger. Automatically set when
40 | # running unit tests
41 | DEBUG_LOGGING_USE_ABSL_LOGGING = bool(
42 | 'UNITTEST_ON_FORGE' in os.environ or 'unittest' in sys.modules
43 | )
44 |
45 |
46 | class _LogSeverity(enum.Enum):
47 | CRITICAL = logging.CRITICAL
48 | ERROR = logging.ERROR
49 | WARNING = logging.WARNING
50 | INFO = logging.INFO
51 | DEBUG = logging.DEBUG
52 |
53 |
54 | MAX_LOG_SIZE = 246000
55 |
56 |
57 | class CloudLoggerInstanceExceptionError(Exception):
58 | pass
59 |
60 |
61 | def _merge_struct(
62 | dict_tuple: Tuple[Union[Mapping[str, Any], Exception, None], ...],
63 | ) -> Optional[MutableMapping[str, str]]:
64 | """Merges a list of dict and ordered dicts.
65 |
66 | * for dict adds item in key sorted order
67 | * preserves order for ordered dicts.
68 | Args:
69 | dict_tuple: dicts and ordered dicts to merge
70 |
71 | Returns:
72 | merged dict.
73 | """
74 | if not dict_tuple:
75 | return None
76 | return_dict = collections.OrderedDict()
77 | for dt in dict_tuple:
78 | if dt is None:
79 | continue
80 | if isinstance(dt, Exception):
81 | # Log exception text and exception stack trace.
82 | exception_str = str(dt)
83 | if exception_str:
84 | exception_str = f'{exception_str}\n'
85 | return_dict['exception'] = f'{exception_str}{traceback.format_exc()}'
86 | else:
87 | keylist = list(dt)
88 | if not isinstance(dt, collections.OrderedDict):
89 | keylist = sorted(keylist)
90 | for key in keylist:
91 | return_dict[key] = str(dt[key])
92 | return return_dict
93 |
94 |
95 | def _absl_log(msg: str, severity: _LogSeverity = _LogSeverity.INFO) -> None:
96 | """Logs using absl logging.
97 |
98 | Args:
99 | msg: Message to log.
100 | severity: Severity of message.
101 | """
102 | if severity == _LogSeverity.DEBUG:
103 | absl_logging.debug(msg)
104 | elif severity == _LogSeverity.WARNING:
105 | absl_logging.warning(msg)
106 | elif severity == _LogSeverity.INFO:
107 | absl_logging.info(msg)
108 | else:
109 | absl_logging.error(msg)
110 |
111 |
112 | def _py_log(
113 | dpas_logger: logging.Logger,
114 | msg: str,
115 | extra: Mapping[str, Any],
116 | severity: _LogSeverity,
117 | ) -> None:
118 | """Logs msg and structured logging using python logger."""
119 | if severity == _LogSeverity.DEBUG:
120 | dpas_logger.debug(msg, extra=extra)
121 | elif severity == _LogSeverity.WARNING:
122 | dpas_logger.warning(msg, extra=extra)
123 | elif severity == _LogSeverity.INFO:
124 | dpas_logger.info(msg, extra=extra)
125 | elif severity == _LogSeverity.CRITICAL:
126 | dpas_logger.critical(msg, extra=extra)
127 | elif severity == _LogSeverity.ERROR:
128 | dpas_logger.error(msg, extra=extra)
129 | else:
130 | raise CloudLoggerInstanceExceptionError(
131 | f'Unsupported logging severity level; Severity="{severity}"'
132 | )
133 |
134 |
135 | def _get_source_location_to_log(stack_frames_back: int) -> Mapping[str, Any]:
136 | """Adds Python source location information to cloud structured logs.
137 |
138 | The source location is added by adding (and overwriting if present) a
139 | "source_location" key to the provided additional_parameters.
140 | The value corresponding to that key is a dict mapping:
141 | "file" to the name of the file (str) containing the logging statement,
142 | "function" python function/method calling logging method
143 | "line" to the line number where the log was recorded (int).
144 |
145 | Args:
146 | stack_frames_back: Additional stack frames back to log source_location.
147 |
148 | Returns:
149 | Source location formatted for structured logging.
150 |
151 | Raises:
152 | ValueError: If stack frame cannot be found for specified position.
153 | """
154 | source_location = {}
155 | current_frame = inspect.currentframe()
156 | for _ in range(stack_frames_back + 1):
157 | if current_frame is None:
158 | raise ValueError('Cannot get stack frame for specified position.')
159 | current_frame = current_frame.f_back
160 | try:
161 | frame_info = inspect.getframeinfo(current_frame)
162 | source_location['source_location'] = dict(
163 | file=frame_info.filename,
164 | function=frame_info.function,
165 | line=frame_info.lineno,
166 | )
167 | finally:
168 | # https://docs.python.org/3/library/inspect.html
169 | del current_frame # explicitly deleting
170 | return source_location
171 |
172 |
173 | def _add_trace_to_log(
174 | project_id: str, trace_key: str, struct: Mapping[str, Any]
175 | ) -> Mapping[str, Any]:
176 | if not project_id or not trace_key:
177 | return {}
178 | trace_id = struct.get(trace_key, '')
179 | if trace_id:
180 | return {'trace': f'projects/{project_id}/traces/{trace_id}'}
181 | return {}
182 |
183 |
184 | class CloudLoggingClientInstance:
185 | """Wrapper for cloud ops structured logging.
186 |
187 | Automatically adds signature to structured logs to make traceable.
188 | """
189 |
190 | # global state to prevent duplicate initialization of cloud logging interfaces
191 | # within a process.
192 | _global_lock = threading.Lock()
193 | # Cloud logging handler init at process level.
194 | _cloud_logging_handler: Optional[
195 | cloud_logging.handlers.CloudLoggingHandler
196 | ] = None
197 | _cloud_logging_handler_init_params = ''
198 |
199 | @classmethod
200 | def _init_fork_module_state(cls) -> None:
201 | cls._global_lock = threading.Lock()
202 | cls._cloud_logging_handler = None
203 | cls._cloud_logging_handler_init_params = ''
204 |
205 | @classmethod
206 | def fork_shutdown(cls) -> None:
207 | with cls._global_lock:
208 | cls._cloud_logging_handler_init_params = ''
209 | handler = cls._cloud_logging_handler
210 | if handler is None:
211 | return
212 | handler.transport.worker.stop()
213 | logging.getLogger().removeHandler(handler)
214 | handler.close()
215 | cls._cloud_logging_handler = None
216 |
217 | @classmethod
218 | def _set_absl_skip_frames(cls) -> None:
219 | """Sets absl logging attribution to skip over internal logging frames."""
220 | absl_logging.ABSLLogger.register_frame_to_skip(
221 | __file__,
222 | function_name='_log',
223 | )
224 | absl_logging.ABSLLogger.register_frame_to_skip(
225 | __file__,
226 | function_name='_absl_log',
227 | )
228 | absl_logging.ABSLLogger.register_frame_to_skip(
229 | __file__,
230 | function_name='debug',
231 | )
232 | absl_logging.ABSLLogger.register_frame_to_skip(
233 | __file__,
234 | function_name='timed_debug',
235 | )
236 | absl_logging.ABSLLogger.register_frame_to_skip(
237 | __file__,
238 | function_name='info',
239 | )
240 | absl_logging.ABSLLogger.register_frame_to_skip(
241 | __file__,
242 | function_name='warning',
243 | )
244 | absl_logging.ABSLLogger.register_frame_to_skip(
245 | __file__,
246 | function_name='error',
247 | )
248 | absl_logging.ABSLLogger.register_frame_to_skip(
249 | __file__,
250 | function_name='critical',
251 | )
252 |
253 | def __init__(
254 | self,
255 | log_name: str = 'python',
256 | gcp_project_to_write_logs_to: str = '',
257 | gcp_credentials: Optional[google.auth.credentials.Credentials] = None,
258 | pod_hostname: str = '',
259 | pod_uid: str = '',
260 | enable_structured_logging: bool = True,
261 | use_absl_logging: bool = DEBUG_LOGGING_USE_ABSL_LOGGING,
262 | log_all_python_logs_to_cloud: bool = False,
263 | enabled: bool = True,
264 | log_error_level: int = _LogSeverity.DEBUG.value,
265 | per_thread_log_signatures: bool = True,
266 | trace_key: str = '',
267 | build_version: str = '',
268 | ):
269 | """Constructor.
270 |
271 | Args:
272 | log_name: Log name to write logs to.
273 | gcp_project_to_write_logs_to: GCP project name to write log to. Undefined
274 | = default.
275 | gcp_credentials: The OAuth2 Credentials to use for this client
276 | (None=default).
277 | pod_hostname: Host name of GKE pod. Should be empty if not running in GKE.
278 | pod_uid: UID of GKE pod. Should be empty if not running in GKE.
279 | enable_structured_logging: Enable structured logging.
280 | use_absl_logging: Send logs to absl logging instead of cloud_logging.
281 | log_all_python_logs_to_cloud: Logs everything to cloud.
282 | enabled: If disabled, logging is not initialized and logging operations
283 | are nops.
284 | log_error_level: Error level at which logger will log.
285 | per_thread_log_signatures: Log signatures reported per thread.
286 | trace_key: Log key value which contains a trace id value.
287 | build_version: Build version to embedd in logs container.
288 | """
289 | # lock for log makes access to singleton
290 | # safe across threads. Logging used in main thread and ack_timeout_mon
291 | CloudLoggingClientInstance._set_absl_skip_frames()
292 | self._build_version = build_version
293 | self._enabled = enabled
294 | self._trace_key = trace_key
295 | self._log_error_level = log_error_level
296 | self._log_lock = threading.RLock()
297 | self._log_name = log_name.strip()
298 | self._pod_hostname = pod_hostname.strip()
299 | self._pod_uid = pod_uid.strip()
300 | self._per_thread_log_signatures = per_thread_log_signatures
301 | self._thread_local_storage = threading.local()
302 | self._shared_log_signature = self._signature_defaults(0)
303 | self._enable_structured_logging = enable_structured_logging
304 | self._debug_log_time = time.time()
305 | self._gcp_project_name = gcp_project_to_write_logs_to.strip()
306 | self._use_absl_logging = use_absl_logging
307 | self._log_all_python_logs_to_cloud = log_all_python_logs_to_cloud
308 | self._gcp_credentials = gcp_credentials
309 | absl_logging.set_verbosity(absl_logging.INFO)
310 | self._python_logger = self._init_cloud_handler()
311 |
312 | @property
313 | def trace_key(self) -> str:
314 | return self._trace_key
315 |
316 | @trace_key.setter
317 | def trace_key(self, val: str) -> None:
318 | self._trace_key = val
319 |
320 | @property
321 | def per_thread_log_signatures(self) -> bool:
322 | return self._per_thread_log_signatures
323 |
324 | @per_thread_log_signatures.setter
325 | def per_thread_log_signatures(self, val: bool) -> None:
326 | with self._log_lock:
327 | self._per_thread_log_signatures = val
328 |
329 | @property
330 | def python_logger(self) -> logging.Logger:
331 | if (
332 | self._enabled
333 | and not self._use_absl_logging
334 | and CloudLoggingClientInstance._cloud_logging_handler is None
335 | ):
336 | self._python_logger = self._init_cloud_handler()
337 | return self._python_logger
338 |
339 | def _get_python_logger_name(self) -> Optional[str]:
340 | return None if self._log_all_python_logs_to_cloud else 'DPASLogger'
341 |
342 | def _get_cloud_logging_handler_init_params(self) -> str:
343 | return (
344 | f'GCP_PROJECT_NAME: {self._gcp_project_name}; LOG_NAME:'
345 | f' {self._log_name}; LOG_ALL: {self._log_all_python_logs_to_cloud}'
346 | )
347 |
348 | def _init_cloud_handler(self) -> logging.Logger:
349 | """Initializes cloud logging handler and returns python logger."""
350 | # Instantiates a cloud logging client to generate text logs for cloud
351 | # operations
352 | with CloudLoggingClientInstance._global_lock:
353 | if not self._enabled or self._use_absl_logging:
354 | return logging.getLogger() # Default PY logger
355 | handler_instance_init_params = (
356 | self._get_cloud_logging_handler_init_params()
357 | )
358 | if CloudLoggingClientInstance._cloud_logging_handler is not None:
359 | running_handler_init_params = (
360 | CloudLoggingClientInstance._cloud_logging_handler_init_params
361 | )
362 | if running_handler_init_params != handler_instance_init_params:
363 | # Call fork_shutdown to shutdown the process's named logging handler.
364 | raise CloudLoggerInstanceExceptionError(
365 | 'Cloud logging handler is running with parameters that do not'
366 | ' match instance defined parameters. Running handler parameters:'
367 | f' {running_handler_init_params}; Instance parameters:'
368 | f' {handler_instance_init_params}'
369 | )
370 | return logging.getLogger(self._get_python_logger_name())
371 | log_name = self.log_name
372 | struct_log = {}
373 | struct_log['log_name'] = log_name
374 | struct_log['log_all_python_logs'] = self._log_all_python_logs_to_cloud
375 | try:
376 | # Attach default python & absl logger to also write to named log.
377 | logging_client = cloud_logging.Client(
378 | project=self._gcp_project_name if self._gcp_project_name else None,
379 | credentials=self._gcp_credentials,
380 | )
381 | logging_client.project = (
382 | self._gcp_project_name if self._gcp_project_name else None
383 | )
384 | handler = cloud_logging.handlers.CloudLoggingHandler(
385 | client=logging_client,
386 | name=log_name,
387 | )
388 | CloudLoggingClientInstance._cloud_logging_handler = handler
389 | CloudLoggingClientInstance._cloud_logging_handler_init_params = (
390 | handler_instance_init_params
391 | )
392 | cloud_logging.handlers.setup_logging(
393 | handler,
394 | log_level=logging.DEBUG
395 | if self._log_all_python_logs_to_cloud
396 | else logging.INFO,
397 | )
398 | dpas_python_logger = logging.getLogger(self._get_python_logger_name())
399 | dpas_python_logger.setLevel(
400 | logging.DEBUG
401 | ) # pytype: disable=attribute-error
402 | return dpas_python_logger
403 | except google.auth.exceptions.DefaultCredentialsError as exp:
404 | self._use_absl_logging = True
405 | self.error('Error initializing logging.', struct_log, exp)
406 | return logging.getLogger()
407 | except Exception as exp:
408 | self._use_absl_logging = True
409 | self.error('Error unexpected exception.', struct_log, exp)
410 | raise
411 |
412 | def __getstate__(self) -> MutableMapping[str, Any]:
413 | """Returns log state for pickle removes lock."""
414 | dct = copy.copy(self.__dict__)
415 | del dct['_log_lock']
416 | del dct['_python_logger']
417 | del dct['_thread_local_storage']
418 | return dct
419 |
420 | def __setstate__(self, dct: MutableMapping[str, Any]):
421 | """Un-pickles class and re-creates log lock."""
422 | self.__dict__ = dct
423 | self._log_lock = threading.RLock()
424 | self._thread_local_storage = threading.local()
425 | # Re-init logging in process.
426 | self._python_logger = self._init_cloud_handler()
427 |
428 | def use_absl_logging(self) -> bool:
429 | return self._use_absl_logging
430 |
431 | @property
432 | def enable_structured_logging(self) -> bool:
433 | return self._enable_structured_logging
434 |
435 | @enable_structured_logging.setter
436 | def enable_structured_logging(self, val: bool) -> None:
437 | with self._log_lock:
438 | self._enable_structured_logging = val
439 |
440 | @property
441 | def enable(self) -> bool:
442 | return self._enabled
443 |
444 | @enable.setter
445 | def enable(self, val: bool) -> None:
446 | with self._log_lock:
447 | self._enabled = val
448 |
449 | @property
450 | def gcp_project_name(self) -> str:
451 | return self._gcp_project_name
452 |
453 | @property
454 | def log_name(self) -> str:
455 | if not self._log_name:
456 | raise ValueError('Undefined Log Name')
457 | return self._log_name
458 |
459 | @property
460 | def hostname(self) -> str:
461 | if not self._pod_hostname:
462 | raise ValueError('POD_HOSTNAME name is not defined.')
463 | return self._pod_hostname
464 |
465 | @property
466 | def build_version(self) -> str:
467 | """Returns build version # for container."""
468 | return self._build_version
469 |
470 | @build_version.setter
471 | def build_version(self, version: str) -> None:
472 | """Returns build version # for container."""
473 | self._build_version = version
474 | with self._log_lock:
475 | if self._per_thread_log_signatures:
476 | if not hasattr(self._thread_local_storage, 'signature'):
477 | self._thread_local_storage.signature = self._signature_defaults(
478 | threading.get_native_id()
479 | )
480 | log_sig = self._thread_local_storage.signature
481 | else:
482 | log_sig = self._shared_log_signature
483 | if not version and 'BUILD_VERSION' in log_sig:
484 | del log_sig['BUILD_VERSION']
485 | else:
486 | log_sig['BUILD_VERSION'] = str(version)
487 |
488 | @property
489 | def pod_uid(self) -> str:
490 | if not self._pod_uid:
491 | raise ValueError('Undefined POD UID')
492 | return self._pod_uid
493 |
494 | def _get_thread_signature(self) -> MutableMapping[str, Any]:
495 | if not self._per_thread_log_signatures:
496 | return self._shared_log_signature
497 | if not hasattr(self._thread_local_storage, 'signature'):
498 | self._thread_local_storage.signature = self._signature_defaults(
499 | threading.get_native_id()
500 | )
501 | return self._thread_local_storage.signature
502 |
503 | @property
504 | def log_signature(self) -> MutableMapping[str, Any]:
505 | """Returns log signature.
506 |
507 | Log signature returned may not match what is currently being logged.
508 | if thread is set to log using another threads log signature.
509 | """
510 | with self._log_lock:
511 | return copy.copy(self._get_thread_signature())
512 |
513 | def _signature_defaults(self, thread_id: int) -> MutableMapping[str, str]:
514 | """Returns default log signature."""
515 | log_signature = collections.OrderedDict()
516 | if self._pod_hostname:
517 | log_signature['HOSTNAME'] = str(self._pod_hostname)
518 | if self._pod_uid:
519 | log_signature['POD_UID'] = str(self._pod_uid)
520 | if self.build_version:
521 | log_signature['BUILD_VERSION'] = str(self.build_version)
522 | if self._per_thread_log_signatures:
523 | log_signature['THREAD_ID'] = str(thread_id)
524 | return log_signature
525 |
526 | @log_signature.setter
527 | def log_signature(self, sig: Mapping[str, Any]) -> None:
528 | """Sets log signature.
529 |
530 | Log signature of thread may not be altered if thread is set to log using
531 | another threads log signature.
532 |
533 | Args:
534 | sig: Signature for threads to logs to use.
535 | """
536 | with self._log_lock:
537 | if self._per_thread_log_signatures:
538 | thread_id = threading.get_native_id()
539 | if not hasattr(self._thread_local_storage, 'log_signature'):
540 | self._thread_local_storage.signature = collections.OrderedDict()
541 | log_sig = self._thread_local_storage.signature
542 | else:
543 | thread_id = 0
544 | log_sig = self._shared_log_signature
545 | log_sig.clear()
546 | if sig is not None:
547 | for key in sorted(sig):
548 | log_sig[str(key)] = str(sig[key])
549 | log_sig.update(self._signature_defaults(thread_id))
550 |
551 | def clear_log_signature(self) -> None:
552 | """Clears thread log signature."""
553 | with self._log_lock:
554 | if self._per_thread_log_signatures:
555 | self._thread_local_storage.signature = self._signature_defaults(
556 | threading.get_native_id()
557 | )
558 | else:
559 | self._shared_log_signature = self._signature_defaults(0)
560 |
561 | def _clip_struct_log(
562 | self, log: MutableMapping[str, Any], max_log_size: int
563 | ) -> None:
564 | """Clip log if structed log exceeds structured log size limits.
565 |
566 | Clipping logic:
567 | log size = total sum of key + value sizes of log structure
568 |
569 | log['message'] and signature components are not clipped to keep
570 | log message text un-altered and message traceability preserved.
571 | log structure keys are not altered.
572 |
573 | Structured logs exceeding size typically have a massive component which.
574 | First try to clip the log by just clipping the largest clippable value.
575 | If log still exceeds size. Proportionally clip log values.
576 |
577 | Args:
578 | log: Structured log.
579 | max_log_size: Max_size of the log.
580 |
581 | Returns:
582 | None
583 |
584 | Raises:
585 | ValueError if log cannot be clipped to maxsize.
586 | """
587 | # determine length of total message key + value
588 | total_size = 0
589 | total_value_size = 0
590 | for key, value in log.items():
591 | total_size += len(key)
592 | total_value_size += len(value)
593 | total_size += total_value_size
594 | exceeds_log = total_size - max_log_size
595 | if exceeds_log <= 0:
596 | return
597 |
598 | # remove keys for log values not being adjusted
599 | # message is not adjust
600 | # message signature is not adjusted
601 | log_keys = set(log)
602 | excluded_key_msg_length = 0
603 | excluded_key_msg_value_length = 0
604 | excluded_keys = list(self._get_thread_signature())
605 | excluded_keys.append('message')
606 | for excluded_key in excluded_keys:
607 | if excluded_key not in log_keys:
608 | continue
609 | value_len = len(str(log[excluded_key]))
610 | excluded_key_msg_length += len(excluded_key) + value_len
611 | excluded_key_msg_value_length += value_len
612 | log_keys.remove(excluded_key)
613 | if excluded_key_msg_length >= max_log_size:
614 | raise ValueError('Message exceeds logging msg length limit.')
615 | total_value_size -= excluded_key_msg_value_length
616 |
617 | # message exceeded length limits due to components in structure log
618 | self._log(
619 | 'Next log message exceed cloud ops length limit and as clipped.',
620 | severity=_LogSeverity.WARNING,
621 | struct=tuple(),
622 | stack_frames_back=0,
623 | )
624 |
625 | # First clip largest entry. In most cases a message will have one
626 | # very large tag which causes the length issue first clip the single
627 | # largest entry so its at most not bigger than the second largest entry.
628 | if len(log_keys) > 1:
629 | key_size_list = []
630 | for key in log_keys:
631 | key_size_list.append((key, len(log[key])))
632 | key_size_list = sorted(key_size_list, key=lambda x: x[1])
633 | largest_key = key_size_list[-1][0]
634 | # difference in size between largest and second largest entry
635 | largest_key_size_delta = key_size_list[-1][1] - key_size_list[-2][1]
636 | clip_len = min(largest_key_size_delta, exceeds_log)
637 | if clip_len > 0:
638 | log[largest_key] = log[largest_key][:-clip_len]
639 | # adjust length that needs to be trimmed
640 | exceeds_log -= clip_len
641 | if exceeds_log == 0:
642 | return
643 | # adjust total size of trimmable value component
644 | total_value_size -= clip_len
645 |
646 | # Proportionally clip all tags
647 | new_exceeds_log = exceeds_log
648 | # iterate over a sorted list to make clipping deterministic
649 | for key in sorted(list(log_keys)):
650 | entry_size = len(log[key])
651 | clip_len = math.ceil(entry_size * exceeds_log / total_value_size)
652 | clip_len = min(min(clip_len, new_exceeds_log), entry_size)
653 | if clip_len > 0:
654 | log[key] = log[key][:-clip_len]
655 | new_exceeds_log -= clip_len
656 | if new_exceeds_log == 0:
657 | return
658 | raise ValueError('Message exceeds logging msg length limit.')
659 |
660 | def _merge_signature(
661 | self, struct: Optional[MutableMapping[str, Any]]
662 | ) -> MutableMapping[str, Any]:
663 | """Adds signature to logging struct.
664 |
665 | Args:
666 | struct: logging struct.
667 |
668 | Returns:
669 | Dict to log
670 | """
671 | if struct is None:
672 | struct = collections.OrderedDict()
673 | struct.update(self._get_thread_signature())
674 | return struct
675 |
676 | def _log(
677 | self,
678 | msg: str,
679 | severity: _LogSeverity,
680 | struct: Tuple[Union[Mapping[str, Any], Exception, None], ...],
681 | stack_frames_back: int = 0,
682 | ):
683 | """Posts structured log message, adds current_msg id to log structure.
684 |
685 | Args:
686 | msg: Message to log.
687 | severity: Severity level of message.
688 | struct: Structure to log.
689 | stack_frames_back: Additional stack frames back to log source_location.
690 | """
691 | if not self._enabled:
692 | return
693 | with self._log_lock:
694 | if severity.value < self._log_error_level:
695 | return
696 | struct = self._merge_signature(_merge_struct(struct))
697 | if not self.use_absl_logging() and self._enable_structured_logging:
698 | # Log using structured logs
699 | source_location = _get_source_location_to_log(stack_frames_back + 1)
700 | trace = _add_trace_to_log(
701 | self._gcp_project_name, self._trace_key, struct
702 | )
703 | self._clip_struct_log(struct, MAX_LOG_SIZE)
704 | _py_log(
705 | self.python_logger,
706 | msg,
707 | extra={'json_fields': struct, **source_location, **trace},
708 | severity=severity,
709 | )
710 | return
711 |
712 | # Log using unstructured logs.
713 | structure_str = [msg]
714 | for key in struct:
715 | structure_str.append(f'{key}: {struct[key]}')
716 | _absl_log('; '.join(structure_str), severity=severity)
717 |
718 | def debug(
719 | self,
720 | msg: str,
721 | *struct: Union[Mapping[str, Any], Exception, None],
722 | stack_frames_back: int = 0,
723 | ) -> None:
724 | """Logs with debug severity.
725 |
726 | Args:
727 | msg: message to log (string).
728 | *struct: zero or more dict or exception to log in structured log.
729 | stack_frames_back: Additional stack frames back to log source_location.
730 | """
731 | self._log(msg, _LogSeverity.DEBUG, struct, 1 + stack_frames_back)
732 |
733 | def timed_debug(
734 | self,
735 | msg: str,
736 | *struct: Union[Mapping[str, Any], Exception, None],
737 | stack_frames_back: int = 0,
738 | ) -> None:
739 | """Logs with debug severity and elapsed time since last timed debug log.
740 |
741 | Args:
742 | msg: message to log (string).
743 | *struct: zero or more dict or exception to log in structured log.
744 | stack_frames_back: Additional stack frames back to log source_location.
745 | """
746 | time_now = time.time()
747 | elapsed_time = '%.3f' % (time_now - self._debug_log_time)
748 | self._debug_log_time = time_now
749 | msg = f'[{elapsed_time}] {msg}'
750 | self._log(msg, _LogSeverity.DEBUG, struct, 1 + stack_frames_back)
751 |
752 | def info(
753 | self,
754 | msg: str,
755 | *struct: Union[Mapping[str, Any], Exception, None],
756 | stack_frames_back: int = 0,
757 | ) -> None:
758 | """Logs with info severity.
759 |
760 | Args:
761 | msg: message to log (string).
762 | *struct: zero or more dict or exception to log in structured log.
763 | stack_frames_back: Additional stack frames back to log source_location.
764 | """
765 | self._log(msg, _LogSeverity.INFO, struct, 1 + stack_frames_back)
766 |
767 | def warning(
768 | self,
769 | msg: str,
770 | *struct: Union[Mapping[str, Any], Exception, None],
771 | stack_frames_back: int = 0,
772 | ) -> None:
773 | """Logs with warning severity.
774 |
775 | Args:
776 | msg: Message to log (string).
777 | *struct: Zero or more dict or exception to log in structured log.
778 | stack_frames_back: Additional stack frames back to log source_location.
779 | """
780 | self._log(msg, _LogSeverity.WARNING, struct, 1 + stack_frames_back)
781 |
782 | def error(
783 | self,
784 | msg: str,
785 | *struct: Union[Mapping[str, Any], Exception, None],
786 | stack_frames_back: int = 0,
787 | ) -> None:
788 | """Logs with error severity.
789 |
790 | Args:
791 | msg: Message to log (string).
792 | *struct: Zero or more dict or exception to log in structured log.
793 | stack_frames_back: Additional stack frames back to log source_location.
794 | """
795 | self._log(msg, _LogSeverity.ERROR, struct, 1 + stack_frames_back)
796 |
797 | def critical(
798 | self,
799 | msg: str,
800 | *struct: Union[Mapping[str, Any], Exception, None],
801 | stack_frames_back: int = 0,
802 | ) -> None:
803 | """Logs with critical severity.
804 |
805 | Args:
806 | msg: Message to log (string).
807 | *struct: Zero or more dict or exception to log in structured log.
808 | stack_frames_back: Additional stack frames back to log source_location.
809 | """
810 | self._log(msg, _LogSeverity.CRITICAL, struct, 1 + stack_frames_back)
811 |
812 | @property
813 | def log_error_level(self) -> int:
814 | return self._log_error_level
815 |
816 | @log_error_level.setter
817 | def log_error_level(self, level: int) -> None:
818 | with self._log_lock:
819 | self._log_error_level = level
820 |
821 |
822 | # Logging interfaces are used from processes which are forked (gunicorn,
823 | # DICOM Proxy, Orchestrator, Refresher). In Python, forked processes do not
824 | # copy threads running within parent processes or re-initialize global/module
825 | # state. This can result in forked modules being executed with invalid global
826 | # state, e.g., acquired locks that will not release or references to invalid
827 | # state. The cloud logging library utilizes a background thread transporting
828 | # logs to cloud. The background threading is not compatible with forking and
829 | # will seg-fault (python queue wait). This can be avoided, by stopping the
830 | # background transport prior to forking and then restarting the transport
831 | # following the fork.
832 | os.register_at_fork(
833 | before=CloudLoggingClientInstance.fork_shutdown, # pylint: disable=protected-access
834 | after_in_child=CloudLoggingClientInstance._init_fork_module_state, # pylint: disable=protected-access
835 | )
836 |
--------------------------------------------------------------------------------