├── python ├── serving │ ├── logging_lib │ │ ├── __init__.py │ │ ├── flags │ │ │ ├── __init__.py │ │ │ ├── flag_utils.py │ │ │ ├── flag_utils_test.py │ │ │ ├── secret_flag_utils.py │ │ │ └── secret_flag_utils_test.py │ │ ├── cloud_logging_client.py │ │ └── cloud_logging_client_instance.py │ ├── serving_framework │ │ ├── triton │ │ │ ├── requirements.in │ │ │ ├── __init__.py │ │ │ ├── server_health_check.py │ │ │ ├── server_health_check_test.py │ │ │ ├── triton_server_model_runner.py │ │ │ ├── triton_streaming_server_model_runner.py │ │ │ └── triton_server_model_runner_test.py │ │ ├── pip-install.txt │ │ ├── requirements.in │ │ ├── README.md │ │ ├── __init__.py │ │ ├── model_transfer.py │ │ ├── inline_prediction_executor_test.py │ │ ├── inline_prediction_executor.py │ │ ├── model_runner.py │ │ ├── server_gunicorn_test.py │ │ └── server_gunicorn.py │ ├── model_requirements.in │ ├── requirements.in │ ├── model_repository │ │ └── default │ │ │ ├── 1 │ │ │ └── model.py │ │ │ └── config.pbtxt │ ├── __init__.py │ ├── predictor_test.py │ ├── entrypoint.sh │ ├── server_gunicorn.py │ ├── Dockerfile │ ├── vertex_schemata │ │ ├── response.yaml │ │ └── request.yaml │ └── predictor.py └── requirements.txt ├── notebooks ├── README.md └── quick_start_with_model_garden.ipynb ├── CONTRIBUTING.md ├── README.md └── LICENSE /python/serving/logging_lib/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /python/serving/logging_lib/flags/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /python/requirements.txt: -------------------------------------------------------------------------------- 1 | -r serving/requirements.txt 2 | -------------------------------------------------------------------------------- /python/serving/serving_framework/triton/requirements.in: -------------------------------------------------------------------------------- 1 | requests~=2.32.3 2 | # TODO: b/375469331 - Enable testing with most current requests-mock release. 3 | requests-mock==1.9.3 4 | tritonclient~=2.56.0 5 | typing-extensions~=4.12.2 -------------------------------------------------------------------------------- /python/serving/model_requirements.in: -------------------------------------------------------------------------------- 1 | torch~=2.9.1 2 | # TODO(b/468062329): Replace with regular pip version when released. 3 | https://github.com/huggingface/transformers/archive/65dc261512cbdb1ee72b88ae5b222f2605aad8e5.tar.gz#egg=transformers -------------------------------------------------------------------------------- /python/serving/serving_framework/pip-install.txt: -------------------------------------------------------------------------------- 1 | # Requirements file for updating pip. 2 | pip==25.0 \ 3 | --hash=sha256:8e0a97f7b4c47ae4a494560da84775e9e2f671d415d8d828e052efefb206b30b \ 4 | --hash=sha256:b6eb97a803356a52b2dd4bb73ba9e65b2ba16caa6bcb25a7497350a4e5859b65 5 | -------------------------------------------------------------------------------- /python/serving/serving_framework/requirements.in: -------------------------------------------------------------------------------- 1 | absl-py>=2.1.0 2 | flask~=3.0.3 3 | grpcio~=1.68.1 4 | grpcio-status~=1.68.1 5 | gunicorn~=23.0.0 6 | jsonschema~=4.23.0 7 | numpy<=2.0.2 # bypassing faulty version restriction in tritonclient 8 | setuptools~=75.6.0 9 | typing-extensions~=4.12.2 10 | -------------------------------------------------------------------------------- /python/serving/requirements.in: -------------------------------------------------------------------------------- 1 | google-cloud-core~=2.4.3 2 | google-cloud-logging~=3.12.1 3 | google-cloud-secret-manager==2.20.2 4 | psutil==7.1.3 5 | # TODO(b/468062329): Replace with regular pip version when released. 6 | https://github.com/huggingface/transformers/archive/65dc261512cbdb1ee72b88ae5b222f2605aad8e5.tar.gz#egg=transformers 7 | scipy~=1.16.3 8 | torch~=2.9.1 9 | -r serving_framework/requirements.in 10 | -r serving_framework/triton/requirements.in 11 | -------------------------------------------------------------------------------- /python/serving/model_repository/default/config.pbtxt: -------------------------------------------------------------------------------- 1 | name: "default" 2 | backend: "pytorch" 3 | runtime: "model.py" 4 | input [ 5 | { 6 | name: "input_features__0" 7 | data_type: TYPE_FP32 8 | dims: [-1, -1, 128] 9 | }, 10 | { 11 | name: "attention_mask__1" 12 | data_type: TYPE_BOOL 13 | dims: [-1, -1] 14 | } 15 | ] 16 | output [ 17 | { 18 | name: "tokens__0" 19 | data_type: TYPE_INT64 20 | dims: [-1, -1] 21 | } 22 | ] 23 | -------------------------------------------------------------------------------- /python/serving/serving_framework/README.md: -------------------------------------------------------------------------------- 1 | # Serving framework 2 | 3 | This directory contains a Python library that simplifies the creation of custom 4 | prediction containers for Vertex AI. It provides the framework for building HTTP 5 | servers that meet the platform's 6 | [requirements](https://cloud.google.com/vertex-ai/docs/predictions/custom-container-requirements). 7 | 8 | To implement a model-specific HTTP server, frame the custom data handling and 9 | orchestration logic within this framework. 10 | -------------------------------------------------------------------------------- /python/serving/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | -------------------------------------------------------------------------------- /python/serving/serving_framework/__init__.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright 2024 Google LLC 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | -------------------------------------------------------------------------------- /python/serving/serving_framework/triton/__init__.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright 2025 Google LLC 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | -------------------------------------------------------------------------------- /notebooks/README.md: -------------------------------------------------------------------------------- 1 | # MedASR Notebooks 2 | 3 | * [Quick start with Hugging Face](quick_start_with_hugging_face.ipynb) - This 4 | notebook demonstrates how to quickly get started with MedASR, Google's open 5 | Automatic Speech Recognition (ASR) model fine-tuned for the medical domain. 6 | It shows how to load the MedASR model from Hugging Face and use the 7 | transformers library pipeline to perform speech-to-text transcription on an 8 | example audio file. 9 | 10 | * [Quick start with Model Garden](quick_start_with_model_garden.ipynb) - This 11 | notebook demonstrates how to use MedASR as a service deployed on 12 | [Vertex AI](https://cloud.google.com/vertex-ai/docs/predictions/overview). 13 | It shows how to use the service API to perform speech-to-text transcription 14 | in online or batch workflows. 15 | 16 | * [Fine-tune with Hugging Face](fine_tune_with_hugging_face.ipynb) - This 17 | notebook shows how to fine-tune the model locally using Hugging Face 18 | libraries. 19 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # How to Contribute 2 | 3 | We would love to accept your patches and contributions to this project. 4 | 5 | ## Before you begin 6 | 7 | ### Sign our Contributor License Agreement 8 | 9 | Contributions to this project must be accompanied by a 10 | [Contributor License Agreement](https://cla.developers.google.com/about) (CLA). 11 | You (or your employer) retain the copyright to your contribution; this simply 12 | gives us permission to use and redistribute your contributions as part of the 13 | project. 14 | 15 | If you or your current employer have already signed the Google CLA (even if it 16 | was for a different project), you probably don't need to do it again. 17 | 18 | Visit to see your current agreements or to 19 | sign a new one. 20 | 21 | ### Review our Community Guidelines 22 | 23 | This project follows HAI-DEF's 24 | [Community guidelines](https://developers.google.com/health-ai-developer-foundations/community-guidelines) 25 | 26 | ## Contribution process 27 | 28 | ### Code Reviews 29 | 30 | All submissions, including submissions by project members, require review. We 31 | use [GitHub pull requests](https://docs.github.com/articles/about-pull-requests) 32 | for this purpose. 33 | 34 | For more details, read HAI-DEF's 35 | [Contributing guidelines](https://developers.google.com/health-ai-developer-foundations/community-guidelines#contributing) 36 | -------------------------------------------------------------------------------- /python/serving/predictor_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Tests for MedASR predictor.""" 16 | 17 | import base64 18 | 19 | from absl.testing import absltest 20 | from absl.testing import parameterized 21 | 22 | from serving import predictor 23 | 24 | 25 | class PredictorTest(parameterized.TestCase): 26 | 27 | @parameterized.named_parameters( 28 | { 29 | 'testcase_name': 'not_base64', 30 | 'audio_string': 'not_base64', 31 | }, 32 | { 33 | 'testcase_name': 'not_wav', 34 | 'audio_string': base64.b64encode(b'not_wav').decode('utf-8'), 35 | }, 36 | ) 37 | def test_bad_file_loading_raises_error(self, audio_string: str): 38 | with self.assertRaises(ValueError): 39 | predictor.load_wav(audio_string) 40 | 41 | 42 | if __name__ == '__main__': 43 | absltest.main() 44 | -------------------------------------------------------------------------------- /python/serving/serving_framework/triton/server_health_check.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """REST-based health check implementation for Tensorflow model servers.""" 16 | 17 | import http 18 | 19 | import requests 20 | from typing_extensions import override 21 | 22 | from serving.serving_framework import server_gunicorn 23 | 24 | 25 | class TritonServerHealthCheck(server_gunicorn.ModelServerHealthCheck): 26 | """Checks the health of the local model server via REST request.""" 27 | 28 | def __init__(self, health_check_port: int): 29 | self._health_check_url = ( 30 | f"http://localhost:{health_check_port}/v2/health/ready" 31 | ) 32 | 33 | @override 34 | def check_health(self) -> bool: 35 | try: 36 | r = requests.get(self._health_check_url) 37 | return r.status_code == http.HTTPStatus.OK.value 38 | except requests.exceptions.ConnectionError: 39 | return False 40 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # MedASR 2 | 3 | MedASR is an automated speech recognition (ASR) model that takes spoken audio as an input and represents it as text. MedASR has been specifically trained for the healthcare domain. This means it has higher out-of-the box performance in healthcare contexts than general ASR models. MedASR is open and can be fine-tuned for even higher performance. 4 | 5 | ## Get started 6 | 7 | * Read our 8 | [developer documentation](https://developers.google.com/health-ai-developer-foundations/medgemma/get-started) 9 | to see the full range of next steps available, including learning more about 10 | the model through its 11 | [model card](https://developers.google.com/health-ai-developer-foundations/medasr/model-card). 12 | 13 | * Explore this repository, which contains [notebooks](./notebooks) for using 14 | the model. 15 | 16 | * Visit the model on 17 | [Hugging Face](https://huggingface.co/models?other=medasr) or 18 | [Model Garden](https://console.cloud.google.com/vertex-ai/publishers/google/model-garden/medasr). 19 | 20 | ## Contributing 21 | 22 | We are open to bug reports, pull requests (PR), and other contributions. See 23 | [CONTRIBUTING](CONTRIBUTING.md) and 24 | [community guidelines](https://developers.google.com/health-ai-developer-foundations/community-guidelines) 25 | for details. 26 | 27 | ## License 28 | 29 | While the model is licensed under the 30 | [Health AI Developer Foundations License](https://developers.google.com/health-ai-developer-foundations/terms), 31 | everything in this repository is licensed under the Apache 2.0 license, see 32 | [LICENSE](LICENSE). 33 | -------------------------------------------------------------------------------- /python/serving/serving_framework/model_transfer.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | r"""A tool to transfer a model from the Vertex gcs bucket to a local directory. 16 | 17 | Copies a target gcs directory into a local directory using default credentials, 18 | intended to be used to set up tfserving-compatible model directories during 19 | serving framework startup. 20 | 21 | usage: 22 | python3 model_transfer.py --gcs_path="gs://bucket/object" \ 23 | --local_path="/path/to/local/dir" 24 | """ 25 | 26 | from collections.abc import Sequence 27 | 28 | from absl import app 29 | from absl import flags 30 | 31 | from google.cloud.aiplatform.utils import gcs_utils 32 | 33 | 34 | _GCS_PATH = flags.DEFINE_string( 35 | "gcs_path", 36 | None, 37 | "The gcs path to copy from.", 38 | required=True, 39 | ) 40 | _LOCAL_PATH = flags.DEFINE_string( 41 | "local_path", 42 | None, 43 | "The local path to copy to.", 44 | required=True, 45 | ) 46 | 47 | 48 | def main(argv: Sequence[str]) -> None: 49 | if len(argv) > 1: 50 | raise app.UsageError("Too many command-line arguments.") 51 | gcs_utils.download_from_gcs(_GCS_PATH.value, _LOCAL_PATH.value) 52 | 53 | if __name__ == "__main__": 54 | app.run(main) 55 | -------------------------------------------------------------------------------- /python/serving/entrypoint.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Copyright 2025 Google LLC 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | # This script launches the serving framework, run as the entrypoint. 17 | 18 | # Exit if expanding an undefined variable. 19 | set -u 20 | 21 | export MODEL_REST_PORT=8600 22 | export AIP_PREDICT_ROUTE="/v1/audio/transcriptions" 23 | 24 | if [[ -v "AIP_STORAGE_URI" && -n "$AIP_STORAGE_URI" ]]; then 25 | export MODEL_FILES="/model_files" 26 | mkdir "$MODEL_FILES" 27 | gcloud storage cp "$AIP_STORAGE_URI/*" "$MODEL_FILES" --recursive 28 | fi 29 | 30 | echo "Serving framework start, launching model server" 31 | 32 | (/opt/tritonserver/bin/tritonserver \ 33 | --model-repository="/serving/model_repository" \ 34 | --allow-grpc=true \ 35 | --grpc-address=127.0.0.1 \ 36 | --grpc-port=8500 \ 37 | --allow-http=true \ 38 | --http-address=127.0.0.1 \ 39 | --http-port="${MODEL_REST_PORT}" \ 40 | --allow-vertex-ai=false \ 41 | --strict-readiness=true || exit) & 42 | 43 | echo "Launching front end" 44 | 45 | (/server-env/bin/python3.12 -m serving.server_gunicorn --alsologtostderr \ 46 | --verbosity=1 || exit)& 47 | 48 | # Wait for any process to exit 49 | wait -n 50 | 51 | # Exit with status of process that exited first 52 | exit $? 53 | -------------------------------------------------------------------------------- /python/serving/logging_lib/flags/flag_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Utility functions for flags.""" 16 | import os 17 | 18 | 19 | def str_to_bool(val: str) -> bool: 20 | """Converts a string representation of truth. 21 | 22 | True values are 'y', 'yes', 't', 'true', 'on', and '1'; 23 | False values are 'n', 'no', 'f', 'false', 'off', and '0'. 24 | 25 | Args: 26 | val: String to convert to bool. 27 | 28 | Returns: 29 | Boolean result 30 | 31 | Raises: 32 | ValueError if val is anything else. 33 | """ 34 | val = val.strip().lower() 35 | if val in ('y', 'yes', 't', 'true', 'on', '1'): 36 | return True 37 | if val in ('n', 'no', 'f', 'false', 'off', '0'): 38 | return False 39 | raise ValueError(f'invalid truth value {str(val)}') 40 | 41 | 42 | def env_value_to_bool(env_name: str, undefined_value: bool = False) -> bool: 43 | """Converts environmental variable value into boolean value for flag init. 44 | 45 | Args: 46 | env_name: Environmental variable name. 47 | undefined_value: Default value to set undefined values to. 48 | 49 | Returns: 50 | Boolean of environmental variable string value. 51 | 52 | Raises: 53 | ValueError : Environmental variable cannot be parsed to bool. 54 | """ 55 | if env_name not in os.environ: 56 | return undefined_value 57 | env_value = os.environ[env_name].strip() 58 | return str_to_bool(env_value) 59 | -------------------------------------------------------------------------------- /python/serving/model_repository/default/1/model.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright 2025 Google LLC 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Model file compatible with Triton PyTorch 2.0 backend.""" 17 | 18 | import os 19 | from typing import Optional 20 | 21 | import torch 22 | from transformers.models import lasr 23 | 24 | 25 | # Override dir to hide imports from the Triton backend's loading strategy. 26 | # Without this, the backend can attempt to load from the imports 27 | # instead of MedASRWrapper. 28 | def __dir__(): 29 | return ["MedASRWrapper", "__name__", "__spec__"] 30 | 31 | 32 | class MedASRWrapper(torch.nn.Module): 33 | """Wraps SiglipModel with custom weight loading and return structure.""" 34 | 35 | def __init__(self): 36 | super(MedASRWrapper, self).__init__() 37 | token = None 38 | if os.getenv("AIP_STORAGE_URI"): 39 | # Using model files copied from Vertex GCS bucket. 40 | model_origin = os.getenv("MODEL_FILES") 41 | else: 42 | # Using model files from HF repository. 43 | model_origin = os.getenv("MODEL_ID") 44 | if not model_origin: 45 | raise ValueError( 46 | "No model origin found. MODEL_ID or AIP_STORAGE_URI must be set." 47 | ) 48 | token = os.getenv("HF_TOKEN") # optional for access to non-public models. 49 | self._model = lasr.LasrForCTC.from_pretrained( 50 | model_origin, 51 | token=token, 52 | ) 53 | 54 | def forward( 55 | self, 56 | input_features: Optional[torch.Tensor] = None, 57 | attention_mask: Optional[torch.Tensor] = None, 58 | ): 59 | output = self._model.generate( 60 | input_features=input_features, attention_mask=attention_mask 61 | ) 62 | return output 63 | -------------------------------------------------------------------------------- /python/serving/logging_lib/flags/flag_utils_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Tests for flag utils.""" 16 | import os 17 | from unittest import mock 18 | 19 | from absl.testing import absltest 20 | from absl.testing import parameterized 21 | 22 | from serving.logging_lib.flags import flag_utils 23 | 24 | # const 25 | _UNDEFINED_ENV_VAR_NAME = 'UNDEFINED' 26 | 27 | 28 | class FlagUtilsTest(parameterized.TestCase): 29 | 30 | @parameterized.parameters(['y', ' YES ', 't', 'tRue', 'on', '1']) 31 | def test_str_to_bool_true(self, val): 32 | self.assertTrue(flag_utils.str_to_bool(val)) 33 | 34 | @parameterized.parameters(['n', 'no', 'f', 'FALSE', ' oFf ', '0']) 35 | def test_strtobool_false(self, val): 36 | self.assertFalse(flag_utils.str_to_bool(val)) 37 | 38 | def test_str_to_bool_raises(self): 39 | with self.assertRaises(ValueError): 40 | flag_utils.str_to_bool('ABCD') 41 | 42 | def test_undefined_env_default_true(self): 43 | self.assertNotIn(_UNDEFINED_ENV_VAR_NAME, os.environ) 44 | self.assertTrue(flag_utils.env_value_to_bool(_UNDEFINED_ENV_VAR_NAME, True)) 45 | 46 | def test_undefined_env_no_default(self): 47 | self.assertNotIn(_UNDEFINED_ENV_VAR_NAME, os.environ) 48 | self.assertFalse(flag_utils.env_value_to_bool(_UNDEFINED_ENV_VAR_NAME)) 49 | 50 | @mock.patch.dict(os.environ, {'FOO': ' True '}) 51 | def test_initialized_env(self): 52 | self.assertTrue(flag_utils.env_value_to_bool('FOO')) 53 | 54 | @mock.patch.dict(os.environ, {'BAD_VALUE': 'TrueABSCDEF'}) 55 | def test_bad_initialized_env(self): 56 | with self.assertRaises(ValueError): 57 | flag_utils.env_value_to_bool('BAD_VALUE') 58 | 59 | 60 | if __name__ == '__main__': 61 | absltest.main() 62 | -------------------------------------------------------------------------------- /python/serving/serving_framework/triton/server_health_check_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import http 16 | import os 17 | from unittest import mock 18 | 19 | import requests 20 | import requests_mock 21 | 22 | from absl.testing import absltest 23 | from serving.serving_framework import server_gunicorn 24 | from serving.serving_framework.triton import server_health_check 25 | 26 | 27 | class ServerHealthCheckTest(absltest.TestCase): 28 | 29 | def setUp(self): 30 | super().setUp() 31 | os.environ["AIP_PREDICT_ROUTE"] = "/fake-predict-route" 32 | os.environ["AIP_HEALTH_ROUTE"] = "/fake-health-route" 33 | 34 | @requests_mock.Mocker() 35 | def test_health_route_pass_check(self, mock_requests): 36 | mock_requests.register_uri( 37 | "GET", 38 | "http://localhost:12345/v2/health/ready", 39 | text="assorted_metadata", 40 | status_code=http.HTTPStatus.OK, 41 | ) 42 | 43 | executor = mock.create_autospec( 44 | server_gunicorn.PredictionExecutor, 45 | instance=True, 46 | ) 47 | 48 | app = server_gunicorn.PredictionApplication( 49 | executor, 50 | health_check=server_health_check.TritonServerHealthCheck( 51 | 12345 52 | ), 53 | ).load() 54 | service = app.test_client() 55 | 56 | response = service.get("/fake-health-route") 57 | 58 | self.assertEqual(response.status_code, http.HTTPStatus.OK) 59 | self.assertEqual(response.text, "ok") 60 | 61 | @requests_mock.Mocker() 62 | def test_health_route_fail_check(self, mock_requests): 63 | mock_requests.register_uri( 64 | "GET", 65 | "http://localhost:12345/v2/health/ready", 66 | exc=requests.exceptions.ConnectionError, 67 | ) 68 | executor = mock.create_autospec( 69 | server_gunicorn.PredictionExecutor, 70 | instance=True, 71 | ) 72 | 73 | app = server_gunicorn.PredictionApplication( 74 | executor, 75 | health_check=server_health_check.TritonServerHealthCheck( 76 | 12345 77 | ), 78 | ).load() 79 | service = app.test_client() 80 | 81 | response = service.get("/fake-health-route") 82 | 83 | self.assertEqual(response.status_code, http.HTTPStatus.SERVICE_UNAVAILABLE) 84 | self.assertEqual(response.text, "not ok") 85 | 86 | 87 | if __name__ == "__main__": 88 | absltest.main() 89 | -------------------------------------------------------------------------------- /python/serving/server_gunicorn.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Launcher for the prediction_executor based encoder server. 16 | 17 | Uses the serving framework to create a request server which 18 | performs the logic for requests in separate processes and uses a local NVIDIA 19 | Triton model server to handle the model. 20 | """ 21 | 22 | from collections.abc import Sequence 23 | import os 24 | 25 | from absl import app 26 | from absl import logging 27 | import jsonschema 28 | import yaml 29 | 30 | from serving.serving_framework import inline_prediction_executor 31 | from serving.serving_framework import server_gunicorn 32 | from serving.serving_framework.triton import server_health_check 33 | from serving.serving_framework.triton import triton_server_model_runner 34 | from serving import predictor 35 | 36 | 37 | def main(argv: Sequence[str]) -> None: 38 | if len(argv) > 1: 39 | raise app.UsageError('Too many command-line arguments.') 40 | if 'AIP_HTTP_PORT' not in os.environ: 41 | raise ValueError( 42 | 'The environment variable AIP_HTTP_PORT needs to be specified.' 43 | ) 44 | http_port = int(os.environ.get('AIP_HTTP_PORT')) 45 | options = { 46 | 'bind': f'0.0.0.0:{http_port}', 47 | 'workers': 3, 48 | 'timeout': 240, 49 | } 50 | model_rest_port = int(os.environ.get('MODEL_REST_PORT')) 51 | health_checker = server_health_check.TritonServerHealthCheck(model_rest_port) 52 | # Get schema validators. 53 | local_path = os.path.dirname(__file__) 54 | with open( 55 | os.path.join(local_path, 'vertex_schemata', 'request.yaml'), 'r' 56 | ) as f: 57 | instance_validator = jsonschema.Draft202012Validator(yaml.safe_load(f)) 58 | # with open( 59 | # os.path.join(local_path, 'vertex_schemata', 'response.yaml'), 'r' 60 | # ) as f: 61 | # prediction_validator = jsonschema.Draft202012Validator(yaml.safe_load(f)) 62 | prediction_validator = None 63 | predictor_instance = predictor.MedASRPredictor( 64 | model_source=os.environ.get('MODEL_FILES') 65 | ) 66 | logging.info('Launching gunicorn application.') 67 | server_gunicorn.PredictionApplication( 68 | inline_prediction_executor.InlinePredictionExecutor( 69 | predictor_instance.predict, 70 | triton_server_model_runner.TritonServerModelRunner, 71 | ), 72 | health_check=health_checker, 73 | options=options, 74 | instance_input=False, 75 | input_validator=instance_validator, 76 | prediction_validator=prediction_validator, 77 | ).run() 78 | 79 | 80 | if __name__ == '__main__': 81 | app.run(main) 82 | -------------------------------------------------------------------------------- /python/serving/serving_framework/inline_prediction_executor_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from collections.abc import Mapping, Set 16 | from unittest import mock 17 | from typing import Any 18 | 19 | import numpy as np 20 | 21 | from absl.testing import absltest 22 | from serving.serving_framework import inline_prediction_executor 23 | from serving.serving_framework import model_runner 24 | 25 | 26 | class DummyModelRunner(model_runner.ModelRunner): 27 | """Dummy model runner for testing.""" 28 | 29 | def run_model_multiple_output( 30 | self, 31 | model_input: Mapping[str, np.ndarray] | np.ndarray, 32 | *, 33 | model_name: str = "default", 34 | model_version: int | None = None, 35 | model_output_keys: Set[str], 36 | parameters: Mapping[str, Any] | None = None, 37 | ) -> Mapping[str, np.ndarray]: 38 | del model_name, model_version, model_output_keys, parameters 39 | return {"output_0": np.ones((1, 2), dtype=np.float32)} 40 | 41 | 42 | class InlinePredictionExecutorTest(absltest.TestCase): 43 | 44 | def test_predict_requires_start(self): 45 | predictor = mock.MagicMock() 46 | executor = inline_prediction_executor.InlinePredictionExecutor( 47 | predictor, DummyModelRunner 48 | ) 49 | with self.assertRaises(RuntimeError): 50 | executor.predict({"placeholder": "input"}) 51 | 52 | def test_execute_catches_predictor_exception(self): 53 | predictor = mock.MagicMock(side_effect=Exception("test error")) 54 | executor = inline_prediction_executor.InlinePredictionExecutor( 55 | predictor, DummyModelRunner 56 | ) 57 | executor.start() 58 | with self.assertRaises(RuntimeError): 59 | executor.execute({"placeholder": "input"}) 60 | 61 | def test_execute_calls_predictor(self): 62 | predictor = mock.MagicMock(return_value={"placeholder": "output"}) 63 | mock_model_runner = mock.create_autospec( 64 | DummyModelRunner, instance=True 65 | ) 66 | mock_model_runner_class = mock.create_autospec( 67 | DummyModelRunner, autospec=True 68 | ) 69 | mock_model_runner_class.return_value = mock_model_runner 70 | executor = inline_prediction_executor.InlinePredictionExecutor( 71 | predictor, mock_model_runner_class 72 | ) 73 | 74 | executor.start() 75 | self.assertEqual( 76 | executor.execute({"placeholder": "input"}), 77 | {"placeholder": "output"}, 78 | ) 79 | mock_model_runner_class.assert_called_once() 80 | predictor.assert_called_once_with( 81 | {"placeholder": "input"}, mock_model_runner 82 | ) 83 | 84 | 85 | if __name__ == "__main__": 86 | absltest.main() 87 | -------------------------------------------------------------------------------- /python/serving/serving_framework/inline_prediction_executor.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """A thin shell to fit a predictor function into the PredictionExecutor interface. 16 | 17 | This is a convenience wrapper to allow a predictor function to be used directly 18 | as a PredictionExecutor. 19 | 20 | Intended usage in launching a server: 21 | predictor = Predictor() 22 | executor = InlinePredictionExecutor(predictor.predict) 23 | server_gunicorn.PredictionApplication(executor, ...).run() 24 | """ 25 | 26 | from collections.abc import Callable 27 | from typing import Any 28 | 29 | from typing_extensions import override 30 | 31 | from serving.serving_framework import model_runner 32 | from serving.serving_framework import server_gunicorn 33 | 34 | 35 | class InlinePredictionExecutor(server_gunicorn.PredictionExecutor): 36 | """Provides prediction request execution using an inline function. 37 | 38 | Provides a little framing to simplify the use of predictor functions in the 39 | server worker process. 40 | 41 | If a function call with no setup is insufficient, overriding the start and 42 | predict methods here can be used to provide more complex behavior. Inheritance 43 | directly from PredictionExecutor is also an option. 44 | """ 45 | 46 | def __init__( 47 | self, 48 | predictor: Callable[ 49 | [dict[str, Any], model_runner.ModelRunner], dict[str, Any] 50 | ], 51 | model_runner_source: Callable[[], model_runner.ModelRunner] 52 | ): 53 | self._predictor = predictor 54 | self._model_runner = None 55 | self._model_runner_source = model_runner_source 56 | 57 | @override 58 | def start(self) -> None: 59 | """Starts the executor. 60 | 61 | Called after the Gunicorn worker process has started. Performs any setup 62 | which needs to be done post-fork. 63 | """ 64 | # Safer to instantiate the RPC stub post-fork. 65 | self._model_runner = self._model_runner_source() 66 | 67 | def predict(self, input_json: dict[str, Any]) -> dict[str, Any]: 68 | """Executes the given request payload.""" 69 | if self._model_runner is None: 70 | raise RuntimeError( 71 | "Model runner is not initialized. Please call start() first." 72 | ) 73 | return self._predictor(input_json, self._model_runner) 74 | 75 | @override 76 | def execute(self, input_json: dict[str, Any]) -> dict[str, Any]: 77 | """Executes the given request payload. 78 | 79 | Args: 80 | input_json: The full json prediction request payload. 81 | 82 | Returns: 83 | The json response to the prediction request. 84 | 85 | Raises: 86 | RuntimeError: Prediction failed in an unhandled way. 87 | """ 88 | try: 89 | return self.predict(input_json) 90 | except Exception as e: 91 | raise RuntimeError("Unhandled exception from predictor.") from e 92 | -------------------------------------------------------------------------------- /python/serving/Dockerfile: -------------------------------------------------------------------------------- 1 | # Copyright 2025 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | # This is used to build a Docker image that includes the necessary dependencies 16 | # for running MedGemma as a microservice. 17 | 18 | FROM python:3.12-slim AS prep 19 | 20 | # Get pip requirements 21 | COPY ./python/serving/requirements.txt requirements.txt 22 | COPY ./python/serving/serving_framework/pip-install.txt pip-install.txt 23 | RUN python3.12 -m venv --copies /payload/server-env && \ 24 | /payload/server-env/bin/python3.12 -m pip install --require-hashes \ 25 | -r pip-install.txt && \ 26 | /payload/server-env/bin/python3.12 -m pip install --require-hashes \ 27 | -r requirements.txt 28 | 29 | # Set up code directories 30 | COPY ./python/serving /payload/serving 31 | COPY ./LICENSE /payload/LICENSE 32 | RUN chmod a+x /payload/serving/entrypoint.sh 33 | 34 | # Install git and clone source code for mirroring 35 | RUN apt-get update && \ 36 | apt-get install --no-install-recommends -y git && \ 37 | rm -rf /var/lib/apt/lists/* 38 | RUN git clone https://github.com/certifi/python-certifi.git /source-mirror/python-certifi && \ 39 | git clone https://github.com/tqdm/tqdm.git /source-mirror/tqdm && \ 40 | git clone https://git.launchpad.net/launchpadlib /source-mirror/launchpadlib && \ 41 | git clone https://git.launchpad.net/lazr.restfulclient /source-mirror/lazr.restfulclient && \ 42 | git clone https://git.launchpad.net/lazr.uri /source-mirror/lazr.uri && \ 43 | git clone https://git.launchpad.net/wadllib /source-mirror/wadllib && \ 44 | git clone https://gitlab.gnome.org/GNOME/pygobject.git /source-mirror/pygobject && \ 45 | git clone https://git.launchpad.net/python-apt /source-mirror/python-apt 46 | 47 | # ---------------------------------------------------------------------------- 48 | 49 | FROM google/cloud-sdk:stable AS gcloud_builder 50 | 51 | # ---------------------------------------------------------------------------- 52 | 53 | FROM nvcr.io/nvidia/tritonserver:25.03-py3 54 | 55 | WORKDIR / 56 | # Install model dependencies for the triton backend. 57 | COPY ./python/serving/model_requirements.txt model_requirements.txt 58 | RUN python3 -m pip install --require-hashes -r /model_requirements.txt 59 | # Copy in minimal gcloud install 60 | COPY --from=gcloud_builder /usr/lib/google-cloud-sdk /usr/lib/google-cloud-sdk 61 | ENV PATH="/usr/lib/google-cloud-sdk/bin:$PATH" 62 | 63 | # Prevent compatibility-breaking NCCL updates. 64 | RUN apt-mark hold libnccl-dev libnccl2 && \ 65 | # apt upgrade 66 | apt-get update && \ 67 | apt-get -y upgrade && \ 68 | # Cleanup cached files. 69 | rm -rf /var/lib/apt/lists/* 70 | 71 | # Copy the code and python environment from the prep image 72 | COPY --from=prep /payload / 73 | # Copy mirrored source code from the prep image 74 | COPY --from=prep /source-mirror /source-mirror/ 75 | 76 | ENTRYPOINT ["/serving/entrypoint.sh"] 77 | -------------------------------------------------------------------------------- /python/serving/serving_framework/triton/triton_server_model_runner.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """ModelRunner implementation using grpc to NVIDIA Triton model server. 16 | 17 | Uses Triton GRPC client and relies on a model server running locally. 18 | """ 19 | 20 | from collections.abc import Mapping, Set 21 | from typing import Any 22 | 23 | from absl import logging 24 | import numpy as np 25 | from tritonclient import grpc as triton_grpc 26 | from tritonclient import utils as triton_utils 27 | from typing_extensions import override 28 | 29 | from serving.serving_framework import model_runner 30 | 31 | _HOSTPORT = "localhost:8500" 32 | 33 | 34 | class TritonServerModelRunner(model_runner.ModelRunner): 35 | """ModelRunner implementation using grpc to NVIDIA Triton model server.""" 36 | 37 | def __init__(self, client: triton_grpc.InferenceServerClient | None = None): 38 | if client is not None: 39 | self._client = client 40 | else: 41 | self._client = triton_grpc.InferenceServerClient(_HOSTPORT) 42 | 43 | @override 44 | def run_model_multiple_output( 45 | self, 46 | model_input: Mapping[str, np.ndarray] | np.ndarray, 47 | *, 48 | model_name: str = "default", 49 | model_version: int | None = None, 50 | model_output_keys: Set[str], 51 | parameters: Mapping[str, Any] | None = None, 52 | ) -> Mapping[str, np.ndarray]: 53 | """Runs a model on the given input and returns multiple outputs. 54 | 55 | Args: 56 | model_input: An array or map of arrays comprising the input tensors for 57 | the model. A bare array is keyed by "inputs". 58 | model_name: The name of the model to run. 59 | model_version: The version of the model to run. Uses default if None. 60 | model_output_keys: The desired model output keys. 61 | parameters: Additional parameters to pass to the model. 62 | 63 | Returns: 64 | A mapping of model output keys to tensors. 65 | 66 | Raises: 67 | KeyError: If any of the model_output_keys are not found in the model 68 | output. 69 | """ 70 | # If a bare np.ndarray was passed, it will be passed using the default 71 | # input key "inputs". 72 | # If a Mapping was passed, use the keys from that mapping. 73 | if isinstance(model_input, np.ndarray): 74 | logging.debug("Handling bare input tensor.") 75 | input_map = {"inputs": model_input} 76 | else: 77 | input_map = model_input 78 | 79 | model_version = str(model_version) if model_version is not None else "" 80 | 81 | inputs = [] 82 | for key, data in input_map.items(): 83 | input_tensor = triton_grpc.InferInput( 84 | key, data.shape, triton_utils.np_to_triton_dtype(data.dtype) 85 | ) 86 | input_tensor.set_data_from_numpy(data) 87 | inputs.append(input_tensor) 88 | 89 | model_parameters = None if parameters is None else dict(parameters) 90 | 91 | result = self._client.infer( 92 | model_name, inputs, model_version, parameters=model_parameters 93 | ) 94 | assert result is not None # infer never returns None, despite annotation. 95 | 96 | outputs = {key: result.as_numpy(key) for key in model_output_keys} 97 | missing_keys = {key for key in model_output_keys if outputs[key] is None} 98 | if missing_keys: 99 | raise KeyError( 100 | f"Model output keys {missing_keys} not found in model output." 101 | ) 102 | return outputs 103 | -------------------------------------------------------------------------------- /python/serving/vertex_schemata/response.yaml: -------------------------------------------------------------------------------- 1 | # Copyright 2025 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | title: MedASRResponse 16 | type: object 17 | description: > 18 | The schema for a single Speech-to-Text transcription response. Parameters are aligned 19 | with the core functionality of the OpenAI Transcription API: 20 | https://platform.openai.com/docs/api-reference/audio/json-object 21 | additionalProperties: false 22 | required: 23 | - text 24 | properties: 25 | text: 26 | type: string 27 | description: The transcribed text. 28 | logprobs: 29 | type: array 30 | description: The log probabilities of the tokens in the transcription. 31 | items: 32 | $ref: "#/components/schemas/LogprobDetail" 33 | usage: 34 | description: Token usage statistics for the request. 35 | anyOf: 36 | - $ref: "#/components/schemas/TokenUsage" 37 | - $ref: "#/components/schemas/DurationUsage" 38 | 39 | components: 40 | schemas: 41 | LogprobDetail: 42 | title: LogprobDetail 43 | type: object 44 | description: A single log probability object for a token. 45 | additionalProperties: false 46 | required: 47 | - bytes 48 | - logprob 49 | - token 50 | properties: 51 | bytes: 52 | type: array 53 | description: The bytes of the token (represented as an array of integers). 54 | items: 55 | type: integer 56 | logprob: 57 | type: number 58 | description: The log probability of the token. 59 | token: 60 | type: string 61 | description: The token in the transcription. 62 | 63 | TokenUsage: 64 | title: TokenUsage 65 | type: object 66 | description: Token usage statistics. 67 | additionalProperties: false 68 | required: 69 | - input_tokens 70 | - output_tokens 71 | - total_tokens 72 | - type 73 | - input_token_details 74 | properties: 75 | input_tokens: 76 | type: integer 77 | description: Number of input tokens for this request. 78 | output_tokens: 79 | type: integer 80 | description: Number of output tokens generated. 81 | total_tokens: 82 | type: integer 83 | description: Total number of tokens used (input + output). 84 | type: 85 | type: string 86 | description: The type of the usage object. Always 'tokens' for this variant. 87 | enum: 88 | - tokens 89 | input_token_details: 90 | $ref: "#/components/schemas/InputTokenDetails" 91 | 92 | DurationUsage: 93 | title: DurationUsage 94 | type: object 95 | description: Audio input duration usage statistics. 96 | additionalProperties: false 97 | required: 98 | - seconds 99 | - type 100 | properties: 101 | seconds: 102 | type: number 103 | description: Duration of the input audio in seconds. 104 | type: 105 | type: string 106 | description: The type of the usage object. Always 'duration' for this variant. 107 | enum: 108 | - duration 109 | 110 | InputTokenDetails: 111 | title: InputTokenDetails 112 | type: object 113 | description: Details about the input tokens billed for this request. 114 | additionalProperties: false 115 | required: 116 | - audio_tokens 117 | - text_tokens 118 | properties: 119 | audio_tokens: 120 | type: integer 121 | description: Number of audio tokens billed for this request. 122 | text_tokens: 123 | type: integer 124 | description: Number of text tokens billed for this request. -------------------------------------------------------------------------------- /python/serving/predictor.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Generates embeddings for text and imaging data.""" 16 | import base64 17 | import binascii 18 | import io 19 | from typing import Any, Mapping 20 | 21 | import numpy as np 22 | from scipy.io import wavfile 23 | import transformers 24 | 25 | from serving.serving_framework import model_runner 26 | from serving.logging_lib import cloud_logging_client 27 | 28 | # Input schema keys 29 | AUDIO_INPUT_KEY = 'file' 30 | 31 | # Model input dictionary mapping 32 | MODEL_INPUT_KEYS = { 33 | 'input_features': 'input_features__0', 34 | 'attention_mask': 'attention_mask__1', 35 | } 36 | 37 | # Model output dictionary keys 38 | TEXT_TRANCRIPT_KEY = 'text' 39 | 40 | 41 | def load_wav(audio_string: str) -> np.ndarray: 42 | """Read a b64 encoded wav file for and convert to mono if needed.""" 43 | try: 44 | audio_bytes = base64.b64decode(audio_string, validate=True) 45 | except (binascii.Error, ValueError) as exp: 46 | raise ValueError('Cannot decode input bytes.') from exp 47 | try: 48 | sample_rate, waveform = wavfile.read(io.BytesIO(audio_bytes)) 49 | except ValueError as exp: 50 | raise ValueError('Invalid wav file.') from exp 51 | cloud_logging_client.debug( 52 | f'WAV file sample rate: {sample_rate}, shape: {waveform.shape},' 53 | f' dtype={waveform.dtype}' 54 | ) 55 | type_info = waveform.dtype 56 | if waveform.ndim > 1: 57 | cloud_logging_client.info('Audio is stereo, converting to mono.') 58 | # Convert to mono by averaging the channels 59 | waveform = waveform.mean(axis=1) 60 | if sample_rate != 16000: 61 | raise ValueError( 62 | f'Sample rate {sample_rate} is not 16000, which is the expected' 63 | ' sample rate for audio.' 64 | ) 65 | # Normalize the waveform to -1, 1 float range. 66 | match type_info.kind: 67 | case 'i': 68 | waveform = waveform/np.iinfo(type_info).max 69 | case 'u': 70 | raise ValueError('Unsigned wav format is not supported.') 71 | case 'f': 72 | pass # already in -1, 1 float range. 73 | return waveform 74 | 75 | 76 | class MedASRPredictor: 77 | """Callable responsible for generating embeddings.""" 78 | 79 | def __init__(self, model_source: str, token: str | None = None): 80 | self._model_source = model_source 81 | self._token = token 82 | self._processor = None 83 | 84 | def predict( 85 | self, 86 | prediction_input: Mapping[str, Any], 87 | model: model_runner.ModelRunner, 88 | ) -> dict[str, Any]: 89 | """Runs inference on provided patches. 90 | 91 | Args: 92 | prediction_input: JSON formatted input for embedding prediction. 93 | model: ModelRunner to handle model step. 94 | 95 | Returns: 96 | JSON formatted output. 97 | """ 98 | 99 | if self._processor is None: 100 | self._processor = transformers.AutoProcessor.from_pretrained( 101 | self._model_source, 102 | token=self._token, 103 | ) 104 | 105 | # build response for each instance. 106 | try: 107 | audio = load_wav(prediction_input[AUDIO_INPUT_KEY]) 108 | cloud_logging_client.debug('Audio loaded.') 109 | processed = self._processor(audio, return_tensors='np') 110 | cloud_logging_client.debug('Model input processed.') 111 | model_input = { 112 | MODEL_INPUT_KEYS[key]: value for key, value in processed.items() 113 | } 114 | tokens = model.run_model(model_input, model_output_key='tokens__0') 115 | cloud_logging_client.debug('Model run completed.') 116 | transcript = self._processor.batch_decode(tokens)[0] 117 | cloud_logging_client.debug('Tokens decoded.') 118 | cloud_logging_client.info('Returning transcripts.') 119 | return{TEXT_TRANCRIPT_KEY: transcript} 120 | except ValueError as exp: 121 | cloud_logging_client.warning('Failed loading wav file.') 122 | return {'error': str(exp)} 123 | -------------------------------------------------------------------------------- /python/serving/vertex_schemata/request.yaml: -------------------------------------------------------------------------------- 1 | # Copyright 2025 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | title: MedASRRequest 16 | type: object 17 | description: > 18 | The schema for a single Speech-to-Text transcription request. Parameters are aligned 19 | with the core functionality of the OpenAI Transcription API: 20 | https://platform.openai.com/docs/api-reference/audio/createTranscription 21 | additionalProperties: false 22 | required: 23 | - file 24 | properties: 25 | file: 26 | type: string 27 | format: binary 28 | description: > 29 | The audio file object (not file name) to transcribe, in one of these formats: flac, 30 | mp3, mp4, mpeg, mpga, m4a, ogg, wav, or webm. 31 | title: Audio File 32 | model: 33 | anyOf: 34 | - type: string 35 | - type: "null" 36 | description: > 37 | This field is included for compatibility with the OpenAI Transcription API. It is not 38 | currently used, because the model is implicitly selected based on the endpoint used. 39 | title: Model 40 | chunking_strategy: 41 | anyOf: 42 | - type: string 43 | enum: 44 | - auto 45 | - $ref: "#/components/schemas/ServerVAD" 46 | - type: "null" 47 | description: > 48 | Controls how the audio is cut into chunks. When set to "auto", the server 49 | first normalizes loudness and then uses voice activity detection (VAD) to 50 | choose boundaries. A server_vad object can be provided to tweak VAD detection 51 | parameters manually. If unset, the audio is transcribed as a single block. 52 | title: Chunking Strategy 53 | include: 54 | anyOf: 55 | - type: array 56 | items: 57 | type: string 58 | enum: 59 | - logprobs 60 | - type: "null" 61 | description: > 62 | Additional information to include in the transcription response. logprobs will 63 | return the log probabilities of the tokens in the response to understand the 64 | model's confidence in the transcription. 65 | title: Include 66 | language: 67 | anyOf: 68 | - type: string 69 | - type: "null" 70 | description: > 71 | The language of the input audio. Supplying the input language in ISO-639-1 72 | (e.g. en) format will improve accuracy and latency. 73 | title: Language 74 | prompt: 75 | anyOf: 76 | - type: string 77 | - type: "null" 78 | description: > 79 | An optional text to guide the model's style or continue a previous audio 80 | segment. The prompt should match the audio language. 81 | title: Prompt 82 | temperature: 83 | anyOf: 84 | - type: number 85 | minimum: 0 86 | maximum: 1 87 | - type: "null" 88 | default: 0 89 | description: > 90 | The sampling temperature, between 0 and 1. Higher values like 0.8 will make 91 | the output more random, while lower values like 0.2 will make it more 92 | focused and deterministic. 93 | title: Temperature 94 | 95 | components: 96 | schemas: 97 | ServerVAD: 98 | title: ServerVAD 99 | type: object 100 | description: > 101 | Object to tweak Voice Activity Detection (VAD) parameters manually. 102 | additionalProperties: false 103 | required: 104 | - type 105 | properties: 106 | type: 107 | type: string 108 | description: > 109 | Must be set to 'server_vad' to enable manual chunking using server side VAD. 110 | enum: 111 | - server_vad 112 | prefix_padding_ms: 113 | type: integer 114 | description: Amount of audio to include before the VAD detected speech (in milliseconds). 115 | default: 300 116 | silence_duration_ms: 117 | type: integer 118 | description: > 119 | Duration of silence to detect speech stop (in milliseconds). With shorter values the 120 | model will respond more quickly, but may jump in on short pauses from the user. 121 | default: 200 122 | threshold: 123 | type: number 124 | description: > 125 | Sensitivity threshold (0.0 to 1.0) for voice activity detection. A higher 126 | threshold will require louder audio to activate the model, and thus might perform 127 | better in noisy environments. 128 | default: 0.5 129 | minimum: 0.0 130 | maximum: 1.0 -------------------------------------------------------------------------------- /python/serving/serving_framework/model_runner.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Abstract base class for dependency injection of model handling. 16 | 17 | Wraps execution of models on input tensors in an implementation-agnostic 18 | interface. Provides a mixin method for batching model execution. 19 | """ 20 | 21 | import abc 22 | from collections.abc import Mapping, Sequence, Set 23 | from typing import Any 24 | 25 | import numpy as np 26 | 27 | 28 | class ModelRunner(abc.ABC): 29 | """Runs a model with tensor inputs and outputs.""" 30 | 31 | @abc.abstractmethod 32 | def run_model_multiple_output( 33 | self, 34 | model_input: Mapping[str, np.ndarray] | np.ndarray, 35 | *, 36 | model_name: str = "default", 37 | model_version: int | None = None, 38 | model_output_keys: Set[str], 39 | parameters: Mapping[str, Any] | None = None, 40 | ) -> Mapping[str, np.ndarray]: 41 | """Runs a model on the given input and returns multiple outputs. 42 | 43 | Args: 44 | model_input: An array or map of arrays comprising the input tensors for 45 | the model. A bare array is given a default input key. 46 | model_name: The name of the model to run. 47 | model_version: The version of the model to run. Uses default if None. 48 | model_output_keys: The desired model output keys. 49 | parameters: Additional parameters to pass to the model. 50 | 51 | Returns: 52 | A mapping of model output keys to tensors. 53 | """ 54 | 55 | def run_model( 56 | self, 57 | model_input: Mapping[str, np.ndarray] | np.ndarray, 58 | *, 59 | model_name: str = "default", 60 | model_version: int | None = None, 61 | model_output_key: str = "output_0", 62 | parameters: Mapping[str, Any] | None = None, 63 | ) -> np.ndarray: 64 | """Runs a model on the given input. 65 | 66 | Args: 67 | model_input: An array or map of arrays comprising the input tensors for 68 | the model. A bare array is given a default input key. 69 | model_name: The name of the model to run. 70 | model_version: The version of the model to run. Uses default if None. 71 | model_output_key: The key to pull the output from. Defaults to "output_0". 72 | parameters: Additional parameters to pass to the model. 73 | 74 | Returns: 75 | The single output tensor. 76 | """ 77 | return self.run_model_multiple_output( 78 | model_input, 79 | model_name=model_name, 80 | model_version=model_version, 81 | model_output_keys={model_output_key}, 82 | parameters=parameters, 83 | )[model_output_key] 84 | 85 | def batch_model( 86 | self, 87 | model_inputs: Sequence[Mapping[str, np.ndarray]] | Sequence[np.ndarray], 88 | *, 89 | model_name: str = "default", 90 | model_version: int | None = None, 91 | model_output_key: str = "output_0", 92 | parameters: Mapping[str, Any] | None = None, 93 | ) -> list[np.ndarray]: 94 | """Runs a model on each of the given inputs. 95 | 96 | Args: 97 | model_inputs: A sequence of arrays or maps of arrays comprising the input 98 | tensors for the model. Bare arrays are given a default input key. 99 | model_name: The name of the model to run. 100 | model_version: The version of the model to run. Uses default if None. 101 | model_output_key: The key to pull the output from. Defaults to "output_0". 102 | parameters: Additional parameters to pass to the model. 103 | 104 | Returns: 105 | A list of the single output tensor from each input. 106 | """ 107 | return [ 108 | self.run_model( 109 | model_input, 110 | model_name=model_name, 111 | model_version=model_version, 112 | model_output_key=model_output_key, 113 | parameters=parameters, 114 | ) 115 | for model_input in model_inputs 116 | ] 117 | 118 | def batch_model_multiple_output( 119 | self, 120 | model_inputs: Sequence[Mapping[str, np.ndarray]] | Sequence[np.ndarray], 121 | *, 122 | model_name: str = "default", 123 | model_version: int | None = None, 124 | model_output_keys: Set[str], 125 | parameters: Mapping[str, Any] | None = None, 126 | ) -> list[Mapping[str, np.ndarray]]: 127 | """Runs a model on the given inputs and returns multiple outputs. 128 | 129 | Args: 130 | model_inputs: An array or map of arrays comprising the input tensors for 131 | the model. Bare arrays are given a default input key. 132 | model_name: The name of the model to run. 133 | model_version: The version of the model to run. Uses default if None. 134 | model_output_keys: The desired model output keys. 135 | parameters: Additional parameters to pass to the model. 136 | 137 | Returns: 138 | A list containing the mapping of model output keys to tensors from each 139 | input. 140 | """ 141 | return [ 142 | self.run_model_multiple_output( 143 | model_input, 144 | model_name=model_name, 145 | model_version=model_version, 146 | model_output_keys=model_output_keys, 147 | parameters=parameters, 148 | ) 149 | for model_input in model_inputs 150 | ] 151 | -------------------------------------------------------------------------------- /python/serving/serving_framework/triton/triton_streaming_server_model_runner.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """ModelRunner implementation using streaming async grpc to Triton model server. 16 | 17 | Uses Triton GRPC aio client and relies on a model server running locally. 18 | """ 19 | 20 | import asyncio 21 | from collections.abc import Mapping, Set 22 | from typing import Any 23 | 24 | from absl import logging 25 | import numpy as np 26 | from tritonclient import grpc as triton_grpc 27 | from tritonclient import utils as triton_utils 28 | from tritonclient.grpc import aio as triton_aio 29 | from typing_extensions import override 30 | 31 | from serving.serving_framework import model_runner 32 | 33 | _HOSTPORT = "localhost:8500" 34 | 35 | 36 | async def _queue_request_generator(request_queue: asyncio.Queue): 37 | while True: 38 | request = await request_queue.get() 39 | if request is None: 40 | break 41 | yield request 42 | 43 | 44 | class TritonStreamingServerModelRunner(model_runner.ModelRunner): 45 | """ModelRunner implementation streaming grpc to NVIDIA Triton model server.""" 46 | 47 | def __init__(self): 48 | pass # TODO(bramsterling): set up persistent event loop and rpc connection. 49 | 50 | async def _call(self, model_name, inputs, model_version, parameters): 51 | async with triton_aio.InferenceServerClient(_HOSTPORT) as client: 52 | # Create a queue to hold the stream open after the request is sent. 53 | # Experimentation demonstrated that not doing this results in 54 | # asyncio.CancelledError due to server-side cancellation. 55 | request_queue = asyncio.Queue() 56 | request_queue.put_nowait({ 57 | "model_name": model_name, 58 | "inputs": inputs, 59 | "model_version": model_version, 60 | "parameters": parameters, 61 | }) 62 | 63 | result_iterator = client.stream_infer( 64 | _queue_request_generator(request_queue) 65 | ) 66 | 67 | extract = [] 68 | try: 69 | async for result, error in result_iterator: 70 | # Enqueue signal for the input stream to close 71 | request_queue.put_nowait(None) 72 | if error: 73 | raise error 74 | 75 | extract.append(result) 76 | 77 | # Since there is only one request, the loop only runs once. 78 | 79 | if not extract: 80 | raise RuntimeError( 81 | "Stream from model server closed without yielding a result." 82 | ) 83 | finally: 84 | # Ensure the input stream gets the close signal in any error case. 85 | # If None has already been enqueued, enqueuing it again has no 86 | # consequence. 87 | request_queue.put_nowait(None) 88 | 89 | return extract[0] 90 | 91 | @override 92 | def run_model_multiple_output( 93 | self, 94 | model_input: Mapping[str, np.ndarray] | np.ndarray, 95 | *, 96 | model_name: str = "default", 97 | model_version: int | None = None, 98 | model_output_keys: Set[str], 99 | parameters: Mapping[str, Any] | None = None, 100 | ) -> Mapping[str, np.ndarray]: 101 | """Runs a model on the given input and returns multiple outputs. 102 | 103 | Args: 104 | model_input: An array or map of arrays comprising the input tensors for 105 | the model. A bare array is keyed by "inputs". 106 | model_name: The name of the model to run. 107 | model_version: The version of the model to run. Uses default if None. 108 | model_output_keys: The desired model output keys. 109 | parameters: Additional parameters to pass to the model. 110 | 111 | Returns: 112 | A mapping of model output keys to tensors. 113 | 114 | Raises: 115 | KeyError: If any of the model_output_keys are not found in the model 116 | output. 117 | """ 118 | # If a bare np.ndarray was passed, it will be passed using the default 119 | # input key "inputs". 120 | # If a Mapping was passed, use the keys from that mapping. 121 | if isinstance(model_input, np.ndarray): 122 | logging.debug("Handling bare input tensor.") 123 | input_map = {"inputs": model_input} 124 | else: 125 | input_map = model_input 126 | 127 | model_version = str(model_version) if model_version is not None else "" 128 | 129 | inputs = [] 130 | for key, data in input_map.items(): 131 | input_tensor = triton_grpc.InferInput( 132 | key, data.shape, triton_utils.np_to_triton_dtype(data.dtype) 133 | ) 134 | input_tensor.set_data_from_numpy(data) 135 | inputs.append(input_tensor) 136 | 137 | try: 138 | result = asyncio.run( 139 | self._call(model_name, inputs, model_version, parameters) 140 | ) 141 | except asyncio.exceptions.CancelledError as e: 142 | raise RuntimeError( 143 | "Model server request was cancelled. This may be due to the server" 144 | " shutting down or an internal asyncio error." 145 | ) from e 146 | 147 | assert result is not None # infer never returns None, despite annotation. 148 | 149 | outputs = {key: result.as_numpy(key) for key in model_output_keys} 150 | missing_keys = {key for key in model_output_keys if outputs[key] is None} 151 | if missing_keys: 152 | raise KeyError( 153 | f"Model output keys {missing_keys} not found in model output." 154 | ) 155 | return outputs 156 | -------------------------------------------------------------------------------- /python/serving/serving_framework/server_gunicorn_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import http 16 | import io 17 | import os 18 | import subprocess 19 | from unittest import mock 20 | 21 | from absl.testing import absltest 22 | from serving.serving_framework import server_gunicorn 23 | 24 | 25 | class DummyHealthCheck: 26 | 27 | def __init__(self, check_result: bool): 28 | self._check_result = check_result 29 | 30 | def check_health(self): 31 | return self._check_result 32 | 33 | 34 | class ServerGunicornTest(absltest.TestCase): 35 | 36 | def setUp(self): 37 | super().setUp() 38 | os.environ["AIP_PREDICT_ROUTE"] = "/fake-predict-route" 39 | os.environ["AIP_HEALTH_ROUTE"] = "/fake-health-route" 40 | 41 | def test_application_option_default(self): 42 | executor = mock.create_autospec( 43 | server_gunicorn.PredictionExecutor, 44 | instance=True, 45 | ) 46 | 47 | app = server_gunicorn.PredictionApplication(executor, health_check=None) 48 | 49 | self.assertEqual(app.cfg.workers, 1) 50 | 51 | def test_application_option_setting(self): 52 | options = { 53 | "workers": 3, 54 | } 55 | executor = mock.create_autospec( 56 | server_gunicorn.PredictionExecutor, 57 | instance=True, 58 | ) 59 | 60 | app = server_gunicorn.PredictionApplication( 61 | executor, options=options, health_check=None 62 | ) 63 | 64 | self.assertEqual(app.cfg.workers, 3) 65 | 66 | def test_health_route_no_check(self): 67 | 68 | executor = mock.create_autospec( 69 | server_gunicorn.PredictionExecutor, 70 | instance=True, 71 | ) 72 | 73 | app = server_gunicorn.PredictionApplication( 74 | executor, health_check=None 75 | ).load() 76 | service = app.test_client() 77 | 78 | response = service.get("/fake-health-route") 79 | 80 | self.assertEqual(response.status_code, http.HTTPStatus.OK) 81 | self.assertEqual(response.text, "ok") 82 | 83 | def test_predict_route_no_json(self): 84 | executor = mock.create_autospec( 85 | server_gunicorn.PredictionExecutor, 86 | instance=True, 87 | ) 88 | app = server_gunicorn.PredictionApplication( 89 | executor, health_check=None 90 | ).load() 91 | service = app.test_client() 92 | 93 | response = service.post("/fake-predict-route", data="invalid") 94 | 95 | executor.start.assert_called_once() 96 | executor.execute.assert_not_called() 97 | self.assertEqual(response.status_code, http.HTTPStatus.BAD_REQUEST) 98 | self.assertDictEqual({"error": "No JSON body."}, response.get_json()) 99 | 100 | def test_predict_route(self): 101 | executor = mock.create_autospec( 102 | server_gunicorn.PredictionExecutor, 103 | instance=True, 104 | ) 105 | app = server_gunicorn.PredictionApplication( 106 | executor, health_check=None 107 | ).load() 108 | service = app.test_client() 109 | executor.execute.return_value = {"placeholder": "output"} 110 | 111 | response = service.post( 112 | "/fake-predict-route", json={"instances": [{"filler": "filler"}]} 113 | ) 114 | 115 | executor.start.assert_called_once() 116 | executor.execute.assert_called_once_with( 117 | {"instances": [{"filler": "filler"}]} 118 | ) 119 | self.assertEqual(response.status_code, http.HTTPStatus.OK) 120 | self.assertDictEqual({"placeholder": "output"}, response.get_json()) 121 | 122 | def test_subprocess_executor_execute(self): 123 | mock_process = mock.create_autospec(subprocess.Popen, instance=True) 124 | with mock.patch.object( 125 | subprocess, "Popen", autospec=True, return_value=mock_process 126 | ) as mock_popen: 127 | executor = server_gunicorn.SubprocessPredictionExecutor(["fake_command"]) 128 | executor.start() 129 | mock_popen.assert_called_once_with( 130 | args=["fake_command"], 131 | stdout=subprocess.PIPE, 132 | stdin=subprocess.PIPE, 133 | ) 134 | mock_process.stdout = io.BytesIO(b'{"placeholder": "output"}\n') 135 | mock_process.stdin = io.BytesIO() 136 | 137 | response = executor.execute({"meaningless": "filler"}) 138 | 139 | self.assertEqual( 140 | b'{"meaningless": "filler"}\n', mock_process.stdin.getvalue() 141 | ) 142 | self.assertDictEqual({"placeholder": "output"}, response) 143 | 144 | def test_subprocess_executor_execute_error_output_closed(self): 145 | mock_process = mock.create_autospec(subprocess.Popen, instance=True) 146 | with mock.patch.object( 147 | subprocess, "Popen", autospec=True, return_value=mock_process 148 | ) as mock_popen: 149 | executor = server_gunicorn.SubprocessPredictionExecutor(["fake_command"]) 150 | executor.start() 151 | 152 | mock_process.stdout = io.BytesIO() # empty output simulates closed pipe. 153 | mock_process.stdin = io.BytesIO() 154 | 155 | with self.assertRaises(RuntimeError) as raised: 156 | executor.execute({"meaningless": "filler"}) 157 | self.assertEqual( 158 | raised.exception.args[0], "Executor process output stream closed." 159 | ) 160 | self.assertEqual(mock_popen.call_count, 2) # executor restarted. 161 | 162 | def test_subprocess_executor_execute_error_input_broken(self): 163 | mock_process = mock.create_autospec(subprocess.Popen, instance=True) 164 | with mock.patch.object( 165 | subprocess, "Popen", autospec=True, return_value=mock_process 166 | ) as mock_popen: 167 | executor = server_gunicorn.SubprocessPredictionExecutor(["fake_command"]) 168 | executor.start() 169 | 170 | mock_process.stdout = io.BytesIO(b'{"placeholder": "output"}\n') 171 | # Simulate broken pipe. 172 | mock_process.stdin = mock.create_autospec(io.BytesIO, instance=True) 173 | mock_process.stdin.write.side_effect = BrokenPipeError 174 | 175 | with self.assertRaises(RuntimeError): 176 | executor.execute({"meaningless": "filler"}) 177 | self.assertEqual(mock_popen.call_count, 2) # executor restarted. 178 | 179 | 180 | if __name__ == "__main__": 181 | absltest.main() 182 | -------------------------------------------------------------------------------- /python/serving/serving_framework/triton/triton_server_model_runner_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from unittest import mock 16 | 17 | import numpy as np 18 | from tritonclient import grpc as triton_grpc 19 | 20 | from absl.testing import absltest 21 | from serving.serving_framework.triton import triton_server_model_runner 22 | from tritonclient.grpc import service_pb2 23 | 24 | 25 | class TritonServerModelRunnerTest(absltest.TestCase): 26 | 27 | def setUp(self): 28 | super().setUp() 29 | 30 | self._client = mock.create_autospec( 31 | triton_grpc.InferenceServerClient, instance=True 32 | ) 33 | self._client.infer = mock.MagicMock() 34 | self._runner = triton_server_model_runner.TritonServerModelRunner( 35 | client=self._client 36 | ) 37 | 38 | self._output_proto = service_pb2.ModelInferResponse( 39 | model_name="test_model", 40 | model_version="1", 41 | outputs=[ 42 | service_pb2.ModelInferResponse.InferOutputTensor( 43 | name="output_a", 44 | shape=[2, 2], 45 | datatype="FP32", 46 | ), 47 | service_pb2.ModelInferResponse.InferOutputTensor( 48 | name="output_b", 49 | shape=[1, 2], 50 | datatype="INT64", 51 | ), 52 | service_pb2.ModelInferResponse.InferOutputTensor( 53 | name="output_c", 54 | shape=[3, 2], 55 | datatype="FP32", 56 | ), 57 | ], 58 | raw_output_contents=[ 59 | np.ones((2, 2), dtype=np.float32).tobytes(), 60 | np.array([[7] * 2], dtype=np.int64).tobytes(), 61 | np.zeros((3, 2), dtype=np.float32).tobytes(), 62 | ], 63 | ) 64 | 65 | def test_run_map_check_input(self): 66 | """Tests that an input map is passed to the model correctly.""" 67 | input_map = { 68 | "a": np.array([[0.5] * 3] * 3, dtype=np.float32), 69 | "b": np.array([[2] * 2] * 3, dtype=np.int64), 70 | } 71 | self._client.infer.return_value = triton_grpc.InferResult( 72 | self._output_proto 73 | ) 74 | 75 | _ = self._runner.run_model_multiple_output( 76 | input_map, 77 | model_name="test_model", 78 | model_output_keys={"output_a", "output_b", "output_c"}, 79 | ) 80 | 81 | self.assertLen(self._client.infer.call_args_list, 1) 82 | self.assertEqual( 83 | self._client.infer.call_args[0][0], 84 | "test_model", 85 | "Model name passed to model does not match expectation.", 86 | ) 87 | self.assertEqual( 88 | self._client.infer.call_args[0][2], 89 | "", 90 | "Model version passed to model does not match expectation.", 91 | ) 92 | self.assertEqual( 93 | self._client.infer.call_args[0][1][0]._get_tensor(), 94 | service_pb2.ModelInferRequest.InferInputTensor( 95 | name="a", 96 | shape=[3, 3], 97 | datatype="FP32", 98 | ), 99 | msg="Input tensor passed to model does not match expectation.", 100 | ) 101 | self.assertEqual( 102 | self._client.infer.call_args[0][1][1]._get_tensor(), 103 | service_pb2.ModelInferRequest.InferInputTensor( 104 | name="b", 105 | shape=[3, 2], 106 | datatype="INT64", 107 | ), 108 | msg="Input tensor passed to model does not match expectation.", 109 | ) 110 | 111 | def test_run_bare_check_input(self): 112 | """Tests the handling of a bare input tensor passed to the model.""" 113 | input_tensor = np.array([[0.5] * 3] * 3, dtype=np.float32) 114 | self._client.infer.return_value = triton_grpc.InferResult( 115 | self._output_proto 116 | ) 117 | 118 | _ = self._runner.run_model_multiple_output( 119 | input_tensor, 120 | model_name="test_model", 121 | model_output_keys={"output_a", "output_b", "output_c"}, 122 | ) 123 | 124 | self.assertLen(self._client.infer.call_args_list, 1) 125 | self.assertEqual( 126 | self._client.infer.call_args[0][0], 127 | "test_model", 128 | "Model name passed to model does not match expectation.", 129 | ) 130 | self.assertEqual( 131 | self._client.infer.call_args[0][2], 132 | "", 133 | "Model version passed to model does not match expectation.", 134 | ) 135 | self.assertEqual( 136 | self._client.infer.call_args[0][1][0]._get_tensor(), 137 | service_pb2.ModelInferRequest.InferInputTensor( 138 | name="inputs", 139 | shape=[3, 3], 140 | datatype="FP32", 141 | ), 142 | msg="Input tensor passed to model does not match expectation.", 143 | ) 144 | 145 | def test_input_model_version(self): 146 | """Tests that the model version is passed to the model correctly.""" 147 | input_tensor = np.array([[0.5] * 3] * 3, dtype=np.float32) 148 | self._client.infer.return_value = triton_grpc.InferResult( 149 | self._output_proto 150 | ) 151 | 152 | _ = self._runner.run_model_multiple_output( 153 | input_tensor, 154 | model_name="test_model", 155 | model_version=3, 156 | model_output_keys={"output_a", "output_b", "output_c"}, 157 | ) 158 | 159 | self.assertEqual( 160 | self._client.infer.call_args[0][2], 161 | "3", 162 | "Model version passed to model does not match expectation.", 163 | ) 164 | 165 | def test_run_check_output(self): 166 | """Tests that the output is returned correctly.""" 167 | input_map = { 168 | "a": np.array([[0.5] * 3] * 3, dtype=np.float32), 169 | "b": np.array([[2] * 2] * 3, dtype=np.int64), 170 | } 171 | self._client.infer.return_value = triton_grpc.InferResult( 172 | self._output_proto 173 | ) 174 | 175 | result = self._runner.run_model_multiple_output( 176 | input_map, 177 | model_name="test_model", 178 | model_output_keys={"output_a", "output_b"}, 179 | ) 180 | 181 | np.testing.assert_array_equal( 182 | result["output_a"], 183 | np.ones((2, 2), dtype=np.float32), 184 | "Output tensor passed to model does not match expectation.", 185 | ) 186 | np.testing.assert_array_equal( 187 | result["output_b"], 188 | np.array([[7] * 2], dtype=np.int64), 189 | "Output tensor passed to model does not match expectation.", 190 | ) 191 | self.assertLen(result, 2) 192 | 193 | def test_run_check_missing_key(self): 194 | """Tests that the output is returned correctly.""" 195 | input_map = { 196 | "a": np.array([[0.5] * 3] * 3, dtype=np.float32), 197 | "b": np.array([[2] * 2] * 3, dtype=np.int64), 198 | } 199 | self._client.infer.return_value = triton_grpc.InferResult( 200 | self._output_proto 201 | ) 202 | 203 | with self.assertRaises(KeyError): 204 | _ = self._runner.run_model_multiple_output( 205 | input_map, 206 | model_name="test_model", 207 | model_output_keys={"output_a", "output_d"}, 208 | ) 209 | 210 | if __name__ == "__main__": 211 | absltest.main() 212 | -------------------------------------------------------------------------------- /python/serving/logging_lib/flags/secret_flag_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Util to get env value from secret manager.""" 16 | 17 | import json 18 | import os 19 | import re 20 | import sys 21 | import threading 22 | from typing import Any, Mapping, TypeVar, Union 23 | 24 | import cachetools 25 | import google.api_core 26 | from google.cloud import secretmanager 27 | 28 | from serving.logging_lib.flags import flag_utils 29 | 30 | _T = TypeVar('_T') 31 | 32 | _SECRET_MANAGER_ENV_CONFIG = 'SECRET_MANAGER_ENV_CONFIG' 33 | 34 | _PARSE_SECRET_CONFIG = re.compile( 35 | r'.*?projects/([^/]+?)/secrets/([^/]+?)($|(/|(/versions/([^/]+))))', 36 | re.IGNORECASE, 37 | ) 38 | 39 | # Enable secret manager if not debugging. 40 | _ENABLE_ENV_SECRET_MANAGER = bool( 41 | 'UNITTEST_ON_FORGE' not in os.environ and 'unittest' not in sys.modules 42 | ) 43 | 44 | 45 | # Cache secret metadata to avoid repeated reads when initalizing flags. 46 | _cache = cachetools.LRUCache(maxsize=1) 47 | _cache_lock = threading.Lock() 48 | 49 | 50 | class SecretDecodeError(Exception): 51 | 52 | def __init__(self, msg: str, secret_name: str = '', data: str = ''): 53 | super().__init__(msg) 54 | self._secret_name = secret_name 55 | self._data = data 56 | 57 | 58 | def _init_fork_module_state() -> None: 59 | global _cache 60 | global _cache_lock 61 | _cache = cachetools.LRUCache(maxsize=1) 62 | _cache_lock = threading.Lock() 63 | 64 | 65 | def _get_secret_version( 66 | client: secretmanager.SecretManagerServiceClient, parent: str 67 | ) -> str: 68 | """Returns the greatest version number for a secret. 69 | 70 | Args: 71 | client: SecretManagerServiceClient. 72 | parent: String defining project and name of secret to return version. 73 | 74 | Returns: 75 | greatest version number for secret. 76 | 77 | Raises: 78 | SecretDecodeError: Could not identify version for secret. 79 | """ 80 | version_list = client.list_secret_versions(request={'parent': parent}) 81 | versions_found = [] 82 | for name in [ver.name for ver in version_list]: 83 | match = _PARSE_SECRET_CONFIG.fullmatch(name) 84 | if match is None: 85 | continue 86 | try: 87 | versions_found.append(int(match.groups()[-1])) 88 | except ValueError: 89 | continue 90 | if not versions_found: 91 | raise SecretDecodeError( 92 | f'Could not find version for secret {parent}.', secret_name=parent 93 | ) 94 | return str(max(versions_found)) 95 | 96 | 97 | def _read_secrets(secret_name: str) -> Mapping[str, Any]: 98 | """Returns secret from secret manager. 99 | 100 | Args: 101 | secret_name: Name of secret. 102 | 103 | Returns: 104 | Secret value. 105 | 106 | Raises: 107 | SecretDecodeError: Error retrieving value from secret manager. 108 | """ 109 | if not secret_name: 110 | return {} 111 | match = _PARSE_SECRET_CONFIG.fullmatch(secret_name) 112 | if match is None: 113 | raise SecretDecodeError( 114 | 'incorrectly formatted secret; expecting' 115 | f' [projects/.+/secrets/.+/versions/.+; passed {secret_name}.', 116 | secret_name=secret_name, 117 | ) 118 | project, secret, *_, version = match.groups() 119 | if not _ENABLE_ENV_SECRET_MANAGER: 120 | return {} 121 | with _cache_lock: 122 | cached_val = _cache.get(secret_name) 123 | if cached_val is not None: 124 | return cached_val 125 | with secretmanager.SecretManagerServiceClient() as client: 126 | parent = client.secret_path(project, secret) 127 | try: 128 | if version is None or not version: 129 | version = _get_secret_version(client, parent) 130 | secret = client.access_secret_version( 131 | request={'name': f'{parent}/versions/{version}'} 132 | ) 133 | except google.api_core.exceptions.NotFound as exp: 134 | raise SecretDecodeError( 135 | 'Secret not found.', secret_name=secret_name 136 | ) from exp 137 | except google.api_core.exceptions.PermissionDenied as exp: 138 | raise SecretDecodeError( 139 | 'Permission denied reading secret.', secret_name=secret_name 140 | ) from exp 141 | data = secret.payload.data 142 | if data is None or not data: 143 | return {} 144 | if isinstance(data, bytes): 145 | data = data.decode('utf-8') 146 | try: 147 | value = json.loads(data) 148 | except json.JSONDecodeError as exp: 149 | raise SecretDecodeError( 150 | 'Could not decode secret value.', secret_name=secret_name, data=data 151 | ) from exp 152 | if not isinstance(value, Mapping): 153 | raise SecretDecodeError( 154 | 'Secret value does not define a mapping.', 155 | secret_name=secret_name, 156 | data=data, 157 | ) 158 | _cache[secret_name] = value 159 | return value 160 | 161 | 162 | def get_secret_or_env(name: str, default: _T) -> Union[str, _T]: 163 | """Returns value defined in secret manager, env, or if undefined default. 164 | 165 | Searchs first for variable definition in JSON dict stored within the GCP 166 | secret managner. The GCP secret containing the dict is defined by the 167 | _SECRET_MANAGER_ENV_CONFIG. If the variable is not found in the GCP encoded 168 | secret, or if the _SECRET_MANAGER_ENV_CONFIG is undefined then the container 169 | ENV are searched. If neither define the variable the default value is 170 | returned. 171 | 172 | Args: 173 | name: Name of ENV. 174 | default: Default value to return if name is not defined. 175 | 176 | Returns: 177 | ENV value. 178 | 179 | Raises: 180 | SecretDecodeError: Error retrieving value from secret manager. 181 | """ 182 | secret_name = os.environ.get(_SECRET_MANAGER_ENV_CONFIG) 183 | if secret_name is not None and secret_name: 184 | secret_env = _read_secrets(secret_name) 185 | result = secret_env.get(name) 186 | if result is not None: 187 | return str(result) 188 | return os.environ.get(name, default) 189 | 190 | 191 | def get_bool_secret_or_env( 192 | env_name: str, undefined_value: bool = False 193 | ) -> bool: 194 | """Returns bool variable value into boolean value. 195 | 196 | Args: 197 | env_name: Environmental variable name. 198 | undefined_value: Default value to set undefined values to. 199 | 200 | Returns: 201 | Boolean of environmental variable string value. 202 | 203 | Raises: 204 | SecretDecodeError: Error retrieving value from secret manager. 205 | ValueError: Environmental variable cannot be parsed to bool. 206 | """ 207 | value = get_secret_or_env(env_name, str(undefined_value)) 208 | if value is not None: 209 | return flag_utils.str_to_bool(value) 210 | 211 | 212 | # Interfaces may be used from processes which are forked (gunicorn, 213 | # DICOM Proxy, Orchestrator, Refresher). In Python, forked processes do not 214 | # copy threads running within parent processes or re-initalize global/module 215 | # state. This can result in forked modules being executed with invalid global 216 | # state, e.g., acquired locks that will not release or references to invalid 217 | # state. 218 | _init_fork_module_state() 219 | os.register_at_fork(after_in_child=_init_fork_module_state) 220 | -------------------------------------------------------------------------------- /python/serving/serving_framework/server_gunicorn.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Gunicorn application for passing requests through to the executor command. 16 | 17 | Provides a thin, subject-agnostic request server for Vertex endpoints which 18 | handles requests by piping their JSON bodies to the given executor command 19 | and returning the json output. 20 | """ 21 | 22 | import abc 23 | from collections.abc import Mapping, Sequence 24 | import http 25 | import json 26 | import os 27 | import subprocess 28 | from typing import Any, Optional 29 | 30 | from absl import logging 31 | import flask 32 | from gunicorn.app import base as gunicorn_base 33 | import jsonschema 34 | from typing_extensions import override 35 | 36 | 37 | INSTANCES_KEY = "instances" 38 | 39 | 40 | class PredictionExecutor(abc.ABC): 41 | """Wraps arbitrary implementation of executing a prediction request.""" 42 | 43 | @abc.abstractmethod 44 | def execute(self, input_json: dict[str, Any]) -> dict[str, Any]: 45 | """Executes the given request payload.""" 46 | 47 | def start(self) -> None: 48 | """Starts the executor. 49 | 50 | Called after the Gunicorn worker process has started. Performs any setup 51 | which needs to be done post-fork. 52 | """ 53 | 54 | 55 | class SubprocessPredictionExecutor(PredictionExecutor): 56 | """Provides prediction request execution using a persistent worker subprocess.""" 57 | 58 | def __init__(self, executor_command: Sequence[str]): 59 | """Initializes the executor with a command to start the subprocess.""" 60 | self._executor_command = executor_command 61 | self._executor_process = None 62 | 63 | def _restart(self) -> None: 64 | if self._executor_process is None: 65 | raise RuntimeError("Executor process not started.") 66 | 67 | self._executor_process.terminate() 68 | self.start() 69 | 70 | @override 71 | def start(self): 72 | """Starts the executor process.""" 73 | self._executor_process = subprocess.Popen( 74 | args=self._executor_command, 75 | stdout=subprocess.PIPE, 76 | stdin=subprocess.PIPE, 77 | ) 78 | 79 | @override 80 | def execute(self, input_json: dict[str, Any]) -> dict[str, Any]: 81 | """Uses the executor process to execute a request. 82 | 83 | Args: 84 | input_json: The full json prediction request payload. 85 | 86 | Returns: 87 | The json response to the prediction request. 88 | 89 | Raises: 90 | RuntimeError: Executor is not started or error communicating with the 91 | subprocess. 92 | """ 93 | if self._executor_process is None: 94 | raise RuntimeError("Executor process not started.") 95 | 96 | # Ensure json string is safe to pass through the pipe protocol. 97 | input_str = json.dumps(input_json).replace("\n", "") 98 | 99 | try: 100 | self._executor_process.stdin.write(input_str.encode("utf-8") + b"\n") 101 | self._executor_process.stdin.flush() 102 | except BrokenPipeError as e: 103 | self._restart() 104 | raise RuntimeError("Executor process input stream closed.") from e 105 | exec_result = self._executor_process.stdout.readline() 106 | if not exec_result: 107 | self._restart() 108 | raise RuntimeError("Executor process output stream closed.") 109 | try: 110 | return json.loads(exec_result) 111 | except json.JSONDecodeError as e: 112 | raise RuntimeError("Executor process output not valid json.") from e 113 | 114 | 115 | class ModelServerHealthCheck(abc.ABC): 116 | """Checks the health of the local model server.""" 117 | 118 | @abc.abstractmethod 119 | def check_health(self) -> bool: 120 | """Check the health of the local model server immediately.""" 121 | 122 | 123 | class _PredictRoute: 124 | """A callable for handling a prediction route.""" 125 | 126 | def __init__( 127 | self, 128 | executor: PredictionExecutor, 129 | *, 130 | input_validator: jsonschema.Draft202012Validator | None = None, 131 | instance_input: bool = True, 132 | prediction_validator: jsonschema.Draft202012Validator | None = None, 133 | ): 134 | self._executor = executor 135 | self._input_validator = input_validator 136 | self._instance_input = instance_input 137 | self._prediction_validator = prediction_validator 138 | 139 | def __call__(self) -> tuple[dict[str, Any], int]: 140 | logging.info("predict route hit") 141 | json_body = flask.request.get_json(silent=True) 142 | if json_body is None: 143 | return {"error": "No JSON body."}, http.HTTPStatus.BAD_REQUEST.value 144 | if self._instance_input and INSTANCES_KEY not in json_body: 145 | return { 146 | "error": "No instances field in request." 147 | }, http.HTTPStatus.BAD_REQUEST.value 148 | 149 | if self._input_validator is not None: 150 | try: 151 | if self._instance_input: 152 | for instance in json_body[INSTANCES_KEY]: 153 | self._input_validator.validate(instance) 154 | else: 155 | self._input_validator.validate(json_body) 156 | except jsonschema.exceptions.ValidationError as e: 157 | logging.warning("Input validation failed") 158 | return {"error": str(e)}, http.HTTPStatus.BAD_REQUEST.value 159 | 160 | logging.debug("Dispatching request to executor.") 161 | try: 162 | exec_result = self._executor.execute(json_body) 163 | logging.debug("Executor returned results.") 164 | if self._prediction_validator is not None: 165 | try: 166 | if "predictions" in exec_result: 167 | for result in exec_result["predictions"]: 168 | self._prediction_validator.validate(result) 169 | except jsonschema.exceptions.ValidationError as e: 170 | logging.exception("Response validation failed") 171 | return { 172 | "error": "Internal server error." 173 | }, http.HTTPStatus.INTERNAL_SERVER_ERROR.value 174 | 175 | return (exec_result, http.HTTPStatus.OK.value) 176 | except RuntimeError: 177 | logging.exception("Internal error handling request: Executor failed.") 178 | return { 179 | "error": "Internal server error." 180 | }, http.HTTPStatus.INTERNAL_SERVER_ERROR.value 181 | 182 | 183 | def _create_app( 184 | executor: PredictionExecutor, 185 | health_check: ModelServerHealthCheck | None, 186 | *, 187 | input_validator: jsonschema.Draft202012Validator | None = None, 188 | instance_input: bool = True, 189 | prediction_validator: jsonschema.Draft202012Validator | None = None, 190 | additional_routes: Mapping[str, PredictionExecutor] | None = None, 191 | ) -> flask.Flask: 192 | """Creates a Flask app with the given executor.""" 193 | flask_app = flask.Flask(__name__) 194 | 195 | if ( 196 | "AIP_HEALTH_ROUTE" not in os.environ 197 | or "AIP_PREDICT_ROUTE" not in os.environ 198 | ): 199 | raise ValueError( 200 | "Both of the environment variables AIP_HEALTH_ROUTE and " 201 | "AIP_PREDICT_ROUTE need to be specified." 202 | ) 203 | 204 | def health_route() -> tuple[str, int]: 205 | logging.info("health route hit") 206 | if health_check is not None and not health_check.check_health(): 207 | return "not ok", http.HTTPStatus.SERVICE_UNAVAILABLE.value 208 | return "ok", http.HTTPStatus.OK.value 209 | 210 | health_path = os.environ.get("AIP_HEALTH_ROUTE") 211 | logging.info("health path: %s", health_path) 212 | flask_app.add_url_rule(health_path, view_func=health_route) 213 | 214 | predict_route = os.environ.get("AIP_PREDICT_ROUTE") 215 | logging.info("predict route: %s", predict_route) 216 | flask_app.add_url_rule( 217 | rule=predict_route, 218 | endpoint=predict_route, 219 | view_func=_PredictRoute( 220 | executor, 221 | input_validator=input_validator, 222 | instance_input=instance_input, 223 | prediction_validator=prediction_validator, 224 | ), 225 | methods=["POST"], 226 | ) 227 | 228 | if additional_routes: 229 | for route, executor in additional_routes.items(): 230 | logging.info("additional route: %s", route) 231 | flask_app.add_url_rule( 232 | rule=route, 233 | endpoint=route, 234 | view_func=_PredictRoute( 235 | executor, 236 | instance_input=False, 237 | ), 238 | methods=["POST"], 239 | ) 240 | 241 | flask_app.config["TRAP_BAD_REQUEST_ERRORS"] = True 242 | 243 | return flask_app 244 | 245 | 246 | class PredictionApplication(gunicorn_base.BaseApplication): 247 | """Application to serve predictors on Vertex endpoints using gunicorn.""" 248 | 249 | def __init__( 250 | self, 251 | executor: PredictionExecutor, 252 | *, 253 | health_check: ModelServerHealthCheck | None, 254 | options: Optional[Mapping[str, Any]] = None, 255 | input_validator: jsonschema.Draft202012Validator | None = None, 256 | instance_input: bool = True, 257 | prediction_validator: jsonschema.Draft202012Validator | None = None, 258 | additional_routes: Mapping[str, PredictionExecutor] | None = None, 259 | ): 260 | """Initializes the application. 261 | 262 | Args: 263 | executor: The executor to use for handling prediction requests. 264 | health_check: The health check to use for the health check route. 265 | options: Gunicorn application options. 266 | input_validator: A jsonschema validator to apply to the input instances. 267 | instance_input: Whether the input will contain a list of instances which 268 | can be validated individually if a validator is provided. If False, the 269 | request object structure will not be assumed and the validator will 270 | be applied to the entire request body. 271 | prediction_validator: A jsonschema validator to apply to the predictions. 272 | additional_routes: Additional routes to serve on the server, keyed by 273 | route path. Validation is not applied to these routes. 274 | """ 275 | self.options = options or {} 276 | self.options = dict(self.options) 277 | self.options["preload_app"] = False 278 | self._executor = executor 279 | self.application = _create_app( 280 | self._executor, 281 | health_check, 282 | input_validator=input_validator, 283 | instance_input=instance_input, 284 | prediction_validator=prediction_validator, 285 | additional_routes=additional_routes, 286 | ) 287 | super().__init__() 288 | 289 | def load_config(self): 290 | config = { 291 | key: value 292 | for key, value in self.options.items() 293 | if key in self.cfg.settings and value is not None 294 | } 295 | for key, value in config.items(): 296 | self.cfg.set(key.lower(), value) 297 | 298 | def load(self) -> flask.Flask: 299 | self._executor.start() 300 | return self.application 301 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | 2 | Apache License 3 | Version 2.0, January 2004 4 | http://www.apache.org/licenses/ 5 | 6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 7 | 8 | 1. Definitions. 9 | 10 | "License" shall mean the terms and conditions for use, reproduction, 11 | and distribution as defined by Sections 1 through 9 of this document. 12 | 13 | "Licensor" shall mean the copyright owner or entity authorized by 14 | the copyright owner that is granting the License. 15 | 16 | "Legal Entity" shall mean the union of the acting entity and all 17 | other entities that control, are controlled by, or are under common 18 | control with that entity. For the purposes of this definition, 19 | "control" means (i) the power, direct or indirect, to cause the 20 | direction or management of such entity, whether by contract or 21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 22 | outstanding shares, or (iii) beneficial ownership of such entity. 23 | 24 | "You" (or "Your") shall mean an individual or Legal Entity 25 | exercising permissions granted by this License. 26 | 27 | "Source" form shall mean the preferred form for making modifications, 28 | including but not limited to software source code, documentation 29 | source, and configuration files. 30 | 31 | "Object" form shall mean any form resulting from mechanical 32 | transformation or translation of a Source form, including but 33 | not limited to compiled object code, generated documentation, 34 | and conversions to other media types. 35 | 36 | "Work" shall mean the work of authorship, whether in Source or 37 | Object form, made available under the License, as indicated by a 38 | copyright notice that is included in or attached to the work 39 | (an example is provided in the Appendix below). 40 | 41 | "Derivative Works" shall mean any work, whether in Source or Object 42 | form, that is based on (or derived from) the Work and for which the 43 | editorial revisions, annotations, elaborations, or other modifications 44 | represent, as a whole, an original work of authorship. For the purposes 45 | of this License, Derivative Works shall not include works that remain 46 | separable from, or merely link (or bind by name) to the interfaces of, 47 | the Work and Derivative Works thereof. 48 | 49 | "Contribution" shall mean any work of authorship, including 50 | the original version of the Work and any modifications or additions 51 | to that Work or Derivative Works thereof, that is intentionally 52 | submitted to Licensor for inclusion in the Work by the copyright owner 53 | or by an individual or Legal Entity authorized to submit on behalf of 54 | the copyright owner. For the purposes of this definition, "submitted" 55 | means any form of electronic, verbal, or written communication sent 56 | to the Licensor or its representatives, including but not limited to 57 | communication on electronic mailing lists, source code control systems, 58 | and issue tracking systems that are managed by, or on behalf of, the 59 | Licensor for the purpose of discussing and improving the Work, but 60 | excluding communication that is conspicuously marked or otherwise 61 | designated in writing by the copyright owner as "Not a Contribution." 62 | 63 | "Contributor" shall mean Licensor and any individual or Legal Entity 64 | on behalf of whom a Contribution has been received by Licensor and 65 | subsequently incorporated within the Work. 66 | 67 | 2. Grant of Copyright License. Subject to the terms and conditions of 68 | this License, each Contributor hereby grants to You a perpetual, 69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 70 | copyright license to reproduce, prepare Derivative Works of, 71 | publicly display, publicly perform, sublicense, and distribute the 72 | Work and such Derivative Works in Source or Object form. 73 | 74 | 3. Grant of Patent License. Subject to the terms and conditions of 75 | this License, each Contributor hereby grants to You a perpetual, 76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 77 | (except as stated in this section) patent license to make, have made, 78 | use, offer to sell, sell, import, and otherwise transfer the Work, 79 | where such license applies only to those patent claims licensable 80 | by such Contributor that are necessarily infringed by their 81 | Contribution(s) alone or by combination of their Contribution(s) 82 | with the Work to which such Contribution(s) was submitted. If You 83 | institute patent litigation against any entity (including a 84 | cross-claim or counterclaim in a lawsuit) alleging that the Work 85 | or a Contribution incorporated within the Work constitutes direct 86 | or contributory patent infringement, then any patent licenses 87 | granted to You under this License for that Work shall terminate 88 | as of the date such litigation is filed. 89 | 90 | 4. Redistribution. You may reproduce and distribute copies of the 91 | Work or Derivative Works thereof in any medium, with or without 92 | modifications, and in Source or Object form, provided that You 93 | meet the following conditions: 94 | 95 | (a) You must give any other recipients of the Work or 96 | Derivative Works a copy of this License; and 97 | 98 | (b) You must cause any modified files to carry prominent notices 99 | stating that You changed the files; and 100 | 101 | (c) You must retain, in the Source form of any Derivative Works 102 | that You distribute, all copyright, patent, trademark, and 103 | attribution notices from the Source form of the Work, 104 | excluding those notices that do not pertain to any part of 105 | the Derivative Works; and 106 | 107 | (d) If the Work includes a "NOTICE" text file as part of its 108 | distribution, then any Derivative Works that You distribute must 109 | include a readable copy of the attribution notices contained 110 | within such NOTICE file, excluding those notices that do not 111 | pertain to any part of the Derivative Works, in at least one 112 | of the following places: within a NOTICE text file distributed 113 | as part of the Derivative Works; within the Source form or 114 | documentation, if provided along with the Derivative Works; or, 115 | within a display generated by the Derivative Works, if and 116 | wherever such third-party notices normally appear. The contents 117 | of the NOTICE file are for informational purposes only and 118 | do not modify the License. You may add Your own attribution 119 | notices within Derivative Works that You distribute, alongside 120 | or as an addendum to the NOTICE text from the Work, provided 121 | that such additional attribution notices cannot be construed 122 | as modifying the License. 123 | 124 | You may add Your own copyright statement to Your modifications and 125 | may provide additional or different license terms and conditions 126 | for use, reproduction, or distribution of Your modifications, or 127 | for any such Derivative Works as a whole, provided Your use, 128 | reproduction, and distribution of the Work otherwise complies with 129 | the conditions stated in this License. 130 | 131 | 5. Submission of Contributions. Unless You explicitly state otherwise, 132 | any Contribution intentionally submitted for inclusion in the Work 133 | by You to the Licensor shall be under the terms and conditions of 134 | this License, without any additional terms or conditions. 135 | Notwithstanding the above, nothing herein shall supersede or modify 136 | the terms of any separate license agreement you may have executed 137 | with Licensor regarding such Contributions. 138 | 139 | 6. Trademarks. This License does not grant permission to use the trade 140 | names, trademarks, service marks, or product names of the Licensor, 141 | except as required for reasonable and customary use in describing the 142 | origin of the Work and reproducing the content of the NOTICE file. 143 | 144 | 7. Disclaimer of Warranty. Unless required by applicable law or 145 | agreed to in writing, Licensor provides the Work (and each 146 | Contributor provides its Contributions) on an "AS IS" BASIS, 147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 148 | implied, including, without limitation, any warranties or conditions 149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 150 | PARTICULAR PURPOSE. You are solely responsible for determining the 151 | appropriateness of using or redistributing the Work and assume any 152 | risks associated with Your exercise of permissions under this License. 153 | 154 | 8. Limitation of Liability. In no event and under no legal theory, 155 | whether in tort (including negligence), contract, or otherwise, 156 | unless required by applicable law (such as deliberate and grossly 157 | negligent acts) or agreed to in writing, shall any Contributor be 158 | liable to You for damages, including any direct, indirect, special, 159 | incidental, or consequential damages of any character arising as a 160 | result of this License or out of the use or inability to use the 161 | Work (including but not limited to damages for loss of goodwill, 162 | work stoppage, computer failure or malfunction, or any and all 163 | other commercial damages or losses), even if such Contributor 164 | has been advised of the possibility of such damages. 165 | 166 | 9. Accepting Warranty or Additional Liability. While redistributing 167 | the Work or Derivative Works thereof, You may choose to offer, 168 | and charge a fee for, acceptance of support, warranty, indemnity, 169 | or other liability obligations and/or rights consistent with this 170 | License. However, in accepting such obligations, You may act only 171 | on Your own behalf and on Your sole responsibility, not on behalf 172 | of any other Contributor, and only if You agree to indemnify, 173 | defend, and hold each Contributor harmless for any liability 174 | incurred by, or claims asserted against, such Contributor by reason 175 | of your accepting any such warranty or additional liability. 176 | 177 | END OF TERMS AND CONDITIONS 178 | 179 | APPENDIX: How to apply the Apache License to your work. 180 | 181 | To apply the Apache License to your work, attach the following 182 | boilerplate notice, with the fields enclosed by brackets "[]" 183 | replaced with your own identifying information. (Don't include 184 | the brackets!) The text should be enclosed in the appropriate 185 | comment syntax for the file format. We also recommend that a 186 | file or class name and description of purpose be included on the 187 | same "printed page" as the copyright notice for easier 188 | identification within third-party archives. 189 | 190 | Copyright [yyyy] [name of copyright owner] 191 | 192 | Licensed under the Apache License, Version 2.0 (the "License"); 193 | you may not use this file except in compliance with the License. 194 | You may obtain a copy of the License at 195 | 196 | http://www.apache.org/licenses/LICENSE-2.0 197 | 198 | Unless required by applicable law or agreed to in writing, software 199 | distributed under the License is distributed on an "AS IS" BASIS, 200 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 201 | See the License for the specific language governing permissions and 202 | limitations under the License. 203 | -------------------------------------------------------------------------------- /python/serving/logging_lib/cloud_logging_client.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Wrapper for cloud ops structured logging.""" 16 | from __future__ import annotations 17 | 18 | import os 19 | import sys 20 | import threading 21 | from typing import Any, Mapping, Optional, Union 22 | 23 | from absl import flags 24 | from absl import logging 25 | import google.auth 26 | import psutil 27 | 28 | from serving.logging_lib.flags import secret_flag_utils 29 | from serving.logging_lib import cloud_logging_client_instance 30 | 31 | # name of cloud ops log 32 | CLOUD_OPS_LOG_NAME_FLG = flags.DEFINE_string( 33 | 'ops_log_name', 34 | secret_flag_utils.get_secret_or_env('CLOUD_OPS_LOG_NAME', 'python'), 35 | 'Cloud ops log name to write logs to.', 36 | ) 37 | CLOUD_OPS_LOG_PROJECT_FLG = flags.DEFINE_string( 38 | 'ops_log_project', 39 | secret_flag_utils.get_secret_or_env('CLOUD_OPS_LOG_PROJECT', None), 40 | 'GCP project name to write log to. Undefined = default', 41 | ) 42 | POD_HOSTNAME_FLG = flags.DEFINE_string( 43 | 'pod_hostname', 44 | secret_flag_utils.get_secret_or_env('HOSTNAME', None), 45 | 'Host name of GKE pod. Set by container ENV. ' 46 | 'Set to mock value in unit test.', 47 | ) 48 | POD_UID_FLG = flags.DEFINE_string( 49 | 'pod_uid', 50 | secret_flag_utils.get_secret_or_env('MY_POD_UID', None), 51 | 'UID of GKE pod. Do not set unless in test.', 52 | ) 53 | ENABLE_STRUCTURED_LOGGING_FLG = flags.DEFINE_boolean( 54 | 'enable_structured_logging', 55 | secret_flag_utils.get_bool_secret_or_env( 56 | 'ENABLE_STRUCTURED_LOGGING', 57 | not secret_flag_utils.get_bool_secret_or_env( 58 | 'DISABLE_STRUCTURED_LOGGING' 59 | ), 60 | ), 61 | 'Enable structured logging.', 62 | ) 63 | ENABLE_LOGGING_FLG = flags.DEFINE_boolean( 64 | 'enable_logging', 65 | secret_flag_utils.get_bool_secret_or_env('ENABLE_LOGGING_FLG', True), 66 | 'Enable logging.', 67 | ) 68 | 69 | _DEBUG_LOGGING_USE_ABSL_LOGGING_FLG = flags.DEFINE_boolean( 70 | 'debug_logging_use_absl_logging', 71 | # Confusing double negative is used to enable external env be a postive 72 | # statement and override the existing default 73 | not secret_flag_utils.get_bool_secret_or_env( 74 | 'ENABLE_CLOUD_LOGGING', 75 | not cloud_logging_client_instance.DEBUG_LOGGING_USE_ABSL_LOGGING, 76 | ), 77 | 'Debug/testing option to logs to absl.logger. Automatically set when ' 78 | 'running unit tests.', 79 | ) 80 | 81 | LOG_ALL_PYTHON_LOGS_TO_CLOUD_FLG = flags.DEFINE_boolean( 82 | 'log_all_python_logs_to_cloud', 83 | secret_flag_utils.get_bool_secret_or_env('LOG_ALL_PYTHON_LOGS_TO_CLOUD'), 84 | 'Logs every modules log to Cloud Ops.', 85 | ) 86 | 87 | PER_THREAD_LOG_SIGNATURES_FLG = flags.DEFINE_boolean( 88 | 'per_thread_log_signatures', 89 | secret_flag_utils.get_bool_secret_or_env('PER_THREAD_LOG_SIGNATURES', True), 90 | 'If True Log signatures are not shared are across threads if false ' 91 | 'Process threads share a common log signature', 92 | ) 93 | 94 | 95 | def _are_flags_initialized() -> bool: 96 | """Returns True if flags are initialized.""" 97 | try: 98 | _ = CLOUD_OPS_LOG_PROJECT_FLG.value 99 | return True 100 | except (flags.UnparsedFlagAccessError, AttributeError): 101 | return False 102 | 103 | 104 | def _get_flags() -> Mapping[str, str]: 105 | load_flags = {} 106 | unparsed_flags = [] 107 | for flag_name in flags.FLAGS: 108 | try: 109 | load_flags[flag_name] = flags.FLAGS.__getattr__(flag_name) 110 | except flags.UnparsedFlagAccessError: 111 | unparsed_flags.append(flag_name) 112 | if unparsed_flags: 113 | load_flags['unparsed_flags'] = ', '.join(unparsed_flags) 114 | return load_flags 115 | 116 | 117 | def _default_gcp_project() -> str: 118 | try: 119 | _, project = google.auth.default( 120 | scopes=['https://www.googleapis.com/auth/cloud-platform'] 121 | ) 122 | return project if project is not None else '' 123 | except google.auth.exceptions.DefaultCredentialsError: 124 | return '' 125 | 126 | 127 | class CloudLoggingClient( 128 | cloud_logging_client_instance.CloudLoggingClientInstance 129 | ): 130 | """Wrapper for cloud ops structured logging. 131 | 132 | Automatically adds signature to structured logs to make traceable. 133 | """ 134 | 135 | # lock for log makes access to singleton 136 | # safe across threads. Logging used in main thread and ack_timeout_mon 137 | _singleton_instance: Optional[CloudLoggingClient] = None 138 | _startup_message_logged = False 139 | _singleton_lock = threading.RLock() 140 | 141 | @classmethod 142 | def _init_fork_module_state(cls) -> None: 143 | cls._singleton_instance = None 144 | cls._startup_message_logged = True 145 | cls._singleton_lock = threading.RLock() 146 | 147 | @classmethod 148 | def _fork_shutdown(cls) -> None: 149 | with cls._singleton_lock: 150 | cls._singleton_instance = None 151 | 152 | @classmethod 153 | def _set_absl_skip_frames(cls) -> None: 154 | """Sets absl logging attribution to skip over internal logging frames.""" 155 | logging.ABSLLogger.register_frame_to_skip( 156 | __file__, 157 | function_name='debug', 158 | ) 159 | logging.ABSLLogger.register_frame_to_skip( 160 | __file__, 161 | function_name='timed_debug', 162 | ) 163 | logging.ABSLLogger.register_frame_to_skip( 164 | __file__, 165 | function_name='info', 166 | ) 167 | logging.ABSLLogger.register_frame_to_skip( 168 | __file__, 169 | function_name='warning', 170 | ) 171 | logging.ABSLLogger.register_frame_to_skip( 172 | __file__, 173 | function_name='error', 174 | ) 175 | logging.ABSLLogger.register_frame_to_skip( 176 | __file__, 177 | function_name='critical', 178 | ) 179 | 180 | def __init__(self): 181 | with CloudLoggingClient._singleton_lock: 182 | if not _are_flags_initialized(): 183 | # if flags are not initialize then init logging flags 184 | flags.FLAGS(sys.argv, known_only=True) 185 | if CloudLoggingClient._singleton_instance is not None: 186 | raise cloud_logging_client_instance.CloudLoggerInstanceExceptionError( 187 | 'Singleton already initialized.' 188 | ) 189 | CloudLoggingClient._set_absl_skip_frames() 190 | gcp_project = CLOUD_OPS_LOG_PROJECT_FLG.value 191 | if gcp_project is None or not gcp_project.strip(): 192 | gcp_project = _default_gcp_project() 193 | pod_host_name = ( 194 | '' if POD_HOSTNAME_FLG.value is None else POD_HOSTNAME_FLG.value 195 | ) 196 | pod_uid = '' if POD_UID_FLG.value is None else POD_UID_FLG.value 197 | super().__init__( 198 | log_name=CLOUD_OPS_LOG_NAME_FLG.value, 199 | gcp_project_to_write_logs_to=gcp_project, 200 | gcp_credentials=None, 201 | pod_hostname=pod_host_name, 202 | pod_uid=pod_uid, 203 | enable_structured_logging=ENABLE_STRUCTURED_LOGGING_FLG.value, 204 | use_absl_logging=_DEBUG_LOGGING_USE_ABSL_LOGGING_FLG.value, 205 | log_all_python_logs_to_cloud=LOG_ALL_PYTHON_LOGS_TO_CLOUD_FLG.value, 206 | per_thread_log_signatures=PER_THREAD_LOG_SIGNATURES_FLG.value, 207 | enabled=ENABLE_LOGGING_FLG.value, 208 | ) 209 | CloudLoggingClient._singleton_instance = self 210 | 211 | def startup_msg(self) -> None: 212 | """Logs default messages after logger fully initialized.""" 213 | if self.use_absl_logging() or CloudLoggingClient._startup_message_logged: 214 | return 215 | CloudLoggingClient._startup_message_logged = True 216 | pid = os.getpid() 217 | process_name = psutil.Process(pid).name() 218 | self.debug( 219 | 'Container process started.', 220 | {'process_name': process_name, 'process_id': pid}, 221 | ) 222 | self.debug( 223 | 'Container environmental variables.', os.environ 224 | ) # pytype: disable=wrong-arg-types # kwargs-checking 225 | vm = psutil.virtual_memory() 226 | self.debug( 227 | 'Compute instance', 228 | { 229 | 'processors(count)': os.cpu_count(), 230 | 'total_system_mem_(bytes)': vm.total, 231 | 'available_system_mem_(bytes)': vm.available, 232 | }, 233 | ) 234 | self.debug('Initalized flags', _get_flags()) 235 | project_name = self.gcp_project_name if self.gcp_project_name else 'DEFAULT' 236 | self.debug(f'Logging to GCP project: {project_name}') 237 | 238 | @classmethod 239 | def logger(cls, show_startup_msg: bool = True) -> CloudLoggingClient: 240 | if cls._singleton_instance is None: 241 | with cls._singleton_lock: # makes instance creation thread safe. 242 | if cls._singleton_instance is None: 243 | cls._singleton_instance = CloudLoggingClient() 244 | if not show_startup_msg: 245 | cls._startup_message_logged = True 246 | else: 247 | cls._singleton_instance.startup_msg() # pytype: disable=attribute-error 248 | return cls._singleton_instance # pytype: disable=bad-return-type 249 | 250 | 251 | def logger() -> CloudLoggingClient: 252 | return CloudLoggingClient.logger() 253 | 254 | 255 | def do_not_log_startup_msg() -> None: 256 | CloudLoggingClient.logger(show_startup_msg=False) 257 | 258 | 259 | def debug( 260 | msg: str, 261 | *struct: Union[Mapping[str, Any], Exception, None], 262 | stack_frames_back: int = 0, 263 | ) -> None: 264 | """Logs with debug severity. 265 | 266 | Args: 267 | msg: message to log (string). 268 | *struct: zero or more dict or exception to log in structured log. 269 | stack_frames_back: Additional stack frames back to log source_location. 270 | """ 271 | logger().debug(msg, *struct, stack_frames_back=stack_frames_back + 1) 272 | 273 | 274 | def timed_debug( 275 | msg: str, 276 | *struct: Union[Mapping[str, Any], Exception, None], 277 | stack_frames_back: int = 0, 278 | ) -> None: 279 | """Logs with debug severity and elapsed time since last timed debug log. 280 | 281 | Args: 282 | msg: message to log (string). 283 | *struct: zero or more dict or exception to log in structured log. 284 | stack_frames_back: Additional stack frames back to log source_location. 285 | """ 286 | logger().timed_debug(msg, *struct, stack_frames_back=stack_frames_back + 1) 287 | 288 | 289 | def info( 290 | msg: str, 291 | *struct: Union[Mapping[str, Any], Exception, None], 292 | stack_frames_back: int = 0, 293 | ) -> None: 294 | """Logs with info severity. 295 | 296 | Args: 297 | msg: message to log (string). 298 | *struct: zero or more dict or exception to log in structured log. 299 | stack_frames_back: Additional stack frames back to log source_location. 300 | """ 301 | logger().info(msg, *struct, stack_frames_back=stack_frames_back + 1) 302 | 303 | 304 | def warning( 305 | msg: str, 306 | *struct: Union[Mapping[str, Any], Exception, None], 307 | stack_frames_back: int = 0, 308 | ) -> None: 309 | """Logs with warning severity. 310 | 311 | Args: 312 | msg: Message to log (string). 313 | *struct: Zero or more dict or exception to log in structured log. 314 | stack_frames_back: Additional stack frames back to log source_location. 315 | """ 316 | logger().warning(msg, *struct, stack_frames_back=stack_frames_back + 1) 317 | 318 | 319 | def error( 320 | msg: str, 321 | *struct: Union[Mapping[str, Any], Exception, None], 322 | stack_frames_back: int = 0, 323 | ) -> None: 324 | """Logs with error severity. 325 | 326 | Args: 327 | msg: Message to log (string). 328 | *struct: Zero or more dict or exception to log in structured log. 329 | stack_frames_back: Additional stack frames back to log source_location. 330 | """ 331 | logger().error(msg, *struct, stack_frames_back=stack_frames_back + 1) 332 | 333 | 334 | def critical( 335 | msg: str, 336 | *struct: Union[Mapping[str, Any], Exception, None], 337 | stack_frames_back: int = 0, 338 | ) -> None: 339 | """Logs with critical severity. 340 | 341 | Args: 342 | msg: Message to log (string). 343 | *struct: Zero or more dict or exception to log in structured log. 344 | stack_frames_back: Additional stack frames back to log source_location. 345 | """ 346 | logger().critical(msg, *struct, stack_frames_back=stack_frames_back + 1) 347 | 348 | 349 | def clear_log_signature() -> None: 350 | logger().clear_log_signature() 351 | 352 | 353 | def get_log_signature() -> Mapping[str, Any]: 354 | return logger().log_signature 355 | 356 | 357 | def set_log_signature(sig: Mapping[str, Any]) -> None: 358 | logger().log_signature = sig 359 | 360 | 361 | def set_per_thread_log_signatures(val: bool) -> None: 362 | logger().per_thread_log_signatures = val 363 | 364 | 365 | def get_build_version(clip_length: Optional[int] = None) -> str: 366 | if clip_length is not None and clip_length >= 0: 367 | return logger().build_version[:clip_length] 368 | return logger().build_version 369 | 370 | 371 | def set_build_version(build_version: str) -> None: 372 | logger().build_version = build_version 373 | 374 | 375 | def set_log_trace_key(key: str) -> None: 376 | logger().trace_key = key 377 | 378 | 379 | # Logging interfaces are used from processes which are forked (gunicorn, 380 | # DICOM Proxy, Orchestrator, Refresher). In Python, forked processes do not 381 | # copy threads running within parent processes or re-initalize global/module 382 | # state. This can result in forked modules being executed with invalid global 383 | # state, e.g., acquired locks that will not release or references to invalid 384 | # state. The cloud logging library utilizes a background thread transporting 385 | # logs to cloud. The background threading is not compatiable with forking and 386 | # will seg-fault (python queue wait). This can be avoided, by stoping and 387 | # the background transport prior to forking and then restarting the transport 388 | # following the fork. 389 | os.register_at_fork( 390 | before=CloudLoggingClient._fork_shutdown, # pylint: disable=protected-access 391 | after_in_child=CloudLoggingClient._init_fork_module_state, # pylint: disable=protected-access 392 | ) 393 | -------------------------------------------------------------------------------- /python/serving/logging_lib/flags/secret_flag_utils_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Tests for secret_flag_utils.""" 16 | import os 17 | from typing import Any, List 18 | from unittest import mock 19 | 20 | from absl.testing import absltest 21 | from absl.testing import parameterized 22 | import google.api_core 23 | from google.cloud import secretmanager 24 | 25 | from serving.logging_lib.flags import secret_flag_utils 26 | 27 | _MOCK_SECRET = 'projects/test-project/secrets/test-secret' 28 | 29 | 30 | def _create_mock_secret_value_response( 31 | data: Any, 32 | ) -> secretmanager.AccessSecretVersionResponse: 33 | payload = mock.create_autospec(secretmanager.SecretPayload, instance=True) 34 | payload.data = data 35 | sec_val = mock.create_autospec( 36 | secretmanager.AccessSecretVersionResponse, instance=True 37 | ) 38 | sec_val.payload = payload 39 | return sec_val 40 | 41 | 42 | def _create_mock_secret_version_response( 43 | secret_version_prefix: str, 44 | version_numbers: List[str], 45 | ) -> list[secretmanager.ListSecretVersionsResponse]: 46 | version_list = [] 47 | for version_number in version_numbers: 48 | version = mock.create_autospec( 49 | secretmanager.ListSecretVersionsResponse, instance=True 50 | ) 51 | version.name = f'{secret_version_prefix}{version_number}' 52 | version_list.append(version) 53 | return version_list 54 | 55 | 56 | class SecretFlagUtilsTest(parameterized.TestCase): 57 | 58 | def setUp(self): 59 | super().setUp() 60 | secret_flag_utils._cache.clear() 61 | self.enter_context( 62 | mock.patch( 63 | 'google.auth.default', autospec=True, return_value=('mock', 'mock') 64 | ) 65 | ) 66 | 67 | def tearDown(self): 68 | super().tearDown() 69 | secret_flag_utils._ENABLE_ENV_SECRET_MANAGER = False 70 | 71 | def test_init_fork_module_state(self): 72 | secret_flag_utils._cache = None 73 | secret_flag_utils._cache_lock = None 74 | secret_flag_utils._init_fork_module_state() 75 | self.assertIsNotNone(secret_flag_utils._cache) 76 | self.assertIsNotNone(secret_flag_utils._cache_lock) 77 | 78 | @parameterized.named_parameters( 79 | dict( 80 | testcase_name='expected', 81 | version_numbers=['1', '2', '3'], 82 | expected_version='3', 83 | ), 84 | dict( 85 | testcase_name='ignore_unexpected_version', 86 | version_numbers=['A', '1'], 87 | expected_version='1', 88 | ), 89 | ) 90 | def test_get_secret_version(self, version_numbers, expected_version): 91 | client = mock.create_autospec( 92 | secretmanager.SecretManagerServiceClient, instance=True 93 | ) 94 | client.list_secret_versions.return_value = ( 95 | _create_mock_secret_version_response( 96 | f'{_MOCK_SECRET}/versions/', version_numbers 97 | ) 98 | ) 99 | version = secret_flag_utils._get_secret_version(client, _MOCK_SECRET) 100 | self.assertEqual(version, expected_version) 101 | 102 | def test_get_secret_version_cannot_find_version(self): 103 | client = mock.create_autospec( 104 | secretmanager.SecretManagerServiceClient, instance=True 105 | ) 106 | client.list_secret_versions.return_value = ( 107 | _create_mock_secret_version_response('', ['NO_MATCH', 'NO_MATCH']) 108 | ) 109 | with self.assertRaises(secret_flag_utils.SecretDecodeError): 110 | secret_flag_utils._get_secret_version(client, _MOCK_SECRET) 111 | 112 | def test_get_secret_version_ignore_bad_response(self): 113 | client = mock.create_autospec( 114 | secretmanager.SecretManagerServiceClient, instance=True 115 | ) 116 | client.list_secret_versions.return_value = [] 117 | with self.assertRaises(secret_flag_utils.SecretDecodeError): 118 | secret_flag_utils._get_secret_version(client, _MOCK_SECRET) 119 | 120 | def test_read_secrets_path_to_version_empty(self): 121 | self.assertEqual(secret_flag_utils._read_secrets(''), {}) 122 | 123 | def test_read_secrets_path_invalid_raises(self): 124 | with self.assertRaises(secret_flag_utils.SecretDecodeError): 125 | secret_flag_utils._read_secrets('no_match') 126 | 127 | def test_read_secrets_path_disabled_returns_empty_result(self): 128 | self.assertEqual( 129 | secret_flag_utils._read_secrets('projects/prj/secrets/sec'), {} 130 | ) 131 | 132 | @parameterized.named_parameters([ 133 | dict( 134 | testcase_name='none', 135 | data=None, 136 | expected_value={}, 137 | ), 138 | dict( 139 | testcase_name='empty', 140 | data='', 141 | expected_value={}, 142 | ), 143 | dict( 144 | testcase_name='empty_dict', 145 | data='{}', 146 | expected_value={}, 147 | ), 148 | dict( 149 | testcase_name='defined_dict_str', 150 | data='{"abc": "123"}', 151 | expected_value={'abc': '123'}, 152 | ), 153 | dict( 154 | testcase_name='defined_dict_bytes', 155 | data='{"abc": "123"}'.encode('utf-8'), 156 | expected_value={'abc': '123'}, 157 | ), 158 | ]) 159 | @mock.patch.object( 160 | secret_flag_utils, '_get_secret_version', autospec=True, return_value='1' 161 | ) 162 | @mock.patch.object( 163 | secretmanager.SecretManagerServiceClient, 164 | 'access_secret_version', 165 | autospec=True, 166 | ) 167 | def test_read_secrets_value( 168 | self, mk_access_secret, mk_get_secret_version, data, expected_value 169 | ): 170 | secret_flag_utils._ENABLE_ENV_SECRET_MANAGER = True 171 | project_secret = 'projects/prj/secrets/sec' 172 | secret_version = '1' 173 | mk_get_secret_version.return_value = secret_version 174 | mk_access_secret.return_value = _create_mock_secret_value_response(data) 175 | self.assertEqual( 176 | secret_flag_utils._read_secrets(project_secret), 177 | expected_value, 178 | ) 179 | mk_get_secret_version.assert_called_once_with(mock.ANY, project_secret) 180 | mk_access_secret.assert_called_once_with( 181 | mock.ANY, 182 | request={'name': f'{project_secret}/versions/{secret_version}'}, 183 | ) 184 | 185 | @mock.patch.object(secret_flag_utils, '_get_secret_version', autospec=True) 186 | @mock.patch.object( 187 | secretmanager.SecretManagerServiceClient, 188 | 'access_secret_version', 189 | autospec=True, 190 | ) 191 | def test_read_secrets_value_does_not_call_version_if_in_path( 192 | self, mk_access_secret, mk_get_secret_version 193 | ): 194 | secret_flag_utils._ENABLE_ENV_SECRET_MANAGER = True 195 | project_secret = 'projects/prj/secrets/sec/versions/2' 196 | mk_access_secret.return_value = _create_mock_secret_value_response('') 197 | self.assertEqual( 198 | secret_flag_utils._read_secrets(project_secret), 199 | {}, 200 | ) 201 | mk_get_secret_version.assert_not_called() 202 | mk_access_secret.assert_called_once_with( 203 | mock.ANY, request={'name': project_secret} 204 | ) 205 | 206 | def test_read_secrets_returns_cache_value(self): 207 | secret_flag_utils._ENABLE_ENV_SECRET_MANAGER = True 208 | project_secret = 'projects/prj/secrets/sec/versions/3' 209 | expected = 'EXPECTED_VALUE' 210 | secret_flag_utils._cache[project_secret] = expected 211 | self.assertEqual(secret_flag_utils._read_secrets(project_secret), expected) 212 | 213 | @parameterized.named_parameters([ 214 | dict(testcase_name='invalid_json', response='{'), 215 | dict(testcase_name='not_dict', response='[]'), 216 | ]) 217 | @mock.patch.object( 218 | secretmanager.SecretManagerServiceClient, 219 | 'access_secret_version', 220 | autospec=True, 221 | ) 222 | def test_read_secrets_returns_invalid_value_raises( 223 | self, mk_access_secret, response 224 | ): 225 | secret_flag_utils._ENABLE_ENV_SECRET_MANAGER = True 226 | project_secret = 'projects/prj/secrets/sec/versions/2' 227 | mk_access_secret.return_value = _create_mock_secret_value_response(response) 228 | with self.assertRaises(secret_flag_utils.SecretDecodeError): 229 | secret_flag_utils._read_secrets(project_secret) 230 | 231 | @parameterized.parameters([ 232 | google.api_core.exceptions.NotFound('mock'), 233 | google.api_core.exceptions.PermissionDenied('mock'), 234 | ]) 235 | @mock.patch.object( 236 | secretmanager.SecretManagerServiceClient, 237 | 'access_secret_version', 238 | autospec=True, 239 | ) 240 | def test_access_secret_version_raises(self, exp, mk_access_secret): 241 | secret_flag_utils._ENABLE_ENV_SECRET_MANAGER = True 242 | project_secret = 'projects/prj/secrets/sec/versions/2' 243 | mk_access_secret.side_effect = exp 244 | with self.assertRaises(secret_flag_utils.SecretDecodeError): 245 | secret_flag_utils._read_secrets(project_secret) 246 | 247 | @parameterized.parameters([ 248 | google.api_core.exceptions.NotFound('mock'), 249 | google.api_core.exceptions.PermissionDenied('mock'), 250 | ]) 251 | @mock.patch.object(secret_flag_utils, '_get_secret_version', autospec=True) 252 | def test_get_secret_version_raises(self, exp, mk_get_secret_version): 253 | secret_flag_utils._ENABLE_ENV_SECRET_MANAGER = True 254 | project_secret = 'projects/prj/secrets/sec' 255 | mk_get_secret_version.side_effect = exp 256 | with self.assertRaises(secret_flag_utils.SecretDecodeError): 257 | secret_flag_utils._read_secrets(project_secret) 258 | 259 | @parameterized.named_parameters([ 260 | dict( 261 | testcase_name='defined_env', 262 | env_name='MOCK_ENV', 263 | expected_value='DEFINED_VALUE', 264 | ), 265 | dict( 266 | testcase_name='not_defined_env', 267 | env_name='UNDEFINED_ENV', 268 | expected_value='MOCK_DEFAULT', 269 | ), 270 | ]) 271 | def test_get_secret_or_env_no_config_defined_returns_env_value_or_default( 272 | self, env_name, expected_value 273 | ): 274 | with mock.patch.dict(os.environ, {'MOCK_ENV': 'DEFINED_VALUE'}): 275 | self.assertEqual( 276 | secret_flag_utils.get_secret_or_env(env_name, 'MOCK_DEFAULT'), 277 | expected_value, 278 | ) 279 | 280 | @parameterized.named_parameters([ 281 | dict( 282 | testcase_name='secret_env', 283 | search_env='SECRET_ENV', 284 | expected='SECRET_VALUE', 285 | ), 286 | dict( 287 | testcase_name='not_in_sec', 288 | search_env='NOT_IN_SEC', 289 | expected='ENV_VALUE', 290 | ), 291 | dict( 292 | testcase_name='not_in_env', 293 | search_env='NOT_IN_ENV', 294 | expected='MOCK_DEFAULT', 295 | ), 296 | ]) 297 | @mock.patch.object( 298 | secretmanager.SecretManagerServiceClient, 299 | 'access_secret_version', 300 | autospec=True, 301 | ) 302 | def test_get_secret_or_env_with_config( 303 | self, mk_access_secret, search_env, expected 304 | ): 305 | secret_flag_utils._ENABLE_ENV_SECRET_MANAGER = True 306 | project_secret = 'projects/prj/secrets/sec/versions/1' 307 | mk_access_secret.return_value = _create_mock_secret_value_response( 308 | '{"SECRET_ENV": "SECRET_VALUE"}' 309 | ) 310 | with mock.patch.dict( 311 | os.environ, 312 | { 313 | secret_flag_utils._SECRET_MANAGER_ENV_CONFIG: project_secret, 314 | 'NOT_IN_SEC': 'ENV_VALUE', 315 | }, 316 | ): 317 | self.assertEqual( 318 | secret_flag_utils.get_secret_or_env(search_env, 'MOCK_DEFAULT'), 319 | expected, 320 | ) 321 | mk_access_secret.assert_called_once_with( 322 | mock.ANY, 323 | request={'name': project_secret}, 324 | ) 325 | 326 | @mock.patch.object( 327 | secretmanager.SecretManagerServiceClient, 328 | 'access_secret_version', 329 | autospec=True, 330 | ) 331 | def test_get_secret_or_env_default_return_value(self, mk_access_secret): 332 | secret_flag_utils._ENABLE_ENV_SECRET_MANAGER = True 333 | project_secret = 'projects/prj/secrets/sec/versions/1' 334 | mk_access_secret.return_value = _create_mock_secret_value_response( 335 | '{"MOCK_ENV": "SECRET_VALUE"}' 336 | ) 337 | with mock.patch.dict( 338 | os.environ, 339 | { 340 | secret_flag_utils._SECRET_MANAGER_ENV_CONFIG: project_secret, 341 | }, 342 | ): 343 | self.assertIsNone( 344 | secret_flag_utils.get_secret_or_env('NOT_IN_SECRET_OR_ENV', None) 345 | ) 346 | mk_access_secret.assert_called_once_with( 347 | mock.ANY, 348 | request={'name': project_secret}, 349 | ) 350 | 351 | @parameterized.named_parameters([ 352 | dict(testcase_name='FOUND', env_name='MOCK_ENV', expected_value=True), 353 | dict( 354 | testcase_name='DEFAULT_MISSING', 355 | env_name='NOT_FOUND_ENV', 356 | expected_value=False, 357 | ), 358 | ]) 359 | @mock.patch.object( 360 | secretmanager.SecretManagerServiceClient, 361 | 'access_secret_version', 362 | autospec=True, 363 | ) 364 | def test_get_bool_secret_or_env( 365 | self, mk_access_secret, env_name, expected_value 366 | ): 367 | secret_flag_utils._ENABLE_ENV_SECRET_MANAGER = True 368 | project_secret = 'projects/prj/secrets/sec/versions/1' 369 | mk_access_secret.return_value = _create_mock_secret_value_response( 370 | '{"MOCK_ENV": "TRUE"}' 371 | ) 372 | with mock.patch.dict( 373 | os.environ, 374 | { 375 | secret_flag_utils._SECRET_MANAGER_ENV_CONFIG: project_secret, 376 | }, 377 | ): 378 | self.assertEqual( 379 | secret_flag_utils.get_bool_secret_or_env(env_name), expected_value 380 | ) 381 | mk_access_secret.assert_called_once_with( 382 | mock.ANY, 383 | request={'name': project_secret}, 384 | ) 385 | 386 | @parameterized.named_parameters([ 387 | dict(testcase_name='FOUND', env_name='ENV_FOUND', expected_value=False), 388 | dict( 389 | testcase_name='DEFAULT_MISSING', 390 | env_name='NOT_FOUND_ENV', 391 | expected_value=True, 392 | ), 393 | ]) 394 | @mock.patch.object( 395 | secretmanager.SecretManagerServiceClient, 396 | 'access_secret_version', 397 | autospec=True, 398 | ) 399 | def test_get_bool_secret_or_env_trys_env( 400 | self, mk_access_secret, env_name, expected_value 401 | ): 402 | secret_flag_utils._ENABLE_ENV_SECRET_MANAGER = True 403 | project_secret = 'projects/prj/secrets/sec/versions/1' 404 | mk_access_secret.return_value = _create_mock_secret_value_response('{}') 405 | with mock.patch.dict( 406 | os.environ, 407 | { 408 | secret_flag_utils._SECRET_MANAGER_ENV_CONFIG: project_secret, 409 | 'ENV_FOUND': 'False', 410 | }, 411 | ): 412 | self.assertEqual( 413 | secret_flag_utils.get_bool_secret_or_env(env_name, True), 414 | expected_value, 415 | ) 416 | mk_access_secret.assert_called_once_with( 417 | mock.ANY, 418 | request={'name': project_secret}, 419 | ) 420 | 421 | 422 | if __name__ == '__main__': 423 | absltest.main() 424 | -------------------------------------------------------------------------------- /notebooks/quick_start_with_model_garden.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": { 7 | "id": "tYba2hfAs0AS" 8 | }, 9 | "outputs": [], 10 | "source": [ 11 | "# Copyright 2025 Google LLC\n", 12 | "#\n", 13 | "# Licensed under the Apache License, Version 2.0 (the \"License\");\n", 14 | "# you may not use this file except in compliance with the License.\n", 15 | "# You may obtain a copy of the License at\n", 16 | "#\n", 17 | "# https://www.apache.org/licenses/LICENSE-2.0\n", 18 | "#\n", 19 | "# Unless required by applicable law or agreed to in writing, software\n", 20 | "# distributed under the License is distributed on an \"AS IS\" BASIS,\n", 21 | "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", 22 | "# See the License for the specific language governing permissions and\n", 23 | "# limitations under the License." 24 | ] 25 | }, 26 | { 27 | "cell_type": "markdown", 28 | "metadata": { 29 | "id": "kILpWz4Gs5Kh" 30 | }, 31 | "source": [ 32 | "# Quick start with Model Garden - MedASR\n", 33 | "\n", 34 | "\n", 35 | " \n", 40 | " \n", 45 | "
\n", 36 | " \n", 37 | " \"Google
Run in Colab Enterprise\n", 38 | "
\n", 39 | "
\n", 41 | " \n", 42 | " \"GitHub
View on GitHub\n", 43 | "
\n", 44 | "
" 46 | ] 47 | }, 48 | { 49 | "cell_type": "markdown", 50 | "metadata": { 51 | "id": "JRGljGF2lUWd" 52 | }, 53 | "source": [ 54 | "## Overview\n", 55 | "\n", 56 | "This notebook demonstrates how to use MedASR in Vertex AI to transcribe medical audio to text using online inference.\n", 57 | "\n", 58 | "**Online inferences** are synchronous requests that are made to the endpoint deployed from Model Garden and are served with low latency. Online inferences are useful if the model outputs are being used in production. The cost for online inference is based on the time a virtual machine spends waiting in an active state (an endpoint with a deployed model) to handle inference requests.\n", 59 | "\n", 60 | "Vertex AI makes it easy to serve your model and make it accessible to the world. Learn more about [Vertex AI](https://cloud.google.com/vertex-ai/docs/start/introduction-unified-platform).\n", 61 | "\n", 62 | "### Objectives\n", 63 | "\n", 64 | "- Deploy MedASR to a Vertex AI Endpoint and get online inferences.\n", 65 | "\n", 66 | "### Costs\n", 67 | "\n", 68 | "This tutorial uses billable components of Google Cloud:\n", 69 | "\n", 70 | "* Vertex AI\n", 71 | "* Cloud Storage\n", 72 | "\n", 73 | "Learn about [Vertex AI pricing](https://cloud.google.com/vertex-ai/pricing), [Cloud Storage pricing](https://cloud.google.com/storage/pricing), and use the [Pricing Calculator](https://cloud.google.com/products/calculator/) to generate a cost estimate based on your projected usage." 74 | ] 75 | }, 76 | { 77 | "cell_type": "markdown", 78 | "metadata": { 79 | "id": "emPr6M1fs5Kj" 80 | }, 81 | "source": [ 82 | "## Before you begin" 83 | ] 84 | }, 85 | { 86 | "cell_type": "code", 87 | "execution_count": null, 88 | "metadata": { 89 | "id": "ag_zmUlgmJD8", 90 | "cellView": "form" 91 | }, 92 | "outputs": [], 93 | "source": [ 94 | "# @title Install dependencies and import packages\n", 95 | "\n", 96 | "! pip install -qU --upgrade pip\n", 97 | "! pip install -qU 'google-cloud-aiplatform>=1.101.0' jiwer levenshtein\n", 98 | "\n", 99 | "import base64\n", 100 | "import json\n", 101 | "import os\n", 102 | "\n", 103 | "from google.cloud import aiplatform\n", 104 | "from IPython.display import Audio, display\n", 105 | "\n", 106 | "models, endpoints = {}, {}" 107 | ] 108 | }, 109 | { 110 | "cell_type": "code", 111 | "execution_count": null, 112 | "metadata": { 113 | "cellView": "form", 114 | "id": "zMsN9Ep1mJD8" 115 | }, 116 | "outputs": [], 117 | "source": [ 118 | "# @title Set up Google Cloud environment\n", 119 | "\n", 120 | "# @markdown #### Prerequisites\n", 121 | "\n", 122 | "# @markdown 1. Make sure that [billing is enabled](https://cloud.google.com/billing/docs/how-to/modify-project) for your project.\n", 123 | "\n", 124 | "# @markdown 2. Make sure that either the Compute Engine API is enabled or that you have the [Service Usage Admin](https://cloud.google.com/iam/docs/understanding-roles#serviceusage.serviceUsageAdmin) (`roles/serviceusage.serviceUsageAdmin`) role to enable the API.\n", 125 | "\n", 126 | "# @markdown This section sets the default Google Cloud project and region, enables the Compute Engine API (if not already enabled), and initializes the Vertex AI API.\n", 127 | "\n", 128 | "# Get the default project ID.\n", 129 | "PROJECT_ID = os.environ[\"GOOGLE_CLOUD_PROJECT\"]\n", 130 | "\n", 131 | "# Get the default region for launching jobs.\n", 132 | "REGION = os.environ[\"GOOGLE_CLOUD_REGION\"]\n", 133 | "\n", 134 | "# Enable the Compute Engine API, if not already.\n", 135 | "print(\"Enabling Compute Engine API.\")\n", 136 | "! gcloud services enable compute.googleapis.com\n", 137 | "\n", 138 | "# Initialize Vertex AI API.\n", 139 | "print(\"Initializing Vertex AI API.\")\n", 140 | "aiplatform.init(project=PROJECT_ID, location=REGION)" 141 | ] 142 | }, 143 | { 144 | "cell_type": "code", 145 | "source": [ 146 | "# @title Retrieve sample data\n", 147 | "\n", 148 | "# @markdown This notebook uses a sample medical audio file and transcript.\n", 149 | "\n", 150 | "! gcloud storage cp gs://healthai-us/medasr/test_audio.wav test_audio.wav\n", 151 | "with open(\"test_audio.wav\", \"rb\") as f:\n", 152 | " audio_bytes = f.read()\n", 153 | "sample_transcript = \"Exam type CT chest PE protocol period. Indication 54 year old female, shortness of breath, evaluate for PE period. Technique standard protocol period. Findings colon. Pulmonary vasculature colon. The main PA is patent period. There are filling defects in the segmental branches of the right lower lobe comma compatible with acute PE period. No saddle embolus period. Lungs colon. No pneumothorax period. Small bilateral effusions comma right greater than left period. New paragraph. Impression colon. Acute segmental PE right lower lobe period.\"\n", 154 | "display(Audio(audio_bytes, autoplay=False))" 155 | ], 156 | "metadata": { 157 | "id": "4GJMn-QvZ4up", 158 | "cellView": "form" 159 | }, 160 | "execution_count": null, 161 | "outputs": [] 162 | }, 163 | { 164 | "cell_type": "code", 165 | "source": [ 166 | "# @title Define utility functions\n", 167 | "\n", 168 | "# @markdown These functions will be used to evaluate the word error rate (WER) of the generated transcripts.\n", 169 | "\n", 170 | "import re\n", 171 | "import jiwer\n", 172 | "import Levenshtein\n", 173 | "\n", 174 | "def normalize(s: str) -> str:\n", 175 | " s = s.lower()\n", 176 | " s = re.sub(r\"[^ a-z0-9']\", ' ', s)\n", 177 | " s = ' '.join(s.split())\n", 178 | " return s\n", 179 | "\n", 180 | "def _colored(text, color):\n", 181 | " if color == 'red':\n", 182 | " return f\"\\033[91m{text}\\033[0m\"\n", 183 | " elif color == 'green':\n", 184 | " return f\"\\033[92m{text}\\033[0m\"\n", 185 | " return text\n", 186 | "\n", 187 | "def evaluate(\n", 188 | " ref_text: str,\n", 189 | " hyp_text: str,\n", 190 | " delete_color: str = 'red',\n", 191 | " insert_color: str = 'green',\n", 192 | ") -> None:\n", 193 | " print('HYP:', hyp_text)\n", 194 | " normalized_ref = normalize(ref_text)\n", 195 | " normalized_hyp = normalize(hyp_text)\n", 196 | "\n", 197 | " # Calculate word lists early so we can use them for both jiwer and diffs\n", 198 | " ref_words = normalized_ref.split()\n", 199 | " hyp_words = normalized_hyp.split()\n", 200 | "\n", 201 | " # jiwer.process_words expects a list of strings (sentences) or list of list of words\n", 202 | " measures = jiwer.process_words([normalized_ref], [normalized_hyp])\n", 203 | "\n", 204 | " # Calculate edit operations using Levenshtein for the colored diff\n", 205 | " edits = Levenshtein.editops(ref_words, hyp_words)\n", 206 | "\n", 207 | " r = 0 # Index for the reference words for diff building\n", 208 | " diff = ''\n", 209 | "\n", 210 | " for op, i, j in edits:\n", 211 | " # Add matched words before the current edit\n", 212 | " if r < i:\n", 213 | " diff += ' ' + ' '.join(ref_words[r:i])\n", 214 | " r = i # Update reference index for next iteration\n", 215 | "\n", 216 | " if op == 'replace':\n", 217 | " diff += (\n", 218 | " f' {_colored(f\"{{-{ref_words[i]}-}}\", delete_color)}'\n", 219 | " f' {_colored(f\"{{+{hyp_words[j]}+}}\", insert_color)}'\n", 220 | " )\n", 221 | " r += 1 # Advance reference index after replacement\n", 222 | " elif op == 'insert':\n", 223 | " diff += f' {_colored(f\"{{+{hyp_words[j]}+}}\", insert_color)}'\n", 224 | " # Reference index `r` does not advance for an insertion\n", 225 | " elif op == 'delete':\n", 226 | " diff += f' {_colored(f\"{{-{ref_words[i]}-}}\", delete_color)}'\n", 227 | " r += 1 # Advance reference index after deletion\n", 228 | "\n", 229 | " # Add any remaining matched words from the reference\n", 230 | " if r < len(ref_words):\n", 231 | " diff += ' ' + ' '.join(ref_words[r:])\n", 232 | "\n", 233 | " print(\n", 234 | " f'WER: {measures.wer * 100:.2f}%: '\n", 235 | " f'insertions {measures.insertions}, deletions {measures.deletions}, substitutions {measures.substitutions}, '\n", 236 | " f'ref tokens {len(ref_words)}'\n", 237 | " )\n", 238 | " print(diff)" 239 | ], 240 | "metadata": { 241 | "cellView": "form", 242 | "id": "1bz6RCO2c7h4" 243 | }, 244 | "execution_count": null, 245 | "outputs": [] 246 | }, 247 | { 248 | "cell_type": "markdown", 249 | "metadata": { 250 | "id": "Bak-klNTmJD8" 251 | }, 252 | "source": [ 253 | "## Get online inferences" 254 | ] 255 | }, 256 | { 257 | "cell_type": "code", 258 | "execution_count": null, 259 | "metadata": { 260 | "id": "tjOWCeGJu-94", 261 | "cellView": "form" 262 | }, 263 | "outputs": [], 264 | "source": [ 265 | "# @title Import deployed model\n", 266 | "\n", 267 | "# @markdown To get [online inferences](https://cloud.google.com/vertex-ai/docs/predictions/get-online-predictions), you will need a MedASR [Vertex AI Endpoint](https://cloud.google.com/vertex-ai/docs/general/deployment) that has been deployed from Model Garden. If you have not already done so, go to the [MedASR model card](https://console.cloud.google.com/vertex-ai/publishers/google/model-garden/medasr) and click \"Deploy model\" to deploy the model.\n", 268 | "\n", 269 | "# @markdown Note: Endpoints deployed from Model Garden must be [dedicated endpoints](https://cloud.google.com/vertex-ai/docs/predictions/choose-endpoint-type).\n", 270 | "\n", 271 | "# @markdown This section gets the Vertex AI Endpoint resource that you deployed from Model Garden to use for online inferences.\n", 272 | "\n", 273 | "# @markdown Fill in the endpoint ID and region below. You can find your deployed endpoint on the [Vertex AI Endpoints page](https://console.cloud.google.com/vertex-ai/online-prediction/endpoints).\n", 274 | "\n", 275 | "ENDPOINT_ID = \"\" # @param {type: \"string\", placeholder:\"e.g. 123456789\"}\n", 276 | "ENDPOINT_REGION = \"\" # @param {type: \"string\", placeholder:\"e.g. us-central1\"}\n", 277 | "\n", 278 | "endpoints[\"endpoint\"] = aiplatform.Endpoint(\n", 279 | " endpoint_name=ENDPOINT_ID,\n", 280 | " project=PROJECT_ID,\n", 281 | " location=ENDPOINT_REGION,\n", 282 | ")" 283 | ] 284 | }, 285 | { 286 | "cell_type": "code", 287 | "execution_count": null, 288 | "metadata": { 289 | "id": "XhijNE6PnWn7", 290 | "cellView": "form" 291 | }, 292 | "outputs": [], 293 | "source": [ 294 | "# @title Run inference using the Vertex AI SDK\n", 295 | "\n", 296 | "# @markdown This section shows how to send [online prediction](https://cloud.google.com/vertex-ai/docs/predictions/get-online-predictions) requests to your Vertex AI endpoint.\n", 297 | "\n", 298 | "# @markdown Click \"Show code\" to see more details.\n", 299 | "\n", 300 | "request = {\n", 301 | " \"file\": base64.b64encode(audio_bytes).decode(\"utf-8\"),\n", 302 | "}\n", 303 | "\n", 304 | "response = endpoints[\"endpoint\"].raw_predict(\n", 305 | " body=json.dumps(request).encode(\"utf-8\"),\n", 306 | " headers={\n", 307 | " \"Content-Type\": \"application/json\",\n", 308 | " },\n", 309 | ")\n", 310 | "generated_transcript = json.loads(response.content)[\"text\"]\n", 311 | "\n", 312 | "print(generated_transcript)\n", 313 | "evaluate(sample_transcript, generated_transcript)" 314 | ] 315 | }, 316 | { 317 | "cell_type": "markdown", 318 | "metadata": { 319 | "id": "4yfFlkncxzDF" 320 | }, 321 | "source": [ 322 | "\n", 323 | "## Next steps\n", 324 | "\n", 325 | "Explore the other [notebooks](https://github.com/google-health/medasr/blob/main/notebooks) to learn what else you can do with the model.\n" 326 | ] 327 | }, 328 | { 329 | "cell_type": "markdown", 330 | "metadata": { 331 | "id": "paQNyrzT_mX_" 332 | }, 333 | "source": [ 334 | "## Clean up resources" 335 | ] 336 | }, 337 | { 338 | "cell_type": "code", 339 | "execution_count": null, 340 | "metadata": { 341 | "id": "edUIpvZZ_mYA", 342 | "cellView": "form" 343 | }, 344 | "outputs": [], 345 | "source": [ 346 | "# @markdown Delete the experiment models and endpoints to recycle the resources\n", 347 | "# @markdown and avoid unnecessary continuous charges that may incur.\n", 348 | "\n", 349 | "# Undeploy model and delete endpoint.\n", 350 | "for endpoint in endpoints.values():\n", 351 | " endpoint.delete(force=True)\n", 352 | "\n", 353 | "# Delete models.\n", 354 | "for model in models.values():\n", 355 | " model.delete()" 356 | ] 357 | } 358 | ], 359 | "metadata": { 360 | "colab": { 361 | "provenance": [], 362 | "toc_visible": true 363 | }, 364 | "kernelspec": { 365 | "display_name": "Python 3", 366 | "name": "python3" 367 | }, 368 | "language_info": { 369 | "name": "python" 370 | } 371 | }, 372 | "nbformat": 4, 373 | "nbformat_minor": 0 374 | } 375 | -------------------------------------------------------------------------------- /python/serving/logging_lib/cloud_logging_client_instance.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Wrapper for cloud ops structured logging. 16 | 17 | Get instance of logger: 18 | logger() -> CloudLoggingClient 19 | 20 | logger().log_signature = dict to append to logs. 21 | """ 22 | import collections 23 | import copy 24 | import enum 25 | import inspect 26 | import logging 27 | import math 28 | import os 29 | import sys 30 | import threading 31 | import time 32 | import traceback 33 | from typing import Any, Mapping, MutableMapping, Optional, Tuple, Union 34 | 35 | from absl import logging as absl_logging 36 | import google.auth 37 | from google.cloud import logging as cloud_logging 38 | 39 | # Debug/testing option to logs to absl.logger. Automatically set when 40 | # running unit tests 41 | DEBUG_LOGGING_USE_ABSL_LOGGING = bool( 42 | 'UNITTEST_ON_FORGE' in os.environ or 'unittest' in sys.modules 43 | ) 44 | 45 | 46 | class _LogSeverity(enum.Enum): 47 | CRITICAL = logging.CRITICAL 48 | ERROR = logging.ERROR 49 | WARNING = logging.WARNING 50 | INFO = logging.INFO 51 | DEBUG = logging.DEBUG 52 | 53 | 54 | MAX_LOG_SIZE = 246000 55 | 56 | 57 | class CloudLoggerInstanceExceptionError(Exception): 58 | pass 59 | 60 | 61 | def _merge_struct( 62 | dict_tuple: Tuple[Union[Mapping[str, Any], Exception, None], ...], 63 | ) -> Optional[MutableMapping[str, str]]: 64 | """Merges a list of dict and ordered dicts. 65 | 66 | * for dict adds item in key sorted order 67 | * preserves order for ordered dicts. 68 | Args: 69 | dict_tuple: dicts and ordered dicts to merge 70 | 71 | Returns: 72 | merged dict. 73 | """ 74 | if not dict_tuple: 75 | return None 76 | return_dict = collections.OrderedDict() 77 | for dt in dict_tuple: 78 | if dt is None: 79 | continue 80 | if isinstance(dt, Exception): 81 | # Log exception text and exception stack trace. 82 | exception_str = str(dt) 83 | if exception_str: 84 | exception_str = f'{exception_str}\n' 85 | return_dict['exception'] = f'{exception_str}{traceback.format_exc()}' 86 | else: 87 | keylist = list(dt) 88 | if not isinstance(dt, collections.OrderedDict): 89 | keylist = sorted(keylist) 90 | for key in keylist: 91 | return_dict[key] = str(dt[key]) 92 | return return_dict 93 | 94 | 95 | def _absl_log(msg: str, severity: _LogSeverity = _LogSeverity.INFO) -> None: 96 | """Logs using absl logging. 97 | 98 | Args: 99 | msg: Message to log. 100 | severity: Severity of message. 101 | """ 102 | if severity == _LogSeverity.DEBUG: 103 | absl_logging.debug(msg) 104 | elif severity == _LogSeverity.WARNING: 105 | absl_logging.warning(msg) 106 | elif severity == _LogSeverity.INFO: 107 | absl_logging.info(msg) 108 | else: 109 | absl_logging.error(msg) 110 | 111 | 112 | def _py_log( 113 | dpas_logger: logging.Logger, 114 | msg: str, 115 | extra: Mapping[str, Any], 116 | severity: _LogSeverity, 117 | ) -> None: 118 | """Logs msg and structured logging using python logger.""" 119 | if severity == _LogSeverity.DEBUG: 120 | dpas_logger.debug(msg, extra=extra) 121 | elif severity == _LogSeverity.WARNING: 122 | dpas_logger.warning(msg, extra=extra) 123 | elif severity == _LogSeverity.INFO: 124 | dpas_logger.info(msg, extra=extra) 125 | elif severity == _LogSeverity.CRITICAL: 126 | dpas_logger.critical(msg, extra=extra) 127 | elif severity == _LogSeverity.ERROR: 128 | dpas_logger.error(msg, extra=extra) 129 | else: 130 | raise CloudLoggerInstanceExceptionError( 131 | f'Unsupported logging severity level; Severity="{severity}"' 132 | ) 133 | 134 | 135 | def _get_source_location_to_log(stack_frames_back: int) -> Mapping[str, Any]: 136 | """Adds Python source location information to cloud structured logs. 137 | 138 | The source location is added by adding (and overwriting if present) a 139 | "source_location" key to the provided additional_parameters. 140 | The value corresponding to that key is a dict mapping: 141 | "file" to the name of the file (str) containing the logging statement, 142 | "function" python function/method calling logging method 143 | "line" to the line number where the log was recorded (int). 144 | 145 | Args: 146 | stack_frames_back: Additional stack frames back to log source_location. 147 | 148 | Returns: 149 | Source location formatted for structured logging. 150 | 151 | Raises: 152 | ValueError: If stack frame cannot be found for specified position. 153 | """ 154 | source_location = {} 155 | current_frame = inspect.currentframe() 156 | for _ in range(stack_frames_back + 1): 157 | if current_frame is None: 158 | raise ValueError('Cannot get stack frame for specified position.') 159 | current_frame = current_frame.f_back 160 | try: 161 | frame_info = inspect.getframeinfo(current_frame) 162 | source_location['source_location'] = dict( 163 | file=frame_info.filename, 164 | function=frame_info.function, 165 | line=frame_info.lineno, 166 | ) 167 | finally: 168 | # https://docs.python.org/3/library/inspect.html 169 | del current_frame # explicitly deleting 170 | return source_location 171 | 172 | 173 | def _add_trace_to_log( 174 | project_id: str, trace_key: str, struct: Mapping[str, Any] 175 | ) -> Mapping[str, Any]: 176 | if not project_id or not trace_key: 177 | return {} 178 | trace_id = struct.get(trace_key, '') 179 | if trace_id: 180 | return {'trace': f'projects/{project_id}/traces/{trace_id}'} 181 | return {} 182 | 183 | 184 | class CloudLoggingClientInstance: 185 | """Wrapper for cloud ops structured logging. 186 | 187 | Automatically adds signature to structured logs to make traceable. 188 | """ 189 | 190 | # global state to prevent duplicate initialization of cloud logging interfaces 191 | # within a process. 192 | _global_lock = threading.Lock() 193 | # Cloud logging handler init at process level. 194 | _cloud_logging_handler: Optional[ 195 | cloud_logging.handlers.CloudLoggingHandler 196 | ] = None 197 | _cloud_logging_handler_init_params = '' 198 | 199 | @classmethod 200 | def _init_fork_module_state(cls) -> None: 201 | cls._global_lock = threading.Lock() 202 | cls._cloud_logging_handler = None 203 | cls._cloud_logging_handler_init_params = '' 204 | 205 | @classmethod 206 | def fork_shutdown(cls) -> None: 207 | with cls._global_lock: 208 | cls._cloud_logging_handler_init_params = '' 209 | handler = cls._cloud_logging_handler 210 | if handler is None: 211 | return 212 | handler.transport.worker.stop() 213 | logging.getLogger().removeHandler(handler) 214 | handler.close() 215 | cls._cloud_logging_handler = None 216 | 217 | @classmethod 218 | def _set_absl_skip_frames(cls) -> None: 219 | """Sets absl logging attribution to skip over internal logging frames.""" 220 | absl_logging.ABSLLogger.register_frame_to_skip( 221 | __file__, 222 | function_name='_log', 223 | ) 224 | absl_logging.ABSLLogger.register_frame_to_skip( 225 | __file__, 226 | function_name='_absl_log', 227 | ) 228 | absl_logging.ABSLLogger.register_frame_to_skip( 229 | __file__, 230 | function_name='debug', 231 | ) 232 | absl_logging.ABSLLogger.register_frame_to_skip( 233 | __file__, 234 | function_name='timed_debug', 235 | ) 236 | absl_logging.ABSLLogger.register_frame_to_skip( 237 | __file__, 238 | function_name='info', 239 | ) 240 | absl_logging.ABSLLogger.register_frame_to_skip( 241 | __file__, 242 | function_name='warning', 243 | ) 244 | absl_logging.ABSLLogger.register_frame_to_skip( 245 | __file__, 246 | function_name='error', 247 | ) 248 | absl_logging.ABSLLogger.register_frame_to_skip( 249 | __file__, 250 | function_name='critical', 251 | ) 252 | 253 | def __init__( 254 | self, 255 | log_name: str = 'python', 256 | gcp_project_to_write_logs_to: str = '', 257 | gcp_credentials: Optional[google.auth.credentials.Credentials] = None, 258 | pod_hostname: str = '', 259 | pod_uid: str = '', 260 | enable_structured_logging: bool = True, 261 | use_absl_logging: bool = DEBUG_LOGGING_USE_ABSL_LOGGING, 262 | log_all_python_logs_to_cloud: bool = False, 263 | enabled: bool = True, 264 | log_error_level: int = _LogSeverity.DEBUG.value, 265 | per_thread_log_signatures: bool = True, 266 | trace_key: str = '', 267 | build_version: str = '', 268 | ): 269 | """Constructor. 270 | 271 | Args: 272 | log_name: Log name to write logs to. 273 | gcp_project_to_write_logs_to: GCP project name to write log to. Undefined 274 | = default. 275 | gcp_credentials: The OAuth2 Credentials to use for this client 276 | (None=default). 277 | pod_hostname: Host name of GKE pod. Should be empty if not running in GKE. 278 | pod_uid: UID of GKE pod. Should be empty if not running in GKE. 279 | enable_structured_logging: Enable structured logging. 280 | use_absl_logging: Send logs to absl logging instead of cloud_logging. 281 | log_all_python_logs_to_cloud: Logs everything to cloud. 282 | enabled: If disabled, logging is not initialized and logging operations 283 | are nops. 284 | log_error_level: Error level at which logger will log. 285 | per_thread_log_signatures: Log signatures reported per thread. 286 | trace_key: Log key value which contains a trace id value. 287 | build_version: Build version to embedd in logs container. 288 | """ 289 | # lock for log makes access to singleton 290 | # safe across threads. Logging used in main thread and ack_timeout_mon 291 | CloudLoggingClientInstance._set_absl_skip_frames() 292 | self._build_version = build_version 293 | self._enabled = enabled 294 | self._trace_key = trace_key 295 | self._log_error_level = log_error_level 296 | self._log_lock = threading.RLock() 297 | self._log_name = log_name.strip() 298 | self._pod_hostname = pod_hostname.strip() 299 | self._pod_uid = pod_uid.strip() 300 | self._per_thread_log_signatures = per_thread_log_signatures 301 | self._thread_local_storage = threading.local() 302 | self._shared_log_signature = self._signature_defaults(0) 303 | self._enable_structured_logging = enable_structured_logging 304 | self._debug_log_time = time.time() 305 | self._gcp_project_name = gcp_project_to_write_logs_to.strip() 306 | self._use_absl_logging = use_absl_logging 307 | self._log_all_python_logs_to_cloud = log_all_python_logs_to_cloud 308 | self._gcp_credentials = gcp_credentials 309 | absl_logging.set_verbosity(absl_logging.INFO) 310 | self._python_logger = self._init_cloud_handler() 311 | 312 | @property 313 | def trace_key(self) -> str: 314 | return self._trace_key 315 | 316 | @trace_key.setter 317 | def trace_key(self, val: str) -> None: 318 | self._trace_key = val 319 | 320 | @property 321 | def per_thread_log_signatures(self) -> bool: 322 | return self._per_thread_log_signatures 323 | 324 | @per_thread_log_signatures.setter 325 | def per_thread_log_signatures(self, val: bool) -> None: 326 | with self._log_lock: 327 | self._per_thread_log_signatures = val 328 | 329 | @property 330 | def python_logger(self) -> logging.Logger: 331 | if ( 332 | self._enabled 333 | and not self._use_absl_logging 334 | and CloudLoggingClientInstance._cloud_logging_handler is None 335 | ): 336 | self._python_logger = self._init_cloud_handler() 337 | return self._python_logger 338 | 339 | def _get_python_logger_name(self) -> Optional[str]: 340 | return None if self._log_all_python_logs_to_cloud else 'DPASLogger' 341 | 342 | def _get_cloud_logging_handler_init_params(self) -> str: 343 | return ( 344 | f'GCP_PROJECT_NAME: {self._gcp_project_name}; LOG_NAME:' 345 | f' {self._log_name}; LOG_ALL: {self._log_all_python_logs_to_cloud}' 346 | ) 347 | 348 | def _init_cloud_handler(self) -> logging.Logger: 349 | """Initializes cloud logging handler and returns python logger.""" 350 | # Instantiates a cloud logging client to generate text logs for cloud 351 | # operations 352 | with CloudLoggingClientInstance._global_lock: 353 | if not self._enabled or self._use_absl_logging: 354 | return logging.getLogger() # Default PY logger 355 | handler_instance_init_params = ( 356 | self._get_cloud_logging_handler_init_params() 357 | ) 358 | if CloudLoggingClientInstance._cloud_logging_handler is not None: 359 | running_handler_init_params = ( 360 | CloudLoggingClientInstance._cloud_logging_handler_init_params 361 | ) 362 | if running_handler_init_params != handler_instance_init_params: 363 | # Call fork_shutdown to shutdown the process's named logging handler. 364 | raise CloudLoggerInstanceExceptionError( 365 | 'Cloud logging handler is running with parameters that do not' 366 | ' match instance defined parameters. Running handler parameters:' 367 | f' {running_handler_init_params}; Instance parameters:' 368 | f' {handler_instance_init_params}' 369 | ) 370 | return logging.getLogger(self._get_python_logger_name()) 371 | log_name = self.log_name 372 | struct_log = {} 373 | struct_log['log_name'] = log_name 374 | struct_log['log_all_python_logs'] = self._log_all_python_logs_to_cloud 375 | try: 376 | # Attach default python & absl logger to also write to named log. 377 | logging_client = cloud_logging.Client( 378 | project=self._gcp_project_name if self._gcp_project_name else None, 379 | credentials=self._gcp_credentials, 380 | ) 381 | logging_client.project = ( 382 | self._gcp_project_name if self._gcp_project_name else None 383 | ) 384 | handler = cloud_logging.handlers.CloudLoggingHandler( 385 | client=logging_client, 386 | name=log_name, 387 | ) 388 | CloudLoggingClientInstance._cloud_logging_handler = handler 389 | CloudLoggingClientInstance._cloud_logging_handler_init_params = ( 390 | handler_instance_init_params 391 | ) 392 | cloud_logging.handlers.setup_logging( 393 | handler, 394 | log_level=logging.DEBUG 395 | if self._log_all_python_logs_to_cloud 396 | else logging.INFO, 397 | ) 398 | dpas_python_logger = logging.getLogger(self._get_python_logger_name()) 399 | dpas_python_logger.setLevel( 400 | logging.DEBUG 401 | ) # pytype: disable=attribute-error 402 | return dpas_python_logger 403 | except google.auth.exceptions.DefaultCredentialsError as exp: 404 | self._use_absl_logging = True 405 | self.error('Error initializing logging.', struct_log, exp) 406 | return logging.getLogger() 407 | except Exception as exp: 408 | self._use_absl_logging = True 409 | self.error('Error unexpected exception.', struct_log, exp) 410 | raise 411 | 412 | def __getstate__(self) -> MutableMapping[str, Any]: 413 | """Returns log state for pickle removes lock.""" 414 | dct = copy.copy(self.__dict__) 415 | del dct['_log_lock'] 416 | del dct['_python_logger'] 417 | del dct['_thread_local_storage'] 418 | return dct 419 | 420 | def __setstate__(self, dct: MutableMapping[str, Any]): 421 | """Un-pickles class and re-creates log lock.""" 422 | self.__dict__ = dct 423 | self._log_lock = threading.RLock() 424 | self._thread_local_storage = threading.local() 425 | # Re-init logging in process. 426 | self._python_logger = self._init_cloud_handler() 427 | 428 | def use_absl_logging(self) -> bool: 429 | return self._use_absl_logging 430 | 431 | @property 432 | def enable_structured_logging(self) -> bool: 433 | return self._enable_structured_logging 434 | 435 | @enable_structured_logging.setter 436 | def enable_structured_logging(self, val: bool) -> None: 437 | with self._log_lock: 438 | self._enable_structured_logging = val 439 | 440 | @property 441 | def enable(self) -> bool: 442 | return self._enabled 443 | 444 | @enable.setter 445 | def enable(self, val: bool) -> None: 446 | with self._log_lock: 447 | self._enabled = val 448 | 449 | @property 450 | def gcp_project_name(self) -> str: 451 | return self._gcp_project_name 452 | 453 | @property 454 | def log_name(self) -> str: 455 | if not self._log_name: 456 | raise ValueError('Undefined Log Name') 457 | return self._log_name 458 | 459 | @property 460 | def hostname(self) -> str: 461 | if not self._pod_hostname: 462 | raise ValueError('POD_HOSTNAME name is not defined.') 463 | return self._pod_hostname 464 | 465 | @property 466 | def build_version(self) -> str: 467 | """Returns build version # for container.""" 468 | return self._build_version 469 | 470 | @build_version.setter 471 | def build_version(self, version: str) -> None: 472 | """Returns build version # for container.""" 473 | self._build_version = version 474 | with self._log_lock: 475 | if self._per_thread_log_signatures: 476 | if not hasattr(self._thread_local_storage, 'signature'): 477 | self._thread_local_storage.signature = self._signature_defaults( 478 | threading.get_native_id() 479 | ) 480 | log_sig = self._thread_local_storage.signature 481 | else: 482 | log_sig = self._shared_log_signature 483 | if not version and 'BUILD_VERSION' in log_sig: 484 | del log_sig['BUILD_VERSION'] 485 | else: 486 | log_sig['BUILD_VERSION'] = str(version) 487 | 488 | @property 489 | def pod_uid(self) -> str: 490 | if not self._pod_uid: 491 | raise ValueError('Undefined POD UID') 492 | return self._pod_uid 493 | 494 | def _get_thread_signature(self) -> MutableMapping[str, Any]: 495 | if not self._per_thread_log_signatures: 496 | return self._shared_log_signature 497 | if not hasattr(self._thread_local_storage, 'signature'): 498 | self._thread_local_storage.signature = self._signature_defaults( 499 | threading.get_native_id() 500 | ) 501 | return self._thread_local_storage.signature 502 | 503 | @property 504 | def log_signature(self) -> MutableMapping[str, Any]: 505 | """Returns log signature. 506 | 507 | Log signature returned may not match what is currently being logged. 508 | if thread is set to log using another threads log signature. 509 | """ 510 | with self._log_lock: 511 | return copy.copy(self._get_thread_signature()) 512 | 513 | def _signature_defaults(self, thread_id: int) -> MutableMapping[str, str]: 514 | """Returns default log signature.""" 515 | log_signature = collections.OrderedDict() 516 | if self._pod_hostname: 517 | log_signature['HOSTNAME'] = str(self._pod_hostname) 518 | if self._pod_uid: 519 | log_signature['POD_UID'] = str(self._pod_uid) 520 | if self.build_version: 521 | log_signature['BUILD_VERSION'] = str(self.build_version) 522 | if self._per_thread_log_signatures: 523 | log_signature['THREAD_ID'] = str(thread_id) 524 | return log_signature 525 | 526 | @log_signature.setter 527 | def log_signature(self, sig: Mapping[str, Any]) -> None: 528 | """Sets log signature. 529 | 530 | Log signature of thread may not be altered if thread is set to log using 531 | another threads log signature. 532 | 533 | Args: 534 | sig: Signature for threads to logs to use. 535 | """ 536 | with self._log_lock: 537 | if self._per_thread_log_signatures: 538 | thread_id = threading.get_native_id() 539 | if not hasattr(self._thread_local_storage, 'log_signature'): 540 | self._thread_local_storage.signature = collections.OrderedDict() 541 | log_sig = self._thread_local_storage.signature 542 | else: 543 | thread_id = 0 544 | log_sig = self._shared_log_signature 545 | log_sig.clear() 546 | if sig is not None: 547 | for key in sorted(sig): 548 | log_sig[str(key)] = str(sig[key]) 549 | log_sig.update(self._signature_defaults(thread_id)) 550 | 551 | def clear_log_signature(self) -> None: 552 | """Clears thread log signature.""" 553 | with self._log_lock: 554 | if self._per_thread_log_signatures: 555 | self._thread_local_storage.signature = self._signature_defaults( 556 | threading.get_native_id() 557 | ) 558 | else: 559 | self._shared_log_signature = self._signature_defaults(0) 560 | 561 | def _clip_struct_log( 562 | self, log: MutableMapping[str, Any], max_log_size: int 563 | ) -> None: 564 | """Clip log if structed log exceeds structured log size limits. 565 | 566 | Clipping logic: 567 | log size = total sum of key + value sizes of log structure 568 | 569 | log['message'] and signature components are not clipped to keep 570 | log message text un-altered and message traceability preserved. 571 | log structure keys are not altered. 572 | 573 | Structured logs exceeding size typically have a massive component which. 574 | First try to clip the log by just clipping the largest clippable value. 575 | If log still exceeds size. Proportionally clip log values. 576 | 577 | Args: 578 | log: Structured log. 579 | max_log_size: Max_size of the log. 580 | 581 | Returns: 582 | None 583 | 584 | Raises: 585 | ValueError if log cannot be clipped to maxsize. 586 | """ 587 | # determine length of total message key + value 588 | total_size = 0 589 | total_value_size = 0 590 | for key, value in log.items(): 591 | total_size += len(key) 592 | total_value_size += len(value) 593 | total_size += total_value_size 594 | exceeds_log = total_size - max_log_size 595 | if exceeds_log <= 0: 596 | return 597 | 598 | # remove keys for log values not being adjusted 599 | # message is not adjust 600 | # message signature is not adjusted 601 | log_keys = set(log) 602 | excluded_key_msg_length = 0 603 | excluded_key_msg_value_length = 0 604 | excluded_keys = list(self._get_thread_signature()) 605 | excluded_keys.append('message') 606 | for excluded_key in excluded_keys: 607 | if excluded_key not in log_keys: 608 | continue 609 | value_len = len(str(log[excluded_key])) 610 | excluded_key_msg_length += len(excluded_key) + value_len 611 | excluded_key_msg_value_length += value_len 612 | log_keys.remove(excluded_key) 613 | if excluded_key_msg_length >= max_log_size: 614 | raise ValueError('Message exceeds logging msg length limit.') 615 | total_value_size -= excluded_key_msg_value_length 616 | 617 | # message exceeded length limits due to components in structure log 618 | self._log( 619 | 'Next log message exceed cloud ops length limit and as clipped.', 620 | severity=_LogSeverity.WARNING, 621 | struct=tuple(), 622 | stack_frames_back=0, 623 | ) 624 | 625 | # First clip largest entry. In most cases a message will have one 626 | # very large tag which causes the length issue first clip the single 627 | # largest entry so its at most not bigger than the second largest entry. 628 | if len(log_keys) > 1: 629 | key_size_list = [] 630 | for key in log_keys: 631 | key_size_list.append((key, len(log[key]))) 632 | key_size_list = sorted(key_size_list, key=lambda x: x[1]) 633 | largest_key = key_size_list[-1][0] 634 | # difference in size between largest and second largest entry 635 | largest_key_size_delta = key_size_list[-1][1] - key_size_list[-2][1] 636 | clip_len = min(largest_key_size_delta, exceeds_log) 637 | if clip_len > 0: 638 | log[largest_key] = log[largest_key][:-clip_len] 639 | # adjust length that needs to be trimmed 640 | exceeds_log -= clip_len 641 | if exceeds_log == 0: 642 | return 643 | # adjust total size of trimmable value component 644 | total_value_size -= clip_len 645 | 646 | # Proportionally clip all tags 647 | new_exceeds_log = exceeds_log 648 | # iterate over a sorted list to make clipping deterministic 649 | for key in sorted(list(log_keys)): 650 | entry_size = len(log[key]) 651 | clip_len = math.ceil(entry_size * exceeds_log / total_value_size) 652 | clip_len = min(min(clip_len, new_exceeds_log), entry_size) 653 | if clip_len > 0: 654 | log[key] = log[key][:-clip_len] 655 | new_exceeds_log -= clip_len 656 | if new_exceeds_log == 0: 657 | return 658 | raise ValueError('Message exceeds logging msg length limit.') 659 | 660 | def _merge_signature( 661 | self, struct: Optional[MutableMapping[str, Any]] 662 | ) -> MutableMapping[str, Any]: 663 | """Adds signature to logging struct. 664 | 665 | Args: 666 | struct: logging struct. 667 | 668 | Returns: 669 | Dict to log 670 | """ 671 | if struct is None: 672 | struct = collections.OrderedDict() 673 | struct.update(self._get_thread_signature()) 674 | return struct 675 | 676 | def _log( 677 | self, 678 | msg: str, 679 | severity: _LogSeverity, 680 | struct: Tuple[Union[Mapping[str, Any], Exception, None], ...], 681 | stack_frames_back: int = 0, 682 | ): 683 | """Posts structured log message, adds current_msg id to log structure. 684 | 685 | Args: 686 | msg: Message to log. 687 | severity: Severity level of message. 688 | struct: Structure to log. 689 | stack_frames_back: Additional stack frames back to log source_location. 690 | """ 691 | if not self._enabled: 692 | return 693 | with self._log_lock: 694 | if severity.value < self._log_error_level: 695 | return 696 | struct = self._merge_signature(_merge_struct(struct)) 697 | if not self.use_absl_logging() and self._enable_structured_logging: 698 | # Log using structured logs 699 | source_location = _get_source_location_to_log(stack_frames_back + 1) 700 | trace = _add_trace_to_log( 701 | self._gcp_project_name, self._trace_key, struct 702 | ) 703 | self._clip_struct_log(struct, MAX_LOG_SIZE) 704 | _py_log( 705 | self.python_logger, 706 | msg, 707 | extra={'json_fields': struct, **source_location, **trace}, 708 | severity=severity, 709 | ) 710 | return 711 | 712 | # Log using unstructured logs. 713 | structure_str = [msg] 714 | for key in struct: 715 | structure_str.append(f'{key}: {struct[key]}') 716 | _absl_log('; '.join(structure_str), severity=severity) 717 | 718 | def debug( 719 | self, 720 | msg: str, 721 | *struct: Union[Mapping[str, Any], Exception, None], 722 | stack_frames_back: int = 0, 723 | ) -> None: 724 | """Logs with debug severity. 725 | 726 | Args: 727 | msg: message to log (string). 728 | *struct: zero or more dict or exception to log in structured log. 729 | stack_frames_back: Additional stack frames back to log source_location. 730 | """ 731 | self._log(msg, _LogSeverity.DEBUG, struct, 1 + stack_frames_back) 732 | 733 | def timed_debug( 734 | self, 735 | msg: str, 736 | *struct: Union[Mapping[str, Any], Exception, None], 737 | stack_frames_back: int = 0, 738 | ) -> None: 739 | """Logs with debug severity and elapsed time since last timed debug log. 740 | 741 | Args: 742 | msg: message to log (string). 743 | *struct: zero or more dict or exception to log in structured log. 744 | stack_frames_back: Additional stack frames back to log source_location. 745 | """ 746 | time_now = time.time() 747 | elapsed_time = '%.3f' % (time_now - self._debug_log_time) 748 | self._debug_log_time = time_now 749 | msg = f'[{elapsed_time}] {msg}' 750 | self._log(msg, _LogSeverity.DEBUG, struct, 1 + stack_frames_back) 751 | 752 | def info( 753 | self, 754 | msg: str, 755 | *struct: Union[Mapping[str, Any], Exception, None], 756 | stack_frames_back: int = 0, 757 | ) -> None: 758 | """Logs with info severity. 759 | 760 | Args: 761 | msg: message to log (string). 762 | *struct: zero or more dict or exception to log in structured log. 763 | stack_frames_back: Additional stack frames back to log source_location. 764 | """ 765 | self._log(msg, _LogSeverity.INFO, struct, 1 + stack_frames_back) 766 | 767 | def warning( 768 | self, 769 | msg: str, 770 | *struct: Union[Mapping[str, Any], Exception, None], 771 | stack_frames_back: int = 0, 772 | ) -> None: 773 | """Logs with warning severity. 774 | 775 | Args: 776 | msg: Message to log (string). 777 | *struct: Zero or more dict or exception to log in structured log. 778 | stack_frames_back: Additional stack frames back to log source_location. 779 | """ 780 | self._log(msg, _LogSeverity.WARNING, struct, 1 + stack_frames_back) 781 | 782 | def error( 783 | self, 784 | msg: str, 785 | *struct: Union[Mapping[str, Any], Exception, None], 786 | stack_frames_back: int = 0, 787 | ) -> None: 788 | """Logs with error severity. 789 | 790 | Args: 791 | msg: Message to log (string). 792 | *struct: Zero or more dict or exception to log in structured log. 793 | stack_frames_back: Additional stack frames back to log source_location. 794 | """ 795 | self._log(msg, _LogSeverity.ERROR, struct, 1 + stack_frames_back) 796 | 797 | def critical( 798 | self, 799 | msg: str, 800 | *struct: Union[Mapping[str, Any], Exception, None], 801 | stack_frames_back: int = 0, 802 | ) -> None: 803 | """Logs with critical severity. 804 | 805 | Args: 806 | msg: Message to log (string). 807 | *struct: Zero or more dict or exception to log in structured log. 808 | stack_frames_back: Additional stack frames back to log source_location. 809 | """ 810 | self._log(msg, _LogSeverity.CRITICAL, struct, 1 + stack_frames_back) 811 | 812 | @property 813 | def log_error_level(self) -> int: 814 | return self._log_error_level 815 | 816 | @log_error_level.setter 817 | def log_error_level(self, level: int) -> None: 818 | with self._log_lock: 819 | self._log_error_level = level 820 | 821 | 822 | # Logging interfaces are used from processes which are forked (gunicorn, 823 | # DICOM Proxy, Orchestrator, Refresher). In Python, forked processes do not 824 | # copy threads running within parent processes or re-initialize global/module 825 | # state. This can result in forked modules being executed with invalid global 826 | # state, e.g., acquired locks that will not release or references to invalid 827 | # state. The cloud logging library utilizes a background thread transporting 828 | # logs to cloud. The background threading is not compatible with forking and 829 | # will seg-fault (python queue wait). This can be avoided, by stopping the 830 | # background transport prior to forking and then restarting the transport 831 | # following the fork. 832 | os.register_at_fork( 833 | before=CloudLoggingClientInstance.fork_shutdown, # pylint: disable=protected-access 834 | after_in_child=CloudLoggingClientInstance._init_fork_module_state, # pylint: disable=protected-access 835 | ) 836 | --------------------------------------------------------------------------------