Skip to content

Add gRPC aio stub and servicer generation #489

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Mar 13, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/main.yml
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ jobs:
- name: Run formatters and linters
run: |
pip3 install black isort flake8-pyi flake8-noqa flake8-bugbear
black --check .
black --check --extend-exclude '(_pb2_grpc|_pb2).pyi?$' .
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this should already be happening in the pyproject.toml

Was intentionally having the generated .pyi files match black formatting. Would be willing to relax that constraint if it's really difficult to maintain, but I think it's worth a little bit of effort to see if we can get it to work because people end up command-clicking to pyi files in VSCode.

Doesn't need to block this diff - I can look later before releasing - filed #489 .

isort --check . --diff
flake8 .
- name: run shellcheck
Expand Down
112 changes: 88 additions & 24 deletions mypy_protobuf/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -663,30 +663,77 @@ def _map_key_value_types(

return ktype, vtype

def _callable_type(self, method: d.MethodDescriptorProto) -> str:
def _callable_type(self, method: d.MethodDescriptorProto, is_async: bool = False) -> str:
module = "grpc.aio" if is_async else "grpc"
if method.client_streaming:
if method.server_streaming:
return self._import("grpc", "StreamStreamMultiCallable")
return self._import(module, "StreamStreamMultiCallable")
else:
return self._import("grpc", "StreamUnaryMultiCallable")
return self._import(module, "StreamUnaryMultiCallable")
else:
if method.server_streaming:
return self._import("grpc", "UnaryStreamMultiCallable")
return self._import(module, "UnaryStreamMultiCallable")
else:
return self._import("grpc", "UnaryUnaryMultiCallable")
return self._import(module, "UnaryUnaryMultiCallable")

def _input_type(self, method: d.MethodDescriptorProto, use_stream_iterator: bool = True) -> str:
def _input_type(self, method: d.MethodDescriptorProto) -> str:
result = self._import_message(method.input_type)
if use_stream_iterator and method.client_streaming:
result = f"{self._import('collections.abc', 'Iterator')}[{result}]"
return result

def _output_type(self, method: d.MethodDescriptorProto, use_stream_iterator: bool = True) -> str:
def _servicer_input_type(self, method: d.MethodDescriptorProto) -> str:
result = self._import_message(method.input_type)
if method.client_streaming:
# See write_grpc_async_hacks().
result = f"_MaybeAsyncIterator[{result}]"
return result

def _output_type(self, method: d.MethodDescriptorProto) -> str:
result = self._import_message(method.output_type)
if use_stream_iterator and method.server_streaming:
result = f"{self._import('collections.abc', 'Iterator')}[{result}]"
return result

def _servicer_output_type(self, method: d.MethodDescriptorProto) -> str:
result = self._import_message(method.output_type)
if method.server_streaming:
# Union[Iterator[Resp], AsyncIterator[Resp]] is subtyped by Iterator[Resp] and AsyncIterator[Resp].
# So both can be used in the covariant function return position.
iterator = f"{self._import('collections.abc', 'Iterator')}[{result}]"
aiterator = f"{self._import('collections.abc', 'AsyncIterator')}[{result}]"
result = f"{self._import('typing', 'Union')}[{iterator}, {aiterator}]"
else:
# Union[Resp, Awaitable[Resp]] is subtyped by Resp and Awaitable[Resp].
# So both can be used in the covariant function return position.
# Awaitable[Resp] is equivalent to async def.
awaitable = f"{self._import('collections.abc', 'Awaitable')}[{result}]"
result = f"{self._import('typing', 'Union')}[{result}, {awaitable}]"
return result

def write_grpc_async_hacks(self) -> None:
wl = self._write_line
# _MaybeAsyncIterator[Req] is supertyped by Iterator[Req] and AsyncIterator[Req].
# So both can be used in the contravariant function parameter position.
wl("_T = {}('_T')", self._import("typing", "TypeVar"))
wl("")
wl(
"class _MaybeAsyncIterator({}[_T], {}[_T], metaclass={}):",
self._import("collections.abc", "AsyncIterator"),
self._import("collections.abc", "Iterator"),
self._import("abc", "ABCMeta"),
)
with self._indent():
wl("...")
wl("")

# _ServicerContext is supertyped by grpc.ServicerContext and grpc.aio.ServicerContext
# So both can be used in the contravariant function parameter position.
wl(
"class _ServicerContext({}, {}): # type: ignore",
self._import("grpc", "ServicerContext"),
self._import("grpc.aio", "ServicerContext"),
)
with self._indent():
wl("...")
wl("")

def write_grpc_methods(self, service: d.ServiceDescriptorProto, scl_prefix: SourceCodeLocation) -> None:
wl = self._write_line
methods = [(i, m) for i, m in enumerate(service.method) if m.name not in PYTHON_RESERVED]
Expand All @@ -701,20 +748,20 @@ def write_grpc_methods(self, service: d.ServiceDescriptorProto, scl_prefix: Sour
with self._indent():
wl("self,")
input_name = "request_iterator" if method.client_streaming else "request"
input_type = self._input_type(method)
input_type = self._servicer_input_type(method)
wl(f"{input_name}: {input_type},")
wl("context: {},", self._import("grpc", "ServicerContext"))
wl("context: _ServicerContext,")
wl(
") -> {}:{}",
self._output_type(method),
self._servicer_output_type(method),
" ..." if not self._has_comments(scl) else "",
)
if self._has_comments(scl):
with self._indent():
if not self._write_comments(scl):
wl("...")

def write_grpc_stub_methods(self, service: d.ServiceDescriptorProto, scl_prefix: SourceCodeLocation) -> None:
def write_grpc_stub_methods(self, service: d.ServiceDescriptorProto, scl_prefix: SourceCodeLocation, is_async: bool = False) -> None:
wl = self._write_line
methods = [(i, m) for i, m in enumerate(service.method) if m.name not in PYTHON_RESERVED]
if not methods:
Expand All @@ -723,10 +770,10 @@ def write_grpc_stub_methods(self, service: d.ServiceDescriptorProto, scl_prefix:
for i, method in methods:
scl = scl_prefix + [d.ServiceDescriptorProto.METHOD_FIELD_NUMBER, i]

wl("{}: {}[", method.name, self._callable_type(method))
wl("{}: {}[", method.name, self._callable_type(method, is_async=is_async))
with self._indent():
wl("{},", self._input_type(method, False))
wl("{},", self._output_type(method, False))
wl("{},", self._input_type(method))
wl("{},", self._output_type(method))
wl("]")
self._write_comments(scl)

Expand All @@ -743,17 +790,31 @@ def write_grpc_services(
scl = scl_prefix + [i]

# The stub client
wl(f"class {service.name}Stub:")
wl(
"class {}Stub:",
service.name,
)
with self._indent():
if self._write_comments(scl):
wl("")
wl(
"def __init__(self, channel: {}) -> None: ...",
self._import("grpc", "Channel"),
)
# To support casting into FooAsyncStub, allow both Channel and aio.Channel here.
channel = f"{self._import('typing', 'Union')}[{self._import('grpc', 'Channel')}, {self._import('grpc.aio', 'Channel')}]"
wl("def __init__(self, channel: {}) -> None: ...", channel)
self.write_grpc_stub_methods(service, scl)
wl("")

# The (fake) async stub client
wl(
"class {}AsyncStub:",
service.name,
)
with self._indent():
if self._write_comments(scl):
wl("")
# No __init__ since this isn't a real class (yet), and requires manual casting to work.
self.write_grpc_stub_methods(service, scl, is_async=True)
wl("")

# The service definition interface
wl(
"class {}Servicer(metaclass={}):",
Expand All @@ -765,11 +826,13 @@ def write_grpc_services(
wl("")
self.write_grpc_methods(service, scl)
wl("")
server = self._import("grpc", "Server")
aserver = self._import("grpc.aio", "Server")
wl(
"def add_{}Servicer_to_server(servicer: {}Servicer, server: {}) -> None: ...",
service.name,
service.name,
self._import("grpc", "Server"),
f"{self._import('typing', 'Union')}[{server}, {aserver}]",
)
wl("")

Expand Down Expand Up @@ -960,6 +1023,7 @@ def generate_mypy_grpc_stubs(
relax_strict_optional_primitives,
grpc=True,
)
pkg_writer.write_grpc_async_hacks()
pkg_writer.write_grpc_services(fd.service, [d.FileDescriptorProto.SERVICE_FIELD_NUMBER])

assert name == fd.name
Expand Down
42 changes: 21 additions & 21 deletions run_test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ RED="\033[0;31m"
NC='\033[0m'

PY_VER_MYPY_PROTOBUF=${PY_VER_MYPY_PROTOBUF:=3.10.6}
PY_VER_MYPY_PROTOBUF_SHORT=$(echo $PY_VER_MYPY_PROTOBUF | cut -d. -f1-2)
PY_VER_MYPY_PROTOBUF_SHORT=$(echo "$PY_VER_MYPY_PROTOBUF" | cut -d. -f1-2)
PY_VER_MYPY=${PY_VER_MYPY:=3.8.13}
PY_VER_UNIT_TESTS="${PY_VER_UNIT_TESTS:=3.8.13}"

Expand Down Expand Up @@ -45,16 +45,16 @@ MYPY_VENV=venv_$PY_VER_MYPY
(
eval "$(pyenv init --path)"
eval "$(pyenv init -)"
pyenv shell $PY_VER_MYPY
pyenv shell "$PY_VER_MYPY"

if [[ -z $SKIP_CLEAN ]] || [[ ! -e $MYPY_VENV ]]; then
python3 --version
python3 -m pip --version
python -m pip install virtualenv
python3 -m virtualenv $MYPY_VENV
$MYPY_VENV/bin/python3 -m pip install -r mypy_requirements.txt
python3 -m virtualenv "$MYPY_VENV"
"$MYPY_VENV"/bin/python3 -m pip install -r mypy_requirements.txt
fi
$MYPY_VENV/bin/mypy --version
"$MYPY_VENV"/bin/mypy --version
)

# Create unit tests venvs
Expand All @@ -63,14 +63,14 @@ for PY_VER in $PY_VER_UNIT_TESTS; do
UNIT_TESTS_VENV=venv_$PY_VER
eval "$(pyenv init --path)"
eval "$(pyenv init -)"
pyenv shell $PY_VER
pyenv shell "$PY_VER"

if [[ -z $SKIP_CLEAN ]] || [[ ! -e $UNIT_TESTS_VENV ]]; then
python -m pip install virtualenv
python -m virtualenv $UNIT_TESTS_VENV
$UNIT_TESTS_VENV/bin/python -m pip install -r test_requirements.txt
python -m virtualenv "$UNIT_TESTS_VENV"
"$UNIT_TESTS_VENV"/bin/python -m pip install -r test_requirements.txt
fi
$UNIT_TESTS_VENV/bin/py.test --version
"$UNIT_TESTS_VENV"/bin/py.test --version
)
done

Expand All @@ -79,19 +79,19 @@ MYPY_PROTOBUF_VENV=venv_$PY_VER_MYPY_PROTOBUF
(
eval "$(pyenv init --path)"
eval "$(pyenv init -)"
pyenv shell $PY_VER_MYPY_PROTOBUF
pyenv shell "$PY_VER_MYPY_PROTOBUF"

# Create virtualenv + Install requirements for mypy-protobuf
if [[ -z $SKIP_CLEAN ]] || [[ ! -e $MYPY_PROTOBUF_VENV ]]; then
python -m pip install virtualenv
python -m virtualenv $MYPY_PROTOBUF_VENV
$MYPY_PROTOBUF_VENV/bin/python -m pip install -e .
python -m virtualenv "$MYPY_PROTOBUF_VENV"
"$MYPY_PROTOBUF_VENV"/bin/python -m pip install -e .
fi
)

# Run mypy-protobuf
(
source $MYPY_PROTOBUF_VENV/bin/activate
source "$MYPY_PROTOBUF_VENV"/bin/activate

# Confirm version number
test "$(protoc-gen-mypy -V)" = "mypy-protobuf 3.4.0"
Expand Down Expand Up @@ -138,22 +138,22 @@ MYPY_PROTOBUF_VENV=venv_$PY_VER_MYPY_PROTOBUF

for PY_VER in $PY_VER_UNIT_TESTS; do
UNIT_TESTS_VENV=venv_$PY_VER
PY_VER_MYPY_TARGET=$(echo $PY_VER | cut -d. -f1-2)
PY_VER_MYPY_TARGET=$(echo "$PY_VER" | cut -d. -f1-2)

# Generate GRPC protos for mypy / tests
(
source $UNIT_TESTS_VENV/bin/activate
source "$UNIT_TESTS_VENV"/bin/activate
find proto/testproto/grpc -name "*.proto" -print0 | xargs -0 python -m grpc_tools.protoc "${PROTOC_ARGS[@]}" --grpc_python_out=test/generated
)

# Run mypy on unit tests / generated output
(
source $MYPY_VENV/bin/activate
source "$MYPY_VENV"/bin/activate
export MYPYPATH=$MYPYPATH:test/generated

# Run mypy
MODULES=( "-m" "test" )
mypy --custom-typeshed-dir="$CUSTOM_TYPESHED_DIR" --python-executable=$UNIT_TESTS_VENV/bin/python --python-version="$PY_VER_MYPY_TARGET" "${MODULES[@]}"
MODULES=( -m test.test_generated_mypy -m test.test_grpc_usage -m test.test_grpc_async_usage )
mypy --custom-typeshed-dir="$CUSTOM_TYPESHED_DIR" --python-executable="$UNIT_TESTS_VENV"/bin/python --python-version="$PY_VER_MYPY_TARGET" "${MODULES[@]}"

# Run stubtest. Stubtest does not work with python impl - only cpp impl
API_IMPL="$(python3 -c "import google.protobuf.internal.api_implementation as a ; print(a.Type())")"
Expand All @@ -173,12 +173,12 @@ for PY_VER in $PY_VER_UNIT_TESTS; do
cut -d: -f1,3- "$MYPY_OUTPUT/mypy_output" > "$MYPY_OUTPUT/mypy_output.omit_linenos"
}

call_mypy $PY_VER "${NEGATIVE_MODULES[@]}"
call_mypy "$PY_VER" "${NEGATIVE_MODULES[@]}"
if ! diff "$MYPY_OUTPUT/mypy_output" "test_negative/output.expected.$PY_VER_MYPY_TARGET" || ! diff "$MYPY_OUTPUT/mypy_output.omit_linenos" "test_negative/output.expected.$PY_VER_MYPY_TARGET.omit_linenos"; then
echo -e "${RED}test_negative/output.expected.$PY_VER_MYPY_TARGET didnt match. Copying over for you. Now rerun${NC}"

# Copy over all the mypy results for the developer.
call_mypy $PY_VER "${NEGATIVE_MODULES[@]}"
call_mypy "$PY_VER" "${NEGATIVE_MODULES[@]}"
cp "$MYPY_OUTPUT/mypy_output" test_negative/output.expected.3.8
cp "$MYPY_OUTPUT/mypy_output.omit_linenos" test_negative/output.expected.3.8.omit_linenos
exit 1
Expand All @@ -187,7 +187,7 @@ for PY_VER in $PY_VER_UNIT_TESTS; do

(
# Run unit tests.
source $UNIT_TESTS_VENV/bin/activate
source "$UNIT_TESTS_VENV"/bin/activate
PYTHONPATH=test/generated py.test --ignore=test/generated -v
)
done
4 changes: 4 additions & 0 deletions stubtest_allowlist.txt
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,10 @@ testproto.readme_enum_pb2._?MyEnum(EnumTypeWrapper)?
testproto.nested.nested_pb2.AnotherNested._?NestedEnum(EnumTypeWrapper)?
testproto.nested.nested_pb2.AnotherNested.NestedMessage._?NestedEnum2(EnumTypeWrapper)?

# Our fake async stubs are not there at runtime (yet)
testproto.grpc.dummy_pb2_grpc.DummyServiceAsyncStub
testproto.grpc.import_pb2_grpc.SimpleServiceAsyncStub

# Part of an "EXPERIMENTAL API" according to comment. Not documented.
testproto.grpc.dummy_pb2_grpc.DummyService
testproto.grpc.import_pb2_grpc.SimpleService
Expand Down
Loading