diff --git a/CHANGELOG.md b/CHANGELOG.md index ba40b95a9..e4f3e2bcc 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -96,6 +96,26 @@ - ANSI colour codes for log output are now opt-in - Prepend log format with log-level (if colours are disabled) - Prepend log format with thread name and id +- Importing submodules from `neo4j.time` (`neo4j.time.xyz`) has been deprecated. + Everything needed should be imported from `neo4j.time` directly. +- `neo4j.spatial.hydrate_point` and `neo4j.spatial.dehydrate_point` have been + deprecated without replacement. They are internal functions. +- Importing `neo4j.packstream` has been deprecated. It's internal and should not + be used by client code. +- Importing `neo4j.routing` has been deprecated. It's internal and should not + be used by client code. +- Importing `neo4j.config` has been deprecated. It's internal and should not + be used by client code. +- `neoj4.Config`, `neoj4.PoolConfig`, `neoj4.SessionConfig`, and + `neoj4.WorkspaceConfig` have been deprecated without replacement. They are + internal classes. +- Importing `neo4j.meta` has been deprecated. It's internal and should not + be used by client code. `ExperimantalWarning` should be imported directly from + `neo4j`. `neo4j.meta.version` is exposed through `neo4j.__vesrion__` +- Importing `neo4j.data` has been deprecated. It's internal and should not + be used by client code. `Record` should be imported directly from `neo4j` + instead. `neo4j.data.DataHydrator` and `neo4j.data.DataDeydrator` have been + removed without replacement. ## Version 4.4 diff --git a/bin/dist-functions b/bin/dist-functions index f4d6780c2..04688be10 100644 --- a/bin/dist-functions +++ b/bin/dist-functions @@ -6,22 +6,22 @@ DIST="${ROOT}/dist" function get_package { - python -c "from neo4j.meta import package; print(package)" + python -c "from neo4j._meta import package; print(package)" } function set_package { - sed -i 's/^package = .*/package = "'$1'"/g' neo4j/meta.py + sed -i 's/^package = .*/package = "'$1'"/g' neo4j/_meta.py } function get_version { - python -c "from neo4j.meta import version; print(version)" + python -c "from neo4j._meta import version; print(version)" } function set_version { - sed -i 's/^version = .*/version = "'$1'"/g' neo4j/meta.py + sed -i 's/^version = .*/version = "'$1'"/g' neo4j/_meta.py } function check_file @@ -49,8 +49,8 @@ function set_metadata_and_setup ORIGINAL_VERSION=$(get_version) echo "Source code originally configured for package ${ORIGINAL_PACKAGE}/${ORIGINAL_VERSION}" echo "----------------------------------------" - grep "package\s\+=" neo4j/meta.py - grep "version\s\+=" neo4j/meta.py + grep "package\s\+=" neo4j/_meta.py + grep "version\s\+=" neo4j/_meta.py echo "----------------------------------------" function cleanup() { @@ -59,8 +59,8 @@ function set_metadata_and_setup set_version "${ORIGINAL_VERSION}" echo "Source code reconfigured back to original package ${ORIGINAL_PACKAGE}/${ORIGINAL_VERSION}" echo "----------------------------------------" - grep "package\s\+=" neo4j/meta.py - grep "version\s\+=" neo4j/meta.py + grep "package\s\+=" neo4j/_meta.py + grep "version\s\+=" neo4j/_meta.py echo "----------------------------------------" } trap cleanup EXIT @@ -70,8 +70,8 @@ function set_metadata_and_setup set_version "${VERSION}" echo "Source code reconfigured for package ${PACKAGE}/${VERSION}" echo "----------------------------------------" - grep "package\s\+=" neo4j/meta.py - grep "version\s\+=" neo4j/meta.py + grep "package\s\+=" neo4j/_meta.py + grep "version\s\+=" neo4j/_meta.py echo "----------------------------------------" # Create source distribution diff --git a/bin/make-unasync b/bin/make-unasync index 6fe47b8c5..d99a51a9b 100755 --- a/bin/make-unasync +++ b/bin/make-unasync @@ -21,10 +21,10 @@ import collections import errno import os -from pathlib import Path import re import sys import tokenize as std_tokenize +from pathlib import Path import isort import isort.files diff --git a/docs/source/conf.py b/docs/source/conf.py index 9b2c2ff4c..d6cfe9fe4 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -21,7 +21,7 @@ sys.path.insert(0, os.path.abspath(os.path.join("..", ".."))) -from neo4j.meta import version as project_version +from neo4j import __version__ as project_version # -- General configuration ------------------------------------------------ diff --git a/neo4j/__init__.py b/neo4j/__init__.py index 435bf3649..89acc690d 100644 --- a/neo4j/__init__.py +++ b/neo4j/__init__.py @@ -16,60 +16,7 @@ # limitations under the License. -__all__ = [ - "__version__", - "Address", - "AsyncBoltDriver", - "AsyncDriver", - "AsyncGraphDatabase", - "AsyncManagedTransaction", - "AsyncNeo4jDriver", - "AsyncResult", - "AsyncSession", - "AsyncTransaction", - "Auth", - "AuthToken", - "basic_auth", - "bearer_auth", - "BoltDriver", - "Bookmark", - "Bookmarks", - "Config", - "custom_auth", - "DEFAULT_DATABASE", - "Driver", - "ExperimentalWarning", - "get_user_agent", - "GraphDatabase", - "IPv4Address", - "IPv6Address", - "kerberos_auth", - "ManagedTransaction", - "Neo4jDriver", - "PoolConfig", - "Query", - "READ_ACCESS", - "Record", - "Result", - "ResultSummary", - "ServerInfo", - "Session", - "SessionConfig", - "SummaryCounters", - "Transaction", - "TRUST_ALL_CERTIFICATES", - "TRUST_SYSTEM_CA_SIGNED_CERTIFICATES", - "TrustAll", - "TrustCustomCAs", - "TrustSystemCAs", - "unit_of_work", - "Version", - "WorkspaceConfig", - "WRITE_ACCESS", -] - - -from logging import getLogger +from logging import getLogger as _getLogger from ._async.driver import ( AsyncBoltDriver, @@ -84,9 +31,19 @@ AsyncTransaction, ) from ._conf import ( + Config as _Config, + PoolConfig as _PoolConfig, + SessionConfig as _SessionConfig, TrustAll, TrustCustomCAs, TrustSystemCAs, + WorkspaceConfig as _WorkspaceConfig, +) +from ._data import Record +from ._meta import ( + ExperimentalWarning, + get_user_agent, + version as __version__, ) from ._sync.driver import ( BoltDriver, @@ -125,19 +82,6 @@ Version, WRITE_ACCESS, ) -from .conf import ( - Config, - PoolConfig, - SessionConfig, - WorkspaceConfig, -) -from .data import Record -from .meta import ( - experimental, - ExperimentalWarning, - get_user_agent, - version as __version__, -) from .work import ( Query, ResultSummary, @@ -146,4 +90,77 @@ ) -log = getLogger("neo4j") +__all__ = [ + "__version__", + "Address", + "AsyncBoltDriver", + "AsyncDriver", + "AsyncGraphDatabase", + "AsyncManagedTransaction", + "AsyncNeo4jDriver", + "AsyncResult", + "AsyncSession", + "AsyncTransaction", + "Auth", + "AuthToken", + "basic_auth", + "bearer_auth", + "BoltDriver", + "Bookmark", + "Bookmarks", + "Config", + "custom_auth", + "DEFAULT_DATABASE", + "Driver", + "ExperimentalWarning", + "get_user_agent", + "GraphDatabase", + "IPv4Address", + "IPv6Address", + "kerberos_auth", + "ManagedTransaction", + "Neo4jDriver", + "PoolConfig", + "Query", + "READ_ACCESS", + "Record", + "Result", + "ResultSummary", + "ServerInfo", + "Session", + "SessionConfig", + "SummaryCounters", + "Transaction", + "TRUST_ALL_CERTIFICATES", + "TRUST_SYSTEM_CA_SIGNED_CERTIFICATES", + "TrustAll", + "TrustCustomCAs", + "TrustSystemCAs", + "unit_of_work", + "Version", + "WorkspaceConfig", + "WRITE_ACCESS", +] + + +_log = _getLogger("neo4j") + + +def __getattr__(name): + # TODO 6.0 - remove this + if name in ( + "log", "Config", "PoolConfig", "SessionConfig", "WorkspaceConfig" + ): + from ._meta import deprecation_warn + deprecation_warn( + "Importing {} from neo4j is deprecated without replacement. It's " + "internal and will be removed in a future version." + .format(name), + stack_level=2 + ) + return globals()[f"_{name}"] + raise AttributeError(f"module {__name__} has no attribute {name}") + + +def __dir__(): + return __all__ diff --git a/neo4j/_async/driver.py b/neo4j/_async/driver.py index 4d171f0db..bec20108d 100644 --- a/neo4j/_async/driver.py +++ b/neo4j/_async/driver.py @@ -16,32 +16,27 @@ # limitations under the License. -import warnings - from .._async_compat.util import AsyncUtil from .._conf import ( - TrustAll, - TrustStore, -) -from ..addressing import Address -from ..api import ( - READ_ACCESS, - TRUST_ALL_CERTIFICATES, - TRUST_SYSTEM_CA_SIGNED_CERTIFICATES, -) -from ..conf import ( Config, PoolConfig, SessionConfig, + TrustAll, + TrustStore, WorkspaceConfig, ) -from ..meta import ( +from .._meta import ( deprecation_warn, experimental, experimental_warn, - ExperimentalWarning, unclosed_resource_warn, ) +from ..addressing import Address +from ..api import ( + READ_ACCESS, + TRUST_ALL_CERTIFICATES, + TRUST_SYSTEM_CA_SIGNED_CERTIFICATES, +) class AsyncGraphDatabase: @@ -145,7 +140,8 @@ def driver(cls, uri, *, auth=None, **config): "Creating a direct driver (`bolt://` scheme) with routing " "context (URI parameters) is deprecated. They will be " "ignored. This will raise an error in a future release. " - 'Given URI "{}"'.format(uri) + 'Given URI "{}"'.format(uri), + stack_level=2 ) # TODO: 6.0 - raise instead of warning # raise ValueError( diff --git a/neo4j/_async/io/_bolt.py b/neo4j/_async/io/_bolt.py index 8c7e82dea..7f31f32dd 100644 --- a/neo4j/_async/io/_bolt.py +++ b/neo4j/_async/io/_bolt.py @@ -24,17 +24,20 @@ from ..._async_compat.network import AsyncBoltSocket from ..._async_compat.util import AsyncUtil +from ..._codec.hydration import v1 as hydration_v1 +from ..._codec.packstream import v1 as packstream_v1 +from ..._conf import PoolConfig from ..._exceptions import ( BoltError, BoltHandshakeError, SocketDeadlineExceeded, ) +from ..._meta import get_user_agent from ...addressing import Address from ...api import ( ServerInfo, Version, ) -from ...conf import PoolConfig from ...exceptions import ( AuthError, DriverError, @@ -42,15 +45,10 @@ ServiceUnavailable, SessionExpired, ) -from ...meta import get_user_agent -from ...packstream import ( - Packer, - Unpacker, -) from ._common import ( AsyncInbox, + AsyncOutbox, CommitResponse, - Outbox, ) @@ -68,6 +66,13 @@ class AsyncBolt: the handshake was carried out. """ + # TODO: let packer/unpacker know of hydration (give them hooks?) + # TODO: make sure query parameter dehydration gets clear error message. + + PACKER_CLS = packstream_v1.Packer + UNPACKER_CLS = packstream_v1.Unpacker + HYDRATION_HANDLER_CLS = hydration_v1.HydrationHandler + MAGIC_PREAMBLE = b"\x60\x60\xB0\x17" PROTOCOL_VERSION = None @@ -107,10 +112,16 @@ def __init__(self, unresolved_address, sock, max_connection_lifetime, *, # configuration hint that exists. Therefore, all hints can be stored at # connection level. This might change in the future. self.configuration_hints = {} - self.outbox = Outbox() - self.inbox = AsyncInbox(self.socket, on_error=self._set_defunct_read) - self.packer = Packer(self.outbox) - self.unpacker = Unpacker(self.inbox) + self.patch = {} + self.outbox = AsyncOutbox( + self.socket, on_error=self._set_defunct_write, + packer_cls=self.PACKER_CLS + ) + self.inbox = AsyncInbox( + self.socket, on_error=self._set_defunct_read, + unpacker_cls=self.UNPACKER_CLS + ) + self.hydration_handler = self.HYDRATION_HANDLER_CLS() self.responses = deque() self._max_connection_lifetime = max_connection_lifetime self._creation_timestamp = perf_counter() @@ -376,14 +387,17 @@ def der_encoded_server_certificate(self): pass @abc.abstractmethod - async def hello(self): + async def hello(self, dehydration_hooks=None, hydration_hooks=None): """ Appends a HELLO message to the outgoing queue, sends it and consumes all remaining messages. """ pass @abc.abstractmethod - async def route(self, database=None, imp_user=None, bookmarks=None): + async def route( + self, database=None, imp_user=None, bookmarks=None, + dehydration_hooks=None, hydration_hooks=None + ): """ Fetch a routing table from the server for the given `database`. For Bolt 4.3 and above, this appends a ROUTE message; for earlier versions, a procedure call is made via @@ -396,13 +410,22 @@ async def route(self, database=None, imp_user=None, bookmarks=None): Requires Bolt 4.4+. :param bookmarks: iterable of bookmark values after which this transaction should begin - :return: dictionary of raw routing data + :param dehydration_hooks: + Hooks to dehydrate types (dict from type (class) to dehydration + function). Dehydration functions receive the value and returns an + object of type understood by packstream. + :param hydration_hooks: + Hooks to hydrate types (mapping from type (class) to + dehydration function). Dehydration functions receive the value of + type understood by packstream and are free to return anything. """ pass @abc.abstractmethod def run(self, query, parameters=None, mode=None, bookmarks=None, - metadata=None, timeout=None, db=None, imp_user=None, **handlers): + metadata=None, timeout=None, db=None, imp_user=None, + dehydration_hooks=None, hydration_hooks=None, + **handlers): """ Appends a RUN message to the output queue. :param query: Cypher query string @@ -415,36 +438,60 @@ def run(self, query, parameters=None, mode=None, bookmarks=None, Requires Bolt 4.0+. :param imp_user: the user to impersonate Requires Bolt 4.4+. + :param dehydration_hooks: + Hooks to dehydrate types (dict from type (class) to dehydration + function). Dehydration functions receive the value and returns an + object of type understood by packstream. + :param hydration_hooks: + Hooks to hydrate types (mapping from type (class) to + dehydration function). Dehydration functions receive the value of + type understood by packstream and are free to return anything. :param handlers: handler functions passed into the returned Response object - :return: Response object """ pass @abc.abstractmethod - def discard(self, n=-1, qid=-1, **handlers): + def discard(self, n=-1, qid=-1, dehydration_hooks=None, + hydration_hooks=None, **handlers): """ Appends a DISCARD message to the output queue. :param n: number of records to discard, default = -1 (ALL) :param qid: query ID to discard for, default = -1 (last query) + :param dehydration_hooks: + Hooks to dehydrate types (dict from type (class) to dehydration + function). Dehydration functions receive the value and returns an + object of type understood by packstream. + :param hydration_hooks: + Hooks to hydrate types (mapping from type (class) to + dehydration function). Dehydration functions receive the value of + type understood by packstream and are free to return anything. :param handlers: handler functions passed into the returned Response object - :return: Response object """ pass @abc.abstractmethod - def pull(self, n=-1, qid=-1, **handlers): + def pull(self, n=-1, qid=-1, dehydration_hooks=None, hydration_hooks=None, + **handlers): """ Appends a PULL message to the output queue. :param n: number of records to pull, default = -1 (ALL) :param qid: query ID to pull for, default = -1 (last query) + :param dehydration_hooks: + Hooks to dehydrate types (dict from type (class) to dehydration + function). Dehydration functions receive the value and returns an + object of type understood by packstream. + :param hydration_hooks: + Hooks to hydrate types (mapping from type (class) to + dehydration function). Dehydration functions receive the value of + type understood by packstream and are free to return anything. :param handlers: handler functions passed into the returned Response object - :return: Response object """ pass @abc.abstractmethod def begin(self, mode=None, bookmarks=None, metadata=None, timeout=None, - db=None, imp_user=None, **handlers): + db=None, imp_user=None, dehydration_hooks=None, + hydration_hooks=None, **handlers): """ Appends a BEGIN message to the output queue. :param mode: access mode for routing - "READ" or "WRITE" (default) @@ -455,53 +502,99 @@ def begin(self, mode=None, bookmarks=None, metadata=None, timeout=None, Requires Bolt 4.0+. :param imp_user: the user to impersonate Requires Bolt 4.4+ + :param dehydration_hooks: + Hooks to dehydrate types (dict from type (class) to dehydration + function). Dehydration functions receive the value and returns an + object of type understood by packstream. + :param hydration_hooks: + Hooks to hydrate types (mapping from type (class) to + dehydration function). Dehydration functions receive the value of + type understood by packstream and are free to return anything. :param handlers: handler functions passed into the returned Response object :return: Response object """ pass @abc.abstractmethod - def commit(self, **handlers): - """ Appends a COMMIT message to the output queue.""" + def commit(self, dehydration_hooks=None, hydration_hooks=None, **handlers): + """ Appends a COMMIT message to the output queue. + + :param dehydration_hooks: + Hooks to dehydrate types (dict from type (class) to dehydration + function). Dehydration functions receive the value and returns an + object of type understood by packstream. + :param hydration_hooks: + Hooks to hydrate types (mapping from type (class) to + dehydration function). Dehydration functions receive the value of + type understood by packstream and are free to return anything. + """ pass @abc.abstractmethod - def rollback(self, **handlers): - """ Appends a ROLLBACK message to the output queue.""" + def rollback(self, dehydration_hooks=None, hydration_hooks=None, **handlers): + """ Appends a ROLLBACK message to the output queue. + + :param dehydration_hooks: + Hooks to dehydrate types (dict from type (class) to dehydration + function). Dehydration functions receive the value and returns an + object of type understood by packstream. + :param hydration_hooks: + Hooks to hydrate types (mapping from type (class) to + dehydration function). Dehydration functions receive the value of + type understood by packstream and are free to return anything.""" pass @abc.abstractmethod - async def reset(self): + async def reset(self, dehydration_hooks=None, hydration_hooks=None): """ Appends a RESET message to the outgoing queue, sends it and consumes all remaining messages. + + :param dehydration_hooks: + Hooks to dehydrate types (dict from type (class) to dehydration + function). Dehydration functions receive the value and returns an + object of type understood by packstream. + :param hydration_hooks: + Hooks to hydrate types (mapping from type (class) to + dehydration function). Dehydration functions receive the value of + type understood by packstream and are free to return anything. """ pass @abc.abstractmethod - def goodbye(self): - """Append a GOODBYE message to the outgoing queue.""" + def goodbye(self, dehydration_hooks=None, hydration_hooks=None): + """Append a GOODBYE message to the outgoing queue. + + :param dehydration_hooks: + Hooks to dehydrate types (dict from type (class) to dehydration + function). Dehydration functions receive the value and returns an + object of type understood by packstream. + :param hydration_hooks: + Hooks to hydrate types (mapping from type (class) to + dehydration function). Dehydration functions receive the value of + type understood by packstream and are free to return anything. + """ pass - def _append(self, signature, fields=(), response=None): + def new_hydration_scope(self): + return self.hydration_handler.new_hydration_scope() + + def _append(self, signature, fields=(), response=None, + dehydration_hooks=None): """ Appends a message to the outgoing queue. :param signature: the signature of the message :param fields: the fields of the message as a tuple :param response: a response object to handle callbacks + :param dehydration_hooks: + Hooks to dehydrate types (dict from type (class) to dehydration + function). Dehydration functions receive the value and returns an + object of type understood by packstream. """ - with self.outbox.tmp_buffer(): - self.packer.pack_struct(signature, fields) - self.outbox.wrap_message() + self.outbox.append_message(signature, fields, dehydration_hooks) self.responses.append(response) async def _send_all(self): - data = self.outbox.view() - if data: - try: - await self.socket.sendall(data) - except OSError as error: - await self._set_defunct_write(error) - self.outbox.clear() + if await self.outbox.flush(): self.idle_since = perf_counter() async def send_all(self): @@ -523,8 +616,7 @@ async def send_all(self): await self._send_all() @abc.abstractmethod - async def _process_message(self, details, summary_signature, - summary_metadata): + async def _process_message(self, tag, fields): """ Receive at most one message from the server, if available. :return: 2-tuple of number of detail messages and number of summary @@ -549,11 +641,10 @@ async def fetch_message(self): return 0, 0 # Receive exactly one message - details, summary_signature, summary_metadata = \ - await AsyncUtil.next(self.inbox) - res = await self._process_message( - details, summary_signature, summary_metadata + tag, fields = await self.inbox.pop( + hydration_hooks=self.responses[0].hydration_hooks ) + res = await self._process_message(tag, fields) self.idle_since = perf_counter() return res diff --git a/neo4j/_async/io/_bolt3.py b/neo4j/_async/io/_bolt3.py index b361e1f52..653f2a10a 100644 --- a/neo4j/_async/io/_bolt3.py +++ b/neo4j/_async/io/_bolt3.py @@ -142,7 +142,7 @@ def get_base_headers(self): "user_agent": self.user_agent, } - async def hello(self): + async def hello(self, dehydration_hooks=None, hydration_hooks=None): headers = self.get_base_headers() headers.update(self.auth_dict) logged_headers = dict(headers) @@ -150,13 +150,17 @@ async def hello(self): logged_headers["credentials"] = "*******" log.debug("[#%04X] C: HELLO %r", self.local_port, logged_headers) self._append(b"\x01", (headers,), - response=InitResponse(self, "hello", - on_success=self.server_info.update)) + response=InitResponse(self, "hello", hydration_hooks, + on_success=self.server_info.update), + dehydration_hooks=dehydration_hooks) await self.send_all() await self.fetch_all() check_supported_server_product(self.server_info.agent) - async def route(self, database=None, imp_user=None, bookmarks=None): + async def route( + self, database=None, imp_user=None, bookmarks=None, + dehydration_hooks=None, hydration_hooks=None + ): if database is not None: raise ConfigurationError( "Database name parameter for selecting database is not " @@ -183,16 +187,20 @@ async def route(self, database=None, imp_user=None, bookmarks=None): "CALL dbms.cluster.routing.getRoutingTable($context)", # This is an internal procedure call. Only available if the Neo4j 3.5 is setup with clustering. {"context": self.routing_context}, mode="r", # Bolt Protocol Version(3, 0) supports mode="r" + dehydration_hooks=dehydration_hooks, + hydration_hooks=hydration_hooks, on_success=metadata.update ) - self.pull(on_success=metadata.update, on_records=records.extend) + self.pull(dehydration_hooks = None, hydration_hooks = None, + on_success=metadata.update, on_records=records.extend) await self.send_all() await self.fetch_all() routing_info = [dict(zip(metadata.get("fields", ()), values)) for values in records] return routing_info def run(self, query, parameters=None, mode=None, bookmarks=None, - metadata=None, timeout=None, db=None, imp_user=None, **handlers): + metadata=None, timeout=None, db=None, imp_user=None, + dehydration_hooks=None, hydration_hooks=None, **handlers): if db is not None: raise ConfigurationError( "Database name parameter for selecting database is not " @@ -231,20 +239,29 @@ def run(self, query, parameters=None, mode=None, bookmarks=None, raise ValueError("Timeout must be a positive number or 0.") fields = (query, parameters, extra) log.debug("[#%04X] C: RUN %s", self.local_port, " ".join(map(repr, fields))) - self._append(b"\x10", fields, Response(self, "run", **handlers)) + self._append(b"\x10", fields, + Response(self, "run", hydration_hooks, **handlers), + dehydration_hooks=dehydration_hooks) - def discard(self, n=-1, qid=-1, **handlers): + def discard(self, n=-1, qid=-1, dehydration_hooks=None, + hydration_hooks=None, **handlers): # Just ignore n and qid, it is not supported in the Bolt 3 Protocol. log.debug("[#%04X] C: DISCARD_ALL", self.local_port) - self._append(b"\x2F", (), Response(self, "discard", **handlers)) + self._append(b"\x2F", (), + Response(self, "discard", hydration_hooks, **handlers), + dehydration_hooks=dehydration_hooks) - def pull(self, n=-1, qid=-1, **handlers): + def pull(self, n=-1, qid=-1, dehydration_hooks=None, hydration_hooks=None, + **handlers): # Just ignore n and qid, it is not supported in the Bolt 3 Protocol. log.debug("[#%04X] C: PULL_ALL", self.local_port) - self._append(b"\x3F", (), Response(self, "pull", **handlers)) + self._append(b"\x3F", (), + Response(self, "pull", hydration_hooks, **handlers), + dehydration_hooks=dehydration_hooks) def begin(self, mode=None, bookmarks=None, metadata=None, timeout=None, - db=None, imp_user=None, **handlers): + db=None, imp_user=None, dehydration_hooks=None, + hydration_hooks=None, **handlers): if db is not None: raise ConfigurationError( "Database name parameter for selecting database is not " @@ -280,17 +297,25 @@ def begin(self, mode=None, bookmarks=None, metadata=None, timeout=None, if extra["tx_timeout"] < 0: raise ValueError("Timeout must be a positive number or 0.") log.debug("[#%04X] C: BEGIN %r", self.local_port, extra) - self._append(b"\x11", (extra,), Response(self, "begin", **handlers)) + self._append(b"\x11", (extra,), + Response(self, "begin", hydration_hooks, **handlers), + dehydration_hooks=dehydration_hooks) - def commit(self, **handlers): + def commit(self, dehydration_hooks=None, hydration_hooks=None, **handlers): log.debug("[#%04X] C: COMMIT", self.local_port) - self._append(b"\x12", (), CommitResponse(self, "commit", **handlers)) + self._append(b"\x12", (), + CommitResponse(self, "commit", hydration_hooks, + **handlers), + dehydration_hooks=dehydration_hooks) - def rollback(self, **handlers): + def rollback(self, dehydration_hooks=None, hydration_hooks=None, + **handlers): log.debug("[#%04X] C: ROLLBACK", self.local_port) - self._append(b"\x13", (), Response(self, "rollback", **handlers)) + self._append(b"\x13", (), + Response(self, "rollback", hydration_hooks, **handlers), + dehydration_hooks=dehydration_hooks) - async def reset(self): + async def reset(self, dehydration_hooks=None, hydration_hooks=None): """ Add a RESET message to the outgoing queue, send it and consume all remaining messages. """ @@ -299,21 +324,33 @@ def fail(metadata): raise BoltProtocolError("RESET failed %r" % metadata, address=self.unresolved_address) log.debug("[#%04X] C: RESET", self.local_port) - self._append(b"\x0F", response=Response(self, "reset", on_failure=fail)) + self._append(b"\x0F", + response=Response(self, "reset", hydration_hooks, + on_failure=fail), + dehydration_hooks=dehydration_hooks) await self.send_all() await self.fetch_all() - def goodbye(self): + def goodbye(self, dehydration_hooks=None, hydration_hooks=None): log.debug("[#%04X] C: GOODBYE", self.local_port) - self._append(b"\x02", ()) + self._append(b"\x02", (), dehydration_hooks=dehydration_hooks) - async def _process_message(self, details, summary_signature, - summary_metadata): + async def _process_message(self, tag, fields): """ Process at most one message from the server, if available. :return: 2-tuple of number of detail messages and number of summary messages fetched """ + details = [] + summary_signature = summary_metadata = None + if tag == b"\x71": # RECORD + details = fields + elif fields: + summary_signature = tag + summary_metadata = fields[0] + else: + summary_signature = tag + if details: log.debug("[#%04X] S: RECORD * %d", self.local_port, len(details)) # Do not log any data await self.responses[0].on_records(details) diff --git a/neo4j/_async/io/_bolt4.py b/neo4j/_async/io/_bolt4.py index d1fbf035b..f82cf7069 100644 --- a/neo4j/_async/io/_bolt4.py +++ b/neo4j/_async/io/_bolt4.py @@ -34,11 +34,11 @@ NotALeader, ServiceUnavailable, ) +from ._bolt import AsyncBolt from ._bolt3 import ( ServerStateManager, ServerStates, ) -from ._bolt import AsyncBolt from ._common import ( check_supported_server_product, CommitResponse, @@ -95,7 +95,7 @@ def get_base_headers(self): "user_agent": self.user_agent, } - async def hello(self): + async def hello(self, dehydration_hooks=None, hydration_hooks=None): headers = self.get_base_headers() headers.update(self.auth_dict) logged_headers = dict(headers) @@ -103,13 +103,19 @@ async def hello(self): logged_headers["credentials"] = "*******" log.debug("[#%04X] C: HELLO %r", self.local_port, logged_headers) self._append(b"\x01", (headers,), - response=InitResponse(self, "hello", - on_success=self.server_info.update)) + response=InitResponse( + self, "hello", hydration_hooks, + on_success=self.server_info.update + ), + dehydration_hooks=dehydration_hooks) await self.send_all() await self.fetch_all() check_supported_server_product(self.server_info.agent) - async def route(self, database=None, imp_user=None, bookmarks=None): + async def route( + self, database=None, imp_user=None, bookmarks=None, + dehydration_hooks=None, hydration_hooks=None + ): if imp_user is not None: raise ConfigurationError( "Impersonation is not supported in Bolt Protocol {!r}. " @@ -138,14 +144,20 @@ async def route(self, database=None, imp_user=None, bookmarks=None): db=SYSTEM_DATABASE, on_success=metadata.update ) - self.pull(on_success=metadata.update, on_records=records.extend) + self.pull( + dehydration_hooks=dehydration_hooks, + hydration_hooks=hydration_hooks, + on_success=metadata.update, + on_records=records.extend + ) await self.send_all() await self.fetch_all() routing_info = [dict(zip(metadata.get("fields", ()), values)) for values in records] return routing_info def run(self, query, parameters=None, mode=None, bookmarks=None, - metadata=None, timeout=None, db=None, imp_user=None, **handlers): + metadata=None, timeout=None, db=None, imp_user=None, + dehydration_hooks=None, hydration_hooks=None, **handlers): if imp_user is not None: raise ConfigurationError( "Impersonation is not supported in Bolt Protocol {!r}. " @@ -179,24 +191,33 @@ def run(self, query, parameters=None, mode=None, bookmarks=None, raise ValueError("Timeout must be a positive number or 0.") fields = (query, parameters, extra) log.debug("[#%04X] C: RUN %s", self.local_port, " ".join(map(repr, fields))) - self._append(b"\x10", fields, Response(self, "run", **handlers)) + self._append(b"\x10", fields, + Response(self, "run", hydration_hooks, **handlers), + dehydration_hooks=dehydration_hooks) - def discard(self, n=-1, qid=-1, **handlers): + def discard(self, n=-1, qid=-1, dehydration_hooks=None, + hydration_hooks=None, **handlers): extra = {"n": n} if qid != -1: extra["qid"] = qid log.debug("[#%04X] C: DISCARD %r", self.local_port, extra) - self._append(b"\x2F", (extra,), Response(self, "discard", **handlers)) + self._append(b"\x2F", (extra,), + Response(self, "discard", hydration_hooks, **handlers), + dehydration_hooks=dehydration_hooks) - def pull(self, n=-1, qid=-1, **handlers): + def pull(self, n=-1, qid=-1, dehydration_hooks=None, hydration_hooks=None, + **handlers): extra = {"n": n} if qid != -1: extra["qid"] = qid log.debug("[#%04X] C: PULL %r", self.local_port, extra) - self._append(b"\x3F", (extra,), Response(self, "pull", **handlers)) + self._append(b"\x3F", (extra,), + Response(self, "pull", hydration_hooks, **handlers), + dehydration_hooks=dehydration_hooks) def begin(self, mode=None, bookmarks=None, metadata=None, timeout=None, - db=None, imp_user=None, **handlers): + db=None, imp_user=None, dehydration_hooks=None, + hydration_hooks=None, **handlers): if imp_user is not None: raise ConfigurationError( "Impersonation is not supported in Bolt Protocol {!r}. " @@ -227,17 +248,25 @@ def begin(self, mode=None, bookmarks=None, metadata=None, timeout=None, if extra["tx_timeout"] < 0: raise ValueError("Timeout must be a positive number or 0.") log.debug("[#%04X] C: BEGIN %r", self.local_port, extra) - self._append(b"\x11", (extra,), Response(self, "begin", **handlers)) + self._append(b"\x11", (extra,), + Response(self, "begin", hydration_hooks, **handlers), + dehydration_hooks=dehydration_hooks) - def commit(self, **handlers): + def commit(self, dehydration_hooks=None, hydration_hooks=None, **handlers): log.debug("[#%04X] C: COMMIT", self.local_port) - self._append(b"\x12", (), CommitResponse(self, "commit", **handlers)) + self._append(b"\x12", (), + CommitResponse(self, "commit", hydration_hooks, + **handlers), + dehydration_hooks=dehydration_hooks) - def rollback(self, **handlers): + def rollback(self, dehydration_hooks=None, hydration_hooks=None, + **handlers): log.debug("[#%04X] C: ROLLBACK", self.local_port) - self._append(b"\x13", (), Response(self, "rollback", **handlers)) + self._append(b"\x13", (), + Response(self, "rollback", hydration_hooks, **handlers), + dehydration_hooks=dehydration_hooks) - async def reset(self): + async def reset(self, dehydration_hooks=None, hydration_hooks=None): """ Add a RESET message to the outgoing queue, send it and consume all remaining messages. """ @@ -246,21 +275,33 @@ def fail(metadata): raise BoltProtocolError("RESET failed %r" % metadata, self.unresolved_address) log.debug("[#%04X] C: RESET", self.local_port) - self._append(b"\x0F", response=Response(self, "reset", on_failure=fail)) + self._append(b"\x0F", + response=Response(self, "reset", hydration_hooks, + on_failure=fail), + dehydration_hooks=dehydration_hooks) await self.send_all() await self.fetch_all() - def goodbye(self): + def goodbye(self, dehydration_hooks=None, hydration_hooks=None): log.debug("[#%04X] C: GOODBYE", self.local_port) - self._append(b"\x02", ()) + self._append(b"\x02", (), dehydration_hooks=dehydration_hooks) - async def _process_message(self, details, summary_signature, - summary_metadata): + async def _process_message(self, tag, fields): """ Process at most one message from the server, if available. :return: 2-tuple of number of detail messages and number of summary messages fetched """ + details = [] + summary_signature = summary_metadata = None + if tag == b"\x71": # RECORD + details = fields + elif fields: + summary_signature = tag + summary_metadata = fields[0] + else: + summary_signature = tag + if details: log.debug("[#%04X] S: RECORD * %d", self.local_port, len(details)) # Do not log any data await self.responses[0].on_records(details) @@ -341,7 +382,15 @@ class AsyncBolt4x3(AsyncBolt4x2): PROTOCOL_VERSION = Version(4, 3) - async def route(self, database=None, imp_user=None, bookmarks=None): + def get_base_headers(self): + headers = super().get_base_headers() + headers["patch_bolt"] = ["utc"] + return headers + + async def route( + self, database=None, imp_user=None, bookmarks=None, + dehydration_hooks=None, hydration_hooks=None + ): if imp_user is not None: raise ConfigurationError( "Impersonation is not supported in Bolt Protocol {!r}. " @@ -359,13 +408,14 @@ async def route(self, database=None, imp_user=None, bookmarks=None): else: bookmarks = list(bookmarks) self._append(b"\x66", (routing_context, bookmarks, database), - response=Response(self, "route", - on_success=metadata.update)) + response=Response(self, "route", hydration_hooks, + on_success=metadata.update), + dehydration_hooks=dehydration_hooks) await self.send_all() await self.fetch_all() return [metadata.get("rt")] - async def hello(self): + async def hello(self, dehydration_hooks=None, hydration_hooks=None): def on_success(metadata): self.configuration_hints.update(metadata.pop("hints", {})) self.server_info.update(metadata) @@ -380,6 +430,9 @@ def on_success(metadata): "connection.recv_timeout_seconds (%r). Make sure " "the server and network is set up correctly.", self.local_port, recv_timeout) + self.patch = set(metadata.pop("patch_bolt", [])) + if "utc" in self.patch: + self.hydration_handler.patch_utc() headers = self.get_base_headers() headers.update(self.auth_dict) @@ -388,8 +441,9 @@ def on_success(metadata): logged_headers["credentials"] = "*******" log.debug("[#%04X] C: HELLO %r", self.local_port, logged_headers) self._append(b"\x01", (headers,), - response=InitResponse(self, "hello", - on_success=on_success)) + response=InitResponse(self, "hello", hydration_hooks, + on_success=on_success), + dehydration_hooks=dehydration_hooks) await self.send_all() await self.fetch_all() check_supported_server_product(self.server_info.agent) @@ -403,7 +457,10 @@ class AsyncBolt4x4(AsyncBolt4x3): PROTOCOL_VERSION = Version(4, 4) - async def route(self, database=None, imp_user=None, bookmarks=None): + async def route( + self, database=None, imp_user=None, bookmarks=None, + dehydration_hooks=None, hydration_hooks=None + ): routing_context = self.routing_context or {} db_context = {} if database is not None: @@ -418,14 +475,16 @@ async def route(self, database=None, imp_user=None, bookmarks=None): else: bookmarks = list(bookmarks) self._append(b"\x66", (routing_context, bookmarks, db_context), - response=Response(self, "route", - on_success=metadata.update)) + response=Response(self, "route", hydration_hooks, + on_success=metadata.update), + dehydration_hooks=dehydration_hooks) await self.send_all() await self.fetch_all() return [metadata.get("rt")] def run(self, query, parameters=None, mode=None, bookmarks=None, - metadata=None, timeout=None, db=None, imp_user=None, **handlers): + metadata=None, timeout=None, db=None, imp_user=None, + dehydration_hooks=None, hydration_hooks=None, **handlers): if not parameters: parameters = {} extra = {} @@ -456,10 +515,13 @@ def run(self, query, parameters=None, mode=None, bookmarks=None, fields = (query, parameters, extra) log.debug("[#%04X] C: RUN %s", self.local_port, " ".join(map(repr, fields))) - self._append(b"\x10", fields, Response(self, "run", **handlers)) + self._append(b"\x10", fields, + Response(self, "run", hydration_hooks, **handlers), + dehydration_hooks=dehydration_hooks) def begin(self, mode=None, bookmarks=None, metadata=None, timeout=None, - db=None, imp_user=None, **handlers): + db=None, imp_user=None, dehydration_hooks=None, + hydration_hooks=None, **handlers): extra = {} if mode in (READ_ACCESS, "r"): # It will default to mode "w" if nothing is specified @@ -486,4 +548,6 @@ def begin(self, mode=None, bookmarks=None, metadata=None, timeout=None, if extra["tx_timeout"] < 0: raise ValueError("Timeout must be a positive number or 0.") log.debug("[#%04X] C: BEGIN %r", self.local_port, extra) - self._append(b"\x11", (extra,), Response(self, "begin", **handlers)) + self._append(b"\x11", (extra,), + Response(self, "begin", hydration_hooks, **handlers), + dehydration_hooks=dehydration_hooks) diff --git a/neo4j/_async/io/_bolt5.py b/neo4j/_async/io/_bolt5.py index 99901ff31..aa399e8ac 100644 --- a/neo4j/_async/io/_bolt5.py +++ b/neo4j/_async/io/_bolt5.py @@ -19,28 +19,24 @@ from logging import getLogger from ssl import SSLSocket -from ..._async_compat.util import AsyncUtil -from ..._exceptions import ( - BoltError, - BoltProtocolError, -) +from ..._codec.hydration import v2 as hydration_v2 +from ..._exceptions import BoltProtocolError from ...api import ( READ_ACCESS, Version, ) from ...exceptions import ( DatabaseUnavailable, - DriverError, ForbiddenOnReadOnlyDatabase, Neo4jError, NotALeader, ServiceUnavailable, ) +from ._bolt import AsyncBolt from ._bolt3 import ( ServerStateManager, ServerStates, ) -from ._bolt import AsyncBolt from ._common import ( check_supported_server_product, CommitResponse, @@ -57,6 +53,8 @@ class AsyncBolt5x0(AsyncBolt): PROTOCOL_VERSION = Version(5, 0) + HYDRATION_HANDLER_CLS = hydration_v2.HydrationHandler + supports_multiple_results = True supports_multiple_databases = True @@ -95,7 +93,7 @@ def get_base_headers(self): headers["routing"] = self.routing_context return headers - async def hello(self): + async def hello(self, dehydration_hooks=None, hydration_hooks=None): def on_success(metadata): self.configuration_hints.update(metadata.pop("hints", {})) self.server_info.update(metadata) @@ -118,13 +116,15 @@ def on_success(metadata): logged_headers["credentials"] = "*******" log.debug("[#%04X] C: HELLO %r", self.local_port, logged_headers) self._append(b"\x01", (headers,), - response=InitResponse(self, "hello", - on_success=on_success)) + response=InitResponse(self, "hello", hydration_hooks, + on_success=on_success), + dehydration_hooks=dehydration_hooks) await self.send_all() await self.fetch_all() check_supported_server_product(self.server_info.agent) - async def route(self, database=None, imp_user=None, bookmarks=None): + async def route(self, database=None, imp_user=None, bookmarks=None, + dehydration_hooks=None, hydration_hooks=None): routing_context = self.routing_context or {} db_context = {} if database is not None: @@ -139,14 +139,16 @@ async def route(self, database=None, imp_user=None, bookmarks=None): else: bookmarks = list(bookmarks) self._append(b"\x66", (routing_context, bookmarks, db_context), - response=Response(self, "route", - on_success=metadata.update)) + response=Response(self, "route", hydration_hooks, + on_success=metadata.update), + dehydration_hooks=hydration_hooks) await self.send_all() await self.fetch_all() return [metadata.get("rt")] def run(self, query, parameters=None, mode=None, bookmarks=None, - metadata=None, timeout=None, db=None, imp_user=None, **handlers): + metadata=None, timeout=None, db=None, imp_user=None, + dehydration_hooks=None, hydration_hooks=None, **handlers): if not parameters: parameters = {} extra = {} @@ -177,24 +179,33 @@ def run(self, query, parameters=None, mode=None, bookmarks=None, fields = (query, parameters, extra) log.debug("[#%04X] C: RUN %s", self.local_port, " ".join(map(repr, fields))) - self._append(b"\x10", fields, Response(self, "run", **handlers)) + self._append(b"\x10", fields, + Response(self, "run", hydration_hooks, **handlers), + dehydration_hooks=dehydration_hooks) - def discard(self, n=-1, qid=-1, **handlers): + def discard(self, n=-1, qid=-1, dehydration_hooks=None, + hydration_hooks=None, **handlers): extra = {"n": n} if qid != -1: extra["qid"] = qid log.debug("[#%04X] C: DISCARD %r", self.local_port, extra) - self._append(b"\x2F", (extra,), Response(self, "discard", **handlers)) + self._append(b"\x2F", (extra,), + Response(self, "discard", hydration_hooks, **handlers), + dehydration_hooks=dehydration_hooks) - def pull(self, n=-1, qid=-1, **handlers): + def pull(self, n=-1, qid=-1, dehydration_hooks=None, hydration_hooks=None, + **handlers): extra = {"n": n} if qid != -1: extra["qid"] = qid log.debug("[#%04X] C: PULL %r", self.local_port, extra) - self._append(b"\x3F", (extra,), Response(self, "pull", **handlers)) + self._append(b"\x3F", (extra,), + Response(self, "pull", hydration_hooks, **handlers), + dehydration_hooks=dehydration_hooks) def begin(self, mode=None, bookmarks=None, metadata=None, timeout=None, - db=None, imp_user=None, **handlers): + db=None, imp_user=None, dehydration_hooks=None, + hydration_hooks=None, **handlers): extra = {} if mode in (READ_ACCESS, "r"): # It will default to mode "w" if nothing is specified @@ -221,17 +232,25 @@ def begin(self, mode=None, bookmarks=None, metadata=None, timeout=None, if extra["tx_timeout"] < 0: raise ValueError("Timeout must be a number <= 0") log.debug("[#%04X] C: BEGIN %r", self.local_port, extra) - self._append(b"\x11", (extra,), Response(self, "begin", **handlers)) + self._append(b"\x11", (extra,), + Response(self, "begin", hydration_hooks, **handlers), + dehydration_hooks=dehydration_hooks) - def commit(self, **handlers): + def commit(self, dehydration_hooks=None, hydration_hooks=None, **handlers): log.debug("[#%04X] C: COMMIT", self.local_port) - self._append(b"\x12", (), CommitResponse(self, "commit", **handlers)) + self._append(b"\x12", (), + CommitResponse(self, "commit", hydration_hooks, + **handlers), + dehydration_hooks=dehydration_hooks) - def rollback(self, **handlers): + def rollback(self, dehydration_hooks=None, hydration_hooks=None, + **handlers): log.debug("[#%04X] C: ROLLBACK", self.local_port) - self._append(b"\x13", (), Response(self, "rollback", **handlers)) + self._append(b"\x13", (), + Response(self, "rollback", hydration_hooks, **handlers), + dehydration_hooks=dehydration_hooks) - async def reset(self): + async def reset(self, dehydration_hooks=None, hydration_hooks=None): """Reset the connection. Add a RESET message to the outgoing queue, send it and consume all @@ -243,22 +262,33 @@ def fail(metadata): self.unresolved_address) log.debug("[#%04X] C: RESET", self.local_port) - self._append(b"\x0F", response=Response(self, "reset", - on_failure=fail)) + self._append(b"\x0F", + response=Response(self, "reset", hydration_hooks, + on_failure=fail), + dehydration_hooks=dehydration_hooks) await self.send_all() await self.fetch_all() - def goodbye(self): + def goodbye(self, dehydration_hooks=None, hydration_hooks=None): log.debug("[#%04X] C: GOODBYE", self.local_port) - self._append(b"\x02", ()) + self._append(b"\x02", (), dehydration_hooks=dehydration_hooks) - async def _process_message(self, details, summary_signature, - summary_metadata): + async def _process_message(self, tag, fields): """Process at most one message from the server, if available. :return: 2-tuple of number of detail messages and number of summary messages fetched """ + details = [] + summary_signature = summary_metadata = None + if tag == b"\x71": # RECORD + details = fields + elif fields: + summary_signature = tag + summary_metadata = fields[0] + else: + summary_signature = tag + if details: # Do not log any data log.debug("[#%04X] S: RECORD * %d", self.local_port, len(details)) diff --git a/neo4j/_async/io/_common.py b/neo4j/_async/io/_common.py index 7ab3cebe9..98744453a 100644 --- a/neo4j/_async/io/_common.py +++ b/neo4j/_async/io/_common.py @@ -17,7 +17,6 @@ import asyncio -from contextlib import contextmanager import logging import socket from struct import pack as struct_pack @@ -30,132 +29,122 @@ SessionExpired, UnsupportedServerProduct, ) -from ...packstream import ( - UnpackableBuffer, - Unpacker, -) log = logging.getLogger("neo4j") -class AsyncMessageInbox: +class AsyncInbox: - def __init__(self, s, on_error): + def __init__(self, sock, on_error, unpacker_cls): self.on_error = on_error - self._local_port = s.getsockname()[1] - self._messages = self._yield_messages(s) - - async def _yield_messages(self, sock): + self._local_port = sock.getsockname()[1] + self._socket = sock + self._buffer = unpacker_cls.new_unpackable_buffer() + self._unpacker = unpacker_cls(self._buffer) + self._broken = False + + async def _buffer_one_chunk(self): + assert not self._broken try: - buffer = UnpackableBuffer() - unpacker = Unpacker(buffer) chunk_size = 0 while True: - while chunk_size == 0: # Determine the chunk size and skip noop - await receive_into_buffer(sock, buffer, 2) - chunk_size = buffer.pop_u16() + await receive_into_buffer(self._socket, self._buffer, 2) + chunk_size = self._buffer.pop_u16() if chunk_size == 0: log.debug("[#%04X] S: ", self._local_port) - await receive_into_buffer(sock, buffer, chunk_size + 2) - chunk_size = buffer.pop_u16() + await receive_into_buffer( + self._socket, self._buffer, chunk_size + 2 + ) + chunk_size = self._buffer.pop_u16() if chunk_size == 0: # chunk_size was the end marker for the message - size, tag = unpacker.unpack_structure_header() - fields = [unpacker.unpack() for _ in range(size)] - yield tag, fields - # Reset for new message - unpacker.reset() + return except (OSError, socket.timeout, SocketDeadlineExceeded) as error: + self._broken = True await AsyncUtil.callback(self.on_error, error) + raise - async def pop(self): - return await AsyncUtil.next(self._messages) - - -class AsyncInbox(AsyncMessageInbox): - - async def __anext__(self): - tag, fields = await self.pop() - if tag == b"\x71": - return fields, None, None - elif fields: - return [], tag, fields[0] - else: - return [], tag, None + async def pop(self, hydration_hooks): + await self._buffer_one_chunk() + try: + size, tag = self._unpacker.unpack_structure_header() + fields = [self._unpacker.unpack(hydration_hooks) + for _ in range(size)] + return tag, fields + finally: + # Reset for new message + self._unpacker.reset() -class Outbox: +class AsyncOutbox: - def __init__(self, max_chunk_size=16384): + def __init__(self, sock, on_error, packer_cls, max_chunk_size=16384): self._max_chunk_size = max_chunk_size self._chunked_data = bytearray() - self._raw_data = bytearray() - self.write = self._raw_data.extend - self._tmp_buffering = 0 + self._buffer = packer_cls.new_packable_buffer() + self._packer = packer_cls(self._buffer) + self.socket = sock + self.on_error = on_error def max_chunk_size(self): return self._max_chunk_size - def clear(self): - if self._tmp_buffering: - raise RuntimeError("Cannot clear while buffering") + def _clear(self): + assert not self._buffer.is_tmp_buffering() self._chunked_data = bytearray() - self._raw_data.clear() + self._buffer.clear() def _chunk_data(self): - data_len = len(self._raw_data) + data_len = len(self._buffer.data) num_full_chunks, chunk_rest = divmod( data_len, self._max_chunk_size ) num_chunks = num_full_chunks + bool(chunk_rest) - data_view = memoryview(self._raw_data) - header_start = len(self._chunked_data) - data_start = header_start + 2 - raw_data_start = 0 - for i in range(num_chunks): - chunk_size = min(data_len - raw_data_start, - self._max_chunk_size) - self._chunked_data[header_start:data_start] = struct_pack( - ">H", chunk_size - ) - self._chunked_data[data_start:(data_start + chunk_size)] = \ - data_view[raw_data_start:(raw_data_start + chunk_size)] - header_start += chunk_size + 2 + with memoryview(self._buffer.data) as data_view: + header_start = len(self._chunked_data) data_start = header_start + 2 - raw_data_start += chunk_size - del data_view - self._raw_data.clear() - - def wrap_message(self): - if self._tmp_buffering: - raise RuntimeError("Cannot wrap message while buffering") + raw_data_start = 0 + for i in range(num_chunks): + chunk_size = min(data_len - raw_data_start, + self._max_chunk_size) + self._chunked_data[header_start:data_start] = struct_pack( + ">H", chunk_size + ) + self._chunked_data[data_start:(data_start + chunk_size)] = \ + data_view[raw_data_start:(raw_data_start + chunk_size)] + header_start += chunk_size + 2 + data_start = header_start + 2 + raw_data_start += chunk_size + self._buffer.clear() + + def _wrap_message(self): + assert not self._buffer.is_tmp_buffering() self._chunk_data() self._chunked_data += b"\x00\x00" - def view(self): - if self._tmp_buffering: - raise RuntimeError("Cannot view while buffering") - self._chunk_data() - return memoryview(self._chunked_data) + def append_message(self, tag, fields, dehydration_hooks): + with self._buffer.tmp_buffer(): + self._packer.pack_struct(tag, fields, dehydration_hooks) + self._wrap_message() - @contextmanager - def tmp_buffer(self): - self._tmp_buffering += 1 - old_len = len(self._raw_data) - try: - yield - except Exception: - del self._raw_data[old_len:] - raise - finally: - self._tmp_buffering -= 1 + async def flush(self): + data = self._chunked_data + if data: + try: + await self.socket.sendall(data) + except OSError as error: + await self.on_error(error) + return False + self._clear() + return True + return False class ConnectionErrorHandler: @@ -218,8 +207,9 @@ class Response: more detail messages followed by one summary message). """ - def __init__(self, connection, message, **handlers): + def __init__(self, connection, message, hydration_hooks, **handlers): self.connection = connection + self.hydration_hooks = hydration_hooks self.handlers = handlers self.message = message self.complete = False @@ -294,9 +284,9 @@ async def receive_into_buffer(sock, buffer, n_bytes): end = buffer.used + n_bytes if end > len(buffer.data): buffer.data += bytearray(end - len(buffer.data)) - view = memoryview(buffer.data) - while buffer.used < end: - n = await sock.recv_into(view[buffer.used:end], end - buffer.used) - if n == 0: - raise OSError("No data") - buffer.used += n + with memoryview(buffer.data) as view: + while buffer.used < end: + n = await sock.recv_into(view[buffer.used:end], end - buffer.used) + if n == 0: + raise OSError("No data") + buffer.used += n diff --git a/neo4j/_async/io/_pool.py b/neo4j/_async/io/_pool.py index 198e0f8ac..0a3bfac58 100644 --- a/neo4j/_async/io/_pool.py +++ b/neo4j/_async/io/_pool.py @@ -17,12 +17,12 @@ import abc +import logging from collections import ( defaultdict, deque, ) from contextlib import asynccontextmanager -import logging from logging import getLogger from random import choice @@ -31,6 +31,10 @@ AsyncRLock, ) from ..._async_compat.network import AsyncNetworkUtil +from ..._conf import ( + PoolConfig, + WorkspaceConfig, +) from ..._deadline import ( connection_deadline, Deadline, @@ -38,14 +42,11 @@ merge_deadlines_and_timeouts, ) from ..._exceptions import BoltError +from ..._routing import RoutingTable from ...api import ( READ_ACCESS, WRITE_ACCESS, ) -from ...conf import ( - PoolConfig, - WorkspaceConfig, -) from ...exceptions import ( ClientError, ConfigurationError, @@ -56,7 +57,6 @@ SessionExpired, WriteServiceUnavailable, ) -from ...routing import RoutingTable from ._bolt import AsyncBolt diff --git a/neo4j/_async/work/result.py b/neo4j/_async/work/result.py index 9a2b6af3d..7f913cb72 100644 --- a/neo4j/_async/work/result.py +++ b/neo4j/_async/work/result.py @@ -20,15 +20,15 @@ from warnings import warn from ..._async_compat.util import AsyncUtil -from ...data import ( - DataDehydrator, +from ..._data import ( + Record, RecordTableRowExporter, ) +from ..._meta import experimental from ...exceptions import ( ResultConsumedError, ResultNotSingleError, ) -from ...meta import experimental from ...time import ( Date, DateTime, @@ -54,10 +54,9 @@ class AsyncResult: :meth:`.AyncSession.run` and :meth:`.AsyncTransaction.run`. """ - def __init__(self, connection, hydrant, fetch_size, on_closed, - on_error): + def __init__(self, connection, fetch_size, on_closed, on_error): self._connection = ConnectionErrorHandler(connection, on_error) - self._hydrant = hydrant + self._hydration_scope = connection.new_hydration_scope() self._on_closed = on_closed self._metadata = None self._keys = None @@ -104,7 +103,7 @@ async def _run( query_metadata = getattr(query, "metadata", None) query_timeout = getattr(query, "timeout", None) - parameters = DataDehydrator.fix_parameters(dict(parameters or {}, **kwargs)) + parameters = dict(parameters or {}, **kwargs) self._metadata = { "query": query_text, @@ -135,6 +134,7 @@ async def on_failed_attach(metadata): timeout=query_timeout, db=db, imp_user=imp_user, + dehydration_hooks=self._hydration_scope.dehydration_hooks, on_success=on_attached, on_failure=on_failed_attach, ) @@ -145,7 +145,10 @@ async def on_failed_attach(metadata): def _pull(self): def on_records(records): if not self._discarding: - self._record_buffer.extend(self._hydrant.hydrate_records(self._keys, records)) + self._record_buffer.extend(( + Record(zip(self._keys, record)) + for record in records + )) async def on_summary(): self._attached = False @@ -167,6 +170,7 @@ def on_success(summary_metadata): self._connection.pull( n=self._fetch_size, qid=self._qid, + hydration_hooks=self._hydration_scope.hydration_hooks, on_records=on_records, on_success=on_success, on_failure=on_failure, @@ -479,7 +483,7 @@ async def graph(self): Can raise :exc:`ResultConsumedError`. """ await self._buffer_all() - return self._hydrant.graph + return self._hydration_scope.get_graph() async def value(self, key=0, default=None): """Helper function that return the remainder of the result as a list of values. diff --git a/neo4j/_async/work/session.py b/neo4j/_async/work/session.py index d8719cc56..650a1c634 100644 --- a/neo4j/_async/work/session.py +++ b/neo4j/_async/work/session.py @@ -21,13 +21,16 @@ from time import perf_counter from ..._async_compat import async_sleep +from ..._conf import SessionConfig +from ..._meta import ( + deprecated, + deprecation_warn, +) from ...api import ( Bookmarks, READ_ACCESS, WRITE_ACCESS, ) -from ...conf import SessionConfig -from ...data import DataHydrator from ...exceptions import ( ClientError, DriverError, @@ -36,10 +39,6 @@ SessionExpired, TransactionError, ) -from ...meta import ( - deprecated, - deprecation_warn, -) from ...work import Query from .result import AsyncResult from .transaction import ( @@ -228,10 +227,8 @@ async def run(self, query, parameters=None, **kwargs): protocol_version = cx.PROTOCOL_VERSION server_info = cx.server_info - hydrant = DataHydrator() - self._auto_result = AsyncResult( - cx, hydrant, self._config.fetch_size, self._result_closed, + cx, self._config.fetch_size, self._result_closed, self._result_error ) await self._auto_result._run( diff --git a/neo4j/_async/work/transaction.py b/neo4j/_async/work/transaction.py index e293a8ecd..03b76476e 100644 --- a/neo4j/_async/work/transaction.py +++ b/neo4j/_async/work/transaction.py @@ -19,7 +19,6 @@ from functools import wraps from ..._async_compat.util import AsyncUtil -from ...data import DataHydrator from ...exceptions import TransactionError from ...work import Query from ..io import ConnectionErrorHandler @@ -123,8 +122,7 @@ async def run(self, query, parameters=None, **kwparameters): await self._results[-1]._buffer_all() result = AsyncResult( - self._connection, DataHydrator(), self._fetch_size, - self._result_on_closed_handler, + self._connection, self._fetch_size, self._result_on_closed_handler, self._error_handler ) self._results.append(result) diff --git a/neo4j/_async/work/workspace.py b/neo4j/_async/work/workspace.py index e4d1d3118..9c589db57 100644 --- a/neo4j/_async/work/workspace.py +++ b/neo4j/_async/work/workspace.py @@ -18,16 +18,16 @@ import asyncio +from ..._conf import WorkspaceConfig from ..._deadline import Deadline -from ...conf import WorkspaceConfig +from ..._meta import ( + deprecation_warn, + unclosed_resource_warn, +) from ...exceptions import ( ServiceUnavailable, SessionExpired, ) -from ...meta import ( - deprecation_warn, - unclosed_resource_warn, -) from ..io import AsyncNeo4jPool diff --git a/neo4j/_async_compat/network/_bolt_socket.py b/neo4j/_async_compat/network/_bolt_socket.py index e8405bc18..475536bc6 100644 --- a/neo4j/_async_compat/network/_bolt_socket.py +++ b/neo4j/_async_compat/network/_bolt_socket.py @@ -20,6 +20,7 @@ import logging import selectors import socket +import struct from socket import ( AF_INET, AF_INET6, @@ -34,7 +35,6 @@ HAS_SNI, SSLError, ) -import struct from time import perf_counter from ... import addressing diff --git a/neo4j/_async_compat/util.py b/neo4j/_async_compat/util.py index 98e29369a..fd510d02b 100644 --- a/neo4j/_async_compat/util.py +++ b/neo4j/_async_compat/util.py @@ -18,7 +18,7 @@ import inspect -from ..meta import experimental +from .._meta import experimental __all__ = [ diff --git a/tests/unit/common/data/__init__.py b/neo4j/_codec/__init__.py similarity index 100% rename from tests/unit/common/data/__init__.py rename to neo4j/_codec/__init__.py diff --git a/neo4j/_codec/hydration/__init__.py b/neo4j/_codec/hydration/__init__.py new file mode 100644 index 000000000..bd4fdb81f --- /dev/null +++ b/neo4j/_codec/hydration/__init__.py @@ -0,0 +1,25 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [https://neo4j.com] +# +# This file is part of Neo4j. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from ._common import HydrationScope +from ._interface import HydrationHandlerABC + + +__all__ = [ + "HydrationHandlerABC", + "HydrationScope", +] diff --git a/neo4j/_codec/hydration/_common.py b/neo4j/_codec/hydration/_common.py new file mode 100644 index 000000000..3a51b030d --- /dev/null +++ b/neo4j/_codec/hydration/_common.py @@ -0,0 +1,50 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [https://neo4j.com] +# +# This file is part of Neo4j. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from ...graph import Graph +from ..packstream import Structure + + +class GraphHydrator: + def __init__(self): + self.graph = Graph() + self.struct_hydration_functions = {} + + +class HydrationScope: + + def __init__(self, hydration_handler, graph_hydrator): + self._hydration_handler = hydration_handler + self._graph_hydrator = graph_hydrator + self._struct_hydration_functions = { + **hydration_handler.struct_hydration_functions, + **graph_hydrator.struct_hydration_functions, + } + self.hydration_hooks = { + Structure: self._hydrate_structure, + } + self.dehydration_hooks = hydration_handler.dehydration_functions + + def _hydrate_structure(self, value): + f = self._struct_hydration_functions.get(value.tag) + if not f: + return value + return f(*value.fields) + + def get_graph(self): + return self._graph_hydrator.graph diff --git a/neo4j/_codec/hydration/_interface/__init__.py b/neo4j/_codec/hydration/_interface/__init__.py new file mode 100644 index 000000000..5092d5e0d --- /dev/null +++ b/neo4j/_codec/hydration/_interface/__init__.py @@ -0,0 +1,29 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [https://neo4j.com] +# +# This file is part of Neo4j. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import abc + + +class HydrationHandlerABC(abc.ABC): + def __init__(self): + self.struct_hydration_functions = {} + self.dehydration_functions = {} + + @abc.abstractmethod + def new_hydration_scope(self): + ... diff --git a/neo4j/_codec/hydration/v1/__init__.py b/neo4j/_codec/hydration/v1/__init__.py new file mode 100644 index 000000000..985f6f033 --- /dev/null +++ b/neo4j/_codec/hydration/v1/__init__.py @@ -0,0 +1,24 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [https://neo4j.com] +# +# This file is part of Neo4j. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from .hydration_handler import HydrationHandler + + +__all__ = [ + "HydrationHandler", +] diff --git a/neo4j/_codec/hydration/v1/hydration_handler.py b/neo4j/_codec/hydration/v1/hydration_handler.py new file mode 100644 index 000000000..89839503f --- /dev/null +++ b/neo4j/_codec/hydration/v1/hydration_handler.py @@ -0,0 +1,196 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [https://neo4j.com] +# +# This file is part of Neo4j. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from datetime import ( + date, + datetime, + time, + timedelta, +) + +from ....graph import ( + Graph, + Node, + Path, +) +from ....spatial import ( + CartesianPoint, + Point, + WGS84Point, +) +from ....time import ( + Date, + DateTime, + Duration, + Time, +) +from .._common import ( + GraphHydrator, + HydrationScope, +) +from .._interface import HydrationHandlerABC +from . import ( + spatial, + temporal, +) + + +class _GraphHydrator(GraphHydrator): + def __init__(self): + super().__init__() + self.struct_hydration_functions = { + **self.struct_hydration_functions, + b"N": self.hydrate_node, + b"R": self.hydrate_relationship, + b"r": self.hydrate_unbound_relationship, + b"P": self.hydrate_path, + } + + def hydrate_node(self, id_, labels=None, + properties=None, element_id=None): + assert isinstance(self.graph, Graph) + # backwards compatibility with Neo4j < 5.0 + if element_id is None: + element_id = str(id_) + + try: + inst = self.graph._nodes[element_id] + except KeyError: + inst = Node(self.graph, element_id, id_, labels, properties) + self.graph._nodes[element_id] = inst + self.graph._legacy_nodes[id_] = inst + else: + # If we have already hydrated this node as the endpoint of + # a relationship, it won't have any labels or properties. + # Therefore, we need to add the ones we have here. + if labels: + inst._labels = inst._labels.union(labels) # frozen_set + if properties: + inst._properties.update(properties) + return inst + + def hydrate_relationship(self, id_, n0_id, n1_id, type_, + properties=None, element_id=None, + n0_element_id=None, n1_element_id=None): + # backwards compatibility with Neo4j < 5.0 + if element_id is None: + element_id = str(id_) + if n0_element_id is None: + n0_element_id = str(n0_id) + if n1_element_id is None: + n1_element_id = str(n1_id) + + inst = self.hydrate_unbound_relationship(id_, type_, properties, + element_id) + inst._start_node = self.hydrate_node(n0_id, + element_id=n0_element_id) + inst._end_node = self.hydrate_node(n1_id, element_id=n1_element_id) + return inst + + def hydrate_unbound_relationship(self, id_, type_, properties=None, + element_id=None): + assert isinstance(self.graph, Graph) + # backwards compatibility with Neo4j < 5.0 + if element_id is None: + element_id = str(id_) + + try: + inst = self.graph._relationships[element_id] + except KeyError: + r = self.graph.relationship_type(type_) + inst = r( + self.graph, element_id, id_, properties + ) + self.graph._relationships[element_id] = inst + self.graph._legacy_relationships[id_] = inst + return inst + + def hydrate_path(self, nodes, relationships, sequence): + assert isinstance(self.graph, Graph) + assert len(nodes) >= 1 + assert len(sequence) % 2 == 0 + last_node = nodes[0] + entities = [last_node] + for i, rel_index in enumerate(sequence[::2]): + assert rel_index != 0 + next_node = nodes[sequence[2 * i + 1]] + if rel_index > 0: + r = relationships[rel_index - 1] + r._start_node = last_node + r._end_node = next_node + entities.append(r) + else: + r = relationships[-rel_index - 1] + r._start_node = next_node + r._end_node = last_node + entities.append(r) + last_node = next_node + return Path(*entities) + + +class HydrationHandler(HydrationHandlerABC): + def __init__(self): + super().__init__() + self._created_scope = False + self.struct_hydration_functions = { + **self.struct_hydration_functions, + b"X": spatial.hydrate_point, + b"Y": spatial.hydrate_point, + b"D": temporal.hydrate_date, + b"T": temporal.hydrate_time, # time zone offset + b"t": temporal.hydrate_time, # no time zone + b"F": temporal.hydrate_datetime, # time zone offset + b"f": temporal.hydrate_datetime, # time zone name + b"d": temporal.hydrate_datetime, # no time zone + b"E": temporal.hydrate_duration, + } + self.dehydration_functions = { + **self.dehydration_functions, + Point: spatial.dehydrate_point, + CartesianPoint: spatial.dehydrate_point, + WGS84Point: spatial.dehydrate_point, + Date: temporal.dehydrate_date, + date: temporal.dehydrate_date, + Time: temporal.dehydrate_time, + time: temporal.dehydrate_time, + DateTime: temporal.dehydrate_datetime, + datetime: temporal.dehydrate_datetime, + Duration: temporal.dehydrate_duration, + timedelta: temporal.dehydrate_timedelta, + } + + def patch_utc(self): + from ..v2 import temporal as temporal_v2 + + assert not self._created_scope + + del self.struct_hydration_functions[b"F"] + del self.struct_hydration_functions[b"f"] + self.struct_hydration_functions.update({ + b"I": temporal_v2.hydrate_datetime, + b"i": temporal_v2.hydrate_datetime, + }) + + self.dehydration_functions.update({ + DateTime: temporal_v2.dehydrate_datetime, + datetime: temporal_v2.dehydrate_datetime, + }) + + def new_hydration_scope(self): + self._created_scope = True + return HydrationScope(self, _GraphHydrator()) diff --git a/neo4j/_codec/hydration/v1/spatial.py b/neo4j/_codec/hydration/v1/spatial.py new file mode 100644 index 000000000..6e2f7a6f5 --- /dev/null +++ b/neo4j/_codec/hydration/v1/spatial.py @@ -0,0 +1,63 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [https://neo4j.com] +# +# This file is part of Neo4j. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from ...._spatial import ( + Point, + srid_table, +) +from ...packstream import Structure + + +def hydrate_point(srid, *coordinates): + """ Create a new instance of a Point subclass from a raw + set of fields. The subclass chosen is determined by the + given SRID; a ValueError will be raised if no such + subclass can be found. + """ + try: + point_class, dim = srid_table[srid] + except KeyError: + point = Point(coordinates) + point.srid = srid + return point + else: + if len(coordinates) != dim: + raise ValueError("SRID %d requires %d coordinates (%d provided)" % (srid, dim, len(coordinates))) + return point_class(coordinates) + + +def dehydrate_point(value): + """ Dehydrator for Point data. + + :param value: + :type value: Point + :return: + """ + dim = len(value) + if dim == 2: + return Structure(b"X", value.srid, *value) + elif dim == 3: + return Structure(b"Y", value.srid, *value) + else: + raise ValueError("Cannot dehydrate Point with %d dimensions" % dim) + + +__all__ = [ + "hydrate_point", + "dehydrate_point", +] diff --git a/neo4j/_codec/hydration/v1/temporal.py b/neo4j/_codec/hydration/v1/temporal.py new file mode 100644 index 000000000..a9c511a5a --- /dev/null +++ b/neo4j/_codec/hydration/v1/temporal.py @@ -0,0 +1,207 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [https://neo4j.com] +# +# This file is part of Neo4j. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from datetime import ( + datetime, + time, + timedelta, +) + +from ....time import ( + Date, + DateTime, + Duration, + Time, +) +from ...packstream import Structure + + +def get_date_unix_epoch(): + return Date(1970, 1, 1) + + +def get_date_unix_epoch_ordinal(): + return get_date_unix_epoch().to_ordinal() + + +def get_datetime_unix_epoch_utc(): + from pytz import utc + return DateTime(1970, 1, 1, 0, 0, 0, utc) + + +def hydrate_date(days): + """ Hydrator for `Date` values. + + :param days: + :return: Date + """ + return Date.from_ordinal(get_date_unix_epoch_ordinal() + days) + + +def dehydrate_date(value): + """ Dehydrator for `date` values. + + :param value: + :type value: Date + :return: + """ + return Structure(b"D", value.toordinal() - get_date_unix_epoch().toordinal()) + + +def hydrate_time(nanoseconds, tz=None): + """ Hydrator for `Time` and `LocalTime` values. + + :param nanoseconds: + :param tz: + :return: Time + """ + from pytz import FixedOffset + seconds, nanoseconds = map(int, divmod(nanoseconds, 1000000000)) + minutes, seconds = map(int, divmod(seconds, 60)) + hours, minutes = map(int, divmod(minutes, 60)) + t = Time(hours, minutes, seconds, nanoseconds) + if tz is None: + return t + tz_offset_minutes, tz_offset_seconds = divmod(tz, 60) + zone = FixedOffset(tz_offset_minutes) + return zone.localize(t) + + +def dehydrate_time(value): + """ Dehydrator for `time` values. + + :param value: + :type value: Time + :return: + """ + if isinstance(value, Time): + nanoseconds = value.ticks + elif isinstance(value, time): + nanoseconds = (3600000000000 * value.hour + 60000000000 * value.minute + + 1000000000 * value.second + 1000 * value.microsecond) + else: + raise TypeError("Value must be a neo4j.time.Time or a datetime.time") + if value.tzinfo: + return Structure(b"T", nanoseconds, + int(value.tzinfo.utcoffset(value).total_seconds())) + else: + return Structure(b"t", nanoseconds) + + +def hydrate_datetime(seconds, nanoseconds, tz=None): + """ Hydrator for `DateTime` and `LocalDateTime` values. + + :param seconds: + :param nanoseconds: + :param tz: + :return: datetime + """ + from pytz import ( + FixedOffset, + timezone, + ) + minutes, seconds = map(int, divmod(seconds, 60)) + hours, minutes = map(int, divmod(minutes, 60)) + days, hours = map(int, divmod(hours, 24)) + t = DateTime.combine( + Date.from_ordinal(get_date_unix_epoch_ordinal() + days), + Time(hours, minutes, seconds, nanoseconds) + ) + if tz is None: + return t + if isinstance(tz, int): + tz_offset_minutes, tz_offset_seconds = divmod(tz, 60) + zone = FixedOffset(tz_offset_minutes) + else: + zone = timezone(tz) + return zone.localize(t) + + +def dehydrate_datetime(value): + """ Dehydrator for `datetime` values. + + :param value: + :type value: datetime or DateTime + :return: + """ + + def seconds_and_nanoseconds(dt): + if isinstance(dt, datetime): + dt = DateTime.from_native(dt) + zone_epoch = DateTime(1970, 1, 1, tzinfo=dt.tzinfo) + dt_clock_time = dt.to_clock_time() + zone_epoch_clock_time = zone_epoch.to_clock_time() + t = dt_clock_time - zone_epoch_clock_time + return t.seconds, t.nanoseconds + + tz = value.tzinfo + if tz is None: + # without time zone + from pytz import utc + value = utc.localize(value) + seconds, nanoseconds = seconds_and_nanoseconds(value) + return Structure(b"d", seconds, nanoseconds) + elif hasattr(tz, "zone") and tz.zone and isinstance(tz.zone, str): + # with named pytz time zone + seconds, nanoseconds = seconds_and_nanoseconds(value) + return Structure(b"f", seconds, nanoseconds, tz.zone) + elif hasattr(tz, "key") and tz.key and isinstance(tz.key, str): + # with named zoneinfo (Python 3.9+) time zone + seconds, nanoseconds = seconds_and_nanoseconds(value) + return Structure(b"f", seconds, nanoseconds, tz.key) + else: + # with time offset + seconds, nanoseconds = seconds_and_nanoseconds(value) + return Structure(b"F", seconds, nanoseconds, + int(tz.utcoffset(value).total_seconds())) + + +def hydrate_duration(months, days, seconds, nanoseconds): + """ Hydrator for `Duration` values. + + :param months: + :param days: + :param seconds: + :param nanoseconds: + :return: `duration` namedtuple + """ + return Duration(months=months, days=days, seconds=seconds, nanoseconds=nanoseconds) + + +def dehydrate_duration(value): + """ Dehydrator for `duration` values. + + :param value: + :type value: Duration + :return: + """ + return Structure(b"E", value.months, value.days, value.seconds, value.nanoseconds) + + +def dehydrate_timedelta(value): + """ Dehydrator for `timedelta` values. + + :param value: + :type value: timedelta + :return: + """ + months = 0 + days = value.days + seconds = value.seconds + nanoseconds = 1000 * value.microseconds + return Structure(b"E", months, days, seconds, nanoseconds) diff --git a/neo4j/_codec/hydration/v2/__init__.py b/neo4j/_codec/hydration/v2/__init__.py new file mode 100644 index 000000000..c3cd9e2e8 --- /dev/null +++ b/neo4j/_codec/hydration/v2/__init__.py @@ -0,0 +1,23 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [https://neo4j.com] +# +# This file is part of Neo4j. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .hydration_handler import HydrationHandler + + +__all__ = [ + "HydrationHandler", +] diff --git a/neo4j/_codec/hydration/v2/hydration_handler.py b/neo4j/_codec/hydration/v2/hydration_handler.py new file mode 100644 index 000000000..092201a07 --- /dev/null +++ b/neo4j/_codec/hydration/v2/hydration_handler.py @@ -0,0 +1,57 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [https://neo4j.com] +# +# This file is part of Neo4j. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from ..v1.hydration_handler import * +from ..v1.hydration_handler import _GraphHydrator +from . import temporal + + +class HydrationHandler(HydrationHandlerABC): + def __init__(self): + super().__init__() + self._created_scope = False + self.struct_hydration_functions = { + **self.struct_hydration_functions, + b"X": spatial.hydrate_point, + b"Y": spatial.hydrate_point, + b"D": temporal.hydrate_date, + b"T": temporal.hydrate_time, # time zone offset + b"t": temporal.hydrate_time, # no time zone + b"I": temporal.hydrate_datetime, # time zone offset + b"i": temporal.hydrate_datetime, # time zone name + b"d": temporal.hydrate_datetime, # no time zone + b"E": temporal.hydrate_duration, + } + self.dehydration_functions = { + **self.dehydration_functions, + Point: spatial.dehydrate_point, + CartesianPoint: spatial.dehydrate_point, + WGS84Point: spatial.dehydrate_point, + Date: temporal.dehydrate_date, + date: temporal.dehydrate_date, + Time: temporal.dehydrate_time, + time: temporal.dehydrate_time, + DateTime: temporal.dehydrate_datetime, + datetime: temporal.dehydrate_datetime, + Duration: temporal.dehydrate_duration, + timedelta: temporal.dehydrate_timedelta, + } + + def new_hydration_scope(self): + self._created_scope = True + return HydrationScope(self, _GraphHydrator()) diff --git a/neo4j/_codec/hydration/v2/temporal.py b/neo4j/_codec/hydration/v2/temporal.py new file mode 100644 index 000000000..4741ce9aa --- /dev/null +++ b/neo4j/_codec/hydration/v2/temporal.py @@ -0,0 +1,92 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [https://neo4j.com] +# +# This file is part of Neo4j. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from ..v1.temporal import * + + +def hydrate_datetime(seconds, nanoseconds, tz=None): + """ Hydrator for `DateTime` and `LocalDateTime` values. + + :param seconds: + :param nanoseconds: + :param tz: + :return: datetime + """ + import pytz + + minutes, seconds = map(int, divmod(seconds, 60)) + hours, minutes = map(int, divmod(minutes, 60)) + days, hours = map(int, divmod(hours, 24)) + t = DateTime.combine( + Date.from_ordinal(get_date_unix_epoch_ordinal() + days), + Time(hours, minutes, seconds, nanoseconds) + ) + if tz is None: + return t + if isinstance(tz, int): + tz_offset_minutes, tz_offset_seconds = divmod(tz, 60) + zone = pytz.FixedOffset(tz_offset_minutes) + else: + zone = pytz.timezone(tz) + t = t.replace(tzinfo=pytz.UTC) + return t.as_timezone(zone) + + +def dehydrate_datetime(value): + """ Dehydrator for `datetime` values. + + :param value: + :type value: datetime + :return: + """ + + import pytz + + def seconds_and_nanoseconds(dt): + if isinstance(dt, datetime): + dt = DateTime.from_native(dt) + dt = dt.astimezone(pytz.UTC) + utc_epoch = DateTime(1970, 1, 1, tzinfo=pytz.UTC) + dt_clock_time = dt.to_clock_time() + utc_epoch_clock_time = utc_epoch.to_clock_time() + t = dt_clock_time - utc_epoch_clock_time + return t.seconds, t.nanoseconds + + tz = value.tzinfo + if tz is None: + # without time zone + value = pytz.UTC.localize(value) + seconds, nanoseconds = seconds_and_nanoseconds(value) + return Structure(b"d", seconds, nanoseconds) + elif hasattr(tz, "zone") and tz.zone and isinstance(tz.zone, str): + # with named pytz time zone + seconds, nanoseconds = seconds_and_nanoseconds(value) + return Structure(b"i", seconds, nanoseconds, tz.zone) + elif hasattr(tz, "key") and tz.key and isinstance(tz.key, str): + # with named zoneinfo (Python 3.9+) time zone + seconds, nanoseconds = seconds_and_nanoseconds(value) + return Structure(b"i", seconds, nanoseconds, tz.key) + else: + # with time offset + seconds, nanoseconds = seconds_and_nanoseconds(value) + offset = tz.utcoffset(value) + if offset.microseconds: + raise ValueError("Bolt protocol does not support sub-second " + "UTC offsets.") + offset_seconds = offset.days * 86400 + offset.seconds + return Structure(b"I", seconds, nanoseconds, offset_seconds) diff --git a/neo4j/_codec/packstream/__init__.py b/neo4j/_codec/packstream/__init__.py new file mode 100644 index 000000000..ba0188b0f --- /dev/null +++ b/neo4j/_codec/packstream/__init__.py @@ -0,0 +1,24 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [https://neo4j.com] +# +# This file is part of Neo4j. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from ._common import Structure + + +__all__ = [ + "Structure", +] diff --git a/neo4j/_codec/packstream/_common.py b/neo4j/_codec/packstream/_common.py new file mode 100644 index 000000000..84403de7c --- /dev/null +++ b/neo4j/_codec/packstream/_common.py @@ -0,0 +1,44 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [https://neo4j.com] +# +# This file is part of Neo4j. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +class Structure: + + def __init__(self, tag, *fields): + self.tag = tag + self.fields = list(fields) + + def __repr__(self): + return "Structure[0x%02X](%s)" % (ord(self.tag), ", ".join(map(repr, self.fields))) + + def __eq__(self, other): + try: + return self.tag == other.tag and self.fields == other.fields + except AttributeError: + return False + + def __ne__(self, other): + return not self.__eq__(other) + + def __len__(self): + return len(self.fields) + + def __getitem__(self, key): + return self.fields[key] + + def __setitem__(self, key, value): + self.fields[key] = value diff --git a/neo4j/_codec/packstream/v1/__init__.py b/neo4j/_codec/packstream/v1/__init__.py new file mode 100644 index 000000000..0c74b6879 --- /dev/null +++ b/neo4j/_codec/packstream/v1/__init__.py @@ -0,0 +1,462 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [https://neo4j.com] +# +# This file is part of Neo4j. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from codecs import decode +from contextlib import contextmanager +from struct import ( + pack as struct_pack, + unpack as struct_unpack, +) + +from .._common import Structure + + +PACKED_UINT_8 = [struct_pack(">B", value) for value in range(0x100)] +PACKED_UINT_16 = [struct_pack(">H", value) for value in range(0x10000)] + +UNPACKED_UINT_8 = {bytes(bytearray([x])): x for x in range(0x100)} +UNPACKED_UINT_16 = {struct_pack(">H", x): x for x in range(0x10000)} + +UNPACKED_MARKERS = {b"\xC0": None, b"\xC2": False, b"\xC3": True} +UNPACKED_MARKERS.update({bytes(bytearray([z])): z for z in range(0x00, 0x80)}) +UNPACKED_MARKERS.update({bytes(bytearray([z + 256])): z for z in range(-0x10, 0x00)}) + + +INT64_MIN = -(2 ** 63) +INT64_MAX = 2 ** 63 + + +class Packer: + + def __init__(self, stream): + self.stream = stream + self._write = self.stream.write + + def pack_raw(self, data): + self._write(data) + + def pack(self, value, dehydration_hooks=None): + write = self._write + + # None + if value is None: + write(b"\xC0") # NULL + + # Boolean + elif value is True: + write(b"\xC3") + elif value is False: + write(b"\xC2") + + # Float (only double precision is supported) + elif isinstance(value, float): + write(b"\xC1") + write(struct_pack(">d", value)) + + # Integer + elif isinstance(value, int): + if -0x10 <= value < 0x80: + write(PACKED_UINT_8[value % 0x100]) + elif -0x80 <= value < -0x10: + write(b"\xC8") + write(PACKED_UINT_8[value % 0x100]) + elif -0x8000 <= value < 0x8000: + write(b"\xC9") + write(PACKED_UINT_16[value % 0x10000]) + elif -0x80000000 <= value < 0x80000000: + write(b"\xCA") + write(struct_pack(">i", value)) + elif INT64_MIN <= value < INT64_MAX: + write(b"\xCB") + write(struct_pack(">q", value)) + else: + raise OverflowError("Integer %s out of range" % value) + + # String + elif isinstance(value, str): + encoded = value.encode("utf-8") + self.pack_string_header(len(encoded)) + self.pack_raw(encoded) + + # Bytes + elif isinstance(value, (bytes, bytearray)): + self.pack_bytes_header(len(value)) + self.pack_raw(value) + + # List + elif isinstance(value, list): + self.pack_list_header(len(value)) + for item in value: + self.pack(item, dehydration_hooks=dehydration_hooks) + + # Map + elif isinstance(value, dict): + self.pack_map_header(len(value)) + for key, item in value.items(): + if not isinstance(key, str): + raise TypeError( + "Map keys must be strings, not {}".format(type(key)) + ) + self.pack(key, dehydration_hooks=dehydration_hooks) + self.pack(item, dehydration_hooks=dehydration_hooks) + + # Structure + elif isinstance(value, Structure): + self.pack_struct(value.tag, value.fields) + + # Other + elif dehydration_hooks and type(value) in dehydration_hooks: + self.pack(dehydration_hooks[type(value)](value)) + else: + raise ValueError("Values of type %s are not supported" % type(value)) + + def pack_bytes_header(self, size): + write = self._write + if size < 0x100: + write(b"\xCC") + write(PACKED_UINT_8[size]) + elif size < 0x10000: + write(b"\xCD") + write(PACKED_UINT_16[size]) + elif size < 0x100000000: + write(b"\xCE") + write(struct_pack(">I", size)) + else: + raise OverflowError("Bytes header size out of range") + + def pack_string_header(self, size): + write = self._write + if size <= 0x0F: + write(bytes((0x80 | size,))) + elif size < 0x100: + write(b"\xD0") + write(PACKED_UINT_8[size]) + elif size < 0x10000: + write(b"\xD1") + write(PACKED_UINT_16[size]) + elif size < 0x100000000: + write(b"\xD2") + write(struct_pack(">I", size)) + else: + raise OverflowError("String header size out of range") + + def pack_list_header(self, size): + write = self._write + if size <= 0x0F: + write(bytes((0x90 | size,))) + elif size < 0x100: + write(b"\xD4") + write(PACKED_UINT_8[size]) + elif size < 0x10000: + write(b"\xD5") + write(PACKED_UINT_16[size]) + elif size < 0x100000000: + write(b"\xD6") + write(struct_pack(">I", size)) + else: + raise OverflowError("List header size out of range") + + def pack_map_header(self, size): + write = self._write + if size <= 0x0F: + write(bytes((0xA0 | size,))) + elif size < 0x100: + write(b"\xD8") + write(PACKED_UINT_8[size]) + elif size < 0x10000: + write(b"\xD9") + write(PACKED_UINT_16[size]) + elif size < 0x100000000: + write(b"\xDA") + write(struct_pack(">I", size)) + else: + raise OverflowError("Map header size out of range") + + def pack_struct(self, signature, fields, dehydration_hooks=None): + if len(signature) != 1 or not isinstance(signature, bytes): + raise ValueError("Structure signature must be a single byte value") + write = self._write + size = len(fields) + if size <= 0x0F: + write(bytes((0xB0 | size,))) + else: + raise OverflowError("Structure size out of range") + write(signature) + for field in fields: + self.pack(field, dehydration_hooks=dehydration_hooks) + + @staticmethod + def new_packable_buffer(): + return PackableBuffer() + + +class PackableBuffer: + def __init__(self): + self.data = bytearray() + # export write method for packer; "inline" for performance + self.write = self.data.extend + self.clear = self.data.clear + self._tmp_buffering = 0 + + @contextmanager + def tmp_buffer(self): + self._tmp_buffering += 1 + old_len = len(self.data) + try: + yield + except Exception: + del self.data[old_len:] + raise + finally: + self._tmp_buffering -= 1 + + def is_tmp_buffering(self): + return bool(self._tmp_buffering) + + +class Unpacker: + + def __init__(self, unpackable): + self.unpackable = unpackable + + def reset(self): + self.unpackable.reset() + + def read(self, n=1): + return self.unpackable.read(n) + + def read_u8(self): + return self.unpackable.read_u8() + + def unpack(self, hydration_hooks=None): + value = self._unpack(hydration_hooks=hydration_hooks) + if hydration_hooks and type(value) in hydration_hooks: + return hydration_hooks[type(value)](value) + return value + + def _unpack(self, hydration_hooks=None): + marker = self.read_u8() + + if marker == -1: + raise ValueError("Nothing to unpack") + + # Tiny Integer + if 0x00 <= marker <= 0x7F: + return marker + elif 0xF0 <= marker <= 0xFF: + return marker - 0x100 + + # Null + elif marker == 0xC0: + return None + + # Float + elif marker == 0xC1: + value, = struct_unpack(">d", self.read(8)) + return value + + # Boolean + elif marker == 0xC2: + return False + elif marker == 0xC3: + return True + + # Integer + elif marker == 0xC8: + return struct_unpack(">b", self.read(1))[0] + elif marker == 0xC9: + return struct_unpack(">h", self.read(2))[0] + elif marker == 0xCA: + return struct_unpack(">i", self.read(4))[0] + elif marker == 0xCB: + return struct_unpack(">q", self.read(8))[0] + + # Bytes + elif marker == 0xCC: + size, = struct_unpack(">B", self.read(1)) + return self.read(size).tobytes() + elif marker == 0xCD: + size, = struct_unpack(">H", self.read(2)) + return self.read(size).tobytes() + elif marker == 0xCE: + size, = struct_unpack(">I", self.read(4)) + return self.read(size).tobytes() + + else: + marker_high = marker & 0xF0 + # String + if marker_high == 0x80: # TINY_STRING + return decode(self.read(marker & 0x0F), "utf-8") + elif marker == 0xD0: # STRING_8: + size, = struct_unpack(">B", self.read(1)) + return decode(self.read(size), "utf-8") + elif marker == 0xD1: # STRING_16: + size, = struct_unpack(">H", self.read(2)) + return decode(self.read(size), "utf-8") + elif marker == 0xD2: # STRING_32: + size, = struct_unpack(">I", self.read(4)) + return decode(self.read(size), "utf-8") + + # List + elif 0x90 <= marker <= 0x9F or 0xD4 <= marker <= 0xD6: + return list(self._unpack_list_items( + marker, hydration_hooks=hydration_hooks) + ) + + # Map + elif 0xA0 <= marker <= 0xAF or 0xD8 <= marker <= 0xDA: + return self._unpack_map( + marker, hydration_hooks=hydration_hooks + ) + + # Structure + elif 0xB0 <= marker <= 0xBF: + size, tag = self._unpack_structure_header(marker) + value = Structure(tag, *([None] * size)) + for i in range(len(value)): + value[i] = self.unpack(hydration_hooks=hydration_hooks) + return value + + else: + raise ValueError("Unknown PackStream marker %02X" % marker) + + def _unpack_list_items(self, marker, hydration_hooks=None): + marker_high = marker & 0xF0 + if marker_high == 0x90: + size = marker & 0x0F + if size == 0: + return + elif size == 1: + yield self.unpack(hydration_hooks=hydration_hooks) + else: + for _ in range(size): + yield self.unpack(hydration_hooks=hydration_hooks) + elif marker == 0xD4: # LIST_8: + size, = struct_unpack(">B", self.read(1)) + for _ in range(size): + yield self.unpack(hydration_hooks=hydration_hooks) + elif marker == 0xD5: # LIST_16: + size, = struct_unpack(">H", self.read(2)) + for _ in range(size): + yield self.unpack(hydration_hooks=hydration_hooks) + elif marker == 0xD6: # LIST_32: + size, = struct_unpack(">I", self.read(4)) + for _ in range(size): + yield self.unpack(hydration_hooks=hydration_hooks) + else: + return + + def unpack_map(self, hydration_hooks=None): + marker = self.read_u8() + return self._unpack_map(marker, hydration_hooks=hydration_hooks) + + def _unpack_map(self, marker, hydration_hooks=None): + marker_high = marker & 0xF0 + if marker_high == 0xA0: + size = marker & 0x0F + value = {} + for _ in range(size): + key = self.unpack(hydration_hooks=hydration_hooks) + value[key] = self.unpack(hydration_hooks=hydration_hooks) + return value + elif marker == 0xD8: # MAP_8: + size, = struct_unpack(">B", self.read(1)) + value = {} + for _ in range(size): + key = self.unpack(hydration_hooks=hydration_hooks) + value[key] = self.unpack(hydration_hooks=hydration_hooks) + return value + elif marker == 0xD9: # MAP_16: + size, = struct_unpack(">H", self.read(2)) + value = {} + for _ in range(size): + key = self.unpack(hydration_hooks=hydration_hooks) + value[key] = self.unpack(hydration_hooks=hydration_hooks) + return value + elif marker == 0xDA: # MAP_32: + size, = struct_unpack(">I", self.read(4)) + value = {} + for _ in range(size): + key = self.unpack(hydration_hooks=hydration_hooks) + value[key] = self.unpack(hydration_hooks=hydration_hooks) + return value + else: + return None + + def unpack_structure_header(self): + marker = self.read_u8() + if marker == -1: + return None, None + else: + return self._unpack_structure_header(marker) + + def _unpack_structure_header(self, marker): + marker_high = marker & 0xF0 + if marker_high == 0xB0: # TINY_STRUCT + signature = self.read(1).tobytes() + return marker & 0x0F, signature + else: + raise ValueError("Expected structure, found marker %02X" % marker) + + @staticmethod + def new_unpackable_buffer(): + return UnpackableBuffer() + + +class UnpackableBuffer: + + initial_capacity = 8192 + + def __init__(self, data=None): + if data is None: + self.data = bytearray(self.initial_capacity) + self.used = 0 + else: + self.data = bytearray(data) + self.used = len(self.data) + self.p = 0 + + def reset(self): + self.used = 0 + self.p = 0 + + def read(self, n=1): + view = memoryview(self.data) + q = self.p + n + subview = view[self.p:q] + self.p = q + return subview + + def read_u8(self): + if self.used - self.p >= 1: + value = self.data[self.p] + self.p += 1 + return value + else: + return -1 + + def pop_u16(self): + """ Remove the last two bytes of data, returning them as a big-endian + 16-bit unsigned integer. + """ + if self.used >= 2: + value = 0x100 * self.data[self.used - 2] + self.data[self.used - 1] + self.used -= 2 + return value + else: + return -1 diff --git a/neo4j/_conf.py b/neo4j/_conf.py index 6237743e8..a3e290cc0 100644 --- a/neo4j/_conf.py +++ b/neo4j/_conf.py @@ -16,6 +16,37 @@ # limitations under the License. +from abc import ABCMeta +from collections.abc import Mapping + +from ._meta import ( + deprecation_warn, + get_user_agent, +) +from .api import ( + DEFAULT_DATABASE, + TRUST_ALL_CERTIFICATES, + TRUST_SYSTEM_CA_SIGNED_CERTIFICATES, + WRITE_ACCESS, +) +from .exceptions import ConfigurationError + + +def iter_items(iterable): + """ Iterate through all items (key-value pairs) within an iterable + dictionary-like object. If the object has a `keys` method, this is + used along with `__getitem__` to yield each pair in turn. If no + `keys` method exists, each iterable element is assumed to be a + 2-tuple of key and value. + """ + if hasattr(iterable, "keys"): + for key in iterable.keys(): + yield key, iterable[key] + else: + for key, value in iterable: + yield key, value + + class TrustStore: # Base class for trust stores. For internal type-checking only. pass @@ -84,3 +115,366 @@ class TrustCustomCAs(TrustStore): """ def __init__(self, *certificates): self.certs = certificates + + +class DeprecatedAlias: + """Used when a config option has been renamed.""" + + def __init__(self, new): + self.new = new + + +class DeprecatedAlternative: + """Used for deprecated config options that have a similar alternative.""" + + def __init__(self, new, converter=None): + self.new = new + self.converter = converter + + +class ConfigType(ABCMeta): + + def __new__(mcs, name, bases, attributes): + fields = [] + deprecated_aliases = {} + deprecated_alternatives = {} + + for base in bases: + if type(base) is mcs: + fields += base.keys() + deprecated_aliases.update(base._deprecated_aliases()) + deprecated_alternatives.update(base._deprecated_alternatives()) + + for k, v in attributes.items(): + if isinstance(v, DeprecatedAlias): + deprecated_aliases[k] = v.new + elif isinstance(v, DeprecatedAlternative): + deprecated_alternatives[k] = v.new, v.converter + elif not (k.startswith("_") + or callable(v) + or isinstance(v, (staticmethod, classmethod))): + fields.append(k) + + def keys(_): + return set(fields) + + def _deprecated_keys(_): + return (set(deprecated_aliases.keys()) + | set(deprecated_alternatives.keys())) + + def _get_new(_, key): + return deprecated_aliases.get( + key, deprecated_alternatives.get(key, (None,))[0] + ) + + def _deprecated_aliases(_): + return deprecated_aliases + + def _deprecated_alternatives(_): + return deprecated_alternatives + + attributes.setdefault("keys", classmethod(keys)) + attributes.setdefault("_get_new", + classmethod(_get_new)) + attributes.setdefault("_deprecated_keys", + classmethod(_deprecated_keys)) + attributes.setdefault("_deprecated_aliases", + classmethod(_deprecated_aliases)) + attributes.setdefault("_deprecated_alternatives", + classmethod(_deprecated_alternatives)) + + return super(ConfigType, mcs).__new__( + mcs, name, bases, {k: v for k, v in attributes.items() + if k not in _deprecated_keys(None)} + ) + + +class Config(Mapping, metaclass=ConfigType): + """ Base class for all configuration containers. + """ + + @staticmethod + def consume_chain(data, *config_classes): + values = [] + for config_class in config_classes: + if not issubclass(config_class, Config): + raise TypeError("%r is not a Config subclass" % config_class) + values.append(config_class._consume(data)) + if data: + raise ConfigurationError("Unexpected config keys: %s" % ", ".join(data.keys())) + return values + + @classmethod + def consume(cls, data): + config, = cls.consume_chain(data, cls) + return config + + @classmethod + def _consume(cls, data): + config = {} + if data: + for key in cls.keys() | cls._deprecated_keys(): + try: + value = data.pop(key) + except KeyError: + pass + else: + config[key] = value + return cls(config) + + def __update(self, data): + data_dict = dict(iter_items(data)) + + def set_attr(k, v): + if k in self.keys(): + setattr(self, k, v) + elif k in self._deprecated_keys(): + k0 = self._get_new(k) + if k0 in data_dict: + raise ConfigurationError( + "Cannot specify both '{}' and '{}' in config" + .format(k0, k) + ) + deprecation_warn( + "The '{}' config key is deprecated, please use '{}' " + "instead".format(k, k0) + ) + if k in self._deprecated_aliases(): + set_attr(k0, v) + else: # k in self._deprecated_alternatives: + _, converter = self._deprecated_alternatives()[k] + converter(self, v) + else: + raise AttributeError(k) + + for key, value in data_dict.items(): + if value is not None: + set_attr(key, value) + + def __init__(self, *args, **kwargs): + for arg in args: + self.__update(arg) + self.__update(kwargs) + + def __repr__(self): + attrs = [] + for key in self: + attrs.append(" %s=%r" % (key, getattr(self, key))) + return "<%s%s>" % (self.__class__.__name__, "".join(attrs)) + + def __len__(self): + return len(self.keys()) + + def __getitem__(self, key): + return getattr(self, key) + + def __iter__(self): + return iter(self.keys()) + + +def _trust_to_trusted_certificates(pool_config, trust): + if trust == TRUST_SYSTEM_CA_SIGNED_CERTIFICATES: + pool_config.trusted_certificates = TrustSystemCAs() + elif trust == TRUST_ALL_CERTIFICATES: + pool_config.trusted_certificates = TrustAll() + + +class PoolConfig(Config): + """ Connection pool configuration. + """ + + #: Max Connection Lifetime + max_connection_lifetime = 3600 # seconds + # The maximum duration the driver will keep a connection for before being removed from the pool. + + #: Max Connection Pool Size + max_connection_pool_size = 100 + # The maximum total number of connections allowed, per host (i.e. cluster nodes), to be managed by the connection pool. + + #: Connection Timeout + connection_timeout = 30.0 # seconds + # The maximum amount of time to wait for a TCP connection to be established. + + #: Update Routing Table Timout + update_routing_table_timeout = 90.0 # seconds + # The maximum amount of time to wait for updating the routing table. + # This includes everything necessary for this to happen. + # Including opening sockets, requesting and receiving the routing table, + # etc. + + #: Trust + trust = DeprecatedAlternative( + "trusted_certificates", _trust_to_trusted_certificates + ) + # Specify how to determine the authenticity of encryption certificates provided by the Neo4j instance on connection. + + #: Custom Resolver + resolver = None + # Custom resolver function, returning list of resolved addresses. + + #: Encrypted + encrypted = False + # Specify whether to use an encrypted connection between the driver and server. + + #: SSL Certificates to Trust + trusted_certificates = TrustSystemCAs() + # Specify how to determine the authenticity of encryption certificates + # provided by the Neo4j instance on connection. + # * `neo4j.TrustSystemCAs()`: Use system trust store. (default) + # * `neo4j.TrustAll()`: Trust any certificate. + # * `neo4j.TrustCustomCAs("", ...)`: + # Trust the specified certificate(s). + + #: Custom SSL context to use for wrapping sockets + ssl_context = None + # Use any custom SSL context to wrap sockets. + # Overwrites `trusted_certificates` and `encrypted`. + # The use of this option is strongly discouraged. + + #: User Agent (Python Driver Specific) + user_agent = get_user_agent() + # Specify the client agent name. + + #: Protocol Version (Python Driver Specific) + protocol_version = None # Version(4, 0) + # Specify a specific Bolt Protocol Version + + #: Initial Connection Pool Size (Python Driver Specific) + init_size = 1 # The other drivers do not seed from the start. + # This will seed the pool with the specified number of connections. + + #: Socket Keep Alive (Python and .NET Driver Specific) + keep_alive = True + # Specify whether TCP keep-alive should be enabled. + + def get_ssl_context(self): + if self.ssl_context is not None: + return self.ssl_context + + if not self.encrypted: + return None + + import ssl + + # SSL stands for Secure Sockets Layer and was originally created by Netscape. + # SSLv2 and SSLv3 are the 2 versions of this protocol (SSLv1 was never publicly released). + # After SSLv3, SSL was renamed to TLS. + # TLS stands for Transport Layer Security and started with TLSv1.0 which is an upgraded version of SSLv3. + # SSLv2 - (Disabled) + # SSLv3 - (Disabled) + # TLS 1.0 - Released in 1999, published as RFC 2246. (Disabled) + # TLS 1.1 - Released in 2006, published as RFC 4346. (Disabled) + # TLS 1.2 - Released in 2008, published as RFC 5246. + # https://docs.python.org/3.7/library/ssl.html#ssl.PROTOCOL_TLS_CLIENT + ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) + + # For recommended security options see + # https://docs.python.org/3.7/library/ssl.html#protocol-versions + ssl_context.options |= ssl.OP_NO_TLSv1 # Python 3.2 + ssl_context.options |= ssl.OP_NO_TLSv1_1 # Python 3.4 + + if isinstance(self.trusted_certificates, TrustAll): + # trust any certificate + ssl_context.check_hostname = False + # https://docs.python.org/3.7/library/ssl.html#ssl.CERT_NONE + ssl_context.verify_mode = ssl.CERT_NONE + elif isinstance(self.trusted_certificates, TrustCustomCAs): + # trust the specified certificate(s) + ssl_context.check_hostname = True + ssl_context.verify_mode = ssl.CERT_REQUIRED + for cert in self.trusted_certificates.certs: + ssl_context.load_verify_locations(cert) + else: + # default + # trust system CA certificates + ssl_context.check_hostname = True + ssl_context.verify_mode = ssl.CERT_REQUIRED + # Must be load_default_certs, not set_default_verify_paths to + # work on Windows with system CAs. + ssl_context.load_default_certs() + + return ssl_context + + +class WorkspaceConfig(Config): + """ WorkSpace configuration. + """ + + #: Session Connection Timeout + session_connection_timeout = 120.0 # seconds + # The maximum amount of time to wait for a session to obtain a usable + # read/write connection. This includes everything necessary for this to + # happen. Including fetching routing tables, opening sockets, etc. + + #: Connection Acquisition Timeout + connection_acquisition_timeout = 60.0 # seconds + # The maximum amount of time a session will wait when requesting a connection from the connection pool. + # Since the process of acquiring a connection may involve creating a new connection, ensure that the value + # of this configuration is higher than the configured Connection Timeout. + + #: Max Transaction Retry Time + max_transaction_retry_time = 30.0 # seconds + # The maximum amount of time that a managed transaction will retry before failing. + + #: Initial Retry Delay + initial_retry_delay = 1.0 # seconds + + #: Retry Delay Multiplier + retry_delay_multiplier = 2.0 # seconds + + #: Retry Delay Jitter Factor + retry_delay_jitter_factor = 0.2 # seconds + + #: Database Name + database = DEFAULT_DATABASE + # Name of the database to query. + # Note: The default database can be set on the Neo4j instance settings. + + #: Fetch Size + fetch_size = 1000 + + #: User to impersonate + impersonated_user = None + # Note that you need appropriate permissions to do so. + + +class SessionConfig(WorkspaceConfig): + """ Session configuration. + """ + + #: Bookmarks + bookmarks = None + + #: Default AccessMode + default_access_mode = WRITE_ACCESS + + +class TransactionConfig(Config): + """ Transaction configuration. This is internal for now. + + neo4j.session.begin_transaction + neo4j.Query + neo4j.unit_of_work + + are both using the same settings. + """ + #: Metadata + metadata = None # dictionary + + #: Timeout + timeout = None # seconds + + +class RoutingConfig(Config): + """ Neo4jDriver routing settings. This is internal for now. + """ + + #: Routing Table Purge_Delay + routing_table_purge_delay = 30.0 # seconds + # The TTL + routing_table_purge_delay should be used to check if the database routing table should be removed. + + #: Max Routing Failures + # max_routing_failures = 1 + + #: Retry Timeout Delay + # retry_timeout_delay = 5.0 # seconds diff --git a/neo4j/_data.py b/neo4j/_data.py new file mode 100644 index 000000000..9207b50f5 --- /dev/null +++ b/neo4j/_data.py @@ -0,0 +1,313 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [https://neo4j.com] +# +# This file is part of Neo4j. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from abc import ( + ABCMeta, + abstractmethod, +) +from collections.abc import ( + Mapping, + Sequence, + Set, +) +from functools import reduce +from operator import xor as xor_operator + +from ._conf import iter_items +from .graph import ( + Node, + Path, + Relationship, +) + + +class Record(tuple, Mapping): + """ A :class:`.Record` is an immutable ordered collection of key-value + pairs. It is generally closer to a :py:class:`namedtuple` than to a + :py:class:`OrderedDict` in as much as iteration of the collection will + yield values rather than keys. + """ + + __keys = None + + def __new__(cls, iterable=()): + keys = [] + values = [] + for key, value in iter_items(iterable): + keys.append(key) + values.append(value) + inst = tuple.__new__(cls, values) + inst.__keys = tuple(keys) + return inst + + def __repr__(self): + return "<%s %s>" % (self.__class__.__name__, + " ".join("%s=%r" % (field, self[i]) for i, field in enumerate(self.__keys))) + + def __eq__(self, other): + """ In order to be flexible regarding comparison, the equality rules + for a record permit comparison with any other Sequence or Mapping. + + :param other: + :return: + """ + compare_as_sequence = isinstance(other, Sequence) + compare_as_mapping = isinstance(other, Mapping) + if compare_as_sequence and compare_as_mapping: + return list(self) == list(other) and dict(self) == dict(other) + elif compare_as_sequence: + return list(self) == list(other) + elif compare_as_mapping: + return dict(self) == dict(other) + else: + return False + + def __ne__(self, other): + return not self.__eq__(other) + + def __hash__(self): + return reduce(xor_operator, map(hash, self.items())) + + def __getitem__(self, key): + if isinstance(key, slice): + keys = self.__keys[key] + values = super(Record, self).__getitem__(key) + return self.__class__(zip(keys, values)) + try: + index = self.index(key) + except IndexError: + return None + else: + return super(Record, self).__getitem__(index) + + def __getslice__(self, start, stop): + key = slice(start, stop) + keys = self.__keys[key] + values = tuple(self)[key] + return self.__class__(zip(keys, values)) + + def get(self, key, default=None): + """ Obtain a value from the record by key, returning a default + value if the key does not exist. + + :param key: a key + :param default: default value + :return: a value + """ + try: + index = self.__keys.index(str(key)) + except ValueError: + return default + if 0 <= index < len(self): + return super(Record, self).__getitem__(index) + else: + return default + + def index(self, key): + """ Return the index of the given item. + + :param key: a key + :return: index + :rtype: int + """ + if isinstance(key, int): + if 0 <= key < len(self.__keys): + return key + raise IndexError(key) + elif isinstance(key, str): + try: + return self.__keys.index(key) + except ValueError: + raise KeyError(key) + else: + raise TypeError(key) + + def value(self, key=0, default=None): + """ Obtain a single value from the record by index or key. If no + index or key is specified, the first value is returned. If the + specified item does not exist, the default value is returned. + + :param key: an index or key + :param default: default value + :return: a single value + """ + try: + index = self.index(key) + except (IndexError, KeyError): + return default + else: + return self[index] + + def keys(self): + """ Return the keys of the record. + + :return: list of key names + """ + return list(self.__keys) + + def values(self, *keys): + """ Return the values of the record, optionally filtering to + include only certain values by index or key. + + :param keys: indexes or keys of the items to include; if none + are provided, all values will be included + :return: list of values + :rtype: list + """ + if keys: + d = [] + for key in keys: + try: + i = self.index(key) + except KeyError: + d.append(None) + else: + d.append(self[i]) + return d + return list(self) + + def items(self, *keys): + """ Return the fields of the record as a list of key and value tuples + + :return: a list of value tuples + :rtype: list + """ + if keys: + d = [] + for key in keys: + try: + i = self.index(key) + except KeyError: + d.append((key, None)) + else: + d.append((self.__keys[i], self[i])) + return d + return list((self.__keys[i], super(Record, self).__getitem__(i)) for i in range(len(self))) + + def data(self, *keys): + """ Return the keys and values of this record as a dictionary, + optionally including only certain values by index or key. Keys + provided in the items that are not in the record will be + inserted with a value of :const:`None`; indexes provided + that are out of bounds will trigger an :exc:`IndexError`. + + :param keys: indexes or keys of the items to include; if none + are provided, all values will be included + :return: dictionary of values, keyed by field name + :raises: :exc:`IndexError` if an out-of-bounds index is specified + """ + return RecordExporter().transform(dict(self.items(*keys))) + + +class DataTransformer(metaclass=ABCMeta): + """ Abstract base class for transforming data from one form into + another. + """ + + @abstractmethod + def transform(self, x): + """ Transform a value, or collection of values. + + :param x: input value + :return: output value + """ + + +class RecordExporter(DataTransformer): + """ Transformer class used by the :meth:`.Record.data` method. + """ + + def transform(self, x): + if isinstance(x, Node): + return self.transform(dict(x)) + elif isinstance(x, Relationship): + return (self.transform(dict(x.start_node)), + x.__class__.__name__, + self.transform(dict(x.end_node))) + elif isinstance(x, Path): + path = [self.transform(x.start_node)] + for i, relationship in enumerate(x.relationships): + path.append(self.transform(relationship.__class__.__name__)) + path.append(self.transform(x.nodes[i + 1])) + return path + elif isinstance(x, str): + return x + elif isinstance(x, Sequence): + t = type(x) + return t(map(self.transform, x)) + elif isinstance(x, Set): + t = type(x) + return t(map(self.transform, x)) + elif isinstance(x, Mapping): + t = type(x) + return t((k, self.transform(v)) for k, v in x.items()) + else: + return x + + +class RecordTableRowExporter(DataTransformer): + """Transformer class used by the :meth:`.Result.to_df` method.""" + + def transform(self, x): + assert isinstance(x, Mapping) + t = type(x) + return t(item + for k, v in x.items() + for item in self._transform( + v, prefix=k.replace("\\", "\\\\").replace(".", "\\.") + ).items()) + + def _transform(self, x, prefix): + if isinstance(x, Node): + res = { + "%s().element_id" % prefix: x.element_id, + "%s().labels" % prefix: x.labels, + } + res.update(("%s().prop.%s" % (prefix, k), v) for k, v in x.items()) + return res + elif isinstance(x, Relationship): + res = { + "%s->.element_id" % prefix: x.element_id, + "%s->.start.element_id" % prefix: x.start_node.element_id, + "%s->.end.element_id" % prefix: x.end_node.element_id, + "%s->.type" % prefix: x.__class__.__name__, + } + res.update(("%s->.prop.%s" % (prefix, k), v) for k, v in x.items()) + return res + elif isinstance(x, Path) or isinstance(x, str): + return {prefix: x} + elif isinstance(x, Sequence): + return dict( + item + for i, v in enumerate(x) + for item in self._transform( + v, prefix="%s[].%i" % (prefix, i) + ).items() + ) + elif isinstance(x, Mapping): + t = type(x) + return t( + item + for k, v in x.items() + for item in self._transform( + v, prefix="%s{}.%s" % (prefix, k.replace("\\", "\\\\") + .replace(".", "\\.")) + ).items() + ) + else: + return {prefix: x} diff --git a/neo4j/_meta.py b/neo4j/_meta.py new file mode 100644 index 000000000..307001e3c --- /dev/null +++ b/neo4j/_meta.py @@ -0,0 +1,124 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [https://neo4j.com] +# +# This file is part of Neo4j. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import asyncio +from functools import wraps +from warnings import warn + + +# Can be automatically overridden in builds +package = "neo4j" +version = "5.0.dev0" + + +def get_user_agent(): + """ Obtain the default user agent string sent to the server after + a successful handshake. + """ + from sys import ( + platform, + version_info, + ) + template = "neo4j-python/{} Python/{}.{}.{}-{}-{} ({})" + fields = (version,) + tuple(version_info) + (platform,) + return template.format(*fields) + + +def deprecation_warn(message, stack_level=1): + warn(message, category=DeprecationWarning, stacklevel=stack_level + 1) + + +def deprecated(message): + """ Decorator for deprecating functions and methods. + + :: + + @deprecated("'foo' has been deprecated in favour of 'bar'") + def foo(x): + pass + + """ + def decorator(f): + if asyncio.iscoroutinefunction(f): + @wraps(f) + async def inner(*args, **kwargs): + deprecation_warn(message, stack_level=2) + return await f(*args, **kwargs) + + return inner + else: + @wraps(f) + def inner(*args, **kwargs): + deprecation_warn(message, stack_level=2) + return f(*args, **kwargs) + + return inner + + return decorator + + +class ExperimentalWarning(Warning): + """ Base class for warnings about experimental features. + """ + + +def experimental_warn(message, stack_level=1): + warn(message, category=ExperimentalWarning, stacklevel=stack_level + 1) + + +def experimental(message): + """ Decorator for tagging experimental functions and methods. + + :: + + @experimental("'foo' is an experimental function and may be " + "removed in a future release") + def foo(x): + pass + + """ + def decorator(f): + if asyncio.iscoroutinefunction(f): + @wraps(f) + async def inner(*args, **kwargs): + experimental_warn(message, stack_level=2) + return await f(*args, **kwargs) + + return inner + else: + @wraps(f) + def inner(*args, **kwargs): + experimental_warn(message, stack_level=2) + return f(*args, **kwargs) + + return inner + + return decorator + + +def unclosed_resource_warn(obj): + import tracemalloc + from warnings import warn + msg = f"Unclosed {obj!r}." + trace = tracemalloc.get_object_traceback(obj) + if trace: + msg += "\nObject allocated at (most recent call last):\n" + msg += "\n".join(trace.format()) + else: + msg += "\nEnable tracemalloc to get the object allocation traceback." + warn(msg, ResourceWarning, stacklevel=2, source=obj) diff --git a/neo4j/_routing.py b/neo4j/_routing.py new file mode 100644 index 000000000..a073dda3a --- /dev/null +++ b/neo4j/_routing.py @@ -0,0 +1,167 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [https://neo4j.com] +# +# This file is part of Neo4j. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from collections.abc import MutableSet +from logging import getLogger +from time import perf_counter + +from .addressing import Address + + +log = getLogger("neo4j") + + +class OrderedSet(MutableSet): + + def __init__(self, elements=()): + # dicts keep insertion order starting with Python 3.7 + self._elements = dict.fromkeys(elements) + self._current = None + + def __repr__(self): + return "{%s}" % ", ".join(map(repr, self._elements)) + + def __contains__(self, element): + return element in self._elements + + def __iter__(self): + return iter(self._elements) + + def __len__(self): + return len(self._elements) + + def __getitem__(self, index): + return list(self._elements.keys())[index] + + def add(self, element): + self._elements[element] = None + + def clear(self): + self._elements.clear() + + def discard(self, element): + try: + del self._elements[element] + except KeyError: + pass + + def remove(self, element): + try: + del self._elements[element] + except KeyError: + raise ValueError(element) + + def update(self, elements=()): + self._elements.update(dict.fromkeys(elements)) + + def replace(self, elements=()): + e = self._elements + e.clear() + e.update(dict.fromkeys(elements)) + + +class RoutingTable: + + @classmethod + def parse_routing_info(cls, *, database, servers, ttl): + """ Parse the records returned from the procedure call and + return a new RoutingTable instance. + """ + routers = [] + readers = [] + writers = [] + try: + for server in servers: + role = server["role"] + addresses = [] + for address in server["addresses"]: + addresses.append(Address.parse(address, default_port=7687)) + if role == "ROUTE": + routers.extend(addresses) + elif role == "READ": + readers.extend(addresses) + elif role == "WRITE": + writers.extend(addresses) + except (KeyError, TypeError): + raise ValueError("Cannot parse routing info") + else: + return cls(database=database, routers=routers, readers=readers, writers=writers, ttl=ttl) + + def __init__(self, *, database, routers=(), readers=(), writers=(), ttl=0): + self.initial_routers = OrderedSet(routers) + self.routers = OrderedSet(routers) + self.readers = OrderedSet(readers) + self.writers = OrderedSet(writers) + self.initialized_without_writers = not self.writers + self.last_updated_time = perf_counter() + self.ttl = ttl + self.database = database + + def __repr__(self): + return "RoutingTable(database=%r routers=%r, readers=%r, writers=%r, last_updated_time=%r, ttl=%r)" % ( + self.database, + self.routers, + self.readers, + self.writers, + self.last_updated_time, + self.ttl, + ) + + def __contains__(self, address): + return address in self.routers or address in self.readers or address in self.writers + + def is_fresh(self, readonly=False): + """ Indicator for whether routing information is still usable. + """ + assert isinstance(readonly, bool) + log.debug("[#0000] C: Checking table freshness (readonly=%r)", readonly) + expired = self.last_updated_time + self.ttl <= perf_counter() + if readonly: + has_server_for_mode = bool(self.readers) + else: + has_server_for_mode = bool(self.writers) + log.debug("[#0000] C: Table expired=%r", expired) + log.debug("[#0000] C: Table routers=%r", self.routers) + log.debug("[#0000] C: Table has_server_for_mode=%r", has_server_for_mode) + return not expired and self.routers and has_server_for_mode + + def should_be_purged_from_memory(self): + """ Check if the routing table is stale and not used for a long time and should be removed from memory. + + :return: Returns true if it is old and not used for a while. + :rtype: bool + """ + from neo4j._conf import RoutingConfig + perf_time = perf_counter() + log.debug("[#0000] C: last_updated_time=%r perf_time=%r", self.last_updated_time, perf_time) + return self.last_updated_time + self.ttl + RoutingConfig.routing_table_purge_delay <= perf_time + + def update(self, new_routing_table): + """ Update the current routing table with new routing information + from a replacement table. + """ + self.routers.replace(new_routing_table.routers) + self.readers.replace(new_routing_table.readers) + self.writers.replace(new_routing_table.writers) + self.initialized_without_writers = not self.writers + self.last_updated_time = perf_counter() + self.ttl = new_routing_table.ttl + log.debug("[#0000] S: table=%r", self) + + def servers(self): + return set(self.routers) | set(self.writers) | set(self.readers) diff --git a/neo4j/_spatial/__init__.py b/neo4j/_spatial/__init__.py new file mode 100644 index 000000000..3c84a0b0f --- /dev/null +++ b/neo4j/_spatial/__init__.py @@ -0,0 +1,106 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [https://neo4j.com] +# +# This file is part of Neo4j. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +""" +This module defines _spatial data types. +""" + + +from threading import Lock + + +# SRID to subclass mappings +srid_table = {} +srid_table_lock = Lock() + + +class Point(tuple): + """Base-class for _spatial data. + + A point within a geometric space. This type is generally used via its + subclasses and should not be instantiated directly unless there is no + subclass defined for the required SRID. + + :param iterable: + An iterable of coordinates. + All items will be converted to :class:`float`. + """ + + #: The SRID (_spatial reference identifier) of the _spatial data. + #: A number that identifies the coordinate system the _spatial type is to be + #: interpreted in. + #: + #: :type: int + srid = None + + def __new__(cls, iterable): + return tuple.__new__(cls, map(float, iterable)) + + def __repr__(self): + return "POINT(%s)" % " ".join(map(str, self)) + + def __eq__(self, other): + try: + return type(self) is type(other) and tuple(self) == tuple(other) + except (AttributeError, TypeError): + return False + + def __ne__(self, other): + return not self.__eq__(other) + + def __hash__(self): + return hash(type(self)) ^ hash(tuple(self)) + + +def point_type(name, fields, srid_map): + """ Dynamically create a Point subclass. + """ + + def srid(self): + try: + return srid_map[len(self)] + except KeyError: + return None + + attributes = {"srid": property(srid)} + + for index, subclass_field in enumerate(fields): + + def accessor(self, i=index, f=subclass_field): + try: + return self[i] + except IndexError: + raise AttributeError(f) + + for field_alias in {subclass_field, "xyz"[index]}: + attributes[field_alias] = property(accessor) + + cls = type(name, (Point,), attributes) + + with srid_table_lock: + for dim, srid in srid_map.items(): + srid_table[srid] = (cls, dim) + + return cls + + +# Point subclass definitions +CartesianPoint = point_type("CartesianPoint", ["x", "y", "z"], + {2: 7203, 3: 9157}) +WGS84Point = point_type("WGS84Point", ["longitude", "latitude", "height"], + {2: 4326, 3: 4979}) diff --git a/neo4j/_sync/driver.py b/neo4j/_sync/driver.py index e413197ff..e67ff7c0e 100644 --- a/neo4j/_sync/driver.py +++ b/neo4j/_sync/driver.py @@ -16,32 +16,27 @@ # limitations under the License. -import warnings - from .._async_compat.util import Util from .._conf import ( - TrustAll, - TrustStore, -) -from ..addressing import Address -from ..api import ( - READ_ACCESS, - TRUST_ALL_CERTIFICATES, - TRUST_SYSTEM_CA_SIGNED_CERTIFICATES, -) -from ..conf import ( Config, PoolConfig, SessionConfig, + TrustAll, + TrustStore, WorkspaceConfig, ) -from ..meta import ( +from .._meta import ( deprecation_warn, experimental, experimental_warn, - ExperimentalWarning, unclosed_resource_warn, ) +from ..addressing import Address +from ..api import ( + READ_ACCESS, + TRUST_ALL_CERTIFICATES, + TRUST_SYSTEM_CA_SIGNED_CERTIFICATES, +) class GraphDatabase: @@ -145,7 +140,8 @@ def driver(cls, uri, *, auth=None, **config): "Creating a direct driver (`bolt://` scheme) with routing " "context (URI parameters) is deprecated. They will be " "ignored. This will raise an error in a future release. " - 'Given URI "{}"'.format(uri) + 'Given URI "{}"'.format(uri), + stack_level=2 ) # TODO: 6.0 - raise instead of warning # raise ValueError( diff --git a/neo4j/_sync/io/_bolt.py b/neo4j/_sync/io/_bolt.py index 5093de404..b7f9ecd88 100644 --- a/neo4j/_sync/io/_bolt.py +++ b/neo4j/_sync/io/_bolt.py @@ -24,17 +24,20 @@ from ..._async_compat.network import BoltSocket from ..._async_compat.util import Util +from ..._codec.hydration import v1 as hydration_v1 +from ..._codec.packstream import v1 as packstream_v1 +from ..._conf import PoolConfig from ..._exceptions import ( BoltError, BoltHandshakeError, SocketDeadlineExceeded, ) +from ..._meta import get_user_agent from ...addressing import Address from ...api import ( ServerInfo, Version, ) -from ...conf import PoolConfig from ...exceptions import ( AuthError, DriverError, @@ -42,11 +45,6 @@ ServiceUnavailable, SessionExpired, ) -from ...meta import get_user_agent -from ...packstream import ( - Packer, - Unpacker, -) from ._common import ( CommitResponse, Inbox, @@ -68,6 +66,13 @@ class Bolt: the handshake was carried out. """ + # TODO: let packer/unpacker know of hydration (give them hooks?) + # TODO: make sure query parameter dehydration gets clear error message. + + PACKER_CLS = packstream_v1.Packer + UNPACKER_CLS = packstream_v1.Unpacker + HYDRATION_HANDLER_CLS = hydration_v1.HydrationHandler + MAGIC_PREAMBLE = b"\x60\x60\xB0\x17" PROTOCOL_VERSION = None @@ -107,10 +112,16 @@ def __init__(self, unresolved_address, sock, max_connection_lifetime, *, # configuration hint that exists. Therefore, all hints can be stored at # connection level. This might change in the future. self.configuration_hints = {} - self.outbox = Outbox() - self.inbox = Inbox(self.socket, on_error=self._set_defunct_read) - self.packer = Packer(self.outbox) - self.unpacker = Unpacker(self.inbox) + self.patch = {} + self.outbox = Outbox( + self.socket, on_error=self._set_defunct_write, + packer_cls=self.PACKER_CLS + ) + self.inbox = Inbox( + self.socket, on_error=self._set_defunct_read, + unpacker_cls=self.UNPACKER_CLS + ) + self.hydration_handler = self.HYDRATION_HANDLER_CLS() self.responses = deque() self._max_connection_lifetime = max_connection_lifetime self._creation_timestamp = perf_counter() @@ -376,14 +387,17 @@ def der_encoded_server_certificate(self): pass @abc.abstractmethod - def hello(self): + def hello(self, dehydration_hooks=None, hydration_hooks=None): """ Appends a HELLO message to the outgoing queue, sends it and consumes all remaining messages. """ pass @abc.abstractmethod - def route(self, database=None, imp_user=None, bookmarks=None): + def route( + self, database=None, imp_user=None, bookmarks=None, + dehydration_hooks=None, hydration_hooks=None + ): """ Fetch a routing table from the server for the given `database`. For Bolt 4.3 and above, this appends a ROUTE message; for earlier versions, a procedure call is made via @@ -396,13 +410,22 @@ def route(self, database=None, imp_user=None, bookmarks=None): Requires Bolt 4.4+. :param bookmarks: iterable of bookmark values after which this transaction should begin - :return: dictionary of raw routing data + :param dehydration_hooks: + Hooks to dehydrate types (dict from type (class) to dehydration + function). Dehydration functions receive the value and returns an + object of type understood by packstream. + :param hydration_hooks: + Hooks to hydrate types (mapping from type (class) to + dehydration function). Dehydration functions receive the value of + type understood by packstream and are free to return anything. """ pass @abc.abstractmethod def run(self, query, parameters=None, mode=None, bookmarks=None, - metadata=None, timeout=None, db=None, imp_user=None, **handlers): + metadata=None, timeout=None, db=None, imp_user=None, + dehydration_hooks=None, hydration_hooks=None, + **handlers): """ Appends a RUN message to the output queue. :param query: Cypher query string @@ -415,36 +438,60 @@ def run(self, query, parameters=None, mode=None, bookmarks=None, Requires Bolt 4.0+. :param imp_user: the user to impersonate Requires Bolt 4.4+. + :param dehydration_hooks: + Hooks to dehydrate types (dict from type (class) to dehydration + function). Dehydration functions receive the value and returns an + object of type understood by packstream. + :param hydration_hooks: + Hooks to hydrate types (mapping from type (class) to + dehydration function). Dehydration functions receive the value of + type understood by packstream and are free to return anything. :param handlers: handler functions passed into the returned Response object - :return: Response object """ pass @abc.abstractmethod - def discard(self, n=-1, qid=-1, **handlers): + def discard(self, n=-1, qid=-1, dehydration_hooks=None, + hydration_hooks=None, **handlers): """ Appends a DISCARD message to the output queue. :param n: number of records to discard, default = -1 (ALL) :param qid: query ID to discard for, default = -1 (last query) + :param dehydration_hooks: + Hooks to dehydrate types (dict from type (class) to dehydration + function). Dehydration functions receive the value and returns an + object of type understood by packstream. + :param hydration_hooks: + Hooks to hydrate types (mapping from type (class) to + dehydration function). Dehydration functions receive the value of + type understood by packstream and are free to return anything. :param handlers: handler functions passed into the returned Response object - :return: Response object """ pass @abc.abstractmethod - def pull(self, n=-1, qid=-1, **handlers): + def pull(self, n=-1, qid=-1, dehydration_hooks=None, hydration_hooks=None, + **handlers): """ Appends a PULL message to the output queue. :param n: number of records to pull, default = -1 (ALL) :param qid: query ID to pull for, default = -1 (last query) + :param dehydration_hooks: + Hooks to dehydrate types (dict from type (class) to dehydration + function). Dehydration functions receive the value and returns an + object of type understood by packstream. + :param hydration_hooks: + Hooks to hydrate types (mapping from type (class) to + dehydration function). Dehydration functions receive the value of + type understood by packstream and are free to return anything. :param handlers: handler functions passed into the returned Response object - :return: Response object """ pass @abc.abstractmethod def begin(self, mode=None, bookmarks=None, metadata=None, timeout=None, - db=None, imp_user=None, **handlers): + db=None, imp_user=None, dehydration_hooks=None, + hydration_hooks=None, **handlers): """ Appends a BEGIN message to the output queue. :param mode: access mode for routing - "READ" or "WRITE" (default) @@ -455,53 +502,99 @@ def begin(self, mode=None, bookmarks=None, metadata=None, timeout=None, Requires Bolt 4.0+. :param imp_user: the user to impersonate Requires Bolt 4.4+ + :param dehydration_hooks: + Hooks to dehydrate types (dict from type (class) to dehydration + function). Dehydration functions receive the value and returns an + object of type understood by packstream. + :param hydration_hooks: + Hooks to hydrate types (mapping from type (class) to + dehydration function). Dehydration functions receive the value of + type understood by packstream and are free to return anything. :param handlers: handler functions passed into the returned Response object :return: Response object """ pass @abc.abstractmethod - def commit(self, **handlers): - """ Appends a COMMIT message to the output queue.""" + def commit(self, dehydration_hooks=None, hydration_hooks=None, **handlers): + """ Appends a COMMIT message to the output queue. + + :param dehydration_hooks: + Hooks to dehydrate types (dict from type (class) to dehydration + function). Dehydration functions receive the value and returns an + object of type understood by packstream. + :param hydration_hooks: + Hooks to hydrate types (mapping from type (class) to + dehydration function). Dehydration functions receive the value of + type understood by packstream and are free to return anything. + """ pass @abc.abstractmethod - def rollback(self, **handlers): - """ Appends a ROLLBACK message to the output queue.""" + def rollback(self, dehydration_hooks=None, hydration_hooks=None, **handlers): + """ Appends a ROLLBACK message to the output queue. + + :param dehydration_hooks: + Hooks to dehydrate types (dict from type (class) to dehydration + function). Dehydration functions receive the value and returns an + object of type understood by packstream. + :param hydration_hooks: + Hooks to hydrate types (mapping from type (class) to + dehydration function). Dehydration functions receive the value of + type understood by packstream and are free to return anything.""" pass @abc.abstractmethod - def reset(self): + def reset(self, dehydration_hooks=None, hydration_hooks=None): """ Appends a RESET message to the outgoing queue, sends it and consumes all remaining messages. + + :param dehydration_hooks: + Hooks to dehydrate types (dict from type (class) to dehydration + function). Dehydration functions receive the value and returns an + object of type understood by packstream. + :param hydration_hooks: + Hooks to hydrate types (mapping from type (class) to + dehydration function). Dehydration functions receive the value of + type understood by packstream and are free to return anything. """ pass @abc.abstractmethod - def goodbye(self): - """Append a GOODBYE message to the outgoing queue.""" + def goodbye(self, dehydration_hooks=None, hydration_hooks=None): + """Append a GOODBYE message to the outgoing queue. + + :param dehydration_hooks: + Hooks to dehydrate types (dict from type (class) to dehydration + function). Dehydration functions receive the value and returns an + object of type understood by packstream. + :param hydration_hooks: + Hooks to hydrate types (mapping from type (class) to + dehydration function). Dehydration functions receive the value of + type understood by packstream and are free to return anything. + """ pass - def _append(self, signature, fields=(), response=None): + def new_hydration_scope(self): + return self.hydration_handler.new_hydration_scope() + + def _append(self, signature, fields=(), response=None, + dehydration_hooks=None): """ Appends a message to the outgoing queue. :param signature: the signature of the message :param fields: the fields of the message as a tuple :param response: a response object to handle callbacks + :param dehydration_hooks: + Hooks to dehydrate types (dict from type (class) to dehydration + function). Dehydration functions receive the value and returns an + object of type understood by packstream. """ - with self.outbox.tmp_buffer(): - self.packer.pack_struct(signature, fields) - self.outbox.wrap_message() + self.outbox.append_message(signature, fields, dehydration_hooks) self.responses.append(response) def _send_all(self): - data = self.outbox.view() - if data: - try: - self.socket.sendall(data) - except OSError as error: - self._set_defunct_write(error) - self.outbox.clear() + if self.outbox.flush(): self.idle_since = perf_counter() def send_all(self): @@ -523,8 +616,7 @@ def send_all(self): self._send_all() @abc.abstractmethod - def _process_message(self, details, summary_signature, - summary_metadata): + def _process_message(self, tag, fields): """ Receive at most one message from the server, if available. :return: 2-tuple of number of detail messages and number of summary @@ -549,11 +641,10 @@ def fetch_message(self): return 0, 0 # Receive exactly one message - details, summary_signature, summary_metadata = \ - Util.next(self.inbox) - res = self._process_message( - details, summary_signature, summary_metadata + tag, fields = self.inbox.pop( + hydration_hooks=self.responses[0].hydration_hooks ) + res = self._process_message(tag, fields) self.idle_since = perf_counter() return res diff --git a/neo4j/_sync/io/_bolt3.py b/neo4j/_sync/io/_bolt3.py index ac6e61fb4..1a169f71c 100644 --- a/neo4j/_sync/io/_bolt3.py +++ b/neo4j/_sync/io/_bolt3.py @@ -142,7 +142,7 @@ def get_base_headers(self): "user_agent": self.user_agent, } - def hello(self): + def hello(self, dehydration_hooks=None, hydration_hooks=None): headers = self.get_base_headers() headers.update(self.auth_dict) logged_headers = dict(headers) @@ -150,13 +150,17 @@ def hello(self): logged_headers["credentials"] = "*******" log.debug("[#%04X] C: HELLO %r", self.local_port, logged_headers) self._append(b"\x01", (headers,), - response=InitResponse(self, "hello", - on_success=self.server_info.update)) + response=InitResponse(self, "hello", hydration_hooks, + on_success=self.server_info.update), + dehydration_hooks=dehydration_hooks) self.send_all() self.fetch_all() check_supported_server_product(self.server_info.agent) - def route(self, database=None, imp_user=None, bookmarks=None): + def route( + self, database=None, imp_user=None, bookmarks=None, + dehydration_hooks=None, hydration_hooks=None + ): if database is not None: raise ConfigurationError( "Database name parameter for selecting database is not " @@ -183,16 +187,20 @@ def route(self, database=None, imp_user=None, bookmarks=None): "CALL dbms.cluster.routing.getRoutingTable($context)", # This is an internal procedure call. Only available if the Neo4j 3.5 is setup with clustering. {"context": self.routing_context}, mode="r", # Bolt Protocol Version(3, 0) supports mode="r" + dehydration_hooks=dehydration_hooks, + hydration_hooks=hydration_hooks, on_success=metadata.update ) - self.pull(on_success=metadata.update, on_records=records.extend) + self.pull(dehydration_hooks = None, hydration_hooks = None, + on_success=metadata.update, on_records=records.extend) self.send_all() self.fetch_all() routing_info = [dict(zip(metadata.get("fields", ()), values)) for values in records] return routing_info def run(self, query, parameters=None, mode=None, bookmarks=None, - metadata=None, timeout=None, db=None, imp_user=None, **handlers): + metadata=None, timeout=None, db=None, imp_user=None, + dehydration_hooks=None, hydration_hooks=None, **handlers): if db is not None: raise ConfigurationError( "Database name parameter for selecting database is not " @@ -231,20 +239,29 @@ def run(self, query, parameters=None, mode=None, bookmarks=None, raise ValueError("Timeout must be a positive number or 0.") fields = (query, parameters, extra) log.debug("[#%04X] C: RUN %s", self.local_port, " ".join(map(repr, fields))) - self._append(b"\x10", fields, Response(self, "run", **handlers)) + self._append(b"\x10", fields, + Response(self, "run", hydration_hooks, **handlers), + dehydration_hooks=dehydration_hooks) - def discard(self, n=-1, qid=-1, **handlers): + def discard(self, n=-1, qid=-1, dehydration_hooks=None, + hydration_hooks=None, **handlers): # Just ignore n and qid, it is not supported in the Bolt 3 Protocol. log.debug("[#%04X] C: DISCARD_ALL", self.local_port) - self._append(b"\x2F", (), Response(self, "discard", **handlers)) + self._append(b"\x2F", (), + Response(self, "discard", hydration_hooks, **handlers), + dehydration_hooks=dehydration_hooks) - def pull(self, n=-1, qid=-1, **handlers): + def pull(self, n=-1, qid=-1, dehydration_hooks=None, hydration_hooks=None, + **handlers): # Just ignore n and qid, it is not supported in the Bolt 3 Protocol. log.debug("[#%04X] C: PULL_ALL", self.local_port) - self._append(b"\x3F", (), Response(self, "pull", **handlers)) + self._append(b"\x3F", (), + Response(self, "pull", hydration_hooks, **handlers), + dehydration_hooks=dehydration_hooks) def begin(self, mode=None, bookmarks=None, metadata=None, timeout=None, - db=None, imp_user=None, **handlers): + db=None, imp_user=None, dehydration_hooks=None, + hydration_hooks=None, **handlers): if db is not None: raise ConfigurationError( "Database name parameter for selecting database is not " @@ -280,17 +297,25 @@ def begin(self, mode=None, bookmarks=None, metadata=None, timeout=None, if extra["tx_timeout"] < 0: raise ValueError("Timeout must be a positive number or 0.") log.debug("[#%04X] C: BEGIN %r", self.local_port, extra) - self._append(b"\x11", (extra,), Response(self, "begin", **handlers)) + self._append(b"\x11", (extra,), + Response(self, "begin", hydration_hooks, **handlers), + dehydration_hooks=dehydration_hooks) - def commit(self, **handlers): + def commit(self, dehydration_hooks=None, hydration_hooks=None, **handlers): log.debug("[#%04X] C: COMMIT", self.local_port) - self._append(b"\x12", (), CommitResponse(self, "commit", **handlers)) + self._append(b"\x12", (), + CommitResponse(self, "commit", hydration_hooks, + **handlers), + dehydration_hooks=dehydration_hooks) - def rollback(self, **handlers): + def rollback(self, dehydration_hooks=None, hydration_hooks=None, + **handlers): log.debug("[#%04X] C: ROLLBACK", self.local_port) - self._append(b"\x13", (), Response(self, "rollback", **handlers)) + self._append(b"\x13", (), + Response(self, "rollback", hydration_hooks, **handlers), + dehydration_hooks=dehydration_hooks) - def reset(self): + def reset(self, dehydration_hooks=None, hydration_hooks=None): """ Add a RESET message to the outgoing queue, send it and consume all remaining messages. """ @@ -299,21 +324,33 @@ def fail(metadata): raise BoltProtocolError("RESET failed %r" % metadata, address=self.unresolved_address) log.debug("[#%04X] C: RESET", self.local_port) - self._append(b"\x0F", response=Response(self, "reset", on_failure=fail)) + self._append(b"\x0F", + response=Response(self, "reset", hydration_hooks, + on_failure=fail), + dehydration_hooks=dehydration_hooks) self.send_all() self.fetch_all() - def goodbye(self): + def goodbye(self, dehydration_hooks=None, hydration_hooks=None): log.debug("[#%04X] C: GOODBYE", self.local_port) - self._append(b"\x02", ()) + self._append(b"\x02", (), dehydration_hooks=dehydration_hooks) - def _process_message(self, details, summary_signature, - summary_metadata): + def _process_message(self, tag, fields): """ Process at most one message from the server, if available. :return: 2-tuple of number of detail messages and number of summary messages fetched """ + details = [] + summary_signature = summary_metadata = None + if tag == b"\x71": # RECORD + details = fields + elif fields: + summary_signature = tag + summary_metadata = fields[0] + else: + summary_signature = tag + if details: log.debug("[#%04X] S: RECORD * %d", self.local_port, len(details)) # Do not log any data self.responses[0].on_records(details) diff --git a/neo4j/_sync/io/_bolt4.py b/neo4j/_sync/io/_bolt4.py index 2c26af8cd..609115264 100644 --- a/neo4j/_sync/io/_bolt4.py +++ b/neo4j/_sync/io/_bolt4.py @@ -34,11 +34,11 @@ NotALeader, ServiceUnavailable, ) +from ._bolt import Bolt from ._bolt3 import ( ServerStateManager, ServerStates, ) -from ._bolt import Bolt from ._common import ( check_supported_server_product, CommitResponse, @@ -95,7 +95,7 @@ def get_base_headers(self): "user_agent": self.user_agent, } - def hello(self): + def hello(self, dehydration_hooks=None, hydration_hooks=None): headers = self.get_base_headers() headers.update(self.auth_dict) logged_headers = dict(headers) @@ -103,13 +103,19 @@ def hello(self): logged_headers["credentials"] = "*******" log.debug("[#%04X] C: HELLO %r", self.local_port, logged_headers) self._append(b"\x01", (headers,), - response=InitResponse(self, "hello", - on_success=self.server_info.update)) + response=InitResponse( + self, "hello", hydration_hooks, + on_success=self.server_info.update + ), + dehydration_hooks=dehydration_hooks) self.send_all() self.fetch_all() check_supported_server_product(self.server_info.agent) - def route(self, database=None, imp_user=None, bookmarks=None): + def route( + self, database=None, imp_user=None, bookmarks=None, + dehydration_hooks=None, hydration_hooks=None + ): if imp_user is not None: raise ConfigurationError( "Impersonation is not supported in Bolt Protocol {!r}. " @@ -138,14 +144,20 @@ def route(self, database=None, imp_user=None, bookmarks=None): db=SYSTEM_DATABASE, on_success=metadata.update ) - self.pull(on_success=metadata.update, on_records=records.extend) + self.pull( + dehydration_hooks=dehydration_hooks, + hydration_hooks=hydration_hooks, + on_success=metadata.update, + on_records=records.extend + ) self.send_all() self.fetch_all() routing_info = [dict(zip(metadata.get("fields", ()), values)) for values in records] return routing_info def run(self, query, parameters=None, mode=None, bookmarks=None, - metadata=None, timeout=None, db=None, imp_user=None, **handlers): + metadata=None, timeout=None, db=None, imp_user=None, + dehydration_hooks=None, hydration_hooks=None, **handlers): if imp_user is not None: raise ConfigurationError( "Impersonation is not supported in Bolt Protocol {!r}. " @@ -179,24 +191,33 @@ def run(self, query, parameters=None, mode=None, bookmarks=None, raise ValueError("Timeout must be a positive number or 0.") fields = (query, parameters, extra) log.debug("[#%04X] C: RUN %s", self.local_port, " ".join(map(repr, fields))) - self._append(b"\x10", fields, Response(self, "run", **handlers)) + self._append(b"\x10", fields, + Response(self, "run", hydration_hooks, **handlers), + dehydration_hooks=dehydration_hooks) - def discard(self, n=-1, qid=-1, **handlers): + def discard(self, n=-1, qid=-1, dehydration_hooks=None, + hydration_hooks=None, **handlers): extra = {"n": n} if qid != -1: extra["qid"] = qid log.debug("[#%04X] C: DISCARD %r", self.local_port, extra) - self._append(b"\x2F", (extra,), Response(self, "discard", **handlers)) + self._append(b"\x2F", (extra,), + Response(self, "discard", hydration_hooks, **handlers), + dehydration_hooks=dehydration_hooks) - def pull(self, n=-1, qid=-1, **handlers): + def pull(self, n=-1, qid=-1, dehydration_hooks=None, hydration_hooks=None, + **handlers): extra = {"n": n} if qid != -1: extra["qid"] = qid log.debug("[#%04X] C: PULL %r", self.local_port, extra) - self._append(b"\x3F", (extra,), Response(self, "pull", **handlers)) + self._append(b"\x3F", (extra,), + Response(self, "pull", hydration_hooks, **handlers), + dehydration_hooks=dehydration_hooks) def begin(self, mode=None, bookmarks=None, metadata=None, timeout=None, - db=None, imp_user=None, **handlers): + db=None, imp_user=None, dehydration_hooks=None, + hydration_hooks=None, **handlers): if imp_user is not None: raise ConfigurationError( "Impersonation is not supported in Bolt Protocol {!r}. " @@ -227,17 +248,25 @@ def begin(self, mode=None, bookmarks=None, metadata=None, timeout=None, if extra["tx_timeout"] < 0: raise ValueError("Timeout must be a positive number or 0.") log.debug("[#%04X] C: BEGIN %r", self.local_port, extra) - self._append(b"\x11", (extra,), Response(self, "begin", **handlers)) + self._append(b"\x11", (extra,), + Response(self, "begin", hydration_hooks, **handlers), + dehydration_hooks=dehydration_hooks) - def commit(self, **handlers): + def commit(self, dehydration_hooks=None, hydration_hooks=None, **handlers): log.debug("[#%04X] C: COMMIT", self.local_port) - self._append(b"\x12", (), CommitResponse(self, "commit", **handlers)) + self._append(b"\x12", (), + CommitResponse(self, "commit", hydration_hooks, + **handlers), + dehydration_hooks=dehydration_hooks) - def rollback(self, **handlers): + def rollback(self, dehydration_hooks=None, hydration_hooks=None, + **handlers): log.debug("[#%04X] C: ROLLBACK", self.local_port) - self._append(b"\x13", (), Response(self, "rollback", **handlers)) + self._append(b"\x13", (), + Response(self, "rollback", hydration_hooks, **handlers), + dehydration_hooks=dehydration_hooks) - def reset(self): + def reset(self, dehydration_hooks=None, hydration_hooks=None): """ Add a RESET message to the outgoing queue, send it and consume all remaining messages. """ @@ -246,21 +275,33 @@ def fail(metadata): raise BoltProtocolError("RESET failed %r" % metadata, self.unresolved_address) log.debug("[#%04X] C: RESET", self.local_port) - self._append(b"\x0F", response=Response(self, "reset", on_failure=fail)) + self._append(b"\x0F", + response=Response(self, "reset", hydration_hooks, + on_failure=fail), + dehydration_hooks=dehydration_hooks) self.send_all() self.fetch_all() - def goodbye(self): + def goodbye(self, dehydration_hooks=None, hydration_hooks=None): log.debug("[#%04X] C: GOODBYE", self.local_port) - self._append(b"\x02", ()) + self._append(b"\x02", (), dehydration_hooks=dehydration_hooks) - def _process_message(self, details, summary_signature, - summary_metadata): + def _process_message(self, tag, fields): """ Process at most one message from the server, if available. :return: 2-tuple of number of detail messages and number of summary messages fetched """ + details = [] + summary_signature = summary_metadata = None + if tag == b"\x71": # RECORD + details = fields + elif fields: + summary_signature = tag + summary_metadata = fields[0] + else: + summary_signature = tag + if details: log.debug("[#%04X] S: RECORD * %d", self.local_port, len(details)) # Do not log any data self.responses[0].on_records(details) @@ -341,7 +382,15 @@ class Bolt4x3(Bolt4x2): PROTOCOL_VERSION = Version(4, 3) - def route(self, database=None, imp_user=None, bookmarks=None): + def get_base_headers(self): + headers = super().get_base_headers() + headers["patch_bolt"] = ["utc"] + return headers + + def route( + self, database=None, imp_user=None, bookmarks=None, + dehydration_hooks=None, hydration_hooks=None + ): if imp_user is not None: raise ConfigurationError( "Impersonation is not supported in Bolt Protocol {!r}. " @@ -359,13 +408,14 @@ def route(self, database=None, imp_user=None, bookmarks=None): else: bookmarks = list(bookmarks) self._append(b"\x66", (routing_context, bookmarks, database), - response=Response(self, "route", - on_success=metadata.update)) + response=Response(self, "route", hydration_hooks, + on_success=metadata.update), + dehydration_hooks=dehydration_hooks) self.send_all() self.fetch_all() return [metadata.get("rt")] - def hello(self): + def hello(self, dehydration_hooks=None, hydration_hooks=None): def on_success(metadata): self.configuration_hints.update(metadata.pop("hints", {})) self.server_info.update(metadata) @@ -380,6 +430,9 @@ def on_success(metadata): "connection.recv_timeout_seconds (%r). Make sure " "the server and network is set up correctly.", self.local_port, recv_timeout) + self.patch = set(metadata.pop("patch_bolt", [])) + if "utc" in self.patch: + self.hydration_handler.patch_utc() headers = self.get_base_headers() headers.update(self.auth_dict) @@ -388,8 +441,9 @@ def on_success(metadata): logged_headers["credentials"] = "*******" log.debug("[#%04X] C: HELLO %r", self.local_port, logged_headers) self._append(b"\x01", (headers,), - response=InitResponse(self, "hello", - on_success=on_success)) + response=InitResponse(self, "hello", hydration_hooks, + on_success=on_success), + dehydration_hooks=dehydration_hooks) self.send_all() self.fetch_all() check_supported_server_product(self.server_info.agent) @@ -403,7 +457,10 @@ class Bolt4x4(Bolt4x3): PROTOCOL_VERSION = Version(4, 4) - def route(self, database=None, imp_user=None, bookmarks=None): + def route( + self, database=None, imp_user=None, bookmarks=None, + dehydration_hooks=None, hydration_hooks=None + ): routing_context = self.routing_context or {} db_context = {} if database is not None: @@ -418,14 +475,16 @@ def route(self, database=None, imp_user=None, bookmarks=None): else: bookmarks = list(bookmarks) self._append(b"\x66", (routing_context, bookmarks, db_context), - response=Response(self, "route", - on_success=metadata.update)) + response=Response(self, "route", hydration_hooks, + on_success=metadata.update), + dehydration_hooks=dehydration_hooks) self.send_all() self.fetch_all() return [metadata.get("rt")] def run(self, query, parameters=None, mode=None, bookmarks=None, - metadata=None, timeout=None, db=None, imp_user=None, **handlers): + metadata=None, timeout=None, db=None, imp_user=None, + dehydration_hooks=None, hydration_hooks=None, **handlers): if not parameters: parameters = {} extra = {} @@ -456,10 +515,13 @@ def run(self, query, parameters=None, mode=None, bookmarks=None, fields = (query, parameters, extra) log.debug("[#%04X] C: RUN %s", self.local_port, " ".join(map(repr, fields))) - self._append(b"\x10", fields, Response(self, "run", **handlers)) + self._append(b"\x10", fields, + Response(self, "run", hydration_hooks, **handlers), + dehydration_hooks=dehydration_hooks) def begin(self, mode=None, bookmarks=None, metadata=None, timeout=None, - db=None, imp_user=None, **handlers): + db=None, imp_user=None, dehydration_hooks=None, + hydration_hooks=None, **handlers): extra = {} if mode in (READ_ACCESS, "r"): # It will default to mode "w" if nothing is specified @@ -486,4 +548,6 @@ def begin(self, mode=None, bookmarks=None, metadata=None, timeout=None, if extra["tx_timeout"] < 0: raise ValueError("Timeout must be a positive number or 0.") log.debug("[#%04X] C: BEGIN %r", self.local_port, extra) - self._append(b"\x11", (extra,), Response(self, "begin", **handlers)) + self._append(b"\x11", (extra,), + Response(self, "begin", hydration_hooks, **handlers), + dehydration_hooks=dehydration_hooks) diff --git a/neo4j/_sync/io/_bolt5.py b/neo4j/_sync/io/_bolt5.py index 74fe2d183..a7180e5a5 100644 --- a/neo4j/_sync/io/_bolt5.py +++ b/neo4j/_sync/io/_bolt5.py @@ -19,28 +19,24 @@ from logging import getLogger from ssl import SSLSocket -from ..._async_compat.util import Util -from ..._exceptions import ( - BoltError, - BoltProtocolError, -) +from ..._codec.hydration import v2 as hydration_v2 +from ..._exceptions import BoltProtocolError from ...api import ( READ_ACCESS, Version, ) from ...exceptions import ( DatabaseUnavailable, - DriverError, ForbiddenOnReadOnlyDatabase, Neo4jError, NotALeader, ServiceUnavailable, ) +from ._bolt import Bolt from ._bolt3 import ( ServerStateManager, ServerStates, ) -from ._bolt import Bolt from ._common import ( check_supported_server_product, CommitResponse, @@ -57,6 +53,8 @@ class Bolt5x0(Bolt): PROTOCOL_VERSION = Version(5, 0) + HYDRATION_HANDLER_CLS = hydration_v2.HydrationHandler + supports_multiple_results = True supports_multiple_databases = True @@ -95,7 +93,7 @@ def get_base_headers(self): headers["routing"] = self.routing_context return headers - def hello(self): + def hello(self, dehydration_hooks=None, hydration_hooks=None): def on_success(metadata): self.configuration_hints.update(metadata.pop("hints", {})) self.server_info.update(metadata) @@ -118,13 +116,15 @@ def on_success(metadata): logged_headers["credentials"] = "*******" log.debug("[#%04X] C: HELLO %r", self.local_port, logged_headers) self._append(b"\x01", (headers,), - response=InitResponse(self, "hello", - on_success=on_success)) + response=InitResponse(self, "hello", hydration_hooks, + on_success=on_success), + dehydration_hooks=dehydration_hooks) self.send_all() self.fetch_all() check_supported_server_product(self.server_info.agent) - def route(self, database=None, imp_user=None, bookmarks=None): + def route(self, database=None, imp_user=None, bookmarks=None, + dehydration_hooks=None, hydration_hooks=None): routing_context = self.routing_context or {} db_context = {} if database is not None: @@ -139,14 +139,16 @@ def route(self, database=None, imp_user=None, bookmarks=None): else: bookmarks = list(bookmarks) self._append(b"\x66", (routing_context, bookmarks, db_context), - response=Response(self, "route", - on_success=metadata.update)) + response=Response(self, "route", hydration_hooks, + on_success=metadata.update), + dehydration_hooks=hydration_hooks) self.send_all() self.fetch_all() return [metadata.get("rt")] def run(self, query, parameters=None, mode=None, bookmarks=None, - metadata=None, timeout=None, db=None, imp_user=None, **handlers): + metadata=None, timeout=None, db=None, imp_user=None, + dehydration_hooks=None, hydration_hooks=None, **handlers): if not parameters: parameters = {} extra = {} @@ -177,24 +179,33 @@ def run(self, query, parameters=None, mode=None, bookmarks=None, fields = (query, parameters, extra) log.debug("[#%04X] C: RUN %s", self.local_port, " ".join(map(repr, fields))) - self._append(b"\x10", fields, Response(self, "run", **handlers)) + self._append(b"\x10", fields, + Response(self, "run", hydration_hooks, **handlers), + dehydration_hooks=dehydration_hooks) - def discard(self, n=-1, qid=-1, **handlers): + def discard(self, n=-1, qid=-1, dehydration_hooks=None, + hydration_hooks=None, **handlers): extra = {"n": n} if qid != -1: extra["qid"] = qid log.debug("[#%04X] C: DISCARD %r", self.local_port, extra) - self._append(b"\x2F", (extra,), Response(self, "discard", **handlers)) + self._append(b"\x2F", (extra,), + Response(self, "discard", hydration_hooks, **handlers), + dehydration_hooks=dehydration_hooks) - def pull(self, n=-1, qid=-1, **handlers): + def pull(self, n=-1, qid=-1, dehydration_hooks=None, hydration_hooks=None, + **handlers): extra = {"n": n} if qid != -1: extra["qid"] = qid log.debug("[#%04X] C: PULL %r", self.local_port, extra) - self._append(b"\x3F", (extra,), Response(self, "pull", **handlers)) + self._append(b"\x3F", (extra,), + Response(self, "pull", hydration_hooks, **handlers), + dehydration_hooks=dehydration_hooks) def begin(self, mode=None, bookmarks=None, metadata=None, timeout=None, - db=None, imp_user=None, **handlers): + db=None, imp_user=None, dehydration_hooks=None, + hydration_hooks=None, **handlers): extra = {} if mode in (READ_ACCESS, "r"): # It will default to mode "w" if nothing is specified @@ -221,17 +232,25 @@ def begin(self, mode=None, bookmarks=None, metadata=None, timeout=None, if extra["tx_timeout"] < 0: raise ValueError("Timeout must be a number <= 0") log.debug("[#%04X] C: BEGIN %r", self.local_port, extra) - self._append(b"\x11", (extra,), Response(self, "begin", **handlers)) + self._append(b"\x11", (extra,), + Response(self, "begin", hydration_hooks, **handlers), + dehydration_hooks=dehydration_hooks) - def commit(self, **handlers): + def commit(self, dehydration_hooks=None, hydration_hooks=None, **handlers): log.debug("[#%04X] C: COMMIT", self.local_port) - self._append(b"\x12", (), CommitResponse(self, "commit", **handlers)) + self._append(b"\x12", (), + CommitResponse(self, "commit", hydration_hooks, + **handlers), + dehydration_hooks=dehydration_hooks) - def rollback(self, **handlers): + def rollback(self, dehydration_hooks=None, hydration_hooks=None, + **handlers): log.debug("[#%04X] C: ROLLBACK", self.local_port) - self._append(b"\x13", (), Response(self, "rollback", **handlers)) + self._append(b"\x13", (), + Response(self, "rollback", hydration_hooks, **handlers), + dehydration_hooks=dehydration_hooks) - def reset(self): + def reset(self, dehydration_hooks=None, hydration_hooks=None): """Reset the connection. Add a RESET message to the outgoing queue, send it and consume all @@ -243,22 +262,33 @@ def fail(metadata): self.unresolved_address) log.debug("[#%04X] C: RESET", self.local_port) - self._append(b"\x0F", response=Response(self, "reset", - on_failure=fail)) + self._append(b"\x0F", + response=Response(self, "reset", hydration_hooks, + on_failure=fail), + dehydration_hooks=dehydration_hooks) self.send_all() self.fetch_all() - def goodbye(self): + def goodbye(self, dehydration_hooks=None, hydration_hooks=None): log.debug("[#%04X] C: GOODBYE", self.local_port) - self._append(b"\x02", ()) + self._append(b"\x02", (), dehydration_hooks=dehydration_hooks) - def _process_message(self, details, summary_signature, - summary_metadata): + def _process_message(self, tag, fields): """Process at most one message from the server, if available. :return: 2-tuple of number of detail messages and number of summary messages fetched """ + details = [] + summary_signature = summary_metadata = None + if tag == b"\x71": # RECORD + details = fields + elif fields: + summary_signature = tag + summary_metadata = fields[0] + else: + summary_signature = tag + if details: # Do not log any data log.debug("[#%04X] S: RECORD * %d", self.local_port, len(details)) diff --git a/neo4j/_sync/io/_common.py b/neo4j/_sync/io/_common.py index 26a2b5547..1eea4ba34 100644 --- a/neo4j/_sync/io/_common.py +++ b/neo4j/_sync/io/_common.py @@ -17,7 +17,6 @@ import asyncio -from contextlib import contextmanager import logging import socket from struct import pack as struct_pack @@ -30,132 +29,122 @@ SessionExpired, UnsupportedServerProduct, ) -from ...packstream import ( - UnpackableBuffer, - Unpacker, -) log = logging.getLogger("neo4j") -class MessageInbox: +class Inbox: - def __init__(self, s, on_error): + def __init__(self, sock, on_error, unpacker_cls): self.on_error = on_error - self._local_port = s.getsockname()[1] - self._messages = self._yield_messages(s) - - def _yield_messages(self, sock): + self._local_port = sock.getsockname()[1] + self._socket = sock + self._buffer = unpacker_cls.new_unpackable_buffer() + self._unpacker = unpacker_cls(self._buffer) + self._broken = False + + def _buffer_one_chunk(self): + assert not self._broken try: - buffer = UnpackableBuffer() - unpacker = Unpacker(buffer) chunk_size = 0 while True: - while chunk_size == 0: # Determine the chunk size and skip noop - receive_into_buffer(sock, buffer, 2) - chunk_size = buffer.pop_u16() + receive_into_buffer(self._socket, self._buffer, 2) + chunk_size = self._buffer.pop_u16() if chunk_size == 0: log.debug("[#%04X] S: ", self._local_port) - receive_into_buffer(sock, buffer, chunk_size + 2) - chunk_size = buffer.pop_u16() + receive_into_buffer( + self._socket, self._buffer, chunk_size + 2 + ) + chunk_size = self._buffer.pop_u16() if chunk_size == 0: # chunk_size was the end marker for the message - size, tag = unpacker.unpack_structure_header() - fields = [unpacker.unpack() for _ in range(size)] - yield tag, fields - # Reset for new message - unpacker.reset() + return except (OSError, socket.timeout, SocketDeadlineExceeded) as error: + self._broken = True Util.callback(self.on_error, error) + raise - def pop(self): - return Util.next(self._messages) - - -class Inbox(MessageInbox): - - def __next__(self): - tag, fields = self.pop() - if tag == b"\x71": - return fields, None, None - elif fields: - return [], tag, fields[0] - else: - return [], tag, None + def pop(self, hydration_hooks): + self._buffer_one_chunk() + try: + size, tag = self._unpacker.unpack_structure_header() + fields = [self._unpacker.unpack(hydration_hooks) + for _ in range(size)] + return tag, fields + finally: + # Reset for new message + self._unpacker.reset() class Outbox: - def __init__(self, max_chunk_size=16384): + def __init__(self, sock, on_error, packer_cls, max_chunk_size=16384): self._max_chunk_size = max_chunk_size self._chunked_data = bytearray() - self._raw_data = bytearray() - self.write = self._raw_data.extend - self._tmp_buffering = 0 + self._buffer = packer_cls.new_packable_buffer() + self._packer = packer_cls(self._buffer) + self.socket = sock + self.on_error = on_error def max_chunk_size(self): return self._max_chunk_size - def clear(self): - if self._tmp_buffering: - raise RuntimeError("Cannot clear while buffering") + def _clear(self): + assert not self._buffer.is_tmp_buffering() self._chunked_data = bytearray() - self._raw_data.clear() + self._buffer.clear() def _chunk_data(self): - data_len = len(self._raw_data) + data_len = len(self._buffer.data) num_full_chunks, chunk_rest = divmod( data_len, self._max_chunk_size ) num_chunks = num_full_chunks + bool(chunk_rest) - data_view = memoryview(self._raw_data) - header_start = len(self._chunked_data) - data_start = header_start + 2 - raw_data_start = 0 - for i in range(num_chunks): - chunk_size = min(data_len - raw_data_start, - self._max_chunk_size) - self._chunked_data[header_start:data_start] = struct_pack( - ">H", chunk_size - ) - self._chunked_data[data_start:(data_start + chunk_size)] = \ - data_view[raw_data_start:(raw_data_start + chunk_size)] - header_start += chunk_size + 2 + with memoryview(self._buffer.data) as data_view: + header_start = len(self._chunked_data) data_start = header_start + 2 - raw_data_start += chunk_size - del data_view - self._raw_data.clear() - - def wrap_message(self): - if self._tmp_buffering: - raise RuntimeError("Cannot wrap message while buffering") + raw_data_start = 0 + for i in range(num_chunks): + chunk_size = min(data_len - raw_data_start, + self._max_chunk_size) + self._chunked_data[header_start:data_start] = struct_pack( + ">H", chunk_size + ) + self._chunked_data[data_start:(data_start + chunk_size)] = \ + data_view[raw_data_start:(raw_data_start + chunk_size)] + header_start += chunk_size + 2 + data_start = header_start + 2 + raw_data_start += chunk_size + self._buffer.clear() + + def _wrap_message(self): + assert not self._buffer.is_tmp_buffering() self._chunk_data() self._chunked_data += b"\x00\x00" - def view(self): - if self._tmp_buffering: - raise RuntimeError("Cannot view while buffering") - self._chunk_data() - return memoryview(self._chunked_data) + def append_message(self, tag, fields, dehydration_hooks): + with self._buffer.tmp_buffer(): + self._packer.pack_struct(tag, fields, dehydration_hooks) + self._wrap_message() - @contextmanager - def tmp_buffer(self): - self._tmp_buffering += 1 - old_len = len(self._raw_data) - try: - yield - except Exception: - del self._raw_data[old_len:] - raise - finally: - self._tmp_buffering -= 1 + def flush(self): + data = self._chunked_data + if data: + try: + self.socket.sendall(data) + except OSError as error: + self.on_error(error) + return False + self._clear() + return True + return False class ConnectionErrorHandler: @@ -218,8 +207,9 @@ class Response: more detail messages followed by one summary message). """ - def __init__(self, connection, message, **handlers): + def __init__(self, connection, message, hydration_hooks, **handlers): self.connection = connection + self.hydration_hooks = hydration_hooks self.handlers = handlers self.message = message self.complete = False @@ -294,9 +284,9 @@ def receive_into_buffer(sock, buffer, n_bytes): end = buffer.used + n_bytes if end > len(buffer.data): buffer.data += bytearray(end - len(buffer.data)) - view = memoryview(buffer.data) - while buffer.used < end: - n = sock.recv_into(view[buffer.used:end], end - buffer.used) - if n == 0: - raise OSError("No data") - buffer.used += n + with memoryview(buffer.data) as view: + while buffer.used < end: + n = sock.recv_into(view[buffer.used:end], end - buffer.used) + if n == 0: + raise OSError("No data") + buffer.used += n diff --git a/neo4j/_sync/io/_pool.py b/neo4j/_sync/io/_pool.py index ca25a49b2..3cd66a6d3 100644 --- a/neo4j/_sync/io/_pool.py +++ b/neo4j/_sync/io/_pool.py @@ -17,12 +17,12 @@ import abc +import logging from collections import ( defaultdict, deque, ) from contextlib import contextmanager -import logging from logging import getLogger from random import choice @@ -31,6 +31,10 @@ RLock, ) from ..._async_compat.network import NetworkUtil +from ..._conf import ( + PoolConfig, + WorkspaceConfig, +) from ..._deadline import ( connection_deadline, Deadline, @@ -38,14 +42,11 @@ merge_deadlines_and_timeouts, ) from ..._exceptions import BoltError +from ..._routing import RoutingTable from ...api import ( READ_ACCESS, WRITE_ACCESS, ) -from ...conf import ( - PoolConfig, - WorkspaceConfig, -) from ...exceptions import ( ClientError, ConfigurationError, @@ -56,7 +57,6 @@ SessionExpired, WriteServiceUnavailable, ) -from ...routing import RoutingTable from ._bolt import Bolt diff --git a/neo4j/_sync/work/result.py b/neo4j/_sync/work/result.py index 807096556..888fd2701 100644 --- a/neo4j/_sync/work/result.py +++ b/neo4j/_sync/work/result.py @@ -20,15 +20,15 @@ from warnings import warn from ..._async_compat.util import Util -from ...data import ( - DataDehydrator, +from ..._data import ( + Record, RecordTableRowExporter, ) +from ..._meta import experimental from ...exceptions import ( ResultConsumedError, ResultNotSingleError, ) -from ...meta import experimental from ...time import ( Date, DateTime, @@ -54,10 +54,9 @@ class Result: :meth:`.AyncSession.run` and :meth:`.Transaction.run`. """ - def __init__(self, connection, hydrant, fetch_size, on_closed, - on_error): + def __init__(self, connection, fetch_size, on_closed, on_error): self._connection = ConnectionErrorHandler(connection, on_error) - self._hydrant = hydrant + self._hydration_scope = connection.new_hydration_scope() self._on_closed = on_closed self._metadata = None self._keys = None @@ -104,7 +103,7 @@ def _run( query_metadata = getattr(query, "metadata", None) query_timeout = getattr(query, "timeout", None) - parameters = DataDehydrator.fix_parameters(dict(parameters or {}, **kwargs)) + parameters = dict(parameters or {}, **kwargs) self._metadata = { "query": query_text, @@ -135,6 +134,7 @@ def on_failed_attach(metadata): timeout=query_timeout, db=db, imp_user=imp_user, + dehydration_hooks=self._hydration_scope.dehydration_hooks, on_success=on_attached, on_failure=on_failed_attach, ) @@ -145,7 +145,10 @@ def on_failed_attach(metadata): def _pull(self): def on_records(records): if not self._discarding: - self._record_buffer.extend(self._hydrant.hydrate_records(self._keys, records)) + self._record_buffer.extend(( + Record(zip(self._keys, record)) + for record in records + )) def on_summary(): self._attached = False @@ -167,6 +170,7 @@ def on_success(summary_metadata): self._connection.pull( n=self._fetch_size, qid=self._qid, + hydration_hooks=self._hydration_scope.hydration_hooks, on_records=on_records, on_success=on_success, on_failure=on_failure, @@ -479,7 +483,7 @@ def graph(self): Can raise :exc:`ResultConsumedError`. """ self._buffer_all() - return self._hydrant.graph + return self._hydration_scope.get_graph() def value(self, key=0, default=None): """Helper function that return the remainder of the result as a list of values. diff --git a/neo4j/_sync/work/session.py b/neo4j/_sync/work/session.py index c3300b27e..72a7d4308 100644 --- a/neo4j/_sync/work/session.py +++ b/neo4j/_sync/work/session.py @@ -21,13 +21,16 @@ from time import perf_counter from ..._async_compat import sleep +from ..._conf import SessionConfig +from ..._meta import ( + deprecated, + deprecation_warn, +) from ...api import ( Bookmarks, READ_ACCESS, WRITE_ACCESS, ) -from ...conf import SessionConfig -from ...data import DataHydrator from ...exceptions import ( ClientError, DriverError, @@ -36,10 +39,6 @@ SessionExpired, TransactionError, ) -from ...meta import ( - deprecated, - deprecation_warn, -) from ...work import Query from .result import Result from .transaction import ( @@ -228,10 +227,8 @@ def run(self, query, parameters=None, **kwargs): protocol_version = cx.PROTOCOL_VERSION server_info = cx.server_info - hydrant = DataHydrator() - self._auto_result = Result( - cx, hydrant, self._config.fetch_size, self._result_closed, + cx, self._config.fetch_size, self._result_closed, self._result_error ) self._auto_result._run( diff --git a/neo4j/_sync/work/transaction.py b/neo4j/_sync/work/transaction.py index 9b6b29c4b..95dd80332 100644 --- a/neo4j/_sync/work/transaction.py +++ b/neo4j/_sync/work/transaction.py @@ -19,7 +19,6 @@ from functools import wraps from ..._async_compat.util import Util -from ...data import DataHydrator from ...exceptions import TransactionError from ...work import Query from ..io import ConnectionErrorHandler @@ -123,8 +122,7 @@ def run(self, query, parameters=None, **kwparameters): self._results[-1]._buffer_all() result = Result( - self._connection, DataHydrator(), self._fetch_size, - self._result_on_closed_handler, + self._connection, self._fetch_size, self._result_on_closed_handler, self._error_handler ) self._results.append(result) diff --git a/neo4j/_sync/work/workspace.py b/neo4j/_sync/work/workspace.py index a177b097c..c10fc912e 100644 --- a/neo4j/_sync/work/workspace.py +++ b/neo4j/_sync/work/workspace.py @@ -18,16 +18,16 @@ import asyncio +from ..._conf import WorkspaceConfig from ..._deadline import Deadline -from ...conf import WorkspaceConfig +from ..._meta import ( + deprecation_warn, + unclosed_resource_warn, +) from ...exceptions import ( ServiceUnavailable, SessionExpired, ) -from ...meta import ( - deprecation_warn, - unclosed_resource_warn, -) from ..io import Neo4jPool diff --git a/neo4j/api.py b/neo4j/api.py index 58292ad83..7930d1d4e 100644 --- a/neo4j/api.py +++ b/neo4j/api.py @@ -24,8 +24,8 @@ urlparse, ) +from ._meta import deprecated from .exceptions import ConfigurationError -from .meta import deprecated READ_ACCESS = "READ" @@ -165,6 +165,7 @@ def custom_auth(principal, credentials, realm, scheme, **parameters): return Auth(scheme, principal, credentials, realm, **parameters) +# TODO 6.0 - remove this class class Bookmark: """A Bookmark object contains an immutable list of bookmark string values. @@ -271,6 +272,10 @@ def from_raw_values(cls, values): if not isinstance(value, str): raise TypeError("Raw bookmark values must be str. " "Found {}".format(type(value))) + try: + value.encode("ascii") + except UnicodeEncodeError as e: + raise ValueError(f"The value {value} is not ASCII") from e bookmarks.append(value) obj._raw_values = frozenset(bookmarks) return obj diff --git a/neo4j/conf.py b/neo4j/conf.py index f93ba30b8..150a58850 100644 --- a/neo4j/conf.py +++ b/neo4j/conf.py @@ -15,401 +15,39 @@ # See the License for the specific language governing permissions and # limitations under the License. +# TODO: 6.0 - remove this file -from abc import ABCMeta -from collections.abc import Mapping from ._conf import ( - TrustAll, - TrustCustomCAs, - TrustSystemCAs, + Config, + ConfigType, + DeprecatedAlias, + DeprecatedAlternative, + iter_items, + PoolConfig, + RoutingConfig, + SessionConfig, + TransactionConfig, + WorkspaceConfig, ) -from .api import ( - DEFAULT_DATABASE, - TRUST_ALL_CERTIFICATES, - TRUST_SYSTEM_CA_SIGNED_CERTIFICATES, - WRITE_ACCESS, +from ._meta import deprecation_warn as _deprecation_warn + + +__all__ = [ + "Config", + "ConfigType", + "DeprecatedAlias", + "DeprecatedAlternative", + "iter_items", + "PoolConfig", + "RoutingConfig", + "SessionConfig", + "TransactionConfig", + "WorkspaceConfig", +] + +_deprecation_warn( + "The module 'neo4j.conf' was made internal and will " + "no longer be available for import in future versions.", + stack_level=2 ) -from .exceptions import ConfigurationError -from .meta import ( - deprecation_warn, - get_user_agent, -) - - -def iter_items(iterable): - """ Iterate through all items (key-value pairs) within an iterable - dictionary-like object. If the object has a `keys` method, this is - used along with `__getitem__` to yield each pair in turn. If no - `keys` method exists, each iterable element is assumed to be a - 2-tuple of key and value. - """ - if hasattr(iterable, "keys"): - for key in iterable.keys(): - yield key, iterable[key] - else: - for key, value in iterable: - yield key, value - - -class DeprecatedAlias: - """Used when a config option has been renamed.""" - - def __init__(self, new): - self.new = new - - -class DeprecatedAlternative: - """Used for deprecated config options that have a similar alternative.""" - - def __init__(self, new, converter=None): - self.new = new - self.converter = converter - - -class ConfigType(ABCMeta): - - def __new__(mcs, name, bases, attributes): - fields = [] - deprecated_aliases = {} - deprecated_alternatives = {} - - for base in bases: - if type(base) is mcs: - fields += base.keys() - deprecated_aliases.update(base._deprecated_aliases()) - deprecated_alternatives.update(base._deprecated_alternatives()) - - for k, v in attributes.items(): - if isinstance(v, DeprecatedAlias): - deprecated_aliases[k] = v.new - elif isinstance(v, DeprecatedAlternative): - deprecated_alternatives[k] = v.new, v.converter - elif not (k.startswith("_") - or callable(v) - or isinstance(v, (staticmethod, classmethod))): - fields.append(k) - - def keys(_): - return set(fields) - - def _deprecated_keys(_): - return (set(deprecated_aliases.keys()) - | set(deprecated_alternatives.keys())) - - def _get_new(_, key): - return deprecated_aliases.get( - key, deprecated_alternatives.get(key, (None,))[0] - ) - - def _deprecated_aliases(_): - return deprecated_aliases - - def _deprecated_alternatives(_): - return deprecated_alternatives - - attributes.setdefault("keys", classmethod(keys)) - attributes.setdefault("_get_new", - classmethod(_get_new)) - attributes.setdefault("_deprecated_keys", - classmethod(_deprecated_keys)) - attributes.setdefault("_deprecated_aliases", - classmethod(_deprecated_aliases)) - attributes.setdefault("_deprecated_alternatives", - classmethod(_deprecated_alternatives)) - - return super(ConfigType, mcs).__new__( - mcs, name, bases, {k: v for k, v in attributes.items() - if k not in _deprecated_keys(None)} - ) - - -class Config(Mapping, metaclass=ConfigType): - """ Base class for all configuration containers. - """ - - @staticmethod - def consume_chain(data, *config_classes): - values = [] - for config_class in config_classes: - if not issubclass(config_class, Config): - raise TypeError("%r is not a Config subclass" % config_class) - values.append(config_class._consume(data)) - if data: - raise ConfigurationError("Unexpected config keys: %s" % ", ".join(data.keys())) - return values - - @classmethod - def consume(cls, data): - config, = cls.consume_chain(data, cls) - return config - - @classmethod - def _consume(cls, data): - config = {} - if data: - for key in cls.keys() | cls._deprecated_keys(): - try: - value = data.pop(key) - except KeyError: - pass - else: - config[key] = value - return cls(config) - - def __update(self, data): - data_dict = dict(iter_items(data)) - - def set_attr(k, v): - if k in self.keys(): - setattr(self, k, v) - elif k in self._deprecated_keys(): - k0 = self._get_new(k) - if k0 in data_dict: - raise ConfigurationError( - "Cannot specify both '{}' and '{}' in config" - .format(k0, k) - ) - deprecation_warn( - "The '{}' config key is deprecated, please use '{}' " - "instead".format(k, k0) - ) - if k in self._deprecated_aliases(): - set_attr(k0, v) - else: # k in self._deprecated_alternatives: - _, converter = self._deprecated_alternatives()[k] - converter(self, v) - else: - raise AttributeError(k) - - for key, value in data_dict.items(): - if value is not None: - set_attr(key, value) - - def __init__(self, *args, **kwargs): - for arg in args: - self.__update(arg) - self.__update(kwargs) - - def __repr__(self): - attrs = [] - for key in self: - attrs.append(" %s=%r" % (key, getattr(self, key))) - return "<%s%s>" % (self.__class__.__name__, "".join(attrs)) - - def __len__(self): - return len(self.keys()) - - def __getitem__(self, key): - return getattr(self, key) - - def __iter__(self): - return iter(self.keys()) - - -def _trust_to_trusted_certificates(pool_config, trust): - if trust == TRUST_SYSTEM_CA_SIGNED_CERTIFICATES: - pool_config.trusted_certificates = TrustSystemCAs() - elif trust == TRUST_ALL_CERTIFICATES: - pool_config.trusted_certificates = TrustAll() - - -class PoolConfig(Config): - """ Connection pool configuration. - """ - - #: Max Connection Lifetime - max_connection_lifetime = 3600 # seconds - # The maximum duration the driver will keep a connection for before being removed from the pool. - - #: Max Connection Pool Size - max_connection_pool_size = 100 - # The maximum total number of connections allowed, per host (i.e. cluster nodes), to be managed by the connection pool. - - #: Connection Timeout - connection_timeout = 30.0 # seconds - # The maximum amount of time to wait for a TCP connection to be established. - - #: Update Routing Table Timout - update_routing_table_timeout = 90.0 # seconds - # The maximum amount of time to wait for updating the routing table. - # This includes everything necessary for this to happen. - # Including opening sockets, requesting and receiving the routing table, - # etc. - - #: Trust - trust = DeprecatedAlternative( - "trusted_certificates", _trust_to_trusted_certificates - ) - # Specify how to determine the authenticity of encryption certificates provided by the Neo4j instance on connection. - - #: Custom Resolver - resolver = None - # Custom resolver function, returning list of resolved addresses. - - #: Encrypted - encrypted = False - # Specify whether to use an encrypted connection between the driver and server. - - #: SSL Certificates to Trust - trusted_certificates = TrustSystemCAs() - # Specify how to determine the authenticity of encryption certificates - # provided by the Neo4j instance on connection. - # * `neo4j.TrustSystemCAs()`: Use system trust store. (default) - # * `neo4j.TrustAll()`: Trust any certificate. - # * `neo4j.TrustCustomCAs("", ...)`: - # Trust the specified certificate(s). - - #: Custom SSL context to use for wrapping sockets - ssl_context = None - # Use any custom SSL context to wrap sockets. - # Overwrites `trusted_certificates` and `encrypted`. - # The use of this option is strongly discouraged. - - #: User Agent (Python Driver Specific) - user_agent = get_user_agent() - # Specify the client agent name. - - #: Protocol Version (Python Driver Specific) - protocol_version = None # Version(4, 0) - # Specify a specific Bolt Protocol Version - - #: Initial Connection Pool Size (Python Driver Specific) - init_size = 1 # The other drivers do not seed from the start. - # This will seed the pool with the specified number of connections. - - #: Socket Keep Alive (Python and .NET Driver Specific) - keep_alive = True - # Specify whether TCP keep-alive should be enabled. - - def get_ssl_context(self): - if self.ssl_context is not None: - return self.ssl_context - - if not self.encrypted: - return None - - import ssl - - # SSL stands for Secure Sockets Layer and was originally created by Netscape. - # SSLv2 and SSLv3 are the 2 versions of this protocol (SSLv1 was never publicly released). - # After SSLv3, SSL was renamed to TLS. - # TLS stands for Transport Layer Security and started with TLSv1.0 which is an upgraded version of SSLv3. - # SSLv2 - (Disabled) - # SSLv3 - (Disabled) - # TLS 1.0 - Released in 1999, published as RFC 2246. (Disabled) - # TLS 1.1 - Released in 2006, published as RFC 4346. (Disabled) - # TLS 1.2 - Released in 2008, published as RFC 5246. - # https://docs.python.org/3.7/library/ssl.html#ssl.PROTOCOL_TLS_CLIENT - ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) - - # For recommended security options see - # https://docs.python.org/3.7/library/ssl.html#protocol-versions - ssl_context.options |= ssl.OP_NO_TLSv1 # Python 3.2 - ssl_context.options |= ssl.OP_NO_TLSv1_1 # Python 3.4 - - if isinstance(self.trusted_certificates, TrustAll): - # trust any certificate - ssl_context.check_hostname = False - # https://docs.python.org/3.7/library/ssl.html#ssl.CERT_NONE - ssl_context.verify_mode = ssl.CERT_NONE - elif isinstance(self.trusted_certificates, TrustCustomCAs): - # trust the specified certificate(s) - ssl_context.check_hostname = True - ssl_context.verify_mode = ssl.CERT_REQUIRED - for cert in self.trusted_certificates.certs: - ssl_context.load_verify_locations(cert) - else: - # default - # trust system CA certificates - ssl_context.check_hostname = True - ssl_context.verify_mode = ssl.CERT_REQUIRED - # Must be load_default_certs, not set_default_verify_paths to - # work on Windows with system CAs. - ssl_context.load_default_certs() - - return ssl_context - - -class WorkspaceConfig(Config): - """ WorkSpace configuration. - """ - - #: Session Connection Timeout - session_connection_timeout = 120.0 # seconds - # The maximum amount of time to wait for a session to obtain a usable - # read/write connection. This includes everything necessary for this to - # happen. Including fetching routing tables, opening sockets, etc. - - #: Connection Acquisition Timeout - connection_acquisition_timeout = 60.0 # seconds - # The maximum amount of time a session will wait when requesting a connection from the connection pool. - # Since the process of acquiring a connection may involve creating a new connection, ensure that the value - # of this configuration is higher than the configured Connection Timeout. - - #: Max Transaction Retry Time - max_transaction_retry_time = 30.0 # seconds - # The maximum amount of time that a managed transaction will retry before failing. - - #: Initial Retry Delay - initial_retry_delay = 1.0 # seconds - - #: Retry Delay Multiplier - retry_delay_multiplier = 2.0 # seconds - - #: Retry Delay Jitter Factor - retry_delay_jitter_factor = 0.2 # seconds - - #: Database Name - database = DEFAULT_DATABASE - # Name of the database to query. - # Note: The default database can be set on the Neo4j instance settings. - - #: Fetch Size - fetch_size = 1000 - - #: User to impersonate - impersonated_user = None - # Note that you need appropriate permissions to do so. - - -class SessionConfig(WorkspaceConfig): - """ Session configuration. - """ - - #: Bookmarks - bookmarks = None - - #: Default AccessMode - default_access_mode = WRITE_ACCESS - - -class TransactionConfig(Config): - """ Transaction configuration. This is internal for now. - - neo4j.session.begin_transaction - neo4j.Query - neo4j.unit_of_work - - are both using the same settings. - """ - #: Metadata - metadata = None # dictionary - - #: Timeout - timeout = None # seconds - - -class RoutingConfig(Config): - """ Neo4jDriver routing settings. This is internal for now. - """ - - #: Routing Table Purge_Delay - routing_table_purge_delay = 30.0 # seconds - # The TTL + routing_table_purge_delay should be used to check if the database routing table should be removed. - - #: Max Routing Failures - # max_routing_failures = 1 - - #: Retry Timeout Delay - # retry_timeout_delay = 5.0 # seconds diff --git a/neo4j/data.py b/neo4j/data.py index 7ce3e712e..0713ed460 100644 --- a/neo4j/data.py +++ b/neo4j/data.py @@ -15,459 +15,31 @@ # See the License for the specific language governing permissions and # limitations under the License. +# TODO: 6.0 - remove this file -from abc import ( - ABCMeta, - abstractmethod, -) -from collections.abc import ( - Mapping, - Sequence, - Set, -) -from datetime import ( - date, - datetime, - time, - timedelta, -) -from functools import reduce -from operator import xor as xor_operator -from .conf import iter_items -from .graph import ( - Graph, - Node, - Path, - Relationship, -) -from .packstream import ( - INT64_MAX, - INT64_MIN, - Structure, -) -from .spatial import ( - dehydrate_point, - hydrate_point, - Point, -) -from .time import ( - Date, - DateTime, - Duration, - Time, -) -from .time.hydration import ( - dehydrate_date, - dehydrate_datetime, - dehydrate_duration, - dehydrate_time, - dehydrate_timedelta, - hydrate_date, - hydrate_datetime, - hydrate_duration, - hydrate_time, +from ._data import ( + DataTransformer, + Record, + RecordExporter, + RecordTableRowExporter, ) +from ._meta import deprecation_warn map_type = type(map(str, range(0))) +__all__ = [ + "map_type", + "Record", + "DataTransformer", + "RecordExporter", + "RecordTableRowExporter", +] -class Record(tuple, Mapping): - """ A :class:`.Record` is an immutable ordered collection of key-value - pairs. It is generally closer to a :py:class:`namedtuple` than to a - :py:class:`OrderedDict` in as much as iteration of the collection will - yield values rather than keys. - """ - - __keys = None - - def __new__(cls, iterable=()): - keys = [] - values = [] - for key, value in iter_items(iterable): - keys.append(key) - values.append(value) - inst = tuple.__new__(cls, values) - inst.__keys = tuple(keys) - return inst - - def __repr__(self): - return "<%s %s>" % (self.__class__.__name__, - " ".join("%s=%r" % (field, self[i]) for i, field in enumerate(self.__keys))) - - def __eq__(self, other): - """ In order to be flexible regarding comparison, the equality rules - for a record permit comparison with any other Sequence or Mapping. - - :param other: - :return: - """ - compare_as_sequence = isinstance(other, Sequence) - compare_as_mapping = isinstance(other, Mapping) - if compare_as_sequence and compare_as_mapping: - return list(self) == list(other) and dict(self) == dict(other) - elif compare_as_sequence: - return list(self) == list(other) - elif compare_as_mapping: - return dict(self) == dict(other) - else: - return False - - def __ne__(self, other): - return not self.__eq__(other) - - def __hash__(self): - return reduce(xor_operator, map(hash, self.items())) - - def __getitem__(self, key): - if isinstance(key, slice): - keys = self.__keys[key] - values = super(Record, self).__getitem__(key) - return self.__class__(zip(keys, values)) - try: - index = self.index(key) - except IndexError: - return None - else: - return super(Record, self).__getitem__(index) - - def __getslice__(self, start, stop): - key = slice(start, stop) - keys = self.__keys[key] - values = tuple(self)[key] - return self.__class__(zip(keys, values)) - - def get(self, key, default=None): - """ Obtain a value from the record by key, returning a default - value if the key does not exist. - - :param key: a key - :param default: default value - :return: a value - """ - try: - index = self.__keys.index(str(key)) - except ValueError: - return default - if 0 <= index < len(self): - return super(Record, self).__getitem__(index) - else: - return default - - def index(self, key): - """ Return the index of the given item. - - :param key: a key - :return: index - :rtype: int - """ - if isinstance(key, int): - if 0 <= key < len(self.__keys): - return key - raise IndexError(key) - elif isinstance(key, str): - try: - return self.__keys.index(key) - except ValueError: - raise KeyError(key) - else: - raise TypeError(key) - - def value(self, key=0, default=None): - """ Obtain a single value from the record by index or key. If no - index or key is specified, the first value is returned. If the - specified item does not exist, the default value is returned. - - :param key: an index or key - :param default: default value - :return: a single value - """ - try: - index = self.index(key) - except (IndexError, KeyError): - return default - else: - return self[index] - - def keys(self): - """ Return the keys of the record. - - :return: list of key names - """ - return list(self.__keys) - - def values(self, *keys): - """ Return the values of the record, optionally filtering to - include only certain values by index or key. - - :param keys: indexes or keys of the items to include; if none - are provided, all values will be included - :return: list of values - :rtype: list - """ - if keys: - d = [] - for key in keys: - try: - i = self.index(key) - except KeyError: - d.append(None) - else: - d.append(self[i]) - return d - return list(self) - - def items(self, *keys): - """ Return the fields of the record as a list of key and value tuples - - :return: a list of value tuples - :rtype: list - """ - if keys: - d = [] - for key in keys: - try: - i = self.index(key) - except KeyError: - d.append((key, None)) - else: - d.append((self.__keys[i], self[i])) - return d - return list((self.__keys[i], super(Record, self).__getitem__(i)) for i in range(len(self))) - - def data(self, *keys): - """ Return the keys and values of this record as a dictionary, - optionally including only certain values by index or key. Keys - provided in the items that are not in the record will be - inserted with a value of :const:`None`; indexes provided - that are out of bounds will trigger an :exc:`IndexError`. - - :param keys: indexes or keys of the items to include; if none - are provided, all values will be included - :return: dictionary of values, keyed by field name - :raises: :exc:`IndexError` if an out-of-bounds index is specified - """ - return RecordExporter().transform(dict(self.items(*keys))) - - -class DataTransformer(metaclass=ABCMeta): - """ Abstract base class for transforming data from one form into - another. - """ - - @abstractmethod - def transform(self, x): - """ Transform a value, or collection of values. - - :param x: input value - :return: output value - """ - - -class RecordExporter(DataTransformer): - """ Transformer class used by the :meth:`.Record.data` method. - """ - - def transform(self, x): - if isinstance(x, Node): - return self.transform(dict(x)) - elif isinstance(x, Relationship): - return (self.transform(dict(x.start_node)), - x.__class__.__name__, - self.transform(dict(x.end_node))) - elif isinstance(x, Path): - path = [self.transform(x.start_node)] - for i, relationship in enumerate(x.relationships): - path.append(self.transform(relationship.__class__.__name__)) - path.append(self.transform(x.nodes[i + 1])) - return path - elif isinstance(x, str): - return x - elif isinstance(x, Sequence): - t = type(x) - return t(map(self.transform, x)) - elif isinstance(x, Set): - t = type(x) - return t(map(self.transform, x)) - elif isinstance(x, Mapping): - t = type(x) - return t((k, self.transform(v)) for k, v in x.items()) - else: - return x - - -class RecordTableRowExporter(DataTransformer): - """Transformer class used by the :meth:`.Result.to_df` method.""" - - def transform(self, x): - assert isinstance(x, Mapping) - t = type(x) - return t(item - for k, v in x.items() - for item in self._transform( - v, prefix=k.replace("\\", "\\\\").replace(".", "\\.") - ).items()) - - def _transform(self, x, prefix): - if isinstance(x, Node): - res = { - "%s().element_id" % prefix: x.element_id, - "%s().labels" % prefix: x.labels, - } - res.update(("%s().prop.%s" % (prefix, k), v) for k, v in x.items()) - return res - elif isinstance(x, Relationship): - res = { - "%s->.element_id" % prefix: x.element_id, - "%s->.start.element_id" % prefix: x.start_node.element_id, - "%s->.end.element_id" % prefix: x.end_node.element_id, - "%s->.type" % prefix: x.__class__.__name__, - } - res.update(("%s->.prop.%s" % (prefix, k), v) for k, v in x.items()) - return res - elif isinstance(x, Path) or isinstance(x, str): - return {prefix: x} - elif isinstance(x, Sequence): - return dict( - item - for i, v in enumerate(x) - for item in self._transform( - v, prefix="%s[].%i" % (prefix, i) - ).items() - ) - elif isinstance(x, Mapping): - t = type(x) - return t( - item - for k, v in x.items() - for item in self._transform( - v, prefix="%s{}.%s" % (prefix, k.replace("\\", "\\\\") - .replace(".", "\\.")) - ).items() - ) - else: - return {prefix: x} - - -class DataHydrator: - # TODO: extend DataTransformer - - def __init__(self): - super(DataHydrator, self).__init__() - self.graph = Graph() - self.graph_hydrator = Graph.Hydrator(self.graph) - self.hydration_functions = { - b"N": self.graph_hydrator.hydrate_node, - b"R": self.graph_hydrator.hydrate_relationship, - b"r": self.graph_hydrator.hydrate_unbound_relationship, - b"P": self.graph_hydrator.hydrate_path, - b"X": hydrate_point, - b"Y": hydrate_point, - b"D": hydrate_date, - b"T": hydrate_time, # time zone offset - b"t": hydrate_time, # no time zone - b"F": hydrate_datetime, # time zone offset - b"f": hydrate_datetime, # time zone name - b"d": hydrate_datetime, # no time zone - b"E": hydrate_duration, - } - - def hydrate(self, values): - """ Convert PackStream values into native values. - """ - - def hydrate_(obj): - if isinstance(obj, Structure): - try: - f = self.hydration_functions[obj.tag] - except KeyError: - # If we don't recognise the structure - # type, just return it as-is - return obj - else: - return f(*map(hydrate_, obj.fields)) - elif isinstance(obj, list): - return list(map(hydrate_, obj)) - elif isinstance(obj, dict): - return {key: hydrate_(value) for key, value in obj.items()} - else: - return obj - - return tuple(map(hydrate_, values)) - - def hydrate_records(self, keys, record_values): - for values in record_values: - yield Record(zip(keys, self.hydrate(values))) - - -class DataDehydrator: - # TODO: extend DataTransformer - - @classmethod - def fix_parameters(cls, parameters): - if not parameters: - return {} - dehydrator = cls() - try: - dehydrated, = dehydrator.dehydrate([parameters]) - except TypeError as error: - value = error.args[0] - raise TypeError("Parameters of type {} are not supported".format(type(value).__name__)) - else: - return dehydrated - - def __init__(self): - self.dehydration_functions = {} - self.dehydration_functions.update({ - Point: dehydrate_point, - Date: dehydrate_date, - date: dehydrate_date, - Time: dehydrate_time, - time: dehydrate_time, - DateTime: dehydrate_datetime, - datetime: dehydrate_datetime, - Duration: dehydrate_duration, - timedelta: dehydrate_timedelta, - }) - # Allow dehydration from any direct Point subclass - self.dehydration_functions.update({cls: dehydrate_point for cls in Point.__subclasses__()}) - - def dehydrate(self, values): - """ Convert native values into PackStream values. - """ - - def dehydrate_(obj): - try: - f = self.dehydration_functions[type(obj)] - except KeyError: - pass - else: - return f(obj) - if obj is None: - return None - elif isinstance(obj, bool): - return obj - elif isinstance(obj, int): - if INT64_MIN <= obj <= INT64_MAX: - return obj - raise ValueError("Integer out of bounds (64-bit signed " - "integer values only)") - elif isinstance(obj, float): - return obj - elif isinstance(obj, str): - return obj - elif isinstance(obj, (bytes, bytearray)): - # order is important here - bytes must be checked after str - return obj - elif isinstance(obj, (list, map_type)): - return list(map(dehydrate_, obj)) - elif isinstance(obj, dict): - if any(not isinstance(key, str) for key in obj.keys()): - raise TypeError("Non-string dictionary keys are " - "not supported") - return {key: dehydrate_(value) for key, value in obj.items()} - else: - raise TypeError(obj) - - return tuple(map(dehydrate_, values)) +deprecation_warn( + "The module 'neo4j.data' was made internal and will " + "no longer be available for import in future versions. " + "`neo4j.data.Record` should be imported directly from `neo4j`.", + stack_level=2 +) diff --git a/neo4j/graph/__init__.py b/neo4j/graph/__init__.py index a614a2d2d..0939c54c5 100644 --- a/neo4j/graph/__init__.py +++ b/neo4j/graph/__init__.py @@ -31,7 +31,7 @@ from collections.abc import Mapping -from ..meta import ( +from .._meta import ( deprecated, deprecation_warn, ) @@ -74,92 +74,6 @@ def relationship_type(self, name): cls = self._relationship_types[name] = type(str(name), (Relationship,), {}) return cls - class Hydrator: - - def __init__(self, graph): - self.graph = graph - - def hydrate_node(self, id_, labels=None, - properties=None, element_id=None): - assert isinstance(self.graph, Graph) - # backwards compatibility with Neo4j < 5.0 - if element_id is None: - element_id = str(id_) - - try: - inst = self.graph._nodes[element_id] - except KeyError: - inst = Node(self.graph, element_id, id_, labels, properties) - self.graph._nodes[element_id] = inst - self.graph._legacy_nodes[id_] = inst - else: - # If we have already hydrated this node as the endpoint of - # a relationship, it won't have any labels or properties. - # Therefore, we need to add the ones we have here. - if labels: - inst._labels = inst._labels.union(labels) # frozen_set - if properties: - inst._properties.update(properties) - return inst - - def hydrate_relationship(self, id_, n0_id, n1_id, type_, - properties=None, element_id=None, - n0_element_id=None, n1_element_id=None): - # backwards compatibility with Neo4j < 5.0 - if element_id is None: - element_id = str(id_) - if n0_element_id is None: - n0_element_id = str(n0_id) - if n1_element_id is None: - n1_element_id = str(n1_id) - - inst = self.hydrate_unbound_relationship(id_, type_, properties, - element_id) - inst._start_node = self.hydrate_node(n0_id, - element_id=n0_element_id) - inst._end_node = self.hydrate_node(n1_id, element_id=n1_element_id) - return inst - - def hydrate_unbound_relationship(self, id_, type_, properties=None, - element_id=None): - assert isinstance(self.graph, Graph) - # backwards compatibility with Neo4j < 5.0 - if element_id is None: - element_id = str(id_) - - try: - inst = self.graph._relationships[element_id] - except KeyError: - r = self.graph.relationship_type(type_) - inst = r( - self.graph, element_id, id_, properties - ) - self.graph._relationships[element_id] = inst - self.graph._legacy_relationships[id_] = inst - return inst - - def hydrate_path(self, nodes, relationships, sequence): - assert isinstance(self.graph, Graph) - assert len(nodes) >= 1 - assert len(sequence) % 2 == 0 - last_node = nodes[0] - entities = [last_node] - for i, rel_index in enumerate(sequence[::2]): - assert rel_index != 0 - next_node = nodes[sequence[2 * i + 1]] - if rel_index > 0: - r = relationships[rel_index - 1] - r._start_node = last_node - r._end_node = next_node - entities.append(r) - else: - r = relationships[-rel_index - 1] - r._start_node = next_node - r._end_node = last_node - entities.append(r) - last_node = next_node - return Path(*entities) - class Entity(Mapping): """ Base class for :class:`.Node` and :class:`.Relationship` that diff --git a/neo4j/meta.py b/neo4j/meta.py index 644c3536b..05b1be0cc 100644 --- a/neo4j/meta.py +++ b/neo4j/meta.py @@ -15,110 +15,34 @@ # See the License for the specific language governing permissions and # limitations under the License. - -import asyncio -from functools import wraps -from warnings import warn - - -# Can be automatically overridden in builds -package = "neo4j" -version = "5.0.dev0" - - -def get_user_agent(): - """ Obtain the default user agent string sent to the server after - a successful handshake. - """ - from sys import ( - platform, - version_info, - ) - template = "neo4j-python/{} Python/{}.{}.{}-{}-{} ({})" - fields = (version,) + tuple(version_info) + (platform,) - return template.format(*fields) - - -def deprecation_warn(message, stack_level=2): - warn(message, category=DeprecationWarning, stacklevel=stack_level) - - -def deprecated(message): - """ Decorator for deprecating functions and methods. - - :: - - @deprecated("'foo' has been deprecated in favour of 'bar'") - def foo(x): - pass - - """ - def decorator(f): - if asyncio.iscoroutinefunction(f): - @wraps(f) - async def inner(*args, **kwargs): - deprecation_warn(message, stack_level=3) - return await f(*args, **kwargs) - - return inner - else: - @wraps(f) - def inner(*args, **kwargs): - deprecation_warn(message, stack_level=3) - return f(*args, **kwargs) - - return inner - - return decorator - - -class ExperimentalWarning(Warning): - """ Base class for warnings about experimental features. - """ - - -def experimental_warn(message, stack_level=2): - warn(message, category=ExperimentalWarning, stacklevel=stack_level) - - -def experimental(message): - """ Decorator for tagging experimental functions and methods. - - :: - - @experimental("'foo' is an experimental function and may be " - "removed in a future release") - def foo(x): - pass - - """ - def decorator(f): - if asyncio.iscoroutinefunction(f): - @wraps(f) - async def inner(*args, **kwargs): - experimental_warn(message, stack_level=3) - return await f(*args, **kwargs) - - return inner - else: - @wraps(f) - def inner(*args, **kwargs): - experimental_warn(message, stack_level=3) - return f(*args, **kwargs) - - return inner - - return decorator - - -def unclosed_resource_warn(obj): - import tracemalloc - from warnings import warn - msg = f"Unclosed {obj!r}." - trace = tracemalloc.get_object_traceback(obj) - if trace: - msg += "\nObject allocated at (most recent call last):\n" - msg += "\n".join(trace.format()) - else: - msg += "\nEnable tracemalloc to get the object allocation traceback." - warn(msg, ResourceWarning, stacklevel=2, source=obj) +# TODO: 6.0 - remove this file + + +from ._meta import ( + deprecated, + deprecation_warn, + experimental, + ExperimentalWarning, + get_user_agent, + package, + version, +) + + +__all__ = [ + "package", + "version", + "get_user_agent", + "deprecation_warn", + "deprecated", + "ExperimentalWarning", + "experimental", +] + +deprecation_warn( + "The module 'neo4j.meta' was made internal and will " + "no longer be available for import in future versions." + "`ExperimentalWarning` can be imported from `neo4j` directly and " + "`neo4j.meta.version` is exposed as `neo4j.__version__`.", + stack_level=2 +) diff --git a/neo4j/packstream.py b/neo4j/packstream.py index 92bf6b96b..041b644f7 100644 --- a/neo4j/packstream.py +++ b/neo4j/packstream.py @@ -16,430 +16,43 @@ # limitations under the License. -from codecs import decode -from struct import ( - pack as struct_pack, - unpack as struct_unpack, +# TODO: 6.0 - remove this file + + +from ._codec.packstream.v1 import ( + INT64_MAX, + INT64_MIN, + PACKED_UINT_8, + PACKED_UINT_16, + Packer, + Structure, + UnpackableBuffer, + UNPACKED_MARKERS, + UNPACKED_UINT_8, + UNPACKED_UINT_16, + Unpacker, +) +from ._meta import deprecation_warn + + +__all__ = [ + "PACKED_UINT_8", + "PACKED_UINT_16", + "UNPACKED_UINT_8", + "UNPACKED_UINT_16", + "UNPACKED_MARKERS", + "UNPACKED_MARKERS", + "UNPACKED_MARKERS", + "INT64_MIN", + "INT64_MAX", + "Structure", + "Packer", + "Unpacker", + "UnpackableBuffer", +] + +deprecation_warn( + "The module 'neo4j.packstream' was made internal and will " + "no longer be available for import in future versions.", + stack_level=2 ) - - -PACKED_UINT_8 = [struct_pack(">B", value) for value in range(0x100)] -PACKED_UINT_16 = [struct_pack(">H", value) for value in range(0x10000)] - -UNPACKED_UINT_8 = {bytes(bytearray([x])): x for x in range(0x100)} -UNPACKED_UINT_16 = {struct_pack(">H", x): x for x in range(0x10000)} - -UNPACKED_MARKERS = {b"\xC0": None, b"\xC2": False, b"\xC3": True} -UNPACKED_MARKERS.update({bytes(bytearray([z])): z for z in range(0x00, 0x80)}) -UNPACKED_MARKERS.update({bytes(bytearray([z + 256])): z for z in range(-0x10, 0x00)}) - - -INT64_MIN = -(2 ** 63) -INT64_MAX = 2 ** 63 - - -class Structure: - - def __init__(self, tag, *fields): - self.tag = tag - self.fields = list(fields) - - def __repr__(self): - return "Structure[0x%02X](%s)" % (ord(self.tag), ", ".join(map(repr, self.fields))) - - def __eq__(self, other): - try: - return self.tag == other.tag and self.fields == other.fields - except AttributeError: - return False - - def __ne__(self, other): - return not self.__eq__(other) - - def __len__(self): - return len(self.fields) - - def __getitem__(self, key): - return self.fields[key] - - def __setitem__(self, key, value): - self.fields[key] = value - - -class Packer: - - def __init__(self, stream): - self.stream = stream - self._write = self.stream.write - - def pack_raw(self, data): - self._write(data) - - def pack(self, value): - return self._pack(value) - - def _pack(self, value): - write = self._write - - # None - if value is None: - write(b"\xC0") # NULL - - # Boolean - elif value is True: - write(b"\xC3") - elif value is False: - write(b"\xC2") - - # Float (only double precision is supported) - elif isinstance(value, float): - write(b"\xC1") - write(struct_pack(">d", value)) - - # Integer - elif isinstance(value, int): - if -0x10 <= value < 0x80: - write(PACKED_UINT_8[value % 0x100]) - elif -0x80 <= value < -0x10: - write(b"\xC8") - write(PACKED_UINT_8[value % 0x100]) - elif -0x8000 <= value < 0x8000: - write(b"\xC9") - write(PACKED_UINT_16[value % 0x10000]) - elif -0x80000000 <= value < 0x80000000: - write(b"\xCA") - write(struct_pack(">i", value)) - elif INT64_MIN <= value < INT64_MAX: - write(b"\xCB") - write(struct_pack(">q", value)) - else: - raise OverflowError("Integer %s out of range" % value) - - # String - elif isinstance(value, str): - encoded = value.encode("utf-8") - self.pack_string_header(len(encoded)) - self.pack_raw(encoded) - - # Bytes - elif isinstance(value, (bytes, bytearray)): - self.pack_bytes_header(len(value)) - self.pack_raw(value) - - # List - elif isinstance(value, list): - self.pack_list_header(len(value)) - for item in value: - self._pack(item) - - # Map - elif isinstance(value, dict): - self.pack_map_header(len(value)) - for key, item in value.items(): - self._pack(key) - self._pack(item) - - # Structure - elif isinstance(value, Structure): - self.pack_struct(value.tag, value.fields) - - # Other - else: - raise ValueError("Values of type %s are not supported" % type(value)) - - def pack_bytes_header(self, size): - write = self._write - if size < 0x100: - write(b"\xCC") - write(PACKED_UINT_8[size]) - elif size < 0x10000: - write(b"\xCD") - write(PACKED_UINT_16[size]) - elif size < 0x100000000: - write(b"\xCE") - write(struct_pack(">I", size)) - else: - raise OverflowError("Bytes header size out of range") - - def pack_string_header(self, size): - write = self._write - if size <= 0x0F: - write(bytes((0x80 | size,))) - elif size < 0x100: - write(b"\xD0") - write(PACKED_UINT_8[size]) - elif size < 0x10000: - write(b"\xD1") - write(PACKED_UINT_16[size]) - elif size < 0x100000000: - write(b"\xD2") - write(struct_pack(">I", size)) - else: - raise OverflowError("String header size out of range") - - def pack_list_header(self, size): - write = self._write - if size <= 0x0F: - write(bytes((0x90 | size,))) - elif size < 0x100: - write(b"\xD4") - write(PACKED_UINT_8[size]) - elif size < 0x10000: - write(b"\xD5") - write(PACKED_UINT_16[size]) - elif size < 0x100000000: - write(b"\xD6") - write(struct_pack(">I", size)) - else: - raise OverflowError("List header size out of range") - - def pack_map_header(self, size): - write = self._write - if size <= 0x0F: - write(bytes((0xA0 | size,))) - elif size < 0x100: - write(b"\xD8") - write(PACKED_UINT_8[size]) - elif size < 0x10000: - write(b"\xD9") - write(PACKED_UINT_16[size]) - elif size < 0x100000000: - write(b"\xDA") - write(struct_pack(">I", size)) - else: - raise OverflowError("Map header size out of range") - - def pack_struct(self, signature, fields): - if len(signature) != 1 or not isinstance(signature, bytes): - raise ValueError("Structure signature must be a single byte value") - write = self._write - size = len(fields) - if size <= 0x0F: - write(bytes((0xB0 | size,))) - else: - raise OverflowError("Structure size out of range") - write(signature) - for field in fields: - self._pack(field) - - -class Unpacker: - - def __init__(self, unpackable): - self.unpackable = unpackable - - def reset(self): - self.unpackable.reset() - - def read(self, n=1): - return self.unpackable.read(n) - - def read_u8(self): - return self.unpackable.read_u8() - - def unpack(self): - return self._unpack() - - def _unpack(self): - marker = self.read_u8() - - if marker == -1: - raise ValueError("Nothing to unpack") - - # Tiny Integer - if 0x00 <= marker <= 0x7F: - return marker - elif 0xF0 <= marker <= 0xFF: - return marker - 0x100 - - # Null - elif marker == 0xC0: - return None - - # Float - elif marker == 0xC1: - value, = struct_unpack(">d", self.read(8)) - return value - - # Boolean - elif marker == 0xC2: - return False - elif marker == 0xC3: - return True - - # Integer - elif marker == 0xC8: - return struct_unpack(">b", self.read(1))[0] - elif marker == 0xC9: - return struct_unpack(">h", self.read(2))[0] - elif marker == 0xCA: - return struct_unpack(">i", self.read(4))[0] - elif marker == 0xCB: - return struct_unpack(">q", self.read(8))[0] - - # Bytes - elif marker == 0xCC: - size, = struct_unpack(">B", self.read(1)) - return self.read(size).tobytes() - elif marker == 0xCD: - size, = struct_unpack(">H", self.read(2)) - return self.read(size).tobytes() - elif marker == 0xCE: - size, = struct_unpack(">I", self.read(4)) - return self.read(size).tobytes() - - else: - marker_high = marker & 0xF0 - # String - if marker_high == 0x80: # TINY_STRING - return decode(self.read(marker & 0x0F), "utf-8") - elif marker == 0xD0: # STRING_8: - size, = struct_unpack(">B", self.read(1)) - return decode(self.read(size), "utf-8") - elif marker == 0xD1: # STRING_16: - size, = struct_unpack(">H", self.read(2)) - return decode(self.read(size), "utf-8") - elif marker == 0xD2: # STRING_32: - size, = struct_unpack(">I", self.read(4)) - return decode(self.read(size), "utf-8") - - # List - elif 0x90 <= marker <= 0x9F or 0xD4 <= marker <= 0xD6: - return list(self._unpack_list_items(marker)) - - # Map - elif 0xA0 <= marker <= 0xAF or 0xD8 <= marker <= 0xDA: - return self._unpack_map(marker) - - # Structure - elif 0xB0 <= marker <= 0xBF: - size, tag = self._unpack_structure_header(marker) - value = Structure(tag, *([None] * size)) - for i in range(len(value)): - value[i] = self._unpack() - return value - - else: - raise ValueError("Unknown PackStream marker %02X" % marker) - - def _unpack_list_items(self, marker): - marker_high = marker & 0xF0 - if marker_high == 0x90: - size = marker & 0x0F - if size == 0: - return - elif size == 1: - yield self._unpack() - else: - for _ in range(size): - yield self._unpack() - elif marker == 0xD4: # LIST_8: - size, = struct_unpack(">B", self.read(1)) - for _ in range(size): - yield self._unpack() - elif marker == 0xD5: # LIST_16: - size, = struct_unpack(">H", self.read(2)) - for _ in range(size): - yield self._unpack() - elif marker == 0xD6: # LIST_32: - size, = struct_unpack(">I", self.read(4)) - for _ in range(size): - yield self._unpack() - else: - return - - def unpack_map(self): - marker = self.read_u8() - return self._unpack_map(marker) - - def _unpack_map(self, marker): - marker_high = marker & 0xF0 - if marker_high == 0xA0: - size = marker & 0x0F - value = {} - for _ in range(size): - key = self._unpack() - value[key] = self._unpack() - return value - elif marker == 0xD8: # MAP_8: - size, = struct_unpack(">B", self.read(1)) - value = {} - for _ in range(size): - key = self._unpack() - value[key] = self._unpack() - return value - elif marker == 0xD9: # MAP_16: - size, = struct_unpack(">H", self.read(2)) - value = {} - for _ in range(size): - key = self._unpack() - value[key] = self._unpack() - return value - elif marker == 0xDA: # MAP_32: - size, = struct_unpack(">I", self.read(4)) - value = {} - for _ in range(size): - key = self._unpack() - value[key] = self._unpack() - return value - else: - return None - - def unpack_structure_header(self): - marker = self.read_u8() - if marker == -1: - return None, None - else: - return self._unpack_structure_header(marker) - - def _unpack_structure_header(self, marker): - marker_high = marker & 0xF0 - if marker_high == 0xB0: # TINY_STRUCT - signature = self.read(1).tobytes() - return marker & 0x0F, signature - else: - raise ValueError("Expected structure, found marker %02X" % marker) - - -class UnpackableBuffer: - - initial_capacity = 8192 - - def __init__(self, data=None): - if data is None: - self.data = bytearray(self.initial_capacity) - self.used = 0 - else: - self.data = bytearray(data) - self.used = len(self.data) - self.p = 0 - - def reset(self): - self.used = 0 - self.p = 0 - - def read(self, n=1): - view = memoryview(self.data) - q = self.p + n - subview = view[self.p:q] - self.p = q - return subview - - def read_u8(self): - if self.used - self.p >= 1: - value = self.data[self.p] - self.p += 1 - return value - else: - return -1 - - def pop_u16(self): - """ Remove the last two bytes of data, returning them as a big-endian - 16-bit unsigned integer. - """ - if self.used >= 2: - value = 0x100 * self.data[self.used - 2] + self.data[self.used - 1] - self.used -= 2 - return value - else: - return -1 diff --git a/neo4j/routing.py b/neo4j/routing.py index 99364fccc..1036d92fc 100644 --- a/neo4j/routing.py +++ b/neo4j/routing.py @@ -15,153 +15,23 @@ # See the License for the specific language governing permissions and # limitations under the License. +# TODO: 6.0 - remove this file -from collections.abc import MutableSet -from logging import getLogger -from time import perf_counter -from .addressing import Address +from ._meta import deprecation_warn as _deprecation_warn +from ._routing import ( + OrderedSet, + RoutingTable, +) -log = getLogger("neo4j") +__all__ = [ + "OrderedSet", + "RoutingTable", +] - -class OrderedSet(MutableSet): - - def __init__(self, elements=()): - # dicts keep insertion order starting with Python 3.7 - self._elements = dict.fromkeys(elements) - self._current = None - - def __repr__(self): - return "{%s}" % ", ".join(map(repr, self._elements)) - - def __contains__(self, element): - return element in self._elements - - def __iter__(self): - return iter(self._elements) - - def __len__(self): - return len(self._elements) - - def __getitem__(self, index): - return list(self._elements.keys())[index] - - def add(self, element): - self._elements[element] = None - - def clear(self): - self._elements.clear() - - def discard(self, element): - try: - del self._elements[element] - except KeyError: - pass - - def remove(self, element): - try: - del self._elements[element] - except KeyError: - raise ValueError(element) - - def update(self, elements=()): - self._elements.update(dict.fromkeys(elements)) - - def replace(self, elements=()): - e = self._elements - e.clear() - e.update(dict.fromkeys(elements)) - - -class RoutingTable: - - @classmethod - def parse_routing_info(cls, *, database, servers, ttl): - """ Parse the records returned from the procedure call and - return a new RoutingTable instance. - """ - routers = [] - readers = [] - writers = [] - try: - for server in servers: - role = server["role"] - addresses = [] - for address in server["addresses"]: - addresses.append(Address.parse(address, default_port=7687)) - if role == "ROUTE": - routers.extend(addresses) - elif role == "READ": - readers.extend(addresses) - elif role == "WRITE": - writers.extend(addresses) - except (KeyError, TypeError): - raise ValueError("Cannot parse routing info") - else: - return cls(database=database, routers=routers, readers=readers, writers=writers, ttl=ttl) - - def __init__(self, *, database, routers=(), readers=(), writers=(), ttl=0): - self.initial_routers = OrderedSet(routers) - self.routers = OrderedSet(routers) - self.readers = OrderedSet(readers) - self.writers = OrderedSet(writers) - self.initialized_without_writers = not self.writers - self.last_updated_time = perf_counter() - self.ttl = ttl - self.database = database - - def __repr__(self): - return "RoutingTable(database=%r routers=%r, readers=%r, writers=%r, last_updated_time=%r, ttl=%r)" % ( - self.database, - self.routers, - self.readers, - self.writers, - self.last_updated_time, - self.ttl, - ) - - def __contains__(self, address): - return address in self.routers or address in self.readers or address in self.writers - - def is_fresh(self, readonly=False): - """ Indicator for whether routing information is still usable. - """ - assert isinstance(readonly, bool) - log.debug("[#0000] C: Checking table freshness (readonly=%r)", readonly) - expired = self.last_updated_time + self.ttl <= perf_counter() - if readonly: - has_server_for_mode = bool(self.readers) - else: - has_server_for_mode = bool(self.writers) - log.debug("[#0000] C: Table expired=%r", expired) - log.debug("[#0000] C: Table routers=%r", self.routers) - log.debug("[#0000] C: Table has_server_for_mode=%r", has_server_for_mode) - return not expired and self.routers and has_server_for_mode - - def should_be_purged_from_memory(self): - """ Check if the routing table is stale and not used for a long time and should be removed from memory. - - :return: Returns true if it is old and not used for a while. - :rtype: bool - """ - from neo4j.conf import RoutingConfig - perf_time = perf_counter() - log.debug("[#0000] C: last_updated_time=%r perf_time=%r", self.last_updated_time, perf_time) - return self.last_updated_time + self.ttl + RoutingConfig.routing_table_purge_delay <= perf_time - - def update(self, new_routing_table): - """ Update the current routing table with new routing information - from a replacement table. - """ - self.routers.replace(new_routing_table.routers) - self.readers.replace(new_routing_table.readers) - self.writers.replace(new_routing_table.writers) - self.initialized_without_writers = not self.writers - self.last_updated_time = perf_counter() - self.ttl = new_routing_table.ttl - log.debug("[#0000] S: table=%r", self) - - def servers(self): - return set(self.routers) | set(self.writers) | set(self.readers) +_deprecation_warn( + "The module 'neo4j.routing' was made internal and will " + "no longer be available for import in future versions.", + stack_level=2 +) diff --git a/neo4j/spatial/__init__.py b/neo4j/spatial/__init__.py index 243f72c5f..f530d3da0 100644 --- a/neo4j/spatial/__init__.py +++ b/neo4j/spatial/__init__.py @@ -30,112 +30,38 @@ "WGS84Point", ] - -from threading import Lock - -from neo4j.packstream import Structure - - -# SRID to subclass mappings -__srid_table = {} -__srid_table_lock = Lock() - - -class Point(tuple): - """Base-class for spatial data. - - A point within a geometric space. This type is generally used via its - subclasses and should not be instantiated directly unless there is no - subclass defined for the required SRID. - - :param iterable: - An iterable of coordinates. - All items will be converted to :class:`float`. - """ - - #: The SRID (spatial reference identifier) of the spatial data. - #: A number that identifies the coordinate system the spatial type is to be - #: interpreted in. - #: - #: :type: int - srid = None - - def __new__(cls, iterable): - return tuple.__new__(cls, map(float, iterable)) - - def __repr__(self): - return "POINT(%s)" % " ".join(map(str, self)) - - def __eq__(self, other): - try: - return type(self) is type(other) and tuple(self) == tuple(other) - except (AttributeError, TypeError): - return False - - def __ne__(self, other): - return not self.__eq__(other) - - def __hash__(self): - return hash(type(self)) ^ hash(tuple(self)) - - -def point_type(name, fields, srid_map): - """ Dynamically create a Point subclass. - """ - - def srid(self): - try: - return srid_map[len(self)] - except KeyError: - return None - - attributes = {"srid": property(srid)} - - for index, subclass_field in enumerate(fields): - - def accessor(self, i=index, f=subclass_field): - try: - return self[i] - except IndexError: - raise AttributeError(f) - - for field_alias in {subclass_field, "xyz"[index]}: - attributes[field_alias] = property(accessor) - - cls = type(name, (Point,), attributes) - - with __srid_table_lock: - for dim, srid in srid_map.items(): - __srid_table[srid] = (cls, dim) - - return cls - - -# Point subclass definitions -CartesianPoint = point_type("CartesianPoint", ["x", "y", "z"], - {2: 7203, 3: 9157}) -WGS84Point = point_type("WGS84Point", ["longitude", "latitude", "height"], - {2: 4326, 3: 4979}) - - +from functools import wraps + +from .._codec.hydration.v1 import spatial as _hydration +from .._meta import deprecated +from .._spatial import ( + CartesianPoint, + Point, + point_type as _point_type, + WGS84Point, +) + + +# TODO: 6.0 remove +@deprecated( + "hydrate_point is considered an internal function and will be removed in " + "a future version" +) def hydrate_point(srid, *coordinates): """ Create a new instance of a Point subclass from a raw set of fields. The subclass chosen is determined by the given SRID; a ValueError will be raised if no such subclass can be found. """ - try: - point_class, dim = __srid_table[srid] - except KeyError: - point = Point(coordinates) - point.srid = srid - return point - else: - if len(coordinates) != dim: - raise ValueError("SRID %d requires %d coordinates (%d provided)" % (srid, dim, len(coordinates))) - return point_class(coordinates) + return _hydration.hydrate_point(srid, *coordinates) +# TODO: 6.0 remove +@deprecated( + "hydrate_point is considered an internal function and will be removed in " + "a future version" +) +@wraps(_hydration.dehydrate_point) def dehydrate_point(value): """ Dehydrator for Point data. @@ -143,10 +69,30 @@ def dehydrate_point(value): :type value: Point :return: """ - dim = len(value) - if dim == 2: - return Structure(b"X", value.srid, *value) - elif dim == 3: - return Structure(b"Y", value.srid, *value) - else: - raise ValueError("Cannot dehydrate Point with %d dimensions" % dim) + return _hydration.dehydrate_point(value) + + +# TODO: 6.0 remove +@deprecated( + "hydrate_point is considered an internal function and will be removed in " + "a future version" +) +@wraps(_hydration.dehydrate_point) +def dehydrate_point(value): + """ Dehydrator for Point data. + + :param value: + :type value: Point + :return: + """ + return _hydration.dehydrate_point(value) + + +# TODO: 6.0 remove +@deprecated( + "point_type is considered an internal function and will be removed in " + "a future version" +) +@wraps(_point_type) +def point_type(name, fields, srid_map): + return _point_type(name, fields, srid_map) diff --git a/neo4j/time/__init__.py b/neo4j/time/__init__.py index b0302e316..d30d705e3 100644 --- a/neo4j/time/__init__.py +++ b/neo4j/time/__init__.py @@ -37,19 +37,36 @@ struct_time, ) -from neo4j.time.arithmetic import ( +from ._arithmetic import ( nano_add, nano_div, round_half_to_even, symmetric_divmod, ) -from neo4j.time.metaclasses import ( +from ._metaclasses import ( DateTimeType, DateType, TimeType, ) +__all__ = [ + "MIN_INT64", + "MAX_INT64", + "MIN_YEAR", + "MAX_YEAR", + "Duration", + "Date", + "ZeroDate", + "Time", + "Midnight", + "Midday", + "DateTime", + "Never", + "UnixEpoch", +] + + MIN_INT64 = -(2 ** 63) MAX_INT64 = (2 ** 63) - 1 @@ -241,7 +258,7 @@ class Clock: def __new__(cls): if cls.__implementations is None: # Find an available clock with the best precision - import neo4j.time.clock_implementations + import neo4j.time._clock_implementations cls.__implementations = sorted((clock for clock in Clock.__subclasses__() if clock.available()), key=lambda clock: clock.precision(), reverse=True) if not cls.__implementations: diff --git a/neo4j/time/_arithmetic.py b/neo4j/time/_arithmetic.py new file mode 100644 index 000000000..93bfe8eda --- /dev/null +++ b/neo4j/time/_arithmetic.py @@ -0,0 +1,124 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [https://neo4j.com] +# +# This file is part of Neo4j. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +__all__ = [ + "nano_add", + "nano_div", + "nano_divmod", + "symmetric_divmod", + "round_half_to_even", +] + + +def nano_add(x, y): + """ + + >>> 0.7 + 0.2 + 0.8999999999999999 + >>> -0.7 + 0.2 + -0.49999999999999994 + >>> nano_add(0.7, 0.2) + 0.9 + >>> nano_add(-0.7, 0.2) + -0.5 + + :param x: + :param y: + :return: + """ + return (int(1000000000 * x) + int(1000000000 * y)) / 1000000000 + + +def nano_div(x, y): + """ + + >>> 0.7 / 0.2 + 3.4999999999999996 + >>> -0.7 / 0.2 + -3.4999999999999996 + >>> nano_div(0.7, 0.2) + 3.5 + >>> nano_div(-0.7, 0.2) + -3.5 + + :param x: + :param y: + :return: + """ + return float(1000000000 * x) / int(1000000000 * y) + + +def nano_divmod(x, y): + """ + + >>> divmod(0.7, 0.2) + (3.0, 0.09999999999999992) + >>> nano_divmod(0.7, 0.2) + (3, 0.1) + + :param x: + :param y: + :return: + """ + number = type(x) + nx = int(1000000000 * x) + ny = int(1000000000 * y) + q, r = divmod(nx, ny) + return int(q), number(r / 1000000000) + + +def symmetric_divmod(dividend, divisor): + number = type(dividend) + if dividend >= 0: + quotient, remainder = divmod(dividend, divisor) + return int(quotient), number(remainder) + else: + quotient, remainder = divmod(-dividend, divisor) + return -int(quotient), -number(remainder) + + +def round_half_to_even(n): + """ + + >>> round_half_to_even(3) + 3 + >>> round_half_to_even(3.2) + 3 + >>> round_half_to_even(3.5) + 4 + >>> round_half_to_even(3.7) + 4 + >>> round_half_to_even(4) + 4 + >>> round_half_to_even(4.2) + 4 + >>> round_half_to_even(4.5) + 4 + >>> round_half_to_even(4.7) + 5 + + :param n: + :return: + """ + ten_n = 10 * n + if ten_n == int(ten_n) and ten_n % 10 == 5: + up = int(n + 0.5) + down = int(n - 0.5) + return up if up % 2 == 0 else down + else: + return int(round(n)) diff --git a/neo4j/time/_clock_implementations.py b/neo4j/time/_clock_implementations.py new file mode 100644 index 000000000..60f82cc8b --- /dev/null +++ b/neo4j/time/_clock_implementations.py @@ -0,0 +1,119 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [https://neo4j.com] +# +# This file is part of Neo4j. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from ctypes import ( + byref, + c_long, + c_longlong, + CDLL, + Structure, +) +from platform import uname + +from . import ( + Clock, + ClockTime, +) +from ._arithmetic import nano_divmod + + +__all__ = [ + "SafeClock", + "PEP564Clock", + "LibCClock", +] + + +class SafeClock(Clock): + """ Clock implementation that should work for any variant of Python. + This clock is guaranteed microsecond precision. + """ + + @classmethod + def precision(cls): + return 6 + + @classmethod + def available(cls): + return True + + def utc_time(self): + from time import time + seconds, nanoseconds = nano_divmod(int(time() * 1000000), 1000000) + return ClockTime(seconds, nanoseconds * 1000) + + +class PEP564Clock(Clock): + """ Clock implementation based on the PEP564 additions to Python 3.7. + This clock is guaranteed nanosecond precision. + """ + + @classmethod + def precision(cls): + return 9 + + @classmethod + def available(cls): + try: + from time import time_ns + except ImportError: + return False + else: + return True + + def utc_time(self): + from time import time_ns + t = time_ns() + seconds, nanoseconds = divmod(t, 1000000000) + return ClockTime(seconds, nanoseconds) + + +class LibCClock(Clock): + """ Clock implementation that works only on platforms that provide + libc. This clock is guaranteed nanosecond precision. + """ + + __libc = "libc.dylib" if uname()[0] == "Darwin" else "libc.so.6" + + class _TimeSpec(Structure): + _fields_ = [ + ("seconds", c_longlong), + ("nanoseconds", c_long), + ] + + @classmethod + def precision(cls): + return 9 + + @classmethod + def available(cls): + try: + _ = CDLL(cls.__libc) + except OSError: + return False + else: + return True + + def utc_time(self): + libc = CDLL(self.__libc) + ts = self._TimeSpec() + status = libc.clock_gettime(0, byref(ts)) + if status == 0: + return ClockTime(ts.seconds, ts.nanoseconds) + else: + raise RuntimeError("clock_gettime failed with status %d" % status) diff --git a/neo4j/time/_metaclasses.py b/neo4j/time/_metaclasses.py new file mode 100644 index 000000000..cf9022fbb --- /dev/null +++ b/neo4j/time/_metaclasses.py @@ -0,0 +1,66 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [https://neo4j.com] +# +# This file is part of Neo4j. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +__all__ = [ + "DateType", + "TimeType", + "DateTimeType", +] + + +class DateType(type): + + def __getattr__(cls, name): + try: + return { + "fromisoformat": cls.from_iso_format, + "fromordinal": cls.from_ordinal, + "fromtimestamp": cls.from_timestamp, + "utcfromtimestamp": cls.utc_from_timestamp, + }[name] + except KeyError: + raise AttributeError("%s has no attribute %r" % (cls.__name__, name)) + + +class TimeType(type): + + def __getattr__(cls, name): + try: + return { + "fromisoformat": cls.from_iso_format, + "utcnow": cls.utc_now, + }[name] + except KeyError: + raise AttributeError("%s has no attribute %r" % (cls.__name__, name)) + + +class DateTimeType(type): + + def __getattr__(cls, name): + try: + return { + "fromisoformat": cls.from_iso_format, + "fromordinal": cls.from_ordinal, + "fromtimestamp": cls.from_timestamp, + "strptime": cls.parse, + "today": cls.now, + "utcfromtimestamp": cls.utc_from_timestamp, + "utcnow": cls.utc_now, + }[name] + except KeyError: + raise AttributeError("%s has no attribute %r" % (cls.__name__, name)) diff --git a/neo4j/time/arithmetic.py b/neo4j/time/arithmetic.py index 6ab7b6581..7f0961f61 100644 --- a/neo4j/time/arithmetic.py +++ b/neo4j/time/arithmetic.py @@ -15,101 +15,29 @@ # See the License for the specific language governing permissions and # limitations under the License. - -def nano_add(x, y): - """ - - >>> 0.7 + 0.2 - 0.8999999999999999 - >>> -0.7 + 0.2 - -0.49999999999999994 - >>> nano_add(0.7, 0.2) - 0.9 - >>> nano_add(-0.7, 0.2) - -0.5 - - :param x: - :param y: - :return: - """ - return (int(1000000000 * x) + int(1000000000 * y)) / 1000000000 - - -def nano_div(x, y): - """ - - >>> 0.7 / 0.2 - 3.4999999999999996 - >>> -0.7 / 0.2 - -3.4999999999999996 - >>> nano_div(0.7, 0.2) - 3.5 - >>> nano_div(-0.7, 0.2) - -3.5 - - :param x: - :param y: - :return: - """ - return float(1000000000 * x) / int(1000000000 * y) - - -def nano_divmod(x, y): - """ - - >>> divmod(0.7, 0.2) - (3.0, 0.09999999999999992) - >>> nano_divmod(0.7, 0.2) - (3, 0.1) - - :param x: - :param y: - :return: - """ - number = type(x) - nx = int(1000000000 * x) - ny = int(1000000000 * y) - q, r = divmod(nx, ny) - return int(q), number(r / 1000000000) - - -def symmetric_divmod(dividend, divisor): - number = type(dividend) - if dividend >= 0: - quotient, remainder = divmod(dividend, divisor) - return int(quotient), number(remainder) - else: - quotient, remainder = divmod(-dividend, divisor) - return -int(quotient), -number(remainder) - - -def round_half_to_even(n): - """ - - >>> round_half_to_even(3) - 3 - >>> round_half_to_even(3.2) - 3 - >>> round_half_to_even(3.5) - 4 - >>> round_half_to_even(3.7) - 4 - >>> round_half_to_even(4) - 4 - >>> round_half_to_even(4.2) - 4 - >>> round_half_to_even(4.5) - 4 - >>> round_half_to_even(4.7) - 5 - - :param n: - :return: - """ - ten_n = 10 * n - if ten_n == int(ten_n) and ten_n % 10 == 5: - up = int(n + 0.5) - down = int(n - 0.5) - return up if up % 2 == 0 else down - else: - return int(round(n)) +# TODO: 6.0 - remove this file + + +from .._meta import deprecation_warn +from ._arithmetic import ( + nano_add, + nano_div, + nano_divmod, + round_half_to_even, + symmetric_divmod, +) + + +__all__ = [ + "nano_add", + "nano_div", + "nano_divmod", + "symmetric_divmod", + "round_half_to_even", +] + +deprecation_warn( + "The module 'neo4j.time.arithmetic' was made internal and will " + "no longer be available for import in future versions.", + stack_level=2 +) diff --git a/neo4j/time/clock_implementations.py b/neo4j/time/clock_implementations.py index 38e93ab5b..facfa5f61 100644 --- a/neo4j/time/clock_implementations.py +++ b/neo4j/time/clock_implementations.py @@ -15,98 +15,25 @@ # See the License for the specific language governing permissions and # limitations under the License. +# TODO: 6.0 - remove this file -from ctypes import ( - byref, - c_long, - c_longlong, - CDLL, - Structure, -) -from platform import uname -from neo4j.time import ( - Clock, - ClockTime, +from .._meta import deprecation_warn +from ._clock_implementations import ( + LibCClock, + PEP564Clock, + SafeClock, ) -from neo4j.time.arithmetic import nano_divmod - - -class SafeClock(Clock): - """ Clock implementation that should work for any variant of Python. - This clock is guaranteed microsecond precision. - """ - - @classmethod - def precision(cls): - return 6 - - @classmethod - def available(cls): - return True - - def utc_time(self): - from time import time - seconds, nanoseconds = nano_divmod(int(time() * 1000000), 1000000) - return ClockTime(seconds, nanoseconds * 1000) - - -class PEP564Clock(Clock): - """ Clock implementation based on the PEP564 additions to Python 3.7. - This clock is guaranteed nanosecond precision. - """ - @classmethod - def precision(cls): - return 9 - @classmethod - def available(cls): - try: - from time import time_ns - except ImportError: - return False - else: - return True +__all__ = [ + "SafeClock", + "PEP564Clock", + "LibCClock", +] - def utc_time(self): - from time import time_ns - t = time_ns() - seconds, nanoseconds = divmod(t, 1000000000) - return ClockTime(seconds, nanoseconds) - - -class LibCClock(Clock): - """ Clock implementation that works only on platforms that provide - libc. This clock is guaranteed nanosecond precision. - """ - - __libc = "libc.dylib" if uname()[0] == "Darwin" else "libc.so.6" - - class _TimeSpec(Structure): - _fields_ = [ - ("seconds", c_longlong), - ("nanoseconds", c_long), - ] - - @classmethod - def precision(cls): - return 9 - - @classmethod - def available(cls): - try: - _ = CDLL(cls.__libc) - except OSError: - return False - else: - return True - - def utc_time(self): - libc = CDLL(self.__libc) - ts = self._TimeSpec() - status = libc.clock_gettime(0, byref(ts)) - if status == 0: - return ClockTime(ts.seconds, ts.nanoseconds) - else: - raise RuntimeError("clock_gettime failed with status %d" % status) +deprecation_warn( + "The module 'neo4j.time.clock_implementations' was made internal and will " + "no longer be available for import in future versions.", + stack_level=2 +) diff --git a/neo4j/time/hydration.py b/neo4j/time/hydration.py index 056cafda0..4681c7c66 100644 --- a/neo4j/time/hydration.py +++ b/neo4j/time/hydration.py @@ -15,193 +15,43 @@ # See the License for the specific language governing permissions and # limitations under the License. - -from datetime import ( - datetime, - time, - timedelta, +# TODO: 6.0 - remove this file + + +from .._codec.hydration.v1.temporal import ( + dehydrate_date, + dehydrate_datetime, + dehydrate_duration, + dehydrate_time, + dehydrate_timedelta, + get_date_unix_epoch, + get_date_unix_epoch_ordinal, + get_datetime_unix_epoch_utc, + hydrate_date, + hydrate_datetime, + hydrate_duration, + hydrate_time, ) - -from neo4j.packstream import Structure -from neo4j.time import ( - Date, - DateTime, - Duration, - Time, +from .._meta import deprecation_warn + + +__all__ = [ + "get_date_unix_epoch", + "get_date_unix_epoch_ordinal", + "get_datetime_unix_epoch_utc", + "hydrate_date", + "dehydrate_date", + "hydrate_time", + "dehydrate_time", + "hydrate_datetime", + "dehydrate_datetime", + "hydrate_duration", + "dehydrate_duration", + "dehydrate_timedelta", +] + +deprecation_warn( + "The module 'neo4j.time.hydration' was made internal and will " + "no longer be available for import in future versions.", + stack_level=2 ) - - -def get_date_unix_epoch(): - return Date(1970, 1, 1) - - -def get_date_unix_epoch_ordinal(): - return get_date_unix_epoch().to_ordinal() - - -def get_datetime_unix_epoch_utc(): - from pytz import utc - return DateTime(1970, 1, 1, 0, 0, 0, utc) - - -def hydrate_date(days): - """ Hydrator for `Date` values. - - :param days: - :return: Date - """ - return Date.from_ordinal(get_date_unix_epoch_ordinal() + days) - - -def dehydrate_date(value): - """ Dehydrator for `date` values. - - :param value: - :type value: Date - :return: - """ - return Structure(b"D", value.toordinal() - get_date_unix_epoch().toordinal()) - - -def hydrate_time(nanoseconds, tz=None): - """ Hydrator for `Time` and `LocalTime` values. - - :param nanoseconds: - :param tz: - :return: Time - """ - from pytz import FixedOffset - seconds, nanoseconds = map(int, divmod(nanoseconds, 1000000000)) - minutes, seconds = map(int, divmod(seconds, 60)) - hours, minutes = map(int, divmod(minutes, 60)) - t = Time(hours, minutes, seconds, nanoseconds) - if tz is None: - return t - tz_offset_minutes, tz_offset_seconds = divmod(tz, 60) - zone = FixedOffset(tz_offset_minutes) - return zone.localize(t) - - -def dehydrate_time(value): - """ Dehydrator for `time` values. - - :param value: - :type value: Time - :return: - """ - if isinstance(value, Time): - nanoseconds = value.ticks - elif isinstance(value, time): - nanoseconds = (3600000000000 * value.hour + 60000000000 * value.minute + - 1000000000 * value.second + 1000 * value.microsecond) - else: - raise TypeError("Value must be a neo4j.time.Time or a datetime.time") - if value.tzinfo: - return Structure(b"T", nanoseconds, - int(value.tzinfo.utcoffset(value).total_seconds())) - else: - return Structure(b"t", nanoseconds) - - -def hydrate_datetime(seconds, nanoseconds, tz=None): - """ Hydrator for `DateTime` and `LocalDateTime` values. - - :param seconds: - :param nanoseconds: - :param tz: - :return: datetime - """ - from pytz import ( - FixedOffset, - timezone, - ) - minutes, seconds = map(int, divmod(seconds, 60)) - hours, minutes = map(int, divmod(minutes, 60)) - days, hours = map(int, divmod(hours, 24)) - t = DateTime.combine( - Date.from_ordinal(get_date_unix_epoch_ordinal() + days), - Time(hours, minutes, seconds, nanoseconds) - ) - if tz is None: - return t - if isinstance(tz, int): - tz_offset_minutes, tz_offset_seconds = divmod(tz, 60) - zone = FixedOffset(tz_offset_minutes) - else: - zone = timezone(tz) - return zone.localize(t) - - -def dehydrate_datetime(value): - """ Dehydrator for `datetime` values. - - :param value: - :type value: datetime or DateTime - :return: - """ - - def seconds_and_nanoseconds(dt): - if isinstance(dt, datetime): - dt = DateTime.from_native(dt) - zone_epoch = DateTime(1970, 1, 1, tzinfo=dt.tzinfo) - dt_clock_time = dt.to_clock_time() - zone_epoch_clock_time = zone_epoch.to_clock_time() - t = dt_clock_time - zone_epoch_clock_time - return t.seconds, t.nanoseconds - - tz = value.tzinfo - if tz is None: - # without time zone - from pytz import utc - value = utc.localize(value) - seconds, nanoseconds = seconds_and_nanoseconds(value) - return Structure(b"d", seconds, nanoseconds) - elif hasattr(tz, "zone") and tz.zone and isinstance(tz.zone, str): - # with named pytz time zone - seconds, nanoseconds = seconds_and_nanoseconds(value) - return Structure(b"f", seconds, nanoseconds, tz.zone) - elif hasattr(tz, "key") and tz.key and isinstance(tz.key, str): - # with named zoneinfo (Python 3.9+) time zone - seconds, nanoseconds = seconds_and_nanoseconds(value) - return Structure(b"f", seconds, nanoseconds, tz.key) - else: - # with time offset - seconds, nanoseconds = seconds_and_nanoseconds(value) - return Structure(b"F", seconds, nanoseconds, - int(tz.utcoffset(value).total_seconds())) - - -def hydrate_duration(months, days, seconds, nanoseconds): - """ Hydrator for `Duration` values. - - :param months: - :param days: - :param seconds: - :param nanoseconds: - :return: `duration` namedtuple - """ - return Duration(months=months, days=days, seconds=seconds, nanoseconds=nanoseconds) - - -def dehydrate_duration(value): - """ Dehydrator for `duration` values. - - :param value: - :type value: Duration - :return: - """ - return Structure(b"E", value.months, value.days, value.seconds, value.nanoseconds) - - -def dehydrate_timedelta(value): - """ Dehydrator for `timedelta` values. - - :param value: - :type value: timedelta - :return: - """ - months = 0 - days = value.days - seconds = value.seconds - nanoseconds = 1000 * value.microseconds - return Structure(b"E", months, days, seconds, nanoseconds) diff --git a/neo4j/time/metaclasses.py b/neo4j/time/metaclasses.py index 95be7e96c..c23101f1a 100644 --- a/neo4j/time/metaclasses.py +++ b/neo4j/time/metaclasses.py @@ -16,44 +16,22 @@ # limitations under the License. -class DateType(type): - - def __getattr__(cls, name): - try: - return { - "fromisoformat": cls.from_iso_format, - "fromordinal": cls.from_ordinal, - "fromtimestamp": cls.from_timestamp, - "utcfromtimestamp": cls.utc_from_timestamp, - }[name] - except KeyError: - raise AttributeError("%s has no attribute %r" % (cls.__name__, name)) - - -class TimeType(type): - - def __getattr__(cls, name): - try: - return { - "fromisoformat": cls.from_iso_format, - "utcnow": cls.utc_now, - }[name] - except KeyError: - raise AttributeError("%s has no attribute %r" % (cls.__name__, name)) - - -class DateTimeType(type): - - def __getattr__(cls, name): - try: - return { - "fromisoformat": cls.from_iso_format, - "fromordinal": cls.from_ordinal, - "fromtimestamp": cls.from_timestamp, - "strptime": cls.parse, - "today": cls.now, - "utcfromtimestamp": cls.utc_from_timestamp, - "utcnow": cls.utc_now, - }[name] - except KeyError: - raise AttributeError("%s has no attribute %r" % (cls.__name__, name)) +from .._meta import deprecation_warn +from ._metaclasses import ( + DateTimeType, + DateType, + TimeType, +) + + +__all__ = [ + "DateType", + "TimeType", + "DateTimeType", +] + +deprecation_warn( + "The module 'neo4j.time.metaclasses' was made internal and will " + "no longer be available for import in future versions.", + stack_level=2 +) diff --git a/setup.cfg b/setup.cfg index 421144bec..cf2d22207 100644 --- a/setup.cfg +++ b/setup.cfg @@ -2,9 +2,13 @@ combine_as_imports=true ensure_newline_before_comments=true force_grid_wrap=2 -force_sort_within_sections=true +# breaks order of relative imports +# https://github.com/PyCQA/isort/issues/1944 +#force_sort_within_sections=true include_trailing_comma=true -#lines_before_imports=2 # currently broken +# currently broken +# https://github.com/PyCQA/isort/issues/1855 +#lines_before_imports=2 lines_after_imports=2 lines_between_sections=1 multi_line_output=3 @@ -14,3 +18,4 @@ use_parentheses=true [tool:pytest] mock_use_standalone_module = true +asyncio_mode = auto diff --git a/setup.py b/setup.py index 31ce78611..7f7f38b7f 100644 --- a/setup.py +++ b/setup.py @@ -25,7 +25,7 @@ setup, ) -from neo4j.meta import ( +from neo4j._meta import ( package, version, ) diff --git a/testkitbackend/_async/backend.py b/testkitbackend/_async/backend.py index fce7386af..b939e9803 100644 --- a/testkitbackend/_async/backend.py +++ b/testkitbackend/_async/backend.py @@ -17,6 +17,7 @@ import asyncio +import traceback from inspect import ( getmembers, isfunction, @@ -26,7 +27,6 @@ loads, ) from pathlib import Path -import traceback from neo4j._exceptions import BoltError from neo4j.exceptions import ( @@ -35,13 +35,13 @@ UnsupportedServerProduct, ) -from . import requests from .._driver_logger import ( buffer_handler, log, ) from ..backend import Request from ..exceptions import MarkdAsDriverException +from . import requests TESTKIT_BACKEND_PATH = Path(__file__).absolute().resolve().parents[1] diff --git a/testkitbackend/_async/requests.py b/testkitbackend/_async/requests.py index 6dbd38561..0ea955cb6 100644 --- a/testkitbackend/_async/requests.py +++ b/testkitbackend/_async/requests.py @@ -17,9 +17,9 @@ import json -from os import path import re import warnings +from os import path import neo4j from neo4j._async_compat.util import AsyncUtil diff --git a/testkitbackend/_sync/backend.py b/testkitbackend/_sync/backend.py index 625e2ee9e..75a23d66d 100644 --- a/testkitbackend/_sync/backend.py +++ b/testkitbackend/_sync/backend.py @@ -17,6 +17,7 @@ import asyncio +import traceback from inspect import ( getmembers, isfunction, @@ -26,7 +27,6 @@ loads, ) from pathlib import Path -import traceback from neo4j._exceptions import BoltError from neo4j.exceptions import ( @@ -35,13 +35,13 @@ UnsupportedServerProduct, ) -from . import requests from .._driver_logger import ( buffer_handler, log, ) from ..backend import Request from ..exceptions import MarkdAsDriverException +from . import requests TESTKIT_BACKEND_PATH = Path(__file__).absolute().resolve().parents[1] diff --git a/testkitbackend/_sync/requests.py b/testkitbackend/_sync/requests.py index 798038fd0..776cfa95a 100644 --- a/testkitbackend/_sync/requests.py +++ b/testkitbackend/_sync/requests.py @@ -17,9 +17,9 @@ import json -from os import path import re import warnings +from os import path import neo4j from neo4j._async_compat.util import Util diff --git a/testkitbackend/test_config.json b/testkitbackend/test_config.json index 6751cfb8c..da189e1cf 100644 --- a/testkitbackend/test_config.json +++ b/testkitbackend/test_config.json @@ -11,7 +11,7 @@ "'stub.server_side_routing.test_server_side_routing.TestServerSideRouting.test_direct_connection_with_url_params'": "Driver emits deprecation warning. Behavior will be unified in 6.0.", "neo4j.datatypes.test_temporal_types.TestDataTypes.test_should_echo_all_timezone_ids": - "test_subtest_skips.tz_id", + "test_subtest_skips.dt_conversion", "neo4j.datatypes.test_temporal_types.TestDataTypes.test_date_time_cypher_created_tz_id": "test_subtest_skips.tz_id" }, @@ -40,6 +40,7 @@ "Feature:Bolt:4.3": true, "Feature:Bolt:4.4": true, "Feature:Bolt:5.0": true, + "Feature:Bolt:Patch:UTC": true, "Feature:Impersonation": true, "Feature:TLS:1.1": "Driver blocks TLS 1.1 for security reasons.", "Feature:TLS:1.2": true, diff --git a/testkitbackend/test_subtest_skips.py b/testkitbackend/test_subtest_skips.py index 68b02b471..6dfb6434e 100644 --- a/testkitbackend/test_subtest_skips.py +++ b/testkitbackend/test_subtest_skips.py @@ -25,6 +25,11 @@ """ +import pytz + +from . import fromtestkit + + def tz_id(**params): # We could do this automatically, but with an explicit black list we # make sure we know what we test and what we don't. @@ -51,3 +56,11 @@ def tz_id(**params): return ( "timezone id %s is not supported by the system" % params["tz_id"] ) + + +def dt_conversion(**params): + dt = params["dt"] + try: + fromtestkit.to_param(dt) + except (pytz.UnknownTimeZoneError, ValueError) as e: + return "cannot create desired dt %s: %r" % (dt, e) diff --git a/tests/conftest.py b/tests/conftest.py index 6a62e1129..4bcc74353 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -17,9 +17,9 @@ import asyncio +import warnings from functools import wraps from os import environ -import warnings import pytest import pytest_asyncio diff --git a/tests/env.py b/tests/env.py index ea9fb46ee..8e4b077ef 100644 --- a/tests/env.py +++ b/tests/env.py @@ -17,8 +17,8 @@ import abc -from os import environ import sys +from os import environ class _LazyEval(abc.ABC): diff --git a/tests/unit/async_/io/conftest.py b/tests/unit/async_/io/conftest.py index 05bbe8b94..08b6e9c41 100644 --- a/tests/unit/async_/io/conftest.py +++ b/tests/unit/async_/io/conftest.py @@ -24,20 +24,22 @@ import pytest -from neo4j._async.io._common import AsyncMessageInbox -from neo4j.packstream import ( - Packer, - UnpackableBuffer, - Unpacker, +from neo4j._async.io._common import ( + AsyncInbox, + AsyncOutbox, ) class AsyncFakeSocket: - def __init__(self, address): + def __init__(self, address, unpacker_cls=None): self.address = address self.captured = b"" - self.messages = AsyncMessageInbox(self, on_error=print) + self.messages = None + if unpacker_cls is not None: + self.messages = AsyncInbox( + self, on_error=print, unpacker_cls=unpacker_cls + ) def getsockname(self): return "127.0.0.1", 0xFFFF @@ -59,16 +61,27 @@ def close(self): return async def pop_message(self): - return await self.messages.pop() + assert self.messages + return await self.messages.pop(None) class AsyncFakeSocket2: - def __init__(self, address=None, on_send=None): + def __init__(self, address=None, on_send=None, + packer_cls=None, unpacker_cls=None): self.address = address self.recv_buffer = bytearray() - self._messages = AsyncMessageInbox(self, on_error=print) + # self.messages = AsyncMessageInbox(self, on_error=print) self.on_send = on_send + self._outbox = self._messages = None + if packer_cls: + self._outbox = AsyncOutbox( + self, on_error=print, packer_cls=packer_cls + ) + if unpacker_cls: + self._messages = AsyncInbox( + self, on_error=print, unpacker_cls=unpacker_cls + ) def getsockname(self): return "127.0.0.1", 0xFFFF @@ -93,50 +106,25 @@ def close(self): def inject(self, data): self.recv_buffer += data - def _pop_chunk(self): - chunk_size, = struct_unpack(">H", self.recv_buffer[:2]) - print("CHUNK SIZE %r" % chunk_size) - end = 2 + chunk_size - chunk_data, self.recv_buffer = self.recv_buffer[2:end], self.recv_buffer[end:] - return chunk_data - async def pop_message(self): - data = bytearray() - while True: - chunk = self._pop_chunk() - print("CHUNK %r" % chunk) - if chunk: - data.extend(chunk) - elif data: - break # end of message - else: - continue # NOOP - header = data[0] - n_fields = header % 0x10 - tag = data[1] - buffer = UnpackableBuffer(data[2:]) - unpacker = Unpacker(buffer) - fields = [unpacker.unpack() for _ in range(n_fields)] - return tag, fields + assert self._messages + return await self._messages.pop(None) async def send_message(self, tag, *fields): - data = self.encode_message(tag, *fields) - await self.sendall(struct_pack(">H", len(data)) + data + b"\x00\x00") - - @classmethod - def encode_message(cls, tag, *fields): - b = BytesIO() - packer = Packer(b) - for field in fields: - packer.pack(field) - return bytearray([0xB0 + len(fields), tag]) + b.getvalue() + assert self._outbox + self._outbox.append_message(tag, fields, None) + await self._outbox.flush() class AsyncFakeSocketPair: - def __init__(self, address): - self.client = AsyncFakeSocket2(address) - self.server = AsyncFakeSocket2() + def __init__(self, address, packer_cls=None, unpacker_cls=None): + self.client = AsyncFakeSocket2( + address, packer_cls=packer_cls, unpacker_cls=unpacker_cls + ) + self.server = AsyncFakeSocket2( + packer_cls=packer_cls, unpacker_cls=unpacker_cls + ) self.client.on_send = self.server.inject self.server.on_send = self.client.inject diff --git a/tests/unit/async_/io/test__common.py b/tests/unit/async_/io/test__common.py index bc95738a3..1c14ea202 100644 --- a/tests/unit/async_/io/test__common.py +++ b/tests/unit/async_/io/test__common.py @@ -18,33 +18,43 @@ import pytest -from neo4j._async.io._common import Outbox +from neo4j._async.io._common import AsyncOutbox +from neo4j._codec.packstream.v1 import PackableBuffer + +from ...._async_compat import mark_async_test @pytest.mark.parametrize(("chunk_size", "data", "result"), ( ( 2, - (bytes(range(10, 15)),), + bytes(range(10, 15)), bytes((0, 2, 10, 11, 0, 2, 12, 13, 0, 1, 14)) ), ( 2, - (bytes(range(10, 14)),), + bytes(range(10, 14)), bytes((0, 2, 10, 11, 0, 2, 12, 13)) ), ( 2, - (bytes((5, 6, 7)), bytes((8, 9))), - bytes((0, 2, 5, 6, 0, 2, 7, 8, 0, 1, 9)) + bytes((5,)), + bytes((0, 1, 5)) ), )) -def test_async_outbox_chunking(chunk_size, data, result): - outbox = Outbox(max_chunk_size=chunk_size) - assert bytes(outbox.view()) == b"" - for d in data: - outbox.write(d) - assert bytes(outbox.view()) == result - # make sure this works multiple times - assert bytes(outbox.view()) == result - outbox.clear() - assert bytes(outbox.view()) == b"" +@mark_async_test +async def test_async_outbox_chunking(chunk_size, data, result, mocker): + buffer = PackableBuffer() + socket_mock = mocker.AsyncMock() + packer_mock = mocker.Mock() + packer_mock.return_value = packer_mock + packer_mock.new_packable_buffer.return_value = buffer + packer_mock.pack_struct.side_effect = \ + lambda *args, **kwargs: buffer.write(data) + outbox = AsyncOutbox(socket_mock, pytest.fail, packer_mock, chunk_size) + outbox.append_message(None, None, None) + socket_mock.sendall.assert_not_called() + assert await outbox.flush() + socket_mock.sendall.assert_awaited_once_with(result + b"\x00\x00") + + assert not await outbox.flush() + socket_mock.sendall.assert_awaited_once() diff --git a/tests/unit/async_/io/test_class_bolt3.py b/tests/unit/async_/io/test_class_bolt3.py index 8644aaded..aa6aac101 100644 --- a/tests/unit/async_/io/test_class_bolt3.py +++ b/tests/unit/async_/io/test_class_bolt3.py @@ -19,7 +19,7 @@ import pytest from neo4j._async.io._bolt3 import AsyncBolt3 -from neo4j.conf import PoolConfig +from neo4j._conf import PoolConfig from neo4j.exceptions import ConfigurationError from ...._async_compat import mark_async_test @@ -72,7 +72,7 @@ def test_db_extra_not_supported_in_run(fake_socket): @mark_async_test async def test_simple_discard(fake_socket): address = ("127.0.0.1", 7687) - socket = fake_socket(address) + socket = fake_socket(address, AsyncBolt3.UNPACKER_CLS) connection = AsyncBolt3(address, socket, PoolConfig.max_connection_lifetime) connection.discard() await connection.send_all() @@ -84,7 +84,7 @@ async def test_simple_discard(fake_socket): @mark_async_test async def test_simple_pull(fake_socket): address = ("127.0.0.1", 7687) - socket = fake_socket(address) + socket = fake_socket(address, AsyncBolt3.UNPACKER_CLS) connection = AsyncBolt3(address, socket, PoolConfig.max_connection_lifetime) connection.pull() await connection.send_all() @@ -99,9 +99,11 @@ async def test_hint_recv_timeout_seconds_gets_ignored( fake_socket_pair, recv_timeout, mocker ): address = ("127.0.0.1", 7687) - sockets = fake_socket_pair(address) + sockets = fake_socket_pair( + address, AsyncBolt3.PACKER_CLS, AsyncBolt3.UNPACKER_CLS + ) sockets.client.settimeout = mocker.AsyncMock() - await sockets.server.send_message(0x70, { + await sockets.server.send_message(b"\x70", { "server": "Neo4j/3.5.0", "hints": {"connection.recv_timeout_seconds": recv_timeout}, }) diff --git a/tests/unit/async_/io/test_class_bolt4x0.py b/tests/unit/async_/io/test_class_bolt4x0.py index 7ba714e8a..56b15f4d0 100644 --- a/tests/unit/async_/io/test_class_bolt4x0.py +++ b/tests/unit/async_/io/test_class_bolt4x0.py @@ -19,7 +19,7 @@ import pytest from neo4j._async.io._bolt4 import AsyncBolt4x0 -from neo4j.conf import PoolConfig +from neo4j._conf import PoolConfig from ...._async_compat import mark_async_test @@ -57,7 +57,7 @@ def test_conn_is_not_stale(fake_socket, set_stale): @mark_async_test async def test_db_extra_in_begin(fake_socket): address = ("127.0.0.1", 7687) - socket = fake_socket(address) + socket = fake_socket(address, AsyncBolt4x0.UNPACKER_CLS) connection = AsyncBolt4x0(address, socket, PoolConfig.max_connection_lifetime) connection.begin(db="something") await connection.send_all() @@ -70,7 +70,7 @@ async def test_db_extra_in_begin(fake_socket): @mark_async_test async def test_db_extra_in_run(fake_socket): address = ("127.0.0.1", 7687) - socket = fake_socket(address) + socket = fake_socket(address, AsyncBolt4x0.UNPACKER_CLS) connection = AsyncBolt4x0(address, socket, PoolConfig.max_connection_lifetime) connection.run("", {}, db="something") await connection.send_all() @@ -85,7 +85,7 @@ async def test_db_extra_in_run(fake_socket): @mark_async_test async def test_n_extra_in_discard(fake_socket): address = ("127.0.0.1", 7687) - socket = fake_socket(address) + socket = fake_socket(address, AsyncBolt4x0.UNPACKER_CLS) connection = AsyncBolt4x0(address, socket, PoolConfig.max_connection_lifetime) connection.discard(n=666) await connection.send_all() @@ -105,7 +105,7 @@ async def test_n_extra_in_discard(fake_socket): @mark_async_test async def test_qid_extra_in_discard(fake_socket, test_input, expected): address = ("127.0.0.1", 7687) - socket = fake_socket(address) + socket = fake_socket(address, AsyncBolt4x0.UNPACKER_CLS) connection = AsyncBolt4x0(address, socket, PoolConfig.max_connection_lifetime) connection.discard(qid=test_input) await connection.send_all() @@ -125,7 +125,7 @@ async def test_qid_extra_in_discard(fake_socket, test_input, expected): @mark_async_test async def test_n_and_qid_extras_in_discard(fake_socket, test_input, expected): address = ("127.0.0.1", 7687) - socket = fake_socket(address) + socket = fake_socket(address, AsyncBolt4x0.UNPACKER_CLS) connection = AsyncBolt4x0(address, socket, PoolConfig.max_connection_lifetime) connection.discard(n=666, qid=test_input) await connection.send_all() @@ -145,7 +145,7 @@ async def test_n_and_qid_extras_in_discard(fake_socket, test_input, expected): @mark_async_test async def test_n_extra_in_pull(fake_socket, test_input, expected): address = ("127.0.0.1", 7687) - socket = fake_socket(address) + socket = fake_socket(address, AsyncBolt4x0.UNPACKER_CLS) connection = AsyncBolt4x0(address, socket, PoolConfig.max_connection_lifetime) connection.pull(n=test_input) await connection.send_all() @@ -165,7 +165,7 @@ async def test_n_extra_in_pull(fake_socket, test_input, expected): @mark_async_test async def test_qid_extra_in_pull(fake_socket, test_input, expected): address = ("127.0.0.1", 7687) - socket = fake_socket(address) + socket = fake_socket(address, AsyncBolt4x0.UNPACKER_CLS) connection = AsyncBolt4x0(address, socket, PoolConfig.max_connection_lifetime) connection.pull(qid=test_input) await connection.send_all() @@ -178,7 +178,7 @@ async def test_qid_extra_in_pull(fake_socket, test_input, expected): @mark_async_test async def test_n_and_qid_extras_in_pull(fake_socket): address = ("127.0.0.1", 7687) - socket = fake_socket(address) + socket = fake_socket(address, AsyncBolt4x0.UNPACKER_CLS) connection = AsyncBolt4x0(address, socket, PoolConfig.max_connection_lifetime) connection.pull(n=666, qid=777) await connection.send_all() @@ -194,9 +194,11 @@ async def test_hint_recv_timeout_seconds_gets_ignored( fake_socket_pair, recv_timeout, mocker ): address = ("127.0.0.1", 7687) - sockets = fake_socket_pair(address) + sockets = fake_socket_pair(address, + packer_cls=AsyncBolt4x0.PACKER_CLS, + unpacker_cls=AsyncBolt4x0.UNPACKER_CLS) sockets.client.settimeout = mocker.MagicMock() - await sockets.server.send_message(0x70, { + await sockets.server.send_message(b"\x70", { "server": "Neo4j/4.0.0", "hints": {"connection.recv_timeout_seconds": recv_timeout}, }) diff --git a/tests/unit/async_/io/test_class_bolt4x1.py b/tests/unit/async_/io/test_class_bolt4x1.py index 47cca348e..4371f005e 100644 --- a/tests/unit/async_/io/test_class_bolt4x1.py +++ b/tests/unit/async_/io/test_class_bolt4x1.py @@ -19,7 +19,7 @@ import pytest from neo4j._async.io._bolt4 import AsyncBolt4x1 -from neo4j.conf import PoolConfig +from neo4j._conf import PoolConfig from ...._async_compat import mark_async_test @@ -57,7 +57,7 @@ def test_conn_is_not_stale(fake_socket, set_stale): @mark_async_test async def test_db_extra_in_begin(fake_socket): address = ("127.0.0.1", 7687) - socket = fake_socket(address) + socket = fake_socket(address, AsyncBolt4x1.UNPACKER_CLS) connection = AsyncBolt4x1(address, socket, PoolConfig.max_connection_lifetime) connection.begin(db="something") await connection.send_all() @@ -70,7 +70,7 @@ async def test_db_extra_in_begin(fake_socket): @mark_async_test async def test_db_extra_in_run(fake_socket): address = ("127.0.0.1", 7687) - socket = fake_socket(address) + socket = fake_socket(address, AsyncBolt4x1.UNPACKER_CLS) connection = AsyncBolt4x1(address, socket, PoolConfig.max_connection_lifetime) connection.run("", {}, db="something") await connection.send_all() @@ -85,7 +85,7 @@ async def test_db_extra_in_run(fake_socket): @mark_async_test async def test_n_extra_in_discard(fake_socket): address = ("127.0.0.1", 7687) - socket = fake_socket(address) + socket = fake_socket(address, AsyncBolt4x1.UNPACKER_CLS) connection = AsyncBolt4x1(address, socket, PoolConfig.max_connection_lifetime) connection.discard(n=666) await connection.send_all() @@ -105,7 +105,7 @@ async def test_n_extra_in_discard(fake_socket): @mark_async_test async def test_qid_extra_in_discard(fake_socket, test_input, expected): address = ("127.0.0.1", 7687) - socket = fake_socket(address) + socket = fake_socket(address, AsyncBolt4x1.UNPACKER_CLS) connection = AsyncBolt4x1(address, socket, PoolConfig.max_connection_lifetime) connection.discard(qid=test_input) await connection.send_all() @@ -126,7 +126,7 @@ async def test_qid_extra_in_discard(fake_socket, test_input, expected): async def test_n_and_qid_extras_in_discard(fake_socket, test_input, expected): # python -m pytest tests/unit/io/test_class_bolt4x0.py -s -k test_n_and_qid_extras_in_discard address = ("127.0.0.1", 7687) - socket = fake_socket(address) + socket = fake_socket(address, AsyncBolt4x1.UNPACKER_CLS) connection = AsyncBolt4x1(address, socket, PoolConfig.max_connection_lifetime) connection.discard(n=666, qid=test_input) await connection.send_all() @@ -146,7 +146,7 @@ async def test_n_and_qid_extras_in_discard(fake_socket, test_input, expected): @mark_async_test async def test_n_extra_in_pull(fake_socket, test_input, expected): address = ("127.0.0.1", 7687) - socket = fake_socket(address) + socket = fake_socket(address, AsyncBolt4x1.UNPACKER_CLS) connection = AsyncBolt4x1(address, socket, PoolConfig.max_connection_lifetime) connection.pull(n=test_input) await connection.send_all() @@ -167,7 +167,7 @@ async def test_n_extra_in_pull(fake_socket, test_input, expected): async def test_qid_extra_in_pull(fake_socket, test_input, expected): # python -m pytest tests/unit/io/test_class_bolt4x0.py -s -k test_qid_extra_in_pull address = ("127.0.0.1", 7687) - socket = fake_socket(address) + socket = fake_socket(address, AsyncBolt4x1.UNPACKER_CLS) connection = AsyncBolt4x1(address, socket, PoolConfig.max_connection_lifetime) connection.pull(qid=test_input) await connection.send_all() @@ -180,7 +180,7 @@ async def test_qid_extra_in_pull(fake_socket, test_input, expected): @mark_async_test async def test_n_and_qid_extras_in_pull(fake_socket): address = ("127.0.0.1", 7687) - socket = fake_socket(address) + socket = fake_socket(address, AsyncBolt4x1.UNPACKER_CLS) connection = AsyncBolt4x1(address, socket, PoolConfig.max_connection_lifetime) connection.pull(n=666, qid=777) await connection.send_all() @@ -193,15 +193,17 @@ async def test_n_and_qid_extras_in_pull(fake_socket): @mark_async_test async def test_hello_passes_routing_metadata(fake_socket_pair): address = ("127.0.0.1", 7687) - sockets = fake_socket_pair(address) - await sockets.server.send_message(0x70, {"server": "Neo4j/4.1.0"}) + sockets = fake_socket_pair(address, + packer_cls=AsyncBolt4x1.PACKER_CLS, + unpacker_cls=AsyncBolt4x1.UNPACKER_CLS) + await sockets.server.send_message(b"\x70", {"server": "Neo4j/4.1.0"}) connection = AsyncBolt4x1( address, sockets.client, PoolConfig.max_connection_lifetime, routing_context={"foo": "bar"} ) await connection.hello() tag, fields = await sockets.server.pop_message() - assert tag == 0x01 + assert tag == b"\x01" assert len(fields) == 1 assert fields[0]["routing"] == {"foo": "bar"} @@ -212,9 +214,11 @@ async def test_hint_recv_timeout_seconds_gets_ignored( fake_socket_pair, recv_timeout, mocker ): address = ("127.0.0.1", 7687) - sockets = fake_socket_pair(address) + sockets = fake_socket_pair(address, + packer_cls=AsyncBolt4x1.PACKER_CLS, + unpacker_cls=AsyncBolt4x1.UNPACKER_CLS) sockets.client.settimeout = mocker.AsyncMock() - await sockets.server.send_message(0x70, { + await sockets.server.send_message(b"\x70", { "server": "Neo4j/4.1.0", "hints": {"connection.recv_timeout_seconds": recv_timeout}, }) diff --git a/tests/unit/async_/io/test_class_bolt4x2.py b/tests/unit/async_/io/test_class_bolt4x2.py index bb3921c8b..804038cb1 100644 --- a/tests/unit/async_/io/test_class_bolt4x2.py +++ b/tests/unit/async_/io/test_class_bolt4x2.py @@ -19,7 +19,7 @@ import pytest from neo4j._async.io._bolt4 import AsyncBolt4x2 -from neo4j.conf import PoolConfig +from neo4j._conf import PoolConfig from ...._async_compat import mark_async_test @@ -57,7 +57,7 @@ def test_conn_is_not_stale(fake_socket, set_stale): @mark_async_test async def test_db_extra_in_begin(fake_socket): address = ("127.0.0.1", 7687) - socket = fake_socket(address) + socket = fake_socket(address, AsyncBolt4x2.UNPACKER_CLS) connection = AsyncBolt4x2(address, socket, PoolConfig.max_connection_lifetime) connection.begin(db="something") await connection.send_all() @@ -70,7 +70,7 @@ async def test_db_extra_in_begin(fake_socket): @mark_async_test async def test_db_extra_in_run(fake_socket): address = ("127.0.0.1", 7687) - socket = fake_socket(address) + socket = fake_socket(address, AsyncBolt4x2.UNPACKER_CLS) connection = AsyncBolt4x2(address, socket, PoolConfig.max_connection_lifetime) connection.run("", {}, db="something") await connection.send_all() @@ -85,7 +85,7 @@ async def test_db_extra_in_run(fake_socket): @mark_async_test async def test_n_extra_in_discard(fake_socket): address = ("127.0.0.1", 7687) - socket = fake_socket(address) + socket = fake_socket(address, AsyncBolt4x2.UNPACKER_CLS) connection = AsyncBolt4x2(address, socket, PoolConfig.max_connection_lifetime) connection.discard(n=666) await connection.send_all() @@ -105,7 +105,7 @@ async def test_n_extra_in_discard(fake_socket): @mark_async_test async def test_qid_extra_in_discard(fake_socket, test_input, expected): address = ("127.0.0.1", 7687) - socket = fake_socket(address) + socket = fake_socket(address, AsyncBolt4x2.UNPACKER_CLS) connection = AsyncBolt4x2(address, socket, PoolConfig.max_connection_lifetime) connection.discard(qid=test_input) await connection.send_all() @@ -126,7 +126,7 @@ async def test_qid_extra_in_discard(fake_socket, test_input, expected): async def test_n_and_qid_extras_in_discard(fake_socket, test_input, expected): # python -m pytest tests/unit/io/test_class_bolt4x0.py -s -k test_n_and_qid_extras_in_discard address = ("127.0.0.1", 7687) - socket = fake_socket(address) + socket = fake_socket(address, AsyncBolt4x2.UNPACKER_CLS) connection = AsyncBolt4x2(address, socket, PoolConfig.max_connection_lifetime) connection.discard(n=666, qid=test_input) await connection.send_all() @@ -146,7 +146,7 @@ async def test_n_and_qid_extras_in_discard(fake_socket, test_input, expected): @mark_async_test async def test_n_extra_in_pull(fake_socket, test_input, expected): address = ("127.0.0.1", 7687) - socket = fake_socket(address) + socket = fake_socket(address, AsyncBolt4x2.UNPACKER_CLS) connection = AsyncBolt4x2(address, socket, PoolConfig.max_connection_lifetime) connection.pull(n=test_input) await connection.send_all() @@ -167,7 +167,7 @@ async def test_n_extra_in_pull(fake_socket, test_input, expected): async def test_qid_extra_in_pull(fake_socket, test_input, expected): # python -m pytest tests/unit/io/test_class_bolt4x0.py -s -k test_qid_extra_in_pull address = ("127.0.0.1", 7687) - socket = fake_socket(address) + socket = fake_socket(address, AsyncBolt4x2.UNPACKER_CLS) connection = AsyncBolt4x2(address, socket, PoolConfig.max_connection_lifetime) connection.pull(qid=test_input) await connection.send_all() @@ -180,7 +180,7 @@ async def test_qid_extra_in_pull(fake_socket, test_input, expected): @mark_async_test async def test_n_and_qid_extras_in_pull(fake_socket): address = ("127.0.0.1", 7687) - socket = fake_socket(address) + socket = fake_socket(address, AsyncBolt4x2.UNPACKER_CLS) connection = AsyncBolt4x2(address, socket, PoolConfig.max_connection_lifetime) connection.pull(n=666, qid=777) await connection.send_all() @@ -193,15 +193,17 @@ async def test_n_and_qid_extras_in_pull(fake_socket): @mark_async_test async def test_hello_passes_routing_metadata(fake_socket_pair): address = ("127.0.0.1", 7687) - sockets = fake_socket_pair(address) - await sockets.server.send_message(0x70, {"server": "Neo4j/4.2.0"}) + sockets = fake_socket_pair(address, + packer_cls=AsyncBolt4x2.PACKER_CLS, + unpacker_cls=AsyncBolt4x2.UNPACKER_CLS) + await sockets.server.send_message(b"\x70", {"server": "Neo4j/4.2.0"}) connection = AsyncBolt4x2( address, sockets.client, PoolConfig.max_connection_lifetime, routing_context={"foo": "bar"} ) await connection.hello() tag, fields = await sockets.server.pop_message() - assert tag == 0x01 + assert tag == b"\x01" assert len(fields) == 1 assert fields[0]["routing"] == {"foo": "bar"} @@ -212,9 +214,11 @@ async def test_hint_recv_timeout_seconds_gets_ignored( fake_socket_pair, recv_timeout, mocker ): address = ("127.0.0.1", 7687) - sockets = fake_socket_pair(address) + sockets = fake_socket_pair(address, + packer_cls=AsyncBolt4x2.PACKER_CLS, + unpacker_cls=AsyncBolt4x2.UNPACKER_CLS) sockets.client.settimeout = mocker.AsyncMock() - await sockets.server.send_message(0x70, { + await sockets.server.send_message(b"\x70", { "server": "Neo4j/4.2.0", "hints": {"connection.recv_timeout_seconds": recv_timeout}, }) diff --git a/tests/unit/async_/io/test_class_bolt4x3.py b/tests/unit/async_/io/test_class_bolt4x3.py index fff16687e..7538e127b 100644 --- a/tests/unit/async_/io/test_class_bolt4x3.py +++ b/tests/unit/async_/io/test_class_bolt4x3.py @@ -21,7 +21,7 @@ import pytest from neo4j._async.io._bolt4 import AsyncBolt4x3 -from neo4j.conf import PoolConfig +from neo4j._conf import PoolConfig from ...._async_compat import mark_async_test @@ -59,7 +59,7 @@ def test_conn_is_not_stale(fake_socket, set_stale): @mark_async_test async def test_db_extra_in_begin(fake_socket): address = ("127.0.0.1", 7687) - socket = fake_socket(address) + socket = fake_socket(address, AsyncBolt4x3.UNPACKER_CLS) connection = AsyncBolt4x3(address, socket, PoolConfig.max_connection_lifetime) connection.begin(db="something") await connection.send_all() @@ -72,7 +72,7 @@ async def test_db_extra_in_begin(fake_socket): @mark_async_test async def test_db_extra_in_run(fake_socket): address = ("127.0.0.1", 7687) - socket = fake_socket(address) + socket = fake_socket(address, AsyncBolt4x3.UNPACKER_CLS) connection = AsyncBolt4x3(address, socket, PoolConfig.max_connection_lifetime) connection.run("", {}, db="something") await connection.send_all() @@ -87,7 +87,7 @@ async def test_db_extra_in_run(fake_socket): @mark_async_test async def test_n_extra_in_discard(fake_socket): address = ("127.0.0.1", 7687) - socket = fake_socket(address) + socket = fake_socket(address, AsyncBolt4x3.UNPACKER_CLS) connection = AsyncBolt4x3(address, socket, PoolConfig.max_connection_lifetime) connection.discard(n=666) await connection.send_all() @@ -107,7 +107,7 @@ async def test_n_extra_in_discard(fake_socket): @mark_async_test async def test_qid_extra_in_discard(fake_socket, test_input, expected): address = ("127.0.0.1", 7687) - socket = fake_socket(address) + socket = fake_socket(address, AsyncBolt4x3.UNPACKER_CLS) connection = AsyncBolt4x3(address, socket, PoolConfig.max_connection_lifetime) connection.discard(qid=test_input) await connection.send_all() @@ -128,7 +128,7 @@ async def test_qid_extra_in_discard(fake_socket, test_input, expected): async def test_n_and_qid_extras_in_discard(fake_socket, test_input, expected): # python -m pytest tests/unit/io/test_class_bolt4x0.py -s -k test_n_and_qid_extras_in_discard address = ("127.0.0.1", 7687) - socket = fake_socket(address) + socket = fake_socket(address, AsyncBolt4x3.UNPACKER_CLS) connection = AsyncBolt4x3(address, socket, PoolConfig.max_connection_lifetime) connection.discard(n=666, qid=test_input) await connection.send_all() @@ -148,7 +148,7 @@ async def test_n_and_qid_extras_in_discard(fake_socket, test_input, expected): @mark_async_test async def test_n_extra_in_pull(fake_socket, test_input, expected): address = ("127.0.0.1", 7687) - socket = fake_socket(address) + socket = fake_socket(address, AsyncBolt4x3.UNPACKER_CLS) connection = AsyncBolt4x3(address, socket, PoolConfig.max_connection_lifetime) connection.pull(n=test_input) await connection.send_all() @@ -169,7 +169,7 @@ async def test_n_extra_in_pull(fake_socket, test_input, expected): async def test_qid_extra_in_pull(fake_socket, test_input, expected): # python -m pytest tests/unit/io/test_class_bolt4x0.py -s -k test_qid_extra_in_pull address = ("127.0.0.1", 7687) - socket = fake_socket(address) + socket = fake_socket(address, AsyncBolt4x3.UNPACKER_CLS) connection = AsyncBolt4x3(address, socket, PoolConfig.max_connection_lifetime) connection.pull(qid=test_input) await connection.send_all() @@ -182,7 +182,7 @@ async def test_qid_extra_in_pull(fake_socket, test_input, expected): @mark_async_test async def test_n_and_qid_extras_in_pull(fake_socket): address = ("127.0.0.1", 7687) - socket = fake_socket(address) + socket = fake_socket(address, AsyncBolt4x3.UNPACKER_CLS) connection = AsyncBolt4x3(address, socket, PoolConfig.max_connection_lifetime) connection.pull(n=666, qid=777) await connection.send_all() @@ -195,15 +195,17 @@ async def test_n_and_qid_extras_in_pull(fake_socket): @mark_async_test async def test_hello_passes_routing_metadata(fake_socket_pair): address = ("127.0.0.1", 7687) - sockets = fake_socket_pair(address) - await sockets.server.send_message(0x70, {"server": "Neo4j/4.3.0"}) + sockets = fake_socket_pair(address, + packer_cls=AsyncBolt4x3.PACKER_CLS, + unpacker_cls=AsyncBolt4x3.UNPACKER_CLS) + await sockets.server.send_message(b"\x70", {"server": "Neo4j/4.3.0"}) connection = AsyncBolt4x3( address, sockets.client, PoolConfig.max_connection_lifetime, routing_context={"foo": "bar"} ) await connection.hello() tag, fields = await sockets.server.pop_message() - assert tag == 0x01 + assert tag == b"\x01" assert len(fields) == 1 assert fields[0]["routing"] == {"foo": "bar"} @@ -225,10 +227,12 @@ async def test_hint_recv_timeout_seconds( fake_socket_pair, hints, valid, caplog, mocker ): address = ("127.0.0.1", 7687) - sockets = fake_socket_pair(address) + sockets = fake_socket_pair(address, + packer_cls=AsyncBolt4x3.PACKER_CLS, + unpacker_cls=AsyncBolt4x3.UNPACKER_CLS) sockets.client.settimeout = mocker.AsyncMock() await sockets.server.send_message( - 0x70, {"server": "Neo4j/4.3.0", "hints": hints} + b"\x70", {"server": "Neo4j/4.3.0", "hints": hints} ) connection = AsyncBolt4x3( address, sockets.client, PoolConfig.max_connection_lifetime diff --git a/tests/unit/async_/io/test_class_bolt4x4.py b/tests/unit/async_/io/test_class_bolt4x4.py index 5507fbc7f..285aa9744 100644 --- a/tests/unit/async_/io/test_class_bolt4x4.py +++ b/tests/unit/async_/io/test_class_bolt4x4.py @@ -21,7 +21,7 @@ import pytest from neo4j._async.io._bolt4 import AsyncBolt4x4 -from neo4j.conf import PoolConfig +from neo4j._conf import PoolConfig from ...._async_compat import mark_async_test @@ -68,7 +68,7 @@ def test_conn_is_not_stale(fake_socket, set_stale): @mark_async_test async def test_extra_in_begin(fake_socket, args, kwargs, expected_fields): address = ("127.0.0.1", 7687) - socket = fake_socket(address) + socket = fake_socket(address, AsyncBolt4x4.UNPACKER_CLS) connection = AsyncBolt4x4(address, socket, PoolConfig.max_connection_lifetime) connection.begin(*args, **kwargs) await connection.send_all() @@ -89,7 +89,7 @@ async def test_extra_in_begin(fake_socket, args, kwargs, expected_fields): @mark_async_test async def test_extra_in_run(fake_socket, args, kwargs, expected_fields): address = ("127.0.0.1", 7687) - socket = fake_socket(address) + socket = fake_socket(address, AsyncBolt4x4.UNPACKER_CLS) connection = AsyncBolt4x4(address, socket, PoolConfig.max_connection_lifetime) connection.run(*args, **kwargs) await connection.send_all() @@ -101,7 +101,7 @@ async def test_extra_in_run(fake_socket, args, kwargs, expected_fields): @mark_async_test async def test_n_extra_in_discard(fake_socket): address = ("127.0.0.1", 7687) - socket = fake_socket(address) + socket = fake_socket(address, AsyncBolt4x4.UNPACKER_CLS) connection = AsyncBolt4x4(address, socket, PoolConfig.max_connection_lifetime) connection.discard(n=666) await connection.send_all() @@ -121,7 +121,7 @@ async def test_n_extra_in_discard(fake_socket): @mark_async_test async def test_qid_extra_in_discard(fake_socket, test_input, expected): address = ("127.0.0.1", 7687) - socket = fake_socket(address) + socket = fake_socket(address, AsyncBolt4x4.UNPACKER_CLS) connection = AsyncBolt4x4(address, socket, PoolConfig.max_connection_lifetime) connection.discard(qid=test_input) await connection.send_all() @@ -142,7 +142,7 @@ async def test_qid_extra_in_discard(fake_socket, test_input, expected): async def test_n_and_qid_extras_in_discard(fake_socket, test_input, expected): # python -m pytest tests/unit/io/test_class_bolt4x0.py -s -k test_n_and_qid_extras_in_discard address = ("127.0.0.1", 7687) - socket = fake_socket(address) + socket = fake_socket(address, AsyncBolt4x4.UNPACKER_CLS) connection = AsyncBolt4x4(address, socket, PoolConfig.max_connection_lifetime) connection.discard(n=666, qid=test_input) await connection.send_all() @@ -162,7 +162,7 @@ async def test_n_and_qid_extras_in_discard(fake_socket, test_input, expected): @mark_async_test async def test_n_extra_in_pull(fake_socket, test_input, expected): address = ("127.0.0.1", 7687) - socket = fake_socket(address) + socket = fake_socket(address, AsyncBolt4x4.UNPACKER_CLS) connection = AsyncBolt4x4(address, socket, PoolConfig.max_connection_lifetime) connection.pull(n=test_input) await connection.send_all() @@ -183,7 +183,7 @@ async def test_n_extra_in_pull(fake_socket, test_input, expected): async def test_qid_extra_in_pull(fake_socket, test_input, expected): # python -m pytest tests/unit/io/test_class_bolt4x0.py -s -k test_qid_extra_in_pull address = ("127.0.0.1", 7687) - socket = fake_socket(address) + socket = fake_socket(address, AsyncBolt4x4.UNPACKER_CLS) connection = AsyncBolt4x4(address, socket, PoolConfig.max_connection_lifetime) connection.pull(qid=test_input) await connection.send_all() @@ -196,7 +196,7 @@ async def test_qid_extra_in_pull(fake_socket, test_input, expected): @mark_async_test async def test_n_and_qid_extras_in_pull(fake_socket): address = ("127.0.0.1", 7687) - socket = fake_socket(address) + socket = fake_socket(address, AsyncBolt4x4.UNPACKER_CLS) connection = AsyncBolt4x4(address, socket, PoolConfig.max_connection_lifetime) connection.pull(n=666, qid=777) await connection.send_all() @@ -209,15 +209,17 @@ async def test_n_and_qid_extras_in_pull(fake_socket): @mark_async_test async def test_hello_passes_routing_metadata(fake_socket_pair): address = ("127.0.0.1", 7687) - sockets = fake_socket_pair(address) - await sockets.server.send_message(0x70, {"server": "Neo4j/4.4.0"}) + sockets = fake_socket_pair(address, + packer_cls=AsyncBolt4x4.PACKER_CLS, + unpacker_cls=AsyncBolt4x4.UNPACKER_CLS) + await sockets.server.send_message(b"\x70", {"server": "Neo4j/4.4.0"}) connection = AsyncBolt4x4( address, sockets.client, PoolConfig.max_connection_lifetime, routing_context={"foo": "bar"} ) await connection.hello() tag, fields = await sockets.server.pop_message() - assert tag == 0x01 + assert tag == b"\x01" assert len(fields) == 1 assert fields[0]["routing"] == {"foo": "bar"} @@ -239,10 +241,12 @@ async def test_hint_recv_timeout_seconds( fake_socket_pair, hints, valid, caplog, mocker ): address = ("127.0.0.1", 7687) - sockets = fake_socket_pair(address) + sockets = fake_socket_pair(address, + packer_cls=AsyncBolt4x4.PACKER_CLS, + unpacker_cls=AsyncBolt4x4.UNPACKER_CLS) sockets.client.settimeout = mocker.MagicMock() await sockets.server.send_message( - 0x70, {"server": "Neo4j/4.3.4", "hints": hints} + b"\x70", {"server": "Neo4j/4.3.4", "hints": hints} ) connection = AsyncBolt4x4( address, sockets.client, PoolConfig.max_connection_lifetime diff --git a/tests/unit/async_/io/test_direct.py b/tests/unit/async_/io/test_direct.py index d68082f49..01d37e463 100644 --- a/tests/unit/async_/io/test_direct.py +++ b/tests/unit/async_/io/test_direct.py @@ -18,13 +18,13 @@ import pytest -from neo4j import ( +from neo4j._async.io import AsyncBolt +from neo4j._async.io._pool import AsyncIOPool +from neo4j._conf import ( Config, PoolConfig, WorkspaceConfig, ) -from neo4j._async.io import AsyncBolt -from neo4j._async.io._pool import AsyncIOPool from neo4j._deadline import Deadline from neo4j.exceptions import ( ClientError, diff --git a/tests/unit/async_/io/test_neo4j_pool.py b/tests/unit/async_/io/test_neo4j_pool.py index b98fcd73a..44b1931f4 100644 --- a/tests/unit/async_/io/test_neo4j_pool.py +++ b/tests/unit/async_/io/test_neo4j_pool.py @@ -16,6 +16,8 @@ # limitations under the License. +import inspect + import pytest from neo4j import ( @@ -23,20 +25,20 @@ WRITE_ACCESS, ) from neo4j._async.io import AsyncNeo4jPool -from neo4j._deadline import Deadline -from neo4j.addressing import ResolvedAddress -from neo4j.conf import ( +from neo4j._conf import ( PoolConfig, RoutingConfig, WorkspaceConfig, ) +from neo4j._deadline import Deadline +from neo4j.addressing import ResolvedAddress from neo4j.exceptions import ( ServiceUnavailable, SessionExpired, ) from ...._async_compat import mark_async_test -from ..work import async_fake_connection_generator +from ..work import async_fake_connection_generator # needed as fixture ROUTER_ADDRESS = ResolvedAddress(("1.2.3.1", 9001), host_name="host") @@ -44,7 +46,7 @@ WRITER_ADDRESS = ResolvedAddress(("1.2.3.1", 9003), host_name="host") -@pytest.fixture() +@pytest.fixture def opener(async_fake_connection_generator, mocker): async def open_(addr, timeout): connection = async_fake_connection_generator() @@ -156,11 +158,13 @@ async def test_reuses_connection(opener): @pytest.mark.parametrize("break_on_close", (True, False)) @mark_async_test async def test_closes_stale_connections(opener, break_on_close): - def break_connection(): - pool.deactivate(cx1.addr) + async def break_connection(): + await pool.deactivate(cx1.addr) if cx_close_mock_side_effect: - cx_close_mock_side_effect() + res = cx_close_mock_side_effect() + if inspect.isawaitable(res): + return await res pool = AsyncNeo4jPool( opener, PoolConfig(), WorkspaceConfig(), ROUTER_ADDRESS @@ -242,8 +246,8 @@ async def test_release_does_not_resets_closed_connections(opener): cx1.is_reset_mock.reset_mock() await pool.release(cx1) cx1.closed.assert_called_once() - cx1.is_reset_mock.asset_not_called() - cx1.reset.asset_not_called() + cx1.is_reset_mock.assert_not_called() + cx1.reset.assert_not_called() @mark_async_test @@ -257,8 +261,8 @@ async def test_release_does_not_resets_defunct_connections(opener): cx1.is_reset_mock.reset_mock() await pool.release(cx1) cx1.defunct.assert_called_once() - cx1.is_reset_mock.asset_not_called() - cx1.reset.asset_not_called() + cx1.is_reset_mock.assert_not_called() + cx1.reset.assert_not_called() @pytest.mark.parametrize("liveness_timeout", (0, 1, 2)) @@ -271,7 +275,7 @@ async def test_acquire_performs_no_liveness_check_on_fresh_connection( ) cx1 = await pool._acquire(READER_ADDRESS, Deadline(30), liveness_timeout) assert cx1.addr == READER_ADDRESS - cx1.reset.asset_not_called() + cx1.reset.assert_not_called() @pytest.mark.parametrize("liveness_timeout", (0, 1, 2)) diff --git a/tests/unit/async_/test_driver.py b/tests/unit/async_/test_driver.py index aaadb43b9..d263ca23e 100644 --- a/tests/unit/async_/test_driver.py +++ b/tests/unit/async_/test_driver.py @@ -17,6 +17,7 @@ import ssl +from functools import wraps import pytest @@ -31,6 +32,7 @@ TrustCustomCAs, TrustSystemCAs, ) +from neo4j._async_compat.util import AsyncUtil from neo4j.api import ( READ_ACCESS, WRITE_ACCESS, @@ -40,6 +42,21 @@ from ..._async_compat import mark_async_test +@wraps(AsyncGraphDatabase.driver) +def create_driver(*args, **kwargs): + if AsyncUtil.is_async_code: + with pytest.warns(ExperimentalWarning, match="async") as warnings: + driver = AsyncGraphDatabase.driver(*args, **kwargs) + print(warnings) + return driver + else: + return AsyncGraphDatabase.driver(*args, **kwargs) + + +def driver(*args, **kwargs): + return AsyncNeo4jDriver(*args, **kwargs) + + @pytest.mark.parametrize("protocol", ("bolt://", "bolt+s://", "bolt+ssc://")) @pytest.mark.parametrize("host", ("localhost", "127.0.0.1", "[::1]", "[0:0:0:0:0:0:0:1]")) @@ -53,7 +70,7 @@ async def test_direct_driver_constructor(protocol, host, port, params, auth_toke with pytest.warns(DeprecationWarning, match="routing context"): driver = AsyncGraphDatabase.driver(uri, auth=auth_token) else: - driver = AsyncGraphDatabase.driver(uri, auth=auth_token) + driver = create_driver(uri, auth=auth_token) assert isinstance(driver, AsyncBoltDriver) await driver.close() @@ -68,7 +85,7 @@ async def test_direct_driver_constructor(protocol, host, port, params, auth_toke @mark_async_test async def test_routing_driver_constructor(protocol, host, port, params, auth_token): uri = protocol + host + port + params - driver = AsyncGraphDatabase.driver(uri, auth=auth_token) + driver = create_driver(uri, auth=auth_token) assert isinstance(driver, AsyncNeo4jDriver) await driver.close() @@ -128,13 +145,20 @@ async def test_routing_driver_constructor(protocol, host, port, params, auth_tok async def test_driver_config_error( test_uri, test_config, expected_failure, expected_failure_message ): + def driver_builder(): + if "trust" in test_config: + with pytest.warns(DeprecationWarning, match="trust"): + return AsyncGraphDatabase.driver(test_uri, **test_config) + else: + return create_driver(test_uri, **test_config) + if "+" in test_uri: # `+s` and `+ssc` are short hand syntax for not having to configure the # encryption behavior of the driver. Specifying both is invalid. with pytest.raises(expected_failure, match=expected_failure_message): - AsyncGraphDatabase.driver(test_uri, **test_config) + driver_builder() else: - driver = AsyncGraphDatabase.driver(test_uri, **test_config) + driver = driver_builder() await driver.close() @@ -145,7 +169,7 @@ async def test_driver_config_error( )) def test_invalid_protocol(test_uri): with pytest.raises(ConfigurationError, match="scheme"): - AsyncGraphDatabase.driver(test_uri) + create_driver(test_uri) @pytest.mark.parametrize( @@ -160,7 +184,7 @@ def test_driver_trust_config_error( test_config, expected_failure, expected_failure_message ): with pytest.raises(expected_failure, match=expected_failure_message): - AsyncGraphDatabase.driver("bolt://127.0.0.1:9001", **test_config) + create_driver("bolt://127.0.0.1:9001", **test_config) @pytest.mark.parametrize("uri", ( @@ -169,7 +193,7 @@ def test_driver_trust_config_error( )) @mark_async_test async def test_driver_opens_write_session_by_default(uri, mocker): - driver = AsyncGraphDatabase.driver(uri) + driver = create_driver(uri) from neo4j import AsyncTransaction # we set a specific db, because else the driver would try to fetch a RT @@ -208,7 +232,7 @@ async def test_driver_opens_write_session_by_default(uri, mocker): )) @mark_async_test async def test_verify_connectivity(uri, mocker): - driver = AsyncGraphDatabase.driver(uri) + driver = create_driver(uri) pool_mock = mocker.patch.object(driver, "_pool", autospec=True) try: @@ -232,10 +256,10 @@ async def test_verify_connectivity(uri, mocker): {"fetch_size": 69}, )) @mark_async_test -async def test_verify_connectivity_parameters_are_experimental( +async def test_verify_connectivity_parameters_are_deprecated( uri, kwargs, mocker ): - driver = AsyncGraphDatabase.driver(uri) + driver = create_driver(uri) mocker.patch.object(driver, "_pool", autospec=True) try: @@ -258,7 +282,7 @@ async def test_verify_connectivity_parameters_are_experimental( async def test_get_server_info_parameters_are_experimental( uri, kwargs, mocker ): - driver = AsyncGraphDatabase.driver(uri) + driver = create_driver(uri) mocker.patch.object(driver, "_pool", autospec=True) try: diff --git a/tests/unit/async_/work/test_result.py b/tests/unit/async_/work/test_result.py index 30ffe5832..8981ef327 100644 --- a/tests/unit/async_/work/test_result.py +++ b/tests/unit/async_/work/test_result.py @@ -14,9 +14,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from re import match -from unittest import mock + + import warnings +from unittest import mock import pandas as pd import pytest @@ -25,6 +26,7 @@ from neo4j import ( Address, AsyncResult, + ExperimentalWarning, Record, ResultSummary, ServerInfo, @@ -33,9 +35,9 @@ Version, ) from neo4j._async_compat.util import AsyncUtil -from neo4j.data import ( - DataDehydrator, - DataHydrator, +from neo4j._codec.hydration.v1 import HydrationHandler +from neo4j._codec.packstream import Structure +from neo4j._data import ( Node, Relationship, ) @@ -44,7 +46,6 @@ EntitySetView, Graph, ) -from neo4j.packstream import Structure from ...._async_compat import mark_async_test @@ -52,9 +53,24 @@ class Records: def __init__(self, fields, records): self.fields = tuple(fields) + self.hydration_scope = HydrationHandler().new_hydration_scope() self.records = tuple(records) + self._hydrate_records() + assert all(len(self.fields) == len(r) for r in self.records) + def _hydrate_records(self): + def _hydrate(value): + if type(value) in self.hydration_scope.hydration_hooks: + return self.hydration_scope.hydration_hooks[type(value)](value) + if isinstance(value, (list, tuple)): + return type(value)(_hydrate(v) for v in value) + if isinstance(value, dict): + return {k: _hydrate(v) for k, v in value.items()} + return value + + self.records = tuple(_hydrate(r) for r in self.records) + def __len__(self): return self.records.__len__() @@ -113,6 +129,7 @@ def __init__(self, records=None, run_meta=None, summary_meta=None, self.summary_meta = summary_meta AsyncConnectionStub.server_info.update({"server": "Neo4j/4.3.0"}) self.unresolved_address = None + self._new_hydration_scope_called = False async def send_all(self): self.sent += self.queued @@ -187,10 +204,20 @@ def pull(self, *args, **kwargs): def defunct(self): return False + def new_hydration_scope(self): + class FakeHydrationScope: + hydration_hooks = None + dehydration_hooks = None -class HydratorStub(DataHydrator): - def hydrate(self, values): - return values + def get_graph(self): + return Graph() + + if len(self._records) > 1: + return FakeHydrationScope() + assert not self._new_hydration_scope_called + assert self._records + self._new_hydration_scope_called = True + return self._records[0].hydration_scope def noop(*_, **__): @@ -254,7 +281,7 @@ async def fetch_and_compare_all_records( @mark_async_test async def test_result_iteration(method, records): connection = AsyncConnectionStub(records=Records(["x"], records)) - result = AsyncResult(connection, HydratorStub(), 2, noop, noop) + result = AsyncResult(connection, 2, noop, noop) await result._run("CYPHER", {}, None, None, "r", None) await fetch_and_compare_all_records(result, "x", records, method) @@ -263,7 +290,7 @@ async def test_result_iteration(method, records): async def test_result_iteration_mixed_methods(): records = [[i] for i in range(10)] connection = AsyncConnectionStub(records=Records(["x"], records)) - result = AsyncResult(connection, HydratorStub(), 4, noop, noop) + result = AsyncResult(connection, 4, noop, noop) await result._run("CYPHER", {}, None, None, "r", None) iter1 = AsyncUtil.iter(result) iter2 = AsyncUtil.iter(result) @@ -299,9 +326,9 @@ async def test_parallel_result_iteration(method, invert_fetch): connection = AsyncConnectionStub( records=(Records(["x"], records1), Records(["x"], records2)) ) - result1 = AsyncResult(connection, HydratorStub(), 2, noop, noop) + result1 = AsyncResult(connection, 2, noop, noop) await result1._run("CYPHER1", {}, None, None, "r", None) - result2 = AsyncResult(connection, HydratorStub(), 2, noop, noop) + result2 = AsyncResult(connection, 2, noop, noop) await result2._run("CYPHER2", {}, None, None, "r", None) if invert_fetch: await fetch_and_compare_all_records( @@ -329,9 +356,9 @@ async def test_interwoven_result_iteration(method, invert_fetch): connection = AsyncConnectionStub( records=(Records(["x"], records1), Records(["y"], records2)) ) - result1 = AsyncResult(connection, HydratorStub(), 2, noop, noop) + result1 = AsyncResult(connection, 2, noop, noop) await result1._run("CYPHER1", {}, None, None, "r", None) - result2 = AsyncResult(connection, HydratorStub(), 2, noop, noop) + result2 = AsyncResult(connection, 2, noop, noop) await result2._run("CYPHER2", {}, None, None, "r", None) start = 0 for n in (1, 2, 3, 1, None): @@ -358,7 +385,7 @@ async def test_interwoven_result_iteration(method, invert_fetch): @mark_async_test async def test_result_peek(records, fetch_size): connection = AsyncConnectionStub(records=Records(["x"], records)) - result = AsyncResult(connection, HydratorStub(), fetch_size, noop, noop) + result = AsyncResult(connection, fetch_size, noop, noop) await result._run("CYPHER", {}, None, None, "r", None) for i in range(len(records) + 1): record = await result.peek() @@ -381,7 +408,7 @@ async def test_result_single_non_strict(records, fetch_size, default): kwargs["strict"] = False connection = AsyncConnectionStub(records=Records(["x"], records)) - result = AsyncResult(connection, HydratorStub(), fetch_size, noop, noop) + result = AsyncResult(connection, fetch_size, noop, noop) await result._run("CYPHER", {}, None, None, "r", None) if len(records) == 0: assert await result.single(**kwargs) is None @@ -400,7 +427,7 @@ async def test_result_single_non_strict(records, fetch_size, default): @mark_async_test async def test_result_single_strict(records, fetch_size): connection = AsyncConnectionStub(records=Records(["x"], records)) - result = AsyncResult(connection, HydratorStub(), fetch_size, noop, noop) + result = AsyncResult(connection, fetch_size, noop, noop) await result._run("CYPHER", {}, None, None, "r", None) try: record = await result.single(strict=True) @@ -427,7 +454,7 @@ async def test_result_single_strict(records, fetch_size): @mark_async_test async def test_result_single_exhausts_records(records, fetch_size, strict): connection = AsyncConnectionStub(records=Records(["x"], records)) - result = AsyncResult(connection, HydratorStub(), fetch_size, noop, noop) + result = AsyncResult(connection, fetch_size, noop, noop) await result._run("CYPHER", {}, None, None, "r", None) try: with warnings.catch_warnings(): @@ -449,7 +476,7 @@ async def test_result_single_exhausts_records(records, fetch_size, strict): @mark_async_test async def test_result_fetch(records, fetch_size, strict): connection = AsyncConnectionStub(records=Records(["x"], records)) - result = AsyncResult(connection, HydratorStub(), fetch_size, noop, noop) + result = AsyncResult(connection, fetch_size, noop, noop) await result._run("CYPHER", {}, None, None, "r", None) assert await result.fetch(0) == [] assert await result.fetch(-1) == [] @@ -461,7 +488,7 @@ async def test_result_fetch(records, fetch_size, strict): @mark_async_test async def test_keys_are_available_before_and_after_stream(): connection = AsyncConnectionStub(records=Records(["x"], [[1], [2]])) - result = AsyncResult(connection, HydratorStub(), 1, noop, noop) + result = AsyncResult(connection, 1, noop, noop) await result._run("CYPHER", {}, None, None, "r", None) assert list(result.keys()) == ["x"] await AsyncUtil.list(result) @@ -477,7 +504,7 @@ async def test_consume(records, consume_one, summary_meta, consume_times): connection = AsyncConnectionStub( records=Records(["x"], records), summary_meta=summary_meta ) - result = AsyncResult(connection, HydratorStub(), 1, noop, noop) + result = AsyncResult(connection, 1, noop, noop) await result._run("CYPHER", {}, None, None, "r", None) if consume_one: try: @@ -512,7 +539,7 @@ async def test_time_in_summary(t_first, t_last): summary_meta=summary_meta ) - result = AsyncResult(connection, HydratorStub(), 1, noop, noop) + result = AsyncResult(connection, 1, noop, noop) await result._run("CYPHER", {}, None, None, "r", None) summary = await result.consume() @@ -534,7 +561,7 @@ async def test_time_in_summary(t_first, t_last): async def test_counts_in_summary(): connection = AsyncConnectionStub(records=Records(["n"], [[1], [2]])) - result = AsyncResult(connection, HydratorStub(), 1, noop, noop) + result = AsyncResult(connection, 1, noop, noop) await result._run("CYPHER", {}, None, None, "r", None) summary = await result.consume() @@ -548,7 +575,7 @@ async def test_query_type(query_type): records=Records(["n"], [[1], [2]]), summary_meta={"type": query_type} ) - result = AsyncResult(connection, HydratorStub(), 1, noop, noop) + result = AsyncResult(connection, 1, noop, noop) await result._run("CYPHER", {}, None, None, "r", None) summary = await result.consume() @@ -563,7 +590,7 @@ async def test_data(num_records): records=Records(["n"], [[i + 1] for i in range(num_records)]) ) - result = AsyncResult(connection, HydratorStub(), 1, noop, noop) + result = AsyncResult(connection, 1, noop, noop) await result._run("CYPHER", {}, None, None, "r", None) await result._buffer_all() records = result._record_buffer.copy() @@ -578,6 +605,7 @@ async def test_data(num_records): assert record.data.called_once_with("hello", "world") +# TODO: dehydration now happens on a much lower level @pytest.mark.parametrize("records", ( Records(["n"], []), Records(["n"], [[42], [69], [420], [1337]]), @@ -603,8 +631,9 @@ async def test_result_graph(records, async_scripted_connection): "on_summary": None }), )) - result = AsyncResult(async_scripted_connection, DataHydrator(), 1, noop, - noop) + async_scripted_connection.new_hydration_scope.return_value = \ + records.hydration_scope + result = AsyncResult(async_scripted_connection, 1, noop, noop) await result._run("CYPHER", {}, None, None, "r", None) graph = await result.graph() assert isinstance(graph, Graph) @@ -702,12 +731,13 @@ async def test_result_graph(records, async_scripted_connection): @mark_async_test async def test_to_df(keys, values, types, instances, test_default_expand): connection = AsyncConnectionStub(records=Records(keys, values)) - result = AsyncResult(connection, DataHydrator(), 1, noop, noop) + result = AsyncResult(connection, 1, noop, noop) await result._run("CYPHER", {}, None, None, "r", None) - if test_default_expand: - df = await result.to_df() - else: - df = await result.to_df(expand=False) + with pytest.warns(ExperimentalWarning, match="pandas"): + if test_default_expand: + df = await result.to_df() + else: + df = await result.to_df(expand=False) assert isinstance(df, pd.DataFrame) assert df.keys().to_list() == keys @@ -807,12 +837,12 @@ async def test_to_df(keys, values, types, instances, test_default_expand): ( ["n"], list(zip(( - Structure(b"N", 0, ["LABEL_A"], - {"a": 1, "b": 2, "d": 1}, "00"), - Structure(b"N", 2, ["LABEL_B"], - {"a": 1, "c": 1.2, "d": 2}, "02"), - Structure(b"N", 1, ["LABEL_A", "LABEL_B"], - {"a": [1, "a"], "d": 3}, "01"), + Node(None, "00", 0, ["LABEL_A"], + {"a": 1, "b": 2, "d": 1}), + Node(None, "02", 2, ["LABEL_B"], + {"a": 1, "c": 1.2, "d": 2}), + Node(None, "01", 1, ["LABEL_A", "LABEL_B"], + {"a": [1, "a"], "d": 3}), ))), [ "n().element_id", "n().labels", "n().prop.a", "n().prop.b", @@ -848,11 +878,7 @@ async def test_to_df(keys, values, types, instances, test_default_expand): ), ( ["dt"], - [ - DataDehydrator().dehydrate(( - neo4j_time.DateTime(2022, 1, 2, 3, 4, 5, 6), - )), - ], + [[neo4j_time.DateTime(2022, 1, 2, 3, 4, 5, 6)]], ["dt"], [[neo4j_time.DateTime(2022, 1, 2, 3, 4, 5, 6)]], ["object"], @@ -863,9 +889,10 @@ async def test_to_df(keys, values, types, instances, test_default_expand): async def test_to_df_expand(keys, values, expected_columns, expected_rows, expected_types): connection = AsyncConnectionStub(records=Records(keys, values)) - result = AsyncResult(connection, DataHydrator(), 1, noop, noop) + result = AsyncResult(connection, 1, noop, noop) await result._run("CYPHER", {}, None, None, "r", None) - df = await result.to_df(expand=True) + with pytest.warns(ExperimentalWarning, match="pandas"): + df = await result.to_df(expand=True) assert isinstance(df, pd.DataFrame) assert len(set(expected_columns)) == len(expected_columns) @@ -895,9 +922,7 @@ async def test_to_df_expand(keys, values, expected_columns, expected_rows, ( ["dt"], [ - DataDehydrator().dehydrate(( - neo4j_time.DateTime(2022, 1, 2, 3, 4, 5, 6), - )), + [neo4j_time.DateTime(2022, 1, 2, 3, 4, 5, 6)], ], pd.DataFrame( [[pd.Timestamp("2022-01-02 03:04:05.000000006")]], @@ -908,9 +933,7 @@ async def test_to_df_expand(keys, values, expected_columns, expected_rows, ( ["d"], [ - DataDehydrator().dehydrate(( - neo4j_time.Date(2222, 2, 22), - )), + [neo4j_time.Date(2222, 2, 22)], ], pd.DataFrame( [[pd.Timestamp("2222-02-22")]], @@ -921,11 +944,11 @@ async def test_to_df_expand(keys, values, expected_columns, expected_rows, ( ["dt_tz"], [ - DataDehydrator().dehydrate(( + [ pytz.timezone("Europe/Stockholm").localize( neo4j_time.DateTime(1970, 1, 1, 0, 0, 0, 0) ), - )), + ], ], pd.DataFrame( [[ @@ -941,17 +964,13 @@ async def test_to_df_expand(keys, values, expected_columns, expected_rows, ["mixed"], [ [None], - DataDehydrator().dehydrate(( - neo4j_time.DateTime(2022, 1, 2, 3, 4, 5, 6), - )), - DataDehydrator().dehydrate(( - neo4j_time.Date(2222, 2, 22), - )), - DataDehydrator().dehydrate(( + [neo4j_time.DateTime(2022, 1, 2, 3, 4, 5, 6)], + [neo4j_time.Date(2222, 2, 22)], + [ pytz.timezone("Europe/Stockholm").localize( neo4j_time.DateTime(1970, 1, 1, 0, 0, 0, 0) ), - )), + ], ], pd.DataFrame( [ @@ -971,18 +990,14 @@ async def test_to_df_expand(keys, values, expected_columns, expected_rows, ( ["mixed"], [ - DataDehydrator().dehydrate(( - neo4j_time.DateTime(2022, 1, 2, 3, 4, 5, 6), - )), - DataDehydrator().dehydrate(( - neo4j_time.Date(2222, 2, 22), - )), + [neo4j_time.DateTime(2022, 1, 2, 3, 4, 5, 6)], + [neo4j_time.Date(2222, 2, 22)], [None], - DataDehydrator().dehydrate(( + [ pytz.timezone("Europe/Stockholm").localize( neo4j_time.DateTime(1970, 1, 1, 0, 0, 0, 0) ), - )), + ], ], pd.DataFrame( [ @@ -1002,17 +1017,13 @@ async def test_to_df_expand(keys, values, expected_columns, expected_rows, ( ["mixed"], [ - DataDehydrator().dehydrate(( - neo4j_time.DateTime(2022, 1, 2, 3, 4, 5, 6), - )), - DataDehydrator().dehydrate(( - neo4j_time.Date(2222, 2, 22), - )), - DataDehydrator().dehydrate(( + [neo4j_time.DateTime(2022, 1, 2, 3, 4, 5, 6),], + [neo4j_time.Date(2222, 2, 22),], + [ pytz.timezone("Europe/Stockholm").localize( neo4j_time.DateTime(1970, 1, 1, 0, 0, 0, 0) ), - )), + ], [None], ], pd.DataFrame( @@ -1052,9 +1063,7 @@ async def test_to_df_expand(keys, values, expected_columns, expected_rows, ], [ None, - *DataDehydrator().dehydrate(( - neo4j_time.DateTime(2022, 1, 2, 3, 4, 5, 6), - )), + neo4j_time.DateTime(2022, 1, 2, 3, 4, 5, 6), 1.234, ], ], @@ -1080,8 +1089,9 @@ async def test_to_df_expand(keys, values, expected_columns, expected_rows, @mark_async_test async def test_to_df_parse_dates(keys, values, expected_df, expand): connection = AsyncConnectionStub(records=Records(keys, values)) - result = AsyncResult(connection, DataHydrator(), 1, noop, noop) + result = AsyncResult(connection, 1, noop, noop) await result._run("CYPHER", {}, None, None, "r", None) - df = await result.to_df(expand=expand, parse_dates=True) + with pytest.warns(ExperimentalWarning, match="pandas"): + df = await result.to_df(expand=expand, parse_dates=True) pd.testing.assert_frame_equal(df, expected_df) diff --git a/tests/unit/async_/work/test_session.py b/tests/unit/async_/work/test_session.py index 3dcb03828..117cb5cbf 100644 --- a/tests/unit/async_/work/test_session.py +++ b/tests/unit/async_/work/test_session.py @@ -25,26 +25,34 @@ AsyncSession, AsyncTransaction, Bookmarks, - SessionConfig, unit_of_work, ) from neo4j._async.io._pool import AsyncIOPool +from neo4j._conf import SessionConfig from ...._async_compat import mark_async_test -from ._fake_connection import async_fake_connection_generator @pytest.fixture() def pool(async_fake_connection_generator, mocker): pool = mocker.AsyncMock(spec=AsyncIOPool) - pool.acquire.side_effect = iter(async_fake_connection_generator, 0) + assert not hasattr(pool, "acquired_connection_mocks") + pool.acquired_connection_mocks = [] + + def acquire_side_effect(*_, **__): + connection = async_fake_connection_generator() + pool.acquired_connection_mocks.append(connection) + return connection + + pool.acquire.side_effect = acquire_side_effect return pool @mark_async_test async def test_session_context_calls_close(mocker): s = AsyncSession(None, SessionConfig()) - mock_close = mocker.patch.object(s, 'close', autospec=True) + mock_close = mocker.patch.object(s, 'close', autospec=True, + side_effect=s.close) async with s: pass mock_close.assert_called_once_with() @@ -195,9 +203,12 @@ async def test_session_returns_bookmarks_directly(pool, bookmark_values): ) @mark_async_test async def test_session_last_bookmark_is_deprecated(pool, bookmarks): - async with AsyncSession(pool, SessionConfig( - bookmarks=bookmarks - )) as session: + if bookmarks is not None: + with pytest.warns(DeprecationWarning): + session = AsyncSession(pool, SessionConfig(bookmarks=bookmarks)) + else: + session = AsyncSession(pool, SessionConfig(bookmarks=bookmarks)) + async with session: with pytest.warns(DeprecationWarning): if bookmarks: assert (await session.last_bookmark()) == bookmarks[-1] @@ -267,57 +278,46 @@ async def test_session_tx_type(pool): assert isinstance(tx, AsyncTransaction) -@pytest.mark.parametrize(("parameters", "error_type"), ( - ({"x": None}, None), - ({"x": True}, None), - ({"x": False}, None), - ({"x": 123456789}, None), - ({"x": 3.1415926}, None), - ({"x": float("nan")}, None), - ({"x": float("inf")}, None), - ({"x": float("-inf")}, None), - ({"x": "foo"}, None), - ({"x": bytearray([0x00, 0x33, 0x66, 0x99, 0xCC, 0xFF])}, None), - ({"x": b"\x00\x33\x66\x99\xcc\xff"}, None), - ({"x": [1, 2, 3]}, None), - ({"x": ["a", "b", "c"]}, None), - ({"x": ["a", 2, 1.234]}, None), - ({"x": ["a", 2, ["c"]]}, None), - ({"x": {"one": "eins", "two": "zwei", "three": "drei"}}, None), - ({"x": {"one": ["eins", "uno", 1], "two": ["zwei", "dos", 2]}}, None), - - # maps must have string keys - ({"x": {1: 'eins', 2: 'zwei', 3: 'drei'}}, TypeError), - ({"x": {(1, 2): '1+2i', (2, 0): '2'}}, TypeError), +@pytest.mark.parametrize("parameters", ( + {"x": None}, + {"x": True}, + {"x": False}, + {"x": 123456789}, + {"x": 3.1415926}, + {"x": float("nan")}, + {"x": float("inf")}, + {"x": float("-inf")}, + {"x": "foo"}, + {"x": bytearray([0x00, 0x33, 0x66, 0x99, 0xCC, 0xFF])}, + {"x": b"\x00\x33\x66\x99\xcc\xff"}, + {"x": [1, 2, 3]}, + {"x": ["a", "b", "c"]}, + {"x": ["a", 2, 1.234]}, + {"x": ["a", 2, ["c"]]}, + {"x": {"one": "eins", "two": "zwei", "three": "drei"}}, + {"x": {"one": ["eins", "uno", 1], "two": ["zwei", "dos", 2]}}, )) @pytest.mark.parametrize("run_type", ("auto", "unmanaged", "managed")) @mark_async_test async def test_session_run_with_parameters( - pool, parameters, error_type, run_type + pool, parameters, run_type, mocker ): - @contextmanager - def raises(): - if error_type is not None: - with pytest.raises(error_type) as exc: - yield exc - else: - yield None - async with AsyncSession(pool, SessionConfig()) as session: if run_type == "auto": - with raises(): - await session.run("RETURN $x", **parameters) + await session.run("RETURN $x", **parameters) elif run_type == "unmanaged": tx = await session.begin_transaction() - with raises(): - await tx.run("RETURN $x", **parameters) + await tx.run("RETURN $x", **parameters) elif run_type == "managed": async def work(tx): - with raises() as exc: - await tx.run("RETURN $x", **parameters) - if exc is not None: - raise exc - with raises(): - await session.write_transaction(work) + await tx.run("RETURN $x", **parameters) + await session.write_transaction(work) else: raise ValueError(run_type) + + assert len(pool.acquired_connection_mocks) == 1 + connection_mock = pool.acquired_connection_mocks[0] + assert connection_mock.run.called_once() + call = connection_mock.run.call_args + assert call.args[0] == "RETURN $x" + assert call.kwargs["parameters"] == parameters diff --git a/tests/unit/async_/work/test_transaction.py b/tests/unit/async_/work/test_transaction.py index 86e968cf1..7fa36ab76 100644 --- a/tests/unit/async_/work/test_transaction.py +++ b/tests/unit/async_/work/test_transaction.py @@ -113,23 +113,6 @@ class OopsError(RuntimeError): assert tx_.closed() -@pytest.mark.parametrize(("parameters", "error_type"), ( - # maps must have string keys - ({"x": {1: 'eins', 2: 'zwei', 3: 'drei'}}, TypeError), - ({"x": {(1, 2): '1+2i', (2, 0): '2'}}, TypeError), - ({"x": uuid4()}, TypeError), -)) -@mark_async_test -async def test_transaction_run_with_invalid_parameters( - async_fake_connection, parameters, error_type -): - on_closed = MagicMock() - on_error = MagicMock() - tx = AsyncTransaction(async_fake_connection, 2, on_closed, on_error) - with pytest.raises(error_type): - await tx.run("RETURN $x", **parameters) - - @mark_async_test async def test_transaction_run_takes_no_query_object(async_fake_connection): on_closed = MagicMock() diff --git a/tests/unit/common/codec/__init__.py b/tests/unit/common/codec/__init__.py new file mode 100644 index 000000000..c42cc6fb6 --- /dev/null +++ b/tests/unit/common/codec/__init__.py @@ -0,0 +1,16 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [https://neo4j.com] +# +# This file is part of Neo4j. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/unit/common/codec/hydration/__init__.py b/tests/unit/common/codec/hydration/__init__.py new file mode 100644 index 000000000..c42cc6fb6 --- /dev/null +++ b/tests/unit/common/codec/hydration/__init__.py @@ -0,0 +1,16 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [https://neo4j.com] +# +# This file is part of Neo4j. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/unit/common/codec/hydration/_common.py b/tests/unit/common/codec/hydration/_common.py new file mode 100644 index 000000000..a0c924c62 --- /dev/null +++ b/tests/unit/common/codec/hydration/_common.py @@ -0,0 +1,71 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [https://neo4j.com] +# +# This file is part of Neo4j. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from datetime import ( + date, + datetime, + time, + timedelta, +) + +import pytest + +from neo4j._codec.hydration import HydrationScope +from neo4j._codec.hydration.v1 import HydrationHandler as HydrationHandlerV1 +from neo4j._codec.hydration.v2 import HydrationHandler as HydrationHandlerV2 +from neo4j._codec.packstream import Structure +from neo4j.spatial import ( + CartesianPoint, + Point, + WGS84Point, +) +from neo4j.time import ( + Date, + DateTime, + Duration, + Time, +) + + +class HydrationHandlerTestBase: + + @pytest.fixture(params=[HydrationHandlerV1, HydrationHandlerV2]) + def hydration_handler(self, request): + return request.param() + + def test_handler_hydration_scope(self, hydration_handler): + scope = hydration_handler.new_hydration_scope() + assert isinstance(scope, HydrationScope) + + @pytest.fixture + def hydration_scope(self, hydration_handler): + return hydration_handler.new_hydration_scope() + + def test_scope_hydration_keys(self, hydration_scope): + hooks = hydration_scope.hydration_hooks + assert isinstance(hooks, dict) + assert set(hooks.keys()) == {Structure} + + def test_scope_dehydration_keys(self, hydration_scope): + hooks = hydration_scope.dehydration_hooks + assert isinstance(hooks, dict) + assert set(hooks.keys()) == { + date, datetime, time, timedelta, + Date, DateTime, Duration, Time, + CartesianPoint, Point, WGS84Point + } diff --git a/tests/unit/common/codec/hydration/v1/__init__.py b/tests/unit/common/codec/hydration/v1/__init__.py new file mode 100644 index 000000000..c42cc6fb6 --- /dev/null +++ b/tests/unit/common/codec/hydration/v1/__init__.py @@ -0,0 +1,16 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [https://neo4j.com] +# +# This file is part of Neo4j. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/unit/common/codec/hydration/v1/_base.py b/tests/unit/common/codec/hydration/v1/_base.py new file mode 100644 index 000000000..4400d3480 --- /dev/null +++ b/tests/unit/common/codec/hydration/v1/_base.py @@ -0,0 +1,29 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [https://neo4j.com] +# +# This file is part of Neo4j. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import pytest + + +class HydrationHandlerTestBase: + @pytest.fixture() + def hydration_handler(self): + raise NotImplementedError() + + @pytest.fixture + def hydration_scope(self, hydration_handler): + return hydration_handler.new_hydration_scope() diff --git a/tests/unit/common/codec/hydration/v1/test_graph_hydration.py b/tests/unit/common/codec/hydration/v1/test_graph_hydration.py new file mode 100644 index 000000000..3307c7106 --- /dev/null +++ b/tests/unit/common/codec/hydration/v1/test_graph_hydration.py @@ -0,0 +1,67 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [https://neo4j.com] +# +# This file is part of Neo4j. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import pytest + +from neo4j._codec.hydration.v1 import HydrationHandler +from neo4j._codec.packstream import Structure +from neo4j.graph import ( + Graph, + Node, + Relationship, +) + +from ._base import HydrationHandlerTestBase + + +class TestGraphHydration(HydrationHandlerTestBase): + @pytest.fixture + def hydration_handler(self): + return HydrationHandler() + + def test_can_hydrate_node_structure(self, hydration_scope): + struct = Structure(b'N', 123, ["Person"], {"name": "Alice"}) + alice = hydration_scope.hydration_hooks[Structure](struct) + + assert isinstance(alice, Node) + with pytest.warns(DeprecationWarning, match="element_id"): + assert alice.id == 123 + # for backwards compatibility, the driver should compute the element_id + assert alice.element_id == "123" + assert alice.labels == {"Person"} + assert set(alice.keys()) == {"name"} + assert alice.get("name") == "Alice" + + def test_can_hydrate_relationship_structure(self, hydration_scope): + struct = Structure(b'R', 123, 456, 789, "KNOWS", {"since": 1999}) + rel = hydration_scope.hydration_hooks[Structure](struct) + + assert isinstance(rel, Relationship) + with pytest.warns(DeprecationWarning, match="element_id"): + assert rel.id == 123 + with pytest.warns(DeprecationWarning, match="element_id"): + assert rel.start_node.id == 456 + with pytest.warns(DeprecationWarning, match="element_id"): + assert rel.end_node.id == 789 + # for backwards compatibility, the driver should compute the element_id + assert rel.element_id == "123" + assert rel.start_node.element_id == "456" + assert rel.end_node.element_id == "789" + assert rel.type == "KNOWS" + assert set(rel.keys()) == {"since"} + assert rel.get("since") == 1999 diff --git a/tests/unit/common/codec/hydration/v1/test_hydration_handler.py b/tests/unit/common/codec/hydration/v1/test_hydration_handler.py new file mode 100644 index 000000000..eccc10e18 --- /dev/null +++ b/tests/unit/common/codec/hydration/v1/test_hydration_handler.py @@ -0,0 +1,78 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [https://neo4j.com] +# +# This file is part of Neo4j. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from datetime import ( + date, + datetime, + time, + timedelta, +) + +import pytest + +from neo4j._codec.hydration import HydrationScope +from neo4j._codec.hydration.v1 import HydrationHandler +from neo4j._codec.packstream import Structure +from neo4j.graph import Graph +from neo4j.spatial import ( + CartesianPoint, + Point, + WGS84Point, +) +from neo4j.time import ( + Date, + DateTime, + Duration, + Time, +) + +from ._base import HydrationHandlerTestBase + + +class TestHydrationHandler(HydrationHandlerTestBase): + @pytest.fixture + def hydration_handler(self): + return HydrationHandler() + + def test_handler_hydration_scope(self, hydration_handler): + scope = hydration_handler.new_hydration_scope() + assert isinstance(scope, HydrationScope) + + @pytest.fixture + def hydration_scope(self, hydration_handler): + return hydration_handler.new_hydration_scope() + + def test_scope_hydration_keys(self, hydration_scope): + hooks = hydration_scope.hydration_hooks + assert isinstance(hooks, dict) + assert set(hooks.keys()) == {Structure} + + def test_scope_dehydration_keys(self, hydration_scope): + hooks = hydration_scope.dehydration_hooks + assert isinstance(hooks, dict) + assert set(hooks.keys()) == { + date, datetime, time, timedelta, + Date, DateTime, Duration, Time, + CartesianPoint, Point, WGS84Point + } + + def test_scope_get_graph(self, hydration_scope): + graph = hydration_scope.get_graph() + assert isinstance(graph, Graph) + assert not graph.nodes + assert not graph.relationships diff --git a/tests/unit/common/codec/hydration/v1/test_spacial_dehydration.py b/tests/unit/common/codec/hydration/v1/test_spacial_dehydration.py new file mode 100644 index 000000000..6486cea52 --- /dev/null +++ b/tests/unit/common/codec/hydration/v1/test_spacial_dehydration.py @@ -0,0 +1,73 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [https://neo4j.com] +# +# This file is part of Neo4j. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import pytest + +from neo4j._codec.hydration.v1 import HydrationHandler +from neo4j._codec.packstream import Structure +from neo4j.spatial import ( + CartesianPoint, + Point, + WGS84Point, +) + +from ._base import HydrationHandlerTestBase + + +class TestSpatialDehydration(HydrationHandlerTestBase): + @pytest.fixture + def hydration_handler(self): + return HydrationHandler() + + def test_cartesian_2d(self, hydration_scope): + point = CartesianPoint((1, 3.1)) + struct = hydration_scope.dehydration_hooks[type(point)](point) + assert struct == Structure(b"X", 7203, 1.0, 3.1) + assert all(isinstance(f, float) for f in struct.fields[1:]) + + def test_cartesian_3d(self, hydration_scope): + point = CartesianPoint((1, -2, 3.1)) + struct = hydration_scope.dehydration_hooks[type(point)](point) + assert struct == Structure(b"Y", 9157, 1.0, -2.0, 3.1) + assert all(isinstance(f, float) for f in struct.fields[1:]) + + def test_wgs84_2d(self, hydration_scope): + point = WGS84Point((1, 3.1)) + struct = hydration_scope.dehydration_hooks[type(point)](point) + assert struct == Structure(b"X", 4326, 1.0, 3.1) + assert all(isinstance(f, float) for f in struct.fields[1:]) + + def test_wgs84_3d(self, hydration_scope): + point = WGS84Point((1, -2, 3.1)) + struct = hydration_scope.dehydration_hooks[type(point)](point) + assert struct == Structure(b"Y", 4979, 1.0, -2.0, 3.1) + assert all(isinstance(f, float) for f in struct.fields[1:]) + + def test_custom_point_2d(self, hydration_scope): + point = Point((1, 3.1)) + point.srid = 12345 + struct = hydration_scope.dehydration_hooks[type(point)](point) + assert struct == Structure(b"X", 12345, 1.0, 3.1) + assert all(isinstance(f, float) for f in struct.fields[1:]) + + def test_custom_point_3d(self, hydration_scope): + point = Point((1, -2, 3.1)) + point.srid = 12345 + struct = hydration_scope.dehydration_hooks[type(point)](point) + assert struct == Structure(b"Y", 12345, 1.0, -2.0, 3.1) + assert all(isinstance(f, float) for f in struct.fields[1:]) diff --git a/tests/unit/common/codec/hydration/v1/test_spacial_hydration.py b/tests/unit/common/codec/hydration/v1/test_spacial_hydration.py new file mode 100644 index 000000000..ef4fad6b8 --- /dev/null +++ b/tests/unit/common/codec/hydration/v1/test_spacial_hydration.py @@ -0,0 +1,77 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [https://neo4j.com] +# +# This file is part of Neo4j. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import pytest + +from neo4j._codec.hydration.v1 import HydrationHandler +from neo4j._codec.packstream import Structure +from neo4j.spatial import ( + CartesianPoint, + Point, + WGS84Point, +) + +from ._base import HydrationHandlerTestBase + + +class TestSpatialHydration(HydrationHandlerTestBase): + @pytest.fixture + def hydration_handler(self): + return HydrationHandler() + + def test_cartesian_2d(self, hydration_scope): + struct = Structure(b"X", 7203, 1.0, 3.1) + point = hydration_scope.hydration_hooks[Structure](struct) + assert isinstance(point, CartesianPoint) + assert point.srid == 7203 + assert tuple(point) == (1.0, 3.1) + + def test_cartesian_3d(self, hydration_scope): + struct = Structure(b"Y", 9157, 1.0, -2.0, 3.1) + point = hydration_scope.hydration_hooks[Structure](struct) + assert isinstance(point, CartesianPoint) + assert point.srid == 9157 + assert tuple(point) == (1.0, -2.0, 3.1) + + def test_wgs84_2d(self, hydration_scope): + struct = Structure(b"X", 4326, 1.0, 3.1) + point = hydration_scope.hydration_hooks[Structure](struct) + assert isinstance(point, WGS84Point) + assert point.srid == 4326 + assert tuple(point) == (1.0, 3.1) + + def test_wgs84_3d(self, hydration_scope): + struct = Structure(b"Y", 4979, 1.0, -2.0, 3.1) + point = hydration_scope.hydration_hooks[Structure](struct) + assert isinstance(point, WGS84Point) + assert point.srid == 4979 + assert tuple(point) == (1.0, -2.0, 3.1) + + def test_custom_point_2d(self, hydration_scope): + struct = Structure(b"X", 12345, 1.0, 3.1) + point = hydration_scope.hydration_hooks[Structure](struct) + assert isinstance(point, Point) + assert point.srid == 12345 + assert tuple(point) == (1.0, 3.1) + + def test_custom_point_3d(self, hydration_scope): + struct = Structure(b"Y", 12345, 1.0, -2.0, 3.1) + point = hydration_scope.hydration_hooks[Structure](struct) + assert isinstance(point, Point) + assert point.srid == 12345 + assert tuple(point) == (1.0, -2.0, 3.1) diff --git a/tests/unit/common/codec/hydration/v1/test_time_dehydration.py b/tests/unit/common/codec/hydration/v1/test_time_dehydration.py new file mode 100644 index 000000000..8315d7081 --- /dev/null +++ b/tests/unit/common/codec/hydration/v1/test_time_dehydration.py @@ -0,0 +1,193 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [https://neo4j.com] +# +# This file is part of Neo4j. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import datetime + +import pytest +import pytz + +from neo4j._codec.hydration.v1 import HydrationHandler +from neo4j._codec.packstream import Structure +from neo4j.time import ( + Date, + DateTime, + Duration, + Time, +) + +from ._base import HydrationHandlerTestBase + + +class TestTimeDehydration(HydrationHandlerTestBase): + @pytest.fixture + def hydration_handler(self): + return HydrationHandler() + + def test_date(self, hydration_scope): + date = Date(1991, 8, 24) + struct = hydration_scope.dehydration_hooks[type(date)](date) + assert struct == Structure(b"D", 7905) + + def test_native_date(self, hydration_scope): + date = datetime.date(1991, 8, 24) + struct = hydration_scope.dehydration_hooks[type(date)](date) + assert struct == Structure(b"D", 7905) + + def test_time(self, hydration_scope): + time = Time(1, 2, 3, 4, pytz.FixedOffset(60)) + struct = hydration_scope.dehydration_hooks[type(time)](time) + assert struct == Structure(b"T", 3723000000004, 3600) + + def test_native_time(self, hydration_scope): + time = datetime.time(1, 2, 3, 4, pytz.FixedOffset(60)) + struct = hydration_scope.dehydration_hooks[type(time)](time) + assert struct == Structure(b"T", 3723000004000, 3600) + + def test_local_time(self, hydration_scope): + time = Time(1, 2, 3, 4) + struct = hydration_scope.dehydration_hooks[type(time)](time) + assert struct == Structure(b"t", 3723000000004) + + def test_local_native_time(self, hydration_scope): + time = datetime.time(1, 2, 3, 4) + struct = hydration_scope.dehydration_hooks[type(time)](time) + assert struct == Structure(b"t", 3723000004000) + + def test_date_time(self, hydration_scope): + dt = DateTime(2018, 10, 12, 11, 37, 41, 474716862, + pytz.FixedOffset(60)) + struct = hydration_scope.dehydration_hooks[type(dt)](dt) + assert struct == Structure(b"F", 1539344261, 474716862, 3600) + + def test_native_date_time(self, hydration_scope): + dt = datetime.datetime(2018, 10, 12, 11, 37, 41, 474716, + pytz.FixedOffset(60)) + struct = hydration_scope.dehydration_hooks[type(dt)](dt) + assert struct == Structure(b"F", 1539344261, 474716000, 3600) + + def test_date_time_negative_offset(self, hydration_scope): + dt = DateTime(2018, 10, 12, 11, 37, 41, 474716862, + pytz.FixedOffset(-60)) + struct = hydration_scope.dehydration_hooks[type(dt)](dt) + assert struct == Structure(b"F", 1539344261, 474716862, -3600) + + def test_native_date_time_negative_offset(self, hydration_scope): + dt = datetime.datetime(2018, 10, 12, 11, 37, 41, 474716, + pytz.FixedOffset(-60)) + struct = hydration_scope.dehydration_hooks[type(dt)](dt) + assert struct == Structure(b"F", 1539344261, 474716000, -3600) + + def test_date_time_zone_id(self, hydration_scope): + dt = DateTime(2018, 10, 12, 11, 37, 41, 474716862, + pytz.timezone("Europe/Stockholm")) + struct = hydration_scope.dehydration_hooks[type(dt)](dt) + assert struct == Structure(b"f", 1539344261, 474716862, + "Europe/Stockholm") + + def test_native_date_time_zone_id(self, hydration_scope): + dt = datetime.datetime(2018, 10, 12, 11, 37, 41, 474716, + pytz.timezone("Europe/Stockholm")) + struct = hydration_scope.dehydration_hooks[type(dt)](dt) + assert struct == Structure(b"f", 1539344261, 474716000, + "Europe/Stockholm") + + def test_local_date_time(self, hydration_scope): + dt = DateTime(2018, 10, 12, 11, 37, 41, 474716862) + struct = hydration_scope.dehydration_hooks[type(dt)](dt) + assert struct == Structure(b"d", 1539344261, 474716862) + + def test_native_local_date_time(self, hydration_scope): + dt = datetime.datetime(2018, 10, 12, 11, 37, 41, 474716) + struct = hydration_scope.dehydration_hooks[type(dt)](dt) + assert struct == Structure(b"d", 1539344261, 474716000) + + def test_duration(self, hydration_scope): + duration = Duration(months=1, days=2, seconds=3, nanoseconds=4) + struct = hydration_scope.dehydration_hooks[type(duration)](duration) + assert struct == Structure(b"E", 1, 2, 3, 4) + + def test_native_duration(self, hydration_scope): + duration = datetime.timedelta(days=1, seconds=2, microseconds=3) + struct = hydration_scope.dehydration_hooks[type(duration)](duration) + assert struct == Structure(b"E", 0, 1, 2, 3000) + + def test_duration_mixed_sign(self, hydration_scope): + duration = Duration(months=1, days=-2, seconds=3, nanoseconds=4) + struct = hydration_scope.dehydration_hooks[type(duration)](duration) + assert struct == Structure(b"E", 1, -2, 3, 4) + + def test_native_duration_mixed_sign(self, hydration_scope): + duration = datetime.timedelta(days=-1, seconds=2, microseconds=3) + struct = hydration_scope.dehydration_hooks[type(duration)](duration) + assert struct == Structure(b"E", 0, -1, 2, 3000) + + +class TestUTCPatchedTimeDehydration(TestTimeDehydration): + @pytest.fixture + def hydration_handler(self): + handler = HydrationHandler() + handler.patch_utc() + return handler + + def test_date_time(self, hydration_scope): + from ..v2.test_time_dehydration import ( + TestTimeDehydration as TestTimeDehydrationV2, + ) + TestTimeDehydrationV2().test_date_time( + hydration_scope + ) + + def test_native_date_time(self, hydration_scope): + from ..v2.test_time_dehydration import ( + TestTimeDehydration as TestTimeDehydrationV2, + ) + TestTimeDehydrationV2().test_native_date_time( + hydration_scope + ) + + def test_date_time_negative_offset(self, hydration_scope): + from ..v2.test_time_dehydration import ( + TestTimeDehydration as TestTimeDehydrationV2, + ) + TestTimeDehydrationV2().test_date_time_negative_offset( + hydration_scope + ) + + def test_native_date_time_negative_offset(self, hydration_scope): + from ..v2.test_time_dehydration import ( + TestTimeDehydration as TestTimeDehydrationV2, + ) + TestTimeDehydrationV2().test_native_date_time_negative_offset( + hydration_scope + ) + + def test_date_time_zone_id(self, hydration_scope): + from ..v2.test_time_dehydration import ( + TestTimeDehydration as TestTimeDehydrationV2, + ) + TestTimeDehydrationV2().test_date_time_zone_id( + hydration_scope + ) + + def test_native_date_time_zone_id(self, hydration_scope): + from ..v2.test_time_dehydration import ( + TestTimeDehydration as TestTimeDehydrationV2, + ) + TestTimeDehydrationV2().test_native_date_time_zone_id( + hydration_scope + ) diff --git a/tests/unit/common/codec/hydration/v1/test_time_hydration.py b/tests/unit/common/codec/hydration/v1/test_time_hydration.py new file mode 100644 index 000000000..3c04c253f --- /dev/null +++ b/tests/unit/common/codec/hydration/v1/test_time_hydration.py @@ -0,0 +1,167 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [https://neo4j.com] +# +# This file is part of Neo4j. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import pytest +import pytz + +from neo4j._codec.hydration.v1 import HydrationHandler +from neo4j._codec.packstream import Structure +from neo4j.time import ( + Date, + DateTime, + Duration, + Time, +) + +from ._base import HydrationHandlerTestBase + + +class TestTimeHydration(HydrationHandlerTestBase): + @pytest.fixture + def hydration_handler(self): + return HydrationHandler() + + def test_hydrate_date_structure(self, hydration_scope): + struct = Structure(b"D", 7905) + d = hydration_scope.hydration_hooks[Structure](struct) + assert isinstance(d, Date) + assert d.year == 1991 + assert d.month == 8 + assert d.day == 24 + + def test_hydrate_time_structure(self, hydration_scope): + struct = Structure(b"T", 3723000000004, 3600) + t = hydration_scope.hydration_hooks[Structure](struct) + assert isinstance(t, Time) + assert t.hour == 1 + assert t.minute == 2 + assert t.second == 3 + assert t.nanosecond == 4 + assert t.tzinfo == pytz.FixedOffset(60) + + def test_hydrate_local_time_structure(self, hydration_scope): + struct = Structure(b"t", 3723000000004) + t = hydration_scope.hydration_hooks[Structure](struct) + assert isinstance(t, Time) + assert t.hour == 1 + assert t.minute == 2 + assert t.second == 3 + assert t.nanosecond == 4 + assert t.tzinfo is None + + def test_hydrate_date_time_structure_v1(self, hydration_scope): + struct = Structure(b"F", 1539344261, 474716862, 3600) + dt = hydration_scope.hydration_hooks[Structure](struct) + assert isinstance(dt, DateTime) + assert dt.year == 2018 + assert dt.month == 10 + assert dt.day == 12 + assert dt.hour == 11 + assert dt.minute == 37 + assert dt.second == 41 + assert dt.nanosecond == 474716862 + assert dt.tzinfo == pytz.FixedOffset(60) + + def test_hydrate_date_time_structure_v2(self, hydration_scope): + struct = Structure(b"I", 1539344261, 474716862, 3600) + dt = hydration_scope.hydration_hooks[Structure](struct) + assert isinstance(dt, Structure) + assert dt == struct + + def test_hydrate_date_time_zone_id_structure_v1(self, hydration_scope): + struct = Structure(b"f", 1539344261, 474716862, "Europe/Stockholm") + dt = hydration_scope.hydration_hooks[Structure](struct) + assert isinstance(dt, DateTime) + assert dt.year == 2018 + assert dt.month == 10 + assert dt.day == 12 + assert dt.hour == 11 + assert dt.minute == 37 + assert dt.second == 41 + assert dt.nanosecond == 474716862 + tz = pytz.timezone("Europe/Stockholm") \ + .localize(dt.replace(tzinfo=None)).tzinfo + assert dt.tzinfo == tz + + def test_hydrate_date_time_zone_id_structure_v2(self, hydration_scope): + struct = Structure(b"i", 1539344261, 474716862, "Europe/Stockholm") + dt = hydration_scope.hydration_hooks[Structure](struct) + assert isinstance(dt, Structure) + assert dt == struct + + def test_hydrate_local_date_time_structure(self, hydration_scope): + struct = Structure(b"d", 1539344261, 474716862) + dt = hydration_scope.hydration_hooks[Structure](struct) + assert isinstance(dt, DateTime) + assert dt.year == 2018 + assert dt.month == 10 + assert dt.day == 12 + assert dt.hour == 11 + assert dt.minute == 37 + assert dt.second == 41 + assert dt.nanosecond == 474716862 + assert dt.tzinfo is None + + def test_hydrate_duration_structure(self, hydration_scope): + struct = Structure(b"E", 1, 2, 3, 4) + d = hydration_scope.hydration_hooks[Structure](struct) + assert isinstance(d, Duration) + assert d.months == 1 + assert d.days == 2 + assert d.seconds == 3 + assert d.nanoseconds == 4 + + +class TestUTCPatchedTimeHydration(TestTimeHydration): + @pytest.fixture + def hydration_handler(self): + handler = HydrationHandler() + handler.patch_utc() + return handler + + def test_hydrate_date_time_structure_v1(self, hydration_scope): + from ..v2.test_time_hydration import ( + TestTimeHydration as TestTimeHydrationV2, + ) + TestTimeHydrationV2().test_hydrate_date_time_structure_v1( + hydration_scope + ) + + def test_hydrate_date_time_structure_v2(self, hydration_scope): + from ..v2.test_time_hydration import ( + TestTimeHydration as TestTimeHydrationV2, + ) + TestTimeHydrationV2().test_hydrate_date_time_structure_v2( + hydration_scope + ) + + def test_hydrate_date_time_zone_id_structure_v1(self, hydration_scope): + from ..v2.test_time_hydration import ( + TestTimeHydration as TestTimeHydrationV2, + ) + TestTimeHydrationV2().test_hydrate_date_time_zone_id_structure_v1( + hydration_scope + ) + + def test_hydrate_date_time_zone_id_structure_v2(self, hydration_scope): + from ..v2.test_time_hydration import ( + TestTimeHydration as TestTimeHydrationV2, + ) + TestTimeHydrationV2().test_hydrate_date_time_zone_id_structure_v2( + hydration_scope + ) diff --git a/tests/unit/common/codec/hydration/v2/__init__.py b/tests/unit/common/codec/hydration/v2/__init__.py new file mode 100644 index 000000000..c42cc6fb6 --- /dev/null +++ b/tests/unit/common/codec/hydration/v2/__init__.py @@ -0,0 +1,16 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [https://neo4j.com] +# +# This file is part of Neo4j. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/unit/common/codec/hydration/v2/test_graph_hydration.py b/tests/unit/common/codec/hydration/v2/test_graph_hydration.py new file mode 100644 index 000000000..1e3e41d2b --- /dev/null +++ b/tests/unit/common/codec/hydration/v2/test_graph_hydration.py @@ -0,0 +1,67 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [https://neo4j.com] +# +# This file is part of Neo4j. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import pytest + +from neo4j._codec.hydration.v1 import HydrationHandler +from neo4j._codec.packstream import Structure +from neo4j.graph import ( + Graph, + Node, + Relationship, +) + +from ..v1.test_graph_hydration import TestGraphHydration as _TestGraphHydration + + +class TestGraphHydration(_TestGraphHydration): + @pytest.fixture + def hydration_handler(self): + return HydrationHandler() + + def test_can_hydrate_node_structure(self, hydration_scope): + struct = Structure(b'N', 123, ["Person"], {"name": "Alice"}, "abc") + alice = hydration_scope.hydration_hooks[Structure](struct) + + assert isinstance(alice, Node) + with pytest.warns(DeprecationWarning, match="element_id"): + assert alice.id == 123 + assert alice.element_id == "abc" + assert alice.labels == {"Person"} + assert set(alice.keys()) == {"name"} + assert alice.get("name") == "Alice" + + def test_can_hydrate_relationship_structure(self, hydration_scope): + struct = Structure(b'R', 123, 456, 789, "KNOWS", {"since": 1999}, + "abc", "def", "ghi") + rel = hydration_scope.hydration_hooks[Structure](struct) + + assert isinstance(rel, Relationship) + with pytest.warns(DeprecationWarning, match="element_id"): + assert rel.id == 123 + with pytest.warns(DeprecationWarning, match="element_id"): + assert rel.start_node.id == 456 + with pytest.warns(DeprecationWarning, match="element_id"): + assert rel.end_node.id == 789 + # for backwards compatibility, the driver should compute the element_id + assert rel.element_id == "abc" + assert rel.start_node.element_id == "def" + assert rel.end_node.element_id == "ghi" + assert rel.type == "KNOWS" + assert set(rel.keys()) == {"since"} + assert rel.get("since") == 1999 diff --git a/tests/unit/common/codec/hydration/v2/test_hydration_handler.py b/tests/unit/common/codec/hydration/v2/test_hydration_handler.py new file mode 100644 index 000000000..c28379ea6 --- /dev/null +++ b/tests/unit/common/codec/hydration/v2/test_hydration_handler.py @@ -0,0 +1,31 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [https://neo4j.com] +# +# This file is part of Neo4j. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import pytest + +from neo4j._codec.hydration.v1 import HydrationHandler + +from ..v1.test_hydration_handler import ( + TestHydrationHandler as TestHydrationHandlerV1, +) + + +class TestHydrationHandler(TestHydrationHandlerV1): + @pytest.fixture + def hydration_handler(self): + return HydrationHandler() diff --git a/tests/unit/common/codec/hydration/v2/test_spacial_dehydration.py b/tests/unit/common/codec/hydration/v2/test_spacial_dehydration.py new file mode 100644 index 000000000..85349dc50 --- /dev/null +++ b/tests/unit/common/codec/hydration/v2/test_spacial_dehydration.py @@ -0,0 +1,31 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [https://neo4j.com] +# +# This file is part of Neo4j. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import pytest + +from neo4j._codec.hydration.v1 import HydrationHandler + +from ..v1.test_spacial_dehydration import ( + TestSpatialDehydration as _TestSpatialDehydrationV1, +) + + +class TestSpatialDehydration(_TestSpatialDehydrationV1): + @pytest.fixture + def hydration_handler(self): + return HydrationHandler() diff --git a/tests/unit/common/codec/hydration/v2/test_spacial_hydration.py b/tests/unit/common/codec/hydration/v2/test_spacial_hydration.py new file mode 100644 index 000000000..d905965ca --- /dev/null +++ b/tests/unit/common/codec/hydration/v2/test_spacial_hydration.py @@ -0,0 +1,31 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [https://neo4j.com] +# +# This file is part of Neo4j. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import pytest + +from neo4j._codec.hydration.v1 import HydrationHandler + +from ..v1.test_spacial_hydration import ( + TestSpatialHydration as _TestSpatialHydrationV1, +) + + +class TestSpatialHydration(_TestSpatialHydrationV1): + @pytest.fixture + def hydration_handler(self): + return HydrationHandler() diff --git a/tests/unit/common/codec/hydration/v2/test_time_dehydration.py b/tests/unit/common/codec/hydration/v2/test_time_dehydration.py new file mode 100644 index 000000000..021db2eb4 --- /dev/null +++ b/tests/unit/common/codec/hydration/v2/test_time_dehydration.py @@ -0,0 +1,74 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [https://neo4j.com] +# +# This file is part of Neo4j. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import datetime + +import pytest +import pytz + +from neo4j._codec.hydration.v2 import HydrationHandler +from neo4j._codec.packstream import Structure +from neo4j.time import DateTime + +from ..v1.test_time_dehydration import ( + TestTimeDehydration as _TestTimeDehydrationV1, +) + + +class TestTimeDehydration(_TestTimeDehydrationV1): + @pytest.fixture + def hydration_handler(self): + return HydrationHandler() + + def test_date_time(self, hydration_scope): + dt = DateTime(2018, 10, 12, 11, 37, 41, 474716862, + pytz.FixedOffset(60)) + struct = hydration_scope.dehydration_hooks[type(dt)](dt) + assert struct == Structure(b"I", 1539340661, 474716862, 3600) + + def test_native_date_time(self, hydration_scope): + dt = datetime.datetime(2018, 10, 12, 11, 37, 41, 474716, + pytz.FixedOffset(60)) + struct = hydration_scope.dehydration_hooks[type(dt)](dt) + assert struct == Structure(b"I", 1539340661, 474716000, 3600) + + def test_date_time_negative_offset(self, hydration_scope): + dt = DateTime(2018, 10, 12, 11, 37, 41, 474716862, + pytz.FixedOffset(-60)) + struct = hydration_scope.dehydration_hooks[type(dt)](dt) + assert struct == Structure(b"I", 1539347861, 474716862, -3600) + + def test_native_date_time_negative_offset(self, hydration_scope): + dt = datetime.datetime(2018, 10, 12, 11, 37, 41, 474716, + pytz.FixedOffset(-60)) + struct = hydration_scope.dehydration_hooks[type(dt)](dt) + assert struct == Structure(b"I", 1539347861, 474716000, -3600) + + def test_date_time_zone_id(self, hydration_scope): + dt = DateTime(2018, 10, 12, 11, 37, 41, 474716862, + pytz.timezone("Europe/Stockholm")) + struct = hydration_scope.dehydration_hooks[type(dt)](dt) + assert struct == Structure(b"i", 1539339941, 474716862, + "Europe/Stockholm") + + def test_native_date_time_zone_id(self, hydration_scope): + dt = datetime.datetime(2018, 10, 12, 11, 37, 41, 474716, + pytz.timezone("Europe/Stockholm")) + struct = hydration_scope.dehydration_hooks[type(dt)](dt) + assert struct == Structure(b"i", 1539339941, 474716000, + "Europe/Stockholm") diff --git a/tests/unit/common/codec/hydration/v2/test_time_hydration.py b/tests/unit/common/codec/hydration/v2/test_time_hydration.py new file mode 100644 index 000000000..7fe308ec0 --- /dev/null +++ b/tests/unit/common/codec/hydration/v2/test_time_hydration.py @@ -0,0 +1,74 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [https://neo4j.com] +# +# This file is part of Neo4j. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import datetime + +import pytest +import pytz + +from neo4j._codec.hydration.v2 import HydrationHandler +from neo4j._codec.packstream import Structure +from neo4j.time import DateTime + +from ..v1.test_time_hydration import TestTimeHydration as _TestTimeHydrationV1 + + +class TestTimeHydration(_TestTimeHydrationV1): + @pytest.fixture + def hydration_handler(self): + return HydrationHandler() + + def test_hydrate_date_time_structure_v1(self, hydration_scope): + struct = Structure(b"F", 1539344261, 474716862, 3600) + dt = hydration_scope.hydration_hooks[Structure](struct) + assert isinstance(dt, Structure) + assert dt == struct + + def test_hydrate_date_time_structure_v2(self, hydration_scope): + struct = Structure(b"I", 1539344261, 474716862, 3600) + dt = hydration_scope.hydration_hooks[Structure](struct) + assert isinstance(dt, DateTime) + assert dt.year == 2018 + assert dt.month == 10 + assert dt.day == 12 + assert dt.hour == 12 + assert dt.minute == 37 + assert dt.second == 41 + assert dt.nanosecond == 474716862 + assert dt.tzinfo == pytz.FixedOffset(60) + + def test_hydrate_date_time_zone_id_structure_v1(self, hydration_scope): + struct = Structure(b"f", 1539344261, 474716862, "Europe/Stockholm") + dt = hydration_scope.hydration_hooks[Structure](struct) + assert isinstance(dt, Structure) + assert dt == struct + + def test_hydrate_date_time_zone_id_structure_v2(self, hydration_scope): + struct = Structure(b"i", 1539344261, 474716862, "Europe/Stockholm") + dt = hydration_scope.hydration_hooks[Structure](struct) + assert isinstance(dt, DateTime) + assert dt.year == 2018 + assert dt.month == 10 + assert dt.day == 12 + assert dt.hour == 13 + assert dt.minute == 37 + assert dt.second == 41 + assert dt.nanosecond == 474716862 + tz = pytz.timezone("Europe/Stockholm") \ + .localize(dt.replace(tzinfo=None)).tzinfo + assert dt.tzinfo == tz diff --git a/tests/unit/common/codec/packstream/__init__.py b/tests/unit/common/codec/packstream/__init__.py new file mode 100644 index 000000000..c42cc6fb6 --- /dev/null +++ b/tests/unit/common/codec/packstream/__init__.py @@ -0,0 +1,16 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [https://neo4j.com] +# +# This file is part of Neo4j. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/unit/common/codec/packstream/v1/__init__.py b/tests/unit/common/codec/packstream/v1/__init__.py new file mode 100644 index 000000000..c42cc6fb6 --- /dev/null +++ b/tests/unit/common/codec/packstream/v1/__init__.py @@ -0,0 +1,16 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [https://neo4j.com] +# +# This file is part of Neo4j. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/unit/common/codec/packstream/v1/test_packstream.py b/tests/unit/common/codec/packstream/v1/test_packstream.py new file mode 100644 index 000000000..14f8fcfb5 --- /dev/null +++ b/tests/unit/common/codec/packstream/v1/test_packstream.py @@ -0,0 +1,326 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [https://neo4j.com] +# +# This file is part of Neo4j. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import struct +from io import BytesIO +from math import pi +from uuid import uuid4 + +import pytest + +from neo4j._codec.packstream import Structure +from neo4j._codec.packstream.v1 import ( + PackableBuffer, + Packer, + UnpackableBuffer, + Unpacker, +) + + +standard_ascii = [chr(i) for i in range(128)] +not_ascii = "♥O◘♦♥O◘♦" + + +class TestPackStream: + @pytest.fixture + def packer_with_buffer(self): + packable_buffer = Packer.new_packable_buffer() + return Packer(packable_buffer), packable_buffer + + @pytest.fixture + def unpacker_with_buffer(self): + unpackable_buffer = Unpacker.new_unpackable_buffer() + return Unpacker(unpackable_buffer), unpackable_buffer + + def test_packable_buffer(self, packer_with_buffer): + packer, packable_buffer = packer_with_buffer + assert isinstance(packable_buffer, PackableBuffer) + assert packable_buffer is packer.stream + + def test_unpackable_buffer(self, unpacker_with_buffer): + unpacker, unpackable_buffer = unpacker_with_buffer + assert isinstance(unpackable_buffer, UnpackableBuffer) + assert unpackable_buffer is unpacker.unpackable + + @pytest.fixture + def pack(self, packer_with_buffer): + packer, packable_buffer = packer_with_buffer + + def _pack(*values, dehydration_hooks=None): + for value in values: + packer.pack(value, dehydration_hooks=dehydration_hooks) + data = bytearray(packable_buffer.data) + packable_buffer.clear() + return data + + return _pack + + @pytest.fixture + def assert_packable(self, packer_with_buffer, unpacker_with_buffer): + def _assert(value, packed_value): + nonlocal packer_with_buffer, unpacker_with_buffer + packer, packable_buffer = packer_with_buffer + unpacker, unpackable_buffer = unpacker_with_buffer + packable_buffer.clear() + unpackable_buffer.reset() + + packer.pack(value) + packed_data = packable_buffer.data + assert packed_data == packed_value + + unpackable_buffer.data = bytearray(packed_data) + unpackable_buffer.used = len(packed_data) + unpacked_data = unpacker.unpack() + assert unpacked_data == value + + return _assert + + def test_none(self, assert_packable): + assert_packable(None, b"\xC0") + + def test_boolean(self, assert_packable): + assert_packable(True, b"\xC3") + assert_packable(False, b"\xC2") + + def test_negative_tiny_int(self, assert_packable): + for z in range(-16, 0): + assert_packable(z, bytes(bytearray([z + 0x100]))) + + def test_positive_tiny_int(self, assert_packable): + for z in range(0, 128): + assert_packable(z, bytes(bytearray([z]))) + + def test_negative_int8(self, assert_packable): + for z in range(-128, -16): + assert_packable(z, bytes(bytearray([0xC8, z + 0x100]))) + + def test_positive_int16(self, assert_packable): + for z in range(128, 32768): + expected = b"\xC9" + struct.pack(">h", z) + assert_packable(z, expected) + + def test_negative_int16(self, assert_packable): + for z in range(-32768, -128): + expected = b"\xC9" + struct.pack(">h", z) + assert_packable(z, expected) + + def test_positive_int32(self, assert_packable): + for e in range(15, 31): + z = 2 ** e + expected = b"\xCA" + struct.pack(">i", z) + assert_packable(z, expected) + + def test_negative_int32(self, assert_packable): + for e in range(15, 31): + z = -(2 ** e + 1) + expected = b"\xCA" + struct.pack(">i", z) + assert_packable(z, expected) + + def test_positive_int64(self, assert_packable): + for e in range(31, 63): + z = 2 ** e + expected = b"\xCB" + struct.pack(">q", z) + assert_packable(z, expected) + + def test_negative_int64(self, assert_packable): + for e in range(31, 63): + z = -(2 ** e + 1) + expected = b"\xCB" + struct.pack(">q", z) + assert_packable(z, expected) + + def test_integer_positive_overflow(self, pack, assert_packable): + with pytest.raises(OverflowError): + pack(2 ** 63 + 1) + + def test_integer_negative_overflow(self, pack, assert_packable): + with pytest.raises(OverflowError): + pack(-(2 ** 63) - 1) + + def test_zero_float64(self, assert_packable): + zero = 0.0 + expected = b"\xC1" + struct.pack(">d", zero) + assert_packable(zero, expected) + + def test_tau_float64(self, assert_packable): + tau = 2 * pi + expected = b"\xC1" + struct.pack(">d", tau) + assert_packable(tau, expected) + + def test_positive_float64(self, assert_packable): + for e in range(0, 100): + r = float(2 ** e) + 0.5 + expected = b"\xC1" + struct.pack(">d", r) + assert_packable(r, expected) + + def test_negative_float64(self, assert_packable): + for e in range(0, 100): + r = -(float(2 ** e) + 0.5) + expected = b"\xC1" + struct.pack(">d", r) + assert_packable(r, expected) + + def test_empty_bytes(self, assert_packable): + assert_packable(b"", b"\xCC\x00") + + def test_empty_bytearray(self, assert_packable): + assert_packable(bytearray(), b"\xCC\x00") + + def test_bytes_8(self, assert_packable): + assert_packable(bytearray(b"hello"), b"\xCC\x05hello") + + def test_bytes_16(self, assert_packable): + b = bytearray(40000) + assert_packable(b, b"\xCD\x9C\x40" + b) + + def test_bytes_32(self, assert_packable): + b = bytearray(80000) + assert_packable(b, b"\xCE\x00\x01\x38\x80" + b) + + def test_bytearray_size_overflow(self, assert_packable): + stream_out = BytesIO() + packer = Packer(stream_out) + with pytest.raises(OverflowError): + packer.pack_bytes_header(2 ** 32) + + def test_empty_string(self, assert_packable): + assert_packable(u"", b"\x80") + + def test_tiny_strings(self, assert_packable): + for size in range(0x10): + assert_packable(u"A" * size, bytes(bytearray([0x80 + size]) + (b"A" * size))) + + def test_string_8(self, assert_packable): + t = u"A" * 40 + b = t.encode("utf-8") + assert_packable(t, b"\xD0\x28" + b) + + def test_string_16(self, assert_packable): + t = u"A" * 40000 + b = t.encode("utf-8") + assert_packable(t, b"\xD1\x9C\x40" + b) + + def test_string_32(self, assert_packable): + t = u"A" * 80000 + b = t.encode("utf-8") + assert_packable(t, b"\xD2\x00\x01\x38\x80" + b) + + def test_unicode_string(self, assert_packable): + t = u"héllö" + b = t.encode("utf-8") + assert_packable(t, bytes(bytearray([0x80 + len(b)])) + b) + + def test_string_size_overflow(self): + stream_out = BytesIO() + packer = Packer(stream_out) + with pytest.raises(OverflowError): + packer.pack_string_header(2 ** 32) + + def test_empty_list(self, assert_packable): + assert_packable([], b"\x90") + + def test_tiny_lists(self, assert_packable): + for size in range(0x10): + data_out = bytearray([0x90 + size]) + bytearray([1] * size) + assert_packable([1] * size, bytes(data_out)) + + def test_list_8(self, assert_packable): + l = [1] * 40 + assert_packable(l, b"\xD4\x28" + (b"\x01" * 40)) + + def test_list_16(self, assert_packable): + l = [1] * 40000 + assert_packable(l, b"\xD5\x9C\x40" + (b"\x01" * 40000)) + + def test_list_32(self, assert_packable): + l = [1] * 80000 + assert_packable(l, b"\xD6\x00\x01\x38\x80" + (b"\x01" * 80000)) + + def test_nested_lists(self, assert_packable): + assert_packable([[[]]], b"\x91\x91\x90") + + def test_list_size_overflow(self): + stream_out = BytesIO() + packer = Packer(stream_out) + with pytest.raises(OverflowError): + packer.pack_list_header(2 ** 32) + + def test_empty_map(self, assert_packable): + assert_packable({}, b"\xA0") + + @pytest.mark.parametrize("size", range(0x10)) + def test_tiny_maps(self, assert_packable, size): + data_in = dict() + data_out = bytearray([0xA0 + size]) + for el in range(1, size + 1): + data_in[chr(64 + el)] = el + data_out += bytearray([0x81, 64 + el, el]) + assert_packable(data_in, bytes(data_out)) + + def test_map_8(self, pack, assert_packable): + d = dict([(u"A%s" % i, 1) for i in range(40)]) + b = b"".join(pack(u"A%s" % i, 1) for i in range(40)) + assert_packable(d, b"\xD8\x28" + b) + + def test_map_16(self, pack, assert_packable): + d = dict([(u"A%s" % i, 1) for i in range(40000)]) + b = b"".join(pack(u"A%s" % i, 1) for i in range(40000)) + assert_packable(d, b"\xD9\x9C\x40" + b) + + def test_map_32(self, pack, assert_packable): + d = dict([(u"A%s" % i, 1) for i in range(80000)]) + b = b"".join(pack(u"A%s" % i, 1) for i in range(80000)) + assert_packable(d, b"\xDA\x00\x01\x38\x80" + b) + + def test_map_size_overflow(self): + stream_out = BytesIO() + packer = Packer(stream_out) + with pytest.raises(OverflowError): + packer.pack_map_header(2 ** 32) + + @pytest.mark.parametrize(("map_", "exc_type"), ( + ({1: "1"}, TypeError), + ({"x": {1: 'eins', 2: 'zwei', 3: 'drei'}}, TypeError), + ({"x": {(1, 2): '1+2i', (2, 0): '2'}}, TypeError), + )) + def test_map_key_type(self, packer_with_buffer, map_, exc_type): + # maps must have string keys + packer, packable_buffer = packer_with_buffer + with pytest.raises(exc_type, match="strings"): + packer.pack(map_) + + def test_illegal_signature(self, assert_packable): + with pytest.raises(ValueError): + assert_packable(Structure(b"XXX"), b"\xB0XXX") + + def test_empty_struct(self, assert_packable): + assert_packable(Structure(b"X"), b"\xB0X") + + def test_tiny_structs(self, assert_packable): + for size in range(0x10): + fields = [1] * size + data_in = Structure(b"A", *fields) + data_out = bytearray([0xB0 + size, 0x41] + fields) + assert_packable(data_in, bytes(data_out)) + + def test_struct_size_overflow(self, pack): + with pytest.raises(OverflowError): + fields = [1] * 16 + pack(Structure(b"X", *fields)) + + def test_illegal_uuid(self, assert_packable): + with pytest.raises(ValueError): + assert_packable(uuid4(), b"\xB0XXX") diff --git a/tests/unit/common/data/test_packing.py b/tests/unit/common/data/test_packing.py deleted file mode 100644 index 8b274b587..000000000 --- a/tests/unit/common/data/test_packing.py +++ /dev/null @@ -1,284 +0,0 @@ -# Copyright (c) "Neo4j" -# Neo4j Sweden AB [https://neo4j.com] -# -# This file is part of Neo4j. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -from io import BytesIO -from math import pi -import struct -from unittest import TestCase -from uuid import uuid4 - -from pytest import raises - -from neo4j.packstream import ( - Packer, - Structure, - UnpackableBuffer, - Unpacker, -) - - -class PackStreamTestCase(TestCase): - - @classmethod - def packb(cls, *values): - stream = BytesIO() - packer = Packer(stream) - for value in values: - packer.pack(value) - return stream.getvalue() - - @classmethod - def assert_packable(cls, value, packed_value): - stream_out = BytesIO() - packer = Packer(stream_out) - packer.pack(value) - packed = stream_out.getvalue() - try: - assert packed == packed_value - except AssertionError: - raise AssertionError("Packed value %r is %r instead of expected %r" % - (value, packed, packed_value)) - unpacked = Unpacker(UnpackableBuffer(packed)).unpack() - try: - assert unpacked == value - except AssertionError: - raise AssertionError("Unpacked value %r is not equal to original %r" % (unpacked, value)) - - def test_none(self): - self.assert_packable(None, b"\xC0") - - def test_boolean(self): - self.assert_packable(True, b"\xC3") - self.assert_packable(False, b"\xC2") - - def test_negative_tiny_int(self): - for z in range(-16, 0): - self.assert_packable(z, bytes(bytearray([z + 0x100]))) - - def test_positive_tiny_int(self): - for z in range(0, 128): - self.assert_packable(z, bytes(bytearray([z]))) - - def test_negative_int8(self): - for z in range(-128, -16): - self.assert_packable(z, bytes(bytearray([0xC8, z + 0x100]))) - - def test_positive_int16(self): - for z in range(128, 32768): - expected = b"\xC9" + struct.pack(">h", z) - self.assert_packable(z, expected) - - def test_negative_int16(self): - for z in range(-32768, -128): - expected = b"\xC9" + struct.pack(">h", z) - self.assert_packable(z, expected) - - def test_positive_int32(self): - for e in range(15, 31): - z = 2 ** e - expected = b"\xCA" + struct.pack(">i", z) - self.assert_packable(z, expected) - - def test_negative_int32(self): - for e in range(15, 31): - z = -(2 ** e + 1) - expected = b"\xCA" + struct.pack(">i", z) - self.assert_packable(z, expected) - - def test_positive_int64(self): - for e in range(31, 63): - z = 2 ** e - expected = b"\xCB" + struct.pack(">q", z) - self.assert_packable(z, expected) - - def test_negative_int64(self): - for e in range(31, 63): - z = -(2 ** e + 1) - expected = b"\xCB" + struct.pack(">q", z) - self.assert_packable(z, expected) - - def test_integer_positive_overflow(self): - with raises(OverflowError): - self.packb(2 ** 63 + 1) - - def test_integer_negative_overflow(self): - with raises(OverflowError): - self.packb(-(2 ** 63) - 1) - - def test_zero_float64(self): - zero = 0.0 - expected = b"\xC1" + struct.pack(">d", zero) - self.assert_packable(zero, expected) - - def test_tau_float64(self): - tau = 2 * pi - expected = b"\xC1" + struct.pack(">d", tau) - self.assert_packable(tau, expected) - - def test_positive_float64(self): - for e in range(0, 100): - r = float(2 ** e) + 0.5 - expected = b"\xC1" + struct.pack(">d", r) - self.assert_packable(r, expected) - - def test_negative_float64(self): - for e in range(0, 100): - r = -(float(2 ** e) + 0.5) - expected = b"\xC1" + struct.pack(">d", r) - self.assert_packable(r, expected) - - def test_empty_bytes(self): - self.assert_packable(b"", b"\xCC\x00") - - def test_empty_bytearray(self): - self.assert_packable(bytearray(), b"\xCC\x00") - - def test_bytes_8(self): - self.assert_packable(bytearray(b"hello"), b"\xCC\x05hello") - - def test_bytes_16(self): - b = bytearray(40000) - self.assert_packable(b, b"\xCD\x9C\x40" + b) - - def test_bytes_32(self): - b = bytearray(80000) - self.assert_packable(b, b"\xCE\x00\x01\x38\x80" + b) - - def test_bytearray_size_overflow(self): - stream_out = BytesIO() - packer = Packer(stream_out) - with raises(OverflowError): - packer.pack_bytes_header(2 ** 32) - - def test_empty_string(self): - self.assert_packable(u"", b"\x80") - - def test_tiny_strings(self): - for size in range(0x10): - self.assert_packable(u"A" * size, bytes(bytearray([0x80 + size]) + (b"A" * size))) - - def test_string_8(self): - t = u"A" * 40 - b = t.encode("utf-8") - self.assert_packable(t, b"\xD0\x28" + b) - - def test_string_16(self): - t = u"A" * 40000 - b = t.encode("utf-8") - self.assert_packable(t, b"\xD1\x9C\x40" + b) - - def test_string_32(self): - t = u"A" * 80000 - b = t.encode("utf-8") - self.assert_packable(t, b"\xD2\x00\x01\x38\x80" + b) - - def test_unicode_string(self): - t = u"héllö" - b = t.encode("utf-8") - self.assert_packable(t, bytes(bytearray([0x80 + len(b)])) + b) - - def test_string_size_overflow(self): - stream_out = BytesIO() - packer = Packer(stream_out) - with raises(OverflowError): - packer.pack_string_header(2 ** 32) - - def test_empty_list(self): - self.assert_packable([], b"\x90") - - def test_tiny_lists(self): - for size in range(0x10): - data_out = bytearray([0x90 + size]) + bytearray([1] * size) - self.assert_packable([1] * size, bytes(data_out)) - - def test_list_8(self): - l = [1] * 40 - self.assert_packable(l, b"\xD4\x28" + (b"\x01" * 40)) - - def test_list_16(self): - l = [1] * 40000 - self.assert_packable(l, b"\xD5\x9C\x40" + (b"\x01" * 40000)) - - def test_list_32(self): - l = [1] * 80000 - self.assert_packable(l, b"\xD6\x00\x01\x38\x80" + (b"\x01" * 80000)) - - def test_nested_lists(self): - self.assert_packable([[[]]], b"\x91\x91\x90") - - def test_list_size_overflow(self): - stream_out = BytesIO() - packer = Packer(stream_out) - with raises(OverflowError): - packer.pack_list_header(2 ** 32) - - def test_empty_map(self): - self.assert_packable({}, b"\xA0") - - def test_tiny_maps(self): - for size in range(0x10): - data_in = dict() - data_out = bytearray([0xA0 + size]) - for el in range(1, size + 1): - data_in[chr(64 + el)] = el - data_out += bytearray([0x81, 64 + el, el]) - self.assert_packable(data_in, bytes(data_out)) - - def test_map_8(self): - d = dict([(u"A%s" % i, 1) for i in range(40)]) - b = b"".join(self.packb(u"A%s" % i, 1) for i in range(40)) - self.assert_packable(d, b"\xD8\x28" + b) - - def test_map_16(self): - d = dict([(u"A%s" % i, 1) for i in range(40000)]) - b = b"".join(self.packb(u"A%s" % i, 1) for i in range(40000)) - self.assert_packable(d, b"\xD9\x9C\x40" + b) - - def test_map_32(self): - d = dict([(u"A%s" % i, 1) for i in range(80000)]) - b = b"".join(self.packb(u"A%s" % i, 1) for i in range(80000)) - self.assert_packable(d, b"\xDA\x00\x01\x38\x80" + b) - - def test_map_size_overflow(self): - stream_out = BytesIO() - packer = Packer(stream_out) - with raises(OverflowError): - packer.pack_map_header(2 ** 32) - - def test_illegal_signature(self): - with self.assertRaises(ValueError): - self.assert_packable(Structure(b"XXX"), b"\xB0XXX") - - def test_empty_struct(self): - self.assert_packable(Structure(b"X"), b"\xB0X") - - def test_tiny_structs(self): - for size in range(0x10): - fields = [1] * size - data_in = Structure(b"A", *fields) - data_out = bytearray([0xB0 + size, 0x41] + fields) - self.assert_packable(data_in, bytes(data_out)) - - def test_struct_size_overflow(self): - with raises(OverflowError): - fields = [1] * 16 - self.packb(Structure(b"X", *fields)) - - def test_illegal_uuid(self): - with self.assertRaises(ValueError): - self.assert_packable(uuid4(), b"\xB0XXX") diff --git a/tests/unit/common/io/test_routing.py b/tests/unit/common/io/test_routing.py index 030768c58..4dc8ecd82 100644 --- a/tests/unit/common/io/test_routing.py +++ b/tests/unit/common/io/test_routing.py @@ -18,11 +18,11 @@ import pytest -from neo4j.api import DEFAULT_DATABASE -from neo4j.routing import ( +from neo4j._routing import ( OrderedSet, RoutingTable, ) +from neo4j.api import DEFAULT_DATABASE VALID_ROUTING_RECORD = { diff --git a/tests/unit/common/spatial/test_cartesian_point.py b/tests/unit/common/spatial/test_cartesian_point.py index 742aa7b61..5ec40ac5e 100644 --- a/tests/unit/common/spatial/test_cartesian_point.py +++ b/tests/unit/common/spatial/test_cartesian_point.py @@ -16,14 +16,8 @@ # limitations under the License. -import io -import struct from unittest import TestCase -import pytest - -from neo4j.data import DataDehydrator -from neo4j.packstream import Packer from neo4j.spatial import CartesianPoint @@ -48,33 +42,3 @@ def test_alias_2d(self): self.assertEqual(p.y, y) with self.assertRaises(AttributeError): p.z - - def test_dehydration_3d(self): - coordinates = (1, -2, 3.1) - p = CartesianPoint(coordinates) - - dehydrator = DataDehydrator() - buffer = io.BytesIO() - packer = Packer(buffer) - packer.pack(dehydrator.dehydrate((p,))[0]) - self.assertEqual( - buffer.getvalue(), - b"\xB4Y" + - b"\xC9" + struct.pack(">h", 9157) + - b"".join(map(lambda c: b"\xC1" + struct.pack(">d", c), coordinates)) - ) - - def test_dehydration_2d(self): - coordinates = (.1, 0) - p = CartesianPoint(coordinates) - - dehydrator = DataDehydrator() - buffer = io.BytesIO() - packer = Packer(buffer) - packer.pack(dehydrator.dehydrate((p,))[0]) - self.assertEqual( - buffer.getvalue(), - b"\xB3X" + - b"\xC9" + struct.pack(">h", 7203) + - b"".join(map(lambda c: b"\xC1" + struct.pack(">d", c), coordinates)) - ) diff --git a/tests/unit/common/spatial/test_point.py b/tests/unit/common/spatial/test_point.py index fd7f35e98..3122e2de3 100644 --- a/tests/unit/common/spatial/test_point.py +++ b/tests/unit/common/spatial/test_point.py @@ -16,13 +16,9 @@ # limitations under the License. -import io -import struct from unittest import TestCase -from neo4j.data import DataDehydrator -from neo4j.packstream import Packer -from neo4j.spatial import ( +from neo4j._spatial import ( Point, point_type, ) @@ -42,22 +38,6 @@ def test_number_arguments(self): p = Point(argument) assert tuple(p) == argument - def test_dehydration(self): - MyPoint = point_type("MyPoint", ["x", "y"], {2: 1234}) - coordinates = (.1, 0) - p = MyPoint(coordinates) - - dehydrator = DataDehydrator() - buffer = io.BytesIO() - packer = Packer(buffer) - packer.pack(dehydrator.dehydrate((p,))[0]) - self.assertEqual( - buffer.getvalue(), - b"\xB3X" + - b"\xC9" + struct.pack(">h", 1234) + - b"".join(map(lambda c: b"\xC1" + struct.pack(">d", c), coordinates)) - ) - def test_immutable_coordinates(self): MyPoint = point_type("MyPoint", ["x", "y"], {2: 1234}) coordinates = (.1, 0) diff --git a/tests/unit/common/spatial/test_wgs84_point.py b/tests/unit/common/spatial/test_wgs84_point.py index 43f4f251f..540cfd2c4 100644 --- a/tests/unit/common/spatial/test_wgs84_point.py +++ b/tests/unit/common/spatial/test_wgs84_point.py @@ -16,12 +16,8 @@ # limitations under the License. -import io -import struct from unittest import TestCase -from neo4j.data import DataDehydrator -from neo4j.packstream import Packer from neo4j.spatial import WGS84Point @@ -64,33 +60,3 @@ def test_alias_2d(self): p.height with self.assertRaises(AttributeError): p.z - - def test_dehydration_3d(self): - coordinates = (1, -2, 3.1) - p = WGS84Point(coordinates) - - dehydrator = DataDehydrator() - buffer = io.BytesIO() - packer = Packer(buffer) - packer.pack(dehydrator.dehydrate((p,))[0]) - self.assertEqual( - buffer.getvalue(), - b"\xB4Y" + - b"\xC9" + struct.pack(">h", 4979) + - b"".join(map(lambda c: b"\xC1" + struct.pack(">d", c), coordinates)) - ) - - def test_dehydration_2d(self): - coordinates = (.1, 0) - p = WGS84Point(coordinates) - - dehydrator = DataDehydrator() - buffer = io.BytesIO() - packer = Packer(buffer) - packer.pack(dehydrator.dehydrate((p,))[0]) - self.assertEqual( - buffer.getvalue(), - b"\xB3X" + - b"\xC9" + struct.pack(">h", 4326) + - b"".join(map(lambda c: b"\xC1" + struct.pack(">d", c), coordinates)) - ) diff --git a/tests/unit/common/test_addressing.py b/tests/unit/common/test_addressing.py index eafe7f17f..99b730f38 100644 --- a/tests/unit/common/test_addressing.py +++ b/tests/unit/common/test_addressing.py @@ -20,7 +20,7 @@ AF_INET, AF_INET6, ) -import unittest.mock as mock +from unittest import mock import pytest diff --git a/tests/unit/common/test_api.py b/tests/unit/common/test_api.py index 2cf519b5a..a0f836796 100644 --- a/tests/unit/common/test_api.py +++ b/tests/unit/common/test_api.py @@ -16,14 +16,12 @@ # limitations under the License. -from copy import deepcopy import itertools -from uuid import uuid4 +from contextlib import contextmanager import pytest import neo4j.api -from neo4j.data import DataDehydrator from neo4j.exceptions import ConfigurationError @@ -31,125 +29,6 @@ not_ascii = "♥O◘♦♥O◘♦" -def dehydrated_value(value): - return DataDehydrator.fix_parameters({"_": value})["_"] - - -def test_value_dehydration_should_allow_none(): - assert dehydrated_value(None) is None - - -@pytest.mark.parametrize( - "test_input, expected", - [ - (True, True), - (False, False), - ] -) -def test_value_dehydration_should_allow_boolean(test_input, expected): - assert dehydrated_value(test_input) is expected - - -@pytest.mark.parametrize( - "test_input, expected", - [ - (0, 0), - (1, 1), - (0x7F, 0x7F), - (0x7FFF, 0x7FFF), - (0x7FFFFFFF, 0x7FFFFFFF), - (0x7FFFFFFFFFFFFFFF, 0x7FFFFFFFFFFFFFFF), - ] -) -def test_value_dehydration_should_allow_integer(test_input, expected): - assert dehydrated_value(test_input) == expected - - -@pytest.mark.parametrize( - "test_input, expected", - [ - (0x10000000000000000, ValueError), - (-0x10000000000000000, ValueError), - ] -) -def test_value_dehydration_should_disallow_oversized_integer(test_input, expected): - with pytest.raises(expected): - dehydrated_value(test_input) - - -@pytest.mark.parametrize( - "test_input, expected", - [ - (0.0, 0.0), - (-0.1, -0.1), - (3.1415926, 3.1415926), - (-3.1415926, -3.1415926), - ] -) -def test_value_dehydration_should_allow_float(test_input, expected): - assert dehydrated_value(test_input) == expected - - -@pytest.mark.parametrize( - "test_input, expected", - [ - (u"", u""), - (u"hello, world", u"hello, world"), - ("".join(standard_ascii), "".join(standard_ascii)), - ] -) -def test_value_dehydration_should_allow_string(test_input, expected): - assert dehydrated_value(test_input) == expected - - -@pytest.mark.parametrize( - "test_input, expected", - [ - (bytearray(), bytearray()), - (bytearray([1, 2, 3]), bytearray([1, 2, 3])), - ] -) -def test_value_dehydration_should_allow_bytes(test_input, expected): - assert dehydrated_value(test_input) == expected - - -@pytest.mark.parametrize( - "test_input, expected", - [ - ([], []), - ([1, 2, 3], [1, 2, 3]), - ([1, 3.1415926, "string", None], [1, 3.1415926, "string", None]) - ] -) -def test_value_dehydration_should_allow_list(test_input, expected): - assert dehydrated_value(test_input) == expected - - -@pytest.mark.parametrize( - "test_input, expected", - [ - ({}, {}), - ({u"one": 1, u"two": 1, u"three": 1}, {u"one": 1, u"two": 1, u"three": 1}), - ({u"list": [1, 2, 3, [4, 5, 6]], u"dict": {u"a": 1, u"b": 2}}, {u"list": [1, 2, 3, [4, 5, 6]], u"dict": {u"a": 1, u"b": 2}}), - ({"alpha": [1, 3.1415926, "string", None]}, {"alpha": [1, 3.1415926, "string", None]}), - ] -) -def test_value_dehydration_should_allow_dict(test_input, expected): - assert dehydrated_value(test_input) == expected - - -@pytest.mark.parametrize( - "test_input, expected", - [ - (object(), TypeError), - (uuid4(), TypeError), - ] -) -def test_value_dehydration_should_disallow_object(test_input, expected): - with pytest.raises(expected): - dehydrated_value(test_input) - - def test_bookmark_is_deprecated(): with pytest.deprecated_call(): neo4j.Bookmark() @@ -223,9 +102,27 @@ def test_bookmark_initialization_with_valid_strings(test_input, expected_values, (("bookmark1", chr(129),), ValueError), ] ) -def test_bookmark_initialization_with_invalid_strings(test_input, expected): +@pytest.mark.parametrize(("method", "deprecated", "splat_args"), ( + (neo4j.Bookmark, True, True), + (neo4j.Bookmarks.from_raw_values, False, False), +)) +def test_bookmark_initialization_with_invalid_strings( + test_input, expected, method, deprecated, splat_args +): + @contextmanager + def deprecation_assertion(): + if deprecated: + with pytest.warns(DeprecationWarning): + yield + else: + yield + with pytest.raises(expected): - neo4j.Bookmark(*test_input) + with deprecation_assertion(): + if splat_args: + method(*test_input) + else: + method(test_input) @pytest.mark.parametrize("test_as_generator", [True, False]) @@ -238,7 +135,6 @@ def test_bookmark_initialization_with_invalid_strings(test_input, expected): ("bookmark1", ""), ("bookmark1",), (), - (not_ascii,), )) def test_bookmarks_raw_values(test_as_generator, values): expected = frozenset(values) @@ -262,6 +158,7 @@ def test_bookmarks_raw_values(test_as_generator, values): ((set(),), TypeError), ((frozenset(),), TypeError), ((["bookmark1", "bookmark2"],), TypeError), + ((not_ascii,), ValueError), )) def test_bookmarks_invalid_raw_values(values, exc_type): with pytest.raises(exc_type): diff --git a/tests/unit/common/test_conf.py b/tests/unit/common/test_conf.py index db3497263..390f2cab7 100644 --- a/tests/unit/common/test_conf.py +++ b/tests/unit/common/test_conf.py @@ -23,18 +23,18 @@ TrustCustomCAs, TrustSystemCAs, ) +from neo4j._conf import ( + Config, + PoolConfig, + SessionConfig, + WorkspaceConfig, +) from neo4j.api import ( READ_ACCESS, TRUST_ALL_CERTIFICATES, TRUST_SYSTEM_CA_SIGNED_CERTIFICATES, WRITE_ACCESS, ) -from neo4j.conf import ( - Config, - PoolConfig, - SessionConfig, - WorkspaceConfig, -) from neo4j.debug import watch from neo4j.exceptions import ConfigurationError diff --git a/tests/unit/common/test_data.py b/tests/unit/common/test_data.py deleted file mode 100644 index d24c29209..000000000 --- a/tests/unit/common/test_data.py +++ /dev/null @@ -1,141 +0,0 @@ -# Copyright (c) "Neo4j" -# Neo4j Sweden AB [https://neo4j.com] -# -# This file is part of Neo4j. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -import pytest - -from neo4j.data import DataHydrator -from neo4j.packstream import Structure - - -# python -m pytest -s -v tests/unit/test_data.py - - -def test_can_hydrate_v1_node_structure(): - hydrant = DataHydrator() - - struct = Structure(b'N', 123, ["Person"], {"name": "Alice"}) - alice, = hydrant.hydrate([struct]) - - with pytest.warns(DeprecationWarning, match="element_id"): - assert alice.id == 123 - # for backwards compatibility, the driver should compute the element_id - assert alice.element_id == "123" - assert alice.labels == {"Person"} - assert set(alice.keys()) == {"name"} - assert alice.get("name") == "Alice" - - -def test_can_hydrate_v2_node_structure(): - hydrant = DataHydrator() - - struct = Structure(b'N', 123, ["Person"], {"name": "Alice"}, "abc") - alice, = hydrant.hydrate([struct]) - - with pytest.warns(DeprecationWarning, match="element_id"): - assert alice.id == 123 - assert alice.element_id == "abc" - assert alice.labels == {"Person"} - assert set(alice.keys()) == {"name"} - assert alice.get("name") == "Alice" - - -def test_can_hydrate_v1_relationship_structure(): - hydrant = DataHydrator() - - struct = Structure(b'R', 123, 456, 789, "KNOWS", {"since": 1999}) - rel, = hydrant.hydrate([struct]) - - with pytest.warns(DeprecationWarning, match="element_id"): - assert rel.id == 123 - with pytest.warns(DeprecationWarning, match="element_id"): - assert rel.start_node.id == 456 - with pytest.warns(DeprecationWarning, match="element_id"): - assert rel.end_node.id == 789 - # for backwards compatibility, the driver should compy the element_id - assert rel.element_id == "123" - assert rel.start_node.element_id == "456" - assert rel.end_node.element_id == "789" - assert rel.type == "KNOWS" - assert set(rel.keys()) == {"since"} - assert rel.get("since") == 1999 - - -def test_can_hydrate_v2_relationship_structure(): - hydrant = DataHydrator() - - struct = Structure(b'R', 123, 456, 789, "KNOWS", {"since": 1999}, - "abc", "def", "ghi") - - rel, = hydrant.hydrate([struct]) - - with pytest.warns(DeprecationWarning, match="element_id"): - assert rel.id == 123 - with pytest.warns(DeprecationWarning, match="element_id"): - assert rel.start_node.id == 456 - with pytest.warns(DeprecationWarning, match="element_id"): - assert rel.end_node.id == 789 - # for backwards compatibility, the driver should compy the element_id - assert rel.element_id == "abc" - assert rel.start_node.element_id == "def" - assert rel.end_node.element_id == "ghi" - assert rel.type == "KNOWS" - assert set(rel.keys()) == {"since"} - assert rel.get("since") == 1999 - - -def test_hydrating_unknown_structure_returns_same(): - hydrant = DataHydrator() - - struct = Structure(b'?', "foo") - mystery, = hydrant.hydrate([struct]) - - assert mystery == struct - - -def test_can_hydrate_in_list(): - hydrant = DataHydrator() - - struct = Structure(b'N', 123, ["Person"], {"name": "Alice"}) - alice_in_list, = hydrant.hydrate([[struct]]) - - assert isinstance(alice_in_list, list) - - alice, = alice_in_list - - with pytest.warns(DeprecationWarning, match="element_id"): - assert alice.id == 123 - assert alice.labels == {"Person"} - assert set(alice.keys()) == {"name"} - assert alice.get("name") == "Alice" - - -def test_can_hydrate_in_dict(): - hydrant = DataHydrator() - - struct = Structure(b'N', 123, ["Person"], {"name": "Alice"}) - alice_in_dict, = hydrant.hydrate([{"foo": struct}]) - - assert isinstance(alice_in_dict, dict) - - alice = alice_in_dict["foo"] - - with pytest.warns(DeprecationWarning, match="element_id"): - assert alice.id == 123 - assert alice.labels == {"Person"} - assert set(alice.keys()) == {"name"} - assert alice.get("name") == "Alice" diff --git a/tests/unit/common/test_import_neo4j.py b/tests/unit/common/test_import_neo4j.py index 01bfd5905..aa97bea28 100644 --- a/tests/unit/common/test_import_neo4j.py +++ b/tests/unit/common/test_import_neo4j.py @@ -14,6 +14,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import pytest def test_import_dunder_version(): @@ -105,7 +106,8 @@ def test_import_async_session(): def test_import_sessionconfig(): - from neo4j import SessionConfig + with pytest.warns(DeprecationWarning): + from neo4j import SessionConfig def test_import_query(): @@ -129,15 +131,17 @@ def test_import_unit_of_work(): def test_import_config(): - from neo4j import Config + with pytest.warns(DeprecationWarning): + from neo4j import Config def test_import_poolconfig(): - from neo4j import PoolConfig + with pytest.warns(DeprecationWarning): + from neo4j import PoolConfig def test_import_graph(): - import neo4j.graph as graph + from neo4j import graph def test_import_graph_node(): @@ -153,12 +157,12 @@ def test_import_graph_graph(): def test_import_spatial(): - import neo4j.spatial as spatial + from neo4j import spatial def test_import_time(): - import neo4j.time as time + from neo4j import time def test_import_exceptions(): - import neo4j.exceptions as exceptions + from neo4j import exceptions diff --git a/tests/unit/common/test_record.py b/tests/unit/common/test_record.py index 8ad155593..fbb79a957 100644 --- a/tests/unit/common/test_record.py +++ b/tests/unit/common/test_record.py @@ -18,10 +18,11 @@ import pytest -from neo4j.data import ( +from neo4j import Record +from neo4j._codec.hydration.v1 import HydrationHandler +from neo4j.graph import ( Graph, Node, - Record, ) @@ -283,8 +284,8 @@ def test_data(raw, keys, serialized): def test_data_relationship(): - g = Graph() - gh = Graph.Hydrator(g) + hydration_scope = HydrationHandler().new_hydration_scope() + gh = hydration_scope._graph_hydrator alice = gh.hydrate_node(1, {"Person"}, {"name": "Alice", "age": 33}) bob = gh.hydrate_node(2, {"Person"}, {"name": "Bob", "age": 44}) alice_knows_bob = gh.hydrate_relationship(1, 1, 2, "KNOWS", @@ -302,8 +303,8 @@ def test_data_relationship(): def test_data_unbound_relationship(): - g = Graph() - gh = Graph.Hydrator(g) + hydration_scope = HydrationHandler().new_hydration_scope() + gh = hydration_scope._graph_hydrator some_one_knows_some_one = gh.hydrate_relationship( 1, 42, 43, "KNOWS", {"since": 1999} ) @@ -313,8 +314,8 @@ def test_data_unbound_relationship(): @pytest.mark.parametrize("cyclic", (True, False)) def test_data_path(cyclic): - g = Graph() - gh = Graph.Hydrator(g) + hydration_scope = HydrationHandler().new_hydration_scope() + gh = hydration_scope._graph_hydrator alice = gh.hydrate_node(1, {"Person"}, {"name": "Alice", "age": 33}) bob = gh.hydrate_node(2, {"Person"}, {"name": "Bob", "age": 44}) if cyclic: diff --git a/tests/unit/common/test_types.py b/tests/unit/common/test_types.py index 4b97529b6..22a6fcf9e 100644 --- a/tests/unit/common/test_types.py +++ b/tests/unit/common/test_types.py @@ -20,6 +20,7 @@ import pytest +from neo4j._codec.hydration.v1 import HydrationHandler from neo4j.graph import ( Graph, Node, @@ -40,8 +41,8 @@ (None, "foobar"), )) def test_can_create_node(id_, element_id): - g = Graph() - gh = Graph.Hydrator(g) + hydration_scope = HydrationHandler().new_hydration_scope() + gh = hydration_scope._graph_hydrator fields = [id_, {"Person"}, {"name": "Alice", "age": 33}] if element_id is not None: @@ -74,8 +75,8 @@ def test_can_create_node(id_, element_id): def test_node_with_null_properties(): - g = Graph() - gh = Graph.Hydrator(g) + hydration_scope = HydrationHandler().new_hydration_scope() + gh = hydration_scope._graph_hydrator stuff = gh.hydrate_node(1, (), {"good": ["puppies", "kittens"], "bad": None}) assert isinstance(stuff, Node) @@ -121,19 +122,16 @@ def test_node_equality(g1, id1, eid1, props1, g2, id2, eid2, props2): @pytest.mark.parametrize("legacy_id", (True, False)) def test_node_hashing(legacy_id): g = Graph() - node_1 = Node(g, "1234" + ("abc" if not legacy_id else ""), - 1234) - node_2 = Node(g, "1234" + ("abc" if not legacy_id else ""), - 1234) - node_3 = Node(g, "5678" + ("abc" if not legacy_id else ""), - 5678) + node_1 = Node(g, "1234" + ("abc" if not legacy_id else ""), 1234) + node_2 = Node(g, "1234" + ("abc" if not legacy_id else ""), 1234) + node_3 = Node(g, "5678" + ("abc" if not legacy_id else ""), 5678) assert hash(node_1) == hash(node_2) assert hash(node_1) != hash(node_3) def test_node_v1_repr(): - g = Graph() - gh = Graph.Hydrator(g) + hydration_scope = HydrationHandler().new_hydration_scope() + gh = hydration_scope._graph_hydrator alice = gh.hydrate_node(1, {"Person"}, {"name": "Alice"}) assert repr(alice) == ( "H", self.recv_buffer[:2]) - print("CHUNK SIZE %r" % chunk_size) - end = 2 + chunk_size - chunk_data, self.recv_buffer = self.recv_buffer[2:end], self.recv_buffer[end:] - return chunk_data - def pop_message(self): - data = bytearray() - while True: - chunk = self._pop_chunk() - print("CHUNK %r" % chunk) - if chunk: - data.extend(chunk) - elif data: - break # end of message - else: - continue # NOOP - header = data[0] - n_fields = header % 0x10 - tag = data[1] - buffer = UnpackableBuffer(data[2:]) - unpacker = Unpacker(buffer) - fields = [unpacker.unpack() for _ in range(n_fields)] - return tag, fields + assert self._messages + return self._messages.pop(None) def send_message(self, tag, *fields): - data = self.encode_message(tag, *fields) - self.sendall(struct_pack(">H", len(data)) + data + b"\x00\x00") - - @classmethod - def encode_message(cls, tag, *fields): - b = BytesIO() - packer = Packer(b) - for field in fields: - packer.pack(field) - return bytearray([0xB0 + len(fields), tag]) + b.getvalue() + assert self._outbox + self._outbox.append_message(tag, fields, None) + self._outbox.flush() class FakeSocketPair: - def __init__(self, address): - self.client = FakeSocket2(address) - self.server = FakeSocket2() + def __init__(self, address, packer_cls=None, unpacker_cls=None): + self.client = FakeSocket2( + address, packer_cls=packer_cls, unpacker_cls=unpacker_cls + ) + self.server = FakeSocket2( + packer_cls=packer_cls, unpacker_cls=unpacker_cls + ) self.client.on_send = self.server.inject self.server.on_send = self.client.inject diff --git a/tests/unit/sync/io/test__common.py b/tests/unit/sync/io/test__common.py index 27dad7cb9..0298573b5 100644 --- a/tests/unit/sync/io/test__common.py +++ b/tests/unit/sync/io/test__common.py @@ -18,33 +18,43 @@ import pytest +from neo4j._codec.packstream.v1 import PackableBuffer from neo4j._sync.io._common import Outbox +from ...._async_compat import mark_sync_test + @pytest.mark.parametrize(("chunk_size", "data", "result"), ( ( 2, - (bytes(range(10, 15)),), + bytes(range(10, 15)), bytes((0, 2, 10, 11, 0, 2, 12, 13, 0, 1, 14)) ), ( 2, - (bytes(range(10, 14)),), + bytes(range(10, 14)), bytes((0, 2, 10, 11, 0, 2, 12, 13)) ), ( 2, - (bytes((5, 6, 7)), bytes((8, 9))), - bytes((0, 2, 5, 6, 0, 2, 7, 8, 0, 1, 9)) + bytes((5,)), + bytes((0, 1, 5)) ), )) -def test_async_outbox_chunking(chunk_size, data, result): - outbox = Outbox(max_chunk_size=chunk_size) - assert bytes(outbox.view()) == b"" - for d in data: - outbox.write(d) - assert bytes(outbox.view()) == result - # make sure this works multiple times - assert bytes(outbox.view()) == result - outbox.clear() - assert bytes(outbox.view()) == b"" +@mark_sync_test +def test_async_outbox_chunking(chunk_size, data, result, mocker): + buffer = PackableBuffer() + socket_mock = mocker.Mock() + packer_mock = mocker.Mock() + packer_mock.return_value = packer_mock + packer_mock.new_packable_buffer.return_value = buffer + packer_mock.pack_struct.side_effect = \ + lambda *args, **kwargs: buffer.write(data) + outbox = Outbox(socket_mock, pytest.fail, packer_mock, chunk_size) + outbox.append_message(None, None, None) + socket_mock.sendall.assert_not_called() + assert outbox.flush() + socket_mock.sendall.assert_called_once_with(result + b"\x00\x00") + + assert not outbox.flush() + socket_mock.sendall.assert_called_once() diff --git a/tests/unit/sync/io/test_class_bolt3.py b/tests/unit/sync/io/test_class_bolt3.py index bfa63f4fd..87f477d8d 100644 --- a/tests/unit/sync/io/test_class_bolt3.py +++ b/tests/unit/sync/io/test_class_bolt3.py @@ -18,8 +18,8 @@ import pytest +from neo4j._conf import PoolConfig from neo4j._sync.io._bolt3 import Bolt3 -from neo4j.conf import PoolConfig from neo4j.exceptions import ConfigurationError from ...._async_compat import mark_sync_test @@ -72,7 +72,7 @@ def test_db_extra_not_supported_in_run(fake_socket): @mark_sync_test def test_simple_discard(fake_socket): address = ("127.0.0.1", 7687) - socket = fake_socket(address) + socket = fake_socket(address, Bolt3.UNPACKER_CLS) connection = Bolt3(address, socket, PoolConfig.max_connection_lifetime) connection.discard() connection.send_all() @@ -84,7 +84,7 @@ def test_simple_discard(fake_socket): @mark_sync_test def test_simple_pull(fake_socket): address = ("127.0.0.1", 7687) - socket = fake_socket(address) + socket = fake_socket(address, Bolt3.UNPACKER_CLS) connection = Bolt3(address, socket, PoolConfig.max_connection_lifetime) connection.pull() connection.send_all() @@ -99,9 +99,11 @@ def test_hint_recv_timeout_seconds_gets_ignored( fake_socket_pair, recv_timeout, mocker ): address = ("127.0.0.1", 7687) - sockets = fake_socket_pair(address) + sockets = fake_socket_pair( + address, Bolt3.PACKER_CLS, Bolt3.UNPACKER_CLS + ) sockets.client.settimeout = mocker.Mock() - sockets.server.send_message(0x70, { + sockets.server.send_message(b"\x70", { "server": "Neo4j/3.5.0", "hints": {"connection.recv_timeout_seconds": recv_timeout}, }) diff --git a/tests/unit/sync/io/test_class_bolt4x0.py b/tests/unit/sync/io/test_class_bolt4x0.py index e1c0a5ccd..88f549936 100644 --- a/tests/unit/sync/io/test_class_bolt4x0.py +++ b/tests/unit/sync/io/test_class_bolt4x0.py @@ -18,8 +18,8 @@ import pytest +from neo4j._conf import PoolConfig from neo4j._sync.io._bolt4 import Bolt4x0 -from neo4j.conf import PoolConfig from ...._async_compat import mark_sync_test @@ -57,7 +57,7 @@ def test_conn_is_not_stale(fake_socket, set_stale): @mark_sync_test def test_db_extra_in_begin(fake_socket): address = ("127.0.0.1", 7687) - socket = fake_socket(address) + socket = fake_socket(address, Bolt4x0.UNPACKER_CLS) connection = Bolt4x0(address, socket, PoolConfig.max_connection_lifetime) connection.begin(db="something") connection.send_all() @@ -70,7 +70,7 @@ def test_db_extra_in_begin(fake_socket): @mark_sync_test def test_db_extra_in_run(fake_socket): address = ("127.0.0.1", 7687) - socket = fake_socket(address) + socket = fake_socket(address, Bolt4x0.UNPACKER_CLS) connection = Bolt4x0(address, socket, PoolConfig.max_connection_lifetime) connection.run("", {}, db="something") connection.send_all() @@ -85,7 +85,7 @@ def test_db_extra_in_run(fake_socket): @mark_sync_test def test_n_extra_in_discard(fake_socket): address = ("127.0.0.1", 7687) - socket = fake_socket(address) + socket = fake_socket(address, Bolt4x0.UNPACKER_CLS) connection = Bolt4x0(address, socket, PoolConfig.max_connection_lifetime) connection.discard(n=666) connection.send_all() @@ -105,7 +105,7 @@ def test_n_extra_in_discard(fake_socket): @mark_sync_test def test_qid_extra_in_discard(fake_socket, test_input, expected): address = ("127.0.0.1", 7687) - socket = fake_socket(address) + socket = fake_socket(address, Bolt4x0.UNPACKER_CLS) connection = Bolt4x0(address, socket, PoolConfig.max_connection_lifetime) connection.discard(qid=test_input) connection.send_all() @@ -125,7 +125,7 @@ def test_qid_extra_in_discard(fake_socket, test_input, expected): @mark_sync_test def test_n_and_qid_extras_in_discard(fake_socket, test_input, expected): address = ("127.0.0.1", 7687) - socket = fake_socket(address) + socket = fake_socket(address, Bolt4x0.UNPACKER_CLS) connection = Bolt4x0(address, socket, PoolConfig.max_connection_lifetime) connection.discard(n=666, qid=test_input) connection.send_all() @@ -145,7 +145,7 @@ def test_n_and_qid_extras_in_discard(fake_socket, test_input, expected): @mark_sync_test def test_n_extra_in_pull(fake_socket, test_input, expected): address = ("127.0.0.1", 7687) - socket = fake_socket(address) + socket = fake_socket(address, Bolt4x0.UNPACKER_CLS) connection = Bolt4x0(address, socket, PoolConfig.max_connection_lifetime) connection.pull(n=test_input) connection.send_all() @@ -165,7 +165,7 @@ def test_n_extra_in_pull(fake_socket, test_input, expected): @mark_sync_test def test_qid_extra_in_pull(fake_socket, test_input, expected): address = ("127.0.0.1", 7687) - socket = fake_socket(address) + socket = fake_socket(address, Bolt4x0.UNPACKER_CLS) connection = Bolt4x0(address, socket, PoolConfig.max_connection_lifetime) connection.pull(qid=test_input) connection.send_all() @@ -178,7 +178,7 @@ def test_qid_extra_in_pull(fake_socket, test_input, expected): @mark_sync_test def test_n_and_qid_extras_in_pull(fake_socket): address = ("127.0.0.1", 7687) - socket = fake_socket(address) + socket = fake_socket(address, Bolt4x0.UNPACKER_CLS) connection = Bolt4x0(address, socket, PoolConfig.max_connection_lifetime) connection.pull(n=666, qid=777) connection.send_all() @@ -194,9 +194,11 @@ def test_hint_recv_timeout_seconds_gets_ignored( fake_socket_pair, recv_timeout, mocker ): address = ("127.0.0.1", 7687) - sockets = fake_socket_pair(address) + sockets = fake_socket_pair(address, + packer_cls=Bolt4x0.PACKER_CLS, + unpacker_cls=Bolt4x0.UNPACKER_CLS) sockets.client.settimeout = mocker.MagicMock() - sockets.server.send_message(0x70, { + sockets.server.send_message(b"\x70", { "server": "Neo4j/4.0.0", "hints": {"connection.recv_timeout_seconds": recv_timeout}, }) diff --git a/tests/unit/sync/io/test_class_bolt4x1.py b/tests/unit/sync/io/test_class_bolt4x1.py index 9a32fa8e3..e656cc349 100644 --- a/tests/unit/sync/io/test_class_bolt4x1.py +++ b/tests/unit/sync/io/test_class_bolt4x1.py @@ -18,8 +18,8 @@ import pytest +from neo4j._conf import PoolConfig from neo4j._sync.io._bolt4 import Bolt4x1 -from neo4j.conf import PoolConfig from ...._async_compat import mark_sync_test @@ -57,7 +57,7 @@ def test_conn_is_not_stale(fake_socket, set_stale): @mark_sync_test def test_db_extra_in_begin(fake_socket): address = ("127.0.0.1", 7687) - socket = fake_socket(address) + socket = fake_socket(address, Bolt4x1.UNPACKER_CLS) connection = Bolt4x1(address, socket, PoolConfig.max_connection_lifetime) connection.begin(db="something") connection.send_all() @@ -70,7 +70,7 @@ def test_db_extra_in_begin(fake_socket): @mark_sync_test def test_db_extra_in_run(fake_socket): address = ("127.0.0.1", 7687) - socket = fake_socket(address) + socket = fake_socket(address, Bolt4x1.UNPACKER_CLS) connection = Bolt4x1(address, socket, PoolConfig.max_connection_lifetime) connection.run("", {}, db="something") connection.send_all() @@ -85,7 +85,7 @@ def test_db_extra_in_run(fake_socket): @mark_sync_test def test_n_extra_in_discard(fake_socket): address = ("127.0.0.1", 7687) - socket = fake_socket(address) + socket = fake_socket(address, Bolt4x1.UNPACKER_CLS) connection = Bolt4x1(address, socket, PoolConfig.max_connection_lifetime) connection.discard(n=666) connection.send_all() @@ -105,7 +105,7 @@ def test_n_extra_in_discard(fake_socket): @mark_sync_test def test_qid_extra_in_discard(fake_socket, test_input, expected): address = ("127.0.0.1", 7687) - socket = fake_socket(address) + socket = fake_socket(address, Bolt4x1.UNPACKER_CLS) connection = Bolt4x1(address, socket, PoolConfig.max_connection_lifetime) connection.discard(qid=test_input) connection.send_all() @@ -126,7 +126,7 @@ def test_qid_extra_in_discard(fake_socket, test_input, expected): def test_n_and_qid_extras_in_discard(fake_socket, test_input, expected): # python -m pytest tests/unit/io/test_class_bolt4x0.py -s -k test_n_and_qid_extras_in_discard address = ("127.0.0.1", 7687) - socket = fake_socket(address) + socket = fake_socket(address, Bolt4x1.UNPACKER_CLS) connection = Bolt4x1(address, socket, PoolConfig.max_connection_lifetime) connection.discard(n=666, qid=test_input) connection.send_all() @@ -146,7 +146,7 @@ def test_n_and_qid_extras_in_discard(fake_socket, test_input, expected): @mark_sync_test def test_n_extra_in_pull(fake_socket, test_input, expected): address = ("127.0.0.1", 7687) - socket = fake_socket(address) + socket = fake_socket(address, Bolt4x1.UNPACKER_CLS) connection = Bolt4x1(address, socket, PoolConfig.max_connection_lifetime) connection.pull(n=test_input) connection.send_all() @@ -167,7 +167,7 @@ def test_n_extra_in_pull(fake_socket, test_input, expected): def test_qid_extra_in_pull(fake_socket, test_input, expected): # python -m pytest tests/unit/io/test_class_bolt4x0.py -s -k test_qid_extra_in_pull address = ("127.0.0.1", 7687) - socket = fake_socket(address) + socket = fake_socket(address, Bolt4x1.UNPACKER_CLS) connection = Bolt4x1(address, socket, PoolConfig.max_connection_lifetime) connection.pull(qid=test_input) connection.send_all() @@ -180,7 +180,7 @@ def test_qid_extra_in_pull(fake_socket, test_input, expected): @mark_sync_test def test_n_and_qid_extras_in_pull(fake_socket): address = ("127.0.0.1", 7687) - socket = fake_socket(address) + socket = fake_socket(address, Bolt4x1.UNPACKER_CLS) connection = Bolt4x1(address, socket, PoolConfig.max_connection_lifetime) connection.pull(n=666, qid=777) connection.send_all() @@ -193,15 +193,17 @@ def test_n_and_qid_extras_in_pull(fake_socket): @mark_sync_test def test_hello_passes_routing_metadata(fake_socket_pair): address = ("127.0.0.1", 7687) - sockets = fake_socket_pair(address) - sockets.server.send_message(0x70, {"server": "Neo4j/4.1.0"}) + sockets = fake_socket_pair(address, + packer_cls=Bolt4x1.PACKER_CLS, + unpacker_cls=Bolt4x1.UNPACKER_CLS) + sockets.server.send_message(b"\x70", {"server": "Neo4j/4.1.0"}) connection = Bolt4x1( address, sockets.client, PoolConfig.max_connection_lifetime, routing_context={"foo": "bar"} ) connection.hello() tag, fields = sockets.server.pop_message() - assert tag == 0x01 + assert tag == b"\x01" assert len(fields) == 1 assert fields[0]["routing"] == {"foo": "bar"} @@ -212,9 +214,11 @@ def test_hint_recv_timeout_seconds_gets_ignored( fake_socket_pair, recv_timeout, mocker ): address = ("127.0.0.1", 7687) - sockets = fake_socket_pair(address) + sockets = fake_socket_pair(address, + packer_cls=Bolt4x1.PACKER_CLS, + unpacker_cls=Bolt4x1.UNPACKER_CLS) sockets.client.settimeout = mocker.Mock() - sockets.server.send_message(0x70, { + sockets.server.send_message(b"\x70", { "server": "Neo4j/4.1.0", "hints": {"connection.recv_timeout_seconds": recv_timeout}, }) diff --git a/tests/unit/sync/io/test_class_bolt4x2.py b/tests/unit/sync/io/test_class_bolt4x2.py index 145bc0850..d6bff9c23 100644 --- a/tests/unit/sync/io/test_class_bolt4x2.py +++ b/tests/unit/sync/io/test_class_bolt4x2.py @@ -18,8 +18,8 @@ import pytest +from neo4j._conf import PoolConfig from neo4j._sync.io._bolt4 import Bolt4x2 -from neo4j.conf import PoolConfig from ...._async_compat import mark_sync_test @@ -57,7 +57,7 @@ def test_conn_is_not_stale(fake_socket, set_stale): @mark_sync_test def test_db_extra_in_begin(fake_socket): address = ("127.0.0.1", 7687) - socket = fake_socket(address) + socket = fake_socket(address, Bolt4x2.UNPACKER_CLS) connection = Bolt4x2(address, socket, PoolConfig.max_connection_lifetime) connection.begin(db="something") connection.send_all() @@ -70,7 +70,7 @@ def test_db_extra_in_begin(fake_socket): @mark_sync_test def test_db_extra_in_run(fake_socket): address = ("127.0.0.1", 7687) - socket = fake_socket(address) + socket = fake_socket(address, Bolt4x2.UNPACKER_CLS) connection = Bolt4x2(address, socket, PoolConfig.max_connection_lifetime) connection.run("", {}, db="something") connection.send_all() @@ -85,7 +85,7 @@ def test_db_extra_in_run(fake_socket): @mark_sync_test def test_n_extra_in_discard(fake_socket): address = ("127.0.0.1", 7687) - socket = fake_socket(address) + socket = fake_socket(address, Bolt4x2.UNPACKER_CLS) connection = Bolt4x2(address, socket, PoolConfig.max_connection_lifetime) connection.discard(n=666) connection.send_all() @@ -105,7 +105,7 @@ def test_n_extra_in_discard(fake_socket): @mark_sync_test def test_qid_extra_in_discard(fake_socket, test_input, expected): address = ("127.0.0.1", 7687) - socket = fake_socket(address) + socket = fake_socket(address, Bolt4x2.UNPACKER_CLS) connection = Bolt4x2(address, socket, PoolConfig.max_connection_lifetime) connection.discard(qid=test_input) connection.send_all() @@ -126,7 +126,7 @@ def test_qid_extra_in_discard(fake_socket, test_input, expected): def test_n_and_qid_extras_in_discard(fake_socket, test_input, expected): # python -m pytest tests/unit/io/test_class_bolt4x0.py -s -k test_n_and_qid_extras_in_discard address = ("127.0.0.1", 7687) - socket = fake_socket(address) + socket = fake_socket(address, Bolt4x2.UNPACKER_CLS) connection = Bolt4x2(address, socket, PoolConfig.max_connection_lifetime) connection.discard(n=666, qid=test_input) connection.send_all() @@ -146,7 +146,7 @@ def test_n_and_qid_extras_in_discard(fake_socket, test_input, expected): @mark_sync_test def test_n_extra_in_pull(fake_socket, test_input, expected): address = ("127.0.0.1", 7687) - socket = fake_socket(address) + socket = fake_socket(address, Bolt4x2.UNPACKER_CLS) connection = Bolt4x2(address, socket, PoolConfig.max_connection_lifetime) connection.pull(n=test_input) connection.send_all() @@ -167,7 +167,7 @@ def test_n_extra_in_pull(fake_socket, test_input, expected): def test_qid_extra_in_pull(fake_socket, test_input, expected): # python -m pytest tests/unit/io/test_class_bolt4x0.py -s -k test_qid_extra_in_pull address = ("127.0.0.1", 7687) - socket = fake_socket(address) + socket = fake_socket(address, Bolt4x2.UNPACKER_CLS) connection = Bolt4x2(address, socket, PoolConfig.max_connection_lifetime) connection.pull(qid=test_input) connection.send_all() @@ -180,7 +180,7 @@ def test_qid_extra_in_pull(fake_socket, test_input, expected): @mark_sync_test def test_n_and_qid_extras_in_pull(fake_socket): address = ("127.0.0.1", 7687) - socket = fake_socket(address) + socket = fake_socket(address, Bolt4x2.UNPACKER_CLS) connection = Bolt4x2(address, socket, PoolConfig.max_connection_lifetime) connection.pull(n=666, qid=777) connection.send_all() @@ -193,15 +193,17 @@ def test_n_and_qid_extras_in_pull(fake_socket): @mark_sync_test def test_hello_passes_routing_metadata(fake_socket_pair): address = ("127.0.0.1", 7687) - sockets = fake_socket_pair(address) - sockets.server.send_message(0x70, {"server": "Neo4j/4.2.0"}) + sockets = fake_socket_pair(address, + packer_cls=Bolt4x2.PACKER_CLS, + unpacker_cls=Bolt4x2.UNPACKER_CLS) + sockets.server.send_message(b"\x70", {"server": "Neo4j/4.2.0"}) connection = Bolt4x2( address, sockets.client, PoolConfig.max_connection_lifetime, routing_context={"foo": "bar"} ) connection.hello() tag, fields = sockets.server.pop_message() - assert tag == 0x01 + assert tag == b"\x01" assert len(fields) == 1 assert fields[0]["routing"] == {"foo": "bar"} @@ -212,9 +214,11 @@ def test_hint_recv_timeout_seconds_gets_ignored( fake_socket_pair, recv_timeout, mocker ): address = ("127.0.0.1", 7687) - sockets = fake_socket_pair(address) + sockets = fake_socket_pair(address, + packer_cls=Bolt4x2.PACKER_CLS, + unpacker_cls=Bolt4x2.UNPACKER_CLS) sockets.client.settimeout = mocker.Mock() - sockets.server.send_message(0x70, { + sockets.server.send_message(b"\x70", { "server": "Neo4j/4.2.0", "hints": {"connection.recv_timeout_seconds": recv_timeout}, }) diff --git a/tests/unit/sync/io/test_class_bolt4x3.py b/tests/unit/sync/io/test_class_bolt4x3.py index fbde3872e..474b15857 100644 --- a/tests/unit/sync/io/test_class_bolt4x3.py +++ b/tests/unit/sync/io/test_class_bolt4x3.py @@ -20,8 +20,8 @@ import pytest +from neo4j._conf import PoolConfig from neo4j._sync.io._bolt4 import Bolt4x3 -from neo4j.conf import PoolConfig from ...._async_compat import mark_sync_test @@ -59,7 +59,7 @@ def test_conn_is_not_stale(fake_socket, set_stale): @mark_sync_test def test_db_extra_in_begin(fake_socket): address = ("127.0.0.1", 7687) - socket = fake_socket(address) + socket = fake_socket(address, Bolt4x3.UNPACKER_CLS) connection = Bolt4x3(address, socket, PoolConfig.max_connection_lifetime) connection.begin(db="something") connection.send_all() @@ -72,7 +72,7 @@ def test_db_extra_in_begin(fake_socket): @mark_sync_test def test_db_extra_in_run(fake_socket): address = ("127.0.0.1", 7687) - socket = fake_socket(address) + socket = fake_socket(address, Bolt4x3.UNPACKER_CLS) connection = Bolt4x3(address, socket, PoolConfig.max_connection_lifetime) connection.run("", {}, db="something") connection.send_all() @@ -87,7 +87,7 @@ def test_db_extra_in_run(fake_socket): @mark_sync_test def test_n_extra_in_discard(fake_socket): address = ("127.0.0.1", 7687) - socket = fake_socket(address) + socket = fake_socket(address, Bolt4x3.UNPACKER_CLS) connection = Bolt4x3(address, socket, PoolConfig.max_connection_lifetime) connection.discard(n=666) connection.send_all() @@ -107,7 +107,7 @@ def test_n_extra_in_discard(fake_socket): @mark_sync_test def test_qid_extra_in_discard(fake_socket, test_input, expected): address = ("127.0.0.1", 7687) - socket = fake_socket(address) + socket = fake_socket(address, Bolt4x3.UNPACKER_CLS) connection = Bolt4x3(address, socket, PoolConfig.max_connection_lifetime) connection.discard(qid=test_input) connection.send_all() @@ -128,7 +128,7 @@ def test_qid_extra_in_discard(fake_socket, test_input, expected): def test_n_and_qid_extras_in_discard(fake_socket, test_input, expected): # python -m pytest tests/unit/io/test_class_bolt4x0.py -s -k test_n_and_qid_extras_in_discard address = ("127.0.0.1", 7687) - socket = fake_socket(address) + socket = fake_socket(address, Bolt4x3.UNPACKER_CLS) connection = Bolt4x3(address, socket, PoolConfig.max_connection_lifetime) connection.discard(n=666, qid=test_input) connection.send_all() @@ -148,7 +148,7 @@ def test_n_and_qid_extras_in_discard(fake_socket, test_input, expected): @mark_sync_test def test_n_extra_in_pull(fake_socket, test_input, expected): address = ("127.0.0.1", 7687) - socket = fake_socket(address) + socket = fake_socket(address, Bolt4x3.UNPACKER_CLS) connection = Bolt4x3(address, socket, PoolConfig.max_connection_lifetime) connection.pull(n=test_input) connection.send_all() @@ -169,7 +169,7 @@ def test_n_extra_in_pull(fake_socket, test_input, expected): def test_qid_extra_in_pull(fake_socket, test_input, expected): # python -m pytest tests/unit/io/test_class_bolt4x0.py -s -k test_qid_extra_in_pull address = ("127.0.0.1", 7687) - socket = fake_socket(address) + socket = fake_socket(address, Bolt4x3.UNPACKER_CLS) connection = Bolt4x3(address, socket, PoolConfig.max_connection_lifetime) connection.pull(qid=test_input) connection.send_all() @@ -182,7 +182,7 @@ def test_qid_extra_in_pull(fake_socket, test_input, expected): @mark_sync_test def test_n_and_qid_extras_in_pull(fake_socket): address = ("127.0.0.1", 7687) - socket = fake_socket(address) + socket = fake_socket(address, Bolt4x3.UNPACKER_CLS) connection = Bolt4x3(address, socket, PoolConfig.max_connection_lifetime) connection.pull(n=666, qid=777) connection.send_all() @@ -195,15 +195,17 @@ def test_n_and_qid_extras_in_pull(fake_socket): @mark_sync_test def test_hello_passes_routing_metadata(fake_socket_pair): address = ("127.0.0.1", 7687) - sockets = fake_socket_pair(address) - sockets.server.send_message(0x70, {"server": "Neo4j/4.3.0"}) + sockets = fake_socket_pair(address, + packer_cls=Bolt4x3.PACKER_CLS, + unpacker_cls=Bolt4x3.UNPACKER_CLS) + sockets.server.send_message(b"\x70", {"server": "Neo4j/4.3.0"}) connection = Bolt4x3( address, sockets.client, PoolConfig.max_connection_lifetime, routing_context={"foo": "bar"} ) connection.hello() tag, fields = sockets.server.pop_message() - assert tag == 0x01 + assert tag == b"\x01" assert len(fields) == 1 assert fields[0]["routing"] == {"foo": "bar"} @@ -225,10 +227,12 @@ def test_hint_recv_timeout_seconds( fake_socket_pair, hints, valid, caplog, mocker ): address = ("127.0.0.1", 7687) - sockets = fake_socket_pair(address) + sockets = fake_socket_pair(address, + packer_cls=Bolt4x3.PACKER_CLS, + unpacker_cls=Bolt4x3.UNPACKER_CLS) sockets.client.settimeout = mocker.Mock() sockets.server.send_message( - 0x70, {"server": "Neo4j/4.3.0", "hints": hints} + b"\x70", {"server": "Neo4j/4.3.0", "hints": hints} ) connection = Bolt4x3( address, sockets.client, PoolConfig.max_connection_lifetime diff --git a/tests/unit/sync/io/test_class_bolt4x4.py b/tests/unit/sync/io/test_class_bolt4x4.py index 665731727..564660966 100644 --- a/tests/unit/sync/io/test_class_bolt4x4.py +++ b/tests/unit/sync/io/test_class_bolt4x4.py @@ -20,8 +20,8 @@ import pytest +from neo4j._conf import PoolConfig from neo4j._sync.io._bolt4 import Bolt4x4 -from neo4j.conf import PoolConfig from ...._async_compat import mark_sync_test @@ -68,7 +68,7 @@ def test_conn_is_not_stale(fake_socket, set_stale): @mark_sync_test def test_extra_in_begin(fake_socket, args, kwargs, expected_fields): address = ("127.0.0.1", 7687) - socket = fake_socket(address) + socket = fake_socket(address, Bolt4x4.UNPACKER_CLS) connection = Bolt4x4(address, socket, PoolConfig.max_connection_lifetime) connection.begin(*args, **kwargs) connection.send_all() @@ -89,7 +89,7 @@ def test_extra_in_begin(fake_socket, args, kwargs, expected_fields): @mark_sync_test def test_extra_in_run(fake_socket, args, kwargs, expected_fields): address = ("127.0.0.1", 7687) - socket = fake_socket(address) + socket = fake_socket(address, Bolt4x4.UNPACKER_CLS) connection = Bolt4x4(address, socket, PoolConfig.max_connection_lifetime) connection.run(*args, **kwargs) connection.send_all() @@ -101,7 +101,7 @@ def test_extra_in_run(fake_socket, args, kwargs, expected_fields): @mark_sync_test def test_n_extra_in_discard(fake_socket): address = ("127.0.0.1", 7687) - socket = fake_socket(address) + socket = fake_socket(address, Bolt4x4.UNPACKER_CLS) connection = Bolt4x4(address, socket, PoolConfig.max_connection_lifetime) connection.discard(n=666) connection.send_all() @@ -121,7 +121,7 @@ def test_n_extra_in_discard(fake_socket): @mark_sync_test def test_qid_extra_in_discard(fake_socket, test_input, expected): address = ("127.0.0.1", 7687) - socket = fake_socket(address) + socket = fake_socket(address, Bolt4x4.UNPACKER_CLS) connection = Bolt4x4(address, socket, PoolConfig.max_connection_lifetime) connection.discard(qid=test_input) connection.send_all() @@ -142,7 +142,7 @@ def test_qid_extra_in_discard(fake_socket, test_input, expected): def test_n_and_qid_extras_in_discard(fake_socket, test_input, expected): # python -m pytest tests/unit/io/test_class_bolt4x0.py -s -k test_n_and_qid_extras_in_discard address = ("127.0.0.1", 7687) - socket = fake_socket(address) + socket = fake_socket(address, Bolt4x4.UNPACKER_CLS) connection = Bolt4x4(address, socket, PoolConfig.max_connection_lifetime) connection.discard(n=666, qid=test_input) connection.send_all() @@ -162,7 +162,7 @@ def test_n_and_qid_extras_in_discard(fake_socket, test_input, expected): @mark_sync_test def test_n_extra_in_pull(fake_socket, test_input, expected): address = ("127.0.0.1", 7687) - socket = fake_socket(address) + socket = fake_socket(address, Bolt4x4.UNPACKER_CLS) connection = Bolt4x4(address, socket, PoolConfig.max_connection_lifetime) connection.pull(n=test_input) connection.send_all() @@ -183,7 +183,7 @@ def test_n_extra_in_pull(fake_socket, test_input, expected): def test_qid_extra_in_pull(fake_socket, test_input, expected): # python -m pytest tests/unit/io/test_class_bolt4x0.py -s -k test_qid_extra_in_pull address = ("127.0.0.1", 7687) - socket = fake_socket(address) + socket = fake_socket(address, Bolt4x4.UNPACKER_CLS) connection = Bolt4x4(address, socket, PoolConfig.max_connection_lifetime) connection.pull(qid=test_input) connection.send_all() @@ -196,7 +196,7 @@ def test_qid_extra_in_pull(fake_socket, test_input, expected): @mark_sync_test def test_n_and_qid_extras_in_pull(fake_socket): address = ("127.0.0.1", 7687) - socket = fake_socket(address) + socket = fake_socket(address, Bolt4x4.UNPACKER_CLS) connection = Bolt4x4(address, socket, PoolConfig.max_connection_lifetime) connection.pull(n=666, qid=777) connection.send_all() @@ -209,15 +209,17 @@ def test_n_and_qid_extras_in_pull(fake_socket): @mark_sync_test def test_hello_passes_routing_metadata(fake_socket_pair): address = ("127.0.0.1", 7687) - sockets = fake_socket_pair(address) - sockets.server.send_message(0x70, {"server": "Neo4j/4.4.0"}) + sockets = fake_socket_pair(address, + packer_cls=Bolt4x4.PACKER_CLS, + unpacker_cls=Bolt4x4.UNPACKER_CLS) + sockets.server.send_message(b"\x70", {"server": "Neo4j/4.4.0"}) connection = Bolt4x4( address, sockets.client, PoolConfig.max_connection_lifetime, routing_context={"foo": "bar"} ) connection.hello() tag, fields = sockets.server.pop_message() - assert tag == 0x01 + assert tag == b"\x01" assert len(fields) == 1 assert fields[0]["routing"] == {"foo": "bar"} @@ -239,10 +241,12 @@ def test_hint_recv_timeout_seconds( fake_socket_pair, hints, valid, caplog, mocker ): address = ("127.0.0.1", 7687) - sockets = fake_socket_pair(address) + sockets = fake_socket_pair(address, + packer_cls=Bolt4x4.PACKER_CLS, + unpacker_cls=Bolt4x4.UNPACKER_CLS) sockets.client.settimeout = mocker.MagicMock() sockets.server.send_message( - 0x70, {"server": "Neo4j/4.3.4", "hints": hints} + b"\x70", {"server": "Neo4j/4.3.4", "hints": hints} ) connection = Bolt4x4( address, sockets.client, PoolConfig.max_connection_lifetime diff --git a/tests/unit/sync/io/test_direct.py b/tests/unit/sync/io/test_direct.py index 98a1c5b0e..cddbecef8 100644 --- a/tests/unit/sync/io/test_direct.py +++ b/tests/unit/sync/io/test_direct.py @@ -18,7 +18,7 @@ import pytest -from neo4j import ( +from neo4j._conf import ( Config, PoolConfig, WorkspaceConfig, diff --git a/tests/unit/sync/io/test_neo4j_pool.py b/tests/unit/sync/io/test_neo4j_pool.py index 6c9b5db62..af10dc7da 100644 --- a/tests/unit/sync/io/test_neo4j_pool.py +++ b/tests/unit/sync/io/test_neo4j_pool.py @@ -16,27 +16,29 @@ # limitations under the License. +import inspect + import pytest from neo4j import ( READ_ACCESS, WRITE_ACCESS, ) -from neo4j._deadline import Deadline -from neo4j._sync.io import Neo4jPool -from neo4j.addressing import ResolvedAddress -from neo4j.conf import ( +from neo4j._conf import ( PoolConfig, RoutingConfig, WorkspaceConfig, ) +from neo4j._deadline import Deadline +from neo4j._sync.io import Neo4jPool +from neo4j.addressing import ResolvedAddress from neo4j.exceptions import ( ServiceUnavailable, SessionExpired, ) from ...._async_compat import mark_sync_test -from ..work import fake_connection_generator +from ..work import fake_connection_generator # needed as fixture ROUTER_ADDRESS = ResolvedAddress(("1.2.3.1", 9001), host_name="host") @@ -44,7 +46,7 @@ WRITER_ADDRESS = ResolvedAddress(("1.2.3.1", 9003), host_name="host") -@pytest.fixture() +@pytest.fixture def opener(fake_connection_generator, mocker): def open_(addr, timeout): connection = fake_connection_generator() @@ -160,7 +162,9 @@ def break_connection(): pool.deactivate(cx1.addr) if cx_close_mock_side_effect: - cx_close_mock_side_effect() + res = cx_close_mock_side_effect() + if inspect.isawaitable(res): + return res pool = Neo4jPool( opener, PoolConfig(), WorkspaceConfig(), ROUTER_ADDRESS @@ -242,8 +246,8 @@ def test_release_does_not_resets_closed_connections(opener): cx1.is_reset_mock.reset_mock() pool.release(cx1) cx1.closed.assert_called_once() - cx1.is_reset_mock.asset_not_called() - cx1.reset.asset_not_called() + cx1.is_reset_mock.assert_not_called() + cx1.reset.assert_not_called() @mark_sync_test @@ -257,8 +261,8 @@ def test_release_does_not_resets_defunct_connections(opener): cx1.is_reset_mock.reset_mock() pool.release(cx1) cx1.defunct.assert_called_once() - cx1.is_reset_mock.asset_not_called() - cx1.reset.asset_not_called() + cx1.is_reset_mock.assert_not_called() + cx1.reset.assert_not_called() @pytest.mark.parametrize("liveness_timeout", (0, 1, 2)) @@ -271,7 +275,7 @@ def test_acquire_performs_no_liveness_check_on_fresh_connection( ) cx1 = pool._acquire(READER_ADDRESS, Deadline(30), liveness_timeout) assert cx1.addr == READER_ADDRESS - cx1.reset.asset_not_called() + cx1.reset.assert_not_called() @pytest.mark.parametrize("liveness_timeout", (0, 1, 2)) diff --git a/tests/unit/sync/test_driver.py b/tests/unit/sync/test_driver.py index b7665795e..df6e904fc 100644 --- a/tests/unit/sync/test_driver.py +++ b/tests/unit/sync/test_driver.py @@ -17,6 +17,7 @@ import ssl +from functools import wraps import pytest @@ -31,6 +32,7 @@ TrustCustomCAs, TrustSystemCAs, ) +from neo4j._async_compat.util import Util from neo4j.api import ( READ_ACCESS, WRITE_ACCESS, @@ -40,6 +42,21 @@ from ..._async_compat import mark_sync_test +@wraps(GraphDatabase.driver) +def create_driver(*args, **kwargs): + if Util.is_async_code: + with pytest.warns(ExperimentalWarning, match="async") as warnings: + driver = GraphDatabase.driver(*args, **kwargs) + print(warnings) + return driver + else: + return GraphDatabase.driver(*args, **kwargs) + + +def driver(*args, **kwargs): + return Neo4jDriver(*args, **kwargs) + + @pytest.mark.parametrize("protocol", ("bolt://", "bolt+s://", "bolt+ssc://")) @pytest.mark.parametrize("host", ("localhost", "127.0.0.1", "[::1]", "[0:0:0:0:0:0:0:1]")) @@ -53,7 +70,7 @@ def test_direct_driver_constructor(protocol, host, port, params, auth_token): with pytest.warns(DeprecationWarning, match="routing context"): driver = GraphDatabase.driver(uri, auth=auth_token) else: - driver = GraphDatabase.driver(uri, auth=auth_token) + driver = create_driver(uri, auth=auth_token) assert isinstance(driver, BoltDriver) driver.close() @@ -68,7 +85,7 @@ def test_direct_driver_constructor(protocol, host, port, params, auth_token): @mark_sync_test def test_routing_driver_constructor(protocol, host, port, params, auth_token): uri = protocol + host + port + params - driver = GraphDatabase.driver(uri, auth=auth_token) + driver = create_driver(uri, auth=auth_token) assert isinstance(driver, Neo4jDriver) driver.close() @@ -128,13 +145,20 @@ def test_routing_driver_constructor(protocol, host, port, params, auth_token): def test_driver_config_error( test_uri, test_config, expected_failure, expected_failure_message ): + def driver_builder(): + if "trust" in test_config: + with pytest.warns(DeprecationWarning, match="trust"): + return GraphDatabase.driver(test_uri, **test_config) + else: + return create_driver(test_uri, **test_config) + if "+" in test_uri: # `+s` and `+ssc` are short hand syntax for not having to configure the # encryption behavior of the driver. Specifying both is invalid. with pytest.raises(expected_failure, match=expected_failure_message): - GraphDatabase.driver(test_uri, **test_config) + driver_builder() else: - driver = GraphDatabase.driver(test_uri, **test_config) + driver = driver_builder() driver.close() @@ -145,7 +169,7 @@ def test_driver_config_error( )) def test_invalid_protocol(test_uri): with pytest.raises(ConfigurationError, match="scheme"): - GraphDatabase.driver(test_uri) + create_driver(test_uri) @pytest.mark.parametrize( @@ -160,7 +184,7 @@ def test_driver_trust_config_error( test_config, expected_failure, expected_failure_message ): with pytest.raises(expected_failure, match=expected_failure_message): - GraphDatabase.driver("bolt://127.0.0.1:9001", **test_config) + create_driver("bolt://127.0.0.1:9001", **test_config) @pytest.mark.parametrize("uri", ( @@ -169,7 +193,7 @@ def test_driver_trust_config_error( )) @mark_sync_test def test_driver_opens_write_session_by_default(uri, mocker): - driver = GraphDatabase.driver(uri) + driver = create_driver(uri) from neo4j import Transaction # we set a specific db, because else the driver would try to fetch a RT @@ -208,7 +232,7 @@ def test_driver_opens_write_session_by_default(uri, mocker): )) @mark_sync_test def test_verify_connectivity(uri, mocker): - driver = GraphDatabase.driver(uri) + driver = create_driver(uri) pool_mock = mocker.patch.object(driver, "_pool", autospec=True) try: @@ -232,10 +256,10 @@ def test_verify_connectivity(uri, mocker): {"fetch_size": 69}, )) @mark_sync_test -def test_verify_connectivity_parameters_are_experimental( +def test_verify_connectivity_parameters_are_deprecated( uri, kwargs, mocker ): - driver = GraphDatabase.driver(uri) + driver = create_driver(uri) mocker.patch.object(driver, "_pool", autospec=True) try: @@ -258,7 +282,7 @@ def test_verify_connectivity_parameters_are_experimental( def test_get_server_info_parameters_are_experimental( uri, kwargs, mocker ): - driver = GraphDatabase.driver(uri) + driver = create_driver(uri) mocker.patch.object(driver, "_pool", autospec=True) try: diff --git a/tests/unit/sync/work/test_result.py b/tests/unit/sync/work/test_result.py index 4edeec99e..a801ca7e7 100644 --- a/tests/unit/sync/work/test_result.py +++ b/tests/unit/sync/work/test_result.py @@ -14,9 +14,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from re import match -from unittest import mock + + import warnings +from unittest import mock import pandas as pd import pytest @@ -24,6 +25,7 @@ from neo4j import ( Address, + ExperimentalWarning, Record, Result, ResultSummary, @@ -33,9 +35,9 @@ Version, ) from neo4j._async_compat.util import Util -from neo4j.data import ( - DataDehydrator, - DataHydrator, +from neo4j._codec.hydration.v1 import HydrationHandler +from neo4j._codec.packstream import Structure +from neo4j._data import ( Node, Relationship, ) @@ -44,7 +46,6 @@ EntitySetView, Graph, ) -from neo4j.packstream import Structure from ...._async_compat import mark_sync_test @@ -52,9 +53,24 @@ class Records: def __init__(self, fields, records): self.fields = tuple(fields) + self.hydration_scope = HydrationHandler().new_hydration_scope() self.records = tuple(records) + self._hydrate_records() + assert all(len(self.fields) == len(r) for r in self.records) + def _hydrate_records(self): + def _hydrate(value): + if type(value) in self.hydration_scope.hydration_hooks: + return self.hydration_scope.hydration_hooks[type(value)](value) + if isinstance(value, (list, tuple)): + return type(value)(_hydrate(v) for v in value) + if isinstance(value, dict): + return {k: _hydrate(v) for k, v in value.items()} + return value + + self.records = tuple(_hydrate(r) for r in self.records) + def __len__(self): return self.records.__len__() @@ -113,6 +129,7 @@ def __init__(self, records=None, run_meta=None, summary_meta=None, self.summary_meta = summary_meta ConnectionStub.server_info.update({"server": "Neo4j/4.3.0"}) self.unresolved_address = None + self._new_hydration_scope_called = False def send_all(self): self.sent += self.queued @@ -187,10 +204,20 @@ def pull(self, *args, **kwargs): def defunct(self): return False + def new_hydration_scope(self): + class FakeHydrationScope: + hydration_hooks = None + dehydration_hooks = None -class HydratorStub(DataHydrator): - def hydrate(self, values): - return values + def get_graph(self): + return Graph() + + if len(self._records) > 1: + return FakeHydrationScope() + assert not self._new_hydration_scope_called + assert self._records + self._new_hydration_scope_called = True + return self._records[0].hydration_scope def noop(*_, **__): @@ -254,7 +281,7 @@ def fetch_and_compare_all_records( @mark_sync_test def test_result_iteration(method, records): connection = ConnectionStub(records=Records(["x"], records)) - result = Result(connection, HydratorStub(), 2, noop, noop) + result = Result(connection, 2, noop, noop) result._run("CYPHER", {}, None, None, "r", None) fetch_and_compare_all_records(result, "x", records, method) @@ -263,7 +290,7 @@ def test_result_iteration(method, records): def test_result_iteration_mixed_methods(): records = [[i] for i in range(10)] connection = ConnectionStub(records=Records(["x"], records)) - result = Result(connection, HydratorStub(), 4, noop, noop) + result = Result(connection, 4, noop, noop) result._run("CYPHER", {}, None, None, "r", None) iter1 = Util.iter(result) iter2 = Util.iter(result) @@ -299,9 +326,9 @@ def test_parallel_result_iteration(method, invert_fetch): connection = ConnectionStub( records=(Records(["x"], records1), Records(["x"], records2)) ) - result1 = Result(connection, HydratorStub(), 2, noop, noop) + result1 = Result(connection, 2, noop, noop) result1._run("CYPHER1", {}, None, None, "r", None) - result2 = Result(connection, HydratorStub(), 2, noop, noop) + result2 = Result(connection, 2, noop, noop) result2._run("CYPHER2", {}, None, None, "r", None) if invert_fetch: fetch_and_compare_all_records( @@ -329,9 +356,9 @@ def test_interwoven_result_iteration(method, invert_fetch): connection = ConnectionStub( records=(Records(["x"], records1), Records(["y"], records2)) ) - result1 = Result(connection, HydratorStub(), 2, noop, noop) + result1 = Result(connection, 2, noop, noop) result1._run("CYPHER1", {}, None, None, "r", None) - result2 = Result(connection, HydratorStub(), 2, noop, noop) + result2 = Result(connection, 2, noop, noop) result2._run("CYPHER2", {}, None, None, "r", None) start = 0 for n in (1, 2, 3, 1, None): @@ -358,7 +385,7 @@ def test_interwoven_result_iteration(method, invert_fetch): @mark_sync_test def test_result_peek(records, fetch_size): connection = ConnectionStub(records=Records(["x"], records)) - result = Result(connection, HydratorStub(), fetch_size, noop, noop) + result = Result(connection, fetch_size, noop, noop) result._run("CYPHER", {}, None, None, "r", None) for i in range(len(records) + 1): record = result.peek() @@ -381,7 +408,7 @@ def test_result_single_non_strict(records, fetch_size, default): kwargs["strict"] = False connection = ConnectionStub(records=Records(["x"], records)) - result = Result(connection, HydratorStub(), fetch_size, noop, noop) + result = Result(connection, fetch_size, noop, noop) result._run("CYPHER", {}, None, None, "r", None) if len(records) == 0: assert result.single(**kwargs) is None @@ -400,7 +427,7 @@ def test_result_single_non_strict(records, fetch_size, default): @mark_sync_test def test_result_single_strict(records, fetch_size): connection = ConnectionStub(records=Records(["x"], records)) - result = Result(connection, HydratorStub(), fetch_size, noop, noop) + result = Result(connection, fetch_size, noop, noop) result._run("CYPHER", {}, None, None, "r", None) try: record = result.single(strict=True) @@ -427,7 +454,7 @@ def test_result_single_strict(records, fetch_size): @mark_sync_test def test_result_single_exhausts_records(records, fetch_size, strict): connection = ConnectionStub(records=Records(["x"], records)) - result = Result(connection, HydratorStub(), fetch_size, noop, noop) + result = Result(connection, fetch_size, noop, noop) result._run("CYPHER", {}, None, None, "r", None) try: with warnings.catch_warnings(): @@ -449,7 +476,7 @@ def test_result_single_exhausts_records(records, fetch_size, strict): @mark_sync_test def test_result_fetch(records, fetch_size, strict): connection = ConnectionStub(records=Records(["x"], records)) - result = Result(connection, HydratorStub(), fetch_size, noop, noop) + result = Result(connection, fetch_size, noop, noop) result._run("CYPHER", {}, None, None, "r", None) assert result.fetch(0) == [] assert result.fetch(-1) == [] @@ -461,7 +488,7 @@ def test_result_fetch(records, fetch_size, strict): @mark_sync_test def test_keys_are_available_before_and_after_stream(): connection = ConnectionStub(records=Records(["x"], [[1], [2]])) - result = Result(connection, HydratorStub(), 1, noop, noop) + result = Result(connection, 1, noop, noop) result._run("CYPHER", {}, None, None, "r", None) assert list(result.keys()) == ["x"] Util.list(result) @@ -477,7 +504,7 @@ def test_consume(records, consume_one, summary_meta, consume_times): connection = ConnectionStub( records=Records(["x"], records), summary_meta=summary_meta ) - result = Result(connection, HydratorStub(), 1, noop, noop) + result = Result(connection, 1, noop, noop) result._run("CYPHER", {}, None, None, "r", None) if consume_one: try: @@ -512,7 +539,7 @@ def test_time_in_summary(t_first, t_last): summary_meta=summary_meta ) - result = Result(connection, HydratorStub(), 1, noop, noop) + result = Result(connection, 1, noop, noop) result._run("CYPHER", {}, None, None, "r", None) summary = result.consume() @@ -534,7 +561,7 @@ def test_time_in_summary(t_first, t_last): def test_counts_in_summary(): connection = ConnectionStub(records=Records(["n"], [[1], [2]])) - result = Result(connection, HydratorStub(), 1, noop, noop) + result = Result(connection, 1, noop, noop) result._run("CYPHER", {}, None, None, "r", None) summary = result.consume() @@ -548,7 +575,7 @@ def test_query_type(query_type): records=Records(["n"], [[1], [2]]), summary_meta={"type": query_type} ) - result = Result(connection, HydratorStub(), 1, noop, noop) + result = Result(connection, 1, noop, noop) result._run("CYPHER", {}, None, None, "r", None) summary = result.consume() @@ -563,7 +590,7 @@ def test_data(num_records): records=Records(["n"], [[i + 1] for i in range(num_records)]) ) - result = Result(connection, HydratorStub(), 1, noop, noop) + result = Result(connection, 1, noop, noop) result._run("CYPHER", {}, None, None, "r", None) result._buffer_all() records = result._record_buffer.copy() @@ -578,6 +605,7 @@ def test_data(num_records): assert record.data.called_once_with("hello", "world") +# TODO: dehydration now happens on a much lower level @pytest.mark.parametrize("records", ( Records(["n"], []), Records(["n"], [[42], [69], [420], [1337]]), @@ -603,8 +631,9 @@ def test_result_graph(records, scripted_connection): "on_summary": None }), )) - result = Result(scripted_connection, DataHydrator(), 1, noop, - noop) + scripted_connection.new_hydration_scope.return_value = \ + records.hydration_scope + result = Result(scripted_connection, 1, noop, noop) result._run("CYPHER", {}, None, None, "r", None) graph = result.graph() assert isinstance(graph, Graph) @@ -702,12 +731,13 @@ def test_result_graph(records, scripted_connection): @mark_sync_test def test_to_df(keys, values, types, instances, test_default_expand): connection = ConnectionStub(records=Records(keys, values)) - result = Result(connection, DataHydrator(), 1, noop, noop) + result = Result(connection, 1, noop, noop) result._run("CYPHER", {}, None, None, "r", None) - if test_default_expand: - df = result.to_df() - else: - df = result.to_df(expand=False) + with pytest.warns(ExperimentalWarning, match="pandas"): + if test_default_expand: + df = result.to_df() + else: + df = result.to_df(expand=False) assert isinstance(df, pd.DataFrame) assert df.keys().to_list() == keys @@ -807,12 +837,12 @@ def test_to_df(keys, values, types, instances, test_default_expand): ( ["n"], list(zip(( - Structure(b"N", 0, ["LABEL_A"], - {"a": 1, "b": 2, "d": 1}, "00"), - Structure(b"N", 2, ["LABEL_B"], - {"a": 1, "c": 1.2, "d": 2}, "02"), - Structure(b"N", 1, ["LABEL_A", "LABEL_B"], - {"a": [1, "a"], "d": 3}, "01"), + Node(None, "00", 0, ["LABEL_A"], + {"a": 1, "b": 2, "d": 1}), + Node(None, "02", 2, ["LABEL_B"], + {"a": 1, "c": 1.2, "d": 2}), + Node(None, "01", 1, ["LABEL_A", "LABEL_B"], + {"a": [1, "a"], "d": 3}), ))), [ "n().element_id", "n().labels", "n().prop.a", "n().prop.b", @@ -848,11 +878,7 @@ def test_to_df(keys, values, types, instances, test_default_expand): ), ( ["dt"], - [ - DataDehydrator().dehydrate(( - neo4j_time.DateTime(2022, 1, 2, 3, 4, 5, 6), - )), - ], + [[neo4j_time.DateTime(2022, 1, 2, 3, 4, 5, 6)]], ["dt"], [[neo4j_time.DateTime(2022, 1, 2, 3, 4, 5, 6)]], ["object"], @@ -863,9 +889,10 @@ def test_to_df(keys, values, types, instances, test_default_expand): def test_to_df_expand(keys, values, expected_columns, expected_rows, expected_types): connection = ConnectionStub(records=Records(keys, values)) - result = Result(connection, DataHydrator(), 1, noop, noop) + result = Result(connection, 1, noop, noop) result._run("CYPHER", {}, None, None, "r", None) - df = result.to_df(expand=True) + with pytest.warns(ExperimentalWarning, match="pandas"): + df = result.to_df(expand=True) assert isinstance(df, pd.DataFrame) assert len(set(expected_columns)) == len(expected_columns) @@ -895,9 +922,7 @@ def test_to_df_expand(keys, values, expected_columns, expected_rows, ( ["dt"], [ - DataDehydrator().dehydrate(( - neo4j_time.DateTime(2022, 1, 2, 3, 4, 5, 6), - )), + [neo4j_time.DateTime(2022, 1, 2, 3, 4, 5, 6)], ], pd.DataFrame( [[pd.Timestamp("2022-01-02 03:04:05.000000006")]], @@ -908,9 +933,7 @@ def test_to_df_expand(keys, values, expected_columns, expected_rows, ( ["d"], [ - DataDehydrator().dehydrate(( - neo4j_time.Date(2222, 2, 22), - )), + [neo4j_time.Date(2222, 2, 22)], ], pd.DataFrame( [[pd.Timestamp("2222-02-22")]], @@ -921,11 +944,11 @@ def test_to_df_expand(keys, values, expected_columns, expected_rows, ( ["dt_tz"], [ - DataDehydrator().dehydrate(( + [ pytz.timezone("Europe/Stockholm").localize( neo4j_time.DateTime(1970, 1, 1, 0, 0, 0, 0) ), - )), + ], ], pd.DataFrame( [[ @@ -941,17 +964,13 @@ def test_to_df_expand(keys, values, expected_columns, expected_rows, ["mixed"], [ [None], - DataDehydrator().dehydrate(( - neo4j_time.DateTime(2022, 1, 2, 3, 4, 5, 6), - )), - DataDehydrator().dehydrate(( - neo4j_time.Date(2222, 2, 22), - )), - DataDehydrator().dehydrate(( + [neo4j_time.DateTime(2022, 1, 2, 3, 4, 5, 6)], + [neo4j_time.Date(2222, 2, 22)], + [ pytz.timezone("Europe/Stockholm").localize( neo4j_time.DateTime(1970, 1, 1, 0, 0, 0, 0) ), - )), + ], ], pd.DataFrame( [ @@ -971,18 +990,14 @@ def test_to_df_expand(keys, values, expected_columns, expected_rows, ( ["mixed"], [ - DataDehydrator().dehydrate(( - neo4j_time.DateTime(2022, 1, 2, 3, 4, 5, 6), - )), - DataDehydrator().dehydrate(( - neo4j_time.Date(2222, 2, 22), - )), + [neo4j_time.DateTime(2022, 1, 2, 3, 4, 5, 6)], + [neo4j_time.Date(2222, 2, 22)], [None], - DataDehydrator().dehydrate(( + [ pytz.timezone("Europe/Stockholm").localize( neo4j_time.DateTime(1970, 1, 1, 0, 0, 0, 0) ), - )), + ], ], pd.DataFrame( [ @@ -1002,17 +1017,13 @@ def test_to_df_expand(keys, values, expected_columns, expected_rows, ( ["mixed"], [ - DataDehydrator().dehydrate(( - neo4j_time.DateTime(2022, 1, 2, 3, 4, 5, 6), - )), - DataDehydrator().dehydrate(( - neo4j_time.Date(2222, 2, 22), - )), - DataDehydrator().dehydrate(( + [neo4j_time.DateTime(2022, 1, 2, 3, 4, 5, 6),], + [neo4j_time.Date(2222, 2, 22),], + [ pytz.timezone("Europe/Stockholm").localize( neo4j_time.DateTime(1970, 1, 1, 0, 0, 0, 0) ), - )), + ], [None], ], pd.DataFrame( @@ -1052,9 +1063,7 @@ def test_to_df_expand(keys, values, expected_columns, expected_rows, ], [ None, - *DataDehydrator().dehydrate(( - neo4j_time.DateTime(2022, 1, 2, 3, 4, 5, 6), - )), + neo4j_time.DateTime(2022, 1, 2, 3, 4, 5, 6), 1.234, ], ], @@ -1080,8 +1089,9 @@ def test_to_df_expand(keys, values, expected_columns, expected_rows, @mark_sync_test def test_to_df_parse_dates(keys, values, expected_df, expand): connection = ConnectionStub(records=Records(keys, values)) - result = Result(connection, DataHydrator(), 1, noop, noop) + result = Result(connection, 1, noop, noop) result._run("CYPHER", {}, None, None, "r", None) - df = result.to_df(expand=expand, parse_dates=True) + with pytest.warns(ExperimentalWarning, match="pandas"): + df = result.to_df(expand=expand, parse_dates=True) pd.testing.assert_frame_equal(df, expected_df) diff --git a/tests/unit/sync/work/test_session.py b/tests/unit/sync/work/test_session.py index 92edd9aab..c93646306 100644 --- a/tests/unit/sync/work/test_session.py +++ b/tests/unit/sync/work/test_session.py @@ -24,27 +24,35 @@ Bookmarks, ManagedTransaction, Session, - SessionConfig, Transaction, unit_of_work, ) +from neo4j._conf import SessionConfig from neo4j._sync.io._pool import IOPool from ...._async_compat import mark_sync_test -from ._fake_connection import fake_connection_generator @pytest.fixture() def pool(fake_connection_generator, mocker): pool = mocker.Mock(spec=IOPool) - pool.acquire.side_effect = iter(fake_connection_generator, 0) + assert not hasattr(pool, "acquired_connection_mocks") + pool.acquired_connection_mocks = [] + + def acquire_side_effect(*_, **__): + connection = fake_connection_generator() + pool.acquired_connection_mocks.append(connection) + return connection + + pool.acquire.side_effect = acquire_side_effect return pool @mark_sync_test def test_session_context_calls_close(mocker): s = Session(None, SessionConfig()) - mock_close = mocker.patch.object(s, 'close', autospec=True) + mock_close = mocker.patch.object(s, 'close', autospec=True, + side_effect=s.close) with s: pass mock_close.assert_called_once_with() @@ -195,9 +203,12 @@ def test_session_returns_bookmarks_directly(pool, bookmark_values): ) @mark_sync_test def test_session_last_bookmark_is_deprecated(pool, bookmarks): - with Session(pool, SessionConfig( - bookmarks=bookmarks - )) as session: + if bookmarks is not None: + with pytest.warns(DeprecationWarning): + session = Session(pool, SessionConfig(bookmarks=bookmarks)) + else: + session = Session(pool, SessionConfig(bookmarks=bookmarks)) + with session: with pytest.warns(DeprecationWarning): if bookmarks: assert (session.last_bookmark()) == bookmarks[-1] @@ -267,57 +278,46 @@ def test_session_tx_type(pool): assert isinstance(tx, Transaction) -@pytest.mark.parametrize(("parameters", "error_type"), ( - ({"x": None}, None), - ({"x": True}, None), - ({"x": False}, None), - ({"x": 123456789}, None), - ({"x": 3.1415926}, None), - ({"x": float("nan")}, None), - ({"x": float("inf")}, None), - ({"x": float("-inf")}, None), - ({"x": "foo"}, None), - ({"x": bytearray([0x00, 0x33, 0x66, 0x99, 0xCC, 0xFF])}, None), - ({"x": b"\x00\x33\x66\x99\xcc\xff"}, None), - ({"x": [1, 2, 3]}, None), - ({"x": ["a", "b", "c"]}, None), - ({"x": ["a", 2, 1.234]}, None), - ({"x": ["a", 2, ["c"]]}, None), - ({"x": {"one": "eins", "two": "zwei", "three": "drei"}}, None), - ({"x": {"one": ["eins", "uno", 1], "two": ["zwei", "dos", 2]}}, None), - - # maps must have string keys - ({"x": {1: 'eins', 2: 'zwei', 3: 'drei'}}, TypeError), - ({"x": {(1, 2): '1+2i', (2, 0): '2'}}, TypeError), +@pytest.mark.parametrize("parameters", ( + {"x": None}, + {"x": True}, + {"x": False}, + {"x": 123456789}, + {"x": 3.1415926}, + {"x": float("nan")}, + {"x": float("inf")}, + {"x": float("-inf")}, + {"x": "foo"}, + {"x": bytearray([0x00, 0x33, 0x66, 0x99, 0xCC, 0xFF])}, + {"x": b"\x00\x33\x66\x99\xcc\xff"}, + {"x": [1, 2, 3]}, + {"x": ["a", "b", "c"]}, + {"x": ["a", 2, 1.234]}, + {"x": ["a", 2, ["c"]]}, + {"x": {"one": "eins", "two": "zwei", "three": "drei"}}, + {"x": {"one": ["eins", "uno", 1], "two": ["zwei", "dos", 2]}}, )) @pytest.mark.parametrize("run_type", ("auto", "unmanaged", "managed")) @mark_sync_test def test_session_run_with_parameters( - pool, parameters, error_type, run_type + pool, parameters, run_type, mocker ): - @contextmanager - def raises(): - if error_type is not None: - with pytest.raises(error_type) as exc: - yield exc - else: - yield None - with Session(pool, SessionConfig()) as session: if run_type == "auto": - with raises(): - session.run("RETURN $x", **parameters) + session.run("RETURN $x", **parameters) elif run_type == "unmanaged": tx = session.begin_transaction() - with raises(): - tx.run("RETURN $x", **parameters) + tx.run("RETURN $x", **parameters) elif run_type == "managed": def work(tx): - with raises() as exc: - tx.run("RETURN $x", **parameters) - if exc is not None: - raise exc - with raises(): - session.write_transaction(work) + tx.run("RETURN $x", **parameters) + session.write_transaction(work) else: raise ValueError(run_type) + + assert len(pool.acquired_connection_mocks) == 1 + connection_mock = pool.acquired_connection_mocks[0] + assert connection_mock.run.called_once() + call = connection_mock.run.call_args + assert call.args[0] == "RETURN $x" + assert call.kwargs["parameters"] == parameters diff --git a/tests/unit/sync/work/test_transaction.py b/tests/unit/sync/work/test_transaction.py index 3c5dfcbee..9a3440faf 100644 --- a/tests/unit/sync/work/test_transaction.py +++ b/tests/unit/sync/work/test_transaction.py @@ -113,23 +113,6 @@ class OopsError(RuntimeError): assert tx_.closed() -@pytest.mark.parametrize(("parameters", "error_type"), ( - # maps must have string keys - ({"x": {1: 'eins', 2: 'zwei', 3: 'drei'}}, TypeError), - ({"x": {(1, 2): '1+2i', (2, 0): '2'}}, TypeError), - ({"x": uuid4()}, TypeError), -)) -@mark_sync_test -def test_transaction_run_with_invalid_parameters( - fake_connection, parameters, error_type -): - on_closed = MagicMock() - on_error = MagicMock() - tx = Transaction(fake_connection, 2, on_closed, on_error) - with pytest.raises(error_type): - tx.run("RETURN $x", **parameters) - - @mark_sync_test def test_transaction_run_takes_no_query_object(fake_connection): on_closed = MagicMock()