Skip to content

Support TypedDicts with missing keys (total=False) #3558

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 20 commits into from
Jun 23, 2017
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions extensions/mypy_extensions.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,16 +30,18 @@ def _dict_new(cls, *args, **kwargs):


def _typeddict_new(cls, _typename, _fields=None, **kwargs):
total = kwargs.pop('total', True)
if _fields is None:
_fields = kwargs
elif kwargs:
raise TypeError("TypedDict takes either a dict or keyword arguments,"
" but not both")
return _TypedDictMeta(_typename, (), {'__annotations__': dict(_fields)})
return _TypedDictMeta(_typename, (), {'__annotations__': dict(_fields),
'__total__': total})


class _TypedDictMeta(type):
def __new__(cls, name, bases, ns):
def __new__(cls, name, bases, ns, total=True):
# Create new typed dict class object.
# This method is called directly when TypedDict is subclassed,
# or via _typeddict_new when TypedDict is instantiated. This way
Expand All @@ -59,6 +61,8 @@ def __new__(cls, name, bases, ns):
for base in bases:
anns.update(base.__dict__.get('__annotations__', {}))
tp_dict.__annotations__ = anns
if not hasattr(tp_dict, '__total__'):
tp_dict.__total__ = total
return tp_dict

__instancecheck__ = __subclasscheck__ = _check_fails
Expand Down
28 changes: 14 additions & 14 deletions mypy/checkexpr.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,31 +292,31 @@ def check_typeddict_call_with_dict(self, callee: TypedDictType,
def check_typeddict_call_with_kwargs(self, callee: TypedDictType,
kwargs: 'OrderedDict[str, Expression]',
context: Context) -> Type:
if callee.items.keys() != kwargs.keys():
callee_item_names = callee.items.keys()
kwargs_item_names = kwargs.keys()

if not (callee.required_keys <= set(kwargs.keys()) <= set(callee.items.keys())):
expected_item_names = [key for key in callee.items.keys()
if key in callee.required_keys or key in kwargs.keys()]
actual_item_names = kwargs.keys()
self.msg.typeddict_instantiated_with_unexpected_items(
expected_item_names=list(callee_item_names),
actual_item_names=list(kwargs_item_names),
expected_item_names=list(expected_item_names),
actual_item_names=list(actual_item_names),
context=context)
return AnyType()

items = OrderedDict() # type: OrderedDict[str, Type]
for (item_name, item_expected_type) in callee.items.items():
item_value = kwargs[item_name]

self.chk.check_simple_assignment(
lvalue_type=item_expected_type, rvalue=item_value, context=item_value,
msg=messages.INCOMPATIBLE_TYPES,
lvalue_name='TypedDict item "{}"'.format(item_name),
rvalue_name='expression')
if item_name in kwargs:
item_value = kwargs[item_name]
self.chk.check_simple_assignment(
lvalue_type=item_expected_type, rvalue=item_value, context=item_value,
msg=messages.INCOMPATIBLE_TYPES,
lvalue_name='TypedDict item "{}"'.format(item_name),
rvalue_name='expression')
items[item_name] = item_expected_type

mapping_value_type = join.join_type_list(list(items.values()))
fallback = self.chk.named_generic_type('typing.Mapping',
[self.chk.str_type(), mapping_value_type])
return TypedDictType(items, fallback)
return TypedDictType(items, set(callee.required_keys), fallback)

# Types and methods that can be used to infer partial types.
item_args = {'builtins.list': ['append'],
Expand Down
5 changes: 4 additions & 1 deletion mypy/fastparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -467,12 +467,15 @@ def visit_ClassDef(self, n: ast3.ClassDef) -> ClassDef:
metaclass = stringify_name(metaclass_arg.value)
if metaclass is None:
metaclass = '<error>' # To be reported later
keywords = [(kw.arg, self.visit(kw.value))
for kw in n.keywords]

cdef = ClassDef(n.name,
self.as_block(n.body, n.lineno),
None,
self.translate_expr_list(n.bases),
metaclass=metaclass)
metaclass=metaclass,
keywords=keywords)
cdef.decorators = self.translate_expr_list(n.decorator_list)
self.class_nesting -= 1
return cdef
Expand Down
8 changes: 6 additions & 2 deletions mypy/join.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,11 +228,15 @@ def visit_typeddict_type(self, t: TypedDictType) -> Type:
items = OrderedDict([
(item_name, s_item_type)
for (item_name, s_item_type, t_item_type) in self.s.zip(t)
if is_equivalent(s_item_type, t_item_type)
if (is_equivalent(s_item_type, t_item_type) and
(item_name in t.required_keys) == (item_name in self.s.required_keys))
])
mapping_value_type = join_type_list(list(items.values()))
fallback = self.s.create_anonymous_fallback(value_type=mapping_value_type)
return TypedDictType(items, fallback)
# We need to filter by items.keys() since some required keys present in both t and
# self.s might be missing from the join if the types are incompatible.
required_keys = set(items.keys()) & t.required_keys & self.s.required_keys
return TypedDictType(items, required_keys, fallback)
elif isinstance(self.s, Instance):
return join_instances(self.s, t.fallback)
else:
Expand Down
8 changes: 5 additions & 3 deletions mypy/meet.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,8 +252,9 @@ def visit_tuple_type(self, t: TupleType) -> Type:

def visit_typeddict_type(self, t: TypedDictType) -> Type:
if isinstance(self.s, TypedDictType):
for (_, l, r) in self.s.zip(t):
if not is_equivalent(l, r):
for (name, l, r) in self.s.zip(t):
if (not is_equivalent(l, r) or
(name in t.required_keys) != (name in self.s.required_keys)):
return self.default(self.s)
item_list = [] # type: List[Tuple[str, Type]]
for (item_name, s_item_type, t_item_type) in self.s.zipall(t):
Expand All @@ -266,7 +267,8 @@ def visit_typeddict_type(self, t: TypedDictType) -> Type:
items = OrderedDict(item_list)
mapping_value_type = join_type_list(list(items.values()))
fallback = self.s.create_anonymous_fallback(value_type=mapping_value_type)
return TypedDictType(items, fallback)
required_keys = t.required_keys | self.s.required_keys
return TypedDictType(items, required_keys, fallback)
else:
return self.default(self.s)

Expand Down
6 changes: 5 additions & 1 deletion mypy/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import os
from abc import abstractmethod
from collections import OrderedDict

from typing import (
Any, TypeVar, List, Tuple, cast, Set, Dict, Union, Optional, Callable,
Expand Down Expand Up @@ -730,6 +731,7 @@ class ClassDef(Statement):
info = None # type: TypeInfo # Related TypeInfo
metaclass = '' # type: Optional[str]
decorators = None # type: List[Expression]
keywords = None # type: OrderedDict[str, Expression]
analyzed = None # type: Optional[Expression]
has_incompatible_baseclass = False

Expand All @@ -738,13 +740,15 @@ def __init__(self,
defs: 'Block',
type_vars: List['mypy.types.TypeVarDef'] = None,
base_type_exprs: List[Expression] = None,
metaclass: str = None) -> None:
metaclass: str = None,
keywords: List[Tuple[str, Expression]] = None) -> None:
self.name = name
self.defs = defs
self.type_vars = type_vars or []
self.base_type_exprs = base_type_exprs or []
self.metaclass = metaclass
self.decorators = []
self.keywords = OrderedDict(keywords or [])

def accept(self, visitor: StatementVisitor[T]) -> T:
return visitor.visit_class_def(self)
Expand Down
27 changes: 22 additions & 5 deletions mypy/plugin.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
"""Plugin system for extending mypy."""

from collections import OrderedDict
from abc import abstractmethod
from typing import Callable, List, Tuple, Optional, NamedTuple, TypeVar

from mypy.nodes import Expression, StrExpr, IntExpr, UnaryExpr, Context
from mypy.nodes import Expression, StrExpr, IntExpr, UnaryExpr, Context, DictExpr
from mypy.types import (
Type, Instance, CallableType, TypedDictType, UnionType, NoneTyp, FunctionLike, TypeVarType,
AnyType, TypeList, UnboundType
Expand Down Expand Up @@ -263,17 +264,26 @@ def typed_dict_get_signature_callback(ctx: MethodSigContext) -> CallableType:
and len(ctx.args[0]) == 1
and isinstance(ctx.args[0][0], StrExpr)
and len(signature.arg_types) == 2
and len(signature.variables) == 1):
and len(signature.variables) == 1
and len(ctx.args[1]) == 1):
key = ctx.args[0][0].value
value_type = ctx.type.items.get(key)
ret_type = signature.ret_type
if value_type:
default_arg = ctx.args[1][0]
if (isinstance(value_type, TypedDictType)
and isinstance(default_arg, DictExpr)
and len(default_arg.items) == 0):
# Caller has empty dict {} as default for typed dict.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we have to special-case this? Assuming the value type is some TypedDict, and the default passed to get() is some dict literal, isn't the natural union type resulting from the two an appropriate TypedDict? Even if it isn't, shouldn't we use the value type as a context for inferring the type of the dict literal? ISTM that d.get('x', {'y': 1}) ought to work too.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We special this so that the context will be a non-total TypedDict in case the default is {}. If we don't do that, {} won't be accepted for total TypedDicts since it has missing keys (all keys are missing).

Any subset of keys should work but this will be harder to implement and likely not very common, so I think it's okay to postpone it until later. I can create an issue to track that.

value_type = value_type.copy_modified(required_keys=set())
# Tweak the signature to include the value type as context. It's
# only needed for type inference since there's a union with a type
# variable that accepts everything.
tv = TypeVarType(signature.variables[0])
return signature.copy_modified(
arg_types=[signature.arg_types[0],
UnionType.make_simplified_union([value_type, tv])])
UnionType.make_simplified_union([value_type, tv])],
ret_type=ret_type)
return signature


Expand All @@ -288,8 +298,15 @@ def typed_dict_get_callback(ctx: MethodContext) -> Type:
if value_type:
if len(ctx.arg_types) == 1:
return UnionType.make_simplified_union([value_type, NoneTyp()])
elif len(ctx.arg_types) == 2 and len(ctx.arg_types[1]) == 1:
return UnionType.make_simplified_union([value_type, ctx.arg_types[1][0]])
elif (len(ctx.arg_types) == 2 and len(ctx.arg_types[1]) == 1
and len(ctx.args[1]) == 1):
default_arg = ctx.args[1][0]
if (isinstance(default_arg, DictExpr) and len(default_arg.items) == 0
and isinstance(value_type, TypedDictType)):
# Special case '{}' as the default for a typed dict type.
return value_type.copy_modified(required_keys=set())
else:
return UnionType.make_simplified_union([value_type, ctx.arg_types[1][0]])
else:
ctx.api.msg.typeddict_item_name_not_found(ctx.type, key, ctx.context)
return AnyType()
Expand Down
Loading