@@ -123,89 +123,7 @@ def _set_non_inheritable_non_atomic(fd: int) -> None: # noqa: ARG001
123
123
_IS_SYNC = False
124
124
125
125
126
- class AsyncBaseConnection :
127
- """A base connection object for server and kms connections."""
128
-
129
- def __init__ (self , conn : AsyncNetworkingInterface , opts : PoolOptions ):
130
- self .conn = conn
131
- self .socket_checker : SocketChecker = SocketChecker ()
132
- self .cancel_context : _CancellationContext = _CancellationContext ()
133
- self .is_sdam = False
134
- self .closed = False
135
- self .last_timeout : float | None = None
136
- self .more_to_come = False
137
- self .opts = opts
138
- self .max_wire_version = - 1
139
-
140
- def set_conn_timeout (self , timeout : Optional [float ]) -> None :
141
- """Cache last timeout to avoid duplicate calls to conn.settimeout."""
142
- if timeout == self .last_timeout :
143
- return
144
- self .last_timeout = timeout
145
- self .conn .get_conn .settimeout (timeout )
146
-
147
- def apply_timeout (
148
- self , client : AsyncMongoClient [Any ], cmd : Optional [MutableMapping [str , Any ]]
149
- ) -> Optional [float ]:
150
- # CSOT: use remaining timeout when set.
151
- timeout = _csot .remaining ()
152
- if timeout is None :
153
- # Reset the socket timeout unless we're performing a streaming monitor check.
154
- if not self .more_to_come :
155
- self .set_conn_timeout (self .opts .socket_timeout )
156
- return None
157
- # RTT validation.
158
- rtt = _csot .get_rtt ()
159
- if rtt is None :
160
- rtt = self .connect_rtt
161
- max_time_ms = timeout - rtt
162
- if max_time_ms < 0 :
163
- timeout_details = _get_timeout_details (self .opts )
164
- formatted = format_timeout_details (timeout_details )
165
- # CSOT: raise an error without running the command since we know it will time out.
166
- errmsg = f"operation would exceed time limit, remaining timeout:{ timeout :.5f} <= network round trip time:{ rtt :.5f} { formatted } "
167
- if self .max_wire_version != - 1 :
168
- raise ExecutionTimeout (
169
- errmsg ,
170
- 50 ,
171
- {"ok" : 0 , "errmsg" : errmsg , "code" : 50 },
172
- self .max_wire_version ,
173
- )
174
- else :
175
- raise TimeoutError (errmsg )
176
- if cmd is not None :
177
- cmd ["maxTimeMS" ] = int (max_time_ms * 1000 )
178
- self .set_conn_timeout (timeout )
179
- return timeout
180
-
181
- async def close_conn (self , reason : Optional [str ]) -> None :
182
- """Close this connection with a reason."""
183
- if self .closed :
184
- return
185
- await self ._close_conn ()
186
-
187
- async def _close_conn (self ) -> None :
188
- """Close this connection."""
189
- if self .closed :
190
- return
191
- self .closed = True
192
- self .cancel_context .cancel ()
193
- # Note: We catch exceptions to avoid spurious errors on interpreter
194
- # shutdown.
195
- try :
196
- await self .conn .close ()
197
- except Exception : # noqa: S110
198
- pass
199
-
200
- def conn_closed (self ) -> bool :
201
- """Return True if we know socket has been closed, False otherwise."""
202
- if _IS_SYNC :
203
- return self .socket_checker .socket_closed (self .conn .get_conn )
204
- else :
205
- return self .conn .is_closing ()
206
-
207
-
208
- class AsyncConnection (AsyncBaseConnection ):
126
+ class AsyncConnection :
209
127
"""Store a connection with some metadata.
210
128
211
129
:param conn: a raw connection object
@@ -223,27 +141,29 @@ def __init__(
223
141
id : int ,
224
142
is_sdam : bool ,
225
143
):
226
- super ().__init__ (conn , pool .opts )
227
144
self .pool_ref = weakref .ref (pool )
228
- self .address : tuple [str , int ] = address
229
- self .id : int = id
145
+ self .conn = conn
146
+ self .address = address
147
+ self .id = id
230
148
self .is_sdam = is_sdam
149
+ self .closed = False
231
150
self .last_checkin_time = time .monotonic ()
232
151
self .performed_handshake = False
233
152
self .is_writable : bool = False
234
153
self .max_wire_version = MAX_WIRE_VERSION
235
- self .max_bson_size : int = MAX_BSON_SIZE
236
- self .max_message_size : int = MAX_MESSAGE_SIZE
237
- self .max_write_batch_size : int = MAX_WRITE_BATCH_SIZE
154
+ self .max_bson_size = MAX_BSON_SIZE
155
+ self .max_message_size = MAX_MESSAGE_SIZE
156
+ self .max_write_batch_size = MAX_WRITE_BATCH_SIZE
238
157
self .supports_sessions = False
239
158
self .hello_ok : bool = False
240
- self .is_mongos : bool = False
159
+ self .is_mongos = False
241
160
self .op_msg_enabled = False
242
161
self .listeners = pool .opts ._event_listeners
243
162
self .enabled_for_cmap = pool .enabled_for_cmap
244
163
self .enabled_for_logging = pool .enabled_for_logging
245
164
self .compression_settings = pool .opts ._compression_settings
246
165
self .compression_context : Union [SnappyContext , ZlibContext , ZstdContext , None ] = None
166
+ self .socket_checker : SocketChecker = SocketChecker ()
247
167
self .oidc_token_gen_id : Optional [int ] = None
248
168
# Support for mechanism negotiation on the initial handshake.
249
169
self .negotiated_mechs : Optional [list [str ]] = None
@@ -254,6 +174,9 @@ def __init__(
254
174
self .pool_gen = pool .gen
255
175
self .generation = self .pool_gen .get_overall ()
256
176
self .ready = False
177
+ self .cancel_context : _CancellationContext = _CancellationContext ()
178
+ self .opts = pool .opts
179
+ self .more_to_come : bool = False
257
180
# For load balancer support.
258
181
self .service_id : Optional [ObjectId ] = None
259
182
self .server_connection_id : Optional [int ] = None
@@ -269,6 +192,44 @@ def __init__(
269
192
# For gossiping $clusterTime from the connection handshake to the client.
270
193
self ._cluster_time = None
271
194
195
+ def set_conn_timeout (self , timeout : Optional [float ]) -> None :
196
+ """Cache last timeout to avoid duplicate calls to conn.settimeout."""
197
+ if timeout == self .last_timeout :
198
+ return
199
+ self .last_timeout = timeout
200
+ self .conn .get_conn .settimeout (timeout )
201
+
202
+ def apply_timeout (
203
+ self , client : AsyncMongoClient [Any ], cmd : Optional [MutableMapping [str , Any ]]
204
+ ) -> Optional [float ]:
205
+ # CSOT: use remaining timeout when set.
206
+ timeout = _csot .remaining ()
207
+ if timeout is None :
208
+ # Reset the socket timeout unless we're performing a streaming monitor check.
209
+ if not self .more_to_come :
210
+ self .set_conn_timeout (self .opts .socket_timeout )
211
+ return None
212
+ # RTT validation.
213
+ rtt = _csot .get_rtt ()
214
+ if rtt is None :
215
+ rtt = self .connect_rtt
216
+ max_time_ms = timeout - rtt
217
+ if max_time_ms < 0 :
218
+ timeout_details = _get_timeout_details (self .opts )
219
+ formatted = format_timeout_details (timeout_details )
220
+ # CSOT: raise an error without running the command since we know it will time out.
221
+ errmsg = f"operation would exceed time limit, remaining timeout:{ timeout :.5f} <= network round trip time:{ rtt :.5f} { formatted } "
222
+ raise ExecutionTimeout (
223
+ errmsg ,
224
+ 50 ,
225
+ {"ok" : 0 , "errmsg" : errmsg , "code" : 50 },
226
+ self .max_wire_version ,
227
+ )
228
+ if cmd is not None :
229
+ cmd ["maxTimeMS" ] = int (max_time_ms * 1000 )
230
+ self .set_conn_timeout (timeout )
231
+ return timeout
232
+
272
233
def pin_txn (self ) -> None :
273
234
self .pinned_txn = True
274
235
assert not self .pinned_cursor
@@ -612,6 +573,26 @@ async def close_conn(self, reason: Optional[str]) -> None:
612
573
error = reason ,
613
574
)
614
575
576
+ async def _close_conn (self ) -> None :
577
+ """Close this connection."""
578
+ if self .closed :
579
+ return
580
+ self .closed = True
581
+ self .cancel_context .cancel ()
582
+ # Note: We catch exceptions to avoid spurious errors on interpreter
583
+ # shutdown.
584
+ try :
585
+ await self .conn .close ()
586
+ except Exception : # noqa: S110
587
+ pass
588
+
589
+ def conn_closed (self ) -> bool :
590
+ """Return True if we know socket has been closed, False otherwise."""
591
+ if _IS_SYNC :
592
+ return self .socket_checker .socket_closed (self .conn .get_conn )
593
+ else :
594
+ return self .conn .is_closing ()
595
+
615
596
def send_cluster_time (
616
597
self ,
617
598
command : MutableMapping [str , Any ],
0 commit comments