diff --git a/pymongo/client_session.py b/pymongo/client_session.py index 4a796154d0..24fb979992 100644 --- a/pymongo/client_session.py +++ b/pymongo/client_session.py @@ -179,7 +179,7 @@ from pymongo.pool import Connection from pymongo.server import Server - from pymongo.typings import _Address + from pymongo.typings import ClusterTime, _Address class SessionOptions: @@ -562,7 +562,7 @@ def session_id(self) -> Mapping[str, Any]: return self._server_session.session_id @property - def cluster_time(self) -> Optional[Mapping[str, Any]]: + def cluster_time(self) -> Optional[ClusterTime]: """The cluster time returned by the last operation executed in this session. """ diff --git a/pymongo/hello.py b/pymongo/hello.py index 913a2b0d1a..1715beb5cf 100644 --- a/pymongo/hello.py +++ b/pymongo/hello.py @@ -22,7 +22,7 @@ from bson.objectid import ObjectId from pymongo import common from pymongo.server_type import SERVER_TYPE -from pymongo.typings import _DocumentType +from pymongo.typings import ClusterTime, _DocumentType class HelloCompat: @@ -155,7 +155,7 @@ def election_id(self) -> Optional[ObjectId]: return self._doc.get("electionId") @property - def cluster_time(self) -> Optional[Mapping[str, Any]]: + def cluster_time(self) -> Optional[ClusterTime]: return self._doc.get("$clusterTime") @property diff --git a/pymongo/mongo_client.py b/pymongo/mongo_client.py index 9655fde282..5699c3db8b 100644 --- a/pymongo/mongo_client.py +++ b/pymongo/mongo_client.py @@ -98,6 +98,7 @@ from pymongo.topology import Topology, _ErrorContext from pymongo.topology_description import TOPOLOGY_TYPE, TopologyDescription from pymongo.typings import ( + ClusterTime, _Address, _CollationIn, _DocumentType, @@ -1106,10 +1107,10 @@ def primary(self) -> Optional[Tuple[str, int]]: .. versionadded:: 3.0 MongoClient gained this property in version 3.0. """ - return self._topology.get_primary() + return self._topology.get_primary() # type: ignore[return-value] @property - def secondaries(self) -> Set[Tuple[str, int]]: + def secondaries(self) -> Set[_Address]: """The secondary members known to this client. A sequence of (host, port) pairs. Empty if this client is not @@ -1122,7 +1123,7 @@ def secondaries(self) -> Set[Tuple[str, int]]: return self._topology.get_secondaries() @property - def arbiters(self) -> Set[Tuple[str, int]]: + def arbiters(self) -> Set[_Address]: """Arbiters in the replica set. A sequence of (host, port) pairs. Empty if this client is not @@ -1729,7 +1730,7 @@ def _kill_cursors( if address: # address could be a tuple or _CursorAddress, but # select_server_by_address needs (host, port). - server = topology.select_server_by_address(tuple(address)) + server = topology.select_server_by_address(tuple(address)) # type: ignore[arg-type] else: # Application called close_cursor() with no address. server = topology.select_server(writable_server_selector) @@ -1906,7 +1907,7 @@ def _send_cluster_time( session_time = session.cluster_time if session else None if topology_time and session_time: if topology_time["clusterTime"] > session_time["clusterTime"]: - cluster_time = topology_time + cluster_time: Optional[ClusterTime] = topology_time else: cluster_time = session_time else: @@ -2271,7 +2272,7 @@ def contribute_socket(self, conn: Connection, completed_handshake: bool = True) def handle( self, exc_type: Optional[Type[BaseException]], exc_val: Optional[BaseException] ) -> None: - if self.handled or exc_type is None: + if self.handled or exc_val is None: return self.handled = True if self.session: @@ -2285,7 +2286,6 @@ def handle( "RetryableWriteError" ): self.session._unpin() - err_ctx = _ErrorContext( exc_val, self.max_wire_version, @@ -2300,8 +2300,8 @@ def __enter__(self) -> _MongoClientErrorHandler: def __exit__( self, - exc_type: Optional[Type[BaseException]], - exc_val: Optional[BaseException], + exc_type: Optional[Type[Exception]], + exc_val: Optional[Exception], exc_tb: Optional[TracebackType], ) -> None: return self.handle(exc_type, exc_val) diff --git a/pymongo/pool.py b/pymongo/pool.py index c89aa10e3b..68052f6495 100644 --- a/pymongo/pool.py +++ b/pymongo/pool.py @@ -105,7 +105,7 @@ from pymongo.read_concern import ReadConcern from pymongo.read_preferences import _ServerMode from pymongo.server_api import ServerApi - from pymongo.typings import _Address, _CollationIn + from pymongo.typings import ClusterTime, _Address, _CollationIn from pymongo.write_concern import WriteConcern try: @@ -779,7 +779,7 @@ def hello(self) -> Hello[Dict[str, Any]]: def _hello( self, - cluster_time: Optional[Mapping[str, Any]], + cluster_time: Optional[ClusterTime], topology_version: Optional[Any], heartbeat_frequency: Optional[int], ) -> Hello[Dict[str, Any]]: diff --git a/pymongo/server_description.py b/pymongo/server_description.py index 0a0ced135c..c2fa030537 100644 --- a/pymongo/server_description.py +++ b/pymongo/server_description.py @@ -22,7 +22,7 @@ from bson.objectid import ObjectId from pymongo.hello import Hello from pymongo.server_type import SERVER_TYPE -from pymongo.typings import _Address +from pymongo.typings import ClusterTime, _Address class ServerDescription: @@ -176,7 +176,7 @@ def election_id(self) -> Optional[ObjectId]: return self._election_id @property - def cluster_time(self) -> Optional[Mapping[str, Any]]: + def cluster_time(self) -> Optional[ClusterTime]: return self._cluster_time @property diff --git a/pymongo/settings.py b/pymongo/settings.py index 3436fcad6b..d6ef93e5c2 100644 --- a/pymongo/settings.py +++ b/pymongo/settings.py @@ -95,11 +95,11 @@ def pool_options(self) -> PoolOptions: return self._pool_options @property - def monitor_class(self) -> Optional[Type[monitor.Monitor]]: + def monitor_class(self) -> Type[monitor.Monitor]: return self._monitor_class @property - def condition_class(self) -> Optional[Type[threading.Condition]]: + def condition_class(self) -> Type[threading.Condition]: return self._condition_class @property diff --git a/pymongo/topology.py b/pymongo/topology.py index 0a2eaf9420..6fd1138fb2 100644 --- a/pymongo/topology.py +++ b/pymongo/topology.py @@ -14,16 +14,29 @@ """Internal class to monitor a topology of one or more servers.""" +from __future__ import annotations + import os import queue import random import time import warnings import weakref -from typing import Any +from typing import ( + TYPE_CHECKING, + Any, + Callable, + Dict, + List, + Mapping, + Optional, + Set, + Tuple, + cast, +) from pymongo import _csot, common, helpers, periodic_executor -from pymongo.client_session import _ServerSessionPool +from pymongo.client_session import _ServerSession, _ServerSessionPool from pymongo.errors import ( ConfigurationError, ConnectionFailure, @@ -38,7 +51,7 @@ from pymongo.hello import Hello from pymongo.lock import _create_lock from pymongo.monitor import SrvMonitor -from pymongo.pool import PoolOptions +from pymongo.pool import Pool, PoolOptions from pymongo.server import Server from pymongo.server_description import ServerDescription from pymongo.server_selectors import ( @@ -57,8 +70,13 @@ updated_topology_description, ) +if TYPE_CHECKING: + from bson import ObjectId + from pymongo.settings import TopologySettings + from pymongo.typings import ClusterTime, _Address -def process_events_queue(queue_ref): + +def process_events_queue(queue_ref: weakref.ReferenceType[queue.Queue]) -> bool: q = queue_ref() if not q: return False # Cancel PeriodicExecutor. @@ -78,12 +96,11 @@ def process_events_queue(queue_ref): class Topology: """Monitor a topology of one or more servers.""" - def __init__(self, topology_settings): + def __init__(self, topology_settings: TopologySettings): self._topology_id = topology_settings._topology_id self._listeners = topology_settings._pool_options._event_listeners - pub = self._listeners is not None - self._publish_server = pub and self._listeners.enabled_for_server - self._publish_tp = pub and self._listeners.enabled_for_topology + self._publish_server = self._listeners is not None and self._listeners.enabled_for_server + self._publish_tp = self._listeners is not None and self._listeners.enabled_for_topology # Create events queue if there are publishers. self._events = None @@ -129,14 +146,16 @@ def __init__(self, topology_settings): self._closed = False self._lock = _create_lock() self._condition = self._settings.condition_class(self._lock) - self._servers = {} - self._pid = None - self._max_cluster_time = None + self._servers: Dict[_Address, Server] = {} + self._pid: Optional[int] = None + self._max_cluster_time: Optional[ClusterTime] = None self._session_pool = _ServerSessionPool() if self._publish_server or self._publish_tp: + assert self._events is not None + weak: weakref.ReferenceType[queue.Queue] - def target(): + def target() -> bool: return process_events_queue(weak) executor = periodic_executor.PeriodicExecutor( @@ -157,7 +176,7 @@ def target(): if self._settings.fqdn is not None and not self._settings.load_balanced: self._srv_monitor = SrvMonitor(self, self._settings) - def open(self): + def open(self) -> None: """Start monitoring, or restart after a fork. No effect if called multiple times. @@ -191,14 +210,19 @@ def open(self): with self._lock: self._ensure_opened() - def get_server_selection_timeout(self): + def get_server_selection_timeout(self) -> float: # CSOT: use remaining timeout when set. timeout = _csot.remaining() if timeout is None: return self._settings.server_selection_timeout return timeout - def select_servers(self, selector, server_selection_timeout=None, address=None): + def select_servers( + self, + selector: Callable[[Selection], Selection], + server_selection_timeout: Optional[float] = None, + address: Optional[_Address] = None, + ) -> List[Server]: """Return a list of Servers matching selector, or time out. :Parameters: @@ -222,9 +246,16 @@ def select_servers(self, selector, server_selection_timeout=None, address=None): with self._lock: server_descriptions = self._select_servers_loop(selector, server_timeout, address) - return [self.get_server_by_address(sd.address) for sd in server_descriptions] + return [ + cast(Server, self.get_server_by_address(sd.address)) for sd in server_descriptions + ] - def _select_servers_loop(self, selector, timeout, address): + def _select_servers_loop( + self, + selector: Callable[[Selection], Selection], + timeout: float, + address: Optional[_Address], + ) -> List[ServerDescription]: """select_servers() guts. Hold the lock when calling this.""" now = time.monotonic() end_time = now + timeout @@ -256,7 +287,12 @@ def _select_servers_loop(self, selector, timeout, address): self._description.check_compatible() return server_descriptions - def _select_server(self, selector, server_selection_timeout=None, address=None): + def _select_server( + self, + selector: Callable[[Selection], Selection], + server_selection_timeout: Optional[float] = None, + address: Optional[_Address] = None, + ) -> Server: servers = self.select_servers(selector, server_selection_timeout, address) if len(servers) == 1: return servers[0] @@ -266,14 +302,21 @@ def _select_server(self, selector, server_selection_timeout=None, address=None): else: return server2 - def select_server(self, selector, server_selection_timeout=None, address=None): + def select_server( + self, + selector: Callable[[Selection], Selection], + server_selection_timeout: Optional[float] = None, + address: Optional[_Address] = None, + ) -> Server: """Like select_servers, but choose a random server if several match.""" server = self._select_server(selector, server_selection_timeout, address) if _csot.get_timeout(): _csot.set_rtt(server.description.min_round_trip_time) return server - def select_server_by_address(self, address, server_selection_timeout=None): + def select_server_by_address( + self, address: _Address, server_selection_timeout: Optional[int] = None + ) -> Server: """Return a Server for "address", reconnecting if necessary. If the server's type is not known, request an immediate check of all @@ -293,7 +336,9 @@ def select_server_by_address(self, address, server_selection_timeout=None): """ return self.select_server(any_server_selector, server_selection_timeout, address) - def _process_change(self, server_description, reset_pool=False): + def _process_change( + self, server_description: ServerDescription, reset_pool: bool = False + ) -> None: """Process a new ServerDescription on an opened topology. Hold the lock when calling this. @@ -354,7 +399,7 @@ def _process_change(self, server_description, reset_pool=False): # Wake waiters in select_servers(). self._condition.notify_all() - def on_change(self, server_description, reset_pool=False): + def on_change(self, server_description: ServerDescription, reset_pool: bool = False) -> None: """Process a new ServerDescription after an hello call completes.""" # We do no I/O holding the lock. with self._lock: @@ -369,7 +414,7 @@ def on_change(self, server_description, reset_pool=False): if self._opened and self._description.has_server(server_description.address): self._process_change(server_description, reset_pool) - def _process_srv_update(self, seedlist): + def _process_srv_update(self, seedlist: List[Tuple[str, Any]]) -> None: """Process a new seedlist on an opened topology. Hold the lock when calling this. """ @@ -389,14 +434,14 @@ def _process_srv_update(self, seedlist): ) ) - def on_srv_update(self, seedlist): + def on_srv_update(self, seedlist: List[Tuple[str, Any]]) -> None: """Process a new list of nodes obtained from scanning SRV records.""" # We do no I/O holding the lock. with self._lock: if self._opened: self._process_srv_update(seedlist) - def get_server_by_address(self, address): + def get_server_by_address(self, address: _Address) -> Optional[Server]: """Get a Server or None. Returns the current version of the server immediately, even if it's @@ -406,10 +451,10 @@ def get_server_by_address(self, address): """ return self._servers.get(address) - def has_server(self, address): + def has_server(self, address: _Address) -> bool: return address in self._servers - def get_primary(self): + def get_primary(self) -> Optional[_Address]: """Return primary's address or None.""" # Implemented here in Topology instead of MongoClient, so it can lock. with self._lock: @@ -419,7 +464,7 @@ def get_primary(self): return writable_server_selector(self._new_selection())[0].address - def _get_replica_set_members(self, selector): + def _get_replica_set_members(self, selector: Callable[[Selection], Selection]) -> Set[_Address]: """Return set of replica set member addresses.""" # Implemented here in Topology instead of MongoClient, so it can lock. with self._lock: @@ -430,21 +475,21 @@ def _get_replica_set_members(self, selector): ): return set() - return {sd.address for sd in selector(self._new_selection())} + return {sd.address for sd in iter(selector(self._new_selection()))} - def get_secondaries(self): + def get_secondaries(self) -> Set[_Address]: """Return set of secondary addresses.""" return self._get_replica_set_members(secondary_server_selector) - def get_arbiters(self): + def get_arbiters(self) -> Set[_Address]: """Return set of arbiter addresses.""" return self._get_replica_set_members(arbiter_server_selector) - def max_cluster_time(self): + def max_cluster_time(self) -> Optional[ClusterTime]: """Return a document, the highest seen $clusterTime.""" return self._max_cluster_time - def _receive_cluster_time_no_lock(self, cluster_time): + def _receive_cluster_time_no_lock(self, cluster_time: Optional[Mapping[str, Any]]) -> None: # Driver Sessions Spec: "Whenever a driver receives a cluster time from # a server it MUST compare it to the current highest seen cluster time # for the deployment. If the new cluster time is higher than the @@ -459,17 +504,17 @@ def _receive_cluster_time_no_lock(self, cluster_time): ): self._max_cluster_time = cluster_time - def receive_cluster_time(self, cluster_time): + def receive_cluster_time(self, cluster_time: Optional[Mapping[str, Any]]) -> None: with self._lock: self._receive_cluster_time_no_lock(cluster_time) - def request_check_all(self, wait_time=5): + def request_check_all(self, wait_time: int = 5) -> None: """Wake all monitors, wait for at least one to check its server.""" with self._lock: self._request_check_all() self._condition.wait(wait_time) - def data_bearing_servers(self): + def data_bearing_servers(self) -> List[ServerDescription]: """Return a list of all data-bearing servers. This includes any server that might be selected for an operation. @@ -478,7 +523,7 @@ def data_bearing_servers(self): return self._description.known_servers return self._description.readable_servers - def update_pool(self): + def update_pool(self) -> None: # Remove any stale sockets and add new sockets if pool is too small. servers = [] with self._lock: @@ -495,7 +540,7 @@ def update_pool(self): self.handle_error(server.description.address, ctx) raise - def close(self): + def close(self) -> None: """Clear pools and terminate monitors. Topology does not reopen on demand. Any further operations will raise :exc:`~.errors.InvalidOperation`. @@ -525,19 +570,19 @@ def close(self): self.__events_executor.close() @property - def description(self): + def description(self) -> TopologyDescription: return self._description - def pop_all_sessions(self): + def pop_all_sessions(self) -> List[_ServerSession]: """Pop all session ids from the pool.""" with self._lock: return self._session_pool.pop_all() - def _check_implicit_session_support(self): + def _check_implicit_session_support(self) -> None: with self._lock: self._check_session_support() - def _check_session_support(self): + def _check_session_support(self) -> float: """Internal check for session support on clusters.""" if self._settings.load_balanced: # Sessions never time out in load balanced mode. @@ -560,13 +605,13 @@ def _check_session_support(self): raise ConfigurationError("Sessions are not supported by this MongoDB deployment") return session_timeout - def get_server_session(self): + def get_server_session(self) -> _ServerSession: """Start or resume a server session, or raise ConfigurationError.""" with self._lock: session_timeout = self._check_session_support() return self._session_pool.get_server_session(session_timeout) - def return_server_session(self, server_session, lock): + def return_server_session(self, server_session: _ServerSession, lock: bool) -> None: if lock: with self._lock: self._session_pool.return_server_session( @@ -576,14 +621,14 @@ def return_server_session(self, server_session, lock): # Called from a __del__ method, can't use a lock. self._session_pool.return_server_session_no_lock(server_session) - def _new_selection(self): + def _new_selection(self) -> Selection: """A Selection object, initially including all known servers. Hold the lock when calling this. """ return Selection.from_topology_description(self._description) - def _ensure_opened(self): + def _ensure_opened(self) -> None: """Start monitors, or restart after a fork. Hold the lock when calling this. @@ -616,7 +661,7 @@ def _ensure_opened(self): for server in self._servers.values(): server.open() - def _is_stale_error(self, address, err_ctx): + def _is_stale_error(self, address: _Address, err_ctx: _ErrorContext) -> bool: server = self._servers.get(address) if server is None: # Another thread removed this server from the topology. @@ -636,13 +681,12 @@ def _is_stale_error(self, address, err_ctx): return _is_stale_error_topology_version(cur_tv, error_tv) - def _handle_error(self, address, err_ctx): + def _handle_error(self, address: _Address, err_ctx: _ErrorContext) -> None: if self._is_stale_error(address, err_ctx): return server = self._servers[address] error = err_ctx.error - exc_type = type(error) service_id = err_ctx.service_id # Ignore a handshake error if the server is behind a load balancer but @@ -652,16 +696,16 @@ def _handle_error(self, address, err_ctx): if self._settings.load_balanced and not service_id and not err_ctx.completed_handshake: return - if issubclass(exc_type, NetworkTimeout) and err_ctx.completed_handshake: + if isinstance(error, NetworkTimeout) and err_ctx.completed_handshake: # The socket has been closed. Don't reset the server. # Server Discovery And Monitoring Spec: "When an application # operation fails because of any network error besides a socket # timeout...." return - elif issubclass(exc_type, WriteError): + elif isinstance(error, WriteError): # Ignore writeErrors. return - elif issubclass(exc_type, (NotPrimaryError, OperationFailure)): + elif isinstance(error, (NotPrimaryError, OperationFailure)): # As per the SDAM spec if: # - the server sees a "not primary" error, and # - the server is not shutting down, and @@ -675,7 +719,7 @@ def _handle_error(self, address, err_ctx): else: # Default error code if one does not exist. default = 10107 if isinstance(error, NotPrimaryError) else None - err_code = error.details.get("code", default) + err_code = error.details.get("code", default) # type: ignore[union-attr] if err_code in helpers._NOT_PRIMARY_CODES: is_shutting_down = err_code in helpers._SHUTDOWN_CODES # Mark server Unknown, clear the pool, and request check. @@ -691,7 +735,7 @@ def _handle_error(self, address, err_ctx): self._process_change(ServerDescription(address, error=error)) # Clear the pool. server.reset(service_id) - elif issubclass(exc_type, ConnectionFailure): + elif isinstance(error, ConnectionFailure): # "Client MUST replace the server's description with type Unknown # ... MUST NOT request an immediate check of the server." if not self._settings.load_balanced: @@ -703,7 +747,7 @@ def _handle_error(self, address, err_ctx): # that server and close the current monitoring connection." server._monitor.cancel_check() - def handle_error(self, address, err_ctx): + def handle_error(self, address: _Address, err_ctx: _ErrorContext) -> None: """Handle an application error. May reset the server to Unknown, clear the pool, and request an @@ -712,12 +756,12 @@ def handle_error(self, address, err_ctx): with self._lock: self._handle_error(address, err_ctx) - def _request_check_all(self): + def _request_check_all(self) -> None: """Wake all monitors. Hold the lock when calling this.""" for server in self._servers.values(): server.request_check() - def _update_servers(self): + def _update_servers(self) -> None: """Sync our Servers from TopologyDescription.server_descriptions. Hold the lock while calling this. @@ -759,10 +803,10 @@ def _update_servers(self): server.close() self._servers.pop(address) - def _create_pool_for_server(self, address): + def _create_pool_for_server(self, address: _Address) -> Pool: return self._settings.pool_class(address, self._settings.pool_options) - def _create_pool_for_monitor(self, address): + def _create_pool_for_monitor(self, address: _Address) -> Pool: options = self._settings.pool_options # According to the Server Discovery And Monitoring Spec, monitors use @@ -782,7 +826,7 @@ def _create_pool_for_monitor(self, address): return self._settings.pool_class(address, monitor_pool_options, handshake=False) - def _error_message(self, selector): + def _error_message(self, selector: Callable[[Selection], Selection]) -> str: """Format an error message if server selection fails. Hold the lock when calling this. @@ -840,7 +884,7 @@ def _error_message(self, selector): else: return ",".join(str(server.error) for server in servers if server.error) - def __repr__(self): + def __repr__(self) -> str: msg = "" if not self._opened: msg = "CLOSED " @@ -851,19 +895,26 @@ def eq_props(self): ts = self._settings return (tuple(sorted(ts.seeds)), ts.replica_set_name, ts.fqdn, ts.srv_service_name) - def __eq__(self, other): + def __eq__(self, other: object) -> bool: if isinstance(other, self.__class__): return self.eq_props() == other.eq_props() return NotImplemented - def __hash__(self): + def __hash__(self) -> int: return hash(self.eq_props()) class _ErrorContext: """An error with context for SDAM error handling.""" - def __init__(self, error, max_wire_version, sock_generation, completed_handshake, service_id): + def __init__( + self, + error: BaseException, + max_wire_version: int, + sock_generation: int, + completed_handshake: bool, + service_id: Optional[ObjectId], + ): self.error = error self.max_wire_version = max_wire_version self.sock_generation = sock_generation @@ -871,7 +922,9 @@ def __init__(self, error, max_wire_version, sock_generation, completed_handshake self.service_id = service_id -def _is_stale_error_topology_version(current_tv, error_tv): +def _is_stale_error_topology_version( + current_tv: Optional[Mapping[str, Any]], error_tv: Optional[Mapping[str, Any]] +) -> bool: """Return True if the error's topologyVersion is <= current.""" if current_tv is None or error_tv is None: return False @@ -880,7 +933,7 @@ def _is_stale_error_topology_version(current_tv, error_tv): return current_tv["counter"] >= error_tv["counter"] -def _is_stale_server_description(current_sd, new_sd): +def _is_stale_server_description(current_sd: ServerDescription, new_sd: ServerDescription) -> bool: """Return True if the new topologyVersion is < current.""" current_tv, new_tv = current_sd.topology_version, new_sd.topology_version if current_tv is None or new_tv is None: diff --git a/pymongo/typings.py b/pymongo/typings.py index 4630870b81..3464c92945 100644 --- a/pymongo/typings.py +++ b/pymongo/typings.py @@ -34,6 +34,7 @@ _Address = Tuple[str, Optional[int]] _CollationIn = Union[Mapping[str, Any], "Collation"] _Pipeline = Sequence[Mapping[str, Any]] +ClusterTime = Mapping[str, Any] _T = TypeVar("_T") diff --git a/test/test_max_staleness.py b/test/test_max_staleness.py index 799083f3b4..1596c1682f 100644 --- a/test/test_max_staleness.py +++ b/test/test_max_staleness.py @@ -123,6 +123,9 @@ def test_last_write_date(self): time.sleep(1) server = client._topology.select_server(writable_server_selector) second = server.description.last_write_date + assert first is not None + + assert second is not None self.assertGreater(second, first) self.assertLess(second, first + 10) diff --git a/test/test_read_preferences.py b/test/test_read_preferences.py index a3343d07c9..814874a266 100644 --- a/test/test_read_preferences.py +++ b/test/test_read_preferences.py @@ -265,7 +265,7 @@ def test_nearest(self): not_used = data_members.difference(used) latencies = ", ".join( - "%s: %dms" % (server.description.address, server.description.round_trip_time) + "%s: %sms" % (server.description.address, server.description.round_trip_time) for server in c._get_topology().select_servers(readable_server_selector) ) diff --git a/test/test_streaming_protocol.py b/test/test_streaming_protocol.py index 72df717901..9da5a550aa 100644 --- a/test/test_streaming_protocol.py +++ b/test/test_streaming_protocol.py @@ -122,6 +122,7 @@ def rtt_exceeds_250_ms(): # XXX: Add a public TopologyDescription getter to MongoClient? topology = client._topology sd = topology.description.server_descriptions()[address] + assert sd.round_trip_time is not None return sd.round_trip_time > 0.250 wait_until(rtt_exceeds_250_ms, "exceed 250ms RTT")