diff --git a/extensions/mypy_extensions.py b/extensions/mypy_extensions.py index 82eea32a31d8..c711e0023a0f 100644 --- a/extensions/mypy_extensions.py +++ b/extensions/mypy_extensions.py @@ -30,16 +30,18 @@ def _dict_new(cls, *args, **kwargs): def _typeddict_new(cls, _typename, _fields=None, **kwargs): + total = kwargs.pop('total', True) if _fields is None: _fields = kwargs elif kwargs: raise TypeError("TypedDict takes either a dict or keyword arguments," " but not both") - return _TypedDictMeta(_typename, (), {'__annotations__': dict(_fields)}) + return _TypedDictMeta(_typename, (), {'__annotations__': dict(_fields), + '__total__': total}) class _TypedDictMeta(type): - def __new__(cls, name, bases, ns): + def __new__(cls, name, bases, ns, total=True): # Create new typed dict class object. # This method is called directly when TypedDict is subclassed, # or via _typeddict_new when TypedDict is instantiated. This way @@ -59,6 +61,8 @@ def __new__(cls, name, bases, ns): for base in bases: anns.update(base.__dict__.get('__annotations__', {})) tp_dict.__annotations__ = anns + if not hasattr(tp_dict, '__total__'): + tp_dict.__total__ = total return tp_dict __instancecheck__ = __subclasscheck__ = _check_fails diff --git a/mypy/checkexpr.py b/mypy/checkexpr.py index 0101f8eec107..00c95cb2e6d4 100644 --- a/mypy/checkexpr.py +++ b/mypy/checkexpr.py @@ -292,31 +292,31 @@ def check_typeddict_call_with_dict(self, callee: TypedDictType, def check_typeddict_call_with_kwargs(self, callee: TypedDictType, kwargs: 'OrderedDict[str, Expression]', context: Context) -> Type: - if callee.items.keys() != kwargs.keys(): - callee_item_names = callee.items.keys() - kwargs_item_names = kwargs.keys() - + if not (callee.required_keys <= set(kwargs.keys()) <= set(callee.items.keys())): + expected_item_names = [key for key in callee.items.keys() + if key in callee.required_keys or key in kwargs.keys()] + actual_item_names = kwargs.keys() self.msg.typeddict_instantiated_with_unexpected_items( - expected_item_names=list(callee_item_names), - actual_item_names=list(kwargs_item_names), + expected_item_names=list(expected_item_names), + actual_item_names=list(actual_item_names), context=context) return AnyType() items = OrderedDict() # type: OrderedDict[str, Type] for (item_name, item_expected_type) in callee.items.items(): - item_value = kwargs[item_name] - - self.chk.check_simple_assignment( - lvalue_type=item_expected_type, rvalue=item_value, context=item_value, - msg=messages.INCOMPATIBLE_TYPES, - lvalue_name='TypedDict item "{}"'.format(item_name), - rvalue_name='expression') + if item_name in kwargs: + item_value = kwargs[item_name] + self.chk.check_simple_assignment( + lvalue_type=item_expected_type, rvalue=item_value, context=item_value, + msg=messages.INCOMPATIBLE_TYPES, + lvalue_name='TypedDict item "{}"'.format(item_name), + rvalue_name='expression') items[item_name] = item_expected_type mapping_value_type = join.join_type_list(list(items.values())) fallback = self.chk.named_generic_type('typing.Mapping', [self.chk.str_type(), mapping_value_type]) - return TypedDictType(items, fallback) + return TypedDictType(items, set(callee.required_keys), fallback) # Types and methods that can be used to infer partial types. item_args = {'builtins.list': ['append'], diff --git a/mypy/fastparse.py b/mypy/fastparse.py index 10ad642dcdf0..58835c6de810 100644 --- a/mypy/fastparse.py +++ b/mypy/fastparse.py @@ -467,12 +467,15 @@ def visit_ClassDef(self, n: ast3.ClassDef) -> ClassDef: metaclass = stringify_name(metaclass_arg.value) if metaclass is None: metaclass = '' # To be reported later + keywords = [(kw.arg, self.visit(kw.value)) + for kw in n.keywords] cdef = ClassDef(n.name, self.as_block(n.body, n.lineno), None, self.translate_expr_list(n.bases), - metaclass=metaclass) + metaclass=metaclass, + keywords=keywords) cdef.decorators = self.translate_expr_list(n.decorator_list) self.class_nesting -= 1 return cdef diff --git a/mypy/join.py b/mypy/join.py index aaaa99fa3798..0ae8c3ab4058 100644 --- a/mypy/join.py +++ b/mypy/join.py @@ -228,11 +228,15 @@ def visit_typeddict_type(self, t: TypedDictType) -> Type: items = OrderedDict([ (item_name, s_item_type) for (item_name, s_item_type, t_item_type) in self.s.zip(t) - if is_equivalent(s_item_type, t_item_type) + if (is_equivalent(s_item_type, t_item_type) and + (item_name in t.required_keys) == (item_name in self.s.required_keys)) ]) mapping_value_type = join_type_list(list(items.values())) fallback = self.s.create_anonymous_fallback(value_type=mapping_value_type) - return TypedDictType(items, fallback) + # We need to filter by items.keys() since some required keys present in both t and + # self.s might be missing from the join if the types are incompatible. + required_keys = set(items.keys()) & t.required_keys & self.s.required_keys + return TypedDictType(items, required_keys, fallback) elif isinstance(self.s, Instance): return join_instances(self.s, t.fallback) else: diff --git a/mypy/meet.py b/mypy/meet.py index 62940b08d62b..f0dcd8b56e34 100644 --- a/mypy/meet.py +++ b/mypy/meet.py @@ -252,8 +252,9 @@ def visit_tuple_type(self, t: TupleType) -> Type: def visit_typeddict_type(self, t: TypedDictType) -> Type: if isinstance(self.s, TypedDictType): - for (_, l, r) in self.s.zip(t): - if not is_equivalent(l, r): + for (name, l, r) in self.s.zip(t): + if (not is_equivalent(l, r) or + (name in t.required_keys) != (name in self.s.required_keys)): return self.default(self.s) item_list = [] # type: List[Tuple[str, Type]] for (item_name, s_item_type, t_item_type) in self.s.zipall(t): @@ -266,7 +267,8 @@ def visit_typeddict_type(self, t: TypedDictType) -> Type: items = OrderedDict(item_list) mapping_value_type = join_type_list(list(items.values())) fallback = self.s.create_anonymous_fallback(value_type=mapping_value_type) - return TypedDictType(items, fallback) + required_keys = t.required_keys | self.s.required_keys + return TypedDictType(items, required_keys, fallback) else: return self.default(self.s) diff --git a/mypy/nodes.py b/mypy/nodes.py index 1b58997b65a2..8aec3d6bba3c 100644 --- a/mypy/nodes.py +++ b/mypy/nodes.py @@ -2,6 +2,7 @@ import os from abc import abstractmethod +from collections import OrderedDict from typing import ( Any, TypeVar, List, Tuple, cast, Set, Dict, Union, Optional, Callable, @@ -730,6 +731,7 @@ class ClassDef(Statement): info = None # type: TypeInfo # Related TypeInfo metaclass = '' # type: Optional[str] decorators = None # type: List[Expression] + keywords = None # type: OrderedDict[str, Expression] analyzed = None # type: Optional[Expression] has_incompatible_baseclass = False @@ -738,13 +740,15 @@ def __init__(self, defs: 'Block', type_vars: List['mypy.types.TypeVarDef'] = None, base_type_exprs: List[Expression] = None, - metaclass: str = None) -> None: + metaclass: str = None, + keywords: List[Tuple[str, Expression]] = None) -> None: self.name = name self.defs = defs self.type_vars = type_vars or [] self.base_type_exprs = base_type_exprs or [] self.metaclass = metaclass self.decorators = [] + self.keywords = OrderedDict(keywords or []) def accept(self, visitor: StatementVisitor[T]) -> T: return visitor.visit_class_def(self) diff --git a/mypy/plugin.py b/mypy/plugin.py index f94790a06e96..37e516bc5030 100644 --- a/mypy/plugin.py +++ b/mypy/plugin.py @@ -1,9 +1,10 @@ """Plugin system for extending mypy.""" +from collections import OrderedDict from abc import abstractmethod from typing import Callable, List, Tuple, Optional, NamedTuple, TypeVar -from mypy.nodes import Expression, StrExpr, IntExpr, UnaryExpr, Context +from mypy.nodes import Expression, StrExpr, IntExpr, UnaryExpr, Context, DictExpr from mypy.types import ( Type, Instance, CallableType, TypedDictType, UnionType, NoneTyp, FunctionLike, TypeVarType, AnyType, TypeList, UnboundType @@ -263,17 +264,26 @@ def typed_dict_get_signature_callback(ctx: MethodSigContext) -> CallableType: and len(ctx.args[0]) == 1 and isinstance(ctx.args[0][0], StrExpr) and len(signature.arg_types) == 2 - and len(signature.variables) == 1): + and len(signature.variables) == 1 + and len(ctx.args[1]) == 1): key = ctx.args[0][0].value value_type = ctx.type.items.get(key) + ret_type = signature.ret_type if value_type: + default_arg = ctx.args[1][0] + if (isinstance(value_type, TypedDictType) + and isinstance(default_arg, DictExpr) + and len(default_arg.items) == 0): + # Caller has empty dict {} as default for typed dict. + value_type = value_type.copy_modified(required_keys=set()) # Tweak the signature to include the value type as context. It's # only needed for type inference since there's a union with a type # variable that accepts everything. tv = TypeVarType(signature.variables[0]) return signature.copy_modified( arg_types=[signature.arg_types[0], - UnionType.make_simplified_union([value_type, tv])]) + UnionType.make_simplified_union([value_type, tv])], + ret_type=ret_type) return signature @@ -288,8 +298,15 @@ def typed_dict_get_callback(ctx: MethodContext) -> Type: if value_type: if len(ctx.arg_types) == 1: return UnionType.make_simplified_union([value_type, NoneTyp()]) - elif len(ctx.arg_types) == 2 and len(ctx.arg_types[1]) == 1: - return UnionType.make_simplified_union([value_type, ctx.arg_types[1][0]]) + elif (len(ctx.arg_types) == 2 and len(ctx.arg_types[1]) == 1 + and len(ctx.args[1]) == 1): + default_arg = ctx.args[1][0] + if (isinstance(default_arg, DictExpr) and len(default_arg.items) == 0 + and isinstance(value_type, TypedDictType)): + # Special case '{}' as the default for a typed dict type. + return value_type.copy_modified(required_keys=set()) + else: + return UnionType.make_simplified_union([value_type, ctx.arg_types[1][0]]) else: ctx.api.msg.typeddict_item_name_not_found(ctx.type, key, ctx.context) return AnyType() diff --git a/mypy/semanal.py b/mypy/semanal.py index a922be0568dc..c6749d25a92c 100644 --- a/mypy/semanal.py +++ b/mypy/semanal.py @@ -666,6 +666,7 @@ def visit_class_def(self, defn: ClassDef) -> None: def analyze_class_body(self, defn: ClassDef) -> Iterator[bool]: with self.tvar_scope_frame(self.tvar_scope.class_frame()): self.clean_up_bases_and_infer_type_variables(defn) + self.analyze_class_keywords(defn) if self.analyze_typeddict_classdef(defn): yield False return @@ -715,6 +716,10 @@ def analyze_class_body(self, defn: ClassDef) -> Iterator[bool]: self.leave_class() + def analyze_class_keywords(self, defn: ClassDef) -> None: + for value in defn.keywords.values(): + value.accept(self) + def enter_class(self, info: TypeInfo) -> None: # Remember previous active class self.type_stack.append(self.type) @@ -1213,8 +1218,8 @@ def analyze_typeddict_classdef(self, defn: ClassDef) -> bool: isinstance(defn.base_type_exprs[0], RefExpr) and defn.base_type_exprs[0].fullname == 'mypy_extensions.TypedDict'): # Building a new TypedDict - fields, types = self.check_typeddict_classdef(defn) - info = self.build_typeddict_typeinfo(defn.name, fields, types) + fields, types, required_keys = self.check_typeddict_classdef(defn) + info = self.build_typeddict_typeinfo(defn.name, fields, types, required_keys) node.node = info defn.analyzed = TypedDictExpr(info) return True @@ -1224,38 +1229,43 @@ def analyze_typeddict_classdef(self, defn: ClassDef) -> bool: not self.is_typeddict(expr) for expr in defn.base_type_exprs): self.fail("All bases of a new TypedDict must be TypedDict types", defn) typeddict_bases = list(filter(self.is_typeddict, defn.base_type_exprs)) - newfields = [] # type: List[str] - newtypes = [] # type: List[Type] - tpdict = None # type: OrderedDict[str, Type] + keys = [] # type: List[str] + types = [] + required_keys = set() for base in typeddict_bases: assert isinstance(base, RefExpr) assert isinstance(base.node, TypeInfo) assert isinstance(base.node.typeddict_type, TypedDictType) - tpdict = base.node.typeddict_type.items - newdict = tpdict.copy() - for key in tpdict: - if key in newfields: + base_typed_dict = base.node.typeddict_type + base_items = base_typed_dict.items + valid_items = base_items.copy() + for key in base_items: + if key in keys: self.fail('Cannot overwrite TypedDict field "{}" while merging' .format(key), defn) - newdict.pop(key) - newfields.extend(newdict.keys()) - newtypes.extend(newdict.values()) - fields, types = self.check_typeddict_classdef(defn, newfields) - newfields.extend(fields) - newtypes.extend(types) - info = self.build_typeddict_typeinfo(defn.name, newfields, newtypes) + valid_items.pop(key) + keys.extend(valid_items.keys()) + types.extend(valid_items.values()) + required_keys.update(base_typed_dict.required_keys) + new_keys, new_types, new_required_keys = self.check_typeddict_classdef(defn, keys) + keys.extend(new_keys) + types.extend(new_types) + required_keys.update(new_required_keys) + info = self.build_typeddict_typeinfo(defn.name, keys, types, required_keys) node.node = info defn.analyzed = TypedDictExpr(info) return True return False def check_typeddict_classdef(self, defn: ClassDef, - oldfields: List[str] = None) -> Tuple[List[str], List[Type]]: + oldfields: List[str] = None) -> Tuple[List[str], + List[Type], + Set[str]]: TPDICT_CLASS_ERROR = ('Invalid statement in TypedDict definition; ' 'expected "field_name: field_type"') if self.options.python_version < (3, 6): self.fail('TypedDict class syntax is only supported in Python 3.6', defn) - return [], [] + return [], [], set() fields = [] # type: List[str] types = [] # type: List[Type] for stmt in defn.defs.body: @@ -1286,7 +1296,14 @@ def check_typeddict_classdef(self, defn: ClassDef, elif not isinstance(stmt.rvalue, TempNode): # x: int assigns rvalue to TempNode(AnyType()) self.fail('Right hand side values are not supported in TypedDict', stmt) - return fields, types + total = True + if 'total' in defn.keywords: + total = self.parse_bool(defn.keywords['total']) + if total is None: + self.fail('Value of "total" must be True or False', defn) + total = True + required_keys = set(fields) if total else set() + return fields, types, required_keys def visit_import(self, i: Import) -> None: for id, as_id in i.ids: @@ -2320,46 +2337,65 @@ def check_typeddict(self, node: Expression, var_name: str = None) -> Optional[Ty fullname = callee.fullname if fullname != 'mypy_extensions.TypedDict': return None - items, types, ok = self.parse_typeddict_args(call, fullname) + items, types, total, ok = self.parse_typeddict_args(call, fullname) if not ok: # Error. Construct dummy return value. - return self.build_typeddict_typeinfo('TypedDict', [], []) - name = cast(StrExpr, call.args[0]).value - if name != var_name or self.is_func_scope(): - # Give it a unique name derived from the line number. - name += '@' + str(call.line) - info = self.build_typeddict_typeinfo(name, items, types) - # Store it as a global just in case it would remain anonymous. - # (Or in the nearest class if there is one.) - stnode = SymbolTableNode(GDEF, info, self.cur_mod_id) - if self.type: - self.type.names[name] = stnode + info = self.build_typeddict_typeinfo('TypedDict', [], [], set()) else: - self.globals[name] = stnode + name = cast(StrExpr, call.args[0]).value + if name != var_name or self.is_func_scope(): + # Give it a unique name derived from the line number. + name += '@' + str(call.line) + required_keys = set(items) if total else set() + info = self.build_typeddict_typeinfo(name, items, types, required_keys) + # Store it as a global just in case it would remain anonymous. + # (Or in the nearest class if there is one.) + stnode = SymbolTableNode(GDEF, info, self.cur_mod_id) + if self.type: + self.type.names[name] = stnode + else: + self.globals[name] = stnode call.analyzed = TypedDictExpr(info) call.analyzed.set_line(call.line, call.column) return info def parse_typeddict_args(self, call: CallExpr, - fullname: str) -> Tuple[List[str], List[Type], bool]: + fullname: str) -> Tuple[List[str], List[Type], bool, bool]: # TODO: Share code with check_argument_count in checkexpr.py? args = call.args if len(args) < 2: return self.fail_typeddict_arg("Too few arguments for TypedDict()", call) - if len(args) > 2: + if len(args) > 3: return self.fail_typeddict_arg("Too many arguments for TypedDict()", call) # TODO: Support keyword arguments - if call.arg_kinds != [ARG_POS, ARG_POS]: + if call.arg_kinds not in ([ARG_POS, ARG_POS], [ARG_POS, ARG_POS, ARG_NAMED]): return self.fail_typeddict_arg("Unexpected arguments to TypedDict()", call) + if len(args) == 3 and call.arg_names[2] != 'total': + return self.fail_typeddict_arg( + 'Unexpected keyword argument "{}" for "TypedDict"'.format(call.arg_names[2]), call) if not isinstance(args[0], (StrExpr, BytesExpr, UnicodeExpr)): return self.fail_typeddict_arg( "TypedDict() expects a string literal as the first argument", call) if not isinstance(args[1], DictExpr): return self.fail_typeddict_arg( "TypedDict() expects a dictionary literal as the second argument", call) + total = True + if len(args) == 3: + total = self.parse_bool(call.args[2]) + if total is None: + return self.fail_typeddict_arg( + 'TypedDict() "total" argument must be True or False', call) dictexpr = args[1] items, types, ok = self.parse_typeddict_fields_with_types(dictexpr.items, call) - return items, types, ok + return items, types, total, ok + + def parse_bool(self, expr: Expression) -> Optional[bool]: + if isinstance(expr, NameExpr): + if expr.fullname == 'builtins.True': + return True + if expr.fullname == 'builtins.False': + return False + return None def parse_typeddict_fields_with_types(self, dict_items: List[Tuple[Expression, Expression]], context: Context) -> Tuple[List[str], List[Type], bool]: @@ -2369,21 +2405,24 @@ def parse_typeddict_fields_with_types(self, dict_items: List[Tuple[Expression, E if isinstance(field_name_expr, (StrExpr, BytesExpr, UnicodeExpr)): items.append(field_name_expr.value) else: - return self.fail_typeddict_arg("Invalid TypedDict() field name", field_name_expr) + self.fail_typeddict_arg("Invalid TypedDict() field name", field_name_expr) + return [], [], False try: type = expr_to_unanalyzed_type(field_type_expr) except TypeTranslationError: - return self.fail_typeddict_arg('Invalid field type', field_type_expr) + self.fail_typeddict_arg('Invalid field type', field_type_expr) + return [], [], False types.append(self.anal_type(type)) return items, types, True def fail_typeddict_arg(self, message: str, - context: Context) -> Tuple[List[str], List[Type], bool]: + context: Context) -> Tuple[List[str], List[Type], bool, bool]: self.fail(message, context) - return [], [], False + return [], [], True, False def build_typeddict_typeinfo(self, name: str, items: List[str], - types: List[Type]) -> TypeInfo: + types: List[Type], + required_keys: Set[str]) -> TypeInfo: fallback = (self.named_type_or_none('typing.Mapping', [self.str_type(), self.object_type()]) or self.object_type()) @@ -2398,8 +2437,8 @@ def patch() -> None: self.patches.append(patch) info = self.basic_new_typeinfo(name, fallback) - info.typeddict_type = TypedDictType(OrderedDict(zip(items, types)), fallback) - + info.typeddict_type = TypedDictType(OrderedDict(zip(items, types)), required_keys, + fallback) return info def check_classvar(self, s: AssignmentStmt) -> None: diff --git a/mypy/subtypes.py b/mypy/subtypes.py index b03843fba9a4..ebd6a1d13d3c 100644 --- a/mypy/subtypes.py +++ b/mypy/subtypes.py @@ -228,9 +228,20 @@ def visit_typeddict_type(self, left: TypedDictType) -> bool: elif isinstance(right, TypedDictType): if not left.names_are_wider_than(right): return False - for (_, l, r) in left.zip(right): + for name, l, r in left.zip(right): if not is_equivalent(l, r, self.check_type_parameter): return False + # Non-required key is not compatible with a required key since + # indexing may fail unexpectedly if a required key is missing. + # Required key is not compatible with a non-required key since + # the prior doesn't support 'del' but the latter should support + # it. + # + # NOTE: 'del' support is currently not implemented (#3550). We + # don't want to have to change subtyping after 'del' support + # lands so here we are anticipating that change. + if (name in left.required_keys) != (name in right.required_keys): + return False # (NOTE: Fallbacks don't matter.) return True else: diff --git a/mypy/test/testextensions.py b/mypy/test/testextensions.py index af3916f98e19..2203cf814f00 100644 --- a/mypy/test/testextensions.py +++ b/mypy/test/testextensions.py @@ -37,6 +37,10 @@ class Point2D(TypedDict): y: int class LabelPoint2D(Point2D, Label): ... + +class Options(TypedDict, total=False): + log_level: int + log_path: str """ if PY36: @@ -58,6 +62,7 @@ def test_basics_iterable_syntax(self): self.assertEqual(Emp.__module__, 'mypy.test.testextensions') self.assertEqual(Emp.__bases__, (dict,)) self.assertEqual(Emp.__annotations__, {'name': str, 'id': int}) + self.assertEqual(Emp.__total__, True) def test_basics_keywords_syntax(self): Emp = TypedDict('Emp', name=str, id=int) @@ -72,6 +77,7 @@ def test_basics_keywords_syntax(self): self.assertEqual(Emp.__module__, 'mypy.test.testextensions') self.assertEqual(Emp.__bases__, (dict,)) self.assertEqual(Emp.__annotations__, {'name': str, 'id': int}) + self.assertEqual(Emp.__total__, True) def test_typeddict_errors(self): Emp = TypedDict('Emp', {'name': str, 'id': int}) @@ -94,6 +100,7 @@ def test_typeddict_errors(self): def test_py36_class_syntax_usage(self): self.assertEqual(LabelPoint2D.__annotations__, {'x': int, 'y': int, 'label': str}) # noqa self.assertEqual(LabelPoint2D.__bases__, (dict,)) # noqa + self.assertEqual(LabelPoint2D.__total__, True) # noqa self.assertNotIsSubclass(LabelPoint2D, typing.Sequence) # noqa not_origin = Point2D(x=0, y=1) # noqa self.assertEqual(not_origin['x'], 0) @@ -120,6 +127,17 @@ def test_optional(self): self.assertEqual(typing.Optional[EmpD], typing.Union[None, EmpD]) self.assertNotEqual(typing.List[EmpD], typing.Tuple[EmpD]) + def test_total(self): + D = TypedDict('D', {'x': int}, total=False) + self.assertEqual(D(), {}) + self.assertEqual(D(x=1), {'x': 1}) + self.assertEqual(D.__total__, False) + + if PY36: + self.assertEqual(Options(), {}) # noqa + self.assertEqual(Options(log_level=2), {'log_level': 2}) # noqa + self.assertEqual(Options.__total__, False) # noqa + if __name__ == '__main__': main() diff --git a/mypy/typeanal.py b/mypy/typeanal.py index 4dc6dc1bd02f..00b5c9fc52c9 100644 --- a/mypy/typeanal.py +++ b/mypy/typeanal.py @@ -375,7 +375,7 @@ def visit_typeddict_type(self, t: TypedDictType) -> Type: (item_name, self.anal_type(item_type)) for (item_name, item_type) in t.items.items() ]) - return TypedDictType(items, t.fallback) + return TypedDictType(items, set(t.required_keys), t.fallback) def visit_star_type(self, t: StarType) -> Type: return StarType(self.anal_type(t.type), t.line) diff --git a/mypy/types.py b/mypy/types.py index 5f5db31e377e..d9ae8bf5fc4d 100644 --- a/mypy/types.py +++ b/mypy/types.py @@ -900,12 +900,14 @@ class TypedDictType(Type): whose TypeInfo has a typeddict_type that is anonymous. """ - items = None # type: OrderedDict[str, Type] # (item_name, item_type) + items = None # type: OrderedDict[str, Type] # item_name -> item_type + required_keys = None # type: Set[str] fallback = None # type: Instance - def __init__(self, items: 'OrderedDict[str, Type]', fallback: Instance, - line: int = -1, column: int = -1) -> None: + def __init__(self, items: 'OrderedDict[str, Type]', required_keys: Set[str], + fallback: Instance, line: int = -1, column: int = -1) -> None: self.items = items + self.required_keys = required_keys self.fallback = fallback self.can_be_true = len(self.items) > 0 self.can_be_false = len(self.items) == 0 @@ -917,6 +919,7 @@ def accept(self, visitor: 'TypeVisitor[T]') -> T: def serialize(self) -> JsonDict: return {'.class': 'TypedDictType', 'items': [[n, t.serialize()] for (n, t) in self.items.items()], + 'required_keys': sorted(self.required_keys), 'fallback': self.fallback.serialize(), } @@ -925,6 +928,7 @@ def deserialize(cls, data: JsonDict) -> 'TypedDictType': assert data['.class'] == 'TypedDictType' return TypedDictType(OrderedDict([(n, deserialize_type(t)) for (n, t) in data['items']]), + set(data['required_keys']), Instance.deserialize(data['fallback'])) def as_anonymous(self) -> 'TypedDictType': @@ -934,14 +938,17 @@ def as_anonymous(self) -> 'TypedDictType': return self.fallback.type.typeddict_type.as_anonymous() def copy_modified(self, *, fallback: Instance = None, - item_types: List[Type] = None) -> 'TypedDictType': + item_types: List[Type] = None, + required_keys: Set[str] = None) -> 'TypedDictType': if fallback is None: fallback = self.fallback if item_types is None: items = self.items else: items = OrderedDict(zip(self.items, item_types)) - return TypedDictType(items, fallback, self.line, self.column) + if required_keys is None: + required_keys = self.required_keys + return TypedDictType(items, required_keys, fallback, self.line, self.column) def create_anonymous_fallback(self, *, value_type: Type) -> Instance: anonymous = self.as_anonymous() @@ -1350,6 +1357,7 @@ def visit_typeddict_type(self, t: TypedDictType) -> Type: for (item_name, item_type) in t.items.items() ]) return TypedDictType(items, + t.required_keys, # TODO: This appears to be unsafe. cast(Any, t.fallback.accept(self)), t.line, t.column) @@ -1495,11 +1503,17 @@ def visit_tuple_type(self, t: TupleType) -> str: def visit_typeddict_type(self, t: TypedDictType) -> str: s = self.keywords_str(t.items.items()) + if t.required_keys == set(t.items): + keys_str = '' + elif t.required_keys == set(): + keys_str = ', _total=False' + else: + keys_str = ', _required_keys=[{}]'.format(', '.join(sorted(t.required_keys))) if t.fallback and t.fallback.type: if s == '': - return 'TypedDict(_fallback={})'.format(t.fallback.accept(self)) + return 'TypedDict(_fallback={}{})'.format(t.fallback.accept(self), keys_str) else: - return 'TypedDict({}, _fallback={})'.format(s, t.fallback.accept(self)) + return 'TypedDict({}, _fallback={}{})'.format(s, t.fallback.accept(self), keys_str) return 'TypedDict({})'.format(s) def visit_star_type(self, t: StarType) -> str: diff --git a/test-data/unit/check-serialize.test b/test-data/unit/check-serialize.test index 9576d95c7eef..bf0faebe5880 100644 --- a/test-data/unit/check-serialize.test +++ b/test-data/unit/check-serialize.test @@ -1026,6 +1026,19 @@ main:2: error: Revealed type is 'TypedDict(x=builtins.int, _fallback=typing.Mapp main:3: error: Revealed type is 'TypedDict(x=builtins.int, _fallback=ntcrash.C.A@4)' main:4: error: Revealed type is 'def () -> ntcrash.C.A@4' +[case testSerializeNonTotalTypedDict] +from m import d +reveal_type(d) +[file m.py] +from mypy_extensions import TypedDict +D = TypedDict('D', {'x': int, 'y': str}, total=False) +d: D +[builtins fixtures/dict.pyi] +[out1] +main:2: error: Revealed type is 'TypedDict(x=builtins.int, y=builtins.str, _fallback=m.D, _total=False)' +[out2] +main:2: error: Revealed type is 'TypedDict(x=builtins.int, y=builtins.str, _fallback=m.D, _total=False)' + -- -- Modules -- diff --git a/test-data/unit/check-typeddict.test b/test-data/unit/check-typeddict.test index 6a16d13bc468..d50cf27344cc 100644 --- a/test-data/unit/check-typeddict.test +++ b/test-data/unit/check-typeddict.test @@ -520,6 +520,40 @@ def g(x: X, y: I) -> None: pass reveal_type(f(g)) # E: Revealed type is '' [builtins fixtures/dict.pyi] +[case testMeetOfTypedDictsWithNonTotal] +from mypy_extensions import TypedDict +from typing import TypeVar, Callable +XY = TypedDict('XY', {'x': int, 'y': int}, total=False) +YZ = TypedDict('YZ', {'y': int, 'z': int}, total=False) +T = TypeVar('T') +def f(x: Callable[[T, T], None]) -> T: pass +def g(x: XY, y: YZ) -> None: pass +reveal_type(f(g)) # E: Revealed type is 'TypedDict(x=builtins.int, y=builtins.int, z=builtins.int, _fallback=typing.Mapping[builtins.str, builtins.int], _total=False)' +[builtins fixtures/dict.pyi] + +[case testMeetOfTypedDictsWithNonTotalAndTotal] +from mypy_extensions import TypedDict +from typing import TypeVar, Callable +XY = TypedDict('XY', {'x': int}, total=False) +YZ = TypedDict('YZ', {'y': int, 'z': int}) +T = TypeVar('T') +def f(x: Callable[[T, T], None]) -> T: pass +def g(x: XY, y: YZ) -> None: pass +reveal_type(f(g)) # E: Revealed type is 'TypedDict(x=builtins.int, y=builtins.int, z=builtins.int, _fallback=typing.Mapping[builtins.str, builtins.int], _required_keys=[y, z])' +[builtins fixtures/dict.pyi] + +[case testMeetOfTypedDictsWithIncompatibleNonTotalAndTotal] +# flags: --strict-optional +from mypy_extensions import TypedDict +from typing import TypeVar, Callable +XY = TypedDict('XY', {'x': int, 'y': int}, total=False) +YZ = TypedDict('YZ', {'y': int, 'z': int}) +T = TypeVar('T') +def f(x: Callable[[T, T], None]) -> T: pass +def g(x: XY, y: YZ) -> None: pass +reveal_type(f(g)) # E: Revealed type is '' +[builtins fixtures/dict.pyi] + -- Constraint Solver @@ -802,7 +836,6 @@ D = TypedDict('D', {'x': int, 'y': str}) E = TypedDict('E', {'d': D}) p = E(d=D(x=0, y='')) reveal_type(p.get('d', {'x': 1, 'y': ''})) # E: Revealed type is 'TypedDict(x=builtins.int, y=builtins.str, _fallback=__main__.D)' -p.get('d', {}) # E: Expected TypedDict keys ('x', 'y') but found no keys [builtins fixtures/dict.pyi] [typing fixtures/typing-full.pyi] @@ -814,6 +847,190 @@ p.get('x', 1 + 'y') # E: Unsupported operand types for + ("int" and "str") [builtins fixtures/dict.pyi] [typing fixtures/typing-full.pyi] +[case testTypedDictChainedGetWithEmptyDictDefault] +# flags: --strict-optional +from mypy_extensions import TypedDict +C = TypedDict('C', {'a': int}) +D = TypedDict('D', {'x': C, 'y': str}) +d: D +reveal_type(d.get('x', {})) \ + # E: Revealed type is 'TypedDict(a=builtins.int, _fallback=__main__.C, _total=False)' +reveal_type(d.get('x', None)) \ + # E: Revealed type is 'Union[TypedDict(a=builtins.int, _fallback=__main__.C), builtins.None]' +reveal_type(d.get('x', {}).get('a')) # E: Revealed type is 'Union[builtins.int, builtins.None]' +reveal_type(d.get('x', {})['a']) # E: Revealed type is 'builtins.int' +[builtins fixtures/dict.pyi] +[typing fixtures/typing-full.pyi] + + +-- Totality (the "total" keyword argument) + +[case testTypedDictWithTotalTrue] +from mypy_extensions import TypedDict +D = TypedDict('D', {'x': int, 'y': str}, total=True) +d: D +reveal_type(d) \ + # E: Revealed type is 'TypedDict(x=builtins.int, y=builtins.str, _fallback=__main__.D)' +[builtins fixtures/dict.pyi] + +[case testTypedDictWithInvalidTotalArgument] +from mypy_extensions import TypedDict +A = TypedDict('A', {'x': int}, total=0) # E: TypedDict() "total" argument must be True or False +B = TypedDict('B', {'x': int}, total=bool) # E: TypedDict() "total" argument must be True or False +C = TypedDict('C', {'x': int}, x=False) # E: Unexpected keyword argument "x" for "TypedDict" +D = TypedDict('D', {'x': int}, False) # E: Unexpected arguments to TypedDict() +[builtins fixtures/dict.pyi] + +[case testTypedDictWithTotalFalse] +from mypy_extensions import TypedDict +D = TypedDict('D', {'x': int, 'y': str}, total=False) +def f(d: D) -> None: + reveal_type(d) # E: Revealed type is 'TypedDict(x=builtins.int, y=builtins.str, _fallback=__main__.D, _total=False)' +f({}) +f({'x': 1}) +f({'y': ''}) +f({'x': 1, 'y': ''}) +f({'x': 1, 'z': ''}) # E: Expected TypedDict key 'x' but found keys ('x', 'z') +f({'x': ''}) # E: Incompatible types (expression has type "str", TypedDict item "x" has type "int") +[builtins fixtures/dict.pyi] + +[case testTypedDictConstructorWithTotalFalse] +from mypy_extensions import TypedDict +D = TypedDict('D', {'x': int, 'y': str}, total=False) +def f(d: D) -> None: pass +reveal_type(D()) # E: Revealed type is 'TypedDict(x=builtins.int, y=builtins.str, _fallback=typing.Mapping[builtins.str, builtins.object], _total=False)' +reveal_type(D(x=1)) # E: Revealed type is 'TypedDict(x=builtins.int, y=builtins.str, _fallback=typing.Mapping[builtins.str, builtins.object], _total=False)' +f(D(y='')) +f(D(x=1, y='')) +f(D(x=1, z='')) # E: Expected TypedDict key 'x' but found keys ('x', 'z') +f(D(x='')) # E: Incompatible types (expression has type "str", TypedDict item "x" has type "int") +[builtins fixtures/dict.pyi] + +[case testTypedDictIndexingWithNonRequiredKey] +from mypy_extensions import TypedDict +D = TypedDict('D', {'x': int, 'y': str}, total=False) +d: D +reveal_type(d['x']) # E: Revealed type is 'builtins.int' +reveal_type(d['y']) # E: Revealed type is 'builtins.str' +reveal_type(d.get('x')) # E: Revealed type is 'builtins.int' +reveal_type(d.get('y')) # E: Revealed type is 'builtins.str' +[builtins fixtures/dict.pyi] +[typing fixtures/typing-full.pyi] + +[case testTypedDictSubtypingWithTotalFalse] +from mypy_extensions import TypedDict +A = TypedDict('A', {'x': int}) +B = TypedDict('B', {'x': int}, total=False) +C = TypedDict('C', {'x': int, 'y': str}, total=False) +def fa(a: A) -> None: pass +def fb(b: B) -> None: pass +def fc(c: C) -> None: pass +a: A +b: B +c: C +fb(b) +fc(c) +fb(c) +fb(a) # E: Argument 1 to "fb" has incompatible type "A"; expected "B" +fa(b) # E: Argument 1 to "fa" has incompatible type "B"; expected "A" +fc(b) # E: Argument 1 to "fc" has incompatible type "B"; expected "C" +[builtins fixtures/dict.pyi] + +[case testTypedDictJoinWithTotalFalse] +from typing import TypeVar +from mypy_extensions import TypedDict +A = TypedDict('A', {'x': int}) +B = TypedDict('B', {'x': int}, total=False) +C = TypedDict('C', {'x': int, 'y': str}, total=False) +T = TypeVar('T') +def j(x: T, y: T) -> T: return x +a: A +b: B +c: C +reveal_type(j(a, b)) \ + # E: Revealed type is 'TypedDict(_fallback=typing.Mapping[builtins.str, ])' +reveal_type(j(b, b)) \ + # E: Revealed type is 'TypedDict(x=builtins.int, _fallback=typing.Mapping[builtins.str, builtins.int], _total=False)' +reveal_type(j(c, c)) \ + # E: Revealed type is 'TypedDict(x=builtins.int, y=builtins.str, _fallback=typing.Mapping[builtins.str, builtins.object], _total=False)' +reveal_type(j(b, c)) \ + # E: Revealed type is 'TypedDict(x=builtins.int, _fallback=typing.Mapping[builtins.str, builtins.int], _total=False)' +reveal_type(j(c, b)) \ + # E: Revealed type is 'TypedDict(x=builtins.int, _fallback=typing.Mapping[builtins.str, builtins.int], _total=False)' +[builtins fixtures/dict.pyi] + +[case testTypedDictClassWithTotalArgument] +from mypy_extensions import TypedDict +class D(TypedDict, total=False): + x: int + y: str +d: D +reveal_type(d) # E: Revealed type is 'TypedDict(x=builtins.int, y=builtins.str, _fallback=__main__.D, _total=False)' +[builtins fixtures/dict.pyi] + +[case testTypedDictClassWithInvalidTotalArgument] +from mypy_extensions import TypedDict +class D(TypedDict, total=1): # E: Value of "total" must be True or False + x: int +class E(TypedDict, total=bool): # E: Value of "total" must be True or False + x: int +class F(TypedDict, total=xyz): # E: Value of "total" must be True or False \ + # E: Name 'xyz' is not defined + x: int +[builtins fixtures/dict.pyi] + +[case testTypedDictClassInheritanceWithTotalArgument] +from mypy_extensions import TypedDict +class A(TypedDict): + x: int +class B(TypedDict, A, total=False): + y: int +class C(TypedDict, B, total=True): + z: str +c: C +reveal_type(c) # E: Revealed type is 'TypedDict(x=builtins.int, y=builtins.int, z=builtins.str, _fallback=__main__.C, _required_keys=[x, z])' +[builtins fixtures/dict.pyi] + + +-- Create Type (Errors) + +[case testCannotCreateTypedDictTypeWithTooFewArguments] +from mypy_extensions import TypedDict +Point = TypedDict('Point') # E: Too few arguments for TypedDict() +[builtins fixtures/dict.pyi] + +[case testCannotCreateTypedDictTypeWithTooManyArguments] +from mypy_extensions import TypedDict +Point = TypedDict('Point', {'x': int, 'y': int}, dict) # E: Unexpected arguments to TypedDict() +[builtins fixtures/dict.pyi] + +[case testCannotCreateTypedDictTypeWithInvalidName] +from mypy_extensions import TypedDict +Point = TypedDict(dict, {'x': int, 'y': int}) # E: TypedDict() expects a string literal as the first argument +[builtins fixtures/dict.pyi] + +[case testCannotCreateTypedDictTypeWithInvalidItems] +from mypy_extensions import TypedDict +Point = TypedDict('Point', {'x'}) # E: TypedDict() expects a dictionary literal as the second argument +[builtins fixtures/dict.pyi] + +-- NOTE: The following code works at runtime but is not yet supported by mypy. +-- Keyword arguments may potentially be supported in the future. +[case testCannotCreateTypedDictTypeWithNonpositionalArgs] +from mypy_extensions import TypedDict +Point = TypedDict(typename='Point', fields={'x': int, 'y': int}) # E: Unexpected arguments to TypedDict() +[builtins fixtures/dict.pyi] + +[case testCannotCreateTypedDictTypeWithInvalidItemName] +from mypy_extensions import TypedDict +Point = TypedDict('Point', {int: int, int: int}) # E: Invalid TypedDict() field name +[builtins fixtures/dict.pyi] + +[case testCannotCreateTypedDictTypeWithInvalidItemType] +from mypy_extensions import TypedDict +Point = TypedDict('Point', {'x': 1, 'y': 1}) # E: Invalid field type +[builtins fixtures/dict.pyi] + -- Special cases diff --git a/test-data/unit/lib-stub/mypy_extensions.pyi b/test-data/unit/lib-stub/mypy_extensions.pyi index fa540b99f4cd..a604c9684eeb 100644 --- a/test-data/unit/lib-stub/mypy_extensions.pyi +++ b/test-data/unit/lib-stub/mypy_extensions.pyi @@ -16,6 +16,6 @@ def VarArg(type: _T = ...) -> _T: ... def KwArg(type: _T = ...) -> _T: ... -def TypedDict(typename: str, fields: Dict[str, Type[_T]]) -> Type[dict]: ... +def TypedDict(typename: str, fields: Dict[str, Type[_T]], *, total: Any = ...) -> Type[dict]: ... class NoReturn: pass diff --git a/test-data/unit/semanal-typeddict.test b/test-data/unit/semanal-typeddict.test index ab6c428752d6..9c1454e49a06 100644 --- a/test-data/unit/semanal-typeddict.test +++ b/test-data/unit/semanal-typeddict.test @@ -34,43 +34,3 @@ MypyFile:1( AssignmentStmt:2( NameExpr(Point* [__main__.Point]) TypedDictExpr:2(Point))) - - --- Create Type (Errors) - -[case testCannotCreateTypedDictTypeWithTooFewArguments] -from mypy_extensions import TypedDict -Point = TypedDict('Point') # E: Too few arguments for TypedDict() -[builtins fixtures/dict.pyi] - -[case testCannotCreateTypedDictTypeWithTooManyArguments] -from mypy_extensions import TypedDict -Point = TypedDict('Point', {'x': int, 'y': int}, dict) # E: Too many arguments for TypedDict() -[builtins fixtures/dict.pyi] - -[case testCannotCreateTypedDictTypeWithInvalidName] -from mypy_extensions import TypedDict -Point = TypedDict(dict, {'x': int, 'y': int}) # E: TypedDict() expects a string literal as the first argument -[builtins fixtures/dict.pyi] - -[case testCannotCreateTypedDictTypeWithInvalidItems] -from mypy_extensions import TypedDict -Point = TypedDict('Point', {'x'}) # E: TypedDict() expects a dictionary literal as the second argument -[builtins fixtures/dict.pyi] - --- NOTE: The following code works at runtime but is not yet supported by mypy. --- Keyword arguments may potentially be supported in the future. -[case testCannotCreateTypedDictTypeWithNonpositionalArgs] -from mypy_extensions import TypedDict -Point = TypedDict(typename='Point', fields={'x': int, 'y': int}) # E: Unexpected arguments to TypedDict() -[builtins fixtures/dict.pyi] - -[case testCannotCreateTypedDictTypeWithInvalidItemName] -from mypy_extensions import TypedDict -Point = TypedDict('Point', {int: int, int: int}) # E: Invalid TypedDict() field name -[builtins fixtures/dict.pyi] - -[case testCannotCreateTypedDictTypeWithInvalidItemType] -from mypy_extensions import TypedDict -Point = TypedDict('Point', {'x': 1, 'y': 1}) # E: Invalid field type -[builtins fixtures/dict.pyi]