diff --git a/cassandra/cluster.py b/cassandra/cluster.py index 679293a52d..0bd6e9f516 100644 --- a/cassandra/cluster.py +++ b/cassandra/cluster.py @@ -73,7 +73,8 @@ ExponentialReconnectionPolicy, HostDistance, RetryPolicy, IdentityTranslator, NoSpeculativeExecutionPlan, NoSpeculativeExecutionPolicy, DefaultLoadBalancingPolicy, - NeverRetryPolicy) + NeverRetryPolicy, ShardConnectionBackoffPolicy, NoDelayShardConnectionBackoffPolicy, + ShardConnectionScheduler) from cassandra.pool import (Host, _ReconnectionHandler, _HostReconnectionHandler, HostConnectionPool, HostConnection, NoConnectionsAvailable) @@ -757,6 +758,11 @@ def auth_provider(self, value): self._auth_provider = value + _shard_connection_backoff_policy: ShardConnectionBackoffPolicy + @property + def shard_connection_backoff_policy(self) -> ShardConnectionBackoffPolicy: + return self._shard_connection_backoff_policy + _load_balancing_policy = None @property def load_balancing_policy(self): @@ -1219,7 +1225,8 @@ def __init__(self, shard_aware_options=None, metadata_request_timeout=None, column_encryption_policy=None, - application_info:Optional[ApplicationInfoBase]=None + application_info: Optional[ApplicationInfoBase] = None, + shard_connection_backoff_policy: Optional[ShardConnectionBackoffPolicy] = None, ): """ ``executor_threads`` defines the number of threads in a pool for handling asynchronous tasks such as @@ -1325,6 +1332,13 @@ def __init__(self, else: self._load_balancing_policy = default_lbp_factory() # set internal attribute to avoid committing to legacy config mode + if shard_connection_backoff_policy is not None: + if not isinstance(shard_connection_backoff_policy, ShardConnectionBackoffPolicy): + raise TypeError("shard_connection_backoff_policy should be an instance of class derived from ShardConnectionBackoffPolicy") + self._shard_connection_backoff_policy = shard_connection_backoff_policy + else: + self._shard_connection_backoff_policy = NoDelayShardConnectionBackoffPolicy() + if reconnection_policy is not None: if isinstance(reconnection_policy, type): raise TypeError("reconnection_policy should not be a class, it should be an instance of that class") @@ -2716,6 +2730,7 @@ def default_serial_consistency_level(self, cl): _metrics = None _request_init_callbacks = None _graph_paging_available = False + shard_connection_backoff_scheduler: ShardConnectionScheduler def __init__(self, cluster, hosts, keyspace=None): self.cluster = cluster @@ -2730,6 +2745,7 @@ def __init__(self, cluster, hosts, keyspace=None): self._protocol_version = self.cluster.protocol_version self.encoder = Encoder() + self.shard_connection_backoff_scheduler = cluster.shard_connection_backoff_policy.new_scheduler(self) # create connection pools in parallel self._initial_connect_futures = set() diff --git a/cassandra/policies.py b/cassandra/policies.py index cb83238e87..d4f665a291 100644 --- a/cassandra/policies.py +++ b/cassandra/policies.py @@ -11,9 +11,16 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import random +import threading +import time +import weakref +from abc import ABC, abstractmethod from collections import namedtuple +from enum import Enum from functools import lru_cache from itertools import islice, cycle, groupby, repeat import logging @@ -21,11 +28,14 @@ from threading import Lock import socket import warnings +from typing import TYPE_CHECKING, Callable, Any, List, Tuple, Iterator, Optional, Dict log = logging.getLogger(__name__) from cassandra import WriteType as WT +if TYPE_CHECKING: + from cluster import Session # This is done this way because WriteType was originally # defined here and in order not to break the API. @@ -864,6 +874,339 @@ def _add_jitter(self, value): return min(max(self.base_delay, delay), self.max_delay) +class ShardConnectionScheduler(ABC): + """ + A base class for a scheduler for a shard connection backoff policy. + ``ShardConnectionScheduler`` is a per Session instance that implements logic + described by ``ShardConnectionBackoffPolicy`` that instantiates it + """ + + @abstractmethod + def schedule( + self, + host_id: str, + shard_id: int, + method: Callable[..., None], + *args: List[Any], + **kwargs: dict[Any, Any]) -> None: + """ + Schedule a create connection request for given host and shard according to policy or executes it right away. + It is non-blocking call, even if policy executes it right away, it is being executed in a separate thread. + + ``host_id`` - an id of the host of the shard + ``shard_id`` - an id of the shard + ``method`` - a callable that creates connection and stores it in the connection pool. + Currently, it is `HostConnection._open_connection_to_missing_shard` + ``*args`` and ``**kwargs`` are passed to ``method`` when policy executes it + """ + raise NotImplementedError() + + +class ShardConnectionBackoffPolicy(ABC): + """ + Base class for shard connection backoff policies. + These policies allow user to control pace of establishing new connections to shards + + On `new_scheduler` instantiate a scheduler that behaves according to the policy + """ + + @abstractmethod + def new_scheduler(self, session: Session) -> ShardConnectionScheduler: + raise NotImplementedError() + + +class NoDelayShardConnectionBackoffPolicy(ShardConnectionBackoffPolicy): + """ + A shard connection backoff policy with no delay between attempts and no concurrency restrictions. + Ensures that at most one pending connection per (host, shard) pair. + If connection attempts for the same (host, shard) it is silently dropped. + + On `new_scheduler` instantiate a scheduler that behaves according to the policy + """ + + def new_scheduler(self, session: Session) -> ShardConnectionScheduler: + return _NoDelayShardConnectionBackoffScheduler(session) + + +class _NoDelayShardConnectionBackoffScheduler(ShardConnectionScheduler): + """ + A shard connection backoff scheduler for ``cassandra.policies.NoDelayShardConnectionBackoffPolicy``. + It does not introduce any delay or concurrency restrictions. + It only ensures that there is only one pending or scheduled connection per host+shard. + """ + session: Session + already_scheduled: dict[str, bool] + lock: threading.Lock + + def __init__(self, session: Session): + self.session = weakref.proxy(session) + self.already_scheduled = {} + self.lock = threading.Lock() + + def _execute( + self, + scheduled_key: str, + method: Callable[..., None], + *args: List[Any], + **kwargs: dict[Any, Any]) -> None: + try: + method(*args, **kwargs) + finally: + with self.lock: + self.already_scheduled[scheduled_key] = False + + def schedule( + self, + host_id: str, + shard_id: int, + method: Callable[..., None], + *args: List[Any], + **kwargs: dict[Any, Any]) -> None: + scheduled_key = f'{host_id}-{shard_id}' + + with self.lock: + if self.already_scheduled.get(scheduled_key): + return + self.already_scheduled[scheduled_key] = True + + if not self.session.is_shutdown: + self.session.submit(self._execute, scheduled_key, method, *args, **kwargs) + + +class ShardConnectionBackoffScope(Enum): + """ + A scope for `ShardConnectionBackoffPolicy`, in particular ``LimitedConcurrencyShardConnectionBackoffPolicy`` + + Scope defines concurrency limitation scope, for instance : + ``LimitedConcurrencyShardConnectionBackoffPolicy`` - allows only one pending connection per scope, if you set it to Cluster, + only one connection per cluster will be allowed + """ + Cluster = 0 + Host = 1 + + +class ShardConnectionBackoffSchedule(ABC): + @abstractmethod + def new_schedule(self) -> Iterator[float]: + """ + This should return a finite or infinite iterable of delays (each as a + floating point number of seconds). + Note that if the iterable is finite, schedule will be recreated right after iterable is exhausted. + """ + raise NotImplementedError() + + +class LimitedConcurrencyShardConnectionBackoffPolicy(ShardConnectionBackoffPolicy): + """ + A shard connection backoff policy that allows only ``max_concurrent`` concurrent connection per scope. + Scope could be ``Host``or ``Cluster`` + For backoff calculation ir needs ``ShardConnectionBackoffSchedule`` or ``ReconnectionPolicy``, since both share same API. + When there is no more scheduled connections schedule of the backoff is reset. + + it also does not allow multiple pending or scheduled connections for same host+shard, + it silently drops attempts to schedule it. + + On ``new_scheduler`` instantiate a scheduler that behaves according to the policy + """ + scope: ShardConnectionBackoffScope + backoff_policy: ShardConnectionBackoffSchedule | ReconnectionPolicy + + max_concurrent: int + """ + Max concurrent connection creation requests per scope. + """ + + def __init__( + self, + scope: ShardConnectionBackoffScope, + backoff_policy: ShardConnectionBackoffSchedule | ReconnectionPolicy, + max_concurrent: int = 1, + ): + if not isinstance(scope, ShardConnectionBackoffScope): + raise ValueError("scope must be a ShardConnectionBackoffScope") + if not isinstance(backoff_policy, (ShardConnectionBackoffSchedule, ReconnectionPolicy)): + raise ValueError("backoff_policy must be a ShardConnectionBackoffSchedule or ReconnectionPolicy") + if max_concurrent < 1: + raise ValueError("max_concurrent must be a positive integer") + self.scope = scope + self.backoff_policy = backoff_policy + self.max_concurrent = max_concurrent + + def new_scheduler(self, session: Session) -> ShardConnectionScheduler: + return _LimitedConcurrencyShardConnectionScheduler(session, self.scope, self.backoff_policy, self.max_concurrent) + + +class CreateConnectionCallback: + method: Callable[..., None] + args: Tuple[Any, ...] + kwargs: Dict[str, Any] + + def __init__(self, method: Callable[..., None], *args, **kwargs) -> None: + self.method = method + self.args = args + self.kwargs = kwargs + + +class _ScopeBucket: + """ + Holds information for a shard reconnection scope, schedules and executes reconnections. + """ + session: Session + backoff_policy: ShardConnectionBackoffSchedule + lock: threading.Lock + + schedule: Iterator[float] + """ + Current schedule generated by ``backoff_policy`` + """ + + max_concurrent: int + """ + Max concurrent connection creation requests in the scope. + """ + + currently_pending: int + """ + Currently pending connections + """ + + items: List[CreateConnectionCallback] + """ + Scheduled create connections requests + """ + + def __init__( + self, + session: Session, + backoff_policy: ShardConnectionBackoffSchedule, + max_concurrent: int, + ): + self.items = [] + self.session = session + self.backoff_policy = backoff_policy + self.lock = threading.Lock() + self.schedule = self.backoff_policy.new_schedule() + self.max_concurrent = max_concurrent + self.currently_pending = 0 + + def _get_delay(self) -> float: + try: + return next(self.schedule) + except StopIteration: + # A bit of trickery to avoid having lock around self.schedule + schedule = self.backoff_policy.new_schedule() + delay = next(schedule) + self.schedule = schedule + return delay + + def _schedule(self): + if self.session.is_shutdown: + return + delay = self._get_delay() + if delay: + self.session.cluster.scheduler.schedule(delay, self._run) + else: + self.session.submit(self._run) + + def _run(self): + if self.session.is_shutdown: + return + + with self.lock: + try: + cb = self.items.pop() + except IndexError: + # Just in case + if self.currently_pending > 0: + self.currently_pending -= 1 + # When items are exhausted reset schedule to ensure that new items going to get another schedule + # It is important for exponential policy + self.schedule = self.backoff_policy.new_schedule() + return + + try: + cb.method(*cb.args, **cb.kwargs) + finally: + self._schedule() + + def schedule_new_connection(self, cb: CreateConnectionCallback): + with self.lock: + self.items.append(cb) + if self.currently_pending < self.max_concurrent: + self.currently_pending += 1 + self._schedule() + + +class _LimitedConcurrencyShardConnectionScheduler(ShardConnectionScheduler): + """ + Dict of host+shard flags, flag is true if there is connection creation request scheduled or currently running for host+shard + """ + already_scheduled: dict[str, bool] + + scopes: dict[str, _ScopeBucket] + """ + Scopes storage + """ + + scope: ShardConnectionBackoffScope + """ + Scope type + """ + + backoff_policy: ShardConnectionBackoffSchedule + session: Session + lock: threading.Lock + + max_concurrent: int + """ + Max concurrent connection creation requests per scope. + """ + + def __init__( + self, + session: Session, + scope: ShardConnectionBackoffScope, + backoff_policy: ShardConnectionBackoffSchedule, + max_concurrent: int, + ): + self.already_scheduled = {} + self.scopes = {} + self.scope = scope + self.backoff_policy = backoff_policy + self.max_concurrent = max_concurrent + self.session = session + self.lock = threading.Lock() + + def _execute(self, scheduled_key: str, method: Callable[..., None], *args, **kwargs): + try: + method(*args, **kwargs) + finally: + with self.lock: + self.already_scheduled[scheduled_key] = False + + def schedule(self, host_id: str, shard_id: int, method: Callable[..., None], *args, **kwargs): + if self.scope == ShardConnectionBackoffScope.Cluster: + scope_hash = "global-cluster-scope" + elif self.scope == ShardConnectionBackoffScope.Host: + scope_hash = host_id + else: + raise ValueError("scope must be Cluster or Host") + + scheduled_key = f'{host_id}-{shard_id}' + + with self.lock: + if self.already_scheduled.get(scheduled_key): + return False + self.already_scheduled[scheduled_key] = True + + scope_info = self.scopes.get(scope_hash) + if not scope_info: + scope_info = _ScopeBucket(self.session, self.backoff_policy, self.max_concurrent) + self.scopes[scope_hash] = scope_info + scope_info.schedule_new_connection(CreateConnectionCallback(self._execute, scheduled_key, method, *args, **kwargs)) + return True + + class RetryPolicy(object): """ A policy that describes whether to retry, rethrow, or ignore coordinator diff --git a/cassandra/pool.py b/cassandra/pool.py index d1f6604abf..07fd3ec3a4 100644 --- a/cassandra/pool.py +++ b/cassandra/pool.py @@ -402,7 +402,6 @@ def __init__(self, host, host_distance, session): # this is used in conjunction with the connection streams. Not using the connection lock because the connection can be replaced in the lifetime of the pool. self._stream_available_condition = Condition(Lock()) self._is_replacing = False - self._connecting = set() self._connections = {} self._pending_connections = [] # A pool of additional connections which are not used but affect how Scylla @@ -418,7 +417,6 @@ def __init__(self, host, host_distance, session): # and are waiting until all requests time out or complete # so that we can dispose of them. self._trash = set() - self._shard_connections_futures = [] self.advanced_shardaware_block_until = 0 if host_distance == HostDistance.IGNORED: @@ -483,25 +481,25 @@ def _get_connection_for_routing_key(self, routing_key=None, keyspace=None, table self.host, routing_key ) - if conn.orphaned_threshold_reached and shard_id not in self._connecting: + if conn.orphaned_threshold_reached: # The connection has met its orphaned stream ID limit # and needs to be replaced. Start opening a connection # to the same shard and replace when it is opened. - self._connecting.add(shard_id) - self._session.submit(self._open_connection_to_missing_shard, shard_id) + self._session.shard_connection_backoff_scheduler.schedule( + self.host.host_id, shard_id, self._open_connection_to_missing_shard, shard_id) log.debug( - "Connection to shard_id=%i reached orphaned stream limit, replacing on host %s (%s/%i)", + "Scheduling Connection to shard_id=%i reached orphaned stream limit, replacing on host %s (%s/%i)", shard_id, self.host, len(self._connections.keys()), self.host.sharding_info.shards_count ) - elif shard_id not in self._connecting: + else: # rate controlled optimistic attempt to connect to a missing shard - self._connecting.add(shard_id) - self._session.submit(self._open_connection_to_missing_shard, shard_id) + self._session.shard_connection_backoff_scheduler.schedule( + self.host.host_id, shard_id, self._open_connection_to_missing_shard, shard_id) log.debug( - "Trying to connect to missing shard_id=%i on host %s (%s/%i)", + "Scheduling connection to missing shard_id=%i on host %s (%s/%i)", shard_id, self.host, len(self._connections.keys()), @@ -609,8 +607,8 @@ def _replace(self, connection): if connection.features.shard_id in self._connections.keys(): del self._connections[connection.features.shard_id] if self.host.sharding_info and not self._session.cluster.shard_aware_options.disable: - self._connecting.add(connection.features.shard_id) - self._session.submit(self._open_connection_to_missing_shard, connection.features.shard_id) + self._session.shard_connection_backoff_scheduler.schedule( + self.host.host_id, connection.features.shard_id, self._open_connection_to_missing_shard, connection.features.shard_id) else: connection = self._session.cluster.connection_factory(self.host.endpoint, on_orphaned_stream_released=self.on_orphaned_stream_released) @@ -635,9 +633,6 @@ def shutdown(self): with self._stream_available_condition: self._stream_available_condition.notify_all() - for future in self._shard_connections_futures: - future.cancel() - connections_to_close = self._connections.copy() pending_connections_to_close = self._pending_connections.copy() self._connections.clear() @@ -843,7 +838,6 @@ def _open_connection_to_missing_shard(self, shard_id): self._excess_connections.add(conn) if close_connection: conn.close() - self._connecting.discard(shard_id) def _open_connections_for_all_shards(self, skip_shard_id=None): """ @@ -856,10 +850,8 @@ def _open_connections_for_all_shards(self, skip_shard_id=None): for shard_id in range(self.host.sharding_info.shards_count): if skip_shard_id is not None and skip_shard_id == shard_id: continue - future = self._session.submit(self._open_connection_to_missing_shard, shard_id) - if isinstance(future, Future): - self._connecting.add(shard_id) - self._shard_connections_futures.append(future) + self._session.shard_connection_backoff_scheduler.schedule( + self.host.host_id, shard_id, self._open_connection_to_missing_shard, shard_id) trash_conns = None with self._lock: diff --git a/tests/integration/__init__.py b/tests/integration/__init__.py index f00d4c7126..206fac860b 100644 --- a/tests/integration/__init__.py +++ b/tests/integration/__init__.py @@ -396,6 +396,7 @@ def _id_and_mark(f): reason='Scylla does not support custom payloads. Cassandra requires native protocol v4.0+') xfail_scylla = lambda reason, *args, **kwargs: pytest.mark.xfail(SCYLLA_VERSION is not None, reason=reason, *args, **kwargs) incorrect_test = lambda reason='This test seems to be incorrect and should be fixed', *args, **kwargs: pytest.mark.xfail(reason=reason, *args, **kwargs) +requires_scylla = pytest.mark.skipif(not SCYLLA_VERSION, reason='This test is designed for scylla only') pypy = unittest.skipUnless(platform.python_implementation() == "PyPy", "Test is skipped unless it's on PyPy") requiresmallclockgranularity = unittest.skipIf("Windows" in platform.system() or "asyncore" in EVENT_LOOP_MANAGER, diff --git a/tests/integration/long/test_policies.py b/tests/integration/long/test_policies.py index 33f35ced0d..1c81f42927 100644 --- a/tests/integration/long/test_policies.py +++ b/tests/integration/long/test_policies.py @@ -11,16 +11,22 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - +import os +import time import unittest +from typing import Optional from cassandra import ConsistencyLevel, Unavailable -from cassandra.cluster import ExecutionProfile, EXEC_PROFILE_DEFAULT +from cassandra.cluster import ExecutionProfile, EXEC_PROFILE_DEFAULT, Session +from cassandra.policies import LimitedConcurrencyShardConnectionBackoffPolicy, ShardConnectionBackoffScope, \ + ConstantReconnectionPolicy, ShardConnectionBackoffPolicy, NoDelayShardConnectionBackoffPolicy +from cassandra.shard_info import _ShardingInfo from tests.integration import use_cluster, get_cluster, get_node, TestCluster def setup_module(): + os.environ['SCYLLA_EXT_OPTS'] = "--smp 4" use_cluster('test_cluster', [4]) @@ -65,3 +71,127 @@ def test_should_rethrow_on_unvailable_with_default_policy_if_cas(self): self.assertEqual(exception.consistency, ConsistencyLevel.SERIAL) self.assertEqual(exception.required_replicas, 2) self.assertEqual(exception.alive_replicas, 1) + + +class ShardBackoffPolicyTests(unittest.TestCase): + @classmethod + def tearDownClass(cls): + cluster = get_cluster() + cluster.start(wait_for_binary_proto=True, wait_other_notice=True) # make sure other nodes are restarted + + def test_limited_concurrency_1_connection_per_cluster(self): + self._test_backoff( + LimitedConcurrencyShardConnectionBackoffPolicy( + backoff_policy=ConstantReconnectionPolicy(0.1), + max_concurrent=1, + scope=ShardConnectionBackoffScope.Cluster, + ) + ) + + def test_limited_concurrency_2_connection_per_cluster(self): + self._test_backoff( + LimitedConcurrencyShardConnectionBackoffPolicy( + backoff_policy=ConstantReconnectionPolicy(0.1), + max_concurrent=2, + scope=ShardConnectionBackoffScope.Cluster, + ) + ) + + def test_limited_concurrency_1_connection_per_host(self): + self._test_backoff( + LimitedConcurrencyShardConnectionBackoffPolicy( + backoff_policy=ConstantReconnectionPolicy(0.1), + max_concurrent=1, + scope=ShardConnectionBackoffScope.Host, + ) + ) + + def test_limited_concurrency_2_connection_per_host(self): + self._test_backoff( + LimitedConcurrencyShardConnectionBackoffPolicy( + backoff_policy=ConstantReconnectionPolicy(0.1), + max_concurrent=1, + scope=ShardConnectionBackoffScope.Host, + ) + ) + + def test_no_delay(self): + self._test_backoff(NoDelayShardConnectionBackoffPolicy()) + + def _test_backoff(self, shard_connection_backoff_policy: ShardConnectionBackoffPolicy): + backoff_policy = None + if isinstance(shard_connection_backoff_policy, LimitedConcurrencyShardConnectionBackoffPolicy): + backoff_policy = shard_connection_backoff_policy.backoff_policy + + cluster = TestCluster( + shard_connection_backoff_policy=shard_connection_backoff_policy, + reconnection_policy=ConstantReconnectionPolicy(0), + ) + session = cluster.connect() + sharding_info = get_sharding_info(session) + + # even if backoff is set and there is no sharding info + # behavior should be the same as if there is no backoff policy + if not backoff_policy or not sharding_info: + time.sleep(2) + expected_connections = 1 + if sharding_info: + expected_connections = sharding_info.shards_count + for host_id, connections_count in get_connections_per_host(session).items(): + self.assertEqual(connections_count, expected_connections) + return + + sleep_time = 0 + schedule = backoff_policy.new_schedule() + # Calculate total time it will need to establish all connections + if shard_connection_backoff_policy.scope == ShardConnectionBackoffScope.Cluster: + for _ in session.hosts: + for _ in range(sharding_info.shards_count - 1): + sleep_time += next(schedule) + sleep_time /= shard_connection_backoff_policy.max_concurrent + elif shard_connection_backoff_policy.scope == ShardConnectionBackoffScope.Host: + for _ in range(sharding_info.shards_count - 1): + sleep_time += next(schedule) + sleep_time /= shard_connection_backoff_policy.max_concurrent + else: + raise ValueError("Unknown scope {}".format(shard_connection_backoff_policy.scope)) + + time.sleep(sleep_time / 2) + self.assertFalse( + is_connection_filled(shard_connection_backoff_policy.scope, session, sharding_info.shards_count)) + time.sleep(sleep_time / 2 + 1) + self.assertTrue( + is_connection_filled(shard_connection_backoff_policy.scope, session, sharding_info.shards_count)) + + +def is_connection_filled(scope: ShardConnectionBackoffScope, session: Session, shards_count: int) -> bool: + if scope == ShardConnectionBackoffScope.Cluster: + expected_connections = shards_count * len(session.hosts) + total_connections = sum(get_connections_per_host(session).values()) + return expected_connections == total_connections + elif scope == ShardConnectionBackoffScope.Host: + expected_connections_per_host = shards_count + for connections_count in get_connections_per_host(session).values(): + if connections_count < expected_connections_per_host: + return False + if connections_count == expected_connections_per_host: + continue + assert False, "Expected {} or less connections but got {}".format(expected_connections_per_host, + connections_count) + return True + else: + raise ValueError("Unknown scope {}".format(scope)) + + +def get_connections_per_host(session: Session) -> dict[str, int]: + host_connections = {} + for host, pool in session._pools.items(): + host_connections[host.host_id] = len(pool._connections) + return host_connections + + +def get_sharding_info(session: Session) -> Optional[_ShardingInfo]: + for host in session.hosts: + if host.sharding_info: + return host.sharding_info + return None diff --git a/tests/unit/test_host_connection_pool.py b/tests/unit/test_host_connection_pool.py index b4cb067d2f..8495d033af 100644 --- a/tests/unit/test_host_connection_pool.py +++ b/tests/unit/test_host_connection_pool.py @@ -26,7 +26,7 @@ from cassandra.connection import Connection from cassandra.pool import HostConnection, HostConnectionPool from cassandra.pool import Host, NoConnectionsAvailable -from cassandra.policies import HostDistance, SimpleConvictionPolicy +from cassandra.policies import HostDistance, SimpleConvictionPolicy, _NoDelayShardConnectionBackoffScheduler LOGGER = logging.getLogger(__name__) @@ -41,6 +41,8 @@ def make_session(self): session.cluster.get_core_connections_per_host.return_value = 1 session.cluster.get_max_requests_per_connection.return_value = 1 session.cluster.get_max_connections_per_host.return_value = 1 + session.shard_connection_backoff_scheduler = _NoDelayShardConnectionBackoffScheduler(session) + session.is_shutdown = False return session def test_borrow_and_return(self): diff --git a/tests/unit/test_policies.py b/tests/unit/test_policies.py index e7757aedfc..083499c209 100644 --- a/tests/unit/test_policies.py +++ b/tests/unit/test_policies.py @@ -26,13 +26,16 @@ from cassandra import ConsistencyLevel from cassandra.cluster import Cluster, ControlConnection from cassandra.metadata import Metadata -from cassandra.policies import (RackAwareRoundRobinPolicy, RoundRobinPolicy, WhiteListRoundRobinPolicy, DCAwareRoundRobinPolicy, +from cassandra.policies import (RackAwareRoundRobinPolicy, RoundRobinPolicy, WhiteListRoundRobinPolicy, + DCAwareRoundRobinPolicy, TokenAwarePolicy, SimpleConvictionPolicy, HostDistance, ExponentialReconnectionPolicy, RetryPolicy, WriteType, DowngradingConsistencyRetryPolicy, ConstantReconnectionPolicy, LoadBalancingPolicy, ConvictionPolicy, ReconnectionPolicy, FallthroughRetryPolicy, - IdentityTranslator, EC2MultiRegionTranslator, HostFilterPolicy, ExponentialBackoffRetryPolicy) + IdentityTranslator, EC2MultiRegionTranslator, HostFilterPolicy, + ExponentialBackoffRetryPolicy, _ScopeBucket, _LimitedConcurrencyShardConnectionScheduler, + ShardConnectionBackoffScope, _NoDelayShardConnectionBackoffScheduler, CreateConnectionCallback) from cassandra.connection import DefaultEndPoint, UnixSocketEndPoint from cassandra.pool import Host from cassandra.query import Statement @@ -1579,3 +1582,253 @@ def test_create_whitelist(self): # Only the filtered replicas should be allowed self.assertEqual(set(query_plan), {Host(DefaultEndPoint("127.0.0.1"), SimpleConvictionPolicy), Host(DefaultEndPoint("127.0.0.4"), SimpleConvictionPolicy)}) + + +class ScopeBucketTests(unittest.TestCase): + def setUp(self): + self.mock_scheduler = Mock() + self.mock_cluster = Mock(scheduler=self.mock_scheduler) + self.mock_session = Mock() + self.mock_session.is_shutdown = False + self.mock_session.cluster = self.mock_cluster + self.mock_session.submit = Mock() + + self.reconnection_policy = Mock() + self.reconnection_policy.new_schedule.side_effect = lambda: iter([0.1, 0.2, 0.3]) + + def test_add_schedules_initial_task(self): + method = Mock() + bucket = _ScopeBucket(self.mock_session, self.reconnection_policy, 1) + bucket.schedule_new_connection(CreateConnectionCallback(method, 1, x=2)) + + self.assertEqual(bucket.currently_pending, 1) + self.mock_scheduler.schedule.assert_called_once() + delay, func = self.mock_scheduler.schedule.call_args[0] + self.assertEqual(delay, 0.1) + self.assertTrue(callable(func)) + + def test_multiple_adds_only_schedule_once(self): + bucket = _ScopeBucket(self.mock_session, self.reconnection_policy, 1) + self._test_multiple_adds_only_schedule_once(bucket, 1) + + def test_multiple_adds_schedule_twice(self): + bucket = _ScopeBucket(self.mock_session, self.reconnection_policy, 2) + self._test_multiple_adds_only_schedule_once(bucket, 2) + + def _test_multiple_adds_only_schedule_once(self, bucket, expected_running): + method1 = Mock() + method2 = Mock() + + bucket.schedule_new_connection(CreateConnectionCallback(method1, "a")) + bucket.schedule_new_connection(CreateConnectionCallback(method2, "b")) + # Only one schedule should be triggered + self.assertEqual(self.mock_scheduler.schedule.call_count, expected_running) + + # Both methods are enqueued + self.assertEqual(len(bucket.items), 2) + self.assertEqual(bucket.currently_pending, expected_running) + + + def test_run_executes_and_reschedules(self): + bucket = _ScopeBucket(self.mock_session, self.reconnection_policy, 2) + method = Mock() + bucket.schedule_new_connection(CreateConnectionCallback(method, "arg", kwarg=123)) + + _, run_func = self.mock_scheduler.schedule.call_args[0] + run_func() + + method.assert_called_once_with("arg", kwarg=123) + self.mock_scheduler.schedule.assert_called_with(0.2, bucket._run) + + def test_run_stops_on_empty_queue(self): + bucket = _ScopeBucket(self.mock_session, self.reconnection_policy, 2) + + bucket.currently_pending = 1 + bucket.items = [] + bucket._run() + + bucket._run() + + self.assertEqual(bucket.currently_pending, 0) + self.assertIsNotNone(bucket.schedule) + + def test_does_not_schedule_if_shutdown(self): + bucket = _ScopeBucket(self.mock_session, self.reconnection_policy, 2) + self.mock_session.is_shutdown = True + method = Mock() + bucket.schedule_new_connection(CreateConnectionCallback(method)) + + self.mock_scheduler.schedule.assert_not_called() + self.mock_session.submit.assert_not_called() + + def test_schedule_uses_submit_if_delay_is_zero(self): + self.reconnection_policy.new_schedule.side_effect = lambda: iter([0]) + bucket = _ScopeBucket(self.mock_session, self.reconnection_policy, 1) + + method = Mock() + bucket.schedule_new_connection(CreateConnectionCallback(method)) + + self.mock_session.submit.assert_called_once_with(bucket._run) + + def test_get_delay_resets_schedule_on_stopiteration(self): + bucket = _ScopeBucket(self.mock_session, self.reconnection_policy, 2) + empty_iter = iter([]) + second_iter = iter([42.0]) + + bucket.schedule = empty_iter + + self.reconnection_policy.new_schedule = Mock(side_effect=[second_iter]) + bucket.reconnection_policy = self.reconnection_policy + + delay = bucket._get_delay() + self.assertEqual(delay, 42.0) + + def test_run_with_multiple_items(self): + bucket = _ScopeBucket(self.mock_session, self.reconnection_policy, 2) + + m1 = Mock() + m2 = Mock() + + bucket.schedule_new_connection(CreateConnectionCallback(m1)) + bucket.schedule_new_connection(CreateConnectionCallback(m2)) + + # Simulate the first scheduled run + delay, run_func = self.mock_scheduler.schedule.call_args[0] + run_func() + + # Because of LIFO order, m2 should be called first + m2.assert_called_once() + m1.assert_not_called() + + # One item should remain in the queue + self.assertEqual(len(bucket.items), 1) + self.assertIs(bucket.items[0].method, m1) + + def test_run_handles_index_error_gracefully(self): + bucket = _ScopeBucket(self.mock_session, self.reconnection_policy, 2) + + bucket.items = [] + bucket.running = True + + try: + bucket._run() + except IndexError: + self.fail("_run raised IndexError unexpectedly") + + self.assertEqual(bucket.currently_pending, 0) + self.assertIsNotNone(bucket.schedule) + + +class NoDelayShardConnectionBackoffSchedulerTests(unittest.TestCase): + def setUp(self): + self.mock_session = Mock() + self.mock_session.is_shutdown = False + self.scheduler = _NoDelayShardConnectionBackoffScheduler(self.mock_session) + + def test_schedule_executes_method_immediately(self): + method = Mock() + self.scheduler.schedule('host1', 0, method, 1, 2, key='val') + + self.mock_session.submit.assert_called_once() + submitted_fn = self.mock_session.submit.call_args[0][0] + submitted_args = self.mock_session.submit.call_args[0][1:] + submitted_kwargs = self.mock_session.submit.call_args[1] + + submitted_fn(*submitted_args, **submitted_kwargs) + + method.assert_called_once_with(1, 2, key='val') + + def test_schedule_skips_if_already_scheduled(self): + method = Mock() + self.scheduler.already_scheduled['host1-0'] = True + + self.scheduler.schedule('host1', 0, method) + + self.mock_session.submit.assert_not_called() + method.assert_not_called() + + def test_already_scheduled_resets_after_execution(self): + method = Mock() + + self.scheduler.schedule('host1', 0, method) + submitted_fn = self.mock_session.submit.call_args[0][0] + submitted_args = self.mock_session.submit.call_args[0][1:] + submitted_fn(*submitted_args) + + self.assertFalse(self.scheduler.already_scheduled['host1-0']) + + def test_schedule_skips_if_session_shutdown(self): + self.mock_session.is_shutdown = True + method = Mock() + + self.scheduler.schedule('host1', 0, method) + + self.mock_session.submit.assert_not_called() + method.assert_not_called() + + +class LimitedConcurrencyShardConnectionSchedulerTests(unittest.TestCase): + def setUp(self): + self.session = Mock() + self.session.is_shutdown = False + self.session.cluster.scheduler.schedule = Mock() + self.session.submit = Mock() + + self.reconnection_policy = Mock() + self.reconnection_policy.new_schedule.return_value = cycle([0]) + + self.method = Mock() + self.host_id = 'host123' + self.shard_id = 0 + + def test_schedules_once_per_key(self): + scheduler = _LimitedConcurrencyShardConnectionScheduler( + self.session, ShardConnectionBackoffScope.Cluster, self.reconnection_policy, 1 + ) + + scheduled = scheduler.schedule(self.host_id, self.shard_id, self.method) + self.assertTrue(scheduled) + # Try to schedule again for same key: should be rejected + scheduled2 = scheduler.schedule(self.host_id, self.shard_id, self.method) + self.assertFalse(scheduled2) + + # _ScopeBucket should have been created for cluster scope + self.assertEqual(len(scheduler.scopes), 1) + scope = next(iter(scheduler.scopes.values())) + self.assertIsNotNone(scope) + self.assertEqual(len(scope.items), 1) + + def test_schedule_separate_keys(self): + scheduler = _LimitedConcurrencyShardConnectionScheduler( + self.session, ShardConnectionBackoffScope.Host, self.reconnection_policy, 1 + ) + + scheduled1 = scheduler.schedule('host1', 1, self.method) + scheduled2 = scheduler.schedule('host1', 2, self.method) + scheduled3 = scheduler.schedule('host2', 1, self.method) + + self.assertTrue(scheduled1) + self.assertTrue(scheduled2) + self.assertTrue(scheduled3) + + # Should create scopes for both hosts + self.assertIn('host1', scheduler.scopes) + self.assertIn('host2', scheduler.scopes) + self.assertEqual(len(scheduler.scopes['host1'].items), 2) + self.assertEqual(len(scheduler.scopes['host2'].items), 1) + + def test_execute_resets_already_scheduled_flag(self): + scheduler = _LimitedConcurrencyShardConnectionScheduler( + self.session, ShardConnectionBackoffScope.Cluster, self.reconnection_policy, 1 + ) + + scheduler.schedule(self.host_id, self.shard_id, self.method) + key = f"{self.host_id}-{self.shard_id}" + + # Simulate running the scheduled task manually + self.assertTrue(scheduler.already_scheduled[key]) + scheduler._execute(key, self.method) + + # Should now be marked not scheduled + self.assertFalse(scheduler.already_scheduled[key]) + self.method.assert_called_once() \ No newline at end of file diff --git a/tests/unit/test_shard_aware.py b/tests/unit/test_shard_aware.py index fe7b95edba..bfac4cb601 100644 --- a/tests/unit/test_shard_aware.py +++ b/tests/unit/test_shard_aware.py @@ -11,17 +11,21 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import uuid try: import unittest2 as unittest except ImportError: import unittest # noqa +import time import logging from mock import MagicMock from concurrent.futures import ThreadPoolExecutor -from cassandra.cluster import ShardAwareOptions +from cassandra.cluster import ShardAwareOptions, _Scheduler +from cassandra.policies import ConstantReconnectionPolicy, \ + NoDelayShardConnectionBackoffPolicy, LimitedConcurrencyShardConnectionBackoffPolicy, ShardConnectionBackoffScope from cassandra.pool import HostConnection, HostDistance from cassandra.connection import ShardingInfo, DefaultEndPoint from cassandra.metadata import Murmur3Token @@ -53,7 +57,21 @@ class OptionsHolder(object): self.assertEqual(shard_info.shard_id_from_token(Murmur3Token.from_key(b"e").value), 4) self.assertEqual(shard_info.shard_id_from_token(Murmur3Token.from_key(b"100000").value), 2) - def test_advanced_shard_aware_port(self): + def test_shard_aware_reconnection_policy_no_delay(self): + # with NoDelayReconnectionPolicy all the connections should be created right away + self._test_shard_aware_reconnection_policy(4, NoDelayShardConnectionBackoffPolicy(), 4, 4) + + def test_shard_aware_reconnection_policy_delay(self): + # with ConstantReconnectionPolicy first connection is created right away, others are delayed + self._test_shard_aware_reconnection_policy( + 4, + LimitedConcurrencyShardConnectionBackoffPolicy( + ShardConnectionBackoffScope.Cluster, + ConstantReconnectionPolicy(0.1), + 1 + ), 1, 4) + + def _test_shard_aware_reconnection_policy(self, shard_count, shard_connection_backoff_policy, expected_count, expected_after): """ Test that on given a `shard_aware_port` on the OPTIONS message (ShardInfo class) the next connections would be open using this port @@ -71,17 +89,25 @@ def __init__(self, is_ssl=False, *args, **kwargs): self.cluster.ssl_options = None self.cluster.shard_aware_options = ShardAwareOptions() self.cluster.executor = ThreadPoolExecutor(max_workers=2) + self._executor_submit_original = self.cluster.executor.submit + self.cluster.executor.submit = self._executor_submit + self.cluster.scheduler = _Scheduler(self.cluster.executor) self.cluster.signal_connection_failure = lambda *args, **kwargs: False self.cluster.connection_factory = self.mock_connection_factory self.connection_counter = 0 + self.shard_connection_backoff_scheduler = shard_connection_backoff_policy.new_scheduler(self) self.futures = [] def submit(self, fn, *args, **kwargs): + if self.is_shutdown: + return None + return self.cluster.executor.submit(fn, *args, **kwargs) + + def _executor_submit(self, fn, *args, **kwargs): logging.info("Scheduling %s with args: %s, kwargs: %s", fn, args, kwargs) - if not self.is_shutdown: - f = self.cluster.executor.submit(fn, *args, **kwargs) - self.futures += [f] - return f + f = self._executor_submit_original(fn, *args, **kwargs) + self.futures += [f] + return f def mock_connection_factory(self, *args, **kwargs): connection = MagicMock() @@ -90,26 +116,50 @@ def mock_connection_factory(self, *args, **kwargs): connection.is_closed = False connection.orphaned_threshold_reached = False connection.endpoint = args[0] - sharding_info = ShardingInfo(shard_id=1, shards_count=4, partitioner="", sharding_algorithm="", sharding_ignore_msb=0, shard_aware_port=19042, shard_aware_port_ssl=19045) + sharding_info = ShardingInfo(shard_id=1, shards_count=shard_count, partitioner="", sharding_algorithm="", sharding_ignore_msb=0, shard_aware_port=19042, shard_aware_port_ssl=19045) connection.features = ProtocolFeatures(shard_id=kwargs.get('shard_id', self.connection_counter), sharding_info=sharding_info) self.connection_counter += 1 return connection host = MagicMock() + host.host_id = uuid.uuid4() host.endpoint = DefaultEndPoint("1.2.3.4") + session = None + backoff_policy = None + if isinstance(shard_connection_backoff_policy, LimitedConcurrencyShardConnectionBackoffPolicy): + backoff_policy = shard_connection_backoff_policy.backoff_policy + try: + for port, is_ssl in [(19042, False), (19045, True)]: + session = MockSession(is_ssl=is_ssl) + pool = HostConnection(host=host, host_distance=HostDistance.REMOTE, session=session) + for f in session.futures: + f.result() + assert len(pool._connections) == expected_count + for shard_id, connection in pool._connections.items(): + assert connection.features.shard_id == shard_id + if shard_id == 0: + assert connection.endpoint == DefaultEndPoint("1.2.3.4") + else: + assert connection.endpoint == DefaultEndPoint("1.2.3.4", port=port) - for port, is_ssl in [(19042, False), (19045, True)]: - session = MockSession(is_ssl=is_ssl) - pool = HostConnection(host=host, host_distance=HostDistance.REMOTE, session=session) - for f in session.futures: - f.result() - assert len(pool._connections) == 4 - for shard_id, connection in pool._connections.items(): - assert connection.features.shard_id == shard_id - if shard_id == 0: - assert connection.endpoint == DefaultEndPoint("1.2.3.4") - else: - assert connection.endpoint == DefaultEndPoint("1.2.3.4", port=port) + sleep_time = 0 + if backoff_policy: + # Check that connections to shards are being established according to the policy + # Calculate total time it will need to establish all connections + # Sleep half of the time and check that connections are not there yet + # Sleep rest of the time + 1 second and check that all connections has been established + schedule = backoff_policy.new_schedule() + for _ in range(shard_count): + sleep_time += next(schedule) + if sleep_time > 0: + time.sleep(sleep_time/2) + # Check that connection are not being established quicker than expected + assert len(pool._connections) < expected_after + time.sleep(sleep_time/2 + 1) - session.cluster.executor.shutdown(wait=True) + assert len(pool._connections) == expected_after + finally: + if session: + session.cluster.scheduler.shutdown() + session.cluster.executor.shutdown(wait=True)