Skip to content

Accept transaction config for execute_query #991

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Nov 27, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 13 additions & 5 deletions docs/source/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -187,8 +187,9 @@ Closing a driver will immediately shut down all connections in the pool.
query_, parameters_, routing_, database_, impersonated_user_,
bookmark_manager_, auth_, result_transformer_, **kwargs
):
@unit_of_work(query_.metadata, query_.timeout)
def work(tx):
result = tx.run(query_, parameters_, **kwargs)
result = tx.run(query_.text, parameters_, **kwargs)
return result_transformer_(result)

with driver.session(
Expand Down Expand Up @@ -245,16 +246,19 @@ Closing a driver will immediately shut down all connections in the pool.
assert isinstance(count, int)
return count

:param query_: cypher query to execute
:type query_: typing.LiteralString
:param query_:
Cypher query to execute.
Use a :class:`.Query` object to pass a query with additional
transaction configuration.
:type query_: typing.LiteralString | Query
:param parameters_: parameters to use in the query
:type parameters_: typing.Dict[str, typing.Any] | None
:param routing_:
whether to route the query to a reader (follower/read replica) or
Whether to route the query to a reader (follower/read replica) or
a writer (leader) in the cluster. Default is to route to a writer.
:type routing_: RoutingControl
:param database_:
database to execute the query against.
Database to execute the query against.

None (default) uses the database configured on the server side.

Expand Down Expand Up @@ -375,6 +379,10 @@ Closing a driver will immediately shut down all connections in the pool.
.. versionchanged:: 5.14
Stabilized ``auth_`` parameter from preview.

.. versionchanged:: 5.15
The ``query_`` parameter now also accepts a :class:`.Query` object
instead of only :class:`str`.


.. _driver-configuration-ref:

Expand Down
18 changes: 13 additions & 5 deletions docs/source/async_api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -174,8 +174,9 @@ Closing a driver will immediately shut down all connections in the pool.
query_, parameters_, routing_, database_, impersonated_user_,
bookmark_manager_, auth_, result_transformer_, **kwargs
):
@unit_of_work(query_.metadata, query_.timeout)
async def work(tx):
result = await tx.run(query_, parameters_, **kwargs)
result = await tx.run(query_.text, parameters_, **kwargs)
return await result_transformer_(result)

async with driver.session(
Expand Down Expand Up @@ -232,16 +233,19 @@ Closing a driver will immediately shut down all connections in the pool.
assert isinstance(count, int)
return count

:param query_: cypher query to execute
:type query_: typing.LiteralString
:param query_:
Cypher query to execute.
Use a :class:`.Query` object to pass a query with additional
transaction configuration.
:type query_: typing.LiteralString | Query
:param parameters_: parameters to use in the query
:type parameters_: typing.Dict[str, typing.Any] | None
:param routing_:
whether to route the query to a reader (follower/read replica) or
Whether to route the query to a reader (follower/read replica) or
a writer (leader) in the cluster. Default is to route to a writer.
:type routing_: RoutingControl
:param database_:
database to execute the query against.
Database to execute the query against.

None (default) uses the database configured on the server side.

Expand Down Expand Up @@ -362,6 +366,10 @@ Closing a driver will immediately shut down all connections in the pool.
.. versionchanged:: 5.14
Stabilized ``auth_`` parameter from preview.

.. versionchanged:: 5.15
The ``query_`` parameter now also accepts a :class:`.Query` object
instead of only :class:`str`.


.. _async-driver-configuration-ref:

Expand Down
42 changes: 31 additions & 11 deletions src/neo4j/_async/driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,11 @@
experimental_warn,
unclosed_resource_warn,
)
from .._work import EagerResult
from .._work import (
EagerResult,
Query,
unit_of_work,
)
from ..addressing import Address
from ..api import (
AsyncBookmarkManager,
Expand Down Expand Up @@ -581,7 +585,7 @@ async def close(self) -> None:
@t.overload
async def execute_query(
self,
query_: te.LiteralString,
query_: t.Union[te.LiteralString, Query],
parameters_: t.Optional[t.Dict[str, t.Any]] = None,
routing_: T_RoutingControl = RoutingControl.WRITE,
database_: t.Optional[str] = None,
Expand All @@ -600,7 +604,7 @@ async def execute_query(
@t.overload
async def execute_query(
self,
query_: te.LiteralString,
query_: t.Union[te.LiteralString, Query],
parameters_: t.Optional[t.Dict[str, t.Any]] = None,
routing_: T_RoutingControl = RoutingControl.WRITE,
database_: t.Optional[str] = None,
Expand All @@ -618,7 +622,7 @@ async def execute_query(

async def execute_query(
self,
query_: te.LiteralString,
query_: t.Union[te.LiteralString, Query],
parameters_: t.Optional[t.Dict[str, t.Any]] = None,
routing_: T_RoutingControl = RoutingControl.WRITE,
database_: t.Optional[str] = None,
Expand Down Expand Up @@ -651,8 +655,9 @@ async def execute_query(
query_, parameters_, routing_, database_, impersonated_user_,
bookmark_manager_, auth_, result_transformer_, **kwargs
):
@unit_of_work(query_.metadata, query_.timeout)
async def work(tx):
result = await tx.run(query_, parameters_, **kwargs)
result = await tx.run(query_.text, parameters_, **kwargs)
return await result_transformer_(result)

async with driver.session(
Expand Down Expand Up @@ -709,16 +714,19 @@ async def example(driver: neo4j.AsyncDriver) -> int:
assert isinstance(count, int)
return count

:param query_: cypher query to execute
:type query_: typing.LiteralString
:param query_:
Cypher query to execute.
Use a :class:`.Query` object to pass a query with additional
transaction configuration.
:type query_: typing.LiteralString | Query
:param parameters_: parameters to use in the query
:type parameters_: typing.Optional[typing.Dict[str, typing.Any]]
:param routing_:
whether to route the query to a reader (follower/read replica) or
Whether to route the query to a reader (follower/read replica) or
a writer (leader) in the cluster. Default is to route to a writer.
:type routing_: RoutingControl
:param database_:
database to execute the query against.
Database to execute the query against.

None (default) uses the database configured on the server side.

Expand Down Expand Up @@ -838,6 +846,10 @@ async def example(driver: neo4j.AsyncDriver) -> neo4j.Record::

.. versionchanged:: 5.14
Stabilized ``auth_`` parameter from preview.

.. versionchanged:: 5.15
The ``query_`` parameter now also accepts a :class:`.Query` object
instead of only :class:`str`.
"""
self._check_state()
invalid_kwargs = [k for k in kwargs if
Expand All @@ -850,6 +862,14 @@ async def example(driver: neo4j.AsyncDriver) -> neo4j.Record::
"latter case, use the `parameters_` dictionary instead."
% invalid_kwargs
)
if isinstance(query_, Query):
timeout = query_.timeout
metadata = query_.metadata
query_str = query_.text
work = unit_of_work(metadata, timeout)(_work)
else:
query_str = query_
work = _work
parameters = dict(parameters_ or {}, **kwargs)

if bookmark_manager_ is _default:
Expand All @@ -876,7 +896,7 @@ async def example(driver: neo4j.AsyncDriver) -> neo4j.Record::
with session._pipelined_begin:
return await session._run_transaction(
access_mode, TelemetryAPI.DRIVER,
_work, (query_, parameters, result_transformer_), {}
work, (query_str, parameters, result_transformer_), {}
)

@property
Expand Down Expand Up @@ -1195,7 +1215,7 @@ async def _get_server_info(self, session_config) -> ServerInfo:

async def _work(
tx: AsyncManagedTransaction,
query: str,
query: te.LiteralString,
parameters: t.Dict[str, t.Any],
transformer: t.Callable[[AsyncResult], t.Awaitable[_T]]
) -> _T:
Expand Down
42 changes: 31 additions & 11 deletions src/neo4j/_sync/driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,11 @@
experimental_warn,
unclosed_resource_warn,
)
from .._work import EagerResult
from .._work import (
EagerResult,
Query,
unit_of_work,
)
from ..addressing import Address
from ..api import (
Auth,
Expand Down Expand Up @@ -580,7 +584,7 @@ def close(self) -> None:
@t.overload
def execute_query(
self,
query_: te.LiteralString,
query_: t.Union[te.LiteralString, Query],
parameters_: t.Optional[t.Dict[str, t.Any]] = None,
routing_: T_RoutingControl = RoutingControl.WRITE,
database_: t.Optional[str] = None,
Expand All @@ -599,7 +603,7 @@ def execute_query(
@t.overload
def execute_query(
self,
query_: te.LiteralString,
query_: t.Union[te.LiteralString, Query],
parameters_: t.Optional[t.Dict[str, t.Any]] = None,
routing_: T_RoutingControl = RoutingControl.WRITE,
database_: t.Optional[str] = None,
Expand All @@ -617,7 +621,7 @@ def execute_query(

def execute_query(
self,
query_: te.LiteralString,
query_: t.Union[te.LiteralString, Query],
parameters_: t.Optional[t.Dict[str, t.Any]] = None,
routing_: T_RoutingControl = RoutingControl.WRITE,
database_: t.Optional[str] = None,
Expand Down Expand Up @@ -650,8 +654,9 @@ def execute_query(
query_, parameters_, routing_, database_, impersonated_user_,
bookmark_manager_, auth_, result_transformer_, **kwargs
):
@unit_of_work(query_.metadata, query_.timeout)
def work(tx):
result = tx.run(query_, parameters_, **kwargs)
result = tx.run(query_.text, parameters_, **kwargs)
return result_transformer_(result)

with driver.session(
Expand Down Expand Up @@ -708,16 +713,19 @@ def example(driver: neo4j.Driver) -> int:
assert isinstance(count, int)
return count

:param query_: cypher query to execute
:type query_: typing.LiteralString
:param query_:
Cypher query to execute.
Use a :class:`.Query` object to pass a query with additional
transaction configuration.
:type query_: typing.LiteralString | Query
:param parameters_: parameters to use in the query
:type parameters_: typing.Optional[typing.Dict[str, typing.Any]]
:param routing_:
whether to route the query to a reader (follower/read replica) or
Whether to route the query to a reader (follower/read replica) or
a writer (leader) in the cluster. Default is to route to a writer.
:type routing_: RoutingControl
:param database_:
database to execute the query against.
Database to execute the query against.

None (default) uses the database configured on the server side.

Expand Down Expand Up @@ -837,6 +845,10 @@ def example(driver: neo4j.Driver) -> neo4j.Record::

.. versionchanged:: 5.14
Stabilized ``auth_`` parameter from preview.

.. versionchanged:: 5.15
The ``query_`` parameter now also accepts a :class:`.Query` object
instead of only :class:`str`.
"""
self._check_state()
invalid_kwargs = [k for k in kwargs if
Expand All @@ -849,6 +861,14 @@ def example(driver: neo4j.Driver) -> neo4j.Record::
"latter case, use the `parameters_` dictionary instead."
% invalid_kwargs
)
if isinstance(query_, Query):
timeout = query_.timeout
metadata = query_.metadata
query_str = query_.text
work = unit_of_work(metadata, timeout)(_work)
else:
query_str = query_
work = _work
parameters = dict(parameters_ or {}, **kwargs)

if bookmark_manager_ is _default:
Expand All @@ -875,7 +895,7 @@ def example(driver: neo4j.Driver) -> neo4j.Record::
with session._pipelined_begin:
return session._run_transaction(
access_mode, TelemetryAPI.DRIVER,
_work, (query_, parameters, result_transformer_), {}
work, (query_str, parameters, result_transformer_), {}
)

@property
Expand Down Expand Up @@ -1194,7 +1214,7 @@ def _get_server_info(self, session_config) -> ServerInfo:

def _work(
tx: ManagedTransaction,
query: str,
query: te.LiteralString,
parameters: t.Dict[str, t.Any],
transformer: t.Callable[[Result], t.Union[_T]]
) -> _T:
Expand Down
13 changes: 10 additions & 3 deletions src/neo4j/_work/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,10 @@ class Query:
"""A query with attached extra data.

This wrapper class for queries is used to attach extra data to queries
passed to :meth:`.Session.run` and :meth:`.AsyncSession.run`, fulfilling
a similar role as :func:`.unit_of_work` for transactions functions.
passed to :meth:`.Session.run`/:meth:`.AsyncSession.run` and
:meth:`.Driver.execute_query`/:meth:`.AsyncDriver.execute_query`,
fulfilling a similar role as :func:`.unit_of_work` for transactions
functions.

:param text: The query text.
:type text: typing.LiteralString
Expand Down Expand Up @@ -74,7 +76,12 @@ def __init__(
self.timeout = timeout

def __str__(self) -> te.LiteralString:
return str(self.text)
# we know that if Query is constructed with a LiteralString,
# str(self.text) will be a LiteralString as well. The conversion isn't
# necessary if the user adheres to the type hints. However, it was
# here before, and we don't want to break backwards compatibility.
text: te.LiteralString = str(self.text) # type: ignore[assignment]
return text


def unit_of_work(
Expand Down
7 changes: 6 additions & 1 deletion testkitbackend/_async/requests.py
Original file line number Diff line number Diff line change
Expand Up @@ -363,6 +363,11 @@ async def ExecuteQuery(backend, data):
value = config.get(config_key, None)
if value is not None:
kwargs[kwargs_key] = value
tx_kwargs = fromtestkit.to_tx_kwargs(config)
if tx_kwargs:
query = neo4j.Query(cypher, **tx_kwargs)
else:
query = cypher
bookmark_manager_id = config.get("bookmarkManagerId")
if bookmark_manager_id is not None:
if bookmark_manager_id == -1:
Expand All @@ -371,7 +376,7 @@ async def ExecuteQuery(backend, data):
bookmark_manager = backend.bookmark_managers[bookmark_manager_id]
kwargs["bookmark_manager_"] = bookmark_manager

eager_result = await driver.execute_query(cypher, params, **kwargs)
eager_result = await driver.execute_query(query, params, **kwargs)
await backend.send_response("EagerResult", {
"keys": eager_result.keys,
"records": list(map(totestkit.record, eager_result.records)),
Expand Down
7 changes: 6 additions & 1 deletion testkitbackend/_sync/requests.py
Original file line number Diff line number Diff line change
Expand Up @@ -363,6 +363,11 @@ def ExecuteQuery(backend, data):
value = config.get(config_key, None)
if value is not None:
kwargs[kwargs_key] = value
tx_kwargs = fromtestkit.to_tx_kwargs(config)
if tx_kwargs:
query = neo4j.Query(cypher, **tx_kwargs)
else:
query = cypher
bookmark_manager_id = config.get("bookmarkManagerId")
if bookmark_manager_id is not None:
if bookmark_manager_id == -1:
Expand All @@ -371,7 +376,7 @@ def ExecuteQuery(backend, data):
bookmark_manager = backend.bookmark_managers[bookmark_manager_id]
kwargs["bookmark_manager_"] = bookmark_manager

eager_result = driver.execute_query(cypher, params, **kwargs)
eager_result = driver.execute_query(query, params, **kwargs)
backend.send_response("EagerResult", {
"keys": eager_result.keys,
"records": list(map(totestkit.record, eager_result.records)),
Expand Down
Loading