Skip to content

Commit a724c3b

Browse files
committed
Improve
1 parent ed25ec4 commit a724c3b

File tree

2 files changed

+38
-39
lines changed

2 files changed

+38
-39
lines changed

pyi.py

+27-35
Original file line numberDiff line numberDiff line change
@@ -233,7 +233,23 @@ def visit_Index(self, node: ast.Index) -> ast.expr:
233233
return node.value
234234

235235

236-
def _annotation_is_ellipsis_callable(annotation: ast.expr | None) -> bool:
236+
def _is_name(node: ast.expr | None, name: str) -> bool:
237+
"""Return `True` if `node` is the AST representation of `name`."""
238+
return isinstance(node, ast.Name) and node.id == name
239+
240+
241+
def _is_attribute(node: ast.expr | None, attribute: str) -> bool:
242+
"""Determine whether `node` is the AST representation of `attribute`.
243+
Only works if `attribute has a single `.` delimiter, e.g. "collection.abc".
244+
"""
245+
return (
246+
isinstance(node, ast.Attribute)
247+
and isinstance(node.value, ast.Name)
248+
and [node.value.id, node.attr] == attribute.split(".")
249+
)
250+
251+
252+
def _is_ellipsis_callable(annotation: ast.expr | None) -> bool:
237253
"""Evaluate whether `annotation` is an "ellipsis callable".
238254
239255
Return `True` if `annotation` is either:
@@ -254,7 +270,6 @@ def _annotation_is_ellipsis_callable(annotation: ast.expr | None) -> bool:
254270
return False
255271

256272
# Now we know it's e.g. `Foo[..., bar]`
257-
258273
subscripted_object = annotation.value
259274

260275
if isinstance(subscripted_object, ast.Name):
@@ -267,18 +282,13 @@ def _annotation_is_ellipsis_callable(annotation: ast.expr | None) -> bool:
267282
return False
268283

269284
# Now we know it's an attribute e.g. `Foo.Callable[..., bar]`
270-
271285
module = subscripted_object.value
286+
return _is_name(module, "typing") or _is_attribute(module, "collections.abc")
272287

273-
if isinstance(module, ast.Name):
274-
return module.id == "typing"
275288

276-
return (
277-
isinstance(module, ast.Attribute)
278-
and isinstance(module.value, ast.Name)
279-
and module.value.id == "collections"
280-
and module.attr == "abc"
281-
)
289+
def _is_Any(annotation: ast.expr | None) -> bool:
290+
"""Return `True` if `annotation` is `Any` or `typing.Any`"""
291+
return _is_name(annotation, "Any") or _is_attribute(annotation, "typing.Any")
282292

283293

284294
def _should_use_ParamSpec(function: ast.FunctionDef | ast.AsyncFunctionDef) -> bool:
@@ -290,37 +300,19 @@ def _should_use_ParamSpec(function: ast.FunctionDef | ast.AsyncFunctionDef) -> b
290300
)
291301

292302
if not any(
293-
_annotation_is_ellipsis_callable(arg_node.annotation)
294-
for arg_node in non_variadic_args
303+
_is_ellipsis_callable(arg_node.annotation) for arg_node in non_variadic_args
295304
):
296305
return False
297306

298307
# First check for functions like `def foo(func: Callable[P, R]) -> Callable[P, R]: ...`
299-
300-
if _annotation_is_ellipsis_callable(function.returns):
308+
if _is_ellipsis_callable(function.returns):
301309
return True
302310

303311
# Now check for functions like `def foo(__func: Callable[P, R], *args: P.args, **kwargs: P.kwargs) -> R: ...`
304-
305-
vararg, kwarg = arguments.vararg, arguments.kwarg
306-
if not (isinstance(vararg, ast.arg) and isinstance(kwarg, ast.arg)):
307-
return False
308-
309-
for annotation in (vararg.annotation, kwarg.annotation):
310-
if isinstance(annotation, ast.Name):
311-
if annotation.id != "Any":
312-
return False
313-
elif isinstance(annotation, ast.Attribute):
314-
if not (
315-
isinstance(annotation.value, ast.Name)
316-
and annotation.value.id == "typing"
317-
and annotation.value == "Any"
318-
):
319-
return False
320-
else:
321-
return False
322-
323-
return True
312+
return all(
313+
(isinstance(arg, ast.arg) and _is_Any(arg.annotation))
314+
for arg in (arguments.vararg, arguments.kwarg)
315+
)
324316

325317

326318
def _unparse_assign_node(node: ast.Assign | ast.AnnAssign) -> str:

tests/paramspec.pyi

+11-4
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import collections.abc
2+
import typing
13
from typing import Any, Callable, ParamSpec, TypeVar
24

35
_P = ParamSpec("_P")
@@ -18,12 +20,17 @@ def func9(func: Callable[..., _R], *args: Any, **kwargs: int) -> _R: ...
1820
def func10(func: Callable[..., _R], *args: str, **kwargs: Any) -> _R: ...
1921
def func11(func: Callable[_P, _R], *args: _P.args, **kwargs: _P.kwargs) -> _R: ...
2022
def func12(func: Callable[_P, str], *args: _P.args, **kwargs: _P.kwargs) -> int: ...
23+
def func13(func: typing.Callable[_P, str], *args: _P.args, **kwargs: _P.kwargs) -> int: ...
24+
def func14(func: collections.abc.Callable[_P, str], *args: _P.args, **kwargs: _P.kwargs) -> int: ...
2125

2226
# BAD FUNCTIONS
23-
def func13(arg: Callable[..., int]) -> Callable[..., str]: ... # Y032 Consider using ParamSpec to annotate function "func13"
24-
def func14(arg: Callable[..., _R]) -> Callable[..., _R]: ... # Y032 Consider using ParamSpec to annotate function "func14"
25-
def func15(arg: Callable[..., _R], *args: Any, **kwargs: Any) -> _R: ... # Y032 Consider using ParamSpec to annotate function "func15"
26-
def func16(arg: Callable[..., str], *args: Any, **kwargs: Any) -> int: ... # Y032 Consider using ParamSpec to annotate function "func16"
27+
def func15(arg: Callable[..., int]) -> Callable[..., str]: ... # Y032 Consider using ParamSpec to annotate function "func15"
28+
def func16(arg: Callable[..., _R]) -> Callable[..., _R]: ... # Y032 Consider using ParamSpec to annotate function "func16"
29+
def func17(arg: Callable[..., _R], *args: Any, **kwargs: Any) -> _R: ... # Y032 Consider using ParamSpec to annotate function "func17"
30+
def func18(arg: Callable[..., str], *args: Any, **kwargs: Any) -> int: ... # Y032 Consider using ParamSpec to annotate function "func18"
31+
def func19(arg: collections.abc.Callable[..., str], *args: Any, **kwargs: Any) -> int: ... # Y032 Consider using ParamSpec to annotate function "func19"
32+
def func20(arg: typing.Callable[..., str], *args: typing.Any, **kwargs: typing.Any) -> int: ... # Y032 Consider using ParamSpec to annotate function "func20"
33+
def func21(arg: collections.abc.Callable[..., str], *args: typing.Any, **kwargs: typing.Any) -> int: ... # Y032 Consider using ParamSpec to annotate function "func21"
2734

2835
class Foo:
2936
def __call__(self, func: Callable[..., _R], *args: Any, **kwargs: Any) -> _R: ... # Y032 Consider using ParamSpec to annotate function "Foo.__call__"

0 commit comments

Comments
 (0)