Skip to content

Commit

Permalink
Fix use of SSLContext with sniffing (#199)
Browse files Browse the repository at this point in the history
* Fix use of SSLContext with sniffing

* Fix lint

* Ignore enabled_cleanup_closed warning

It's not triggered for normal code.

* Remove warning filter

(cherry picked from commit af34992)
  • Loading branch information
pquentin authored and github-actions[bot] committed Jan 6, 2025
1 parent 211a0e1 commit e5eba82
Show file tree
Hide file tree
Showing 4 changed files with 26 additions and 11 deletions.
2 changes: 1 addition & 1 deletion elastic_transport/_node/_urllib3_chain_certs.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ def _validate_conn(self, conn: HTTPSConnection) -> None: # type: ignore[overrid
if sys.version_info >= (3, 13):
fingerprints = [
hash_func(cert).digest()
for cert in conn.sock.get_verified_chain()
for cert in conn.sock.get_verified_chain() # type: ignore
]
else:
# 'get_verified_chain()' and 'Certificate.public_bytes()' are private APIs
Expand Down
8 changes: 4 additions & 4 deletions elastic_transport/_transport.py
Original file line number Diff line number Diff line change
Expand Up @@ -540,13 +540,13 @@ def validate_sniffing_options(

def warn_if_varying_node_config_options(node_configs: List[NodeConfig]) -> None:
"""Function which detects situations when sniffing may produce incorrect configs"""
exempt_attrs = {"host", "port", "connections_per_node", "_extras"}
exempt_attrs = {"host", "port", "connections_per_node", "_extras", "ssl_context"}
match_attr_dict = None
for node_config in node_configs:
attr_dict = {
k: v
for k, v in dataclasses.asdict(node_config).items()
if k not in exempt_attrs
field.name: getattr(node_config, field.name)
for field in dataclasses.fields(node_config)
if field.name not in exempt_attrs
}
if match_attr_dict is None:
match_attr_dict = attr_dict
Expand Down
15 changes: 12 additions & 3 deletions tests/async_/test_async_transport.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import asyncio
import random
import re
import ssl
import sys
import time
import warnings
Expand Down Expand Up @@ -505,13 +506,21 @@ async def test_error_sniffing_callback_without_sniffing_enabled():
@pytest.mark.asyncio
async def test_heterogeneous_node_config_warning_with_sniffing():
with warnings.catch_warnings(record=True) as w:
# SSLContext objects cannot be compared and are thus ignored
context = ssl.create_default_context()
AsyncTransport(
[
NodeConfig("http", "localhost", 80, path_prefix="/a"),
NodeConfig("http", "localhost", 81, path_prefix="/b"),
NodeConfig(
"https", "localhost", 80, path_prefix="/a", ssl_context=context
),
NodeConfig(
"https", "localhost", 81, path_prefix="/b", ssl_context=context
),
],
sniff_on_start=True,
sniff_callback=lambda *_: [],
sniff_callback=lambda *_: [
NodeConfig("https", "localhost", 80, path_prefix="/a")
],
)

assert len(w) == 1
Expand Down
12 changes: 9 additions & 3 deletions tests/test_transport.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

import random
import re
import ssl
import threading
import time
import warnings
Expand Down Expand Up @@ -537,14 +538,19 @@ def test_error_sniffing_callback_without_sniffing_enabled():

def test_heterogeneous_node_config_warning_with_sniffing():
with warnings.catch_warnings(record=True) as w:
context = ssl.create_default_context()
Transport(
[
NodeConfig("http", "localhost", 80, path_prefix="/a"),
NodeConfig("http", "localhost", 81, path_prefix="/b"),
NodeConfig(
"https", "localhost", 80, path_prefix="/a", ssl_context=context
),
NodeConfig(
"https", "localhost", 81, path_prefix="/b", ssl_context=context
),
],
sniff_on_start=True,
sniff_callback=lambda *_: [
NodeConfig("http", "localhost", 80, path_prefix="/a")
NodeConfig("https", "localhost", 80, path_prefix="/a")
],
)

Expand Down

0 comments on commit e5eba82

Please sign in to comment.