├── tests ├── __init__.py ├── grpc │ ├── __init__.py │ ├── thing_service.py │ └── test_stream_stream.py ├── streams │ ├── load_varint_cutoff.in │ ├── delimited_messages.in │ ├── dump_varint_negative.expected │ ├── dump_varint_positive.expected │ ├── message_dump_file_single.expected │ ├── message_dump_file_multiple.expected │ └── java │ │ ├── src │ │ └── main │ │ │ ├── proto │ │ │ └── betterproto │ │ │ │ ├── oneof.proto │ │ │ │ └── nested.proto │ │ │ └── java │ │ │ └── betterproto │ │ │ ├── CompatibilityTest.java │ │ │ └── Tests.java │ │ ├── .gitignore │ │ └── pom.xml ├── inputs │ ├── oneof_empty │ │ ├── test_oneof_empty.py │ │ ├── oneof_empty.json │ │ ├── oneof_empty_maybe1.json │ │ ├── oneof_empty_maybe2.json │ │ └── oneof_empty.proto │ ├── googletypes │ │ ├── googletypes-missing.json │ │ ├── googletypes.json │ │ └── googletypes.proto │ ├── bool │ │ ├── bool.json │ │ ├── bool.proto │ │ └── test_bool.py │ ├── oneof │ │ ├── oneof.json │ │ ├── oneof-name.json │ │ ├── oneof_name.json │ │ ├── oneof.proto │ │ └── test_oneof.py │ ├── double │ │ ├── double.json │ │ ├── double-negative.json │ │ └── double.proto │ ├── bytes │ │ ├── bytes.json │ │ └── bytes.proto │ ├── proto3_field_presence │ │ ├── proto3_field_presence_default.json │ │ ├── proto3_field_presence_missing.json │ │ ├── proto3_field_presence.json │ │ ├── proto3_field_presence.proto │ │ └── test_proto3_field_presence.py │ ├── oneof_enum │ │ ├── oneof_enum-enum-0.json │ │ ├── oneof_enum-enum-1.json │ │ ├── oneof_enum.json │ │ ├── oneof_enum.proto │ │ └── test_oneof_enum.py │ ├── empty_repeated │ │ ├── empty_repeated.json │ │ └── empty_repeated.proto │ ├── int32 │ │ ├── int32.json │ │ └── int32.proto │ ├── repeated │ │ ├── repeated.json │ │ └── repeated.proto │ ├── casing │ │ ├── casing.json │ │ ├── casing.proto │ │ └── test_casing.py │ ├── ref │ │ ├── ref.json │ │ ├── ref.proto │ │ └── repeatedmessage.proto │ ├── proto3_field_presence_oneof │ │ ├── proto3_field_presence_oneof.json │ │ ├── proto3_field_presence_oneof.proto │ │ └── test_proto3_field_presence_oneof.py │ ├── googletypes_struct │ │ ├── googletypes_struct.json │ │ └── googletypes_struct.proto │ ├── timestamp_dict_encode │ │ ├── timestamp_dict_encode.json │ │ ├── timestamp_dict_encode.proto │ │ └── test_timestamp_dict_encode.py │ ├── deprecated │ │ ├── deprecated.json │ │ └── deprecated.proto │ ├── map │ │ ├── map.json │ │ └── map.proto │ ├── nested2 │ │ ├── package.proto │ │ └── nested2.proto │ ├── nested │ │ ├── nested.json │ │ └── nested.proto │ ├── signed │ │ ├── signed.json │ │ └── signed.proto │ ├── import_root_sibling │ │ ├── sibling.proto │ │ └── import_root_sibling.proto │ ├── invalid_field │ │ ├── invalid_field.proto │ │ └── test_invalid_field.py │ ├── enum │ │ ├── enum.json │ │ ├── enum.proto │ │ └── test_enum.py │ ├── repeatedpacked │ │ ├── repeatedpacked.json │ │ └── repeatedpacked.proto │ ├── fixed │ │ ├── fixed.json │ │ └── fixed.proto │ ├── import_circular_dependency │ │ ├── root.proto │ │ ├── other.proto │ │ └── import_circular_dependency.proto │ ├── import_packages_same_name │ │ ├── posts_v1.proto │ │ ├── users_v1.proto │ │ └── import_packages_same_name.proto │ ├── import_root_package_from_child │ │ ├── root.proto │ │ └── child.proto │ ├── empty_service │ │ └── empty_service.proto │ ├── import_child_package_from_root │ │ ├── child.proto │ │ └── import_child_package_from_root.proto │ ├── import_cousin_package │ │ ├── cousin.proto │ │ └── test.proto │ ├── import_capitalized_package │ │ ├── capitalized.proto │ │ └── test.proto │ ├── import_cousin_package_same_name │ │ ├── cousin.proto │ │ └── test.proto │ ├── mapmessage │ │ ├── mapmessage.json │ │ └── mapmessage.proto │ ├── import_child_package_from_package │ │ ├── child.proto │ │ ├── package_message.proto │ │ └── import_child_package_from_package.proto │ ├── import_service_input_message │ │ ├── request_message.proto │ │ ├── child_package_request_message.proto │ │ ├── import_service_input_message.proto │ │ └── test_import_service_input_message.py │ ├── field_name_identical_to_type │ │ ├── field_name_identical_to_type.json │ │ └── field_name_identical_to_type.proto │ ├── import_parent_package_from_child │ │ ├── parent_package_message.proto │ │ └── import_parent_package_from_child.proto │ ├── repeated_duration_timestamp │ │ ├── repeated_duration_timestamp.json │ │ ├── repeated_duration_timestamp.proto │ │ └── test_repeated_duration_timestamp.py │ ├── repeatedmessage │ │ ├── repeatedmessage.json │ │ └── repeatedmessage.proto │ ├── regression_414 │ │ ├── regression_414.proto │ │ └── test_regression_414.py │ ├── googletypes_value │ │ ├── googletypes_value.json │ │ └── googletypes_value.proto │ ├── regression_387 │ │ ├── regression_387.proto │ │ └── test_regression_387.py │ ├── float │ │ ├── float.json │ │ └── float.proto │ ├── casing_message_field_uppercase │ │ ├── casing_message_field_uppercase.proto │ │ └── casing_message_field_uppercase.py │ ├── recursivemessage │ │ ├── recursivemessage.json │ │ └── recursivemessage.proto │ ├── nestedtwice │ │ ├── nestedtwice.json │ │ ├── test_nestedtwice.py │ │ └── nestedtwice.proto │ ├── service_uppercase │ │ ├── test_service.py │ │ └── service.proto │ ├── casing_inner_class │ │ ├── casing_inner_class.proto │ │ └── test_casing_inner_class.py │ ├── googletypes_service_returns_empty │ │ └── googletypes_service_returns_empty.proto │ ├── google_impl_behavior_equivalence │ │ ├── google_impl_behavior_equivalence.proto │ │ └── test_google_impl_behavior_equivalence.py │ ├── namespace_builtin_types │ │ ├── namespace_builtin_types.json │ │ └── namespace_builtin_types.proto │ ├── service_separate_packages │ │ ├── service.proto │ │ └── messages.proto │ ├── example_service │ │ ├── example_service.proto │ │ └── test_example_service.py │ ├── googletypes_service_returns_googletype │ │ └── googletypes_service_returns_googletype.proto │ ├── namespace_keywords │ │ ├── namespace_keywords.json │ │ └── namespace_keywords.proto │ ├── entry │ │ └── entry.proto │ ├── service │ │ └── service.proto │ ├── oneof_default_value_serialization │ │ ├── oneof_default_value_serialization.proto │ │ └── test_oneof_default_value_serialization.py │ ├── googletypes_response_embedded │ │ ├── googletypes_response_embedded.proto │ │ └── test_googletypes_response_embedded.py │ ├── googletypes_response │ │ ├── googletypes_response.proto │ │ └── test_googletypes_response.py │ ├── config.py │ ├── documentation │ │ └── documentation.proto │ └── googletypes_request │ │ ├── googletypes_request.proto │ │ └── test_googletypes_request.py ├── conftest.py ├── test_mapmessage.py ├── test_version.py ├── test_all_definition.py ├── test_timestamp.py ├── test_documentation.py ├── mocks.py ├── test_struct.py ├── oneof_pattern_matching.py ├── test_deprecated.py ├── test_enum.py ├── README.md ├── test_typing_compiler.py ├── test_module_validation.py ├── test_casing.py └── util.py ├── benchmarks ├── __init__.py └── benchmarks.py ├── src └── betterproto │ ├── py.typed │ ├── grpc │ ├── __init__.py │ ├── util │ │ └── __init__.py │ ├── grpclib_server.py │ └── grpclib_client.py │ ├── lib │ ├── __init__.py │ ├── google │ │ ├── __init__.py │ │ └── protobuf │ │ │ ├── __init__.py │ │ │ └── compiler │ │ │ └── __init__.py │ ├── std │ │ ├── __init__.py │ │ └── google │ │ │ └── __init__.py │ └── pydantic │ │ ├── __init__.py │ │ └── google │ │ └── __init__.py │ ├── compile │ ├── __init__.py │ └── naming.py │ ├── plugin │ ├── __init__.py │ ├── __main__.py │ ├── plugin.bat │ ├── main.py │ ├── compiler.py │ ├── module_validation.py │ └── typing_compiler.py │ ├── _version.py │ ├── _types.py │ ├── utils.py │ ├── templates │ └── header.py.j2 │ ├── casing.py │ └── enum.py ├── .env.default ├── MANIFEST.in ├── pytest.ini ├── .github ├── ISSUE_TEMPLATE │ ├── config.yml │ ├── feature_request.yml │ └── bug_report.yml ├── workflows │ ├── code-quality.yml │ ├── release.yml │ ├── codeql-analysis.yml │ └── ci.yml ├── PULL_REQUEST_TEMPLATE.md └── CONTRIBUTING.md ├── .readthedocs.yml ├── .gitignore ├── docs ├── api.rst ├── index.rst ├── conf.py └── quick-start.rst ├── .pre-commit-config.yaml ├── LICENSE.md └── pyproject.toml /tests/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /benchmarks/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/grpc/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/betterproto/py.typed: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /.env.default: -------------------------------------------------------------------------------- 1 | PYTHONPATH=. 2 | -------------------------------------------------------------------------------- /src/betterproto/grpc/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/betterproto/lib/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/betterproto/compile/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/betterproto/grpc/util/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/betterproto/lib/google/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/betterproto/lib/std/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/streams/load_varint_cutoff.in: -------------------------------------------------------------------------------- 1 | ȁ -------------------------------------------------------------------------------- /src/betterproto/lib/pydantic/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/betterproto/lib/std/google/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/betterproto/lib/pydantic/google/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/inputs/oneof_empty/test_oneof_empty.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/inputs/googletypes/googletypes-missing.json: -------------------------------------------------------------------------------- 1 | {} 2 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | recursive-exclude tests * 2 | exclude output 3 | -------------------------------------------------------------------------------- /src/betterproto/plugin/__init__.py: -------------------------------------------------------------------------------- 1 | from .main import main 2 | -------------------------------------------------------------------------------- /tests/inputs/bool/bool.json: -------------------------------------------------------------------------------- 1 | { 2 | "value": true 3 | } 4 | -------------------------------------------------------------------------------- /tests/inputs/oneof/oneof.json: -------------------------------------------------------------------------------- 1 | { 2 | "pitied": 100 3 | } 4 | -------------------------------------------------------------------------------- /tests/inputs/double/double.json: -------------------------------------------------------------------------------- 1 | { 2 | "count": 123.45 3 | } 4 | -------------------------------------------------------------------------------- /tests/inputs/oneof/oneof-name.json: -------------------------------------------------------------------------------- 1 | { 2 | "pitier": "Mr. T" 3 | } 4 | -------------------------------------------------------------------------------- /tests/inputs/oneof/oneof_name.json: -------------------------------------------------------------------------------- 1 | { 2 | "pitier": "Mr. T" 3 | } 4 | -------------------------------------------------------------------------------- /tests/inputs/oneof_empty/oneof_empty.json: -------------------------------------------------------------------------------- 1 | { 2 | "nothing": {} 3 | } 4 | -------------------------------------------------------------------------------- /tests/inputs/bytes/bytes.json: -------------------------------------------------------------------------------- 1 | { 2 | "data": "SGVsbG8sIFdvcmxkIQ==" 3 | } 4 | -------------------------------------------------------------------------------- /tests/inputs/double/double-negative.json: -------------------------------------------------------------------------------- 1 | { 2 | "count": -123.45 3 | } 4 | -------------------------------------------------------------------------------- /tests/inputs/proto3_field_presence/proto3_field_presence_default.json: -------------------------------------------------------------------------------- 1 | {} 2 | -------------------------------------------------------------------------------- /src/betterproto/plugin/__main__.py: -------------------------------------------------------------------------------- 1 | from .main import main 2 | 3 | 4 | main() 5 | -------------------------------------------------------------------------------- /tests/inputs/oneof_empty/oneof_empty_maybe1.json: -------------------------------------------------------------------------------- 1 | { 2 | "maybe1": {} 3 | } 4 | -------------------------------------------------------------------------------- /tests/inputs/oneof_enum/oneof_enum-enum-0.json: -------------------------------------------------------------------------------- 1 | { 2 | "signal": "PASS" 3 | } 4 | -------------------------------------------------------------------------------- /tests/inputs/oneof_enum/oneof_enum-enum-1.json: -------------------------------------------------------------------------------- 1 | { 2 | "signal": "RESIGN" 3 | } 4 | -------------------------------------------------------------------------------- /src/betterproto/plugin/plugin.bat: -------------------------------------------------------------------------------- 1 | @SET plugin_dir=%~dp0 2 | @python -m %plugin_dir% %* -------------------------------------------------------------------------------- /tests/inputs/empty_repeated/empty_repeated.json: -------------------------------------------------------------------------------- 1 | { 2 | "msg": [{"values":[]}] 3 | } 4 | -------------------------------------------------------------------------------- /tests/inputs/int32/int32.json: -------------------------------------------------------------------------------- 1 | { 2 | "positive": 150, 3 | "negative": -150 4 | } 5 | -------------------------------------------------------------------------------- /tests/inputs/repeated/repeated.json: -------------------------------------------------------------------------------- 1 | { 2 | "names": ["one", "two", "three"] 3 | } 4 | -------------------------------------------------------------------------------- /tests/inputs/casing/casing.json: -------------------------------------------------------------------------------- 1 | { 2 | "camelCase": 1, 3 | "snakeCase": "ONE" 4 | } 5 | -------------------------------------------------------------------------------- /tests/inputs/ref/ref.json: -------------------------------------------------------------------------------- 1 | { 2 | "greeting": { 3 | "greeting": "hello" 4 | } 5 | } 6 | -------------------------------------------------------------------------------- /src/betterproto/lib/google/protobuf/__init__.py: -------------------------------------------------------------------------------- 1 | from betterproto.lib.std.google.protobuf import * 2 | -------------------------------------------------------------------------------- /tests/inputs/oneof_enum/oneof_enum.json: -------------------------------------------------------------------------------- 1 | { 2 | "move": { 3 | "x": 2, 4 | "y": 3 5 | } 6 | } 7 | -------------------------------------------------------------------------------- /tests/inputs/proto3_field_presence_oneof/proto3_field_presence_oneof.json: -------------------------------------------------------------------------------- 1 | { 2 | "nested": {} 3 | } 4 | -------------------------------------------------------------------------------- /tests/inputs/googletypes_struct/googletypes_struct.json: -------------------------------------------------------------------------------- 1 | { 2 | "struct": { 3 | "key": true 4 | } 5 | } 6 | -------------------------------------------------------------------------------- /tests/inputs/oneof_empty/oneof_empty_maybe2.json: -------------------------------------------------------------------------------- 1 | { 2 | "maybe2": { 3 | "sometimes": "now" 4 | } 5 | } 6 | -------------------------------------------------------------------------------- /tests/inputs/timestamp_dict_encode/timestamp_dict_encode.json: -------------------------------------------------------------------------------- 1 | { 2 | "ts" : "2023-03-15T22:35:51.253277Z" 3 | } -------------------------------------------------------------------------------- /src/betterproto/lib/google/protobuf/compiler/__init__.py: -------------------------------------------------------------------------------- 1 | from betterproto.lib.std.google.protobuf.compiler import * 2 | -------------------------------------------------------------------------------- /tests/inputs/deprecated/deprecated.json: -------------------------------------------------------------------------------- 1 | { 2 | "message": { 3 | "value": "hello" 4 | }, 5 | "value": 10 6 | } 7 | -------------------------------------------------------------------------------- /tests/inputs/map/map.json: -------------------------------------------------------------------------------- 1 | { 2 | "counts": { 3 | "item1": 1, 4 | "item2": 2, 5 | "item3": 3 6 | } 7 | } 8 | -------------------------------------------------------------------------------- /tests/inputs/bool/bool.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto3"; 2 | 3 | package bool; 4 | 5 | message Test { 6 | bool value = 1; 7 | } 8 | -------------------------------------------------------------------------------- /tests/inputs/nested2/package.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto3"; 2 | 3 | package nested2.equipment; 4 | 5 | message Weapon { 6 | 7 | } -------------------------------------------------------------------------------- /pytest.ini: -------------------------------------------------------------------------------- 1 | [pytest] 2 | python_files = test_*.py 3 | python_classes = 4 | norecursedirs = **/output_* 5 | addopts = -p no:warnings -------------------------------------------------------------------------------- /tests/inputs/bytes/bytes.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto3"; 2 | 3 | package bytes; 4 | 5 | message Test { 6 | bytes data = 1; 7 | } 8 | -------------------------------------------------------------------------------- /tests/inputs/nested/nested.json: -------------------------------------------------------------------------------- 1 | { 2 | "nested": { 3 | "count": 150 4 | }, 5 | "sibling": {}, 6 | "msg": "THIS" 7 | } 8 | -------------------------------------------------------------------------------- /tests/inputs/double/double.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto3"; 2 | 3 | package double; 4 | 5 | message Test { 6 | double count = 1; 7 | } 8 | -------------------------------------------------------------------------------- /tests/streams/delimited_messages.in: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/danielgtaylor/python-betterproto/HEAD/tests/streams/delimited_messages.in -------------------------------------------------------------------------------- /tests/inputs/map/map.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto3"; 2 | 3 | package map; 4 | 5 | message Test { 6 | map counts = 1; 7 | } 8 | -------------------------------------------------------------------------------- /tests/inputs/signed/signed.json: -------------------------------------------------------------------------------- 1 | { 2 | "signed32": 150, 3 | "negative32": -150, 4 | "string64": "150", 5 | "negative64": "-150" 6 | } 7 | -------------------------------------------------------------------------------- /tests/inputs/import_root_sibling/sibling.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto3"; 2 | 3 | package import_root_sibling; 4 | 5 | message SiblingMessage { 6 | 7 | } 8 | -------------------------------------------------------------------------------- /tests/inputs/invalid_field/invalid_field.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto3"; 2 | 3 | package invalid_field; 4 | 5 | message Test { 6 | int32 x = 1; 7 | } 8 | -------------------------------------------------------------------------------- /tests/inputs/repeated/repeated.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto3"; 2 | 3 | package repeated; 4 | 5 | message Test { 6 | repeated string names = 1; 7 | } 8 | -------------------------------------------------------------------------------- /tests/streams/dump_varint_negative.expected: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/danielgtaylor/python-betterproto/HEAD/tests/streams/dump_varint_negative.expected -------------------------------------------------------------------------------- /tests/streams/dump_varint_positive.expected: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/danielgtaylor/python-betterproto/HEAD/tests/streams/dump_varint_positive.expected -------------------------------------------------------------------------------- /tests/inputs/enum/enum.json: -------------------------------------------------------------------------------- 1 | { 2 | "choice": "FOUR", 3 | "choices": [ 4 | "ZERO", 5 | "ONE", 6 | "THREE", 7 | "FOUR" 8 | ] 9 | } 10 | -------------------------------------------------------------------------------- /tests/inputs/repeatedpacked/repeatedpacked.json: -------------------------------------------------------------------------------- 1 | { 2 | "counts": [1, 2, -1, -2], 3 | "signed": ["1", "2", "-1", "-2"], 4 | "fixed": [1.0, 2.7, 3.4] 5 | } 6 | -------------------------------------------------------------------------------- /tests/streams/message_dump_file_single.expected: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/danielgtaylor/python-betterproto/HEAD/tests/streams/message_dump_file_single.expected -------------------------------------------------------------------------------- /tests/inputs/fixed/fixed.json: -------------------------------------------------------------------------------- 1 | { 2 | "foo": 4294967295, 3 | "bar": -2147483648, 4 | "baz": "18446744073709551615", 5 | "qux": "-9223372036854775808" 6 | } 7 | -------------------------------------------------------------------------------- /tests/inputs/import_circular_dependency/root.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto3"; 2 | 3 | package import_circular_dependency; 4 | 5 | message RootPackageMessage { 6 | 7 | } 8 | -------------------------------------------------------------------------------- /tests/inputs/import_packages_same_name/posts_v1.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto3"; 2 | 3 | package import_packages_same_name.posts.v1; 4 | 5 | message Post { 6 | 7 | } 8 | -------------------------------------------------------------------------------- /tests/inputs/import_packages_same_name/users_v1.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto3"; 2 | 3 | package import_packages_same_name.users.v1; 4 | 5 | message User { 6 | 7 | } 8 | -------------------------------------------------------------------------------- /tests/inputs/import_root_package_from_child/root.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto3"; 2 | 3 | package import_root_package_from_child; 4 | 5 | 6 | message RootMessage { 7 | } 8 | -------------------------------------------------------------------------------- /tests/streams/message_dump_file_multiple.expected: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/danielgtaylor/python-betterproto/HEAD/tests/streams/message_dump_file_multiple.expected -------------------------------------------------------------------------------- /tests/inputs/empty_service/empty_service.proto: -------------------------------------------------------------------------------- 1 | /* Empty service without comments */ 2 | syntax = "proto3"; 3 | 4 | package empty_service; 5 | 6 | service Test { 7 | } 8 | -------------------------------------------------------------------------------- /tests/inputs/import_child_package_from_root/child.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto3"; 2 | 3 | package import_child_package_from_root.childpackage; 4 | 5 | message Message { 6 | 7 | } 8 | -------------------------------------------------------------------------------- /tests/inputs/import_cousin_package/cousin.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto3"; 2 | 3 | package import_cousin_package.cousin.cousin_subpackage; 4 | 5 | message CousinMessage { 6 | } 7 | -------------------------------------------------------------------------------- /tests/inputs/import_capitalized_package/capitalized.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto3"; 2 | 3 | 4 | package import_capitalized_package.Capitalized; 5 | 6 | message Message { 7 | 8 | } 9 | -------------------------------------------------------------------------------- /tests/inputs/googletypes/googletypes.json: -------------------------------------------------------------------------------- 1 | { 2 | "maybe": false, 3 | "ts": "1972-01-01T10:00:20.021Z", 4 | "duration": "1.200s", 5 | "important": 10, 6 | "empty": {} 7 | } 8 | -------------------------------------------------------------------------------- /tests/inputs/import_cousin_package_same_name/cousin.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto3"; 2 | 3 | package import_cousin_package_same_name.cousin.subpackage; 4 | 5 | message CousinMessage { 6 | } 7 | -------------------------------------------------------------------------------- /tests/inputs/mapmessage/mapmessage.json: -------------------------------------------------------------------------------- 1 | { 2 | "items": { 3 | "foo": { 4 | "count": 1 5 | }, 6 | "bar": { 7 | "count": 2 8 | } 9 | } 10 | } 11 | -------------------------------------------------------------------------------- /tests/inputs/ref/ref.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto3"; 2 | 3 | package ref; 4 | 5 | import "repeatedmessage.proto"; 6 | 7 | message Test { 8 | repeatedmessage.Sub greeting = 1; 9 | } 10 | -------------------------------------------------------------------------------- /tests/inputs/import_child_package_from_package/child.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto3"; 2 | 3 | package import_child_package_from_package.package.childpackage; 4 | 5 | message ChildMessage { 6 | 7 | } 8 | -------------------------------------------------------------------------------- /tests/inputs/import_service_input_message/request_message.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto3"; 2 | 3 | package import_service_input_message; 4 | 5 | message RequestMessage { 6 | int32 argument = 1; 7 | } -------------------------------------------------------------------------------- /tests/inputs/field_name_identical_to_type/field_name_identical_to_type.json: -------------------------------------------------------------------------------- 1 | { 2 | "int": 26, 3 | "float": 26.0, 4 | "str": "value-for-str", 5 | "bytes": "001a", 6 | "bool": true 7 | } -------------------------------------------------------------------------------- /tests/inputs/import_parent_package_from_child/parent_package_message.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto3"; 2 | 3 | package import_parent_package_from_child.parent; 4 | 5 | message ParentPackageMessage { 6 | } 7 | -------------------------------------------------------------------------------- /tests/inputs/repeated_duration_timestamp/repeated_duration_timestamp.json: -------------------------------------------------------------------------------- 1 | { 2 | "times": ["1972-01-01T10:00:20.021Z", "1972-01-01T10:00:20.021Z"], 3 | "durations": ["1.200s", "1.200s"] 4 | } 5 | -------------------------------------------------------------------------------- /tests/inputs/repeatedmessage/repeatedmessage.json: -------------------------------------------------------------------------------- 1 | { 2 | "greetings": [ 3 | { 4 | "greeting": "hello" 5 | }, 6 | { 7 | "greeting": "hi" 8 | } 9 | ] 10 | } 11 | -------------------------------------------------------------------------------- /tests/inputs/fixed/fixed.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto3"; 2 | 3 | package fixed; 4 | 5 | message Test { 6 | fixed32 foo = 1; 7 | sfixed32 bar = 2; 8 | fixed64 baz = 3; 9 | sfixed64 qux = 4; 10 | } 11 | -------------------------------------------------------------------------------- /tests/inputs/regression_414/regression_414.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto3"; 2 | 3 | package regression_414; 4 | 5 | message Test { 6 | bytes body = 1; 7 | bytes auth = 2; 8 | repeated bytes signatures = 3; 9 | } -------------------------------------------------------------------------------- /tests/inputs/mapmessage/mapmessage.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto3"; 2 | 3 | package mapmessage; 4 | 5 | message Test { 6 | map items = 1; 7 | } 8 | 9 | message Nested { 10 | int32 count = 1; 11 | } -------------------------------------------------------------------------------- /tests/inputs/ref/repeatedmessage.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto3"; 2 | 3 | package repeatedmessage; 4 | 5 | message Test { 6 | repeated Sub greetings = 1; 7 | } 8 | 9 | message Sub { 10 | string greeting = 1; 11 | } -------------------------------------------------------------------------------- /tests/conftest.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import sys 3 | 4 | import pytest 5 | 6 | 7 | @pytest.fixture 8 | def reset_sys_path(): 9 | original = copy.deepcopy(sys.path) 10 | yield 11 | sys.path = original 12 | -------------------------------------------------------------------------------- /tests/inputs/import_service_input_message/child_package_request_message.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto3"; 2 | 3 | package import_service_input_message.child; 4 | 5 | message ChildRequestMessage { 6 | int32 child_argument = 1; 7 | } -------------------------------------------------------------------------------- /tests/inputs/proto3_field_presence/proto3_field_presence_missing.json: -------------------------------------------------------------------------------- 1 | { 2 | "test1": 0, 3 | "test2": false, 4 | "test3": "", 5 | "test4": "", 6 | "test6": "A", 7 | "test7": "0", 8 | "test8": 0 9 | } 10 | -------------------------------------------------------------------------------- /tests/inputs/googletypes_value/googletypes_value.json: -------------------------------------------------------------------------------- 1 | { 2 | "value1": "hello world", 3 | "value2": true, 4 | "value3": 1, 5 | "value4": null, 6 | "value5": [ 7 | 1, 8 | 2, 9 | 3 10 | ] 11 | } 12 | -------------------------------------------------------------------------------- /tests/inputs/repeatedmessage/repeatedmessage.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto3"; 2 | 3 | package repeatedmessage; 4 | 5 | message Test { 6 | repeated Sub greetings = 1; 7 | } 8 | 9 | message Sub { 10 | string greeting = 1; 11 | } -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/config.yml: -------------------------------------------------------------------------------- 1 | name: 2 | description: 3 | contact_links: 4 | - name: For questions about the library 5 | about: Support questions are better answered in our Discord group. 6 | url: https://discord.gg/DEVteTupPb 7 | -------------------------------------------------------------------------------- /src/betterproto/_version.py: -------------------------------------------------------------------------------- 1 | try: 2 | from importlib import metadata 3 | except ImportError: # for Python<3.8 4 | import importlib_metadata as metadata # type: ignore 5 | 6 | 7 | __version__ = metadata.version("betterproto") 8 | -------------------------------------------------------------------------------- /tests/inputs/timestamp_dict_encode/timestamp_dict_encode.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto3"; 2 | 3 | package timestamp_dict_encode; 4 | 5 | import "google/protobuf/timestamp.proto"; 6 | 7 | message Test { 8 | google.protobuf.Timestamp ts = 1; 9 | } -------------------------------------------------------------------------------- /tests/inputs/empty_repeated/empty_repeated.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto3"; 2 | 3 | package empty_repeated; 4 | 5 | message MessageA { 6 | repeated float values = 1; 7 | } 8 | 9 | message Test { 10 | repeated MessageA msg = 1; 11 | } 12 | -------------------------------------------------------------------------------- /tests/inputs/googletypes_struct/googletypes_struct.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto3"; 2 | 3 | package googletypes_struct; 4 | 5 | import "google/protobuf/struct.proto"; 6 | 7 | message Test { 8 | google.protobuf.Struct struct = 1; 9 | } 10 | -------------------------------------------------------------------------------- /tests/inputs/import_circular_dependency/other.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto3"; 2 | 3 | import "root.proto"; 4 | package import_circular_dependency.other; 5 | 6 | message OtherPackageMessage { 7 | RootPackageMessage rootPackageMessage = 1; 8 | } 9 | -------------------------------------------------------------------------------- /tests/inputs/repeatedpacked/repeatedpacked.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto3"; 2 | 3 | package repeatedpacked; 4 | 5 | message Test { 6 | repeated int32 counts = 1; 7 | repeated sint64 signed = 2; 8 | repeated double fixed = 3; 9 | } 10 | -------------------------------------------------------------------------------- /tests/inputs/regression_387/regression_387.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto3"; 2 | 3 | package regression_387; 4 | 5 | message Test { 6 | uint64 id = 1; 7 | } 8 | 9 | message ParentElement { 10 | string name = 1; 11 | repeated Test elems = 2; 12 | } -------------------------------------------------------------------------------- /tests/inputs/float/float.json: -------------------------------------------------------------------------------- 1 | { 2 | "positive": "Infinity", 3 | "negative": "-Infinity", 4 | "nan": "NaN", 5 | "three": 3.0, 6 | "threePointOneFour": 3.14, 7 | "negThree": -3.0, 8 | "negThreePointOneFour": -3.14 9 | } 10 | -------------------------------------------------------------------------------- /tests/inputs/casing_message_field_uppercase/casing_message_field_uppercase.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto3"; 2 | 3 | package casing_message_field_uppercase; 4 | 5 | message Test { 6 | int32 UPPERCASE = 1; 7 | int32 UPPERCASE_V2 = 2; 8 | int32 UPPER_CAMEL_CASE = 3; 9 | } -------------------------------------------------------------------------------- /tests/inputs/recursivemessage/recursivemessage.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "Zues", 3 | "child": { 4 | "name": "Hercules" 5 | }, 6 | "intermediate": { 7 | "child": { 8 | "name": "Douglas Adams" 9 | }, 10 | "number": 42 11 | } 12 | } 13 | -------------------------------------------------------------------------------- /tests/inputs/int32/int32.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto3"; 2 | 3 | package int32; 4 | 5 | // Some documentation about the Test message. 6 | message Test { 7 | // Some documentation about the count. 8 | int32 positive = 1; 9 | int32 negative = 2; 10 | } 11 | -------------------------------------------------------------------------------- /tests/inputs/import_child_package_from_package/package_message.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto3"; 2 | 3 | import "child.proto"; 4 | 5 | package import_child_package_from_package.package; 6 | 7 | message PackageMessage { 8 | package.childpackage.ChildMessage c = 1; 9 | } 10 | -------------------------------------------------------------------------------- /tests/inputs/nestedtwice/nestedtwice.json: -------------------------------------------------------------------------------- 1 | { 2 | "top": { 3 | "name": "double-nested", 4 | "middle": { 5 | "bottom": [{"foo": "hello"}], 6 | "enumBottom": ["A"], 7 | "topMiddleBottom": [{"a": "hello"}], 8 | "bar": true 9 | } 10 | } 11 | } 12 | -------------------------------------------------------------------------------- /tests/inputs/service_uppercase/test_service.py: -------------------------------------------------------------------------------- 1 | import inspect 2 | 3 | from tests.output_betterproto.service_uppercase import TestStub 4 | 5 | 6 | def test_parameters(): 7 | sig = inspect.signature(TestStub.do_thing) 8 | assert len(sig.parameters) == 5, "Expected 5 parameters" 9 | -------------------------------------------------------------------------------- /tests/inputs/import_root_package_from_child/child.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto3"; 2 | 3 | package import_root_package_from_child.child; 4 | 5 | import "root.proto"; 6 | 7 | // Verify that we can import root message from child package 8 | 9 | message Test { 10 | RootMessage message = 1; 11 | } 12 | -------------------------------------------------------------------------------- /.readthedocs.yml: -------------------------------------------------------------------------------- 1 | version: 2 2 | formats: [] 3 | 4 | build: 5 | image: latest 6 | 7 | sphinx: 8 | configuration: docs/conf.py 9 | fail_on_warning: false 10 | 11 | python: 12 | version: 3.7 13 | install: 14 | - method: pip 15 | path: . 16 | extra_requirements: 17 | - dev -------------------------------------------------------------------------------- /tests/inputs/casing_inner_class/casing_inner_class.proto: -------------------------------------------------------------------------------- 1 | // https://github.com/danielgtaylor/python-betterproto/issues/344 2 | syntax = "proto3"; 3 | 4 | package casing_inner_class; 5 | 6 | message Test { 7 | message inner_class { 8 | sint32 old_exp = 1; 9 | } 10 | inner_class inner = 2; 11 | } -------------------------------------------------------------------------------- /tests/inputs/import_cousin_package/test.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto3"; 2 | 3 | package import_cousin_package.test.subpackage; 4 | 5 | import "cousin.proto"; 6 | 7 | // Verify that we can import message unrelated to us 8 | 9 | message Test { 10 | cousin.cousin_subpackage.CousinMessage message = 1; 11 | } 12 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .coverage 2 | .DS_Store 3 | .env 4 | .vscode/settings.json 5 | .mypy_cache 6 | .pytest_cache 7 | .python-version 8 | build/ 9 | tests/output_* 10 | **/__pycache__ 11 | dist 12 | **/*.egg-info 13 | output 14 | .idea 15 | .DS_Store 16 | .tox 17 | .venv 18 | .asv 19 | venv 20 | .devcontainer 21 | .ruff_cache -------------------------------------------------------------------------------- /tests/inputs/proto3_field_presence/proto3_field_presence.json: -------------------------------------------------------------------------------- 1 | { 2 | "test1": 128, 3 | "test2": true, 4 | "test3": "A value", 5 | "test4": "aGVsbG8=", 6 | "test5": { 7 | "test": "Hello" 8 | }, 9 | "test6": "B", 10 | "test7": "8589934592", 11 | "test8": 2.5, 12 | "test9": "2022-01-24T12:12:42Z" 13 | } 14 | -------------------------------------------------------------------------------- /tests/inputs/recursivemessage/recursivemessage.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto3"; 2 | 3 | package recursivemessage; 4 | 5 | message Test { 6 | string name = 1; 7 | Test child = 2; 8 | Intermediate intermediate = 3; 9 | } 10 | 11 | 12 | message Intermediate { 13 | int32 number = 1; 14 | Test child = 2; 15 | } 16 | -------------------------------------------------------------------------------- /tests/inputs/import_root_sibling/import_root_sibling.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto3"; 2 | 3 | package import_root_sibling; 4 | 5 | import "sibling.proto"; 6 | 7 | // Tests generated imports when a message in the root package refers to another message in the root package 8 | 9 | message Test { 10 | SiblingMessage sibling = 1; 11 | } 12 | -------------------------------------------------------------------------------- /tests/inputs/oneof_enum/oneof_enum.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto3"; 2 | 3 | package oneof_enum; 4 | 5 | message Test { 6 | oneof action { 7 | Signal signal = 1; 8 | Move move = 2; 9 | } 10 | } 11 | 12 | enum Signal { 13 | PASS = 0; 14 | RESIGN = 1; 15 | } 16 | 17 | message Move { 18 | int32 x = 1; 19 | int32 y = 2; 20 | } -------------------------------------------------------------------------------- /tests/inputs/import_capitalized_package/test.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto3"; 2 | 3 | package import_capitalized_package; 4 | 5 | import "capitalized.proto"; 6 | 7 | // Tests that we can import from a package with a capital name, that looks like a nested type, but isn't. 8 | 9 | message Test { 10 | Capitalized.Message message = 1; 11 | } 12 | -------------------------------------------------------------------------------- /tests/inputs/import_child_package_from_root/import_child_package_from_root.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto3"; 2 | 3 | package import_child_package_from_root; 4 | 5 | import "child.proto"; 6 | 7 | // Tests generated imports when a message in root refers to a message in a child package. 8 | 9 | message Test { 10 | childpackage.Message child = 1; 11 | } 12 | -------------------------------------------------------------------------------- /tests/inputs/oneof_empty/oneof_empty.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto3"; 2 | 3 | package oneof_empty; 4 | 5 | message Nothing {} 6 | 7 | message MaybeNothing { 8 | string sometimes = 42; 9 | } 10 | 11 | message Test { 12 | oneof empty { 13 | Nothing nothing = 1; 14 | MaybeNothing maybe1 = 2; 15 | MaybeNothing maybe2 = 3; 16 | } 17 | } 18 | -------------------------------------------------------------------------------- /tests/inputs/googletypes_service_returns_empty/googletypes_service_returns_empty.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto3"; 2 | 3 | package googletypes_service_returns_empty; 4 | 5 | import "google/protobuf/empty.proto"; 6 | 7 | service Test { 8 | rpc Send (RequestMessage) returns (google.protobuf.Empty) { 9 | } 10 | } 11 | 12 | message RequestMessage { 13 | } -------------------------------------------------------------------------------- /tests/inputs/import_cousin_package_same_name/test.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto3"; 2 | 3 | package import_cousin_package_same_name.test.subpackage; 4 | 5 | import "cousin.proto"; 6 | 7 | // Verify that we can import a message unrelated to us, in a subpackage with the same name as us. 8 | 9 | message Test { 10 | cousin.subpackage.CousinMessage message = 1; 11 | } 12 | -------------------------------------------------------------------------------- /tests/inputs/casing/casing.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto3"; 2 | 3 | package casing; 4 | 5 | enum my_enum { 6 | ZERO = 0; 7 | ONE = 1; 8 | TWO = 2; 9 | } 10 | 11 | message Test { 12 | int32 camelCase = 1; 13 | my_enum snake_case = 2; 14 | snake_case_message snake_case_message = 3; 15 | int32 UPPERCASE = 4; 16 | } 17 | 18 | message snake_case_message { 19 | 20 | } -------------------------------------------------------------------------------- /tests/inputs/service_uppercase/service.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto3"; 2 | 3 | package service_uppercase; 4 | 5 | message DoTHINGRequest { 6 | string name = 1; 7 | repeated string comments = 2; 8 | } 9 | 10 | message DoTHINGResponse { 11 | repeated string names = 1; 12 | } 13 | 14 | service Test { 15 | rpc DoThing (DoTHINGRequest) returns (DoTHINGResponse); 16 | } 17 | -------------------------------------------------------------------------------- /src/betterproto/_types.py: -------------------------------------------------------------------------------- 1 | from typing import ( 2 | TYPE_CHECKING, 3 | TypeVar, 4 | ) 5 | 6 | 7 | if TYPE_CHECKING: 8 | from grpclib._typing import IProtoMessage 9 | 10 | from . import Message 11 | 12 | # Bound type variable to allow methods to return `self` of subclasses 13 | T = TypeVar("T", bound="Message") 14 | ST = TypeVar("ST", bound="IProtoMessage") 15 | -------------------------------------------------------------------------------- /tests/inputs/float/float.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto3"; 2 | 3 | package float; 4 | 5 | // Some documentation about the Test message. 6 | message Test { 7 | double positive = 1; 8 | double negative = 2; 9 | double nan = 3; 10 | double three = 4; 11 | double three_point_one_four = 5; 12 | double neg_three = 6; 13 | double neg_three_point_one_four = 7; 14 | } 15 | -------------------------------------------------------------------------------- /tests/inputs/import_child_package_from_package/import_child_package_from_package.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto3"; 2 | 3 | package import_child_package_from_package; 4 | 5 | import "package_message.proto"; 6 | 7 | // Tests generated imports when a message in a package refers to a message in a nested child package. 8 | 9 | message Test { 10 | package.PackageMessage message = 1; 11 | } 12 | -------------------------------------------------------------------------------- /tests/inputs/repeated_duration_timestamp/repeated_duration_timestamp.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto3"; 2 | 3 | package repeated_duration_timestamp; 4 | 5 | import "google/protobuf/duration.proto"; 6 | import "google/protobuf/timestamp.proto"; 7 | 8 | 9 | message Test { 10 | repeated google.protobuf.Timestamp times = 1; 11 | repeated google.protobuf.Duration durations = 2; 12 | } 13 | -------------------------------------------------------------------------------- /tests/inputs/import_packages_same_name/import_packages_same_name.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto3"; 2 | 3 | package import_packages_same_name; 4 | 5 | import "users_v1.proto"; 6 | import "posts_v1.proto"; 7 | 8 | // Tests generated message can correctly reference two packages with the same leaf-name 9 | 10 | message Test { 11 | users.v1.User user = 1; 12 | posts.v1.Post post = 2; 13 | } 14 | -------------------------------------------------------------------------------- /tests/inputs/repeated_duration_timestamp/test_repeated_duration_timestamp.py: -------------------------------------------------------------------------------- 1 | from datetime import ( 2 | datetime, 3 | timedelta, 4 | ) 5 | 6 | from tests.output_betterproto.repeated_duration_timestamp import Test 7 | 8 | 9 | def test_roundtrip(): 10 | message = Test() 11 | message.times = [datetime.now(), datetime.now()] 12 | message.durations = [timedelta(), timedelta()] 13 | -------------------------------------------------------------------------------- /tests/inputs/field_name_identical_to_type/field_name_identical_to_type.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto3"; 2 | 3 | package field_name_identical_to_type; 4 | 5 | // Tests that messages may contain fields with names that are identical to their python types (PR #294) 6 | 7 | message Test { 8 | int32 int = 1; 9 | float float = 2; 10 | string str = 3; 11 | bytes bytes = 4; 12 | bool bool = 5; 13 | } -------------------------------------------------------------------------------- /tests/streams/java/src/main/proto/betterproto/oneof.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto3"; 2 | 3 | package oneof; 4 | option java_package = "betterproto.oneof"; 5 | 6 | message Test { 7 | oneof foo { 8 | int32 pitied = 1; 9 | string pitier = 2; 10 | } 11 | 12 | int32 just_a_regular_field = 3; 13 | 14 | oneof bar { 15 | int32 drinks = 11; 16 | string bar_name = 12; 17 | } 18 | } 19 | 20 | -------------------------------------------------------------------------------- /.github/workflows/code-quality.yml: -------------------------------------------------------------------------------- 1 | name: Code Quality 2 | 3 | on: 4 | push: 5 | branches: 6 | - master 7 | pull_request: 8 | branches: 9 | - '**' 10 | 11 | jobs: 12 | check-formatting: 13 | name: Check code/doc formatting 14 | runs-on: ubuntu-latest 15 | steps: 16 | - uses: actions/checkout@v4 17 | - uses: actions/setup-python@v5 18 | - uses: pre-commit/action@v3.0.1 19 | -------------------------------------------------------------------------------- /tests/inputs/oneof/oneof.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto3"; 2 | 3 | package oneof; 4 | 5 | message MixedDrink { 6 | int32 shots = 1; 7 | } 8 | 9 | message Test { 10 | oneof foo { 11 | int32 pitied = 1; 12 | string pitier = 2; 13 | } 14 | 15 | int32 just_a_regular_field = 3; 16 | 17 | oneof bar { 18 | int32 drinks = 11; 19 | string bar_name = 12; 20 | MixedDrink mixed_drink = 13; 21 | } 22 | } 23 | 24 | -------------------------------------------------------------------------------- /tests/inputs/regression_387/test_regression_387.py: -------------------------------------------------------------------------------- 1 | from tests.output_betterproto.regression_387 import ( 2 | ParentElement, 3 | Test, 4 | ) 5 | 6 | 7 | def test_regression_387(): 8 | el = ParentElement(name="test", elems=[Test(id=0), Test(id=42)]) 9 | binary = bytes(el) 10 | decoded = ParentElement().parse(binary) 11 | assert decoded == el 12 | assert decoded.elems == [Test(id=0), Test(id=42)] 13 | -------------------------------------------------------------------------------- /tests/inputs/nested2/nested2.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto3"; 2 | 3 | package nested2; 4 | 5 | import "package.proto"; 6 | 7 | message Game { 8 | message Player { 9 | enum Race { 10 | human = 0; 11 | orc = 1; 12 | } 13 | } 14 | } 15 | 16 | message Test { 17 | Game game = 1; 18 | Game.Player GamePlayer = 2; 19 | Game.Player.Race GamePlayerRace = 3; 20 | equipment.Weapon Weapon = 4; 21 | } -------------------------------------------------------------------------------- /tests/inputs/import_parent_package_from_child/import_parent_package_from_child.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto3"; 2 | 3 | import "parent_package_message.proto"; 4 | 5 | package import_parent_package_from_child.parent.child; 6 | 7 | // Tests generated imports when a message refers to a message defined in its parent package 8 | 9 | message Test { 10 | ParentPackageMessage message_implicit = 1; 11 | parent.ParentPackageMessage message_explicit = 2; 12 | } 13 | -------------------------------------------------------------------------------- /tests/inputs/signed/signed.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto3"; 2 | 3 | package signed; 4 | 5 | message Test { 6 | // todo: rename fields after fixing bug where 'signed_32_positive' will map to 'signed_32Positive' as output json 7 | sint32 signed32 = 1; // signed_32_positive 8 | sint32 negative32 = 2; // signed_32_negative 9 | sint64 string64 = 3; // signed_64_positive 10 | sint64 negative64 = 4; // signed_64_negative 11 | } 12 | -------------------------------------------------------------------------------- /tests/test_mapmessage.py: -------------------------------------------------------------------------------- 1 | from tests.output_betterproto.mapmessage import ( 2 | Nested, 3 | Test, 4 | ) 5 | 6 | 7 | def test_mapmessage_to_dict_preserves_message(): 8 | message = Test( 9 | items={ 10 | "test": Nested( 11 | count=1, 12 | ) 13 | } 14 | ) 15 | 16 | message.to_dict() 17 | 18 | assert isinstance(message.items["test"], Nested), "Wrong nested type after to_dict" 19 | -------------------------------------------------------------------------------- /tests/inputs/proto3_field_presence_oneof/proto3_field_presence_oneof.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto3"; 2 | 3 | package proto3_field_presence_oneof; 4 | 5 | message Test { 6 | oneof kind { 7 | Nested nested = 1; 8 | WithOptional with_optional = 2; 9 | } 10 | } 11 | 12 | message InnerNested { 13 | optional bool a = 1; 14 | } 15 | 16 | message Nested { 17 | InnerNested inner = 1; 18 | } 19 | 20 | message WithOptional { 21 | optional bool b = 2; 22 | } 23 | -------------------------------------------------------------------------------- /tests/inputs/regression_414/test_regression_414.py: -------------------------------------------------------------------------------- 1 | from tests.output_betterproto.regression_414 import Test 2 | 3 | 4 | def test_full_cycle(): 5 | body = bytes([0, 1]) 6 | auth = bytes([2, 3]) 7 | sig = [b""] 8 | 9 | obj = Test(body=body, auth=auth, signatures=sig) 10 | 11 | decoded = Test().parse(bytes(obj)) 12 | assert decoded == obj 13 | assert decoded.body == body 14 | assert decoded.auth == auth 15 | assert decoded.signatures == sig 16 | -------------------------------------------------------------------------------- /tests/inputs/googletypes_value/googletypes_value.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto3"; 2 | 3 | package googletypes_value; 4 | 5 | import "google/protobuf/struct.proto"; 6 | 7 | // Tests that fields of type google.protobuf.Value can contain arbitrary JSON-values. 8 | 9 | message Test { 10 | google.protobuf.Value value1 = 1; 11 | google.protobuf.Value value2 = 2; 12 | google.protobuf.Value value3 = 3; 13 | google.protobuf.Value value4 = 4; 14 | google.protobuf.Value value5 = 5; 15 | } 16 | -------------------------------------------------------------------------------- /tests/inputs/invalid_field/test_invalid_field.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | 4 | def test_invalid_field(): 5 | from tests.output_betterproto.invalid_field import Test 6 | 7 | with pytest.raises(TypeError): 8 | Test(unknown_field=12) 9 | 10 | 11 | def test_invalid_field_pydantic(): 12 | from pydantic import ValidationError 13 | 14 | from tests.output_betterproto_pydantic.invalid_field import Test 15 | 16 | with pytest.raises(ValidationError): 17 | Test(unknown_field=12) 18 | -------------------------------------------------------------------------------- /tests/test_version.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | import tomlkit 4 | 5 | from betterproto import __version__ 6 | 7 | 8 | PROJECT_TOML = Path(__file__).joinpath("..", "..", "pyproject.toml").resolve() 9 | 10 | 11 | def test_version(): 12 | with PROJECT_TOML.open() as toml_file: 13 | project_config = tomlkit.loads(toml_file.read()) 14 | assert __version__ == project_config["project"]["version"], ( 15 | "Project version should match in package and package config" 16 | ) 17 | -------------------------------------------------------------------------------- /tests/inputs/google_impl_behavior_equivalence/google_impl_behavior_equivalence.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto3"; 2 | 3 | import "google/protobuf/timestamp.proto"; 4 | package google_impl_behavior_equivalence; 5 | 6 | message Foo { int64 bar = 1; } 7 | 8 | message Test { 9 | oneof group { 10 | string string = 1; 11 | int64 integer = 2; 12 | Foo foo = 3; 13 | } 14 | } 15 | 16 | message Spam { 17 | google.protobuf.Timestamp ts = 1; 18 | } 19 | 20 | message Request { Empty foo = 1; } 21 | 22 | message Empty {} 23 | -------------------------------------------------------------------------------- /tests/inputs/googletypes/googletypes.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto3"; 2 | 3 | package googletypes; 4 | 5 | import "google/protobuf/duration.proto"; 6 | import "google/protobuf/timestamp.proto"; 7 | import "google/protobuf/wrappers.proto"; 8 | import "google/protobuf/empty.proto"; 9 | 10 | message Test { 11 | google.protobuf.BoolValue maybe = 1; 12 | google.protobuf.Timestamp ts = 2; 13 | google.protobuf.Duration duration = 3; 14 | google.protobuf.Int32Value important = 4; 15 | google.protobuf.Empty empty = 5; 16 | } 17 | -------------------------------------------------------------------------------- /tests/inputs/deprecated/deprecated.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto3"; 2 | 3 | package deprecated; 4 | 5 | // Some documentation about the Test message. 6 | message Test { 7 | Message message = 1 [deprecated=true]; 8 | int32 value = 2; 9 | } 10 | 11 | message Message { 12 | option deprecated = true; 13 | string value = 1; 14 | } 15 | 16 | message Empty {} 17 | 18 | service TestService { 19 | rpc func(Empty) returns (Empty); 20 | rpc deprecated_func(Empty) returns (Empty) { option deprecated = true; }; 21 | } 22 | -------------------------------------------------------------------------------- /tests/inputs/namespace_builtin_types/namespace_builtin_types.json: -------------------------------------------------------------------------------- 1 | { 2 | "int": "value-for-int", 3 | "float": "value-for-float", 4 | "complex": "value-for-complex", 5 | "list": "value-for-list", 6 | "tuple": "value-for-tuple", 7 | "range": "value-for-range", 8 | "str": "value-for-str", 9 | "bytearray": "value-for-bytearray", 10 | "bytes": "value-for-bytes", 11 | "memoryview": "value-for-memoryview", 12 | "set": "value-for-set", 13 | "frozenset": "value-for-frozenset", 14 | "map": "value-for-map", 15 | "bool": "value-for-bool" 16 | } -------------------------------------------------------------------------------- /tests/inputs/casing_inner_class/test_casing_inner_class.py: -------------------------------------------------------------------------------- 1 | import tests.output_betterproto.casing_inner_class as casing_inner_class 2 | 3 | 4 | def test_message_casing_inner_class_name(): 5 | assert hasattr(casing_inner_class, "TestInnerClass"), ( 6 | "Inline defined Message is correctly converted to CamelCase" 7 | ) 8 | 9 | 10 | def test_message_casing_inner_class_attributes(): 11 | message = casing_inner_class.Test() 12 | assert hasattr(message.inner, "old_exp"), ( 13 | "Inline defined Message attribute is snake_case" 14 | ) 15 | -------------------------------------------------------------------------------- /tests/inputs/nested/nested.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto3"; 2 | 3 | package nested; 4 | 5 | // A test message with a nested message inside of it. 6 | message Test { 7 | // This is the nested type. 8 | message Nested { 9 | // Stores a simple counter. 10 | int32 count = 1; 11 | } 12 | // This is the nested enum. 13 | enum Msg { 14 | NONE = 0; 15 | THIS = 1; 16 | } 17 | 18 | Nested nested = 1; 19 | Sibling sibling = 2; 20 | Sibling sibling2 = 3; 21 | Msg msg = 4; 22 | } 23 | 24 | message Sibling { 25 | int32 foo = 1; 26 | } -------------------------------------------------------------------------------- /tests/inputs/service_separate_packages/service.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto3"; 2 | 3 | import "messages.proto"; 4 | 5 | package service_separate_packages.things.service; 6 | 7 | service Test { 8 | rpc DoThing (things.messages.DoThingRequest) returns (things.messages.DoThingResponse); 9 | rpc DoManyThings (stream things.messages.DoThingRequest) returns (things.messages.DoThingResponse); 10 | rpc GetThingVersions (things.messages.GetThingRequest) returns (stream things.messages.GetThingResponse); 11 | rpc GetDifferentThings (stream things.messages.GetThingRequest) returns (stream things.messages.GetThingResponse); 12 | } 13 | -------------------------------------------------------------------------------- /tests/inputs/casing_message_field_uppercase/casing_message_field_uppercase.py: -------------------------------------------------------------------------------- 1 | from tests.output_betterproto.casing_message_field_uppercase import Test 2 | 3 | 4 | def test_message_casing(): 5 | message = Test() 6 | assert hasattr(message, "uppercase"), ( 7 | "UPPERCASE attribute is converted to 'uppercase' in python" 8 | ) 9 | assert hasattr(message, "uppercase_v2"), ( 10 | "UPPERCASE_V2 attribute is converted to 'uppercase_v2' in python" 11 | ) 12 | assert hasattr(message, "upper_camel_case"), ( 13 | "UPPER_CAMEL_CASE attribute is converted to upper_camel_case in python" 14 | ) 15 | -------------------------------------------------------------------------------- /docs/api.rst: -------------------------------------------------------------------------------- 1 | .. currentmodule:: betterproto 2 | 3 | API reference 4 | ============= 5 | 6 | The following document outlines betterproto's api. **None** of these classes should be 7 | extended by the user manually. 8 | 9 | 10 | Message 11 | -------- 12 | 13 | .. autoclass:: betterproto.Message 14 | :members: 15 | :special-members: __bytes__, __bool__ 16 | 17 | 18 | .. autofunction:: betterproto.serialized_on_wire 19 | 20 | .. autofunction:: betterproto.which_one_of 21 | 22 | 23 | Enumerations 24 | ------------- 25 | 26 | .. autoclass:: betterproto.Enum() 27 | :members: 28 | 29 | 30 | .. autoclass:: betterproto.Casing() 31 | :members: 32 | -------------------------------------------------------------------------------- /tests/streams/java/src/main/proto/betterproto/nested.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto3"; 2 | 3 | package nested; 4 | option java_package = "betterproto.nested"; 5 | 6 | // A test message with a nested message inside of it. 7 | message Test { 8 | // This is the nested type. 9 | message Nested { 10 | // Stores a simple counter. 11 | int32 count = 1; 12 | } 13 | // This is the nested enum. 14 | enum Msg { 15 | NONE = 0; 16 | THIS = 1; 17 | } 18 | 19 | Nested nested = 1; 20 | Sibling sibling = 2; 21 | Sibling sibling2 = 3; 22 | Msg msg = 4; 23 | } 24 | 25 | message Sibling { 26 | int32 foo = 1; 27 | } -------------------------------------------------------------------------------- /.github/PULL_REQUEST_TEMPLATE.md: -------------------------------------------------------------------------------- 1 | ## Summary 2 | 3 | 4 | 5 | ## Checklist 6 | 7 | 8 | 9 | - [ ] If code changes were made then they have been tested. 10 | - [ ] I have updated the documentation to reflect the changes. 11 | - [ ] This PR fixes an issue. 12 | - [ ] This PR adds something new (e.g. new method or parameters). 13 | - [ ] This change has an associated test. 14 | - [ ] This PR is a breaking change (e.g. methods or parameters removed/renamed) 15 | - [ ] This PR is **not** a code change (e.g. documentation, README, ...) 16 | 17 | -------------------------------------------------------------------------------- /tests/inputs/bool/test_bool.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from tests.output_betterproto.bool import Test 4 | from tests.output_betterproto_pydantic.bool import Test as TestPyd 5 | 6 | 7 | def test_value(): 8 | message = Test() 9 | assert not message.value, "Boolean is False by default" 10 | 11 | 12 | def test_pydantic_no_value(): 13 | message = TestPyd() 14 | assert not message.value, "Boolean is False by default" 15 | 16 | 17 | def test_pydantic_value(): 18 | message = TestPyd(value=False) 19 | assert not message.value 20 | 21 | 22 | def test_pydantic_bad_value(): 23 | with pytest.raises(ValueError): 24 | TestPyd(value=123) 25 | -------------------------------------------------------------------------------- /tests/streams/java/.gitignore: -------------------------------------------------------------------------------- 1 | ### Output ### 2 | target/ 3 | !.mvn/wrapper/maven-wrapper.jar 4 | !**/src/main/**/target/ 5 | !**/src/test/**/target/ 6 | dependency-reduced-pom.xml 7 | MANIFEST.MF 8 | 9 | ### IntelliJ IDEA ### 10 | .idea/ 11 | *.iws 12 | *.iml 13 | *.ipr 14 | 15 | ### Eclipse ### 16 | .apt_generated 17 | .classpath 18 | .factorypath 19 | .project 20 | .settings 21 | .springBeans 22 | .sts4-cache 23 | 24 | ### NetBeans ### 25 | /nbproject/private/ 26 | /nbbuild/ 27 | /dist/ 28 | /nbdist/ 29 | /.nb-gradle/ 30 | build/ 31 | !**/src/main/**/build/ 32 | !**/src/test/**/build/ 33 | 34 | ### VS Code ### 35 | .vscode/ 36 | 37 | ### Mac OS ### 38 | .DS_Store -------------------------------------------------------------------------------- /src/betterproto/compile/naming.py: -------------------------------------------------------------------------------- 1 | from betterproto import casing 2 | 3 | 4 | def pythonize_class_name(name: str) -> str: 5 | return casing.pascal_case(name) 6 | 7 | 8 | def pythonize_field_name(name: str) -> str: 9 | return casing.safe_snake_case(name) 10 | 11 | 12 | def pythonize_method_name(name: str) -> str: 13 | return casing.safe_snake_case(name) 14 | 15 | 16 | def pythonize_enum_member_name(name: str, enum_name: str) -> str: 17 | enum_name = casing.snake_case(enum_name).upper() 18 | find = name.find(enum_name) 19 | if find != -1: 20 | name = name[find + len(enum_name) :].strip("_") 21 | return casing.sanitize_name(name) 22 | -------------------------------------------------------------------------------- /tests/inputs/example_service/example_service.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto3"; 2 | 3 | package example_service; 4 | 5 | service Test { 6 | rpc ExampleUnaryUnary(ExampleRequest) returns (ExampleResponse); 7 | rpc ExampleUnaryStream(ExampleRequest) returns (stream ExampleResponse); 8 | rpc ExampleStreamUnary(stream ExampleRequest) returns (ExampleResponse); 9 | rpc ExampleStreamStream(stream ExampleRequest) returns (stream ExampleResponse); 10 | } 11 | 12 | message ExampleRequest { 13 | string example_string = 1; 14 | int64 example_integer = 2; 15 | } 16 | 17 | message ExampleResponse { 18 | string example_string = 1; 19 | int64 example_integer = 2; 20 | } 21 | -------------------------------------------------------------------------------- /tests/inputs/proto3_field_presence/proto3_field_presence.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto3"; 2 | 3 | package proto3_field_presence; 4 | 5 | import "google/protobuf/timestamp.proto"; 6 | 7 | message InnerTest { 8 | string test = 1; 9 | } 10 | 11 | message Test { 12 | optional uint32 test1 = 1; 13 | optional bool test2 = 2; 14 | optional string test3 = 3; 15 | optional bytes test4 = 4; 16 | optional InnerTest test5 = 5; 17 | optional TestEnum test6 = 6; 18 | optional uint64 test7 = 7; 19 | optional float test8 = 8; 20 | optional google.protobuf.Timestamp test9 = 9; 21 | } 22 | 23 | enum TestEnum { 24 | A = 0; 25 | B = 1; 26 | } 27 | -------------------------------------------------------------------------------- /tests/test_all_definition.py: -------------------------------------------------------------------------------- 1 | def test_all_definition(): 2 | """ 3 | Check that a compiled module defines __all__ with the right value. 4 | 5 | These modules have been chosen since they contain messages, services and enums. 6 | """ 7 | import tests.output_betterproto.enum as enum 8 | import tests.output_betterproto.service as service 9 | 10 | assert service.__all__ == ( 11 | "ThingType", 12 | "DoThingRequest", 13 | "DoThingResponse", 14 | "GetThingRequest", 15 | "GetThingResponse", 16 | "TestStub", 17 | "TestBase", 18 | ) 19 | assert enum.__all__ == ("Choice", "ArithmeticOperator", "Test") 20 | -------------------------------------------------------------------------------- /tests/inputs/enum/enum.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto3"; 2 | 3 | package enum; 4 | 5 | // Tests that enums are correctly serialized and that it correctly handles skipped and out-of-order enum values 6 | message Test { 7 | Choice choice = 1; 8 | repeated Choice choices = 2; 9 | } 10 | 11 | enum Choice { 12 | ZERO = 0; 13 | ONE = 1; 14 | // TWO = 2; 15 | FOUR = 4; 16 | THREE = 3; 17 | } 18 | 19 | // A "C" like enum with the enum name prefixed onto members, these should be stripped 20 | enum ArithmeticOperator { 21 | ARITHMETIC_OPERATOR_NONE = 0; 22 | ARITHMETIC_OPERATOR_PLUS = 1; 23 | ARITHMETIC_OPERATOR_MINUS = 2; 24 | ARITHMETIC_OPERATOR_0_PREFIXED = 3; 25 | } 26 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | ci: 2 | autofix_prs: false 3 | 4 | repos: 5 | - repo: https://github.com/astral-sh/ruff-pre-commit 6 | rev: v0.9.1 7 | hooks: 8 | - id: ruff-format 9 | args: ["--diff", "src", "tests"] 10 | - id: ruff 11 | args: ["--select", "I", "src", "tests"] 12 | 13 | - repo: https://github.com/PyCQA/doc8 14 | rev: 0.10.1 15 | hooks: 16 | - id: doc8 17 | additional_dependencies: 18 | - toml 19 | 20 | - repo: https://github.com/macisamuele/language-formatters-pre-commit-hooks 21 | rev: v2.14.0 22 | hooks: 23 | - id: pretty-format-java 24 | args: [--autofix, --aosp] 25 | files: ^.*\.java$ 26 | -------------------------------------------------------------------------------- /tests/inputs/googletypes_service_returns_googletype/googletypes_service_returns_googletype.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto3"; 2 | 3 | package googletypes_service_returns_googletype; 4 | 5 | import "google/protobuf/empty.proto"; 6 | import "google/protobuf/struct.proto"; 7 | 8 | // Tests that imports are generated correctly when returning Google well-known types 9 | 10 | service Test { 11 | rpc GetEmpty (RequestMessage) returns (google.protobuf.Empty); 12 | rpc GetStruct (RequestMessage) returns (google.protobuf.Struct); 13 | rpc GetListValue (RequestMessage) returns (google.protobuf.ListValue); 14 | rpc GetValue (RequestMessage) returns (google.protobuf.Value); 15 | } 16 | 17 | message RequestMessage { 18 | } -------------------------------------------------------------------------------- /tests/inputs/namespace_keywords/namespace_keywords.json: -------------------------------------------------------------------------------- 1 | { 2 | "False": 1, 3 | "None": 2, 4 | "True": 3, 5 | "and": 4, 6 | "as": 5, 7 | "assert": 6, 8 | "async": 7, 9 | "await": 8, 10 | "break": 9, 11 | "class": 10, 12 | "continue": 11, 13 | "def": 12, 14 | "del": 13, 15 | "elif": 14, 16 | "else": 15, 17 | "except": 16, 18 | "finally": 17, 19 | "for": 18, 20 | "from": 19, 21 | "global": 20, 22 | "if": 21, 23 | "import": 22, 24 | "in": 23, 25 | "is": 24, 26 | "lambda": 25, 27 | "nonlocal": 26, 28 | "not": 27, 29 | "or": 28, 30 | "pass": 29, 31 | "raise": 30, 32 | "return": 31, 33 | "try": 32, 34 | "while": 33, 35 | "with": 34, 36 | "yield": 35 37 | } 38 | -------------------------------------------------------------------------------- /tests/inputs/import_service_input_message/import_service_input_message.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto3"; 2 | 3 | package import_service_input_message; 4 | 5 | import "request_message.proto"; 6 | import "child_package_request_message.proto"; 7 | 8 | // Tests generated service correctly imports the RequestMessage 9 | 10 | service Test { 11 | rpc DoThing (RequestMessage) returns (RequestResponse); 12 | rpc DoThing2 (child.ChildRequestMessage) returns (RequestResponse); 13 | rpc DoThing3 (Nested.RequestMessage) returns (RequestResponse); 14 | } 15 | 16 | 17 | message RequestResponse { 18 | int32 value = 1; 19 | } 20 | 21 | message Nested { 22 | message RequestMessage { 23 | int32 nestedArgument = 1; 24 | } 25 | } -------------------------------------------------------------------------------- /tests/inputs/entry/entry.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto3"; 2 | 3 | package entry; 4 | 5 | // This is a minimal example of a repeated message field that caused issues when 6 | // checking whether a message is a map. 7 | // 8 | // During the check wheter a field is a "map", the string "entry" is added to 9 | // the field name, checked against the type name and then further checks are 10 | // made against the nested type of a parent message. In this edge-case, the 11 | // first check would pass even though it shouldn't and that would cause an 12 | // error because the parent type does not have a "nested_type" attribute. 13 | 14 | message Test { 15 | repeated ExportEntry export = 1; 16 | } 17 | 18 | message ExportEntry { 19 | string name = 1; 20 | } 21 | -------------------------------------------------------------------------------- /tests/inputs/nestedtwice/test_nestedtwice.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from tests.output_betterproto.nestedtwice import ( 4 | Test, 5 | TestTop, 6 | TestTopMiddle, 7 | TestTopMiddleBottom, 8 | TestTopMiddleEnumBottom, 9 | TestTopMiddleTopMiddleBottom, 10 | ) 11 | 12 | 13 | @pytest.mark.parametrize( 14 | ("cls", "expected_comment"), 15 | [ 16 | (Test, "Test doc."), 17 | (TestTopMiddleEnumBottom, "EnumBottom doc."), 18 | (TestTop, "Top doc."), 19 | (TestTopMiddle, "Middle doc."), 20 | (TestTopMiddleTopMiddleBottom, "TopMiddleBottom doc."), 21 | (TestTopMiddleBottom, "Bottom doc."), 22 | ], 23 | ) 24 | def test_comment(cls, expected_comment): 25 | assert cls.__doc__ == expected_comment 26 | -------------------------------------------------------------------------------- /tests/inputs/casing/test_casing.py: -------------------------------------------------------------------------------- 1 | import tests.output_betterproto.casing as casing 2 | from tests.output_betterproto.casing import Test 3 | 4 | 5 | def test_message_attributes(): 6 | message = Test() 7 | assert hasattr(message, "snake_case_message"), ( 8 | "snake_case field name is same in python" 9 | ) 10 | assert hasattr(message, "camel_case"), "CamelCase field is snake_case in python" 11 | assert hasattr(message, "uppercase"), "UPPERCASE field is lowercase in python" 12 | 13 | 14 | def test_message_casing(): 15 | assert hasattr(casing, "SnakeCaseMessage"), ( 16 | "snake_case Message name is converted to CamelCase in python" 17 | ) 18 | 19 | 20 | def test_enum_casing(): 21 | assert hasattr(casing, "MyEnum"), ( 22 | "snake_case Enum name is converted to CamelCase in python" 23 | ) 24 | -------------------------------------------------------------------------------- /.github/workflows/release.yml: -------------------------------------------------------------------------------- 1 | name: Release 2 | 3 | on: 4 | push: 5 | branches: 6 | - master 7 | tags: 8 | - '**' 9 | pull_request: 10 | branches: 11 | - '**' 12 | 13 | jobs: 14 | packaging: 15 | name: Distribution 16 | runs-on: ubuntu-latest 17 | steps: 18 | - uses: actions/checkout@v4 19 | - name: Set up Python 3.9 20 | uses: actions/setup-python@v5 21 | with: 22 | python-version: 3.9 23 | - name: Install poetry 24 | run: python -m pip install poetry 25 | - name: Build package 26 | run: poetry build 27 | - name: Publish package to PyPI 28 | if: github.event_name == 'push' && startsWith(github.event.ref, 'refs/tags') 29 | env: 30 | POETRY_PYPI_TOKEN_PYPI: ${{ secrets.pypi }} 31 | run: poetry publish -n 32 | -------------------------------------------------------------------------------- /tests/inputs/service/service.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto3"; 2 | 3 | package service; 4 | 5 | enum ThingType { 6 | UNKNOWN = 0; 7 | LIVING = 1; 8 | DEAD = 2; 9 | } 10 | 11 | message DoThingRequest { 12 | string name = 1; 13 | repeated string comments = 2; 14 | ThingType type = 3; 15 | } 16 | 17 | message DoThingResponse { 18 | repeated string names = 1; 19 | } 20 | 21 | message GetThingRequest { 22 | string name = 1; 23 | } 24 | 25 | message GetThingResponse { 26 | string name = 1; 27 | int32 version = 2; 28 | } 29 | 30 | service Test { 31 | rpc DoThing (DoThingRequest) returns (DoThingResponse); 32 | rpc DoManyThings (stream DoThingRequest) returns (DoThingResponse); 33 | rpc GetThingVersions (GetThingRequest) returns (stream GetThingResponse); 34 | rpc GetDifferentThings (stream GetThingRequest) returns (stream GetThingResponse); 35 | } 36 | -------------------------------------------------------------------------------- /tests/inputs/service_separate_packages/messages.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto3"; 2 | 3 | import "google/protobuf/duration.proto"; 4 | import "google/protobuf/timestamp.proto"; 5 | 6 | package service_separate_packages.things.messages; 7 | 8 | message DoThingRequest { 9 | string name = 1; 10 | 11 | // use `repeated` so we can check if `List` is correctly imported 12 | repeated string comments = 2; 13 | 14 | // use google types `timestamp` and `duration` so we can check 15 | // if everything from `datetime` is correctly imported 16 | google.protobuf.Timestamp when = 3; 17 | google.protobuf.Duration duration = 4; 18 | } 19 | 20 | message DoThingResponse { 21 | repeated string names = 1; 22 | } 23 | 24 | message GetThingRequest { 25 | string name = 1; 26 | } 27 | 28 | message GetThingResponse { 29 | string name = 1; 30 | int32 version = 2; 31 | } 32 | -------------------------------------------------------------------------------- /tests/test_timestamp.py: -------------------------------------------------------------------------------- 1 | from datetime import ( 2 | datetime, 3 | timezone, 4 | ) 5 | 6 | import pytest 7 | 8 | from betterproto import _Timestamp 9 | 10 | 11 | @pytest.mark.parametrize( 12 | "dt", 13 | [ 14 | datetime(2023, 10, 11, 9, 41, 12, tzinfo=timezone.utc), 15 | datetime.now(timezone.utc), 16 | # potential issue with floating point precision: 17 | datetime(2242, 12, 31, 23, 0, 0, 1, tzinfo=timezone.utc), 18 | # potential issue with negative timestamps: 19 | datetime(1969, 12, 31, 23, 0, 0, 1, tzinfo=timezone.utc), 20 | ], 21 | ) 22 | def test_timestamp_to_datetime_and_back(dt: datetime): 23 | """ 24 | Make sure converting a datetime to a protobuf timestamp message 25 | and then back again ends up with the same datetime. 26 | """ 27 | assert _Timestamp.from_datetime(dt).to_datetime() == dt 28 | -------------------------------------------------------------------------------- /tests/inputs/oneof_default_value_serialization/oneof_default_value_serialization.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto3"; 2 | 3 | package oneof_default_value_serialization; 4 | 5 | import "google/protobuf/duration.proto"; 6 | import "google/protobuf/timestamp.proto"; 7 | import "google/protobuf/wrappers.proto"; 8 | 9 | message Message{ 10 | int64 value = 1; 11 | } 12 | 13 | message NestedMessage{ 14 | int64 id = 1; 15 | oneof value_type{ 16 | Message wrapped_message_value = 2; 17 | } 18 | } 19 | 20 | message Test{ 21 | oneof value_type { 22 | bool bool_value = 1; 23 | int64 int64_value = 2; 24 | google.protobuf.Timestamp timestamp_value = 3; 25 | google.protobuf.Duration duration_value = 4; 26 | Message wrapped_message_value = 5; 27 | NestedMessage wrapped_nested_message_value = 6; 28 | google.protobuf.BoolValue wrapped_bool_value = 7; 29 | } 30 | } -------------------------------------------------------------------------------- /tests/inputs/googletypes_response_embedded/googletypes_response_embedded.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto3"; 2 | 3 | package googletypes_response_embedded; 4 | 5 | import "google/protobuf/wrappers.proto"; 6 | 7 | // Tests that wrapped values are supported as part of output message 8 | service Test { 9 | rpc getOutput (Input) returns (Output); 10 | } 11 | 12 | message Input { 13 | 14 | } 15 | 16 | message Output { 17 | google.protobuf.DoubleValue double_value = 1; 18 | google.protobuf.FloatValue float_value = 2; 19 | google.protobuf.Int64Value int64_value = 3; 20 | google.protobuf.UInt64Value uint64_value = 4; 21 | google.protobuf.Int32Value int32_value = 5; 22 | google.protobuf.UInt32Value uint32_value = 6; 23 | google.protobuf.BoolValue bool_value = 7; 24 | google.protobuf.StringValue string_value = 8; 25 | google.protobuf.BytesValue bytes_value = 9; 26 | } 27 | -------------------------------------------------------------------------------- /tests/inputs/googletypes_response/googletypes_response.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto3"; 2 | 3 | package googletypes_response; 4 | 5 | import "google/protobuf/wrappers.proto"; 6 | 7 | // Tests that wrapped values can be used directly as return values 8 | 9 | service Test { 10 | rpc GetDouble (Input) returns (google.protobuf.DoubleValue); 11 | rpc GetFloat (Input) returns (google.protobuf.FloatValue); 12 | rpc GetInt64 (Input) returns (google.protobuf.Int64Value); 13 | rpc GetUInt64 (Input) returns (google.protobuf.UInt64Value); 14 | rpc GetInt32 (Input) returns (google.protobuf.Int32Value); 15 | rpc GetUInt32 (Input) returns (google.protobuf.UInt32Value); 16 | rpc GetBool (Input) returns (google.protobuf.BoolValue); 17 | rpc GetString (Input) returns (google.protobuf.StringValue); 18 | rpc GetBytes (Input) returns (google.protobuf.BytesValue); 19 | } 20 | 21 | message Input { 22 | 23 | } 24 | -------------------------------------------------------------------------------- /docs/index.rst: -------------------------------------------------------------------------------- 1 | Welcome to betterproto's documentation! 2 | ======================================= 3 | 4 | betterproto is a protobuf compiler and interpreter. It improves the experience of using 5 | Protobuf and gRPC in Python, by generating readable, understandable, and idiomatic 6 | Python code, using modern language features. 7 | 8 | 9 | Features: 10 | ~~~~~~~~~ 11 | 12 | - Generated messages are both binary & JSON serializable 13 | - Messages use relevant python types, e.g. ``Enum``, ``datetime`` and ``timedelta`` 14 | objects 15 | - ``async``/``await`` support for gRPC Clients and Servers 16 | - Generates modern, readable, idiomatic python code 17 | 18 | Contents: 19 | ~~~~~~~~~ 20 | 21 | .. toctree:: 22 | :maxdepth: 2 23 | 24 | quick-start 25 | api 26 | migrating 27 | 28 | 29 | If you still can't find what you're looking for, try in one of the following pages: 30 | 31 | * :ref:`genindex` 32 | * :ref:`modindex` 33 | * :ref:`search` 34 | -------------------------------------------------------------------------------- /tests/inputs/nestedtwice/nestedtwice.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto3"; 2 | 3 | package nestedtwice; 4 | 5 | /* Test doc. */ 6 | message Test { 7 | /* Top doc. */ 8 | message Top { 9 | /* Middle doc. */ 10 | message Middle { 11 | /* TopMiddleBottom doc.*/ 12 | message TopMiddleBottom { 13 | // TopMiddleBottom.a doc. 14 | string a = 1; 15 | } 16 | /* EnumBottom doc. */ 17 | enum EnumBottom{ 18 | /* EnumBottom.A doc. */ 19 | A = 0; 20 | B = 1; 21 | } 22 | /* Bottom doc. */ 23 | message Bottom { 24 | /* Bottom.foo doc. */ 25 | string foo = 1; 26 | } 27 | reserved 1; 28 | /* Middle.bottom doc. */ 29 | repeated Bottom bottom = 2; 30 | repeated EnumBottom enumBottom=3; 31 | repeated TopMiddleBottom topMiddleBottom=4; 32 | bool bar = 5; 33 | } 34 | /* Top.name doc. */ 35 | string name = 1; 36 | Middle middle = 2; 37 | } 38 | /* Test.top doc. */ 39 | Top top = 1; 40 | } 41 | -------------------------------------------------------------------------------- /src/betterproto/grpc/grpclib_server.py: -------------------------------------------------------------------------------- 1 | from abc import ABC 2 | from collections.abc import AsyncIterable 3 | from typing import ( 4 | Any, 5 | Callable, 6 | Dict, 7 | ) 8 | 9 | import grpclib 10 | import grpclib.server 11 | 12 | 13 | class ServiceBase(ABC): 14 | """ 15 | Base class for async gRPC servers. 16 | """ 17 | 18 | async def _call_rpc_handler_server_stream( 19 | self, 20 | handler: Callable, 21 | stream: grpclib.server.Stream, 22 | request: Any, 23 | ) -> None: 24 | response_iter = handler(request) 25 | # check if response is actually an AsyncIterator 26 | # this might be false if the method just returns without 27 | # yielding at least once 28 | # in that case, we just interpret it as an empty iterator 29 | if isinstance(response_iter, AsyncIterable): 30 | async for response_message in response_iter: 31 | await stream.send_message(response_message) 32 | else: 33 | response_iter.close() 34 | -------------------------------------------------------------------------------- /tests/inputs/config.py: -------------------------------------------------------------------------------- 1 | # Test cases that are expected to fail, e.g. unimplemented features or bug-fixes. 2 | # Remove from list when fixed. 3 | xfail = { 4 | "namespace_keywords", # 70 5 | "googletypes_struct", # 9 6 | "googletypes_value", # 9 7 | "import_capitalized_package", 8 | "example", # This is the example in the readme. Not a test. 9 | } 10 | 11 | services = { 12 | "googletypes_request", 13 | "googletypes_response", 14 | "googletypes_response_embedded", 15 | "service", 16 | "service_separate_packages", 17 | "import_service_input_message", 18 | "googletypes_service_returns_empty", 19 | "googletypes_service_returns_googletype", 20 | "example_service", 21 | "empty_service", 22 | "service_uppercase", 23 | } 24 | 25 | 26 | # Indicate json sample messages to skip when testing that json (de)serialization 27 | # is symmetrical becuase some cases legitimately are not symmetrical. 28 | # Each key references the name of the test scenario and the values in the tuple 29 | # Are the names of the json files. 30 | non_symmetrical_json = {"empty_repeated": ("empty_repeated",)} 31 | -------------------------------------------------------------------------------- /LICENSE.md: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Daniel G. Taylor 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /tests/test_documentation.py: -------------------------------------------------------------------------------- 1 | import ast 2 | import inspect 3 | 4 | 5 | def check(generated_doc: str, type: str) -> None: 6 | assert f"Documentation of {type} 1" in generated_doc 7 | assert "other line 1" in generated_doc 8 | assert f"Documentation of {type} 2" in generated_doc 9 | assert "other line 2" in generated_doc 10 | assert f"Documentation of {type} 3" in generated_doc 11 | 12 | 13 | def test_documentation() -> None: 14 | from .output_betterproto.documentation import ( 15 | Enum, 16 | ServiceBase, 17 | ServiceStub, 18 | Test, 19 | ) 20 | 21 | check(Test.__doc__, "message") 22 | 23 | source = inspect.getsource(Test) 24 | tree = ast.parse(source) 25 | check(tree.body[0].body[2].value.value, "field") 26 | 27 | check(Enum.__doc__, "enum") 28 | 29 | source = inspect.getsource(Enum) 30 | tree = ast.parse(source) 31 | check(tree.body[0].body[2].value.value, "variant") 32 | 33 | check(ServiceBase.__doc__, "service") 34 | check(ServiceBase.get.__doc__, "method") 35 | 36 | check(ServiceStub.__doc__, "service") 37 | check(ServiceStub.get.__doc__, "method") 38 | -------------------------------------------------------------------------------- /tests/inputs/documentation/documentation.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto3"; 2 | package documentation; 3 | 4 | // Documentation of message 1 5 | // other line 1 6 | 7 | // Documentation of message 2 8 | // other line 2 9 | message Test { // Documentation of message 3 10 | // Documentation of field 1 11 | // other line 1 12 | 13 | // Documentation of field 2 14 | // other line 2 15 | uint32 x = 1; // Documentation of field 3 16 | } 17 | 18 | // Documentation of enum 1 19 | // other line 1 20 | 21 | // Documentation of enum 2 22 | // other line 2 23 | enum Enum { // Documentation of enum 3 24 | // Documentation of variant 1 25 | // other line 1 26 | 27 | // Documentation of variant 2 28 | // other line 2 29 | Enum_Variant = 0; // Documentation of variant 3 30 | } 31 | 32 | // Documentation of service 1 33 | // other line 1 34 | 35 | // Documentation of service 2 36 | // other line 2 37 | service Service { // Documentation of service 3 38 | // Documentation of method 1 39 | // other line 1 40 | 41 | // Documentation of method 2 42 | // other line 2 43 | rpc get(Test) returns (Test); // Documentation of method 3 44 | } 45 | -------------------------------------------------------------------------------- /tests/inputs/proto3_field_presence_oneof/test_proto3_field_presence_oneof.py: -------------------------------------------------------------------------------- 1 | from tests.output_betterproto.proto3_field_presence_oneof import ( 2 | InnerNested, 3 | Nested, 4 | Test, 5 | WithOptional, 6 | ) 7 | 8 | 9 | def test_serialization(): 10 | """Ensure that serialization of fields unset but with explicit field 11 | presence do not bloat the serialized payload with length-delimited fields 12 | with length 0""" 13 | 14 | def test_empty_nested(message: Test) -> None: 15 | # '0a' => tag 1, length delimited 16 | # '00' => length: 0 17 | assert bytes(message) == bytearray.fromhex("0a 00") 18 | 19 | test_empty_nested(Test(nested=Nested())) 20 | test_empty_nested(Test(nested=Nested(inner=None))) 21 | test_empty_nested(Test(nested=Nested(inner=InnerNested(a=None)))) 22 | 23 | def test_empty_with_optional(message: Test) -> None: 24 | # '12' => tag 2, length delimited 25 | # '00' => length: 0 26 | assert bytes(message) == bytearray.fromhex("12 00") 27 | 28 | test_empty_with_optional(Test(with_optional=WithOptional())) 29 | test_empty_with_optional(Test(with_optional=WithOptional(b=None))) 30 | -------------------------------------------------------------------------------- /tests/mocks.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | from grpclib.client import Channel 4 | 5 | 6 | class MockChannel(Channel): 7 | # noinspection PyMissingConstructor 8 | def __init__(self, responses=None) -> None: 9 | self.responses = responses or [] 10 | self.requests = [] 11 | self._loop = None 12 | 13 | def request(self, route, cardinality, request, response_type, **kwargs): 14 | self.requests.append( 15 | { 16 | "route": route, 17 | "cardinality": cardinality, 18 | "request": request, 19 | "response_type": response_type, 20 | } 21 | ) 22 | return MockStream(self.responses) 23 | 24 | 25 | class MockStream: 26 | def __init__(self, responses: List) -> None: 27 | super().__init__() 28 | self.responses = responses 29 | 30 | async def recv_message(self): 31 | return self.responses.pop(0) 32 | 33 | async def send_message(self, *args, **kwargs): 34 | pass 35 | 36 | async def __aexit__(self, exc_type, exc_val, exc_tb): 37 | return True 38 | 39 | async def __aenter__(self): 40 | return self 41 | -------------------------------------------------------------------------------- /tests/inputs/googletypes_request/googletypes_request.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto3"; 2 | 3 | package googletypes_request; 4 | 5 | import "google/protobuf/duration.proto"; 6 | import "google/protobuf/empty.proto"; 7 | import "google/protobuf/timestamp.proto"; 8 | import "google/protobuf/wrappers.proto"; 9 | 10 | // Tests that google types can be used as params 11 | 12 | service Test { 13 | rpc SendDouble (google.protobuf.DoubleValue) returns (Input); 14 | rpc SendFloat (google.protobuf.FloatValue) returns (Input); 15 | rpc SendInt64 (google.protobuf.Int64Value) returns (Input); 16 | rpc SendUInt64 (google.protobuf.UInt64Value) returns (Input); 17 | rpc SendInt32 (google.protobuf.Int32Value) returns (Input); 18 | rpc SendUInt32 (google.protobuf.UInt32Value) returns (Input); 19 | rpc SendBool (google.protobuf.BoolValue) returns (Input); 20 | rpc SendString (google.protobuf.StringValue) returns (Input); 21 | rpc SendBytes (google.protobuf.BytesValue) returns (Input); 22 | rpc SendDatetime (google.protobuf.Timestamp) returns (Input); 23 | rpc SendTimedelta (google.protobuf.Duration) returns (Input); 24 | rpc SendEmpty (google.protobuf.Empty) returns (Input); 25 | } 26 | 27 | message Input { 28 | 29 | } 30 | -------------------------------------------------------------------------------- /tests/inputs/import_circular_dependency/import_circular_dependency.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto3"; 2 | 3 | package import_circular_dependency; 4 | 5 | import "root.proto"; 6 | import "other.proto"; 7 | 8 | // This test-case verifies support for circular dependencies in the generated python files. 9 | // 10 | // This is important because we generate 1 python file/module per package, rather than 1 file per proto file. 11 | // 12 | // Scenario: 13 | // 14 | // The proto messages depend on each other in a non-circular way: 15 | // 16 | // Test -------> RootPackageMessage <--------------. 17 | // `------------------------------------> OtherPackageMessage 18 | // 19 | // Test and RootPackageMessage are in different files, but belong to the same package (root): 20 | // 21 | // (Test -------> RootPackageMessage) <------------. 22 | // `------------------------------------> OtherPackageMessage 23 | // 24 | // After grouping the packages into single files or modules, a circular dependency is created: 25 | // 26 | // (root: Test & RootPackageMessage) <-------> (other: OtherPackageMessage) 27 | message Test { 28 | RootPackageMessage message = 1; 29 | other.OtherPackageMessage other_value = 2; 30 | } 31 | -------------------------------------------------------------------------------- /tests/inputs/namespace_keywords/namespace_keywords.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto3"; 2 | 3 | package namespace_keywords; 4 | 5 | // Tests that messages may contain fields that are Python keywords 6 | // 7 | // Generated with Python 3.7.6 8 | // print('\n'.join(f'string {k} = {i+1};' for i,k in enumerate(keyword.kwlist))) 9 | 10 | message Test { 11 | string False = 1; 12 | string None = 2; 13 | string True = 3; 14 | string and = 4; 15 | string as = 5; 16 | string assert = 6; 17 | string async = 7; 18 | string await = 8; 19 | string break = 9; 20 | string class = 10; 21 | string continue = 11; 22 | string def = 12; 23 | string del = 13; 24 | string elif = 14; 25 | string else = 15; 26 | string except = 16; 27 | string finally = 17; 28 | string for = 18; 29 | string from = 19; 30 | string global = 20; 31 | string if = 21; 32 | string import = 22; 33 | string in = 23; 34 | string is = 24; 35 | string lambda = 25; 36 | string nonlocal = 26; 37 | string not = 27; 38 | string or = 28; 39 | string pass = 29; 40 | string raise = 30; 41 | string return = 31; 42 | string try = 32; 43 | string while = 33; 44 | string with = 34; 45 | string yield = 35; 46 | } -------------------------------------------------------------------------------- /.github/CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing 2 | 3 | There's lots to do, and we're working hard, so any help is welcome! 4 | 5 | - :speech_balloon: Join us on [Discord](https://discord.gg/DEVteTupPb)! 6 | 7 | What can you do? 8 | 9 | - :+1: Vote on [issues](https://github.com/danielgtaylor/python-betterproto/issues). 10 | - :speech_balloon: Give feedback on [Pull Requests](https://github.com/danielgtaylor/python-betterproto/pulls) and [Issues](https://github.com/danielgtaylor/python-betterproto/issues): 11 | - Suggestions 12 | - Express approval 13 | - Raise concerns 14 | - :small_red_triangle: Create an issue: 15 | - File a bug (please check its not a duplicate) 16 | - Propose an enhancement 17 | - :white_check_mark: Create a PR: 18 | - [Creating a failing test-case](https://github.com/danielgtaylor/python-betterproto/blob/master/tests/README.md) to make bug-fixing easier 19 | - Fix any of the open issues 20 | - [Good first issues](https://github.com/danielgtaylor/python-betterproto/issues?q=is%3Aissue+is%3Aopen+label%3A%22good+first+issue%22) 21 | - [Issues with tests](https://github.com/danielgtaylor/python-betterproto/issues?q=is%3Aissue+is%3Aopen+label%3A%22has+test%22) 22 | - New bugfix or idea 23 | - If you'd like to discuss your idea first, join us on Discord! 24 | -------------------------------------------------------------------------------- /tests/inputs/import_service_input_message/test_import_service_input_message.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from tests.mocks import MockChannel 4 | from tests.output_betterproto.import_service_input_message import ( 5 | NestedRequestMessage, 6 | RequestMessage, 7 | RequestResponse, 8 | TestStub, 9 | ) 10 | from tests.output_betterproto.import_service_input_message.child import ( 11 | ChildRequestMessage, 12 | ) 13 | 14 | 15 | @pytest.mark.asyncio 16 | async def test_service_correctly_imports_reference_message(): 17 | mock_response = RequestResponse(value=10) 18 | service = TestStub(MockChannel([mock_response])) 19 | response = await service.do_thing(RequestMessage(1)) 20 | assert mock_response == response 21 | 22 | 23 | @pytest.mark.asyncio 24 | async def test_service_correctly_imports_reference_message_from_child_package(): 25 | mock_response = RequestResponse(value=10) 26 | service = TestStub(MockChannel([mock_response])) 27 | response = await service.do_thing2(ChildRequestMessage(1)) 28 | assert mock_response == response 29 | 30 | 31 | @pytest.mark.asyncio 32 | async def test_service_correctly_imports_nested_reference(): 33 | mock_response = RequestResponse(value=10) 34 | service = TestStub(MockChannel([mock_response])) 35 | response = await service.do_thing3(NestedRequestMessage(1)) 36 | assert mock_response == response 37 | -------------------------------------------------------------------------------- /tests/inputs/namespace_builtin_types/namespace_builtin_types.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto3"; 2 | 3 | package namespace_builtin_types; 4 | 5 | // Tests that messages may contain fields with names that are python types 6 | 7 | message Test { 8 | // https://docs.python.org/2/library/stdtypes.html#numeric-types-int-float-long-complex 9 | string int = 1; 10 | string float = 2; 11 | string complex = 3; 12 | 13 | // https://docs.python.org/3/library/stdtypes.html#sequence-types-list-tuple-range 14 | string list = 4; 15 | string tuple = 5; 16 | string range = 6; 17 | 18 | // https://docs.python.org/3/library/stdtypes.html#str 19 | string str = 7; 20 | 21 | // https://docs.python.org/3/library/stdtypes.html#bytearray-objects 22 | string bytearray = 8; 23 | 24 | // https://docs.python.org/3/library/stdtypes.html#bytes-and-bytearray-operations 25 | string bytes = 9; 26 | 27 | // https://docs.python.org/3/library/stdtypes.html#memory-views 28 | string memoryview = 10; 29 | 30 | // https://docs.python.org/3/library/stdtypes.html#set-types-set-frozenset 31 | string set = 11; 32 | string frozenset = 12; 33 | 34 | // https://docs.python.org/3/library/stdtypes.html#dict 35 | string map = 13; 36 | string dict = 14; 37 | 38 | // https://docs.python.org/3/library/stdtypes.html#boolean-values 39 | string bool = 15; 40 | } -------------------------------------------------------------------------------- /tests/test_struct.py: -------------------------------------------------------------------------------- 1 | import json 2 | 3 | from betterproto.lib.google.protobuf import Struct 4 | from betterproto.lib.pydantic.google.protobuf import Struct as StructPydantic 5 | 6 | 7 | def test_struct_roundtrip(): 8 | data = { 9 | "foo": "bar", 10 | "baz": None, 11 | "quux": 123, 12 | "zap": [1, {"two": 3}, "four"], 13 | } 14 | data_json = json.dumps(data) 15 | 16 | struct_from_dict = Struct().from_dict(data) 17 | assert struct_from_dict.fields == data 18 | assert struct_from_dict.to_dict() == data 19 | assert struct_from_dict.to_json() == data_json 20 | 21 | struct_from_json = Struct().from_json(data_json) 22 | assert struct_from_json.fields == data 23 | assert struct_from_json.to_dict() == data 24 | assert struct_from_json == struct_from_dict 25 | assert struct_from_json.to_json() == data_json 26 | 27 | struct_pyd_from_dict = StructPydantic(fields={}).from_dict(data) 28 | assert struct_pyd_from_dict.fields == data 29 | assert struct_pyd_from_dict.to_dict() == data 30 | assert struct_pyd_from_dict.to_json() == data_json 31 | 32 | struct_pyd_from_dict = StructPydantic(fields={}).from_json(data_json) 33 | assert struct_pyd_from_dict.fields == data 34 | assert struct_pyd_from_dict.to_dict() == data 35 | assert struct_pyd_from_dict == struct_pyd_from_dict 36 | assert struct_pyd_from_dict.to_json() == data_json 37 | -------------------------------------------------------------------------------- /tests/inputs/googletypes_response_embedded/test_googletypes_response_embedded.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from tests.mocks import MockChannel 4 | from tests.output_betterproto.googletypes_response_embedded import ( 5 | Input, 6 | Output, 7 | TestStub, 8 | ) 9 | 10 | 11 | @pytest.mark.asyncio 12 | async def test_service_passes_through_unwrapped_values_embedded_in_response(): 13 | """ 14 | We do not not need to implement value unwrapping for embedded well-known types, 15 | as this is already handled by grpclib. This test merely shows that this is the case. 16 | """ 17 | output = Output( 18 | double_value=10.0, 19 | float_value=12.0, 20 | int64_value=-13, 21 | uint64_value=14, 22 | int32_value=-15, 23 | uint32_value=16, 24 | bool_value=True, 25 | string_value="string", 26 | bytes_value=bytes(0xFF)[0:4], 27 | ) 28 | 29 | service = TestStub(MockChannel(responses=[output])) 30 | response = await service.get_output(Input()) 31 | 32 | assert response.double_value == 10.0 33 | assert response.float_value == 12.0 34 | assert response.int64_value == -13 35 | assert response.uint64_value == 14 36 | assert response.int32_value == -15 37 | assert response.uint32_value == 16 38 | assert response.bool_value 39 | assert response.string_value == "string" 40 | assert response.bytes_value == bytes(0xFF)[0:4] 41 | -------------------------------------------------------------------------------- /tests/streams/java/src/main/java/betterproto/CompatibilityTest.java: -------------------------------------------------------------------------------- 1 | package betterproto; 2 | 3 | import java.io.IOException; 4 | 5 | public class CompatibilityTest { 6 | public static void main(String[] args) throws IOException { 7 | if (args.length < 2) 8 | throw new RuntimeException("Attempted to run without the required arguments."); 9 | else if (args.length > 2) 10 | throw new RuntimeException( 11 | "Attempted to run with more than the expected number of arguments (>1)."); 12 | 13 | Tests tests = new Tests(args[1]); 14 | 15 | switch (args[0]) { 16 | case "single_varint": 17 | tests.testSingleVarint(); 18 | break; 19 | 20 | case "multiple_varints": 21 | tests.testMultipleVarints(); 22 | break; 23 | 24 | case "single_message": 25 | tests.testSingleMessage(); 26 | break; 27 | 28 | case "multiple_messages": 29 | tests.testMultipleMessages(); 30 | break; 31 | 32 | case "infinite_messages": 33 | tests.testInfiniteMessages(); 34 | break; 35 | 36 | default: 37 | throw new RuntimeException( 38 | "Attempted to run with unknown argument '" + args[0] + "'."); 39 | } 40 | } 41 | } 42 | -------------------------------------------------------------------------------- /tests/inputs/oneof/test_oneof.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | import betterproto 4 | from tests.output_betterproto.oneof import ( 5 | MixedDrink, 6 | Test, 7 | ) 8 | from tests.output_betterproto_pydantic.oneof import Test as TestPyd 9 | from tests.util import get_test_case_json_data 10 | 11 | 12 | def test_which_count(): 13 | message = Test() 14 | message.from_json(get_test_case_json_data("oneof")[0].json) 15 | assert betterproto.which_one_of(message, "foo") == ("pitied", 100) 16 | 17 | 18 | def test_which_name(): 19 | message = Test() 20 | message.from_json(get_test_case_json_data("oneof", "oneof_name.json")[0].json) 21 | assert betterproto.which_one_of(message, "foo") == ("pitier", "Mr. T") 22 | 23 | 24 | def test_which_count_pyd(): 25 | message = TestPyd(pitier="Mr. T", just_a_regular_field=2, bar_name="a_bar") 26 | assert betterproto.which_one_of(message, "foo") == ("pitier", "Mr. T") 27 | 28 | 29 | def test_oneof_constructor_assign(): 30 | message = Test(mixed_drink=MixedDrink(shots=42)) 31 | field, value = betterproto.which_one_of(message, "bar") 32 | assert field == "mixed_drink" 33 | assert value.shots == 42 34 | 35 | 36 | # Issue #305: 37 | @pytest.mark.xfail 38 | def test_oneof_nested_assign(): 39 | message = Test() 40 | message.mixed_drink.shots = 42 41 | field, value = betterproto.which_one_of(message, "bar") 42 | assert field == "mixed_drink" 43 | assert value.shots == 42 44 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/feature_request.yml: -------------------------------------------------------------------------------- 1 | name: Feature Request 2 | description: Suggest a feature for this library 3 | labels: ["enhancement"] 4 | 5 | body: 6 | - type: input 7 | attributes: 8 | label: Summary 9 | description: > 10 | What problem is your feature trying to solve? What would become easier or possible if feature was implemented? 11 | validations: 12 | required: true 13 | 14 | - type: dropdown 15 | attributes: 16 | multiple: false 17 | label: What is the feature request for? 18 | options: 19 | - The core library 20 | - RPC handling 21 | - The documentation 22 | validations: 23 | required: true 24 | 25 | - type: textarea 26 | attributes: 27 | label: The Problem 28 | description: > 29 | What problem is your feature trying to solve? 30 | What would become easier or possible if feature was implemented? 31 | validations: 32 | required: true 33 | 34 | - type: textarea 35 | attributes: 36 | label: The Ideal Solution 37 | description: > 38 | What is your ideal solution to the problem? 39 | What would you like this feature to do? 40 | validations: 41 | required: true 42 | 43 | - type: textarea 44 | attributes: 45 | label: The Current Solution 46 | description: > 47 | What is the current solution to the problem, if any? 48 | validations: 49 | required: false 50 | -------------------------------------------------------------------------------- /tests/oneof_pattern_matching.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | 3 | import pytest 4 | 5 | import betterproto 6 | 7 | 8 | def test_oneof_pattern_matching(): 9 | @dataclass 10 | class Sub(betterproto.Message): 11 | val: int = betterproto.int32_field(1) 12 | 13 | @dataclass 14 | class Foo(betterproto.Message): 15 | bar: int = betterproto.int32_field(1, group="group1") 16 | baz: str = betterproto.string_field(2, group="group1") 17 | sub: Sub = betterproto.message_field(3, group="group2") 18 | abc: str = betterproto.string_field(4, group="group2") 19 | 20 | foo = Foo(baz="test1", abc="test2") 21 | 22 | match foo: 23 | case Foo(bar=_): 24 | pytest.fail("Matched 'bar' instead of 'baz'") 25 | case Foo(baz=v): 26 | assert v == "test1" 27 | case _: 28 | pytest.fail("Matched neither 'bar' nor 'baz'") 29 | 30 | match foo: 31 | case Foo(sub=_): 32 | pytest.fail("Matched 'sub' instead of 'abc'") 33 | case Foo(abc=v): 34 | assert v == "test2" 35 | case _: 36 | pytest.fail("Matched neither 'sub' nor 'abc'") 37 | 38 | foo.sub = Sub(val=1) 39 | 40 | match foo: 41 | case Foo(sub=Sub(val=v)): 42 | assert v == 1 43 | case Foo(abc=v): 44 | pytest.fail("Matched 'abc' instead of 'sub'") 45 | case _: 46 | pytest.fail("Matched neither 'sub' nor 'abc'") 47 | -------------------------------------------------------------------------------- /.github/workflows/codeql-analysis.yml: -------------------------------------------------------------------------------- 1 | name: "CodeQL" 2 | 3 | on: 4 | push: 5 | branches: [ "master" ] 6 | pull_request: 7 | branches: 8 | - '**' 9 | schedule: 10 | - cron: '19 1 * * 6' 11 | 12 | jobs: 13 | analyze: 14 | name: Analyze 15 | runs-on: ubuntu-latest 16 | permissions: 17 | actions: read 18 | contents: read 19 | security-events: write 20 | 21 | strategy: 22 | fail-fast: false 23 | matrix: 24 | language: [ 'python' ] 25 | 26 | steps: 27 | - name: Checkout repository 28 | uses: actions/checkout@v4 29 | 30 | # Initializes the CodeQL tools for scanning. 31 | - name: Initialize CodeQL 32 | uses: github/codeql-action/init@v3 33 | with: 34 | languages: ${{ matrix.language }} 35 | # If you wish to specify custom queries, you can do so here or in a config file. 36 | # By default, queries listed here will override any specified in a config file. 37 | # Prefix the list here with "+" to use these queries and those in the config file. 38 | 39 | # Details on CodeQL's query packs refer to : https://docs.github.com/en/code-security/code-scanning/automatically-scanning-your-code-for-vulnerabilities-and-errors/configuring-code-scanning#using-queries-in-ql-packs 40 | # queries: security-extended,security-and-quality 41 | 42 | - name: Autobuild 43 | uses: github/codeql-action/autobuild@v3 44 | 45 | - name: Perform CodeQL Analysis 46 | uses: github/codeql-action/analyze@v3 47 | -------------------------------------------------------------------------------- /src/betterproto/utils.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from typing import ( 4 | Any, 5 | Callable, 6 | Generic, 7 | Optional, 8 | Type, 9 | TypeVar, 10 | ) 11 | 12 | from typing_extensions import ( 13 | Concatenate, 14 | ParamSpec, 15 | Self, 16 | ) 17 | 18 | 19 | SelfT = TypeVar("SelfT") 20 | P = ParamSpec("P") 21 | HybridT = TypeVar("HybridT", covariant=True) 22 | 23 | 24 | class hybridmethod(Generic[SelfT, P, HybridT]): 25 | def __init__( 26 | self, 27 | func: Callable[ 28 | Concatenate[type[SelfT], P], HybridT 29 | ], # Must be the classmethod version 30 | ): 31 | self.cls_func = func 32 | self.__doc__ = func.__doc__ 33 | 34 | def instancemethod(self, func: Callable[Concatenate[SelfT, P], HybridT]) -> Self: 35 | self.instance_func = func 36 | return self 37 | 38 | def __get__( 39 | self, instance: Optional[SelfT], owner: Type[SelfT] 40 | ) -> Callable[P, HybridT]: 41 | if instance is None or self.instance_func is None: 42 | # either bound to the class, or no instance method available 43 | return self.cls_func.__get__(owner, None) 44 | return self.instance_func.__get__(instance, owner) 45 | 46 | 47 | T_co = TypeVar("T_co") 48 | TT_co = TypeVar("TT_co", bound="type[Any]") 49 | 50 | 51 | class classproperty(Generic[TT_co, T_co]): 52 | def __init__(self, func: Callable[[TT_co], T_co]): 53 | self.__func__ = func 54 | 55 | def __get__(self, instance: Any, type: TT_co) -> T_co: 56 | return self.__func__(type) 57 | -------------------------------------------------------------------------------- /src/betterproto/plugin/main.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | import os 4 | import sys 5 | 6 | from betterproto.lib.google.protobuf.compiler import ( 7 | CodeGeneratorRequest, 8 | CodeGeneratorResponse, 9 | ) 10 | from betterproto.plugin.models import monkey_patch_oneof_index 11 | from betterproto.plugin.parser import generate_code 12 | 13 | 14 | def main() -> None: 15 | """The plugin's main entry point.""" 16 | # Read request message from stdin 17 | data = sys.stdin.buffer.read() 18 | 19 | # Apply Work around for proto2/3 difference in protoc messages 20 | monkey_patch_oneof_index() 21 | 22 | # Parse request 23 | request = CodeGeneratorRequest() 24 | request.parse(data) 25 | 26 | dump_file = os.getenv("BETTERPROTO_DUMP") 27 | if dump_file: 28 | dump_request(dump_file, request) 29 | 30 | # Generate code 31 | response = generate_code(request) 32 | 33 | # Serialise response message 34 | output = response.SerializeToString() 35 | 36 | # Write to stdout 37 | sys.stdout.buffer.write(output) 38 | 39 | 40 | def dump_request(dump_file: str, request: CodeGeneratorRequest) -> None: 41 | """ 42 | For developers: Supports running plugin.py standalone so its possible to debug it. 43 | Run protoc (or generate.py) with BETTERPROTO_DUMP="yourfile.bin" to write the request to a file. 44 | Then run plugin.py from your IDE in debugging mode, and redirect stdin to the file. 45 | """ 46 | with open(str(dump_file), "wb") as fh: 47 | sys.stderr.write(f"\033[31mWriting input from protoc to: {dump_file}\033[0m\n") 48 | fh.write(request.SerializeToString()) 49 | 50 | 51 | if __name__ == "__main__": 52 | main() 53 | -------------------------------------------------------------------------------- /tests/inputs/googletypes_request/test_googletypes_request.py: -------------------------------------------------------------------------------- 1 | from datetime import ( 2 | datetime, 3 | timedelta, 4 | ) 5 | from typing import ( 6 | Any, 7 | Callable, 8 | ) 9 | 10 | import pytest 11 | 12 | import betterproto.lib.google.protobuf as protobuf 13 | from tests.mocks import MockChannel 14 | from tests.output_betterproto.googletypes_request import ( 15 | Input, 16 | TestStub, 17 | ) 18 | 19 | 20 | test_cases = [ 21 | (TestStub.send_double, protobuf.DoubleValue, 2.5), 22 | (TestStub.send_float, protobuf.FloatValue, 2.5), 23 | (TestStub.send_int64, protobuf.Int64Value, -64), 24 | (TestStub.send_u_int64, protobuf.UInt64Value, 64), 25 | (TestStub.send_int32, protobuf.Int32Value, -32), 26 | (TestStub.send_u_int32, protobuf.UInt32Value, 32), 27 | (TestStub.send_bool, protobuf.BoolValue, True), 28 | (TestStub.send_string, protobuf.StringValue, "string"), 29 | (TestStub.send_bytes, protobuf.BytesValue, bytes(0xFF)[0:4]), 30 | (TestStub.send_datetime, protobuf.Timestamp, datetime(2038, 1, 19, 3, 14, 8)), 31 | (TestStub.send_timedelta, protobuf.Duration, timedelta(seconds=123456)), 32 | ] 33 | 34 | 35 | @pytest.mark.asyncio 36 | @pytest.mark.parametrize(["service_method", "wrapper_class", "value"], test_cases) 37 | async def test_channel_receives_wrapped_type( 38 | service_method: Callable[[TestStub, Input], Any], wrapper_class: Callable, value 39 | ): 40 | wrapped_value = wrapper_class() 41 | wrapped_value.value = value 42 | channel = MockChannel(responses=[Input()]) 43 | service = TestStub(channel) 44 | 45 | await service_method(service, wrapped_value) 46 | 47 | assert channel.requests[0]["request"] == type(wrapped_value) 48 | -------------------------------------------------------------------------------- /tests/inputs/proto3_field_presence/test_proto3_field_presence.py: -------------------------------------------------------------------------------- 1 | import json 2 | 3 | from tests.output_betterproto.proto3_field_presence import ( 4 | InnerTest, 5 | Test, 6 | TestEnum, 7 | ) 8 | 9 | 10 | def test_null_fields_json(): 11 | """Ensure that using "null" in JSON is equivalent to not specifying a 12 | field, for fields with explicit presence""" 13 | 14 | def test_json(ref_json: str, obj_json: str) -> None: 15 | """`ref_json` and `obj_json` are JSON strings describing a `Test` object. 16 | Test that deserializing both leads to the same object, and that 17 | `ref_json` is the normalized format.""" 18 | ref_obj = Test().from_json(ref_json) 19 | obj = Test().from_json(obj_json) 20 | 21 | assert obj == ref_obj 22 | assert json.loads(obj.to_json(0)) == json.loads(ref_json) 23 | 24 | test_json("{}", '{ "test1": null, "test2": null, "test3": null }') 25 | test_json("{}", '{ "test4": null, "test5": null, "test6": null }') 26 | test_json("{}", '{ "test7": null, "test8": null }') 27 | test_json('{ "test5": {} }', '{ "test3": null, "test5": {} }') 28 | 29 | # Make sure that if include_default_values is set, None values are 30 | # exported. 31 | obj = Test() 32 | assert obj.to_dict() == {} 33 | assert obj.to_dict(include_default_values=True) == { 34 | "test1": None, 35 | "test2": None, 36 | "test3": None, 37 | "test4": None, 38 | "test5": None, 39 | "test6": None, 40 | "test7": None, 41 | "test8": None, 42 | "test9": None, 43 | } 44 | 45 | 46 | def test_unset_access(): # see #523 47 | assert Test().test1 is None 48 | assert Test(test1=None).test1 is None 49 | -------------------------------------------------------------------------------- /tests/inputs/oneof_enum/test_oneof_enum.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | import betterproto 4 | from tests.output_betterproto.oneof_enum import ( 5 | Move, 6 | Signal, 7 | Test, 8 | ) 9 | from tests.util import get_test_case_json_data 10 | 11 | 12 | def test_which_one_of_returns_enum_with_default_value(): 13 | """ 14 | returns first field when it is enum and set with default value 15 | """ 16 | message = Test() 17 | message.from_json( 18 | get_test_case_json_data("oneof_enum", "oneof_enum-enum-0.json")[0].json 19 | ) 20 | 21 | assert not hasattr(message, "move") 22 | assert object.__getattribute__(message, "move") == betterproto.PLACEHOLDER 23 | assert message.signal == Signal.PASS 24 | assert betterproto.which_one_of(message, "action") == ("signal", Signal.PASS) 25 | 26 | 27 | def test_which_one_of_returns_enum_with_non_default_value(): 28 | """ 29 | returns first field when it is enum and set with non default value 30 | """ 31 | message = Test() 32 | message.from_json( 33 | get_test_case_json_data("oneof_enum", "oneof_enum-enum-1.json")[0].json 34 | ) 35 | assert not hasattr(message, "move") 36 | assert object.__getattribute__(message, "move") == betterproto.PLACEHOLDER 37 | assert message.signal == Signal.RESIGN 38 | assert betterproto.which_one_of(message, "action") == ("signal", Signal.RESIGN) 39 | 40 | 41 | def test_which_one_of_returns_second_field_when_set(): 42 | message = Test() 43 | message.from_json(get_test_case_json_data("oneof_enum")[0].json) 44 | assert message.move == Move(x=2, y=3) 45 | assert not hasattr(message, "signal") 46 | assert object.__getattribute__(message, "signal") == betterproto.PLACEHOLDER 47 | assert betterproto.which_one_of(message, "action") == ("move", Move(x=2, y=3)) 48 | -------------------------------------------------------------------------------- /tests/test_deprecated.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | 3 | import pytest 4 | 5 | from tests.mocks import MockChannel 6 | from tests.output_betterproto.deprecated import ( 7 | Empty, 8 | Message, 9 | Test, 10 | TestServiceStub, 11 | ) 12 | 13 | 14 | @pytest.fixture 15 | def message(): 16 | with warnings.catch_warnings(): 17 | warnings.filterwarnings("ignore", category=DeprecationWarning) 18 | return Message(value="hello") 19 | 20 | 21 | def test_deprecated_message(): 22 | with pytest.warns(DeprecationWarning) as record: 23 | Message(value="hello") 24 | 25 | assert len(record) == 1 26 | assert str(record[0].message) == f"{Message.__name__} is deprecated" 27 | 28 | 29 | def test_message_with_deprecated_field(message): 30 | with pytest.warns(DeprecationWarning) as record: 31 | Test(message=message, value=10) 32 | 33 | assert len(record) == 1 34 | assert str(record[0].message) == f"{Test.__name__}.message is deprecated" 35 | 36 | 37 | def test_message_with_deprecated_field_not_set(message): 38 | with warnings.catch_warnings(): 39 | warnings.simplefilter("error") 40 | Test(value=10) 41 | 42 | 43 | def test_message_with_deprecated_field_not_set_default(message): 44 | with warnings.catch_warnings(): 45 | warnings.simplefilter("error") 46 | _ = Test(value=10).message 47 | 48 | 49 | @pytest.mark.asyncio 50 | async def test_service_with_deprecated_method(): 51 | stub = TestServiceStub(MockChannel([Empty(), Empty()])) 52 | 53 | with pytest.warns(DeprecationWarning) as record: 54 | await stub.deprecated_func(Empty()) 55 | 56 | assert len(record) == 1 57 | assert str(record[0].message) == f"TestService.deprecated_func is deprecated" 58 | 59 | with warnings.catch_warnings(): 60 | warnings.simplefilter("error") 61 | await stub.func(Empty()) 62 | -------------------------------------------------------------------------------- /src/betterproto/templates/header.py.j2: -------------------------------------------------------------------------------- 1 | # Generated by the protocol buffer compiler. DO NOT EDIT! 2 | # sources: {{ ', '.join(output_file.input_filenames) }} 3 | # plugin: python-betterproto 4 | # This file has been @generated 5 | 6 | __all__ = ( 7 | {%- for enum in output_file.enums -%} 8 | "{{ enum.py_name }}", 9 | {%- endfor -%} 10 | {%- for message in output_file.messages -%} 11 | "{{ message.py_name }}", 12 | {%- endfor -%} 13 | {%- for service in output_file.services -%} 14 | "{{ service.py_name }}Stub", 15 | "{{ service.py_name }}Base", 16 | {%- endfor -%} 17 | ) 18 | 19 | {% for i in output_file.python_module_imports|sort %} 20 | import {{ i }} 21 | {% endfor %} 22 | 23 | {% if output_file.pydantic_dataclasses %} 24 | from pydantic.dataclasses import dataclass 25 | {%- else -%} 26 | from dataclasses import dataclass 27 | {% endif %} 28 | 29 | {% if output_file.datetime_imports %} 30 | from datetime import {% for i in output_file.datetime_imports|sort %}{{ i }}{% if not loop.last %}, {% endif %}{% endfor %} 31 | 32 | {% endif%} 33 | {% set typing_imports = output_file.typing_compiler.imports() %} 34 | {% if typing_imports %} 35 | {% for line in output_file.typing_compiler.import_lines() %} 36 | {{ line }} 37 | {% endfor %} 38 | {% endif %} 39 | 40 | {% if output_file.pydantic_imports %} 41 | from pydantic import {% for i in output_file.pydantic_imports|sort %}{{ i }}{% if not loop.last %}, {% endif %}{% endfor %} 42 | 43 | {% endif %} 44 | 45 | import betterproto 46 | {% if output_file.services %} 47 | from betterproto.grpc.grpclib_server import ServiceBase 48 | import grpclib 49 | {% endif %} 50 | 51 | {% if output_file.imports_type_checking_only %} 52 | from typing import TYPE_CHECKING 53 | 54 | if TYPE_CHECKING: 55 | {% for i in output_file.imports_type_checking_only|sort %} {{ i }} 56 | {% endfor %} 57 | {% endif %} 58 | -------------------------------------------------------------------------------- /tests/test_enum.py: -------------------------------------------------------------------------------- 1 | from typing import ( 2 | Optional, 3 | Tuple, 4 | ) 5 | 6 | import pytest 7 | 8 | import betterproto 9 | 10 | 11 | class Colour(betterproto.Enum): 12 | RED = 1 13 | GREEN = 2 14 | BLUE = 3 15 | 16 | 17 | PURPLE = Colour.__new__(Colour, name=None, value=4) 18 | 19 | 20 | @pytest.mark.parametrize( 21 | "member, str_value", 22 | [ 23 | (Colour.RED, "RED"), 24 | (Colour.GREEN, "GREEN"), 25 | (Colour.BLUE, "BLUE"), 26 | ], 27 | ) 28 | def test_str(member: Colour, str_value: str) -> None: 29 | assert str(member) == str_value 30 | 31 | 32 | @pytest.mark.parametrize( 33 | "member, repr_value", 34 | [ 35 | (Colour.RED, "Colour.RED"), 36 | (Colour.GREEN, "Colour.GREEN"), 37 | (Colour.BLUE, "Colour.BLUE"), 38 | ], 39 | ) 40 | def test_repr(member: Colour, repr_value: str) -> None: 41 | assert repr(member) == repr_value 42 | 43 | 44 | @pytest.mark.parametrize( 45 | "member, values", 46 | [ 47 | (Colour.RED, ("RED", 1)), 48 | (Colour.GREEN, ("GREEN", 2)), 49 | (Colour.BLUE, ("BLUE", 3)), 50 | (PURPLE, (None, 4)), 51 | ], 52 | ) 53 | def test_name_values(member: Colour, values: Tuple[Optional[str], int]) -> None: 54 | assert (member.name, member.value) == values 55 | 56 | 57 | @pytest.mark.parametrize( 58 | "member, input_str", 59 | [ 60 | (Colour.RED, "RED"), 61 | (Colour.GREEN, "GREEN"), 62 | (Colour.BLUE, "BLUE"), 63 | ], 64 | ) 65 | def test_from_string(member: Colour, input_str: str) -> None: 66 | assert Colour.from_string(input_str) == member 67 | 68 | 69 | @pytest.mark.parametrize( 70 | "member, input_int", 71 | [ 72 | (Colour.RED, 1), 73 | (Colour.GREEN, 2), 74 | (Colour.BLUE, 3), 75 | (PURPLE, 4), 76 | ], 77 | ) 78 | def test_try_value(member: Colour, input_int: int) -> None: 79 | assert Colour.try_value(input_int) == member 80 | -------------------------------------------------------------------------------- /docs/conf.py: -------------------------------------------------------------------------------- 1 | # Configuration file for the Sphinx documentation builder. 2 | # 3 | # This file only contains a selection of the most common options. For a full 4 | # list see the documentation: 5 | # https://www.sphinx-doc.org/en/master/usage/configuration.html 6 | 7 | # If extensions (or modules to document with autodoc) are in another directory, 8 | # add these directories to sys.path here. If the directory is relative to the 9 | # documentation root, use os.path.abspath to make it absolute, like shown here. 10 | 11 | import pathlib 12 | 13 | import toml 14 | 15 | 16 | # -- Project information ----------------------------------------------------- 17 | 18 | project = "betterproto" 19 | copyright = "2019 Daniel G. Taylor" 20 | author = "danielgtaylor" 21 | pyproject = toml.load(open(pathlib.Path(__file__).parent.parent / "pyproject.toml")) 22 | 23 | 24 | # The full version, including alpha/beta/rc tags. 25 | release = pyproject["tool"]["poetry"]["version"] 26 | 27 | 28 | # -- General configuration --------------------------------------------------- 29 | 30 | # Add any Sphinx extension module names here, as strings. They can be 31 | # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom 32 | # ones. 33 | extensions = [ 34 | "sphinx.ext.autodoc", 35 | "sphinx.ext.intersphinx", 36 | "sphinx.ext.napoleon", 37 | ] 38 | 39 | autodoc_member_order = "bysource" 40 | autodoc_typehints = "none" 41 | 42 | extlinks = { 43 | "issue": ("https://github.com/danielgtaylor/python-betterproto/issues/%s", "GH-"), 44 | } 45 | 46 | # Links used for cross-referencing stuff in other documentation 47 | intersphinx_mapping = { 48 | "py": ("https://docs.python.org/3", None), 49 | } 50 | 51 | 52 | # -- Options for HTML output ------------------------------------------------- 53 | 54 | # The name of the Pygments (syntax highlighting) style to use. 55 | pygments_style = "friendly" 56 | 57 | # The theme to use for HTML and HTML Help pages. See the documentation for 58 | # a list of builtin themes. 59 | 60 | html_theme = "sphinx_rtd_theme" 61 | -------------------------------------------------------------------------------- /.github/workflows/ci.yml: -------------------------------------------------------------------------------- 1 | name: CI 2 | 3 | on: 4 | push: 5 | branches: 6 | - master 7 | pull_request: 8 | branches: 9 | - '**' 10 | 11 | jobs: 12 | tests: 13 | name: ${{ matrix.os }} / ${{ matrix.python-version }} 14 | runs-on: ${{ matrix.os }}-latest 15 | strategy: 16 | fail-fast: false 17 | matrix: 18 | os: [Ubuntu, MacOS, Windows] 19 | python-version: ['3.9', '3.10', '3.11', '3.12', '3.13'] 20 | steps: 21 | - uses: actions/checkout@v4 22 | 23 | - name: Set up Python ${{ matrix.python-version }} 24 | uses: actions/setup-python@v5 25 | with: 26 | python-version: ${{ matrix.python-version }} 27 | 28 | - name: Get full Python version 29 | id: full-python-version 30 | shell: bash 31 | run: echo "version=$(python -c "import sys; print('-'.join(str(v) for v in sys.version_info))")" >> "$GITHUB_OUTPUT" 32 | 33 | - name: Install poetry 34 | shell: bash 35 | run: | 36 | python -m pip install poetry 37 | echo "$HOME/.poetry/bin" >> $GITHUB_PATH 38 | 39 | - name: Configure poetry 40 | shell: bash 41 | run: poetry config virtualenvs.in-project true 42 | 43 | - name: Set up cache 44 | uses: actions/cache@v4 45 | id: cache 46 | with: 47 | path: .venv 48 | key: venv-${{ runner.os }}-${{ steps.full-python-version.outputs.version }}-${{ hashFiles('**/poetry.lock') }} 49 | 50 | - name: Ensure cache is healthy 51 | if: steps.cache.outputs.cache-hit == 'true' 52 | shell: bash 53 | run: poetry run pip --version >/dev/null 2>&1 || rm -rf .venv 54 | 55 | - name: Install dependencies 56 | shell: bash 57 | run: poetry install -E compiler 58 | 59 | - name: Generate code from proto files 60 | shell: bash 61 | run: poetry run python -m tests.generate -v 62 | 63 | - name: Execute test suite 64 | shell: bash 65 | run: poetry run python -m pytest tests/ 66 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/bug_report.yml: -------------------------------------------------------------------------------- 1 | name: Bug Report 2 | description: Report broken or incorrect behaviour 3 | labels: ["bug", "investigation needed"] 4 | 5 | body: 6 | - type: markdown 7 | attributes: 8 | value: > 9 | Thanks for taking the time to fill out a bug report! 10 | 11 | If you're not sure it's a bug and you just have a question, the [community Discord channel](https://discord.gg/DEVteTupPb) is a better place for general questions than a GitHub issue. 12 | 13 | - type: input 14 | attributes: 15 | label: Summary 16 | description: A simple summary of your bug report 17 | validations: 18 | required: true 19 | 20 | - type: textarea 21 | attributes: 22 | label: Reproduction Steps 23 | description: > 24 | What you did to make it happen. 25 | Ideally there should be a short code snippet in this section to help reproduce the bug. 26 | validations: 27 | required: true 28 | 29 | - type: textarea 30 | attributes: 31 | label: Expected Results 32 | description: > 33 | What did you expect to happen? 34 | validations: 35 | required: true 36 | 37 | - type: textarea 38 | attributes: 39 | label: Actual Results 40 | description: > 41 | What actually happened? 42 | validations: 43 | required: true 44 | 45 | - type: textarea 46 | attributes: 47 | label: System Information 48 | description: > 49 | Paste the result of `protoc --version; python --version; pip show betterproto` below. 50 | validations: 51 | required: true 52 | 53 | - type: checkboxes 54 | attributes: 55 | label: Checklist 56 | options: 57 | - label: I have searched the issues for duplicates. 58 | required: true 59 | - label: I have shown the entire traceback, if possible. 60 | required: true 61 | - label: I have verified this issue occurs on the latest prelease of betterproto which can be installed using `pip install -U --pre betterproto`, if possible. 62 | required: true 63 | 64 | -------------------------------------------------------------------------------- /tests/inputs/googletypes_response/test_googletypes_response.py: -------------------------------------------------------------------------------- 1 | from typing import ( 2 | Any, 3 | Callable, 4 | Optional, 5 | ) 6 | 7 | import pytest 8 | 9 | import betterproto.lib.google.protobuf as protobuf 10 | from tests.mocks import MockChannel 11 | from tests.output_betterproto.googletypes_response import ( 12 | Input, 13 | TestStub, 14 | ) 15 | 16 | 17 | test_cases = [ 18 | (TestStub.get_double, protobuf.DoubleValue, 2.5), 19 | (TestStub.get_float, protobuf.FloatValue, 2.5), 20 | (TestStub.get_int64, protobuf.Int64Value, -64), 21 | (TestStub.get_u_int64, protobuf.UInt64Value, 64), 22 | (TestStub.get_int32, protobuf.Int32Value, -32), 23 | (TestStub.get_u_int32, protobuf.UInt32Value, 32), 24 | (TestStub.get_bool, protobuf.BoolValue, True), 25 | (TestStub.get_string, protobuf.StringValue, "string"), 26 | (TestStub.get_bytes, protobuf.BytesValue, bytes(0xFF)[0:4]), 27 | ] 28 | 29 | 30 | @pytest.mark.asyncio 31 | @pytest.mark.parametrize(["service_method", "wrapper_class", "value"], test_cases) 32 | async def test_channel_receives_wrapped_type( 33 | service_method: Callable[[TestStub, Input], Any], wrapper_class: Callable, value 34 | ): 35 | wrapped_value = wrapper_class() 36 | wrapped_value.value = value 37 | channel = MockChannel(responses=[wrapped_value]) 38 | service = TestStub(channel) 39 | method_param = Input() 40 | 41 | await service_method(service, method_param) 42 | 43 | assert channel.requests[0]["response_type"] != Optional[type(value)] 44 | assert channel.requests[0]["response_type"] == type(wrapped_value) 45 | 46 | 47 | @pytest.mark.asyncio 48 | @pytest.mark.xfail 49 | @pytest.mark.parametrize(["service_method", "wrapper_class", "value"], test_cases) 50 | async def test_service_unwraps_response( 51 | service_method: Callable[[TestStub, Input], Any], wrapper_class: Callable, value 52 | ): 53 | """ 54 | grpclib does not unwrap wrapper values returned by services 55 | """ 56 | wrapped_value = wrapper_class() 57 | wrapped_value.value = value 58 | service = TestStub(MockChannel(responses=[wrapped_value])) 59 | method_param = Input() 60 | 61 | response_value = await service_method(service, method_param) 62 | 63 | assert response_value == value 64 | assert type(response_value) == type(value) 65 | -------------------------------------------------------------------------------- /src/betterproto/plugin/compiler.py: -------------------------------------------------------------------------------- 1 | import os.path 2 | import subprocess 3 | import sys 4 | 5 | from .module_validation import ModuleValidator 6 | 7 | 8 | try: 9 | # betterproto[compiler] specific dependencies 10 | import jinja2 11 | except ImportError as err: 12 | print( 13 | "\033[31m" 14 | f"Unable to import `{err.name}` from betterproto plugin! " 15 | "Please ensure that you've installed betterproto as " 16 | '`pip install "betterproto[compiler]"` so that compiler dependencies ' 17 | "are included." 18 | "\033[0m" 19 | ) 20 | raise SystemExit(1) 21 | 22 | from .models import OutputTemplate 23 | 24 | 25 | def outputfile_compiler(output_file: OutputTemplate) -> str: 26 | templates_folder = os.path.abspath( 27 | os.path.join(os.path.dirname(__file__), "..", "templates") 28 | ) 29 | 30 | env = jinja2.Environment( 31 | trim_blocks=True, 32 | lstrip_blocks=True, 33 | loader=jinja2.FileSystemLoader(templates_folder), 34 | undefined=jinja2.StrictUndefined, 35 | ) 36 | # Load the body first so we have a compleate list of imports needed. 37 | body_template = env.get_template("template.py.j2") 38 | header_template = env.get_template("header.py.j2") 39 | 40 | code = body_template.render(output_file=output_file) 41 | code = header_template.render(output_file=output_file) + code 42 | 43 | # Sort imports, delete unused ones 44 | code = subprocess.check_output( 45 | ["ruff", "check", "--select", "I,F401", "--fix", "--silent", "-"], 46 | input=code, 47 | encoding="utf-8", 48 | ) 49 | 50 | # Format the code 51 | code = subprocess.check_output( 52 | ["ruff", "format", "-"], input=code, encoding="utf-8" 53 | ) 54 | 55 | # Validate the generated code. 56 | validator = ModuleValidator(iter(code.splitlines())) 57 | if not validator.validate(): 58 | message_builder = ["[WARNING]: Generated code has collisions in the module:"] 59 | for collision, lines in validator.collisions.items(): 60 | message_builder.append(f' "{collision}" on lines:') 61 | for num, line in lines: 62 | message_builder.append(f" {num}:{line}") 63 | print("\n".join(message_builder), file=sys.stderr) 64 | return code 65 | -------------------------------------------------------------------------------- /tests/inputs/oneof_default_value_serialization/test_oneof_default_value_serialization.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | 3 | import pytest 4 | 5 | import betterproto 6 | from tests.output_betterproto.oneof_default_value_serialization import ( 7 | Message, 8 | NestedMessage, 9 | Test, 10 | ) 11 | 12 | 13 | def assert_round_trip_serialization_works(message: Test) -> None: 14 | assert betterproto.which_one_of(message, "value_type") == betterproto.which_one_of( 15 | Test().from_json(message.to_json()), "value_type" 16 | ) 17 | 18 | 19 | def test_oneof_default_value_serialization_works_for_all_values(): 20 | """ 21 | Serialization from message with oneof set to default -> JSON -> message should keep 22 | default value field intact. 23 | """ 24 | 25 | test_cases = [ 26 | Test(bool_value=False), 27 | Test(int64_value=0), 28 | Test( 29 | timestamp_value=datetime.datetime( 30 | year=1970, 31 | month=1, 32 | day=1, 33 | hour=0, 34 | minute=0, 35 | tzinfo=datetime.timezone.utc, 36 | ) 37 | ), 38 | Test(duration_value=datetime.timedelta(0)), 39 | Test(wrapped_message_value=Message(value=0)), 40 | # NOTE: Do NOT use betterproto.BoolValue here, it will cause JSON serialization 41 | # errors. 42 | # TODO: Do we want to allow use of BoolValue directly within a wrapped field or 43 | # should we simply hard fail here? 44 | Test(wrapped_bool_value=False), 45 | ] 46 | for message in test_cases: 47 | assert_round_trip_serialization_works(message) 48 | 49 | 50 | def test_oneof_no_default_values_passed(): 51 | message = Test() 52 | assert ( 53 | betterproto.which_one_of(message, "value_type") 54 | == betterproto.which_one_of(Test().from_json(message.to_json()), "value_type") 55 | == ("", None) 56 | ) 57 | 58 | 59 | def test_oneof_nested_oneof_messages_are_serialized_with_defaults(): 60 | """ 61 | Nested messages with oneofs should also be handled 62 | """ 63 | message = Test( 64 | wrapped_nested_message_value=NestedMessage( 65 | id=0, wrapped_message_value=Message(value=0) 66 | ) 67 | ) 68 | assert ( 69 | betterproto.which_one_of(message, "value_type") 70 | == betterproto.which_one_of(Test().from_json(message.to_json()), "value_type") 71 | == ( 72 | "wrapped_nested_message_value", 73 | NestedMessage(id=0, wrapped_message_value=Message(value=0)), 74 | ) 75 | ) 76 | -------------------------------------------------------------------------------- /tests/README.md: -------------------------------------------------------------------------------- 1 | # Standard Tests Development Guide 2 | 3 | Standard test cases are found in [betterproto/tests/inputs](inputs), where each subdirectory represents a testcase, that is verified in isolation. 4 | 5 | ``` 6 | inputs/ 7 | bool/ 8 | double/ 9 | int32/ 10 | ... 11 | ``` 12 | 13 | ## Test case directory structure 14 | 15 | Each testcase has a `.proto` file with a message called `Test`, and optionally a matching `.json` file and a custom test called `test_*.py`. 16 | 17 | ```bash 18 | bool/ 19 | bool.proto 20 | bool.json # optional 21 | test_bool.py # optional 22 | ``` 23 | 24 | ### proto 25 | 26 | `.proto` — *The protobuf message to test* 27 | 28 | ```protobuf 29 | syntax = "proto3"; 30 | 31 | message Test { 32 | bool value = 1; 33 | } 34 | ``` 35 | 36 | You can add multiple `.proto` files to the test case, as long as one file matches the directory name. 37 | 38 | ### json 39 | 40 | `.json` — *Test-data to validate the message with* 41 | 42 | ```json 43 | { 44 | "value": true 45 | } 46 | ``` 47 | 48 | ### pytest 49 | 50 | `test_.py` — *Custom test to validate specific aspects of the generated class* 51 | 52 | ```python 53 | from tests.output_betterproto.bool.bool import Test 54 | 55 | def test_value(): 56 | message = Test() 57 | assert not message.value, "Boolean is False by default" 58 | ``` 59 | 60 | ## Standard tests 61 | 62 | The following tests are automatically executed for all cases: 63 | 64 | - [x] Can the generated python code be imported? 65 | - [x] Can the generated message class be instantiated? 66 | - [x] Is the generated code compatible with the Google's `grpc_tools.protoc` implementation? 67 | - _when `.json` is present_ 68 | 69 | ## Running the tests 70 | 71 | - `pipenv run generate` 72 | This generates: 73 | - `betterproto/tests/output_betterproto` — *the plugin generated python classes* 74 | - `betterproto/tests/output_reference` — *reference implementation classes* 75 | - `pipenv run test` 76 | 77 | ## Intentionally Failing tests 78 | 79 | The standard test suite includes tests that fail by intention. These tests document known bugs and missing features that are intended to be corrected in the future. 80 | 81 | When running `pytest`, they show up as `x` or `X` in the test results. 82 | 83 | ``` 84 | betterproto/tests/test_inputs.py ..x...x..x...x.X........xx........x.....x.......x.xx....x...................... [ 84%] 85 | ``` 86 | 87 | - `.` — PASSED 88 | - `x` — XFAIL: expected failure 89 | - `X` — XPASS: expected failure, but still passed 90 | 91 | Test cases marked for expected failure are declared in [inputs/config.py](inputs/config.py) -------------------------------------------------------------------------------- /tests/inputs/timestamp_dict_encode/test_timestamp_dict_encode.py: -------------------------------------------------------------------------------- 1 | from datetime import ( 2 | datetime, 3 | timedelta, 4 | timezone, 5 | ) 6 | 7 | import pytest 8 | 9 | from tests.output_betterproto.timestamp_dict_encode import Test 10 | 11 | 12 | # Current World Timezone range (UTC-12 to UTC+14) 13 | MIN_UTC_OFFSET_MIN = -12 * 60 14 | MAX_UTC_OFFSET_MIN = 14 * 60 15 | 16 | # Generate all timezones in range in 15 min increments 17 | timezones = [ 18 | timezone(timedelta(minutes=x)) 19 | for x in range(MIN_UTC_OFFSET_MIN, MAX_UTC_OFFSET_MIN + 1, 15) 20 | ] 21 | 22 | 23 | @pytest.mark.parametrize("tz", timezones) 24 | def test_timezone_aware_datetime_dict_encode(tz: timezone): 25 | original_time = datetime.now(tz=tz) 26 | original_message = Test() 27 | original_message.ts = original_time 28 | encoded = original_message.to_dict() 29 | decoded_message = Test() 30 | decoded_message.from_dict(encoded) 31 | 32 | # check that the timestamps are equal after decoding from dict 33 | assert original_message.ts.tzinfo is not None 34 | assert decoded_message.ts.tzinfo is not None 35 | assert original_message.ts == decoded_message.ts 36 | 37 | 38 | def test_naive_datetime_dict_encode(): 39 | # make suer naive datetime objects are still treated as utc 40 | original_time = datetime.now() 41 | assert original_time.tzinfo is None 42 | original_message = Test() 43 | original_message.ts = original_time 44 | original_time_utc = original_time.replace(tzinfo=timezone.utc) 45 | encoded = original_message.to_dict() 46 | decoded_message = Test() 47 | decoded_message.from_dict(encoded) 48 | 49 | # check that the timestamps are equal after decoding from dict 50 | assert decoded_message.ts.tzinfo is not None 51 | assert original_time_utc == decoded_message.ts 52 | 53 | 54 | @pytest.mark.parametrize("tz", timezones) 55 | def test_timezone_aware_json_serialize(tz: timezone): 56 | original_time = datetime.now(tz=tz) 57 | original_message = Test() 58 | original_message.ts = original_time 59 | json_serialized = original_message.to_json() 60 | decoded_message = Test() 61 | decoded_message.from_json(json_serialized) 62 | 63 | # check that the timestamps are equal after decoding from dict 64 | assert original_message.ts.tzinfo is not None 65 | assert decoded_message.ts.tzinfo is not None 66 | assert original_message.ts == decoded_message.ts 67 | 68 | 69 | def test_naive_datetime_json_serialize(): 70 | # make suer naive datetime objects are still treated as utc 71 | original_time = datetime.now() 72 | assert original_time.tzinfo is None 73 | original_message = Test() 74 | original_message.ts = original_time 75 | original_time_utc = original_time.replace(tzinfo=timezone.utc) 76 | json_serialized = original_message.to_json() 77 | decoded_message = Test() 78 | decoded_message.from_json(json_serialized) 79 | 80 | # check that the timestamps are equal after decoding from dict 81 | assert decoded_message.ts.tzinfo is not None 82 | assert original_time_utc == decoded_message.ts 83 | -------------------------------------------------------------------------------- /tests/grpc/thing_service.py: -------------------------------------------------------------------------------- 1 | from typing import Dict 2 | 3 | import grpclib 4 | import grpclib.server 5 | 6 | from tests.output_betterproto.service import ( 7 | DoThingRequest, 8 | DoThingResponse, 9 | GetThingRequest, 10 | GetThingResponse, 11 | ) 12 | 13 | 14 | class ThingService: 15 | def __init__(self, test_hook=None): 16 | # This lets us pass assertions to the servicer ;) 17 | self.test_hook = test_hook 18 | 19 | async def do_thing( 20 | self, stream: "grpclib.server.Stream[DoThingRequest, DoThingResponse]" 21 | ): 22 | request = await stream.recv_message() 23 | if self.test_hook is not None: 24 | self.test_hook(stream) 25 | await stream.send_message(DoThingResponse([request.name])) 26 | 27 | async def do_many_things( 28 | self, stream: "grpclib.server.Stream[DoThingRequest, DoThingResponse]" 29 | ): 30 | thing_names = [request.name async for request in stream] 31 | if self.test_hook is not None: 32 | self.test_hook(stream) 33 | await stream.send_message(DoThingResponse(thing_names)) 34 | 35 | async def get_thing_versions( 36 | self, stream: "grpclib.server.Stream[GetThingRequest, GetThingResponse]" 37 | ): 38 | request = await stream.recv_message() 39 | if self.test_hook is not None: 40 | self.test_hook(stream) 41 | for version_num in range(1, 6): 42 | await stream.send_message( 43 | GetThingResponse(name=request.name, version=version_num) 44 | ) 45 | 46 | async def get_different_things( 47 | self, stream: "grpclib.server.Stream[GetThingRequest, GetThingResponse]" 48 | ): 49 | if self.test_hook is not None: 50 | self.test_hook(stream) 51 | # Respond to each input item immediately 52 | response_num = 0 53 | async for request in stream: 54 | response_num += 1 55 | await stream.send_message( 56 | GetThingResponse(name=request.name, version=response_num) 57 | ) 58 | 59 | def __mapping__(self) -> Dict[str, "grpclib.const.Handler"]: 60 | return { 61 | "/service.Test/DoThing": grpclib.const.Handler( 62 | self.do_thing, 63 | grpclib.const.Cardinality.UNARY_UNARY, 64 | DoThingRequest, 65 | DoThingResponse, 66 | ), 67 | "/service.Test/DoManyThings": grpclib.const.Handler( 68 | self.do_many_things, 69 | grpclib.const.Cardinality.STREAM_UNARY, 70 | DoThingRequest, 71 | DoThingResponse, 72 | ), 73 | "/service.Test/GetThingVersions": grpclib.const.Handler( 74 | self.get_thing_versions, 75 | grpclib.const.Cardinality.UNARY_STREAM, 76 | GetThingRequest, 77 | GetThingResponse, 78 | ), 79 | "/service.Test/GetDifferentThings": grpclib.const.Handler( 80 | self.get_different_things, 81 | grpclib.const.Cardinality.STREAM_STREAM, 82 | GetThingRequest, 83 | GetThingResponse, 84 | ), 85 | } 86 | -------------------------------------------------------------------------------- /tests/grpc/test_stream_stream.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | from dataclasses import dataclass 3 | from typing import AsyncIterator 4 | 5 | import pytest 6 | 7 | import betterproto 8 | from betterproto.grpc.util.async_channel import AsyncChannel 9 | 10 | 11 | @dataclass 12 | class Message(betterproto.Message): 13 | body: str = betterproto.string_field(1) 14 | 15 | 16 | @pytest.fixture 17 | def expected_responses(): 18 | return [Message("Hello world 1"), Message("Hello world 2"), Message("Done")] 19 | 20 | 21 | class ClientStub: 22 | async def connect(self, requests: AsyncIterator): 23 | await asyncio.sleep(0.1) 24 | async for request in requests: 25 | await asyncio.sleep(0.1) 26 | yield request 27 | await asyncio.sleep(0.1) 28 | yield Message("Done") 29 | 30 | 31 | async def to_list(generator: AsyncIterator): 32 | return [value async for value in generator] 33 | 34 | 35 | @pytest.fixture 36 | def client(): 37 | # channel = Channel(host='127.0.0.1', port=50051) 38 | # return ClientStub(channel) 39 | return ClientStub() 40 | 41 | 42 | @pytest.mark.asyncio 43 | async def test_send_from_before_connect_and_close_automatically( 44 | client, expected_responses 45 | ): 46 | requests = AsyncChannel() 47 | await requests.send_from( 48 | [Message(body="Hello world 1"), Message(body="Hello world 2")], close=True 49 | ) 50 | responses = client.connect(requests) 51 | 52 | assert await to_list(responses) == expected_responses 53 | 54 | 55 | @pytest.mark.asyncio 56 | async def test_send_from_after_connect_and_close_automatically( 57 | client, expected_responses 58 | ): 59 | requests = AsyncChannel() 60 | responses = client.connect(requests) 61 | await requests.send_from( 62 | [Message(body="Hello world 1"), Message(body="Hello world 2")], close=True 63 | ) 64 | 65 | assert await to_list(responses) == expected_responses 66 | 67 | 68 | @pytest.mark.asyncio 69 | async def test_send_from_close_manually_immediately(client, expected_responses): 70 | requests = AsyncChannel() 71 | responses = client.connect(requests) 72 | await requests.send_from( 73 | [Message(body="Hello world 1"), Message(body="Hello world 2")], close=False 74 | ) 75 | requests.close() 76 | 77 | assert await to_list(responses) == expected_responses 78 | 79 | 80 | @pytest.mark.asyncio 81 | async def test_send_individually_and_close_before_connect(client, expected_responses): 82 | requests = AsyncChannel() 83 | await requests.send(Message(body="Hello world 1")) 84 | await requests.send(Message(body="Hello world 2")) 85 | requests.close() 86 | responses = client.connect(requests) 87 | 88 | assert await to_list(responses) == expected_responses 89 | 90 | 91 | @pytest.mark.asyncio 92 | async def test_send_individually_and_close_after_connect(client, expected_responses): 93 | requests = AsyncChannel() 94 | await requests.send(Message(body="Hello world 1")) 95 | await requests.send(Message(body="Hello world 2")) 96 | responses = client.connect(requests) 97 | requests.close() 98 | 99 | assert await to_list(responses) == expected_responses 100 | -------------------------------------------------------------------------------- /tests/inputs/example_service/test_example_service.py: -------------------------------------------------------------------------------- 1 | from typing import ( 2 | AsyncIterable, 3 | AsyncIterator, 4 | ) 5 | 6 | import pytest 7 | from grpclib.testing import ChannelFor 8 | 9 | from tests.output_betterproto.example_service import ( 10 | ExampleRequest, 11 | ExampleResponse, 12 | TestBase, 13 | TestStub, 14 | ) 15 | 16 | 17 | class ExampleService(TestBase): 18 | async def example_unary_unary( 19 | self, example_request: ExampleRequest 20 | ) -> "ExampleResponse": 21 | return ExampleResponse( 22 | example_string=example_request.example_string, 23 | example_integer=example_request.example_integer, 24 | ) 25 | 26 | async def example_unary_stream( 27 | self, example_request: ExampleRequest 28 | ) -> AsyncIterator["ExampleResponse"]: 29 | response = ExampleResponse( 30 | example_string=example_request.example_string, 31 | example_integer=example_request.example_integer, 32 | ) 33 | yield response 34 | yield response 35 | yield response 36 | 37 | async def example_stream_unary( 38 | self, example_request_iterator: AsyncIterator["ExampleRequest"] 39 | ) -> "ExampleResponse": 40 | async for example_request in example_request_iterator: 41 | return ExampleResponse( 42 | example_string=example_request.example_string, 43 | example_integer=example_request.example_integer, 44 | ) 45 | 46 | async def example_stream_stream( 47 | self, example_request_iterator: AsyncIterator["ExampleRequest"] 48 | ) -> AsyncIterator["ExampleResponse"]: 49 | async for example_request in example_request_iterator: 50 | yield ExampleResponse( 51 | example_string=example_request.example_string, 52 | example_integer=example_request.example_integer, 53 | ) 54 | 55 | 56 | @pytest.mark.asyncio 57 | async def test_calls_with_different_cardinalities(): 58 | example_request = ExampleRequest("test string", 42) 59 | 60 | async with ChannelFor([ExampleService()]) as channel: 61 | stub = TestStub(channel) 62 | 63 | # unary unary 64 | response = await stub.example_unary_unary(example_request) 65 | assert response.example_string == example_request.example_string 66 | assert response.example_integer == example_request.example_integer 67 | 68 | # unary stream 69 | async for response in stub.example_unary_stream(example_request): 70 | assert response.example_string == example_request.example_string 71 | assert response.example_integer == example_request.example_integer 72 | 73 | # stream unary 74 | async def request_iterator(): 75 | yield example_request 76 | yield example_request 77 | yield example_request 78 | 79 | response = await stub.example_stream_unary(request_iterator()) 80 | assert response.example_string == example_request.example_string 81 | assert response.example_integer == example_request.example_integer 82 | 83 | # stream stream 84 | async for response in stub.example_stream_stream(request_iterator()): 85 | assert response.example_string == example_request.example_string 86 | assert response.example_integer == example_request.example_integer 87 | -------------------------------------------------------------------------------- /tests/inputs/google_impl_behavior_equivalence/test_google_impl_behavior_equivalence.py: -------------------------------------------------------------------------------- 1 | from datetime import ( 2 | datetime, 3 | timezone, 4 | ) 5 | 6 | import pytest 7 | from google.protobuf import json_format 8 | from google.protobuf.timestamp_pb2 import Timestamp 9 | 10 | import betterproto 11 | from tests.output_betterproto.google_impl_behavior_equivalence import ( 12 | Empty, 13 | Foo, 14 | Request, 15 | Spam, 16 | Test, 17 | ) 18 | from tests.output_reference.google_impl_behavior_equivalence.google_impl_behavior_equivalence_pb2 import ( 19 | Empty as ReferenceEmpty, 20 | Foo as ReferenceFoo, 21 | Request as ReferenceRequest, 22 | Spam as ReferenceSpam, 23 | Test as ReferenceTest, 24 | ) 25 | 26 | 27 | def test_oneof_serializes_similar_to_google_oneof(): 28 | tests = [ 29 | (Test(string="abc"), ReferenceTest(string="abc")), 30 | (Test(integer=2), ReferenceTest(integer=2)), 31 | (Test(foo=Foo(bar=1)), ReferenceTest(foo=ReferenceFoo(bar=1))), 32 | # Default values should also behave the same within oneofs 33 | (Test(string=""), ReferenceTest(string="")), 34 | (Test(integer=0), ReferenceTest(integer=0)), 35 | (Test(foo=Foo(bar=0)), ReferenceTest(foo=ReferenceFoo(bar=0))), 36 | ] 37 | for message, message_reference in tests: 38 | # NOTE: As of July 2020, MessageToJson inserts newlines in the output string so, 39 | # just compare dicts 40 | assert message.to_dict() == json_format.MessageToDict(message_reference) 41 | 42 | 43 | def test_bytes_are_the_same_for_oneof(): 44 | message = Test(string="") 45 | message_reference = ReferenceTest(string="") 46 | 47 | message_bytes = bytes(message) 48 | message_reference_bytes = message_reference.SerializeToString() 49 | 50 | assert message_bytes == message_reference_bytes 51 | 52 | message2 = Test().parse(message_reference_bytes) 53 | message_reference2 = ReferenceTest() 54 | message_reference2.ParseFromString(message_reference_bytes) 55 | 56 | assert message == message2 57 | assert message_reference == message_reference2 58 | 59 | # None of these fields were explicitly set BUT they should not actually be null 60 | # themselves 61 | assert not hasattr(message, "foo") 62 | assert object.__getattribute__(message, "foo") == betterproto.PLACEHOLDER 63 | assert not hasattr(message2, "foo") 64 | assert object.__getattribute__(message2, "foo") == betterproto.PLACEHOLDER 65 | 66 | assert isinstance(message_reference.foo, ReferenceFoo) 67 | assert isinstance(message_reference2.foo, ReferenceFoo) 68 | 69 | 70 | @pytest.mark.parametrize("dt", (datetime.min.replace(tzinfo=timezone.utc),)) 71 | def test_datetime_clamping(dt): # see #407 72 | ts = Timestamp() 73 | ts.FromDatetime(dt) 74 | assert bytes(Spam(dt)) == ReferenceSpam(ts=ts).SerializeToString() 75 | message_bytes = bytes(Spam(dt)) 76 | 77 | assert ( 78 | Spam().parse(message_bytes).ts.timestamp() 79 | == ReferenceSpam.FromString(message_bytes).ts.seconds 80 | ) 81 | 82 | 83 | def test_empty_message_field(): 84 | message = Request() 85 | reference_message = ReferenceRequest() 86 | 87 | message.foo = Empty() 88 | reference_message.foo.CopyFrom(ReferenceEmpty()) 89 | 90 | assert betterproto.serialized_on_wire(message.foo) 91 | assert reference_message.HasField("foo") 92 | 93 | assert bytes(message) == reference_message.SerializeToString() 94 | -------------------------------------------------------------------------------- /tests/test_typing_compiler.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from betterproto.plugin.typing_compiler import ( 4 | DirectImportTypingCompiler, 5 | NoTyping310TypingCompiler, 6 | TypingImportTypingCompiler, 7 | ) 8 | 9 | 10 | def test_direct_import_typing_compiler(): 11 | compiler = DirectImportTypingCompiler() 12 | assert compiler.imports() == {} 13 | assert compiler.optional("str") == "Optional[str]" 14 | assert compiler.imports() == {"typing": {"Optional"}} 15 | assert compiler.list("str") == "List[str]" 16 | assert compiler.imports() == {"typing": {"Optional", "List"}} 17 | assert compiler.dict("str", "int") == "Dict[str, int]" 18 | assert compiler.imports() == {"typing": {"Optional", "List", "Dict"}} 19 | assert compiler.union("str", "int") == "Union[str, int]" 20 | assert compiler.imports() == {"typing": {"Optional", "List", "Dict", "Union"}} 21 | assert compiler.iterable("str") == "Iterable[str]" 22 | assert compiler.imports() == { 23 | "typing": {"Optional", "List", "Dict", "Union", "Iterable"} 24 | } 25 | assert compiler.async_iterable("str") == "AsyncIterable[str]" 26 | assert compiler.imports() == { 27 | "typing": {"Optional", "List", "Dict", "Union", "Iterable", "AsyncIterable"} 28 | } 29 | assert compiler.async_iterator("str") == "AsyncIterator[str]" 30 | assert compiler.imports() == { 31 | "typing": { 32 | "Optional", 33 | "List", 34 | "Dict", 35 | "Union", 36 | "Iterable", 37 | "AsyncIterable", 38 | "AsyncIterator", 39 | } 40 | } 41 | 42 | 43 | def test_typing_import_typing_compiler(): 44 | compiler = TypingImportTypingCompiler() 45 | assert compiler.imports() == {} 46 | assert compiler.optional("str") == "typing.Optional[str]" 47 | assert compiler.imports() == {"typing": None} 48 | assert compiler.list("str") == "typing.List[str]" 49 | assert compiler.imports() == {"typing": None} 50 | assert compiler.dict("str", "int") == "typing.Dict[str, int]" 51 | assert compiler.imports() == {"typing": None} 52 | assert compiler.union("str", "int") == "typing.Union[str, int]" 53 | assert compiler.imports() == {"typing": None} 54 | assert compiler.iterable("str") == "typing.Iterable[str]" 55 | assert compiler.imports() == {"typing": None} 56 | assert compiler.async_iterable("str") == "typing.AsyncIterable[str]" 57 | assert compiler.imports() == {"typing": None} 58 | assert compiler.async_iterator("str") == "typing.AsyncIterator[str]" 59 | assert compiler.imports() == {"typing": None} 60 | 61 | 62 | def test_no_typing_311_typing_compiler(): 63 | compiler = NoTyping310TypingCompiler() 64 | assert compiler.imports() == {} 65 | assert compiler.optional("str") == '"str | None"' 66 | assert compiler.imports() == {} 67 | assert compiler.list("str") == '"list[str]"' 68 | assert compiler.imports() == {} 69 | assert compiler.dict("str", "int") == '"dict[str, int]"' 70 | assert compiler.imports() == {} 71 | assert compiler.union("str", "int") == '"str | int"' 72 | assert compiler.imports() == {} 73 | assert compiler.iterable("str") == '"Iterable[str]"' 74 | assert compiler.async_iterable("str") == '"AsyncIterable[str]"' 75 | assert compiler.async_iterator("str") == '"AsyncIterator[str]"' 76 | assert compiler.imports() == { 77 | "collections.abc": {"Iterable", "AsyncIterable", "AsyncIterator"} 78 | } 79 | -------------------------------------------------------------------------------- /tests/test_module_validation.py: -------------------------------------------------------------------------------- 1 | from typing import ( 2 | List, 3 | Optional, 4 | Set, 5 | ) 6 | 7 | import pytest 8 | 9 | from betterproto.plugin.module_validation import ModuleValidator 10 | 11 | 12 | @pytest.mark.parametrize( 13 | ["text", "expected_collisions"], 14 | [ 15 | pytest.param( 16 | ["import os"], 17 | None, 18 | id="single import", 19 | ), 20 | pytest.param( 21 | ["import os", "import sys"], 22 | None, 23 | id="multiple imports", 24 | ), 25 | pytest.param( 26 | ["import os", "import os"], 27 | {"os"}, 28 | id="duplicate imports", 29 | ), 30 | pytest.param( 31 | ["from os import path", "import os"], 32 | None, 33 | id="duplicate imports with alias", 34 | ), 35 | pytest.param( 36 | ["from os import path", "import os as os_alias"], 37 | None, 38 | id="duplicate imports with alias", 39 | ), 40 | pytest.param( 41 | ["from os import path", "import os as path"], 42 | {"path"}, 43 | id="duplicate imports with alias", 44 | ), 45 | pytest.param( 46 | ["import os", "class os:"], 47 | {"os"}, 48 | id="duplicate import with class", 49 | ), 50 | pytest.param( 51 | ["import os", "class os:", " pass", "import sys"], 52 | {"os"}, 53 | id="duplicate import with class and another", 54 | ), 55 | pytest.param( 56 | ["def test(): pass", "class test:"], 57 | {"test"}, 58 | id="duplicate class and function", 59 | ), 60 | pytest.param( 61 | ["def test(): pass", "def test(): pass"], 62 | {"test"}, 63 | id="duplicate functions", 64 | ), 65 | pytest.param( 66 | ["def test(): pass", "test = 100"], 67 | {"test"}, 68 | id="function and variable", 69 | ), 70 | pytest.param( 71 | ["def test():", " test = 3"], 72 | None, 73 | id="function and variable in function", 74 | ), 75 | pytest.param( 76 | [ 77 | "def test(): pass", 78 | "'''", 79 | "def test(): pass", 80 | "'''", 81 | "def test_2(): pass", 82 | ], 83 | None, 84 | id="duplicate functions with multiline string", 85 | ), 86 | pytest.param( 87 | ["def test(): pass", "# def test(): pass"], 88 | None, 89 | id="duplicate functions with comments", 90 | ), 91 | pytest.param( 92 | ["from test import (", " A", " B", " C", ")"], 93 | None, 94 | id="multiline import", 95 | ), 96 | pytest.param( 97 | ["from test import (", " A", " B", " C", ")", "from test import A"], 98 | {"A"}, 99 | id="multiline import with duplicate", 100 | ), 101 | ], 102 | ) 103 | def test_module_validator(text: List[str], expected_collisions: Optional[Set[str]]): 104 | line_iterator = iter(text) 105 | validator = ModuleValidator(line_iterator) 106 | valid = validator.validate() 107 | if expected_collisions is None: 108 | assert valid 109 | else: 110 | assert set(validator.collisions.keys()) == expected_collisions 111 | assert not valid 112 | -------------------------------------------------------------------------------- /tests/streams/java/pom.xml: -------------------------------------------------------------------------------- 1 | 2 | 5 | 4.0.0 6 | 7 | betterproto 8 | compatibility-test 9 | 1.0-SNAPSHOT 10 | jar 11 | 12 | 13 | 11 14 | 11 15 | UTF-8 16 | 3.23.4 17 | 18 | 19 | 20 | 21 | com.google.protobuf 22 | protobuf-java 23 | ${protobuf.version} 24 | 25 | 26 | 27 | 28 | 29 | 30 | kr.motd.maven 31 | os-maven-plugin 32 | 1.7.1 33 | 34 | 35 | 36 | 37 | 38 | org.apache.maven.plugins 39 | maven-shade-plugin 40 | 3.5.0 41 | 42 | 43 | package 44 | 45 | shade 46 | 47 | 48 | 49 | 50 | betterproto.CompatibilityTest 51 | 52 | 53 | 54 | 55 | 56 | 57 | 58 | 59 | org.apache.maven.plugins 60 | maven-jar-plugin 61 | 3.3.0 62 | 63 | 64 | 65 | true 66 | betterproto.CompatibilityTest 67 | 68 | 69 | 70 | 71 | 72 | 73 | org.xolstice.maven.plugins 74 | protobuf-maven-plugin 75 | 0.6.1 76 | 77 | 78 | 79 | compile 80 | 81 | 82 | 83 | 84 | 85 | com.google.protobuf:protoc:${protobuf.version}:exe:${os.detected.classifier} 86 | 87 | 88 | 89 | 90 | 91 | ${project.artifactId} 92 | 93 | 94 | -------------------------------------------------------------------------------- /tests/test_casing.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from betterproto.casing import ( 4 | camel_case, 5 | pascal_case, 6 | snake_case, 7 | ) 8 | 9 | 10 | @pytest.mark.parametrize( 11 | ["value", "expected"], 12 | [ 13 | ("", ""), 14 | ("a", "A"), 15 | ("foobar", "Foobar"), 16 | ("fooBar", "FooBar"), 17 | ("FooBar", "FooBar"), 18 | ("foo.bar", "FooBar"), 19 | ("foo_bar", "FooBar"), 20 | ("FOOBAR", "Foobar"), 21 | ("FOOBar", "FooBar"), 22 | ("UInt32", "UInt32"), 23 | ("FOO_BAR", "FooBar"), 24 | ("FOOBAR1", "Foobar1"), 25 | ("FOOBAR_1", "Foobar1"), 26 | ("FOO1BAR2", "Foo1Bar2"), 27 | ("foo__bar", "FooBar"), 28 | ("_foobar", "Foobar"), 29 | ("foobaR", "FoobaR"), 30 | ("foo~bar", "FooBar"), 31 | ("foo:bar", "FooBar"), 32 | ("1foobar", "1Foobar"), 33 | ], 34 | ) 35 | def test_pascal_case(value, expected): 36 | actual = pascal_case(value, strict=True) 37 | assert actual == expected, f"{value} => {expected} (actual: {actual})" 38 | 39 | 40 | @pytest.mark.parametrize( 41 | ["value", "expected"], 42 | [ 43 | ("", ""), 44 | ("a", "a"), 45 | ("foobar", "foobar"), 46 | ("fooBar", "fooBar"), 47 | ("FooBar", "fooBar"), 48 | ("foo.bar", "fooBar"), 49 | ("foo_bar", "fooBar"), 50 | ("FOOBAR", "foobar"), 51 | ("FOO_BAR", "fooBar"), 52 | ("FOOBAR1", "foobar1"), 53 | ("FOOBAR_1", "foobar1"), 54 | ("FOO1BAR2", "foo1Bar2"), 55 | ("foo__bar", "fooBar"), 56 | ("_foobar", "foobar"), 57 | ("foobaR", "foobaR"), 58 | ("foo~bar", "fooBar"), 59 | ("foo:bar", "fooBar"), 60 | ("1foobar", "1Foobar"), 61 | ], 62 | ) 63 | def test_camel_case_strict(value, expected): 64 | actual = camel_case(value, strict=True) 65 | assert actual == expected, f"{value} => {expected} (actual: {actual})" 66 | 67 | 68 | @pytest.mark.parametrize( 69 | ["value", "expected"], 70 | [ 71 | ("foo_bar", "fooBar"), 72 | ("FooBar", "fooBar"), 73 | ("foo__bar", "foo_Bar"), 74 | ("foo__Bar", "foo__Bar"), 75 | ], 76 | ) 77 | def test_camel_case_not_strict(value, expected): 78 | actual = camel_case(value, strict=False) 79 | assert actual == expected, f"{value} => {expected} (actual: {actual})" 80 | 81 | 82 | @pytest.mark.parametrize( 83 | ["value", "expected"], 84 | [ 85 | ("", ""), 86 | ("a", "a"), 87 | ("foobar", "foobar"), 88 | ("fooBar", "foo_bar"), 89 | ("FooBar", "foo_bar"), 90 | ("foo.bar", "foo_bar"), 91 | ("foo_bar", "foo_bar"), 92 | ("foo_Bar", "foo_bar"), 93 | ("FOOBAR", "foobar"), 94 | ("FOOBar", "foo_bar"), 95 | ("UInt32", "u_int32"), 96 | ("FOO_BAR", "foo_bar"), 97 | ("FOOBAR1", "foobar1"), 98 | ("FOOBAR_1", "foobar_1"), 99 | ("FOOBAR_123", "foobar_123"), 100 | ("FOO1BAR2", "foo1_bar2"), 101 | ("foo__bar", "foo_bar"), 102 | ("_foobar", "foobar"), 103 | ("foobaR", "fooba_r"), 104 | ("foo~bar", "foo_bar"), 105 | ("foo:bar", "foo_bar"), 106 | ("1foobar", "1_foobar"), 107 | ("GetUInt64", "get_u_int64"), 108 | ], 109 | ) 110 | def test_snake_case_strict(value, expected): 111 | actual = snake_case(value) 112 | assert actual == expected, f"{value} => {expected} (actual: {actual})" 113 | 114 | 115 | @pytest.mark.parametrize( 116 | ["value", "expected"], 117 | [ 118 | ("fooBar", "foo_bar"), 119 | ("FooBar", "foo_bar"), 120 | ("foo_Bar", "foo__bar"), 121 | ("foo__bar", "foo__bar"), 122 | ("FOOBar", "foo_bar"), 123 | ("__foo", "__foo"), 124 | ("GetUInt64", "get_u_int64"), 125 | ], 126 | ) 127 | def test_snake_case_not_strict(value, expected): 128 | actual = snake_case(value, strict=False) 129 | assert actual == expected, f"{value} => {expected} (actual: {actual})" 130 | -------------------------------------------------------------------------------- /src/betterproto/casing.py: -------------------------------------------------------------------------------- 1 | import keyword 2 | import re 3 | 4 | 5 | # Word delimiters and symbols that will not be preserved when re-casing. 6 | # language=PythonRegExp 7 | SYMBOLS = "[^a-zA-Z0-9]*" 8 | 9 | # Optionally capitalized word. 10 | # language=PythonRegExp 11 | WORD = "[A-Z]*[a-z]*[0-9]*" 12 | 13 | # Uppercase word, not followed by lowercase letters. 14 | # language=PythonRegExp 15 | WORD_UPPER = "[A-Z]+(?![a-z])[0-9]*" 16 | 17 | 18 | def safe_snake_case(value: str) -> str: 19 | """Snake case a value taking into account Python keywords.""" 20 | value = snake_case(value) 21 | value = sanitize_name(value) 22 | return value 23 | 24 | 25 | def snake_case(value: str, strict: bool = True) -> str: 26 | """ 27 | Join words with an underscore into lowercase and remove symbols. 28 | 29 | Parameters 30 | ----------- 31 | value: :class:`str` 32 | The value to convert. 33 | strict: :class:`bool` 34 | Whether or not to force single underscores. 35 | 36 | Returns 37 | -------- 38 | :class:`str` 39 | The value in snake_case. 40 | """ 41 | 42 | def substitute_word(symbols: str, word: str, is_start: bool) -> str: 43 | if not word: 44 | return "" 45 | if strict: 46 | delimiter_count = 0 if is_start else 1 # Single underscore if strict. 47 | elif is_start: 48 | delimiter_count = len(symbols) 49 | elif word.isupper() or word.islower(): 50 | delimiter_count = max( 51 | 1, len(symbols) 52 | ) # Preserve all delimiters if not strict. 53 | else: 54 | delimiter_count = len(symbols) + 1 # Extra underscore for leading capital. 55 | 56 | return ("_" * delimiter_count) + word.lower() 57 | 58 | snake = re.sub( 59 | f"(^)?({SYMBOLS})({WORD_UPPER}|{WORD})", 60 | lambda groups: substitute_word(groups[2], groups[3], groups[1] is not None), 61 | value, 62 | ) 63 | return snake 64 | 65 | 66 | def pascal_case(value: str, strict: bool = True) -> str: 67 | """ 68 | Capitalize each word and remove symbols. 69 | 70 | Parameters 71 | ----------- 72 | value: :class:`str` 73 | The value to convert. 74 | strict: :class:`bool` 75 | Whether or not to output only alphanumeric characters. 76 | 77 | Returns 78 | -------- 79 | :class:`str` 80 | The value in PascalCase. 81 | """ 82 | 83 | def substitute_word(symbols, word): 84 | if strict: 85 | return word.capitalize() # Remove all delimiters 86 | 87 | if word.islower(): 88 | delimiter_length = len(symbols[:-1]) # Lose one delimiter 89 | else: 90 | delimiter_length = len(symbols) # Preserve all delimiters 91 | 92 | return ("_" * delimiter_length) + word.capitalize() 93 | 94 | return re.sub( 95 | f"({SYMBOLS})({WORD_UPPER}|{WORD})", 96 | lambda groups: substitute_word(groups[1], groups[2]), 97 | value, 98 | ) 99 | 100 | 101 | def camel_case(value: str, strict: bool = True) -> str: 102 | """ 103 | Capitalize all words except first and remove symbols. 104 | 105 | Parameters 106 | ----------- 107 | value: :class:`str` 108 | The value to convert. 109 | strict: :class:`bool` 110 | Whether or not to output only alphanumeric characters. 111 | 112 | Returns 113 | -------- 114 | :class:`str` 115 | The value in camelCase. 116 | """ 117 | return lowercase_first(pascal_case(value, strict=strict)) 118 | 119 | 120 | def lowercase_first(value: str) -> str: 121 | """ 122 | Lower cases the first character of the value. 123 | 124 | Parameters 125 | ---------- 126 | value: :class:`str` 127 | The value to lower case. 128 | 129 | Returns 130 | ------- 131 | :class:`str` 132 | The lower cased string. 133 | """ 134 | return value[0:1].lower() + value[1:] 135 | 136 | 137 | def sanitize_name(value: str) -> str: 138 | # https://www.python.org/dev/peps/pep-0008/#descriptive-naming-styles 139 | if keyword.iskeyword(value): 140 | return f"{value}_" 141 | if not value.isidentifier(): 142 | return f"_{value}" 143 | return value 144 | -------------------------------------------------------------------------------- /tests/inputs/enum/test_enum.py: -------------------------------------------------------------------------------- 1 | from tests.output_betterproto.enum import ( 2 | ArithmeticOperator, 3 | Choice, 4 | Test, 5 | ) 6 | 7 | 8 | def test_enum_set_and_get(): 9 | assert Test(choice=Choice.ZERO).choice == Choice.ZERO 10 | assert Test(choice=Choice.ONE).choice == Choice.ONE 11 | assert Test(choice=Choice.THREE).choice == Choice.THREE 12 | assert Test(choice=Choice.FOUR).choice == Choice.FOUR 13 | 14 | 15 | def test_enum_set_with_int(): 16 | assert Test(choice=0).choice == Choice.ZERO 17 | assert Test(choice=1).choice == Choice.ONE 18 | assert Test(choice=3).choice == Choice.THREE 19 | assert Test(choice=4).choice == Choice.FOUR 20 | 21 | 22 | def test_enum_is_comparable_with_int(): 23 | assert Test(choice=Choice.ZERO).choice == 0 24 | assert Test(choice=Choice.ONE).choice == 1 25 | assert Test(choice=Choice.THREE).choice == 3 26 | assert Test(choice=Choice.FOUR).choice == 4 27 | 28 | 29 | def test_enum_to_dict(): 30 | assert "choice" not in Test(choice=Choice.ZERO).to_dict(), ( 31 | "Default enum value is not serialized" 32 | ) 33 | assert ( 34 | Test(choice=Choice.ZERO).to_dict(include_default_values=True)["choice"] 35 | == "ZERO" 36 | ) 37 | assert Test(choice=Choice.ONE).to_dict()["choice"] == "ONE" 38 | assert Test(choice=Choice.THREE).to_dict()["choice"] == "THREE" 39 | assert Test(choice=Choice.FOUR).to_dict()["choice"] == "FOUR" 40 | 41 | 42 | def test_repeated_enum_is_comparable_with_int(): 43 | assert Test(choices=[Choice.ZERO]).choices == [0] 44 | assert Test(choices=[Choice.ONE]).choices == [1] 45 | assert Test(choices=[Choice.THREE]).choices == [3] 46 | assert Test(choices=[Choice.FOUR]).choices == [4] 47 | 48 | 49 | def test_repeated_enum_set_and_get(): 50 | assert Test(choices=[Choice.ZERO]).choices == [Choice.ZERO] 51 | assert Test(choices=[Choice.ONE]).choices == [Choice.ONE] 52 | assert Test(choices=[Choice.THREE]).choices == [Choice.THREE] 53 | assert Test(choices=[Choice.FOUR]).choices == [Choice.FOUR] 54 | 55 | 56 | def test_repeated_enum_to_dict(): 57 | assert Test(choices=[Choice.ZERO]).to_dict()["choices"] == ["ZERO"] 58 | assert Test(choices=[Choice.ONE]).to_dict()["choices"] == ["ONE"] 59 | assert Test(choices=[Choice.THREE]).to_dict()["choices"] == ["THREE"] 60 | assert Test(choices=[Choice.FOUR]).to_dict()["choices"] == ["FOUR"] 61 | 62 | all_enums_dict = Test( 63 | choices=[Choice.ZERO, Choice.ONE, Choice.THREE, Choice.FOUR] 64 | ).to_dict() 65 | assert (all_enums_dict["choices"]) == ["ZERO", "ONE", "THREE", "FOUR"] 66 | 67 | 68 | def test_repeated_enum_with_single_value_to_dict(): 69 | assert Test(choices=Choice.ONE).to_dict()["choices"] == ["ONE"] 70 | assert Test(choices=1).to_dict()["choices"] == ["ONE"] 71 | 72 | 73 | def test_repeated_enum_with_non_list_iterables_to_dict(): 74 | assert Test(choices=(1, 3)).to_dict()["choices"] == ["ONE", "THREE"] 75 | assert Test(choices=(1, 3)).to_dict()["choices"] == ["ONE", "THREE"] 76 | assert Test(choices=(Choice.ONE, Choice.THREE)).to_dict()["choices"] == [ 77 | "ONE", 78 | "THREE", 79 | ] 80 | 81 | def enum_generator(): 82 | yield Choice.ONE 83 | yield Choice.THREE 84 | 85 | assert Test(choices=enum_generator()).to_dict()["choices"] == ["ONE", "THREE"] 86 | 87 | 88 | def test_enum_mapped_on_parse(): 89 | # test default value 90 | b = Test().parse(bytes(Test())) 91 | assert b.choice.name == Choice.ZERO.name 92 | assert b.choices == [] 93 | 94 | # test non default value 95 | a = Test().parse(bytes(Test(choice=Choice.ONE))) 96 | assert a.choice.name == Choice.ONE.name 97 | assert b.choices == [] 98 | 99 | # test repeated 100 | c = Test().parse(bytes(Test(choices=[Choice.THREE, Choice.FOUR]))) 101 | assert c.choices[0].name == Choice.THREE.name 102 | assert c.choices[1].name == Choice.FOUR.name 103 | 104 | # bonus: defaults after empty init are also mapped 105 | assert Test().choice.name == Choice.ZERO.name 106 | 107 | 108 | def test_renamed_enum_members(): 109 | assert set(ArithmeticOperator.__members__) == { 110 | "NONE", 111 | "PLUS", 112 | "MINUS", 113 | "_0_PREFIXED", 114 | } 115 | -------------------------------------------------------------------------------- /benchmarks/benchmarks.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import List 3 | 4 | import betterproto 5 | 6 | 7 | @dataclass 8 | class TestMessage(betterproto.Message): 9 | foo: int = betterproto.uint32_field(1) 10 | bar: str = betterproto.string_field(2) 11 | baz: float = betterproto.float_field(3) 12 | 13 | 14 | @dataclass 15 | class TestNestedChildMessage(betterproto.Message): 16 | str_key: str = betterproto.string_field(1) 17 | bytes_key: bytes = betterproto.bytes_field(2) 18 | bool_key: bool = betterproto.bool_field(3) 19 | float_key: float = betterproto.float_field(4) 20 | int_key: int = betterproto.uint64_field(5) 21 | 22 | 23 | @dataclass 24 | class TestNestedMessage(betterproto.Message): 25 | foo: TestNestedChildMessage = betterproto.message_field(1) 26 | bar: TestNestedChildMessage = betterproto.message_field(2) 27 | baz: TestNestedChildMessage = betterproto.message_field(3) 28 | 29 | 30 | @dataclass 31 | class TestRepeatedMessage(betterproto.Message): 32 | foo_repeat: List[str] = betterproto.string_field(1) 33 | bar_repeat: List[int] = betterproto.int64_field(2) 34 | baz_repeat: List[bool] = betterproto.bool_field(3) 35 | 36 | 37 | class BenchMessage: 38 | """Test creation and usage a proto message.""" 39 | 40 | def setup(self): 41 | self.cls = TestMessage 42 | self.instance = TestMessage() 43 | self.instance_filled = TestMessage(0, "test", 0.0) 44 | self.instance_filled_bytes = bytes(self.instance_filled) 45 | self.instance_filled_nested = TestNestedMessage( 46 | TestNestedChildMessage("foo", bytearray(b"test1"), True, 0.1234, 500), 47 | TestNestedChildMessage("bar", bytearray(b"test2"), True, 3.1415, 302), 48 | TestNestedChildMessage("baz", bytearray(b"test3"), False, 1e5, 300), 49 | ) 50 | self.instance_filled_nested_bytes = bytes(self.instance_filled_nested) 51 | self.instance_filled_repeated = TestRepeatedMessage( 52 | [f"test{i}" for i in range(1_000)], 53 | [(i - 500) ** 3 for i in range(1_000)], 54 | [i % 2 == 0 for i in range(1_000)], 55 | ) 56 | self.instance_filled_repeated_bytes = bytes(self.instance_filled_repeated) 57 | 58 | def time_overhead(self): 59 | """Overhead in class definition.""" 60 | 61 | @dataclass 62 | class Message(betterproto.Message): 63 | foo: int = betterproto.uint32_field(1) 64 | bar: str = betterproto.string_field(2) 65 | baz: float = betterproto.float_field(3) 66 | 67 | def time_instantiation(self): 68 | """Time instantiation""" 69 | self.cls() 70 | 71 | def time_attribute_access(self): 72 | """Time to access an attribute""" 73 | self.instance.foo 74 | self.instance.bar 75 | self.instance.baz 76 | 77 | def time_init_with_values(self): 78 | """Time to set an attribute""" 79 | self.cls(0, "test", 0.0) 80 | 81 | def time_attribute_setting(self): 82 | """Time to set attributes""" 83 | self.instance.foo = 0 84 | self.instance.bar = "test" 85 | self.instance.baz = 0.0 86 | 87 | def time_serialize(self): 88 | """Time serializing a message to wire.""" 89 | bytes(self.instance_filled) 90 | 91 | def time_deserialize(self): 92 | """Time deserialize a message.""" 93 | TestMessage().parse(self.instance_filled_bytes) 94 | 95 | def time_serialize_nested(self): 96 | """Time serializing a nested message to wire.""" 97 | bytes(self.instance_filled_nested) 98 | 99 | def time_deserialize_nested(self): 100 | """Time deserialize a nested message.""" 101 | TestNestedMessage().parse(self.instance_filled_nested_bytes) 102 | 103 | def time_serialize_repeated(self): 104 | """Time serializing a repeated message to wire.""" 105 | bytes(self.instance_filled_repeated) 106 | 107 | def time_deserialize_repeated(self): 108 | """Time deserialize a repeated message.""" 109 | TestRepeatedMessage().parse(self.instance_filled_repeated_bytes) 110 | 111 | 112 | class MemSuite: 113 | def setup(self): 114 | self.cls = TestMessage 115 | 116 | def mem_instance(self): 117 | return self.cls() 118 | -------------------------------------------------------------------------------- /tests/streams/java/src/main/java/betterproto/Tests.java: -------------------------------------------------------------------------------- 1 | package betterproto; 2 | 3 | import betterproto.nested.NestedOuterClass; 4 | import betterproto.oneof.Oneof; 5 | 6 | import com.google.protobuf.CodedInputStream; 7 | import com.google.protobuf.CodedOutputStream; 8 | 9 | import java.io.FileInputStream; 10 | import java.io.FileOutputStream; 11 | import java.io.IOException; 12 | 13 | public class Tests { 14 | String path; 15 | 16 | public Tests(String path) { 17 | this.path = path; 18 | } 19 | 20 | public void testSingleVarint() throws IOException { 21 | // Read in the Python-generated single varint file 22 | FileInputStream inputStream = new FileInputStream(path + "/py_single_varint.out"); 23 | CodedInputStream codedInput = CodedInputStream.newInstance(inputStream); 24 | 25 | int value = codedInput.readUInt32(); 26 | 27 | inputStream.close(); 28 | 29 | // Write the value back to a file 30 | FileOutputStream outputStream = new FileOutputStream(path + "/java_single_varint.out"); 31 | CodedOutputStream codedOutput = CodedOutputStream.newInstance(outputStream); 32 | 33 | codedOutput.writeUInt32NoTag(value); 34 | 35 | codedOutput.flush(); 36 | outputStream.close(); 37 | } 38 | 39 | public void testMultipleVarints() throws IOException { 40 | // Read in the Python-generated multiple varints file 41 | FileInputStream inputStream = new FileInputStream(path + "/py_multiple_varints.out"); 42 | CodedInputStream codedInput = CodedInputStream.newInstance(inputStream); 43 | 44 | int value1 = codedInput.readUInt32(); 45 | int value2 = codedInput.readUInt32(); 46 | long value3 = codedInput.readUInt64(); 47 | 48 | inputStream.close(); 49 | 50 | // Write the values back to a file 51 | FileOutputStream outputStream = new FileOutputStream(path + "/java_multiple_varints.out"); 52 | CodedOutputStream codedOutput = CodedOutputStream.newInstance(outputStream); 53 | 54 | codedOutput.writeUInt32NoTag(value1); 55 | codedOutput.writeUInt64NoTag(value2); 56 | codedOutput.writeUInt64NoTag(value3); 57 | 58 | codedOutput.flush(); 59 | outputStream.close(); 60 | } 61 | 62 | public void testSingleMessage() throws IOException { 63 | // Read in the Python-generated single message file 64 | FileInputStream inputStream = new FileInputStream(path + "/py_single_message.out"); 65 | CodedInputStream codedInput = CodedInputStream.newInstance(inputStream); 66 | 67 | Oneof.Test message = Oneof.Test.parseFrom(codedInput); 68 | 69 | inputStream.close(); 70 | 71 | // Write the message back to a file 72 | FileOutputStream outputStream = new FileOutputStream(path + "/java_single_message.out"); 73 | CodedOutputStream codedOutput = CodedOutputStream.newInstance(outputStream); 74 | 75 | message.writeTo(codedOutput); 76 | 77 | codedOutput.flush(); 78 | outputStream.close(); 79 | } 80 | 81 | public void testMultipleMessages() throws IOException { 82 | // Read in the Python-generated multi-message file 83 | FileInputStream inputStream = new FileInputStream(path + "/py_multiple_messages.out"); 84 | 85 | Oneof.Test oneof = Oneof.Test.parseDelimitedFrom(inputStream); 86 | NestedOuterClass.Test nested = NestedOuterClass.Test.parseDelimitedFrom(inputStream); 87 | 88 | inputStream.close(); 89 | 90 | // Write the messages back to a file 91 | FileOutputStream outputStream = new FileOutputStream(path + "/java_multiple_messages.out"); 92 | 93 | oneof.writeDelimitedTo(outputStream); 94 | nested.writeDelimitedTo(outputStream); 95 | 96 | outputStream.flush(); 97 | outputStream.close(); 98 | } 99 | 100 | public void testInfiniteMessages() throws IOException { 101 | // Read in as many messages as are present in the Python-generated file and write them back 102 | FileInputStream inputStream = new FileInputStream(path + "/py_infinite_messages.out"); 103 | FileOutputStream outputStream = new FileOutputStream(path + "/java_infinite_messages.out"); 104 | 105 | Oneof.Test current = Oneof.Test.parseDelimitedFrom(inputStream); 106 | while (current != null) { 107 | current.writeDelimitedTo(outputStream); 108 | current = Oneof.Test.parseDelimitedFrom(inputStream); 109 | } 110 | 111 | inputStream.close(); 112 | outputStream.flush(); 113 | outputStream.close(); 114 | } 115 | } 116 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = "betterproto" 3 | version = "2.0.0b7" 4 | description = "A better Protobuf / gRPC generator & library" 5 | authors = [ 6 | {name = "Daniel G. Taylor", email = "danielgtaylor@gmail.com"} 7 | ] 8 | readme = "README.md" 9 | repository = "https://github.com/danielgtaylor/python-betterproto" 10 | keywords = ["protobuf", "gRPC"] 11 | license = "MIT" 12 | packages = [ 13 | { include = "betterproto", from = "src" } 14 | ] 15 | requires-python = ">=3.9,<4.0" 16 | dynamic = ["dependencies"] 17 | 18 | [tool.poetry.dependencies] 19 | # The Ruff version is pinned. To update it, also update it in .pre-commit-config.yaml 20 | ruff = { version = "~0.9.1", optional = true } 21 | grpclib = "^0.4.1" 22 | jinja2 = { version = ">=3.0.3", optional = true } 23 | python-dateutil = "^2.8" 24 | typing-extensions = "^4.7.1" 25 | betterproto-rust-codec = { version = "0.1.1", optional = true } 26 | 27 | [tool.poetry.group.dev.dependencies] 28 | asv = "^0.6.4" 29 | bpython = "^0.24" 30 | jinja2 = ">=3.0.3" 31 | mypy = "^1.11.2" 32 | sphinx = "7.4.7" 33 | sphinx-rtd-theme = "3.0.2" 34 | pre-commit = "^4.0.1" 35 | grpcio-tools = "^1.54.2" 36 | tox = "^4.0.0" 37 | 38 | [tool.poetry.group.test.dependencies] 39 | poethepoet = ">=0.9.0" 40 | pytest = "^7.4.4" 41 | pytest-asyncio = "^0.23.8" 42 | pytest-cov = "^6.0.0" 43 | pytest-mock = "^3.1.1" 44 | pydantic = ">=2.0,<3" 45 | protobuf = "^5" 46 | cachelib = "^0.13.0" 47 | tomlkit = ">=0.7.0" 48 | 49 | [project.scripts] 50 | protoc-gen-python_betterproto = "betterproto.plugin:main" 51 | 52 | [project.optional-dependencies] 53 | compiler = ["ruff", "jinja2"] 54 | rust-codec = ["betterproto-rust-codec"] 55 | 56 | [tool.ruff] 57 | extend-exclude = ["tests/output_*"] 58 | target-version = "py38" 59 | 60 | [tool.ruff.lint.isort] 61 | combine-as-imports = true 62 | lines-after-imports = 2 63 | 64 | # Dev workflow tasks 65 | 66 | [tool.poe.tasks.generate] 67 | script = "tests.generate:main" 68 | help = "Generate test cases (do this once before running test)" 69 | 70 | [tool.poe.tasks.test] 71 | cmd = "pytest" 72 | help = "Run tests" 73 | 74 | [tool.poe.tasks.types] 75 | cmd = "mypy src --ignore-missing-imports" 76 | help = "Check types with mypy" 77 | 78 | [tool.poe.tasks.format] 79 | sequence = ["_format", "_sort-imports"] 80 | help = "Format the source code, and sort the imports" 81 | 82 | [tool.poe.tasks.check] 83 | sequence = ["_check-format", "_check-imports"] 84 | help = "Check that the source code is formatted and the imports sorted" 85 | 86 | [tool.poe.tasks._format] 87 | cmd = "ruff format src tests" 88 | help = "Format the source code without sorting the imports" 89 | 90 | [tool.poe.tasks._sort-imports] 91 | cmd = "ruff check --select I --fix src tests" 92 | help = "Sort the imports" 93 | 94 | [tool.poe.tasks._check-format] 95 | cmd = "ruff format --diff src tests" 96 | help = "Check that the source code is formatted" 97 | 98 | [tool.poe.tasks._check-imports] 99 | cmd = "ruff check --select I src tests" 100 | help = "Check that the imports are sorted" 101 | 102 | [tool.poe.tasks.docs] 103 | cmd = "sphinx-build docs docs/build" 104 | help = "Build the sphinx docs" 105 | 106 | [tool.poe.tasks.bench] 107 | shell = "asv run master^! && asv run HEAD^! && asv compare master HEAD" 108 | help = "Benchmark current commit vs. master branch" 109 | 110 | [tool.poe.tasks.clean] 111 | cmd = """ 112 | rm -rf .asv .coverage .mypy_cache .pytest_cache 113 | dist betterproto.egg-info **/__pycache__ 114 | testsoutput_* 115 | """ 116 | help = "Clean out generated files from the workspace" 117 | 118 | [tool.poe.tasks.generate_lib] 119 | cmd = """ 120 | protoc 121 | --plugin=protoc-gen-custom=src/betterproto/plugin/main.py 122 | --custom_opt=INCLUDE_GOOGLE 123 | --custom_out=src/betterproto/lib/std 124 | -I /usr/local/include/ 125 | /usr/local/include/google/protobuf/**/*.proto 126 | """ 127 | help = "Regenerate the types in betterproto.lib.std.google" 128 | 129 | # CI tasks 130 | 131 | [tool.poe.tasks.full-test] 132 | shell = "poe generate && tox" 133 | help = "Run tests with multiple pythons" 134 | 135 | [tool.doc8] 136 | paths = ["docs"] 137 | max_line_length = 88 138 | 139 | [tool.doc8.ignore_path_errors] 140 | "docs/migrating.rst" = [ 141 | "D001", # contains table which is longer than 88 characters long 142 | ] 143 | 144 | [tool.coverage.run] 145 | omit = ["betterproto/tests/*"] 146 | 147 | [tool.tox] 148 | legacy_tox_ini = """ 149 | [tox] 150 | requires = 151 | tox>=4.2 152 | tox-poetry-installer[poetry]==1.0.0b1 153 | env_list = 154 | py311 155 | py38 156 | py37 157 | 158 | [testenv] 159 | commands = 160 | pytest {posargs: --cov betterproto} 161 | poetry_dep_groups = 162 | test 163 | require_locked_deps = true 164 | require_poetry = true 165 | """ 166 | 167 | [build-system] 168 | requires = ["poetry-core>=2.0.0,<3"] 169 | build-backend = "poetry.core.masonry.api" 170 | -------------------------------------------------------------------------------- /src/betterproto/plugin/module_validation.py: -------------------------------------------------------------------------------- 1 | import re 2 | from collections import defaultdict 3 | from dataclasses import ( 4 | dataclass, 5 | field, 6 | ) 7 | from typing import ( 8 | Dict, 9 | Iterator, 10 | List, 11 | Tuple, 12 | ) 13 | 14 | 15 | @dataclass 16 | class ModuleValidator: 17 | line_iterator: Iterator[str] 18 | line_number: int = field(init=False, default=0) 19 | 20 | collisions: Dict[str, List[Tuple[int, str]]] = field( 21 | init=False, default_factory=lambda: defaultdict(list) 22 | ) 23 | 24 | def add_import(self, imp: str, number: int, full_line: str): 25 | """ 26 | Adds an import to be tracked. 27 | """ 28 | self.collisions[imp].append((number, full_line)) 29 | 30 | def process_import(self, imp: str): 31 | """ 32 | Filters out the import to its actual value. 33 | """ 34 | if " as " in imp: 35 | imp = imp[imp.index(" as ") + 4 :] 36 | 37 | imp = imp.strip() 38 | assert " " not in imp, imp 39 | return imp 40 | 41 | def evaluate_multiline_import(self, line: str): 42 | """ 43 | Evaluates a multiline import from a starting line 44 | """ 45 | # Filter the first line and remove anything before the import statement. 46 | full_line = line 47 | line = line.split("import", 1)[1] 48 | if "(" in line: 49 | conditional = lambda line: ")" not in line 50 | else: 51 | conditional = lambda line: "\\" in line 52 | 53 | # Remove open parenthesis if it exists. 54 | if "(" in line: 55 | line = line[line.index("(") + 1 :] 56 | 57 | # Choose the conditional based on how multiline imports are formatted. 58 | while conditional(line): 59 | # Split the line by commas 60 | imports = line.split(",") 61 | 62 | for imp in imports: 63 | # Add the import to the namespace 64 | imp = self.process_import(imp) 65 | if imp: 66 | self.add_import(imp, self.line_number, full_line) 67 | # Get the next line 68 | full_line = line = next(self.line_iterator) 69 | # Increment the line number 70 | self.line_number += 1 71 | 72 | # validate the last line 73 | if ")" in line: 74 | line = line[: line.index(")")] 75 | imports = line.split(",") 76 | for imp in imports: 77 | imp = self.process_import(imp) 78 | if imp: 79 | self.add_import(imp, self.line_number, full_line) 80 | 81 | def evaluate_import(self, line: str): 82 | """ 83 | Extracts an import from a line. 84 | """ 85 | whole_line = line 86 | line = line[line.index("import") + 6 :] 87 | values = line.split(",") 88 | for v in values: 89 | self.add_import(self.process_import(v), self.line_number, whole_line) 90 | 91 | def next(self): 92 | """ 93 | Evaluate each line for names in the module. 94 | """ 95 | line = next(self.line_iterator) 96 | 97 | # Skip lines with indentation or comments 98 | if ( 99 | # Skip indents and whitespace. 100 | line.startswith(" ") 101 | or line == "\n" 102 | or line.startswith("\t") 103 | or 104 | # Skip comments 105 | line.startswith("#") 106 | or 107 | # Skip decorators 108 | line.startswith("@") 109 | ): 110 | self.line_number += 1 111 | return 112 | 113 | # Skip docstrings. 114 | if line.startswith('"""') or line.startswith("'''"): 115 | quote = line[0] * 3 116 | line = line[3:] 117 | while quote not in line: 118 | line = next(self.line_iterator) 119 | self.line_number += 1 120 | return 121 | 122 | # Evaluate Imports. 123 | if line.startswith("from ") or line.startswith("import "): 124 | if "(" in line or "\\" in line: 125 | self.evaluate_multiline_import(line) 126 | else: 127 | self.evaluate_import(line) 128 | 129 | # Evaluate Classes. 130 | elif line.startswith("class "): 131 | class_name = re.search(r"class (\w+)", line).group(1) 132 | if class_name: 133 | self.add_import(class_name, self.line_number, line) 134 | 135 | # Evaluate Functions. 136 | elif line.startswith("def "): 137 | function_name = re.search(r"def (\w+)", line).group(1) 138 | if function_name: 139 | self.add_import(function_name, self.line_number, line) 140 | 141 | # Evaluate direct assignments. 142 | elif "=" in line: 143 | assignment = re.search(r"(\w+)\s*=", line).group(1) 144 | if assignment: 145 | self.add_import(assignment, self.line_number, line) 146 | 147 | self.line_number += 1 148 | 149 | def validate(self) -> bool: 150 | """ 151 | Run Validation. 152 | """ 153 | try: 154 | while True: 155 | self.next() 156 | except StopIteration: 157 | pass 158 | 159 | # Filter collisions for those with more than one value. 160 | self.collisions = {k: v for k, v in self.collisions.items() if len(v) > 1} 161 | 162 | # Return True if no collisions are found. 163 | return not bool(self.collisions) 164 | -------------------------------------------------------------------------------- /src/betterproto/plugin/typing_compiler.py: -------------------------------------------------------------------------------- 1 | import abc 2 | from collections import defaultdict 3 | from dataclasses import ( 4 | dataclass, 5 | field, 6 | ) 7 | from typing import ( 8 | Dict, 9 | Iterator, 10 | Optional, 11 | Set, 12 | ) 13 | 14 | 15 | class TypingCompiler(metaclass=abc.ABCMeta): 16 | @abc.abstractmethod 17 | def optional(self, type: str) -> str: 18 | raise NotImplementedError() 19 | 20 | @abc.abstractmethod 21 | def list(self, type: str) -> str: 22 | raise NotImplementedError() 23 | 24 | @abc.abstractmethod 25 | def dict(self, key: str, value: str) -> str: 26 | raise NotImplementedError() 27 | 28 | @abc.abstractmethod 29 | def union(self, *types: str) -> str: 30 | raise NotImplementedError() 31 | 32 | @abc.abstractmethod 33 | def iterable(self, type: str) -> str: 34 | raise NotImplementedError() 35 | 36 | @abc.abstractmethod 37 | def async_iterable(self, type: str) -> str: 38 | raise NotImplementedError() 39 | 40 | @abc.abstractmethod 41 | def async_iterator(self, type: str) -> str: 42 | raise NotImplementedError() 43 | 44 | @abc.abstractmethod 45 | def imports(self) -> Dict[str, Optional[Set[str]]]: 46 | """ 47 | Returns either the direct import as a key with none as value, or a set of 48 | values to import from the key. 49 | """ 50 | raise NotImplementedError() 51 | 52 | def import_lines(self) -> Iterator: 53 | imports = self.imports() 54 | for key, value in imports.items(): 55 | if value is None: 56 | yield f"import {key}" 57 | else: 58 | yield f"from {key} import (" 59 | for v in sorted(value): 60 | yield f" {v}," 61 | yield ")" 62 | 63 | 64 | @dataclass 65 | class DirectImportTypingCompiler(TypingCompiler): 66 | _imports: Dict[str, Set[str]] = field(default_factory=lambda: defaultdict(set)) 67 | 68 | def optional(self, type: str) -> str: 69 | self._imports["typing"].add("Optional") 70 | return f"Optional[{type}]" 71 | 72 | def list(self, type: str) -> str: 73 | self._imports["typing"].add("List") 74 | return f"List[{type}]" 75 | 76 | def dict(self, key: str, value: str) -> str: 77 | self._imports["typing"].add("Dict") 78 | return f"Dict[{key}, {value}]" 79 | 80 | def union(self, *types: str) -> str: 81 | self._imports["typing"].add("Union") 82 | return f"Union[{', '.join(types)}]" 83 | 84 | def iterable(self, type: str) -> str: 85 | self._imports["typing"].add("Iterable") 86 | return f"Iterable[{type}]" 87 | 88 | def async_iterable(self, type: str) -> str: 89 | self._imports["typing"].add("AsyncIterable") 90 | return f"AsyncIterable[{type}]" 91 | 92 | def async_iterator(self, type: str) -> str: 93 | self._imports["typing"].add("AsyncIterator") 94 | return f"AsyncIterator[{type}]" 95 | 96 | def imports(self) -> Dict[str, Optional[Set[str]]]: 97 | return {k: v if v else None for k, v in self._imports.items()} 98 | 99 | 100 | @dataclass 101 | class TypingImportTypingCompiler(TypingCompiler): 102 | _imported: bool = False 103 | 104 | def optional(self, type: str) -> str: 105 | self._imported = True 106 | return f"typing.Optional[{type}]" 107 | 108 | def list(self, type: str) -> str: 109 | self._imported = True 110 | return f"typing.List[{type}]" 111 | 112 | def dict(self, key: str, value: str) -> str: 113 | self._imported = True 114 | return f"typing.Dict[{key}, {value}]" 115 | 116 | def union(self, *types: str) -> str: 117 | self._imported = True 118 | return f"typing.Union[{', '.join(types)}]" 119 | 120 | def iterable(self, type: str) -> str: 121 | self._imported = True 122 | return f"typing.Iterable[{type}]" 123 | 124 | def async_iterable(self, type: str) -> str: 125 | self._imported = True 126 | return f"typing.AsyncIterable[{type}]" 127 | 128 | def async_iterator(self, type: str) -> str: 129 | self._imported = True 130 | return f"typing.AsyncIterator[{type}]" 131 | 132 | def imports(self) -> Dict[str, Optional[Set[str]]]: 133 | if self._imported: 134 | return {"typing": None} 135 | return {} 136 | 137 | 138 | @dataclass 139 | class NoTyping310TypingCompiler(TypingCompiler): 140 | _imports: Dict[str, Set[str]] = field(default_factory=lambda: defaultdict(set)) 141 | 142 | @staticmethod 143 | def _fmt(type: str) -> str: # for now this is necessary till 3.14 144 | if type.startswith('"'): 145 | return type[1:-1] 146 | return type 147 | 148 | def optional(self, type: str) -> str: 149 | return f'"{self._fmt(type)} | None"' 150 | 151 | def list(self, type: str) -> str: 152 | return f'"list[{self._fmt(type)}]"' 153 | 154 | def dict(self, key: str, value: str) -> str: 155 | return f'"dict[{key}, {self._fmt(value)}]"' 156 | 157 | def union(self, *types: str) -> str: 158 | return f'"{" | ".join(map(self._fmt, types))}"' 159 | 160 | def iterable(self, type: str) -> str: 161 | self._imports["collections.abc"].add("Iterable") 162 | return f'"Iterable[{type}]"' 163 | 164 | def async_iterable(self, type: str) -> str: 165 | self._imports["collections.abc"].add("AsyncIterable") 166 | return f'"AsyncIterable[{type}]"' 167 | 168 | def async_iterator(self, type: str) -> str: 169 | self._imports["collections.abc"].add("AsyncIterator") 170 | return f'"AsyncIterator[{type}]"' 171 | 172 | def imports(self) -> Dict[str, Optional[Set[str]]]: 173 | return {k: v if v else None for k, v in self._imports.items()} 174 | -------------------------------------------------------------------------------- /tests/util.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import atexit 3 | import importlib 4 | import os 5 | import platform 6 | import sys 7 | import tempfile 8 | from dataclasses import dataclass 9 | from pathlib import Path 10 | from types import ModuleType 11 | from typing import ( 12 | Callable, 13 | Dict, 14 | Generator, 15 | List, 16 | Optional, 17 | Tuple, 18 | Union, 19 | ) 20 | 21 | 22 | os.environ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION"] = "python" 23 | 24 | root_path = Path(__file__).resolve().parent 25 | inputs_path = root_path.joinpath("inputs") 26 | output_path_reference = root_path.joinpath("output_reference") 27 | output_path_betterproto = root_path.joinpath("output_betterproto") 28 | output_path_betterproto_pydantic = root_path.joinpath("output_betterproto_pydantic") 29 | 30 | 31 | def get_files(path, suffix: str) -> Generator[str, None, None]: 32 | for r, dirs, files in os.walk(path): 33 | for filename in [f for f in files if f.endswith(suffix)]: 34 | yield os.path.join(r, filename) 35 | 36 | 37 | def get_directories(path): 38 | for root, directories, files in os.walk(path): 39 | yield from directories 40 | 41 | 42 | async def protoc( 43 | path: Union[str, Path], 44 | output_dir: Union[str, Path], 45 | reference: bool = False, 46 | pydantic_dataclasses: bool = False, 47 | ): 48 | path: Path = Path(path).resolve() 49 | output_dir: Path = Path(output_dir).resolve() 50 | python_out_option: str = "python_betterproto_out" if not reference else "python_out" 51 | 52 | if pydantic_dataclasses: 53 | plugin_path = Path("src/betterproto/plugin/main.py") 54 | 55 | if "Win" in platform.system(): 56 | with tempfile.NamedTemporaryFile( 57 | "w", encoding="UTF-8", suffix=".bat", delete=False 58 | ) as tf: 59 | # See https://stackoverflow.com/a/42622705 60 | tf.writelines( 61 | [ 62 | "@echo off", 63 | f"\nchdir {os.getcwd()}", 64 | f"\n{sys.executable} -u {plugin_path.as_posix()}", 65 | ] 66 | ) 67 | 68 | tf.flush() 69 | 70 | plugin_path = Path(tf.name) 71 | atexit.register(os.remove, plugin_path) 72 | 73 | command = [ 74 | sys.executable, 75 | "-m", 76 | "grpc.tools.protoc", 77 | f"--plugin=protoc-gen-custom={plugin_path.as_posix()}", 78 | "--experimental_allow_proto3_optional", 79 | "--custom_opt=pydantic_dataclasses", 80 | f"--proto_path={path.as_posix()}", 81 | f"--custom_out={output_dir.as_posix()}", 82 | *[p.as_posix() for p in path.glob("*.proto")], 83 | ] 84 | else: 85 | command = [ 86 | sys.executable, 87 | "-m", 88 | "grpc.tools.protoc", 89 | f"--proto_path={path.as_posix()}", 90 | f"--{python_out_option}={output_dir.as_posix()}", 91 | *[p.as_posix() for p in path.glob("*.proto")], 92 | ] 93 | proc = await asyncio.create_subprocess_exec( 94 | *command, stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE 95 | ) 96 | stdout, stderr = await proc.communicate() 97 | return stdout, stderr, proc.returncode 98 | 99 | 100 | @dataclass 101 | class TestCaseJsonFile: 102 | json: str 103 | test_name: str 104 | file_name: str 105 | 106 | def belongs_to(self, non_symmetrical_json: Dict[str, Tuple[str, ...]]): 107 | return self.file_name in non_symmetrical_json.get(self.test_name, tuple()) 108 | 109 | 110 | def get_test_case_json_data( 111 | test_case_name: str, *json_file_names: str 112 | ) -> List[TestCaseJsonFile]: 113 | """ 114 | :return: 115 | A list of all files found in "{inputs_path}/test_case_name" with names matching 116 | f"{test_case_name}.json" or f"{test_case_name}_*.json", OR given by 117 | json_file_names 118 | """ 119 | test_case_dir = inputs_path.joinpath(test_case_name) 120 | possible_file_paths = [ 121 | *(test_case_dir.joinpath(json_file_name) for json_file_name in json_file_names), 122 | test_case_dir.joinpath(f"{test_case_name}.json"), 123 | *test_case_dir.glob(f"{test_case_name}_*.json"), 124 | ] 125 | 126 | result = [] 127 | for test_data_file_path in possible_file_paths: 128 | if not test_data_file_path.exists(): 129 | continue 130 | with test_data_file_path.open("r") as fh: 131 | result.append( 132 | TestCaseJsonFile( 133 | fh.read(), test_case_name, test_data_file_path.name.split(".")[0] 134 | ) 135 | ) 136 | 137 | return result 138 | 139 | 140 | def find_module( 141 | module: ModuleType, predicate: Callable[[ModuleType], bool] 142 | ) -> Optional[ModuleType]: 143 | """ 144 | Recursively search module tree for a module that matches the search predicate. 145 | Assumes that the submodules are directories containing __init__.py. 146 | 147 | Example: 148 | 149 | # find module inside foo that contains Test 150 | import foo 151 | test_module = find_module(foo, lambda m: hasattr(m, 'Test')) 152 | """ 153 | if predicate(module): 154 | return module 155 | 156 | module_path = Path(*module.__path__) 157 | 158 | for sub in [sub.parent for sub in module_path.glob("**/__init__.py")]: 159 | if sub == module_path: 160 | continue 161 | sub_module_path = sub.relative_to(module_path) 162 | sub_module_name = ".".join(sub_module_path.parts) 163 | 164 | sub_module = importlib.import_module(f".{sub_module_name}", module.__name__) 165 | 166 | if predicate(sub_module): 167 | return sub_module 168 | 169 | return None 170 | -------------------------------------------------------------------------------- /src/betterproto/grpc/grpclib_client.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | from abc import ABC 3 | from typing import ( 4 | TYPE_CHECKING, 5 | AsyncIterable, 6 | AsyncIterator, 7 | Collection, 8 | Iterable, 9 | Mapping, 10 | Optional, 11 | Tuple, 12 | Type, 13 | Union, 14 | ) 15 | 16 | import grpclib.const 17 | 18 | 19 | if TYPE_CHECKING: 20 | from grpclib.client import Channel 21 | from grpclib.metadata import Deadline 22 | 23 | from .._types import ( 24 | ST, 25 | IProtoMessage, 26 | Message, 27 | T, 28 | ) 29 | 30 | 31 | Value = Union[str, bytes] 32 | MetadataLike = Union[Mapping[str, Value], Collection[Tuple[str, Value]]] 33 | MessageSource = Union[Iterable["IProtoMessage"], AsyncIterable["IProtoMessage"]] 34 | 35 | 36 | class ServiceStub(ABC): 37 | """ 38 | Base class for async gRPC clients. 39 | """ 40 | 41 | def __init__( 42 | self, 43 | channel: "Channel", 44 | *, 45 | timeout: Optional[float] = None, 46 | deadline: Optional["Deadline"] = None, 47 | metadata: Optional[MetadataLike] = None, 48 | ) -> None: 49 | self.channel = channel 50 | self.timeout = timeout 51 | self.deadline = deadline 52 | self.metadata = metadata 53 | 54 | def __resolve_request_kwargs( 55 | self, 56 | timeout: Optional[float], 57 | deadline: Optional["Deadline"], 58 | metadata: Optional[MetadataLike], 59 | ): 60 | return { 61 | "timeout": self.timeout if timeout is None else timeout, 62 | "deadline": self.deadline if deadline is None else deadline, 63 | "metadata": self.metadata if metadata is None else metadata, 64 | } 65 | 66 | async def _unary_unary( 67 | self, 68 | route: str, 69 | request: "IProtoMessage", 70 | response_type: Type["T"], 71 | *, 72 | timeout: Optional[float] = None, 73 | deadline: Optional["Deadline"] = None, 74 | metadata: Optional[MetadataLike] = None, 75 | ) -> "T": 76 | """Make a unary request and return the response.""" 77 | async with self.channel.request( 78 | route, 79 | grpclib.const.Cardinality.UNARY_UNARY, 80 | type(request), 81 | response_type, 82 | **self.__resolve_request_kwargs(timeout, deadline, metadata), 83 | ) as stream: 84 | await stream.send_message(request, end=True) 85 | response = await stream.recv_message() 86 | assert response is not None 87 | return response 88 | 89 | async def _unary_stream( 90 | self, 91 | route: str, 92 | request: "IProtoMessage", 93 | response_type: Type["T"], 94 | *, 95 | timeout: Optional[float] = None, 96 | deadline: Optional["Deadline"] = None, 97 | metadata: Optional[MetadataLike] = None, 98 | ) -> AsyncIterator["T"]: 99 | """Make a unary request and return the stream response iterator.""" 100 | async with self.channel.request( 101 | route, 102 | grpclib.const.Cardinality.UNARY_STREAM, 103 | type(request), 104 | response_type, 105 | **self.__resolve_request_kwargs(timeout, deadline, metadata), 106 | ) as stream: 107 | await stream.send_message(request, end=True) 108 | async for message in stream: 109 | yield message 110 | 111 | async def _stream_unary( 112 | self, 113 | route: str, 114 | request_iterator: MessageSource, 115 | request_type: Type["IProtoMessage"], 116 | response_type: Type["T"], 117 | *, 118 | timeout: Optional[float] = None, 119 | deadline: Optional["Deadline"] = None, 120 | metadata: Optional[MetadataLike] = None, 121 | ) -> "T": 122 | """Make a stream request and return the response.""" 123 | async with self.channel.request( 124 | route, 125 | grpclib.const.Cardinality.STREAM_UNARY, 126 | request_type, 127 | response_type, 128 | **self.__resolve_request_kwargs(timeout, deadline, metadata), 129 | ) as stream: 130 | await stream.send_request() 131 | await self._send_messages(stream, request_iterator) 132 | response = await stream.recv_message() 133 | assert response is not None 134 | return response 135 | 136 | async def _stream_stream( 137 | self, 138 | route: str, 139 | request_iterator: MessageSource, 140 | request_type: Type["IProtoMessage"], 141 | response_type: Type["T"], 142 | *, 143 | timeout: Optional[float] = None, 144 | deadline: Optional["Deadline"] = None, 145 | metadata: Optional[MetadataLike] = None, 146 | ) -> AsyncIterator["T"]: 147 | """ 148 | Make a stream request and return an AsyncIterator to iterate over response 149 | messages. 150 | """ 151 | async with self.channel.request( 152 | route, 153 | grpclib.const.Cardinality.STREAM_STREAM, 154 | request_type, 155 | response_type, 156 | **self.__resolve_request_kwargs(timeout, deadline, metadata), 157 | ) as stream: 158 | await stream.send_request() 159 | sending_task = asyncio.ensure_future( 160 | self._send_messages(stream, request_iterator) 161 | ) 162 | try: 163 | async for response in stream: 164 | yield response 165 | except: 166 | sending_task.cancel() 167 | raise 168 | 169 | @staticmethod 170 | async def _send_messages(stream, messages: MessageSource): 171 | if isinstance(messages, AsyncIterable): 172 | async for message in messages: 173 | await stream.send_message(message) 174 | else: 175 | for message in messages: 176 | await stream.send_message(message) 177 | await stream.end() 178 | -------------------------------------------------------------------------------- /docs/quick-start.rst: -------------------------------------------------------------------------------- 1 | Getting Started 2 | =============== 3 | 4 | Installation 5 | ++++++++++++ 6 | 7 | Installation from PyPI is as simple as running: 8 | 9 | .. code-block:: sh 10 | 11 | python3 -m pip install -U betterproto 12 | 13 | If you are using Windows, then the following should be used instead: 14 | 15 | .. code-block:: sh 16 | 17 | py -3 -m pip install -U betterproto 18 | 19 | To include the protoc plugin, install betterproto[compiler] instead of betterproto, 20 | e.g. 21 | 22 | .. code-block:: sh 23 | 24 | python3 -m pip install -U "betterproto[compiler]" 25 | 26 | Compiling proto files 27 | +++++++++++++++++++++ 28 | 29 | 30 | Given you installed the compiler and have a proto file, e.g ``example.proto``: 31 | 32 | .. code-block:: proto 33 | 34 | syntax = "proto3"; 35 | 36 | package hello; 37 | 38 | // Greeting represents a message you can tell a user. 39 | message Greeting { 40 | string message = 1; 41 | } 42 | 43 | To compile the proto you would run the following: 44 | 45 | You can run the following to invoke protoc directly: 46 | 47 | .. code-block:: sh 48 | 49 | mkdir hello 50 | protoc -I . --python_betterproto_out=lib example.proto 51 | 52 | or run the following to invoke protoc via grpcio-tools: 53 | 54 | .. code-block:: sh 55 | 56 | pip install grpcio-tools 57 | python -m grpc_tools.protoc -I . --python_betterproto_out=lib example.proto 58 | 59 | 60 | This will generate ``lib/__init__.py`` which looks like: 61 | 62 | .. code-block:: python 63 | 64 | # Generated by the protocol buffer compiler. DO NOT EDIT! 65 | # sources: example.proto 66 | # plugin: python-betterproto 67 | from dataclasses import dataclass 68 | 69 | import betterproto 70 | 71 | 72 | @dataclass 73 | class Greeting(betterproto.Message): 74 | """Greeting represents a message you can tell a user.""" 75 | 76 | message: str = betterproto.string_field(1) 77 | 78 | 79 | Then to use it: 80 | 81 | .. code-block:: python 82 | 83 | >>> from lib import Greeting 84 | 85 | >>> test = Greeting() 86 | >>> test 87 | Greeting(message='') 88 | 89 | >>> test.message = "Hey!" 90 | >>> test 91 | Greeting(message="Hey!") 92 | 93 | >>> bytes(test) 94 | b'\n\x04Hey!' 95 | >>> Greeting().parse(serialized) 96 | Greeting(message="Hey!") 97 | 98 | 99 | Async gRPC Support 100 | ++++++++++++++++++ 101 | 102 | The generated code includes `grpclib `_ based 103 | stub (client and server) classes for rpc services declared in the input proto files. 104 | It is enabled by default. 105 | 106 | 107 | Given a service definition similar to the one below: 108 | 109 | .. code-block:: proto 110 | 111 | syntax = "proto3"; 112 | 113 | package echo; 114 | 115 | message EchoRequest { 116 | string value = 1; 117 | // Number of extra times to echo 118 | uint32 extra_times = 2; 119 | } 120 | 121 | message EchoResponse { 122 | repeated string values = 1; 123 | } 124 | 125 | message EchoStreamResponse { 126 | string value = 1; 127 | } 128 | 129 | service Echo { 130 | rpc Echo(EchoRequest) returns (EchoResponse); 131 | rpc EchoStream(EchoRequest) returns (stream EchoStreamResponse); 132 | } 133 | 134 | The generated client can be used like so: 135 | 136 | .. code-block:: python 137 | 138 | import asyncio 139 | from grpclib.client import Channel 140 | import echo 141 | 142 | 143 | async def main(): 144 | channel = Channel(host="127.0.0.1", port=50051) 145 | service = echo.EchoStub(channel) 146 | response = await service.echo(value="hello", extra_times=1) 147 | print(response) 148 | 149 | async for response in service.echo_stream(value="hello", extra_times=1): 150 | print(response) 151 | 152 | # don't forget to close the channel when you're done! 153 | channel.close() 154 | 155 | asyncio.run(main()) # python 3.7 only 156 | 157 | # outputs 158 | EchoResponse(values=['hello', 'hello']) 159 | EchoStreamResponse(value='hello') 160 | EchoStreamResponse(value='hello') 161 | 162 | 163 | The server-facing stubs can be used to implement a Python 164 | gRPC server. 165 | To use them, simply subclass the base class in the generated files and override the 166 | service methods: 167 | 168 | .. code-block:: python 169 | 170 | from echo import EchoBase 171 | from grpclib.server import Server 172 | from typing import AsyncIterator 173 | 174 | 175 | class EchoService(EchoBase): 176 | async def echo(self, value: str, extra_times: int) -> "EchoResponse": 177 | return value 178 | 179 | async def echo_stream( 180 | self, value: str, extra_times: int 181 | ) -> AsyncIterator["EchoStreamResponse"]: 182 | for _ in range(extra_times): 183 | yield value 184 | 185 | 186 | async def start_server(): 187 | HOST = "127.0.0.1" 188 | PORT = 1337 189 | server = Server([EchoService()]) 190 | await server.start(HOST, PORT) 191 | await server.serve_forever() 192 | 193 | JSON 194 | ++++ 195 | Message objects include :meth:`betterproto.Message.to_json` and 196 | :meth:`betterproto.Message.from_json` methods for JSON (de)serialisation, and 197 | :meth:`betterproto.Message.to_dict`, :meth:`betterproto.Message.from_dict` for 198 | converting back and forth from JSON serializable dicts. 199 | 200 | For compatibility the default is to convert field names to 201 | :attr:`betterproto.Casing.CAMEL`. You can control this behavior by passing a 202 | different casing value, e.g: 203 | 204 | .. code-block:: python 205 | 206 | @dataclass 207 | class MyMessage(betterproto.Message): 208 | a_long_field_name: str = betterproto.string_field(1) 209 | 210 | 211 | >>> test = MyMessage(a_long_field_name="Hello World!") 212 | >>> test.to_dict(betterproto.Casing.SNAKE) 213 | {"a_long_field_name": "Hello World!"} 214 | >>> test.to_dict(betterproto.Casing.CAMEL) 215 | {"aLongFieldName": "Hello World!"} 216 | 217 | >>> test.to_json(indent=2) 218 | '{\n "aLongFieldName": "Hello World!"\n}' 219 | 220 | >>> test.from_dict({"aLongFieldName": "Goodbye World!"}) 221 | >>> test.a_long_field_name 222 | "Goodbye World!" 223 | -------------------------------------------------------------------------------- /src/betterproto/enum.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from enum import ( 4 | EnumMeta, 5 | IntEnum, 6 | ) 7 | from types import MappingProxyType 8 | from typing import ( 9 | TYPE_CHECKING, 10 | Any, 11 | Dict, 12 | Optional, 13 | Tuple, 14 | ) 15 | 16 | 17 | if TYPE_CHECKING: 18 | from collections.abc import ( 19 | Generator, 20 | Mapping, 21 | ) 22 | 23 | from typing_extensions import ( 24 | Never, 25 | Self, 26 | ) 27 | 28 | 29 | def _is_descriptor(obj: object) -> bool: 30 | return ( 31 | hasattr(obj, "__get__") or hasattr(obj, "__set__") or hasattr(obj, "__delete__") 32 | ) 33 | 34 | 35 | class EnumType(EnumMeta if TYPE_CHECKING else type): 36 | _value_map_: Mapping[int, Enum] 37 | _member_map_: Mapping[str, Enum] 38 | 39 | def __new__( 40 | mcs, name: str, bases: Tuple[type, ...], namespace: Dict[str, Any] 41 | ) -> Self: 42 | value_map = {} 43 | member_map = {} 44 | 45 | new_mcs = type( 46 | f"{name}Type", 47 | tuple( 48 | dict.fromkeys( 49 | [base.__class__ for base in bases if base.__class__ is not type] 50 | + [EnumType, type] 51 | ) 52 | ), # reorder the bases so EnumType and type are last to avoid conflicts 53 | {"_value_map_": value_map, "_member_map_": member_map}, 54 | ) 55 | 56 | members = { 57 | name: value 58 | for name, value in namespace.items() 59 | if not _is_descriptor(value) and not name.startswith("__") 60 | } 61 | 62 | cls = type.__new__( 63 | new_mcs, 64 | name, 65 | bases, 66 | {key: value for key, value in namespace.items() if key not in members}, 67 | ) 68 | # this allows us to disallow member access from other members as 69 | # members become proper class variables 70 | 71 | for name, value in members.items(): 72 | member = value_map.get(value) 73 | if member is None: 74 | member = cls.__new__(cls, name=name, value=value) # type: ignore 75 | value_map[value] = member 76 | member_map[name] = member 77 | type.__setattr__(new_mcs, name, member) 78 | 79 | return cls 80 | 81 | if not TYPE_CHECKING: 82 | 83 | def __call__(cls, value: int) -> Enum: 84 | try: 85 | return cls._value_map_[value] 86 | except (KeyError, TypeError): 87 | raise ValueError(f"{value!r} is not a valid {cls.__name__}") from None 88 | 89 | def __iter__(cls) -> Generator[Enum, None, None]: 90 | yield from cls._member_map_.values() 91 | 92 | def __reversed__(cls) -> Generator[Enum, None, None]: 93 | yield from reversed(cls._member_map_.values()) 94 | 95 | def __getitem__(cls, key: str) -> Enum: 96 | return cls._member_map_[key] 97 | 98 | @property 99 | def __members__(cls) -> MappingProxyType[str, Enum]: 100 | return MappingProxyType(cls._member_map_) 101 | 102 | def __repr__(cls) -> str: 103 | return f"" 104 | 105 | def __len__(cls) -> int: 106 | return len(cls._member_map_) 107 | 108 | def __setattr__(cls, name: str, value: Any) -> Never: 109 | raise AttributeError(f"{cls.__name__}: cannot reassign Enum members.") 110 | 111 | def __delattr__(cls, name: str) -> Never: 112 | raise AttributeError(f"{cls.__name__}: cannot delete Enum members.") 113 | 114 | def __contains__(cls, member: object) -> bool: 115 | return isinstance(member, cls) and member.name in cls._member_map_ 116 | 117 | 118 | class Enum(IntEnum if TYPE_CHECKING else int, metaclass=EnumType): 119 | """ 120 | The base class for protobuf enumerations, all generated enumerations will 121 | inherit from this. Emulates `enum.IntEnum`. 122 | """ 123 | 124 | name: Optional[str] 125 | value: int 126 | 127 | if not TYPE_CHECKING: 128 | 129 | def __new__(cls, *, name: Optional[str], value: int) -> Self: 130 | self = super().__new__(cls, value) 131 | super().__setattr__(self, "name", name) 132 | super().__setattr__(self, "value", value) 133 | return self 134 | 135 | def __getnewargs_ex__(self) -> Tuple[Tuple[()], Dict[str, Any]]: 136 | return (), {"name": self.name, "value": self.value} 137 | 138 | def __str__(self) -> str: 139 | return self.name or "None" 140 | 141 | def __repr__(self) -> str: 142 | return f"{self.__class__.__name__}.{self.name}" 143 | 144 | def __setattr__(self, key: str, value: Any) -> Never: 145 | raise AttributeError( 146 | f"{self.__class__.__name__} Cannot reassign a member's attributes." 147 | ) 148 | 149 | def __delattr__(self, item: Any) -> Never: 150 | raise AttributeError( 151 | f"{self.__class__.__name__} Cannot delete a member's attributes." 152 | ) 153 | 154 | def __copy__(self) -> Self: 155 | return self 156 | 157 | def __deepcopy__(self, memo: Any) -> Self: 158 | return self 159 | 160 | @classmethod 161 | def try_value(cls, value: int = 0) -> Self: 162 | """Return the value which corresponds to the value. 163 | 164 | Parameters 165 | ----------- 166 | value: :class:`int` 167 | The value of the enum member to get. 168 | 169 | Returns 170 | ------- 171 | :class:`Enum` 172 | The corresponding member or a new instance of the enum if 173 | ``value`` isn't actually a member. 174 | """ 175 | try: 176 | return cls._value_map_[value] 177 | except (KeyError, TypeError): 178 | return cls.__new__(cls, name=None, value=value) 179 | 180 | @classmethod 181 | def from_string(cls, name: str) -> Self: 182 | """Return the value which corresponds to the string name. 183 | 184 | Parameters 185 | ----------- 186 | name: :class:`str` 187 | The name of the enum member to get. 188 | 189 | Raises 190 | ------- 191 | :exc:`ValueError` 192 | The member was not found in the Enum. 193 | """ 194 | try: 195 | return cls._member_map_[name] 196 | except KeyError as e: 197 | raise ValueError(f"Unknown value {name} for enum {cls.__name__}") from e 198 | --------------------------------------------------------------------------------