├── tests ├── __init__.py ├── conftest.py ├── test_descriptor_to_message_class.py ├── test_descriptor_to_file.py ├── test_validation.py └── test_json_to_service.py ├── .prettierrc.yaml ├── .coveragerc ├── .prettierignore ├── .github ├── CODEOWNERS ├── ISSUE_TEMPLATE │ ├── user-story.md │ ├── feature_request.md │ └── bug_report.md └── workflows │ ├── release.yml │ └── tests.yml ├── .gitignore ├── requirements_test.txt ├── requirements.txt ├── .isort.cfg ├── py_to_proto ├── compat_annotated.py ├── compat.py ├── __init__.py ├── descriptor_to_message_class.py ├── jtd_to_proto.py ├── descriptor_to_file.py ├── json_to_service.py ├── utils.py ├── dataclass_to_proto.py ├── validation.py └── converter_base.py ├── .pre-commit-config.yaml ├── scripts ├── release.sh ├── publish.sh ├── run_tests.sh ├── fmt.sh ├── install_release.sh └── build_wheel.sh ├── ppa.dockerfile ├── LICENSE ├── Makefile ├── setup.py ├── Dockerfile ├── CODE_OF_CONDUCT.md ├── README.md └── CONTRIBUTING.md /tests/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /.prettierrc.yaml: -------------------------------------------------------------------------------- 1 | tabWidth: 4 2 | -------------------------------------------------------------------------------- /.coveragerc: -------------------------------------------------------------------------------- 1 | [run] 2 | omit = 3 | tests/** 4 | py_to_proto/compat_annotated.py 5 | -------------------------------------------------------------------------------- /.prettierignore: -------------------------------------------------------------------------------- 1 | # Ignore this auto-generated json file 2 | test/sample_libs/sample_lib/__static_import_tracker__.json 3 | -------------------------------------------------------------------------------- /.github/CODEOWNERS: -------------------------------------------------------------------------------- 1 | gabe.l.hart@gmail.com 2 | ghart@us.ibm.com 3 | joseph.runde@ibm.com 4 | joe@joerun.de 5 | evaline.ju@ibm.com 6 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .coverage 2 | htmlcov 3 | __pycache__ 4 | *.egg-info/ 5 | build/ 6 | dist/ 7 | .DS_Store 8 | .pytest_cache 9 | .bash_history 10 | .python_history 11 | .idea/ 12 | -------------------------------------------------------------------------------- /requirements_test.txt: -------------------------------------------------------------------------------- 1 | # Testing 2 | pytest>=6.2.5 3 | pytest-cov>=3.0.0 4 | pytest-xdist>=2.5.0 5 | tls_test_tools>=0.1.1 6 | 7 | # Round-trip proto compilation 8 | grpcio-tools>=1.46.3 9 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | # NOTE: protobuf 3.19 is the highest version allowed by tensorflow (currently), 2 | # so we explicitly pin the lower bound to allow compatibility with tf 3 | protobuf>=3.19.0,<7.0.0 4 | alchemy-logging>=1.0.3 5 | typing-extensions>=4.5.0,<5; python_version < '3.9' 6 | -------------------------------------------------------------------------------- /.isort.cfg: -------------------------------------------------------------------------------- 1 | [settings] 2 | profile=black 3 | from_first=true 4 | import_heading_future=Future 5 | import_heading_stdlib=Standard 6 | import_heading_thirdparty=Third Party 7 | import_heading_firstparty=First Party 8 | import_heading_localfolder=Local 9 | known_firstparty=alog 10 | known_localfolder=py_to_proto,tests 11 | -------------------------------------------------------------------------------- /py_to_proto/compat_annotated.py: -------------------------------------------------------------------------------- 1 | """ 2 | Helper module to allow Annotated to be imported in 3.7 and 3.8 3 | """ 4 | 5 | try: 6 | # Standard 7 | from typing import Annotated, get_args, get_origin 8 | except ImportError: 9 | # Third Party 10 | from typing_extensions import Annotated, get_args, get_origin 11 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: https://github.com/pre-commit/mirrors-prettier 3 | rev: v2.1.2 4 | hooks: 5 | - id: prettier 6 | - repo: https://github.com/psf/black 7 | rev: 22.3.0 8 | hooks: 9 | - id: black 10 | exclude: imports 11 | - repo: https://github.com/PyCQA/isort 12 | rev: 5.11.5 13 | hooks: 14 | - id: isort 15 | exclude: imports 16 | -------------------------------------------------------------------------------- /scripts/release.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | # Run from the project root 4 | cd $(dirname ${BASH_SOURCE[0]})/.. 5 | 6 | # Get the tag for this release 7 | tag=$(echo $REF | cut -d'/' -f3-) 8 | 9 | # Build the docker phase that will release and then test it 10 | docker build . \ 11 | --target=release_test \ 12 | --build-arg RELEASE_VERSION=$tag \ 13 | --build-arg PYPI_TOKEN=${PYPI_TOKEN:-""} \ 14 | --build-arg RELEASE_DRY_RUN=${RELEASE_DRY_RUN:-"false"} \ 15 | --build-arg PYTHON_VERSION=${PYTHON_VERSION:-"3.7"} 16 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/user-story.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: User story 3 | about: A user-oriented story describing a piece of work to do 4 | title: "" 5 | labels: "" 6 | assignees: "" 7 | --- 8 | 9 | ## Description 10 | 11 | As a , I want to , so that I can 12 | 13 | ## Discussion 14 | 15 | Provide detailed discussion here 16 | 17 | ## Acceptance Criteria 18 | 19 | 20 | 21 | - [ ] Unit tests cover new/changed code 22 | - [ ] Examples build against new/changed code 23 | - [ ] READMEs are updated 24 | - [ ] Type of [semantic version](https://semver.org/) change is identified 25 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/feature_request.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Feature request 3 | about: Suggest an idea for this project 4 | title: "" 5 | labels: "" 6 | assignees: "" 7 | --- 8 | 9 | ## Is your feature request related to a problem? Please describe. 10 | 11 | A clear and concise description of what the problem is. Ex. I'm always frustrated when [...] 12 | 13 | ## Describe the solution you'd like 14 | 15 | A clear and concise description of what you want to happen. 16 | 17 | ## Describe alternatives you've considered 18 | 19 | A clear and concise description of any alternative solutions or features you've considered. 20 | 21 | ## Additional context 22 | 23 | Add any other context about the feature request here. 24 | -------------------------------------------------------------------------------- /scripts/publish.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | # Run from the base of the python directory 4 | cd $(dirname ${BASH_SOURCE[0]})/.. 5 | 6 | # Clear out old publication files in case they're still around 7 | rm -rf build dist *.egg-info/ 8 | 9 | # Build 10 | py_tag="py$(echo $PYTHON_VERSION | cut -d'.' -f 1,2 | sed 's,\.,,g')" 11 | ./scripts/build_wheel.sh -v $RELEASE_VERSION -p $py_tag 12 | 13 | # Publish to PyPi 14 | if [ "${RELEASE_DRY_RUN}" != "true" ] 15 | then 16 | un_arg="" 17 | pw_arg="" 18 | if [ "$PYPI_TOKEN" != "" ] 19 | then 20 | un_arg="--username __token__" 21 | pw_arg="--password $PYPI_TOKEN" 22 | fi 23 | twine upload $un_arg $pw_arg dist/* 24 | else 25 | echo "Release DRY RUN" 26 | fi 27 | 28 | # Clean up 29 | rm -rf build dist *.egg-info/ 30 | -------------------------------------------------------------------------------- /scripts/run_tests.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | set -e 4 | BASE_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")/.." && pwd)" 5 | cd "$BASE_DIR" 6 | 7 | if [ "$PARALLEL" == "1" ] 8 | then 9 | if [[ "$OSTYPE" =~ "darwin"* ]] 10 | then 11 | num_procs=$(sysctl -n hw.physicalcpu) 12 | else 13 | num_procs=$(nproc) 14 | fi 15 | procs=${NPROCS:-$num_procs} 16 | echo "Running tests in parallel with [$procs] workers" 17 | procs_arg="-n $procs" 18 | else 19 | echo "Running tests in serial" 20 | procs_arg="" 21 | fi 22 | 23 | FAIL_THRESH=100.0 24 | python3 -m pytest \ 25 | $procs_arg \ 26 | --cov-config=.coveragerc \ 27 | --cov=py_to_proto \ 28 | --cov-report=term \ 29 | --cov-report=html \ 30 | --cov-fail-under=$FAIL_THRESH \ 31 | -W error "$@" 32 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/bug_report.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Bug report 3 | about: Create a report to help us improve 4 | title: "" 5 | labels: "" 6 | assignees: "" 7 | --- 8 | 9 | ## Describe the bug 10 | 11 | A clear and concise description of what the bug is. 12 | 13 | ## Platform 14 | 15 | Please provide details about the environment you are using, including the following: 16 | 17 | - Interpreter version: 18 | - Library version: 19 | 20 | ## Sample Code 21 | 22 | Please include a minimal sample of the code that will (if possible) reproduce the bug in isolation 23 | 24 | ## Expected behavior 25 | 26 | A clear and concise description of what you expected to happen. 27 | 28 | ## Observed behavior 29 | 30 | What you see happening (error messages, stack traces, etc...) 31 | 32 | ## Additional context 33 | 34 | Add any other context about the problem here. 35 | -------------------------------------------------------------------------------- /tests/conftest.py: -------------------------------------------------------------------------------- 1 | """ 2 | Common test helpers 3 | """ 4 | 5 | # Standard 6 | import os 7 | 8 | # Third Party 9 | from google.protobuf import descriptor_pool, struct_pb2, timestamp_pb2 10 | import pytest 11 | 12 | # First Party 13 | import alog 14 | 15 | # Global logging config 16 | alog.configure( 17 | default_level=os.environ.get("LOG_LEVEL", "info"), 18 | filters=os.environ.get("LOG_FILTERS", ""), 19 | formatter="json" if os.environ.get("LOG_JSON", "").lower() == "true" else "pretty", 20 | thread_id=os.environ.get("LOG_THREAD_ID", "").lower() == "true", 21 | ) 22 | 23 | 24 | @pytest.fixture 25 | def temp_dpool(): 26 | """Fixture to isolate the descriptor pool used in each test""" 27 | dpool = descriptor_pool.DescriptorPool() 28 | dpool.AddSerializedFile(struct_pb2.DESCRIPTOR.serialized_pb) 29 | dpool.AddSerializedFile(timestamp_pb2.DESCRIPTOR.serialized_pb) 30 | yield dpool 31 | -------------------------------------------------------------------------------- /scripts/fmt.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | # If disabled, do nothing (for docker build) 4 | if [ "${RUN_FMT:-"true"}" != "true" ] 5 | then 6 | echo "fmt disabled" 7 | exit 8 | fi 9 | 10 | pre-commit run --all-files 11 | RETURN_CODE=$? 12 | 13 | function echoWarning() { 14 | LIGHT_YELLOW='\033[1;33m' 15 | NC='\033[0m' # No Color 16 | echo -e "${LIGHT_YELLOW}${1}${NC}" 17 | } 18 | 19 | if [ "$RETURN_CODE" -ne 0 ]; then 20 | if [ "${CI}" != "true" ]; then 21 | echoWarning "☝️ This appears to have failed, but actually your files have been formatted." 22 | echoWarning "Make a new commit with these changes before making a pull request." 23 | else 24 | echoWarning "This test failed because your code isn't formatted correctly." 25 | echoWarning 'Locally, run `make run fmt`, it will appear to fail, but change files.' 26 | echoWarning "Add the changed files to your commit and this stage will pass." 27 | fi 28 | 29 | exit $RETURN_CODE 30 | fi 31 | -------------------------------------------------------------------------------- /ppa.dockerfile: -------------------------------------------------------------------------------- 1 | ################################################################################ 2 | # This dockerfile builds a base image that can be used to validate the library 3 | # against ubuntu PPA python builds 4 | # 5 | # Reference: https://github.com/IBM/import-tracker/issues/40 6 | ################################################################################ 7 | 8 | FROM ubuntu:18.04 9 | 10 | ARG PYTHON_VERSION=3.7 11 | RUN true && \ 12 | apt-get update && \ 13 | apt-get install software-properties-common curl -y && \ 14 | add-apt-repository ppa:deadsnakes/ppa -y && \ 15 | apt-get update && \ 16 | DEBIAN_FRONTEND="noninteractive" apt-get install -y \ 17 | python${PYTHON_VERSION} \ 18 | python${PYTHON_VERSION}-distutils && \ 19 | curl -sS https://bootstrap.pypa.io/get-pip.py | python${PYTHON_VERSION} && \ 20 | ln -s $(which python${PYTHON_VERSION}) /usr/local/bin/python && \ 21 | ln -s $(which python${PYTHON_VERSION}) /usr/local/bin/python3 && \ 22 | true 23 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 International Business Machines 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /py_to_proto/compat.py: -------------------------------------------------------------------------------- 1 | """ 2 | Compatibility module for API changes between different versions of protobuf 3 | """ 4 | 5 | # Standard 6 | from typing import Type 7 | import types 8 | 9 | # Third Party 10 | from google.protobuf.descriptor import ServiceDescriptor 11 | 12 | # protobuf >= 6 13 | try: # pragma: no cover 14 | # Third Party 15 | from google.protobuf.service_reflection import GeneratedServiceType 16 | 17 | def make_service_class( 18 | service_descriptor: ServiceDescriptor, 19 | ) -> Type[GeneratedServiceType]: 20 | return GeneratedServiceType( 21 | service_descriptor.name, 22 | (), 23 | {"DESCRIPTOR": service_descriptor}, 24 | ) 25 | 26 | 27 | # protobuf < 6 28 | except ImportError: # pragma: no cover 29 | # Third Party 30 | from google.protobuf.service import Service as GeneratedServiceType 31 | 32 | def make_service_class( 33 | service_descriptor: ServiceDescriptor, 34 | ) -> Type[GeneratedServiceType]: 35 | return types.new_class( 36 | service_descriptor.name, 37 | (GeneratedServiceType,), 38 | {"metaclass": GeneratedServiceType}, 39 | lambda ns: ns.update({"DESCRIPTOR": service_descriptor}), 40 | ) 41 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | ##@ General 2 | 3 | all: help 4 | 5 | # NOTE: Help stolen from operator-sdk auto-generated makfile! 6 | .PHONY: help 7 | help: ## Display this help. 8 | @awk 'BEGIN {FS = ":.*##"; printf "\nUsage:\n make \033[36m\033[0m\n"} /^[a-zA-Z_0-9-\\.]+:.*?##/ { printf " \033[36m%-15s\033[0m %s\n", $$1, $$2 } /^##@/ { printf "\n\033[1m%s\033[0m\n", substr($$0, 5) } ' $(MAKEFILE_LIST) 9 | 10 | .PHONY: test 11 | test: ## Run the unit tests 12 | PARALLEL=1 ./scripts/run_tests.sh 13 | 14 | .PHONY: fmt 15 | fmt: ## Run code formatting 16 | ./scripts/fmt.sh 17 | 18 | .PHONY: wheel 19 | wheel: ## Build release wheels 20 | ./scripts/build_wheel.sh 21 | 22 | ##@ Develop 23 | 24 | PYTHON_VERSION ?= 3.8 25 | PROTOBUF_VERSION ?= 26 | 27 | .PHONY: develop.build 28 | develop.build: ## Build the development environment container 29 | docker build . --target=base \ 30 | -t py-to-proto-develop \ 31 | --build-arg PYTHON_VERSION=${PYTHON_VERSION} \ 32 | --build-arg PROTOBUF_VERSION="${PROTOBUF_VERSION}" 33 | 34 | .PHONY: develop 35 | develop: develop.build ## Run the develop shell with the local codebase mounted 36 | touch .bash_history 37 | docker run --rm -it \ 38 | --entrypoint bash \ 39 | -w /src \ 40 | -v ${PWD}:/src \ 41 | -v ${PWD}/.bash_history:/root/.bash_history \ 42 | py-to-proto-develop 43 | -------------------------------------------------------------------------------- /py_to_proto/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | This library holds utilities for converting JSON Typedef to Protobuf. 3 | 4 | Rerferences: 5 | * https://jsontypedef.com/ 6 | * https://developers.google.com/protocol-buffers 7 | 8 | Example: 9 | 10 | ``` 11 | import py_to_proto 12 | 13 | # Declare the Foo protobuf message class 14 | Foo = py_to_proto.descriptor_to_message_class( 15 | py_to_proto.py_to_proto( 16 | name="Foo", 17 | package="foobar", 18 | jtd_def={ 19 | "properties": { 20 | # Bool field 21 | "foo": { 22 | "type": "boolean", 23 | }, 24 | # Array of nested enum values 25 | "bar": { 26 | "elements": { 27 | "enum": ["EXAM", "JOKE_SETTING"], 28 | } 29 | } 30 | } 31 | }, 32 | ) 33 | ) 34 | 35 | def write_foo_proto(filename: str): 36 | \"\"\"Write out the .proto file for Foo to the given filename\"\"\" 37 | with open(filename, "w") as handle: 38 | handle.write(Foo.to_proto_file()) 39 | ``` 40 | """ 41 | 42 | # Local 43 | from .dataclass_to_proto import dataclass_to_proto 44 | from .descriptor_to_file import descriptor_to_file 45 | from .descriptor_to_message_class import descriptor_to_message_class 46 | from .jtd_to_proto import jtd_to_proto 47 | -------------------------------------------------------------------------------- /scripts/install_release.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | ################################################################################ 4 | # This script is used to execute unit tests against a recent release without the 5 | # code held locally. It's intended to be used inside of the `release_test` 6 | # phase of the central Dockerfile. 7 | ################################################################################ 8 | 9 | # Make sure RELEASE_VERSION is defined 10 | if [ -z ${RELEASE_VERSION+x} ] 11 | then 12 | echo "RELEASE_VERSION must be set" 13 | exit 1 14 | fi 15 | 16 | # The name of the library we're testing 17 | LIBRARY_NAME="py_to_proto" 18 | 19 | # 10 minutes max for trying to install the new version 20 | MAX_DURATION="${MAX_DURATION:-600}" 21 | 22 | # Time to wait between attempts to install the version 23 | RETRY_SLEEP=5 24 | 25 | # Retry the install until it succeeds 26 | start_time=$(date +%s) 27 | success="0" 28 | while [ "$(expr "$(date +%s)" "-" "${start_time}" )" -lt "${MAX_DURATION}" ] 29 | do 30 | pip cache purge 31 | pip install ${LIBRARY_NAME}==${RELEASE_VERSION} 32 | exit_code=$? 33 | if [ "$exit_code" != "0" ] 34 | then 35 | echo "Trying again in [${RETRY_SLEEP}s]" 36 | sleep ${RETRY_SLEEP} 37 | else 38 | success="1" 39 | break 40 | fi 41 | done 42 | 43 | # If the install didn't succeed, exit with failure 44 | if [ "$success" == "0" ] 45 | then 46 | echo "Unable to install [${LIBRARY_NAME}==${RELEASE_VERSION}]!" 47 | exit 1 48 | fi 49 | -------------------------------------------------------------------------------- /.github/workflows/release.yml: -------------------------------------------------------------------------------- 1 | # This workflow runs the typescript implementation unit tests 2 | name: release 3 | on: 4 | release: 5 | types: [published] 6 | workflow_dispatch: {} 7 | jobs: 8 | build-38: 9 | runs-on: ubuntu-latest 10 | steps: 11 | - uses: actions/checkout@v2 12 | - name: Run release 13 | env: 14 | PYPI_TOKEN: ${{ secrets.PYPI_TOKEN }} 15 | PYTHON_VERSION: "3.8" 16 | run: REF="${{ github.ref }}" ./scripts/release.sh 17 | build-39: 18 | runs-on: ubuntu-latest 19 | steps: 20 | - uses: actions/checkout@v2 21 | - name: Run release 22 | env: 23 | PYPI_TOKEN: ${{ secrets.PYPI_TOKEN }} 24 | PYTHON_VERSION: "3.9" 25 | run: REF="${{ github.ref }}" ./scripts/release.sh 26 | build-310: 27 | runs-on: ubuntu-latest 28 | steps: 29 | - uses: actions/checkout@v2 30 | - name: Run release 31 | env: 32 | PYPI_TOKEN: ${{ secrets.PYPI_TOKEN }} 33 | PYTHON_VERSION: "3.10" 34 | run: REF="${{ github.ref }}" ./scripts/release.sh 35 | build-311: 36 | runs-on: ubuntu-latest 37 | steps: 38 | - uses: actions/checkout@v2 39 | - name: Run release 40 | env: 41 | PYPI_TOKEN: ${{ secrets.PYPI_TOKEN }} 42 | PYTHON_VERSION: "3.11" 43 | run: REF="${{ github.ref }}" ./scripts/release.sh 44 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | """A setuptools setup module for py_to_proto""" 2 | 3 | # Standard 4 | import os 5 | 6 | # Third Party 7 | from setuptools import setup 8 | 9 | # Read the README to provide the long description 10 | python_base = os.path.abspath(os.path.dirname(__file__)) 11 | with open(os.path.join(python_base, "README.md"), "r") as handle: 12 | long_description = handle.read() 13 | 14 | # Read version from the env 15 | version = os.environ.get("RELEASE_VERSION") 16 | assert version is not None, "Must set RELEASE_VERSION" 17 | 18 | # Read in the requirements 19 | with open(os.path.join(python_base, "requirements.txt"), "r") as handle: 20 | requirements = handle.read() 21 | 22 | setup( 23 | name="py_to_proto", 24 | version=version, 25 | description="A tool to dynamically create protobuf message classes from python data schemas", 26 | long_description=long_description, 27 | long_description_content_type="text/markdown", 28 | url="https://github.com/IBM/py-to-proto", 29 | author="Gabe Goodhart", 30 | author_email="gabe.l.hart@gmail.com", 31 | license="MIT", 32 | classifiers=[ 33 | "Intended Audience :: Developers", 34 | "Programming Language :: Python :: 3", 35 | "Programming Language :: Python :: 3.7", 36 | "Programming Language :: Python :: 3.8", 37 | "Programming Language :: Python :: 3.9", 38 | "Programming Language :: Python :: 3.10", 39 | "Programming Language :: Python :: 3.11", 40 | ], 41 | keywords=["json", "json typedef", "jtd", "protobuf", "proto", "dataclass"], 42 | packages=["py_to_proto"], 43 | install_requires=requirements, 44 | ) 45 | -------------------------------------------------------------------------------- /scripts/build_wheel.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Version of the library that we want to tag our wheel as 4 | release_version=${RELEASE_VERSION:-""} 5 | # Python tags we want to support 6 | python_versions="py37 py38 py39 py310" 7 | GREEN='\033[0;32m' 8 | NC='\033[0m' 9 | 10 | function show_help 11 | { 12 | cat <<- EOM 13 | Usage: scripts/build_wheels.sh -v [Library Version] -p [python versions] 14 | EOM 15 | } 16 | 17 | while (($# > 0)); do 18 | case "$1" in 19 | -h | --h | --he | --hel | --help) 20 | show_help 21 | exit 2 22 | ;; 23 | -p | --python_versions) 24 | shift 25 | python_versions="" 26 | while [ "$#" -gt "0" ] 27 | do 28 | if [ "$python_versions" != "" ] 29 | then 30 | python_versions="$python_versions " 31 | fi 32 | python_versions="$python_versions$1" 33 | if [ "$#" -gt "1" ] && [[ "$2" == "-"* ]] 34 | then 35 | break 36 | fi 37 | shift 38 | done 39 | ;; 40 | -v | --release_version) 41 | shift; release_version="$1";; 42 | *) 43 | echo "Unkown argument: $1" 44 | show_help 45 | exit 2 46 | ;; 47 | esac 48 | shift 49 | done 50 | 51 | if [ "$release_version" == "" ]; then 52 | echo "ERROR: a release version for the library must be specified." 53 | show_help 54 | exit 1 55 | else 56 | echo -e "Building wheels for version: ${GREEN}${release_version}${NC}" 57 | sleep 2 58 | fi 59 | for python_version in $python_versions; do 60 | echo -e "${GREEN}Building wheel for Python version [${python_version}]${NC}" 61 | RELEASE_VERSION=$release_version python3 setup.py bdist_wheel --python-tag ${python_version} clean --all 62 | echo -e "${GREEN}Done building wheel for Python version [${python_version}]${NC}" 63 | sleep 1 64 | done 65 | -------------------------------------------------------------------------------- /.github/workflows/tests.yml: -------------------------------------------------------------------------------- 1 | # This workflow runs the typescript implementation unit tests 2 | name: tests 3 | on: 4 | push: 5 | branches: [main] 6 | pull_request: 7 | branches: [main] 8 | workflow_dispatch: {} 9 | jobs: 10 | build-38: 11 | runs-on: ubuntu-latest 12 | steps: 13 | - uses: actions/checkout@v2 14 | - name: Run unit tests 15 | run: docker build . --target=test --build-arg PYTHON_VERSION=${PYTHON_VERSION} 16 | env: 17 | PYTHON_VERSION: "3.8" 18 | build-39: 19 | runs-on: ubuntu-latest 20 | steps: 21 | - uses: actions/checkout@v2 22 | - name: Run unit tests 23 | run: docker build . --target=test --build-arg PYTHON_VERSION=${PYTHON_VERSION} 24 | env: 25 | PYTHON_VERSION: "3.9" 26 | build-310: 27 | runs-on: ubuntu-latest 28 | steps: 29 | - uses: actions/checkout@v2 30 | - name: Run unit tests 31 | run: docker build . --target=test --build-arg PYTHON_VERSION=${PYTHON_VERSION} 32 | env: 33 | PYTHON_VERSION: "3.10" 34 | build-311: 35 | runs-on: ubuntu-latest 36 | steps: 37 | - uses: actions/checkout@v2 38 | - name: Run unit tests 39 | run: docker build . --target=test --build-arg PYTHON_VERSION=${PYTHON_VERSION} 40 | env: 41 | PYTHON_VERSION: "3.11" 42 | 43 | # Builds to validate alternate versions of protobuf 44 | build-38-pb319: 45 | runs-on: ubuntu-latest 46 | steps: 47 | - uses: actions/checkout@v2 48 | - name: Run unit tests 49 | run: docker build . --target=test --build-arg PYTHON_VERSION=${PYTHON_VERSION} --build-arg PROTOBUF_VERSION=">=3.19.0,<3.20" 50 | env: 51 | PYTHON_VERSION: "3.8" 52 | build-38-pb320: 53 | runs-on: ubuntu-latest 54 | steps: 55 | - uses: actions/checkout@v2 56 | - name: Run unit tests 57 | run: docker build . --target=test --build-arg PYTHON_VERSION=${PYTHON_VERSION} --build-arg PROTOBUF_VERSION=">=3.20.0,<3.21" 58 | env: 59 | PYTHON_VERSION: "3.8" 60 | -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | ## Base ######################################################################## 2 | # 3 | # This phase sets up dependencies for the other phases 4 | ## 5 | ARG PYTHON_VERSION=3.8 6 | ARG BASE_IMAGE=python:${PYTHON_VERSION}-slim 7 | FROM ${BASE_IMAGE} as base 8 | ARG PROTOBUF_VERSION="" 9 | 10 | # This image is only for building, so we run as root 11 | WORKDIR /src 12 | 13 | # Install build, test, and publish dependencies 14 | COPY requirements.txt requirements_test.txt /src/ 15 | RUN true && \ 16 | apt-get update -y && \ 17 | apt-get install make git -y && \ 18 | apt-get clean autoclean && \ 19 | apt-get autoremove --yes && \ 20 | pip install pip --upgrade && \ 21 | pip install twine pre-commit && \ 22 | pip install -r /src/requirements.txt && \ 23 | pip install -r /src/requirements_test.txt && \ 24 | if [ "$PROTOBUF_VERSION" != "" ]; then \ 25 | pip uninstall -y protobuf grpcio-tools && \ 26 | pip install "protobuf${PROTOBUF_VERSION}" grpcio-tools; \ 27 | fi && \ 28 | true 29 | 30 | ## Test ######################################################################## 31 | # 32 | # This phase runs the unit tests for the library 33 | ## 34 | FROM base as test 35 | COPY . /src 36 | ARG RUN_FMT="true" 37 | RUN true && \ 38 | ./scripts/run_tests.sh && \ 39 | RELEASE_DRY_RUN=true RELEASE_VERSION=0.0.0 \ 40 | ./scripts/publish.sh && \ 41 | ./scripts/fmt.sh && \ 42 | true 43 | 44 | ## Release ##################################################################### 45 | # 46 | # This phase builds the release and publishes it to pypi 47 | ## 48 | FROM test as release 49 | ARG PYPI_TOKEN 50 | ARG RELEASE_VERSION 51 | ARG RELEASE_DRY_RUN 52 | RUN ./scripts/publish.sh 53 | # Create a temp file that the release_test stage uses to ensure 54 | # correct order of build stages 55 | RUN touch RELEASED.txt 56 | 57 | ## Release Test ################################################################ 58 | # 59 | # This phase installs the indicated version from PyPi and runs the unit tests 60 | # against the installed version. 61 | ## 62 | FROM base as release_test 63 | # Copy a random file from the release phase just 64 | # to ensure release_test runs _after_ release 65 | COPY --from=release /src/RELEASED.txt . 66 | ARG RELEASE_VERSION 67 | ARG RELEASE_DRY_RUN 68 | COPY ./tests /src/tests 69 | COPY ./scripts/install_release.sh /src/scripts/install_release.sh 70 | RUN true && \ 71 | ./scripts/install_release.sh && \ 72 | python3 -m pytest -W error && \ 73 | true 74 | -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Contributor Covenant Code of Conduct 2 | 3 | ## Our Pledge 4 | 5 | We as members, contributors, and leaders pledge to make participation in our community a harassment-free experience for everyone, regardless of age, body size, visible or invisible disability, ethnicity, sex characteristics, gender identity and expression, level of experience, education, socio-economic status, nationality, personal appearance, race, religion, or sexual identity and orientation. 6 | 7 | We pledge to act and interact in ways that contribute to an open, welcoming, diverse, inclusive, and healthy community. 8 | 9 | ## Our Standards 10 | 11 | Examples of behavior that contributes to a positive environment for our 12 | community include: 13 | 14 | - Demonstrating empathy and kindness toward other people 15 | - Being respectful of differing opinions, viewpoints, and experiences 16 | - Giving and gracefully accepting constructive feedback 17 | - Accepting responsibility and apologizing to those affected by our mistakes, and learning from the experience 18 | - Focusing on what is best not just for us as individuals, but for the overall community 19 | 20 | Examples of unacceptable behavior include: 21 | 22 | - The use of sexualized language or imagery, and sexual attention or advances of any kind 23 | - Trolling, insulting or derogatory comments, and personal or political attacks 24 | - Public or private harassment 25 | - Publishing others' private information, such as a physical or email address, without their explicit permission 26 | - Other conduct which could reasonably be considered inappropriate in a professional setting 27 | 28 | ## Enforcement Responsibilities 29 | 30 | Community leaders are responsible for clarifying and enforcing our standards of acceptable behavior and will take appropriate and fair corrective action in response to any behavior that they deem inappropriate, threatening, offensive, or harmful. 31 | 32 | Community leaders have the right and responsibility to remove, edit, or reject comments, commits, code, wiki edits, issues, and other contributions that are not aligned to this Code of Conduct, and will communicate reasons for moderation decisions when appropriate. 33 | 34 | ## Scope 35 | 36 | This Code of Conduct applies within all community spaces, and also applies when an individual is officially representing the community in public spaces. 37 | 38 | ## Enforcement 39 | 40 | Instances of abusive, harassing, or otherwise unacceptable behavior may be reported to the community leaders responsible for enforcement - [CODOWNERS](.github/CODEOWNERS.md). All complaints will be reviewed and investigated promptly and fairly. 41 | 42 | All community leaders are obligated to respect the privacy and security of the reporter of any incident. 43 | 44 | ## Enforcement Guidelines 45 | 46 | Community leaders will follow these Community Impact Guidelines in determining the consequences for any action they deem in violation of this Code of Conduct: 47 | 48 | ### 1. Correction 49 | 50 | **Community Impact**: Use of inappropriate language or other behavior deemed unprofessional or unwelcome in the community. 51 | 52 | **Consequence**: A private, written warning from community leaders, providing clarity around the nature of the violation and an explanation of why the behavior was inappropriate. A public apology may be requested. 53 | 54 | ### 2. Warning 55 | 56 | **Community Impact**: A violation through a single incident or series of actions. 57 | 58 | **Consequence**: A warning with consequences for continued behavior. No interaction with the people involved, including unsolicited interaction with those enforcing the Code of Conduct, for a specified period of time. This includes avoiding interactions in community spaces as well as external channels like social media. Violating these terms may lead to a temporary or permanent ban. 59 | 60 | ### 3. Temporary Ban 61 | 62 | **Community Impact**: A serious violation of community standards, including sustained inappropriate behavior. 63 | 64 | **Consequence**: A temporary ban from any sort of interaction or public communication with the community for a specified period of time. No public or private interaction with the people involved, including unsolicited interaction with those enforcing the Code of Conduct, is allowed during this period. Violating these terms may lead to a permanent ban. 65 | 66 | ### 4. Permanent Ban 67 | 68 | **Community Impact**: Demonstrating a pattern of violation of community standards, including sustained inappropriate behavior, harassment of an individual, or aggression toward or disparagement of classes of individuals. 69 | 70 | **Consequence**: A permanent ban from any sort of public interaction within the community. 71 | 72 | ## Attribution 73 | 74 | This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 2.0, available at https://www.contributor-covenant.org/version/2/0/code_of_conduct.html. 75 | 76 | Community Impact Guidelines were inspired by [Mozilla's code of conduct enforcement ladder](https://github.com/mozilla/diversity). 77 | 78 | [homepage]: https://www.contributor-covenant.org 79 | 80 | For answers to common questions about this code of conduct, see the FAQ at https://www.contributor-covenant.org/faq. Translations are available at https://www.contributor-covenant.org/translations. 81 | -------------------------------------------------------------------------------- /py_to_proto/descriptor_to_message_class.py: -------------------------------------------------------------------------------- 1 | """ 2 | This module implements a helper to create python classes from in-memory protobuf 3 | Descriptor objects 4 | """ 5 | 6 | # Standard 7 | from functools import wraps 8 | from types import MethodType 9 | from typing import Any, Callable, Type, Union 10 | import os 11 | 12 | # Third Party 13 | from google.protobuf import descriptor as _descriptor 14 | from google.protobuf import message as _message 15 | from google.protobuf import reflection 16 | from google.protobuf.internal.enum_type_wrapper import EnumTypeWrapper 17 | 18 | # Local 19 | from .compat import GeneratedServiceType 20 | from .descriptor_to_file import descriptor_to_file 21 | 22 | 23 | def descriptor_to_message_class( 24 | descriptor: Union[_descriptor.Descriptor, _descriptor.EnumDescriptor], 25 | ) -> Union[Type[_message.Message], EnumTypeWrapper]: 26 | """Create the proto class from the given descriptor 27 | 28 | Args: 29 | descriptor: Union[_descriptor.Descriptor, _descriptor.EnumDescriptor] 30 | The message or enum Descriptor 31 | 32 | Returns: 33 | generated: Union[Type[_message.Message], EnumTypeWrapper] 34 | The generated message class or the enum wrapper 35 | """ 36 | # Handle enum descriptors 37 | if isinstance(descriptor, _descriptor.EnumDescriptor): 38 | message_class = EnumTypeWrapper(descriptor) 39 | 40 | # Handle message descriptors 41 | else: 42 | # Check to see whether this descriptor already has a concrete class. 43 | # NOTE: The MessageFactory already does this for newer versions of 44 | # proto, but in order to maintain compatibility with older versions, 45 | # this is needed here. 46 | try: 47 | message_class = descriptor._concrete_class 48 | except (TypeError, SystemError, AttributeError): 49 | # protobuf version compatibility 50 | if hasattr(reflection.message_factory, "GetMessageClass"): 51 | # Newer protobuf versions use GetMessageClass 52 | message_class = reflection.message_factory.GetMessageClass( 53 | descriptor 54 | ) # pragma: no cover 55 | else: 56 | # Older protobuf versions require creating an instance of a MessageFactory 57 | message_class = ( 58 | reflection.message_factory.MessageFactory().GetPrototype(descriptor) 59 | ) # pragma: no cover 60 | 61 | # Recursively add nested messages 62 | for nested_message_descriptor in descriptor.nested_types: 63 | nested_message_class = descriptor_to_message_class( 64 | nested_message_descriptor 65 | ) 66 | setattr(message_class, nested_message_descriptor.name, nested_message_class) 67 | 68 | # Recursively add nested enums 69 | for nested_enum_descriptor in descriptor.enum_types: 70 | setattr( 71 | message_class, 72 | nested_enum_descriptor.name, 73 | descriptor_to_message_class(nested_enum_descriptor), 74 | ) 75 | 76 | message_class = _add_protobuf_serializers(message_class, descriptor) 77 | return message_class 78 | 79 | 80 | ## Implementation Details ###################################################### 81 | 82 | 83 | def _maybe_classmethod(func: Callable, parent: Any): 84 | """Helper to attach the given function to the parent as either a classmethod 85 | of an instance method 86 | """ 87 | 88 | if isinstance(parent, type): 89 | 90 | @classmethod 91 | @wraps(func) 92 | def _wrapper(cls, *args, **kwargs): 93 | return func(cls, *args, **kwargs) 94 | 95 | else: 96 | 97 | @wraps(func) 98 | def _wrapper(self, *args, **kwargs): 99 | return func(self, *args, **kwargs) 100 | 101 | _wrapper = MethodType(_wrapper, parent) 102 | 103 | setattr(parent, func.__name__, _wrapper) 104 | 105 | 106 | def _add_protobuf_serializers( 107 | type_class: Union[ 108 | Type[_message.Message], EnumTypeWrapper, Type[GeneratedServiceType] 109 | ], 110 | descriptor: Union[ 111 | _descriptor.Descriptor, 112 | _descriptor.EnumDescriptor, 113 | _descriptor.ServiceDescriptor, 114 | ], 115 | ) -> Union[Type[_message.Message], EnumTypeWrapper, Type[GeneratedServiceType]]: 116 | """Helper to add the to_proto_file and write_proto_file to a given type class. 117 | 118 | Args: 119 | descriptor: Union[_descriptor.Descriptor, _descriptor.EnumDescriptor, _descriptor.ServiceDescriptor] 120 | The message or enum Descriptor 121 | type_class: Union[Type[_message.Message], EnumTypeWrapper, Type[GeneratedServiceType]] 122 | 123 | Returns: 124 | Union[Type[_message.Message], EnumTypeWrapper, Type[GeneratedServiceType]] 125 | A new class with the to_proto_file and write_proto_file added 126 | """ 127 | # Add to_proto_file 128 | if not hasattr(type_class, "to_proto_file"): 129 | 130 | def to_proto_file(first_arg) -> str: 131 | f"Create the serialized .proto file content holding all definitions for {descriptor.name}" 132 | return descriptor_to_file(first_arg.DESCRIPTOR) 133 | 134 | _maybe_classmethod(to_proto_file, type_class) 135 | 136 | # Add write_proto_file 137 | if not hasattr(type_class, "write_proto_file"): 138 | 139 | def write_proto_file(first_arg, root_dir: str = "."): 140 | "Write out the proto file to the target directory" 141 | if not os.path.exists(root_dir): 142 | os.makedirs(root_dir) 143 | with open( 144 | os.path.join(root_dir, first_arg.DESCRIPTOR.file.name), "w" 145 | ) as handle: 146 | handle.write(first_arg.to_proto_file()) 147 | 148 | _maybe_classmethod(write_proto_file, type_class) 149 | 150 | return type_class 151 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # PY To Proto 2 | 3 | This library holds utilities for converting in-memory data schema representations to [Protobuf](https://developers.google.com/protocol-buffers). The intent is to allow python libraries to leverage the power of `protobuf` while maintaining the source-of-truth for their data in pure python and avoiding static build steps. 4 | 5 | ## Why? 6 | 7 | The `protobuf` langauge is a powerful tool for defining language-agnostic, composable datastructures. `Protobuf` also offers cross-language compatibility so that a given set of definitions can be compiled into numerous target programming languages. The downside is that `protobuf` requires_a static built step to perform this `proto` -> `X` conversion step. Alternately, there are multiple ways of representing data schemas in pure python which allow a python library to interact with well-typed data objects. The downside here is that these structures can not easily be used from other programming languages. The pros/cons of these generally fall along the following lines: 8 | 9 | - `Protobuf`: 10 | - **Advantages** 11 | - Compact serialization 12 | - Auto-generated [`grpc`](https://grpc.io/) client and service libraries 13 | - Client libraries can be used from different programming languages 14 | - **Disadvantages** 15 | - Learning curve to understand the full ecosystem 16 | - Not a familiar tool outside of service engineering 17 | - Static compilation step required to use in code 18 | - Python schemas: 19 | - **Advantages** 20 | - Can be learned quickly using pure-python documentation 21 | - Can be written inline in pure python 22 | - **Disadvantages** 23 | - Generally, no standard serialization beyond `json` 24 | - No automated service implementations 25 | - No/manual mechanism for usage in other programming languages 26 | 27 | This project aims to bring the advantages of both types of schema representation so that a given project can take advantage of the best of both: 28 | 29 | - Define your structures in pure python for simplicity 30 | - Dynamically create [`google.protobuf.Descriptor`](https://github.com/protocolbuffers/protobuf/blob/main/python/google/protobuf/descriptor.py#L245) objects to allow for `protobuf` serialization and deserialization 31 | - Reverse render a `.proto` file from the generated `Descriptor` so that stubs can be generated in other languages 32 | - No static compiliation needed! 33 | 34 | ## Supported Python Schema Types 35 | 36 | Currently, objects can be declared using either [python `dataclasses`](https://docs.python.org/3/library/dataclasses.html) or [Json TypeDef (JTD)](https://jsontypedef.com/). Additional schemas can be added by [subclassing `ConverterBase`](py_to_proto/converter_base.py). 37 | 38 | ### Dataclass To Proto 39 | 40 | The following example illustrates how `dataclasses` and `enums` can be converted to proto: 41 | 42 | ```py 43 | from dataclasses import dataclass 44 | from enum import Enum 45 | from typing import Annotated, Dict, List, Enum 46 | import py_to_proto 47 | 48 | # Define the Foo structure as a python dataclass, including a nested enum 49 | @dataclass 50 | class Foo: 51 | 52 | class BarEnum(Enum): 53 | EXAM: 0 54 | JOKE_SETTING: 1 55 | 56 | foo: bool 57 | bar: List[BarEnum] 58 | 59 | # Define the Foo protobuf message class 60 | FooProto = py_to_proto.descriptor_to_message_class( 61 | py_to_proto.dataclass_to_proto( 62 | package="foobar", 63 | dataclass_=Foo, 64 | ) 65 | ) 66 | 67 | # Declare the Bar structure as a python dataclass with a reference to the 68 | # FooProto type 69 | @dataclass 70 | class Bar: 71 | baz: FooProto 72 | 73 | # Define the Bar protobuf message class 74 | BarProto = py_to_proto.descriptor_to_message_class( 75 | py_to_proto.dataclass_to_proto( 76 | package="foobar", 77 | dataclass_=Bar, 78 | ) 79 | ) 80 | 81 | # Instantiate a BarProto 82 | print(BarProto(baz=FooProto(foo=True, bar=[Foo.BarEnum.EXAM.value]))) 83 | 84 | def write_protos(proto_dir: str): 85 | """Write out the .proto files for FooProto and BarProto to the given 86 | directory 87 | """ 88 | FooProto.write_proto_file(proto_dir) 89 | BarProto.write_proto_file(proto_dir) 90 | ``` 91 | 92 | ### JTD To Proto 93 | 94 | The following example illustrates how JTD schemas can be converted to proto: 95 | 96 | ```py 97 | import py_to_proto 98 | 99 | # Declare the Foo protobuf message class 100 | Foo = py_to_proto.descriptor_to_message_class( 101 | py_to_proto.jtd_to_proto( 102 | name="Foo", 103 | package="foobar", 104 | jtd_def={ 105 | "properties": { 106 | # Bool field 107 | "foo": { 108 | "type": "boolean", 109 | }, 110 | # Array of nested enum values 111 | "bar": { 112 | "elements": { 113 | "enum": ["EXAM", "JOKE_SETTING"], 114 | } 115 | } 116 | } 117 | }, 118 | ) 119 | ) 120 | 121 | # Declare an object that references Foo as the type for a field 122 | Bar = py_to_proto.descriptor_to_message_class( 123 | py_to_proto.jtd_to_proto( 124 | name="Bar", 125 | package="foobar", 126 | jtd_def={ 127 | "properties": { 128 | "baz": { 129 | "type": Foo.DESCRIPTOR, 130 | }, 131 | }, 132 | }, 133 | ), 134 | ) 135 | 136 | def write_protos(proto_dir: str): 137 | """Write out the .proto files for Foo and Bar to the given directory""" 138 | Foo.write_proto_file(proto_dir) 139 | Bar.write_proto_file(proto_dir) 140 | ``` 141 | 142 | ## Similar Projects 143 | 144 | There are a number of similar projects in this space that offer slightly different value: 145 | 146 | - [`jtd-codegen`](https://jsontypedef.com/docs/jtd-codegen/): This project focuses on statically generating language-native code (including `python`) to represent the JTD schema. 147 | - [`py-json-to-proto`](https://pypi.org/project/py-json-to-proto/): This project aims to deduce a schema from an instance of a `json` object. 148 | - [`pure-protobuf`](https://pypi.org/project/pure-protobuf/): This project has a very similar aim to `py-to-proto`, but it skips the intermediate `descriptor` representation and thus is not able to produce native `message.Message` classes. 149 | -------------------------------------------------------------------------------- /tests/test_descriptor_to_message_class.py: -------------------------------------------------------------------------------- 1 | """ 2 | Tests for descriptor_to_message_class 3 | """ 4 | 5 | # Standard 6 | import os 7 | import tempfile 8 | 9 | # Third Party 10 | from google.protobuf import message 11 | from google.protobuf.internal.enum_type_wrapper import EnumTypeWrapper 12 | 13 | # Local 14 | from .conftest import temp_dpool 15 | from py_to_proto.descriptor_to_message_class import descriptor_to_message_class 16 | from py_to_proto.jtd_to_proto import jtd_to_proto 17 | 18 | 19 | def test_descriptor_to_message_class_generated_descriptor(temp_dpool): 20 | """Make sure that a generated descriptor can be used to create a class""" 21 | descriptor = jtd_to_proto( 22 | "Foo", 23 | "foo.bar", 24 | { 25 | "properties": { 26 | "foo": {"type": "boolean"}, 27 | "bar": {"type": "float32"}, 28 | } 29 | }, 30 | descriptor_pool=temp_dpool, 31 | ) 32 | Foo = descriptor_to_message_class(descriptor) 33 | assert issubclass(Foo, message.Message) 34 | foo = Foo(foo=True, bar=1.234) 35 | assert foo.foo is True 36 | assert foo.bar is not None # NOTE: There are precision errors comparing == 1.234 37 | 38 | # Make sure the class can be serialized 39 | serialized_content = Foo.to_proto_file() 40 | assert "message Foo" in serialized_content 41 | 42 | 43 | def test_descriptor_to_message_class_write_proto_file(temp_dpool): 44 | """Make sure that each message class has write_proto_files attached to it 45 | and that it correctly writes the protobufs to the right named files. 46 | """ 47 | Foo = descriptor_to_message_class( 48 | jtd_to_proto( 49 | name="Foo", 50 | package="foobar", 51 | jtd_def={ 52 | "properties": { 53 | "foo": { 54 | "type": "boolean", 55 | }, 56 | } 57 | }, 58 | descriptor_pool=temp_dpool, 59 | ) 60 | ) 61 | 62 | Bar = descriptor_to_message_class( 63 | jtd_to_proto( 64 | name="Bar", 65 | package="foobar", 66 | jtd_def={ 67 | "properties": { 68 | "bar": { 69 | "type": Foo.DESCRIPTOR, 70 | }, 71 | }, 72 | }, 73 | descriptor_pool=temp_dpool, 74 | ), 75 | ) 76 | 77 | with tempfile.TemporaryDirectory() as workdir: 78 | Foo.write_proto_file(workdir) 79 | Bar.write_proto_file(workdir) 80 | assert set(os.listdir(workdir)) == { 81 | Foo.DESCRIPTOR.file.name, 82 | Bar.DESCRIPTOR.file.name, 83 | } 84 | with open(os.path.join(workdir, Bar.DESCRIPTOR.file.name), "r") as handle: 85 | bar_content = handle.read() 86 | assert f'import "{Foo.DESCRIPTOR.file.name}"' in bar_content 87 | 88 | 89 | def test_descriptor_to_message_class_write_proto_file_no_dir(temp_dpool): 90 | """Make sure that each message class has write_proto_files attached to it 91 | and that it correctly writes the protobufs to the right named files. 92 | Also ensures that the directory gets created if it doesn't exist 93 | """ 94 | Foo = descriptor_to_message_class( 95 | jtd_to_proto( 96 | name="Foo", 97 | package="foobar", 98 | jtd_def={ 99 | "properties": { 100 | "foo": { 101 | "type": "boolean", 102 | }, 103 | } 104 | }, 105 | descriptor_pool=temp_dpool, 106 | ) 107 | ) 108 | 109 | with tempfile.TemporaryDirectory() as workdir: 110 | protos_dir_path = os.path.join(workdir, "protos") 111 | Foo.write_proto_file(protos_dir_path) 112 | assert set(os.listdir(protos_dir_path)) == { 113 | Foo.DESCRIPTOR.file.name, 114 | } 115 | 116 | 117 | def test_descriptor_to_message_class_nested_messages(temp_dpool): 118 | """Make sure that nested messages are wrapped and added to the parents""" 119 | top = descriptor_to_message_class( 120 | jtd_to_proto( 121 | name="Top", 122 | package="foobar", 123 | jtd_def={ 124 | "properties": { 125 | "ghost": { 126 | "properties": { 127 | "boo": { 128 | "type": "string", 129 | } 130 | } 131 | } 132 | } 133 | }, 134 | descriptor_pool=temp_dpool, 135 | ) 136 | ) 137 | assert issubclass(top, message.Message) 138 | assert issubclass(top.Ghost, message.Message) 139 | 140 | 141 | def test_descriptor_to_message_class_nested_enums(temp_dpool): 142 | """Make sure that nested enums are wrapped and added to the parents""" 143 | top = descriptor_to_message_class( 144 | jtd_to_proto( 145 | name="Top", 146 | package="foobar", 147 | jtd_def={ 148 | "properties": { 149 | "bat": { 150 | "enum": ["VAMPIRE", "BASEBALL"], 151 | } 152 | } 153 | }, 154 | descriptor_pool=temp_dpool, 155 | ) 156 | ) 157 | assert issubclass(top, message.Message) 158 | assert isinstance(top.Bat, EnumTypeWrapper) 159 | 160 | 161 | def test_descriptor_to_message_class_top_level_enum(temp_dpool): 162 | """Make sure that a top-level EnumDescriptor results in an EnumTypeWrapper""" 163 | top = descriptor_to_message_class( 164 | jtd_to_proto( 165 | name="Top", 166 | package="foobar", 167 | jtd_def={"enum": ["VAMPIRE", "DRACULA"]}, 168 | descriptor_pool=temp_dpool, 169 | ) 170 | ) 171 | assert isinstance(top, EnumTypeWrapper) 172 | with tempfile.TemporaryDirectory() as workdir: 173 | top.write_proto_file(workdir) 174 | assert os.listdir(workdir) == [top.DESCRIPTOR.file.name] 175 | 176 | 177 | def test_multiple_invocations_of_descriptor_to_message(temp_dpool): 178 | """Ensure that invoking descriptor_to_message_class with the same descriptor 179 | returns the same instance of a class. 180 | """ 181 | descriptor = jtd_to_proto( 182 | "Foo", 183 | "foo.bar", 184 | { 185 | "properties": { 186 | "foo": {"type": "boolean"}, 187 | "bar": {"type": "float32"}, 188 | } 189 | }, 190 | descriptor_pool=temp_dpool, 191 | ) 192 | Foo = descriptor_to_message_class(descriptor) 193 | foo = Foo(foo=True, bar=1.234) 194 | 195 | Bar = descriptor_to_message_class(descriptor) 196 | bar = Bar(foo=True, bar=1.234) 197 | 198 | assert Foo is Bar 199 | assert Foo == Bar 200 | assert id(Foo) == id(Bar) 201 | assert foo == bar 202 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing 2 | 3 | 👍🎉 First off, thank you for taking the time to contribute! 🎉👍 4 | 5 | The following is a set of guidelines for contributing. These are just guidelines, not rules. Use your best judgment, and feel free to propose changes to this document in a pull request. 6 | 7 | ## What Should I Know Before I Get Started? 8 | 9 | If you're new to GitHub and working with open source repositories, this section will be helpful. Otherwise, you can skip to learning how to [set up your dev environment](#set-up-your-dev-environment) 10 | 11 | ### Code of Conduct 12 | 13 | This project adheres to the [Contributor Covenant](./CODE_OF_CONDUCT.md). By participating, you are expected to uphold this code. 14 | 15 | Please report unacceptable behavior to one of the [Code Owners](./.github/CODEOWNERS). 16 | 17 | ### How Do I Start Contributing? 18 | 19 | The below workflow is designed to help you begin your first contribution journey. It will guide you through creating and picking up issues, working through them, having your work reviewed, and then merging. 20 | 21 | Help on open source projects is always welcome and there is always something that can be improved. For example, documentation (like the text you are reading now) can always use improvement, code can always be clarified, variables or functions can always be renamed or commented on, and there is always a need for more test coverage. If you see something that you think should be fixed, take ownership! Here is how you get started: 22 | 23 | ## How Can I Contribute? 24 | 25 | When contributing, it's useful to start by looking at [issues](https://github.com/IBM/py-to-proto/issues). After picking up an issue, writing code, or updating a document, make a pull request and your work will be reviewed and merged. If you're adding a new feature, it's best to [write an issue](https://github.com/IBM/py-to-proto/issues/new?assignees=&labels=&template=feature_request.md&title=) first to discuss it with maintainers first. 26 | 27 | ### Reporting Bugs 28 | 29 | This section guides you through submitting a bug report. Following these guidelines helps maintainers and the community understand your report ✏️, reproduce the behavior 💻, and find related reports 🔎. 30 | 31 | #### How Do I Submit A (Good) Bug Report? 32 | 33 | Bugs are tracked as [GitHub issues using the Bug Report template](https://github.com/IBM/py-to-proto/issues/new?assignees=&labels=&template=bug_report.md&title=). Create an issue on that and provide the information suggested in the bug report issue template. 34 | 35 | ### Suggesting Enhancements 36 | 37 | This section guides you through submitting an enhancement suggestion, including completely new features, tools, and minor improvements to existing functionality. Following these guidelines helps maintainers and the community understand your suggestion ✏️ and find related suggestions 🔎 38 | 39 | #### How Do I Submit A (Good) Enhancement Suggestion? 40 | 41 | Enhancement suggestions are tracked as [GitHub issues using the Feature Request template](https://github.com/IBM/py-to-proto/issues/new?assignees=&labels=&template=feature_request.md&title=). Create an issue and provide the information suggested in the feature requests or user story issue template. 42 | 43 | #### How Do I Submit A (Good) Improvement Item? 44 | 45 | Improvements to existing functionality are tracked as [GitHub issues using the User Story template](https://github.com/IBM/py-to-proto/issues/new?assignees=&labels=&template=user-story.md&title=). Create an issue and provide the information suggested in the feature requests or user story issue template. 46 | 47 | ## Development 48 | 49 | ### Set up your dev environments 50 | 51 | #### Using Docker 52 | 53 | The easiest way to get up and running is to use the dockerized development environment which you can launch using: 54 | 55 | ```sh 56 | make develop 57 | ``` 58 | 59 | Within the `develop` shell, any of the `make` targets that do not require `docker` can be run directly. The shell has the local files mounted, so changes to the files on your host machine will be reflected when commands are run in the `develop` shell. 60 | 61 | #### Locally 62 | 63 | You can also develop locally using standard python development practices. You'll need to install the dependencies for the unit tests. It is recommended that you do this in a virtual environment such as [`conda`](https://docs.conda.io/en/latest/miniconda.html) or [`pyenv`](https://github.com/pyenv/pyenv) so that you avoid version conflicts in a shared global dependency set. 64 | 65 | ```sh 66 | pip install -r requirements_test.txt 67 | ``` 68 | 69 | ### Run unit tests 70 | 71 | Running the tests is as simple as: 72 | 73 | ```sh 74 | make test 75 | ``` 76 | 77 | If you want to use the full set of [`pytest` CLI arguments](https://docs.pytest.org/en/6.2.x/usage.html), you can run the `scripts/run_tests.sh` script directly with any arguments added to the command. For example, to run only a single test without capturing output, you can do: 78 | 79 | ```sh 80 | ./scripts/run_tests.sh tests/test_jtd_to_proto.py 81 | ``` 82 | 83 | ### Code formatting 84 | 85 | This project uses [pre-commit](https://pre-commit.com/) to enforce coding style using [black](https://github.com/psf/black). To set up `pre-commit` locally, you can: 86 | 87 | ```sh 88 | pip install pre-commit 89 | ``` 90 | 91 | Coding style is enforced by the CI tests, so if not installed locally, your PR will fail until formatting has been applied. 92 | 93 | ## Your First Code Contribution 94 | 95 | Unsure where to begin contributing? You can start by looking through these issues: 96 | 97 | - Issues with the [`good first issue` label](https://github.com/IBM/py-to-proto/issues?q=is%3Aopen+is%3Aissue+label%3A%22good+first+issue%22) - these should only require a few lines of code and are good targets if you're just starting contributing. 98 | - Issues with the [`help wanted` label](https://github.com/IBM/py-to-proto/issues?q=is%3Aopen+is%3Aissue+label%3A%22help+wanted%22) - these range from simple to more complex, but are generally things we want but can't get to in a short time frame. 99 | 100 | ### How to contribute 101 | 102 | To contribute to this repo, you'll use the Fork and Pull model common in many open source repositories. For details on this process, watch [how to contribute](https://egghead.io/courses/how-to-contribute-to-an-open-source-project-on-github). 103 | 104 | When ready, you can create a pull request. Pull requests are often referred to as "PR". In general, we follow the standard [github pull request](https://help.github.com/en/articles/about-pull-requests) process. Follow the template to provide details about your pull request to the maintainers. 105 | 106 | Before sending pull requests, make sure your changes pass tests. 107 | 108 | #### Code Review 109 | 110 | Once you've [created a pull request](#how-to-contribute), maintainers will review your code and likely make suggestions to fix before merging. It will be easier for your pull request to receive reviews if you consider the criteria the reviewers follow while working. Remember to: 111 | 112 | - Run tests locally and ensure they pass 113 | - Follow the project coding conventions 114 | - Write detailed commit messages 115 | - Break large changes into a logical series of smaller patches, which are easy to understand individually and combine to solve a broader issue 116 | 117 | ## Releasing (Maintainers only) 118 | 119 | The responsibility for releasing new versions of the libraries falls to the maintainers. Releases will follow standard [semantic versioning](https://semver.org/) and be hosted on [pypi](https://pypi.org/project/py-to-proto/). 120 | -------------------------------------------------------------------------------- /py_to_proto/jtd_to_proto.py: -------------------------------------------------------------------------------- 1 | # Standard 2 | from typing import Any, Dict, Iterable, List, Optional, Tuple, Union 3 | 4 | # Third Party 5 | from google.protobuf import any_pb2 6 | from google.protobuf import descriptor as _descriptor 7 | from google.protobuf import descriptor_pool as _descriptor_pool 8 | from google.protobuf import timestamp_pb2 9 | 10 | # First Party 11 | import alog 12 | 13 | # Local 14 | from .converter_base import ConverterBase 15 | from .validation import is_valid_jtd 16 | 17 | log = alog.use_channel("JTD2P") 18 | 19 | 20 | ## Globals ##################################################################### 21 | 22 | JTD_TO_PROTO_TYPES = { 23 | "any": any_pb2.Any, 24 | "boolean": _descriptor.FieldDescriptor.TYPE_BOOL, 25 | "string": _descriptor.FieldDescriptor.TYPE_STRING, 26 | "timestamp": timestamp_pb2.Timestamp, 27 | "float32": _descriptor.FieldDescriptor.TYPE_FLOAT, 28 | "float64": _descriptor.FieldDescriptor.TYPE_DOUBLE, 29 | # NOTE: All number types except fixed, double, and float are stored as 30 | # varints meaning as long as your numbers stay in about the int8 or int16 31 | # range, they are only 1 or 2 bytes, even though it says int32. 32 | # 33 | # CITE: https://groups.google.com/g/protobuf/c/Er39mNGnRWU/m/x6Srz_GrZPgJ 34 | "int8": _descriptor.FieldDescriptor.TYPE_INT32, 35 | "uint8": _descriptor.FieldDescriptor.TYPE_UINT32, 36 | "int16": _descriptor.FieldDescriptor.TYPE_INT32, 37 | "uint16": _descriptor.FieldDescriptor.TYPE_UINT32, 38 | "int32": _descriptor.FieldDescriptor.TYPE_INT32, 39 | "uint32": _descriptor.FieldDescriptor.TYPE_UINT32, 40 | "int64": _descriptor.FieldDescriptor.TYPE_INT64, 41 | "uint64": _descriptor.FieldDescriptor.TYPE_UINT64, 42 | # Not strictly part of the JTD spec, but important for protobuf messages 43 | "bytes": _descriptor.FieldDescriptor.TYPE_BYTES, 44 | } 45 | 46 | # Common type used everywhere for a JTD dict 47 | _JtdDefType = Dict[str, Union[dict, str]] 48 | 49 | 50 | ## Interface ################################################################### 51 | 52 | 53 | def jtd_to_proto( 54 | name: str, 55 | package: str, 56 | jtd_def: _JtdDefType, 57 | *, 58 | validate_jtd: bool = False, 59 | type_mapping: Optional[Dict[str, Union[int, _descriptor.Descriptor]]] = None, 60 | descriptor_pool: Optional[_descriptor_pool.DescriptorPool] = None, 61 | ) -> _descriptor.Descriptor: 62 | """Convert a JTD schema into a set of proto DESCRIPTOR objects. 63 | 64 | Reference: https://jsontypedef.com/docs/jtd-in-5-minutes/ 65 | 66 | Args: 67 | name: str 68 | The name for the top-level message object 69 | package: str 70 | The proto package name to use for this object 71 | jtd_def: Dict[str, Union[dict, str]] 72 | The full JTD schema dict 73 | 74 | Kwargs: 75 | validate_jtd: bool 76 | Whether or not to validate the JTD schema 77 | type_mapping: Optional[Dict[str, Union[int, _descriptor.Descriptor]]] 78 | A non-default mapping from JTD type names to proto types 79 | descriptor_pool: Optional[descriptor_pool.DescriptorPool] 80 | If given, this DescriptorPool will be used to aggregate the set of 81 | message descriptors 82 | 83 | Returns: 84 | descriptor: descriptor.Descriptor 85 | The top-level MessageDescriptor corresponding to this jtd definition 86 | """ 87 | return JTDConverter( 88 | name=name, 89 | package=package, 90 | jtd_def=jtd_def, 91 | validate=validate_jtd, 92 | type_mapping=type_mapping, 93 | descriptor_pool=descriptor_pool, 94 | ).descriptor 95 | 96 | 97 | ## Impl ######################################################################## 98 | 99 | 100 | class JTDConverter(ConverterBase): 101 | """Converter implementation for JTD source schemas""" 102 | 103 | def __init__( 104 | self, 105 | name: str, 106 | package: str, 107 | jtd_def: _JtdDefType, 108 | *, 109 | type_mapping: Optional[Dict[str, Union[int, _descriptor.Descriptor]]] = None, 110 | validate: bool = False, 111 | descriptor_pool: Optional[_descriptor_pool.DescriptorPool] = None, 112 | ): 113 | """Fill in the default type mapping and additional default vals, then 114 | initialize the parent 115 | """ 116 | type_mapping = type_mapping or JTD_TO_PROTO_TYPES 117 | super().__init__( 118 | name=name, 119 | package=package, 120 | source_schema=jtd_def, 121 | type_mapping=type_mapping, 122 | validate=validate, 123 | descriptor_pool=descriptor_pool, 124 | ) 125 | 126 | ## Abstract Interface ###################################################### 127 | 128 | def validate(self, source_schema: _JtdDefType) -> bool: 129 | """Perform preprocess validation of the input""" 130 | log.debug2("Validating JTD") 131 | valid_types = self.type_mapping.keys() 132 | return is_valid_jtd(source_schema, valid_types=valid_types) 133 | 134 | ## Types ## 135 | 136 | def get_concrete_type(self, entry: _JtdDefType) -> Any: 137 | """If this is a concrete type, get the JTD key for it""" 138 | return entry.get("type") 139 | 140 | ## Maps ## 141 | 142 | def get_map_key_val_types( 143 | self, 144 | entry: _JtdDefType, 145 | ) -> Optional[Tuple[int, ConverterBase.ConvertOutputTypes]]: 146 | """Get the key and value types for a given map type""" 147 | values = entry.get("values") 148 | if values is not None: 149 | string_type = self.type_mapping.get("string") 150 | if string_type is None: 151 | raise ValueError( 152 | "Provided type mapping has no key for 'string', so values maps cannot be used" 153 | ) 154 | val_type = self._convert(entry=values, name="value") 155 | return (string_type, val_type) 156 | 157 | ## Enums ## 158 | 159 | def get_enum_vals(self, entry: _JtdDefType) -> Optional[List[Tuple[str, int]]]: 160 | """Get the ordered list of enum name -> number mappings if this entry is 161 | an enum 162 | 163 | NOTE: If any values appear multiple times, this implies an alias 164 | 165 | NOTE 2: All names must be unique 166 | """ 167 | enum = entry.get("enum") 168 | if enum is not None: 169 | return [ 170 | (entry_name, entry_idx) for entry_idx, entry_name in enumerate(enum) 171 | ] 172 | 173 | ## Messages ## 174 | 175 | def get_message_fields( 176 | self, 177 | entry: _JtdDefType, 178 | ) -> Optional[Iterable[Tuple[str, Any]]]: 179 | """Get the mapping of names to type-specific field descriptors""" 180 | properties = entry.get("properties", {}) 181 | optional_properties = entry.get("optionalProperties", {}) 182 | all_properties = {**properties, **optional_properties} 183 | if all_properties: 184 | return all_properties.items() 185 | 186 | def has_additional_fields(self, entry: _JtdDefType) -> bool: 187 | """Check whether the given entry expects to support arbitrary key/val 188 | additional properties 189 | """ 190 | return entry.get("additionalProperties", False) 191 | 192 | def get_optional_field_names(self, entry: _JtdDefType) -> List[str]: 193 | """Get the names of any fields which are explicitly marked 'optional'""" 194 | return entry.get("optionalProperties", {}).keys() 195 | 196 | ## Fields ## 197 | 198 | def get_field_number(self, num_fields: int, field_def: _JtdDefType) -> int: 199 | """If the field has a metadata field "field_number" use that, otherwise, 200 | use the next field number sequentially 201 | """ 202 | return field_def.get("metadata", {}).get("field_number", num_fields + 1) 203 | 204 | def get_oneof_fields( 205 | self, field_def: _JtdDefType 206 | ) -> Optional[Iterable[Tuple[str, Any]]]: 207 | """If the given field def is a discriminator, it's a oneof""" 208 | discriminator = field_def.get("discriminator") 209 | if discriminator is not None: 210 | mapping = field_def.get("mapping") 211 | assert isinstance(mapping, dict), "Invalid discriminator without mapping" 212 | return mapping.items() 213 | 214 | def get_oneof_name(self, field_def: _JtdDefType) -> str: 215 | """For an identified oneof field def, get the name""" 216 | return field_def.get("discriminator") 217 | 218 | def get_field_type(self, field_def: _JtdDefType) -> Any: 219 | """Get the type of the field. The definition of type here will be 220 | specific to the converter (e.g. string for JTD, py type for dataclass) 221 | """ 222 | elements = field_def.get("elements") 223 | if elements is not None: 224 | return elements 225 | return field_def 226 | 227 | def is_repeated_field(self, field_def: _JtdDefType) -> bool: 228 | """Determine if the given field def is repeated""" 229 | return "elements" in field_def 230 | -------------------------------------------------------------------------------- /py_to_proto/descriptor_to_file.py: -------------------------------------------------------------------------------- 1 | """ 2 | This module implements serialization of an in-memory Descriptor to a portable 3 | .proto file 4 | """ 5 | 6 | # Standard 7 | from typing import List, Optional, Union 8 | 9 | # Third Party 10 | from google.protobuf import descriptor as _descriptor 11 | from google.protobuf import descriptor_pb2 12 | 13 | ## Globals ##################################################################### 14 | 15 | 16 | PROTO_FILE_PRIMITIVE_TYPE_NAMES = { 17 | type_val: type_name[5:].lower() 18 | for type_name, type_val in vars(_descriptor.FieldDescriptor).items() 19 | if type_name.startswith("TYPE_") 20 | } 21 | 22 | PROTO_FILE_INDENT = " " 23 | 24 | PROTO_FILE_AUTOGEN_HEADER = """ 25 | /*------------------------------------------------------------------------------ 26 | * AUTO GENERATED 27 | *----------------------------------------------------------------------------*/ 28 | """ 29 | 30 | PROTO_FILE_ENUM_HEADER = """ 31 | /*-- ENUMS -------------------------------------------------------------------*/ 32 | """ 33 | 34 | PROTO_FILE_MESSAGE_HEADER = """ 35 | /*-- MESSAGES ----------------------------------------------------------------*/ 36 | """ 37 | 38 | PROTO_FILE_SERVICES_HEADER = """ 39 | /*-- SERVICES ----------------------------------------------------------------*/ 40 | """ 41 | 42 | PROTO_FILE_NESTED_ENUM_HEADER = f"{PROTO_FILE_INDENT}/*-- nested enums --*/" 43 | PROTO_FILE_NESTED_MESSAGE_HEADER = f"{PROTO_FILE_INDENT}/*-- nested messages --*/" 44 | PROTO_FILE_FIELD_HEADER = f"{PROTO_FILE_INDENT}/*-- fields --*/" 45 | PROTO_FILE_ONEOF_HEADER = f"{PROTO_FILE_INDENT}/*-- oneofs --*/" 46 | 47 | 48 | ## Interface ################################################################### 49 | 50 | 51 | def descriptor_to_file( 52 | descriptor: Union[ 53 | _descriptor.FileDescriptor, 54 | _descriptor.Descriptor, 55 | _descriptor.ServiceDescriptor, 56 | ], 57 | ) -> str: 58 | """Serialize a .proto file from a FileDescriptor 59 | 60 | Args: 61 | descriptor: Union[descriptor.FileDescriptor, descriptor.MessageDescriptor] 62 | The file or message descriptor to serialize 63 | 64 | Returns: 65 | proto_file_content: str 66 | The serialized file content for the .proto file 67 | """ 68 | 69 | # If this is a message descriptor, use its corresponding FileDescriptor 70 | if isinstance( 71 | descriptor, 72 | ( 73 | _descriptor.Descriptor, 74 | _descriptor.EnumDescriptor, 75 | _descriptor.ServiceDescriptor, 76 | ), 77 | ): 78 | descriptor = descriptor.file 79 | if not isinstance(descriptor, _descriptor.FileDescriptor): 80 | raise ValueError(f"Invalid file descriptor of type {type(descriptor)}") 81 | proto_file_lines = [] 82 | 83 | # Create the header 84 | proto_file_lines.append(PROTO_FILE_AUTOGEN_HEADER) 85 | 86 | # Add package, syntax, and imports 87 | syntax = getattr(descriptor, "syntax", "proto3") 88 | proto_file_lines.append(f'syntax = "{syntax}";') 89 | if descriptor.package: 90 | proto_file_lines.append(f"package {descriptor.package};") 91 | for dep in descriptor.dependencies: 92 | proto_file_lines.append(f'import "{dep.name}";') 93 | proto_file_lines.append("") 94 | 95 | # Add all enums 96 | if descriptor.enum_types_by_name: 97 | proto_file_lines.append(PROTO_FILE_ENUM_HEADER) 98 | for enum_descriptor in descriptor.enum_types_by_name.values(): 99 | proto_file_lines.extend(_enum_descriptor_to_file(enum_descriptor)) 100 | proto_file_lines.append("") 101 | 102 | # Add all messages 103 | if descriptor.message_types_by_name: 104 | proto_file_lines.append(PROTO_FILE_MESSAGE_HEADER) 105 | for message_descriptor in descriptor.message_types_by_name.values(): 106 | proto_file_lines.extend(_message_descriptor_to_file(message_descriptor)) 107 | proto_file_lines.append("") 108 | 109 | if descriptor.services_by_name: 110 | proto_file_lines.append(PROTO_FILE_SERVICES_HEADER) 111 | for service_descriptor in descriptor.services_by_name.values(): 112 | proto_file_lines.extend(_service_descriptor_to_file(service_descriptor)) 113 | proto_file_lines.append("") 114 | 115 | return "\n".join(proto_file_lines) 116 | 117 | 118 | ## Impl ######################################################################## 119 | 120 | 121 | def _indent_lines(indent: int, lines: List[str]) -> List[str]: 122 | """Add indentation to the given lines""" 123 | if not indent: 124 | return lines 125 | return [ 126 | indent * PROTO_FILE_INDENT + line if line else line 127 | for line in "\n".join(lines).split("\n") 128 | ] 129 | 130 | 131 | def _enum_descriptor_to_file( 132 | enum_descriptor: _descriptor.EnumDescriptor, 133 | indent: int = 0, 134 | ) -> List[str]: 135 | """Make the string representation of an enum""" 136 | lines = [] 137 | lines.append(f"enum {enum_descriptor.name} {{") 138 | for val in enum_descriptor.values: 139 | lines.append(f"{PROTO_FILE_INDENT}{val.name} = {val.number};") 140 | lines.append("}") 141 | return _indent_lines(indent, lines) 142 | 143 | 144 | def _message_descriptor_to_file( 145 | message_descriptor: _descriptor.Descriptor, 146 | indent: int = 0, 147 | ) -> List[str]: 148 | """Make the string representation of an enum""" 149 | lines = [] 150 | lines.append(f"message {message_descriptor.name} {{") 151 | 152 | # Add nested enums 153 | if message_descriptor.enum_types: 154 | lines.append("") 155 | lines.append(PROTO_FILE_NESTED_ENUM_HEADER) 156 | for enum_descriptor in message_descriptor.enum_types: 157 | lines.extend(_enum_descriptor_to_file(enum_descriptor, indent=1)) 158 | 159 | # Add nested messages 160 | if message_descriptor.nested_types: 161 | lines.append("") 162 | lines.append(PROTO_FILE_NESTED_MESSAGE_HEADER) 163 | for nested_msg_descriptor in message_descriptor.nested_types: 164 | if _is_map_entry(nested_msg_descriptor): 165 | continue 166 | lines.extend(_message_descriptor_to_file(nested_msg_descriptor, indent=1)) 167 | 168 | # Add fields 169 | if message_descriptor.fields: 170 | lines.append("") 171 | lines.append(PROTO_FILE_FIELD_HEADER) 172 | for field_descriptor in message_descriptor.fields: 173 | # If the field is part of a oneof, defer it until adding oneofs 174 | # Unless the oneof is internal bookkeeping for an optional field 175 | if field_descriptor.containing_oneof and not _is_optional_field_oneof( 176 | field_descriptor.containing_oneof 177 | ): 178 | continue 179 | lines.extend(_field_descriptor_to_file(field_descriptor, indent=1)) 180 | 181 | # Add oneofs 182 | oneofs = ( 183 | [ 184 | oneof 185 | for oneof in message_descriptor.oneofs 186 | if not _is_optional_field_oneof(oneof) 187 | ] 188 | if message_descriptor.oneofs 189 | else [] 190 | ) 191 | if oneofs: 192 | lines.append("") 193 | lines.append(PROTO_FILE_ONEOF_HEADER) 194 | for oneof_descriptor in oneofs: 195 | lines.extend(_oneof_descriptor_to_file(oneof_descriptor, indent=1)) 196 | 197 | lines.append("}") 198 | return _indent_lines(indent, lines) 199 | 200 | 201 | def _service_descriptor_to_file( 202 | service_descriptor: _descriptor.ServiceDescriptor, 203 | indent: int = 0, 204 | ) -> List[str]: 205 | """Make the string representation of a service""" 206 | lines = [] 207 | lines.append(f"service {service_descriptor.name} {{") 208 | for method in service_descriptor.methods: 209 | # The MethodDescriptor protobuf representation holds fields to represent 210 | # server and client streaming, but these are not exposed in the python 211 | # class that wraps a MethodDescriptor. They are, however, held in the 212 | # underlying C implementation in upb, so the information is retained but 213 | # is only accessible when re-serializing the python object to a proto 214 | # representation of the descriptor. 215 | md_proto = descriptor_pb2.MethodDescriptorProto() 216 | method.CopyToProto(md_proto) 217 | client_streaming = "stream " if md_proto.client_streaming else "" 218 | server_streaming = "stream " if md_proto.server_streaming else "" 219 | 220 | lines.append( 221 | "{}rpc {}({}{}) returns ({}{});".format( 222 | PROTO_FILE_INDENT, 223 | method.name, 224 | client_streaming, 225 | method.input_type.full_name, 226 | server_streaming, 227 | method.output_type.full_name, 228 | ) 229 | ) 230 | lines.append("}") 231 | return _indent_lines(indent, lines) 232 | 233 | 234 | def _field_descriptor_to_file( 235 | field_descriptor: _descriptor.FieldDescriptor, 236 | indent: int = 0, 237 | ) -> List[str]: 238 | """Get the string version of a field""" 239 | 240 | # Add the repeated qualifier if needed 241 | field_line = "" 242 | if ( 243 | not _is_map_entry(field_descriptor.message_type) 244 | and field_descriptor.label == field_descriptor.LABEL_REPEATED 245 | ): 246 | field_line += "repeated " 247 | 248 | # Add the optional qualifier if needed 249 | if _is_optional_field_oneof(field_descriptor.containing_oneof): 250 | field_line += "optional " 251 | 252 | # Add the type 253 | field_line += _get_field_type_str(field_descriptor) 254 | 255 | # Add the name and number 256 | field_line += f" {field_descriptor.name} = {field_descriptor.number};" 257 | return _indent_lines(indent, [field_line]) 258 | 259 | 260 | def _oneof_descriptor_to_file( 261 | oneof_descriptor: _descriptor.OneofDescriptor, 262 | indent: int = 0, 263 | ) -> List[str]: 264 | """Get the string version of a oneof""" 265 | lines = [] 266 | lines.append(f"oneof {oneof_descriptor.name} {{") 267 | for field_descriptor in oneof_descriptor.fields: 268 | lines.extend(_field_descriptor_to_file(field_descriptor, indent=1)) 269 | lines.append("}") 270 | return _indent_lines(indent, lines) 271 | 272 | 273 | def _get_field_type_str(field_descriptor: _descriptor.FieldDescriptor) -> str: 274 | """Get the string version of a field's type""" 275 | 276 | # Add the type 277 | if field_descriptor.type == field_descriptor.TYPE_MESSAGE: 278 | if _is_map_entry(field_descriptor.message_type): 279 | key_type = _get_field_type_str( 280 | field_descriptor.message_type.fields_by_name["key"] 281 | ) 282 | val_type = _get_field_type_str( 283 | field_descriptor.message_type.fields_by_name["value"] 284 | ) 285 | return f"map<{key_type}, {val_type}>" 286 | else: 287 | return field_descriptor.message_type.full_name 288 | elif field_descriptor.type == field_descriptor.TYPE_ENUM: 289 | return field_descriptor.enum_type.full_name 290 | else: 291 | return PROTO_FILE_PRIMITIVE_TYPE_NAMES[field_descriptor.type] 292 | 293 | 294 | def _is_map_entry(message_descriptor: _descriptor.Descriptor) -> bool: 295 | """Check whether this message is a map entry""" 296 | return message_descriptor is not None and getattr( 297 | message_descriptor.GetOptions(), "map_entry", False 298 | ) 299 | 300 | 301 | def _is_optional_field_oneof(oneof_descriptor: Optional[_descriptor.OneofDescriptor]): 302 | """Check whether the oneof is an internal detail for dealing with an optional 303 | field, rather than an explicit oneof in the message description""" 304 | return ( 305 | oneof_descriptor 306 | and len(oneof_descriptor.fields) == 1 307 | and oneof_descriptor.name.startswith("_") 308 | ) 309 | -------------------------------------------------------------------------------- /tests/test_descriptor_to_file.py: -------------------------------------------------------------------------------- 1 | """ 2 | Tests for descriptor_to_file 3 | """ 4 | 5 | # Standard 6 | from types import ModuleType 7 | from typing import Dict, List, Optional 8 | import importlib 9 | import os 10 | import random 11 | import shlex 12 | import string 13 | import subprocess 14 | import sys 15 | import tempfile 16 | 17 | # Third Party 18 | import pytest 19 | 20 | # First Party 21 | import alog 22 | 23 | # Local 24 | from .conftest import temp_dpool 25 | from py_to_proto.descriptor_to_file import descriptor_to_file 26 | from py_to_proto.json_to_service import json_to_service 27 | from py_to_proto.jtd_to_proto import jtd_to_proto 28 | 29 | log = alog.use_channel("TEST") 30 | 31 | ## Helpers ##################################################################### 32 | 33 | sample_jtd_def = jtd_def = { 34 | "properties": { 35 | # bool field 36 | "foo": { 37 | "type": "boolean", 38 | }, 39 | # Array of strings 40 | "bar": { 41 | "elements": { 42 | "type": "string", 43 | } 44 | }, 45 | # Nested Object 46 | "buz": { 47 | "properties": { 48 | "bee": { 49 | "type": "boolean", 50 | } 51 | }, 52 | # Arbitrary map 53 | "additionalProperties": True, 54 | }, 55 | # timestamp field 56 | "time": { 57 | "type": "timestamp", 58 | }, 59 | # Array of objects 60 | "baz": { 61 | "elements": { 62 | "properties": { 63 | "nested": { 64 | "type": "int8", 65 | } 66 | } 67 | } 68 | }, 69 | # Enum 70 | "bat": { 71 | "enum": ["VAMPIRE", "DRACULA"], 72 | }, 73 | # Array of enums 74 | "bif": { 75 | "elements": { 76 | "enum": ["NAME", "SOUND_EFFECT"], 77 | } 78 | }, 79 | # Typed dict with primitive values 80 | "biz": { 81 | "values": { 82 | "type": "float32", 83 | } 84 | }, 85 | # Dict with message values 86 | "bonk": { 87 | "values": { 88 | "properties": { 89 | "how_hard": {"type": "float32"}, 90 | } 91 | } 92 | }, 93 | # Dict with enum values 94 | "bang": { 95 | "values": { 96 | "enum": ["BLAM", "KAPOW"], 97 | } 98 | }, 99 | # Descriminator (oneof) 100 | "bit": { 101 | "discriminator": "bitType", 102 | "mapping": { 103 | "SCREW_DRIVER": { 104 | "properties": { 105 | "isPhillips": {"type": "boolean"}, 106 | } 107 | }, 108 | "DRILL": { 109 | "properties": { 110 | "size": {"type": "float32"}, 111 | } 112 | }, 113 | }, 114 | }, 115 | }, 116 | # optionalProperties are also handled as properties 117 | "optionalProperties": { 118 | # Optional primitive 119 | "optionalString": { 120 | "type": "string", 121 | }, 122 | # Optional array 123 | "optionalList": { 124 | "elements": { 125 | "type": "string", 126 | } 127 | }, 128 | }, 129 | } 130 | 131 | 132 | def compile_proto_module( 133 | proto_content: str, imported_file_contents: Dict[str, str] = None 134 | ) -> Optional[ModuleType]: 135 | """Compile the proto file content locally""" 136 | with tempfile.TemporaryDirectory() as dirname: 137 | mod_name = "{}_temp".format( 138 | "".join([random.choice(string.ascii_lowercase) for _ in range(8)]) 139 | ) 140 | 141 | fname = os.path.join(dirname, f"{mod_name}.proto") 142 | with open(fname, "w") as handle: 143 | handle.write(proto_content) 144 | 145 | # Write out any files that need to be imported 146 | if imported_file_contents: 147 | for file_name, file_content in imported_file_contents.items(): 148 | file_path = os.path.join(dirname, file_name) 149 | with open(file_path, "w") as handle: 150 | handle.write(file_content) 151 | 152 | proto_files_to_compile = " ".join(os.listdir(dirname)) 153 | 154 | proc = subprocess.Popen( 155 | shlex.split( 156 | f"{sys.executable} -m grpc_tools.protoc -I '{dirname}' --python_out {dirname} {proto_files_to_compile}" 157 | ), 158 | stdout=subprocess.PIPE, 159 | stderr=subprocess.PIPE, 160 | ) 161 | stdout, stderr = proc.communicate() 162 | log.debug("Std Out--------\n%s", stdout) 163 | log.debug("Std Err--------\n%s", stderr) 164 | if proc.returncode != 0: 165 | return 166 | 167 | # Put this dir on the sys.path and load the module 168 | sys.path.append(dirname) 169 | 170 | mod = importlib.import_module(f"{mod_name}_pb2") 171 | sys.path.pop() 172 | return mod 173 | 174 | 175 | ## Tests ####################################################################### 176 | 177 | 178 | def test_descriptor_to_file_compilable_proto(temp_dpool): 179 | """Make sure that the generated protobuf can be compiled""" 180 | assert compile_proto_module( 181 | descriptor_to_file( 182 | jtd_to_proto( 183 | "Widgets", 184 | "foo.bar.baz.bat", 185 | sample_jtd_def, 186 | descriptor_pool=temp_dpool, 187 | validate_jtd=True, 188 | ) 189 | ) 190 | ) 191 | 192 | 193 | def test_descriptor_to_file_non_generated_proto(): 194 | """Make sure that a descriptor for an object generated with protoc can be 195 | serialized 196 | """ 197 | # Make a "standard" protobuf module 198 | temp_pb2 = compile_proto_module( 199 | """ 200 | syntax = "proto3"; 201 | package foo.bar.baz.biz; 202 | 203 | enum FooEnum { 204 | FOO = 0; 205 | BAR = 1; 206 | } 207 | 208 | message MsgWithMap { 209 | map the_map = 1; 210 | } 211 | 212 | message MsgWithOneof { 213 | oneof test_oneof { 214 | string str_version = 1; 215 | MsgWithMap msg_version = 2; 216 | } 217 | } 218 | """ 219 | ) 220 | assert temp_pb2 221 | 222 | # Try to serialize from the file descriptor 223 | auto_gen_content = descriptor_to_file(temp_pb2.DESCRIPTOR) 224 | assert "enum FooEnum" in auto_gen_content 225 | assert "message MsgWithMap" in auto_gen_content 226 | assert "message MsgWithOneof" in auto_gen_content 227 | 228 | # Serialize from one of the messages 229 | # NOTE: This just de-aliases to the file, so the generated content will hold 230 | # all of the messages 231 | auto_gen_content = descriptor_to_file(temp_pb2.MsgWithMap.DESCRIPTOR) 232 | assert "enum FooEnum" in auto_gen_content 233 | assert "message MsgWithMap" in auto_gen_content 234 | assert "message MsgWithOneof" in auto_gen_content 235 | 236 | 237 | def test_descriptor_to_file_invalid_descriptor_arg(): 238 | """Make sure an error is raised if the argument is not a valid descriptor""" 239 | with pytest.raises(ValueError): 240 | descriptor_to_file({"foo": "bar"}) 241 | 242 | 243 | def test_descriptor_to_file_enum_descriptor(temp_dpool): 244 | """Make sure descriptor_to_file can be called on a EnumDescriptor""" 245 | enum_descriptor = jtd_to_proto( 246 | "Foo", 247 | "foo.bar", 248 | {"enum": ["FOO", "BAR"]}, 249 | descriptor_pool=temp_dpool, 250 | ) 251 | res = descriptor_to_file(enum_descriptor) 252 | assert "enum Foo {" in res 253 | 254 | 255 | def test_descriptor_to_file_optional_properties(temp_dpool): 256 | """Make sure descriptor_to_file sticks `optional` in front of optional fields""" 257 | raw_protobuf = descriptor_to_file( 258 | jtd_to_proto( 259 | "Widgets", 260 | "foo.bar.baz.bat", 261 | sample_jtd_def, 262 | descriptor_pool=temp_dpool, 263 | validate_jtd=True, 264 | ) 265 | ) 266 | raw_protobuf_lines = raw_protobuf.splitlines() 267 | # Non-array things in `optionalProperties` should have `optional` 268 | assert any( 269 | "optional string optionalString" in line for line in raw_protobuf_lines 270 | ), f"optionalString not in {raw_protobuf}" 271 | # But fields cannot be both `repeated` and `optional` 272 | assert any( 273 | "repeated string optionalList" in line for line in raw_protobuf_lines 274 | ), f"optionalList broken in {raw_protobuf}" 275 | # Additionally, check that the internal oneof was not rendered 276 | assert "_optionalString" not in raw_protobuf 277 | 278 | 279 | def test_descriptor_to_file_service_descriptor(temp_dpool): 280 | """Make sure descriptor_to_file can be called on a ServiceDescriptor""" 281 | foo_message_descriptor = jtd_to_proto( 282 | name="Foo", 283 | package="foo.bar", 284 | jtd_def={ 285 | "properties": { 286 | "foo": {"type": "boolean"}, 287 | "bar": {"type": "float32"}, 288 | } 289 | }, 290 | descriptor_pool=temp_dpool, 291 | ) 292 | service_descriptor = json_to_service( 293 | name="FooService", 294 | package="foo.bar", 295 | json_service_def={ 296 | "service": { 297 | "rpcs": [ 298 | { 299 | "name": "FooPredictUnaryUnary", 300 | "input_type": "foo.bar.Foo", 301 | "output_type": "foo.bar.Foo", 302 | }, 303 | { 304 | "name": "FooPredictUnaryStream", 305 | "input_type": "foo.bar.Foo", 306 | "output_type": "foo.bar.Foo", 307 | "server_streaming": True, 308 | }, 309 | { 310 | "name": "FooPredictStreamUnary", 311 | "input_type": "foo.bar.Foo", 312 | "output_type": "foo.bar.Foo", 313 | "client_streaming": True, 314 | }, 315 | { 316 | "name": "FooPredictStreamStream", 317 | "input_type": "foo.bar.Foo", 318 | "output_type": "foo.bar.Foo", 319 | "client_streaming": True, 320 | "server_streaming": True, 321 | }, 322 | ] 323 | } 324 | }, 325 | descriptor_pool=temp_dpool, 326 | ).descriptor 327 | # TODO: type annotation fixup 328 | res = descriptor_to_file(service_descriptor) 329 | assert "service FooService {" in res 330 | assert "rpc FooPredictUnaryUnary(foo.bar.Foo) returns (foo.bar.Foo)" in res 331 | assert "rpc FooPredictUnaryStream(foo.bar.Foo) returns (stream foo.bar.Foo)" in res 332 | assert "rpc FooPredictStreamUnary(stream foo.bar.Foo) returns (foo.bar.Foo)" in res 333 | assert ( 334 | "rpc FooPredictStreamStream(stream foo.bar.Foo) returns (stream foo.bar.Foo)" 335 | in res 336 | ) 337 | 338 | 339 | def test_descriptor_to_file_compilable_proto_with_service_descriptor(temp_dpool): 340 | """Make sure descriptor_to_file can be called on a ServiceDescriptor""" 341 | 342 | random_message_name = "".join( 343 | [random.choice(string.ascii_lowercase) for _ in range(8)] 344 | ) 345 | # 🌶️🌶️🌶️ The message names must be capitalized to work 346 | random_message_name = random_message_name.capitalize() 347 | 348 | foo_message_descriptor = jtd_to_proto( 349 | name=f"{random_message_name}", 350 | package="foo.bar", 351 | jtd_def={ 352 | "properties": { 353 | "foo": {"type": "boolean"}, 354 | "bar": {"type": "float32"}, 355 | } 356 | }, 357 | descriptor_pool=temp_dpool, 358 | ) 359 | message_descriptor_file = descriptor_to_file(foo_message_descriptor) 360 | imported_files = {foo_message_descriptor.file.name: message_descriptor_file} 361 | service_descriptor = json_to_service( 362 | name=f"{random_message_name}Service", 363 | package="foo.bar", 364 | json_service_def={ 365 | "service": { 366 | "rpcs": [ 367 | { 368 | "name": "FooPredict", 369 | "input_type": f"foo.bar.{random_message_name}", 370 | "output_type": f"foo.bar.{random_message_name}", 371 | } 372 | ] 373 | } 374 | }, 375 | descriptor_pool=temp_dpool, 376 | ).descriptor 377 | res = descriptor_to_file(service_descriptor) 378 | assert compile_proto_module(res, imported_file_contents=imported_files) 379 | -------------------------------------------------------------------------------- /py_to_proto/json_to_service.py: -------------------------------------------------------------------------------- 1 | # Standard 2 | from typing import Callable, Dict, List, Optional, Type 3 | import dataclasses 4 | import types 5 | 6 | # Third Party 7 | from google.protobuf import descriptor_pb2 8 | from google.protobuf import descriptor_pool as _descriptor_pool 9 | from google.protobuf.descriptor import MethodDescriptor, ServiceDescriptor 10 | from google.protobuf.service_reflection import GeneratedServiceType 11 | import grpc 12 | 13 | # First Party 14 | import alog 15 | 16 | # Local 17 | from .compat import GeneratedServiceType, make_service_class 18 | from .descriptor_to_message_class import ( 19 | _add_protobuf_serializers, 20 | descriptor_to_message_class, 21 | ) 22 | from .utils import safe_add_fd_to_pool 23 | from .validation import JTD_TYPE_VALIDATORS, validate_jtd 24 | 25 | log = alog.use_channel("JSON2S") 26 | 27 | SERVICE_JTD_SCHEMA = { 28 | "properties": { 29 | "service": { 30 | "properties": { 31 | "rpcs": { 32 | "elements": { 33 | "properties": { 34 | "input_type": {"type": "string"}, 35 | "name": {"type": "string"}, 36 | "output_type": {"type": "string"}, 37 | }, 38 | "optionalProperties": { 39 | "server_streaming": {"type": "boolean"}, 40 | "client_streaming": {"type": "boolean"}, 41 | }, 42 | } 43 | } 44 | } 45 | } 46 | } 47 | } 48 | 49 | EXTENDED_TYPE_VALIDATORS = dict( 50 | bytes=lambda x: isinstance(x, bytes), **JTD_TYPE_VALIDATORS 51 | ) 52 | 53 | # Python type hint equivalent of jtd service schema 54 | ServiceJsonType = Dict[str, Dict[str, List[Dict[str, str]]]] 55 | 56 | 57 | @dataclasses.dataclass 58 | class GRPCService: 59 | descriptor: ServiceDescriptor 60 | registration_function: Callable[[GeneratedServiceType, grpc.Server], None] 61 | client_stub_class: Type 62 | service_class: Type[GeneratedServiceType] 63 | 64 | 65 | def json_to_service( 66 | name: str, 67 | package: str, 68 | json_service_def: ServiceJsonType, 69 | *, 70 | descriptor_pool: Optional[_descriptor_pool.DescriptorPool] = None, 71 | ) -> GRPCService: 72 | """Convert a JSON representation of an RPC service into a GRPCService. 73 | 74 | Reference: https://jsontypedef.com/docs/jtd-in-5-minutes/ 75 | 76 | Args: 77 | name: str 78 | The name for the top-level service object 79 | package: str 80 | The proto package name to use for this service 81 | json_service_def: Dict[str, Union[dict, str]] 82 | A JSON dict describing a service that matches the SERVICE_JTD_SCHEMA 83 | 84 | Kwargs: 85 | descriptor_pool: Optional[descriptor_pool.DescriptorPool] 86 | If given, this DescriptorPool will be used to aggregate the set of 87 | message descriptors 88 | 89 | Returns: 90 | grpc_service: GRPCService 91 | The GRPCService container with the service descriptor and other associated 92 | grpc bits required to boot a server: 93 | - Servicer registration function 94 | - Client stub class 95 | - Servicer base class 96 | """ 97 | # Ensure we have a valid service spec 98 | log.debug2("Validating service json") 99 | if not validate_jtd(json_service_def, SERVICE_JTD_SCHEMA, EXTENDED_TYPE_VALIDATORS): 100 | raise ValueError("Invalid service json") 101 | 102 | # And descriptor pool 103 | if descriptor_pool is None: 104 | log.debug2("Using the default descriptor pool") 105 | descriptor_pool = _descriptor_pool.Default() 106 | 107 | # First get the descriptor proto: 108 | service_fd_proto = _json_to_service_file_descriptor_proto( 109 | name, package, json_service_def, descriptor_pool=descriptor_pool 110 | ) 111 | assert ( 112 | len(service_fd_proto.service) == 1 113 | ), f"File Descriptor {service_fd_proto.name} should only have one service" 114 | service_descriptor_proto = service_fd_proto.service[0] 115 | 116 | # Then put that in the pool to get the real descriptor back 117 | log.debug("Adding Descriptors to DescriptorPool") 118 | safe_add_fd_to_pool(service_fd_proto, descriptor_pool) 119 | service_fullname = name if not package else ".".join([package, name]) 120 | service_descriptor = descriptor_pool.FindServiceByName(service_fullname) 121 | 122 | # Then the client stub: 123 | client_stub = _service_descriptor_to_client_stub( 124 | service_descriptor, service_descriptor_proto 125 | ) 126 | 127 | # And the registration function: 128 | registration_function = _service_descriptor_to_server_registration_function( 129 | service_descriptor, service_descriptor_proto 130 | ) 131 | 132 | # And service class! 133 | service_class = _service_descriptor_to_service(service_descriptor) 134 | 135 | return GRPCService( 136 | descriptor=service_descriptor, 137 | service_class=service_class, 138 | client_stub_class=client_stub, 139 | registration_function=registration_function, 140 | ) 141 | 142 | 143 | def _json_to_service_file_descriptor_proto( 144 | name: str, 145 | package: str, 146 | json_service_def: ServiceJsonType, 147 | *, 148 | descriptor_pool: Optional[_descriptor_pool.DescriptorPool] = None, 149 | ) -> descriptor_pb2.FileDescriptorProto: 150 | """Creates the FileDescriptorProto for the service definition""" 151 | 152 | method_descriptor_protos: List[descriptor_pb2.MethodDescriptorProto] = [] 153 | imports: List[str] = [] 154 | 155 | json_service = json_service_def["service"] 156 | rpcs_def = json_service["rpcs"] 157 | for rpc_def in rpcs_def: 158 | rpc_input_type = rpc_def["input_type"] 159 | input_descriptor = descriptor_pool.FindMessageTypeByName(rpc_input_type) 160 | 161 | rpc_output_type = rpc_def["output_type"] 162 | output_descriptor = descriptor_pool.FindMessageTypeByName(rpc_output_type) 163 | 164 | method_descriptor_protos.append( 165 | descriptor_pb2.MethodDescriptorProto( 166 | name=rpc_def["name"], 167 | input_type=input_descriptor.full_name, 168 | output_type=output_descriptor.full_name, 169 | client_streaming=rpc_def.get("client_streaming", False), 170 | server_streaming=rpc_def.get("server_streaming", False), 171 | ) 172 | ) 173 | imports.append(input_descriptor.file.name) 174 | imports.append(output_descriptor.file.name) 175 | 176 | imports = sorted(list(set(imports))) 177 | 178 | service_descriptor_proto = descriptor_pb2.ServiceDescriptorProto( 179 | name=name, method=method_descriptor_protos 180 | ) 181 | 182 | fd_proto = descriptor_pb2.FileDescriptorProto( 183 | name=f"{name.lower()}.proto", 184 | package=package, 185 | syntax="proto3", 186 | dependency=imports, 187 | # **proto_kwargs, 188 | service=[service_descriptor_proto], 189 | ) 190 | 191 | return fd_proto 192 | 193 | 194 | def _service_descriptor_to_service( 195 | service_descriptor: ServiceDescriptor, 196 | ) -> Type[GeneratedServiceType]: 197 | """Create a service class from a service descriptor 198 | 199 | Args: 200 | service_descriptor: google.protobuf.descriptor.ServiceDescriptor 201 | The ServiceDescriptor to generate a service interface for 202 | 203 | Returns: 204 | Type[google.protobuf.service_reflection.GeneratedServiceType] 205 | A new class with metaclass 206 | google.protobuf.service_reflection.GeneratedServiceType containing 207 | the methods from the service_descriptor 208 | """ 209 | service_class = make_service_class(service_descriptor) 210 | service_class = _add_protobuf_serializers(service_class, service_descriptor) 211 | 212 | return service_class 213 | 214 | 215 | def _service_descriptor_to_client_stub( 216 | service_descriptor: ServiceDescriptor, 217 | service_descriptor_proto: descriptor_pb2.ServiceDescriptorProto, 218 | ) -> Type: 219 | """Generates a new client stub class from the service descriptor 220 | 221 | Args: 222 | service_descriptor: google.protobuf.descriptor.ServiceDescriptor 223 | The ServiceDescriptor to generate a service interface for 224 | service_descriptor_proto: google.protobuf.descriptor_pb2.ServiceDescriptorProto 225 | The descriptor proto for that service. This holds the I/O streaming information 226 | for each method 227 | """ 228 | _assert_method_lists_same(service_descriptor, service_descriptor_proto) 229 | 230 | def _get_channel_func( 231 | channel: grpc.Channel, method: descriptor_pb2.MethodDescriptorProto 232 | ) -> Callable: 233 | if method.client_streaming and method.server_streaming: 234 | return channel.stream_stream 235 | if not method.client_streaming and method.server_streaming: 236 | return channel.unary_stream 237 | if method.client_streaming and not method.server_streaming: 238 | return channel.stream_unary 239 | return channel.unary_unary 240 | 241 | # Initializer 242 | def initializer(self, channel: grpc.Channel): 243 | f"""Initializes a client stub with for the {service_descriptor.name} Service""" 244 | for method, method_proto in zip( 245 | service_descriptor.methods, service_descriptor_proto.method 246 | ): 247 | setattr( 248 | self, 249 | method.name, 250 | _get_channel_func(channel, method_proto)( 251 | _get_method_fullname(method), 252 | request_serializer=descriptor_to_message_class( 253 | method.input_type 254 | ).SerializeToString, 255 | response_deserializer=descriptor_to_message_class( 256 | method.output_type 257 | ).FromString, 258 | ), 259 | ) 260 | 261 | # Creating class dynamically 262 | return type( 263 | f"{service_descriptor.name}Stub", 264 | (object,), 265 | { 266 | "__init__": initializer, 267 | }, 268 | ) 269 | 270 | 271 | def _service_descriptor_to_server_registration_function( 272 | service_descriptor: ServiceDescriptor, 273 | service_descriptor_proto: descriptor_pb2.ServiceDescriptorProto, 274 | ) -> Callable[[GeneratedServiceType, grpc.Server], None]: 275 | """Generates a server registration function from the service descriptor 276 | 277 | Args: 278 | service_descriptor: google.protobuf.descriptor.ServiceDescriptor 279 | The ServiceDescriptor to generate a service interface for 280 | service_descriptor_proto: google.protobuf.descriptor_pb2.ServiceDescriptorProto 281 | The descriptor proto for that service. This holds the I/O streaming information 282 | for each method 283 | 284 | Returns: 285 | function: Server registration function to add service handlers to a server 286 | """ 287 | _assert_method_lists_same(service_descriptor, service_descriptor_proto) 288 | 289 | def _get_handler(method: descriptor_pb2.MethodDescriptorProto): 290 | if method.client_streaming and method.server_streaming: 291 | return grpc.stream_stream_rpc_method_handler 292 | if not method.client_streaming and method.server_streaming: 293 | return grpc.unary_stream_rpc_method_handler 294 | if method.client_streaming and not method.server_streaming: 295 | return grpc.stream_unary_rpc_method_handler 296 | return grpc.unary_unary_rpc_method_handler 297 | 298 | def registration_function(servicer: GeneratedServiceType, server: grpc.Server): 299 | """Server registration function""" 300 | rpc_method_handlers = { 301 | method.name: _get_handler(method_proto)( 302 | getattr(servicer, method.name), 303 | request_deserializer=descriptor_to_message_class( 304 | method.input_type 305 | ).FromString, 306 | response_serializer=descriptor_to_message_class( 307 | method.output_type 308 | ).SerializeToString, 309 | ) 310 | for method, method_proto in zip( 311 | service_descriptor.methods, service_descriptor_proto.method 312 | ) 313 | } 314 | generic_handler = grpc.method_handlers_generic_handler( 315 | service_descriptor.full_name, rpc_method_handlers 316 | ) 317 | server.add_generic_rpc_handlers((generic_handler,)) 318 | 319 | return registration_function 320 | 321 | 322 | def _get_method_fullname(method: MethodDescriptor): 323 | method_name_parts = method.full_name.split(".") 324 | return f"/{'.'.join(method_name_parts[:-1])}/{method_name_parts[-1]}" 325 | 326 | 327 | def _assert_method_lists_same( 328 | service_descriptor: ServiceDescriptor, 329 | service_descriptor_proto: descriptor_pb2.ServiceDescriptorProto, 330 | ): 331 | assert len(service_descriptor.methods) == len(service_descriptor_proto.method), ( 332 | f"Method count mismatch: {service_descriptor.full_name} has" 333 | f" {len(service_descriptor.methods)} methods but proto descriptor" 334 | f" {service_descriptor_proto.name} has {len(service_descriptor_proto.method)} methods" 335 | ) 336 | 337 | for m1, m2 in zip(service_descriptor.methods, service_descriptor_proto.method): 338 | assert m1.name == m2.name, f"Method mismatch: {m1.name}, {m2.name}" 339 | -------------------------------------------------------------------------------- /py_to_proto/utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Common utilities that are shared across converters 3 | """ 4 | 5 | # Standard 6 | from typing import Any, List 7 | import re 8 | 9 | # Third Party 10 | from google.protobuf import descriptor_pb2 11 | import google.protobuf.descriptor_pool 12 | 13 | # First Party 14 | import alog 15 | 16 | log = alog.use_channel("2PUTL") 17 | 18 | 19 | def to_upper_camel(snake_str: str) -> str: 20 | """Convert a snake_case string to UpperCamelCase""" 21 | if not snake_str: 22 | return snake_str 23 | return ( 24 | snake_str[0].upper() 25 | + re.sub("_([a-zA-Z])", lambda pat: pat.group(1).upper(), snake_str)[1:] 26 | ) 27 | 28 | 29 | def safe_add_fd_to_pool( 30 | fd_proto: descriptor_pb2.FileDescriptorProto, 31 | descriptor_pool: google.protobuf.descriptor_pool.DescriptorPool, 32 | ): 33 | """Safely add a new file descriptor to a descriptor pool. This function will 34 | look for naming collisions and if one occurs, it will validate the inbound 35 | descriptor against the conflicting descriptor in the pool to see if they are 36 | the same. If they are, no further action is taken. If they are not, an error 37 | is raised. 38 | """ 39 | try: 40 | existing_fd = descriptor_pool.FindFileByName(fd_proto.name) 41 | # Rebuild the file descriptor proto so that we can compare; there is 42 | # almost certainly a more efficient way to compare that avoids this. 43 | existing_proto = descriptor_pb2.FileDescriptorProto() 44 | existing_fd.CopyToProto(existing_proto) 45 | # Raise if the file exists already with different content 46 | # Otherwise, do not attempt to re-add the file 47 | if not _are_same_file_descriptors(fd_proto, existing_proto): 48 | # NOTE: This is a TypeError because that is what you get most of the time when you 49 | # have conflict issues in the descriptor pool arising from JTD to Proto followed by 50 | # importing differing defs for the same top level message type using different file 51 | # names (i.e., skipping this validation) compiled by protoc. Raising TypeError here 52 | # ensures that we at least usually raise the same error type regardless of 53 | # import / operation order. 54 | raise TypeError( 55 | f"Cannot add new file {fd_proto.name} to descriptor pool, file already exists with different content" 56 | ) 57 | except KeyError: 58 | # It's okay for the file to not already exist, we'll add it! 59 | try: 60 | descriptor_pool.Add(fd_proto) 61 | except TypeError as e: 62 | # More likely than not, this is a duplicate symbol; the main case in which 63 | # this could occur is when you've compiled files with protoc, added them to your 64 | # descriptor pool, and ALSO added the defs in your py_to_proto schema, but the 65 | # lookup validation with fd_proto.name is skipped because the .proto file fed to 66 | # protoc had a different name! 67 | raise TypeError( 68 | f"Failed to add {fd_proto.name} to descriptor pool with error: [{e}]; Hint: if you previously used protoc to compile this definition, you must recompile it with the name {fd_proto.name} to avoid the conflict." 69 | ) 70 | 71 | 72 | ## Implementation Details ###################################################### 73 | 74 | 75 | def _are_same_file_descriptors( 76 | d1: descriptor_pb2.FileDescriptorProto, d2: descriptor_pb2.FileDescriptorProto 77 | ) -> bool: 78 | """Validate that there are no consistency issues in the message descriptors of 79 | our proto file descriptors. 80 | 81 | Args: 82 | d1: descriptor_pb2.FileDescriptorProto 83 | First FileDescriptorProto we want to compare. 84 | d2: descriptor_pb2.FileDescriptorProto 85 | second FileDescriptorProto we want to compare. 86 | 87 | Returns: 88 | True if the provided file descriptor proto files are identical. 89 | """ 90 | have_same_deps = d1.dependency == d2.dependency 91 | are_same_package = d1.package == d2.package 92 | have_aligned_enums = _are_same_enum_descriptor(d1.enum_type, d2.enum_type) 93 | have_aligned_messages = _check_message_descs_alignment( 94 | d1.message_type, d2.message_type 95 | ) 96 | have_aligned_services = _check_service_desc_alignment(d1.service, d2.service) 97 | return ( 98 | have_same_deps 99 | and are_same_package 100 | and have_aligned_enums 101 | and have_aligned_messages 102 | and have_aligned_services 103 | ) 104 | 105 | 106 | def _are_same_enum_descriptor(d1_enums: Any, d2_enums: Any) -> bool: 107 | """Determine if two iterables of EnumDescriptorProtos have the same enums. 108 | This means the following: 109 | 110 | 1. They have the same names in their respective .enum_type properties. 111 | 2. For every enum in enum_type, they have the same number of values & the same names. 112 | 113 | Args: 114 | d1_enums: Any 115 | First iterable of enum desc protos to compare, e.g., RepeatedCompositeContainer. 116 | d2_enums: Any 117 | Second iterable of enum desc protos to compare, e.g., RepeatedCompositeContainer. 118 | 119 | Returns: 120 | True if the provided iterable enum descriptors are identical. 121 | """ 122 | d1_enum_map = {enum.name: enum for enum in d1_enums} 123 | d2_enum_map = {enum.name: enum for enum in d2_enums} 124 | if d1_enum_map.keys() != d2_enum_map.keys(): 125 | return False 126 | 127 | for enum_name in d1_enum_map.keys(): 128 | d1_enum_descriptor = d1_enum_map[enum_name] 129 | d2_enum_descriptor = d2_enum_map[enum_name] 130 | if len(d1_enum_descriptor.value) != len(d2_enum_descriptor.value): 131 | return False 132 | # Compare each entry in the repeated composite container, 133 | # i.e., all of our EnumValueDescriptorProto objects 134 | for first_enum_val, second_enum_val in zip( 135 | d1_enum_descriptor.value, d2_enum_descriptor.value 136 | ): 137 | if ( 138 | first_enum_val.name != second_enum_val.name 139 | or first_enum_val.number != second_enum_val.number 140 | ): 141 | return False 142 | return True 143 | 144 | 145 | def _check_message_descs_alignment( 146 | d1_msg_container: Any, d2_msg_container: Any 147 | ) -> bool: 148 | """Determine if two message descriptor proto containers, i.e., RepeatedCompositeContainers 149 | have the same message types. This means the following: 150 | 151 | 1. The messages contained in each FileDescriptorProto are the same. 152 | 2. For each of those respective messages, their respective fields are roughly the same. 153 | Note that this includes nested_types, which are verified recursively. 154 | 155 | Args: 156 | d1_msg_container: Any 157 | First container iterable of message descriptors protos to be verified. 158 | d2_msg_container: Any 159 | Second container iterable of message descriptors protos to be verified. 160 | 161 | Returns: 162 | bool 163 | True if the contained message descriptor protos are identical. 164 | """ 165 | d1_msg_descs = {msg.name: msg for msg in d1_msg_container} 166 | d2_msg_descs = {msg.name: msg for msg in d2_msg_container} 167 | 168 | # Ensure that our descriptors have the same dependencies & top level message types 169 | if d1_msg_descs.keys() != d2_msg_descs.keys(): 170 | return False 171 | # For every encapsulated message descriptor, ensure that every field has the same 172 | # name, number, label, type, and type name 173 | for msg_name in d1_msg_descs.keys(): 174 | d1_message_descriptor = d1_msg_descs[msg_name] 175 | d2_message_descriptor = d2_msg_descs[msg_name] 176 | # Ensure that these messages are actually the same 177 | if not _are_same_message_descriptor( 178 | d1_message_descriptor, d2_message_descriptor 179 | ): 180 | return False 181 | return True 182 | 183 | 184 | def _check_service_desc_alignment( 185 | d1_service_list: List[descriptor_pb2.ServiceDescriptorProto], 186 | d2_service_list: List[descriptor_pb2.ServiceDescriptorProto], 187 | ) -> bool: 188 | d1_service_descs = {svc.name: svc for svc in d1_service_list} 189 | d2_service_descs = {svc.name: svc for svc in d2_service_list} 190 | 191 | log.debug( 192 | "Checking service descriptors: [%s] and [%s]", 193 | d1_service_descs, 194 | d2_service_descs, 195 | ) 196 | # Ensure that our service names are the same set 197 | if d1_service_descs.keys() != d2_service_descs.keys(): 198 | # Excluding from code coverage: We can't actually generate file descriptors with multiple services in them. 199 | # But, this check seems pretty basic and worth leaving in if this ever gets extended in the future. 200 | return False # pragma: no cover 201 | 202 | # For every service, ensure that every method is the same 203 | for svc_name in d1_service_descs.keys(): 204 | d1_service = d1_service_descs[svc_name] 205 | d2_service = d2_service_descs[svc_name] 206 | 207 | if not _are_same_service_descriptor(d1_service, d2_service): 208 | return False 209 | return True 210 | 211 | 212 | def _are_same_service_descriptor( 213 | d1_service: descriptor_pb2.ServiceDescriptorProto, 214 | d2_service: descriptor_pb2.ServiceDescriptorProto, 215 | ) -> bool: 216 | # Not checking service.name because we only compare services with the same name 217 | 218 | d1_methods = {method.name: method for method in d1_service.method} 219 | d2_methods = {method.name: method for method in d2_service.method} 220 | 221 | # Ensure that our service names are the same set 222 | if d1_methods.keys() != d2_methods.keys(): 223 | return False 224 | 225 | # For every service, ensure that every method is the same 226 | for method_name in d1_methods.keys(): 227 | d1_method = d1_methods[method_name] 228 | d2_method = d2_methods[method_name] 229 | 230 | if not _are_same_method_descriptor(d1_method, d2_method): 231 | return False 232 | 233 | return True 234 | 235 | 236 | def _are_same_method_descriptor( 237 | d1_method: descriptor_pb2.MethodDescriptorProto, 238 | d2_method: descriptor_pb2.MethodDescriptorProto, 239 | ) -> bool: 240 | # Not checking method.name because we only compare services with the same name 241 | 242 | if not _are_types_similar(d1_method.input_type, d2_method.input_type): 243 | return False 244 | if not _are_types_similar(d1_method.output_type, d2_method.output_type): 245 | return False 246 | # TODO: Add the ability for `json_to_service` to set options 247 | # Then we can test this! 248 | if d1_method.options != d2_method.options: 249 | log.debug( # pragma: no cover 250 | "Method options differ! [%s] vs. [%s]", d1_method.options, d2_method.options 251 | ) 252 | return False # pragma: no cover 253 | if d1_method.client_streaming != d2_method.client_streaming: 254 | return False 255 | if d1_method.server_streaming != d2_method.server_streaming: 256 | return False 257 | return True 258 | 259 | 260 | def _are_types_similar(type_1: str, type_2: str) -> bool: 261 | """Returns true iff type names are the same or differ only by a leading `.`""" 262 | # TODO: figure out why when you `json_to_service` the same thing twice, on of the service descriptors ends up with 263 | # fully qualified names (.foo.bar.Foo) and the other does not (foo.bar.Foo) 264 | return type_1.lstrip(".") == type_2.lstrip(".") 265 | 266 | 267 | def _are_same_message_descriptor( 268 | d1: descriptor_pb2.DescriptorProto, d2: descriptor_pb2.DescriptorProto 269 | ) -> bool: 270 | """Determine if two message descriptors proto are representing the same thing. We do this by 271 | ensuring that their fields all have the same fields, then inspecting each of their labels, 272 | names, etc, for alignment. We do the same for any nested fields. 273 | 274 | Args: 275 | d1: descriptor_pb2.DescriptorProto 276 | First message descriptor to be compared. 277 | d2: descriptor_pb2.DescriptorProto 278 | second message descriptor to be compared. 279 | 280 | Returns: 281 | bool 282 | True of messages are identical, False otherwise. 283 | """ 284 | # Compare any nested enums in our message. 285 | if not _are_same_enum_descriptor(d1.enum_type, d2.enum_type): 286 | return False 287 | # Make sure all of our named fields align, then check them individually 288 | d1_field_descs = {field.name: field for field in d1.field} 289 | d2_field_descs = {field.name: field for field in d2.field} 290 | if d1_field_descs.keys() != d2_field_descs.keys(): 291 | return False 292 | for field_name in d1_field_descs.keys(): 293 | # We consider two fields equal if they have the same name, label 294 | d1_field_descriptor = d1_field_descs[field_name] 295 | d2_field_descriptor = d2_field_descs[field_name] 296 | if ( 297 | d1_field_descriptor.label != d2_field_descriptor.label 298 | or d1_field_descriptor.type != d2_field_descriptor.type 299 | ): 300 | return False 301 | # For nested fields, we treat them similarly to how we've treated messages 302 | # and recurse into comparisons used for the top level messages. 303 | if d1.nested_type or d2.nested_type: 304 | return _check_message_descs_alignment(d1.nested_type, d2.nested_type) 305 | # Otherwise, we have no more nested layers to check; we're done! 306 | return True 307 | -------------------------------------------------------------------------------- /tests/test_validation.py: -------------------------------------------------------------------------------- 1 | """ 2 | Unit tests for validation logic. These tests exercise all examples from the RFC 3 | https://www.rfc-editor.org/rfc/rfc8927 4 | """ 5 | 6 | # Third Party 7 | import pytest 8 | 9 | # First Party 10 | import alog 11 | 12 | # Local 13 | from py_to_proto.jtd_to_proto import jtd_to_proto 14 | from py_to_proto.validation import _validate_jtd_impl, is_valid_jtd, validate_jtd 15 | 16 | log = alog.use_channel("TEST") 17 | 18 | ## is_valid_jtd ################################################################ 19 | 20 | SampleDescriptor = jtd_to_proto( 21 | "Sample", "foo.bar", {"properties": {"foo": {"type": "string"}}} 22 | ) 23 | 24 | 25 | VALID_SCHEMAS = [ 26 | # Empty 27 | {}, 28 | {"nullable": True}, 29 | {"metadata": {"foo": 12345}}, 30 | {"definitions": {}}, 31 | # Ref 32 | { 33 | "definitions": { 34 | "coordinates": { 35 | "properties": {"lat": {"type": "float32"}, "lng": {"type": "float32"}} 36 | } 37 | }, 38 | "properties": { 39 | "user_location": {"ref": "coordinates"}, 40 | "server_location": {"ref": "coordinates"}, 41 | }, 42 | }, 43 | # Type 44 | {"type": "uint8"}, 45 | {"type": SampleDescriptor}, 46 | # Enum 47 | {"enum": ["PENDING", "IN_PROGRESS", "DONE"]}, 48 | # Elements 49 | {"elements": {"type": "uint8"}}, 50 | # Properties 51 | {"optionalProperties": {"foo": {}}}, 52 | {"optionalProperties": {"foo": {}}, "additionalProperties": True}, 53 | { 54 | "properties": { 55 | "users": { 56 | "elements": { 57 | "properties": { 58 | "id": {"type": "string"}, 59 | "name": {"type": "string"}, 60 | "create_time": {"type": "timestamp"}, 61 | }, 62 | "optionalProperties": {"delete_time": {"type": "timestamp"}}, 63 | } 64 | }, 65 | "next_page_token": {"type": "string"}, 66 | } 67 | }, 68 | # Values 69 | {"values": {"type": "uint8"}}, 70 | # Discriminator 71 | { 72 | "discriminator": "event_type", 73 | "mapping": { 74 | "account_deleted": {"properties": {"account_id": {"type": "string"}}}, 75 | "account_payment_plan_changed": { 76 | "properties": { 77 | "account_id": {"type": "string"}, 78 | "payment_plan": {"enum": ["FREE", "PAID"]}, 79 | }, 80 | "optionalProperties": {"upgraded_by": {"type": "string"}}, 81 | }, 82 | }, 83 | }, 84 | { 85 | "discriminator": "event_type", 86 | "nullable": True, 87 | "mapping": { 88 | "account_deleted": { 89 | "nullable": True, 90 | "properties": {"account_id": {"type": "string"}}, 91 | }, 92 | }, 93 | }, 94 | ] 95 | 96 | INVALID_SCHEMAS = [ 97 | # Empty 98 | {"nullable": "foo"}, 99 | {"metadata": "foo"}, 100 | # Ref 101 | {"ref": "foo"}, 102 | {"ref": 1234}, 103 | {"definitions": {"foo": {}}, "ref": "bar"}, 104 | {"definitions": 1234}, 105 | {"definitions": {"foo": {"definitions": {}}}}, 106 | # Type 107 | {"type": True}, 108 | {"type": "foo"}, 109 | # Enum 110 | {"enum": []}, 111 | {"enum": 1234}, 112 | {"enum": ["a\\b", "a\u005Cb"]}, 113 | # Elements 114 | {"elements": True}, 115 | {"elements": {"type": "foo"}}, 116 | # Properties 117 | { 118 | "properties": {"confusing": {}}, 119 | "optionalProperties": {"confusing": {}}, 120 | }, 121 | {"optionalProperties": {}}, 122 | {"properties": {}}, 123 | {"properties": {}, "optionalProperties": {}}, 124 | {"properties": 1234}, 125 | {"optionalProperties": {"foo": {}}, "additionalProperties": 12345}, 126 | # Values 127 | {"values": True}, 128 | {"values": {"type": "foo"}}, 129 | # Discriminator 130 | { 131 | "discriminator": "event_type", 132 | "mapping": { 133 | "can_the_object_be_null_or_not?": { 134 | "nullable": True, 135 | "properties": {"foo": {"type": "string"}}, 136 | } 137 | }, 138 | }, 139 | { 140 | "discriminator": "event_type", 141 | "mapping": { 142 | "is_event_type_a_string_or_a_float32?": { 143 | "properties": {"event_type": {"type": "float32"}} 144 | } 145 | }, 146 | }, 147 | { 148 | "discriminator": "event_type", 149 | "mapping": { 150 | "is_event_type_a_string_or_an_optional_float32?": { 151 | "optionalProperties": {"event_type": {"type": "float32"}} 152 | } 153 | }, 154 | }, 155 | { 156 | "discriminator": "key", 157 | "mapping": {"int": {"type": "int32"}, "str": {"type": "string"}}, 158 | }, 159 | ] 160 | 161 | 162 | @pytest.mark.parametrize("schema", VALID_SCHEMAS) 163 | def test_valid_schemas(schema): 164 | """Make sure all valid schemas return True as expected""" 165 | log.debug("Testing valid schema: %s", schema) 166 | assert is_valid_jtd(schema) 167 | 168 | 169 | @pytest.mark.parametrize("schema", INVALID_SCHEMAS) 170 | def test_invalid_schemas(schema): 171 | """Make sure all invalid schemas return False as expected""" 172 | log.debug("Testing invalid schema: %s", schema) 173 | assert not is_valid_jtd(schema) 174 | 175 | 176 | ## validate_jtd ################################################################ 177 | 178 | 179 | class CustomClass: 180 | pass 181 | 182 | 183 | # (object, schema) 184 | VALID_JTD = [ 185 | # Empty 186 | ({"foo": 1234, "bar": CustomClass()}, {}), 187 | (None, {"nullable": True}), 188 | (CustomClass(), {"metadata": {"foo": "bar"}}), 189 | # Ref 190 | (123, {"definitions": {"a": {"type": "float32"}}, "ref": "a"}), 191 | (None, {"definitions": {"a": {"type": "float32"}}, "ref": "a", "nullable": True}), 192 | # Type 193 | (123, {"type": "int32"}), 194 | (123, {"type": "float64"}), 195 | (1.23, {"type": "float64"}), 196 | (None, {"type": "boolean", "nullable": True}), 197 | # Enum 198 | ("FOO", {"enum": ["FOO", "BAR"]}), 199 | (None, {"enum": ["FOO", "BAR"], "nullable": True}), 200 | # Elements 201 | ([1, 2], {"elements": {"type": "int32"}}), 202 | ([], {"elements": {"type": "int32"}}), 203 | (None, {"elements": {"type": "int32"}, "nullable": True}), 204 | ( 205 | [{"foo": 1}, {"foo": 2}], 206 | {"elements": {"properties": {"foo": {"type": "int32"}}}}, 207 | ), 208 | # Properties 209 | ({"foo": 123}, {"properties": {"foo": {"type": "int32"}}}), 210 | ({"foo": ["bar"]}, {"properties": {"foo": {"elements": {"type": "string"}}}}), 211 | ( 212 | {"foo": 123, "bar": "baz"}, 213 | {"properties": {"foo": {"type": "int32"}, "bar": {"type": "string"}}}, 214 | ), 215 | ( 216 | {"foo": 123, "bar": "baz"}, 217 | { 218 | "properties": {"foo": {"type": "int32"}}, 219 | "optionalProperties": {"bar": {"type": "string"}}, 220 | }, 221 | ), 222 | ( 223 | {"foo": 123}, 224 | { 225 | "properties": {"foo": {"type": "int32"}}, 226 | "optionalProperties": {"bar": {"type": "string"}}, 227 | }, 228 | ), 229 | ({}, {"optionalProperties": {"bar": {"type": "string"}}}), 230 | ( 231 | {"buz": 123}, 232 | { 233 | "optionalProperties": {"bar": {"type": "string"}}, 234 | "additionalProperties": True, 235 | }, 236 | ), 237 | # Values 238 | ({"foo": 123, "bar": -2}, {"values": {"type": "int32"}}), 239 | ({"foo": {"bar": -2}}, {"values": {"properties": {"bar": {"type": "int32"}}}}), 240 | # Discriminator 241 | ( 242 | {"key": "str", "val": "this is a test"}, 243 | { 244 | "discriminator": "key", 245 | "mapping": { 246 | "int": {"properties": {"val": {"type": "int32"}}}, 247 | "str": {"properties": {"val": {"type": "string"}}}, 248 | }, 249 | }, 250 | ), 251 | ( 252 | {"key": "int", "val": 123}, 253 | { 254 | "discriminator": "key", 255 | "mapping": { 256 | "int": {"properties": {"val": {"type": "int32"}}}, 257 | "str": {"properties": {"val": {"type": "string"}}}, 258 | }, 259 | }, 260 | ), 261 | ( 262 | {"key": "int", "val_int": 123}, 263 | { 264 | "discriminator": "key", 265 | "mapping": { 266 | "int": {"properties": {"val_int": {"type": "int32"}}}, 267 | "str": {"properties": {"val_str": {"type": "string"}}}, 268 | }, 269 | }, 270 | ), 271 | ( 272 | {"key": "str", "val_str": "asdf"}, 273 | { 274 | "discriminator": "key", 275 | "mapping": { 276 | "int": {"properties": {"val_int": {"type": "int32"}}}, 277 | "str": {"properties": {"val_str": {"type": "string"}}}, 278 | }, 279 | }, 280 | ), 281 | ( 282 | {"key": "str", "val": "this is a test", "something": "else"}, 283 | { 284 | "discriminator": "key", 285 | "mapping": { 286 | "int": {"properties": {"val": {"type": "int32"}}}, 287 | "str": { 288 | "properties": {"val": {"type": "string"}}, 289 | "additionalProperties": True, 290 | }, 291 | }, 292 | }, 293 | ), 294 | ] 295 | 296 | INVALID_JTD = [ 297 | # Ref 298 | (None, {"definitions": {"a": {"type": "float32"}}, "ref": "a", "nullable": False}), 299 | ({"foo": "bar"}, {"definitions": {"a": {"type": "float32"}}, "ref": "a"}), 300 | # Type 301 | (1.23, {"type": "int8"}), 302 | (-2, {"type": "uint8"}), 303 | (None, {"type": "boolean"}), 304 | # Enum 305 | ("BAZ", {"enum": ["FOO", "BAR"]}), 306 | (0, {"enum": ["FOO", "BAR"]}), 307 | ({}, {"enum": ["FOO", "BAR"]}), 308 | (None, {"enum": ["FOO", "BAR"]}), 309 | # Elements 310 | ([1, 2, "foo"], {"elements": {"type": "int32"}}), 311 | (None, {"elements": {"type": "int32"}, "nullable": False}), 312 | ( 313 | [{"foo": 1}, {"foo": 2}], 314 | {"elements": {"properties": {"foo": {"type": "string"}}}}, 315 | ), 316 | # Properties 317 | ({"foo": 123}, {"properties": {"foo": {"type": "string"}}}), 318 | ({"foo": [123]}, {"properties": {"foo": {"elements": {"type": "string"}}}}), 319 | ( 320 | {"bar": "baz"}, 321 | {"properties": {"foo": {"type": "int32"}, "bar": {"type": "string"}}}, 322 | ), 323 | ( 324 | {"bar": "baz"}, 325 | { 326 | "properties": {"foo": {"type": "int32"}}, 327 | "optionalProperties": {"bar": {"type": "string"}}, 328 | }, 329 | ), 330 | ( 331 | {}, 332 | { 333 | "properties": {"foo": {"type": "int32"}}, 334 | "optionalProperties": {"bar": {"type": "string"}}, 335 | }, 336 | ), 337 | ({"buz": 123}, {"optionalProperties": {"bar": {"type": "string"}}}), 338 | ( 339 | {"buz": 123}, 340 | { 341 | "optionalProperties": {"bar": {"type": "string"}}, 342 | "additionalProperties": False, 343 | }, 344 | ), 345 | ({"bar": 123}, {"optionalProperties": {"bar": {"type": "string"}}}), 346 | ([{"foo": 123}], {"properties": {"foo": {"type": "string"}}}), 347 | # Values 348 | ({"foo": 123, "bar": "asdf"}, {"values": {"type": "int32"}}), 349 | ({"foo": {"bar": "test"}}, {"values": {"properties": {"bar": {"type": "int32"}}}}), 350 | # Discriminator 351 | ( 352 | {"key": "str", "val": 123}, 353 | { 354 | "discriminator": "key", 355 | "mapping": { 356 | "int": {"properties": {"val": {"type": "int32"}}}, 357 | "str": {"properties": {"val": {"type": "string"}}}, 358 | }, 359 | }, 360 | ), 361 | ( 362 | {"key": "int", "val": "asdf"}, 363 | { 364 | "discriminator": "key", 365 | "mapping": { 366 | "int": {"properties": {"val": {"type": "int32"}}}, 367 | "str": {"properties": {"val": {"type": "string"}}}, 368 | }, 369 | }, 370 | ), 371 | ( 372 | {"key": "str", "val": "this is a test", "something": "else"}, 373 | { 374 | "discriminator": "key", 375 | "mapping": { 376 | "int": {"properties": {"val": {"type": "int32"}}}, 377 | "str": {"properties": {"val": {"type": "string"}}}, 378 | }, 379 | }, 380 | ), 381 | ( 382 | 123, 383 | { 384 | "discriminator": "key", 385 | "mapping": { 386 | "int": {"properties": {"val": {"type": "int32"}}}, 387 | "str": {"properties": {"val": {"type": "string"}}}, 388 | }, 389 | }, 390 | ), 391 | ] 392 | 393 | 394 | @pytest.mark.parametrize("obj,schema", VALID_JTD) 395 | def test_valid_jtd(obj, schema): 396 | """Test all valid object validations""" 397 | log.debug("Comparing %s to %s", obj, schema) 398 | assert validate_jtd(obj, schema) 399 | 400 | 401 | @pytest.mark.parametrize("obj,schema", INVALID_JTD) 402 | def test_invalid_jtd(obj, schema): 403 | """Test all invalid object validations""" 404 | log.debug("Comparing %s to %s", obj, schema) 405 | assert not validate_jtd(obj, schema) 406 | 407 | 408 | def test_custom_type_validator(): 409 | """Make sure that a custom type validator works as expected""" 410 | assert validate_jtd( 411 | CustomClass(), 412 | {"type": "CustomClass"}, 413 | {"CustomClass": lambda x: isinstance(x, CustomClass)}, 414 | ) 415 | assert not validate_jtd( 416 | 123, 417 | {"type": "CustomClass"}, 418 | {"CustomClass": lambda x: isinstance(x, CustomClass)}, 419 | ) 420 | 421 | 422 | def test_validate_jtd_invalid_schema(): 423 | """Make sure that an invalid schema causes an error in validate_jtd""" 424 | with pytest.raises(ValueError): 425 | validate_jtd({}, {"not": "a valid schema"}) 426 | 427 | 428 | def test_validate_jtd_impl_invalid_schema(): 429 | """COV! Make sure an error is raised if somehow the schema isn't valid""" 430 | with pytest.raises(ValueError): 431 | _validate_jtd_impl({}, {"invalid": "schema"}, {}) 432 | -------------------------------------------------------------------------------- /py_to_proto/dataclass_to_proto.py: -------------------------------------------------------------------------------- 1 | # Standard 2 | from datetime import datetime 3 | from enum import Enum 4 | from typing import Any, Dict, Iterable, List, Optional, Tuple, Union 5 | import dataclasses 6 | 7 | # Third Party 8 | from google.protobuf import any_pb2 9 | from google.protobuf import descriptor as _descriptor 10 | from google.protobuf import descriptor_pool as _descriptor_pool 11 | from google.protobuf import timestamp_pb2 12 | 13 | # First Party 14 | import alog 15 | 16 | # Local 17 | from .compat_annotated import Annotated, get_args, get_origin 18 | from .converter_base import ConverterBase 19 | 20 | log = alog.use_channel("DCLS2P") 21 | 22 | 23 | ## Globals ##################################################################### 24 | 25 | PY_TO_PROTO_TYPES = { 26 | Any: any_pb2.Any, 27 | bool: _descriptor.FieldDescriptor.TYPE_BOOL, 28 | str: _descriptor.FieldDescriptor.TYPE_STRING, 29 | bytes: _descriptor.FieldDescriptor.TYPE_BYTES, 30 | datetime: timestamp_pb2.Timestamp, 31 | float: _descriptor.FieldDescriptor.TYPE_DOUBLE, 32 | # TODO: support more integer types with numpy dtypes 33 | int: _descriptor.FieldDescriptor.TYPE_INT64, 34 | } 35 | 36 | ## Interface ################################################################### 37 | 38 | 39 | class FieldNumber(int): 40 | """A positive number used to identify a field""" 41 | 42 | def __new__(cls, *args, **kwargs): 43 | inst = super().__new__(cls, *args, **kwargs) 44 | if inst <= 0: 45 | raise ValueError("A field number must be a positive integer") 46 | return inst 47 | 48 | 49 | class OneofField(str): 50 | """A field name for an element of a oneof""" 51 | 52 | 53 | def dataclass_to_proto( 54 | package: str, 55 | dataclass_: type, 56 | *, 57 | name: Optional[str] = None, 58 | validate: bool = False, 59 | type_mapping: Optional[Dict[str, Union[int, _descriptor.Descriptor]]] = None, 60 | descriptor_pool: Optional[_descriptor_pool.DescriptorPool] = None, 61 | ) -> _descriptor.Descriptor: 62 | """Convert a dataclass into a set of proto DESCRIPTOR objects. 63 | 64 | Reference: https://docs.python.org/3/library/dataclasses.html#dataclasses.dataclass 65 | 66 | Args: 67 | name: str 68 | The name for the top-level message object 69 | package: str 70 | The proto package name to use for this object 71 | dataclass_: type 72 | The dataclass class 73 | 74 | Kwargs: 75 | validate: bool 76 | Whether or not to validate the class proactively 77 | type_mapping: Optional[Dict[str, Union[int, _descriptor.Descriptor]]] 78 | A non-default mapping from JTD type names to proto types 79 | descriptor_pool: Optional[descriptor_pool.DescriptorPool] 80 | If given, this DescriptorPool will be used to aggregate the set of 81 | message descriptors 82 | 83 | Returns: 84 | descriptor: descriptor.Descriptor 85 | The top-level MessageDescriptor corresponding to this jtd definition 86 | """ 87 | return DataclassConverter( 88 | dataclass_=dataclass_, 89 | package=package, 90 | name=name, 91 | validate=validate, 92 | type_mapping=type_mapping, 93 | descriptor_pool=descriptor_pool, 94 | ).descriptor 95 | 96 | 97 | ## Impl ######################################################################## 98 | 99 | 100 | class DataclassConverter(ConverterBase): 101 | """Converter implementation for dataclasses as the source""" 102 | 103 | def __init__( 104 | self, 105 | dataclass_: type, 106 | package: str, 107 | *, 108 | name: Optional[str] = None, 109 | type_mapping: Optional[Dict[str, Union[int, _descriptor.Descriptor]]] = None, 110 | validate: bool = False, 111 | descriptor_pool: Optional[_descriptor_pool.DescriptorPool] = None, 112 | ): 113 | """Fill in the default type mapping and additional default vals, then 114 | initialize the parent 115 | """ 116 | type_mapping = type_mapping or PY_TO_PROTO_TYPES 117 | name = name or getattr(dataclass_, "__name__", "") 118 | super().__init__( 119 | name=name, 120 | package=package, 121 | source_schema=dataclass_, 122 | type_mapping=type_mapping, 123 | validate=validate, 124 | descriptor_pool=descriptor_pool, 125 | ) 126 | 127 | ## Abstract Interface ###################################################### 128 | 129 | def validate(self, source_schema: type) -> bool: 130 | """Perform preprocess validation of the input""" 131 | if not dataclasses.is_dataclass(source_schema) and not ( 132 | isinstance(source_schema, type) and issubclass(source_schema, Enum) 133 | ): 134 | return False 135 | # TODO: More validation! 136 | return True 137 | 138 | ## Types ## 139 | 140 | def get_concrete_type(self, entry: Any) -> Any: 141 | """If this is a concrete type, get the type map key for it""" 142 | # Unwrap any Annotations 143 | entry_type = self._resolve_wrapped_type(entry) 144 | 145 | # If it's a known type, just return it 146 | if entry_type in self.type_mapping or isinstance( 147 | entry_type, (_descriptor.Descriptor, _descriptor.EnumDescriptor) 148 | ): 149 | return entry_type 150 | 151 | # If it's a type with a descriptor, return that descriptor 152 | descriptor_attr = getattr(entry_type, "DESCRIPTOR", None) 153 | if descriptor_attr is not None: 154 | return descriptor_attr 155 | 156 | ## Maps ## 157 | 158 | def get_map_key_val_types( 159 | self, 160 | entry: Any, 161 | ) -> Optional[Tuple[int, ConverterBase.ConvertOutputTypes]]: 162 | """Get the key and value types for a given map type""" 163 | if get_origin(entry) is dict: 164 | key_type, val_type = get_args(entry) 165 | return ( 166 | self._convert(key_type, name="key"), 167 | self._convert(val_type, name="value"), 168 | ) 169 | 170 | ## Enums ## 171 | 172 | def get_enum_vals(self, entry: Any) -> Optional[Iterable[Tuple[str, int]]]: 173 | """Get the ordered list of enum name -> number mappings if this entry is 174 | an enum 175 | 176 | NOTE: If any values appear multiple times, this implies an alias 177 | 178 | NOTE 2: All names must be unique 179 | """ 180 | if isinstance(entry, type) and issubclass(entry, Enum): 181 | values = [(name, val.value) for name, val in entry.__members__.items()] 182 | # NOTE: proto3 _requires_ a placeholder 0-value for every enum that 183 | # is the equivalent of unset. Some enums may do this intentionally 184 | # while others won't, so we add one in here if not in the python 185 | # version. 186 | if 0 not in [entry[1] for entry in values]: 187 | log.debug3("Adding placeholder 0-val for enum %s", entry) 188 | values = [("PLACEHOLDER_UNSET", 0)] + values 189 | return values 190 | 191 | ## Messages ## 192 | 193 | def get_message_fields(self, entry: Any) -> Optional[Iterable[Tuple[str, Any]]]: 194 | """Get the mapping of names to type-specific field descriptors if this 195 | entry is a message 196 | """ 197 | if dataclasses.is_dataclass(entry): 198 | return entry.__dataclass_fields__.items() 199 | 200 | def has_additional_fields(self, entry: Any) -> bool: 201 | """Check whether the given entry expects to support arbitrary key/val 202 | additional properties 203 | """ 204 | # There's no way to do additional keys with a dataclass 205 | return False 206 | 207 | def get_optional_field_names(self, entry: Any) -> List[str]: 208 | """Get the names of any fields which are explicitly marked 'optional'. 209 | 210 | For a dataclass this means looking at the types of the members for ones 211 | that either have default values. Fields marked as Optional that do not 212 | have default values are NOT considered optional since they are required 213 | in the __init__. 214 | """ 215 | return [ 216 | field_name 217 | for field_name, field in entry.__dataclass_fields__.items() 218 | if ( 219 | field.default is not dataclasses.MISSING 220 | or field.default_factory is not dataclasses.MISSING 221 | ) 222 | ] 223 | 224 | ## Fields ## 225 | 226 | def get_field_number( 227 | self, 228 | num_fields: int, 229 | field_def: Union[dataclasses.Field, type], 230 | ) -> int: 231 | """From the given field definition and index, get the proto field number 232 | from any metadata in the field definition and fall back to the next 233 | sequential value 234 | """ 235 | field_type = ( 236 | field_def.type if isinstance(field_def, dataclasses.Field) else field_def 237 | ) 238 | field_num = self._get_unique_annotation(field_type, FieldNumber) 239 | if field_num is not None: 240 | return field_num 241 | return num_fields + 1 242 | 243 | def get_oneof_fields( 244 | self, field_def: dataclasses.Field 245 | ) -> Optional[Iterable[Tuple[str, Any]]]: 246 | """If the given field is a Union, return an iterable of the sub-field 247 | definitions for its 248 | """ 249 | field_type = self._resolve_wrapped_type(field_def.type) 250 | oneof_fields = [] 251 | if get_origin(field_type) is Union: 252 | for arg in get_args(field_type): 253 | oneof_field_name = self._get_unique_annotation(arg, OneofField) 254 | res_type = self._resolve_wrapped_type(arg) 255 | # handle list type separately 256 | if get_origin(res_type) is list: 257 | assert get_args( 258 | res_type 259 | ), f"List {arg} does not have any type argument" 260 | field_type = get_args(res_type)[0] 261 | oneof_field_name = oneof_field_name or ( 262 | f"{field_def.name}_{str(field_type.__name__)}_sequence".lower() 263 | ) 264 | arg = dataclasses.make_dataclass( 265 | f"{field_def.name.capitalize()}{str(field_type.__name__).capitalize()}Sequence", 266 | [("values", List[field_type])], 267 | ) 268 | elif oneof_field_name is None: 269 | oneof_field_name = ( 270 | f"{field_def.name}_{str(res_type.__name__)}".lower() 271 | ) 272 | log.debug3("Using default oneof field name: %s", oneof_field_name) 273 | oneof_fields.append((oneof_field_name, arg)) 274 | 275 | # here it's not a union, but it's still annotated. 276 | # Special case in which we only have one field in the Union 277 | # but we still want to create a one-of in case OneofField is present 278 | # see https://github.com/IBM/py-to-proto/issues/63 279 | elif get_origin(field_def.type) is Annotated and any( 280 | type(arg) is OneofField for arg in get_args(field_def.type) 281 | ): 282 | # it can only be 1 arg, hence no need to iterate through the args 283 | oneof_field_name = self._get_unique_annotation(field_def.type, OneofField) 284 | assert ( 285 | len(oneof_field_name) > 0 286 | ), "Got OneofField annotation without any name?" 287 | 288 | log.debug3("Using oneof field name: %s", oneof_field_name) 289 | oneof_fields.append((oneof_field_name, field_def.type)) 290 | return oneof_fields 291 | 292 | def get_oneof_name(self, field_def: dataclasses.Field) -> str: 293 | """For an identified oneof field def, get the name""" 294 | return field_def.name 295 | 296 | def get_field_type(self, field_def: dataclasses.Field) -> Any: 297 | """Get the type of the field. The definition of type here will be 298 | specific to the converter (e.g. string for JTD, py type for dataclass) 299 | """ 300 | field_type = self._resolve_wrapped_type(field_def.type) 301 | if get_origin(field_type) is list: 302 | args = get_args(field_type) 303 | if len(args) == 1: 304 | return args[0] 305 | return field_type 306 | 307 | def is_repeated_field(self, field_def: dataclasses.Field) -> bool: 308 | """Determine if the given field def is repeated""" 309 | return get_origin(self._resolve_wrapped_type(field_def.type)) is list 310 | 311 | ## Implementation Details ################################################## 312 | 313 | @classmethod 314 | def _resolve_wrapped_type(cls, field_type: type) -> type: 315 | """Unwrap the type inside an Annotated or Optional, or just return the 316 | type if not wrapped 317 | """ 318 | origin = get_origin(field_type) 319 | args = get_args(field_type) 320 | 321 | # Unwrap Annotated and recurse in case it's an Annotated[Optional] 322 | if origin is Annotated: 323 | return cls._resolve_wrapped_type(args[0]) 324 | 325 | # Unwrap Optional and recurse in case it's an Optional[Annotated] 326 | if origin is Union and type(None) in args: 327 | non_none_args = [arg for arg in args if arg is not type(None)] 328 | assert non_none_args, f"Cannot have a union with only one NoneType arg" 329 | if len(non_none_args) > 1: 330 | res_type = Union.__getitem__(tuple(non_none_args)) 331 | else: 332 | res_type = non_none_args[0] 333 | return cls._resolve_wrapped_type(res_type) 334 | 335 | # If not Annotated or Optional, return as is 336 | return field_type 337 | 338 | @staticmethod 339 | def _get_annotations(field_type: type, annotation_type: type) -> List: 340 | """Get all annotations of the given annotation type from the given field 341 | type if it's annotated 342 | """ 343 | if get_origin(field_type) is Annotated: 344 | return [ 345 | arg 346 | for arg in get_args(field_type)[1:] 347 | if isinstance(arg, annotation_type) 348 | ] 349 | return [] 350 | 351 | @classmethod 352 | def _get_unique_annotation( 353 | cls, field_type: type, annotation_type: type 354 | ) -> Optional[Any]: 355 | """Get any annotations of the given annotation type and ensure they're 356 | unique 357 | """ 358 | annos = cls._get_annotations(field_type, annotation_type) 359 | if annos: 360 | if len(annos) > 1: 361 | raise ValueError(f"Multiple {annotation_type} annotations found") 362 | return annos[0] 363 | -------------------------------------------------------------------------------- /py_to_proto/validation.py: -------------------------------------------------------------------------------- 1 | """ 2 | This module implements recursive JTD schema validation 3 | 4 | https://www.rfc-editor.org/rfc/rfc8927 5 | """ 6 | # Standard 7 | from datetime import datetime 8 | from typing import Any, Callable, Dict, List, Optional 9 | 10 | # Third Party 11 | from google.protobuf import descriptor as _descriptor 12 | 13 | # First Party 14 | import alog 15 | 16 | log = alog.use_channel("JTD2P") 17 | 18 | # Map of type validators for standard JTD types 19 | JTD_TYPE_VALIDATORS = { 20 | "boolean": lambda x: isinstance(x, bool), 21 | "float32": lambda x: isinstance(x, (int, float)), 22 | "float64": lambda x: isinstance(x, (int, float)), 23 | "int8": lambda x: isinstance(x, int), 24 | "uint8": lambda x: isinstance(x, int) and x >= 0, 25 | "int16": lambda x: isinstance(x, int), 26 | "uint16": lambda x: isinstance(x, int) and x >= 0, 27 | "int32": lambda x: isinstance(x, int), 28 | "uint32": lambda x: isinstance(x, int) and x >= 0, 29 | "string": lambda x: isinstance(x, str), 30 | "timestamp": lambda x: isinstance(x, datetime), 31 | } 32 | 33 | # List of standard type names 34 | JTD_TYPES = list(JTD_TYPE_VALIDATORS.keys()) 35 | 36 | 37 | def is_valid_jtd( 38 | schema: Dict[str, Any], valid_types: Optional[List[str]] = None 39 | ) -> bool: 40 | """Determine whether the given dict represents a valid JTD schema 41 | 42 | Args: 43 | schema (Dict[str, Any]) 44 | The candidate schema for validation 45 | valid_types (Optional[List[str]]) 46 | List of valid type name strings. This defaults to the standard JTD 47 | types, but can be changed/extended to support additional types 48 | 49 | Returns: 50 | is_valid (bool) 51 | True if the schema is valid, False otherwise 52 | """ 53 | valid_types = valid_types or JTD_TYPES 54 | return _is_valid_jtd_impl(schema, valid_types, is_root_schema=True) 55 | 56 | 57 | def validate_jtd( 58 | obj: Any, 59 | schema: Dict[str, Any], 60 | type_validators: Optional[Dict[str, Callable[[Any], bool]]] = None, 61 | ) -> bool: 62 | """Validate the given object against the given schema 63 | 64 | Args: 65 | obj (Any) 66 | The candidate object to validate 67 | schema (Dict[str, Any]) 68 | The schema to validate against 69 | type_validators (Optional[Dict[str, Callable[[Any], bool]]]) 70 | Mapping from types string names to validation functions 71 | 72 | Returns: 73 | is_valid (bool) 74 | True if the object matches the schema, False otherwise 75 | """ 76 | type_validators = type_validators or JTD_TYPE_VALIDATORS 77 | if not is_valid_jtd(schema, type_validators.keys()): 78 | raise ValueError(f"Invalid schema: {schema}") 79 | return _validate_jtd_impl(obj, schema, type_validators, is_root_schema=True) 80 | 81 | 82 | ## Implementation ############################################################## 83 | 84 | _SHARED_KEYS = {"nullable", "metadata", "definitions"} 85 | 86 | 87 | def _is_string_key_dict(value: Any) -> bool: 88 | return isinstance(value, dict) and all(isinstance(key, str) for key in value) 89 | 90 | 91 | def _is_valid_jtd_impl( 92 | schema: Dict[str, Any], 93 | valid_types: List[str], 94 | definitions: Optional[Dict[str, Any]] = None, 95 | *, 96 | is_root_schema: bool = False, 97 | ) -> bool: 98 | """Recursive implementation of schema validation""" 99 | 100 | # Make sure it is a dict with string keys 101 | if not _is_string_key_dict(schema): 102 | log.debug4("Invalid jtd: Not a dict with string keys") 103 | return False 104 | 105 | # Check for metadata and/or nullable keywords which any form can contain 106 | if not isinstance(schema.get("nullable", False), bool): 107 | log.debug4("Invalid jtd: Found non-bool 'nullable'") 108 | return False 109 | if not _is_string_key_dict(schema.get("metadata", {})): 110 | log.debug4("Invalid jtd: Found 'metadata' that is not a dict of strings") 111 | return False 112 | 113 | # Definitions (2.1) 114 | definitions = definitions or {} 115 | if is_root_schema: 116 | definitions = schema.get("definitions", {}) 117 | if not _is_string_key_dict(definitions): 118 | log.debug4("Invalid jtd: Found 'definitions' that is not a dict of strings") 119 | return False 120 | # TODO: Can definitions refer to _other_ definitions? The RFC is 121 | # ambiguous here, so I think it should _technically_ be possible, but 122 | # for our sake, we won't allow it for now. 123 | if any( 124 | not _is_valid_jtd_impl(val, valid_types) for val in definitions.values() 125 | ): 126 | log.debug4("Invalid jtd: Found 'definitions' value that is not valid jtd") 127 | return False 128 | elif "definitions" in schema: 129 | log.debug4("Found 'definitions' in non-root schema") 130 | return False 131 | 132 | # Get the set of keys in this schema with universal keys removed 133 | schema_keys = set(schema.keys()) - _SHARED_KEYS 134 | 135 | # Empty (2.2.1) 136 | if schema_keys == set(): 137 | return True 138 | 139 | # Ref (2.2.2) 140 | if schema_keys == {"ref"}: 141 | ref_val = schema["ref"] 142 | if not isinstance(ref_val, str) or ref_val not in definitions: 143 | log.debug4("Invalid jtd: Bad reference <%s>", ref_val) 144 | return False 145 | return True 146 | 147 | # Type (2.2.3) 148 | if schema_keys == {"type"}: 149 | type_val = schema["type"] 150 | if ( 151 | # All protobuf descriptors are "special" cases 152 | not isinstance(type_val, _descriptor.Descriptor) 153 | and 154 | # All non-descriptor types must be valid types 155 | (not isinstance(type_val, str) or (type_val not in valid_types)) 156 | ): 157 | log.debug4("Invalid jtd: Bad type <%s>", type_val) 158 | return False 159 | return True 160 | 161 | # Enum (2.2.4) 162 | if schema_keys == {"enum"}: 163 | enum_val = schema["enum"] 164 | if ( 165 | not isinstance(enum_val, list) # Must be a list 166 | or not enum_val # Must be non-empty 167 | or len(set(enum_val)) != len(enum_val) # Must have no duplicate entries 168 | ): 169 | log.debug4("Invalid jtd: Bad enum <%s>", enum_val) 170 | return False 171 | return True 172 | 173 | # Elements (2.2.5) 174 | if schema_keys == {"elements"}: 175 | elements_val = schema["elements"] 176 | if not _is_valid_jtd_impl(elements_val, valid_types, definitions): 177 | log.debug4("Invalid jtd: Bad elements <%s>", elements_val) 178 | return False 179 | return True 180 | 181 | # Properties (2.2.6) 182 | if "properties" in schema_keys or "optionalProperties" in schema_keys: 183 | properties_val = schema.get("properties", {}) 184 | opt_properties_val = schema.get("optionalProperties", {}) 185 | if ( 186 | # No extra keys beyond additionalProperties 187 | schema_keys - {"properties", "optionalProperties", "additionalProperties"} 188 | # additionalProperties must be a bool 189 | or not isinstance(schema.get("additionalProperties", False), bool) 190 | # String dict properties 191 | or not _is_string_key_dict(properties_val) 192 | # String dict optionalProperties 193 | or not _is_string_key_dict(opt_properties_val) 194 | # Non-empty 195 | or (not properties_val and not opt_properties_val) 196 | # No overlapping keys 197 | or set(properties_val.keys()).intersection(opt_properties_val.keys()) 198 | # Valid properties definitions 199 | or any( 200 | not _is_valid_jtd_impl(val, valid_types, definitions) 201 | for val in properties_val.values() 202 | ) 203 | # Valid optionalProperties definitions 204 | or any( 205 | not _is_valid_jtd_impl(val, valid_types, definitions) 206 | for val in opt_properties_val.values() 207 | ) 208 | ): 209 | log.debug4( 210 | "Invalid jtd: Bad properties <%s> / optionalProperties <%s>", 211 | properties_val, 212 | opt_properties_val, 213 | ) 214 | return False 215 | return True 216 | 217 | # Values (2.2.7) 218 | if schema_keys == {"values"}: 219 | values_val = schema["values"] 220 | if not _is_valid_jtd_impl(values_val, valid_types, definitions): 221 | log.debug4("Invalid jtd: Bad 'values' <%s>", values_val) 222 | return False 223 | return True 224 | 225 | # Discriminator (2.2.8) 226 | if schema_keys == {"discriminator", "mapping"}: 227 | discriminator_val = schema["discriminator"] 228 | mapping_val = schema["mapping"] 229 | nullable = schema.get("nullable", False) 230 | if ( 231 | # Discriminator is a string 232 | not isinstance(discriminator_val, str) 233 | # Mapping is a string dict 234 | or not _is_string_key_dict(mapping_val) 235 | # Mapping entries are valid JTD 236 | or any( 237 | not _is_valid_jtd_impl(val, valid_types, definitions) 238 | for val in mapping_val.values() 239 | ) 240 | # Mapping entries are of the "properties" form 241 | or any( 242 | "properties" not in val and "optionalProperties" not in val 243 | for val in mapping_val.values() 244 | ) 245 | # Mapping entry "nullable" matches discriminator "nullable" 246 | or any( 247 | val.get("nullable", False) != nullable for val in mapping_val.values() 248 | ) 249 | # Discriminator must not shadow properties in mapping elements 250 | or discriminator_val 251 | in set.union( 252 | *[ 253 | set(entry.get("properties", {}).keys()) 254 | for entry in mapping_val.values() 255 | ], 256 | *[ 257 | set(entry.get("optionalProperties", {}).keys()) 258 | for entry in mapping_val.values() 259 | ], 260 | ) 261 | ): 262 | log.debug4( 263 | "Invalid jtd: Bad discriminator <%s> / mapping <%s>", 264 | discriminator_val, 265 | mapping_val, 266 | ) 267 | return False 268 | return True 269 | 270 | # All other sets of keys are invalid 271 | log.debug4("Invalid jtd: Bad key set <%s>", schema_keys) 272 | return False 273 | 274 | 275 | def _validate_jtd_impl( 276 | obj: Any, 277 | schema: Dict[str, Any], 278 | type_validators: Dict[str, Callable[[Any], bool]], 279 | definitions: Optional[Dict[str, Any]] = None, 280 | *, 281 | is_root_schema: bool = False, 282 | ): 283 | """Recursive validation implementation""" 284 | 285 | # Pull out common definitions from the root that will be passed along 286 | # everywhere 287 | definitions = definitions or {} 288 | if is_root_schema: 289 | definitions = schema.get("definitions", {}) 290 | 291 | # Check to see if this schema is null and nullable 292 | if obj is None and schema.get("nullable", False): 293 | return True 294 | 295 | # Get the set of keys in this schema with universal keys removed 296 | schema_keys = set(schema.keys()) - _SHARED_KEYS 297 | 298 | # Empty (3.3.1) 299 | if not schema_keys: 300 | return True 301 | 302 | # Ref (3.3.2) 303 | if schema_keys == {"ref"}: 304 | ref_val = schema["ref"] 305 | if not _validate_jtd_impl( 306 | obj, definitions[ref_val], type_validators, definitions 307 | ): 308 | log.debug4("Invalid value <%s> or ref <%s>", obj, ref_val) 309 | return False 310 | return True 311 | 312 | # Type (3.3.3) 313 | if schema_keys == {"type"}: 314 | type_val = schema["type"] 315 | validator = type_validators.get(type_val) 316 | if not (validator is not None and validator(obj)): 317 | log.debug4("Invalid value <%s> for type <%s>", obj, type_val) 318 | return False 319 | return True 320 | 321 | # Enum (3.3.4) 322 | if schema_keys == {"enum"}: 323 | enum_vals = schema["enum"] 324 | if obj not in enum_vals: 325 | log.debug4("Invalid enum value <%s> for enum <%s>", obj, enum_vals) 326 | return False 327 | return True 328 | 329 | # Elements (3.3.5) 330 | if schema_keys == {"elements"}: 331 | element_schema = schema["elements"] 332 | if not isinstance(obj, list) or any( 333 | not _validate_jtd_impl(entry, element_schema, type_validators, definitions) 334 | for entry in obj 335 | ): 336 | log.debug4( 337 | "Invalid elements value <%s> for element schema <%s>", 338 | obj, 339 | element_schema, 340 | ) 341 | return False 342 | return True 343 | 344 | # Properties (3.3.6) 345 | if "properties" in schema_keys or "optionalProperties" in schema_keys: 346 | if not _is_string_key_dict(obj): 347 | log.debug4("Invalid properties <%s> is not a string key dict", obj) 348 | return False 349 | schema_properties = schema.get("properties", {}) 350 | schema_opt_properties = schema.get("optionalProperties", {}) 351 | if any( 352 | prop not in obj 353 | or not _validate_jtd_impl( 354 | obj[prop], 355 | prop_schema, 356 | type_validators, 357 | definitions, 358 | ) 359 | for prop, prop_schema in schema_properties.items() 360 | ): 361 | log.debug4( 362 | "Invalid properties <%s> for properties %s", obj, schema_properties 363 | ) 364 | return False 365 | if any( 366 | prop in obj 367 | and not _validate_jtd_impl( 368 | obj[prop], prop_schema, type_validators, definitions 369 | ) 370 | for prop, prop_schema in schema_opt_properties.items() 371 | ): 372 | log.debug4( 373 | "Invalid optional properties <%s> for optional properties %s", 374 | obj, 375 | schema_opt_properties, 376 | ) 377 | return False 378 | all_props = set.union( 379 | set(schema_properties.keys()), schema_opt_properties.keys() 380 | ) 381 | if ( 382 | not schema.get("additionalProperties", False) 383 | and set(obj.keys()) - all_props - _SHARED_KEYS 384 | ): 385 | log.debug4("Invalid additional properties in <%s> for %s", obj, schema) 386 | return False 387 | return True 388 | 389 | # Values (3.3.7) 390 | if schema_keys == {"values"}: 391 | value_schema = schema["values"] 392 | if not _is_string_key_dict(obj) or any( 393 | not _validate_jtd_impl(entry, value_schema, type_validators, definitions) 394 | for entry in obj.values() 395 | ): 396 | log.debug4("Invalid values <%s> for values schema <%s>", obj, value_schema) 397 | return False 398 | return True 399 | 400 | # Discriminator (3.3.8) 401 | if schema_keys == {"discriminator", "mapping"}: 402 | if not _is_string_key_dict(obj): 403 | log.debug4("Invalid discriminator <%s> which is not a string key dict", obj) 404 | return False 405 | schema_discriminator = schema["discriminator"] 406 | schema_mapping = schema["mapping"] 407 | discriminator_val = obj.get(schema_discriminator) 408 | if discriminator_val not in schema_mapping or not _validate_jtd_impl( 409 | {key: val for key, val in obj.items() if key != schema_discriminator}, 410 | schema_mapping[discriminator_val], 411 | type_validators, 412 | definitions, 413 | ): 414 | log.debug4("Invalid discriminator <%s> for schema %s", obj, schema) 415 | return False 416 | return True 417 | 418 | # Since the schema must be valid, we should never get here! 419 | raise ValueError(f"Programming Error: unhandled schema {schema}") 420 | -------------------------------------------------------------------------------- /tests/test_json_to_service.py: -------------------------------------------------------------------------------- 1 | """ 2 | Tests for json_to_service functions 3 | """ 4 | # Standard 5 | from concurrent import futures 6 | from contextlib import contextmanager 7 | from typing import Type 8 | import os 9 | import types 10 | 11 | # Third Party 12 | import grpc 13 | import pytest 14 | import tls_test_tools 15 | 16 | # Local 17 | from py_to_proto import descriptor_to_message_class 18 | from py_to_proto.json_to_service import GRPCService, json_to_service 19 | from py_to_proto.jtd_to_proto import jtd_to_proto 20 | 21 | ## Helpers ##################################################################### 22 | 23 | 24 | @pytest.fixture 25 | def foo_message(temp_dpool): 26 | """Message foo fixture""" 27 | # google.protobuf.message.Message 28 | return descriptor_to_message_class( 29 | jtd_to_proto( 30 | "Foo", 31 | "foo.bar", 32 | { 33 | "properties": { 34 | "foo": {"type": "boolean"}, 35 | }, 36 | "optionalProperties": { 37 | "bar": {"type": "float32"}, 38 | }, 39 | }, 40 | descriptor_pool=temp_dpool, 41 | ) 42 | ) 43 | 44 | 45 | @pytest.fixture 46 | def bar_message(temp_dpool): 47 | """Message bar fixture""" 48 | # google.protobuf.message.Message 49 | return descriptor_to_message_class( 50 | jtd_to_proto( 51 | "Bar", 52 | "foo.bar", 53 | { 54 | "properties": { 55 | "boo": {"type": "int32"}, 56 | "baz": {"type": "boolean"}, 57 | } 58 | }, 59 | descriptor_pool=temp_dpool, 60 | ) 61 | ) 62 | 63 | 64 | @pytest.fixture 65 | def foo_service_json(): 66 | return { 67 | "service": { 68 | "rpcs": [ 69 | { 70 | "name": "FooPredict", 71 | "input_type": "foo.bar.Foo", 72 | "output_type": "foo.bar.Bar", 73 | } 74 | ] 75 | } 76 | } 77 | 78 | 79 | @pytest.fixture 80 | def foo_service(temp_dpool, foo_message, bar_message, foo_service_json): 81 | """Service descriptor fixture""" 82 | return json_to_service( 83 | package="foo.bar", 84 | name="FooService", 85 | json_service_def=foo_service_json, 86 | descriptor_pool=temp_dpool, 87 | ) 88 | 89 | 90 | @contextmanager 91 | def _test_server_client( 92 | grpc_service: GRPCService, 93 | servicer_impl_class: Type, 94 | ): 95 | # Boot the server 96 | server = grpc.server(futures.ThreadPoolExecutor(max_workers=50)) 97 | grpc_service.registration_function(servicer_impl_class(), server) 98 | open_port = tls_test_tools.open_port() 99 | server.add_insecure_port(f"[::]:{open_port}") 100 | server.start() 101 | 102 | # Create the client-side connection 103 | chan = grpc.insecure_channel(f"localhost:{open_port}") 104 | my_stub = grpc_service.client_stub_class(chan) 105 | 106 | yield my_stub 107 | 108 | server.stop(grace=0) 109 | 110 | 111 | ## Tests ####################################################################### 112 | 113 | 114 | def test_json_to_service_descriptor(temp_dpool, foo_message, bar_message): 115 | """Ensure that json can be converted to service descriptor""" 116 | 117 | service_json = { 118 | "service": { 119 | "rpcs": [ 120 | { 121 | "name": "FooTrain", 122 | "input_type": "foo.bar.Foo", 123 | "output_type": "foo.bar.Bar", 124 | }, 125 | { 126 | "name": "FooPredict", 127 | "input_type": "foo.bar.Foo", 128 | "output_type": "foo.bar.Foo", 129 | }, 130 | ] 131 | } 132 | } 133 | # _descriptor.ServiceDescriptor 134 | service = json_to_service( 135 | package="foo.bar", 136 | name="FooService", 137 | json_service_def=service_json, 138 | descriptor_pool=temp_dpool, 139 | ) 140 | # Validate message naming 141 | assert service.descriptor.name == "FooService" 142 | assert len(service.descriptor.methods) == 2 143 | 144 | 145 | def test_duplicate_services_are_okay(temp_dpool, foo_message, bar_message): 146 | """Ensure that json can be converted to service descriptor multiple times""" 147 | 148 | service_json = { 149 | "service": { 150 | "rpcs": [ 151 | { 152 | "name": "FooTrain", 153 | "input_type": "foo.bar.Foo", 154 | "output_type": "foo.bar.Bar", 155 | }, 156 | { 157 | "name": "FooPredict", 158 | "input_type": "foo.bar.Foo", 159 | "output_type": "foo.bar.Foo", 160 | }, 161 | ] 162 | } 163 | } 164 | # _descriptor.ServiceDescriptor 165 | service = json_to_service( 166 | package="foo.bar", 167 | name="FooService", 168 | json_service_def=service_json, 169 | descriptor_pool=temp_dpool, 170 | ) 171 | 172 | another_service = json_to_service( 173 | package="foo.bar", 174 | name="FooService", 175 | json_service_def=service_json, 176 | descriptor_pool=temp_dpool, 177 | ) 178 | assert service.descriptor == another_service.descriptor 179 | 180 | 181 | ORIGINAL_SERVICE = { 182 | "service": { 183 | "rpcs": [ 184 | { 185 | "name": "FooTrain", 186 | "input_type": "foo.bar.Foo", 187 | "output_type": "foo.bar.Bar", 188 | } 189 | ] 190 | } 191 | } 192 | INVALID_DUPLICATE_SERVICES = [ 193 | { 194 | "service": { 195 | "rpcs": [ 196 | { 197 | "name": "FooPredict", # Different method name 198 | "input_type": "foo.bar.Foo", 199 | "output_type": "foo.bar.Foo", 200 | } 201 | ] 202 | } 203 | }, 204 | { 205 | "service": { 206 | "rpcs": [ 207 | { 208 | "name": "FooTrain", 209 | "input_type": "foo.bar.Bar", # Different input 210 | "output_type": "foo.bar.Bar", 211 | } 212 | ] 213 | } 214 | }, 215 | { 216 | "service": { 217 | "rpcs": [ 218 | { 219 | "name": "FooTrain", 220 | "input_type": "foo.bar.Foo", 221 | "output_type": "foo.bar.Foo", # Different output 222 | } 223 | ] 224 | } 225 | }, 226 | { 227 | "service": { 228 | "rpcs": [ 229 | { 230 | "name": "FooTrain", 231 | "input_type": "foo.bar.Foo", 232 | "output_type": "foo.bar.Bar", 233 | "client_streaming": True, # Different client streaming 234 | } 235 | ] 236 | } 237 | }, 238 | { 239 | "service": { 240 | "rpcs": [ 241 | { 242 | "name": "FooTrain", 243 | "input_type": "foo.bar.Foo", 244 | "output_type": "foo.bar.Bar", 245 | "server_streaming": True, # Different server streaming 246 | } 247 | ] 248 | } 249 | }, 250 | ] 251 | 252 | 253 | @pytest.mark.parametrize("schema", INVALID_DUPLICATE_SERVICES) 254 | def test_multiple_services_with_the_same_name_are_not_okay( 255 | schema, temp_dpool, foo_message, bar_message 256 | ): 257 | """Ensure that json can be converted to service descriptor""" 258 | 259 | json_to_service( 260 | package="foo.bar", 261 | name="FooService", 262 | json_service_def=ORIGINAL_SERVICE, 263 | descriptor_pool=temp_dpool, 264 | ) 265 | 266 | with pytest.raises(TypeError): 267 | json_to_service( 268 | package="foo.bar", 269 | name="FooService", 270 | json_service_def=schema, 271 | descriptor_pool=temp_dpool, 272 | ) 273 | 274 | 275 | def test_json_to_service_input_validation(temp_dpool, foo_message): 276 | """Make sure that an error is raised if the service definition is invalid""" 277 | # This def is missing the `input_type` field 278 | service_json = { 279 | "service": { 280 | "rpcs": [ 281 | { 282 | "name": "FooPredict", 283 | "output_type": "foo.bar.Foo", 284 | } 285 | ] 286 | } 287 | } 288 | with pytest.raises(ValueError) as excinfo: 289 | json_to_service( 290 | package="foo.bar", 291 | name="FooService", 292 | json_service_def=service_json, 293 | descriptor_pool=temp_dpool, 294 | ) 295 | assert "Invalid service json" in str(excinfo.value) 296 | 297 | 298 | def test_service_descriptor_to_service(foo_service): 299 | """Ensure that service class can be created from service descriptor""" 300 | ServiceClass = foo_service.service_class 301 | 302 | assert hasattr(ServiceClass, "FooPredict") 303 | assert ServiceClass.__name__ == foo_service.descriptor.name 304 | 305 | 306 | def test_services_can_be_written_to_protobuf_files(foo_service, tmp_path): 307 | """Ensure that service class can be created from service descriptor""" 308 | ServiceClass = foo_service.service_class 309 | 310 | assert hasattr(ServiceClass, "to_proto_file") 311 | assert hasattr(ServiceClass, "write_proto_file") 312 | 313 | tempdir = str(tmp_path) 314 | ServiceClass.write_proto_file(tempdir) 315 | assert "fooservice.proto" in os.listdir(tempdir) 316 | with open(os.path.join(tempdir, "fooservice.proto"), "r") as f: 317 | assert "service FooService {" in f.read() 318 | 319 | 320 | def test_service_descriptor_to_client_stub(foo_service): 321 | """Ensure that client stub can be created from service descriptor""" 322 | stub_class = foo_service.client_stub_class 323 | assert hasattr(stub_class(grpc.insecure_channel("localhost:9000")), "FooPredict") 324 | assert stub_class.__name__ == "FooServiceStub" 325 | 326 | 327 | def test_service_descriptor_to_registration_function(foo_service): 328 | """Ensure that server registration function can be created from service descriptor""" 329 | 330 | registration_fn = foo_service.registration_function 331 | assert isinstance(registration_fn, types.FunctionType) 332 | 333 | server = grpc.server(futures.ThreadPoolExecutor(max_workers=50)) 334 | service_class = foo_service.service_class 335 | 336 | registration_fn(service_class(), server) 337 | 338 | # GORP 339 | assert ( 340 | "/foo.bar.FooService/FooPredict" 341 | in server._state.generic_handlers[0]._method_handlers 342 | ) 343 | 344 | 345 | def test_end_to_end_unary_unary_integration( 346 | foo_message, bar_message, foo_service, temp_dpool 347 | ): 348 | """Test a full grpc service integration""" 349 | # Define and start a gRPC service 350 | class Servicer(foo_service.service_class): 351 | """gRPC Service Impl""" 352 | 353 | def FooPredict(self, request, context): 354 | # Test that the `optionalProperty` "bar" of the request can be checked for existence 355 | if request.foo: 356 | assert request.HasField("bar") 357 | else: 358 | assert not request.HasField("bar") 359 | return bar_message(boo=42, baz=True) 360 | 361 | with _test_server_client(foo_service, Servicer) as client: 362 | # nb: we'll set "foo" to the existence of "bar" to put asserts in the request handler 363 | input = foo_message(foo=True, bar=-9000) 364 | 365 | # Make a gRPC call 366 | response = client.FooPredict(request=input) 367 | assert isinstance(response, bar_message) 368 | assert response.boo == 42 369 | assert response.baz 370 | 371 | # Test that we can not set `bar` and correctly check that it was not set on the server side 372 | input = foo_message(foo=False) 373 | response = client.FooPredict(request=input) 374 | assert isinstance(response, bar_message) 375 | 376 | 377 | def test_end_to_end_server_streaming_integration(foo_message, bar_message, temp_dpool): 378 | service_json = { 379 | "service": { 380 | "rpcs": [ 381 | { 382 | "name": "FooPredict", 383 | "input_type": "foo.bar.Foo", 384 | "output_type": "foo.bar.Bar", 385 | "server_streaming": True, 386 | } 387 | ] 388 | } 389 | } 390 | service = json_to_service( 391 | package="foo.bar", 392 | name="FooService", 393 | json_service_def=service_json, 394 | descriptor_pool=temp_dpool, 395 | ) 396 | 397 | class Servicer(service.service_class): 398 | """gRPC Service Impl""" 399 | 400 | def FooPredict(self, request, context): 401 | return iter(map(lambda i: bar_message(boo=i, baz=True), range(100))) 402 | 403 | with _test_server_client(service, Servicer) as client: 404 | for i, bar in enumerate(client.FooPredict(request=foo_message())): 405 | assert bar.boo == i 406 | assert i == 99 407 | 408 | 409 | def test_end_to_end_client_streaming_integration(foo_message, bar_message, temp_dpool): 410 | service_json = { 411 | "service": { 412 | "rpcs": [ 413 | { 414 | "name": "FooPredict", 415 | "input_type": "foo.bar.Foo", 416 | "client_streaming": True, 417 | "output_type": "foo.bar.Bar", 418 | } 419 | ] 420 | } 421 | } 422 | service = json_to_service( 423 | package="foo.bar", 424 | name="FooService", 425 | json_service_def=service_json, 426 | descriptor_pool=temp_dpool, 427 | ) 428 | 429 | class Servicer(service.service_class): 430 | """gRPC Service Impl""" 431 | 432 | def FooPredict(self, request_stream, context): 433 | return bar_message(boo=int(sum(i.bar for i in request_stream)), baz=True) 434 | 435 | with _test_server_client(service, Servicer) as client: 436 | input = iter(map(lambda i: foo_message(foo=True, bar=i), range(100))) 437 | 438 | # Make a gRPC call 439 | response = client.FooPredict(input) 440 | assert response.boo == 4950 # sum of range(100) 441 | 442 | 443 | def test_end_to_end_client_and_server_streaming_integration( 444 | foo_message, bar_message, temp_dpool 445 | ): 446 | service_json = { 447 | "service": { 448 | "rpcs": [ 449 | { 450 | "name": "FooPredict", 451 | "input_type": "foo.bar.Foo", 452 | "client_streaming": True, 453 | "output_type": "foo.bar.Bar", 454 | "server_streaming": True, 455 | } 456 | ] 457 | } 458 | } 459 | service = json_to_service( 460 | package="foo.bar", 461 | name="FooService", 462 | json_service_def=service_json, 463 | descriptor_pool=temp_dpool, 464 | ) 465 | 466 | class Servicer(service.service_class): 467 | """gRPC Service Impl""" 468 | 469 | def FooPredict(self, request_stream, context): 470 | count = sum(i.bar for i in request_stream) 471 | return iter( 472 | map(lambda i: bar_message(boo=int(count), baz=True), range(100)) 473 | ) 474 | 475 | with _test_server_client(service, Servicer) as client: 476 | input = iter(map(lambda i: foo_message(foo=True, bar=i), range(100))) 477 | 478 | # Make a gRPC call 479 | for i, bar in enumerate(client.FooPredict(input)): 480 | assert bar.boo == 4950 # sum of range(100) 481 | assert i == 99 482 | 483 | 484 | def test_multiple_rpcs_with_streaming(foo_message, bar_message, temp_dpool): 485 | """ensuring that everything works with more than one endpoint""" 486 | service_json = { 487 | "service": { 488 | "rpcs": [ 489 | { 490 | "name": "BarPredict", 491 | "input_type": "foo.bar.Foo", 492 | "client_streaming": True, 493 | "output_type": "foo.bar.Bar", 494 | "server_streaming": True, 495 | }, 496 | { 497 | "name": "FooPredict", 498 | "input_type": "foo.bar.Foo", 499 | "client_streaming": True, 500 | "output_type": "foo.bar.Foo", 501 | "server_streaming": True, 502 | }, 503 | ] 504 | } 505 | } 506 | service = json_to_service( 507 | package="foo.bar", 508 | name="FooService", 509 | json_service_def=service_json, 510 | descriptor_pool=temp_dpool, 511 | ) 512 | 513 | class Servicer(service.service_class): 514 | """gRPC Service Impl""" 515 | 516 | def BarPredict(self, request_stream, context): 517 | count = sum(i.bar for i in request_stream) 518 | return iter( 519 | map(lambda i: bar_message(boo=int(count), baz=True), range(100)) 520 | ) 521 | 522 | def FooPredict(self, request_stream, context): 523 | count = sum(i.bar for i in request_stream) 524 | return iter( 525 | map(lambda i: foo_message(foo=True, bar=float(count)), range(10)) 526 | ) 527 | 528 | with _test_server_client(service, Servicer) as client: 529 | request = iter(map(lambda i: foo_message(foo=True, bar=i), range(100))) 530 | 531 | # Make a gRPC call 532 | for i, bar in enumerate(client.BarPredict(request)): 533 | assert bar.boo == 4950 # sum of range(100) 534 | assert i == 99 535 | 536 | request_2 = iter(map(lambda i: foo_message(foo=True, bar=i), range(100))) 537 | for i, foo in enumerate(client.FooPredict(request_2)): 538 | assert foo.bar == 4950.0 # sum of range(100) 539 | assert i == 9 540 | -------------------------------------------------------------------------------- /py_to_proto/converter_base.py: -------------------------------------------------------------------------------- 1 | """ 2 | This base class provides the abstract interface that needs to be implemented to 3 | convert from some schema format into a protobuf descriptor. It also implements 4 | the common conversion scaffolding that all converters will use to create the 5 | descriptor. 6 | """ 7 | 8 | # Standard 9 | from typing import Any, Dict, Generic, Iterable, List, Optional, Tuple, TypeVar, Union 10 | import abc 11 | import copy 12 | 13 | # Third Party 14 | from google.protobuf import descriptor as _descriptor 15 | from google.protobuf import descriptor_pb2 16 | from google.protobuf import descriptor_pool as _descriptor_pool 17 | from google.protobuf import struct_pb2 18 | 19 | # First Party 20 | import alog 21 | 22 | # Local 23 | from .utils import safe_add_fd_to_pool, to_upper_camel 24 | 25 | T = TypeVar("T") 26 | 27 | 28 | log = alog.use_channel("2PCVRT") 29 | 30 | 31 | # Top level descriptor types 32 | _DescriptorTypes = (_descriptor.Descriptor, _descriptor.EnumDescriptor) 33 | _DescriptorTypesUnion = Union[_descriptor.Descriptor, _descriptor.EnumDescriptor] 34 | 35 | 36 | class ConverterBase(Generic[T], abc.ABC): 37 | __doc__ = __doc__ 38 | 39 | # Types that can be returned from _convert. This is exposed as public so 40 | # that derived converters can reference it 41 | ConvertOutputTypes = Union[ 42 | # Concrete type 43 | int, 44 | # Message descriptor reference 45 | _descriptor.Descriptor, 46 | # Enum descriptor reference 47 | _descriptor.EnumDescriptor, 48 | # Nested message 49 | descriptor_pb2.DescriptorProto, 50 | # Nested enum 51 | descriptor_pb2.EnumDescriptorProto, 52 | ] 53 | 54 | def __init__( 55 | self, 56 | name: str, 57 | package: str, 58 | source_schema: T, 59 | type_mapping: Dict[Any, Union[int, _descriptor.Descriptor]], 60 | validate: bool, 61 | descriptor_pool: Optional[_descriptor_pool.DescriptorPool], 62 | ): 63 | """This class performs its work on initialization by invoking the 64 | abstract methods that the child class will implement. 65 | 66 | Args: 67 | name (str) 68 | The name for the top-level message or enum descriptor 69 | package (str) 70 | The proto package name to use for this object 71 | source_schema (T) 72 | The source schema object for the derived converter 73 | type_mapping (Dict[Any, Union[int, _descriptor.Descriptor]]) 74 | A mapping from how types are represented in T to the protobuf 75 | type enum and/or Descriptor to be used to represent the type in 76 | proto. 77 | validate (bool) 78 | Whether or not to perform validation before attempting 79 | conversion 80 | descriptor_pool (Optional[_descriptor_pool.DescriptorPool]) 81 | An explicit descriptor pool to use for the new descriptor 82 | """ 83 | # Set up the shared members for this converter 84 | self.type_mapping = type_mapping 85 | self.package = package 86 | self.imports = set() 87 | 88 | # Perform validation if requested 89 | if validate: 90 | log.debug2("Validating") 91 | if not self.validate(source_schema): 92 | raise ValueError(f"Invalid Schema: {source_schema}") 93 | 94 | # Figure out which descriptor pool to use 95 | if descriptor_pool is None: 96 | log.debug2("Using default descriptor pool") 97 | descriptor_pool = _descriptor_pool.Default() 98 | self.descriptor_pool = descriptor_pool 99 | 100 | # Perform the recursive conversion to update the descriptors and enums in 101 | # place 102 | log.debug("Performing conversion") 103 | descriptor_proto = self._convert(entry=source_schema, name=name) 104 | proto_kwargs = {} 105 | is_enum = False 106 | if isinstance(descriptor_proto, descriptor_pb2.DescriptorProto): 107 | proto_kwargs["message_type"] = [descriptor_proto] 108 | elif isinstance(descriptor_proto, descriptor_pb2.EnumDescriptorProto): 109 | is_enum = True 110 | proto_kwargs["enum_type"] = [descriptor_proto] 111 | else: 112 | raise ValueError("Only messages and enums are supported") 113 | 114 | # Create the FileDescriptorProto with all messages 115 | log.debug("Creating FileDescriptorProto") 116 | fd_proto = descriptor_pb2.FileDescriptorProto( 117 | name=f"{package}.{name.lower()}.proto", 118 | package=package, 119 | syntax="proto3", 120 | dependency=sorted(list(self.imports)), 121 | **proto_kwargs, 122 | ) 123 | log.debug4("Full FileDescriptorProto:\n%s", fd_proto) 124 | 125 | # Add the new file descriptor to the pool 126 | log.debug("Adding Descriptors to DescriptorPool") 127 | safe_add_fd_to_pool(fd_proto, self.descriptor_pool) 128 | 129 | # Return the descriptor for the top-level message 130 | fullname = name if not package else ".".join([package, name]) 131 | if is_enum: 132 | self.descriptor = descriptor_pool.FindEnumTypeByName(fullname) 133 | else: 134 | self.descriptor = descriptor_pool.FindMessageTypeByName(fullname) 135 | 136 | ## Abstract Interface ###################################################### 137 | 138 | @abc.abstractmethod 139 | def validate(self, source_schema: T) -> bool: 140 | """Perform preprocess validation of the input""" 141 | 142 | ## Types ## 143 | 144 | @abc.abstractmethod 145 | def get_concrete_type(self, entry: Any) -> Any: 146 | """If this is a concrete type, get the type map key for it""" 147 | 148 | ## Maps ## 149 | 150 | @abc.abstractmethod 151 | def get_map_key_val_types( 152 | self, 153 | entry: Any, 154 | ) -> Optional[Tuple[int, ConvertOutputTypes]]: 155 | """Get the key and value types for a given map type""" 156 | 157 | ## Enums ## 158 | 159 | @abc.abstractmethod 160 | def get_enum_vals(self, entry: Any) -> Optional[Iterable[Tuple[str, int]]]: 161 | """Get the ordered list of enum name -> number mappings if this entry is 162 | an enum 163 | 164 | NOTE: If any values appear multiple times, this implies an alias 165 | 166 | NOTE 2: All names must be unique 167 | """ 168 | 169 | ## Messages ## 170 | 171 | @abc.abstractmethod 172 | def get_message_fields(self, entry: Any) -> Optional[Iterable[Tuple[str, Any]]]: 173 | """Get the mapping of names to type-specific field descriptors if this 174 | entry is a message 175 | """ 176 | 177 | @abc.abstractmethod 178 | def has_additional_fields(self, entry: Any) -> bool: 179 | """Check whether the given entry expects to support arbitrary key/val 180 | additional properties 181 | """ 182 | 183 | @abc.abstractmethod 184 | def get_optional_field_names(self, entry: Any) -> List[str]: 185 | """Get the names of any fields which are explicitly marked 'optional'""" 186 | 187 | ## Fields ## 188 | 189 | @abc.abstractmethod 190 | def get_field_number(self, num_fields: int, field_def: Any) -> int: 191 | """From the given field definition and index, get the proto field number""" 192 | 193 | @abc.abstractmethod 194 | def get_oneof_fields(self, field_def: Any) -> Optional[Iterable[Tuple[str, Any]]]: 195 | """If the given field is a oneof, return an iterable of the sub-field 196 | definitions 197 | """ 198 | 199 | @abc.abstractmethod 200 | def get_oneof_name(self, field_def: Any) -> str: 201 | """For an identified oneof field def, get the name""" 202 | 203 | @abc.abstractmethod 204 | def get_field_type(self, field_def: Any) -> Any: 205 | """Get the type of the field. The definition of type here will be 206 | specific to the converter (e.g. string for JTD, py type for dataclass) 207 | """ 208 | 209 | @abc.abstractmethod 210 | def is_repeated_field(self, field_def: Any) -> bool: 211 | """Determine if the given field def is repeated""" 212 | 213 | ## Implementation Details ################################################## 214 | 215 | def get_descriptor(self, entry: Any) -> Optional[_DescriptorTypesUnion]: 216 | """Given an entry, try to get a pre-existing descriptor from it. Child 217 | classes may overwrite this for alternate converters that have other 218 | known ways of getting a descriptor beyond these basics. 219 | """ 220 | if isinstance(entry, _DescriptorTypes): 221 | return entry 222 | descriptor_attr = getattr(entry, "DESCRIPTOR", None) 223 | if descriptor_attr and isinstance(descriptor_attr, _DescriptorTypes): 224 | return descriptor_attr 225 | return None 226 | 227 | def _add_descriptor_imports(self, descriptor: _DescriptorTypesUnion): 228 | """Helper to add the descriptor's file to the required imports""" 229 | import_file = descriptor.file.name 230 | log.debug3("Adding import file %s", import_file) 231 | 232 | # If the referenced descriptor lives in a different descriptor pool, we 233 | # need to copy it over to the target pool 234 | if descriptor.file.pool != self.descriptor_pool: 235 | log.debug2("Copying descriptor file %s to pool", import_file) 236 | fd_proto = descriptor_pb2.FileDescriptorProto() 237 | descriptor.file.CopyToProto(fd_proto) 238 | safe_add_fd_to_pool(fd_proto, self.descriptor_pool) 239 | self.imports.add(import_file) 240 | 241 | @staticmethod 242 | def _get_field_type_name(field_type: Any, field_name: str) -> str: 243 | """If the nested field definition is a type (a class), the expectation 244 | is that the nested object will have the same name as the class itself, 245 | otherwise we use the field name as the implicit name for nested objects. 246 | """ 247 | if isinstance(field_type, type): 248 | return field_type.__name__ 249 | return field_name 250 | 251 | def _convert(self, entry: Any, name: str) -> ConvertOutputTypes: 252 | """This is the core recursive implementation detail function that does 253 | the common conversion logic for all converters. 254 | """ 255 | 256 | # Handle concrete types 257 | concrete_type = self.get_concrete_type(entry) 258 | if concrete_type: 259 | log.debug2("Handling concrete type: %s", concrete_type) 260 | return self._convert_concrete_type(concrete_type) 261 | 262 | # Handle Dicts 263 | map_info = self.get_map_key_val_types(entry) 264 | if map_info: 265 | log.debug2("Handling map type: %s", entry) 266 | return self._convert_map(name, *map_info) 267 | 268 | # Handle enums 269 | enum_entries = self.get_enum_vals(entry) 270 | if enum_entries is not None: 271 | log.debug2("Handling Enum: %s", entry) 272 | return self._convert_enum(name, entry, enum_entries) 273 | 274 | # Handle messages 275 | # 276 | # Returns: descriptor_pb2.DescriptorProto 277 | message_fields = self.get_message_fields(entry) 278 | if message_fields is not None: 279 | log.debug2("Handling Message") 280 | return self._convert_message(name, entry, message_fields) 281 | 282 | # We should never get here! 283 | raise ValueError(f"Got unsupported entry type {entry}") 284 | 285 | def _convert_concrete_type( 286 | self, concrete_type: Any 287 | ) -> Union[int, _descriptor.Descriptor, _descriptor.EnumDescriptor]: 288 | """Perform the common conversion for an extracted concrete type""" 289 | entry_type = self.type_mapping.get(concrete_type, concrete_type) 290 | proto_type_descriptor = None 291 | descriptor_ref = self.get_descriptor(entry_type) 292 | if descriptor_ref is not None: 293 | proto_type_descriptor = descriptor_ref 294 | else: 295 | if concrete_type not in self.type_mapping: 296 | raise ValueError(f"Invalid type specifier: {concrete_type}") 297 | proto_type_val = self.type_mapping[concrete_type] 298 | proto_type_descriptor = getattr(proto_type_val, "DESCRIPTOR", None) 299 | if proto_type_descriptor is None: 300 | if not isinstance(proto_type_val, int): 301 | raise ValueError( 302 | "All proto_type_map values must be Descriptors or int" 303 | ) 304 | proto_type_descriptor = proto_type_val 305 | 306 | # If this is a non-primitive type, make sure any import files are added 307 | if isinstance(proto_type_descriptor, _DescriptorTypes): 308 | self._add_descriptor_imports(proto_type_descriptor) 309 | log.debug3("Returning type %s", proto_type_descriptor) 310 | return proto_type_descriptor 311 | 312 | def _convert_map( 313 | self, 314 | name: str, 315 | key_type: int, 316 | val_type: ConvertOutputTypes, 317 | ) -> descriptor_pb2.DescriptorProto: 318 | """Handle map conversion 319 | 320 | If this is a Dict, handle it by making the "special" submessage and then 321 | making this field's type be that submessage 322 | 323 | Maps in descriptors are implemented in a _funky_ way. The map syntax 324 | map the_map = 1; 325 | 326 | gets converted to a repeated message as follows: 327 | option map_entry = true; 328 | optional KeyType key = 1; 329 | optional ValType value = 2; 330 | 331 | CITE: https://github.com/protocolbuffers/protobuf/blob/main/src/google/protobuf/descriptor.cc#L7512 332 | """ 333 | nested_cls_name = f"{to_upper_camel(name)}Entry" 334 | log.debug3("Making nested map<> class: %s", nested_cls_name) 335 | key_field = descriptor_pb2.FieldDescriptorProto( 336 | name="key", 337 | type=key_type, 338 | number=1, 339 | ) 340 | val_field_kwargs = {} 341 | msg_descriptor_kwargs = {} 342 | if isinstance(val_type, int): 343 | val_field_kwargs = {"type": val_type} 344 | elif isinstance(val_type, _descriptor.EnumDescriptor): 345 | val_field_kwargs = { 346 | "type": _descriptor.FieldDescriptor.TYPE_ENUM, 347 | "type_name": val_type.name, 348 | } 349 | elif isinstance(val_type, _descriptor.Descriptor): 350 | val_field_kwargs = { 351 | "type": _descriptor.FieldDescriptor.TYPE_MESSAGE, 352 | "type_name": val_type.name, 353 | } 354 | elif isinstance(val_type, descriptor_pb2.EnumDescriptorProto): 355 | val_field_kwargs = { 356 | "type": _descriptor.FieldDescriptor.TYPE_ENUM, 357 | "type_name": val_type.name, 358 | } 359 | msg_descriptor_kwargs["enum_type"] = [val_type] 360 | elif isinstance(val_type, descriptor_pb2.DescriptorProto): 361 | val_field_kwargs = { 362 | "type": _descriptor.FieldDescriptor.TYPE_MESSAGE, 363 | "type_name": val_type.name, 364 | } 365 | msg_descriptor_kwargs["nested_type"] = [val_type] 366 | assert ( 367 | val_field_kwargs 368 | ), f"Programming Error: Got unhandled map value type: {val_type}" 369 | val_field = descriptor_pb2.FieldDescriptorProto( 370 | name="value", 371 | number=2, 372 | **val_field_kwargs, 373 | ) 374 | nested = descriptor_pb2.DescriptorProto( 375 | name=nested_cls_name, 376 | field=[key_field, val_field], 377 | options=descriptor_pb2.MessageOptions(map_entry=True), 378 | **msg_descriptor_kwargs, 379 | ) 380 | return nested 381 | 382 | def _convert_enum( 383 | self, name: str, entry: Any, enum_entries: Iterable[Tuple[str, int]] 384 | ) -> descriptor_pb2.EnumDescriptorProto: 385 | """Convert nested enums""" 386 | enum_name = self._get_field_type_name(entry, to_upper_camel(name)) 387 | log.debug("Enum name: %s", enum_name) 388 | has_aliases = len(set([entry[1] for entry in enum_entries])) != len( 389 | enum_entries 390 | ) 391 | options = descriptor_pb2.EnumOptions(allow_alias=has_aliases) 392 | enum_proto = descriptor_pb2.EnumDescriptorProto( 393 | name=enum_name, 394 | value=[ 395 | descriptor_pb2.EnumValueDescriptorProto( 396 | name=entry[0], 397 | number=entry[1], 398 | ) 399 | for entry in enum_entries 400 | ], 401 | options=options, 402 | ) 403 | return enum_proto 404 | 405 | def _convert_message( 406 | self, 407 | name: str, 408 | entry: Any, 409 | message_fields, 410 | ) -> descriptor_pb2.DescriptorProto: 411 | """Convert a nested message""" 412 | field_descriptors = [] 413 | nested_enums = [] 414 | nested_messages = [] 415 | nested_oneofs = [] 416 | message_name = to_upper_camel(name) 417 | log.debug("Message name: %s", message_name) 418 | 419 | for field_name, field_def in message_fields: 420 | field_number = self.get_field_number(len(field_descriptors), field_def) 421 | log.debug2( 422 | "Handling field [%s.%s] (%d)", 423 | message_name, 424 | field_name, 425 | field_number, 426 | ) 427 | 428 | # Get the field's number 429 | field_kwargs = { 430 | "name": field_name, 431 | "number": field_number, 432 | "label": _descriptor.FieldDescriptor.LABEL_OPTIONAL, 433 | } 434 | 435 | # Check to see if the field is repeated 436 | if self.is_repeated_field(field_def): 437 | log.debug3("Handling repeated field %s", field_name) 438 | field_kwargs["label"] = _descriptor.FieldDescriptor.LABEL_REPEATED 439 | 440 | # If the field is a oneof, handle it as such 441 | oneof_fields = self.get_oneof_fields(field_def) 442 | if oneof_fields: 443 | log.debug2("Handling oneof field %s", field_name) 444 | nested_results = [ 445 | ( 446 | self._convert( 447 | entry=oneof_field_def, 448 | name=self._get_field_type_name( 449 | oneof_field_def, oneof_field_name 450 | ), 451 | ), 452 | { 453 | "oneof_index": len(nested_oneofs), 454 | "number": self.get_field_number( 455 | len(field_descriptors) + oneof_field_idx, 456 | oneof_field_def, 457 | ), 458 | "name": oneof_field_name.lower(), 459 | }, 460 | ) 461 | for oneof_field_idx, ( 462 | oneof_field_name, 463 | oneof_field_def, 464 | ) in enumerate(oneof_fields) 465 | ] 466 | # Add the name for this oneof 467 | nested_oneofs.append( 468 | descriptor_pb2.OneofDescriptorProto( 469 | name=self.get_oneof_name(field_def) 470 | ) 471 | ) 472 | 473 | # Otherwise, it's a "regular" field, so just recurse on the type 474 | else: 475 | log.debug3("Handling non-oneof field: %s", field_name) 476 | # If the nested field definition is a type (a class), the 477 | # expectation is that the nested object will have the same name 478 | # as the class itself, otherwise we use the field name as the 479 | # implicit name for nested objects. 480 | field_type = self.get_field_type(field_def) 481 | nested_name = self._get_field_type_name(field_type, field_name) 482 | nested_result = self._convert(entry=field_type, name=nested_name) 483 | nested_results = [(nested_result, {})] 484 | 485 | # For all nested fields produced by either the onoof logic or 486 | # the single-field logic, construct a FieldDescriptor and add it 487 | # to the message descriptor. 488 | for nested, extra_kwargs in nested_results: 489 | nested_field_kwargs = copy.copy(field_kwargs) 490 | nested_field_kwargs.update(extra_kwargs) 491 | 492 | # If the result is an int, it's a type value 493 | if isinstance(nested, int): 494 | nested_field_kwargs["type"] = nested 495 | 496 | # If the result is an enum descriptor ref, it's an external 497 | # enum 498 | elif isinstance(nested, _descriptor.EnumDescriptor): 499 | nested_field_kwargs["type"] = _descriptor.FieldDescriptor.TYPE_ENUM 500 | nested_field_kwargs["type_name"] = nested.full_name 501 | 502 | # If the result is a message descriptor ref, it's an 503 | # external message 504 | elif isinstance(nested, _descriptor.Descriptor): 505 | nested_field_kwargs[ 506 | "type" 507 | ] = _descriptor.FieldDescriptor.TYPE_MESSAGE 508 | nested_field_kwargs["type_name"] = nested.full_name 509 | 510 | # If the result is an enum proto, it's a nested enum 511 | elif isinstance(nested, descriptor_pb2.EnumDescriptorProto): 512 | log.debug3("Adding nested enum %s", nested.name) 513 | nested_field_kwargs["type"] = _descriptor.FieldDescriptor.TYPE_ENUM 514 | nested_field_kwargs["type_name"] = nested.name 515 | nested_enums.append(nested) 516 | 517 | # If the result is a message proto, it's a nested message 518 | elif isinstance(nested, descriptor_pb2.DescriptorProto): 519 | log.debug3("Adding nested message %s", nested.name) 520 | nested_field_kwargs[ 521 | "type" 522 | ] = _descriptor.FieldDescriptor.TYPE_MESSAGE 523 | nested_field_kwargs["type_name"] = nested.name 524 | nested_messages.append(nested) 525 | 526 | # If the message has map_entry set, we need to indicate that 527 | # it's repeated 528 | if nested.options.map_entry: 529 | nested_field_kwargs[ 530 | "label" 531 | ] = _descriptor.FieldDescriptor.LABEL_REPEATED 532 | 533 | # If the nested map entry itself has nested types or enums, 534 | # they need to be moved up to this message 535 | while nested.nested_type: 536 | nested_type = nested.nested_type.pop() 537 | plain_name = nested_type.name 538 | nested_name = to_upper_camel( 539 | "_".join([field_name, plain_name]) 540 | ) 541 | nested_type.MergeFrom( 542 | descriptor_pb2.DescriptorProto(name=nested_name) 543 | ) 544 | for field in nested.field: 545 | if field.type_name == plain_name: 546 | field.MergeFrom( 547 | descriptor_pb2.FieldDescriptorProto( 548 | type_name=nested_name 549 | ) 550 | ) 551 | nested_messages.append(nested_type) 552 | while nested.enum_type: 553 | nested_enum = nested.enum_type.pop() 554 | plain_name = nested_enum.name 555 | nested_name = to_upper_camel( 556 | "_".join([field_name, plain_name]) 557 | ) 558 | nested_enum.MergeFrom( 559 | descriptor_pb2.EnumDescriptorProto(name=nested_name) 560 | ) 561 | for field in nested.field: 562 | if field.type_name == plain_name: 563 | field.MergeFrom( 564 | descriptor_pb2.FieldDescriptorProto( 565 | type_name=nested_name 566 | ) 567 | ) 568 | nested_enums.append(nested_enum) 569 | 570 | # Create the field descriptor 571 | field_descriptors.append( 572 | descriptor_pb2.FieldDescriptorProto(**nested_field_kwargs) 573 | ) 574 | 575 | # If additional keys/vals allowed, add a 'special' field for this. 576 | # This is one place where there's not a good mapping between some 577 | # schema types and proto since proto does not allow for arbitrary 578 | # mappings _in addition_ to specific keys. Instead, there needs to 579 | # be a special Struct field to hold these additional fields. 580 | if self.has_additional_fields(entry): 581 | if "additionalProperties" in [field.name for field in field_descriptors]: 582 | raise ValueError( 583 | "Cannot specify 'additionalProperties' as a field and support arbitrary key/vals" 584 | ) 585 | field_descriptors.append( 586 | descriptor_pb2.FieldDescriptorProto( 587 | name="additionalProperties", 588 | number=len(field_descriptors) + 1, 589 | type=_descriptor.FieldDescriptor.TYPE_MESSAGE, 590 | label=_descriptor.FieldDescriptor.LABEL_OPTIONAL, 591 | type_name=struct_pb2.Struct.DESCRIPTOR.full_name, 592 | ) 593 | ) 594 | self.imports.add(struct_pb2.Struct.DESCRIPTOR.file.name) 595 | 596 | # Support optional properties as oneofs i.e. optional int32 foo = 1; 597 | # becomes interpreted as oneof _foo { int32 foo = 1; } 598 | optional_oneofs: List[descriptor_pb2.OneofDescriptorProto] = [] 599 | for field in field_descriptors: 600 | if ( 601 | field.name in self.get_optional_field_names(entry) 602 | and field.label == _descriptor.FieldDescriptor.LABEL_OPTIONAL 603 | ): 604 | log.debug3("Making field %s as optional with oneof", field.name) 605 | # OneofDescriptorProto do not contain fields themselves. 606 | # Instead the FieldDescriptorProto must contain the index of 607 | # the oneof inside the DescriptorProto 608 | optional_oneofs.append( 609 | descriptor_pb2.OneofDescriptorProto(name=f"_{field.name}") 610 | ) 611 | field.oneof_index = len(nested_oneofs) + len(optional_oneofs) - 1 612 | 613 | # Construct the message descriptor proto with the aggregated fields 614 | # and nested enums/messages/oneofs 615 | log.debug3( 616 | "All field descriptors for [%s]:\n%s", message_name, field_descriptors 617 | ) 618 | descriptor_proto = descriptor_pb2.DescriptorProto( 619 | name=message_name, 620 | field=field_descriptors, 621 | enum_type=nested_enums, 622 | nested_type=nested_messages, 623 | oneof_decl=nested_oneofs + optional_oneofs, 624 | ) 625 | return descriptor_proto 626 | --------------------------------------------------------------------------------