diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 9c53af51..62863b54 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -49,7 +49,7 @@ jobs: - name: Move compiled files to betterproto2 shell: bash - run: mv betterproto2_compiler/tests/output_betterproto betterproto2_compiler/tests/output_betterproto_pydantic betterproto2_compiler/tests/output_reference betterproto2/tests + run: mv betterproto2_compiler/tests/output_betterproto betterproto2_compiler/tests/output_betterproto_pydantic betterproto2_compiler/tests/output_betterproto_descriptor betterproto2_compiler/tests/output_reference betterproto2/tests - name: Execute test suite working-directory: ./betterproto2 diff --git a/.gitignore b/.gitignore index de01ba67..442c3f77 100644 --- a/.gitignore +++ b/.gitignore @@ -6,7 +6,7 @@ .pytest_cache .python-version build/ -tests/output_* +*/tests/output_* **/__pycache__ dist **/*.egg-info diff --git a/betterproto2/docs/descriptors.md b/betterproto2/docs/descriptors.md new file mode 100644 index 00000000..de521dee --- /dev/null +++ b/betterproto2/docs/descriptors.md @@ -0,0 +1,11 @@ +# Google Protobuf Descriptors + +Google's protoc plugin for Python generated DESCRIPTOR fields that enable reflection capabilities in many libraries (e.g. grpc, grpclib, mcap). + +By default, betterproto2 doesn't generate these as it introduces a dependency on `protobuf`. If you're okay with this dependency and want to generate DESCRIPTORs, use the compiler option `python_betterproto2_opt=google_protobuf_descriptors`. + + +## grpclib Reflection + +In order to properly use reflection right now, you will need to modify the `DescriptorPool` that is used by grpclib's `ServerReflection`. To do so, take a look at the use of `ServerReflection.extend` in the `test_grpclib_reflection` test in https://github.com/vmagamedov/grpclib/blob/master/tests/grpc/test_grpclib_reflection.py + In the future, once https://github.com/vmagamedov/grpclib/pull/204 is merged, you will be able to pass the `default_google_proto_descriptor_pool` into the `ServerReflection.extend` class method. diff --git a/betterproto2/mkdocs.yml b/betterproto2/mkdocs.yml index a8c90113..9d5b7289 100644 --- a/betterproto2/mkdocs.yml +++ b/betterproto2/mkdocs.yml @@ -14,6 +14,7 @@ nav: - Clients: tutorial/clients.md - API: api.md - Development: development.md + - Protobuf Descriptors: descriptors.md plugins: diff --git a/betterproto2/pyproject.toml b/betterproto2/pyproject.toml index 41eb24f7..516bdb51 100644 --- a/betterproto2/pyproject.toml +++ b/betterproto2/pyproject.toml @@ -22,7 +22,8 @@ Repository = "https://github.com/betterproto/python-betterproto2" grpcio = ["grpcio>=1.72.1"] grpclib = ["grpclib>=0.4.8"] pydantic = ["pydantic>=2.11.5"] -all = ["grpclib>=0.4.8", "grpcio>=1.72.1", "pydantic>=2.11.5"] +protobuf = ["protobuf>=5.29.3"] +all = ["grpclib>=0.4.8", "grpcio>=1.72.1", "pydantic>=2.11.5", "protobuf>=5.29.3"] [dependency-groups] dev = [ @@ -38,7 +39,6 @@ dev = [ test = [ "cachelib>=0.13.0", "poethepoet>=0.34.0", - "protobuf>=5.29.3", "pytest>=8.4.0", "pytest-asyncio>=1.0.0", "pytest-cov>=6.1.1", @@ -144,6 +144,7 @@ rm -rf tests/output_* && git clone https://github.com/betterproto/python-betterproto2-compiler --branch compiled-test-files --single-branch compiled_files && mv compiled_files/tests_betterproto tests/output_betterproto && mv compiled_files/tests_betterproto_pydantic tests/output_betterproto_pydantic && +mv compiled_files/tests_betterproto_pydantic tests/output_betterproto_descriptor && mv compiled_files/tests_reference tests/output_reference && rm -rf compiled_files """ diff --git a/betterproto2/src/betterproto2/__init__.py b/betterproto2/src/betterproto2/__init__.py index 239d379b..7ceda2f0 100644 --- a/betterproto2/src/betterproto2/__init__.py +++ b/betterproto2/src/betterproto2/__init__.py @@ -1,6 +1,6 @@ from __future__ import annotations -__all__ = ["__version__", "check_compiler_version", "unwrap", "MessagePool", "validators"] +__all__ = ["__version__", "check_compiler_version", "classproperty", "unwrap", "MessagePool", "validators"] import dataclasses import enum as builtin_enum diff --git a/betterproto2/tests/grpc/test_grpclib_reflection.py b/betterproto2/tests/grpc/test_grpclib_reflection.py new file mode 100644 index 00000000..3680cf7a --- /dev/null +++ b/betterproto2/tests/grpc/test_grpclib_reflection.py @@ -0,0 +1,98 @@ +import asyncio +from typing import Generic, TypeVar + +import pytest +from google.protobuf import descriptor_pb2 +from grpclib.reflection.service import ServerReflection +from grpclib.reflection.v1.reflection_grpc import ServerReflectionBase as ServerReflectionBaseV1 +from grpclib.reflection.v1alpha.reflection_grpc import ServerReflectionBase as ServerReflectionBaseV1Alpha +from grpclib.testing import ChannelFor + +from tests.output_betterproto.example_service import TestBase +from tests.output_betterproto.grpc.reflection.v1 import ( + ErrorResponse, + ListServiceResponse, + ServerReflectionRequest, + ServerReflectionStub, + ServiceResponse, +) +from tests.output_betterproto_descriptor.google_proto_descriptor_pool import default_google_proto_descriptor_pool + + +class TestService(TestBase): + pass + + +T = TypeVar("T") + + +class AsyncIterableQueue(Generic[T]): + CLOSED_SENTINEL = object() + + def __init__(self): + self._queue = asyncio.Queue() + self._done = asyncio.Event() + + def put(self, item: T): + self._queue.put_nowait(item) + + def close(self): + self._queue.put_nowait(self.CLOSED_SENTINEL) + + def __aiter__(self): + return self + + async def __anext__(self) -> T: + val = await self._queue.get() + if val is self.CLOSED_SENTINEL: + raise StopAsyncIteration + return val + + +@pytest.mark.asyncio +async def test_grpclib_reflection(): + service = TestService() + services = ServerReflection.extend([service]) + for service in services: + # This won't be needed once https://github.com/vmagamedov/grpclib/pull/204 is in. + if isinstance(service, ServerReflectionBaseV1Alpha | ServerReflectionBaseV1): + service._pool = default_google_proto_descriptor_pool + + async with ChannelFor(services) as channel: + requests = AsyncIterableQueue[ServerReflectionRequest]() + responses = ServerReflectionStub(channel).server_reflection_info(requests) + + # list services + requests.put(ServerReflectionRequest(list_services="")) + response = await anext(responses) + assert response.list_services_response == ListServiceResponse( + service=[ServiceResponse(name="example_service.Test")] + ) + + # list methods + + # should fail before we've added descriptors to the protobuf pool + requests.put(ServerReflectionRequest(file_containing_symbol="example_service.Test")) + response = await anext(responses) + assert response.error_response == ErrorResponse(error_code=5, error_message="not found") + assert response.file_descriptor_response is None + + # now it should work + import tests.output_betterproto_descriptor.example_service as example_service_with_desc + + requests.put(ServerReflectionRequest(file_containing_symbol="example_service.Test")) + response = await anext(responses) + expected = descriptor_pb2.FileDescriptorProto.FromString( + example_service_with_desc.EXAMPLE_SERVICE_PROTO_DESCRIPTOR.serialized_pb + ) + assert response.error_response is None + assert response.file_descriptor_response is not None + assert len(response.file_descriptor_response.file_descriptor_proto) == 1 + actual = descriptor_pb2.FileDescriptorProto.FromString( + response.file_descriptor_response.file_descriptor_proto[0] + ) + assert actual == expected + + requests.close() + + await anext(responses, None) diff --git a/betterproto2/tests/grpc/test_message_enum_descriptors.py b/betterproto2/tests/grpc/test_message_enum_descriptors.py new file mode 100644 index 00000000..a7922820 --- /dev/null +++ b/betterproto2/tests/grpc/test_message_enum_descriptors.py @@ -0,0 +1,19 @@ +import pytest + +from tests.output_betterproto.import_cousin_package_same_name.test.subpackage import Test + +# importing the cousin should cause no descriptor pool errors since the subpackage imports it once already +from tests.output_betterproto_descriptor.import_cousin_package_same_name.cousin.subpackage import CousinMessage +from tests.output_betterproto_descriptor.import_cousin_package_same_name.test.subpackage import Test as TestWithDesc + + +def test_message_enum_descriptors(): + # Normally descriptors are not available as they require protobuf support + # to inteoperate with other libraries. + with pytest.raises(AttributeError): + Test.DESCRIPTOR.full_name + + # But the python_betterproto2_opt=google_protobuf_descriptors option + # will add them in as long as protobuf is depended on. + assert TestWithDesc.DESCRIPTOR.full_name == "import_cousin_package_same_name.test.subpackage.Test" + assert CousinMessage.DESCRIPTOR.full_name == "import_cousin_package_same_name.cousin.subpackage.CousinMessage" diff --git a/betterproto2/tests/test_deprecated.py b/betterproto2/tests/test_deprecated.py index ea16d370..2930f6cf 100644 --- a/betterproto2/tests/test_deprecated.py +++ b/betterproto2/tests/test_deprecated.py @@ -7,6 +7,7 @@ Empty, Message, Test, + TestNested, TestServiceStub, ) @@ -26,6 +27,14 @@ def test_deprecated_message(): assert str(record[0].message) == f"{Message.__name__} is deprecated" +def test_deprecated_nested_message_field(): + with pytest.warns(DeprecationWarning) as record: + TestNested(nested_value="hello") + + assert len(record) == 1 + assert str(record[0].message) == f"TestNested.nested_value is deprecated" + + def test_message_with_deprecated_field(message): with pytest.warns(DeprecationWarning) as record: Test(message=message, value=10) diff --git a/betterproto2/uv.lock b/betterproto2/uv.lock index 8db85352..fcf92ad3 100644 --- a/betterproto2/uv.lock +++ b/betterproto2/uv.lock @@ -68,6 +68,7 @@ dependencies = [ all = [ { name = "grpcio" }, { name = "grpclib" }, + { name = "protobuf" }, { name = "pydantic" }, ] grpcio = [ @@ -76,6 +77,9 @@ grpcio = [ grpclib = [ { name = "grpclib" }, ] +protobuf = [ + { name = "protobuf" }, +] pydantic = [ { name = "pydantic" }, ] @@ -93,7 +97,6 @@ dev = [ test = [ { name = "cachelib" }, { name = "poethepoet" }, - { name = "protobuf" }, { name = "pytest" }, { name = "pytest-asyncio" }, { name = "pytest-cov" }, @@ -106,12 +109,14 @@ requires-dist = [ { name = "grpcio", marker = "extra == 'grpcio'", specifier = ">=1.72.1" }, { name = "grpclib", marker = "extra == 'all'", specifier = ">=0.4.8" }, { name = "grpclib", marker = "extra == 'grpclib'", specifier = ">=0.4.8" }, + { name = "protobuf", marker = "extra == 'all'", specifier = ">=5.29.3" }, + { name = "protobuf", marker = "extra == 'protobuf'", specifier = ">=5.29.3" }, { name = "pydantic", marker = "extra == 'all'", specifier = ">=2.11.5" }, { name = "pydantic", marker = "extra == 'pydantic'", specifier = ">=2.11.5" }, { name = "python-dateutil", specifier = ">=2.9.0.post0" }, { name = "typing-extensions", specifier = ">=4.14.0" }, ] -provides-extras = ["grpcio", "grpclib", "pydantic", "all"] +provides-extras = ["grpcio", "grpclib", "pydantic", "protobuf", "all"] [package.metadata.requires-dev] dev = [ @@ -126,7 +131,6 @@ dev = [ test = [ { name = "cachelib", specifier = ">=0.13.0" }, { name = "poethepoet", specifier = ">=0.34.0" }, - { name = "protobuf", specifier = ">=5.29.3" }, { name = "pytest", specifier = ">=8.4.0" }, { name = "pytest-asyncio", specifier = ">=1.0.0" }, { name = "pytest-cov", specifier = ">=6.1.1" }, diff --git a/betterproto2_compiler/pyproject.toml b/betterproto2_compiler/pyproject.toml index 28bc9a45..ab3804a2 100644 --- a/betterproto2_compiler/pyproject.toml +++ b/betterproto2_compiler/pyproject.toml @@ -122,6 +122,20 @@ python -m grpc.tools.protoc \ google/protobuf/timestamp.proto \ google/protobuf/type.proto \ google/protobuf/wrappers.proto + +python -m grpc.tools.protoc \ + --python_betterproto2_out=tests/output_betterproto_descriptor \ + --python_betterproto2_opt=google_protobuf_descriptors \ + google/protobuf/any.proto \ + google/protobuf/api.proto \ + google/protobuf/duration.proto \ + google/protobuf/empty.proto \ + google/protobuf/field_mask.proto \ + google/protobuf/source_context.proto \ + google/protobuf/struct.proto \ + google/protobuf/timestamp.proto \ + google/protobuf/type.proto \ + google/protobuf/wrappers.proto """ [tool.poe.tasks.typecheck] diff --git a/betterproto2_compiler/src/betterproto2_compiler/compile/importing.py b/betterproto2_compiler/src/betterproto2_compiler/compile/importing.py index 9e4901db..6b7b8c08 100644 --- a/betterproto2_compiler/src/betterproto2_compiler/compile/importing.py +++ b/betterproto2_compiler/src/betterproto2_compiler/compile/importing.py @@ -55,6 +55,32 @@ def parse_source_type_name(field_type_name: str, request: PluginRequestCompiler) raise ValueError(f"can't find type name: {field_type_name}") +def get_symbol_reference( + *, + package: str, + imports: set, + source_package: str, + symbol: str, +) -> tuple[str, str | None]: + """ + Return a Python symbol within a proto package. Adds the import if + necessary and returns it as well for usage. Unwraps well known type if required. + """ + current_package: list[str] = package.split(".") if package else [] + py_package: list[str] = source_package.split(".") if source_package else [] + + if py_package == current_package: + return (reference_sibling(symbol), None) + + if py_package[: len(current_package)] == current_package: + return reference_descendent(current_package, imports, py_package, symbol) + + if current_package[: len(py_package)] == py_package: + return reference_ancestor(current_package, imports, py_package, symbol) + + return reference_cousin(current_package, imports, py_package, symbol) + + def get_type_reference( *, package: str, @@ -73,30 +99,25 @@ def get_type_reference( if wrap and (source_package, source_type) in WRAPPED_TYPES: return WRAPPED_TYPES[(source_package, source_type)] - current_package: list[str] = package.split(".") if package else [] - py_package: list[str] = source_package.split(".") if source_package else [] py_type: str = pythonize_class_name(source_type) - - if py_package == current_package: - return reference_sibling(py_type) - - if py_package[: len(current_package)] == current_package: - return reference_descendent(current_package, imports, py_package, py_type) - - if current_package[: len(py_package)] == py_package: - return reference_ancestor(current_package, imports, py_package, py_type) - - return reference_cousin(current_package, imports, py_package, py_type) + (ref, _) = get_symbol_reference( + package=package, + imports=imports, + source_package=source_package, + symbol=py_type, + ) + return ref -def reference_absolute(imports: set[str], py_package: list[str], py_type: str) -> str: +def reference_absolute(imports: set[str], py_package: list[str], py_type: str) -> tuple[str, str]: """ Returns a reference to a python type located in the root, i.e. sys.path. """ string_import = ".".join(py_package) string_alias = "__".join([safe_snake_case(name) for name in py_package]) - imports.add(f"import {string_import} as {string_alias}") - return f"{string_alias}.{py_type}" + import_to_add = f"import {string_import} as {string_alias}" + imports.add(import_to_add) + return (f"{string_alias}.{py_type}", import_to_add) def reference_sibling(py_type: str) -> str: @@ -106,7 +127,9 @@ def reference_sibling(py_type: str) -> str: return f"{py_type}" -def reference_descendent(current_package: list[str], imports: set[str], py_package: list[str], py_type: str) -> str: +def reference_descendent( + current_package: list[str], imports: set[str], py_package: list[str], py_type: str +) -> tuple[str, str]: """ Returns a reference to a python type in a package that is a descendent of the current package, and adds the required import that is aliased to avoid name @@ -116,15 +139,19 @@ def reference_descendent(current_package: list[str], imports: set[str], py_packa string_from = ".".join(importing_descendent[:-1]) string_import = importing_descendent[-1] if string_from: - string_alias = "_".join(importing_descendent) - imports.add(f"from .{string_from} import {string_import} as {string_alias}") - return f"{string_alias}.{py_type}" + string_alias = f"{'_'.join(importing_descendent)}" + import_to_add = f"from .{string_from} import {string_import} as {string_alias}" + imports.add(import_to_add) + return (f"{string_alias}.{py_type}", import_to_add) else: - imports.add(f"from . import {string_import}") - return f"{string_import}.{py_type}" + import_to_add = f"from . import {string_import}" + imports.add(import_to_add) + return (f"{string_import}.{py_type}", import_to_add) -def reference_ancestor(current_package: list[str], imports: set[str], py_package: list[str], py_type: str) -> str: +def reference_ancestor( + current_package: list[str], imports: set[str], py_package: list[str], py_type: str +) -> tuple[str, str]: """ Returns a reference to a python type in a package which is an ancestor to the current package, and adds the required import that is aliased (if possible) to avoid @@ -137,15 +164,19 @@ def reference_ancestor(current_package: list[str], imports: set[str], py_package string_import = py_package[-1] string_alias = f"_{'_' * distance_up}{string_import}__" string_from = f"..{'.' * distance_up}" - imports.add(f"from {string_from} import {string_import} as {string_alias}") - return f"{string_alias}.{py_type}" + import_to_add = f"from {string_from} import {string_import} as {string_alias}" + imports.add(import_to_add) + return (f"{string_alias}.{py_type}", import_to_add) else: string_alias = f"{'_' * distance_up}{py_type}__" - imports.add(f"from .{'.' * distance_up} import {py_type} as {string_alias}") - return string_alias + import_to_add = f"from .{'.' * distance_up} import {py_type} as {string_alias}" + imports.add(import_to_add) + return (string_alias, import_to_add) -def reference_cousin(current_package: list[str], imports: set[str], py_package: list[str], py_type: str) -> str: +def reference_cousin( + current_package: list[str], imports: set[str], py_package: list[str], py_type: str +) -> tuple[str, str]: """ Returns a reference to a python type in a package that is not descendent, ancestor or sibling, and adds the required import that is aliased to avoid name conflicts. @@ -161,5 +192,6 @@ def reference_cousin(current_package: list[str], imports: set[str], py_package: + "__".join([safe_snake_case(name) for name in py_package[len(shared_ancestry) :]]) + "__" ) - imports.add(f"from {string_from} import {string_import} as {string_alias}") - return f"{string_alias}.{py_type}" + import_to_add = f"from {string_from} import {string_import} as {string_alias}" + imports.add(import_to_add) + return (f"{string_alias}.{py_type}", import_to_add) diff --git a/betterproto2_compiler/src/betterproto2_compiler/plugin/models.py b/betterproto2_compiler/src/betterproto2_compiler/plugin/models.py index 52b69d54..5b339735 100644 --- a/betterproto2_compiler/src/betterproto2_compiler/plugin/models.py +++ b/betterproto2_compiler/src/betterproto2_compiler/plugin/models.py @@ -53,6 +53,7 @@ MethodDescriptorProto, OneofDescriptorProto, ServiceDescriptorProto, + SourceCodeInfo, ) from betterproto2_compiler.lib.google.protobuf.compiler import CodeGeneratorRequest from betterproto2_compiler.settings import Settings @@ -216,6 +217,33 @@ def input_filenames(self) -> list[str]: """ return sorted([f.name for f in self.input_files]) + def get_descriptor_name(self, source_file: FileDescriptorProto): + return f"{source_file.name.replace('/', '_').replace('.', '_').upper()}_DESCRIPTOR" + + @property + def descriptors(self): + """Google protobuf library descriptors. + + Returns + ------- + str + A list of pool registrations for proto descriptors. + """ + descriptors: list[str] = [] + + for f in self.input_files: + # Remove the source_code_info field since it is not needed at runtime. + source_code_info: SourceCodeInfo | None = f.source_code_info + f.source_code_info = None + + descriptors.append( + f"{self.get_descriptor_name(f)} = default_google_proto_descriptor_pool.AddSerializedFile({bytes(f)})" + ) + + f.source_code_info = source_code_info + + return "\n".join(descriptors) + @dataclass(kw_only=True) class MessageCompiler(ProtoContentBase): @@ -223,6 +251,7 @@ class MessageCompiler(ProtoContentBase): output_file: OutputTemplate proto_obj: DescriptorProto + prefixed_proto_name: str fields: list["FieldCompiler"] = field(default_factory=list) oneofs: list["OneofCompiler"] = field(default_factory=list) builtins_types: set[str] = field(default_factory=set) @@ -233,7 +262,7 @@ def proto_name(self) -> str: @property def py_name(self) -> str: - return pythonize_class_name(self.proto_name) + return pythonize_class_name(self.prefixed_proto_name) @property def deprecated(self) -> bool: @@ -266,6 +295,17 @@ def custom_methods(self) -> list[str]: return methods_source + @property + def descriptor_name(self) -> str: + """Google protobuf library descriptor name. + + Returns + ------- + str + The Python name of the descriptor to reference. + """ + return self.output_file.get_descriptor_name(self.source_file) + def is_map(proto_field_obj: FieldDescriptorProto, parent_message: DescriptorProto) -> bool: """True if proto_field_obj is a map, otherwise False.""" @@ -562,6 +602,7 @@ class EnumDefinitionCompiler(ProtoContentBase): output_file: OutputTemplate proto_obj: EnumDescriptorProto + prefixed_proto_name: str entries: list["EnumDefinitionCompiler.EnumEntry"] = field(default_factory=list) @dataclass(unsafe_hash=True, kw_only=True) @@ -589,12 +630,23 @@ def proto_name(self) -> str: @property def py_name(self) -> str: - return pythonize_class_name(self.proto_name) + return pythonize_class_name(self.prefixed_proto_name) @property def deprecated(self) -> bool: return bool(self.proto_obj.options and self.proto_obj.options.deprecated) + @property + def descriptor_name(self) -> str: + """Google protobuf library descriptor name. + + Returns + ------- + str + The Python name of the descriptor to reference. + """ + return self.output_file.get_descriptor_name(self.source_file) + @dataclass(kw_only=True) class ServiceCompiler(ProtoContentBase): diff --git a/betterproto2_compiler/src/betterproto2_compiler/plugin/parser.py b/betterproto2_compiler/src/betterproto2_compiler/plugin/parser.py index 72435bb1..1d609a74 100644 --- a/betterproto2_compiler/src/betterproto2_compiler/plugin/parser.py +++ b/betterproto2_compiler/src/betterproto2_compiler/plugin/parser.py @@ -35,20 +35,21 @@ def traverse( proto_file: FileDescriptorProto, -) -> Generator[tuple[EnumDescriptorProto | DescriptorProto, list[int]], None, None]: +) -> Generator[tuple[EnumDescriptorProto | DescriptorProto, list[int], str], None, None]: # Todo: Keep information about nested hierarchy def _traverse( path: list[int], items: list[EnumDescriptorProto] | list[DescriptorProto], prefix: str = "", - ) -> Generator[tuple[EnumDescriptorProto | DescriptorProto, list[int]], None, None]: + ) -> Generator[tuple[EnumDescriptorProto | DescriptorProto, list[int], str], None, None]: for i, item in enumerate(items): # Adjust the name since we flatten the hierarchy. - # Todo: don't change the name, but include full name in returned tuple should_rename = not isinstance(item, DescriptorProto) or not item.options or not item.options.map_entry - item.name = next_prefix = f"{prefix}.{item.name}" if prefix and should_rename else item.name - yield item, [*path, i] + # Record prefixed name but *do not* mutate original file. + # We use this prefixed name to create pythonized names. + prefixed_name = next_prefix = f"{prefix}.{item.name}" if prefix and should_rename else item.name + yield item, [*path, i], prefixed_name if isinstance(item, DescriptorProto): # Get nested types. @@ -81,6 +82,7 @@ def get_settings(plugin_options: list[str]) -> Settings: return Settings( pydantic_dataclasses="pydantic_dataclasses" in plugin_options, + google_protobuf_descriptors="google_protobuf_descriptors" in plugin_options, client_generation=client_generation, server_generation=server_generation, ) @@ -109,12 +111,13 @@ def generate_code(request: CodeGeneratorRequest) -> CodeGeneratorResponse: # get the references to input/output messages for each service for output_package_name, output_package in request_data.output_packages.items(): for proto_input_file in output_package.input_files: - for item, path in traverse(proto_input_file): + for item, path, prefixed_proto_name in traverse(proto_input_file): read_protobuf_type( source_file=proto_input_file, item=item, path=path, output_package=output_package, + prefixed_proto_name=prefixed_proto_name, ) # Read Services @@ -168,6 +171,15 @@ def generate_code(request: CodeGeneratorRequest) -> CodeGeneratorResponse: ) ) + if settings.google_protobuf_descriptors: + response.file.append( + CodeGeneratorResponseFile( + name="google_proto_descriptor_pool.py", + content="from google.protobuf import descriptor_pool\n\n" + + "default_google_proto_descriptor_pool = descriptor_pool.DescriptorPool()\n", + ) + ) + for output_package_name in sorted(output_paths.union(init_files)): print(f"Writing {output_package_name}", file=sys.stderr) @@ -179,6 +191,7 @@ def read_protobuf_type( path: list[int], source_file: "FileDescriptorProto", output_package: OutputTemplate, + prefixed_proto_name: str, ) -> None: if isinstance(item, DescriptorProto): if item.options and item.options.map_entry: @@ -188,10 +201,11 @@ def read_protobuf_type( message_data = MessageCompiler( source_file=source_file, output_file=output_package, + prefixed_proto_name=prefixed_proto_name, proto_obj=item, path=path, ) - output_package.messages[message_data.proto_name] = message_data + output_package.messages[message_data.prefixed_proto_name] = message_data for index, field in enumerate(item.field): if is_map(field, item): @@ -243,10 +257,11 @@ def read_protobuf_type( enum = EnumDefinitionCompiler( source_file=source_file, output_file=output_package, + prefixed_proto_name=prefixed_proto_name, proto_obj=item, path=path, ) - output_package.enums[enum.proto_name] = enum + output_package.enums[enum.prefixed_proto_name] = enum def read_protobuf_service( diff --git a/betterproto2_compiler/src/betterproto2_compiler/settings.py b/betterproto2_compiler/src/betterproto2_compiler/settings.py index a5269939..2952a383 100644 --- a/betterproto2_compiler/src/betterproto2_compiler/settings.py +++ b/betterproto2_compiler/src/betterproto2_compiler/settings.py @@ -68,6 +68,7 @@ class ServerGeneration(StrEnum): @dataclass class Settings: pydantic_dataclasses: bool + google_protobuf_descriptors: bool client_generation: ClientGeneration server_generation: ServerGeneration diff --git a/betterproto2_compiler/src/betterproto2_compiler/templates/header.py.j2 b/betterproto2_compiler/src/betterproto2_compiler/templates/header.py.j2 index 64a2de57..aa348bf7 100644 --- a/betterproto2_compiler/src/betterproto2_compiler/templates/header.py.j2 +++ b/betterproto2_compiler/src/betterproto2_compiler/templates/header.py.j2 @@ -32,12 +32,19 @@ import betterproto2 from betterproto2.grpc.grpclib_server import ServiceBase import grpc import grpclib +from google.protobuf.descriptor import Descriptor, EnumDescriptor {# Import the message pool of the generated code. #} {% if output_file.package %} from {{ "." * output_file.package.count(".") }}..message_pool import default_message_pool +{% if output_file.settings.google_protobuf_descriptors %} +from {{ "." * output_file.package.count(".") }}..google_proto_descriptor_pool import default_google_proto_descriptor_pool +{% endif %} {% else %} from .message_pool import default_message_pool +{% if output_file.settings.google_protobuf_descriptors %} +from .google_proto_descriptor_pool import default_google_proto_descriptor_pool +{% endif %} {% endif %} if TYPE_CHECKING: @@ -45,4 +52,5 @@ if TYPE_CHECKING: from betterproto2.grpc.grpclib_client import MetadataLike from grpclib.metadata import Deadline -betterproto2.check_compiler_version("{{ version }}") +_COMPILER_VERSION="{{ version }}" +betterproto2.check_compiler_version(_COMPILER_VERSION) diff --git a/betterproto2_compiler/src/betterproto2_compiler/templates/template.py.j2 b/betterproto2_compiler/src/betterproto2_compiler/templates/template.py.j2 index f51a3b39..09562fb5 100644 --- a/betterproto2_compiler/src/betterproto2_compiler/templates/template.py.j2 +++ b/betterproto2_compiler/src/betterproto2_compiler/templates/template.py.j2 @@ -6,6 +6,13 @@ class {{ enum.py_name | add_to_all }}(betterproto2.Enum): """ {% endif %} + {% if output_file.settings.google_protobuf_descriptors %} + {# Add descriptor class property to be more drop-in compatible with other libraries. #} + @betterproto2.classproperty + def DESCRIPTOR(self) -> EnumDescriptor: + return {{ enum.descriptor_name }}.enum_types_by_name['{{ enum.prefixed_proto_name }}'] + {% endif %} + {% for entry in enum.entries %} {{ entry.name }} = {{ entry.value }} {% if entry.comment %} @@ -45,6 +52,13 @@ class {{ message.py_name | add_to_all }}(betterproto2.Message): """ {% endif %} + {% if output_file.settings.google_protobuf_descriptors %} + {# Add descriptor class property to be more drop-in compatible with other libraries. #} + @betterproto2.classproperty + def DESCRIPTOR(self) -> Descriptor: + return {{ message.descriptor_name }}.message_types_by_name['{{ message.prefixed_proto_name }}'] + {% endif %} + {% for field in message.fields %} {{ field.get_field_string() }} {% if field.comment %} @@ -81,7 +95,7 @@ class {{ message.py_name | add_to_all }}(betterproto2.Message): {{ method_source }} {% endfor %} -default_message_pool.register_message("{{ output_file.package }}", "{{ message.proto_name }}", {{ message.py_name }}) +default_message_pool.register_message("{{ output_file.package }}", "{{ message.prefixed_proto_name }}", {{ message.py_name }}) {% endfor %} @@ -102,6 +116,11 @@ default_message_pool.register_message("{{ output_file.package }}", "{{ message.p {{ i }} {% endfor %} +{% if output_file.settings.google_protobuf_descriptors %} +{# Add descriptors to Google protobuf's default pool to be more drop-in compatible with other libraries. #} +{{ output_file.descriptors }} +{% endif %} + {% if output_file.settings.server_generation == "async" %} {% for _, service in output_file.services|dictsort(by="key") %} class {{ (service.py_name + "Base") | add_to_all }}(ServiceBase): @@ -127,6 +146,10 @@ class {{ (service.py_name + "Base") | add_to_all }}(ServiceBase): {% endif %} raise grpclib.GRPCError(grpclib.const.Status.UNIMPLEMENTED) + {% if method.server_streaming %} + {# yielding here changes the return type from a coroutine to an async_generator #} + yield {{ method.py_output_message_type }}() + {% endif %} {% endfor %} diff --git a/betterproto2_compiler/tests/generate.py b/betterproto2_compiler/tests/generate.py index 8e422d6a..16b5297b 100644 --- a/betterproto2_compiler/tests/generate.py +++ b/betterproto2_compiler/tests/generate.py @@ -9,6 +9,7 @@ get_directories, inputs_path, output_path_betterproto, + output_path_betterproto_descriptor, output_path_betterproto_pydantic, output_path_reference, protoc, @@ -57,23 +58,28 @@ async def generate_test_case_output(test_case_input_path: Path, test_case_name: test_case_output_path_reference = output_path_reference.joinpath(test_case_name) test_case_output_path_betterproto = output_path_betterproto test_case_output_path_betterproto_pyd = output_path_betterproto_pydantic + test_case_output_path_betterproto_desc = output_path_betterproto_descriptor os.makedirs(test_case_output_path_reference, exist_ok=True) os.makedirs(test_case_output_path_betterproto, exist_ok=True) os.makedirs(test_case_output_path_betterproto_pyd, exist_ok=True) + os.makedirs(test_case_output_path_betterproto_desc, exist_ok=True) clear_directory(test_case_output_path_reference) clear_directory(test_case_output_path_betterproto) clear_directory(test_case_output_path_betterproto_pyd) + clear_directory(test_case_output_path_betterproto_desc) ( (ref_out, ref_err, ref_code), (plg_out, plg_err, plg_code), (plg_out_pyd, plg_err_pyd, plg_code_pyd), + (plg_out_desc, plg_err_desc, plg_code_desc), ) = await asyncio.gather( protoc(test_case_input_path, test_case_output_path_reference, True), protoc(test_case_input_path, test_case_output_path_betterproto, False), protoc(test_case_input_path, test_case_output_path_betterproto_pyd, False, True), + protoc(test_case_input_path, test_case_output_path_betterproto_desc, False, False, True), ) if ref_code == 0: @@ -127,7 +133,26 @@ async def generate_test_case_output(test_case_input_path: Path, test_case_name: sys.stderr.buffer.write(plg_err_pyd) sys.stderr.buffer.flush() - return max(ref_code, plg_code, plg_code_pyd) + if plg_code_desc == 0: + print(f"\033[31;1;4mGenerated plugin (google protobuf descriptor) output for {test_case_name!r}\033[0m") + else: + print( + f"\033[31;1;4mFailed to generate plugin (google protobuf descriptor) output for {test_case_name!r}\033[0m" + ) + print(plg_err_desc.decode()) + + if verbose: + if plg_out_desc: + print("Plugin stdout:") + sys.stdout.buffer.write(plg_out_desc) + sys.stdout.buffer.flush() + + if plg_err_desc: + print("Plugin stderr:") + sys.stderr.buffer.write(plg_err_desc) + sys.stderr.buffer.flush() + + return max(ref_code, plg_code, plg_code_pyd, plg_code_desc) def main(): diff --git a/betterproto2_compiler/tests/inputs/deprecated/deprecated.proto b/betterproto2_compiler/tests/inputs/deprecated/deprecated.proto index f504d03a..2e64c621 100644 --- a/betterproto2_compiler/tests/inputs/deprecated/deprecated.proto +++ b/betterproto2_compiler/tests/inputs/deprecated/deprecated.proto @@ -6,6 +6,9 @@ package deprecated; message Test { Message message = 1 [deprecated=true]; int32 value = 2; + message Nested { + int32 nested_value = 1 [deprecated=true]; + } } message Message { diff --git a/betterproto2_compiler/tests/inputs/example_service/example_service.proto b/betterproto2_compiler/tests/inputs/example_service/example_service.proto index 96455cc3..4ef60236 100644 --- a/betterproto2_compiler/tests/inputs/example_service/example_service.proto +++ b/betterproto2_compiler/tests/inputs/example_service/example_service.proto @@ -2,6 +2,8 @@ syntax = "proto3"; package example_service; +import "google/protobuf/struct.proto"; + service Test { rpc ExampleUnaryUnary(ExampleRequest) returns (ExampleResponse); rpc ExampleUnaryStream(ExampleRequest) returns (stream ExampleResponse); @@ -12,9 +14,11 @@ service Test { message ExampleRequest { string example_string = 1; int64 example_integer = 2; + google.protobuf.Struct example_struct = 3; } message ExampleResponse { string example_string = 1; int64 example_integer = 2; + google.protobuf.Struct example_struct = 3; } diff --git a/betterproto2_compiler/tests/inputs/grpc_reflection_v1/reflection.proto b/betterproto2_compiler/tests/inputs/grpc_reflection_v1/reflection.proto new file mode 100644 index 00000000..f9f349fe --- /dev/null +++ b/betterproto2_compiler/tests/inputs/grpc_reflection_v1/reflection.proto @@ -0,0 +1,146 @@ +// Copyright 2016 The gRPC Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Service exported by server reflection. A more complete description of how +// server reflection works can be found at +// https://github.com/grpc/grpc/blob/master/doc/server-reflection.md +// +// The canonical version of this proto can be found at +// https://github.com/grpc/grpc-proto/blob/master/grpc/reflection/v1/reflection.proto + +syntax = "proto3"; + +package grpc.reflection.v1; + +option go_package = "google.golang.org/grpc/reflection/grpc_reflection_v1"; +option java_multiple_files = true; +option java_package = "io.grpc.reflection.v1"; +option java_outer_classname = "ServerReflectionProto"; + +service ServerReflection { + // The reflection service is structured as a bidirectional stream, ensuring + // all related requests go to a single server. + rpc ServerReflectionInfo(stream ServerReflectionRequest) + returns (stream ServerReflectionResponse); +} + +// The message sent by the client when calling ServerReflectionInfo method. +message ServerReflectionRequest { + string host = 1; + // To use reflection service, the client should set one of the following + // fields in message_request. The server distinguishes requests by their + // defined field and then handles them using corresponding methods. + oneof message_request { + // Find a proto file by the file name. + string file_by_filename = 3; + + // Find the proto file that declares the given fully-qualified symbol name. + // This field should be a fully-qualified symbol name + // (e.g. .[.] or .). + string file_containing_symbol = 4; + + // Find the proto file which defines an extension extending the given + // message type with the given field number. + ExtensionRequest file_containing_extension = 5; + + // Finds the tag numbers used by all known extensions of the given message + // type, and appends them to ExtensionNumberResponse in an undefined order. + // Its corresponding method is best-effort: it's not guaranteed that the + // reflection service will implement this method, and it's not guaranteed + // that this method will provide all extensions. Returns + // StatusCode::UNIMPLEMENTED if it's not implemented. + // This field should be a fully-qualified type name. The format is + // . + string all_extension_numbers_of_type = 6; + + // List the full names of registered services. The content will not be + // checked. + string list_services = 7; + } +} + +// The type name and extension number sent by the client when requesting +// file_containing_extension. +message ExtensionRequest { + // Fully-qualified type name. The format should be . + string containing_type = 1; + int32 extension_number = 2; +} + +// The message sent by the server to answer ServerReflectionInfo method. +message ServerReflectionResponse { + string valid_host = 1; + ServerReflectionRequest original_request = 2; + // The server sets one of the following fields according to the message_request + // in the request. + oneof message_response { + // This message is used to answer file_by_filename, file_containing_symbol, + // file_containing_extension requests with transitive dependencies. + // As the repeated label is not allowed in oneof fields, we use a + // FileDescriptorResponse message to encapsulate the repeated fields. + // The reflection service is allowed to avoid sending FileDescriptorProtos + // that were previously sent in response to earlier requests in the stream. + FileDescriptorResponse file_descriptor_response = 4; + + // This message is used to answer all_extension_numbers_of_type requests. + ExtensionNumberResponse all_extension_numbers_response = 5; + + // This message is used to answer list_services requests. + ListServiceResponse list_services_response = 6; + + // This message is used when an error occurs. + ErrorResponse error_response = 7; + } +} + +// Serialized FileDescriptorProto messages sent by the server answering +// a file_by_filename, file_containing_symbol, or file_containing_extension +// request. +message FileDescriptorResponse { + // Serialized FileDescriptorProto messages. We avoid taking a dependency on + // descriptor.proto, which uses proto2 only features, by making them opaque + // bytes instead. + repeated bytes file_descriptor_proto = 1; +} + +// A list of extension numbers sent by the server answering +// all_extension_numbers_of_type request. +message ExtensionNumberResponse { + // Full name of the base type, including the package name. The format + // is . + string base_type_name = 1; + repeated int32 extension_number = 2; +} + +// A list of ServiceResponse sent by the server answering list_services request. +message ListServiceResponse { + // The information of each service may be expanded in the future, so we use + // ServiceResponse message to encapsulate it. + repeated ServiceResponse service = 1; +} + +// The information of a single service used by ListServiceResponse to answer +// list_services request. +message ServiceResponse { + // Full name of a registered service, including its package name. The format + // is . + string name = 1; +} + +// The error code and error message sent by the server when an error occurs. +message ErrorResponse { + // This field uses the error codes defined in grpc::StatusCode. + int32 error_code = 1; + string error_message = 2; +} diff --git a/betterproto2_compiler/tests/util.py b/betterproto2_compiler/tests/util.py index 0e8366ff..8e12dda7 100644 --- a/betterproto2_compiler/tests/util.py +++ b/betterproto2_compiler/tests/util.py @@ -10,6 +10,7 @@ output_path_reference = root_path.joinpath("output_reference") output_path_betterproto = root_path.joinpath("output_betterproto") output_path_betterproto_pydantic = root_path.joinpath("output_betterproto_pydantic") +output_path_betterproto_descriptor = root_path.joinpath("output_betterproto_descriptor") def get_directories(path): @@ -17,18 +18,24 @@ def get_directories(path): yield from directories -async def protoc(path: str | Path, output_dir: str | Path, reference: bool = False, pydantic_dataclasses: bool = False): - path: Path = Path(path).resolve() - output_dir: Path = Path(output_dir).resolve() +async def protoc( + path: str | Path, + output_dir: str | Path, + reference: bool = False, + pydantic_dataclasses: bool = False, + google_protobuf_descriptors: bool = False, +): + resolved_path: Path = Path(path).resolve() + resolved_output_dir: Path = Path(output_dir).resolve() python_out_option: str = "python_out" if reference else "python_betterproto2_out" command = [ sys.executable, "-m", "grpc.tools.protoc", - f"--proto_path={path.as_posix()}", - f"--{python_out_option}={output_dir.as_posix()}", - *[p.as_posix() for p in path.glob("*.proto")], + f"--proto_path={resolved_path.as_posix()}", + f"--{python_out_option}={resolved_output_dir.as_posix()}", + *[p.as_posix() for p in resolved_path.glob("*.proto")], ] if not reference: @@ -38,10 +45,13 @@ async def protoc(path: str | Path, output_dir: str | Path, reference: bool = Fal if pydantic_dataclasses: command.insert(3, "--python_betterproto2_opt=pydantic_dataclasses") + if google_protobuf_descriptors: + command.insert(3, "--python_betterproto2_opt=google_protobuf_descriptors") + proc = await asyncio.create_subprocess_exec( *command, stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE, ) stdout, stderr = await proc.communicate() - return stdout, stderr, proc.returncode + return stdout, stderr, proc.returncode or 0