|
16 | 16 | # limitations under the License.
|
17 | 17 |
|
18 | 18 |
|
| 19 | +import ssl |
| 20 | + |
19 | 21 | import pytest
|
20 | 22 |
|
21 | 23 | from neo4j import (
|
@@ -266,3 +268,150 @@ def test_init_session_config_with_not_valid_key():
|
266 | 268 | _ = SessionConfig.consume(test_config_b)
|
267 | 269 |
|
268 | 270 | assert session_config.connection_acquisition_timeout == 333
|
| 271 | + |
| 272 | + |
| 273 | +@pytest.mark.parametrize("config", ( |
| 274 | + {}, |
| 275 | + {"encrypted": False}, |
| 276 | + {"trusted_certificates": TrustSystemCAs()}, |
| 277 | + {"trusted_certificates": TrustAll()}, |
| 278 | + {"trusted_certificates": TrustCustomCAs("foo", "bar")}, |
| 279 | +)) |
| 280 | +def test_no_ssl_mock(config, mocker): |
| 281 | + ssl_context_mock = mocker.patch("ssl.SSLContext", autospec=True) |
| 282 | + pool_config = PoolConfig.consume(config) |
| 283 | + assert pool_config.encrypted is False |
| 284 | + assert pool_config.get_ssl_context() is None |
| 285 | + ssl_context_mock.assert_not_called() |
| 286 | + |
| 287 | + |
| 288 | +@pytest.mark.parametrize("config", ( |
| 289 | + {"encrypted": True}, |
| 290 | + {"encrypted": True, "trusted_certificates": TrustSystemCAs()}, |
| 291 | +)) |
| 292 | +def test_trust_system_cas_mock(config, mocker): |
| 293 | + ssl_context_mock = mocker.patch("ssl.SSLContext", autospec=True) |
| 294 | + pool_config = PoolConfig.consume(config) |
| 295 | + assert pool_config.encrypted is True |
| 296 | + ssl_context = pool_config.get_ssl_context() |
| 297 | + _assert_mock_tls_1_2(ssl_context_mock) |
| 298 | + assert ssl_context.minimum_version == ssl.TLSVersion.TLSv1_2 |
| 299 | + ssl_context_mock.return_value.load_default_certs.assert_called_once_with() |
| 300 | + ssl_context_mock.return_value.load_verify_locations.assert_not_called() |
| 301 | + assert ssl_context.check_hostname is True |
| 302 | + assert ssl_context.verify_mode == ssl.CERT_REQUIRED |
| 303 | + |
| 304 | + |
| 305 | +@pytest.mark.parametrize("config", ( |
| 306 | + {"encrypted": True, "trusted_certificates": TrustCustomCAs("foo", "bar")}, |
| 307 | + {"encrypted": True, "trusted_certificates": TrustCustomCAs()}, |
| 308 | +)) |
| 309 | +def test_trust_custom_cas_mock(config, mocker): |
| 310 | + ssl_context_mock = mocker.patch("ssl.SSLContext", autospec=True) |
| 311 | + certs = config["trusted_certificates"].certs |
| 312 | + pool_config = PoolConfig.consume(config) |
| 313 | + assert pool_config.encrypted is True |
| 314 | + ssl_context = pool_config.get_ssl_context() |
| 315 | + _assert_mock_tls_1_2(ssl_context_mock) |
| 316 | + assert ssl_context.minimum_version == ssl.TLSVersion.TLSv1_2 |
| 317 | + ssl_context_mock.return_value.load_default_certs.assert_not_called() |
| 318 | + assert ( |
| 319 | + ssl_context_mock.return_value.load_verify_locations.call_args_list |
| 320 | + == [((cert,), {}) for cert in certs] |
| 321 | + ) |
| 322 | + assert ssl_context.check_hostname is True |
| 323 | + assert ssl_context.verify_mode == ssl.CERT_REQUIRED |
| 324 | + |
| 325 | + |
| 326 | +@pytest.mark.parametrize("config", ( |
| 327 | + {"encrypted": True, "trusted_certificates": TrustAll()}, |
| 328 | +)) |
| 329 | +def test_trust_all_mock(config, mocker): |
| 330 | + ssl_context_mock = mocker.patch("ssl.SSLContext", autospec=True) |
| 331 | + pool_config = PoolConfig.consume(config) |
| 332 | + assert pool_config.encrypted is True |
| 333 | + ssl_context = pool_config.get_ssl_context() |
| 334 | + _assert_mock_tls_1_2(ssl_context_mock) |
| 335 | + assert ssl_context.minimum_version == ssl.TLSVersion.TLSv1_2 |
| 336 | + ssl_context_mock.return_value.load_default_certs.assert_not_called() |
| 337 | + ssl_context_mock.return_value.load_verify_locations.assert_not_called() |
| 338 | + assert ssl_context.check_hostname is False |
| 339 | + assert ssl_context.verify_mode is ssl.CERT_NONE |
| 340 | + |
| 341 | + |
| 342 | +def _assert_mock_tls_1_2(mock): |
| 343 | + mock.assert_called_once_with(ssl.PROTOCOL_TLS_CLIENT) |
| 344 | + assert mock.return_value.minimum_version == ssl.TLSVersion.TLSv1_2 |
| 345 | + |
| 346 | + |
| 347 | +@pytest.mark.parametrize("config", ( |
| 348 | + {}, |
| 349 | + {"encrypted": False}, |
| 350 | + {"trusted_certificates": TrustSystemCAs()}, |
| 351 | + {"trusted_certificates": TrustAll()}, |
| 352 | + {"trusted_certificates": TrustCustomCAs("foo", "bar")}, |
| 353 | +)) |
| 354 | +def test_no_ssl(config): |
| 355 | + pool_config = PoolConfig.consume(config) |
| 356 | + assert pool_config.encrypted is False |
| 357 | + assert pool_config.get_ssl_context() is None |
| 358 | + |
| 359 | + |
| 360 | +@pytest.mark.parametrize("config", ( |
| 361 | + {"encrypted": True}, |
| 362 | + {"encrypted": True, "trusted_certificates": TrustSystemCAs()}, |
| 363 | +)) |
| 364 | +def test_trust_system_cas(config): |
| 365 | + pool_config = PoolConfig.consume(config) |
| 366 | + assert pool_config.encrypted is True |
| 367 | + ssl_context = pool_config.get_ssl_context() |
| 368 | + assert isinstance(ssl_context, ssl.SSLContext) |
| 369 | + _assert_context_tls_1_2(ssl_context) |
| 370 | + assert ssl_context.check_hostname is True |
| 371 | + assert ssl_context.verify_mode == ssl.CERT_REQUIRED |
| 372 | + |
| 373 | + |
| 374 | +@pytest.mark.parametrize("config", ( |
| 375 | + {"encrypted": True, "trusted_certificates": TrustCustomCAs()}, |
| 376 | +)) |
| 377 | +def test_trust_custom_cas(config): |
| 378 | + pool_config = PoolConfig.consume(config) |
| 379 | + assert pool_config.encrypted is True |
| 380 | + ssl_context = pool_config.get_ssl_context() |
| 381 | + assert isinstance(ssl_context, ssl.SSLContext) |
| 382 | + _assert_context_tls_1_2(ssl_context) |
| 383 | + assert ssl_context.check_hostname is True |
| 384 | + assert ssl_context.verify_mode == ssl.CERT_REQUIRED |
| 385 | + |
| 386 | + |
| 387 | +@pytest.mark.parametrize("config", ( |
| 388 | + {"encrypted": True, "trusted_certificates": TrustAll()}, |
| 389 | +)) |
| 390 | +def test_trust_all(config): |
| 391 | + pool_config = PoolConfig.consume(config) |
| 392 | + assert pool_config.encrypted is True |
| 393 | + ssl_context = pool_config.get_ssl_context() |
| 394 | + assert isinstance(ssl_context, ssl.SSLContext) |
| 395 | + _assert_context_tls_1_2(ssl_context) |
| 396 | + assert ssl_context.check_hostname is False |
| 397 | + assert ssl_context.verify_mode is ssl.CERT_NONE |
| 398 | + |
| 399 | + |
| 400 | +def _assert_context_tls_1_2(ctx): |
| 401 | + assert ctx.protocol == ssl.PROTOCOL_TLS_CLIENT |
| 402 | + assert ctx.minimum_version == ssl.TLSVersion.TLSv1_2 |
| 403 | + |
| 404 | + |
| 405 | +@pytest.mark.parametrize("encrypted", (True, False)) |
| 406 | +@pytest.mark.parametrize("trusted_certificates", ( |
| 407 | + TrustSystemCAs(), TrustAll(), TrustCustomCAs() |
| 408 | +)) |
| 409 | +def test_custom_ssl_context(encrypted, trusted_certificates): |
| 410 | + custom_ssl_context = object() |
| 411 | + pool_config = PoolConfig.consume({ |
| 412 | + "encrypted": encrypted, |
| 413 | + "trusted_certificates": trusted_certificates, |
| 414 | + "ssl_context": custom_ssl_context, |
| 415 | + }) |
| 416 | + assert pool_config.encrypted is encrypted |
| 417 | + assert pool_config.get_ssl_context() is custom_ssl_context |
0 commit comments