Skip to content

Refactor plugin system and special case TypedDict get and int.__pow__ #3501

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 11 commits into from
Jun 7, 2017
4 changes: 3 additions & 1 deletion mypy/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
from mypy.stats import dump_type_stats
from mypy.types import Type
from mypy.version import __version__
from mypy.plugin import DefaultPlugin


# We need to know the location of this file to load data, but
Expand Down Expand Up @@ -1505,8 +1506,9 @@ def type_check_first_pass(self) -> None:
if self.options.semantic_analysis_only:
return
with self.wrap_context():
plugin = DefaultPlugin(self.options.python_version)
self.type_checker = TypeChecker(manager.errors, manager.modules, self.options,
self.tree, self.xpath)
self.tree, self.xpath, plugin)
self.type_checker.check_first_pass()

def type_check_second_pass(self) -> bool:
Expand Down
10 changes: 8 additions & 2 deletions mypy/checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@
from mypy.binder import ConditionalTypeBinder, get_declaration
from mypy.meet import is_overlapping_types
from mypy.options import Options
from mypy.plugin import Plugin

from mypy import experiments

Expand Down Expand Up @@ -127,8 +128,12 @@ class TypeChecker(NodeVisitor[None]):
# directly or indirectly.
module_refs = None # type: Set[str]

# Plugin that provides special type checking rules for specific library
# functions such as open(), etc.
plugin = None # type: Plugin

def __init__(self, errors: Errors, modules: Dict[str, MypyFile], options: Options,
tree: MypyFile, path: str) -> None:
tree: MypyFile, path: str, plugin: Plugin) -> None:
"""Construct a type checker.

Use errors to report type check errors.
Expand All @@ -139,7 +144,8 @@ def __init__(self, errors: Errors, modules: Dict[str, MypyFile], options: Option
self.tree = tree
self.path = path
self.msg = MessageBuilder(errors, modules)
self.expr_checker = mypy.checkexpr.ExpressionChecker(self, self.msg)
self.plugin = plugin
self.expr_checker = mypy.checkexpr.ExpressionChecker(self, self.msg, self.plugin)
self.scope = Scope(tree)
self.binder = ConditionalTypeBinder()
self.globals = tree.names
Expand Down
160 changes: 131 additions & 29 deletions mypy/checkexpr.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@
from mypy.util import split_module_names
from mypy.typevars import fill_typevars
from mypy.visitor import ExpressionVisitor
from mypy.funcplugins import get_function_plugin_callbacks, PluginCallback
from mypy.plugin import Plugin, PluginContext, MethodSignatureHook
from mypy.typeanal import make_optional_type

from mypy import experiments
Expand Down Expand Up @@ -105,17 +105,18 @@ class ExpressionChecker(ExpressionVisitor[Type]):
type_context = None # type: List[Optional[Type]]

strfrm_checker = None # type: StringFormatterChecker
function_plugins = None # type: Dict[str, PluginCallback]
plugin = None # type: Plugin

def __init__(self,
chk: 'mypy.checker.TypeChecker',
msg: MessageBuilder) -> None:
msg: MessageBuilder,
plugin: Plugin) -> None:
"""Construct an expression type checker."""
self.chk = chk
self.msg = msg
self.plugin = plugin
self.type_context = [None]
self.strfrm_checker = StringFormatterChecker(self, self.chk, self.msg)
self.function_plugins = get_function_plugin_callbacks(self.chk.options.python_version)

def visit_name_expr(self, e: NameExpr) -> Type:
"""Type check a name expression.
Expand Down Expand Up @@ -208,11 +209,33 @@ def visit_call_expr(self, e: CallExpr, allow_none_return: bool = False) -> Type:
isinstance(callee_type, CallableType)
and callee_type.implicit):
return self.msg.untyped_function_call(callee_type, e)
# Figure out the full name of the callee for plugin loopup.
object_type = None
if not isinstance(e.callee, RefExpr):
fullname = None
else:
fullname = e.callee.fullname
ret_type = self.check_call_expr_with_callee_type(callee_type, e, fullname)
if (fullname is None
and isinstance(e.callee, MemberExpr)
and isinstance(callee_type, FunctionLike)):
# For method calls we include the defining class for the method
# in the full name (example: 'typing.Mapping.get').
callee_expr_type = self.chk.type_map.get(e.callee.expr)
info = None
# TODO: Support fallbacks of other kinds of types as well?
if isinstance(callee_expr_type, Instance):
info = callee_expr_type.type
elif isinstance(callee_expr_type, TypedDictType):
info = callee_expr_type.fallback.type.get_containing_type_info(e.callee.name)
if info:
fullname = '{}.{}'.format(info.fullname(), e.callee.name)
object_type = callee_expr_type
# Apply plugin signature hook that may generate a better signature.
signature_hook = self.plugin.get_method_signature_hook(fullname)
if signature_hook:
callee_type = self.apply_method_signature_hook(
e, callee_type, object_type, signature_hook)
ret_type = self.check_call_expr_with_callee_type(callee_type, e, fullname, object_type)
if isinstance(ret_type, UninhabitedType):
self.chk.binder.unreachable()
if not allow_none_return and isinstance(ret_type, NoneTyp):
Expand Down Expand Up @@ -351,8 +374,10 @@ def apply_function_plugin(self,
formal_to_actual: List[List[int]],
args: List[Expression],
num_formals: int,
fullname: Optional[str]) -> Type:
"""Use special case logic to infer the return type for of a particular named function.
fullname: Optional[str],
object_type: Optional[Type],
context: Context) -> Type:
"""Use special case logic to infer the return type of a specific named function/method.

Return the inferred return type.
"""
Expand All @@ -362,41 +387,90 @@ def apply_function_plugin(self,
for actual in actuals:
formal_arg_types[formal].append(arg_types[actual])
formal_arg_exprs[formal].append(args[actual])
return self.function_plugins[fullname](
formal_arg_types, formal_arg_exprs, inferred_ret_type, self.chk.named_generic_type)

def check_call_expr_with_callee_type(self, callee_type: Type,
e: CallExpr, callable_name: Optional[str]) -> Type:
if object_type is None:
# Apply function plugin
callback = self.plugin.get_function_hook(fullname)
assert callback is not None # Assume that caller ensures this
return callback(formal_arg_types, formal_arg_exprs, inferred_ret_type,
self.chk.named_generic_type)
else:
# Apply method plugin
method_callback = self.plugin.get_method_hook(fullname)
assert method_callback is not None # Assume that caller ensures this
return method_callback(object_type, formal_arg_types, formal_arg_exprs,
inferred_ret_type, self.create_plugin_context(context))

def apply_method_signature_hook(self, e: CallExpr, callee: FunctionLike, object_type: Type,
signature_hook: MethodSignatureHook) -> FunctionLike:
"""Apply a plugin hook that may infer a more precise signature for a method."""
if isinstance(callee, CallableType):
arg_kinds = e.arg_kinds
arg_names = e.arg_names
args = e.args
num_formals = len(callee.arg_kinds)
formal_to_actual = map_actuals_to_formals(
arg_kinds, arg_names,
callee.arg_kinds, callee.arg_names,
lambda i: self.accept(args[i]))
formal_arg_exprs = [[] for _ in range(num_formals)] # type: List[List[Expression]]
for formal, actuals in enumerate(formal_to_actual):
for actual in actuals:
formal_arg_exprs[formal].append(args[actual])
return signature_hook(object_type, formal_arg_exprs, callee,
self.chk.named_generic_type)
else:
assert isinstance(callee, Overloaded)
items = []
for item in callee.items():
adjusted = self.apply_method_signature_hook(e, item, object_type, signature_hook)
assert isinstance(adjusted, CallableType)
items.append(adjusted)
return Overloaded(items)

def create_plugin_context(self, context: Context) -> PluginContext:
return PluginContext(self.chk.named_generic_type, self.msg, context)

def check_call_expr_with_callee_type(self,
callee_type: Type,
e: CallExpr,
callable_name: Optional[str],
object_type: Optional[Type]) -> Type:
"""Type check call expression.

The given callee type overrides the type of the callee
expression.
"""
return self.check_call(callee_type, e.args, e.arg_kinds, e,
e.arg_names, callable_node=e.callee,
callable_name=callable_name)[0]
callable_name=callable_name,
object_type=object_type)[0]

def check_call(self, callee: Type, args: List[Expression],
arg_kinds: List[int], context: Context,
arg_names: List[str] = None,
callable_node: Expression = None,
arg_messages: MessageBuilder = None,
callable_name: Optional[str] = None) -> Tuple[Type, Type]:
callable_name: Optional[str] = None,
object_type: Optional[Type] = None) -> Tuple[Type, Type]:
"""Type check a call.

Also infer type arguments if the callee is a generic function.

Return (result type, inferred callee type).

Arguments:
callee: type of the called value
args: actual argument expressions
arg_kinds: contains nodes.ARG_* constant for each argument in args
describing whether the argument is positional, *arg, etc.
arg_names: names of arguments (optional)
callable_node: associate the inferred callable type to this node,
if specified
arg_messages: TODO
callee: type of the called value
args: actual argument expressions
arg_kinds: contains nodes.ARG_* constant for each argument in args
describing whether the argument is positional, *arg, etc.
arg_names: names of arguments (optional)
callable_node: associate the inferred callable type to this node,
if specified
arg_messages: TODO
callable_name: Fully-qualified name of the function/method to call,
or None if unavaiable (examples: 'builtins.open', 'typing.Mapping.get')
object_type: If callable_name refers to a method, the type of the object
on which the method is being called
"""
arg_messages = arg_messages or self.msg
if isinstance(callee, CallableType):
Expand Down Expand Up @@ -443,10 +517,12 @@ def check_call(self, callee: Type, args: List[Expression],
if callable_node:
# Store the inferred callable type.
self.chk.store_type(callable_node, callee)
if callable_name in self.function_plugins:

if ((object_type is None and self.plugin.get_function_hook(callable_name))
or (object_type is not None and self.plugin.get_method_hook(callable_name))):
ret_type = self.apply_function_plugin(
arg_types, callee.ret_type, arg_kinds, formal_to_actual,
args, len(callee.arg_types), callable_name)
args, len(callee.arg_types), callable_name, object_type, context)
callee = callee.copy_modified(ret_type=ret_type)
return callee.ret_type, callee
elif isinstance(callee, Overloaded):
Expand All @@ -461,7 +537,9 @@ def check_call(self, callee: Type, args: List[Expression],
callee, context,
messages=arg_messages)
return self.check_call(target, args, arg_kinds, context, arg_names,
arg_messages=arg_messages)
arg_messages=arg_messages,
callable_name=callable_name,
object_type=object_type)
elif isinstance(callee, AnyType) or not self.chk.in_checked_function():
self.infer_arg_types_in_context(None, args)
return AnyType(), AnyType()
Expand Down Expand Up @@ -1295,8 +1373,16 @@ def check_op_local(self, method: str, base_type: Type, arg: Expression,
method_type = analyze_member_access(method, base_type, context, False, False, True,
self.named_type, self.not_ready_callback, local_errors,
original_type=base_type, chk=self.chk)
callable_name = None
object_type = None
if isinstance(base_type, Instance):
# TODO: Find out in which class the method was defined originally?
# TODO: Support non-Instance types.
callable_name = '{}.{}'.format(base_type.type.fullname(), method)
object_type = base_type
return self.check_call(method_type, [arg], [nodes.ARG_POS],
context, arg_messages=local_errors)
context, arg_messages=local_errors,
callable_name=callable_name, object_type=object_type)

def check_op(self, method: str, base_type: Type, arg: Expression,
context: Context,
Expand Down Expand Up @@ -1769,13 +1855,14 @@ def visit_dict_expr(self, e: DictExpr) -> Type:
# an error, but returns the TypedDict type that matches the literal it found
# that would cause a second error when that TypedDict type is returned upstream
# to avoid the second error, we always return TypedDict type that was requested
if isinstance(self.type_context[-1], TypedDictType):
typeddict_context = self.find_typeddict_context(self.type_context[-1])
if typeddict_context:
self.check_typeddict_call_with_dict(
callee=self.type_context[-1],
callee=typeddict_context,
kwargs=e,
context=e
)
return self.type_context[-1].copy_modified()
return typeddict_context.copy_modified()

# Collect function arguments, watching out for **expr.
args = [] # type: List[Expression] # Regular "key: value"
Expand Down Expand Up @@ -1826,6 +1913,21 @@ def visit_dict_expr(self, e: DictExpr) -> Type:
self.check_call(method, [arg], [nodes.ARG_POS], arg)
return rv

def find_typeddict_context(self, context: Type) -> Optional[TypedDictType]:
if isinstance(context, TypedDictType):
return context
elif isinstance(context, UnionType):
items = []
for item in context.items:
item_context = self.find_typeddict_context(item)
if item_context:
items.append(item_context)
if len(items) == 1:
# Only one union item is TypedDict, so use the context as it's unambiguous.
return items[0]
# No TypedDict type in context.
return None

def visit_lambda_expr(self, e: LambdaExpr) -> Type:
"""Type check lambda expression."""
inferred_type, type_override = self.infer_lambda_type_using_context(e)
Expand Down
Loading