@@ -67,14 +67,20 @@ def __init__(self, config):
67
67
# flag to guard against trying to rewrite a pyc file while we are already writing another pyc file,
68
68
# which might result in infinite recursion (#3506)
69
69
self ._writing_pyc = False
70
+ self ._basenames_to_check_rewrite = set ('conftest' ,)
71
+ self ._marked_for_rewrite_cache = {}
72
+ self ._session_paths_checked = False
70
73
71
74
def set_session (self , session ):
72
75
self .session = session
76
+ self ._session_paths_checked = False
73
77
74
78
def find_module (self , name , path = None ):
75
79
if self ._writing_pyc :
76
80
return None
77
81
state = self .config ._assertstate
82
+ if self ._early_rewrite_bailout (name , state ):
83
+ return None
78
84
state .trace ("find_module called for: %s" % name )
79
85
names = name .rsplit ("." , 1 )
80
86
lastname = names [- 1 ]
@@ -166,6 +172,41 @@ def find_module(self, name, path=None):
166
172
self .modules [name ] = co , pyc
167
173
return self
168
174
175
+ def _early_rewrite_bailout (self , name , state ):
176
+ """
177
+ This is a fast way to get out of rewriting modules. Profiling has
178
+ shown that the call to imp.find_module (inside of the find_module
179
+ from this class) is a major slowdown, so, this method tries to
180
+ filter what we're sure won't be rewritten before getting to it.
181
+ """
182
+ if not self ._session_paths_checked and self .session is not None \
183
+ and hasattr (self .session , '_initialpaths' ):
184
+ self ._session_paths_checked = True
185
+ for path in self .session ._initialpaths :
186
+ # Make something as c:/projects/my_project/path.py ->
187
+ # ['c:', 'projects', 'my_project', 'path.py']
188
+ parts = str (path ).split (os .path .sep )
189
+ # add 'path' to basenames to be checked.
190
+ self ._basenames_to_check_rewrite .add (os .path .splitext (parts [- 1 ])[0 ])
191
+
192
+ # Note: conftest already by default in _basenames_to_check_rewrite.
193
+ parts = name .split ('.' )
194
+ if parts [- 1 ] in self ._basenames_to_check_rewrite :
195
+ return False
196
+
197
+ # For matching the name it must be as if it was a filename.
198
+ parts [- 1 ] = parts [- 1 ] + '.py'
199
+ fn_pypath = py .path .local (os .path .sep .join (parts ))
200
+ for pat in self .fnpats :
201
+ if fn_pypath .fnmatch (pat ):
202
+ return False
203
+
204
+ if self ._is_marked_for_rewrite (name , state ):
205
+ return False
206
+
207
+ state .trace ("early skip of rewriting module: %s" % (name ,))
208
+ return True
209
+
169
210
def _should_rewrite (self , name , fn_pypath , state ):
170
211
# always rewrite conftest files
171
212
fn = str (fn_pypath )
@@ -185,12 +226,20 @@ def _should_rewrite(self, name, fn_pypath, state):
185
226
state .trace ("matched test file %r" % (fn ,))
186
227
return True
187
228
188
- for marked in self ._must_rewrite :
189
- if name == marked or name .startswith (marked + "." ):
190
- state .trace ("matched marked file %r (from %r)" % (name , marked ))
191
- return True
229
+ return self ._is_marked_for_rewrite (name , state )
192
230
193
- return False
231
+ def _is_marked_for_rewrite (self , name , state ):
232
+ try :
233
+ return self ._marked_for_rewrite_cache [name ]
234
+ except KeyError :
235
+ for marked in self ._must_rewrite :
236
+ if name == marked or name .startswith (marked + "." ):
237
+ state .trace ("matched marked file %r (from %r)" % (name , marked ))
238
+ self ._marked_for_rewrite_cache [name ] = True
239
+ return True
240
+
241
+ self ._marked_for_rewrite_cache [name ] = False
242
+ return False
194
243
195
244
def mark_rewrite (self , * names ):
196
245
"""Mark import names as needing to be rewritten.
@@ -207,6 +256,7 @@ def mark_rewrite(self, *names):
207
256
):
208
257
self ._warn_already_imported (name )
209
258
self ._must_rewrite .update (names )
259
+ self ._marked_for_rewrite_cache .clear ()
210
260
211
261
def _warn_already_imported (self , name ):
212
262
self .config .warn (
@@ -239,16 +289,6 @@ def load_module(self, name):
239
289
raise
240
290
return sys .modules [name ]
241
291
242
- def is_package (self , name ):
243
- try :
244
- fd , fn , desc = imp .find_module (name )
245
- except ImportError :
246
- return False
247
- if fd is not None :
248
- fd .close ()
249
- tp = desc [2 ]
250
- return tp == imp .PKG_DIRECTORY
251
-
252
292
@classmethod
253
293
def _register_with_pkg_resources (cls ):
254
294
"""
0 commit comments