Skip to content

Commit 814dbf2

Browse files
Add method to asynchronously prepare CQL statements
1 parent 6e2ffd4 commit 814dbf2

File tree

4 files changed

+172
-38
lines changed

4 files changed

+172
-38
lines changed

cassandra/cluster.py

Lines changed: 81 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -2717,7 +2717,7 @@ def execute_async(self, query, parameters=None, trace=False, custom_payload=None
27172717
if execute_as:
27182718
custom_payload[_proxy_execute_key] = execute_as.encode()
27192719

2720-
future = self._create_response_future(
2720+
future = self._create_execute_response_future(
27212721
query, parameters, trace, custom_payload, timeout,
27222722
execution_profile, paging_state, host)
27232723
future._protocol_handler = self.client_protocol_handler
@@ -2782,8 +2782,8 @@ def execute_graph_async(self, query, parameters=None, trace=False, execution_pro
27822782
custom_payload[_proxy_execute_key] = execute_as.encode()
27832783
custom_payload[_request_timeout_key] = int64_pack(int(execution_profile.request_timeout * 1000))
27842784

2785-
future = self._create_response_future(query, parameters=None, trace=trace, custom_payload=custom_payload,
2786-
timeout=_NOT_SET, execution_profile=execution_profile)
2785+
future = self._create_execute_response_future(query, parameters=None, trace=trace, custom_payload=custom_payload,
2786+
timeout=_NOT_SET, execution_profile=execution_profile)
27872787

27882788
future.message.query_params = graph_parameters
27892789
future._protocol_handler = self.client_protocol_handler
@@ -2885,9 +2885,9 @@ def _transform_params(self, parameters, graph_options):
28852885

28862886
def _target_analytics_master(self, future):
28872887
future._start_timer()
2888-
master_query_future = self._create_response_future("CALL DseClientTool.getAnalyticsGraphServer()",
2889-
parameters=None, trace=False,
2890-
custom_payload=None, timeout=future.timeout)
2888+
master_query_future = self._create_execute_response_future("CALL DseClientTool.getAnalyticsGraphServer()",
2889+
parameters=None, trace=False,
2890+
custom_payload=None, timeout=future.timeout)
28912891
master_query_future.row_factory = tuple_factory
28922892
master_query_future.send_request()
28932893

@@ -2910,9 +2910,37 @@ def _on_analytics_master_result(self, response, master_future, query_future):
29102910

29112911
self.submit(query_future.send_request)
29122912

2913-
def _create_response_future(self, query, parameters, trace, custom_payload,
2914-
timeout, execution_profile=EXEC_PROFILE_DEFAULT,
2915-
paging_state=None, host=None):
2913+
def prepare_async(self, query, custom_payload=None, keyspace=None):
2914+
"""
2915+
Prepare the given query and return a :class:`~.PrepareFuture`
2916+
object. You may also call :meth:`~.PrepareFuture.result()`
2917+
on the :class:`.PrepareFuture` to synchronously block for
2918+
prepared statement object at any time.
2919+
2920+
See :meth:`Session.prepare` for parameter definitions.
2921+
2922+
Example usage::
2923+
2924+
>>> future = session.prepare_async("SELECT * FROM mycf")
2925+
>>> # do other stuff...
2926+
2927+
>>> try:
2928+
... prepared_statement = future.result()
2929+
... except Exception:
2930+
... log.exception("Operation failed:")
2931+
"""
2932+
future = self._create_prepare_response_future(query, keyspace, custom_payload)
2933+
future._protocol_handler = self.client_protocol_handler
2934+
self._on_request(future)
2935+
future.send_request()
2936+
return future
2937+
2938+
def _create_prepare_response_future(self, query, keyspace, custom_payload):
2939+
return PrepareFuture(self, query, keyspace, custom_payload, self.default_timeout)
2940+
2941+
def _create_execute_response_future(self, query, parameters, trace, custom_payload,
2942+
timeout, execution_profile=EXEC_PROFILE_DEFAULT,
2943+
paging_state=None, host=None):
29162944
""" Returns the ResponseFuture before calling send_request() on it """
29172945

29182946
prepared_statement = None
@@ -3121,33 +3149,9 @@ def prepare(self, query, custom_payload=None, keyspace=None):
31213149
`custom_payload` is a key value map to be passed along with the prepare
31223150
message. See :ref:`custom_payload`.
31233151
"""
3124-
message = PrepareMessage(query=query, keyspace=keyspace)
3125-
future = ResponseFuture(self, message, query=None, timeout=self.default_timeout)
3126-
try:
3127-
future.send_request()
3128-
response = future.result().one()
3129-
except Exception:
3130-
log.exception("Error preparing query:")
3131-
raise
3152+
return self.prepare_async(query, custom_payload, keyspace).result()
31323153

3133-
prepared_keyspace = keyspace if keyspace else None
3134-
prepared_statement = PreparedStatement.from_message(
3135-
response.query_id, response.bind_metadata, response.pk_indexes, self.cluster.metadata, query, prepared_keyspace,
3136-
self._protocol_version, response.column_metadata, response.result_metadata_id, self.cluster.column_encryption_policy)
3137-
prepared_statement.custom_payload = future.custom_payload
3138-
3139-
self.cluster.add_prepared(response.query_id, prepared_statement)
3140-
3141-
if self.cluster.prepare_on_all_hosts:
3142-
host = future._current_host
3143-
try:
3144-
self.prepare_on_all_hosts(prepared_statement.query_string, host, prepared_keyspace)
3145-
except Exception:
3146-
log.exception("Error preparing query on all hosts:")
3147-
3148-
return prepared_statement
3149-
3150-
def prepare_on_all_hosts(self, query, excluded_host, keyspace=None):
3154+
def prepare_on_all_nodes(self, query, excluded_host, keyspace=None):
31513155
"""
31523156
Prepare the given query on all hosts, excluding ``excluded_host``.
31533157
Intended for internal use only.
@@ -5105,6 +5109,48 @@ def __str__(self):
51055109
__repr__ = __str__
51065110

51075111

5112+
class PrepareFuture(ResponseFuture):
5113+
_final_prepare_result = _NOT_SET
5114+
5115+
def __init__(self, session, query, keyspace, custom_payload, timeout):
5116+
super().__init__(session, PrepareMessage(query=query, keyspace=keyspace), None, timeout)
5117+
self.query_string = query
5118+
self._keyspace = keyspace
5119+
self._custom_payload = custom_payload
5120+
5121+
def _set_final_result(self, response):
5122+
session = self.session
5123+
cluster = session.cluster
5124+
prepared_statement = PreparedStatement.from_message(
5125+
response.query_id, response.bind_metadata, response.pk_indexes, cluster.metadata, self.query_string,
5126+
self._keyspace, session._protocol_version, response.column_metadata, response.result_metadata_id,
5127+
cluster.column_encryption_policy)
5128+
prepared_statement.custom_payload = response.custom_payload
5129+
cluster.add_prepared(response.query_id, prepared_statement)
5130+
self._final_prepare_result = prepared_statement
5131+
5132+
if cluster.prepare_on_all_hosts:
5133+
# trigger asynchronous preparation of query on other C* nodes,
5134+
# we are on event loop thread, so do not execute those synchronously
5135+
session.submit(
5136+
session.prepare_on_all_nodes,
5137+
self.query_string, self._current_host, self._keyspace)
5138+
5139+
super()._set_final_result(response)
5140+
5141+
def result(self):
5142+
self._event.wait()
5143+
if self._final_prepare_result is not _NOT_SET:
5144+
return self._final_prepare_result
5145+
else:
5146+
raise self._final_exception
5147+
5148+
def __str__(self):
5149+
result = "(no result yet)" if self._final_result is _NOT_SET else self._final_result
5150+
return "<PrepareFuture: query='%s' request_id=%s result=%s exception=%s coordinator_host=%s>" \
5151+
% (self.query_string, self._req_id, result, self._final_exception, self.coordinator_host)
5152+
__repr__ = __str__
5153+
51085154
class QueryExhausted(Exception):
51095155
"""
51105156
Raised when :meth:`.ResponseFuture.start_fetching_next_page()` is called and

cassandra/query.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1035,7 +1035,7 @@ def populate(self, max_wait=2.0, wait_for_complete=True, query_cl=None):
10351035

10361036
def _execute(self, query, parameters, time_spent, max_wait):
10371037
timeout = (max_wait - time_spent) if max_wait is not None else None
1038-
future = self._session._create_response_future(query, parameters, trace=False, custom_payload=None, timeout=timeout)
1038+
future = self._session._create_execute_response_future(query, parameters, trace=False, custom_payload=None, timeout=timeout)
10391039
# in case the user switched the row factory, set it to namedtuple for this query
10401040
future.row_factory = named_tuple_factory
10411041
future.send_request()

tests/integration/standard/test_prepared_statements.py

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from cassandra import InvalidRequest, DriverException
2323

2424
from cassandra import ConsistencyLevel, ProtocolVersion
25+
from cassandra.cluster import PrepareFuture
2526
from cassandra.query import PreparedStatement, UNSET_VALUE
2627
from tests.integration import (get_server_versions, greaterthanorequalcass40, greaterthanorequaldse50,
2728
requirecassandra, BasicSharedKeyspaceUnitTestCase)
@@ -121,6 +122,83 @@ def test_basic(self):
121122
results = self.session.execute(bound)
122123
self.assertEqual(results, [('x', 'y', 'z')])
123124

125+
def test_basic_async(self):
126+
"""
127+
Test basic asynchronous PreparedStatement usage
128+
"""
129+
self.session.execute(
130+
"""
131+
DROP KEYSPACE IF EXISTS preparedtests
132+
"""
133+
)
134+
self.session.execute(
135+
"""
136+
CREATE KEYSPACE preparedtests
137+
WITH replication = {'class': 'SimpleStrategy', 'replication_factor': '1'}
138+
""")
139+
140+
self.session.set_keyspace("preparedtests")
141+
self.session.execute(
142+
"""
143+
CREATE TABLE cf0 (
144+
a text,
145+
b text,
146+
c text,
147+
PRIMARY KEY (a, b)
148+
)
149+
""")
150+
151+
prepared_future = self.session.prepare_async(
152+
"""
153+
INSERT INTO cf0 (a, b, c) VALUES (?, ?, ?)
154+
""")
155+
self.assertIsInstance(prepared_future, PrepareFuture)
156+
prepared = prepared_future.result()
157+
self.assertIsInstance(prepared, PreparedStatement)
158+
159+
bound = prepared.bind(('a', 'b', 'c'))
160+
self.session.execute(bound)
161+
162+
prepared_future = self.session.prepare_async(
163+
"""
164+
SELECT * FROM cf0 WHERE a=?
165+
""")
166+
self.assertIsInstance(prepared_future, PrepareFuture)
167+
prepared = prepared_future.result()
168+
self.assertIsInstance(prepared, PreparedStatement)
169+
170+
bound = prepared.bind(('a'))
171+
results = self.session.execute(bound)
172+
self.assertEqual(results, [('a', 'b', 'c')])
173+
174+
# test with new dict binding
175+
prepared_future = self.session.prepare_async(
176+
"""
177+
INSERT INTO cf0 (a, b, c) VALUES (?, ?, ?)
178+
""")
179+
self.assertIsInstance(prepared_future, PrepareFuture)
180+
prepared = prepared_future.result()
181+
self.assertIsInstance(prepared, PreparedStatement)
182+
183+
bound = prepared.bind({
184+
'a': 'x',
185+
'b': 'y',
186+
'c': 'z'
187+
})
188+
self.session.execute(bound)
189+
190+
prepared_future = self.session.prepare_async(
191+
"""
192+
SELECT * FROM cf0 WHERE a=?
193+
""")
194+
self.assertIsInstance(prepared_future, PrepareFuture)
195+
prepared = prepared_future.result()
196+
self.assertIsInstance(prepared, PreparedStatement)
197+
198+
bound = prepared.bind({'a': 'x'})
199+
results = self.session.execute(bound)
200+
self.assertEqual(results, [('x', 'y', 'z')])
201+
124202
def test_missing_primary_key(self):
125203
"""
126204
Ensure an InvalidRequest is thrown

tests/integration/standard/test_query.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -647,7 +647,6 @@ def test_prepared_statement(self):
647647

648648
prepared = session.prepare('INSERT INTO test3rf.test (k, v) VALUES (?, ?)')
649649
prepared.consistency_level = ConsistencyLevel.ONE
650-
651650
self.assertEqual(str(prepared),
652651
'<PreparedStatement query="INSERT INTO test3rf.test (k, v) VALUES (?, ?)", consistency=ONE>')
653652

@@ -717,6 +716,17 @@ def test_prepared_statements(self):
717716
self.session.execute_async(batch).result()
718717
self.confirm_results()
719718

719+
def test_prepare_async(self):
720+
prepared = self.session.prepare_async("INSERT INTO test3rf.test (k, v) VALUES (?, ?)").result()
721+
722+
batch = BatchStatement(BatchType.LOGGED)
723+
for i in range(10):
724+
batch.add(prepared, (i, i))
725+
726+
self.session.execute(batch)
727+
self.session.execute_async(batch).result()
728+
self.confirm_results()
729+
720730
def test_bound_statements(self):
721731
prepared = self.session.prepare("INSERT INTO test3rf.test (k, v) VALUES (?, ?)")
722732

@@ -942,7 +952,7 @@ def test_no_connection_refused_on_timeout(self):
942952
exception_type = type(result).__name__
943953
if exception_type == "NoHostAvailable":
944954
self.fail("PYTHON-91: Disconnected from Cassandra: %s" % result.message)
945-
if exception_type in ["WriteTimeout", "WriteFailure", "ReadTimeout", "ReadFailure", "ErrorMessageSub"]:
955+
if exception_type in ["WriteTimeout", "WriteFailure", "ReadTimeout", "ReadFailure", "ErrorMessage", "ErrorMessageSub"]:
946956
if type(result).__name__ in ["WriteTimeout", "WriteFailure"]:
947957
received_timeout = True
948958
continue

0 commit comments

Comments
 (0)