Skip to content

Commit f90729e

Browse files
Initial support for to_attr inferrence in Prefetch calls (#2779)
* Initial support for `to_attr` inferrence * Remove `OrderedDict` * Only support inline `Prefetch` calls in `prefetch_related`
1 parent 59bb8a6 commit f90729e

File tree

5 files changed

+237
-1
lines changed

5 files changed

+237
-1
lines changed

mypy_django_plugin/lib/fullnames.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
BASE_MANAGER_CLASS_FULLNAME = "django.db.models.manager.BaseManager"
1717
MANAGER_CLASS_FULLNAME = "django.db.models.manager.Manager"
1818
RELATED_MANAGER_CLASS = "django.db.models.fields.related_descriptors.RelatedManager"
19+
PREFETCH_CLASS_FULLNAME = "django.db.models.query.Prefetch"
1920

2021
CHOICES_CLASS_FULLNAME = "django.db.models.enums.Choices"
2122
CHOICES_TYPE_METACLASS_FULLNAME = "django.db.models.enums.ChoicesType"

mypy_django_plugin/lib/helpers.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,12 +13,14 @@
1313
ArgKind,
1414
AssignmentStmt,
1515
Block,
16+
CallExpr,
1617
ClassDef,
1718
Context,
1819
Expression,
1920
MemberExpr,
2021
MypyFile,
2122
NameExpr,
23+
Node,
2224
RefExpr,
2325
StrExpr,
2426
SymbolNode,
@@ -39,6 +41,7 @@
3941
from mypy.typeanal import make_optional_type
4042
from mypy.types import (
4143
AnyType,
44+
CallableType,
4245
Instance,
4346
LiteralType,
4447
NoneTyp,
@@ -198,6 +201,44 @@ def get_min_argument_count(ctx: MethodContext | FunctionContext) -> int:
198201
return sum(not kind.is_star() for kinds in ctx.arg_kinds for kind in kinds)
199202

200203

204+
def _get_class_init_type(call: CallExpr) -> CallableType | None:
205+
callee_node: Node | None = call.callee
206+
207+
if isinstance(callee_node, RefExpr):
208+
callee_node = callee_node.node
209+
210+
if (
211+
isinstance(callee_node, TypeInfo)
212+
and (init_sym := callee_node.get("__init__"))
213+
and isinstance((init_type := get_proper_type(init_sym.type)), CallableType)
214+
):
215+
return init_type
216+
return None
217+
218+
219+
def get_class_init_argument_by_name(call: CallExpr, name: str) -> Expression | None:
220+
"""Adaptation of `mypy.plugins.common._get_argument` for class initializers"""
221+
callee_type = _get_class_init_type(call)
222+
if not callee_type:
223+
return None
224+
225+
argument = callee_type.argument_by_name(name)
226+
if not argument:
227+
return None
228+
assert argument.name
229+
230+
for i, (attr_name, attr_value) in enumerate(
231+
zip(call.arg_names, call.args, strict=False),
232+
start=1, # Start at one to skip first `self` arg
233+
):
234+
if argument.pos is not None and not attr_name and i == argument.pos:
235+
return attr_value
236+
if attr_name == argument.name:
237+
return attr_value
238+
239+
return None
240+
241+
201242
def get_call_argument_by_name(ctx: FunctionContext | MethodContext, name: str) -> Expression | None:
202243
"""
203244
Return the expression for the specific argument.

mypy_django_plugin/main.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -160,6 +160,9 @@ def manager_and_queryset_method_hooks(self) -> dict[str, Callable[[MethodContext
160160
"filter": typecheck_filtering_method,
161161
"get": typecheck_filtering_method,
162162
"exclude": typecheck_filtering_method,
163+
"prefetch_related": partial(
164+
querysets.extract_prefetch_related_annotations, django_context=self.django_context
165+
),
163166
}
164167

165168
def get_method_hook(self, fullname: str) -> Callable[[MethodContext], MypyType] | None:

mypy_django_plugin/transformers/querysets.py

Lines changed: 80 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from django.db.models.fields.related import RelatedField
66
from django.db.models.fields.reverse_related import ForeignObjectRel
77
from mypy.checker import TypeChecker
8-
from mypy.nodes import ARG_NAMED, ARG_NAMED_OPT, Expression
8+
from mypy.nodes import ARG_NAMED, ARG_NAMED_OPT, CallExpr, Expression
99
from mypy.plugin import FunctionContext, MethodContext
1010
from mypy.types import AnyType, Instance, TupleType, TypedDictType, TypeOfAny, get_proper_type
1111
from mypy.types import Type as MypyType
@@ -372,3 +372,82 @@ def extract_proper_type_queryset_values(ctx: MethodContext, django_context: Djan
372372

373373
row_type = helpers.make_typeddict(ctx.api, column_types, set(column_types.keys()), set())
374374
return default_return_type.copy_modified(args=[model_type, row_type])
375+
376+
377+
def _infer_prefetch_element_model_type(queryset_expr: Expression | None, api: TypeChecker) -> Instance | None:
378+
"""Infer the model Instance from `Prefetch(queryset=...)`"""
379+
if queryset_expr is None:
380+
# TODO: Infer the model type from the lookup in `Prefetch(lookup=..., to_attr=...)`
381+
return None
382+
try:
383+
qs_type = get_proper_type(api.expr_checker.accept(queryset_expr))
384+
except Exception:
385+
return None
386+
if isinstance(qs_type, Instance):
387+
return _extract_model_type_from_queryset(qs_type, api)
388+
return None
389+
390+
391+
def extract_prefetch_related_annotations(ctx: MethodContext, django_context: DjangoContext) -> MypyType:
392+
"""
393+
Extract annotated attributes via `prefetch_related(Prefetch(..., to_attr=...))`
394+
395+
See https://docs.djangoproject.com/en/5.2/ref/models/querysets/#prefetch-objects
396+
"""
397+
if not (
398+
isinstance(ctx.type, Instance)
399+
and isinstance((default_return_type := get_proper_type(ctx.default_return_type)), Instance)
400+
and (api := helpers.get_typechecker_api(ctx))
401+
and (model_type := _extract_model_type_from_queryset(ctx.type, api)) is not None
402+
and ctx.args
403+
and ctx.arg_types
404+
and ctx.arg_types[0]
405+
):
406+
return ctx.default_return_type
407+
408+
fields: dict[str, MypyType] = {}
409+
410+
for expr, typ in zip(ctx.args[0], ctx.arg_types[0], strict=False):
411+
typ = get_proper_type(typ)
412+
if not (
413+
isinstance(typ, Instance)
414+
and typ.type.fullname == fullnames.PREFETCH_CLASS_FULLNAME
415+
and isinstance(expr, CallExpr)
416+
and (to_attr_expr := helpers.get_class_init_argument_by_name(expr, "to_attr"))
417+
and (to_attr_value := helpers.resolve_string_attribute_value(to_attr_expr, django_context))
418+
):
419+
continue
420+
421+
# Determine model type from the `queryset` attr
422+
queryset_expr = helpers.get_class_init_argument_by_name(expr, "queryset")
423+
elem_model = _infer_prefetch_element_model_type(queryset_expr, api)
424+
value_type = api.named_generic_type(
425+
"builtins.list",
426+
[elem_model if elem_model is not None else AnyType(TypeOfAny.special_form)],
427+
)
428+
429+
fields[to_attr_value] = value_type
430+
431+
if not fields:
432+
return ctx.default_return_type
433+
434+
fields_dict = helpers.make_typeddict(
435+
api,
436+
fields=fields,
437+
required_keys=set(fields.keys()),
438+
readonly_keys=set(),
439+
)
440+
441+
annotated_model = get_annotated_type(api, model_type, fields_dict=fields_dict)
442+
443+
# Keep row shape; if row is a model instance, update it to annotated
444+
# Todo: consolidate with `extract_proper_type_queryset_annotate` row handling above.
445+
if len(default_return_type.args) > 1:
446+
original_row = get_proper_type(default_return_type.args[1])
447+
row_type: MypyType = original_row
448+
if isinstance(original_row, Instance) and helpers.is_model_type(original_row.type):
449+
row_type = annotated_model
450+
else:
451+
row_type = annotated_model
452+
453+
return default_return_type.copy_modified(args=[annotated_model, row_type])
Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,112 @@
1+
- case: prefetch_related_to_attr
2+
main: |
3+
from myapp.models import Article, Tag
4+
from django.db.models import Prefetch
5+
from django.db import models
6+
7+
# Noop (to_attr not provided)
8+
reveal_type(Article.objects.prefetch_related(Prefetch("tags")).all()) # N: Revealed type is "django.db.models.query.QuerySet[myapp.models.Article, myapp.models.Article]"
9+
reveal_type(Article.objects.prefetch_related(Prefetch("tags", Tag.objects.all())).all()) # N: Revealed type is "django.db.models.query.QuerySet[myapp.models.Article, myapp.models.Article]"
10+
11+
# On the QuerySet
12+
article_qs = Article.objects.all().prefetch_related(Prefetch("tags", Tag.objects.all(), to_attr="every_tags"))
13+
reveal_type(article_qs) # N: Revealed type is "django.db.models.query.QuerySet[myapp.models.Article@AnnotatedWith[TypedDict({'every_tags': builtins.list[myapp.models.Tag]})], myapp.models.Article@AnnotatedWith[TypedDict({'every_tags': builtins.list[myapp.models.Tag]})]]"
14+
reveal_type(article_qs.get().every_tags) # N: Revealed type is "builtins.list[myapp.models.Tag]"
15+
reveal_type(
16+
Tag.objects.all().prefetch_related( # N: Revealed type is "builtins.list[myapp.models.Article]"
17+
Prefetch("article_set", Article.objects.all(), to_attr="every_articles")
18+
).get().every_articles
19+
)
20+
21+
# On the Manager
22+
reveal_type(
23+
Article.objects.prefetch_related( # N: Revealed type is "builtins.list[myapp.models.Tag]"
24+
Prefetch("tags", Tag.objects.all(), to_attr="every_tags")
25+
).get().every_tags
26+
)
27+
reveal_type(
28+
Tag.objects.prefetch_related( # N: Revealed type is "builtins.list[myapp.models.Article]"
29+
Prefetch("article_set", Article.objects.all(), to_attr="every_articles")
30+
).get().every_articles
31+
)
32+
33+
# Member expr
34+
reveal_type(
35+
Article.objects.prefetch_related( # N: Revealed type is "django.db.models.query.QuerySet[myapp.models.Article@AnnotatedWith[TypedDict({'every_tags': builtins.list[myapp.models.Tag]})], myapp.models.Article@AnnotatedWith[TypedDict({'every_tags': builtins.list[myapp.models.Tag]})]]"
36+
models.Prefetch("tags", Tag.objects.all(), to_attr="every_tags")
37+
).all()
38+
)
39+
# pos args only
40+
reveal_type(
41+
Article.objects.prefetch_related( # N: Revealed type is "django.db.models.query.QuerySet[myapp.models.Article@AnnotatedWith[TypedDict({'every_tags': builtins.list[myapp.models.Tag]})], myapp.models.Article@AnnotatedWith[TypedDict({'every_tags': builtins.list[myapp.models.Tag]})]]"
42+
models.Prefetch("tags", Tag.objects.all(), "every_tags")
43+
).all()
44+
)
45+
46+
# Multiple Prefetch items: both `to_attr` annotations should be present
47+
multi_qs = Article.objects.all().prefetch_related(
48+
Prefetch("tags", Tag.objects.all(), to_attr="ts"),
49+
Prefetch("tags", Tag.objects.all(), to_attr="ts2"),
50+
)
51+
reveal_type(multi_qs) # N: Revealed type is "django.db.models.query.QuerySet[myapp.models.Article@AnnotatedWith[TypedDict({'ts': builtins.list[myapp.models.Tag], 'ts2': builtins.list[myapp.models.Tag]})], myapp.models.Article@AnnotatedWith[TypedDict({'ts': builtins.list[myapp.models.Tag], 'ts2': builtins.list[myapp.models.Tag]})]]"
52+
reveal_type(multi_qs.get().ts) # N: Revealed type is "builtins.list[myapp.models.Tag]"
53+
reveal_type(multi_qs.get().ts2) # N: Revealed type is "builtins.list[myapp.models.Tag]"
54+
55+
# Mixed inline `Prefetch` and plain lookup string in one call; string should be ignored
56+
mixed_plain = Article.objects.prefetch_related(
57+
"article_set",
58+
Prefetch("article_set", Article.objects.all(), to_attr="arts3"),
59+
)
60+
reveal_type(mixed_plain) # N: Revealed type is "django.db.models.query.QuerySet[myapp.models.Article@AnnotatedWith[TypedDict({'arts3': builtins.list[myapp.models.Article]})], myapp.models.Article@AnnotatedWith[TypedDict({'arts3': builtins.list[myapp.models.Article]})]]"
61+
62+
63+
## Not Supported
64+
65+
# Prefetch with `to_attr` arg but without the `queryset` arg
66+
# TODO: We should be able to resolve a more accurate type using existing lookup `resolve_lookup_expected_type` machinery
67+
reveal_type(Article.objects.prefetch_related(models.Prefetch("tags", to_attr="just_tags")).get().just_tags) # N: Revealed type is "builtins.list[Any]"
68+
69+
# Intermediary variable -- function scope
70+
def foo() -> None:
71+
tag_prefetch = Prefetch("tags", Tag.objects.all(), to_attr="every_tags")
72+
reveal_type(Article.objects.prefetch_related(tag_prefetch).all()) # N: Revealed type is "django.db.models.query.QuerySet[myapp.models.Article, myapp.models.Article]"
73+
74+
# Intermediary variable -- module scope
75+
tag_prefetch = Prefetch("tags", Tag.objects.all(), to_attr="every_tags")
76+
reveal_type(Article.objects.prefetch_related(tag_prefetch).all()) # N: Revealed type is "django.db.models.query.QuerySet[myapp.models.Article, myapp.models.Article]"
77+
78+
# Mixed inline `Prefetch` and variable `Prefetch` in one call
79+
mixed_qs = Article.objects.prefetch_related(
80+
tag_prefetch,
81+
Prefetch("article_set", Article.objects.all(), to_attr="arts2"),
82+
)
83+
reveal_type(mixed_qs) # N: Revealed type is "django.db.models.query.QuerySet[myapp.models.Article@AnnotatedWith[TypedDict({'arts2': builtins.list[myapp.models.Article]})], myapp.models.Article@AnnotatedWith[TypedDict({'arts2': builtins.list[myapp.models.Article]})]]"
84+
85+
installed_apps:
86+
- myapp
87+
files:
88+
- path: myapp/__init__.py
89+
- path: myapp/models.py
90+
content: |
91+
from django.db import models
92+
93+
class Tag(models.Model): ...
94+
class Article(models.Model):
95+
tags = models.ManyToManyField(to=Tag, related_name="articles", blank=True)
96+
97+
- case: prefetch_related_and_annotate
98+
main: |
99+
from django.db.models import Prefetch, F
100+
from django.contrib.auth.models import User, Group
101+
102+
user = (
103+
User.objects
104+
.annotate(annotated_user=F("username"))
105+
.prefetch_related(Prefetch("groups", Group.objects.all(), to_attr="to_attr_groups"))
106+
.get()
107+
)
108+
reveal_type(user.annotated_user) # N: Revealed type is "Any"
109+
reveal_type(user.to_attr_groups) # N: Revealed type is "builtins.list[django.contrib.auth.models.Group]"
110+
111+
installed_apps:
112+
- django.contrib.auth

0 commit comments

Comments
 (0)