Skip to content

Commit a216788

Browse files
committed
Merge branch 'master' of github.com:mongodb/mongo-python-driver
2 parents 30d1c7f + 7a07c02 commit a216788

File tree

8 files changed

+576
-407
lines changed

8 files changed

+576
-407
lines changed

doc/changelog.rst

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,6 @@
11
Changelog
22
=========
33

4-
Changes in Version 4.15.1 (XXXX/XX/XX)
5-
--------------------------------------
6-
7-
Version 4.15.1 is a bug fix release.
8-
9-
- Fixed a bug in ``AsyncMongoClient`` that caused a
10-
``ServerSelectionTimeoutError`` when used with ``uvicorn``, ``FastAPI``, or ``uvloop``.
11-
124
Changes in Version 4.15.0 (2025/09/10)
135
--------------------------------------
146

pymongo/asynchronous/encryption.py

Lines changed: 18 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,6 @@
6464
from pymongo.asynchronous.cursor import AsyncCursor
6565
from pymongo.asynchronous.database import AsyncDatabase
6666
from pymongo.asynchronous.mongo_client import AsyncMongoClient
67-
from pymongo.asynchronous.pool import AsyncBaseConnection
6867
from pymongo.common import CONNECT_TIMEOUT
6968
from pymongo.daemon import _spawn_daemon
7069
from pymongo.encryption_options import AutoEncryptionOpts, RangeOpts, TextOpts
@@ -77,11 +76,11 @@
7776
ServerSelectionTimeoutError,
7877
)
7978
from pymongo.helpers_shared import _get_timeout_details
80-
from pymongo.network_layer import PyMongoKMSProtocol, async_receive_kms, async_sendall
79+
from pymongo.network_layer import async_socket_sendall
8180
from pymongo.operations import UpdateOne
8281
from pymongo.pool_options import PoolOptions
8382
from pymongo.pool_shared import (
84-
_configured_protocol_interface,
83+
_async_configured_socket,
8584
_raise_connection_failure,
8685
)
8786
from pymongo.read_concern import ReadConcern
@@ -94,8 +93,10 @@
9493
if TYPE_CHECKING:
9594
from pymongocrypt.mongocrypt import MongoCryptKmsContext
9695

96+
from pymongo.pyopenssl_context import _sslConn
9797
from pymongo.typings import _Address
9898

99+
99100
_IS_SYNC = False
100101

101102
_HTTPS_PORT = 443
@@ -110,10 +111,9 @@
110111
_KEY_VAULT_OPTS = CodecOptions(document_class=RawBSONDocument)
111112

112113

113-
async def _connect_kms(address: _Address, opts: PoolOptions) -> AsyncBaseConnection:
114+
async def _connect_kms(address: _Address, opts: PoolOptions) -> Union[socket.socket, _sslConn]:
114115
try:
115-
interface = await _configured_protocol_interface(address, opts, PyMongoKMSProtocol)
116-
return AsyncBaseConnection(interface, opts)
116+
return await _async_configured_socket(address, opts)
117117
except Exception as exc:
118118
_raise_connection_failure(address, exc, timeout_details=_get_timeout_details(opts))
119119

@@ -198,11 +198,19 @@ async def kms_request(self, kms_context: MongoCryptKmsContext) -> None:
198198
try:
199199
conn = await _connect_kms(address, opts)
200200
try:
201-
await async_sendall(conn.conn.get_conn, message)
201+
await async_socket_sendall(conn, message)
202202
while kms_context.bytes_needed > 0:
203203
# CSOT: update timeout.
204-
conn.set_conn_timeout(max(_csot.clamp_remaining(_KMS_CONNECT_TIMEOUT), 0))
205-
data = await async_receive_kms(conn, kms_context.bytes_needed)
204+
conn.settimeout(max(_csot.clamp_remaining(_KMS_CONNECT_TIMEOUT), 0))
205+
data: memoryview | bytes
206+
if _IS_SYNC:
207+
data = conn.recv(kms_context.bytes_needed)
208+
else:
209+
from pymongo.network_layer import ( # type: ignore[attr-defined]
210+
async_receive_data_socket,
211+
)
212+
213+
data = await async_receive_data_socket(conn, kms_context.bytes_needed)
206214
if not data:
207215
raise OSError("KMS connection closed")
208216
kms_context.feed(data)
@@ -221,7 +229,7 @@ async def kms_request(self, kms_context: MongoCryptKmsContext) -> None:
221229
address, exc, msg_prefix=msg_prefix, timeout_details=_get_timeout_details(opts)
222230
)
223231
finally:
224-
await conn.close_conn(None)
232+
conn.close()
225233
except MongoCryptError:
226234
raise # Propagate MongoCryptError errors directly.
227235
except Exception as exc:

pymongo/asynchronous/pool.py

Lines changed: 71 additions & 90 deletions
Original file line numberDiff line numberDiff line change
@@ -123,89 +123,7 @@ def _set_non_inheritable_non_atomic(fd: int) -> None: # noqa: ARG001
123123
_IS_SYNC = False
124124

125125

126-
class AsyncBaseConnection:
127-
"""A base connection object for server and kms connections."""
128-
129-
def __init__(self, conn: AsyncNetworkingInterface, opts: PoolOptions):
130-
self.conn = conn
131-
self.socket_checker: SocketChecker = SocketChecker()
132-
self.cancel_context: _CancellationContext = _CancellationContext()
133-
self.is_sdam = False
134-
self.closed = False
135-
self.last_timeout: float | None = None
136-
self.more_to_come = False
137-
self.opts = opts
138-
self.max_wire_version = -1
139-
140-
def set_conn_timeout(self, timeout: Optional[float]) -> None:
141-
"""Cache last timeout to avoid duplicate calls to conn.settimeout."""
142-
if timeout == self.last_timeout:
143-
return
144-
self.last_timeout = timeout
145-
self.conn.get_conn.settimeout(timeout)
146-
147-
def apply_timeout(
148-
self, client: AsyncMongoClient[Any], cmd: Optional[MutableMapping[str, Any]]
149-
) -> Optional[float]:
150-
# CSOT: use remaining timeout when set.
151-
timeout = _csot.remaining()
152-
if timeout is None:
153-
# Reset the socket timeout unless we're performing a streaming monitor check.
154-
if not self.more_to_come:
155-
self.set_conn_timeout(self.opts.socket_timeout)
156-
return None
157-
# RTT validation.
158-
rtt = _csot.get_rtt()
159-
if rtt is None:
160-
rtt = self.connect_rtt
161-
max_time_ms = timeout - rtt
162-
if max_time_ms < 0:
163-
timeout_details = _get_timeout_details(self.opts)
164-
formatted = format_timeout_details(timeout_details)
165-
# CSOT: raise an error without running the command since we know it will time out.
166-
errmsg = f"operation would exceed time limit, remaining timeout:{timeout:.5f} <= network round trip time:{rtt:.5f} {formatted}"
167-
if self.max_wire_version != -1:
168-
raise ExecutionTimeout(
169-
errmsg,
170-
50,
171-
{"ok": 0, "errmsg": errmsg, "code": 50},
172-
self.max_wire_version,
173-
)
174-
else:
175-
raise TimeoutError(errmsg)
176-
if cmd is not None:
177-
cmd["maxTimeMS"] = int(max_time_ms * 1000)
178-
self.set_conn_timeout(timeout)
179-
return timeout
180-
181-
async def close_conn(self, reason: Optional[str]) -> None:
182-
"""Close this connection with a reason."""
183-
if self.closed:
184-
return
185-
await self._close_conn()
186-
187-
async def _close_conn(self) -> None:
188-
"""Close this connection."""
189-
if self.closed:
190-
return
191-
self.closed = True
192-
self.cancel_context.cancel()
193-
# Note: We catch exceptions to avoid spurious errors on interpreter
194-
# shutdown.
195-
try:
196-
await self.conn.close()
197-
except Exception: # noqa: S110
198-
pass
199-
200-
def conn_closed(self) -> bool:
201-
"""Return True if we know socket has been closed, False otherwise."""
202-
if _IS_SYNC:
203-
return self.socket_checker.socket_closed(self.conn.get_conn)
204-
else:
205-
return self.conn.is_closing()
206-
207-
208-
class AsyncConnection(AsyncBaseConnection):
126+
class AsyncConnection:
209127
"""Store a connection with some metadata.
210128
211129
:param conn: a raw connection object
@@ -223,27 +141,29 @@ def __init__(
223141
id: int,
224142
is_sdam: bool,
225143
):
226-
super().__init__(conn, pool.opts)
227144
self.pool_ref = weakref.ref(pool)
228-
self.address: tuple[str, int] = address
229-
self.id: int = id
145+
self.conn = conn
146+
self.address = address
147+
self.id = id
230148
self.is_sdam = is_sdam
149+
self.closed = False
231150
self.last_checkin_time = time.monotonic()
232151
self.performed_handshake = False
233152
self.is_writable: bool = False
234153
self.max_wire_version = MAX_WIRE_VERSION
235-
self.max_bson_size: int = MAX_BSON_SIZE
236-
self.max_message_size: int = MAX_MESSAGE_SIZE
237-
self.max_write_batch_size: int = MAX_WRITE_BATCH_SIZE
154+
self.max_bson_size = MAX_BSON_SIZE
155+
self.max_message_size = MAX_MESSAGE_SIZE
156+
self.max_write_batch_size = MAX_WRITE_BATCH_SIZE
238157
self.supports_sessions = False
239158
self.hello_ok: bool = False
240-
self.is_mongos: bool = False
159+
self.is_mongos = False
241160
self.op_msg_enabled = False
242161
self.listeners = pool.opts._event_listeners
243162
self.enabled_for_cmap = pool.enabled_for_cmap
244163
self.enabled_for_logging = pool.enabled_for_logging
245164
self.compression_settings = pool.opts._compression_settings
246165
self.compression_context: Union[SnappyContext, ZlibContext, ZstdContext, None] = None
166+
self.socket_checker: SocketChecker = SocketChecker()
247167
self.oidc_token_gen_id: Optional[int] = None
248168
# Support for mechanism negotiation on the initial handshake.
249169
self.negotiated_mechs: Optional[list[str]] = None
@@ -254,6 +174,9 @@ def __init__(
254174
self.pool_gen = pool.gen
255175
self.generation = self.pool_gen.get_overall()
256176
self.ready = False
177+
self.cancel_context: _CancellationContext = _CancellationContext()
178+
self.opts = pool.opts
179+
self.more_to_come: bool = False
257180
# For load balancer support.
258181
self.service_id: Optional[ObjectId] = None
259182
self.server_connection_id: Optional[int] = None
@@ -269,6 +192,44 @@ def __init__(
269192
# For gossiping $clusterTime from the connection handshake to the client.
270193
self._cluster_time = None
271194

195+
def set_conn_timeout(self, timeout: Optional[float]) -> None:
196+
"""Cache last timeout to avoid duplicate calls to conn.settimeout."""
197+
if timeout == self.last_timeout:
198+
return
199+
self.last_timeout = timeout
200+
self.conn.get_conn.settimeout(timeout)
201+
202+
def apply_timeout(
203+
self, client: AsyncMongoClient[Any], cmd: Optional[MutableMapping[str, Any]]
204+
) -> Optional[float]:
205+
# CSOT: use remaining timeout when set.
206+
timeout = _csot.remaining()
207+
if timeout is None:
208+
# Reset the socket timeout unless we're performing a streaming monitor check.
209+
if not self.more_to_come:
210+
self.set_conn_timeout(self.opts.socket_timeout)
211+
return None
212+
# RTT validation.
213+
rtt = _csot.get_rtt()
214+
if rtt is None:
215+
rtt = self.connect_rtt
216+
max_time_ms = timeout - rtt
217+
if max_time_ms < 0:
218+
timeout_details = _get_timeout_details(self.opts)
219+
formatted = format_timeout_details(timeout_details)
220+
# CSOT: raise an error without running the command since we know it will time out.
221+
errmsg = f"operation would exceed time limit, remaining timeout:{timeout:.5f} <= network round trip time:{rtt:.5f} {formatted}"
222+
raise ExecutionTimeout(
223+
errmsg,
224+
50,
225+
{"ok": 0, "errmsg": errmsg, "code": 50},
226+
self.max_wire_version,
227+
)
228+
if cmd is not None:
229+
cmd["maxTimeMS"] = int(max_time_ms * 1000)
230+
self.set_conn_timeout(timeout)
231+
return timeout
232+
272233
def pin_txn(self) -> None:
273234
self.pinned_txn = True
274235
assert not self.pinned_cursor
@@ -612,6 +573,26 @@ async def close_conn(self, reason: Optional[str]) -> None:
612573
error=reason,
613574
)
614575

576+
async def _close_conn(self) -> None:
577+
"""Close this connection."""
578+
if self.closed:
579+
return
580+
self.closed = True
581+
self.cancel_context.cancel()
582+
# Note: We catch exceptions to avoid spurious errors on interpreter
583+
# shutdown.
584+
try:
585+
await self.conn.close()
586+
except Exception: # noqa: S110
587+
pass
588+
589+
def conn_closed(self) -> bool:
590+
"""Return True if we know socket has been closed, False otherwise."""
591+
if _IS_SYNC:
592+
return self.socket_checker.socket_closed(self.conn.get_conn)
593+
else:
594+
return self.conn.is_closing()
595+
615596
def send_cluster_time(
616597
self,
617598
command: MutableMapping[str, Any],

0 commit comments

Comments
 (0)