diff --git a/src/_pytest/assertion/rewrite.py b/src/_pytest/assertion/rewrite.py
index 2e606d1903a..3b7b6431ffd 100644
--- a/src/_pytest/assertion/rewrite.py
+++ b/src/_pytest/assertion/rewrite.py
@@ -80,6 +80,7 @@ def __init__(self, config: Config) -> None:
         self._basenames_to_check_rewrite = {"conftest"}
         self._marked_for_rewrite_cache: dict[str, bool] = {}
         self._session_paths_checked = False
+        self._fns: dict[str, str] = {}
 
     def set_session(self, session: Session | None) -> None:
         self.session = session
@@ -126,7 +127,7 @@ def find_spec(
         ):
             return None
         else:
-            fn = spec.origin
+            self._fns[name] = fn = spec.origin
 
         if not self._should_rewrite(name, fn, state):
             return None
@@ -143,14 +144,11 @@ def create_module(
     ) -> types.ModuleType | None:
         return None  # default behaviour is fine
 
-    def exec_module(self, module: types.ModuleType) -> None:
-        assert module.__spec__ is not None
-        assert module.__spec__.origin is not None
-        fn = Path(module.__spec__.origin)
+    def get_code(self, fullname: str) -> types.CodeType:
+        assert fullname in self._fns
+        fn = Path(self._fns[fullname])
         state = self.config.stash[assertstate_key]
 
-        self._rewritten_names[module.__name__] = fn
-
         # The requested module looks like a test file, so rewrite it. This is
         # the most magical part of the process: load the source, rewrite the
         # asserts, and load the rewritten source. We also cache the rewritten
@@ -183,7 +181,21 @@ def exec_module(self, module: types.ModuleType) -> None:
                     self._writing_pyc = False
         else:
             state.trace(f"found cached rewritten pyc for {fn}")
-        exec(co, module.__dict__)
+
+        return co
+
+    def exec_module(self, module: types.ModuleType) -> None:
+        module_name = module.__name__
+
+        assert (
+            module_name in self._fns
+            and module.__spec__ is not None
+            and module.__spec__.origin == self._fns[module_name]
+        )
+
+        self._rewritten_names[module_name] = Path(self._fns[module_name])
+
+        exec(self.get_code(module_name), module.__dict__)
 
     def _early_rewrite_bailout(self, name: str, state: AssertionState) -> bool:
         """A fast way to get out of rewriting modules.