├── python ├── serving │ ├── logging_lib │ │ ├── __init__.py │ │ ├── flags │ │ │ ├── __init__.py │ │ │ ├── flag_utils.py │ │ │ ├── flag_utils_test.py │ │ │ ├── secret_flag_utils.py │ │ │ └── secret_flag_utils_test.py │ │ └── cloud_logging_client.py │ ├── serving_framework │ │ ├── tensorflow │ │ │ ├── __init__.py │ │ │ ├── requirements.in │ │ │ ├── inline_model_runner.py │ │ │ ├── server_model_runner.py │ │ │ ├── inline_model_runner_test.py │ │ │ └── server_model_runner_test.py │ │ ├── pip-install.txt │ │ ├── requirements.in │ │ ├── README.md │ │ ├── __init__.py │ │ ├── model_transfer.py │ │ ├── inline_prediction_executor_test.py │ │ ├── inline_prediction_executor.py │ │ ├── model_runner.py │ │ ├── server_gunicorn.py │ │ └── server_gunicorn_test.py │ ├── testdata │ │ ├── google.jpg │ │ ├── test.dcm │ │ ├── dcm_frame_1.jpg │ │ └── multiframe_camelyon_challenge_image.dcm │ ├── __init__.py │ ├── data_models │ │ ├── __init__.py │ │ ├── patch_coordinate.py │ │ ├── patch_coordinate_test.py │ │ ├── embedding_request.py │ │ ├── embedding_request_test.py │ │ ├── embedding_response.py │ │ └── embedding_response_test.py │ ├── test_utils │ │ ├── __init__.py │ │ ├── test_files.py │ │ └── pete_mock.py │ ├── requirements.in │ ├── abstract_pete_predictor.py │ ├── pete_errors_test.py │ ├── pete_flags_test.py │ ├── entrypoint.sh │ ├── Dockerfile │ ├── pete_error_mapping_test.py │ ├── pete_errors.py │ ├── pete_error_mapping.py │ ├── pete_logging.py │ ├── README.md │ ├── pete_test_util.py │ ├── pete_flags.py │ ├── server_gunicorn.py │ ├── pete_prediction_executor.py │ ├── vertex_schemata │ │ ├── prediction.yaml │ │ └── instance.yaml │ └── pete2_e2e_test.py └── requirements.txt ├── .github ├── ISSUE_TEMPLATE │ ├── config.yml │ ├── 1-bug.md │ └── 2-feature-request.md └── PULL_REQUEST_TEMPLATE │ └── pull_request_template.md ├── CONTRIBUTING.md ├── notebooks ├── README.md └── quick_start_with_hugging_face.ipynb ├── README.md └── LICENSE /python/serving/logging_lib/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /python/serving/logging_lib/flags/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /python/requirements.txt: -------------------------------------------------------------------------------- 1 | -r serving/requirements.txt 2 | -------------------------------------------------------------------------------- /python/serving/serving_framework/tensorflow/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /python/serving/testdata/google.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Google-Health/path-foundation/HEAD/python/serving/testdata/google.jpg -------------------------------------------------------------------------------- /python/serving/testdata/test.dcm: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Google-Health/path-foundation/HEAD/python/serving/testdata/test.dcm -------------------------------------------------------------------------------- /python/serving/testdata/dcm_frame_1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Google-Health/path-foundation/HEAD/python/serving/testdata/dcm_frame_1.jpg -------------------------------------------------------------------------------- /python/serving/serving_framework/tensorflow/requirements.in: -------------------------------------------------------------------------------- 1 | tensorflow~=2.18.0 2 | tensorflow-serving-api~=2.18.0 3 | tensorflow-io-gcs-filesystem>=0.23.1 -------------------------------------------------------------------------------- /python/serving/testdata/multiframe_camelyon_challenge_image.dcm: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Google-Health/path-foundation/HEAD/python/serving/testdata/multiframe_camelyon_challenge_image.dcm -------------------------------------------------------------------------------- /python/serving/serving_framework/pip-install.txt: -------------------------------------------------------------------------------- 1 | # Requirements file for updating pip. 2 | pip==25.0 \ 3 | --hash=sha256:8e0a97f7b4c47ae4a494560da84775e9e2f671d415d8d828e052efefb206b30b \ 4 | --hash=sha256:b6eb97a803356a52b2dd4bb73ba9e65b2ba16caa6bcb25a7497350a4e5859b65 5 | -------------------------------------------------------------------------------- /python/serving/serving_framework/requirements.in: -------------------------------------------------------------------------------- 1 | absl-py>=2.1.0 2 | flask~=3.0.3 3 | grpcio~=1.68.1 4 | grpcio-status~=1.68.1 5 | gunicorn~=23.0.0 6 | numpy<=2.0.2 # bypassing faulty version restriction in tritonclient 7 | requests~=2.32.3 8 | # TODO: b/375469331 - Enable testing with most current requests-mock release. 9 | requests-mock==1.9.3 10 | setuptools~=75.6.0 11 | typing-extensions~=4.12.2 12 | -------------------------------------------------------------------------------- /python/serving/serving_framework/README.md: -------------------------------------------------------------------------------- 1 | # Serving framework 2 | 3 | This directory contains a Python library that simplifies the creation of custom 4 | prediction containers for Vertex AI. It provides the framework for building HTTP 5 | servers that meet the platform's 6 | [requirements](https://cloud.google.com/vertex-ai/docs/predictions/custom-container-requirements). 7 | 8 | To implement a model-specific HTTP server, frame the custom data handling and 9 | orchestration logic within this framework. 10 | -------------------------------------------------------------------------------- /python/serving/__init__.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright 2024 Google LLC 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | -------------------------------------------------------------------------------- /python/serving/data_models/__init__.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright 2024 Google LLC 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | -------------------------------------------------------------------------------- /python/serving/test_utils/__init__.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright 2024 Google LLC 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | -------------------------------------------------------------------------------- /python/serving/serving_framework/__init__.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright 2024 Google LLC 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/config.yml: -------------------------------------------------------------------------------- 1 | # Allow users to create issues that don't follow the templates since they don't cover all use cases 2 | blank_issues_enabled: true 3 | 4 | # Redirect users to other channels for general support or security issues 5 | contact_links: 6 | - name: Community Support 7 | url: https://github.com/google-health/$REPO_NAME/discussions 8 | about: Please ask and answer questions here. 9 | - name: Security Bug Reporting 10 | url: https://g.co/vulnz 11 | about: > 12 | To report a security issue, please use https://g.co/vulnz. The Google Security Team will 13 | respond within 5 working days of your report on https://g.co/vulnz. -------------------------------------------------------------------------------- /python/serving/requirements.in: -------------------------------------------------------------------------------- 1 | absl-py~=2.1.0 2 | ez-wsi-dicomweb~=6.0.8 3 | google-cloud-aiplatform~=1.70.0 4 | google-cloud-logging~=3.11.3 5 | google-cloud-secret-manager~=2.20.2 6 | jsonschema~=4.23.0 7 | numpy~=2.0.2 8 | pyyaml~=6.0.2 9 | redis~=5.1.1 10 | async-timeout~=4.0.3 # Required for older python compatibility. 11 | requests~=2.32.3 12 | mock~=5.1.0 13 | 14 | # requests mock > 1.9.3 adds a RLock which stops the mock from handling 15 | # simultaneous requests from different threads. 16 | # TODO: b/375469331 - Enable testing with most current requests-mock release. 17 | requests-mock==1.9.3 18 | 19 | tensorflow~=2.18.0 20 | 21 | -r serving_framework/requirements.in 22 | -r serving_framework/tensorflow/requirements.in 23 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/1-bug.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Bug Report 3 | about: Create a bug report to help us improve 4 | title: 'Bug: ' 5 | labels: 'bug', 'needs triage' 6 | assignees: '' 7 | --- 8 | 9 | ## Describe the overall issue and situation 10 | 11 | Provide a clear summary of what the issue is about, the area of the project you 12 | found it in, and what you were trying to do. 13 | 14 | ## Expected behavior 15 | 16 | Provide a clear and concise description of what you expected to happen 17 | 18 | ## Actual behavior 19 | 20 | Provide a clear and concise description of what actually happened. 21 | 22 | ## Steps to reproduce the issue 23 | 24 | Provide a sequence of steps we can use to reproduce the issue. 25 | 26 | 1. 27 | 2. 28 | 3. 29 | 30 | ## Any additional content 31 | 32 | Describe your environment or any other set up details that might help us 33 | reproduce the issue. 34 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/2-feature-request.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Feature Request 3 | about: Suggest an idea or improvement 4 | title: 'Request: ' 5 | labels: 'enhancement', 'needs triage' 6 | assignees: '' 7 | --- 8 | 9 | ## Describe the overall idea and motivation 10 | 11 | Provide a clear summary of the idea and what use cases it's addressing. 12 | 13 | ## Related to an issue? 14 | 15 | Is this addressing a known / documented issue? If so, which one? 16 | 17 | ## Possible solutions and alternatives 18 | 19 | Do you already have an idea of how the solution should work? If so, document 20 | that here. 21 | 22 | Also, if there are alternatives, please document those as well. 23 | 24 | ## Priority and timeline considerations 25 | 26 | Is this time sensitive? Is it a nice to have? Please describe what priority you 27 | feel this should have and why. We'll take this into advisement as we go through 28 | our internal prioritization process. 29 | 30 | ## Additional context 31 | 32 | Is there anything else to consider that wasn't covered by the above? 33 | 34 | Would you like to contribute to the project and work on this request? 35 | -------------------------------------------------------------------------------- /python/serving/abstract_pete_predictor.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Defines abstract pathology embedding prediction interface.""" 16 | import abc 17 | from typing import Any, Mapping 18 | 19 | from serving.serving_framework import model_runner 20 | 21 | 22 | class AbstractPetePredictor(metaclass=abc.ABCMeta): 23 | 24 | @abc.abstractmethod 25 | def predict( 26 | self, 27 | prediction_input: Mapping[str, Any], 28 | model: model_runner.ModelRunner, 29 | ) -> Mapping[str, Any]: 30 | """Returns embeddings for embedding prediction requests.""" 31 | -------------------------------------------------------------------------------- /python/serving/test_utils/test_files.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """test_util for pete.""" 16 | 17 | import os 18 | 19 | 20 | TEST_STORE_PATH = 'projects/project_name/locations/us-west1/datasets/dataset_name/dicomStores/dicom_store_name' 21 | 22 | 23 | def testdata_path(*args: str) -> str: 24 | base_path = [os.path.join(os.path.dirname(os.path.dirname(__file__)), 'testdata')] 25 | base_path.extend(args) 26 | return os.path.join(*base_path) 27 | 28 | 29 | def test_multi_frame_dicom_instance_path() -> str: 30 | return testdata_path('multiframe_camelyon_challenge_image.dcm') 31 | -------------------------------------------------------------------------------- /python/serving/pete_errors_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Tests for pete errors.""" 16 | 17 | import inspect 18 | import sys 19 | 20 | from absl.testing import absltest 21 | from absl.testing import parameterized 22 | from serving import pete_errors 23 | 24 | 25 | class PeteErrorsTest(parameterized.TestCase): 26 | 27 | def test_all_errors_use_base_class(self): 28 | for _, cls in inspect.getmembers( 29 | sys.modules[pete_errors.__name__], inspect.isclass 30 | ): 31 | if cls.__name__ == 'InternalBugError': 32 | continue 33 | self.assertTrue(issubclass(cls, pete_errors.PeteError)) 34 | 35 | 36 | if __name__ == '__main__': 37 | absltest.main() 38 | -------------------------------------------------------------------------------- /python/serving/pete_flags_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Tests for pete flags.""" 16 | 17 | from absl.testing import absltest 18 | 19 | from serving import pete_flags 20 | 21 | 22 | class PeteFlagsTest(absltest.TestCase): 23 | 24 | def test_default_load_multi_string_returns_none(self): 25 | self.assertIsNone(pete_flags._load_multi_string(None)) 26 | 27 | def test_load_multi_string_parse_json_string(self): 28 | self.assertEqual(pete_flags._load_multi_string('["a", "b"]'), ['a', 'b']) 29 | 30 | def test_bad_json_returns_value(self): 31 | self.assertEqual(pete_flags._load_multi_string('/abc'), '/abc') 32 | 33 | 34 | if __name__ == '__main__': 35 | absltest.main() 36 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # How to Contribute 2 | 3 | We would love to accept your patches and contributions to this project. 4 | 5 | ## Before you begin 6 | 7 | ### Sign our Contributor License Agreement 8 | 9 | Contributions to this project must be accompanied by a 10 | [Contributor License Agreement](https://cla.developers.google.com/about) (CLA). 11 | You (or your employer) retain the copyright to your contribution; this simply 12 | gives us permission to use and redistribute your contributions as part of the 13 | project. 14 | 15 | If you or your current employer have already signed the Google CLA (even if it 16 | was for a different project), you probably don't need to do it again. 17 | 18 | Visit to see your current agreements or to 19 | sign a new one. 20 | 21 | ### Review our Community Guidelines 22 | 23 | This project follows HAI-DEF's 24 | [Community guidelines](https://developers.google.com/health-ai-developer-foundations/community-guidelines) 25 | 26 | ## Contribution process 27 | 28 | ### Code Reviews 29 | 30 | All submissions, including submissions by project members, require review. We 31 | use [GitHub pull requests](https://docs.github.com/articles/about-pull-requests) 32 | for this purpose. 33 | 34 | For more details, read HAI-DEF's 35 | [Contributing guidelines](https://developers.google.com/health-ai-developer-foundations/community-guidelines#contributing) 36 | -------------------------------------------------------------------------------- /python/serving/entrypoint.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Copyright 2024 Google LLC 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | # This script launches the serving framework, run as the entrypoint. 17 | 18 | # Exit if any command fails or if expanding an undefined variable. 19 | set -eu 20 | 21 | export MODEL_REST_PORT=8600 22 | export LOCAL_MODEL_PATH=/model 23 | 24 | echo "Serving framework start, launching model server" 25 | 26 | /server-env/bin/python3.12 -m serving.serving_framework.model_transfer \ 27 | --gcs_path="${AIP_STORAGE_URI}" \ 28 | --local_path="${LOCAL_MODEL_PATH}/1" 29 | /usr/bin/tensorflow_model_server \ 30 | --xla_cpu_compilation_enabled=true \ 31 | --port=8500 \ 32 | --rest_api_port="${MODEL_REST_PORT}" \ 33 | --model_name=default \ 34 | --model_base_path="${LOCAL_MODEL_PATH}" & 35 | 36 | echo "Launching front end" 37 | 38 | /server-env/bin/python3.12 -m serving.server_gunicorn --alsologtostderr \ 39 | --verbosity=1 & 40 | 41 | # Wait for any process to exit 42 | wait -n 43 | 44 | # Exit with status of process that exited first 45 | exit $? 46 | -------------------------------------------------------------------------------- /python/serving/Dockerfile: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | # This is used to build a Docker image that includes the necessary dependencies 16 | # for running the Path Foundation as a microservice. 17 | 18 | FROM tensorflow/serving:2.18.0 19 | 20 | COPY ./python/serving /serving 21 | COPY ./LICENSE /LICENSE 22 | RUN chmod a+x /serving/entrypoint.sh 23 | 24 | # Install python3.12 and git 25 | RUN apt-get update && \ 26 | apt-get -y upgrade && \ 27 | apt-get install -y software-properties-common && \ 28 | add-apt-repository ppa:deadsnakes/ppa && \ 29 | apt-get install -y git python3.12 python3.12-venv 30 | 31 | # Get pypi requirements 32 | RUN python3.12 -m venv /server-env && \ 33 | /server-env/bin/python3.12 -m pip install --require-hashes \ 34 | -r /serving/serving_framework/pip-install.txt && \ 35 | /server-env/bin/python3.12 -m pip install --require-hashes \ 36 | -r /serving/requirements.txt 37 | 38 | # Clone python-certifi to meet MPL 2.0 License terms for source code mirroring 39 | RUN git clone https://github.com/certifi/python-certifi.git 40 | 41 | ENTRYPOINT ["/serving/entrypoint.sh"] 42 | -------------------------------------------------------------------------------- /python/serving/pete_error_mapping_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Tests for pete error mapping.""" 16 | 17 | import inspect 18 | import sys 19 | 20 | from absl.testing import absltest 21 | from absl.testing import parameterized 22 | 23 | from serving import pete_error_mapping 24 | from serving import pete_errors 25 | from serving.data_models import embedding_response 26 | 27 | 28 | class PeteErrorMappingTest(parameterized.TestCase): 29 | 30 | def test_all_errors_use_base_class(self): 31 | for _, cls in inspect.getmembers( 32 | sys.modules[pete_errors.__name__], inspect.isclass 33 | ): 34 | if cls.__name__ == 'InternalBugError': 35 | continue 36 | if cls is not pete_errors.PeteError: 37 | self.assertIn(cls, pete_error_mapping._ERROR_MAPPINGS.keys()) 38 | 39 | def test_error_mapping(self): 40 | self.assertEqual( 41 | embedding_response.ErrorCode.INVALID_RESPONSE_ERROR, 42 | pete_error_mapping.get_error_code( 43 | pete_errors.InvalidResponseError('Potato') 44 | ), 45 | ) 46 | 47 | 48 | if __name__ == '__main__': 49 | absltest.main() 50 | -------------------------------------------------------------------------------- /python/serving/serving_framework/model_transfer.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | r"""A tool to transfer a model from the Vertex gcs bucket to a local directory. 16 | 17 | Copies a target gcs directory into a local directory using default credentials, 18 | intended to be used to set up tfserving-compatible model directories during 19 | serving framework startup. 20 | 21 | usage: 22 | python3 model_transfer.py --gcs_path="gs://bucket/object" \ 23 | --local_path="/path/to/local/dir" 24 | """ 25 | 26 | from collections.abc import Sequence 27 | 28 | from absl import app 29 | from absl import flags 30 | 31 | from google.cloud.aiplatform.utils import gcs_utils 32 | 33 | 34 | _GCS_PATH = flags.DEFINE_string( 35 | "gcs_path", 36 | None, 37 | "The gcs path to copy from.", 38 | required=True, 39 | ) 40 | _LOCAL_PATH = flags.DEFINE_string( 41 | "local_path", 42 | None, 43 | "The local path to copy to.", 44 | required=True, 45 | ) 46 | 47 | 48 | def main(argv: Sequence[str]) -> None: 49 | if len(argv) > 1: 50 | raise app.UsageError("Too many command-line arguments.") 51 | gcs_utils.download_from_gcs(_GCS_PATH.value, _LOCAL_PATH.value) 52 | 53 | if __name__ == "__main__": 54 | app.run(main) 55 | -------------------------------------------------------------------------------- /notebooks/README.md: -------------------------------------------------------------------------------- 1 | # Path Foundation Notebooks 2 | 3 | * [Quick start with Hugging Face](quick_start_with_hugging_face.ipynb) - 4 | Example of encoding a pathology image patch into an embedding vector by 5 | running the model locally from Hugging Face. 6 | 7 | * [Quick start with Vertex Model Garden](quick_start_with_model_garden.ipynb) - 8 | Example of serving the model on 9 | [Vertex AI](https://cloud.google.com/vertex-ai/docs/predictions/overview) 10 | and using Vertex AI APIs to encode pathology image patches to embeddings in 11 | online or batch workflows. 12 | 13 | * [Train a data efficient classifier - GCS version](train_data_efficient_classifier_gcs.ipynb) - 14 | Example of using the generated embeddings to train a custom classifier with 15 | less data and compute. This version shows how to use the data as files in 16 | [Cloud Storage (GCS)](https://cloud.google.com/storage). 17 | 18 | * [Train a data efficient classifier - DICOMWeb version](train_data_efficient_classifier_dicom.ipynb) - 19 | Example of using the generated embeddings to train a custom classifier with 20 | less data and compute. This version shows how to use the data as DICOM 21 | objects in 22 | [Cloud DICOM store](https://cloud.google.com/healthcare-api/docs/how-tos/dicom). 23 | 24 | * [Simplify client code with EZ-WSI](simplify_client_code_with_ez_wsi.ipynb) - 25 | Instructions how to utilize 26 | [EZ-WSI DicomWeb library](https://github.com/GoogleCloudPlatform/EZ-WSI-DICOMweb) 27 | to simplify client side code for working with DICOM data and generating 28 | embeddings from a variety of data sources including Cloud DICOM store, GCS, 29 | and locally stored files or in-memory data representations. 30 | 31 | * [Fine-tune data efficient classifier](fine_tune_data_efficient_classifier.ipynb) 32 | Example of fine-tuning the weights of the pathology embedding model to 33 | classify pathology image patches as an alternative to a linear classifier. -------------------------------------------------------------------------------- /python/serving/data_models/patch_coordinate.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Shared dataclasses across requests and responses for Pete.""" 16 | 17 | import dataclasses 18 | 19 | from serving import pete_errors 20 | from serving import pete_flags 21 | 22 | 23 | @dataclasses.dataclass(frozen=True) 24 | class PatchCoordinate: 25 | """A coordinate of a patch.""" 26 | 27 | x_origin: int 28 | y_origin: int 29 | width: int 30 | height: int 31 | 32 | def __post_init__(self): 33 | if ( 34 | self.width != pete_flags.ENDPOINT_INPUT_WIDTH_FLAG.value 35 | or self.height != pete_flags.ENDPOINT_INPUT_HEIGHT_FLAG.value 36 | ): 37 | raise pete_errors.PatchDimensionsDoNotMatchEndpointInputDimensionsError( 38 | 'Patch coordinate width and height must be' 39 | f' {pete_flags.ENDPOINT_INPUT_WIDTH_FLAG.value}x{pete_flags.ENDPOINT_INPUT_HEIGHT_FLAG.value}.' 40 | ) 41 | 42 | 43 | def create_patch_coordinate( 44 | x_origin: int, 45 | y_origin: int, 46 | width: int = -1, 47 | height: int = -1, 48 | ) -> PatchCoordinate: 49 | """Creates a patch coordinate.""" 50 | if width == -1: 51 | width = pete_flags.ENDPOINT_INPUT_WIDTH_FLAG.value 52 | if height == -1: 53 | height = pete_flags.ENDPOINT_INPUT_HEIGHT_FLAG.value 54 | return PatchCoordinate( 55 | x_origin=x_origin, 56 | y_origin=y_origin, 57 | width=width, 58 | height=height, 59 | ) 60 | -------------------------------------------------------------------------------- /python/serving/data_models/patch_coordinate_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Tests for patch coordinate.""" 16 | 17 | from absl.testing import absltest 18 | from absl.testing import parameterized 19 | from serving import pete_errors 20 | from serving.data_models import patch_coordinate 21 | 22 | 23 | class PatchCoordinateTest(parameterized.TestCase): 24 | 25 | def setUp(self): 26 | super().setUp() 27 | 28 | self._patch_coordinate_zero_dimensions = ( 29 | patch_coordinate.create_patch_coordinate( 30 | x_origin=1, 31 | y_origin=3, 32 | ) 33 | ) 34 | self._patch_coordinate_zero_dimensions_dict = { 35 | 'x_origin': 1, 36 | 'y_origin': 3, 37 | 'width': 224, 38 | 'height': 224, 39 | } 40 | 41 | def test_dicom_embedding_patch_coordinate_invalid_dimensions(self): 42 | with self.assertRaises( 43 | pete_errors.PatchDimensionsDoNotMatchEndpointInputDimensionsError 44 | ): 45 | _ = patch_coordinate.PatchCoordinate( 46 | x_origin=1, 47 | y_origin=3, 48 | width=11, 49 | height=10, 50 | ) 51 | 52 | def test_dicom_embedding_patch_coordinate_default_dimensions(self): 53 | parameters = self._patch_coordinate_zero_dimensions 54 | 55 | self.assertEqual( 56 | parameters.__dict__, self._patch_coordinate_zero_dimensions_dict 57 | ) 58 | 59 | 60 | if __name__ == '__main__': 61 | absltest.main() 62 | -------------------------------------------------------------------------------- /.github/PULL_REQUEST_TEMPLATE/pull_request_template.md: -------------------------------------------------------------------------------- 1 | # Description 2 | 3 | Replace this with a clear and concise change description 4 | 5 | 7 | 8 | 11 | 12 | Fixes #[issue number] 13 | 14 | Choose one: (Bug fix | Feature | Documentation | Testing | Code health | Other) 15 | 16 | # How Has This Been Tested? 17 | 18 | Replace this with a description of the tests that you ran to verify your 19 | changes. If executing the existing test suite without customization, simply 20 | paste the command line used. 21 | 22 | ``` 23 | $ python -m unittest discover ... 24 | ``` 25 | 26 | # Checklist: 27 | 28 | 29 | 30 | 31 | 32 | - [ ] I have read and acknowledged Google's Open Source 33 | [Code of conduct](https://opensource.google/conduct). 34 | - [ ] I have read the 35 | [Contributing](https://github.com/google-health/path-foundation/blob/master/CONTRIBUTING.md) 36 | page, and I either signed the Google 37 | [Individual CLA](https://cla.developers.google.com/about/google-individual) 38 | or am covered by my company's 39 | [Corporate CLA](https://cla.developers.google.com/about/google-corporate). 40 | - [ ] I have discussed my proposed solution with code owners in the linked 41 | issue(s) and we have agreed upon the general approach. 42 | - [ ] I have made any needed documentation changes, or noted in the linked 43 | issue(s) that documentation elsewhere needs updating. 44 | - [ ] I have added tests, or I have ensured existing tests cover the changes 45 | - [ ] I have followed 46 | [Google's Python Style Guide](https://google.github.io/styleguide/pyguide.html) 47 | and ran `pylint` over the affected code. 48 | -------------------------------------------------------------------------------- /python/serving/logging_lib/flags/flag_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Utility functions for flags.""" 16 | import os 17 | 18 | 19 | def str_to_bool(val: str) -> bool: 20 | """Converts a string representation of truth. 21 | 22 | True values are 'y', 'yes', 't', 'true', 'on', and '1'; 23 | False values are 'n', 'no', 'f', 'false', 'off', and '0'. 24 | 25 | Args: 26 | val: String to convert to bool. 27 | 28 | Returns: 29 | Boolean result 30 | 31 | Raises: 32 | ValueError if val is anything else. 33 | """ 34 | val = val.strip().lower() 35 | if val in ('y', 'yes', 't', 'true', 'on', '1'): 36 | return True 37 | if val in ('n', 'no', 'f', 'false', 'off', '0'): 38 | return False 39 | raise ValueError(f'invalid truth value {str(val)}') 40 | 41 | 42 | def env_value_to_bool(env_name: str, undefined_value: bool = False) -> bool: 43 | """Converts environmental variable value into boolean value for flag init. 44 | 45 | Args: 46 | env_name: Environmental variable name. 47 | undefined_value: Default value to set undefined values to. 48 | 49 | Returns: 50 | Boolean of environmental variable string value. 51 | 52 | Raises: 53 | ValueError : Environmental variable cannot be parsed to bool. 54 | """ 55 | if env_name not in os.environ: 56 | return undefined_value 57 | env_value = os.environ[env_name].strip() 58 | return str_to_bool(env_value) 59 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Path Foundation 2 | 3 | Path Foundation is a machine learning (ML) model that produces embeddings based 4 | on digital pathology images. The embeddings can be used to efficiently build AI 5 | models for pathology analysis-related tasks, requiring less data and less 6 | compute than having to fully train a model without the embeddings. 7 | 8 | As a Health AI Developer Foundations (HAI-DEF) model trained on large scale 9 | datasets, Path Foundation helps businesses and institutions in healthcare and 10 | life sciences do more with less pathology data, accelerating their ability to 11 | build AI models for pathology image analysis. 12 | 13 | ## Get started 14 | 15 | * Read our 16 | [developer documentation](https://developers.google.com/health-ai-developer-foundations/path-foundation/get-started) 17 | to see the full range of next steps available, including learning more about 18 | the model through its 19 | [model card](https://developers.google.com/health-ai-developer-foundations/path-foundation/model-card) 20 | or 21 | [serving API](https://developers.google.com/health-ai-developer-foundations/path-foundation/serving-api). 22 | 23 | * Explore this repository, which contains [notebooks](./notebooks) for using 24 | the model from Hugging Face and Vertex AI as well as the 25 | [implementation](./python/serving) of the container that you can deploy to Vertex 26 | AI. 27 | 28 | * Visit the model on 29 | [Hugging Face](https://huggingface.co/google/path-foundation) or 30 | [Model Garden](https://console.cloud.google.com/vertex-ai/publishers/google/model-garden/path-foundation). 31 | 32 | ## Contributing 33 | 34 | We are open to bug reports, pull requests (PR), and other contributions. See 35 | [CONTRIBUTING](CONTRIBUTING.md) and 36 | [community guidelines](https://developers.google.com/health-ai-developer-foundations/community-guidelines) 37 | for details. 38 | 39 | ## License 40 | 41 | While the model is licensed under the 42 | [Health AI Developer Foundations License](https://developers.google.com/health-ai-developer-foundations/terms), 43 | everything in this repository is licensed under the Apache 2.0 license, see 44 | [LICENSE](LICENSE). 45 | -------------------------------------------------------------------------------- /python/serving/logging_lib/flags/flag_utils_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Tests for flag utils.""" 16 | import os 17 | from unittest import mock 18 | 19 | from absl.testing import absltest 20 | from absl.testing import parameterized 21 | 22 | from serving.logging_lib.flags import flag_utils 23 | 24 | # const 25 | _UNDEFINED_ENV_VAR_NAME = 'UNDEFINED' 26 | 27 | 28 | class FlagUtilsTest(parameterized.TestCase): 29 | 30 | @parameterized.parameters(['y', ' YES ', 't', 'tRue', 'on', '1']) 31 | def test_str_to_bool_true(self, val): 32 | self.assertTrue(flag_utils.str_to_bool(val)) 33 | 34 | @parameterized.parameters(['n', 'no', 'f', 'FALSE', ' oFf ', '0']) 35 | def test_strtobool_false(self, val): 36 | self.assertFalse(flag_utils.str_to_bool(val)) 37 | 38 | def test_str_to_bool_raises(self): 39 | with self.assertRaises(ValueError): 40 | flag_utils.str_to_bool('ABCD') 41 | 42 | def test_undefined_env_default_true(self): 43 | self.assertNotIn(_UNDEFINED_ENV_VAR_NAME, os.environ) 44 | self.assertTrue(flag_utils.env_value_to_bool(_UNDEFINED_ENV_VAR_NAME, True)) 45 | 46 | def test_undefined_env_no_default(self): 47 | self.assertNotIn(_UNDEFINED_ENV_VAR_NAME, os.environ) 48 | self.assertFalse(flag_utils.env_value_to_bool(_UNDEFINED_ENV_VAR_NAME)) 49 | 50 | @mock.patch.dict(os.environ, {'FOO': ' True '}) 51 | def test_initialized_env(self): 52 | self.assertTrue(flag_utils.env_value_to_bool('FOO')) 53 | 54 | @mock.patch.dict(os.environ, {'BAD_VALUE': 'TrueABSCDEF'}) 55 | def test_bad_initialized_env(self): 56 | with self.assertRaises(ValueError): 57 | flag_utils.env_value_to_bool('BAD_VALUE') 58 | 59 | 60 | if __name__ == '__main__': 61 | absltest.main() 62 | -------------------------------------------------------------------------------- /python/serving/serving_framework/tensorflow/inline_model_runner.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Implements ModelRunner running a provided checkpoint in-process. 16 | 17 | Runs the serving_default signature of the provided checkpoint. 18 | 19 | Does not implement multi-model support. 20 | """ 21 | 22 | from collections.abc import Mapping, Set 23 | 24 | import numpy as np 25 | import tensorflow as tf 26 | from typing_extensions import override 27 | 28 | from serving.serving_framework import model_runner 29 | 30 | 31 | class InlineModelRunner(model_runner.ModelRunner): 32 | """ModelRunner implementation using in-process tensorflow.""" 33 | 34 | def __init__(self, model: tf.train.Checkpoint): 35 | self._model = model 36 | 37 | @override 38 | def run_model_multiple_output( 39 | self, 40 | model_input: Mapping[str, np.ndarray] | np.ndarray, 41 | *, 42 | model_name: str = "default", 43 | model_version: int | None = None, 44 | model_output_keys: Set[str], 45 | ) -> Mapping[str, np.ndarray]: 46 | """Runs a model on the given input and returns multiple outputs. 47 | 48 | Args: 49 | model_input: An array or map of arrays comprising the input tensors for 50 | the model. 51 | model_name: The name of the model to run. 52 | model_version: The version of the model to run. Uses default if None. 53 | model_output_keys: The desired model output keys. 54 | 55 | Returns: 56 | A mapping of model output keys to tensors. 57 | """ 58 | if model_name != "default" or model_version is not None: 59 | raise NotImplementedError( 60 | "InlineModelRunner does not support multiple models." 61 | ) 62 | del model_name, model_version 63 | if isinstance(model_input, np.ndarray): 64 | tensor_input = tf.convert_to_tensor(model_input) 65 | else: 66 | tensor_input = { 67 | k: tf.convert_to_tensor(v) for k, v in model_input.items() 68 | } 69 | 70 | result = self._model.signatures["serving_default"](tensor_input) 71 | return {k: result[k].numpy() for k in model_output_keys} 72 | -------------------------------------------------------------------------------- /python/serving/pete_errors.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Error classes for Pete.""" 16 | 17 | 18 | class InternalBugError(Exception): 19 | """Internal error capture exceptions which should never happen. 20 | 21 | The exception is purposefully not a child of PeteError to prevent it from 22 | being caught by pete exception handling logic. If InternalBugError are 23 | raised they should be investigated as bugs. Most internal errors check for 24 | expected conditions between the EZ-WSI pete interface. 25 | """ 26 | 27 | 28 | class PeteError(Exception): 29 | """Base error class for Pete Errors.""" 30 | 31 | def __init__(self, message: str = '', api_description: str = ''): 32 | """Errors with optional alternative descriptions for API echoing.""" 33 | super().__init__(message if message else api_description) 34 | self._api_description = api_description 35 | 36 | @property 37 | def api_description(self) -> str: 38 | """Returns the API description of the error.""" 39 | return self._api_description if self._api_description else str(self) 40 | 41 | 42 | class InstancesNotConcatenatedError(PeteError): 43 | pass 44 | 45 | 46 | class InvalidRequestFieldError(PeteError): 47 | pass 48 | 49 | 50 | class InvalidResponseError(PeteError): 51 | pass 52 | 53 | 54 | class InvalidCredentialsError(PeteError): 55 | pass 56 | 57 | 58 | class LevelNotFoundError(PeteError): 59 | pass 60 | 61 | 62 | class TooManyPatchesError(PeteError): 63 | pass 64 | 65 | 66 | class EzWsiStateError(PeteError): 67 | pass 68 | 69 | 70 | class GcsImagePathFormatError(PeteError): 71 | pass 72 | 73 | 74 | class ImageError(PeteError): 75 | pass 76 | 77 | 78 | class PatchOutsideOfImageDimensionsError(PeteError): 79 | pass 80 | 81 | 82 | class HttpError(PeteError): 83 | pass 84 | 85 | 86 | class InvalidIccProfileTransformError(PeteError): 87 | pass 88 | 89 | 90 | class ImageDimensionError(PeteError): 91 | pass 92 | 93 | 94 | class DicomTiledFullError(PeteError): 95 | pass 96 | 97 | 98 | class DicomPathError(PeteError): 99 | pass 100 | 101 | 102 | class DicomError(PeteError): 103 | pass 104 | 105 | 106 | class DicomImageDownsamplingTooLargeError(PeteError): 107 | pass 108 | 109 | 110 | class UnapprovedDicomStoreError(PeteError): 111 | pass 112 | 113 | 114 | class UnapprovedGcsBucketError(PeteError): 115 | pass 116 | 117 | 118 | class PatchDimensionsDoNotMatchEndpointInputDimensionsError(PeteError): 119 | pass 120 | -------------------------------------------------------------------------------- /python/serving/serving_framework/inline_prediction_executor_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from collections.abc import Mapping, Set 16 | from unittest import mock 17 | 18 | import numpy as np 19 | 20 | from absl.testing import absltest 21 | from serving.serving_framework import inline_prediction_executor 22 | from serving.serving_framework import model_runner 23 | 24 | 25 | class DummyModelRunner(model_runner.ModelRunner): 26 | """Dummy model runner for testing.""" 27 | 28 | def run_model_multiple_output( 29 | self, 30 | model_input: Mapping[str, np.ndarray] | np.ndarray, 31 | *, 32 | model_name: str = "default", 33 | model_version: int | None = None, 34 | model_output_keys: Set[str], 35 | ) -> Mapping[str, np.ndarray]: 36 | del model_name, model_version, model_output_keys 37 | return {"output_0": np.ones((1, 2), dtype=np.float32)} 38 | 39 | 40 | class InlinePredictionExecutorTest(absltest.TestCase): 41 | 42 | def test_predict_requires_start(self): 43 | predictor = mock.MagicMock() 44 | executor = inline_prediction_executor.InlinePredictionExecutor( 45 | predictor, DummyModelRunner 46 | ) 47 | with self.assertRaises(RuntimeError): 48 | executor.predict({"placeholder": "input"}) 49 | 50 | def test_execute_catches_predictor_exception(self): 51 | predictor = mock.MagicMock(side_effect=Exception("test error")) 52 | executor = inline_prediction_executor.InlinePredictionExecutor( 53 | predictor, DummyModelRunner 54 | ) 55 | executor.start() 56 | with self.assertRaises(RuntimeError): 57 | executor.execute({"placeholder": "input"}) 58 | 59 | def test_execute_calls_predictor(self): 60 | predictor = mock.MagicMock(return_value={"placeholder": "output"}) 61 | mock_model_runner = mock.create_autospec( 62 | DummyModelRunner, instance=True 63 | ) 64 | mock_model_runner_class = mock.create_autospec( 65 | DummyModelRunner, autospec=True 66 | ) 67 | mock_model_runner_class.return_value = mock_model_runner 68 | executor = inline_prediction_executor.InlinePredictionExecutor( 69 | predictor, mock_model_runner_class 70 | ) 71 | 72 | executor.start() 73 | self.assertEqual( 74 | executor.execute({"placeholder": "input"}), 75 | {"placeholder": "output"}, 76 | ) 77 | mock_model_runner_class.assert_called_once() 78 | predictor.assert_called_once_with( 79 | {"placeholder": "input"}, mock_model_runner 80 | ) 81 | 82 | 83 | if __name__ == "__main__": 84 | absltest.main() 85 | -------------------------------------------------------------------------------- /python/serving/serving_framework/inline_prediction_executor.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """A thin shell to fit a predictor function into the PredictionExecutor interface. 16 | 17 | This is a convenience wrapper to allow a predictor function to be used directly 18 | as a PredictionExecutor. 19 | 20 | Intended usage in launching a server: 21 | predictor = Predictor() 22 | executor = InlinePredictionExecutor(predictor.predict) 23 | server_gunicorn.PredictionApplication(executor, ...).run() 24 | """ 25 | 26 | from collections.abc import Callable 27 | from typing import Any 28 | 29 | from typing_extensions import override 30 | 31 | from serving.serving_framework import model_runner 32 | from serving.serving_framework import server_gunicorn 33 | 34 | 35 | class InlinePredictionExecutor(server_gunicorn.PredictionExecutor): 36 | """Provides prediction request execution using an inline function. 37 | 38 | Provides a little framing to simplify the use of predictor functions in the 39 | server worker process. 40 | 41 | If a function call with no setup is insufficient, overriding the start and 42 | predict methods here can be used to provide more complex behavior. Inheritance 43 | directly from PredictionExecutor is also an option. 44 | """ 45 | 46 | def __init__( 47 | self, 48 | predictor: Callable[ 49 | [dict[str, Any], model_runner.ModelRunner], dict[str, Any] 50 | ], 51 | model_runner_source: Callable[[], model_runner.ModelRunner] 52 | ): 53 | self._predictor = predictor 54 | self._model_runner = None 55 | self._model_runner_source = model_runner_source 56 | 57 | @override 58 | def start(self) -> None: 59 | """Starts the executor. 60 | 61 | Called after the Gunicorn worker process has started. Performs any setup 62 | which needs to be done post-fork. 63 | """ 64 | # Safer to instantiate the RPC stub post-fork. 65 | self._model_runner = self._model_runner_source() 66 | 67 | def predict(self, input_json: dict[str, Any]) -> dict[str, Any]: 68 | """Executes the given request payload.""" 69 | if self._model_runner is None: 70 | raise RuntimeError( 71 | "Model runner is not initialized. Please call start() first." 72 | ) 73 | return self._predictor(input_json, self._model_runner) 74 | 75 | @override 76 | def execute(self, input_json: dict[str, Any]) -> dict[str, Any]: 77 | """Executes the given request payload. 78 | 79 | Args: 80 | input_json: The full json prediction request payload. 81 | 82 | Returns: 83 | The json response to the prediction request. 84 | 85 | Raises: 86 | RuntimeError: Prediction failed in an unhandled way. 87 | """ 88 | try: 89 | return self.predict(input_json) 90 | except Exception as e: 91 | raise RuntimeError("Unhandled exception from predictor.") from e 92 | -------------------------------------------------------------------------------- /python/serving/pete_error_mapping.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Mappings between errors in python and error_codes returned in API responses.""" 16 | 17 | from serving import pete_errors 18 | from serving.data_models import embedding_response 19 | 20 | 21 | _ERROR_MAPPINGS = { 22 | pete_errors.InvalidRequestFieldError: ( 23 | embedding_response.ErrorCode.INVALID_REQUEST_FIELD_ERROR 24 | ), 25 | pete_errors.InvalidResponseError: ( 26 | embedding_response.ErrorCode.INVALID_RESPONSE_ERROR 27 | ), 28 | pete_errors.InstancesNotConcatenatedError: ( 29 | embedding_response.ErrorCode.INSTANCES_NOT_CONCATENATED_ERROR 30 | ), 31 | pete_errors.InvalidCredentialsError: ( 32 | embedding_response.ErrorCode.INVALID_CREDENTIALS_ERROR 33 | ), 34 | pete_errors.TooManyPatchesError: ( 35 | embedding_response.ErrorCode.TOO_MANY_PATCHES_ERROR 36 | ), 37 | pete_errors.LevelNotFoundError: ( 38 | embedding_response.ErrorCode.LEVEL_NOT_FOUND_ERROR 39 | ), 40 | pete_errors.EzWsiStateError: ( 41 | embedding_response.ErrorCode.EZ_WSI_STATE_ERROR 42 | ), 43 | pete_errors.PatchOutsideOfImageDimensionsError: ( 44 | embedding_response.ErrorCode.PATCH_OUTSIDE_OF_IMAGE_DIMENSIONS_ERROR 45 | ), 46 | pete_errors.ImageError: embedding_response.ErrorCode.IMAGE_ERROR, 47 | pete_errors.HttpError: embedding_response.ErrorCode.HTTP_ERROR, 48 | pete_errors.InvalidIccProfileTransformError: ( 49 | embedding_response.ErrorCode.INVALID_ICC_PROFILE_TRANSFORM_ERROR 50 | ), 51 | pete_errors.ImageDimensionError: ( 52 | embedding_response.ErrorCode.IMAGE_DIMENSION_ERROR 53 | ), 54 | pete_errors.DicomTiledFullError: ( 55 | embedding_response.ErrorCode.DICOM_TILED_FULL_ERROR 56 | ), 57 | pete_errors.DicomError: embedding_response.ErrorCode.DICOM_ERROR, 58 | pete_errors.DicomImageDownsamplingTooLargeError: ( 59 | embedding_response.ErrorCode.DICOM_IMAGE_DOWNSAMPLING_TOO_LARGE_ERROR 60 | ), 61 | pete_errors.DicomPathError: embedding_response.ErrorCode.DICOM_PATH_ERROR, 62 | pete_errors.GcsImagePathFormatError: ( 63 | embedding_response.ErrorCode.GCS_IMAGE_PATH_FORMAT_ERROR 64 | ), 65 | pete_errors.UnapprovedDicomStoreError: ( 66 | embedding_response.ErrorCode.UNAPPROVED_DICOM_STORE_ERROR 67 | ), 68 | pete_errors.UnapprovedGcsBucketError: ( 69 | embedding_response.ErrorCode.UNAPPROVED_GCS_BUCKET_ERROR 70 | ), 71 | pete_errors.PatchDimensionsDoNotMatchEndpointInputDimensionsError: ( 72 | embedding_response.ErrorCode.PATCH_DIMENSIONS_DO_NOT_MATCH_ENDPOINT_INPUT_DIMENSIONS_ERROR 73 | ), 74 | } 75 | 76 | 77 | def get_error_code( 78 | error: pete_errors.PeteError, 79 | ) -> embedding_response.ErrorCode: 80 | """Maps PeteErrors to ERROR_CODES.""" 81 | return _ERROR_MAPPINGS[type(error)] 82 | -------------------------------------------------------------------------------- /python/serving/data_models/embedding_request.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Request dataclasses for Pete.""" 16 | 17 | import dataclasses 18 | import enum 19 | from typing import Any, List, Mapping, Union 20 | from serving.data_models import patch_coordinate 21 | 22 | 23 | class ModelSize(enum.Enum): 24 | UNDEFINED = 0 25 | SMALL = 1 # ~1M parameters 26 | MEDIUM = 2 # ~20M parameters. 27 | LARGE = 3 # ~100M parameters. 28 | 29 | 30 | class ModelKind(enum.Enum): 31 | UNDEFINED = 0 32 | # Best suited for high magnification images. 33 | # Pixel spacings of .002mm, .001mm, .0005mm or 5x, 10x, 20x. 34 | LOW_PIXEL_SPACING = 1 35 | # Best suited for low magnification images. 36 | # Pixel spacings of .004mm, .008mm, .016mm, 5x_div_2, 5x_div4, 5x_div8. 37 | HIGH_PIXEL_SPACING = 2 38 | 39 | 40 | @dataclasses.dataclass(frozen=True) 41 | class EmbeddingInstanceV1: 42 | """An instance in a DICOM Embedding Request as described in the schema file.""" 43 | 44 | dicom_web_store_url: str 45 | dicom_study_uid: str 46 | dicom_series_uid: str 47 | bearer_token: str 48 | ez_wsi_state: Union[str, Mapping[str, Any]] 49 | instance_uids: List[str] 50 | patch_coordinates: List[patch_coordinate.PatchCoordinate] 51 | 52 | 53 | @dataclasses.dataclass(frozen=True) 54 | class DicomImageV2: 55 | """An instance in a DICOM Embedding Request as described in the schema file.""" 56 | 57 | series_path: str 58 | bearer_token: str 59 | extensions: Mapping[str, Any] 60 | instance_uids: List[str] 61 | patch_coordinates: List[patch_coordinate.PatchCoordinate] 62 | 63 | 64 | @dataclasses.dataclass(frozen=True) 65 | class GcsImageV2: 66 | """An instance in a DICOM Embedding Request as described in the schema file.""" 67 | 68 | image_file_uri: str 69 | bearer_token: str 70 | extensions: Mapping[str, Any] 71 | patch_coordinates: List[patch_coordinate.PatchCoordinate] 72 | 73 | 74 | @dataclasses.dataclass(frozen=True) 75 | class EmbeddedImageV2: 76 | """An instance in a DICOM Embedding Request as described in the schema file.""" 77 | 78 | image_bytes: str 79 | extensions: Mapping[str, Any] 80 | patch_coordinates: List[patch_coordinate.PatchCoordinate] 81 | 82 | 83 | EmbeddingInstanceV2 = Union[DicomImageV2, GcsImageV2, EmbeddedImageV2] 84 | 85 | 86 | @dataclasses.dataclass(frozen=True) 87 | class EmbeddingParameters: 88 | """A prediction in a DICOM Embedding Request as described in the schema file.""" 89 | 90 | model_size: str 91 | model_kind: str 92 | 93 | 94 | @dataclasses.dataclass(frozen=True) 95 | class EmbeddingRequestV1: 96 | """A DICOM Embedding Request is a single parameter and list of instances.""" 97 | 98 | parameters: EmbeddingParameters 99 | instances: List[EmbeddingInstanceV1] 100 | 101 | 102 | @dataclasses.dataclass(frozen=True) 103 | class EmbeddingRequestV2: 104 | """A DICOM Embedding Request is a list of instances.""" 105 | 106 | instances: List[EmbeddingInstanceV2] 107 | -------------------------------------------------------------------------------- /python/serving/pete_logging.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Initializes Pete Cloud Logging.""" 16 | 17 | from typing import Any, Mapping, Optional 18 | import uuid 19 | 20 | import ez_wsi_dicomweb.ez_wsi_logging_factory 21 | 22 | from serving import pete_flags 23 | from serving.logging_lib import cloud_logging_client 24 | 25 | 26 | def _set_log_signature() -> None: 27 | log_signature = {'pathology_embedding_trace_id': str(uuid.uuid4())} 28 | endpoint_log_name = pete_flags.ENDPOINT_LOG_NAME_FLAG.value 29 | if endpoint_log_name: 30 | log_signature.update({'endpoint_log_name': endpoint_log_name}) 31 | cloud_logging_client.set_log_signature(log_signature) 32 | cloud_logging_client.set_log_trace_key('pathology_embedding_trace_id') 33 | 34 | 35 | def init_application_logging() -> None: 36 | _set_log_signature() 37 | 38 | 39 | def init_embedding_request_logging() -> None: 40 | cloud_logging_client.do_not_log_startup_msg() 41 | _set_log_signature() 42 | 43 | 44 | class _EZWSICloudLoggingInterface( 45 | ez_wsi_dicomweb.ez_wsi_logging_factory.AbstractLoggingInterface 46 | ): 47 | """EZ-WSI Cloud Logging Interface.""" 48 | 49 | def __init__(self, signature: Optional[Mapping[str, Any]]): 50 | self._signature = signature 51 | 52 | def debug( 53 | self, 54 | msg: str, 55 | *args: ez_wsi_dicomweb.ez_wsi_logging_factory.OptionalStructureElements, 56 | ) -> None: 57 | cloud_logging_client.debug(msg, *args, self._signature, stack_frames_back=1) 58 | 59 | def info( 60 | self, 61 | msg: str, 62 | *args: ez_wsi_dicomweb.ez_wsi_logging_factory.OptionalStructureElements, 63 | ) -> None: 64 | cloud_logging_client.info(msg, *args, self._signature, stack_frames_back=1) 65 | 66 | def warning( 67 | self, 68 | msg: str, 69 | *args: ez_wsi_dicomweb.ez_wsi_logging_factory.OptionalStructureElements, 70 | ) -> None: 71 | cloud_logging_client.warning( 72 | msg, *args, self._signature, stack_frames_back=1 73 | ) 74 | 75 | def error( 76 | self, 77 | msg: str, 78 | *args: ez_wsi_dicomweb.ez_wsi_logging_factory.OptionalStructureElements, 79 | ) -> None: 80 | cloud_logging_client.error(msg, *args, self._signature, stack_frames_back=1) 81 | 82 | def critical( 83 | self, 84 | msg: str, 85 | *args: ez_wsi_dicomweb.ez_wsi_logging_factory.OptionalStructureElements, 86 | ) -> None: 87 | cloud_logging_client.critical( 88 | msg, *args, self._signature, stack_frames_back=1 89 | ) 90 | 91 | 92 | class EZWSILoggingInterfaceFactory( 93 | ez_wsi_dicomweb.ez_wsi_logging_factory.AbstractLoggingInterfaceFactory 94 | ): 95 | """EZ-WSI Cloud Logging Interface Factory.""" 96 | 97 | def __init__(self, signature: Mapping[str, Any]): 98 | self._signature = signature 99 | 100 | def create_logger( 101 | self, signature: Optional[Mapping[str, Any]] = None 102 | ) -> ez_wsi_dicomweb.ez_wsi_logging_factory.AbstractLoggingInterface: 103 | signature = {} if signature is None else dict(signature) 104 | signature.update(self._signature) 105 | return _EZWSICloudLoggingInterface(signature) 106 | -------------------------------------------------------------------------------- /python/serving/README.md: -------------------------------------------------------------------------------- 1 | # Path Foundation serving 2 | 3 | This folder contains the source code and configuration necessary to serve the 4 | model on 5 | [Vertex AI](https://cloud.google.com/vertex-ai/docs/predictions/use-custom-container). 6 | The implementation follows this 7 | [container architecture](https://developers.google.com/health-ai-developer-foundations/model-serving/container-architecture). 8 | 9 | The serving container can be used in both online and batch prediction workflows: 10 | 11 | * **Online predictions**: Deploy the container as a 12 | [REST](https://en.wikipedia.org/wiki/REST) endpoint, like a 13 | [Vertex AI endpoint](https://cloud.google.com/vertex-ai/docs/general/deployment). 14 | This allows you to access the model for real-time predictions via the REST 15 | [Application Programming Interface (API)](https://developers.google.com/health-ai-developer-foundations/cxr-foundation/serving-api). 16 | 17 | * **Batch predictions**: Use the container to run large-scale 18 | [Vertex AI batch prediction jobs](https://cloud.google.com/vertex-ai/docs/predictions/get-batch-predictions) 19 | to process many inputs at once. 20 | 21 | Note: PETE is an acronym used throughout the code that stands for Pathology 22 | Encoder Tech Engine. 23 | 24 | ## Description of select files and folders 25 | 26 | * [`serving_framework/`](./serving_framework): A library for 27 | implementing Vertex AI-compatible HTTP servers. 28 | 29 | * [`vertex_schemata/`](./vertex_schemata): Folder containing YAML files that 30 | define the 31 | [PredictSchemata](https://cloud.google.com/vertex-ai/docs/reference/rest/v1/PredictSchemata) 32 | for Vertex AI endpoints. 33 | 34 | * [`abstract_pete_predictor.py`](./abstract_pete_predictor.py): Defines 35 | `AbstractPetePredictor`, an abstract base class that provides a blueprint 36 | for PETE predictor classes. Subclasses must implement the predict method to 37 | provide concrete prediction logic. 38 | 39 | * [`Dockerfile`](./Dockerfile): Defines the Docker image for serving the 40 | model. 41 | 42 | * [`entrypoint.sh`](./entrypoint.sh): A bash script that is used as the Docker 43 | entrypoint. It sets up the necessary environment variables, copies the 44 | TensorFlow [SavedModel(s)](https://www.tensorflow.org/guide/saved_model) 45 | locally and launches the TensorFlow server and the frontend HTTP server. 46 | 47 | * [`pete_error_mapping.py`](./pete_error_mapping.py): Defines mappings between 48 | errors in Python and error codes returned in API responses. 49 | 50 | * [`pete_errors.py`](./pete_errors.py): Defines error classes. It includes the 51 | base class `PeteError` and various specific error classes that inherit from 52 | it. 53 | 54 | * [`pete_flags.py`](./pete_flags.py): Defines flags configured by 55 | environmental variables that configure container. 56 | 57 | * [`pete_icc_profile_cache.py`](./pete_icc_profile_cache.py): Enables 58 | [ICC profile](https://en.wikipedia.org/wiki/ICC_profile) caching using 59 | [Redis](https://redis.io) or 60 | [Cloud Storage](https://cloud.google.com/storage). 61 | 62 | * [`pete_prediction_executor.py`](./pete_prediction_executor.py): Defines the 63 | prediction executor for PETE. It includes the main function that runs the 64 | prediction executor loop. 65 | 66 | * [`pete_predictor_v2.py`](./pete_predictor_v2.py): Defines the PETE predictor 67 | v2 class. It includes the predict function that runs inference on provided 68 | patches. 69 | 70 | * [`requirements.txt`](./requirements.txt): Lists the required Python 71 | packages. 72 | 73 | * [`server_gunicorn.py`](./server_gunicorn.py): Creates the HTTP server that 74 | launches the prediction executor. 75 | 76 | ## Dependencies 77 | 78 | * [`data_processing/`](../data_processing): A library for data 79 | retrieval and processing. 80 | -------------------------------------------------------------------------------- /python/serving/pete_test_util.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Test utilities for pete.""" 16 | 17 | from __future__ import annotations 18 | 19 | import dataclasses 20 | import math 21 | import time 22 | from typing import Any, List, Optional 23 | 24 | 25 | @dataclasses.dataclass 26 | class _MockRedisData: 27 | value: bytes 28 | expire_time: float = -1.0 29 | 30 | 31 | class MockRedisClient: 32 | """Redis Client Mock.""" 33 | 34 | def __init__(self, host: str, port: int): 35 | self._host = host 36 | self._port = port 37 | self._mock_data_dict = {} 38 | self._pipeline_results = None 39 | 40 | def __enter__(self) -> MockRedisClient: 41 | return self 42 | 43 | def __exit__(self, exc_type, exc_val, exc_tb): 44 | return 45 | 46 | def __len__(self) -> int: 47 | return len(self._mock_data_dict) 48 | 49 | def clear(self) -> None: 50 | # not actual redis client method used to clear mock data from 51 | # mock redis dictionary. 52 | self._mock_data_dict = {} 53 | 54 | def pipeline(self) -> MockRedisClient: 55 | self._pipeline_results = [] 56 | return self 57 | 58 | def execute(self) -> List[Any]: 59 | if self._pipeline_results is None: 60 | raise ValueError('Pipeline not initialized.') 61 | return self._pipeline_results 62 | 63 | def _handle_pipeline_result(self, result: Any) -> Any: 64 | if self._pipeline_results is not None: 65 | self._pipeline_results.append(result) 66 | return result 67 | 68 | def incr(self, key: str) -> int: 69 | new_mk_data = _MockRedisData(int(0).to_bytes(1, 'little')) 70 | key_entry = self._mock_data_dict.get(key, new_mk_data) 71 | if key_entry.expire_time >= 0 and key_entry.expire_time <= time.time(): 72 | key_entry = new_mk_data 73 | new_value = int.from_bytes(key_entry.value, byteorder='little') + 1 74 | key_entry.value = new_value.to_bytes( 75 | int(math.ceil(math.log2(new_value+1) / 8.0)), 'little') 76 | self._mock_data_dict[key] = key_entry 77 | return self._handle_pipeline_result(new_value) 78 | 79 | def expire(self, key: str, seconds: int, nx: bool = False): 80 | try: 81 | # if nx is true only set expiration if expiration time is not set. 82 | if nx and self._mock_data_dict[key].expire_time != -1: 83 | return 84 | self._mock_data_dict[key].expire_time = time.time() + seconds 85 | except KeyError: 86 | pass 87 | 88 | def get(self, key: str) -> Optional[bytes]: 89 | key_entry = self._mock_data_dict.get(key) 90 | if key_entry is None: 91 | return self._handle_pipeline_result(None) 92 | if key_entry.expire_time >= 0 and key_entry.expire_time <= time.time(): 93 | return self._handle_pipeline_result(None) 94 | return self._handle_pipeline_result(key_entry.value) 95 | 96 | def set(self, key: str, value: bytes, nx: bool = False, ex: int = -1) -> bool: 97 | if nx: 98 | key_entry = self._mock_data_dict.get(key) 99 | if key_entry is not None: 100 | return self._handle_pipeline_result(False) 101 | self._mock_data_dict[key] = _MockRedisData( 102 | value, ex + time.time() if ex != -1 else -1 103 | ) 104 | return self._handle_pipeline_result(True) 105 | -------------------------------------------------------------------------------- /python/serving/test_utils/pete_mock.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Mock Embedding Endpoint.""" 16 | 17 | from collections.abc import Set 18 | import http 19 | import io 20 | import json 21 | from typing import List, Mapping, Optional, Sequence 22 | 23 | import numpy as np 24 | import requests 25 | import requests_mock 26 | 27 | from serving.serving_framework import model_runner 28 | from serving import abstract_pete_predictor 29 | 30 | 31 | class _MockModel(model_runner.ModelRunner): 32 | """Mocks model.""" 33 | 34 | def mock_model(self, data: np.ndarray) -> np.ndarray: 35 | return np.mean(data, axis=(1, 2)) 36 | 37 | def run_model_multiple_output( 38 | self, 39 | model_input: Mapping[str, np.ndarray] | np.ndarray, 40 | *, 41 | model_name: str = "default", 42 | model_version: int | None = None, 43 | model_output_keys: Set[str], 44 | ) -> Mapping[str, np.ndarray]: 45 | raise NotImplementedError("Not implemented.") 46 | 47 | def run_model( 48 | self, 49 | model_input: Mapping[str, np.ndarray] | np.ndarray, 50 | *, 51 | model_name: str = "default", 52 | model_version: int | None = None, 53 | model_output_key: str = "output_0", 54 | ) -> np.ndarray: 55 | if not isinstance(model_input, np.ndarray): 56 | raise ValueError("Model input must be a numpy array.") 57 | return self.mock_model(model_input) 58 | 59 | def batch_model( 60 | self, 61 | model_inputs: Sequence[Mapping[str, np.ndarray]] | Sequence[np.ndarray], 62 | *, 63 | model_name: str = "default", 64 | model_version: int | None = None, 65 | model_output_key: str = "output_0", 66 | ) -> List[np.ndarray]: 67 | if not isinstance(model_inputs[0], np.ndarray): 68 | raise ValueError("Model input must be a Sequence of numpy array.") 69 | return [self.mock_model(model_input) for model_input in model_inputs] 70 | 71 | 72 | class EndpointMock: 73 | """Mocks Pathology Embedding Enpoint.""" 74 | 75 | def __init__( 76 | self, 77 | mock_request: requests_mock.Mocker, 78 | mock_endpoint_url: str, 79 | pete_endpoint: abstract_pete_predictor.AbstractPetePredictor, 80 | ): 81 | self._mock_endpoint_url = mock_endpoint_url 82 | mock_request.add_matcher(self._handle_request) 83 | self._pete_endpoint = pete_endpoint 84 | self._mock_model_runner = _MockModel() 85 | 86 | def _handle_request( 87 | self, request: requests.Request 88 | ) -> Optional[requests.Response]: 89 | """Handles a request for the mock embedding endpoint. 90 | 91 | Args: 92 | request: The request to handle. 93 | 94 | Returns: 95 | None if request not handled otherwise mock V1 embedding response. 96 | Mock embedding is mean channel value per patch. 97 | """ 98 | if not request.url.startswith(self._mock_endpoint_url): 99 | return None 100 | result = self._pete_endpoint.predict( 101 | request.json(), self._mock_model_runner 102 | ) 103 | resp = requests.Response() 104 | resp.status_code = http.HTTPStatus.OK 105 | msg = json.dumps(result).encode("utf-8") 106 | resp.raw = io.BytesIO(msg) 107 | return resp 108 | -------------------------------------------------------------------------------- /python/serving/pete_flags.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Pete flags.""" 16 | 17 | import json 18 | import os 19 | import sys 20 | from typing import List, Optional, Union 21 | 22 | from absl import flags 23 | 24 | 25 | def _load_multi_string(val: Optional[str]) -> Optional[Union[List[str], str]]: 26 | if val is None: 27 | return None 28 | try: 29 | return json.loads(val) 30 | except json.decoder.JSONDecodeError: 31 | return val 32 | 33 | 34 | ENDPOINT_LOG_NAME_FLAG = flags.DEFINE_string( 35 | 'endpoint_log_name', 36 | os.environ.get('ENDPOINT_LOG_NAME', ''), 37 | 'Optional name write in endpoint logs to easily identify endpoints.', 38 | ) 39 | 40 | # If true and Redis host is defined stores ICC Profile bytes in redis. 41 | ICC_PROFILE_CACHE_GCS_BUCKET_FLAG = flags.DEFINE_string( 42 | 'icc_profile_cache_gcs_bucket', 43 | os.environ.get('ICC_PROFILE_CACHE_GCS_BUCKET', ''), 44 | 'Name of gcs bucket to cache icc profile to.', 45 | ) 46 | 47 | ICC_PROFILE_CACHE_REDIS_IP_FLAG = flags.DEFINE_string( 48 | 'icc_profile_cache_redis_ip', 49 | os.environ.get('ICC_PROFILE_CACHE_REDIS_IP', ''), 50 | 'IP address of REDIS server to cache cache icc profile to.', 51 | ) 52 | 53 | ICC_PROFILE_CACHE_REDIS_PORT_FLAG = flags.DEFINE_integer( 54 | 'icc_profile_cache_redis_port', 55 | int(os.environ.get('ICC_PROFILE_CACHE_REDIS_PORT', '6379')), 56 | 'Port of REDIS server to cache cache icc profile to.', 57 | ) 58 | 59 | # If true and Redis host is defined stores ICC Profile bytes in redis. 60 | STORE_ICC_PROFILE_BYTES_IN_REDIS_FLAG = flags.DEFINE_bool( 61 | 'store_icc_profile_bytes_in_redis', 62 | bool(os.environ.get('STORE_ICC_PROFILE_BYTES_IN_REDIS', False)), 63 | 'bool cache icc profile bytes in redis', 64 | ) 65 | 66 | # If true and Redis host is defined stores ICC Profile bytes in redis. 67 | IS_DEBUGGING_FLAG = flags.DEFINE_bool( 68 | 'is_debugging', 69 | bool( 70 | os.environ.get( 71 | 'IS_DEBUGGING', 72 | 'UNITTEST_ON_FORGE' in os.environ or 'unittest' in sys.modules, 73 | ) 74 | ), 75 | 'internal flag for unit tests detects if running in debugger.', 76 | ) 77 | 78 | APPROVED_GCS_SOURCE_LIST_FLAG = flags.DEFINE_multi_string( 79 | 'approved_gcs_source_list', 80 | _load_multi_string(os.environ.get('APPROVED_GCS_SOURCE_LIST', None)), 81 | 'List of GCS buckets endpoints can read from; all are allowed if' 82 | ' undefined.', 83 | ) 84 | 85 | 86 | APPROVED_DICOM_STORE_SOURCE_LIST_FLAG = flags.DEFINE_multi_string( 87 | 'approved_dicom_store_source_list', 88 | _load_multi_string( 89 | os.environ.get('APPROVED_DICOM_STORE_SOURCE_LIST', None) 90 | ), 91 | 'List of DICOM stores endpoint can read from; all are allowed if' 92 | ' undefined.', 93 | ) 94 | 95 | ENDPOINT_INPUT_WIDTH_FLAG = flags.DEFINE_integer( 96 | 'endpoint_input_width', 97 | int( 98 | os.environ.get('ENDPOINT_INPUT_WIDTH', 224) 99 | ), 100 | 'Width in pixels of input image to endpoint.', 101 | ) 102 | 103 | ENDPOINT_INPUT_HEIGHT_FLAG = flags.DEFINE_integer( 104 | 'endpoint_input_height', 105 | int( 106 | os.environ.get('ENDPOINT_INPUT_HEIGHT', 224) 107 | ), 108 | 'Height in pixels of input image to endpoint.', 109 | ) 110 | -------------------------------------------------------------------------------- /python/serving/server_gunicorn.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Launcher for the pete_prediction_executor based encoder server. 16 | 17 | Uses the serving framework to create a request server which 18 | performs the logic for requests in separate processes and uses a local TFserving 19 | instance to handle the model. 20 | """ 21 | 22 | from collections.abc import Sequence 23 | import contextlib 24 | import os 25 | from typing import Any, Mapping, Dict 26 | 27 | from absl import app 28 | from absl import logging 29 | from typing_extensions import override 30 | 31 | from serving.serving_framework import inline_prediction_executor 32 | from serving.serving_framework import model_runner 33 | from serving.serving_framework import server_gunicorn 34 | from serving.serving_framework.tensorflow import server_model_runner 35 | from serving import pete_error_mapping 36 | from serving import pete_errors 37 | from serving import pete_logging 38 | from serving import pete_predictor_v2 39 | from serving.data_models import embedding_response 40 | from serving.logging_lib import cloud_logging_client 41 | 42 | 43 | class PredictionExecutor(inline_prediction_executor.InlinePredictionExecutor): 44 | """Provides prediction request execution using an inline function.""" 45 | 46 | def __init__(self): 47 | self._pete_predictor = pete_predictor_v2.PetePredictor() 48 | self._exitstack = contextlib.ExitStack() 49 | super().__init__(self._run_request, server_model_runner.ServerModelRunner) 50 | 51 | def _run_request( 52 | self, 53 | request_json: Mapping[str, Any], 54 | model_runner: model_runner.ModelRunner, 55 | ) -> Dict[str, Any]: 56 | """Runs a single json request using provided components.""" 57 | pete_logging.init_embedding_request_logging() 58 | try: 59 | return dict(self._pete_predictor.predict(request_json, model_runner)) 60 | except pete_errors.PeteError as err: 61 | return dict( 62 | embedding_response.prediction_error_response_v2( 63 | pete_error_mapping.get_error_code(err) 64 | ) 65 | ) 66 | except Exception as err: 67 | cloud_logging_client.error( 68 | 'Unexpected exception raised while processing request.', err 69 | ) 70 | raise 71 | 72 | @override 73 | def start(self): 74 | pete_logging.init_application_logging() 75 | self._exitstack.enter_context(self._pete_predictor) 76 | super().start() 77 | 78 | 79 | def main(argv: Sequence[str]) -> None: 80 | if len(argv) > 1: 81 | raise app.UsageError('Too many command-line arguments.') 82 | if 'AIP_HTTP_PORT' not in os.environ: 83 | raise ValueError( 84 | 'The environment variable AIP_HTTP_PORT needs to be specified.' 85 | ) 86 | http_port = int(os.environ.get('AIP_HTTP_PORT')) 87 | options = { 88 | 'bind': f'0.0.0.0:{http_port}', 89 | 'workers': 3, 90 | 'timeout': 120, 91 | } 92 | health_checker = server_gunicorn.ModelServerHealthCheck( 93 | health_check_port=int(os.environ.get('MODEL_REST_PORT')), 94 | model_name='default', 95 | ) 96 | logging.info('Launching gunicorn application.') 97 | server_gunicorn.PredictionApplication( 98 | PredictionExecutor(), 99 | health_check=health_checker, 100 | options=options, 101 | ).run() 102 | 103 | 104 | if __name__ == '__main__': 105 | app.run(main) 106 | -------------------------------------------------------------------------------- /python/serving/pete_prediction_executor.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Executable to carry out pathology encoding request glue code. 16 | 17 | A subprocess which handles a piped in pathology encoder endpoint request json 18 | body and returns the response json body to stdout. Depends on a local TFserving 19 | instance to provide the encoder model. 20 | """ 21 | 22 | from collections.abc import Sequence 23 | import json 24 | import sys 25 | import time 26 | from typing import Any, Mapping 27 | 28 | from absl import app 29 | 30 | from serving.serving_framework.tensorflow import server_model_runner 31 | from serving import abstract_pete_predictor 32 | from serving import pete_error_mapping 33 | from serving import pete_errors 34 | from serving import pete_logging 35 | from serving import pete_predictor_v2 36 | from serving.data_models import embedding_response 37 | from serving.logging_lib import cloud_logging_client 38 | 39 | 40 | def _run_request( 41 | request_str: str, 42 | predictor: abstract_pete_predictor.AbstractPetePredictor, 43 | model_runner: server_model_runner.ServerModelRunner, 44 | ) -> Mapping[str, Any]: 45 | """Runs a single json request using provided components.""" 46 | try: 47 | try: 48 | request_json = json.loads(request_str) 49 | except json.JSONDecodeError as exp: 50 | cloud_logging_client.error( 51 | 'Failed to parse request JSON.', 52 | exp, 53 | ) 54 | raise pete_errors.InvalidRequestFieldError( 55 | 'Failed to parse request json.' 56 | ) from exp 57 | return predictor.predict(request_json, model_runner) 58 | except pete_errors.PeteError as err: 59 | return embedding_response.prediction_error_response_v2( 60 | pete_error_mapping.get_error_code(err) 61 | ) 62 | except Exception as err: 63 | cloud_logging_client.error( 64 | 'Unexpected exception raised while processing request.', err 65 | ) 66 | raise 67 | 68 | 69 | def main(argv: Sequence[str]) -> None: 70 | if len(argv) > 1: 71 | raise app.UsageError('Too many command-line arguments.') 72 | pete_logging.init_application_logging() 73 | 74 | try: 75 | with pete_predictor_v2.PetePredictor() as predictor: 76 | model_runner = server_model_runner.ServerModelRunner() 77 | 78 | cloud_logging_client.info('Starting pete prediction executor loop.') 79 | while True: 80 | pete_logging.init_embedding_request_logging() 81 | cloud_logging_client.debug('Waiting for request.') 82 | try: 83 | request_str = sys.stdin.readline() 84 | except EOFError: 85 | cloud_logging_client.debug('EOF on input, exiting.') 86 | return 87 | start_time = time.time() 88 | cloud_logging_client.debug('Received request.') 89 | result_json = _run_request(request_str, predictor, model_runner) 90 | cloud_logging_client.debug('Returning result from executor.') 91 | try: 92 | json.dump(result_json, sys.stdout) 93 | sys.stdout.write('\n') 94 | sys.stdout.flush() 95 | except BrokenPipeError: 96 | cloud_logging_client.debug('Pipe broken, exiting.') 97 | return 98 | elapsed = time.time() - start_time 99 | cloud_logging_client.info(f'Finished handling request ({elapsed} sec).') 100 | except Exception as exp: 101 | cloud_logging_client.error('Unhandled exception in executor.', exp) 102 | raise 103 | 104 | 105 | if __name__ == '__main__': 106 | app.run(main) 107 | -------------------------------------------------------------------------------- /python/serving/serving_framework/tensorflow/server_model_runner.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Implements ModelRunner by forwarding to a TFserving instance. 16 | 17 | Relies on the model being served by a TFserving instance running on localhost 18 | unless a stub configured otherwise is provided. 19 | """ 20 | 21 | from collections.abc import Mapping, Set 22 | 23 | from absl import logging 24 | import grpc 25 | import numpy as np 26 | from typing_extensions import override 27 | 28 | # pylint: disable = g-direct-tensorflow-import 29 | # Importing protos should be safe anyway? 30 | from serving.serving_framework import model_runner 31 | from tensorflow.python.framework import tensor_util 32 | from tensorflow_serving.apis import predict_pb2 33 | from tensorflow_serving.apis import prediction_service_pb2_grpc 34 | 35 | 36 | _HOSTPORT = "localhost:8500" 37 | 38 | 39 | class ServerModelRunner(model_runner.ModelRunner): 40 | """ModelRunner implementation using grpc to TFserving.""" 41 | 42 | def __init__( 43 | self, stub: prediction_service_pb2_grpc.PredictionServiceStub = None 44 | ): 45 | """Initializes the instance, with a local connection by default. 46 | 47 | Args: 48 | stub: A stub to use for the connection. If not provided, a default 49 | connection to localhost is established. This argument is intended for 50 | testing. 51 | """ 52 | if stub is not None: 53 | self._stub = stub 54 | return 55 | credentials = grpc.local_channel_credentials( 56 | grpc.LocalConnectionType.LOCAL_TCP 57 | ) 58 | channel = grpc.secure_channel(_HOSTPORT, credentials) 59 | self._stub = prediction_service_pb2_grpc.PredictionServiceStub(channel) 60 | 61 | @override 62 | def run_model_multiple_output( 63 | self, 64 | model_input: Mapping[str, np.ndarray] | np.ndarray, 65 | *, 66 | model_name: str = "default", 67 | model_version: int | None = None, 68 | model_output_keys: Set[str], 69 | ) -> Mapping[str, np.ndarray]: 70 | """Runs a model on the given input and returns multiple outputs. 71 | 72 | Args: 73 | model_input: An array or map of arrays comprising the input tensors for 74 | the model. A bare array is keyed by "inputs". 75 | model_name: The name of the model to run. 76 | model_version: The version of the model to run. Uses default if None. 77 | model_output_keys: The desired model output keys. 78 | 79 | Returns: 80 | A mapping of model output keys to tensors. 81 | 82 | Raises: 83 | KeyError: If any of the model_output_keys are not found in the model 84 | output. 85 | """ 86 | # If a bare np.ndarray was passed, it will be passed using the default 87 | # input key "inputs". 88 | # If a Mapping was passed, use the keys from that mapping. 89 | if isinstance(model_input, np.ndarray): 90 | logging.debug("Handling bare input tensor.") 91 | input_map = {"inputs": tensor_util.make_tensor_proto(model_input)} 92 | else: 93 | logging.debug("Handling input tensor map.") 94 | input_map = { 95 | k: tensor_util.make_tensor_proto(v) for k, v in model_input.items() 96 | } 97 | 98 | request = predict_pb2.PredictRequest() 99 | request.model_spec.name = model_name 100 | if model_version is not None: 101 | request.model_spec.version.value = model_version 102 | for key, data in input_map.items(): 103 | request.inputs[key].CopyFrom(data) 104 | 105 | logging.debug("Calling PredictionService.Predict") 106 | result = self._stub.Predict(request, timeout=60).outputs 107 | logging.debug("PredictionService.Predict returned.") 108 | # Check for expected keys in the result. 109 | result_keys = set(result.keys()) 110 | missing_keys = model_output_keys - result_keys 111 | if missing_keys: 112 | raise KeyError( 113 | f"Model output keys {missing_keys} not found in model output. " 114 | f"Available keys: {result_keys}" 115 | ) 116 | return {k: tensor_util.MakeNdarray(result[k]) for k in model_output_keys} 117 | -------------------------------------------------------------------------------- /python/serving/vertex_schemata/prediction.yaml: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | title: PathologyEmbeddingResponse 16 | type: object 17 | description: The generated model results (i.e. computed vector representation of the input data). 18 | oneOf: 19 | - type: object 20 | required: [result] 21 | - type: object 22 | required: [error] 23 | properties: 24 | result: 25 | $ref: '#/components/result' 26 | error: 27 | $ref: '#/components/error' 28 | 29 | components: 30 | error: 31 | type: object 32 | description: The error response if an exception occurred while processing the request. 33 | required: 34 | - code 35 | properties: 36 | code: 37 | type: string 38 | description: > 39 | List of known error codes. 40 | enum: 41 | - TOO_MANY_PATCHES_ERROR 42 | - INVALID_CREDENTIALS_ERROR 43 | - PATCH_DIMENSIONS_DO_NOT_MATCH_ENDPOINT_INPUT_DIMENSIONS_ERROR 44 | - INSTANCES_NOT_CONCATENATED_ERROR 45 | - INVALID_REQUEST_FIELD_ERROR 46 | - INVALID_RESPONSE_ERROR 47 | - LEVEL_NOT_FOUND_ERROR 48 | - EZ_WSIDE_STATE_ERROR 49 | - IMAGE_ERROR 50 | - HTTP_ERROR 51 | - INVALID_ICC_PROFILE_TRANSFORM_ERROR 52 | - IMAGE_DIMENSION_ERROR 53 | - DICOM_TILED_FULL_ERROR 54 | - DICOM_ERROR 55 | - DICOM_IMAGE_DOWNSAMPLING_TOO_LARGE_ERROR 56 | - PATCH_OUTSIDE_OF_IMAGE_DIMENSIONS_ERROR 57 | - DICOM_PATH_ERROR 58 | - GCS_IMAGE_PATH_FORMAT_ERROR 59 | - UNAPPROVED_DICOM_STORE_ERROR 60 | - UNAPPROVED_GCS_BUCKET_ERROR 61 | description: 62 | type: string 63 | description: A human-readable explanation of the error. 64 | maxLength: 1024 65 | 66 | result: 67 | type: object 68 | required: 69 | - patch_embeddings 70 | properties: 71 | patch_embeddings: 72 | type: array 73 | minItems: 1 74 | description: > 75 | The patch coordinates and embedding response. 76 | items: 77 | $ref: '#/components/patch_embedding' 78 | 79 | patch_embedding: 80 | type: object 81 | required: 82 | - patch_coordinate 83 | - embedding_vector 84 | properties: 85 | patch_coordinate: 86 | $ref: '#/components/patch_coordinate' 87 | embedding_vector: 88 | type: array 89 | description: > 90 | The 384 or 768 dimension embedding result generated from the input image & patches. 91 | items: 92 | $ref: '#/components/embedding_value' 93 | 94 | embedding_value: 95 | type: number 96 | format: float 97 | description: Single embedding value. 98 | 99 | patch_coordinate: 100 | type: object 101 | required: 102 | - x_origin 103 | - y_origin 104 | properties: 105 | x_origin: 106 | type: integer 107 | format: int64 108 | minimum: 0 109 | description: > 110 | The upper leftmost x coordinate of the rectangular patch section. 111 | y_origin: 112 | type: integer 113 | format: int64 114 | minimum: 0 115 | description: > 116 | The upper leftmost y coordinate of the rectangular patch section. 117 | width: 118 | type: integer 119 | format: int64 120 | minimum: 1 121 | description: > 122 | The width of the rectangular patch section extending rightward from the x_origin. 123 | If the underlying model doesn't support custom patch sizes, this value is ignored. 124 | For more details on the default patch size and whether custom patch sizes are supported, 125 | please consult the API specification. 126 | height: 127 | type: integer 128 | format: int64 129 | minimum: 1 130 | description: > 131 | The height of the rectangular patch section. The rectangle extends down from the y_origin. 132 | If the underlying model doesn't support custom patch sizes, this value is ignored. 133 | For more details on the default patch size and whether custom patch sizes are supported, 134 | please consult the API specification. -------------------------------------------------------------------------------- /python/serving/vertex_schemata/instance.yaml: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | title: PathologyEmbeddingRequest 16 | type: object 17 | required: 18 | - patch_coordinates 19 | oneOf: 20 | - type: object 21 | required: [dicom_path] 22 | - type: object 23 | required: [image_file_uri] 24 | - type: object 25 | required: [raw_image_bytes] 26 | not: 27 | required: [bearer_token] 28 | properties: 29 | dicom_path: 30 | $ref: '#/components/dicom_path' 31 | image_file_uri: 32 | type: string 33 | pattern: ^gs://.+/[^/]+\.[^/]+$ 34 | description: > 35 | The path to an image file in a Google Cloud Storage bucket. Provide the URI in this format: 36 | gs://{BUCKET-NAME}/{OPTIONAL-FOLDER-HIERARCHY}/{FILE-NAME}.{FILE-TYPE} 37 | raw_image_bytes: 38 | type: string 39 | format: byte 40 | description: > 41 | Input data as a base64-encoded string. Refer to the API specification for details. 42 | bearer_token: 43 | type: string 44 | description: > 45 | The token to access the Cloud DICOM Store or Cloud Storage bucket where the images are stored. 46 | patch_coordinates: 47 | type: array 48 | minItems: 1 49 | description: An array of patch coordinates. 50 | items: 51 | $ref: '#/components/patch_coordinate' 52 | extensions: 53 | type: object 54 | description: > 55 | An optional dictionary to enable flexible communication between the client and server. Refer 56 | to [extensions](../README.md#extensions) for the list of supported keys and their purposes. 57 | properties: 58 | key: 59 | type: string 60 | description: > 61 | A unique key to identify the extension. 62 | value: 63 | type: object 64 | description: > 65 | The value for the given extension as an embedded json. 66 | additionalProperties: true 67 | 68 | components: 69 | dicom_path: 70 | type: object 71 | required: 72 | - series_path 73 | - instance_uids 74 | properties: 75 | series_path: 76 | type: string 77 | pattern: ^https://.+/studies/[0-9\.]{1,64}/series/[0-9\.]{1,64}$ 78 | description: > 79 | The path to a DICOM Series in a DICOMWeb Store. Provide the URI in this format: 80 | https://{DICOMWEB-STORE-URI}/studies/{STUDY-UID}/series/{SERIES-UID} 81 | instance_uids: 82 | type: array 83 | minItems: 1 84 | description: > 85 | A list of unique identifiers for DICOM SOP Instances that contain the image pixels 86 | corresponding to the specified coordinates. All SOP Instances listed must have the same 87 | pixel spacing. 88 | items: 89 | type: string 90 | pattern: ^[0-9\.]{1,64}$ 91 | 92 | patch_coordinate: 93 | type: object 94 | required: 95 | - x_origin 96 | - y_origin 97 | properties: 98 | x_origin: 99 | type: integer 100 | format: int64 101 | minimum: 0 102 | description: > 103 | The upper leftmost x coordinate of the rectangular patch section. 104 | y_origin: 105 | type: integer 106 | format: int64 107 | minimum: 0 108 | description: > 109 | The upper leftmost y coordinate of the rectangular patch section. 110 | width: 111 | type: integer 112 | format: int64 113 | minimum: 1 114 | description: > 115 | The width of the rectangular patch section extending rightward from the x_origin. 116 | If the underlying model doesn't support custom patch sizes, this value is ignored. 117 | For more details on the default patch size and whether custom patch sizes are supported, 118 | please consult the API specification. 119 | height: 120 | type: integer 121 | format: int64 122 | minimum: 1 123 | description: > 124 | The height of the rectangular patch section. The rectangle extends down from the y_origin. 125 | If the underlying model doesn't support custom patch sizes, this value is ignored. 126 | For more details on the default patch size and whether custom patch sizes are supported, 127 | please consult the API specification. 128 | -------------------------------------------------------------------------------- /python/serving/pete2_e2e_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """E2E tests for ez_wsi_dicomweb -> Pete and back.""" 16 | 17 | from collections.abc import Sequence 18 | import os 19 | import shutil 20 | 21 | from absl.testing import absltest 22 | from ez_wsi_dicomweb import credential_factory 23 | from ez_wsi_dicomweb import dicom_slide 24 | from ez_wsi_dicomweb import dicom_web_interface 25 | from ez_wsi_dicomweb import gcs_image 26 | from ez_wsi_dicomweb import local_image 27 | from ez_wsi_dicomweb import patch_embedding 28 | from ez_wsi_dicomweb import patch_embedding_endpoints 29 | from ez_wsi_dicomweb.ml_toolkit import dicom_path 30 | import numpy as np 31 | import pydicom 32 | import requests_mock 33 | 34 | from serving import pete_predictor_v2 35 | from serving.test_utils import pete_mock 36 | from serving.test_utils import test_files 37 | from ez_wsi_dicomweb.test_utils.dicom_store_mock import dicom_store_mock 38 | from ez_wsi_dicomweb.test_utils.gcs_mock import gcs_mock 39 | 40 | 41 | def _round(embeddings: Sequence[float], decimals: int = 3) -> Sequence[float]: 42 | return [round(e, decimals) for e in embeddings] 43 | 44 | 45 | class EzWsiPete2E2eTest(absltest.TestCase): 46 | 47 | @requests_mock.Mocker() 48 | def test_ez_wsi_dicom_embeddings(self, mock_request): 49 | instance = pydicom.dcmread( 50 | test_files.test_multi_frame_dicom_instance_path() 51 | ) 52 | series_path = dicom_path.FromString( 53 | f'{test_files.TEST_STORE_PATH}/dicomWeb/studies/{instance.StudyInstanceUID}/series/{instance.SeriesInstanceUID}' 54 | ) 55 | store_path = str(series_path.GetStorePath()) 56 | with dicom_store_mock.MockDicomStores( 57 | store_path, mock_request=mock_request 58 | ) as mock_store: 59 | mock_store[store_path].add_instance(instance) 60 | slide = dicom_slide.DicomSlide( 61 | dicom_web_interface.DicomWebInterface( 62 | credential_factory.DefaultCredentialFactory() 63 | ), 64 | series_path, 65 | ) 66 | endpoint = patch_embedding_endpoints.V2PatchEmbeddingEndpoint() 67 | with pete_predictor_v2.PetePredictor() as predictor: 68 | pete_mock.EndpointMock(mock_request, endpoint.end_point_url, predictor) 69 | patch = slide.get_patch(slide.native_level, 0, 0, 224, 224) 70 | embedding = patch_embedding.get_patch_embedding(endpoint, patch) 71 | self.assertEqual(_round(embedding.tolist(), 3), [0.775, 0.716, 0.827]) 72 | 73 | @requests_mock.Mocker() 74 | def test_ez_wsi_gcs_embeddings(self, mock_request): 75 | temp_dir = self.create_tempdir() 76 | shutil.copyfile( 77 | test_files.testdata_path('dcm_frame_1.jpg'), 78 | os.path.join(temp_dir, 'test_image.jpg'), 79 | ) 80 | with gcs_mock.GcsMock({'test_bucket': temp_dir}): 81 | image = gcs_image.GcsImage( 82 | 'gs://test_bucket/test_image.jpg', 83 | image_dimensions=gcs_image.ImageDimensions(224, 224), 84 | credential_factory=credential_factory.NoAuthCredentialsFactory(), 85 | ) 86 | endpoint = patch_embedding_endpoints.V2PatchEmbeddingEndpoint( 87 | credential_factory=credential_factory.NoAuthCredentialsFactory() 88 | ) 89 | with pete_predictor_v2.PetePredictor() as predictor: 90 | pete_mock.EndpointMock(mock_request, endpoint.end_point_url, predictor) 91 | patch = image.get_patch(0, 0, 224, 224) 92 | embedding = patch_embedding.get_patch_embedding(endpoint, patch) 93 | self.assertEqual(_round(embedding.tolist(), 3), [0.776, 0.713, 0.826]) 94 | 95 | @requests_mock.Mocker() 96 | def test_ez_wsi_local_image(self, mock_request): 97 | mem = np.zeros((224, 224), dtype=np.uint8) 98 | endpoint = patch_embedding_endpoints.V2PatchEmbeddingEndpoint( 99 | credential_factory=credential_factory.NoAuthCredentialsFactory() 100 | ) 101 | with pete_predictor_v2.PetePredictor() as predictor: 102 | pete_mock.EndpointMock(mock_request, endpoint.end_point_url, predictor) 103 | image = local_image.LocalImage(mem) 104 | patch = image.get_patch(0, 0, 224, 224) 105 | embedding = patch_embedding.get_patch_embedding(endpoint, patch) 106 | self.assertEqual(_round(embedding.tolist(), 3), [0.0, 0.0, 0.0]) 107 | 108 | 109 | if __name__ == '__main__': 110 | absltest.main() 111 | -------------------------------------------------------------------------------- /python/serving/data_models/embedding_request_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Tests fpr dicom embedding request.""" 16 | 17 | from absl import flags 18 | from absl.testing import absltest 19 | from absl.testing import parameterized 20 | from serving.data_models import embedding_request 21 | from serving.data_models import patch_coordinate 22 | 23 | 24 | # Necessary to avoid flag parsing errors during unit tests. 25 | def setUpModule(): 26 | flags.FLAGS(['./program']) 27 | 28 | 29 | class DicomEmbeddingRequestTest(parameterized.TestCase): 30 | 31 | def setUp(self): 32 | super().setUp() 33 | 34 | self._dicom_embedding_parameters = embedding_request.EmbeddingParameters( 35 | model_size='SMALL', model_kind='LOW_PIXEL_SPACING' 36 | ) 37 | self._instance_uids_1 = ['1.2.3.4', '2.3.4.5'] 38 | self._patch_coordinate_1 = patch_coordinate.PatchCoordinate( 39 | x_origin=1, 40 | y_origin=1, 41 | width=224, 42 | height=224, 43 | ) 44 | self._instance_uids_2 = ['3.4.5.6', '4.5.6.7'] 45 | self._patch_coordinate_2 = patch_coordinate.PatchCoordinate( 46 | x_origin=2, 47 | y_origin=3, 48 | width=224, 49 | height=224, 50 | ) 51 | 52 | self._dicom_embedding_instance = embedding_request.EmbeddingInstanceV1( 53 | dicom_web_store_url='potato', 54 | dicom_study_uid='10.11.12.13', 55 | dicom_series_uid='11.12.13.14', 56 | bearer_token='xx_pear_xx', 57 | ez_wsi_state={'hello': 'world'}, 58 | instance_uids=self._instance_uids_1, 59 | patch_coordinates=[ 60 | self._patch_coordinate_1, 61 | self._patch_coordinate_2, 62 | ], 63 | ) 64 | 65 | self._dicom_embedding_instance_2 = embedding_request.EmbeddingInstanceV1( 66 | dicom_web_store_url='potato_2', 67 | dicom_study_uid='12.13.14.15', 68 | dicom_series_uid='13.14.15.16', 69 | bearer_token='xx_pineapple_xx', 70 | ez_wsi_state={'hello': 'goodbye'}, 71 | instance_uids=self._instance_uids_2, 72 | patch_coordinates=[ 73 | self._patch_coordinate_2, 74 | self._patch_coordinate_1, 75 | ], 76 | ) 77 | 78 | self._dicom_embedding_parameters_dict = { 79 | 'model_size': 'SMALL', 80 | 'model_kind': 'LOW_PIXEL_SPACING', 81 | } 82 | self._dicom_embedding_instance_dict = { 83 | 'dicom_web_store_url': 'potato', 84 | 'dicom_study_uid': '10.11.12.13', 85 | 'dicom_series_uid': '11.12.13.14', 86 | 'bearer_token': 'xx_pear_xx', 87 | 'ez_wsi_state': {'hello': 'world'}, 88 | 'instance_uids': ['1.2.3.4', '2.3.4.5'], 89 | 'patch_coordinates': [ 90 | self._patch_coordinate_1, 91 | self._patch_coordinate_2, 92 | ], 93 | } 94 | self._dicom_embedding_instance_2_dict = { 95 | 'dicom_web_store_url': 'potato_2', 96 | 'dicom_study_uid': '12.13.14.15', 97 | 'dicom_series_uid': '13.14.15.16', 98 | 'bearer_token': 'xx_pineapple_xx', 99 | 'ez_wsi_state': {'hello': 'goodbye'}, 100 | 'instance_uids': ['3.4.5.6', '4.5.6.7'], 101 | 'patch_coordinates': [ 102 | self._patch_coordinate_2, 103 | self._patch_coordinate_1, 104 | ], 105 | } 106 | 107 | self._dicom_embedding_request_dict = { 108 | 'parameters': self._dicom_embedding_parameters, 109 | 'instances': [ 110 | self._dicom_embedding_instance, 111 | self._dicom_embedding_instance_2, 112 | ], 113 | } 114 | 115 | def test_dicom_embedding_parameters(self): 116 | parameters = self._dicom_embedding_parameters 117 | 118 | self.assertEqual(parameters.__dict__, self._dicom_embedding_parameters_dict) 119 | 120 | def test_dicom_embedding_instance(self): 121 | parameters = self._dicom_embedding_instance 122 | 123 | self.assertEqual(parameters.__dict__, self._dicom_embedding_instance_dict) 124 | 125 | def test_dicom_embedding_request(self): 126 | parameters = embedding_request.EmbeddingRequestV1( 127 | parameters=self._dicom_embedding_parameters, 128 | instances=[ 129 | self._dicom_embedding_instance, 130 | self._dicom_embedding_instance_2, 131 | ], 132 | ) 133 | 134 | self.assertEqual(parameters.__dict__, self._dicom_embedding_request_dict) 135 | 136 | 137 | if __name__ == '__main__': 138 | absltest.main() 139 | -------------------------------------------------------------------------------- /python/serving/serving_framework/model_runner.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Abstract base class for dependency injection of model handling. 16 | 17 | Wraps execution of models on input tensors in an implementation-agnostic 18 | interface. Provides a mixin method for batching model execution. 19 | """ 20 | 21 | import abc 22 | from collections.abc import Mapping, Sequence, Set 23 | 24 | import numpy as np 25 | 26 | 27 | class ModelRunner(abc.ABC): 28 | """Runs a model with tensor inputs and outputs.""" 29 | 30 | @abc.abstractmethod 31 | def run_model_multiple_output( 32 | self, 33 | model_input: Mapping[str, np.ndarray] | np.ndarray, 34 | *, 35 | model_name: str = "default", 36 | model_version: int | None = None, 37 | model_output_keys: Set[str], 38 | ) -> Mapping[str, np.ndarray]: 39 | """Runs a model on the given input and returns multiple outputs. 40 | 41 | Args: 42 | model_input: An array or map of arrays comprising the input tensors for 43 | the model. A bare array is given a default input key. 44 | model_name: The name of the model to run. 45 | model_version: The version of the model to run. Uses default if None. 46 | model_output_keys: The desired model output keys. 47 | 48 | Returns: 49 | A mapping of model output keys to tensors. 50 | """ 51 | 52 | def run_model( 53 | self, 54 | model_input: Mapping[str, np.ndarray] | np.ndarray, 55 | *, 56 | model_name: str = "default", 57 | model_version: int | None = None, 58 | model_output_key: str = "output_0", 59 | ) -> np.ndarray: 60 | """Runs a model on the given input. 61 | 62 | Args: 63 | model_input: An array or map of arrays comprising the input tensors for 64 | the model. A bare array is given a default input key. 65 | model_name: The name of the model to run. 66 | model_version: The version of the model to run. Uses default if None. 67 | model_output_key: The key to pull the output from. Defaults to "output_0". 68 | 69 | Returns: 70 | The single output tensor. 71 | """ 72 | return self.run_model_multiple_output( 73 | model_input, 74 | model_name=model_name, 75 | model_version=model_version, 76 | model_output_keys={model_output_key}, 77 | )[model_output_key] 78 | 79 | def batch_model( 80 | self, 81 | model_inputs: Sequence[Mapping[str, np.ndarray]] | Sequence[np.ndarray], 82 | *, 83 | model_name: str = "default", 84 | model_version: int | None = None, 85 | model_output_key: str = "output_0", 86 | ) -> list[np.ndarray]: 87 | """Runs a model on each of the given inputs. 88 | 89 | Args: 90 | model_inputs: A sequence of arrays or maps of arrays comprising the input 91 | tensors for the model. Bare arrays are given a default input key. 92 | model_name: The name of the model to run. 93 | model_version: The version of the model to run. Uses default if None. 94 | model_output_key: The key to pull the output from. Defaults to "output_0". 95 | 96 | Returns: 97 | A list of the single output tensor from each input. 98 | """ 99 | return [ 100 | self.run_model( 101 | model_input, 102 | model_name=model_name, 103 | model_version=model_version, 104 | model_output_key=model_output_key, 105 | ) 106 | for model_input in model_inputs 107 | ] 108 | 109 | def batch_model_multiple_output( 110 | self, 111 | model_inputs: Sequence[Mapping[str, np.ndarray]] | Sequence[np.ndarray], 112 | *, 113 | model_name: str = "default", 114 | model_version: int | None = None, 115 | model_output_keys: Set[str], 116 | ) -> list[Mapping[str, np.ndarray]]: 117 | """Runs a model on the given inputs and returns multiple outputs. 118 | 119 | Args: 120 | model_inputs: An array or map of arrays comprising the input tensors for 121 | the model. Bare arrays are given a default input key. 122 | model_name: The name of the model to run. 123 | model_version: The version of the model to run. Uses default if None. 124 | model_output_keys: The desired model output keys. 125 | 126 | Returns: 127 | A list containing the mapping of model output keys to tensors from each 128 | input. 129 | """ 130 | return [ 131 | self.run_model_multiple_output( 132 | model_input, 133 | model_name=model_name, 134 | model_version=model_version, 135 | model_output_keys=model_output_keys, 136 | ) 137 | for model_input in model_inputs 138 | ] 139 | -------------------------------------------------------------------------------- /python/serving/data_models/embedding_response.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Response dataclasses for Pete.""" 16 | 17 | import dataclasses 18 | import enum 19 | from typing import Any, List, Mapping, Optional, Sequence 20 | 21 | from ez_wsi_dicomweb import patch_embedding_endpoints 22 | 23 | from serving import pete_errors 24 | from serving.data_models import patch_coordinate 25 | 26 | _MAX_ERROR_DESCRIPTION_LENGTH = 1024 27 | 28 | 29 | class ErrorCode(enum.Enum): 30 | """The error codes for PeteErrorResponse mapped from PeteErrors.""" 31 | 32 | TOO_MANY_PATCHES_ERROR = 'TOO_MANY_PATCHES_ERROR' 33 | INVALID_CREDENTIALS_ERROR = ( 34 | patch_embedding_endpoints.EndpointJsonKeys.INVALID_CREDENTIALS 35 | ) 36 | PATCH_DIMENSIONS_DO_NOT_MATCH_ENDPOINT_INPUT_DIMENSIONS_ERROR = ( 37 | 'PATCH_DIMENSIONS_DO_NOT_MATCH_ENDPOINT_INPUT_DIMENSIONS_ERROR' 38 | ) 39 | INSTANCES_NOT_CONCATENATED_ERROR = 'INSTANCES_NOT_CONCATENATED_ERROR' 40 | INVALID_REQUEST_FIELD_ERROR = 'INVALID_REQUEST_FIELD_ERROR' 41 | INVALID_RESPONSE_ERROR = 'INVALID_RESPONSE_ERROR' 42 | LEVEL_NOT_FOUND_ERROR = 'LEVEL_NOT_FOUND_ERROR' 43 | EZ_WSI_STATE_ERROR = 'EZ_WSI_STATE_ERROR' 44 | IMAGE_ERROR = 'IMAGE_ERROR' 45 | HTTP_ERROR = 'HTTP_ERROR' 46 | INVALID_ICC_PROFILE_TRANSFORM_ERROR = 'INVALID_ICC_PROFILE_TRANSFORM_ERROR' 47 | IMAGE_DIMENSION_ERROR = 'IMAGE_DIMENSION_ERROR' 48 | DICOM_TILED_FULL_ERROR = 'DICOM_TILED_FULL_ERROR' 49 | DICOM_ERROR = 'DICOM_ERROR' 50 | DICOM_IMAGE_DOWNSAMPLING_TOO_LARGE_ERROR = ( 51 | 'DICOM_IMAGE_DOWNSAMPLING_TOO_LARGE_ERROR' 52 | ) 53 | PATCH_OUTSIDE_OF_IMAGE_DIMENSIONS_ERROR = ( 54 | 'PATCH_OUTSIDE_OF_IMAGE_DIMENSIONS_ERROR' 55 | ) 56 | DICOM_PATH_ERROR = 'DICOM_PATH_ERROR' 57 | GCS_IMAGE_PATH_FORMAT_ERROR = 'GCS_IMAGE_PATH_FORMAT_ERROR' 58 | UNAPPROVED_DICOM_STORE_ERROR = 'UNAPPROVED_DICOM_STORE_ERROR' 59 | UNAPPROVED_GCS_BUCKET_ERROR = 'UNAPPROVED_GCS_BUCKET_ERROR' 60 | 61 | 62 | @dataclasses.dataclass(frozen=True) 63 | class PeteErrorResponse: 64 | """The response when Pete is unable to successfully complete a request.""" 65 | 66 | error_code: ErrorCode 67 | 68 | 69 | @dataclasses.dataclass(frozen=True) 70 | class PatchEmbeddingV1: 71 | """A List of embeddings, instance uids, and patch coordinate.""" 72 | 73 | embeddings: List[float] 74 | patch_coordinate: patch_coordinate.PatchCoordinate 75 | 76 | 77 | @dataclasses.dataclass(frozen=True) 78 | class PatchEmbeddingV2: 79 | """A List of embeddings, instance uids, and patch coordinate.""" 80 | 81 | embedding_vector: List[float] 82 | patch_coordinate: patch_coordinate.PatchCoordinate 83 | 84 | 85 | @dataclasses.dataclass(frozen=True) 86 | class EmbeddingResultV1: 87 | """The response when Pete is able to successfully complete a request.""" 88 | 89 | dicom_study_uid: str 90 | dicom_series_uid: str 91 | instance_uids: List[str] 92 | patch_embeddings: List[PatchEmbeddingV1] 93 | 94 | 95 | @dataclasses.dataclass(frozen=True) 96 | class EmbeddingResponseV1: 97 | """An instance in a Embedding Response as described in the schema file.""" 98 | 99 | model_version: str 100 | error_response: Optional[PeteErrorResponse] 101 | embedding_result: List[EmbeddingResultV1] 102 | 103 | def __post_init__(self): 104 | if self.error_response is None and self.embedding_result is None: 105 | raise pete_errors.InvalidResponseError( 106 | 'At least one of error_response or embedding_result must be set.' 107 | ) 108 | 109 | 110 | def embedding_instance_response_v2( 111 | results: Sequence[PatchEmbeddingV2], 112 | ) -> Mapping[str, Any]: 113 | """Returns a JSON-serializable embedding instance responses.""" 114 | return { 115 | patch_embedding_endpoints.EndpointJsonKeys.RESULT: { 116 | patch_embedding_endpoints.EndpointJsonKeys.PATCH_EMBEDDINGS: [ 117 | dataclasses.asdict(patch_embedding) for patch_embedding in results 118 | ] 119 | }, 120 | } 121 | 122 | 123 | def instance_error_response_v2( 124 | error_code: ErrorCode, description: str = '' 125 | ) -> Mapping[str, Any]: 126 | error = { 127 | patch_embedding_endpoints.EndpointJsonKeys.ERROR_CODE: error_code.value 128 | } 129 | if description: 130 | error[patch_embedding_endpoints.EndpointJsonKeys.ERROR_CODE_DESCRIPTION] = ( 131 | description[:_MAX_ERROR_DESCRIPTION_LENGTH] 132 | ) 133 | return { 134 | patch_embedding_endpoints.EndpointJsonKeys.ERROR: error, 135 | } 136 | 137 | 138 | def prediction_error_response_v2(error_code: ErrorCode) -> Mapping[str, Any]: 139 | return { 140 | patch_embedding_endpoints.EndpointJsonKeys.VERTEXAI_ERROR: ( 141 | error_code.value 142 | ) 143 | } 144 | -------------------------------------------------------------------------------- /notebooks/quick_start_with_hugging_face.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": { 6 | "id": "QW75OIoNtN1C" 7 | }, 8 | "source": [ 9 | "~~~\n", 10 | "Copyright 2024 Google LLC\n", 11 | "\n", 12 | "Licensed under the Apache License, Version 2.0 (the \"License\");\n", 13 | "you may not use this file except in compliance with the License.\n", 14 | "You may obtain a copy of the License at\n", 15 | "\n", 16 | " https://www.apache.org/licenses/LICENSE-2.0\n", 17 | "\n", 18 | "Unless required by applicable law or agreed to in writing, software\n", 19 | "distributed under the License is distributed on an \"AS IS\" BASIS,\n", 20 | "WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", 21 | "See the License for the specific language governing permissions and\n", 22 | "limitations under the License.\n", 23 | "~~~\n", 24 | "# Quick start with Hugging Face\n", 25 | "\n", 26 | "\u003ctable\u003e\u003ctbody\u003e\u003ctr\u003e\n", 27 | " \u003ctd style=\"text-align: center\"\u003e\n", 28 | " \u003ca href=\"https://colab.research.google.com/github/google-health/path-foundation/blob/master/notebooks/quick_start_with_hugging_face.ipynb\"\u003e\n", 29 | " \u003cimg alt=\"Google Colab logo\" src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" width=\"32px\"\u003e\u003cbr\u003e Run in Google Colab\n", 30 | " \u003c/a\u003e\n", 31 | " \u003c/td\u003e\n", 32 | " \u003ctd style=\"text-align: center\"\u003e\n", 33 | " \u003ca href=\"https://github.com/google-health/path-foundation/blob/master/notebooks/quick_start_with_hugging_face.ipynb\"\u003e\n", 34 | " \u003cimg alt=\"GitHub logo\" src=\"https://github.githubassets.com/assets/GitHub-Mark-ea2971cee799.png\" width=\"32px\"\u003e\u003cbr\u003e View on GitHub\n", 35 | " \u003c/a\u003e\n", 36 | " \u003c/td\u003e\n", 37 | " \u003ctd style=\"text-align: center\"\u003e\n", 38 | " \u003ca href=\"https://huggingface.co/google/path-foundation\"\u003e\n", 39 | " \u003cimg alt=\"Hugging Face logo\" src=\"https://huggingface.co/front/assets/huggingface_logo-noborder.svg\" width=\"32px\"\u003e\u003cbr\u003e View on Hugging Face\n", 40 | " \u003c/a\u003e\n", 41 | " \u003c/td\u003e\n", 42 | "\u003c/tr\u003e\u003c/tbody\u003e\u003c/table\u003e\n", 43 | "\n", 44 | "This Colab notebook provides a basic demonstration of the Path Foundation encoder. Given a pathology image patch, this encoder generates a machine learning representation called an \"embedding\". Learn more about embeddings and their benefits on [this page](https://developers.google.com/health-ai-developer-foundations/path-foundation)." 45 | ] 46 | }, 47 | { 48 | "cell_type": "code", 49 | "execution_count": null, 50 | "metadata": { 51 | "id": "umPIpHAIqGtJ" 52 | }, 53 | "outputs": [], 54 | "source": [ 55 | "# @title Authenticate with Hugging Face\n", 56 | "from huggingface_hub import notebook_login\n", 57 | "from huggingface_hub.utils import HfFolder\n", 58 | "\n", 59 | "if HfFolder.get_token() is None:\n", 60 | " from huggingface_hub import notebook_login\n", 61 | " notebook_login()" 62 | ] 63 | }, 64 | { 65 | "cell_type": "code", 66 | "execution_count": null, 67 | "metadata": { 68 | "id": "VgMA-MlAEkup" 69 | }, 70 | "outputs": [], 71 | "source": [ 72 | "# @title Load test image from Hugging Face\n", 73 | "\n", 74 | "from PIL import Image as PILImage\n", 75 | "from IPython.display import display\n", 76 | "from huggingface_hub import hf_hub_download\n", 77 | "\n", 78 | "# Download the test image from Hugging Face Hub\n", 79 | "hf_hub_download(repo_id=\"google/path-foundation\", filename='Test.png', local_dir='.')\n", 80 | "\n", 81 | "# Open the image, crop it, convert it to RGB format, and display it.\n", 82 | "img = PILImage.open(\"Test.png\").crop((0, 0, 224, 224)).convert('RGB')\n", 83 | "display(img)" 84 | ] 85 | }, 86 | { 87 | "cell_type": "code", 88 | "execution_count": null, 89 | "metadata": { 90 | "collapsed": true, 91 | "id": "48iQhBXdSsMG" 92 | }, 93 | "outputs": [], 94 | "source": [ 95 | "# @title Compute Embeddings\n", 96 | "from huggingface_hub import from_pretrained_keras\n", 97 | "import matplotlib.pyplot as plt\n", 98 | "import tensorflow as tf\n", 99 | "import numpy as np\n", 100 | "\n", 101 | "# Convert the image to a Tensor and scale to [0, 1]\n", 102 | "tensor = tf.cast(tf.expand_dims(np.array(img), axis=0), tf.float32) / 255.0\n", 103 | "\n", 104 | "# Load the model directly from Hugging Face\n", 105 | "loaded_model = from_pretrained_keras(\"google/path-foundation\")\n", 106 | "\n", 107 | "# Call inference\n", 108 | "infer = loaded_model.signatures[\"serving_default\"]\n", 109 | "embeddings = infer(tf.constant(tensor))\n", 110 | "\n", 111 | "# Extract the embedding vector\n", 112 | "embedding_vector = embeddings['output_0'].numpy().flatten()\n", 113 | "print(\"Size of embedding vector:\", len(embedding_vector))\n", 114 | "\n", 115 | "# Plot the embedding vector\n", 116 | "plt.figure(figsize=(12, 4))\n", 117 | "plt.plot(embedding_vector)\n", 118 | "plt.title('Embedding Vector')\n", 119 | "plt.xlabel('Index')\n", 120 | "plt.ylabel('Value')\n", 121 | "plt.grid(True)\n", 122 | "plt.show()" 123 | ] 124 | }, 125 | { 126 | "cell_type": "markdown", 127 | "metadata": { 128 | "id": "l1WK4BYyoyQj" 129 | }, 130 | "source": [ 131 | "# Next steps\n", 132 | "\n", 133 | " Explore the other [notebooks](https://github.com/google-health/path-foundation/blob/master/notebooks)." 134 | ] 135 | } 136 | ], 137 | "metadata": { 138 | "accelerator": "GPU", 139 | "colab": { 140 | "gpuType": "T4", 141 | "private_outputs": true 142 | }, 143 | "kernelspec": { 144 | "display_name": "Python 3", 145 | "name": "python3" 146 | }, 147 | "language_info": { 148 | "name": "python" 149 | } 150 | }, 151 | "nbformat": 4, 152 | "nbformat_minor": 0 153 | } 154 | -------------------------------------------------------------------------------- /python/serving/serving_framework/server_gunicorn.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Gunicorn application for passing requests through to the executor command. 16 | 17 | Provides a thin, subject-agnostic request server for Vertex endpoints which 18 | handles requests by piping their JSON bodies to the given executor command 19 | and returning the json output. 20 | """ 21 | 22 | import abc 23 | from collections.abc import Mapping, Sequence 24 | import http 25 | import json 26 | import os 27 | import subprocess 28 | from typing import Any, Optional 29 | 30 | from absl import logging 31 | import flask 32 | from gunicorn.app import base as gunicorn_base 33 | import requests 34 | from typing_extensions import override 35 | 36 | 37 | class PredictionExecutor(abc.ABC): 38 | """Wraps arbitrary implementation of executing a prediction request.""" 39 | 40 | @abc.abstractmethod 41 | def execute(self, input_json: dict[str, Any]) -> dict[str, Any]: 42 | """Executes the given request payload.""" 43 | 44 | def start(self) -> None: 45 | """Starts the executor. 46 | 47 | Called after the Gunicorn worker process has started. Performs any setup 48 | which needs to be done post-fork. 49 | """ 50 | 51 | 52 | class SubprocessPredictionExecutor(PredictionExecutor): 53 | """Provides prediction request execution using a persistent worker subprocess.""" 54 | 55 | def __init__(self, executor_command: Sequence[str]): 56 | """Initializes the executor with a command to start the subprocess.""" 57 | self._executor_command = executor_command 58 | self._executor_process = None 59 | 60 | def _restart(self) -> None: 61 | if self._executor_process is None: 62 | raise RuntimeError("Executor process not started.") 63 | 64 | self._executor_process.terminate() 65 | self.start() 66 | 67 | @override 68 | def start(self): 69 | """Starts the executor process.""" 70 | self._executor_process = subprocess.Popen( 71 | args=self._executor_command, 72 | stdout=subprocess.PIPE, 73 | stdin=subprocess.PIPE, 74 | ) 75 | 76 | @override 77 | def execute(self, input_json: dict[str, Any]) -> dict[str, Any]: 78 | """Uses the executor process to execute a request. 79 | 80 | Args: 81 | input_json: The full json prediction request payload. 82 | 83 | Returns: 84 | The json response to the prediction request. 85 | 86 | Raises: 87 | RuntimeError: Executor is not started or error communicating with the 88 | subprocess. 89 | """ 90 | if self._executor_process is None: 91 | raise RuntimeError("Executor process not started.") 92 | 93 | # Ensure json string is safe to pass through the pipe protocol. 94 | input_str = json.dumps(input_json).replace("\n", "") 95 | 96 | try: 97 | self._executor_process.stdin.write(input_str.encode("utf-8") + b"\n") 98 | self._executor_process.stdin.flush() 99 | except BrokenPipeError as e: 100 | self._restart() 101 | raise RuntimeError("Executor process input stream closed.") from e 102 | exec_result = self._executor_process.stdout.readline() 103 | if not exec_result: 104 | self._restart() 105 | raise RuntimeError("Executor process output stream closed.") 106 | try: 107 | return json.loads(exec_result) 108 | except json.JSONDecodeError as e: 109 | raise RuntimeError("Executor process output not valid json.") from e 110 | 111 | 112 | class ModelServerHealthCheck: 113 | """Checks the health of the local model server via REST request.""" 114 | 115 | def __init__(self, health_check_port: int, model_name: str): 116 | self._health_check_url = ( 117 | f"http://localhost:{health_check_port}/v1/models/{model_name}" 118 | ) 119 | 120 | def check_health(self) -> bool: 121 | try: 122 | r = requests.get(self._health_check_url) 123 | return r.status_code == http.HTTPStatus.OK.value 124 | except requests.exceptions.ConnectionError: 125 | return False 126 | 127 | 128 | def _create_app( 129 | executor: PredictionExecutor, 130 | health_check: ModelServerHealthCheck | None, 131 | ) -> flask.Flask: 132 | """Creates a Flask app with the given executor.""" 133 | flask_app = flask.Flask(__name__) 134 | 135 | if ( 136 | "AIP_HEALTH_ROUTE" not in os.environ 137 | or "AIP_PREDICT_ROUTE" not in os.environ 138 | ): 139 | raise ValueError( 140 | "Both of the environment variables AIP_HEALTH_ROUTE and " 141 | "AIP_PREDICT_ROUTE need to be specified." 142 | ) 143 | 144 | def health_route() -> tuple[str, int]: 145 | logging.info("health route hit") 146 | if health_check is not None and not health_check.check_health(): 147 | return "not ok", http.HTTPStatus.SERVICE_UNAVAILABLE.value 148 | return "ok", http.HTTPStatus.OK.value 149 | 150 | health_path = os.environ.get("AIP_HEALTH_ROUTE") 151 | logging.info("health path: %s", health_path) 152 | flask_app.add_url_rule(health_path, view_func=health_route) 153 | 154 | def predict() -> tuple[dict[str, Any], int]: 155 | logging.info("predict route hit") 156 | if flask.request.get_json(silent=True) is None: 157 | return {"error": "No JSON body."}, http.HTTPStatus.BAD_REQUEST.value 158 | 159 | logging.debug("Dispatching request to executor.") 160 | try: 161 | exec_result = executor.execute(flask.request.get_json()) 162 | logging.debug("Executor returned results.") 163 | return (exec_result, http.HTTPStatus.OK.value) 164 | except RuntimeError: 165 | logging.exception("Internal error handling request: Executor failed.") 166 | return { 167 | "error": "Internal server error." 168 | }, http.HTTPStatus.INTERNAL_SERVER_ERROR.value 169 | 170 | predict_route = os.environ.get("AIP_PREDICT_ROUTE") 171 | logging.info("predict route: %s", predict_route) 172 | flask_app.add_url_rule(predict_route, view_func=predict, methods=["POST"]) 173 | 174 | flask_app.config["TRAP_BAD_REQUEST_ERRORS"] = True 175 | 176 | return flask_app 177 | 178 | 179 | class PredictionApplication(gunicorn_base.BaseApplication): 180 | """Application to serve predictors on Vertex endpoints using gunicorn.""" 181 | 182 | def __init__( 183 | self, 184 | executor: PredictionExecutor, 185 | *, 186 | health_check: ModelServerHealthCheck | None, 187 | options: Optional[Mapping[str, Any]] = None, 188 | ): 189 | self.options = options or {} 190 | self.options = dict(self.options) 191 | self.options["preload_app"] = False 192 | self._executor = executor 193 | self.application = _create_app(self._executor, health_check) 194 | super().__init__() 195 | 196 | def load_config(self): 197 | config = { 198 | key: value 199 | for key, value in self.options.items() 200 | if key in self.cfg.settings and value is not None 201 | } 202 | for key, value in config.items(): 203 | self.cfg.set(key.lower(), value) 204 | 205 | def load(self) -> flask.Flask: 206 | self._executor.start() 207 | return self.application 208 | -------------------------------------------------------------------------------- /python/serving/logging_lib/flags/secret_flag_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Util to get env value from secret manager.""" 16 | 17 | import json 18 | import os 19 | import re 20 | import sys 21 | import threading 22 | from typing import Any, Mapping, TypeVar, Union 23 | 24 | import cachetools 25 | import google.api_core 26 | from google.cloud import secretmanager 27 | 28 | from serving.logging_lib.flags import flag_utils 29 | 30 | _T = TypeVar('_T') 31 | 32 | _SECRET_MANAGER_ENV_CONFIG = 'SECRET_MANAGER_ENV_CONFIG' 33 | 34 | _PARSE_SECRET_CONFIG = re.compile( 35 | r'.*?projects/([^/]+?)/secrets/([^/]+?)($|(/|(/versions/([^/]+))))', 36 | re.IGNORECASE, 37 | ) 38 | 39 | # Enable secret manager if not debugging. 40 | _ENABLE_ENV_SECRET_MANAGER = bool( 41 | 'UNITTEST_ON_FORGE' not in os.environ and 'unittest' not in sys.modules 42 | ) 43 | 44 | 45 | # Cache secret metadata to avoid repeated reads when initalizing flags. 46 | _cache = cachetools.LRUCache(maxsize=1) 47 | _cache_lock = threading.Lock() 48 | 49 | 50 | class SecretDecodeError(Exception): 51 | 52 | def __init__(self, msg: str, secret_name: str = '', data: str = ''): 53 | super().__init__(msg) 54 | self._secret_name = secret_name 55 | self._data = data 56 | 57 | 58 | def _init_fork_module_state() -> None: 59 | global _cache 60 | global _cache_lock 61 | _cache = cachetools.LRUCache(maxsize=1) 62 | _cache_lock = threading.Lock() 63 | 64 | 65 | def _get_secret_version( 66 | client: secretmanager.SecretManagerServiceClient, parent: str 67 | ) -> str: 68 | """Returns the greatest version number for a secret. 69 | 70 | Args: 71 | client: SecretManagerServiceClient. 72 | parent: String defining project and name of secret to return version. 73 | 74 | Returns: 75 | greatest version number for secret. 76 | 77 | Raises: 78 | SecretDecodeError: Could not identify version for secret. 79 | """ 80 | version_list = client.list_secret_versions(request={'parent': parent}) 81 | versions_found = [] 82 | for name in [ver.name for ver in version_list]: 83 | match = _PARSE_SECRET_CONFIG.fullmatch(name) 84 | if match is None: 85 | continue 86 | try: 87 | versions_found.append(int(match.groups()[-1])) 88 | except ValueError: 89 | continue 90 | if not versions_found: 91 | raise SecretDecodeError( 92 | f'Could not find version for secret {parent}.', secret_name=parent 93 | ) 94 | return str(max(versions_found)) 95 | 96 | 97 | def _read_secrets(secret_name: str) -> Mapping[str, Any]: 98 | """Returns secret from secret manager. 99 | 100 | Args: 101 | secret_name: Name of secret. 102 | 103 | Returns: 104 | Secret value. 105 | 106 | Raises: 107 | SecretDecodeError: Error retrieving value from secret manager. 108 | """ 109 | if not secret_name: 110 | return {} 111 | match = _PARSE_SECRET_CONFIG.fullmatch(secret_name) 112 | if match is None: 113 | raise SecretDecodeError( 114 | 'incorrectly formatted secret; expecting' 115 | f' [projects/.+/secrets/.+/versions/.+; passed {secret_name}.', 116 | secret_name=secret_name, 117 | ) 118 | project, secret, *_, version = match.groups() 119 | if not _ENABLE_ENV_SECRET_MANAGER: 120 | return {} 121 | with _cache_lock: 122 | cached_val = _cache.get(secret_name) 123 | if cached_val is not None: 124 | return cached_val 125 | with secretmanager.SecretManagerServiceClient() as client: 126 | parent = client.secret_path(project, secret) 127 | try: 128 | if version is None or not version: 129 | version = _get_secret_version(client, parent) 130 | secret = client.access_secret_version( 131 | request={'name': f'{parent}/versions/{version}'} 132 | ) 133 | except google.api_core.exceptions.NotFound as exp: 134 | raise SecretDecodeError( 135 | 'Secret not found.', secret_name=secret_name 136 | ) from exp 137 | except google.api_core.exceptions.PermissionDenied as exp: 138 | raise SecretDecodeError( 139 | 'Permission denied reading secret.', secret_name=secret_name 140 | ) from exp 141 | data = secret.payload.data 142 | if data is None or not data: 143 | return {} 144 | if isinstance(data, bytes): 145 | data = data.decode('utf-8') 146 | try: 147 | value = json.loads(data) 148 | except json.JSONDecodeError as exp: 149 | raise SecretDecodeError( 150 | 'Could not decode secret value.', secret_name=secret_name, data=data 151 | ) from exp 152 | if not isinstance(value, Mapping): 153 | raise SecretDecodeError( 154 | 'Secret value does not define a mapping.', 155 | secret_name=secret_name, 156 | data=data, 157 | ) 158 | _cache[secret_name] = value 159 | return value 160 | 161 | 162 | def get_secret_or_env(name: str, default: _T) -> Union[str, _T]: 163 | """Returns value defined in secret manager, env, or if undefined default. 164 | 165 | Searchs first for variable definition in JSON dict stored within the GCP 166 | secret managner. The GCP secret containing the dict is defined by the 167 | _SECRET_MANAGER_ENV_CONFIG. If the variable is not found in the GCP encoded 168 | secret, or if the _SECRET_MANAGER_ENV_CONFIG is undefined then the container 169 | ENV are searched. If neither define the variable the default value is 170 | returned. 171 | 172 | Args: 173 | name: Name of ENV. 174 | default: Default value to return if name is not defined. 175 | 176 | Returns: 177 | ENV value. 178 | 179 | Raises: 180 | SecretDecodeError: Error retrieving value from secret manager. 181 | """ 182 | secret_name = os.environ.get(_SECRET_MANAGER_ENV_CONFIG) 183 | if secret_name is not None and secret_name: 184 | secret_env = _read_secrets(secret_name) 185 | result = secret_env.get(name) 186 | if result is not None: 187 | return str(result) 188 | return os.environ.get(name, default) 189 | 190 | 191 | def get_bool_secret_or_env( 192 | env_name: str, undefined_value: bool = False 193 | ) -> bool: 194 | """Returns bool variable value into boolean value. 195 | 196 | Args: 197 | env_name: Environmental variable name. 198 | undefined_value: Default value to set undefined values to. 199 | 200 | Returns: 201 | Boolean of environmental variable string value. 202 | 203 | Raises: 204 | SecretDecodeError: Error retrieving value from secret manager. 205 | ValueError: Environmental variable cannot be parsed to bool. 206 | """ 207 | value = get_secret_or_env(env_name, str(undefined_value)) 208 | if value is not None: 209 | return flag_utils.str_to_bool(value) 210 | 211 | 212 | # Interfaces may be used from processes which are forked (gunicorn, 213 | # DICOM Proxy, Orchestrator, Refresher). In Python, forked processes do not 214 | # copy threads running within parent processes or re-initalize global/module 215 | # state. This can result in forked modules being executed with invalid global 216 | # state, e.g., acquired locks that will not release or references to invalid 217 | # state. 218 | _init_fork_module_state() 219 | os.register_at_fork(after_in_child=_init_fork_module_state) 220 | -------------------------------------------------------------------------------- /python/serving/data_models/embedding_response_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Tests for embedding response.""" 16 | 17 | from absl import flags 18 | from absl.testing import absltest 19 | from absl.testing import parameterized 20 | from ez_wsi_dicomweb import patch_embedding_endpoints 21 | from serving import pete_errors 22 | from serving.data_models import embedding_response 23 | from serving.data_models import patch_coordinate 24 | 25 | _EndpointJsonKeys = patch_embedding_endpoints.EndpointJsonKeys 26 | 27 | 28 | # Necessary to avoid flag parsing errors during unit tests. 29 | def setUpModule(): 30 | flags.FLAGS(['./program']) 31 | 32 | 33 | class DicomEmbeddingResponseTest(parameterized.TestCase): 34 | 35 | def setUp(self): 36 | super().setUp() 37 | 38 | self._model_version = 'model_version' 39 | 40 | self._instance_uids = ['eggplant', 'basil'] 41 | self._patch_coordinate = patch_coordinate.PatchCoordinate( 42 | x_origin=1, 43 | y_origin=3, 44 | width=224, 45 | height=224, 46 | ) 47 | self._instance_uids_2 = ['tomato', 'tomatillo'] 48 | self._patch_coordinate_2 = patch_coordinate.PatchCoordinate( 49 | x_origin=2, 50 | y_origin=3, 51 | width=224, 52 | height=224, 53 | ) 54 | self._pete_error = embedding_response.PeteErrorResponse( 55 | error_code=embedding_response.ErrorCode.TOO_MANY_PATCHES_ERROR 56 | ) 57 | self.dicom_embedding_result_item_1 = embedding_response.PatchEmbeddingV1( 58 | embeddings=[i + 1 for i in range(384)], 59 | patch_coordinate=self._patch_coordinate, 60 | ) 61 | self.dicom_embedding_result_item_2 = embedding_response.PatchEmbeddingV1( 62 | embeddings=[i + 13 for i in range(384)], 63 | patch_coordinate=self._patch_coordinate_2, 64 | ) 65 | self.dicom_embedding_embedding_result_1 = ( 66 | embedding_response.EmbeddingResultV1( 67 | dicom_study_uid='potato', 68 | dicom_series_uid='tomato', 69 | instance_uids=self._instance_uids, 70 | patch_embeddings=[ 71 | self.dicom_embedding_result_item_1, 72 | self.dicom_embedding_result_item_2, 73 | ], 74 | ) 75 | ) 76 | self.dicom_embedding_embedding_result_2 = ( 77 | embedding_response.EmbeddingResultV1( 78 | dicom_study_uid='grape', 79 | dicom_series_uid='berry', 80 | instance_uids=self._instance_uids_2, 81 | patch_embeddings=[ 82 | self.dicom_embedding_result_item_2, 83 | self.dicom_embedding_result_item_1, 84 | ], 85 | ) 86 | ) 87 | self.dicom_embedding_response = embedding_response.EmbeddingResponseV1( 88 | model_version=self._model_version, 89 | error_response=None, 90 | embedding_result=[ 91 | self.dicom_embedding_embedding_result_1, 92 | self.dicom_embedding_embedding_result_2, 93 | ], 94 | ) 95 | self.dicom_embedding_error_response = ( 96 | embedding_response.EmbeddingResponseV1( 97 | model_version=self._model_version, 98 | error_response=self._pete_error, 99 | embedding_result=None, 100 | ) 101 | ) 102 | 103 | self._patch_coordinate_dict = { 104 | 'x_origin': 1, 105 | 'y_origin': 3, 106 | 'width': 224, 107 | 'height': 224, 108 | } 109 | 110 | self._error_response_dict = { 111 | 'model_version': self._model_version, 112 | 'error_response': self._pete_error, 113 | 'embedding_result': None, 114 | } 115 | self._response_dict = { 116 | 'model_version': self._model_version, 117 | 'error_response': None, 118 | 'embedding_result': [ 119 | self.dicom_embedding_embedding_result_1, 120 | self.dicom_embedding_embedding_result_2, 121 | ], 122 | } 123 | 124 | def test_pete_error_response(self): 125 | parameters = self.dicom_embedding_error_response 126 | 127 | self.assertEqual(parameters.__dict__, self._error_response_dict) 128 | 129 | def test_dicom_embedding_result_item(self): 130 | parameters = self.dicom_embedding_response 131 | 132 | self.assertEqual(parameters.__dict__, self._response_dict) 133 | 134 | def test_dicom_embedding_response_fails(self): 135 | with self.assertRaises(pete_errors.InvalidResponseError): 136 | embedding_response.EmbeddingResponseV1( 137 | model_version=self._model_version, 138 | error_response=None, 139 | embedding_result=None, 140 | ) 141 | 142 | def test_embedding_instance_response_v2(self): 143 | embedding = embedding_response.PatchEmbeddingV2( 144 | [1, 2, 3, 4], patch_coordinate.PatchCoordinate(0, 10, 224, 224) 145 | ) 146 | self.assertEqual( 147 | embedding_response.embedding_instance_response_v2([embedding] * 2), 148 | { 149 | 'result': { 150 | 'patch_embeddings': [ 151 | { 152 | 'embedding_vector': [1, 2, 3, 4], 153 | 'patch_coordinate': { 154 | 'x_origin': 0, 155 | 'y_origin': 10, 156 | 'width': 224, 157 | 'height': 224, 158 | }, 159 | }, 160 | { 161 | 'embedding_vector': [1, 2, 3, 4], 162 | 'patch_coordinate': { 163 | 'x_origin': 0, 164 | 'y_origin': 10, 165 | 'width': 224, 166 | 'height': 224, 167 | }, 168 | }, 169 | ] 170 | }, 171 | }, 172 | ) 173 | 174 | @parameterized.named_parameters([ 175 | dict( 176 | testcase_name='code_only', 177 | description='', 178 | expected={ 179 | _EndpointJsonKeys.ERROR: { 180 | _EndpointJsonKeys.ERROR_CODE: 'TOO_MANY_PATCHES_ERROR', 181 | }, 182 | }, 183 | ), 184 | dict( 185 | testcase_name='code_and_description', 186 | description='foo', 187 | expected={ 188 | _EndpointJsonKeys.ERROR: { 189 | _EndpointJsonKeys.ERROR_CODE: 'TOO_MANY_PATCHES_ERROR', 190 | _EndpointJsonKeys.ERROR_CODE_DESCRIPTION: 'foo', 191 | }, 192 | }, 193 | ), 194 | ]) 195 | def test_instance_error_response_v2(self, description, expected): 196 | self.assertEqual( 197 | embedding_response.instance_error_response_v2( 198 | embedding_response.ErrorCode.TOO_MANY_PATCHES_ERROR, 199 | description=description, 200 | ), 201 | expected, 202 | ) 203 | 204 | def test_prediction_error_response_v2(self): 205 | self.assertEqual( 206 | embedding_response.prediction_error_response_v2( 207 | embedding_response.ErrorCode.TOO_MANY_PATCHES_ERROR 208 | ), 209 | {_EndpointJsonKeys.VERTEXAI_ERROR: 'TOO_MANY_PATCHES_ERROR'}, 210 | ) 211 | 212 | 213 | if __name__ == '__main__': 214 | absltest.main() 215 | -------------------------------------------------------------------------------- /python/serving/serving_framework/server_gunicorn_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import http 16 | import io 17 | import os 18 | import subprocess 19 | from unittest import mock 20 | 21 | import requests 22 | import requests_mock 23 | 24 | from absl.testing import absltest 25 | from serving.serving_framework import server_gunicorn 26 | 27 | 28 | class DummyHealthCheck: 29 | 30 | def __init__(self, check_result: bool): 31 | self._check_result = check_result 32 | 33 | def check_health(self): 34 | return self._check_result 35 | 36 | 37 | class ServerGunicornTest(absltest.TestCase): 38 | 39 | def setUp(self): 40 | super().setUp() 41 | os.environ["AIP_PREDICT_ROUTE"] = "/fake-predict-route" 42 | os.environ["AIP_HEALTH_ROUTE"] = "/fake-health-route" 43 | 44 | def test_application_option_default(self): 45 | executor = mock.create_autospec( 46 | server_gunicorn.PredictionExecutor, 47 | instance=True, 48 | ) 49 | 50 | app = server_gunicorn.PredictionApplication(executor, health_check=None) 51 | 52 | self.assertEqual(app.cfg.workers, 1) 53 | 54 | def test_application_option_setting(self): 55 | options = { 56 | "workers": 3, 57 | } 58 | executor = mock.create_autospec( 59 | server_gunicorn.PredictionExecutor, 60 | instance=True, 61 | ) 62 | 63 | app = server_gunicorn.PredictionApplication( 64 | executor, options=options, health_check=None 65 | ) 66 | 67 | self.assertEqual(app.cfg.workers, 3) 68 | 69 | def test_health_route_no_check(self): 70 | 71 | executor = mock.create_autospec( 72 | server_gunicorn.PredictionExecutor, 73 | instance=True, 74 | ) 75 | 76 | app = server_gunicorn.PredictionApplication( 77 | executor, health_check=None 78 | ).load() 79 | service = app.test_client() 80 | 81 | response = service.get("/fake-health-route") 82 | 83 | self.assertEqual(response.status_code, http.HTTPStatus.OK) 84 | self.assertEqual(response.text, "ok") 85 | 86 | @requests_mock.Mocker() 87 | def test_health_route_pass_check(self, mock_requests): 88 | mock_requests.register_uri( 89 | "GET", 90 | "http://localhost:12345/v1/models/default", 91 | text="assorted_metadata", 92 | status_code=http.HTTPStatus.OK, 93 | ) 94 | 95 | executor = mock.create_autospec( 96 | server_gunicorn.PredictionExecutor, 97 | instance=True, 98 | ) 99 | 100 | app = server_gunicorn.PredictionApplication( 101 | executor, 102 | health_check=server_gunicorn.ModelServerHealthCheck(12345, "default"), 103 | ).load() 104 | service = app.test_client() 105 | 106 | response = service.get("/fake-health-route") 107 | 108 | self.assertEqual(response.status_code, http.HTTPStatus.OK) 109 | self.assertEqual(response.text, "ok") 110 | 111 | @requests_mock.Mocker() 112 | def test_health_route_fail_check(self, mock_requests): 113 | mock_requests.register_uri( 114 | "GET", 115 | "http://localhost:12345/v1/models/default", 116 | exc=requests.exceptions.ConnectionError, 117 | ) 118 | executor = mock.create_autospec( 119 | server_gunicorn.PredictionExecutor, 120 | instance=True, 121 | ) 122 | 123 | app = server_gunicorn.PredictionApplication( 124 | executor, 125 | health_check=server_gunicorn.ModelServerHealthCheck(12345, "default"), 126 | ).load() 127 | service = app.test_client() 128 | 129 | response = service.get("/fake-health-route") 130 | 131 | self.assertEqual(response.status_code, http.HTTPStatus.SERVICE_UNAVAILABLE) 132 | self.assertEqual(response.text, "not ok") 133 | 134 | def test_predict_route_no_json(self): 135 | executor = mock.create_autospec( 136 | server_gunicorn.PredictionExecutor, 137 | instance=True, 138 | ) 139 | app = server_gunicorn.PredictionApplication( 140 | executor, health_check=None 141 | ).load() 142 | service = app.test_client() 143 | 144 | response = service.post("/fake-predict-route", data="invalid") 145 | 146 | executor.start.assert_called_once() 147 | executor.execute.assert_not_called() 148 | self.assertEqual(response.status_code, http.HTTPStatus.BAD_REQUEST) 149 | self.assertDictEqual({"error": "No JSON body."}, response.get_json()) 150 | 151 | def test_predict_route(self): 152 | executor = mock.create_autospec( 153 | server_gunicorn.PredictionExecutor, 154 | instance=True, 155 | ) 156 | app = server_gunicorn.PredictionApplication( 157 | executor, health_check=None 158 | ).load() 159 | service = app.test_client() 160 | executor.execute.return_value = {"placeholder": "output"} 161 | 162 | response = service.post( 163 | "/fake-predict-route", json={"meaningless": "filler"} 164 | ) 165 | 166 | executor.start.assert_called_once() 167 | executor.execute.assert_called_once_with({"meaningless": "filler"}) 168 | self.assertEqual(response.status_code, http.HTTPStatus.OK) 169 | self.assertDictEqual({"placeholder": "output"}, response.get_json()) 170 | 171 | def test_subprocess_executor_execute(self): 172 | mock_process = mock.create_autospec(subprocess.Popen, instance=True) 173 | with mock.patch.object( 174 | subprocess, "Popen", autospec=True, return_value=mock_process 175 | ) as mock_popen: 176 | executor = server_gunicorn.SubprocessPredictionExecutor(["fake_command"]) 177 | executor.start() 178 | mock_popen.assert_called_once_with( 179 | args=["fake_command"], 180 | stdout=subprocess.PIPE, 181 | stdin=subprocess.PIPE, 182 | ) 183 | mock_process.stdout = io.BytesIO(b'{"placeholder": "output"}\n') 184 | mock_process.stdin = io.BytesIO() 185 | 186 | response = executor.execute({"meaningless": "filler"}) 187 | 188 | self.assertEqual( 189 | b'{"meaningless": "filler"}\n', mock_process.stdin.getvalue() 190 | ) 191 | self.assertDictEqual({"placeholder": "output"}, response) 192 | 193 | def test_subprocess_executor_execute_error_output_closed(self): 194 | mock_process = mock.create_autospec(subprocess.Popen, instance=True) 195 | with mock.patch.object( 196 | subprocess, "Popen", autospec=True, return_value=mock_process 197 | ) as mock_popen: 198 | executor = server_gunicorn.SubprocessPredictionExecutor(["fake_command"]) 199 | executor.start() 200 | 201 | mock_process.stdout = io.BytesIO() # empty output simulates closed pipe. 202 | mock_process.stdin = io.BytesIO() 203 | 204 | with self.assertRaises(RuntimeError) as raised: 205 | executor.execute({"meaningless": "filler"}) 206 | self.assertEqual( 207 | raised.exception.args[0], "Executor process output stream closed." 208 | ) 209 | self.assertEqual(mock_popen.call_count, 2) # executor restarted. 210 | 211 | def test_subprocess_executor_execute_error_input_broken(self): 212 | mock_process = mock.create_autospec(subprocess.Popen, instance=True) 213 | with mock.patch.object( 214 | subprocess, "Popen", autospec=True, return_value=mock_process 215 | ) as mock_popen: 216 | executor = server_gunicorn.SubprocessPredictionExecutor(["fake_command"]) 217 | executor.start() 218 | 219 | mock_process.stdout = io.BytesIO(b'{"placeholder": "output"}\n') 220 | # Simulate broken pipe. 221 | mock_process.stdin = mock.create_autospec(io.BytesIO, instance=True) 222 | mock_process.stdin.write.side_effect = BrokenPipeError 223 | 224 | with self.assertRaises(RuntimeError): 225 | executor.execute({"meaningless": "filler"}) 226 | self.assertEqual(mock_popen.call_count, 2) # executor restarted. 227 | 228 | 229 | if __name__ == "__main__": 230 | absltest.main() 231 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | 2 | Apache License 3 | Version 2.0, January 2004 4 | http://www.apache.org/licenses/ 5 | 6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 7 | 8 | 1. Definitions. 9 | 10 | "License" shall mean the terms and conditions for use, reproduction, 11 | and distribution as defined by Sections 1 through 9 of this document. 12 | 13 | "Licensor" shall mean the copyright owner or entity authorized by 14 | the copyright owner that is granting the License. 15 | 16 | "Legal Entity" shall mean the union of the acting entity and all 17 | other entities that control, are controlled by, or are under common 18 | control with that entity. For the purposes of this definition, 19 | "control" means (i) the power, direct or indirect, to cause the 20 | direction or management of such entity, whether by contract or 21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 22 | outstanding shares, or (iii) beneficial ownership of such entity. 23 | 24 | "You" (or "Your") shall mean an individual or Legal Entity 25 | exercising permissions granted by this License. 26 | 27 | "Source" form shall mean the preferred form for making modifications, 28 | including but not limited to software source code, documentation 29 | source, and configuration files. 30 | 31 | "Object" form shall mean any form resulting from mechanical 32 | transformation or translation of a Source form, including but 33 | not limited to compiled object code, generated documentation, 34 | and conversions to other media types. 35 | 36 | "Work" shall mean the work of authorship, whether in Source or 37 | Object form, made available under the License, as indicated by a 38 | copyright notice that is included in or attached to the work 39 | (an example is provided in the Appendix below). 40 | 41 | "Derivative Works" shall mean any work, whether in Source or Object 42 | form, that is based on (or derived from) the Work and for which the 43 | editorial revisions, annotations, elaborations, or other modifications 44 | represent, as a whole, an original work of authorship. For the purposes 45 | of this License, Derivative Works shall not include works that remain 46 | separable from, or merely link (or bind by name) to the interfaces of, 47 | the Work and Derivative Works thereof. 48 | 49 | "Contribution" shall mean any work of authorship, including 50 | the original version of the Work and any modifications or additions 51 | to that Work or Derivative Works thereof, that is intentionally 52 | submitted to Licensor for inclusion in the Work by the copyright owner 53 | or by an individual or Legal Entity authorized to submit on behalf of 54 | the copyright owner. For the purposes of this definition, "submitted" 55 | means any form of electronic, verbal, or written communication sent 56 | to the Licensor or its representatives, including but not limited to 57 | communication on electronic mailing lists, source code control systems, 58 | and issue tracking systems that are managed by, or on behalf of, the 59 | Licensor for the purpose of discussing and improving the Work, but 60 | excluding communication that is conspicuously marked or otherwise 61 | designated in writing by the copyright owner as "Not a Contribution." 62 | 63 | "Contributor" shall mean Licensor and any individual or Legal Entity 64 | on behalf of whom a Contribution has been received by Licensor and 65 | subsequently incorporated within the Work. 66 | 67 | 2. Grant of Copyright License. Subject to the terms and conditions of 68 | this License, each Contributor hereby grants to You a perpetual, 69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 70 | copyright license to reproduce, prepare Derivative Works of, 71 | publicly display, publicly perform, sublicense, and distribute the 72 | Work and such Derivative Works in Source or Object form. 73 | 74 | 3. Grant of Patent License. Subject to the terms and conditions of 75 | this License, each Contributor hereby grants to You a perpetual, 76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 77 | (except as stated in this section) patent license to make, have made, 78 | use, offer to sell, sell, import, and otherwise transfer the Work, 79 | where such license applies only to those patent claims licensable 80 | by such Contributor that are necessarily infringed by their 81 | Contribution(s) alone or by combination of their Contribution(s) 82 | with the Work to which such Contribution(s) was submitted. If You 83 | institute patent litigation against any entity (including a 84 | cross-claim or counterclaim in a lawsuit) alleging that the Work 85 | or a Contribution incorporated within the Work constitutes direct 86 | or contributory patent infringement, then any patent licenses 87 | granted to You under this License for that Work shall terminate 88 | as of the date such litigation is filed. 89 | 90 | 4. Redistribution. You may reproduce and distribute copies of the 91 | Work or Derivative Works thereof in any medium, with or without 92 | modifications, and in Source or Object form, provided that You 93 | meet the following conditions: 94 | 95 | (a) You must give any other recipients of the Work or 96 | Derivative Works a copy of this License; and 97 | 98 | (b) You must cause any modified files to carry prominent notices 99 | stating that You changed the files; and 100 | 101 | (c) You must retain, in the Source form of any Derivative Works 102 | that You distribute, all copyright, patent, trademark, and 103 | attribution notices from the Source form of the Work, 104 | excluding those notices that do not pertain to any part of 105 | the Derivative Works; and 106 | 107 | (d) If the Work includes a "NOTICE" text file as part of its 108 | distribution, then any Derivative Works that You distribute must 109 | include a readable copy of the attribution notices contained 110 | within such NOTICE file, excluding those notices that do not 111 | pertain to any part of the Derivative Works, in at least one 112 | of the following places: within a NOTICE text file distributed 113 | as part of the Derivative Works; within the Source form or 114 | documentation, if provided along with the Derivative Works; or, 115 | within a display generated by the Derivative Works, if and 116 | wherever such third-party notices normally appear. The contents 117 | of the NOTICE file are for informational purposes only and 118 | do not modify the License. You may add Your own attribution 119 | notices within Derivative Works that You distribute, alongside 120 | or as an addendum to the NOTICE text from the Work, provided 121 | that such additional attribution notices cannot be construed 122 | as modifying the License. 123 | 124 | You may add Your own copyright statement to Your modifications and 125 | may provide additional or different license terms and conditions 126 | for use, reproduction, or distribution of Your modifications, or 127 | for any such Derivative Works as a whole, provided Your use, 128 | reproduction, and distribution of the Work otherwise complies with 129 | the conditions stated in this License. 130 | 131 | 5. Submission of Contributions. Unless You explicitly state otherwise, 132 | any Contribution intentionally submitted for inclusion in the Work 133 | by You to the Licensor shall be under the terms and conditions of 134 | this License, without any additional terms or conditions. 135 | Notwithstanding the above, nothing herein shall supersede or modify 136 | the terms of any separate license agreement you may have executed 137 | with Licensor regarding such Contributions. 138 | 139 | 6. Trademarks. This License does not grant permission to use the trade 140 | names, trademarks, service marks, or product names of the Licensor, 141 | except as required for reasonable and customary use in describing the 142 | origin of the Work and reproducing the content of the NOTICE file. 143 | 144 | 7. Disclaimer of Warranty. Unless required by applicable law or 145 | agreed to in writing, Licensor provides the Work (and each 146 | Contributor provides its Contributions) on an "AS IS" BASIS, 147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 148 | implied, including, without limitation, any warranties or conditions 149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 150 | PARTICULAR PURPOSE. You are solely responsible for determining the 151 | appropriateness of using or redistributing the Work and assume any 152 | risks associated with Your exercise of permissions under this License. 153 | 154 | 8. Limitation of Liability. In no event and under no legal theory, 155 | whether in tort (including negligence), contract, or otherwise, 156 | unless required by applicable law (such as deliberate and grossly 157 | negligent acts) or agreed to in writing, shall any Contributor be 158 | liable to You for damages, including any direct, indirect, special, 159 | incidental, or consequential damages of any character arising as a 160 | result of this License or out of the use or inability to use the 161 | Work (including but not limited to damages for loss of goodwill, 162 | work stoppage, computer failure or malfunction, or any and all 163 | other commercial damages or losses), even if such Contributor 164 | has been advised of the possibility of such damages. 165 | 166 | 9. Accepting Warranty or Additional Liability. While redistributing 167 | the Work or Derivative Works thereof, You may choose to offer, 168 | and charge a fee for, acceptance of support, warranty, indemnity, 169 | or other liability obligations and/or rights consistent with this 170 | License. However, in accepting such obligations, You may act only 171 | on Your own behalf and on Your sole responsibility, not on behalf 172 | of any other Contributor, and only if You agree to indemnify, 173 | defend, and hold each Contributor harmless for any liability 174 | incurred by, or claims asserted against, such Contributor by reason 175 | of your accepting any such warranty or additional liability. 176 | 177 | END OF TERMS AND CONDITIONS 178 | 179 | APPENDIX: How to apply the Apache License to your work. 180 | 181 | To apply the Apache License to your work, attach the following 182 | boilerplate notice, with the fields enclosed by brackets "[]" 183 | replaced with your own identifying information. (Don't include 184 | the brackets!) The text should be enclosed in the appropriate 185 | comment syntax for the file format. We also recommend that a 186 | file or class name and description of purpose be included on the 187 | same "printed page" as the copyright notice for easier 188 | identification within third-party archives. 189 | 190 | Copyright [yyyy] [name of copyright owner] 191 | 192 | Licensed under the Apache License, Version 2.0 (the "License"); 193 | you may not use this file except in compliance with the License. 194 | You may obtain a copy of the License at 195 | 196 | http://www.apache.org/licenses/LICENSE-2.0 197 | 198 | Unless required by applicable law or agreed to in writing, software 199 | distributed under the License is distributed on an "AS IS" BASIS, 200 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 201 | See the License for the specific language governing permissions and 202 | limitations under the License. 203 | -------------------------------------------------------------------------------- /python/serving/serving_framework/tensorflow/inline_model_runner_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from unittest import mock 16 | 17 | import numpy as np 18 | import tensorflow as tf 19 | 20 | from absl.testing import absltest 21 | from absl.testing import parameterized 22 | from serving.serving_framework.tensorflow import inline_model_runner 23 | 24 | 25 | def tensor_equal(first: tf.Tensor, second: tf.Tensor) -> bool: 26 | """Equality check for tensors. 27 | 28 | As implemented only confirms that the values are equal at all shared 29 | addresses. 30 | It is expected that some tensors of different shape may be misidentified as 31 | equal. 32 | 33 | Args: 34 | first: First tensor to compare. 35 | second: Second tensor to compare. 36 | 37 | Returns: 38 | True if the tensors are equal, False otherwise. 39 | """ 40 | return bool(tf.math.reduce_all(tf.equal(first, second))) 41 | 42 | 43 | class InlineModelRunnerTest(parameterized.TestCase): 44 | 45 | def setUp(self): 46 | super().setUp() 47 | 48 | self._model = mock.create_autospec(tf.train.Checkpoint, instance=True) 49 | self._runner = inline_model_runner.InlineModelRunner(model=self._model) 50 | 51 | def test_singleton_input(self): 52 | # Setup test values. 53 | input_np = np.array([[0.5] * 3] * 3, dtype=np.float32) 54 | input_tensor = tf.constant(input_np, shape=(3, 3), dtype=tf.float32) 55 | output_np = np.ones((3, 2), dtype=np.float32) 56 | output_tensor = tf.constant(output_np, shape=(3, 2), dtype=tf.float32) 57 | # Setup mock model. 58 | model_signature_mock = mock.MagicMock() 59 | model_signature_mock.return_value = {"output_0": output_tensor} 60 | self._model.signatures = {"serving_default": model_signature_mock} 61 | 62 | result = self._runner.run_model(input_np) 63 | 64 | self.assertLen(model_signature_mock.call_args_list, 1) 65 | self.assertTrue( 66 | tensor_equal(model_signature_mock.call_args[0][0], input_tensor), 67 | "Tensor passed to model values do not match input.", 68 | ) 69 | # Missing a check for the shape of the input tensor. 70 | np.testing.assert_array_equal(result, output_np) 71 | 72 | def test_map_input(self): 73 | # Setup test values. 74 | input_np_map = { 75 | "a": np.array([[0.5] * 3] * 3, dtype=np.float32), 76 | "b": np.array([[0.25] * 3] * 3, dtype=np.float32), 77 | } 78 | input_tensor_map = { 79 | label: tf.constant(input_np_map[label], shape=(3, 3), dtype=tf.float32) 80 | for label in input_np_map 81 | } 82 | output_np = np.ones((3, 2), dtype=np.float32) 83 | output_tensor = tf.constant(output_np, shape=(3, 2), dtype=tf.float32) 84 | # Setup mock model. 85 | model_signature_mock = mock.MagicMock() 86 | model_signature_mock.return_value = {"output_0": output_tensor} 87 | self._model.signatures = {"serving_default": model_signature_mock} 88 | 89 | result = self._runner.run_model(input_np_map) 90 | 91 | self.assertLen(model_signature_mock.call_args_list, 1) 92 | self.assertSameElements( 93 | model_signature_mock.call_args[0][0].keys(), 94 | input_np_map.keys(), 95 | "Input tensor map keys don't match argument keys.", 96 | ) 97 | self.assertTrue( 98 | all([ 99 | tensor_equal( 100 | model_signature_mock.call_args[0][0][key], input_tensor_map[key] 101 | ) 102 | for key in input_tensor_map 103 | ]), 104 | "Tensor passed to model values do not match input.", 105 | ) 106 | # Missing a check for the shape of the input tensor. 107 | np.testing.assert_array_equal(result, output_np) 108 | 109 | def test_batch_singleton_input(self): 110 | # Setup test values. 111 | input_nps = [ 112 | np.array([[0.5] * 3] * 3, dtype=np.float32), 113 | np.array([[0.25] * 3] * 3, dtype=np.float32), 114 | ] 115 | input_tensors = [ 116 | tf.constant(input_np, shape=(3, 3), dtype=tf.float32) 117 | for input_np in input_nps 118 | ] 119 | output_nps = [ 120 | np.ones((3, 2), dtype=np.float32), 121 | np.zeros((3, 2), dtype=np.float32), 122 | ] 123 | output_tensors = [ 124 | tf.constant(output_np, shape=(3, 2), dtype=tf.float32) 125 | for output_np in output_nps 126 | ] 127 | # Setup mock model. 128 | model_signature_mock = mock.MagicMock() 129 | model_signature_mock.side_effect = [ 130 | {"output_0": output_tensor} for output_tensor in output_tensors 131 | ] 132 | self._model.signatures = {"serving_default": model_signature_mock} 133 | 134 | result = self._runner.batch_model(input_nps) 135 | 136 | self.assertLen(model_signature_mock.call_args_list, len(input_nps)) 137 | for input_tensor, call in zip( 138 | input_tensors, model_signature_mock.call_args_list 139 | ): 140 | self.assertTrue( 141 | tensor_equal(call[0][0], input_tensor), 142 | "Tensor passed to model values do not match input.", 143 | ) 144 | # Missing a check for the shape of the input tensor. 145 | for result_np, output_np in zip(result, output_nps): 146 | np.testing.assert_array_equal(result_np, output_np) 147 | 148 | def test_batch_map_input(self): 149 | # Setup test values. 150 | input_np_maps = [ 151 | { 152 | "a": np.array([[0.5] * 3] * 3, dtype=np.float32), 153 | "b": np.array([[0.25] * 3] * 3, dtype=np.float32), 154 | }, 155 | { 156 | "a": np.array([[0.25] * 3] * 3, dtype=np.float32), 157 | "b": np.array([[0.5] * 3] * 3, dtype=np.float32), 158 | }, 159 | ] 160 | input_tensor_maps = [] 161 | for input_np_map in input_np_maps: 162 | input_tensor_maps.append({ 163 | label: tf.constant( 164 | input_np_map[label], shape=(3, 3), dtype=tf.float32 165 | ) 166 | for label in input_np_map 167 | }) 168 | output_nps = [ 169 | np.ones((3, 2), dtype=np.float32), 170 | np.zeros((3, 2), dtype=np.float32), 171 | ] 172 | output_tensors = [ 173 | tf.constant(output_np, shape=(3, 2), dtype=tf.float32) 174 | for output_np in output_nps 175 | ] 176 | # Setup mock model. 177 | model_signature_mock = mock.MagicMock() 178 | model_signature_mock.side_effect = [ 179 | {"output_0": output_tensor} for output_tensor in output_tensors 180 | ] 181 | self._model.signatures = {"serving_default": model_signature_mock} 182 | 183 | result = self._runner.batch_model(input_np_maps) 184 | 185 | self.assertLen(model_signature_mock.call_args_list, len(input_np_maps)) 186 | for input_tensor_map, call in zip( 187 | input_tensor_maps, model_signature_mock.call_args_list 188 | ): 189 | self.assertTrue( 190 | all([ 191 | tensor_equal(call[0][0][key], input_tensor_map[key]) 192 | for key in input_tensor_map 193 | ]), 194 | "Tensor passed to model values do not match input.", 195 | ) 196 | # Missing a check for the shape of the input tensor. 197 | for result_np, output_np in zip(result, output_nps): 198 | np.testing.assert_array_equal(result_np, output_np) 199 | 200 | def test_keyed_output(self): 201 | # Setup test values. 202 | input_np = np.array([[0.5] * 3] * 3, dtype=np.float32) 203 | input_tensor = tf.constant(input_np, shape=(3, 3), dtype=tf.float32) 204 | output_np = np.ones((3, 2), dtype=np.float32) 205 | output_tensor = tf.constant(output_np, shape=(3, 2), dtype=tf.float32) 206 | surplus_np = np.array([[0.1] * 3] * 2, dtype=np.float32) 207 | surplus_tensor = tf.constant(surplus_np, shape=(3, 2), dtype=tf.float32) 208 | # Setup mock model. 209 | model_signature_mock = mock.MagicMock() 210 | model_signature_mock.return_value = { 211 | "output_a": output_tensor, 212 | "output_b": surplus_tensor, 213 | } 214 | self._model.signatures = {"serving_default": model_signature_mock} 215 | 216 | result = self._runner.run_model(input_np, model_output_key="output_a") 217 | 218 | self.assertLen(model_signature_mock.call_args_list, 1) 219 | self.assertTrue( 220 | tensor_equal(model_signature_mock.call_args[0][0], input_tensor), 221 | "Tensor passed to model values do not match input.", 222 | ) 223 | # Missing a check for the shape of the input tensor. 224 | np.testing.assert_array_equal(result, output_np) 225 | 226 | def test_multi_output(self): 227 | # Setup test values. 228 | input_np = np.array([[0.5] * 3] * 3, dtype=np.float32) 229 | input_tensor = tf.constant(input_np, shape=(3, 3), dtype=tf.float32) 230 | output_np = np.ones((3, 2), dtype=np.float32) 231 | output_tensor = tf.constant(output_np, shape=(3, 2), dtype=tf.float32) 232 | second_out_np = np.array([[0.7] * 3] * 3, dtype=np.float32) 233 | second_out_tensor = tf.constant( 234 | second_out_np, shape=(3, 3), dtype=tf.float32 235 | ) 236 | surplus_np = np.array([[0.1] * 3] * 2, dtype=np.float32) 237 | surplus_tensor = tf.constant(surplus_np, shape=(3, 2), dtype=tf.float32) 238 | # Setup mock model. 239 | model_signature_mock = mock.MagicMock() 240 | model_signature_mock.return_value = { 241 | "output_a": output_tensor, 242 | "output_b": surplus_tensor, 243 | "output_c": second_out_tensor, 244 | } 245 | self._model.signatures = {"serving_default": model_signature_mock} 246 | 247 | result = self._runner.run_model_multiple_output( 248 | input_np, model_output_keys={"output_a", "output_c"} 249 | ) 250 | 251 | self.assertLen(model_signature_mock.call_args_list, 1) 252 | self.assertTrue( 253 | tensor_equal(model_signature_mock.call_args[0][0], input_tensor), 254 | "Tensor passed to model values do not match input.", 255 | ) 256 | # Missing a check for the shape of the input tensor. 257 | self.assertEqual(result.keys(), {"output_a", "output_c"}) 258 | np.testing.assert_array_equal(result["output_a"], output_np) 259 | np.testing.assert_array_equal(result["output_c"], second_out_np) 260 | 261 | def test_batch_multi_output(self): 262 | # Setup test values. 263 | input_nps = [ 264 | np.array([[0.5] * 3] * 3, dtype=np.float32), 265 | np.array([[0.25] * 3] * 3, dtype=np.float32), 266 | ] 267 | input_tensors = [ 268 | tf.constant(input_np, shape=(3, 3), dtype=tf.float32) 269 | for input_np in input_nps 270 | ] 271 | output_np_a_1 = np.array([[0.5] * 2] * 3, dtype=np.float32) 272 | output_np_a_2 = np.array([[0.25] * 2] * 3, dtype=np.float32) 273 | output_np_b_1 = np.array([[0.6] * 2] * 3, dtype=np.float32) 274 | output_np_b_2 = np.array([[0.2] * 2] * 3, dtype=np.float32) 275 | output_np_c_1 = np.array([[0.7] * 2] * 3, dtype=np.float32) 276 | output_np_c_2 = np.array([[0.8] * 2] * 3, dtype=np.float32) 277 | output_maps = [ 278 | { 279 | "output_a": tf.constant( 280 | output_np_a_1, shape=(3, 2), dtype=tf.float32 281 | ), 282 | "output_b": tf.constant( 283 | output_np_b_1, shape=(3, 2), dtype=tf.float32 284 | ), 285 | "output_c": tf.constant( 286 | output_np_c_1, shape=(3, 2), dtype=tf.float32 287 | ), 288 | }, 289 | { 290 | "output_a": tf.constant( 291 | output_np_a_2, shape=(3, 2), dtype=tf.float32 292 | ), 293 | "output_b": tf.constant( 294 | output_np_b_2, shape=(3, 2), dtype=tf.float32 295 | ), 296 | "output_c": tf.constant( 297 | output_np_c_2, shape=(3, 2), dtype=tf.float32 298 | ), 299 | }, 300 | ] 301 | # Setup mock model. 302 | model_signature_mock = mock.MagicMock() 303 | model_signature_mock.side_effect = output_maps 304 | self._model.signatures = {"serving_default": model_signature_mock} 305 | 306 | result = self._runner.batch_model_multiple_output( 307 | input_nps, model_output_keys={"output_a", "output_c"} 308 | ) 309 | 310 | self.assertLen(model_signature_mock.call_args_list, len(input_nps)) 311 | for input_tensor, call in zip( 312 | input_tensors, model_signature_mock.call_args_list 313 | ): 314 | self.assertTrue( 315 | tensor_equal(call[0][0], input_tensor), 316 | "Tensor passed to model values do not match input.", 317 | ) 318 | # Missing a check for the shape of the input tensor. 319 | np.testing.assert_array_equal(result[0]["output_a"], output_np_a_1) 320 | np.testing.assert_array_equal(result[0]["output_c"], output_np_c_1) 321 | np.testing.assert_array_equal(result[1]["output_a"], output_np_a_2) 322 | np.testing.assert_array_equal(result[1]["output_c"], output_np_c_2) 323 | self.assertNotIn("output_b", result[0]) 324 | self.assertNotIn("output_b", result[1]) 325 | 326 | @parameterized.named_parameters( 327 | ("name_only", "alternate", None), 328 | ("name_and_version", "alternate", 1), 329 | ("version_only", "default", 1), 330 | ) 331 | def test_not_implemented_multiversion(self, name: str, version: int | None): 332 | with self.assertRaises(NotImplementedError): 333 | self._runner.run_model( 334 | model_input=np.array([[0.5] * 3] * 3, dtype=np.float32), 335 | model_name=name, 336 | model_version=version, 337 | ) 338 | 339 | 340 | if __name__ == "__main__": 341 | absltest.main() 342 | -------------------------------------------------------------------------------- /python/serving/logging_lib/cloud_logging_client.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Wrapper for cloud ops structured logging.""" 16 | from __future__ import annotations 17 | 18 | import os 19 | import sys 20 | import threading 21 | from typing import Any, Mapping, Optional, Union 22 | 23 | from absl import flags 24 | from absl import logging 25 | import google.auth 26 | import psutil 27 | 28 | from serving.logging_lib.flags import secret_flag_utils 29 | from serving.logging_lib import cloud_logging_client_instance 30 | 31 | # name of cloud ops log 32 | CLOUD_OPS_LOG_NAME_FLG = flags.DEFINE_string( 33 | 'ops_log_name', 34 | secret_flag_utils.get_secret_or_env('CLOUD_OPS_LOG_NAME', 'python'), 35 | 'Cloud ops log name to write logs to.', 36 | ) 37 | CLOUD_OPS_LOG_PROJECT_FLG = flags.DEFINE_string( 38 | 'ops_log_project', 39 | secret_flag_utils.get_secret_or_env('CLOUD_OPS_LOG_PROJECT', None), 40 | 'GCP project name to write log to. Undefined = default', 41 | ) 42 | POD_HOSTNAME_FLG = flags.DEFINE_string( 43 | 'pod_hostname', 44 | secret_flag_utils.get_secret_or_env('HOSTNAME', None), 45 | 'Host name of GKE pod. Set by container ENV. ' 46 | 'Set to mock value in unit test.', 47 | ) 48 | POD_UID_FLG = flags.DEFINE_string( 49 | 'pod_uid', 50 | secret_flag_utils.get_secret_or_env('MY_POD_UID', None), 51 | 'UID of GKE pod. Do not set unless in test.', 52 | ) 53 | ENABLE_STRUCTURED_LOGGING_FLG = flags.DEFINE_boolean( 54 | 'enable_structured_logging', 55 | secret_flag_utils.get_bool_secret_or_env( 56 | 'ENABLE_STRUCTURED_LOGGING', 57 | not secret_flag_utils.get_bool_secret_or_env( 58 | 'DISABLE_STRUCTURED_LOGGING' 59 | ), 60 | ), 61 | 'Enable structured logging.', 62 | ) 63 | ENABLE_LOGGING_FLG = flags.DEFINE_boolean( 64 | 'enable_logging', 65 | secret_flag_utils.get_bool_secret_or_env('ENABLE_LOGGING_FLG', True), 66 | 'Enable logging.', 67 | ) 68 | 69 | _DEBUG_LOGGING_USE_ABSL_LOGGING_FLG = flags.DEFINE_boolean( 70 | 'debug_logging_use_absl_logging', 71 | # Confusing double negative is used to enable external env be a postive 72 | # statement and override the existing default 73 | not secret_flag_utils.get_bool_secret_or_env( 74 | 'ENABLE_CLOUD_LOGGING', 75 | not cloud_logging_client_instance.DEBUG_LOGGING_USE_ABSL_LOGGING, 76 | ), 77 | 'Debug/testing option to logs to absl.logger. Automatically set when ' 78 | 'running unit tests.', 79 | ) 80 | 81 | LOG_ALL_PYTHON_LOGS_TO_CLOUD_FLG = flags.DEFINE_boolean( 82 | 'log_all_python_logs_to_cloud', 83 | secret_flag_utils.get_bool_secret_or_env('LOG_ALL_PYTHON_LOGS_TO_CLOUD'), 84 | 'Logs every modules log to Cloud Ops.', 85 | ) 86 | 87 | PER_THREAD_LOG_SIGNATURES_FLG = flags.DEFINE_boolean( 88 | 'per_thread_log_signatures', 89 | secret_flag_utils.get_bool_secret_or_env('PER_THREAD_LOG_SIGNATURES', True), 90 | 'If True Log signatures are not shared are across threads if false ' 91 | 'Process threads share a common log signature', 92 | ) 93 | 94 | 95 | def _are_flags_initialized() -> bool: 96 | """Returns True if flags are initialized.""" 97 | try: 98 | return CLOUD_OPS_LOG_PROJECT_FLG.value is not None 99 | except (flags.UnparsedFlagAccessError, AttributeError): 100 | return False 101 | 102 | 103 | def _get_flags() -> Mapping[str, str]: 104 | load_flags = {} 105 | unparsed_flags = [] 106 | for flag_name in flags.FLAGS: 107 | try: 108 | load_flags[flag_name] = flags.FLAGS.__getattr__(flag_name) 109 | except flags.UnparsedFlagAccessError: 110 | unparsed_flags.append(flag_name) 111 | if unparsed_flags: 112 | load_flags['unparsed_flags'] = ', '.join(unparsed_flags) 113 | return load_flags 114 | 115 | 116 | def _default_gcp_project() -> str: 117 | try: 118 | _, project = google.auth.default( 119 | scopes=['https://www.googleapis.com/auth/cloud-platform'] 120 | ) 121 | return project if project is not None else '' 122 | except google.auth.exceptions.DefaultCredentialsError: 123 | return '' 124 | 125 | 126 | class CloudLoggingClient( 127 | cloud_logging_client_instance.CloudLoggingClientInstance 128 | ): 129 | """Wrapper for cloud ops structured logging. 130 | 131 | Automatically adds signature to structured logs to make traceable. 132 | """ 133 | 134 | # lock for log makes access to singleton 135 | # safe across threads. Logging used in main thread and ack_timeout_mon 136 | _singleton_instance: Optional[CloudLoggingClient] = None 137 | _startup_message_logged = False 138 | _singleton_lock = threading.RLock() 139 | 140 | @classmethod 141 | def _init_fork_module_state(cls) -> None: 142 | cls._singleton_instance = None 143 | cls._startup_message_logged = True 144 | cls._singleton_lock = threading.RLock() 145 | 146 | @classmethod 147 | def _fork_shutdown(cls) -> None: 148 | with cls._singleton_lock: 149 | cls._singleton_instance = None 150 | 151 | @classmethod 152 | def _set_absl_skip_frames(cls) -> None: 153 | """Sets absl logging attribution to skip over internal logging frames.""" 154 | logging.ABSLLogger.register_frame_to_skip( 155 | __file__, 156 | function_name='debug', 157 | ) 158 | logging.ABSLLogger.register_frame_to_skip( 159 | __file__, 160 | function_name='timed_debug', 161 | ) 162 | logging.ABSLLogger.register_frame_to_skip( 163 | __file__, 164 | function_name='info', 165 | ) 166 | logging.ABSLLogger.register_frame_to_skip( 167 | __file__, 168 | function_name='warning', 169 | ) 170 | logging.ABSLLogger.register_frame_to_skip( 171 | __file__, 172 | function_name='error', 173 | ) 174 | logging.ABSLLogger.register_frame_to_skip( 175 | __file__, 176 | function_name='critical', 177 | ) 178 | 179 | def __init__(self): 180 | with CloudLoggingClient._singleton_lock: 181 | if not _are_flags_initialized(): 182 | # if flags are not initialize then init logging flags 183 | flags.FLAGS(sys.argv, known_only=True) 184 | if CloudLoggingClient._singleton_instance is not None: 185 | raise cloud_logging_client_instance.CloudLoggerInstanceExceptionError( 186 | 'Singleton already initialized.' 187 | ) 188 | CloudLoggingClient._set_absl_skip_frames() 189 | gcp_project = ( 190 | _default_gcp_project() 191 | if CLOUD_OPS_LOG_PROJECT_FLG.value is None 192 | else CLOUD_OPS_LOG_PROJECT_FLG.value 193 | ) 194 | pod_host_name = ( 195 | '' if POD_HOSTNAME_FLG.value is None else POD_HOSTNAME_FLG.value 196 | ) 197 | pod_uid = '' if POD_UID_FLG.value is None else POD_UID_FLG.value 198 | super().__init__( 199 | log_name=CLOUD_OPS_LOG_NAME_FLG.value, 200 | gcp_project_to_write_logs_to=gcp_project, 201 | gcp_credentials=None, 202 | pod_hostname=pod_host_name, 203 | pod_uid=pod_uid, 204 | enable_structured_logging=ENABLE_STRUCTURED_LOGGING_FLG.value, 205 | use_absl_logging=_DEBUG_LOGGING_USE_ABSL_LOGGING_FLG.value, 206 | log_all_python_logs_to_cloud=LOG_ALL_PYTHON_LOGS_TO_CLOUD_FLG.value, 207 | per_thread_log_signatures=PER_THREAD_LOG_SIGNATURES_FLG.value, 208 | enabled=ENABLE_LOGGING_FLG.value, 209 | ) 210 | CloudLoggingClient._singleton_instance = self 211 | 212 | def startup_msg(self) -> None: 213 | """Logs default messages after logger fully initialized.""" 214 | if self.use_absl_logging() or CloudLoggingClient._startup_message_logged: 215 | return 216 | CloudLoggingClient._startup_message_logged = True 217 | pid = os.getpid() 218 | process_name = psutil.Process(pid).name() 219 | self.debug( 220 | 'Container process started.', 221 | {'process_name': process_name, 'process_id': pid}, 222 | ) 223 | self.debug( 224 | 'Container environmental variables.', os.environ 225 | ) # pytype: disable=wrong-arg-types # kwargs-checking 226 | vm = psutil.virtual_memory() 227 | self.debug( 228 | 'Compute instance', 229 | { 230 | 'processors(count)': os.cpu_count(), 231 | 'total_system_mem_(bytes)': vm.total, 232 | 'available_system_mem_(bytes)': vm.available, 233 | }, 234 | ) 235 | self.debug('Initalized flags', _get_flags()) 236 | project_name = self.gcp_project_name if self.gcp_project_name else 'DEFAULT' 237 | self.debug(f'Logging to GCP project: {project_name}') 238 | 239 | @classmethod 240 | def logger(cls, show_startup_msg: bool = True) -> CloudLoggingClient: 241 | if cls._singleton_instance is None: 242 | with cls._singleton_lock: # makes instance creation thread safe. 243 | if cls._singleton_instance is None: 244 | cls._singleton_instance = CloudLoggingClient() 245 | if not show_startup_msg: 246 | cls._startup_message_logged = True 247 | else: 248 | cls._singleton_instance.startup_msg() # pytype: disable=attribute-error 249 | return cls._singleton_instance # pytype: disable=bad-return-type 250 | 251 | 252 | def logger() -> CloudLoggingClient: 253 | return CloudLoggingClient.logger() 254 | 255 | 256 | def do_not_log_startup_msg() -> None: 257 | CloudLoggingClient.logger(show_startup_msg=False) 258 | 259 | 260 | def debug( 261 | msg: str, 262 | *struct: Union[Mapping[str, Any], Exception, None], 263 | stack_frames_back: int = 0, 264 | ) -> None: 265 | """Logs with debug severity. 266 | 267 | Args: 268 | msg: message to log (string). 269 | *struct: zero or more dict or exception to log in structured log. 270 | stack_frames_back: Additional stack frames back to log source_location. 271 | """ 272 | logger().debug(msg, *struct, stack_frames_back=stack_frames_back + 1) 273 | 274 | 275 | def timed_debug( 276 | msg: str, 277 | *struct: Union[Mapping[str, Any], Exception, None], 278 | stack_frames_back: int = 0, 279 | ) -> None: 280 | """Logs with debug severity and elapsed time since last timed debug log. 281 | 282 | Args: 283 | msg: message to log (string). 284 | *struct: zero or more dict or exception to log in structured log. 285 | stack_frames_back: Additional stack frames back to log source_location. 286 | """ 287 | logger().timed_debug(msg, *struct, stack_frames_back=stack_frames_back + 1) 288 | 289 | 290 | def info( 291 | msg: str, 292 | *struct: Union[Mapping[str, Any], Exception, None], 293 | stack_frames_back: int = 0, 294 | ) -> None: 295 | """Logs with info severity. 296 | 297 | Args: 298 | msg: message to log (string). 299 | *struct: zero or more dict or exception to log in structured log. 300 | stack_frames_back: Additional stack frames back to log source_location. 301 | """ 302 | logger().info(msg, *struct, stack_frames_back=stack_frames_back + 1) 303 | 304 | 305 | def warning( 306 | msg: str, 307 | *struct: Union[Mapping[str, Any], Exception, None], 308 | stack_frames_back: int = 0, 309 | ) -> None: 310 | """Logs with warning severity. 311 | 312 | Args: 313 | msg: Message to log (string). 314 | *struct: Zero or more dict or exception to log in structured log. 315 | stack_frames_back: Additional stack frames back to log source_location. 316 | """ 317 | logger().warning(msg, *struct, stack_frames_back=stack_frames_back + 1) 318 | 319 | 320 | def error( 321 | msg: str, 322 | *struct: Union[Mapping[str, Any], Exception, None], 323 | stack_frames_back: int = 0, 324 | ) -> None: 325 | """Logs with error severity. 326 | 327 | Args: 328 | msg: Message to log (string). 329 | *struct: Zero or more dict or exception to log in structured log. 330 | stack_frames_back: Additional stack frames back to log source_location. 331 | """ 332 | logger().error(msg, *struct, stack_frames_back=stack_frames_back + 1) 333 | 334 | 335 | def critical( 336 | msg: str, 337 | *struct: Union[Mapping[str, Any], Exception, None], 338 | stack_frames_back: int = 0, 339 | ) -> None: 340 | """Logs with critical severity. 341 | 342 | Args: 343 | msg: Message to log (string). 344 | *struct: Zero or more dict or exception to log in structured log. 345 | stack_frames_back: Additional stack frames back to log source_location. 346 | """ 347 | logger().critical(msg, *struct, stack_frames_back=stack_frames_back + 1) 348 | 349 | 350 | def clear_log_signature() -> None: 351 | logger().clear_log_signature() 352 | 353 | 354 | def get_log_signature() -> Mapping[str, Any]: 355 | return logger().log_signature 356 | 357 | 358 | def set_log_signature(sig: Mapping[str, Any]) -> None: 359 | logger().log_signature = sig 360 | 361 | 362 | def set_per_thread_log_signatures(val: bool) -> None: 363 | logger().per_thread_log_signatures = val 364 | 365 | 366 | def get_build_version(clip_length: Optional[int] = None) -> str: 367 | if clip_length is not None and clip_length >= 0: 368 | return logger().build_version[:clip_length] 369 | return logger().build_version 370 | 371 | 372 | def set_build_version(build_version: str) -> None: 373 | logger().build_version = build_version 374 | 375 | 376 | def set_log_trace_key(key: str) -> None: 377 | logger().trace_key = key 378 | 379 | 380 | # Logging interfaces are used from processes which are forked (gunicorn, 381 | # DICOM Proxy, Orchestrator, Refresher). In Python, forked processes do not 382 | # copy threads running within parent processes or re-initalize global/module 383 | # state. This can result in forked modules being executed with invalid global 384 | # state, e.g., acquired locks that will not release or references to invalid 385 | # state. The cloud logging library utilizes a background thread transporting 386 | # logs to cloud. The background threading is not compatiable with forking and 387 | # will seg-fault (python queue wait). This can be avoided, by stoping and 388 | # the background transport prior to forking and then restarting the transport 389 | # following the fork. 390 | os.register_at_fork( 391 | before=CloudLoggingClient._fork_shutdown, # pylint: disable=protected-access 392 | after_in_child=CloudLoggingClient._init_fork_module_state, # pylint: disable=protected-access 393 | ) 394 | -------------------------------------------------------------------------------- /python/serving/serving_framework/tensorflow/server_model_runner_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from unittest import mock 16 | 17 | import numpy as np 18 | import tensorflow as tf 19 | 20 | from absl.testing import absltest 21 | from serving.serving_framework.tensorflow import server_model_runner 22 | from tensorflow_serving.apis import predict_pb2 23 | from tensorflow_serving.apis import prediction_service_pb2_grpc 24 | 25 | 26 | class ServerModelRunnerTest(absltest.TestCase): 27 | 28 | def setUp(self): 29 | super().setUp() 30 | 31 | self._stub = mock.create_autospec( 32 | prediction_service_pb2_grpc.PredictionServiceStub, instance=True 33 | ) 34 | self._stub.Predict = mock.MagicMock() 35 | self._runner = server_model_runner.ServerModelRunner(stub=self._stub) 36 | 37 | def test_singleton_input(self): 38 | input_np = np.array([[0.5] * 3] * 3, dtype=np.float32) 39 | input_proto = predict_pb2.PredictRequest() 40 | input_proto.inputs["inputs"].CopyFrom(tf.make_tensor_proto(input_np)) 41 | input_proto.model_spec.name = "default" 42 | output_np = np.ones((3, 2), dtype=np.float32) 43 | output_proto = predict_pb2.PredictResponse() 44 | output_proto.outputs["output_0"].CopyFrom(tf.make_tensor_proto(output_np)) 45 | output_proto.model_spec.name = "default" 46 | self._stub.Predict.return_value = output_proto 47 | 48 | result = self._runner.run_model(input_np) 49 | 50 | self.assertLen(self._stub.Predict.call_args_list, 1) 51 | self.assertEqual( 52 | self._stub.Predict.call_args[0][0].SerializeToString( 53 | deterministic=True 54 | ), 55 | input_proto.SerializeToString(deterministic=True), 56 | "Proto passed to model does not match expectation.", 57 | ) 58 | np.testing.assert_array_equal(result, output_np) 59 | 60 | def test_map_input(self): 61 | input_np_map = { 62 | "a": np.array([[0.5] * 3] * 3, dtype=np.float32), 63 | "b": np.array([[0.25] * 3] * 3, dtype=np.float32), 64 | } 65 | input_proto = predict_pb2.PredictRequest() 66 | for label in input_np_map: 67 | input_proto.inputs[label].CopyFrom( 68 | tf.make_tensor_proto(input_np_map[label]) 69 | ) 70 | input_proto.model_spec.name = "default" 71 | output_np = np.ones((3, 2), dtype=np.float32) 72 | output_proto = predict_pb2.PredictResponse() 73 | output_proto.outputs["output_0"].CopyFrom(tf.make_tensor_proto(output_np)) 74 | output_proto.model_spec.name = "default" 75 | self._stub.Predict.return_value = output_proto 76 | 77 | result = self._runner.run_model(input_np_map) 78 | 79 | self.assertLen(self._stub.Predict.call_args_list, 1) 80 | self.assertEqual( 81 | self._stub.Predict.call_args[0][0].SerializeToString( 82 | deterministic=True 83 | ), 84 | input_proto.SerializeToString(deterministic=True), 85 | "Tensor passed to model values do not match input.", 86 | ) 87 | np.testing.assert_array_equal(result, output_np) 88 | 89 | def test_batch_singleton_input(self): 90 | input_nps = [ 91 | np.array([[0.5] * 3] * 3, dtype=np.float32), 92 | np.array([[0.25] * 3] * 3, dtype=np.float32), 93 | ] 94 | input_protos = [] 95 | for input_np in input_nps: 96 | input_proto = predict_pb2.PredictRequest() 97 | input_proto.inputs["inputs"].CopyFrom(tf.make_tensor_proto(input_np)) 98 | input_proto.model_spec.name = "default" 99 | input_protos.append(input_proto) 100 | output_nps = [ 101 | np.ones((3, 2), dtype=np.float32), 102 | np.zeros((3, 2), dtype=np.float32), 103 | ] 104 | output_protos = [] 105 | for output_np in output_nps: 106 | output_proto = predict_pb2.PredictResponse() 107 | output_proto.outputs["output_0"].CopyFrom(tf.make_tensor_proto(output_np)) 108 | output_proto.model_spec.name = "default" 109 | output_protos.append(output_proto) 110 | self._stub.Predict.side_effect = output_protos 111 | 112 | result = self._runner.batch_model(input_nps) 113 | 114 | self.assertLen(self._stub.Predict.call_args_list, len(input_nps)) 115 | for input_proto, call in zip( 116 | input_protos, self._stub.Predict.call_args_list 117 | ): 118 | self.assertEqual( 119 | call[0][0].SerializeToString(deterministic=True), 120 | input_proto.SerializeToString(deterministic=True), 121 | "Tensor passed to model values do not match input.", 122 | ) 123 | for result_np, output_np in zip(result, output_nps): 124 | np.testing.assert_array_equal(result_np, output_np) 125 | 126 | def test_batch_map_input(self): 127 | input_np_maps = [ 128 | { 129 | "a": np.array([[0.5] * 3] * 3, dtype=np.float32), 130 | "b": np.array([[0.25] * 3] * 3, dtype=np.float32), 131 | }, 132 | { 133 | "a": np.array([[0.25] * 3] * 3, dtype=np.float32), 134 | "b": np.array([[0.5] * 3] * 3, dtype=np.float32), 135 | }, 136 | ] 137 | input_protos = [] 138 | for input_np_map in input_np_maps: 139 | input_proto = predict_pb2.PredictRequest() 140 | for label in input_np_map: 141 | input_proto.inputs[label].CopyFrom( 142 | tf.make_tensor_proto(input_np_map[label]) 143 | ) 144 | input_proto.model_spec.name = "default" 145 | input_protos.append(input_proto) 146 | output_nps = [ 147 | np.ones((3, 2), dtype=np.float32), 148 | np.zeros((3, 2), dtype=np.float32), 149 | ] 150 | output_protos = [] 151 | for output_np in output_nps: 152 | output_proto = predict_pb2.PredictResponse() 153 | output_proto.outputs["output_0"].CopyFrom(tf.make_tensor_proto(output_np)) 154 | output_proto.model_spec.name = "default" 155 | output_protos.append(output_proto) 156 | self._stub.Predict.side_effect = output_protos 157 | 158 | result = self._runner.batch_model(input_np_maps) 159 | 160 | self.assertLen(self._stub.Predict.call_args_list, len(input_np_maps)) 161 | for input_proto, call in zip( 162 | input_protos, self._stub.Predict.call_args_list 163 | ): 164 | self.assertEqual( 165 | call[0][0].SerializeToString(deterministic=True), 166 | input_proto.SerializeToString(deterministic=True), 167 | "Tensor passed to model values do not match input.", 168 | ) 169 | for result_np, output_np in zip(result, output_nps): 170 | np.testing.assert_array_equal(result_np, output_np) 171 | 172 | def test_keyed_output(self): 173 | # Set up test values. 174 | input_np = np.array([[0.5] * 3] * 3, dtype=np.float32) 175 | input_proto = predict_pb2.PredictRequest() 176 | input_proto.inputs["inputs"].CopyFrom(tf.make_tensor_proto(input_np)) 177 | input_proto.model_spec.name = "default" 178 | output_np = np.ones((3, 2), dtype=np.float32) 179 | surplus_np = np.array([[0.1] * 3] * 2, dtype=np.float32) 180 | output_proto = predict_pb2.PredictResponse() 181 | output_proto.outputs["output_a"].CopyFrom(tf.make_tensor_proto(output_np)) 182 | output_proto.outputs["output_b"].CopyFrom(tf.make_tensor_proto(surplus_np)) 183 | output_proto.model_spec.name = "default" 184 | self._stub.Predict.return_value = output_proto 185 | 186 | result = self._runner.run_model(input_np, model_output_key="output_a") 187 | 188 | self.assertLen(self._stub.Predict.call_args_list, 1) 189 | self.assertEqual( 190 | self._stub.Predict.call_args[0][0].SerializeToString( 191 | deterministic=True 192 | ), 193 | input_proto.SerializeToString(deterministic=True), 194 | "Proto passed to model does not match expectation.", 195 | ) 196 | np.testing.assert_array_equal(result, output_np) 197 | 198 | def test_miskeyed_output(self): 199 | # Set up test values. 200 | input_np = np.array([[0.5] * 3] * 3, dtype=np.float32) 201 | input_proto = predict_pb2.PredictRequest() 202 | input_proto.inputs["inputs"].CopyFrom(tf.make_tensor_proto(input_np)) 203 | input_proto.model_spec.name = "default" 204 | output_np = np.ones((3, 2), dtype=np.float32) 205 | surplus_np = np.array([[0.1] * 3] * 2, dtype=np.float32) 206 | output_proto = predict_pb2.PredictResponse() 207 | output_proto.outputs["output_a"].CopyFrom(tf.make_tensor_proto(output_np)) 208 | output_proto.outputs["output_b"].CopyFrom(tf.make_tensor_proto(surplus_np)) 209 | output_proto.model_spec.name = "default" 210 | self._stub.Predict.return_value = output_proto 211 | 212 | with self.assertRaises(KeyError): 213 | _ = self._runner.run_model(input_np, model_output_key="output_c") 214 | 215 | def test_multi_output(self): 216 | # Set up test values. 217 | input_np = np.array([[0.5] * 3] * 3, dtype=np.float32) 218 | input_proto = predict_pb2.PredictRequest() 219 | input_proto.inputs["inputs"].CopyFrom(tf.make_tensor_proto(input_np)) 220 | input_proto.model_spec.name = "default" 221 | output_np = np.ones((3, 2), dtype=np.float32) 222 | second_out_np = np.array([[0.7] * 3] * 3, dtype=np.float32) 223 | surplus_np = np.array([[0.1] * 3] * 2, dtype=np.float32) 224 | output_proto = predict_pb2.PredictResponse() 225 | output_proto.outputs["output_a"].CopyFrom(tf.make_tensor_proto(output_np)) 226 | output_proto.outputs["output_b"].CopyFrom(tf.make_tensor_proto(surplus_np)) 227 | output_proto.outputs["output_c"].CopyFrom( 228 | tf.make_tensor_proto(second_out_np) 229 | ) 230 | output_proto.model_spec.name = "default" 231 | self._stub.Predict.return_value = output_proto 232 | 233 | result = self._runner.run_model_multiple_output( 234 | input_np, model_output_keys={"output_a", "output_c"} 235 | ) 236 | 237 | self.assertLen(self._stub.Predict.call_args_list, 1) 238 | self.assertEqual( 239 | self._stub.Predict.call_args[0][0].SerializeToString( 240 | deterministic=True 241 | ), 242 | input_proto.SerializeToString(deterministic=True), 243 | "Proto passed to model does not match expectation.", 244 | ) 245 | self.assertEqual(result.keys(), {"output_a", "output_c"}) 246 | np.testing.assert_array_equal(result["output_a"], output_np) 247 | np.testing.assert_array_equal(result["output_c"], second_out_np) 248 | 249 | def test_batch_multi_output(self): 250 | input_nps = [ 251 | np.array([[0.5] * 3] * 3, dtype=np.float32), 252 | np.array([[0.25] * 3] * 3, dtype=np.float32), 253 | ] 254 | input_protos = [] 255 | for input_np in input_nps: 256 | input_proto = predict_pb2.PredictRequest() 257 | input_proto.inputs["inputs"].CopyFrom(tf.make_tensor_proto(input_np)) 258 | input_proto.model_spec.name = "default" 259 | input_protos.append(input_proto) 260 | output_nps_a = [ 261 | np.ones((3, 2), dtype=np.float32), 262 | np.zeros((3, 2), dtype=np.float32), 263 | ] 264 | output_nps_b = [ 265 | np.array([[0.6] * 2] * 3, dtype=np.float32), 266 | np.array([[0.2] * 2] * 3, dtype=np.float32), 267 | ] 268 | output_nps_c = [ 269 | np.array([[0.7] * 2] * 3, dtype=np.float32), 270 | np.array([[0.8] * 2] * 3, dtype=np.float32), 271 | ] 272 | output_protos = [] 273 | for output_np in zip(output_nps_a, output_nps_b, output_nps_c): 274 | output_proto = predict_pb2.PredictResponse() 275 | output_proto.outputs["output_a"].CopyFrom( 276 | tf.make_tensor_proto(output_np[0]) 277 | ) 278 | output_proto.outputs["output_b"].CopyFrom( 279 | tf.make_tensor_proto(output_np[1]) 280 | ) 281 | output_proto.outputs["output_c"].CopyFrom( 282 | tf.make_tensor_proto(output_np[2]) 283 | ) 284 | output_proto.model_spec.name = "default" 285 | output_protos.append(output_proto) 286 | self._stub.Predict.side_effect = output_protos 287 | 288 | result = self._runner.batch_model_multiple_output( 289 | input_nps, model_output_keys={"output_a", "output_c"} 290 | ) 291 | 292 | self.assertLen(self._stub.Predict.call_args_list, len(input_nps)) 293 | for input_proto, call in zip( 294 | input_protos, self._stub.Predict.call_args_list 295 | ): 296 | self.assertEqual( 297 | call[0][0].SerializeToString(deterministic=True), 298 | input_proto.SerializeToString(deterministic=True), 299 | "Tensor passed to model values do not match input.", 300 | ) 301 | for result_map, *output_nps in zip(result, output_nps_a, output_nps_c): 302 | np.testing.assert_array_equal(result_map["output_a"], output_nps[0]) 303 | np.testing.assert_array_equal(result_map["output_c"], output_nps[1]) 304 | self.assertNotIn("output_b", result_map) 305 | 306 | def test_model_name_specification(self): 307 | input_np = np.array([[0.5] * 3] * 3, dtype=np.float32) 308 | input_proto = predict_pb2.PredictRequest() 309 | input_proto.inputs["inputs"].CopyFrom(tf.make_tensor_proto(input_np)) 310 | input_proto.model_spec.name = "alternative" 311 | output_np = np.ones((3, 2), dtype=np.float32) 312 | output_proto = predict_pb2.PredictResponse() 313 | output_proto.outputs["output_0"].CopyFrom(tf.make_tensor_proto(output_np)) 314 | output_proto.model_spec.name = "alternative" 315 | self._stub.Predict.return_value = output_proto 316 | 317 | _ = self._runner.run_model(input_np, model_name="alternative") 318 | 319 | self.assertLen(self._stub.Predict.call_args_list, 1) 320 | self.assertEqual( 321 | self._stub.Predict.call_args[0][0].SerializeToString( 322 | deterministic=True 323 | ), 324 | input_proto.SerializeToString(deterministic=True), 325 | "Proto passed to model does not match expectation.", 326 | ) 327 | 328 | def test_model_version_specification(self): 329 | input_np = np.array([[0.5] * 3] * 3, dtype=np.float32) 330 | input_proto = predict_pb2.PredictRequest() 331 | input_proto.inputs["inputs"].CopyFrom(tf.make_tensor_proto(input_np)) 332 | input_proto.model_spec.name = "default" 333 | input_proto.model_spec.version.value = 5 334 | output_np = np.ones((3, 2), dtype=np.float32) 335 | output_proto = predict_pb2.PredictResponse() 336 | output_proto.outputs["output_0"].CopyFrom(tf.make_tensor_proto(output_np)) 337 | output_proto.model_spec.name = "alternative" 338 | self._stub.Predict.return_value = output_proto 339 | 340 | _ = self._runner.run_model(input_np, model_version=5) 341 | 342 | self.assertLen(self._stub.Predict.call_args_list, 1) 343 | self.assertEqual( 344 | self._stub.Predict.call_args[0][0].SerializeToString( 345 | deterministic=True 346 | ), 347 | input_proto.SerializeToString(deterministic=True), 348 | "Proto passed to model does not match expectation.", 349 | ) 350 | 351 | 352 | if __name__ == "__main__": 353 | absltest.main() 354 | -------------------------------------------------------------------------------- /python/serving/logging_lib/flags/secret_flag_utils_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Tests for secret_flag_utils.""" 16 | import os 17 | from typing import Any, List 18 | from unittest import mock 19 | 20 | from absl.testing import absltest 21 | from absl.testing import parameterized 22 | import google.api_core 23 | from google.cloud import secretmanager 24 | 25 | from serving.logging_lib.flags import secret_flag_utils 26 | 27 | _MOCK_SECRET = 'projects/test-project/secrets/test-secret' 28 | 29 | 30 | def _create_mock_secret_value_response( 31 | data: Any, 32 | ) -> secretmanager.AccessSecretVersionResponse: 33 | payload = mock.create_autospec(secretmanager.SecretPayload, instance=True) 34 | payload.data = data 35 | sec_val = mock.create_autospec( 36 | secretmanager.AccessSecretVersionResponse, instance=True 37 | ) 38 | sec_val.payload = payload 39 | return sec_val 40 | 41 | 42 | def _create_mock_secret_version_response( 43 | secret_version_prefix: str, 44 | version_numbers: List[str], 45 | ) -> list[secretmanager.ListSecretVersionsResponse]: 46 | version_list = [] 47 | for version_number in version_numbers: 48 | version = mock.create_autospec( 49 | secretmanager.ListSecretVersionsResponse, instance=True 50 | ) 51 | version.name = f'{secret_version_prefix}{version_number}' 52 | version_list.append(version) 53 | return version_list 54 | 55 | 56 | class SecretFlagUtilsTest(parameterized.TestCase): 57 | 58 | def setUp(self): 59 | super().setUp() 60 | secret_flag_utils._cache.clear() 61 | self.enter_context( 62 | mock.patch( 63 | 'google.auth.default', autospec=True, return_value=('mock', 'mock') 64 | ) 65 | ) 66 | 67 | def tearDown(self): 68 | super().tearDown() 69 | secret_flag_utils._ENABLE_ENV_SECRET_MANAGER = False 70 | 71 | def test_init_fork_module_state(self): 72 | secret_flag_utils._cache = None 73 | secret_flag_utils._cache_lock = None 74 | secret_flag_utils._init_fork_module_state() 75 | self.assertIsNotNone(secret_flag_utils._cache) 76 | self.assertIsNotNone(secret_flag_utils._cache_lock) 77 | 78 | @parameterized.named_parameters( 79 | dict( 80 | testcase_name='expected', 81 | version_numbers=['1', '2', '3'], 82 | expected_version='3', 83 | ), 84 | dict( 85 | testcase_name='ignore_unexpected_version', 86 | version_numbers=['A', '1'], 87 | expected_version='1', 88 | ), 89 | ) 90 | def test_get_secret_version(self, version_numbers, expected_version): 91 | client = mock.create_autospec( 92 | secretmanager.SecretManagerServiceClient, instance=True 93 | ) 94 | client.list_secret_versions.return_value = ( 95 | _create_mock_secret_version_response( 96 | f'{_MOCK_SECRET}/versions/', version_numbers 97 | ) 98 | ) 99 | version = secret_flag_utils._get_secret_version(client, _MOCK_SECRET) 100 | self.assertEqual(version, expected_version) 101 | 102 | def test_get_secret_version_cannot_find_version(self): 103 | client = mock.create_autospec( 104 | secretmanager.SecretManagerServiceClient, instance=True 105 | ) 106 | client.list_secret_versions.return_value = ( 107 | _create_mock_secret_version_response('', ['NO_MATCH', 'NO_MATCH']) 108 | ) 109 | with self.assertRaises(secret_flag_utils.SecretDecodeError): 110 | secret_flag_utils._get_secret_version(client, _MOCK_SECRET) 111 | 112 | def test_get_secret_version_ignore_bad_response(self): 113 | client = mock.create_autospec( 114 | secretmanager.SecretManagerServiceClient, instance=True 115 | ) 116 | client.list_secret_versions.return_value = [] 117 | with self.assertRaises(secret_flag_utils.SecretDecodeError): 118 | secret_flag_utils._get_secret_version(client, _MOCK_SECRET) 119 | 120 | def test_read_secrets_path_to_version_empty(self): 121 | self.assertEqual(secret_flag_utils._read_secrets(''), {}) 122 | 123 | def test_read_secrets_path_invalid_raises(self): 124 | with self.assertRaises(secret_flag_utils.SecretDecodeError): 125 | secret_flag_utils._read_secrets('no_match') 126 | 127 | def test_read_secrets_path_disabled_returns_empty_result(self): 128 | self.assertEqual( 129 | secret_flag_utils._read_secrets('projects/prj/secrets/sec'), {} 130 | ) 131 | 132 | @parameterized.named_parameters([ 133 | dict( 134 | testcase_name='none', 135 | data=None, 136 | expected_value={}, 137 | ), 138 | dict( 139 | testcase_name='empty', 140 | data='', 141 | expected_value={}, 142 | ), 143 | dict( 144 | testcase_name='empty_dict', 145 | data='{}', 146 | expected_value={}, 147 | ), 148 | dict( 149 | testcase_name='defined_dict_str', 150 | data='{"abc": "123"}', 151 | expected_value={'abc': '123'}, 152 | ), 153 | dict( 154 | testcase_name='defined_dict_bytes', 155 | data='{"abc": "123"}'.encode('utf-8'), 156 | expected_value={'abc': '123'}, 157 | ), 158 | ]) 159 | @mock.patch.object( 160 | secret_flag_utils, '_get_secret_version', autospec=True, return_value='1' 161 | ) 162 | @mock.patch.object( 163 | secretmanager.SecretManagerServiceClient, 164 | 'access_secret_version', 165 | autospec=True, 166 | ) 167 | def test_read_secrets_value( 168 | self, mk_access_secret, mk_get_secret_version, data, expected_value 169 | ): 170 | secret_flag_utils._ENABLE_ENV_SECRET_MANAGER = True 171 | project_secret = 'projects/prj/secrets/sec' 172 | secret_version = '1' 173 | mk_get_secret_version.return_value = secret_version 174 | mk_access_secret.return_value = _create_mock_secret_value_response(data) 175 | self.assertEqual( 176 | secret_flag_utils._read_secrets(project_secret), 177 | expected_value, 178 | ) 179 | mk_get_secret_version.assert_called_once_with(mock.ANY, project_secret) 180 | mk_access_secret.assert_called_once_with( 181 | mock.ANY, 182 | request={'name': f'{project_secret}/versions/{secret_version}'}, 183 | ) 184 | 185 | @mock.patch.object(secret_flag_utils, '_get_secret_version', autospec=True) 186 | @mock.patch.object( 187 | secretmanager.SecretManagerServiceClient, 188 | 'access_secret_version', 189 | autospec=True, 190 | ) 191 | def test_read_secrets_value_does_not_call_version_if_in_path( 192 | self, mk_access_secret, mk_get_secret_version 193 | ): 194 | secret_flag_utils._ENABLE_ENV_SECRET_MANAGER = True 195 | project_secret = 'projects/prj/secrets/sec/versions/2' 196 | mk_access_secret.return_value = _create_mock_secret_value_response('') 197 | self.assertEqual( 198 | secret_flag_utils._read_secrets(project_secret), 199 | {}, 200 | ) 201 | mk_get_secret_version.assert_not_called() 202 | mk_access_secret.assert_called_once_with( 203 | mock.ANY, request={'name': project_secret} 204 | ) 205 | 206 | def test_read_secrets_returns_cache_value(self): 207 | secret_flag_utils._ENABLE_ENV_SECRET_MANAGER = True 208 | project_secret = 'projects/prj/secrets/sec/versions/3' 209 | expected = 'EXPECTED_VALUE' 210 | secret_flag_utils._cache[project_secret] = expected 211 | self.assertEqual(secret_flag_utils._read_secrets(project_secret), expected) 212 | 213 | @parameterized.named_parameters([ 214 | dict(testcase_name='invalid_json', response='{'), 215 | dict(testcase_name='not_dict', response='[]'), 216 | ]) 217 | @mock.patch.object( 218 | secretmanager.SecretManagerServiceClient, 219 | 'access_secret_version', 220 | autospec=True, 221 | ) 222 | def test_read_secrets_returns_invalid_value_raises( 223 | self, mk_access_secret, response 224 | ): 225 | secret_flag_utils._ENABLE_ENV_SECRET_MANAGER = True 226 | project_secret = 'projects/prj/secrets/sec/versions/2' 227 | mk_access_secret.return_value = _create_mock_secret_value_response(response) 228 | with self.assertRaises(secret_flag_utils.SecretDecodeError): 229 | secret_flag_utils._read_secrets(project_secret) 230 | 231 | @parameterized.parameters([ 232 | google.api_core.exceptions.NotFound('mock'), 233 | google.api_core.exceptions.PermissionDenied('mock'), 234 | ]) 235 | @mock.patch.object( 236 | secretmanager.SecretManagerServiceClient, 237 | 'access_secret_version', 238 | autospec=True, 239 | ) 240 | def test_access_secret_version_raises(self, exp, mk_access_secret): 241 | secret_flag_utils._ENABLE_ENV_SECRET_MANAGER = True 242 | project_secret = 'projects/prj/secrets/sec/versions/2' 243 | mk_access_secret.side_effect = exp 244 | with self.assertRaises(secret_flag_utils.SecretDecodeError): 245 | secret_flag_utils._read_secrets(project_secret) 246 | 247 | @parameterized.parameters([ 248 | google.api_core.exceptions.NotFound('mock'), 249 | google.api_core.exceptions.PermissionDenied('mock'), 250 | ]) 251 | @mock.patch.object(secret_flag_utils, '_get_secret_version', autospec=True) 252 | def test_get_secret_version_raises(self, exp, mk_get_secret_version): 253 | secret_flag_utils._ENABLE_ENV_SECRET_MANAGER = True 254 | project_secret = 'projects/prj/secrets/sec' 255 | mk_get_secret_version.side_effect = exp 256 | with self.assertRaises(secret_flag_utils.SecretDecodeError): 257 | secret_flag_utils._read_secrets(project_secret) 258 | 259 | @parameterized.named_parameters([ 260 | dict( 261 | testcase_name='defined_env', 262 | env_name='MOCK_ENV', 263 | expected_value='DEFINED_VALUE', 264 | ), 265 | dict( 266 | testcase_name='not_defined_env', 267 | env_name='UNDEFINED_ENV', 268 | expected_value='MOCK_DEFAULT', 269 | ), 270 | ]) 271 | def test_get_secret_or_env_no_config_defined_returns_env_value_or_default( 272 | self, env_name, expected_value 273 | ): 274 | with mock.patch.dict(os.environ, {'MOCK_ENV': 'DEFINED_VALUE'}): 275 | self.assertEqual( 276 | secret_flag_utils.get_secret_or_env(env_name, 'MOCK_DEFAULT'), 277 | expected_value, 278 | ) 279 | 280 | @parameterized.named_parameters([ 281 | dict( 282 | testcase_name='secret_env', 283 | search_env='SECRET_ENV', 284 | expected='SECRET_VALUE', 285 | ), 286 | dict( 287 | testcase_name='not_in_sec', 288 | search_env='NOT_IN_SEC', 289 | expected='ENV_VALUE', 290 | ), 291 | dict( 292 | testcase_name='not_in_env', 293 | search_env='NOT_IN_ENV', 294 | expected='MOCK_DEFAULT', 295 | ), 296 | ]) 297 | @mock.patch.object( 298 | secretmanager.SecretManagerServiceClient, 299 | 'access_secret_version', 300 | autospec=True, 301 | ) 302 | def test_get_secret_or_env_with_config( 303 | self, mk_access_secret, search_env, expected 304 | ): 305 | secret_flag_utils._ENABLE_ENV_SECRET_MANAGER = True 306 | project_secret = 'projects/prj/secrets/sec/versions/1' 307 | mk_access_secret.return_value = _create_mock_secret_value_response( 308 | '{"SECRET_ENV": "SECRET_VALUE"}' 309 | ) 310 | with mock.patch.dict( 311 | os.environ, 312 | { 313 | secret_flag_utils._SECRET_MANAGER_ENV_CONFIG: project_secret, 314 | 'NOT_IN_SEC': 'ENV_VALUE', 315 | }, 316 | ): 317 | self.assertEqual( 318 | secret_flag_utils.get_secret_or_env(search_env, 'MOCK_DEFAULT'), 319 | expected, 320 | ) 321 | mk_access_secret.assert_called_once_with( 322 | mock.ANY, 323 | request={'name': project_secret}, 324 | ) 325 | 326 | @mock.patch.object( 327 | secretmanager.SecretManagerServiceClient, 328 | 'access_secret_version', 329 | autospec=True, 330 | ) 331 | def test_get_secret_or_env_default_return_value(self, mk_access_secret): 332 | secret_flag_utils._ENABLE_ENV_SECRET_MANAGER = True 333 | project_secret = 'projects/prj/secrets/sec/versions/1' 334 | mk_access_secret.return_value = _create_mock_secret_value_response( 335 | '{"MOCK_ENV": "SECRET_VALUE"}' 336 | ) 337 | with mock.patch.dict( 338 | os.environ, 339 | { 340 | secret_flag_utils._SECRET_MANAGER_ENV_CONFIG: project_secret, 341 | }, 342 | ): 343 | self.assertIsNone( 344 | secret_flag_utils.get_secret_or_env('NOT_IN_SECRET_OR_ENV', None) 345 | ) 346 | mk_access_secret.assert_called_once_with( 347 | mock.ANY, 348 | request={'name': project_secret}, 349 | ) 350 | 351 | @parameterized.named_parameters([ 352 | dict(testcase_name='FOUND', env_name='MOCK_ENV', expected_value=True), 353 | dict( 354 | testcase_name='DEFAULT_MISSING', 355 | env_name='NOT_FOUND_ENV', 356 | expected_value=False, 357 | ), 358 | ]) 359 | @mock.patch.object( 360 | secretmanager.SecretManagerServiceClient, 361 | 'access_secret_version', 362 | autospec=True, 363 | ) 364 | def test_get_bool_secret_or_env( 365 | self, mk_access_secret, env_name, expected_value 366 | ): 367 | secret_flag_utils._ENABLE_ENV_SECRET_MANAGER = True 368 | project_secret = 'projects/prj/secrets/sec/versions/1' 369 | mk_access_secret.return_value = _create_mock_secret_value_response( 370 | '{"MOCK_ENV": "TRUE"}' 371 | ) 372 | with mock.patch.dict( 373 | os.environ, 374 | { 375 | secret_flag_utils._SECRET_MANAGER_ENV_CONFIG: project_secret, 376 | }, 377 | ): 378 | self.assertEqual( 379 | secret_flag_utils.get_bool_secret_or_env(env_name), expected_value 380 | ) 381 | mk_access_secret.assert_called_once_with( 382 | mock.ANY, 383 | request={'name': project_secret}, 384 | ) 385 | 386 | @parameterized.named_parameters([ 387 | dict(testcase_name='FOUND', env_name='ENV_FOUND', expected_value=False), 388 | dict( 389 | testcase_name='DEFAULT_MISSING', 390 | env_name='NOT_FOUND_ENV', 391 | expected_value=True, 392 | ), 393 | ]) 394 | @mock.patch.object( 395 | secretmanager.SecretManagerServiceClient, 396 | 'access_secret_version', 397 | autospec=True, 398 | ) 399 | def test_get_bool_secret_or_env_trys_env( 400 | self, mk_access_secret, env_name, expected_value 401 | ): 402 | secret_flag_utils._ENABLE_ENV_SECRET_MANAGER = True 403 | project_secret = 'projects/prj/secrets/sec/versions/1' 404 | mk_access_secret.return_value = _create_mock_secret_value_response('{}') 405 | with mock.patch.dict( 406 | os.environ, 407 | { 408 | secret_flag_utils._SECRET_MANAGER_ENV_CONFIG: project_secret, 409 | 'ENV_FOUND': 'False', 410 | }, 411 | ): 412 | self.assertEqual( 413 | secret_flag_utils.get_bool_secret_or_env(env_name, True), 414 | expected_value, 415 | ) 416 | mk_access_secret.assert_called_once_with( 417 | mock.ANY, 418 | request={'name': project_secret}, 419 | ) 420 | 421 | 422 | if __name__ == '__main__': 423 | absltest.main() 424 | --------------------------------------------------------------------------------