Skip to content

Detect functions that should use ParamSpec, but don't #150

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 9 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,10 @@
# Change Log

## Unreleased

* introduce Y032 (detect functions that should be annotated with ``ParamSpec``,
but aren't).

## 22.1.0

* extend Y001 to cover `ParamSpec` and `TypeVarTuple` in addition to `TypeVar`
Expand Down
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ currently emitted:
| Y029 | It is almost always redundant to define `__str__` or `__repr__` in a stub file, as the signatures are almost always identical to `object.__str__` and `object.__repr__`.
| Y030 | Union expressions should never have more than one `Literal` member, as `Literal[1] \| Literal[2]` is semantically identical to `Literal[1, 2]`.
| Y031 | `TypedDict`s should use class-based syntax instead of assignment-based syntax wherever possible. (In situations where this is not possible, such as if a field is a Python keyword or an invalid identifier, this error will not be raised.)
| Y032 | Use `ParamSpec` to annotate certain kinds of functions. (E.g. `def foo(func: Callable[..., R], *args: Any, **kwargs: Any) -> R: ...` should probably be rewritten as `def foo(func: Callable[P, R], *args: P.args, **kwargs: P.kwargs) -> R: ...`.

Many error codes enforce modern conventions, and some cannot yet be used in
all cases:
Expand Down
89 changes: 89 additions & 0 deletions pyi.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,83 @@ def visit_Index(self, node: ast.Index) -> ast.expr:
return node.value


def _is_name(node: ast.expr | None, name: str) -> bool:
"""Return `True` if `node` is the AST representation of `name`."""
return isinstance(node, ast.Name) and node.id == name


def _is_attribute(node: ast.expr | None, attribute: str) -> bool:
"""Determine whether `node` is the AST representation of `attribute`.
Only works if `attribute has a single `.` delimiter, e.g. "collection.abc".
"""
return (
isinstance(node, ast.Attribute)
and isinstance(node.value, ast.Name)
and [node.value.id, node.attr] == attribute.split(".")
)


def _is_ellipsis_callable(annotation: ast.expr | None) -> bool:
"""Evaluate whether `annotation` is an "ellipsis callable".

Return `True` if `annotation` is either:
* `Callable[..., foo]`
* `typing.Callable[..., foo]`
* `collections.abc.Callable[..., foo]`
"""
if not isinstance(annotation, ast.Subscript):
return False

# Now we know it's a subscript e.g. `Foo[bar]`
if not (
isinstance(annotation.slice, ast.Tuple)
and len(annotation.slice.elts) == 2
and isinstance(annotation.slice.elts[0], ast.Ellipsis)
):
return False

# Now we know it's e.g. `Foo[..., bar]`
subscripted_object = annotation.value

if isinstance(subscripted_object, ast.Name):
return subscripted_object.id == "Callable"

if not (
isinstance(subscripted_object, ast.Attribute)
and subscripted_object.attr == "Callable"
):
return False

# Now we know it's an attribute e.g. `Foo.Callable[..., bar]`
module = subscripted_object.value
return _is_name(module, "typing") or _is_attribute(module, "collections.abc")


def _is_Any(annotation: ast.expr | None) -> bool:
"""Return `True` if `annotation` is `Any` or `typing.Any`"""
return _is_name(annotation, "Any") or _is_attribute(annotation, "typing.Any")


def _should_use_ParamSpec(function: ast.FunctionDef | ast.AsyncFunctionDef) -> bool:
"""Determine whether a function needs to be rewritten to use ParamSpec, if it doesn't currently."""
arguments = function.args

non_variadic_args = chain(
arguments.args, arguments.kwonlyargs, getattr(arguments, "posonlyargs", [])
)

if not any(
_is_ellipsis_callable(arg_node.annotation) for arg_node in non_variadic_args
):
return False

# Now check for functions like `def foo(__func: Callable[P, R], *args: P.args, **kwargs: P.kwargs) -> R: ...`
return all(
(isinstance(arg, ast.arg) and _is_Any(arg.annotation))
for arg in (arguments.vararg, arguments.kwarg)
)


def _unparse_assign_node(node: ast.Assign | ast.AnnAssign) -> str:
"""Unparse an Assign node, and remove any newlines in it"""
return unparse(node).replace("\n", "")
Expand Down Expand Up @@ -306,6 +383,7 @@ def __init__(self, filename: Path = Path("none")) -> None:
self.string_literals_allowed = NestingCounter()
self.in_function = NestingCounter()
self.in_class = NestingCounter()
self.current_class_name = ""

def __repr__(self) -> str:
return f"{self.__class__.__name__}(filename={self.filename!r})"
Expand Down Expand Up @@ -711,7 +789,10 @@ def _check_platform_check(self, node: ast.Compare) -> None:

def visit_ClassDef(self, node: ast.ClassDef) -> None:
with self.in_class.enabled():
old_class_name = self.current_class_name
self.current_class_name = node.name
self.generic_visit(node)
self.current_class_name = old_class_name

# empty class body should contain "..." not "pass"
if len(node.body) == 1:
Expand Down Expand Up @@ -878,6 +959,13 @@ def _visit_function(self, node: ast.FunctionDef | ast.AsyncFunctionDef) -> None:
if self.in_class.active:
self.check_self_typevars(node)

if _should_use_ParamSpec(node):
if self.in_class.active:
funcname = f"{self.current_class_name}.{node.name}"
else:
funcname = node.name
self.error(node, Y032.format(funcname))

def visit_arguments(self, node: ast.arguments) -> None:
self.generic_visit(node)
args = node.args[-len(node.defaults) :]
Expand Down Expand Up @@ -998,3 +1086,4 @@ def parse_options(cls, optmanager, options, extra_args) -> None:
Y029 = "Y029 Defining __repr__ or __str__ in a stub is almost always redundant"
Y030 = "Y030 Multiple Literal members in a union. {suggestion}"
Y031 = "Y031 Use class-based syntax for TypedDicts where possible"
Y032 = 'Y032 Consider using ParamSpec to annotate function "{}"'
36 changes: 36 additions & 0 deletions tests/paramspec.pyi
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
import collections.abc
import typing
from typing import Any, Callable, ParamSpec, TypeVar

_P = ParamSpec("_P")
_T = TypeVar("_T")
_R = TypeVar("_R")

# GOOD FUNCTIONS
def func() -> None: ...
def func1(arg: int) -> str: ...
def func2(arg: Callable[[int], str]) -> Callable[[str], int]: ...
def func3(arg: Callable[..., str]) -> Callable[[str], str]: ...
def func4(arg: Callable[[str], str]) -> Callable[..., str]: ...
def func5(arg: Callable[[str], _R]) -> Callable[[str], _R]: ...
def func6(arg: Callable[[_T], str]) -> Callable[[_T], int]: ...
def func7(arg: Callable[_P, _R]) -> Callable[_P, _R]: ...
def func8(func: Callable[..., _R], *args: str, **kwargs: int) -> _R: ...
def func9(func: Callable[..., _R], *args: Any, **kwargs: int) -> _R: ...
def func10(func: Callable[..., _R], *args: str, **kwargs: Any) -> _R: ...
def func11(func: Callable[_P, _R], *args: _P.args, **kwargs: _P.kwargs) -> _R: ...
def func12(func: Callable[_P, str], *args: _P.args, **kwargs: _P.kwargs) -> int: ...
def func13(func: typing.Callable[_P, str], *args: _P.args, **kwargs: _P.kwargs) -> int: ...
def func14(func: collections.abc.Callable[_P, str], *args: _P.args, **kwargs: _P.kwargs) -> int: ...
def func15(arg: Callable[..., int]) -> Callable[..., str]: ...
def func16(arg: Callable[..., _R]) -> Callable[..., _R]: ...

# BAD FUNCTIONS
def func17(arg: Callable[..., _R], *args: Any, **kwargs: Any) -> _R: ... # Y032 Consider using ParamSpec to annotate function "func17"
def func18(arg: Callable[..., str], *args: Any, **kwargs: Any) -> int: ... # Y032 Consider using ParamSpec to annotate function "func18"
def func19(arg: collections.abc.Callable[..., str], *args: Any, **kwargs: Any) -> int: ... # Y032 Consider using ParamSpec to annotate function "func19"
def func20(arg: typing.Callable[..., str], *args: typing.Any, **kwargs: typing.Any) -> int: ... # Y032 Consider using ParamSpec to annotate function "func20"
def func21(arg: collections.abc.Callable[..., str], *args: typing.Any, **kwargs: typing.Any) -> int: ... # Y032 Consider using ParamSpec to annotate function "func21"

class Foo:
def __call__(self, func: Callable[..., _R], *args: Any, **kwargs: Any) -> _R: ... # Y032 Consider using ParamSpec to annotate function "Foo.__call__"