From e5eba822688ac9b678c53d55728677f43aba10e4 Mon Sep 17 00:00:00 2001 From: Quentin Pradet Date: Mon, 6 Jan 2025 16:45:28 +0400 Subject: [PATCH] Fix use of SSLContext with sniffing (#199) * Fix use of SSLContext with sniffing * Fix lint * Ignore enabled_cleanup_closed warning It's not triggered for normal code. * Remove warning filter (cherry picked from commit af34992196d2f5daa8699d6ec7f162c3bb8b388a) --- elastic_transport/_node/_urllib3_chain_certs.py | 2 +- elastic_transport/_transport.py | 8 ++++---- tests/async_/test_async_transport.py | 15 ++++++++++++--- tests/test_transport.py | 12 +++++++++--- 4 files changed, 26 insertions(+), 11 deletions(-) diff --git a/elastic_transport/_node/_urllib3_chain_certs.py b/elastic_transport/_node/_urllib3_chain_certs.py index e36449b..9da6dd4 100644 --- a/elastic_transport/_node/_urllib3_chain_certs.py +++ b/elastic_transport/_node/_urllib3_chain_certs.py @@ -108,7 +108,7 @@ def _validate_conn(self, conn: HTTPSConnection) -> None: # type: ignore[overrid if sys.version_info >= (3, 13): fingerprints = [ hash_func(cert).digest() - for cert in conn.sock.get_verified_chain() + for cert in conn.sock.get_verified_chain() # type: ignore ] else: # 'get_verified_chain()' and 'Certificate.public_bytes()' are private APIs diff --git a/elastic_transport/_transport.py b/elastic_transport/_transport.py index bf1de58..3219e52 100644 --- a/elastic_transport/_transport.py +++ b/elastic_transport/_transport.py @@ -540,13 +540,13 @@ def validate_sniffing_options( def warn_if_varying_node_config_options(node_configs: List[NodeConfig]) -> None: """Function which detects situations when sniffing may produce incorrect configs""" - exempt_attrs = {"host", "port", "connections_per_node", "_extras"} + exempt_attrs = {"host", "port", "connections_per_node", "_extras", "ssl_context"} match_attr_dict = None for node_config in node_configs: attr_dict = { - k: v - for k, v in dataclasses.asdict(node_config).items() - if k not in exempt_attrs + field.name: getattr(node_config, field.name) + for field in dataclasses.fields(node_config) + if field.name not in exempt_attrs } if match_attr_dict is None: match_attr_dict = attr_dict diff --git a/tests/async_/test_async_transport.py b/tests/async_/test_async_transport.py index 2e8a884..2e288e2 100644 --- a/tests/async_/test_async_transport.py +++ b/tests/async_/test_async_transport.py @@ -18,6 +18,7 @@ import asyncio import random import re +import ssl import sys import time import warnings @@ -505,13 +506,21 @@ async def test_error_sniffing_callback_without_sniffing_enabled(): @pytest.mark.asyncio async def test_heterogeneous_node_config_warning_with_sniffing(): with warnings.catch_warnings(record=True) as w: + # SSLContext objects cannot be compared and are thus ignored + context = ssl.create_default_context() AsyncTransport( [ - NodeConfig("http", "localhost", 80, path_prefix="/a"), - NodeConfig("http", "localhost", 81, path_prefix="/b"), + NodeConfig( + "https", "localhost", 80, path_prefix="/a", ssl_context=context + ), + NodeConfig( + "https", "localhost", 81, path_prefix="/b", ssl_context=context + ), ], sniff_on_start=True, - sniff_callback=lambda *_: [], + sniff_callback=lambda *_: [ + NodeConfig("https", "localhost", 80, path_prefix="/a") + ], ) assert len(w) == 1 diff --git a/tests/test_transport.py b/tests/test_transport.py index 4545399..07d063d 100644 --- a/tests/test_transport.py +++ b/tests/test_transport.py @@ -17,6 +17,7 @@ import random import re +import ssl import threading import time import warnings @@ -537,14 +538,19 @@ def test_error_sniffing_callback_without_sniffing_enabled(): def test_heterogeneous_node_config_warning_with_sniffing(): with warnings.catch_warnings(record=True) as w: + context = ssl.create_default_context() Transport( [ - NodeConfig("http", "localhost", 80, path_prefix="/a"), - NodeConfig("http", "localhost", 81, path_prefix="/b"), + NodeConfig( + "https", "localhost", 80, path_prefix="/a", ssl_context=context + ), + NodeConfig( + "https", "localhost", 81, path_prefix="/b", ssl_context=context + ), ], sniff_on_start=True, sniff_callback=lambda *_: [ - NodeConfig("http", "localhost", 80, path_prefix="/a") + NodeConfig("https", "localhost", 80, path_prefix="/a") ], )