|
| 1 | +# Copyright (c) "Neo4j" |
| 2 | +# Neo4j Sweden AB [https://neo4j.com] |
| 3 | +# |
| 4 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 5 | +# you may not use this file except in compliance with the License. |
| 6 | +# You may obtain a copy of the License at |
| 7 | +# |
| 8 | +# https://www.apache.org/licenses/LICENSE-2.0 |
| 9 | +# |
| 10 | +# Unless required by applicable law or agreed to in writing, software |
| 11 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 12 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 13 | +# See the License for the specific language governing permissions and |
| 14 | +# limitations under the License. |
| 15 | + |
| 16 | + |
| 17 | +from __future__ import annotations |
| 18 | + |
| 19 | +import typing as t |
| 20 | +from contextlib import suppress as _suppress |
| 21 | +from socket import ( |
| 22 | + AddressFamily, |
| 23 | + AF_INET, |
| 24 | + AF_INET6, |
| 25 | + getservbyname, |
| 26 | +) |
| 27 | + |
| 28 | + |
| 29 | +if t.TYPE_CHECKING: |
| 30 | + import typing_extensions as te |
| 31 | + |
| 32 | + |
| 33 | +_T = t.TypeVar("_T") |
| 34 | + |
| 35 | + |
| 36 | +if t.TYPE_CHECKING: |
| 37 | + |
| 38 | + class _WithPeerName(te.Protocol): |
| 39 | + def getpeername(self) -> tuple: ... |
| 40 | + |
| 41 | + |
| 42 | +__all__ = [ |
| 43 | + "Address", |
| 44 | + "IPv4Address", |
| 45 | + "IPv6Address", |
| 46 | + "ResolvedAddress", |
| 47 | + "ResolvedIPv4Address", |
| 48 | + "ResolvedIPv6Address", |
| 49 | +] |
| 50 | + |
| 51 | + |
| 52 | +class _AddressMeta(type(tuple)): # type: ignore[misc] |
| 53 | + def __init__(cls, *args, **kwargs): |
| 54 | + super().__init__(*args, **kwargs) |
| 55 | + cls._ipv4_cls = None |
| 56 | + cls._ipv6_cls = None |
| 57 | + |
| 58 | + def _subclass_by_family(cls, family): |
| 59 | + subclasses = [ |
| 60 | + sc |
| 61 | + for sc in cls.__subclasses__() |
| 62 | + if ( |
| 63 | + sc.__module__ == cls.__module__ |
| 64 | + and getattr(sc, "family", None) == family |
| 65 | + ) |
| 66 | + ] |
| 67 | + if len(subclasses) != 1: |
| 68 | + raise ValueError( |
| 69 | + f"Class {cls} needs exactly one direct subclass with " |
| 70 | + f"attribute `family == {family}` within this module. " |
| 71 | + f"Found: {subclasses}" |
| 72 | + ) |
| 73 | + return subclasses[0] |
| 74 | + |
| 75 | + @property |
| 76 | + def ipv4_cls(cls): |
| 77 | + if cls._ipv4_cls is None: |
| 78 | + cls._ipv4_cls = cls._subclass_by_family(AF_INET) |
| 79 | + return cls._ipv4_cls |
| 80 | + |
| 81 | + @property |
| 82 | + def ipv6_cls(cls): |
| 83 | + if cls._ipv6_cls is None: |
| 84 | + cls._ipv6_cls = cls._subclass_by_family(AF_INET6) |
| 85 | + return cls._ipv6_cls |
| 86 | + |
| 87 | + |
| 88 | +class Address(tuple, metaclass=_AddressMeta): |
| 89 | + """ |
| 90 | + Base class to represent server addresses within the driver. |
| 91 | +
|
| 92 | + A tuple of two (IPv4) or four (IPv6) elements, representing the address |
| 93 | + parts. See also python's :mod:`socket` module for more information. |
| 94 | +
|
| 95 | + >>> Address(("example.com", 7687)) |
| 96 | + IPv4Address(('example.com', 7687)) |
| 97 | + >>> Address(("127.0.0.1", 7687)) |
| 98 | + IPv4Address(('127.0.0.1', 7687)) |
| 99 | + >>> Address(("::1", 7687, 0, 0)) |
| 100 | + IPv6Address(('::1', 7687, 0, 0)) |
| 101 | +
|
| 102 | + :param iterable: A collection of two or four elements creating an |
| 103 | + :class:`.IPv4Address` or :class:`.IPv6Address` instance respectively. |
| 104 | + """ |
| 105 | + |
| 106 | + #: Address family (:data:`socket.AF_INET` or :data:`socket.AF_INET6`). |
| 107 | + family: AddressFamily | None = None |
| 108 | + |
| 109 | + def __new__(cls, iterable: t.Collection) -> Address: |
| 110 | + if isinstance(iterable, cls): |
| 111 | + return iterable |
| 112 | + n_parts = len(iterable) |
| 113 | + inst = tuple.__new__(cls, iterable) |
| 114 | + if n_parts == 2: |
| 115 | + inst.__class__ = cls.ipv4_cls |
| 116 | + elif n_parts == 4: |
| 117 | + inst.__class__ = cls.ipv6_cls |
| 118 | + else: |
| 119 | + raise ValueError( |
| 120 | + "Addresses must consist of either " |
| 121 | + "two parts (IPv4) or four parts (IPv6)" |
| 122 | + ) |
| 123 | + return inst |
| 124 | + |
| 125 | + @classmethod |
| 126 | + def from_socket(cls, socket: _WithPeerName) -> Address: |
| 127 | + """ |
| 128 | + Create an address from a socket object. |
| 129 | +
|
| 130 | + Uses the socket's ``getpeername`` method to retrieve the remote |
| 131 | + address the socket is connected to. |
| 132 | + """ |
| 133 | + address = socket.getpeername() |
| 134 | + return cls(address) |
| 135 | + |
| 136 | + @classmethod |
| 137 | + def parse( |
| 138 | + cls, |
| 139 | + s: str, |
| 140 | + default_host: str | None = None, |
| 141 | + default_port: int | None = None, |
| 142 | + ) -> Address: |
| 143 | + """ |
| 144 | + Parse a string into an address. |
| 145 | +
|
| 146 | + The string must be in the format ``host:port`` (IPv4) or |
| 147 | + ``[host]:port`` (IPv6). |
| 148 | + If no port is specified, or is empty, ``default_port`` will be used. |
| 149 | + If no host is specified, or is empty, ``default_host`` will be used. |
| 150 | +
|
| 151 | + >>> Address.parse("localhost:7687") |
| 152 | + IPv4Address(('localhost', 7687)) |
| 153 | + >>> Address.parse("[::1]:7687") |
| 154 | + IPv6Address(('::1', 7687, 0, 0)) |
| 155 | + >>> Address.parse("localhost") |
| 156 | + IPv4Address(('localhost', 0)) |
| 157 | + >>> Address.parse("localhost", default_port=1234) |
| 158 | + IPv4Address(('localhost', 1234)) |
| 159 | +
|
| 160 | + :param s: The string to parse. |
| 161 | + :param default_host: The default host to use if none is specified. |
| 162 | + :data:`None` indicates to use ``"localhost"`` as default. |
| 163 | + :param default_port: The default port to use if none is specified. |
| 164 | + :data:`None` indicates to use ``0`` as default. |
| 165 | +
|
| 166 | + :returns: The parsed address. |
| 167 | + """ |
| 168 | + if not isinstance(s, str): |
| 169 | + raise TypeError("Address.parse requires a string argument") |
| 170 | + if s.startswith("["): |
| 171 | + # IPv6 |
| 172 | + port: str | int |
| 173 | + host, _, port = s[1:].rpartition("]") |
| 174 | + port = port.lstrip(":") |
| 175 | + with _suppress(TypeError, ValueError): |
| 176 | + port = int(port) |
| 177 | + host = host or default_host or "localhost" |
| 178 | + port = port or default_port or 0 |
| 179 | + return cls((host, port, 0, 0)) |
| 180 | + else: |
| 181 | + # IPv4 |
| 182 | + host, _, port = s.partition(":") |
| 183 | + with _suppress(TypeError, ValueError): |
| 184 | + port = int(port) |
| 185 | + host = host or default_host or "localhost" |
| 186 | + port = port or default_port or 0 |
| 187 | + return cls((host, port)) |
| 188 | + |
| 189 | + @classmethod |
| 190 | + def parse_list( |
| 191 | + cls, |
| 192 | + *s: str, |
| 193 | + default_host: str | None = None, |
| 194 | + default_port: int | None = None, |
| 195 | + ) -> list[Address]: |
| 196 | + """ |
| 197 | + Parse multiple addresses into a list. |
| 198 | +
|
| 199 | + See :meth:`.parse` for details on the string format. |
| 200 | +
|
| 201 | + Either a whitespace-separated list of strings or multiple strings |
| 202 | + can be used. |
| 203 | +
|
| 204 | + >>> Address.parse_list("localhost:7687", "[::1]:7687") |
| 205 | + [IPv4Address(('localhost', 7687)), IPv6Address(('::1', 7687, 0, 0))] |
| 206 | + >>> Address.parse_list("localhost:7687 [::1]:7687") |
| 207 | + [IPv4Address(('localhost', 7687)), IPv6Address(('::1', 7687, 0, 0))] |
| 208 | +
|
| 209 | + :param s: The string(s) to parse. |
| 210 | + :param default_host: The default host to use if none is specified. |
| 211 | + :data:`None` indicates to use ``"localhost"`` as default. |
| 212 | + :param default_port: The default port to use if none is specified. |
| 213 | + :data:`None` indicates to use ``0`` as default. |
| 214 | +
|
| 215 | + :returns: The list of parsed addresses. |
| 216 | + """ # noqa: E501 can't split the doctest lines |
| 217 | + if not all(isinstance(s0, str) for s0 in s): |
| 218 | + raise TypeError("Address.parse_list requires a string argument") |
| 219 | + return [ |
| 220 | + cls.parse(a, default_host, default_port) |
| 221 | + for a in " ".join(s).split() |
| 222 | + ] |
| 223 | + |
| 224 | + def __repr__(self): |
| 225 | + return f"{self.__class__.__name__}({tuple(self)!r})" |
| 226 | + |
| 227 | + @property |
| 228 | + def _host_name(self) -> t.Any: |
| 229 | + return self[0] |
| 230 | + |
| 231 | + @property |
| 232 | + def host(self) -> t.Any: |
| 233 | + """ |
| 234 | + The host part of the address. |
| 235 | +
|
| 236 | + This is the first part of the address tuple. |
| 237 | +
|
| 238 | + >>> Address(("localhost", 7687)).host |
| 239 | + 'localhost' |
| 240 | + """ |
| 241 | + return self[0] |
| 242 | + |
| 243 | + @property |
| 244 | + def port(self) -> t.Any: |
| 245 | + """ |
| 246 | + The port part of the address. |
| 247 | +
|
| 248 | + This is the second part of the address tuple. |
| 249 | +
|
| 250 | + >>> Address(("localhost", 7687)).port |
| 251 | + 7687 |
| 252 | + >>> Address(("localhost", 7687, 0, 0)).port |
| 253 | + 7687 |
| 254 | + >>> Address(("localhost", "7687")).port |
| 255 | + '7687' |
| 256 | + >>> Address(("localhost", "http")).port |
| 257 | + 'http' |
| 258 | + """ |
| 259 | + return self[1] |
| 260 | + |
| 261 | + @property |
| 262 | + def _unresolved(self) -> Address: |
| 263 | + return self |
| 264 | + |
| 265 | + @property |
| 266 | + def port_number(self) -> int: |
| 267 | + """ |
| 268 | + The port part of the address as an integer. |
| 269 | +
|
| 270 | + First try to resolve the port as an integer, using |
| 271 | + :func:`socket.getservbyname`. If that fails, fall back to parsing the |
| 272 | + port as an integer. |
| 273 | +
|
| 274 | + >>> Address(("localhost", 7687)).port_number |
| 275 | + 7687 |
| 276 | + >>> Address(("localhost", "http")).port_number |
| 277 | + 80 |
| 278 | + >>> Address(("localhost", "7687")).port_number |
| 279 | + 7687 |
| 280 | + >>> Address(("localhost", [])).port_number |
| 281 | + Traceback (most recent call last): |
| 282 | + ... |
| 283 | + TypeError: Unknown port value [] |
| 284 | + >>> Address(("localhost", "banana-protocol")).port_number |
| 285 | + Traceback (most recent call last): |
| 286 | + ... |
| 287 | + ValueError: Unknown port value 'banana-protocol' |
| 288 | +
|
| 289 | + :returns: The resolved port number. |
| 290 | +
|
| 291 | + :raise ValueError: If the port cannot be resolved. |
| 292 | + :raise TypeError: If the port cannot be resolved. |
| 293 | + """ |
| 294 | + error_cls: type = TypeError |
| 295 | + |
| 296 | + try: |
| 297 | + return getservbyname(self[1]) |
| 298 | + except OSError: |
| 299 | + # OSError: service/proto not found |
| 300 | + error_cls = ValueError |
| 301 | + except TypeError: |
| 302 | + # TypeError: getservbyname() argument 1 must be str, not X |
| 303 | + pass |
| 304 | + try: |
| 305 | + return int(self[1]) |
| 306 | + except ValueError: |
| 307 | + error_cls = ValueError |
| 308 | + except TypeError: |
| 309 | + pass |
| 310 | + raise error_cls(f"Unknown port value {self[1]!r}") |
| 311 | + |
| 312 | + |
| 313 | +class IPv4Address(Address): |
| 314 | + """ |
| 315 | + An IPv4 address (family ``AF_INET``). |
| 316 | +
|
| 317 | + This class is also used for addresses that specify a host name instead of |
| 318 | + an IP address. E.g., |
| 319 | +
|
| 320 | + >>> Address(("example.com", 7687)) |
| 321 | + IPv4Address(('example.com', 7687)) |
| 322 | +
|
| 323 | + This class should not be instantiated directly. Instead, use |
| 324 | + :class:`.Address` or one of its factory methods. |
| 325 | + """ |
| 326 | + |
| 327 | + family = AF_INET |
| 328 | + |
| 329 | + def __str__(self) -> str: |
| 330 | + return "{}:{}".format(*self) |
| 331 | + |
| 332 | + |
| 333 | +class IPv6Address(Address): |
| 334 | + """ |
| 335 | + An IPv6 address (family ``AF_INET6``). |
| 336 | +
|
| 337 | + This class should not be instantiated directly. Instead, use |
| 338 | + :class:`.Address` or one of its factory methods. |
| 339 | + """ |
| 340 | + |
| 341 | + family = AF_INET6 |
| 342 | + |
| 343 | + def __str__(self) -> str: |
| 344 | + return "[{}]:{}".format(*self) |
| 345 | + |
| 346 | + |
| 347 | +# TODO: 6.0 - make this class private |
| 348 | +class ResolvedAddress(Address): |
| 349 | + _unresolved_host_name: str |
| 350 | + |
| 351 | + @property |
| 352 | + def _host_name(self) -> str: |
| 353 | + return self._unresolved_host_name |
| 354 | + |
| 355 | + @property |
| 356 | + def _unresolved(self) -> Address: |
| 357 | + return super().__new__(Address, (self._host_name, *self[1:])) |
| 358 | + |
| 359 | + def __new__(cls, iterable, *, host_name: str) -> ResolvedAddress: |
| 360 | + new = super().__new__(cls, iterable) |
| 361 | + new = t.cast(ResolvedAddress, new) |
| 362 | + new._unresolved_host_name = host_name |
| 363 | + return new |
| 364 | + |
| 365 | + |
| 366 | +# TODO: 6.0 - make this class private |
| 367 | +class ResolvedIPv4Address(IPv4Address, ResolvedAddress): |
| 368 | + pass |
| 369 | + |
| 370 | + |
| 371 | +# TODO: 6.0 - make this class private |
| 372 | +class ResolvedIPv6Address(IPv6Address, ResolvedAddress): |
| 373 | + pass |
0 commit comments