Skip to content

Commit a702c40

Browse files
ecpricerwbarton
authored andcommitted
Change ConditionalTypeBinder to mainly use with statements (#1731)
Rather than use `binder.push_frame()` and `binder.pop_frame(*args)`, you now do `with binder.frame_context(*args)`. There was a previous context manager used as `with binder`, but it was only used in a few places because it didn't support options. In the process, I discovered several bugs in the old control flow analysis; new test cases have been added to catch them. This commit also moves ConditionalTypeBinder to its own file and includes other code clean-up and new comments.
1 parent baf0580 commit a702c40

File tree

5 files changed

+747
-513
lines changed

5 files changed

+747
-513
lines changed

mypy/binder.py

Lines changed: 264 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,264 @@
1+
from typing import (Any, Dict, List, Set, Iterator)
2+
from contextlib import contextmanager
3+
4+
from mypy.types import Type, AnyType, PartialType
5+
from mypy.nodes import (Node, Var)
6+
7+
from mypy.subtypes import is_subtype
8+
from mypy.join import join_simple
9+
from mypy.sametypes import is_same_type
10+
11+
12+
class Frame(Dict[Any, Type]):
13+
pass
14+
15+
16+
class Key(AnyType):
17+
pass
18+
19+
20+
class ConditionalTypeBinder:
21+
"""Keep track of conditional types of variables.
22+
23+
NB: Variables are tracked by literal expression, so it is possible
24+
to confuse the binder; for example,
25+
26+
```
27+
class A:
28+
a = None # type: Union[int, str]
29+
x = A()
30+
lst = [x]
31+
reveal_type(x.a) # Union[int, str]
32+
x.a = 1
33+
reveal_type(x.a) # int
34+
reveal_type(lst[0].a) # Union[int, str]
35+
lst[0].a = 'a'
36+
reveal_type(x.a) # int
37+
reveal_type(lst[0].a) # str
38+
```
39+
"""
40+
41+
def __init__(self) -> None:
42+
# The set of frames currently used. These map
43+
# expr.literal_hash -- literals like 'foo.bar' --
44+
# to types.
45+
self.frames = [Frame()]
46+
47+
# For frames higher in the stack, we record the set of
48+
# Frames that can escape there
49+
self.options_on_return = [] # type: List[List[Frame]]
50+
51+
# Maps expr.literal_hash] to get_declaration(expr)
52+
# for every expr stored in the binder
53+
self.declarations = Frame()
54+
# Set of other keys to invalidate if a key is changed, e.g. x -> {x.a, x[0]}
55+
# Whenever a new key (e.g. x.a.b) is added, we update this
56+
self.dependencies = {} # type: Dict[Key, Set[Key]]
57+
58+
# breaking_out is set to True on return/break/continue/raise
59+
# It is cleared on pop_frame() and placed in last_pop_breaking_out
60+
# Lines of code after breaking_out = True are unreachable and not
61+
# typechecked.
62+
self.breaking_out = False
63+
64+
# Whether the last pop changed the newly top frame on exit
65+
self.last_pop_changed = False
66+
# Whether the last pop was necessarily breaking out, and couldn't fall through
67+
self.last_pop_breaking_out = False
68+
69+
self.try_frames = set() # type: Set[int]
70+
self.loop_frames = [] # type: List[int]
71+
72+
def _add_dependencies(self, key: Key, value: Key = None) -> None:
73+
if value is None:
74+
value = key
75+
else:
76+
self.dependencies.setdefault(key, set()).add(value)
77+
if isinstance(key, tuple):
78+
for elt in key:
79+
self._add_dependencies(elt, value)
80+
81+
def push_frame(self) -> Frame:
82+
"""Push a new frame into the binder."""
83+
f = Frame()
84+
self.frames.append(f)
85+
self.options_on_return.append([])
86+
return f
87+
88+
def _push(self, key: Key, type: Type, index: int=-1) -> None:
89+
self.frames[index][key] = type
90+
91+
def _get(self, key: Key, index: int=-1) -> Type:
92+
if index < 0:
93+
index += len(self.frames)
94+
for i in range(index, -1, -1):
95+
if key in self.frames[i]:
96+
return self.frames[i][key]
97+
return None
98+
99+
def push(self, expr: Node, typ: Type) -> None:
100+
if not expr.literal:
101+
return
102+
key = expr.literal_hash
103+
if key not in self.declarations:
104+
self.declarations[key] = self.get_declaration(expr)
105+
self._add_dependencies(key)
106+
self._push(key, typ)
107+
108+
def get(self, expr: Node) -> Type:
109+
return self._get(expr.literal_hash)
110+
111+
def cleanse(self, expr: Node) -> None:
112+
"""Remove all references to a Node from the binder."""
113+
self._cleanse_key(expr.literal_hash)
114+
115+
def _cleanse_key(self, key: Key) -> None:
116+
"""Remove all references to a key from the binder."""
117+
for frame in self.frames:
118+
if key in frame:
119+
del frame[key]
120+
121+
def update_from_options(self, frames: List[Frame]) -> bool:
122+
"""Update the frame to reflect that each key will be updated
123+
as in one of the frames. Return whether any item changes.
124+
125+
If a key is declared as AnyType, only update it if all the
126+
options are the same.
127+
"""
128+
129+
changed = False
130+
keys = set(key for f in frames for key in f)
131+
132+
for key in keys:
133+
current_value = self._get(key)
134+
resulting_values = [f.get(key, current_value) for f in frames]
135+
if any(x is None for x in resulting_values):
136+
continue
137+
138+
if isinstance(self.declarations.get(key), AnyType):
139+
type = resulting_values[0]
140+
if not all(is_same_type(type, t) for t in resulting_values[1:]):
141+
type = AnyType()
142+
else:
143+
type = resulting_values[0]
144+
for other in resulting_values[1:]:
145+
type = join_simple(self.declarations[key], type, other)
146+
if not is_same_type(type, current_value):
147+
self._push(key, type)
148+
changed = True
149+
150+
return changed
151+
152+
def pop_frame(self, fall_through: int = 0) -> Frame:
153+
"""Pop a frame and return it.
154+
155+
See frame_context() for documentation of fall_through.
156+
"""
157+
if fall_through and not self.breaking_out:
158+
self.allow_jump(-fall_through)
159+
160+
result = self.frames.pop()
161+
options = self.options_on_return.pop()
162+
163+
self.last_pop_changed = self.update_from_options(options)
164+
self.last_pop_breaking_out = self.breaking_out
165+
166+
return result
167+
168+
def get_declaration(self, expr: Any) -> Type:
169+
if hasattr(expr, 'node') and isinstance(expr.node, Var):
170+
type = expr.node.type
171+
if isinstance(type, PartialType):
172+
return None
173+
return type
174+
else:
175+
return None
176+
177+
def assign_type(self, expr: Node,
178+
type: Type,
179+
declared_type: Type,
180+
restrict_any: bool = False) -> None:
181+
if not expr.literal:
182+
return
183+
self.invalidate_dependencies(expr)
184+
185+
if declared_type is None:
186+
# Not sure why this happens. It seems to mainly happen in
187+
# member initialization.
188+
return
189+
if not is_subtype(type, declared_type):
190+
# Pretty sure this is only happens when there's a type error.
191+
192+
# Ideally this function wouldn't be called if the
193+
# expression has a type error, though -- do other kinds of
194+
# errors cause this function to get called at invalid
195+
# times?
196+
return
197+
198+
# If x is Any and y is int, after x = y we do not infer that x is int.
199+
# This could be changed.
200+
# Eric: I'm changing it in weak typing mode, since Any is so common.
201+
202+
if (isinstance(self.most_recent_enclosing_type(expr, type), AnyType)
203+
and not restrict_any):
204+
pass
205+
elif isinstance(type, AnyType):
206+
self.push(expr, declared_type)
207+
else:
208+
self.push(expr, type)
209+
210+
for i in self.try_frames:
211+
# XXX This should probably not copy the entire frame, but
212+
# just copy this variable into a single stored frame.
213+
self.allow_jump(i)
214+
215+
def invalidate_dependencies(self, expr: Node) -> None:
216+
"""Invalidate knowledge of types that include expr, but not expr itself.
217+
218+
For example, when expr is foo.bar, invalidate foo.bar.baz.
219+
220+
It is overly conservative: it invalidates globally, including
221+
in code paths unreachable from here.
222+
"""
223+
for dep in self.dependencies.get(expr.literal_hash, set()):
224+
self._cleanse_key(dep)
225+
226+
def most_recent_enclosing_type(self, expr: Node, type: Type) -> Type:
227+
if isinstance(type, AnyType):
228+
return self.get_declaration(expr)
229+
key = expr.literal_hash
230+
enclosers = ([self.get_declaration(expr)] +
231+
[f[key] for f in self.frames
232+
if key in f and is_subtype(type, f[key])])
233+
return enclosers[-1]
234+
235+
def allow_jump(self, index: int) -> None:
236+
# self.frames and self.options_on_return have different lengths
237+
# so make sure the index is positive
238+
if index < 0:
239+
index += len(self.options_on_return)
240+
frame = Frame()
241+
for f in self.frames[index + 1:]:
242+
frame.update(f)
243+
self.options_on_return[index].append(frame)
244+
245+
def push_loop_frame(self) -> None:
246+
self.loop_frames.append(len(self.frames) - 1)
247+
248+
def pop_loop_frame(self) -> None:
249+
self.loop_frames.pop()
250+
251+
@contextmanager
252+
def frame_context(self, fall_through: int = 0) -> Iterator[Frame]:
253+
"""Return a context manager that pushes/pops frames on enter/exit.
254+
255+
If fall_through > 0, then it will allow the frame to escape to
256+
its ancestor `fall_through` levels higher.
257+
258+
A simple 'with binder.frame_context(): pass' will change the
259+
last_pop_* flags but nothing else.
260+
"""
261+
was_breaking_out = self.breaking_out
262+
yield self.push_frame()
263+
self.pop_frame(fall_through)
264+
self.breaking_out = was_breaking_out

0 commit comments

Comments
 (0)