Skip to content

Commit 463b726

Browse files
committed
Add tests for transaction lifetime
E.g., what happens when you commit, a closed (committed, rolled back, ...) transaction or when you try to manage a transaction inside a transaction function.
1 parent 2185a7c commit 463b726

File tree

7 files changed

+159
-11
lines changed

7 files changed

+159
-11
lines changed

nutkit/frontend/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1 +1,2 @@
11
from .driver import Driver
2+
from .exceptions import ApplicationCodeError

nutkit/frontend/exceptions.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
class ApplicationCodeError(Exception):
2+
pass

nutkit/frontend/session.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,9 @@
11
from .. import protocol
2+
from .exceptions import ApplicationCodeError
23
from .result import Result
34
from .transaction import Transaction
45

56

6-
class ApplicationCodeError(Exception):
7-
pass
8-
9-
107
class Session:
118
def __init__(self, driver, session):
129
self._driver = driver

tests/neo4j/test_tx_func_run.py

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from nutkit.frontend.session import ApplicationCodeError
1+
from nutkit.frontend import ApplicationCodeError
22
import nutkit.protocol as types
33
from tests.neo4j.shared import (
44
get_driver,
@@ -127,10 +127,6 @@ def run(tx):
127127
self._session1 = self._driver.session("w")
128128
expected_exc = types.FrontendError
129129
# TODO: remove this block once all languages work
130-
if get_driver_name() in ["javascript"]:
131-
expected_exc = types.DriverError
132-
if get_driver_name() in ["dotnet"]:
133-
expected_exc = types.BackendError
134130
with self.assertRaises(expected_exc):
135131
self._session1.write_transaction(run)
136132
bookmarks = self._session1.last_bookmarks()
@@ -161,8 +157,6 @@ def assertion_query(tx):
161157
# TODO: remove this block once all languages work
162158
if get_driver_name() in ["javascript"]:
163159
expected_exc = types.DriverError
164-
if get_driver_name() in ["dotnet"]:
165-
expected_exc = types.BackendError
166160
with self.assertRaises(expected_exc):
167161
self._session1.write_transaction(run)
168162

tests/stub/tx_lifetime/__init__.py

Whitespace-only changes.
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
!: BOLT 4.4
2+
3+
A: HELLO {"{}": "*"}
4+
*: RESET
5+
C: BEGIN {"{}": "*"}
6+
S: SUCCESS {}
7+
C: RUN {"U": "*"} {"{}": "*"} {"{}": "*"}
8+
S: SUCCESS {"fields": ["n"]}
9+
{*
10+
C: PULL {"n": {"Z": "*"}, "[qid]": -1}
11+
S: RECORD [1]
12+
RECORD [2]
13+
SUCCESS {"has_more": true}
14+
*}
15+
{{
16+
C: DISCARD {"n": -1, "[qid]": -1}
17+
S: SUCCESS {"type": "r"}
18+
{?
19+
{{
20+
C: ROLLBACK
21+
----
22+
C: COMMIT
23+
----
24+
C: RESET
25+
}}
26+
S: SUCCESS {}
27+
?}
28+
----
29+
A: RESET
30+
}}
31+
*: RESET
32+
?: GOODBYE
Lines changed: 122 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,122 @@
1+
from contextlib import contextmanager
2+
3+
from nutkit.frontend import Driver
4+
import nutkit.protocol as types
5+
from tests.shared import (
6+
get_driver_name,
7+
TestkitTestCase,
8+
)
9+
from tests.stub.shared import StubServer
10+
11+
12+
class TestTxLifetime(TestkitTestCase):
13+
def setUp(self):
14+
super().setUp()
15+
self._server = StubServer(9000)
16+
17+
def tearDown(self):
18+
# If test raised an exception this will make sure that the stub server
19+
# is killed and it's output is dumped for analysis.
20+
self._server.reset()
21+
super().tearDown()
22+
23+
@contextmanager
24+
def _start_session(self, script):
25+
uri = "bolt://%s" % self._server.address
26+
driver = Driver(self._backend, uri,
27+
types.AuthorizationToken("basic", principal="",
28+
credentials=""))
29+
self._server.start(path=self.script_path("v4x4", script))
30+
session = driver.session("r", fetch_size=2)
31+
try:
32+
yield session
33+
finally:
34+
session.close()
35+
driver.close()
36+
37+
def _asserts_tx_closed_error(self, exc):
38+
driver = get_driver_name()
39+
assert isinstance(exc, types.DriverError)
40+
if driver in ["python"]:
41+
self.assertEqual(exc.errorType,
42+
"<class 'neo4j.exceptions.TransactionError'>")
43+
self.assertIn("closed", exc.msg.lower())
44+
elif driver in ["javascript", "go", "dotnet"]:
45+
self.assertIn("transaction", exc.msg.lower())
46+
elif driver in ["java"]:
47+
self.assertEqual(exc.errorType,
48+
"org.neo4j.driver.exceptions.ClientException")
49+
else:
50+
self.fail("no error mapping is defined for %s driver" % driver)
51+
52+
def _asserts_tx_managed_error(self, exc):
53+
driver = get_driver_name()
54+
if driver in ["python"]:
55+
self.assertEqual(exc.errorType, "<class 'AttributeError'>")
56+
self.assertIn("managed", exc.msg.lower())
57+
elif driver in ["go"]:
58+
self.assertIn("retryable transaction", exc.msg.lower())
59+
else:
60+
self.fail("no error mapping is defined for %s driver" % driver)
61+
62+
def _test_unmanaged_tx(self, first_action, second_action):
63+
exc = None
64+
script = "tx_inf_results_until_end.script"
65+
with self._start_session(script) as session:
66+
tx = session.begin_transaction()
67+
res = tx.run("Query")
68+
res.consume()
69+
getattr(tx, first_action)()
70+
if second_action == "close":
71+
getattr(tx, second_action)()
72+
elif second_action == "run":
73+
with self.assertRaises(types.DriverError) as exc:
74+
tx.run("Query").consume()
75+
else:
76+
with self.assertRaises(types.DriverError) as exc:
77+
getattr(tx, second_action)()
78+
79+
self._server.done()
80+
self.assertEqual(
81+
self._server.count_requests("ROLLBACK"),
82+
int(first_action in ["rollback", "close"])
83+
)
84+
self.assertEqual(
85+
self._server.count_requests("COMMIT"),
86+
int(first_action == "commit")
87+
)
88+
if exc is not None:
89+
self._asserts_tx_closed_error(exc.exception)
90+
91+
def test_unmanaged_tx_raises_tx_closed_exec(self):
92+
for first_action in ("commit", "rollback", "close"):
93+
for second_action in ("commit", "rollback", "close", "run"):
94+
with self.subTest(first_action=first_action,
95+
second_action=second_action):
96+
self._test_unmanaged_tx(first_action, second_action)
97+
self._server.reset()
98+
99+
def _test_managed_tx(self, close_action):
100+
def work(tx_):
101+
res_ = tx_.run("Query")
102+
res_.consume()
103+
with self.assertRaises(types.DriverError) as exc_:
104+
getattr(tx_, close_action)()
105+
self._asserts_tx_managed_error(exc_.exception)
106+
raise exc_.exception
107+
108+
script = "tx_inf_results_until_end.script"
109+
with self._start_session(script) as session:
110+
with self.assertRaises(types.DriverError):
111+
session.read_transaction(work)
112+
113+
self._server.done()
114+
self._server._dump()
115+
self.assertEqual(self._server.count_requests("ROLLBACK"), 1)
116+
self.assertEqual(self._server.count_requests("COMMIT"), 0)
117+
118+
def test_managed_tx_raises_tx_managed_exec(self):
119+
for close_action in ("commit", "rollback", "close"):
120+
with self.subTest(close_action=close_action):
121+
self._test_managed_tx(close_action)
122+
self._server.reset()

0 commit comments

Comments
 (0)