├── example ├── __init__.py ├── example.proto ├── example_pb2.py └── example_pb2.pyi ├── pbspark ├── _version.py ├── __init__.py ├── _timestamp.py └── _proto.py ├── .tool-versions ├── tests ├── fixtures.py └── test_proto.py ├── Makefile ├── LICENSE.txt ├── pyproject.toml ├── .github └── workflows │ └── build.yml ├── CHANGELOG.md ├── .gitignore ├── README.md └── poetry.lock /example/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /pbspark/_version.py: -------------------------------------------------------------------------------- 1 | __version__ = "0.9.0" 2 | -------------------------------------------------------------------------------- /.tool-versions: -------------------------------------------------------------------------------- 1 | python 3.9.5 2 | poetry 1.5.1 3 | protoc 21.1 4 | java adoptopenjdk-8.0.275+1 5 | -------------------------------------------------------------------------------- /pbspark/__init__.py: -------------------------------------------------------------------------------- 1 | from ._proto import MessageConverter 2 | from ._proto import df_from_protobuf 3 | from ._proto import df_to_protobuf 4 | from ._proto import from_protobuf 5 | from ._proto import to_protobuf 6 | from ._version import __version__ 7 | -------------------------------------------------------------------------------- /tests/fixtures.py: -------------------------------------------------------------------------------- 1 | """These items are in a separate module so that spark workers can deserialize them. 2 | 3 | When referenced from the same file as the test a ModuleNotFoundError is raised. 4 | """ 5 | import json 6 | from decimal import Decimal 7 | 8 | from google.protobuf.json_format import MessageToDict 9 | 10 | from example.example_pb2 import DecimalMessage 11 | from example.example_pb2 import RecursiveMessage 12 | 13 | 14 | def encode_recursive(message: RecursiveMessage, depth=0): 15 | if depth == 2: 16 | return json.dumps(MessageToDict(message)) 17 | return { 18 | "note": message.note, 19 | "message": encode_recursive(message.message, depth=depth + 1), 20 | } 21 | 22 | 23 | def decimal_serializer(message: DecimalMessage): 24 | return Decimal(message.value) 25 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | export PROTO_PATH=. 2 | 3 | fmt: 4 | poetry run isort . 5 | poetry run black . 6 | poetry run mypy . --show-error-codes 7 | 8 | gen: 9 | poetry run protoc -I $$PROTO_PATH --python_out=$$PROTO_PATH --mypy_out=$$PROTO_PATH --proto_path=$$PROTO_PATH $$PROTO_PATH/example/*.proto 10 | poetry run isort ./example 11 | poetry run black ./example 12 | 13 | test: 14 | poetry run pytest tests/ 15 | 16 | clean: 17 | rm -rf dist 18 | 19 | .PHONY: dist 20 | dist: 21 | poetry build 22 | 23 | sdist: 24 | poetry build -f sdist 25 | 26 | publish: clean dist 27 | poetry publish 28 | 29 | release: clean sdist 30 | ghr -u crflynn -r pbspark -c $(shell git rev-parse HEAD) -delete -b "release" -n $(shell poetry version -s) $(shell poetry version -s) dist/*.tar.gz 31 | 32 | setup: 33 | asdf plugin add python || true 34 | asdf plugin add poetry || true 35 | asdf plugin add protoc || true 36 | asdf plugin add java || true 37 | asdf install 38 | poetry install 39 | -------------------------------------------------------------------------------- /pbspark/_timestamp.py: -------------------------------------------------------------------------------- 1 | """Timestamp serde.""" 2 | import calendar 3 | import datetime 4 | 5 | from google.protobuf.internal import well_known_types 6 | from google.protobuf.timestamp_pb2 import Timestamp 7 | 8 | 9 | def _to_datetime(message: Timestamp) -> datetime.datetime: 10 | """Convert a Timestamp to a python datetime.""" 11 | return well_known_types._EPOCH_DATETIME_NAIVE + datetime.timedelta( # type: ignore[attr-defined] 12 | seconds=message.seconds, 13 | microseconds=well_known_types._RoundTowardZero( # type: ignore[attr-defined] 14 | message.nanos, 15 | well_known_types._NANOS_PER_MICROSECOND, # type: ignore[attr-defined] 16 | ), 17 | ) 18 | 19 | 20 | def _from_datetime(dt: datetime.datetime, timestamp: Timestamp, path: str): 21 | """Convert a python datetime to Timestamp.""" 22 | timestamp.seconds = calendar.timegm(dt.utctimetuple()) 23 | timestamp.nanos = dt.microsecond * well_known_types._NANOS_PER_MICROSECOND # type: ignore[attr-defined] 24 | -------------------------------------------------------------------------------- /LICENSE.txt: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Christopher Flynn 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.poetry] 2 | name = "pbspark" 3 | version = "0.9.0" 4 | description = "Convert between protobuf messages and pyspark dataframes" 5 | authors = ["flynn "] 6 | license = "MIT" 7 | readme = "README.md" 8 | homepage = "https://github.com/crflynn/pbspark" 9 | repository = "https://github.com/crflynn/pbspark" 10 | documentation = "https://github.com/crflynn/pbspark" 11 | include = [ 12 | { path="CHANGELOG.md", format="sdist" }, 13 | { path="LICENSE.txt", format="sdist" }, 14 | { path="README.md", format="sdist" }, 15 | ] 16 | keywords = ["spark", "protobuf", "pyspark"] 17 | classifiers = [ 18 | "Development Status :: 4 - Beta", 19 | "License :: OSI Approved :: MIT License", 20 | "Programming Language :: Python", 21 | "Topic :: Database", 22 | ] 23 | 24 | [tool.poetry.dependencies] 25 | python = "^3.7" 26 | pyspark = ">=3.1.1" 27 | protobuf = ">=3.20.0" 28 | 29 | [tool.poetry.dev-dependencies] 30 | black = "^21.11b1" 31 | isort = "^5.10.1" 32 | pytest = "^6.2.5" 33 | mypy-protobuf = "^3.0.0" 34 | types-protobuf = "^3.18.2" 35 | click = "8.0.4" # https://github.com/psf/black/issues/2964 36 | mypy = "^0.942" 37 | 38 | [build-system] 39 | requires = ["poetry-core>=1.0.0"] 40 | build-backend = "poetry.core.masonry.api" 41 | 42 | [tool.black] 43 | line-length = 88 44 | target-version = ['py37'] 45 | include = '\.pyi?$' 46 | exclude = ''' 47 | ( 48 | /( 49 | \.eggs 50 | | \.circleci 51 | | \.git 52 | | \.github 53 | | \.hg 54 | | \.mypy_cache 55 | | \.pytest_cache 56 | | \.tox 57 | | \.venv 58 | | _build 59 | | buck-out 60 | | build 61 | | dist 62 | )/ 63 | ) 64 | ''' 65 | 66 | [tool.isort] 67 | force_single_line = true 68 | multi_line_output = 3 69 | include_trailing_comma = true 70 | force_grid_wrap = 0 71 | use_parentheses = true 72 | line_length = 88 73 | -------------------------------------------------------------------------------- /example/example.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto3"; 2 | 3 | package example; 4 | 5 | import "google/protobuf/timestamp.proto"; 6 | import "google/protobuf/duration.proto"; 7 | import "google/protobuf/wrappers.proto"; 8 | 9 | message SimpleMessage { 10 | string name = 1; 11 | int64 quantity = 2; 12 | float measure = 3; 13 | } 14 | 15 | message NestedMessage { 16 | string key = 1; 17 | string value = 2; 18 | } 19 | 20 | message DecimalMessage { 21 | string value = 1; 22 | } 23 | 24 | message ExampleMessage { 25 | int32 int32 = 1; 26 | int64 int64 = 2; 27 | uint32 uint32 = 3; 28 | uint64 uint64 = 4; 29 | double double = 5; 30 | float float = 6; 31 | bool bool = 7; 32 | enum SomeEnum { 33 | unspecified = 0; 34 | first = 1; 35 | second = 2; 36 | } 37 | SomeEnum enum = 8; 38 | string string = 9; 39 | NestedMessage nested = 10; 40 | repeated string stringlist = 11; 41 | bytes bytes = 12; 42 | sfixed32 sfixed32 = 13; 43 | sfixed64 sfixed64 = 14; 44 | sint32 sint32 = 15; 45 | sint64 sint64 = 16; 46 | fixed32 fixed32 = 17; 47 | fixed64 fixed64 = 18; 48 | oneof oneof { 49 | string oneofstring = 19; 50 | int32 oneofint32 = 20; 51 | } 52 | map map = 21; 53 | google.protobuf.Timestamp timestamp = 22; 54 | google.protobuf.Duration duration = 23; 55 | DecimalMessage decimal = 24; 56 | google.protobuf.DoubleValue doublevalue = 25; 57 | google.protobuf.FloatValue floatvalue = 26; 58 | google.protobuf.Int64Value int64value = 27; 59 | google.protobuf.UInt64Value uint64value = 28; 60 | google.protobuf.Int32Value int32value = 29; 61 | google.protobuf.UInt32Value uint32value = 30; 62 | google.protobuf.BoolValue boolvalue = 31; 63 | google.protobuf.StringValue stringvalue = 32; 64 | google.protobuf.BytesValue bytesvalue = 33; 65 | string case_name = 34; 66 | } 67 | 68 | message RecursiveMessage { 69 | string note = 1; 70 | RecursiveMessage message = 2; 71 | } -------------------------------------------------------------------------------- /.github/workflows/build.yml: -------------------------------------------------------------------------------- 1 | name: build 2 | 3 | on: 4 | push: 5 | branches: 6 | - main 7 | pull_request: 8 | branches: 9 | - main 10 | jobs: 11 | 12 | linux: 13 | 14 | runs-on: ubuntu-latest 15 | strategy: 16 | matrix: 17 | python-version: [3.9.5] 18 | 19 | steps: 20 | - name: Checkout 21 | uses: actions/checkout@v1 22 | 23 | # required for coverage 24 | - name: Install sqlite 25 | run: | 26 | sudo apt-get update 27 | sudo apt-get install -y libsqlite3-dev 28 | 29 | - name: Install asdf 30 | uses: asdf-vm/actions/setup@v1.0.1 31 | 32 | - name: Install asdf plugins 33 | run: | 34 | asdf plugin add python 35 | asdf plugin add poetry 36 | asdf plugin add protoc 37 | asdf plugin add java 38 | 39 | - name: Cache dependencies 40 | id: cache-deps 41 | uses: actions/cache@v1 42 | with: 43 | path: ~/.asdf 44 | key: v2-${{ runner.os }}-${{ matrix.python-version }}-asdf-${{ hashFiles(format('{0}{1}', github.workspace, '/poetry.lock')) }} 45 | restore-keys: | 46 | v2-${{ runner.os }}-${{ matrix.python-version }}-asdf- 47 | 48 | - name: Install python 49 | if: steps.cache-deps.outputs.cache-hit != 'true' 50 | run: | 51 | asdf install python ${{ matrix.python-version }} 52 | 53 | - name: Set python 54 | run: | 55 | asdf local python ${{ matrix.python-version }} 56 | 57 | - name: Install tools 58 | run: | 59 | asdf install 60 | poetry config virtualenvs.create false 61 | 62 | - name: Install deps 63 | run: | 64 | poetry install 65 | 66 | - name: Fmt 67 | run: | 68 | poetry run isort --check . 69 | poetry run black --check . 70 | poetry run mypy . 71 | 72 | # # broken in ci 73 | # - name: Generate 74 | # run: | 75 | # make gen 76 | 77 | - name: Test 78 | run: | 79 | make test 80 | -------------------------------------------------------------------------------- /CHANGELOG.md: -------------------------------------------------------------------------------- 1 | # Changelog 2 | 3 | ## 2023-06-07 - 0.9.0 4 | 5 | * Relax pyspark constraint 6 | 7 | ## 2023-01-11 - 0.8.0 8 | 9 | * Breaking: Provide the same kwargs used in the protobuf lib on encoding/decoding rather than the ``options`` dict, except ``DescriptorPool`` which is unserializable. 10 | * Breaking: Change param ``mc`` -> ``message_converter`` on top level functions. 11 | 12 | ## 2022-07-07 - 0.7.0 13 | 14 | * Bugfix: Fixed a bug where int64 protobuf types were not being properly converted into spark types. 15 | * Added support for protobuf wrapper types. 16 | 17 | ## 2022-06-22 - 0.6.1 18 | 19 | * Bugfix: Fixed a bug where ``options`` was not being passed recursively in ``get_spark_schema``. 20 | 21 | ## 2022-06-13 - 0.6.0 22 | 23 | * Add ``to_protobuf`` and ``from_protobuf`` functions to operate on columns without needing a ``MessageConverter``. 24 | * Add ``df_to_protobuf`` and ``df_from_protobuf`` functions to operate on DataFrames without needing a ``MessageConverter``. These functions also optionally handle struct expansion. 25 | 26 | ## 2022-06-12 - 0.5.1 27 | 28 | * Bugfix: Fix ``bytearray`` TypeError when using newer versions of protobuf 29 | 30 | ## 2022-05-20 - 0.5.0 31 | 32 | * Breaking: return type instances to be passed to custom serializers rather than type class + init kwargs 33 | * Bugfix: `get_spark_schema` now returns properly when the descriptor passed has a registered custom serializer 34 | 35 | ## 2022-05-19 - 0.4.0 36 | 37 | * Breaking: pbspark now encodes the well known type `Timestamp` to spark `TimestampType` by default. 38 | * Bugfix: protobuf bytes now properly convert to spark BinaryType 39 | * Bugfix: message decoding now properly works by populating the passed message instance rather than returning a new one 40 | * protobuf objects are now patched only temporarily when being used by pbspark 41 | * Timestamp conversion now references protobuf well known type objects rather than objects copied from protobuf to pbspark 42 | * Modify the encoding function to convert udf-passed Row objects to dictionaries before passing to the parser. 43 | * Documentation fixes and more details on custom encoders. 44 | 45 | ## 2022-04-19 - 0.3.0 46 | 47 | * Breaking: protobuf bytes fields will now convert directly to spark ByteType and vice versa. 48 | * Relax constraint on pyspark 49 | * Bump minimum protobuf version to 3.20.0 50 | 51 | ## 2021-12-05 - 0.2.0 52 | 53 | * Add `to_protobuf` method to encode pyspark structs to protobuf 54 | 55 | ## 2021-12-01 - 0.1.0 56 | 57 | * initial release 58 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 98 | __pypackages__/ 99 | 100 | # Celery stuff 101 | celerybeat-schedule 102 | celerybeat.pid 103 | 104 | # SageMath parsed files 105 | *.sage.py 106 | 107 | # Environments 108 | .env 109 | .venv 110 | env/ 111 | venv/ 112 | ENV/ 113 | env.bak/ 114 | venv.bak/ 115 | 116 | # Spyder project settings 117 | .spyderproject 118 | .spyproject 119 | 120 | # Rope project settings 121 | .ropeproject 122 | 123 | # mkdocs documentation 124 | /site 125 | 126 | # mypy 127 | .mypy_cache/ 128 | .dmypy.json 129 | dmypy.json 130 | 131 | # Pyre type checker 132 | .pyre/ 133 | 134 | # pytype static type analyzer 135 | .pytype/ 136 | 137 | # Cython debug symbols 138 | cython_debug/ 139 | 140 | 141 | .idea/ 142 | .DS_Store -------------------------------------------------------------------------------- /example/example_pb2.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Generated by the protocol buffer compiler. DO NOT EDIT! 3 | # source: example/example.proto 4 | """Generated protocol buffer code.""" 5 | from google.protobuf import descriptor as _descriptor 6 | from google.protobuf import descriptor_pool as _descriptor_pool 7 | from google.protobuf import symbol_database as _symbol_database 8 | from google.protobuf.internal import builder as _builder 9 | 10 | # @@protoc_insertion_point(imports) 11 | 12 | _sym_db = _symbol_database.Default() 13 | 14 | 15 | from google.protobuf import duration_pb2 as google_dot_protobuf_dot_duration__pb2 16 | from google.protobuf import timestamp_pb2 as google_dot_protobuf_dot_timestamp__pb2 17 | from google.protobuf import wrappers_pb2 as google_dot_protobuf_dot_wrappers__pb2 18 | 19 | DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile( 20 | b'\n\x15\x65xample/example.proto\x12\x07\x65xample\x1a\x1fgoogle/protobuf/timestamp.proto\x1a\x1egoogle/protobuf/duration.proto\x1a\x1egoogle/protobuf/wrappers.proto"@\n\rSimpleMessage\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\x10\n\x08quantity\x18\x02 \x01(\x03\x12\x0f\n\x07measure\x18\x03 \x01(\x02"+\n\rNestedMessage\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\t"\x1f\n\x0e\x44\x65\x63imalMessage\x12\r\n\x05value\x18\x01 \x01(\t"\x89\t\n\x0e\x45xampleMessage\x12\r\n\x05int32\x18\x01 \x01(\x05\x12\r\n\x05int64\x18\x02 \x01(\x03\x12\x0e\n\x06uint32\x18\x03 \x01(\r\x12\x0e\n\x06uint64\x18\x04 \x01(\x04\x12\x0e\n\x06\x64ouble\x18\x05 \x01(\x01\x12\r\n\x05\x66loat\x18\x06 \x01(\x02\x12\x0c\n\x04\x62ool\x18\x07 \x01(\x08\x12.\n\x04\x65num\x18\x08 \x01(\x0e\x32 .example.ExampleMessage.SomeEnum\x12\x0e\n\x06string\x18\t \x01(\t\x12&\n\x06nested\x18\n \x01(\x0b\x32\x16.example.NestedMessage\x12\x12\n\nstringlist\x18\x0b \x03(\t\x12\r\n\x05\x62ytes\x18\x0c \x01(\x0c\x12\x10\n\x08sfixed32\x18\r \x01(\x0f\x12\x10\n\x08sfixed64\x18\x0e \x01(\x10\x12\x0e\n\x06sint32\x18\x0f \x01(\x11\x12\x0e\n\x06sint64\x18\x10 \x01(\x12\x12\x0f\n\x07\x66ixed32\x18\x11 \x01(\x07\x12\x0f\n\x07\x66ixed64\x18\x12 \x01(\x06\x12\x15\n\x0boneofstring\x18\x13 \x01(\tH\x00\x12\x14\n\noneofint32\x18\x14 \x01(\x05H\x00\x12-\n\x03map\x18\x15 \x03(\x0b\x32 .example.ExampleMessage.MapEntry\x12-\n\ttimestamp\x18\x16 \x01(\x0b\x32\x1a.google.protobuf.Timestamp\x12+\n\x08\x64uration\x18\x17 \x01(\x0b\x32\x19.google.protobuf.Duration\x12(\n\x07\x64\x65\x63imal\x18\x18 \x01(\x0b\x32\x17.example.DecimalMessage\x12\x31\n\x0b\x64oublevalue\x18\x19 \x01(\x0b\x32\x1c.google.protobuf.DoubleValue\x12/\n\nfloatvalue\x18\x1a \x01(\x0b\x32\x1b.google.protobuf.FloatValue\x12/\n\nint64value\x18\x1b \x01(\x0b\x32\x1b.google.protobuf.Int64Value\x12\x31\n\x0buint64value\x18\x1c \x01(\x0b\x32\x1c.google.protobuf.UInt64Value\x12/\n\nint32value\x18\x1d \x01(\x0b\x32\x1b.google.protobuf.Int32Value\x12\x31\n\x0buint32value\x18\x1e \x01(\x0b\x32\x1c.google.protobuf.UInt32Value\x12-\n\tboolvalue\x18\x1f \x01(\x0b\x32\x1a.google.protobuf.BoolValue\x12\x31\n\x0bstringvalue\x18 \x01(\x0b\x32\x1c.google.protobuf.StringValue\x12/\n\nbytesvalue\x18! \x01(\x0b\x32\x1b.google.protobuf.BytesValue\x12\x11\n\tcase_name\x18" \x01(\t\x1a*\n\x08MapEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\t:\x02\x38\x01"2\n\x08SomeEnum\x12\x0f\n\x0bunspecified\x10\x00\x12\t\n\x05\x66irst\x10\x01\x12\n\n\x06second\x10\x02\x42\x07\n\x05oneof"L\n\x10RecursiveMessage\x12\x0c\n\x04note\x18\x01 \x01(\t\x12*\n\x07message\x18\x02 \x01(\x0b\x32\x19.example.RecursiveMessageb\x06proto3' 21 | ) 22 | 23 | _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, globals()) 24 | _builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, "example.example_pb2", globals()) 25 | if _descriptor._USE_C_DESCRIPTORS == False: 26 | 27 | DESCRIPTOR._options = None 28 | _EXAMPLEMESSAGE_MAPENTRY._options = None 29 | _EXAMPLEMESSAGE_MAPENTRY._serialized_options = b"8\001" 30 | _SIMPLEMESSAGE._serialized_start = 131 31 | _SIMPLEMESSAGE._serialized_end = 195 32 | _NESTEDMESSAGE._serialized_start = 197 33 | _NESTEDMESSAGE._serialized_end = 240 34 | _DECIMALMESSAGE._serialized_start = 242 35 | _DECIMALMESSAGE._serialized_end = 273 36 | _EXAMPLEMESSAGE._serialized_start = 276 37 | _EXAMPLEMESSAGE._serialized_end = 1437 38 | _EXAMPLEMESSAGE_MAPENTRY._serialized_start = 1334 39 | _EXAMPLEMESSAGE_MAPENTRY._serialized_end = 1376 40 | _EXAMPLEMESSAGE_SOMEENUM._serialized_start = 1378 41 | _EXAMPLEMESSAGE_SOMEENUM._serialized_end = 1428 42 | _RECURSIVEMESSAGE._serialized_start = 1439 43 | _RECURSIVEMESSAGE._serialized_end = 1515 44 | # @@protoc_insertion_point(module_scope) 45 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # pbspark 2 | 3 | This package provides a way to convert protobuf messages into pyspark dataframes and vice versa using pyspark `udf`s. 4 | 5 | ## Installation 6 | 7 | To install: 8 | 9 | ```bash 10 | pip install pbspark 11 | ``` 12 | 13 | ## Usage 14 | 15 | Suppose we have a pyspark DataFrame which contains a column `value` which has protobuf encoded messages of our `SimpleMessage`: 16 | 17 | ```protobuf 18 | syntax = "proto3"; 19 | 20 | package example; 21 | 22 | message SimpleMessage { 23 | string name = 1; 24 | int64 quantity = 2; 25 | float measure = 3; 26 | } 27 | ``` 28 | 29 | ### Basic conversion functions 30 | 31 | There are two functions for operating on columns, `to_protobuf` and `from_protobuf`. These operations convert to/from an encoded protobuf column to a column of a struct representing the inferred message structure. `MessageConverter` instances (discussed below) can optionally be passed to these functions. 32 | 33 | ```python 34 | from pyspark.sql.session import SparkSession 35 | from example.example_pb2 import SimpleMessage 36 | from pbspark import from_protobuf 37 | from pbspark import to_protobuf 38 | 39 | spark = SparkSession.builder.getOrCreate() 40 | 41 | example = SimpleMessage(name="hello", quantity=5, measure=12.3) 42 | data = [{"value": example.SerializeToString()}] 43 | df_encoded = spark.createDataFrame(data) 44 | 45 | df_decoded = df_encoded.select(from_protobuf(df_encoded.value, SimpleMessage).alias("value")) 46 | df_expanded = df_decoded.select("value.*") 47 | df_expanded.show() 48 | 49 | # +-----+--------+-------+ 50 | # | name|quantity|measure| 51 | # +-----+--------+-------+ 52 | # |hello| 5| 12.3| 53 | # +-----+--------+-------+ 54 | 55 | df_reencoded = df_decoded.select(to_protobuf(df_decoded.value, SimpleMessage).alias("value")) 56 | ``` 57 | 58 | There are two helper functions, `df_to_protobuf` and `df_from_protobuf` for use on dataframes. They have a kwarg `expanded`, which will also take care of expanding/contracting the data between the single `value` column used in these examples and a dataframe which contains a column for each message field. `MessageConverter` instances (discussed below) can optionally be passed to these functions. 59 | 60 | ```python 61 | from pyspark.sql.session import SparkSession 62 | from example.example_pb2 import SimpleMessage 63 | from pbspark import df_from_protobuf 64 | from pbspark import df_to_protobuf 65 | 66 | spark = SparkSession.builder.getOrCreate() 67 | 68 | example = SimpleMessage(name="hello", quantity=5, measure=12.3) 69 | data = [{"value": example.SerializeToString()}] 70 | df_encoded = spark.createDataFrame(data) 71 | 72 | # expanded=True will perform a `.select("value.*")` after converting, 73 | # resulting in each protobuf field having its own column 74 | df_expanded = df_from_protobuf(df_encoded, SimpleMessage, expanded=True) 75 | df_expanded.show() 76 | 77 | # +-----+--------+-------+ 78 | # | name|quantity|measure| 79 | # +-----+--------+-------+ 80 | # |hello| 5| 12.3| 81 | # +-----+--------+-------+ 82 | 83 | # expanded=True will first pack data using `struct([df[c] for c in df.columns])`, 84 | # use this if the passed dataframe is already expanded 85 | df_reencoded = df_to_protobuf(df_expanded, SimpleMessage, expanded=True) 86 | ``` 87 | 88 | ### Column conversion using the `MessageConverter` 89 | 90 | The four helper functions above are also available as methods on the `MessageConverter` class. Using an instance of `MessageConverter` we can decode the column of encoded messages into a column of spark `StructType` and then expand the fields. 91 | 92 | ```python 93 | from pyspark.sql.session import SparkSession 94 | from pbspark import MessageConverter 95 | from example.example_pb2 import SimpleMessage 96 | 97 | spark = SparkSession.builder.getOrCreate() 98 | 99 | example = SimpleMessage(name="hello", quantity=5, measure=12.3) 100 | data = [{"value": example.SerializeToString()}] 101 | df_encoded = spark.createDataFrame(data) 102 | 103 | mc = MessageConverter() 104 | df_decoded = df_encoded.select(mc.from_protobuf(df_encoded.value, SimpleMessage).alias("value")) 105 | df_expanded = df_decoded.select("value.*") 106 | df_expanded.show() 107 | 108 | # +-----+--------+-------+ 109 | # | name|quantity|measure| 110 | # +-----+--------+-------+ 111 | # |hello| 5| 12.3| 112 | # +-----+--------+-------+ 113 | 114 | df_expanded.schema 115 | # StructType(List(StructField(name,StringType,true),StructField(quantity,IntegerType,true),StructField(measure,FloatType,true)) 116 | ``` 117 | 118 | We can also re-encode them into protobuf. 119 | 120 | ```python 121 | df_reencoded = df_decoded.select(mc.to_protobuf(df_decoded.value, SimpleMessage).alias("value")) 122 | ``` 123 | 124 | For expanded data, we can also encode after packing into a struct column: 125 | 126 | ```python 127 | from pyspark.sql.functions import struct 128 | 129 | df_unexpanded = df_expanded.select( 130 | struct([df_expanded[c] for c in df_expanded.columns]).alias("value") 131 | ) 132 | df_reencoded = df_unexpanded.select( 133 | mc.to_protobuf(df_unexpanded.value, SimpleMessage).alias("value") 134 | ) 135 | ``` 136 | 137 | ### Conversion details 138 | 139 | Internally, `pbspark` uses protobuf's `MessageToDict`, which deserializes everything into JSON compatible objects by default. The exceptions are 140 | * protobuf's bytes type, which `MessageToDict` would decode to a base64-encoded string; `pbspark` will decode any bytes fields directly to a spark `BinaryType`. 141 | * protobuf's well known type, Timestamp type, which `MessageToDict` would decode to a string; `pbspark` will decode any Timestamp messages directly to a spark `TimestampType` (via python datetime objects). 142 | * protobuf's int64 types, which `MessageToDict` would decode to a string for compatibility reasons; `pbspark` will decode these to `LongType`. 143 | 144 | ### Custom conversion of message types 145 | 146 | Custom serde is also supported. Suppose we use our `NestedMessage` from the repository's example and we want to serialize the key and value together into a single string. 147 | 148 | ```protobuf 149 | message NestedMessage { 150 | string key = 1; 151 | string value = 2; 152 | } 153 | ``` 154 | 155 | We can create and register a custom serializer with the `MessageConverter`. 156 | 157 | ```python 158 | from pbspark import MessageConverter 159 | from example.example_pb2 import ExampleMessage 160 | from example.example_pb2 import NestedMessage 161 | from pyspark.sql.types import StringType 162 | 163 | mc = MessageConverter() 164 | 165 | # register a custom serializer 166 | # this will serialize the NestedMessages into a string rather than a 167 | # struct with `key` and `value` fields 168 | encode_nested = lambda message: message.key + ":" + message.value 169 | 170 | mc.register_serializer(NestedMessage, encode_nested, StringType()) 171 | 172 | # ... 173 | 174 | from pyspark.sql.session import SparkSession 175 | from pyspark import SparkContext 176 | from pyspark.serializers import CloudPickleSerializer 177 | 178 | sc = SparkContext(serializer=CloudPickleSerializer()) 179 | spark = SparkSession(sc).builder.getOrCreate() 180 | 181 | message = ExampleMessage(nested=NestedMessage(key="hello", value="world")) 182 | data = [{"value": message.SerializeToString()}] 183 | df_encoded = spark.createDataFrame(data) 184 | 185 | df_decoded = df_encoded.select(mc.from_protobuf(df_encoded.value, ExampleMessage).alias("value")) 186 | # rather than a struct the value of `nested` is a string 187 | df_decoded.select("value.nested").show() 188 | 189 | # +-----------+ 190 | # | nested| 191 | # +-----------+ 192 | # |hello:world| 193 | # +-----------+ 194 | ``` 195 | 196 | ### How to write conversion functions 197 | 198 | More generally, custom serde functions should be written in the following format. 199 | 200 | ```python 201 | # Encoding takes a message instance and returns the result 202 | # of the custom transformation. 203 | def encode_nested(message: NestedMessage) -> str: 204 | return message.key + ":" + message.value 205 | 206 | # Decoding takes the encoded value, a message instance, and path string 207 | # and populates the fields of the message instance. It returns `None`. 208 | # The path str is used in the protobuf parser to log parse error info. 209 | # Note that the first argument type should match the return type of the 210 | # encoder if using both. 211 | def decode_nested(s: str, message: NestedMessage, path: str): 212 | key, value = s.split(":") 213 | message.key = key 214 | message.value = value 215 | ``` 216 | 217 | ### Avoiding PicklingErrors 218 | 219 | A seemingly common issue with protobuf and distributed processing is when a `PicklingError` is encountered when transmitting (pickling) protobuf message types from a main process to a fork. To avoid this, you need to ensure that the fully qualified module name in your protoc-generated python file is the same as the module path from which the message type is imported. In other words, for the example here, the descriptor module passed to the builder is `example.example_pb2` 220 | 221 | ```python 222 | # from example/example_pb2.py 223 | _builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, "example.example_pb2", globals()) 224 | ^^^^^^^^^^^^^^^^^^^ 225 | ``` 226 | 227 | And to import the message type we would call the same module path: 228 | 229 | ```python 230 | from example.example_pb2 import ExampleMessage 231 | ^^^^^^^^^^^^^^^^^^^ 232 | ``` 233 | 234 | Note that the import module is the same as the one passed to the builder from the protoc-generated python. If these do not match, then you will encounter a `PicklingError`. From the pickle documentation: *pickle can save and restore class instances transparently, however the class definition must be importable and live in the same module as when the object was stored.* 235 | 236 | To ensure that the module path is correct, you should run `protoc` from the relative root path of your proto files. For example, in this project, in the `Makefile` under the `gen` command, we call `protoc` from the project root rather than from within the `example` directory. 237 | 238 | ```makefile 239 | export PROTO_PATH=. 240 | 241 | gen: 242 | poetry run protoc -I $$PROTO_PATH --python_out=$$PROTO_PATH --mypy_out=$$PROTO_PATH --proto_path=$$PROTO_PATH $$PROTO_PATH/example/*.proto 243 | ``` 244 | 245 | ### Known issues 246 | 247 | `RecursionError` when using self-referencing protobuf messages. Spark schemas do not allow for arbitrary depth, so protobuf messages which are circular- or self-referencing will result in infinite recursion errors when inferring the schema. If you have message structures like this you should resort to creating custom conversion functions, which forcibly limit the structural depth when converting these messages. 248 | 249 | ## Development 250 | 251 | Ensure that [asdf](https://asdf-vm.com/) is installed, then run `make setup`. 252 | 253 | * To format code `make fmt` 254 | * To test code `make test` 255 | * To run protoc `make gen` 256 | -------------------------------------------------------------------------------- /example/example_pb2.pyi: -------------------------------------------------------------------------------- 1 | """ 2 | @generated by mypy-protobuf. Do not edit manually! 3 | isort:skip_file 4 | """ 5 | import builtins 6 | import google.protobuf.descriptor 7 | import google.protobuf.duration_pb2 8 | import google.protobuf.internal.containers 9 | import google.protobuf.internal.enum_type_wrapper 10 | import google.protobuf.message 11 | import google.protobuf.timestamp_pb2 12 | import google.protobuf.wrappers_pb2 13 | import typing 14 | import typing_extensions 15 | 16 | DESCRIPTOR: google.protobuf.descriptor.FileDescriptor 17 | 18 | class SimpleMessage(google.protobuf.message.Message): 19 | DESCRIPTOR: google.protobuf.descriptor.Descriptor 20 | NAME_FIELD_NUMBER: builtins.int 21 | QUANTITY_FIELD_NUMBER: builtins.int 22 | MEASURE_FIELD_NUMBER: builtins.int 23 | name: typing.Text 24 | quantity: builtins.int 25 | measure: builtins.float 26 | def __init__( 27 | self, 28 | *, 29 | name: typing.Text = ..., 30 | quantity: builtins.int = ..., 31 | measure: builtins.float = ..., 32 | ) -> None: ... 33 | def ClearField( 34 | self, 35 | field_name: typing_extensions.Literal[ 36 | "measure", b"measure", "name", b"name", "quantity", b"quantity" 37 | ], 38 | ) -> None: ... 39 | 40 | global___SimpleMessage = SimpleMessage 41 | 42 | class NestedMessage(google.protobuf.message.Message): 43 | DESCRIPTOR: google.protobuf.descriptor.Descriptor 44 | KEY_FIELD_NUMBER: builtins.int 45 | VALUE_FIELD_NUMBER: builtins.int 46 | key: typing.Text 47 | value: typing.Text 48 | def __init__( 49 | self, 50 | *, 51 | key: typing.Text = ..., 52 | value: typing.Text = ..., 53 | ) -> None: ... 54 | def ClearField( 55 | self, field_name: typing_extensions.Literal["key", b"key", "value", b"value"] 56 | ) -> None: ... 57 | 58 | global___NestedMessage = NestedMessage 59 | 60 | class DecimalMessage(google.protobuf.message.Message): 61 | DESCRIPTOR: google.protobuf.descriptor.Descriptor 62 | VALUE_FIELD_NUMBER: builtins.int 63 | value: typing.Text 64 | def __init__( 65 | self, 66 | *, 67 | value: typing.Text = ..., 68 | ) -> None: ... 69 | def ClearField( 70 | self, field_name: typing_extensions.Literal["value", b"value"] 71 | ) -> None: ... 72 | 73 | global___DecimalMessage = DecimalMessage 74 | 75 | class ExampleMessage(google.protobuf.message.Message): 76 | DESCRIPTOR: google.protobuf.descriptor.Descriptor 77 | class _SomeEnum: 78 | ValueType = typing.NewType("ValueType", builtins.int) 79 | V: typing_extensions.TypeAlias = ValueType 80 | class _SomeEnumEnumTypeWrapper( 81 | google.protobuf.internal.enum_type_wrapper._EnumTypeWrapper[ 82 | ExampleMessage._SomeEnum.ValueType 83 | ], 84 | builtins.type, 85 | ): 86 | DESCRIPTOR: google.protobuf.descriptor.EnumDescriptor 87 | unspecified: ExampleMessage._SomeEnum.ValueType # 0 88 | first: ExampleMessage._SomeEnum.ValueType # 1 89 | second: ExampleMessage._SomeEnum.ValueType # 2 90 | class SomeEnum(_SomeEnum, metaclass=_SomeEnumEnumTypeWrapper): 91 | pass 92 | unspecified: ExampleMessage.SomeEnum.ValueType # 0 93 | first: ExampleMessage.SomeEnum.ValueType # 1 94 | second: ExampleMessage.SomeEnum.ValueType # 2 95 | class MapEntry(google.protobuf.message.Message): 96 | DESCRIPTOR: google.protobuf.descriptor.Descriptor 97 | KEY_FIELD_NUMBER: builtins.int 98 | VALUE_FIELD_NUMBER: builtins.int 99 | key: typing.Text 100 | value: typing.Text 101 | def __init__( 102 | self, 103 | *, 104 | key: typing.Text = ..., 105 | value: typing.Text = ..., 106 | ) -> None: ... 107 | def ClearField( 108 | self, 109 | field_name: typing_extensions.Literal["key", b"key", "value", b"value"], 110 | ) -> None: ... 111 | INT32_FIELD_NUMBER: builtins.int 112 | INT64_FIELD_NUMBER: builtins.int 113 | UINT32_FIELD_NUMBER: builtins.int 114 | UINT64_FIELD_NUMBER: builtins.int 115 | DOUBLE_FIELD_NUMBER: builtins.int 116 | FLOAT_FIELD_NUMBER: builtins.int 117 | BOOL_FIELD_NUMBER: builtins.int 118 | ENUM_FIELD_NUMBER: builtins.int 119 | STRING_FIELD_NUMBER: builtins.int 120 | NESTED_FIELD_NUMBER: builtins.int 121 | STRINGLIST_FIELD_NUMBER: builtins.int 122 | BYTES_FIELD_NUMBER: builtins.int 123 | SFIXED32_FIELD_NUMBER: builtins.int 124 | SFIXED64_FIELD_NUMBER: builtins.int 125 | SINT32_FIELD_NUMBER: builtins.int 126 | SINT64_FIELD_NUMBER: builtins.int 127 | FIXED32_FIELD_NUMBER: builtins.int 128 | FIXED64_FIELD_NUMBER: builtins.int 129 | ONEOFSTRING_FIELD_NUMBER: builtins.int 130 | ONEOFINT32_FIELD_NUMBER: builtins.int 131 | MAP_FIELD_NUMBER: builtins.int 132 | TIMESTAMP_FIELD_NUMBER: builtins.int 133 | DURATION_FIELD_NUMBER: builtins.int 134 | DECIMAL_FIELD_NUMBER: builtins.int 135 | DOUBLEVALUE_FIELD_NUMBER: builtins.int 136 | FLOATVALUE_FIELD_NUMBER: builtins.int 137 | INT64VALUE_FIELD_NUMBER: builtins.int 138 | UINT64VALUE_FIELD_NUMBER: builtins.int 139 | INT32VALUE_FIELD_NUMBER: builtins.int 140 | UINT32VALUE_FIELD_NUMBER: builtins.int 141 | BOOLVALUE_FIELD_NUMBER: builtins.int 142 | STRINGVALUE_FIELD_NUMBER: builtins.int 143 | BYTESVALUE_FIELD_NUMBER: builtins.int 144 | CASE_NAME_FIELD_NUMBER: builtins.int 145 | int32: builtins.int 146 | int64: builtins.int 147 | uint32: builtins.int 148 | uint64: builtins.int 149 | double: builtins.float 150 | float: builtins.float 151 | bool: builtins.bool 152 | enum: global___ExampleMessage.SomeEnum.ValueType 153 | string: typing.Text 154 | @property 155 | def nested(self) -> global___NestedMessage: ... 156 | @property 157 | def stringlist( 158 | self, 159 | ) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[ 160 | typing.Text 161 | ]: ... 162 | bytes: builtins.bytes 163 | sfixed32: builtins.int 164 | sfixed64: builtins.int 165 | sint32: builtins.int 166 | sint64: builtins.int 167 | fixed32: builtins.int 168 | fixed64: builtins.int 169 | oneofstring: typing.Text 170 | oneofint32: builtins.int 171 | @property 172 | def map( 173 | self, 174 | ) -> google.protobuf.internal.containers.ScalarMap[typing.Text, typing.Text]: ... 175 | @property 176 | def timestamp(self) -> google.protobuf.timestamp_pb2.Timestamp: ... 177 | @property 178 | def duration(self) -> google.protobuf.duration_pb2.Duration: ... 179 | @property 180 | def decimal(self) -> global___DecimalMessage: ... 181 | @property 182 | def doublevalue(self) -> google.protobuf.wrappers_pb2.DoubleValue: ... 183 | @property 184 | def floatvalue(self) -> google.protobuf.wrappers_pb2.FloatValue: ... 185 | @property 186 | def int64value(self) -> google.protobuf.wrappers_pb2.Int64Value: ... 187 | @property 188 | def uint64value(self) -> google.protobuf.wrappers_pb2.UInt64Value: ... 189 | @property 190 | def int32value(self) -> google.protobuf.wrappers_pb2.Int32Value: ... 191 | @property 192 | def uint32value(self) -> google.protobuf.wrappers_pb2.UInt32Value: ... 193 | @property 194 | def boolvalue(self) -> google.protobuf.wrappers_pb2.BoolValue: ... 195 | @property 196 | def stringvalue(self) -> google.protobuf.wrappers_pb2.StringValue: ... 197 | @property 198 | def bytesvalue(self) -> google.protobuf.wrappers_pb2.BytesValue: ... 199 | case_name: typing.Text 200 | def __init__( 201 | self, 202 | *, 203 | int32: builtins.int = ..., 204 | int64: builtins.int = ..., 205 | uint32: builtins.int = ..., 206 | uint64: builtins.int = ..., 207 | double: builtins.float = ..., 208 | float: builtins.float = ..., 209 | bool: builtins.bool = ..., 210 | enum: global___ExampleMessage.SomeEnum.ValueType = ..., 211 | string: typing.Text = ..., 212 | nested: typing.Optional[global___NestedMessage] = ..., 213 | stringlist: typing.Optional[typing.Iterable[typing.Text]] = ..., 214 | bytes: builtins.bytes = ..., 215 | sfixed32: builtins.int = ..., 216 | sfixed64: builtins.int = ..., 217 | sint32: builtins.int = ..., 218 | sint64: builtins.int = ..., 219 | fixed32: builtins.int = ..., 220 | fixed64: builtins.int = ..., 221 | oneofstring: typing.Text = ..., 222 | oneofint32: builtins.int = ..., 223 | map: typing.Optional[typing.Mapping[typing.Text, typing.Text]] = ..., 224 | timestamp: typing.Optional[google.protobuf.timestamp_pb2.Timestamp] = ..., 225 | duration: typing.Optional[google.protobuf.duration_pb2.Duration] = ..., 226 | decimal: typing.Optional[global___DecimalMessage] = ..., 227 | doublevalue: typing.Optional[google.protobuf.wrappers_pb2.DoubleValue] = ..., 228 | floatvalue: typing.Optional[google.protobuf.wrappers_pb2.FloatValue] = ..., 229 | int64value: typing.Optional[google.protobuf.wrappers_pb2.Int64Value] = ..., 230 | uint64value: typing.Optional[google.protobuf.wrappers_pb2.UInt64Value] = ..., 231 | int32value: typing.Optional[google.protobuf.wrappers_pb2.Int32Value] = ..., 232 | uint32value: typing.Optional[google.protobuf.wrappers_pb2.UInt32Value] = ..., 233 | boolvalue: typing.Optional[google.protobuf.wrappers_pb2.BoolValue] = ..., 234 | stringvalue: typing.Optional[google.protobuf.wrappers_pb2.StringValue] = ..., 235 | bytesvalue: typing.Optional[google.protobuf.wrappers_pb2.BytesValue] = ..., 236 | case_name: typing.Text = ..., 237 | ) -> None: ... 238 | def HasField( 239 | self, 240 | field_name: typing_extensions.Literal[ 241 | "boolvalue", 242 | b"boolvalue", 243 | "bytesvalue", 244 | b"bytesvalue", 245 | "decimal", 246 | b"decimal", 247 | "doublevalue", 248 | b"doublevalue", 249 | "duration", 250 | b"duration", 251 | "floatvalue", 252 | b"floatvalue", 253 | "int32value", 254 | b"int32value", 255 | "int64value", 256 | b"int64value", 257 | "nested", 258 | b"nested", 259 | "oneof", 260 | b"oneof", 261 | "oneofint32", 262 | b"oneofint32", 263 | "oneofstring", 264 | b"oneofstring", 265 | "stringvalue", 266 | b"stringvalue", 267 | "timestamp", 268 | b"timestamp", 269 | "uint32value", 270 | b"uint32value", 271 | "uint64value", 272 | b"uint64value", 273 | ], 274 | ) -> builtins.bool: ... 275 | def ClearField( 276 | self, 277 | field_name: typing_extensions.Literal[ 278 | "bool", 279 | b"bool", 280 | "boolvalue", 281 | b"boolvalue", 282 | "bytes", 283 | b"bytes", 284 | "bytesvalue", 285 | b"bytesvalue", 286 | "case_name", 287 | b"case_name", 288 | "decimal", 289 | b"decimal", 290 | "double", 291 | b"double", 292 | "doublevalue", 293 | b"doublevalue", 294 | "duration", 295 | b"duration", 296 | "enum", 297 | b"enum", 298 | "fixed32", 299 | b"fixed32", 300 | "fixed64", 301 | b"fixed64", 302 | "float", 303 | b"float", 304 | "floatvalue", 305 | b"floatvalue", 306 | "int32", 307 | b"int32", 308 | "int32value", 309 | b"int32value", 310 | "int64", 311 | b"int64", 312 | "int64value", 313 | b"int64value", 314 | "map", 315 | b"map", 316 | "nested", 317 | b"nested", 318 | "oneof", 319 | b"oneof", 320 | "oneofint32", 321 | b"oneofint32", 322 | "oneofstring", 323 | b"oneofstring", 324 | "sfixed32", 325 | b"sfixed32", 326 | "sfixed64", 327 | b"sfixed64", 328 | "sint32", 329 | b"sint32", 330 | "sint64", 331 | b"sint64", 332 | "string", 333 | b"string", 334 | "stringlist", 335 | b"stringlist", 336 | "stringvalue", 337 | b"stringvalue", 338 | "timestamp", 339 | b"timestamp", 340 | "uint32", 341 | b"uint32", 342 | "uint32value", 343 | b"uint32value", 344 | "uint64", 345 | b"uint64", 346 | "uint64value", 347 | b"uint64value", 348 | ], 349 | ) -> None: ... 350 | def WhichOneof( 351 | self, oneof_group: typing_extensions.Literal["oneof", b"oneof"] 352 | ) -> typing.Optional[typing_extensions.Literal["oneofstring", "oneofint32"]]: ... 353 | 354 | global___ExampleMessage = ExampleMessage 355 | 356 | class RecursiveMessage(google.protobuf.message.Message): 357 | DESCRIPTOR: google.protobuf.descriptor.Descriptor 358 | NOTE_FIELD_NUMBER: builtins.int 359 | MESSAGE_FIELD_NUMBER: builtins.int 360 | note: typing.Text 361 | @property 362 | def message(self) -> global___RecursiveMessage: ... 363 | def __init__( 364 | self, 365 | *, 366 | note: typing.Text = ..., 367 | message: typing.Optional[global___RecursiveMessage] = ..., 368 | ) -> None: ... 369 | def HasField( 370 | self, field_name: typing_extensions.Literal["message", b"message"] 371 | ) -> builtins.bool: ... 372 | def ClearField( 373 | self, 374 | field_name: typing_extensions.Literal["message", b"message", "note", b"note"], 375 | ) -> None: ... 376 | 377 | global___RecursiveMessage = RecursiveMessage 378 | -------------------------------------------------------------------------------- /tests/test_proto.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import json 3 | from decimal import Decimal 4 | 5 | import pytest 6 | from google.protobuf import descriptor_pb2 7 | from google.protobuf import json_format 8 | from google.protobuf.descriptor_pool import DescriptorPool 9 | from google.protobuf.duration_pb2 import Duration 10 | from google.protobuf.json_format import MessageToDict 11 | from google.protobuf.timestamp_pb2 import Timestamp 12 | from google.protobuf.wrappers_pb2 import BoolValue 13 | from google.protobuf.wrappers_pb2 import BytesValue 14 | from google.protobuf.wrappers_pb2 import DoubleValue 15 | from google.protobuf.wrappers_pb2 import FloatValue 16 | from google.protobuf.wrappers_pb2 import Int32Value 17 | from google.protobuf.wrappers_pb2 import Int64Value 18 | from google.protobuf.wrappers_pb2 import StringValue 19 | from google.protobuf.wrappers_pb2 import UInt32Value 20 | from google.protobuf.wrappers_pb2 import UInt64Value 21 | from pyspark import SparkContext 22 | from pyspark.serializers import CloudPickleSerializer 23 | from pyspark.sql.functions import col 24 | from pyspark.sql.functions import struct 25 | from pyspark.sql.session import SparkSession 26 | from pyspark.sql.types import ArrayType 27 | from pyspark.sql.types import BinaryType 28 | from pyspark.sql.types import BooleanType 29 | from pyspark.sql.types import DecimalType 30 | from pyspark.sql.types import DoubleType 31 | from pyspark.sql.types import FloatType 32 | from pyspark.sql.types import IntegerType 33 | from pyspark.sql.types import LongType 34 | from pyspark.sql.types import StringType 35 | from pyspark.sql.types import StructField 36 | from pyspark.sql.types import StructType 37 | from pyspark.sql.types import TimestampType 38 | from pyspark.sql.utils import PythonException 39 | 40 | from example.example_pb2 import DecimalMessage 41 | from example.example_pb2 import ExampleMessage 42 | from example.example_pb2 import NestedMessage 43 | from example.example_pb2 import RecursiveMessage 44 | from pbspark._proto import MessageConverter 45 | from pbspark._proto import _patched_convert_scalar_field_value 46 | from pbspark._proto import df_from_protobuf 47 | from pbspark._proto import df_to_protobuf 48 | from pbspark._proto import from_protobuf 49 | from pbspark._proto import to_protobuf 50 | from tests.fixtures import decimal_serializer # type: ignore[import] 51 | from tests.fixtures import encode_recursive 52 | 53 | 54 | @pytest.fixture() 55 | def example(): 56 | ts = Timestamp() 57 | ts.FromDatetime(datetime.datetime.utcnow()) 58 | dur = Duration(seconds=1, nanos=1) 59 | ex = ExampleMessage( 60 | string="asdf", 61 | int32=69, 62 | int64=789, 63 | uint64=404, 64 | float=4.20, 65 | stringlist=["one", "two", "three"], 66 | bytes=b"something", 67 | nested=NestedMessage( 68 | key="hello", 69 | value="world", 70 | ), 71 | enum=ExampleMessage.SomeEnum.first, 72 | timestamp=ts, 73 | duration=dur, 74 | decimal=DecimalMessage( 75 | value="3.50", 76 | ), 77 | doublevalue=DoubleValue(value=1.23), 78 | floatvalue=FloatValue(value=2.34), 79 | int64value=Int64Value(value=9001), 80 | uint64value=UInt64Value(value=9002), 81 | int32value=Int32Value(value=666), 82 | uint32value=UInt32Value(value=789), 83 | boolvalue=BoolValue(value=True), 84 | stringvalue=StringValue(value="qwerty"), 85 | bytesvalue=BytesValue(value=b"buf"), 86 | ) 87 | return ex 88 | 89 | 90 | @pytest.fixture(scope="session") 91 | def spark(): 92 | sc = SparkContext(serializer=CloudPickleSerializer()) 93 | spark = SparkSession(sc).builder.getOrCreate() 94 | spark.conf.set("spark.sql.session.timeZone", "UTC") 95 | return spark 96 | 97 | 98 | @pytest.fixture(params=[True, False]) 99 | def expanded(request): 100 | return request.param 101 | 102 | 103 | @pytest.fixture(params=[True, False]) 104 | def including_default_value_fields(request): 105 | return request.param 106 | 107 | 108 | @pytest.fixture(params=[True, False]) 109 | def use_integers_for_enums(request): 110 | return request.param 111 | 112 | 113 | @pytest.fixture(params=[True, False]) 114 | def preserving_proto_field_name(request): 115 | return request.param 116 | 117 | 118 | @pytest.fixture(params=[True, False]) 119 | def ignore_unknown_fields(request): 120 | return request.param 121 | 122 | 123 | def test_get_spark_schema(): 124 | mc = MessageConverter() 125 | mc.register_serializer( 126 | DecimalMessage, decimal_serializer, DecimalType(precision=10, scale=2) 127 | ) 128 | schema = mc.get_spark_schema(ExampleMessage) 129 | expected_schema = StructType( 130 | [ 131 | StructField("int32", IntegerType(), True), 132 | StructField("int64", LongType(), True), 133 | StructField("uint32", LongType(), True), 134 | StructField("uint64", LongType(), True), 135 | StructField("double", DoubleType(), True), 136 | StructField("float", FloatType(), True), 137 | StructField("bool", BooleanType(), True), 138 | StructField("enum", StringType(), True), 139 | StructField("string", StringType(), True), 140 | StructField( 141 | "nested", 142 | StructType( 143 | [ 144 | StructField("key", StringType(), True), 145 | StructField("value", StringType(), True), 146 | ] 147 | ), 148 | True, 149 | ), 150 | StructField("stringlist", ArrayType(StringType(), True), True), 151 | StructField("bytes", BinaryType(), True), 152 | StructField("sfixed32", IntegerType(), True), 153 | StructField("sfixed64", LongType(), True), 154 | StructField("sint32", IntegerType(), True), 155 | StructField("sint64", LongType(), True), 156 | StructField("fixed32", LongType(), True), 157 | StructField("fixed64", LongType(), True), 158 | StructField("oneofstring", StringType(), True), 159 | StructField("oneofint32", IntegerType(), True), 160 | StructField( 161 | "map", 162 | ArrayType( 163 | StructType( 164 | [ 165 | StructField("key", StringType(), True), 166 | StructField("value", StringType(), True), 167 | ] 168 | ), 169 | True, 170 | ), 171 | True, 172 | ), 173 | StructField("timestamp", TimestampType(), True), 174 | StructField("duration", StringType(), True), 175 | StructField("decimal", DecimalType(10, 2), True), 176 | StructField("doublevalue", DoubleType(), True), 177 | StructField("floatvalue", FloatType(), True), 178 | StructField("int64value", LongType(), True), 179 | StructField("uint64value", LongType(), True), 180 | StructField("int32value", IntegerType(), True), 181 | StructField("uint32value", LongType(), True), 182 | StructField("boolvalue", BooleanType(), True), 183 | StructField("stringvalue", StringType(), True), 184 | StructField("bytesvalue", BinaryType(), True), 185 | StructField("caseName", StringType(), True), 186 | ] 187 | ) 188 | assert schema == expected_schema 189 | 190 | 191 | def test_patched_convert_scalar_field_value(): 192 | assert not hasattr(json_format._ConvertScalarFieldValue, "__wrapped__") 193 | with _patched_convert_scalar_field_value(): 194 | assert hasattr(json_format._ConvertScalarFieldValue, "__wrapped__") 195 | assert not hasattr(json_format._ConvertScalarFieldValue, "__wrapped__") 196 | 197 | 198 | def test_get_decoder(example): 199 | mc = MessageConverter() 200 | mc.register_serializer( 201 | DecimalMessage, decimal_serializer, DecimalType(precision=10, scale=2) 202 | ) 203 | decoder = mc.get_decoder(ExampleMessage) 204 | s = example.SerializeToString() 205 | decoded = decoder(s) 206 | assert decoded == mc.message_to_dict(example) 207 | expected = { 208 | "int32": 69, 209 | "int64": 789, 210 | "uint64": 404, 211 | "float": 4.2, 212 | "enum": "first", 213 | "string": "asdf", 214 | "nested": {"key": "hello", "value": "world"}, 215 | "stringlist": ["one", "two", "three"], 216 | "bytes": b"something", 217 | "timestamp": example.timestamp.ToDatetime(), 218 | "duration": example.duration.ToJsonString(), 219 | "decimal": Decimal(example.decimal.value), 220 | "doublevalue": 1.23, 221 | "floatvalue": 2.34, 222 | "int64value": 9001, 223 | "uint64value": 9002, 224 | "int32value": 666, 225 | "uint32value": 789, 226 | "boolvalue": True, 227 | "stringvalue": "qwerty", 228 | "bytesvalue": b"buf", 229 | } 230 | assert decoded == expected 231 | 232 | 233 | def test_from_protobuf( 234 | example, spark, preserving_proto_field_name, use_integers_for_enums 235 | ): 236 | mc = MessageConverter() 237 | mc.register_serializer( 238 | DecimalMessage, decimal_serializer, DecimalType(precision=10, scale=2) 239 | ) 240 | 241 | data = [{"value": example.SerializeToString()}] 242 | 243 | df = spark.createDataFrame(data) # type: ignore[type-var] 244 | dfs = df.select( 245 | mc.from_protobuf( 246 | data=df.value, 247 | message_type=ExampleMessage, 248 | preserving_proto_field_name=preserving_proto_field_name, 249 | use_integers_for_enums=use_integers_for_enums, 250 | ).alias("value") 251 | ) 252 | dfe = dfs.select("value.*") 253 | dfe.show() 254 | dfe.printSchema() 255 | 256 | if preserving_proto_field_name: 257 | field_names = [field.name for field in ExampleMessage.DESCRIPTOR.fields] 258 | else: 259 | field_names = [ 260 | field.camelcase_name for field in ExampleMessage.DESCRIPTOR.fields 261 | ] 262 | for field_name in field_names: 263 | assert field_name in dfe.columns 264 | 265 | if use_integers_for_enums: 266 | assert StructField("enum", IntegerType(), True) in dfe.schema.fields 267 | else: 268 | assert StructField("enum", StringType(), True) in dfe.schema.fields 269 | 270 | 271 | def test_round_trip(example, spark): 272 | mc = MessageConverter() 273 | 274 | data = [{"value": example.SerializeToString()}] 275 | 276 | df = spark.createDataFrame(data) # type: ignore[type-var] 277 | df.show() 278 | 279 | df.printSchema() 280 | dfs = df.select(mc.from_protobuf(df.value, ExampleMessage).alias("value")) 281 | df_again = dfs.select(mc.to_protobuf(dfs.value, ExampleMessage).alias("value")) 282 | df_again.show() 283 | assert df.schema == df_again.schema 284 | assert df.collect() == df_again.collect() 285 | 286 | # make a flattened df and then encode from unflattened df 287 | df_flattened = dfs.select("value.*") 288 | 289 | df_unflattened = df_flattened.select( 290 | struct([df_flattened[c] for c in df_flattened.columns]).alias("value") 291 | ) 292 | df_unflattened.show() 293 | schema = df_unflattened.schema 294 | # this will be false because there are no null records 295 | schema.fields[0].nullable = True 296 | assert dfs.schema == schema 297 | assert dfs.collect() == df_unflattened.collect() 298 | df_again = df_unflattened.select( 299 | mc.to_protobuf(df_unflattened.value, ExampleMessage).alias("value") 300 | ) 301 | df_again.show() 302 | assert df.schema == df_again.schema 303 | assert df.collect() == df_again.collect() 304 | 305 | 306 | def test_recursive_message(spark): 307 | message = RecursiveMessage( 308 | note="one", 309 | message=RecursiveMessage(note="two", message=RecursiveMessage(note="three")), 310 | ) 311 | 312 | return_type = StructType( 313 | [ 314 | StructField("note", StringType(), True), 315 | StructField( 316 | "message", 317 | StructType( 318 | [ 319 | StructField("note", StringType(), True), 320 | StructField("message", StringType(), True), 321 | ] 322 | ), 323 | True, 324 | ), 325 | ] 326 | ) 327 | expected = { 328 | "note": "one", 329 | "message": { 330 | "note": "two", 331 | "message": json.dumps(MessageToDict(message.message.message)), 332 | }, 333 | } 334 | assert encode_recursive(message) == expected 335 | mc = MessageConverter() 336 | mc.register_serializer(RecursiveMessage, encode_recursive, return_type) 337 | 338 | data = [{"value": message.SerializeToString()}] 339 | 340 | df = spark.createDataFrame(data) # type: ignore[type-var] 341 | df.show() 342 | 343 | dfs = df.select(mc.from_protobuf(df.value, RecursiveMessage).alias("value")) 344 | dfs.show(truncate=False) 345 | data = dfs.collect() 346 | assert data[0].asDict(True)["value"] == expected 347 | 348 | 349 | def test_to_from_protobuf(example, spark, expanded): 350 | data = [{"value": example.SerializeToString()}] 351 | 352 | df = spark.createDataFrame(data) # type: ignore[type-var] 353 | 354 | df_decoded = df.select(from_protobuf(df.value, ExampleMessage).alias("value")) 355 | 356 | mc = MessageConverter() 357 | assert df_decoded.schema.fields[0].dataType == mc.get_spark_schema(ExampleMessage) 358 | 359 | df_encoded = df_decoded.select( 360 | to_protobuf(df_decoded.value, ExampleMessage).alias("value") 361 | ) 362 | 363 | assert df_encoded.columns == ["value"] 364 | assert df_encoded.schema == df.schema 365 | assert df.collect() == df_encoded.collect() 366 | 367 | 368 | def test_df_to_from_protobuf(example, spark, expanded): 369 | data = [{"value": example.SerializeToString()}] 370 | 371 | df = spark.createDataFrame(data) # type: ignore[type-var] 372 | 373 | df_decoded = df_from_protobuf(df, ExampleMessage, expanded=expanded) 374 | 375 | mc = MessageConverter() 376 | schema = mc.get_spark_schema(ExampleMessage) 377 | if expanded: 378 | assert df_decoded.schema == schema 379 | else: 380 | assert df_decoded.schema.fields[0].dataType == schema 381 | 382 | df_encoded = df_to_protobuf(df_decoded, ExampleMessage, expanded=expanded) 383 | 384 | assert df_encoded.columns == ["value"] 385 | assert df_encoded.schema == df.schema 386 | assert df.collect() == df_encoded.collect() 387 | 388 | 389 | def test_including_default_value_fields(spark, including_default_value_fields): 390 | example = ExampleMessage(string="asdf") 391 | data = [{"value": example.SerializeToString()}] 392 | 393 | df = spark.createDataFrame(data) # type: ignore[type-var] 394 | 395 | df_decoded = df_from_protobuf( 396 | df=df, 397 | message_type=ExampleMessage, 398 | expanded=True, 399 | including_default_value_fields=including_default_value_fields, 400 | ) 401 | data = df_decoded.collect() 402 | if including_default_value_fields: 403 | assert data[0].asDict(True)["int32"] == 0 404 | else: 405 | assert data[0].asDict(True)["int32"] is None 406 | 407 | 408 | def test_use_integers_for_enums(spark, use_integers_for_enums): 409 | example = ExampleMessage(enum=ExampleMessage.SomeEnum.first) 410 | data = [{"value": example.SerializeToString()}] 411 | 412 | df = spark.createDataFrame(data) # type: ignore[type-var] 413 | 414 | df_decoded = df_from_protobuf( 415 | df=df, 416 | message_type=ExampleMessage, 417 | expanded=True, 418 | use_integers_for_enums=use_integers_for_enums, 419 | ) 420 | data = df_decoded.collect() 421 | if use_integers_for_enums: 422 | assert data[0].asDict(True)["enum"] == 1 423 | else: 424 | assert data[0].asDict(True)["enum"] == "first" 425 | 426 | 427 | def test_preserving_proto_field_name(spark, preserving_proto_field_name): 428 | example = ExampleMessage(case_name="asdf") 429 | data = [{"value": example.SerializeToString()}] 430 | 431 | df = spark.createDataFrame(data) # type: ignore[type-var] 432 | 433 | df_decoded = df_from_protobuf( 434 | df=df, 435 | message_type=ExampleMessage, 436 | expanded=True, 437 | preserving_proto_field_name=preserving_proto_field_name, 438 | ) 439 | data = df_decoded.collect() 440 | if preserving_proto_field_name: 441 | assert data[0].asDict(True)["case_name"] == "asdf" 442 | else: 443 | assert data[0].asDict(True)["caseName"] == "asdf" 444 | 445 | 446 | def test_float_precision(spark): 447 | example = ExampleMessage(float=1.234567) 448 | data = [{"value": example.SerializeToString()}] 449 | 450 | df = spark.createDataFrame(data) # type: ignore[type-var] 451 | 452 | df_decoded = df_from_protobuf( 453 | df=df, 454 | message_type=ExampleMessage, 455 | expanded=True, 456 | float_precision=2, 457 | ) 458 | data = df_decoded.collect() 459 | assert data[0].asDict(True)["float"] == pytest.approx(1.2) 460 | 461 | 462 | def test_ignore_unknown_fields(spark, ignore_unknown_fields): 463 | example = ExampleMessage(string="asdf") 464 | data = [{"value": example.SerializeToString()}] 465 | 466 | df = spark.createDataFrame(data) # type: ignore[type-var] 467 | 468 | df_decoded = df_from_protobuf( 469 | df=df, 470 | message_type=ExampleMessage, 471 | expanded=True, 472 | ) 473 | df_decoded = df_decoded.withColumn("unknown", col("string")) 474 | df_decoded.show() 475 | df_recoded = df_to_protobuf( 476 | df=df_decoded, 477 | message_type=ExampleMessage, 478 | ignore_unknown_fields=ignore_unknown_fields, 479 | expanded=True, 480 | ) 481 | if not ignore_unknown_fields: 482 | with pytest.raises(PythonException): 483 | df_recoded.collect() 484 | else: 485 | df_recoded.collect() 486 | -------------------------------------------------------------------------------- /poetry.lock: -------------------------------------------------------------------------------- 1 | # This file is automatically @generated by Poetry 1.5.1 and should not be changed by hand. 2 | 3 | [[package]] 4 | name = "atomicwrites" 5 | version = "1.4.0" 6 | description = "Atomic file writes." 7 | optional = false 8 | python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*" 9 | files = [ 10 | {file = "atomicwrites-1.4.0-py2.py3-none-any.whl", hash = "sha256:6d1784dea7c0c8d4a5172b6c620f40b6e4cbfdf96d783691f2e1302a7b88e197"}, 11 | {file = "atomicwrites-1.4.0.tar.gz", hash = "sha256:ae70396ad1a434f9c7046fd2dd196fc04b12f9e91ffb859164193be8b6168a7a"}, 12 | ] 13 | 14 | [[package]] 15 | name = "attrs" 16 | version = "21.4.0" 17 | description = "Classes Without Boilerplate" 18 | optional = false 19 | python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*" 20 | files = [ 21 | {file = "attrs-21.4.0-py2.py3-none-any.whl", hash = "sha256:2d27e3784d7a565d36ab851fe94887c5eccd6a463168875832a1be79c82828b4"}, 22 | {file = "attrs-21.4.0.tar.gz", hash = "sha256:626ba8234211db98e869df76230a137c4c40a12d72445c45d5f5b716f076e2fd"}, 23 | ] 24 | 25 | [package.extras] 26 | dev = ["cloudpickle", "coverage[toml] (>=5.0.2)", "furo", "hypothesis", "mypy", "pre-commit", "pympler", "pytest (>=4.3.0)", "pytest-mypy-plugins", "six", "sphinx", "sphinx-notfound-page", "zope.interface"] 27 | docs = ["furo", "sphinx", "sphinx-notfound-page", "zope.interface"] 28 | tests = ["cloudpickle", "coverage[toml] (>=5.0.2)", "hypothesis", "mypy", "pympler", "pytest (>=4.3.0)", "pytest-mypy-plugins", "six", "zope.interface"] 29 | tests-no-zope = ["cloudpickle", "coverage[toml] (>=5.0.2)", "hypothesis", "mypy", "pympler", "pytest (>=4.3.0)", "pytest-mypy-plugins", "six"] 30 | 31 | [[package]] 32 | name = "black" 33 | version = "21.12b0" 34 | description = "The uncompromising code formatter." 35 | optional = false 36 | python-versions = ">=3.6.2" 37 | files = [ 38 | {file = "black-21.12b0-py3-none-any.whl", hash = "sha256:a615e69ae185e08fdd73e4715e260e2479c861b5740057fde6e8b4e3b7dd589f"}, 39 | {file = "black-21.12b0.tar.gz", hash = "sha256:77b80f693a569e2e527958459634f18df9b0ba2625ba4e0c2d5da5be42e6f2b3"}, 40 | ] 41 | 42 | [package.dependencies] 43 | click = ">=7.1.2" 44 | mypy-extensions = ">=0.4.3" 45 | pathspec = ">=0.9.0,<1" 46 | platformdirs = ">=2" 47 | tomli = ">=0.2.6,<2.0.0" 48 | typed-ast = {version = ">=1.4.2", markers = "python_version < \"3.8\" and implementation_name == \"cpython\""} 49 | typing-extensions = [ 50 | {version = ">=3.10.0.0", markers = "python_version < \"3.10\""}, 51 | {version = ">=3.10.0.0,<3.10.0.1 || >3.10.0.1", markers = "python_version >= \"3.10\""}, 52 | ] 53 | 54 | [package.extras] 55 | colorama = ["colorama (>=0.4.3)"] 56 | d = ["aiohttp (>=3.7.4)"] 57 | jupyter = ["ipython (>=7.8.0)", "tokenize-rt (>=3.2.0)"] 58 | python2 = ["typed-ast (>=1.4.3)"] 59 | uvloop = ["uvloop (>=0.15.2)"] 60 | 61 | [[package]] 62 | name = "click" 63 | version = "8.0.4" 64 | description = "Composable command line interface toolkit" 65 | optional = false 66 | python-versions = ">=3.6" 67 | files = [ 68 | {file = "click-8.0.4-py3-none-any.whl", hash = "sha256:6a7a62563bbfabfda3a38f3023a1db4a35978c0abd76f6c9605ecd6554d6d9b1"}, 69 | {file = "click-8.0.4.tar.gz", hash = "sha256:8458d7b1287c5fb128c90e23381cf99dcde74beaf6c7ff6384ce84d6fe090adb"}, 70 | ] 71 | 72 | [package.dependencies] 73 | colorama = {version = "*", markers = "platform_system == \"Windows\""} 74 | importlib-metadata = {version = "*", markers = "python_version < \"3.8\""} 75 | 76 | [[package]] 77 | name = "colorama" 78 | version = "0.4.4" 79 | description = "Cross-platform colored terminal text." 80 | optional = false 81 | python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*" 82 | files = [ 83 | {file = "colorama-0.4.4-py2.py3-none-any.whl", hash = "sha256:9f47eda37229f68eee03b24b9748937c7dc3868f906e8ba69fbcbdd3bc5dc3e2"}, 84 | {file = "colorama-0.4.4.tar.gz", hash = "sha256:5941b2b48a20143d2267e95b1c2a7603ce057ee39fd88e7329b0c292aa16869b"}, 85 | ] 86 | 87 | [[package]] 88 | name = "importlib-metadata" 89 | version = "4.11.3" 90 | description = "Read metadata from Python packages" 91 | optional = false 92 | python-versions = ">=3.7" 93 | files = [ 94 | {file = "importlib_metadata-4.11.3-py3-none-any.whl", hash = "sha256:1208431ca90a8cca1a6b8af391bb53c1a2db74e5d1cef6ddced95d4b2062edc6"}, 95 | {file = "importlib_metadata-4.11.3.tar.gz", hash = "sha256:ea4c597ebf37142f827b8f39299579e31685c31d3a438b59f469406afd0f2539"}, 96 | ] 97 | 98 | [package.dependencies] 99 | typing-extensions = {version = ">=3.6.4", markers = "python_version < \"3.8\""} 100 | zipp = ">=0.5" 101 | 102 | [package.extras] 103 | docs = ["jaraco.packaging (>=9)", "rst.linker (>=1.9)", "sphinx"] 104 | perf = ["ipython"] 105 | testing = ["flufl.flake8", "importlib-resources (>=1.3)", "packaging", "pyfakefs", "pytest (>=6)", "pytest-black (>=0.3.7)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=1.0.1)", "pytest-flake8", "pytest-mypy (>=0.9.1)", "pytest-perf (>=0.9.2)"] 106 | 107 | [[package]] 108 | name = "iniconfig" 109 | version = "1.1.1" 110 | description = "iniconfig: brain-dead simple config-ini parsing" 111 | optional = false 112 | python-versions = "*" 113 | files = [ 114 | {file = "iniconfig-1.1.1-py2.py3-none-any.whl", hash = "sha256:011e24c64b7f47f6ebd835bb12a743f2fbe9a26d4cecaa7f53bc4f35ee9da8b3"}, 115 | {file = "iniconfig-1.1.1.tar.gz", hash = "sha256:bc3af051d7d14b2ee5ef9969666def0cd1a000e121eaea580d4a313df4b37f32"}, 116 | ] 117 | 118 | [[package]] 119 | name = "isort" 120 | version = "5.10.1" 121 | description = "A Python utility / library to sort Python imports." 122 | optional = false 123 | python-versions = ">=3.6.1,<4.0" 124 | files = [ 125 | {file = "isort-5.10.1-py3-none-any.whl", hash = "sha256:6f62d78e2f89b4500b080fe3a81690850cd254227f27f75c3a0c491a1f351ba7"}, 126 | {file = "isort-5.10.1.tar.gz", hash = "sha256:e8443a5e7a020e9d7f97f1d7d9cd17c88bcb3bc7e218bf9cf5095fe550be2951"}, 127 | ] 128 | 129 | [package.extras] 130 | colors = ["colorama (>=0.4.3,<0.5.0)"] 131 | pipfile-deprecated-finder = ["pipreqs", "requirementslib"] 132 | plugins = ["setuptools"] 133 | requirements-deprecated-finder = ["pip-api", "pipreqs"] 134 | 135 | [[package]] 136 | name = "mypy" 137 | version = "0.942" 138 | description = "Optional static typing for Python" 139 | optional = false 140 | python-versions = ">=3.6" 141 | files = [ 142 | {file = "mypy-0.942-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:5bf44840fb43ac4074636fd47ee476d73f0039f4f54e86d7265077dc199be24d"}, 143 | {file = "mypy-0.942-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:dcd955f36e0180258a96f880348fbca54ce092b40fbb4b37372ae3b25a0b0a46"}, 144 | {file = "mypy-0.942-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:6776e5fa22381cc761df53e7496a805801c1a751b27b99a9ff2f0ca848c7eca0"}, 145 | {file = "mypy-0.942-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:edf7237137a1a9330046dbb14796963d734dd740a98d5e144a3eb1d267f5f9ee"}, 146 | {file = "mypy-0.942-cp310-cp310-win_amd64.whl", hash = "sha256:64235137edc16bee6f095aba73be5334677d6f6bdb7fa03cfab90164fa294a17"}, 147 | {file = "mypy-0.942-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:b840cfe89c4ab6386c40300689cd8645fc8d2d5f20101c7f8bd23d15fca14904"}, 148 | {file = "mypy-0.942-cp36-cp36m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:2b184db8c618c43c3a31b32ff00cd28195d39e9c24e7c3b401f3db7f6e5767f5"}, 149 | {file = "mypy-0.942-cp36-cp36m-win_amd64.whl", hash = "sha256:1a0459c333f00e6a11cbf6b468b870c2b99a906cb72d6eadf3d1d95d38c9352c"}, 150 | {file = "mypy-0.942-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:4c3e497588afccfa4334a9986b56f703e75793133c4be3a02d06a3df16b67a58"}, 151 | {file = "mypy-0.942-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:6f6ad963172152e112b87cc7ec103ba0f2db2f1cd8997237827c052a3903eaa6"}, 152 | {file = "mypy-0.942-cp37-cp37m-win_amd64.whl", hash = "sha256:0e2dd88410937423fba18e57147dd07cd8381291b93d5b1984626f173a26543e"}, 153 | {file = "mypy-0.942-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:246e1aa127d5b78488a4a0594bd95f6d6fb9d63cf08a66dafbff8595d8891f67"}, 154 | {file = "mypy-0.942-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:d8d3ba77e56b84cd47a8ee45b62c84b6d80d32383928fe2548c9a124ea0a725c"}, 155 | {file = "mypy-0.942-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:2bc249409a7168d37c658e062e1ab5173300984a2dada2589638568ddc1db02b"}, 156 | {file = "mypy-0.942-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:9521c1265ccaaa1791d2c13582f06facf815f426cd8b07c3a485f486a8ffc1f3"}, 157 | {file = "mypy-0.942-cp38-cp38-win_amd64.whl", hash = "sha256:e865fec858d75b78b4d63266c9aff770ecb6a39dfb6d6b56c47f7f8aba6baba8"}, 158 | {file = "mypy-0.942-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:6ce34a118d1a898f47def970a2042b8af6bdcc01546454726c7dd2171aa6dfca"}, 159 | {file = "mypy-0.942-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:10daab80bc40f84e3f087d896cdb53dc811a9f04eae4b3f95779c26edee89d16"}, 160 | {file = "mypy-0.942-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:3841b5433ff936bff2f4dc8d54cf2cdbfea5d8e88cedfac45c161368e5770ba6"}, 161 | {file = "mypy-0.942-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:6f7106cbf9cc2f403693bf50ed7c9fa5bb3dfa9007b240db3c910929abe2a322"}, 162 | {file = "mypy-0.942-cp39-cp39-win_amd64.whl", hash = "sha256:7742d2c4e46bb5017b51c810283a6a389296cda03df805a4f7869a6f41246534"}, 163 | {file = "mypy-0.942-py3-none-any.whl", hash = "sha256:a1b383fe99678d7402754fe90448d4037f9512ce70c21f8aee3b8bf48ffc51db"}, 164 | {file = "mypy-0.942.tar.gz", hash = "sha256:17e44649fec92e9f82102b48a3bf7b4a5510ad0cd22fa21a104826b5db4903e2"}, 165 | ] 166 | 167 | [package.dependencies] 168 | mypy-extensions = ">=0.4.3" 169 | tomli = ">=1.1.0" 170 | typed-ast = {version = ">=1.4.0,<2", markers = "python_version < \"3.8\""} 171 | typing-extensions = ">=3.10" 172 | 173 | [package.extras] 174 | dmypy = ["psutil (>=4.0)"] 175 | python2 = ["typed-ast (>=1.4.0,<2)"] 176 | reports = ["lxml"] 177 | 178 | [[package]] 179 | name = "mypy-extensions" 180 | version = "0.4.3" 181 | description = "Experimental type system extensions for programs checked with the mypy typechecker." 182 | optional = false 183 | python-versions = "*" 184 | files = [ 185 | {file = "mypy_extensions-0.4.3-py2.py3-none-any.whl", hash = "sha256:090fedd75945a69ae91ce1303b5824f428daf5a028d2f6ab8a299250a846f15d"}, 186 | {file = "mypy_extensions-0.4.3.tar.gz", hash = "sha256:2d82818f5bb3e369420cb3c4060a7970edba416647068eb4c5343488a6c604a8"}, 187 | ] 188 | 189 | [[package]] 190 | name = "mypy-protobuf" 191 | version = "3.2.0" 192 | description = "Generate mypy stub files from protobuf specs" 193 | optional = false 194 | python-versions = ">=3.6" 195 | files = [ 196 | {file = "mypy-protobuf-3.2.0.tar.gz", hash = "sha256:730aa15337c38f0446fbe08f6c6c2370ee01d395125369d4b70e08b1e2ee30ee"}, 197 | {file = "mypy_protobuf-3.2.0-py3-none-any.whl", hash = "sha256:65fc0492165f4a3c0aff69b03e34096fc1453e4dac8f14b4e9c2306cdde06010"}, 198 | ] 199 | 200 | [package.dependencies] 201 | protobuf = ">=3.19.3" 202 | types-protobuf = ">=3.19.5" 203 | 204 | [[package]] 205 | name = "packaging" 206 | version = "21.3" 207 | description = "Core utilities for Python packages" 208 | optional = false 209 | python-versions = ">=3.6" 210 | files = [ 211 | {file = "packaging-21.3-py3-none-any.whl", hash = "sha256:ef103e05f519cdc783ae24ea4e2e0f508a9c99b2d4969652eed6a2e1ea5bd522"}, 212 | {file = "packaging-21.3.tar.gz", hash = "sha256:dd47c42927d89ab911e606518907cc2d3a1f38bbd026385970643f9c5b8ecfeb"}, 213 | ] 214 | 215 | [package.dependencies] 216 | pyparsing = ">=2.0.2,<3.0.5 || >3.0.5" 217 | 218 | [[package]] 219 | name = "pathspec" 220 | version = "0.9.0" 221 | description = "Utility library for gitignore style pattern matching of file paths." 222 | optional = false 223 | python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,>=2.7" 224 | files = [ 225 | {file = "pathspec-0.9.0-py2.py3-none-any.whl", hash = "sha256:7d15c4ddb0b5c802d161efc417ec1a2558ea2653c2e8ad9c19098201dc1c993a"}, 226 | {file = "pathspec-0.9.0.tar.gz", hash = "sha256:e564499435a2673d586f6b2130bb5b95f04a3ba06f81b8f895b651a3c76aabb1"}, 227 | ] 228 | 229 | [[package]] 230 | name = "platformdirs" 231 | version = "2.5.2" 232 | description = "A small Python module for determining appropriate platform-specific dirs, e.g. a \"user data dir\"." 233 | optional = false 234 | python-versions = ">=3.7" 235 | files = [ 236 | {file = "platformdirs-2.5.2-py3-none-any.whl", hash = "sha256:027d8e83a2d7de06bbac4e5ef7e023c02b863d7ea5d079477e722bb41ab25788"}, 237 | {file = "platformdirs-2.5.2.tar.gz", hash = "sha256:58c8abb07dcb441e6ee4b11d8df0ac856038f944ab98b7be6b27b2a3c7feef19"}, 238 | ] 239 | 240 | [package.extras] 241 | docs = ["furo (>=2021.7.5b38)", "proselint (>=0.10.2)", "sphinx (>=4)", "sphinx-autodoc-typehints (>=1.12)"] 242 | test = ["appdirs (==1.4.4)", "pytest (>=6)", "pytest-cov (>=2.7)", "pytest-mock (>=3.6)"] 243 | 244 | [[package]] 245 | name = "pluggy" 246 | version = "1.0.0" 247 | description = "plugin and hook calling mechanisms for python" 248 | optional = false 249 | python-versions = ">=3.6" 250 | files = [ 251 | {file = "pluggy-1.0.0-py2.py3-none-any.whl", hash = "sha256:74134bbf457f031a36d68416e1509f34bd5ccc019f0bcc952c7b909d06b37bd3"}, 252 | {file = "pluggy-1.0.0.tar.gz", hash = "sha256:4224373bacce55f955a878bf9cfa763c1e360858e330072059e10bad68531159"}, 253 | ] 254 | 255 | [package.dependencies] 256 | importlib-metadata = {version = ">=0.12", markers = "python_version < \"3.8\""} 257 | 258 | [package.extras] 259 | dev = ["pre-commit", "tox"] 260 | testing = ["pytest", "pytest-benchmark"] 261 | 262 | [[package]] 263 | name = "protobuf" 264 | version = "4.21.1" 265 | description = "" 266 | optional = false 267 | python-versions = ">=3.7" 268 | files = [ 269 | {file = "protobuf-4.21.1-cp310-abi3-win32.whl", hash = "sha256:52c1e44e25f2949be7ffa7c66acbfea940b0945dd416920231f7cb30ea5ac6db"}, 270 | {file = "protobuf-4.21.1-cp310-abi3-win_amd64.whl", hash = "sha256:72d357cc4d834cc85bd957e8b8e1f4b64c2eac9ca1a942efeb8eb2e723fca852"}, 271 | {file = "protobuf-4.21.1-cp37-abi3-macosx_10_9_universal2.whl", hash = "sha256:3767c64593a49c7ac0accd08ed39ce42744405f0989d468f0097a17496fdbe84"}, 272 | {file = "protobuf-4.21.1-cp37-abi3-manylinux2014_aarch64.whl", hash = "sha256:0d4719e724472e296062ba8e82a36d64693fcfdb550d9dff98af70ca79eafe3d"}, 273 | {file = "protobuf-4.21.1-cp37-abi3-manylinux2014_x86_64.whl", hash = "sha256:a4c0c6f2f95a559e59a0258d3e4b186f340cbdc5adec5ce1bc06d01972527c88"}, 274 | {file = "protobuf-4.21.1-cp37-cp37m-win32.whl", hash = "sha256:32fff501b6df3050936d1839b80ea5899bf34db24792d223d7640611f67de15a"}, 275 | {file = "protobuf-4.21.1-cp37-cp37m-win_amd64.whl", hash = "sha256:b3d7d4b4945fe3c001403b6c24442901a5e58c0a3059290f5a63523ed4435f82"}, 276 | {file = "protobuf-4.21.1-cp38-cp38-win32.whl", hash = "sha256:34400fd76f85bdae9a2e9c1444ea4699c0280962423eff4418765deceebd81b5"}, 277 | {file = "protobuf-4.21.1-cp38-cp38-win_amd64.whl", hash = "sha256:c8829092c5aeb61619161269b2f8a2e36fd7cb26abbd9282d3bc453f02769146"}, 278 | {file = "protobuf-4.21.1-cp39-cp39-win32.whl", hash = "sha256:2b35602cb65d53c168c104469e714bf68670335044c38eee3c899d6a8af03ffc"}, 279 | {file = "protobuf-4.21.1-cp39-cp39-win_amd64.whl", hash = "sha256:3f2ed842e8ca43b790cb4a101bcf577226e0ded98a6a6ba2d5e12095a08dc4da"}, 280 | {file = "protobuf-4.21.1-py2.py3-none-any.whl", hash = "sha256:b309fda192850ac4184ca1777aab9655564bc8d10a9cc98f10e1c8bf11295c22"}, 281 | {file = "protobuf-4.21.1-py3-none-any.whl", hash = "sha256:79cd8d0a269b714f6b32641f86928c718e8d234466919b3f552bfb069dbb159b"}, 282 | {file = "protobuf-4.21.1.tar.gz", hash = "sha256:5d9b5c8270461706973c3871c6fbdd236b51dfe9dab652f1fb6a36aa88287e47"}, 283 | ] 284 | 285 | [[package]] 286 | name = "py" 287 | version = "1.11.0" 288 | description = "library with cross-python path, ini-parsing, io, code, log facilities" 289 | optional = false 290 | python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*" 291 | files = [ 292 | {file = "py-1.11.0-py2.py3-none-any.whl", hash = "sha256:607c53218732647dff4acdfcd50cb62615cedf612e72d1724fb1a0cc6405b378"}, 293 | {file = "py-1.11.0.tar.gz", hash = "sha256:51c75c4126074b472f746a24399ad32f6053d1b34b68d2fa41e558e6f4a98719"}, 294 | ] 295 | 296 | [[package]] 297 | name = "py4j" 298 | version = "0.10.9.3" 299 | description = "Enables Python programs to dynamically access arbitrary Java objects" 300 | optional = false 301 | python-versions = "*" 302 | files = [ 303 | {file = "py4j-0.10.9.3-py2.py3-none-any.whl", hash = "sha256:04f5b06917c0c8a81ab34121dda09a2ba1f74e96d59203c821d5cb7d28c35363"}, 304 | {file = "py4j-0.10.9.3.tar.gz", hash = "sha256:0d92844da4cb747155b9563c44fc322c9a1562b3ef0979ae692dbde732d784dd"}, 305 | ] 306 | 307 | [[package]] 308 | name = "pyparsing" 309 | version = "3.0.8" 310 | description = "pyparsing module - Classes and methods to define and execute parsing grammars" 311 | optional = false 312 | python-versions = ">=3.6.8" 313 | files = [ 314 | {file = "pyparsing-3.0.8-py3-none-any.whl", hash = "sha256:ef7b523f6356f763771559412c0d7134753f037822dad1b16945b7b846f7ad06"}, 315 | {file = "pyparsing-3.0.8.tar.gz", hash = "sha256:7bf433498c016c4314268d95df76c81b842a4cb2b276fa3312cfb1e1d85f6954"}, 316 | ] 317 | 318 | [package.extras] 319 | diagrams = ["jinja2", "railroad-diagrams"] 320 | 321 | [[package]] 322 | name = "pyspark" 323 | version = "3.2.1" 324 | description = "Apache Spark Python API" 325 | optional = false 326 | python-versions = ">=3.6" 327 | files = [ 328 | {file = "pyspark-3.2.1.tar.gz", hash = "sha256:0b81359262ec6e9ac78c353344e7de026027d140c6def949ff0d80ab70f89a54"}, 329 | ] 330 | 331 | [package.dependencies] 332 | py4j = "0.10.9.3" 333 | 334 | [package.extras] 335 | ml = ["numpy (>=1.7)"] 336 | mllib = ["numpy (>=1.7)"] 337 | pandas-on-spark = ["numpy (>=1.14)", "pandas (>=0.23.2)", "pyarrow (>=1.0.0)"] 338 | sql = ["pandas (>=0.23.2)", "pyarrow (>=1.0.0)"] 339 | 340 | [[package]] 341 | name = "pytest" 342 | version = "6.2.5" 343 | description = "pytest: simple powerful testing with Python" 344 | optional = false 345 | python-versions = ">=3.6" 346 | files = [ 347 | {file = "pytest-6.2.5-py3-none-any.whl", hash = "sha256:7310f8d27bc79ced999e760ca304d69f6ba6c6649c0b60fb0e04a4a77cacc134"}, 348 | {file = "pytest-6.2.5.tar.gz", hash = "sha256:131b36680866a76e6781d13f101efb86cf674ebb9762eb70d3082b6f29889e89"}, 349 | ] 350 | 351 | [package.dependencies] 352 | atomicwrites = {version = ">=1.0", markers = "sys_platform == \"win32\""} 353 | attrs = ">=19.2.0" 354 | colorama = {version = "*", markers = "sys_platform == \"win32\""} 355 | importlib-metadata = {version = ">=0.12", markers = "python_version < \"3.8\""} 356 | iniconfig = "*" 357 | packaging = "*" 358 | pluggy = ">=0.12,<2.0" 359 | py = ">=1.8.2" 360 | toml = "*" 361 | 362 | [package.extras] 363 | testing = ["argcomplete", "hypothesis (>=3.56)", "mock", "nose", "requests", "xmlschema"] 364 | 365 | [[package]] 366 | name = "toml" 367 | version = "0.10.2" 368 | description = "Python Library for Tom's Obvious, Minimal Language" 369 | optional = false 370 | python-versions = ">=2.6, !=3.0.*, !=3.1.*, !=3.2.*" 371 | files = [ 372 | {file = "toml-0.10.2-py2.py3-none-any.whl", hash = "sha256:806143ae5bfb6a3c6e736a764057db0e6a0e05e338b5630894a5f779cabb4f9b"}, 373 | {file = "toml-0.10.2.tar.gz", hash = "sha256:b3bda1d108d5dd99f4a20d24d9c348e91c4db7ab1b749200bded2f839ccbe68f"}, 374 | ] 375 | 376 | [[package]] 377 | name = "tomli" 378 | version = "1.2.3" 379 | description = "A lil' TOML parser" 380 | optional = false 381 | python-versions = ">=3.6" 382 | files = [ 383 | {file = "tomli-1.2.3-py3-none-any.whl", hash = "sha256:e3069e4be3ead9668e21cb9b074cd948f7b3113fd9c8bba083f48247aab8b11c"}, 384 | {file = "tomli-1.2.3.tar.gz", hash = "sha256:05b6166bff487dc068d322585c7ea4ef78deed501cc124060e0f238e89a9231f"}, 385 | ] 386 | 387 | [[package]] 388 | name = "typed-ast" 389 | version = "1.5.3" 390 | description = "a fork of Python 2 and 3 ast modules with type comment support" 391 | optional = false 392 | python-versions = ">=3.6" 393 | files = [ 394 | {file = "typed_ast-1.5.3-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:9ad3b48cf2b487be140072fb86feff36801487d4abb7382bb1929aaac80638ea"}, 395 | {file = "typed_ast-1.5.3-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:542cd732351ba8235f20faa0fc7398946fe1a57f2cdb289e5497e1e7f48cfedb"}, 396 | {file = "typed_ast-1.5.3-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5dc2c11ae59003d4a26dda637222d9ae924387f96acae9492df663843aefad55"}, 397 | {file = "typed_ast-1.5.3-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:fd5df1313915dbd70eaaa88c19030b441742e8b05e6103c631c83b75e0435ccc"}, 398 | {file = "typed_ast-1.5.3-cp310-cp310-win_amd64.whl", hash = "sha256:e34f9b9e61333ecb0f7d79c21c28aa5cd63bec15cb7e1310d7d3da6ce886bc9b"}, 399 | {file = "typed_ast-1.5.3-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:f818c5b81966d4728fec14caa338e30a70dfc3da577984d38f97816c4b3071ec"}, 400 | {file = "typed_ast-1.5.3-cp36-cp36m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3042bfc9ca118712c9809201f55355479cfcdc17449f9f8db5e744e9625c6805"}, 401 | {file = "typed_ast-1.5.3-cp36-cp36m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:4fff9fdcce59dc61ec1b317bdb319f8f4e6b69ebbe61193ae0a60c5f9333dc49"}, 402 | {file = "typed_ast-1.5.3-cp36-cp36m-win_amd64.whl", hash = "sha256:8e0b8528838ffd426fea8d18bde4c73bcb4167218998cc8b9ee0a0f2bfe678a6"}, 403 | {file = "typed_ast-1.5.3-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:8ef1d96ad05a291f5c36895d86d1375c0ee70595b90f6bb5f5fdbee749b146db"}, 404 | {file = "typed_ast-1.5.3-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ed44e81517364cb5ba367e4f68fca01fba42a7a4690d40c07886586ac267d9b9"}, 405 | {file = "typed_ast-1.5.3-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:f60d9de0d087454c91b3999a296d0c4558c1666771e3460621875021bf899af9"}, 406 | {file = "typed_ast-1.5.3-cp37-cp37m-win_amd64.whl", hash = "sha256:9e237e74fd321a55c90eee9bc5d44be976979ad38a29bbd734148295c1ce7617"}, 407 | {file = "typed_ast-1.5.3-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:ee852185964744987609b40aee1d2eb81502ae63ee8eef614558f96a56c1902d"}, 408 | {file = "typed_ast-1.5.3-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:27e46cdd01d6c3a0dd8f728b6a938a6751f7bd324817501c15fb056307f918c6"}, 409 | {file = "typed_ast-1.5.3-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d64dabc6336ddc10373922a146fa2256043b3b43e61f28961caec2a5207c56d5"}, 410 | {file = "typed_ast-1.5.3-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:8cdf91b0c466a6c43f36c1964772918a2c04cfa83df8001ff32a89e357f8eb06"}, 411 | {file = "typed_ast-1.5.3-cp38-cp38-win_amd64.whl", hash = "sha256:9cc9e1457e1feb06b075c8ef8aeb046a28ec351b1958b42c7c31c989c841403a"}, 412 | {file = "typed_ast-1.5.3-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:e20d196815eeffb3d76b75223e8ffed124e65ee62097e4e73afb5fec6b993e7a"}, 413 | {file = "typed_ast-1.5.3-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:37e5349d1d5de2f4763d534ccb26809d1c24b180a477659a12c4bde9dd677d74"}, 414 | {file = "typed_ast-1.5.3-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c9f1a27592fac87daa4e3f16538713d705599b0a27dfe25518b80b6b017f0a6d"}, 415 | {file = "typed_ast-1.5.3-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:8831479695eadc8b5ffed06fdfb3e424adc37962a75925668deeb503f446c0a3"}, 416 | {file = "typed_ast-1.5.3-cp39-cp39-win_amd64.whl", hash = "sha256:20d5118e494478ef2d3a2702d964dae830aedd7b4d3b626d003eea526be18718"}, 417 | {file = "typed_ast-1.5.3.tar.gz", hash = "sha256:27f25232e2dd0edfe1f019d6bfaaf11e86e657d9bdb7b0956db95f560cceb2b3"}, 418 | ] 419 | 420 | [[package]] 421 | name = "types-protobuf" 422 | version = "3.19.17" 423 | description = "Typing stubs for protobuf" 424 | optional = false 425 | python-versions = "*" 426 | files = [ 427 | {file = "types-protobuf-3.19.17.tar.gz", hash = "sha256:d0d930326cd76d9e85fd592e18c2248636bdbbf0226618082f57a82f00dd7d25"}, 428 | {file = "types_protobuf-3.19.17-py3-none-any.whl", hash = "sha256:68e9e7bd9439a7b5ee679ada87a2622f2f7df2c66434ff49a5729007ceebc501"}, 429 | ] 430 | 431 | [[package]] 432 | name = "typing-extensions" 433 | version = "4.2.0" 434 | description = "Backported and Experimental Type Hints for Python 3.7+" 435 | optional = false 436 | python-versions = ">=3.7" 437 | files = [ 438 | {file = "typing_extensions-4.2.0-py3-none-any.whl", hash = "sha256:6657594ee297170d19f67d55c05852a874e7eb634f4f753dbd667855e07c1708"}, 439 | {file = "typing_extensions-4.2.0.tar.gz", hash = "sha256:f1c24655a0da0d1b67f07e17a5e6b2a105894e6824b92096378bb3668ef02376"}, 440 | ] 441 | 442 | [[package]] 443 | name = "zipp" 444 | version = "3.8.0" 445 | description = "Backport of pathlib-compatible object wrapper for zip files" 446 | optional = false 447 | python-versions = ">=3.7" 448 | files = [ 449 | {file = "zipp-3.8.0-py3-none-any.whl", hash = "sha256:c4f6e5bbf48e74f7a38e7cc5b0480ff42b0ae5178957d564d18932525d5cf099"}, 450 | {file = "zipp-3.8.0.tar.gz", hash = "sha256:56bf8aadb83c24db6c4b577e13de374ccfb67da2078beba1d037c17980bf43ad"}, 451 | ] 452 | 453 | [package.extras] 454 | docs = ["jaraco.packaging (>=9)", "rst.linker (>=1.9)", "sphinx"] 455 | testing = ["func-timeout", "jaraco.itertools", "pytest (>=6)", "pytest-black (>=0.3.7)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=1.0.1)", "pytest-flake8", "pytest-mypy (>=0.9.1)"] 456 | 457 | [metadata] 458 | lock-version = "2.0" 459 | python-versions = "^3.7" 460 | content-hash = "7a85a10720c9e2670c2b09fef1986d23eeecbb4b7fac8f56e847c34303c1d876" 461 | -------------------------------------------------------------------------------- /pbspark/_proto.py: -------------------------------------------------------------------------------- 1 | import inspect 2 | import typing as t 3 | from contextlib import contextmanager 4 | from functools import wraps 5 | 6 | from google.protobuf import json_format 7 | from google.protobuf.descriptor import Descriptor 8 | from google.protobuf.descriptor import FieldDescriptor 9 | from google.protobuf.descriptor_pool import DescriptorPool 10 | from google.protobuf.message import Message 11 | from google.protobuf.timestamp_pb2 import Timestamp 12 | from pyspark.sql import Column 13 | from pyspark.sql import DataFrame 14 | from pyspark.sql.functions import col 15 | from pyspark.sql.functions import struct 16 | from pyspark.sql.functions import udf 17 | from pyspark.sql.types import ArrayType 18 | from pyspark.sql.types import BinaryType 19 | from pyspark.sql.types import BooleanType 20 | from pyspark.sql.types import DataType 21 | from pyspark.sql.types import DoubleType 22 | from pyspark.sql.types import FloatType 23 | from pyspark.sql.types import IntegerType 24 | from pyspark.sql.types import LongType 25 | from pyspark.sql.types import Row 26 | from pyspark.sql.types import StringType 27 | from pyspark.sql.types import StructField 28 | from pyspark.sql.types import StructType 29 | from pyspark.sql.types import TimestampType 30 | 31 | from pbspark._timestamp import _from_datetime 32 | from pbspark._timestamp import _to_datetime 33 | 34 | # Built in types like these have special methods 35 | # for serialization via MessageToDict. Because the 36 | # MessageToDict function is an intermediate step to 37 | # JSON, some types are serialized to strings. 38 | _MESSAGETYPE_TO_SPARK_TYPE_MAP: t.Dict[str, DataType] = { 39 | # google/protobuf/timestamp.proto 40 | "google.protobuf.Timestamp": StringType(), 41 | # google/protobuf/duration.proto 42 | "google.protobuf.Duration": StringType(), 43 | # google/protobuf/wrappers.proto 44 | "google.protobuf.DoubleValue": DoubleType(), 45 | "google.protobuf.FloatValue": FloatType(), 46 | "google.protobuf.Int64Value": LongType(), 47 | "google.protobuf.UInt64Value": LongType(), 48 | "google.protobuf.Int32Value": IntegerType(), 49 | "google.protobuf.UInt32Value": LongType(), 50 | "google.protobuf.BoolValue": BooleanType(), 51 | "google.protobuf.StringValue": StringType(), 52 | "google.protobuf.BytesValue": BinaryType(), 53 | } 54 | 55 | # Protobuf types map to these CPP Types. We map 56 | # them to Spark types for generating a spark schema. 57 | # Note that bytes fields are specified by the `type` attribute in addition to 58 | # the `cpp_type` attribute so there is special handling in the `get_spark_schema` 59 | # method. 60 | _CPPTYPE_TO_SPARK_TYPE_MAP: t.Dict[int, DataType] = { 61 | FieldDescriptor.CPPTYPE_INT32: IntegerType(), 62 | FieldDescriptor.CPPTYPE_INT64: LongType(), 63 | FieldDescriptor.CPPTYPE_UINT32: LongType(), 64 | FieldDescriptor.CPPTYPE_UINT64: LongType(), 65 | FieldDescriptor.CPPTYPE_DOUBLE: DoubleType(), 66 | FieldDescriptor.CPPTYPE_FLOAT: FloatType(), 67 | FieldDescriptor.CPPTYPE_BOOL: BooleanType(), 68 | FieldDescriptor.CPPTYPE_ENUM: StringType(), 69 | FieldDescriptor.CPPTYPE_STRING: StringType(), 70 | } 71 | 72 | 73 | # region serde overrides 74 | class _Printer(json_format._Printer): # type: ignore 75 | """Printer override to handle custom messages and byte fields.""" 76 | 77 | def __init__(self, custom_serializers=None, **kwargs): 78 | self._custom_serializers = custom_serializers or {} 79 | super().__init__(**kwargs) 80 | 81 | def _MessageToJsonObject(self, message): 82 | full_name = message.DESCRIPTOR.full_name 83 | if full_name in self._custom_serializers: 84 | return self._custom_serializers[full_name](message) 85 | return super()._MessageToJsonObject(message) 86 | 87 | def _FieldToJsonObject(self, field, value): 88 | # specially handle bytes before protobuf's method does 89 | if ( 90 | field.cpp_type == FieldDescriptor.CPPTYPE_STRING 91 | and field.type == FieldDescriptor.TYPE_BYTES 92 | ): 93 | return value 94 | # don't convert int64s to string (protobuf does this for js precision compat) 95 | elif field.cpp_type in json_format._INT64_TYPES: 96 | return value 97 | return super()._FieldToJsonObject(field, value) 98 | 99 | 100 | class _Parser(json_format._Parser): # type: ignore 101 | """Parser override to handle custom messages.""" 102 | 103 | def __init__(self, custom_deserializers=None, **kwargs): 104 | self._custom_deserializers = custom_deserializers or {} 105 | super().__init__(**kwargs) 106 | 107 | def ConvertMessage(self, value, message, path): 108 | full_name = message.DESCRIPTOR.full_name 109 | if full_name in self._custom_deserializers: 110 | self._custom_deserializers[full_name](value, message, path) 111 | return 112 | with _patched_convert_scalar_field_value(): 113 | super().ConvertMessage(value, message, path) 114 | 115 | 116 | # protobuf converts to/from b64 strings, but we prefer to stay as bytes. 117 | # we handle bytes parser by decorating to handle byte fields first 118 | def _handle_bytes(func): 119 | @wraps(func) 120 | def wrapper(value, field, path, require_str=False): 121 | if ( 122 | field.cpp_type == FieldDescriptor.CPPTYPE_STRING 123 | and field.type == FieldDescriptor.TYPE_BYTES 124 | ): 125 | return bytes(value) # convert from bytearray to bytes 126 | return func(value=value, field=field, path=path, require_str=require_str) 127 | 128 | return wrapper 129 | 130 | 131 | @contextmanager 132 | def _patched_convert_scalar_field_value(): 133 | """Temporarily patch the scalar field conversion function.""" 134 | convert_scalar_field_value_func = json_format._ConvertScalarFieldValue # type: ignore[attr-defined] 135 | json_format._ConvertScalarFieldValue = _handle_bytes( # type: ignore[attr-defined] 136 | json_format._ConvertScalarFieldValue # type: ignore[attr-defined] 137 | ) 138 | try: 139 | yield 140 | finally: 141 | json_format._ConvertScalarFieldValue = convert_scalar_field_value_func 142 | 143 | 144 | # endregion 145 | 146 | 147 | class MessageConverter: 148 | def __init__(self): 149 | self._custom_serializers: t.Dict[str, t.Callable] = {} 150 | self._custom_deserializers: t.Dict[str, t.Callable] = {} 151 | self._message_type_to_spark_type_map = _MESSAGETYPE_TO_SPARK_TYPE_MAP.copy() 152 | self.register_timestamp_serializer() 153 | self.register_timestamp_deserializer() 154 | 155 | def register_serializer( 156 | self, 157 | message: t.Type[Message], 158 | serializer: t.Callable, 159 | return_type: DataType, 160 | ): 161 | """Map a message type to a custom serializer and spark output type. 162 | 163 | The serializer should be a function which returns an object which 164 | can be coerced into the spark return type. 165 | """ 166 | full_name = message.DESCRIPTOR.full_name 167 | self._custom_serializers[full_name] = serializer 168 | self._message_type_to_spark_type_map[full_name] = return_type 169 | 170 | def unregister_serializer(self, message: t.Type[Message]): 171 | full_name = message.DESCRIPTOR.full_name 172 | self._custom_serializers.pop(full_name, None) 173 | self._message_type_to_spark_type_map.pop(full_name, None) 174 | if full_name in _MESSAGETYPE_TO_SPARK_TYPE_MAP: 175 | self._message_type_to_spark_type_map[ 176 | full_name 177 | ] = _MESSAGETYPE_TO_SPARK_TYPE_MAP[full_name] 178 | 179 | def register_deserializer(self, message: t.Type[Message], deserializer: t.Callable): 180 | full_name = message.DESCRIPTOR.full_name 181 | self._custom_deserializers[full_name] = deserializer 182 | 183 | def unregister_deserializer(self, message: t.Type[Message]): 184 | full_name = message.DESCRIPTOR.full_name 185 | self._custom_deserializers.pop(full_name, None) 186 | 187 | # region timestamp 188 | def register_timestamp_serializer(self): 189 | self.register_serializer(Timestamp, _to_datetime, TimestampType()) 190 | 191 | def unregister_timestamp_serializer(self): 192 | self.unregister_serializer(Timestamp) 193 | 194 | def register_timestamp_deserializer(self): 195 | self.register_deserializer(Timestamp, _from_datetime) 196 | 197 | def unregister_timestamp_deserializer(self): 198 | self.unregister_deserializer(Timestamp) 199 | 200 | # endregion 201 | 202 | def message_to_dict( 203 | self, 204 | message: Message, 205 | including_default_value_fields: bool = False, 206 | preserving_proto_field_name: bool = False, 207 | use_integers_for_enums: bool = False, 208 | descriptor_pool: t.Optional[DescriptorPool] = None, 209 | float_precision: t.Optional[int] = None, 210 | ): 211 | """Custom MessageToDict using overridden printer. 212 | 213 | Args: 214 | message: The protocol buffers message instance to serialize. 215 | including_default_value_fields: If True, singular primitive fields, 216 | repeated fields, and map fields will always be serialized. If 217 | False, only serialize non-empty fields. Singular message fields 218 | and oneof fields are not affected by this option. 219 | preserving_proto_field_name: If True, use the original proto field 220 | names as defined in the .proto file. If False, convert the field 221 | names to lowerCamelCase. 222 | use_integers_for_enums: If true, print integers instead of enum names. 223 | descriptor_pool: A Descriptor Pool for resolving types. If None use the 224 | default. 225 | float_precision: If set, use this to specify float field valid digits. 226 | """ 227 | printer = _Printer( 228 | custom_serializers=self._custom_serializers, 229 | including_default_value_fields=including_default_value_fields, 230 | preserving_proto_field_name=preserving_proto_field_name, 231 | use_integers_for_enums=use_integers_for_enums, 232 | descriptor_pool=descriptor_pool, 233 | float_precision=float_precision, 234 | ) 235 | return printer._MessageToJsonObject(message=message) 236 | 237 | def parse_dict( 238 | self, 239 | value: dict, 240 | message: Message, 241 | ignore_unknown_fields: bool = False, 242 | descriptor_pool: t.Optional[DescriptorPool] = None, 243 | max_recursion_depth: int = 100, 244 | ): 245 | """Custom ParseDict using overridden parser.""" 246 | parser = _Parser( 247 | custom_deserializers=self._custom_deserializers, 248 | ignore_unknown_fields=ignore_unknown_fields, 249 | descriptor_pool=descriptor_pool, 250 | max_recursion_depth=max_recursion_depth, 251 | ) 252 | return parser.ConvertMessage(value=value, message=message, path=None) 253 | 254 | def get_spark_schema( 255 | self, 256 | descriptor: t.Union[t.Type[Message], Descriptor], 257 | preserving_proto_field_name: bool = False, 258 | use_integers_for_enums: bool = False, 259 | ) -> DataType: 260 | """Generate a spark schema from a message type or descriptor 261 | 262 | Given a message type generated from protoc (or its descriptor), 263 | create a spark schema derived from the protobuf schema when 264 | serializing with ``MessageToDict``. 265 | 266 | Args: 267 | descriptor: A message type or its descriptor 268 | preserving_proto_field_name: If True, use the original proto field 269 | names as defined in the .proto file. If False, convert the field 270 | names to lowerCamelCase. 271 | use_integers_for_enums: If true, print integers instead of enum names. 272 | """ 273 | schema = [] 274 | if inspect.isclass(descriptor) and issubclass(descriptor, Message): 275 | descriptor_ = descriptor.DESCRIPTOR 276 | else: 277 | descriptor_ = descriptor # type: ignore[assignment] 278 | full_name = descriptor_.full_name 279 | if full_name in self._message_type_to_spark_type_map: 280 | return self._message_type_to_spark_type_map[full_name] 281 | for field in descriptor_.fields: 282 | spark_type: DataType 283 | if field.cpp_type == FieldDescriptor.CPPTYPE_MESSAGE: 284 | full_name = field.message_type.full_name 285 | if full_name in self._message_type_to_spark_type_map: 286 | spark_type = self._message_type_to_spark_type_map[full_name] 287 | else: 288 | spark_type = self.get_spark_schema( 289 | descriptor=field.message_type, 290 | preserving_proto_field_name=preserving_proto_field_name, 291 | ) 292 | # protobuf converts to/from b64 strings, but we prefer to stay as bytes 293 | elif ( 294 | field.cpp_type == FieldDescriptor.CPPTYPE_STRING 295 | and field.type == FieldDescriptor.TYPE_BYTES 296 | ): 297 | spark_type = BinaryType() 298 | elif ( 299 | field.cpp_type == FieldDescriptor.CPPTYPE_ENUM 300 | and use_integers_for_enums 301 | ): 302 | spark_type = IntegerType() 303 | else: 304 | spark_type = _CPPTYPE_TO_SPARK_TYPE_MAP[field.cpp_type] 305 | if field.label == FieldDescriptor.LABEL_REPEATED: 306 | spark_type = ArrayType(spark_type, True) 307 | field_name = ( 308 | field.camelcase_name if not preserving_proto_field_name else field.name 309 | ) 310 | schema.append((field_name, spark_type, True)) 311 | struct_args = [StructField(*entry) for entry in schema] 312 | return StructType(struct_args) 313 | 314 | def get_decoder( 315 | self, 316 | message_type: t.Type[Message], 317 | including_default_value_fields: bool = False, 318 | preserving_proto_field_name: bool = False, 319 | use_integers_for_enums: bool = False, 320 | float_precision: t.Optional[int] = None, 321 | ) -> t.Callable: 322 | """Create a deserialization function for a message type. 323 | 324 | Create a function that accepts a serialized message bytestring 325 | and returns a dictionary representing the message. 326 | 327 | Args: 328 | message_type: The message type for decoding. 329 | including_default_value_fields: If True, singular primitive fields, 330 | repeated fields, and map fields will always be serialized. If 331 | False, only serialize non-empty fields. Singular message fields 332 | and oneof fields are not affected by this option. 333 | preserving_proto_field_name: If True, use the original proto field 334 | names as defined in the .proto file. If False, convert the field 335 | names to lowerCamelCase. 336 | use_integers_for_enums: If true, print integers instead of enum names. 337 | float_precision: If set, use this to specify float field valid digits. 338 | """ 339 | 340 | def decoder(s: bytes) -> dict: 341 | if isinstance(s, bytearray): 342 | s = bytes(s) 343 | return self.message_to_dict( 344 | message_type.FromString(s), 345 | including_default_value_fields=including_default_value_fields, 346 | preserving_proto_field_name=preserving_proto_field_name, 347 | use_integers_for_enums=use_integers_for_enums, 348 | float_precision=float_precision, 349 | ) 350 | 351 | return decoder 352 | 353 | def get_decoder_udf( 354 | self, 355 | message_type: t.Type[Message], 356 | including_default_value_fields: bool = False, 357 | preserving_proto_field_name: bool = False, 358 | use_integers_for_enums: bool = False, 359 | float_precision: t.Optional[int] = None, 360 | ) -> t.Callable: 361 | """Create a deserialization udf for a message type. 362 | 363 | Creates a function for deserializing messages to dict 364 | with spark schema for expected output. 365 | 366 | Args: 367 | message_type: The message type for decoding. 368 | including_default_value_fields: If True, singular primitive fields, 369 | repeated fields, and map fields will always be serialized. If 370 | False, only serialize non-empty fields. Singular message fields 371 | and oneof fields are not affected by this option. 372 | preserving_proto_field_name: If True, use the original proto field 373 | names as defined in the .proto file. If False, convert the field 374 | names to lowerCamelCase. 375 | use_integers_for_enums: If true, print integers instead of enum names. 376 | float_precision: If set, use this to specify float field valid digits. 377 | """ 378 | return udf( 379 | self.get_decoder( 380 | message_type=message_type, 381 | including_default_value_fields=including_default_value_fields, 382 | preserving_proto_field_name=preserving_proto_field_name, 383 | use_integers_for_enums=use_integers_for_enums, 384 | float_precision=float_precision, 385 | ), 386 | self.get_spark_schema( 387 | descriptor=message_type.DESCRIPTOR, 388 | preserving_proto_field_name=preserving_proto_field_name, 389 | use_integers_for_enums=use_integers_for_enums, 390 | ), 391 | ) 392 | 393 | def from_protobuf( 394 | self, 395 | data: t.Union[Column, str], 396 | message_type: t.Type[Message], 397 | including_default_value_fields: bool = False, 398 | preserving_proto_field_name: bool = False, 399 | use_integers_for_enums: bool = False, 400 | float_precision: t.Optional[int] = None, 401 | ) -> Column: 402 | """Deserialize protobuf messages to spark structs. 403 | 404 | Given a column and protobuf message type, deserialize 405 | protobuf messages also using our custom serializers. 406 | 407 | Args: 408 | message_type: The message type for decoding. 409 | including_default_value_fields: If True, singular primitive fields, 410 | repeated fields, and map fields will always be serialized. If 411 | False, only serialize non-empty fields. Singular message fields 412 | and oneof fields are not affected by this option. 413 | preserving_proto_field_name: If True, use the original proto field 414 | names as defined in the .proto file. If False, convert the field 415 | names to lowerCamelCase. 416 | use_integers_for_enums: If true, print integers instead of enum names. 417 | float_precision: If set, use this to specify float field valid digits. 418 | """ 419 | column = col(data) if isinstance(data, str) else data 420 | protobuf_decoder_udf = self.get_decoder_udf( 421 | message_type=message_type, 422 | including_default_value_fields=including_default_value_fields, 423 | preserving_proto_field_name=preserving_proto_field_name, 424 | use_integers_for_enums=use_integers_for_enums, 425 | float_precision=float_precision, 426 | ) 427 | return protobuf_decoder_udf(column) 428 | 429 | def get_encoder( 430 | self, 431 | message_type: t.Type[Message], 432 | ignore_unknown_fields: bool = False, 433 | max_recursion_depth: int = 100, 434 | ) -> t.Callable: 435 | """Create an encoding function for a message type. 436 | 437 | Create a function that accepts a dictionary representing the message 438 | and returns a serialized message bytestring. 439 | 440 | Args: 441 | message_type: The message type for encoding. 442 | ignore_unknown_fields: If True, do not raise errors for unknown fields. 443 | max_recursion_depth: max recursion depth of JSON message to be 444 | deserialized. JSON messages over this depth will fail to be 445 | deserialized. Default value is 100. 446 | """ 447 | 448 | def encoder(s: dict) -> bytes: 449 | message = message_type() 450 | # udf may pass a Row object, but we want to pass a dict to the parser 451 | if isinstance(s, Row): 452 | s = s.asDict(recursive=True) 453 | self.parse_dict( 454 | s, 455 | message, 456 | ignore_unknown_fields=ignore_unknown_fields, 457 | max_recursion_depth=max_recursion_depth, 458 | ) 459 | return message.SerializeToString() 460 | 461 | return encoder 462 | 463 | def get_encoder_udf( 464 | self, 465 | message_type: t.Type[Message], 466 | ignore_unknown_fields: bool = False, 467 | max_recursion_depth: int = 100, 468 | ) -> t.Callable: 469 | """Get a pyspark udf for encoding to protobuf. 470 | 471 | Args: 472 | message_type: The message type for encoding. 473 | ignore_unknown_fields: If True, do not raise errors for unknown fields. 474 | max_recursion_depth: max recursion depth of JSON message to be 475 | deserialized. JSON messages over this depth will fail to be 476 | deserialized. Default value is 100. 477 | """ 478 | return udf( 479 | self.get_encoder( 480 | message_type=message_type, 481 | ignore_unknown_fields=ignore_unknown_fields, 482 | max_recursion_depth=max_recursion_depth, 483 | ), 484 | BinaryType(), 485 | ) 486 | 487 | def to_protobuf( 488 | self, 489 | data: t.Union[Column, str], 490 | message_type: t.Type[Message], 491 | ignore_unknown_fields: bool = False, 492 | max_recursion_depth: int = 100, 493 | ) -> Column: 494 | """Serialize spark structs to protobuf messages. 495 | 496 | Given a column and protobuf message type, serialize 497 | protobuf messages also using our custom serializers. 498 | 499 | Args: 500 | data: A pyspark column. 501 | message_type: The message type for encoding. 502 | ignore_unknown_fields: If True, do not raise errors for unknown fields. 503 | max_recursion_depth: max recursion depth of JSON message to be 504 | deserialized. JSON messages over this depth will fail to be 505 | deserialized. Default value is 100. 506 | """ 507 | column = col(data) if isinstance(data, str) else data 508 | protobuf_encoder_udf = self.get_encoder_udf( 509 | message_type, 510 | ignore_unknown_fields=ignore_unknown_fields, 511 | max_recursion_depth=max_recursion_depth, 512 | ) 513 | return protobuf_encoder_udf(column) 514 | 515 | def df_from_protobuf( 516 | self, 517 | df: DataFrame, 518 | message_type: t.Type[Message], 519 | including_default_value_fields: bool = False, 520 | preserving_proto_field_name: bool = False, 521 | use_integers_for_enums: bool = False, 522 | float_precision: t.Optional[int] = None, 523 | expanded: bool = False, 524 | ) -> DataFrame: 525 | """Decode a dataframe of encoded protobuf. 526 | 527 | Args: 528 | df: A pyspark dataframe with encoded protobuf in the column at index 0. 529 | message_type: The message type for decoding. 530 | including_default_value_fields: If True, singular primitive fields, 531 | repeated fields, and map fields will always be serialized. If 532 | False, only serialize non-empty fields. Singular message fields 533 | and oneof fields are not affected by this option. 534 | preserving_proto_field_name: If True, use the original proto field 535 | names as defined in the .proto file. If False, convert the field 536 | names to lowerCamelCase. 537 | use_integers_for_enums: If true, print integers instead of enum names. 538 | float_precision: If set, use this to specify float field valid digits. 539 | expanded: If True, return a dataframe in which each field is its own 540 | column. Otherwise, return a dataframe with a single struct column 541 | named `value`. 542 | """ 543 | df_decoded = df.select( 544 | self.from_protobuf( 545 | data=df.columns[0], 546 | message_type=message_type, 547 | including_default_value_fields=including_default_value_fields, 548 | preserving_proto_field_name=preserving_proto_field_name, 549 | use_integers_for_enums=use_integers_for_enums, 550 | float_precision=float_precision, 551 | ).alias("value") 552 | ) 553 | if expanded: 554 | df_decoded = df_decoded.select("value.*") 555 | return df_decoded 556 | 557 | def df_to_protobuf( 558 | self, 559 | df: DataFrame, 560 | message_type: t.Type[Message], 561 | ignore_unknown_fields: bool = False, 562 | max_recursion_depth: int = 100, 563 | expanded: bool = False, 564 | ) -> DataFrame: 565 | """Encode data in a dataframe to protobuf as column `value`. 566 | 567 | Args: 568 | df: A pyspark dataframe. 569 | message_type: The message type for encoding. 570 | ignore_unknown_fields: If True, do not raise errors for unknown fields. 571 | max_recursion_depth: max recursion depth of JSON message to be 572 | deserialized. JSON messages over this depth will fail to be 573 | deserialized. Default value is 100. 574 | expanded: If True, the passed dataframe columns will be packed into a 575 | struct before converting. Otherwise, it is assumed that the 576 | dataframe passed is a single column of data already packed into a 577 | struct. 578 | 579 | Returns a dataframe with a single column named `value` containing encoded data. 580 | """ 581 | if expanded: 582 | df_struct = df.select( 583 | struct([df[c] for c in df.columns]).alias("value") # type: ignore[arg-type] 584 | ) 585 | else: 586 | df_struct = df.select(col(df.columns[0]).alias("value")) 587 | df_encoded = df_struct.select( 588 | self.to_protobuf( 589 | data=df_struct.value, 590 | message_type=message_type, 591 | ignore_unknown_fields=ignore_unknown_fields, 592 | max_recursion_depth=max_recursion_depth, 593 | ).alias("value") 594 | ) 595 | return df_encoded 596 | 597 | 598 | def from_protobuf( 599 | data: t.Union[Column, str], 600 | message_type: t.Type[Message], 601 | including_default_value_fields: bool = False, 602 | preserving_proto_field_name: bool = False, 603 | use_integers_for_enums: bool = False, 604 | float_precision: t.Optional[int] = None, 605 | message_converter: MessageConverter = None, 606 | ) -> Column: 607 | """Deserialize protobuf messages to spark structs 608 | 609 | Args: 610 | data: A pyspark column. 611 | message_type: The message type for decoding. 612 | including_default_value_fields: If True, singular primitive fields, 613 | repeated fields, and map fields will always be serialized. If 614 | False, only serialize non-empty fields. Singular message fields 615 | and oneof fields are not affected by this option. 616 | preserving_proto_field_name: If True, use the original proto field 617 | names as defined in the .proto file. If False, convert the field 618 | names to lowerCamelCase. 619 | use_integers_for_enums: If true, print integers instead of enum names. 620 | float_precision: If set, use this to specify float field valid digits. 621 | message_converter: An instance of a message converter. If None, use the default. 622 | """ 623 | message_converter = message_converter or MessageConverter() 624 | return message_converter.from_protobuf( 625 | data=data, 626 | message_type=message_type, 627 | including_default_value_fields=including_default_value_fields, 628 | preserving_proto_field_name=preserving_proto_field_name, 629 | use_integers_for_enums=use_integers_for_enums, 630 | float_precision=float_precision, 631 | ) 632 | 633 | 634 | def to_protobuf( 635 | data: t.Union[Column, str], 636 | message_type: t.Type[Message], 637 | ignore_unknown_fields: bool = False, 638 | max_recursion_depth: int = 100, 639 | message_converter: MessageConverter = None, 640 | ) -> Column: 641 | """Serialize spark structs to protobuf messages. 642 | 643 | Given a column and protobuf message type, serialize 644 | protobuf messages also using our custom serializers. 645 | 646 | Args: 647 | data: A pyspark column. 648 | message_type: The message type for encoding. 649 | ignore_unknown_fields: If True, do not raise errors for unknown fields. 650 | max_recursion_depth: max recursion depth of JSON message to be 651 | deserialized. JSON messages over this depth will fail to be 652 | deserialized. Default value is 100. 653 | message_converter: An instance of a message converter. If None, use the default. 654 | """ 655 | message_converter = message_converter or MessageConverter() 656 | return message_converter.to_protobuf( 657 | data=data, 658 | message_type=message_type, 659 | ignore_unknown_fields=ignore_unknown_fields, 660 | max_recursion_depth=max_recursion_depth, 661 | ) 662 | 663 | 664 | def df_from_protobuf( 665 | df: DataFrame, 666 | message_type: t.Type[Message], 667 | including_default_value_fields: bool = False, 668 | preserving_proto_field_name: bool = False, 669 | use_integers_for_enums: bool = False, 670 | float_precision: t.Optional[int] = None, 671 | expanded: bool = False, 672 | message_converter: MessageConverter = None, 673 | ) -> DataFrame: 674 | """Decode a dataframe of encoded protobuf. 675 | 676 | Args: 677 | df: A pyspark dataframe with encoded protobuf in the column at index 0. 678 | message_type: The message type for decoding. 679 | including_default_value_fields: If True, singular primitive fields, 680 | repeated fields, and map fields will always be serialized. If 681 | False, only serialize non-empty fields. Singular message fields 682 | and oneof fields are not affected by this option. 683 | preserving_proto_field_name: If True, use the original proto field 684 | names as defined in the .proto file. If False, convert the field 685 | names to lowerCamelCase. 686 | use_integers_for_enums: If true, print integers instead of enum names. 687 | float_precision: If set, use this to specify float field valid digits. 688 | expanded: If True, return a dataframe in which each field is its own 689 | column. Otherwise, return a dataframe with a single struct column 690 | named `value`. 691 | message_converter: An instance of a message converter. If None, use the default. 692 | """ 693 | message_converter = message_converter or MessageConverter() 694 | return message_converter.df_from_protobuf( 695 | df=df, 696 | message_type=message_type, 697 | including_default_value_fields=including_default_value_fields, 698 | preserving_proto_field_name=preserving_proto_field_name, 699 | use_integers_for_enums=use_integers_for_enums, 700 | float_precision=float_precision, 701 | expanded=expanded, 702 | ) 703 | 704 | 705 | def df_to_protobuf( 706 | df: DataFrame, 707 | message_type: t.Type[Message], 708 | ignore_unknown_fields: bool = False, 709 | max_recursion_depth: int = 100, 710 | expanded: bool = False, 711 | message_converter: MessageConverter = None, 712 | ) -> DataFrame: 713 | """Encode data in a dataframe to protobuf as column `value`. 714 | 715 | Args: 716 | df: A pyspark dataframe. 717 | message_type: The message type for encoding. 718 | ignore_unknown_fields: If True, do not raise errors for unknown fields. 719 | max_recursion_depth: max recursion depth of JSON message to be 720 | deserialized. JSON messages over this depth will fail to be 721 | deserialized. Default value is 100. 722 | expanded: If True, the passed dataframe columns will be packed into a 723 | struct before converting. Otherwise, it is assumed that the 724 | dataframe passed is a single column of data already packed into a 725 | struct. 726 | message_converter: An instance of a message converter. If None, use the default. 727 | 728 | Returns a dataframe with a single column named `value` containing encoded data. 729 | """ 730 | message_converter = message_converter or MessageConverter() 731 | return message_converter.df_to_protobuf( 732 | df=df, 733 | message_type=message_type, 734 | ignore_unknown_fields=ignore_unknown_fields, 735 | max_recursion_depth=max_recursion_depth, 736 | expanded=expanded, 737 | ) 738 | --------------------------------------------------------------------------------