diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 5458b69ad5..bce829b152 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -32,4 +32,4 @@ repos: - id: mypy args: [--check-untyped-defs] exclude: 'tests/' - additional_dependencies: ['charset_normalizer', 'urllib3.future>=2.0.934', 'wassima>=1.0.1', 'idna', 'kiss_headers'] + additional_dependencies: ['charset_normalizer', 'urllib3.future>=2.1.900', 'wassima>=1.0.1', 'idna', 'kiss_headers'] diff --git a/HISTORY.md b/HISTORY.md index 3f3deefbfe..b7a18d2777 100644 --- a/HISTORY.md +++ b/HISTORY.md @@ -1,12 +1,52 @@ Release History =============== -3.0.3 (2023-10-??) +3.1.0 (2023-10-10) ------------------ **Misc** - Static typing has been improved to provide a better development experience. +**Added** +- Certificate revocation verification via the OCSP protocol. + + This feature is broadly available and is enabled by default when `verify=True`. + We decided to follow what browsers do by default, so Niquests follows by being non-strict. + OCSP responses are expected to arrive in less than 200ms, otherwise ignored (e.g. OCSP is dropped). + Niquests keeps in-memory the results until the size exceed 2,048 entries, then an algorithm choose an entry + to be deleted (oldest request or the first one that ended in error). + + You can at your own discretion enable strict OCSP checks by passing the environment variable `NIQUESTS_STRICT_OCSP` + with anything inside but `0`. In strict mode the maximum delay for response passes from 200ms to 1,000ms and + raises an error or explicit warning. + + In non-strict mode, this security measure will be deactivated automatically if your usage is unreasonable. + e.g. Making a hundred of requests to a hundred of domains, thus consuming resources that should have been + allocated to browser users. This was made available for users with a limited target of domains to get + a complementary security measure. + + Unless in strict-mode, the proxy configuration will be respected when given, as long as it specify + a plain `http` proxy. This is meant for people who want privacy. + + This feature may not be available if the `cryptography` package is missing from your environment. + Verify the availability after Niquests upgrade by running `python -m niquests.help`. + + There is several downside of using OCSP, Niquests knows it. It is not a silver bullet solution. But better than nothing. + It does not apply to HTTPS proxies themselves. For now. + +- Add property `ocsp_verified` in both `PreparedRequest`, and `Response` to have a clue on the post handshake verification. + + Will be `None` if no verification took place, `True` if the verification leads to a confirmation from the OCSP server + that the certificate is valid, `False` otherwise. + +**Changed** +- Bump lower version requirement for `urllib3.future` to 2.1.900 to ensure compatibility with newer features. +- Internal in-memory QUIC capabilities is now thread safe and limited to 12,288 entries. +- Pickling a `Session` object no-longer dump adapters or the QUIC in-memory capabilities, they are reset on setstate. + +**Fixed** +- `conn_info` was unset if the response came after a redirect. + 3.0.2 (2023-10-01) ------------------ diff --git a/README.md b/README.md index e3d584b9c5..aa06c9b246 100644 --- a/README.md +++ b/README.md @@ -2,7 +2,6 @@ **Niquests** is a simple, yet elegant, HTTP library. It is a drop-in replacement for **Requests** that is no longer under feature freeze. -We will support and maintain v2.x.y that only ships with possible minor breaking changes. All breaking changes are issued in the v3.x that should be available as a pre-release on PyPI. Why did we pursue this? We don't have to reinvent the wheel all over again, HTTP client **Requests** is well established and really plaisant in its usage. We believe that **Requests** have the most inclusive, and developer friendly interfaces. We @@ -48,6 +47,7 @@ Niquests is ready for the demands of building robust and reliable HTTP–speakin - Automatic Content Decompression and Decoding - OS truststore by default, no more certifi! +- OCSP Certificate Revocation Verification - Browser-style TLS/SSL Verification - Sessions with Cookie Persistence - Keep-Alive & Connection Pooling @@ -62,12 +62,8 @@ Niquests is ready for the demands of building robust and reliable HTTP–speakin - SOCKS Proxy Support - Connection Timeouts - Streaming Downloads +- HTTP/2 by default - HTTP/3 over QUIC -- HTTP/2 - -## API Reference and User Guide available on [Read the Docs](https://niquests.readthedocs.io) - -[![Read the Docs](https://raw.githubusercontent.com/jawah/niquests/main/ext/ss.png)](https://niquests.readthedocs.io) --- diff --git a/docs/index.rst b/docs/index.rst index bd4124b54c..77b03ff951 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -59,6 +59,7 @@ Niquests is ready for today's web. - Automatic Content Decompression and Decoding - OS truststore by default, no more certifi! +- OCSP Certificate Revocation Verification - Browser-style TLS/SSL Verification - Sessions with Cookie Persistence - Keep-Alive & Connection Pooling @@ -73,8 +74,8 @@ Niquests is ready for today's web. - SOCKS Proxy Support - Connection Timeouts - Streaming Downloads +- HTTP/2 by default - HTTP/3 over QUIC -- HTTP/2 Niquests officially supports Python 3.7+, and runs great on PyPy. diff --git a/docs/user/advanced.rst b/docs/user/advanced.rst index 329896b146..1f9bc7bee3 100644 --- a/docs/user/advanced.rst +++ b/docs/user/advanced.rst @@ -1139,3 +1139,42 @@ It is also possible to use the ``Timeout`` class from ``urllib3`` directly:: r = niquests.get('https://github.com', timeout=Timeout(3, 9)) .. _`connect()`: https://linux.die.net/man/2/connect + +OCSP or Certificate Revocation +------------------------------ + +Difficult subject. Short story, when a HTTP client establish a secure connection, +it verify that the certificate is valid. The problem is that a certificate +can be both valid and revoked due its immutability, the revocation status must +be taken from an outside source, most of the revocation are linked to a hack/security violation. + +Niquests try to protect you from the evoked problem by doing a post-handshake verification +using the OCSP protocols via plain HTTP. + +Unfortunately, at this moment, no bullet proof solution has emerged against revoked certificate. +We are aware of this. But still, it is better than nothing! + +By default, Niquests operate a soft-fail verification, or non-strict if you prefer. + +This feature is broadly available and is enabled by default when ``verify=True``. +We decided to follow what browsers do by default, so Niquests follows by being non-strict. +OCSP responses are expected to arrive in less than 200ms, otherwise ignored (e.g. OCSP is dropped). +Niquests keeps in-memory the results until the size exceed 2,048 entries, then an algorithm choose an entry +to be deleted (oldest request or the first one that ended in error). + +You can at your own discretion enable strict OCSP checks by passing the environment variable ``NIQUESTS_STRICT_OCSP`` +with anything inside but ``0``. In strict mode the maximum delay for response passes from 200ms to 1,000ms and +raises an error or explicit warning. + +In non-strict mode, this security measure will be deactivated automatically if your usage is unreasonable. +e.g. Making a hundred of requests to a hundred of domains, thus consuming resources that should have been +allocated to browser users. This was made available for users with a limited target of domains to get +a complementary security measure. + +Unless in strict-mode, the proxy configuration will be respected when given, as long as it specify +a plain ``http`` proxy. This is meant for people who want privacy. + +This feature may not be available if the ``cryptography`` package is missing from your environment. +Verify the availability by running ``python -m niquests.help``. + +.. note:: Access property ``ocsp_verified`` in both ``PreparedRequest``, and ``Response`` to have information about this post handshake verification. diff --git a/pyproject.toml b/pyproject.toml index 42b7ff360b..010da416d6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -8,7 +8,7 @@ description = "Python HTTP for Humans." readme = "README.md" license-files = { paths = ["LICENSE"] } license = "Apache-2.0" -keywords = ["requests", "http/2", "http/3", "QUIC", "http", "https", "http client", "http/1.1"] +keywords = ["requests", "http/2", "http/3", "QUIC", "http", "https", "http client", "http/1.1", "ocsp", "revocation", "tls"] authors = [ {name = "Kenneth Reitz", email = "me@kennethreitz.org"} ] @@ -41,7 +41,7 @@ dynamic = ["version"] dependencies = [ "charset_normalizer>=2,<4", "idna>=2.5,<4", - "urllib3.future>=2.0.936,<3", + "urllib3.future>=2.1.900,<3", "wassima>=1.0.1,<2", "kiss_headers>=2,<4", ] diff --git a/src/niquests/__version__.py b/src/niquests/__version__.py index dbb2637e3f..7ea85aa0d6 100644 --- a/src/niquests/__version__.py +++ b/src/niquests/__version__.py @@ -9,9 +9,9 @@ __url__: str = "https://niquests.readthedocs.io" __version__: str -__version__ = "3.0.2" +__version__ = "3.1.0" -__build__: int = 0x030002 +__build__: int = 0x030100 __author__: str = "Kenneth Reitz" __author_email__: str = "me@kennethreitz.org" __license__: str = "Apache-2.0" diff --git a/src/niquests/api.py b/src/niquests/api.py index 13fcb12f47..049f96753d 100644 --- a/src/niquests/api.py +++ b/src/niquests/api.py @@ -31,9 +31,9 @@ TLSVerifyType, ) from .models import PreparedRequest, Response +from .structures import SharableLimitedDict -#: This is a non-thread safe in-memory cache for the AltSvc / h3 -_SHARED_QUIC_CACHE: CacheLayerAltSvcType = {} +_SHARED_QUIC_CACHE: CacheLayerAltSvcType = SharableLimitedDict(max_size=12_288) def request( diff --git a/src/niquests/extensions/__init__.py b/src/niquests/extensions/__init__.py new file mode 100644 index 0000000000..a3231de591 --- /dev/null +++ b/src/niquests/extensions/__init__.py @@ -0,0 +1,4 @@ +""" +This subpackage hold anything that is very relevant +to the HTTP ecosystem but not per-say Niquests core logic. +""" diff --git a/src/niquests/extensions/_ocsp.py b/src/niquests/extensions/_ocsp.py new file mode 100644 index 0000000000..da7fd2f281 --- /dev/null +++ b/src/niquests/extensions/_ocsp.py @@ -0,0 +1,511 @@ +from __future__ import annotations + +import datetime +import typing +import warnings +from random import randint +from statistics import mean + +if typing.TYPE_CHECKING: + from ..sessions import Session + +import hmac +import socket +import threading +from hashlib import sha256 + +import wassima +from cryptography.hazmat.primitives import serialization +from cryptography.hazmat.primitives.hashes import SHA1 +from cryptography.x509 import ( + Certificate, + load_der_x509_certificate, + load_pem_x509_certificate, + ocsp, +) +from urllib3 import ConnectionInfo +from urllib3.exceptions import SecurityWarning +from urllib3.util.url import parse_url + +from ..exceptions import RequestException, SSLError +from ..models import PreparedRequest +from ._picotls import ( + CHANGE_CIPHER, + HANDSHAKE, + derive_secret, + gen_client_hello, + handle_encrypted_extensions, + handle_server_cert, + handle_server_hello, + multiply_num_on_ec_point, + num_to_bytes, + recv_tls, + recv_tls_and_decrypt, + send_tls, +) + + +def _str_fingerprint_of(certificate: Certificate) -> str: + return ":".join([format(i, "02x") for i in certificate.fingerprint(SHA1())]) + + +def _infer_issuer_from(certificate: Certificate) -> Certificate | None: + issuer: Certificate | None = None + + for der_cert in wassima.root_der_certificates() + _SharedStaplingCache.issuers: + if isinstance(der_cert, Certificate): + possible_issuer = der_cert + else: + possible_issuer = load_der_x509_certificate(der_cert) + + try: + certificate.verify_directly_issued_by(possible_issuer) + except ValueError: + continue + else: + issuer = possible_issuer + break + + return issuer + + +def _ask_nicely_for_issuer( + hostname: str, dst_address: tuple[str, int], timeout: int | float = 0.2 +) -> Certificate | None: + """When encountering a problem in development, one should always say that there is many solutions. + From dirtiest to the cleanest, not always known but with progressive effort, we'll eventually land at the cleanest. + + This function do a manual TLS 1.2+ handshake till we extract certificates from the remote peer. Does not + need to be secure, we just have to retrieve the issuer cert if any.""" + if dst_address[0].count(".") == 3: + sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + else: + sock = socket.socket(socket.AF_INET6, socket.SOCK_STREAM) + + sock.connect(dst_address) + sock.settimeout(timeout) + + SECP256R1_P = 0xFFFFFFFF00000001000000000000000000000000FFFFFFFFFFFFFFFFFFFFFFFF + SECP256R1_A = 0xFFFFFFFF00000001000000000000000000000000FFFFFFFFFFFFFFFFFFFFFFFC + SECP256R1_G = ( + 0x6B17D1F2E12C4247F8BCE6E563A440F277037D812DEB33A0F4A13945D898C296, + 0x4FE342E2FE1A7F9B8EE7EB4A7C0F9E162BCE33576B315ECECBB6406837BF51F5, + ) + + randelem = [b"\xAC", b"\xDC", b"\xFA", b"\xAF"] + client_random = b"".join([randelem[randint(0, 3)] for e in range(32)]) + our_ecdh_privkey = randint(42, 98) + our_ecdh_pubkey_x, our_ecdh_pubkey_y = multiply_num_on_ec_point( + our_ecdh_privkey, SECP256R1_G[0], SECP256R1_G[1], SECP256R1_A, SECP256R1_P + ) + + client_hello = gen_client_hello( + hostname, client_random, our_ecdh_pubkey_x, our_ecdh_pubkey_y + ) + + send_tls(sock, HANDSHAKE, client_hello) + + rec_type, server_hello = recv_tls(sock) + + if not rec_type == HANDSHAKE: + sock.close() + return None + + ( + server_random, + session_id, + server_ecdh_pubkey_x, + server_ecdh_pubkey_y, + ) = handle_server_hello(server_hello) + + rec_type, server_change_cipher = recv_tls(sock) + + if not rec_type == CHANGE_CIPHER: + sock.close() + return None + + our_secret_point_x = multiply_num_on_ec_point( + our_ecdh_privkey, + server_ecdh_pubkey_x, + server_ecdh_pubkey_y, + SECP256R1_A, + SECP256R1_P, + )[0] + our_secret = num_to_bytes(our_secret_point_x, 32) + + early_secret = hmac.new(b"", b"\x00" * 32, sha256).digest() + preextractsec = derive_secret( + b"derived", key=early_secret, data=sha256(b"").digest(), hash_len=32 + ) + handshake_secret = hmac.new(preextractsec, our_secret, sha256).digest() + hello_hash = sha256(client_hello + server_hello).digest() + server_hs_secret = derive_secret( + b"s hs traffic", key=handshake_secret, data=hello_hash, hash_len=32 + ) + server_write_key = derive_secret( + b"key", key=server_hs_secret, data=b"", hash_len=16 + ) + server_write_iv = derive_secret(b"iv", key=server_hs_secret, data=b"", hash_len=12) + + server_seq_num = 0 + + rec_type, encrypted_extensions = recv_tls_and_decrypt( + sock, server_write_key, server_write_iv, server_seq_num + ) + + if not rec_type == HANDSHAKE: + sock.close() + return None + + server_seq_num += 1 + + remaining_bytes = handle_encrypted_extensions(encrypted_extensions) + + if not remaining_bytes: + rec_type, server_cert = recv_tls_and_decrypt( + sock, server_write_key, server_write_iv, server_seq_num + ) + else: + rec_type, server_cert = rec_type, remaining_bytes + + if not rec_type == HANDSHAKE: + sock.close() + return None + + server_seq_num += 1 + + der_certificates = handle_server_cert(server_cert) + certificates = [] + + for der in der_certificates: + certificates.append(load_der_x509_certificate(der)) + + sock.close() + + if len(certificates) <= 1: + return None + + # kept in order, the immediate issuer come just after the leaf one. + return certificates[1] + + +class InMemoryRevocationStatus: + def __init__(self, max_size: int = 2048): + self._max_size: int = max_size + self._store: dict[str, ocsp.OCSPResponse] = {} + self._issuers: list[Certificate] = [] + self._timings: list[datetime.datetime] = [] + self._access_lock: threading.Lock = threading.Lock() + self.hold: bool = False + + @property + def issuers(self) -> list[Certificate]: + with self._access_lock: + return self._issuers + + def __len__(self) -> int: + with self._access_lock: + return len(self._store) + + def rate(self): + with self._access_lock: + previous_dt: datetime.datetime | None = None + delays: list[float] = [] + + for dt in self._timings: + if previous_dt is None: + previous_dt = dt + continue + delays.append((dt - previous_dt).total_seconds()) + previous_dt = dt + + return mean(delays) if delays else 0.0 + + def check(self, peer_certificate: Certificate) -> ocsp.OCSPResponse | None: + with self._access_lock: + fingerprint: str = _str_fingerprint_of(peer_certificate) + + if fingerprint not in self._store: + return None + + cached_response = self._store[fingerprint] + + if cached_response.certificate_status == ocsp.OCSPCertStatus.GOOD: + if ( + cached_response.next_update + and datetime.datetime.now().timestamp() + >= cached_response.next_update.timestamp() + ): + del self._store[fingerprint] + return None + return cached_response + + return cached_response + + def save( + self, + peer_certificate: Certificate, + issuer_certificate: Certificate, + ocsp_response: ocsp.OCSPResponse, + ) -> None: + with self._access_lock: + if len(self._store) >= self._max_size: + tbd_key: str | None = None + closest_next_update: datetime.datetime | None = None + + for k in self._store: + if ( + self._store[k].response_status + != ocsp.OCSPResponseStatus.SUCCESSFUL + ): + tbd_key = k + break + + if self._store[k].certificate_status != ocsp.OCSPCertStatus.REVOKED: + if closest_next_update is None: + closest_next_update = self._store[k].next_update + tbd_key = k + continue + if self._store[k].next_update > closest_next_update: # type: ignore + closest_next_update = self._store[k].next_update + tbd_key = k + + if tbd_key: + del self._store[tbd_key] + else: + del self._store[list(self._store.keys())[0]] + + self._store[_str_fingerprint_of(peer_certificate)] = ocsp_response + + issuer_fingerprint = _str_fingerprint_of(issuer_certificate) + + if not any( + _str_fingerprint_of(c) == issuer_fingerprint for c in self._issuers + ): + self._issuers.append(issuer_certificate) + + if len(self._issuers) >= self._max_size: + self._issuers.pop(0) + + self._timings.append(datetime.datetime.now()) + + if len(self._timings) >= self._max_size: + self._timings.pop(0) + + +_SharedStaplingCache = InMemoryRevocationStatus() + + +def verify( + session: Session, + r: PreparedRequest, + strict: bool = False, + timeout: float | int = 0.2, +) -> None: + conn_info: ConnectionInfo | None = r.conn_info + + # we can't do anything in that case. + if ( + conn_info is None + or conn_info.certificate_der is None + or conn_info.certificate_dict is None + ): + return + + endpoints: list[str] = [ # type: ignore + # exclude non-HTTP endpoint. like ldap. + ep # type: ignore + for ep in list(conn_info.certificate_dict.get("OCSP", [])) # type: ignore + if ep.startswith("http://") # type: ignore + ] + + # well... not all issued certificate have a OCSP entry. e.g. mkcert. + if not endpoints: + return + + # this feature, by default, is reserved for a reasonable usage. + if not strict: + mean_rate_sec = _SharedStaplingCache.rate() + cache_count = len(_SharedStaplingCache) + + if cache_count >= 10 and mean_rate_sec <= 1.0: + _SharedStaplingCache.hold = True + + if _SharedStaplingCache.hold: + return + + peer_certificate = load_der_x509_certificate(conn_info.certificate_der) + cached_response = _SharedStaplingCache.check(peer_certificate) + + if cached_response is not None: + issuer_certificate = _infer_issuer_from(peer_certificate) + + if issuer_certificate: + conn_info.issuer_certificate_der = issuer_certificate.public_bytes( + serialization.Encoding.DER + ) + + if cached_response.response_status == ocsp.OCSPResponseStatus.SUCCESSFUL: + if cached_response.certificate_status == ocsp.OCSPCertStatus.REVOKED: + r.ocsp_verified = False + raise SSLError( + f"""Unable to establish a secure connection to {r.url} because the certificate has been revoked + by issuer ({cached_response.revocation_reason}). + You should avoid trying to request anything from it as the remote has been compromised. + See https://en.wikipedia.org/wiki/OCSP_stapling for more information.""" + ) + elif cached_response.certificate_status == ocsp.OCSPCertStatus.UNKNOWN: + r.ocsp_verified = False + if strict is True: + raise SSLError( + f"""Unable to establish a secure connection to {r.url} because the issuer does not know whether + certificate is valid or not. This error occurred because you enabled strict mode for + the OCSP / Revocation check.""" + ) + else: + r.ocsp_verified = True + + return + + # When using Python native capabilities, you won't have the issuerCA DER by default. + # Unfortunately! But no worries, we can circumvent it! + # Three ways are valid to fetch it (in order of preference, safest to riskiest): + # - The issuer can be (but unlikely) a root CA. + # - Retrieve it by asking it from the TLS layer. + # - Downloading it using specified caIssuers from the peer certificate. + if conn_info.issuer_certificate_der is None: + # It could be a root (self-signed) certificate. Or a previously seen issuer. + issuer_certificate = _infer_issuer_from(peer_certificate) + + # If not, try to ask nicely the remote to give us the certificate chain, and extract + # from it the immediate issuer. + if issuer_certificate is None: + try: + if r.url is None: + raise ValueError + + url_parsed = parse_url(r.url) + + if url_parsed.hostname is None or conn_info.destination_address is None: + raise ValueError + + issuer_certificate = _ask_nicely_for_issuer( + url_parsed.hostname, + conn_info.destination_address, + timeout, + ) + + if issuer_certificate is not None: + peer_certificate.verify_directly_issued_by(issuer_certificate) + + except (socket.gaierror, TimeoutError, ConnectionError): + pass + except ValueError: + issuer_certificate = None + + hint_ca_issuers: list[str] = [ep for ep in list(conn_info.certificate_dict.get("caIssuers", [])) if ep.startswith("http://")] # type: ignore + + if issuer_certificate is None and hint_ca_issuers: + try: + raw_intermediary_response = session.get(hint_ca_issuers[0]) + except RequestException: + pass + else: + if ( + raw_intermediary_response.status_code + and 300 > raw_intermediary_response.status_code >= 200 + ): + raw_intermediary_content = raw_intermediary_response.content + + if raw_intermediary_content is not None: + # binary DER + if ( + b"-----BEGIN CERTIFICATE-----" + not in raw_intermediary_content + ): + issuer_certificate = load_der_x509_certificate( + raw_intermediary_content + ) + # b64 PEM + elif b"-----BEGIN CERTIFICATE-----" in raw_intermediary_content: + issuer_certificate = load_pem_x509_certificate( + raw_intermediary_content + ) + + # Well! We're out of luck. No further should we go. + if issuer_certificate is None: + if strict: + warnings.warn( + f"""Unable to insure that the remote peer ({r.url}) has a currently valid certificate via OCSP. + You are seeing this warning due to enabling strict mode for OCSP / Revocation check. + Reason: Remote did not provide any intermediaries certificate.""", + SecurityWarning, + ) + return + + conn_info.issuer_certificate_der = issuer_certificate.public_bytes( + serialization.Encoding.DER + ) + else: + issuer_certificate = load_der_x509_certificate(conn_info.issuer_certificate_der) + + builder = ocsp.OCSPRequestBuilder() + builder = builder.add_certificate(peer_certificate, issuer_certificate, SHA1()) + + req = builder.build() + + try: + ocsp_http_response = session.post( + endpoints[randint(0, len(endpoints) - 1)], + data=req.public_bytes(serialization.Encoding.DER), + headers={"Content-Type": "application/ocsp-request"}, + timeout=timeout, + ) + except RequestException as e: + if strict: + warnings.warn( + f"""Unable to insure that the remote peer ({r.url}) has a currently valid certificate via OCSP. + You are seeing this warning due to enabling strict mode for OCSP / Revocation check. + Reason: {e}""", + SecurityWarning, + ) + return + + if ocsp_http_response.status_code and 300 > ocsp_http_response.status_code >= 200: + if ocsp_http_response.content is None: + return + + ocsp_resp = ocsp.load_der_ocsp_response(ocsp_http_response.content) + + _SharedStaplingCache.save(peer_certificate, issuer_certificate, ocsp_resp) + + if ocsp_resp.response_status == ocsp.OCSPResponseStatus.SUCCESSFUL: + if ocsp_resp.certificate_status == ocsp.OCSPCertStatus.REVOKED: + r.ocsp_verified = False + raise SSLError( + f"""Unable to establish a secure connection to {r.url} because the certificate has been revoked + by issuer ({ocsp_resp.revocation_reason}). + You should avoid trying to request anything from it as the remote has been compromised. + See https://en.wikipedia.org/wiki/OCSP_stapling for more information.""" + ) + if ocsp_resp.certificate_status == ocsp.OCSPCertStatus.UNKNOWN: + r.ocsp_verified = False + if strict is True: + raise SSLError( + f"""Unable to establish a secure connection to {r.url} because the issuer does not know whether + certificate is valid or not. This error occurred because you enabled strict mode for + the OCSP / Revocation check.""" + ) + else: + r.ocsp_verified = True + else: + if strict: + warnings.warn( + f"""Unable to insure that the remote peer ({r.url}) has a currently valid certificate via OCSP. + You are seeing this warning due to enabling strict mode for OCSP / Revocation check. + OCSP Server Status: {ocsp_resp.response_status}""", + SecurityWarning, + ) + + +__all__ = ("verify",) diff --git a/src/niquests/extensions/_picotls.py b/src/niquests/extensions/_picotls.py new file mode 100644 index 0000000000..1d0da9f503 --- /dev/null +++ b/src/niquests/extensions/_picotls.py @@ -0,0 +1,738 @@ +""" +This module purpose is to have a "super" minimalist way to +speak with a TLS 1.2+ server. The goal of this is to extract +the certificate chain to be used for OCSP stapling / revocation. +It's not meant to establish a secure connection. Never! +""" +from __future__ import annotations + +import hmac +from hashlib import sha256 + +import idna + +LEGACY_TLS_VERSION = b"\x03\x03" +TLS_AES_128_GCM_SHA256 = b"\x13\x01" + +CHANGE_CIPHER = b"\x14" +ALERT = b"\x15" +HANDSHAKE = b"\x16" +APPLICATION_DATA = b"\x17" + +# SYMMETRIC CIPHERS +AES_ROUNDS = 10 + +# AES_SBOX is some permutation of numbers 0-255 +AES_SBOX = [ + 99, + 124, + 119, + 123, + 242, + 107, + 111, + 197, + 48, + 1, + 103, + 43, + 254, + 215, + 171, + 118, + 202, + 130, + 201, + 125, + 250, + 89, + 71, + 240, + 173, + 212, + 162, + 175, + 156, + 164, + 114, + 192, + 183, + 253, + 147, + 38, + 54, + 63, + 247, + 204, + 52, + 165, + 229, + 241, + 113, + 216, + 49, + 21, + 4, + 199, + 35, + 195, + 24, + 150, + 5, + 154, + 7, + 18, + 128, + 226, + 235, + 39, + 178, + 117, + 9, + 131, + 44, + 26, + 27, + 110, + 90, + 160, + 82, + 59, + 214, + 179, + 41, + 227, + 47, + 132, + 83, + 209, + 0, + 237, + 32, + 252, + 177, + 91, + 106, + 203, + 190, + 57, + 74, + 76, + 88, + 207, + 208, + 239, + 170, + 251, + 67, + 77, + 51, + 133, + 69, + 249, + 2, + 127, + 80, + 60, + 159, + 168, + 81, + 163, + 64, + 143, + 146, + 157, + 56, + 245, + 188, + 182, + 218, + 33, + 16, + 255, + 243, + 210, + 205, + 12, + 19, + 236, + 95, + 151, + 68, + 23, + 196, + 167, + 126, + 61, + 100, + 93, + 25, + 115, + 96, + 129, + 79, + 220, + 34, + 42, + 144, + 136, + 70, + 238, + 184, + 20, + 222, + 94, + 11, + 219, + 224, + 50, + 58, + 10, + 73, + 6, + 36, + 92, + 194, + 211, + 172, + 98, + 145, + 149, + 228, + 121, + 231, + 200, + 55, + 109, + 141, + 213, + 78, + 169, + 108, + 86, + 244, + 234, + 101, + 122, + 174, + 8, + 186, + 120, + 37, + 46, + 28, + 166, + 180, + 198, + 232, + 221, + 116, + 31, + 75, + 189, + 139, + 138, + 112, + 62, + 181, + 102, + 72, + 3, + 246, + 14, + 97, + 53, + 87, + 185, + 134, + 193, + 29, + 158, + 225, + 248, + 152, + 17, + 105, + 217, + 142, + 148, + 155, + 30, + 135, + 233, + 206, + 85, + 40, + 223, + 140, + 161, + 137, + 13, + 191, + 230, + 66, + 104, + 65, + 153, + 45, + 15, + 176, + 84, + 187, + 22, +] + + +def bytes_to_num(b): + return int.from_bytes(b, "big") + + +def num_to_bytes(num, bytes_len): + return int.to_bytes(num, bytes_len, "big") + + +def xor(a, b): + return bytes(i ^ j for i, j in zip(a, b)) + + +def egcd(a, b): + if a == 0: + return 0, 1 + y, x = egcd(b % a, a) + return x - (b // a) * y, y + + +def mod_inv(a, p): + return egcd(a, p)[0] if a >= 0 else p - egcd(-a, p)[0] + + +def add_two_ec_points(p1_x, p1_y, p2_x, p2_y, a, p): + if p1_x == p2_x and p1_y == p2_y: + s = (3 * p1_x * p1_x + a) * mod_inv(2 * p2_y, p) + elif p1_x != p2_x: + s = (p1_y - p2_y) * mod_inv(p1_x - p2_x, p) + else: + raise NotImplementedError + + x = s * s - p1_x - p2_x + y = -p1_y + s * (p1_x - x) + return x % p, y % p + + +def multiply_num_on_ec_point(num, g_x, g_y, a, p): + x, y = None, None + while num: + if num & 1: + x, y = add_two_ec_points(x, y, g_x, g_y, a, p) if x else (g_x, g_y) + g_x, g_y = add_two_ec_points(g_x, g_y, g_x, g_y, a, p) + num >>= 1 + return x, y + + +def gen_client_hello(hostname, client_random, ecdh_pubkey_x, ecdh_pubkey_y): + CLIENT_HELLO = b"\x01" + + session_id = b"" + compression_method = b"\x00" # no compression + + if not hostname.isascii(): + hostname = idna.encode(hostname) + else: + hostname = hostname.encode("ascii") + + hostname_prefix = b"\x00\x00" + hostname_list_length = num_to_bytes(len(hostname) + 5, 2) + hostname_item_length = num_to_bytes(len(hostname) + 3, 2) + hostname_length = num_to_bytes(len(hostname), 2) + + hostname_extension = ( + hostname_prefix + + hostname_list_length + + hostname_item_length + + b"\x00" + + hostname_length + + hostname + ) + + supported_versions = b"\x00\x2b" + supported_versions_length = b"\x00\x03" + another_supported_versions_length = b"\x02" + tls1_3_version = b"\x03\x04" + supported_version_extension = ( + supported_versions + + supported_versions_length + + another_supported_versions_length + + tls1_3_version + ) + + signature_algos = b"\x00\x0d" + signature_algos_length = b"\x00\x06" + another_signature_algos_length = b"\x00\x04" + rsa_pss_rsae_sha256_algo = b"\x08\x04" + # ECDSA/SECP256r1/SHA256 (e.g. for EV certs sig, mandatory) + ecdsa_secp256r1_sha256_algo = b"\x04\x03" + signature_algos_extension = ( + signature_algos + + signature_algos_length + + another_signature_algos_length + + rsa_pss_rsae_sha256_algo + + ecdsa_secp256r1_sha256_algo + ) + + supported_groups = b"\x00\x0a" + supported_groups_length = b"\x00\x04" + another_supported_groups_length = b"\x00\x02" + secp256r1_group = b"\x00\x17" + supported_groups_extension = ( + supported_groups + + supported_groups_length + + another_supported_groups_length + + secp256r1_group + ) + + ecdh_pubkey = ( + b"\x04" + num_to_bytes(ecdh_pubkey_x, 32) + num_to_bytes(ecdh_pubkey_y, 32) + ) + + key_share = b"\x00\x33" + key_share_length = num_to_bytes(len(ecdh_pubkey) + 4 + 2, 2) + another_key_share_length = num_to_bytes(len(ecdh_pubkey) + 4, 2) + key_exchange_len = num_to_bytes(len(ecdh_pubkey), 2) + key_share_extension = ( + key_share + + key_share_length + + another_key_share_length + + secp256r1_group + + key_exchange_len + + ecdh_pubkey + ) + + extensions = ( + hostname_extension + + supported_version_extension + + signature_algos_extension + + supported_groups_extension + + key_share_extension + ) + + client_hello_data = ( + LEGACY_TLS_VERSION + + client_random + + num_to_bytes(len(session_id), 1) + + session_id + + num_to_bytes(len(TLS_AES_128_GCM_SHA256), 2) + + TLS_AES_128_GCM_SHA256 + + num_to_bytes(len(compression_method), 1) + + compression_method + + num_to_bytes(len(extensions), 2) + ) + extensions + + client_hello_len_bytes = num_to_bytes(len(client_hello_data), 3) + client_hello_tlv = CLIENT_HELLO + client_hello_len_bytes + client_hello_data + + return client_hello_tlv + + +def handle_server_hello(server_hello): + handshake_type = server_hello[0] + + SERVER_HELLO = 0x2 + assert handshake_type == SERVER_HELLO + + # server_hello_len = server_hello[1:4] + # server_version = server_hello[4:6] + + server_random = server_hello[6:38] + + session_id_len = bytes_to_num(server_hello[38:39]) + session_id = server_hello[39 : 39 + session_id_len] + + cipher_suite = server_hello[39 + session_id_len : 39 + session_id_len + 2] + assert cipher_suite == TLS_AES_128_GCM_SHA256 + + # compression_method = server_hello[39 + session_id_len + 2 : 39 + session_id_len + 3] + + extensions_length = bytes_to_num( + server_hello[39 + session_id_len + 3 : 39 + session_id_len + 3 + 2] + ) + extensions = server_hello[ + 39 + session_id_len + 3 + 2 : 39 + session_id_len + 3 + 2 + extensions_length + ] + + public_ec_key = b"" + ptr = 0 + while ptr < extensions_length: + extension_type = extensions[ptr : ptr + 2] + extension_length = bytes_to_num(extensions[ptr + 2 : ptr + 4]) + KEY_SHARE = b"\x00\x33" + if extension_type != KEY_SHARE: + ptr += extension_length + 4 + continue + group = extensions[ptr + 4 : ptr + 6] + SECP256R1_GROUP = b"\x00\x17" + assert group == SECP256R1_GROUP + key_exchange_len = bytes_to_num(extensions[ptr + 6 : ptr + 8]) + + public_ec_key = extensions[ptr + 8 : ptr + 8 + key_exchange_len] + break + + if not public_ec_key: + raise ValueError("No public ECDH key in server hello") + + public_ec_key_x = bytes_to_num(public_ec_key[1:33]) + public_ec_key_y = bytes_to_num(public_ec_key[33:]) + + return server_random, session_id, public_ec_key_x, public_ec_key_y + + +def mutliply_blocks(x, y): + z = 0 + for i in range(128): + if x & (1 << (127 - i)): + z ^= y + y = (y >> 1) ^ (0xE1 << 120) if y & 1 else y >> 1 + return z + + +def ghash(h, data): + CHUNK_LEN = 16 + + y = 0 + for pos in range(0, len(data), CHUNK_LEN): + chunk = bytes_to_num(data[pos : pos + CHUNK_LEN]) + y = mutliply_blocks(y ^ chunk, h) + return y + + +def derive_secret(label, key, data, hash_len): + full_label = b"tls13 " + label + packed_data = ( + num_to_bytes(hash_len, 2) + + num_to_bytes(len(full_label), 1) + + full_label + + num_to_bytes(len(data), 1) + + data + ) + + secret = bytearray() + i = 1 + while len(secret) < hash_len: + secret += hmac.new( + key, secret[-32:] + packed_data + num_to_bytes(i, 1), sha256 + ).digest() + i += 1 + return bytes(secret[:hash_len]) + + +def aes128_expand_key(key): + RCON = [0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x1B, 0x36] + + enc_keys = [[0, 0, 0, 0] for i in range(AES_ROUNDS + 1)] + enc_keys[0] = [bytes_to_num(key[i : i + 4]) for i in [0, 4, 8, 12]] + + for t in range(1, AES_ROUNDS + 1): + prev_key = enc_keys[t - 1] + enc_keys[t][0] = ( + (AES_SBOX[(prev_key[3] >> 8 * 2) & 0xFF] << 8 * 3) + ^ (AES_SBOX[(prev_key[3] >> 8 * 1) & 0xFF] << 8 * 2) + ^ (AES_SBOX[(prev_key[3] >> 8 * 0) & 0xFF] << 8 * 1) + ^ (AES_SBOX[(prev_key[3] >> 8 * 3) & 0xFF] << 8 * 0) + ^ (RCON[t - 1] << 8 * 3) + ^ prev_key[0] + ) + + for i in range(1, 4): + enc_keys[t][i] = enc_keys[t][i - 1] ^ prev_key[i] + return enc_keys + + +def aes128_encrypt(key, plaintext): + TWOTIMES = [2 * num if 2 * num < 256 else 2 * num & 0xFF ^ 27 for num in range(256)] + + enc_keys = aes128_expand_key(key) + + t = [bytes_to_num(plaintext[4 * i : 4 * i + 4]) ^ enc_keys[0][i] for i in range(4)] + for r in range(1, AES_ROUNDS): + t = [ + [ + AES_SBOX[(t[(i + 0) % 4] >> 8 * 3) & 0xFF], + AES_SBOX[(t[(i + 1) % 4] >> 8 * 2) & 0xFF], + AES_SBOX[(t[(i + 2) % 4] >> 8 * 1) & 0xFF], + AES_SBOX[(t[(i + 3) % 4] >> 8 * 0) & 0xFF], + ] + for i in range(4) + ] + + t = [ + [ + c[1] ^ c[2] ^ c[3] ^ TWOTIMES[c[0] ^ c[1]], + c[0] ^ c[2] ^ c[3] ^ TWOTIMES[c[1] ^ c[2]], + c[0] ^ c[1] ^ c[3] ^ TWOTIMES[c[2] ^ c[3]], + c[0] ^ c[1] ^ c[2] ^ TWOTIMES[c[3] ^ c[0]], + ] + for c in t + ] + + t = [bytes_to_num(t[i]) ^ enc_keys[r][i] for i in range(4)] + + result = [ + bytes( + [ + AES_SBOX[(t[(i + 0) % 4] >> 8 * 3) & 0xFF] + ^ (enc_keys[-1][i] >> 8 * 3) & 0xFF, + AES_SBOX[(t[(i + 1) % 4] >> 8 * 2) & 0xFF] + ^ (enc_keys[-1][i] >> 8 * 2) & 0xFF, + AES_SBOX[(t[(i + 2) % 4] >> 8 * 1) & 0xFF] + ^ (enc_keys[-1][i] >> 8 * 1) & 0xFF, + AES_SBOX[(t[(i + 3) % 4] >> 8 * 0) & 0xFF] + ^ (enc_keys[-1][i] >> 8 * 0) & 0xFF, + ] + ) + for i in range(4) + ] + return b"".join(result) + + +def aes128_ctr_encrypt(key, msg, nonce, counter_start_val): + BLOCK_SIZE = 16 + + ans = [] + counter = counter_start_val + for s in range(0, len(msg), BLOCK_SIZE): + chunk = msg[s : s + BLOCK_SIZE] + + chunk_nonce = nonce + num_to_bytes(counter, 4) + encrypted_chunk_nonce = aes128_encrypt(key, chunk_nonce) + + decrypted_chunk = xor(chunk, encrypted_chunk_nonce) + ans.append(decrypted_chunk) + + counter += 1 + return b"".join(ans) + + +def aes128_ctr_decrypt(key, msg, nonce, counter_start_val): + return aes128_ctr_encrypt(key, msg, nonce, counter_start_val) + + +def calc_pretag(key, encrypted_msg, associated_data): + v = b"\x00" * (16 * ((len(associated_data) + 15) // 16) - len(associated_data)) + u = b"\x00" * (16 * ((len(encrypted_msg) + 15) // 16) - len(encrypted_msg)) + + h = bytes_to_num(aes128_encrypt(key, b"\x00" * 16)) + data = ( + associated_data + + v + + encrypted_msg + + u + + num_to_bytes(len(associated_data) * 8, 8) + + num_to_bytes(len(encrypted_msg) * 8, 8) + ) + return num_to_bytes(ghash(h, data), 16) + + +def aes128_gcm_decrypt(key, msg, nonce, associated_data): + TAG_LEN = 16 + + encrypted_msg, tag = msg[:-TAG_LEN], msg[-TAG_LEN:] + + pretag = calc_pretag(key, encrypted_msg, associated_data) + check_tag = aes128_ctr_encrypt(key, pretag, nonce, counter_start_val=1) + if check_tag != tag: + raise ValueError("Decrypt error, bad tag") + return aes128_ctr_decrypt(key, encrypted_msg, nonce, counter_start_val=2) + + +def do_authenticated_decryption(key, nonce_start, seq_num, msg_type, payload): + nonce = xor(nonce_start, num_to_bytes(seq_num, 12)) + + data = msg_type + LEGACY_TLS_VERSION + num_to_bytes(len(payload), 2) + msg = aes128_gcm_decrypt(key, payload, nonce, associated_data=data) + + msg_type, msg_data = msg[-1:], msg[:-1] + return msg_type, msg_data + + +def handle_server_cert(server_cert_data): + handshake_type = server_cert_data[0] + + CERTIFICATE = 0x0B + assert handshake_type == CERTIFICATE + + # certificate_payload_len = bytes_to_num(server_cert_data[1:4]) + certificate_list_len = bytes_to_num(server_cert_data[5:8]) + + certificates = [] + + cert_string_left = server_cert_data[8 : 8 + certificate_list_len] + + while cert_string_left: + cert_len = bytes_to_num(cert_string_left[0:3]) + certificates.append(cert_string_left[3 : 3 + cert_len]) + + cert_string_left = cert_string_left[3 + cert_len + 2 :] + + return certificates + + +def handle_encrypted_extensions(msg): + ENCRYPTED_EXTENSIONS = 0x8 + + assert msg[0] == ENCRYPTED_EXTENSIONS + extensions_length = bytes_to_num(msg[1:4]) + assert len(msg[4:]) >= extensions_length + return msg[4 + extensions_length :] + # ignore the rest + + +def recv_tls_and_decrypt(s, key, nonce, seq_num): + rec_type, encrypted_msg = recv_tls(s) + assert rec_type == APPLICATION_DATA + + msg_type, msg = do_authenticated_decryption( + key, nonce, seq_num, APPLICATION_DATA, encrypted_msg + ) + return msg_type, msg + + +def recv_num_bytes(s, num): + ret = bytearray() + while len(ret) < num: + data = s.recv(min(4096, num - len(ret))) + if not data: + raise BrokenPipeError + ret += data + return bytes(ret) + + +def recv_tls(s): + rec_type = recv_num_bytes(s, 1) + tls_version = recv_num_bytes(s, 2) + + assert tls_version == LEGACY_TLS_VERSION + + rec_len = bytes_to_num(recv_num_bytes(s, 2)) + rec = recv_num_bytes(s, rec_len) + return rec_type, rec + + +def send_tls(s, rec_type, msg): + tls_record = rec_type + LEGACY_TLS_VERSION + num_to_bytes(len(msg), 2) + msg + s.sendall(tls_record) + + +__all__ = ( + "multiply_num_on_ec_point", + "gen_client_hello", + "send_tls", + "recv_tls", + "handle_server_hello", + "num_to_bytes", + "derive_secret", + "recv_tls_and_decrypt", + "handle_encrypted_extensions", + "handle_server_cert", + "HANDSHAKE", + "ALERT", + "CHANGE_CIPHER", +) diff --git a/src/niquests/help.py b/src/niquests/help.py index 80a2320ad6..226ecabe59 100644 --- a/src/niquests/help.py +++ b/src/niquests/help.py @@ -30,6 +30,11 @@ except ImportError: cryptography = None # type: ignore +try: + from .extensions._ocsp import verify as ocsp_verify +except ImportError: + ocsp_verify = None # type: ignore + def _implementation(): """Return a dict with the Python implementation and version. @@ -120,6 +125,7 @@ def info(): "certifi_fallback": wassima.RUSTLS_LOADED is False and certifi is not None, "version": wassima.__version__, }, + "ocsp": {"enabled": ocsp_verify is not None}, } diff --git a/src/niquests/models.py b/src/niquests/models.py index 014bd39952..514e1cbbd9 100644 --- a/src/niquests/models.py +++ b/src/niquests/models.py @@ -254,6 +254,8 @@ def __init__(self) -> None: self._body_position: int | object | None = None #: valuable intel about the opened connection. self.conn_info: ConnectionInfo | None = None + #: marker about if OCSP post-handshake verification took place. + self.ocsp_verified: bool | None = None def prepare( self, @@ -933,10 +935,18 @@ def next(self) -> PreparedRequest | None: @property def conn_info(self) -> ConnectionInfo | None: + """Provide context to the established connection that was used to perform the request.""" if self.request and hasattr(self.request, "conn_info"): return self.request.conn_info return None + @property + def ocsp_verified(self) -> bool | None: + """Marker that can inform you of the OCSP verification.""" + if self.request and hasattr(self.request, "ocsp_verified"): + return self.request.ocsp_verified + return None + def iter_content( self, chunk_size: int = 1, decode_unicode: bool = False ) -> typing.Generator[bytes | str, None, None]: diff --git a/src/niquests/sessions.py b/src/niquests/sessions.py index a37ef435fc..0a154718d6 100644 --- a/src/niquests/sessions.py +++ b/src/niquests/sessions.py @@ -56,6 +56,11 @@ ) from .hooks import HOOKS, default_hooks, dispatch_hook +try: + from .extensions._ocsp import verify as ocsp_verify +except ImportError: + ocsp_verify = None # type: ignore[assignment] + # formerly defined here, reexposed here for backward compatibility from .models import ( # noqa: F401 DEFAULT_REDIRECT_LIMIT, @@ -65,7 +70,7 @@ Response, ) from .status_codes import codes -from .structures import CaseInsensitiveDict +from .structures import CaseInsensitiveDict, SharableLimitedDict from .utils import ( # noqa: F401 DEFAULT_PORTS, default_headers, @@ -188,11 +193,10 @@ class Session: "params", "verify", "cert", - "adapters", "stream", "trust_env", "max_redirects", - "quic_cache_layer", + "retries", ] def __init__( @@ -201,6 +205,9 @@ def __init__( quic_cache_layer: CacheLayerAltSvcType | None = None, retries: RetryType = DEFAULT_RETRIES, ): + #: Configured retries for current Session + self.retries = retries + #: A case-insensitive dictionary of headers to be sent on each #: :class:`Request ` sent from this #: :class:`Session `. @@ -259,7 +266,11 @@ def __init__( #: A simple dict that allows us to persist which server support QUIC #: It is simply forwarded to urllib3.future that handle the caching logic. #: Can be any mutable mapping. - self.quic_cache_layer = quic_cache_layer + self.quic_cache_layer = ( + quic_cache_layer + if quic_cache_layer is not None + else SharableLimitedDict(max_size=12_288) + ) # Default connection adapters. self.adapters: OrderedDict[str, BaseAdapter] = OrderedDict() @@ -935,10 +946,39 @@ def send(self, request: PreparedRequest, **kwargs: typing.Any) -> Response: stream = kwargs.get("stream") hooks = request.hooks + ptr_request = request + def on_post_connection(conn_info: ConnectionInfo) -> None: - nonlocal request - request.conn_info = conn_info - dispatch_hook("pre_send", hooks, request) # type: ignore[arg-type] + """This function will be called by urllib3.future just after establishing the connection.""" + nonlocal ptr_request, request, kwargs + ptr_request.conn_info = conn_info + + if ( + request.url + and request.url.startswith("https://") + and ocsp_verify is not None + and kwargs["verify"] + ): + strict_ocsp_enabled: bool = ( + os.environ.get("NIQUESTS_STRICT_OCSP", "0") != "0" + ) + + with Session() as ocsp_session: + ocsp_session.trust_env = False + + if not strict_ocsp_enabled: + ocsp_session.proxies = kwargs["proxies"] + + ocsp_verify( + ocsp_session, + ptr_request, + strict_ocsp_enabled, + 0.2 if not strict_ocsp_enabled else 1.0, + ) + + # don't trigger pre_send for redirects + if ptr_request == request: + dispatch_hook("pre_send", hooks, ptr_request) # type: ignore[arg-type] kwargs.setdefault("on_post_connection", on_post_connection) @@ -971,8 +1011,16 @@ def on_post_connection(conn_info: ConnectionInfo) -> None: # Resolve redirects if allowed. if allow_redirects: # Redirect resolving generator. - gen = self.resolve_redirects(r, request, **kwargs) - history = [resp for resp in gen if isinstance(resp, Response)] + gen = self.resolve_redirects( + r, request, yield_requests_trail=True, **kwargs + ) + history = [] + + for resp_or_req in gen: + if isinstance(resp_or_req, Response): + history.append(resp_or_req) + continue + ptr_request = resp_or_req else: history = [] @@ -1068,6 +1116,17 @@ def __setstate__(self, state): for attr, value in state.items(): setattr(self, attr, value) + self.quic_cache_layer = SharableLimitedDict(max_size=12_288) + + self.adapters = OrderedDict() + self.mount( + "https://", + HTTPAdapter( + quic_cache_layer=self.quic_cache_layer, max_retries=self.retries + ), + ) + self.mount("http://", HTTPAdapter(max_retries=self.retries)) + def get_redirect_target(self, resp: Response) -> str | None: """Receives a Response. Returns a redirect URI or ``None``""" # Due to the nature of how requests processes redirects this method will @@ -1147,6 +1206,7 @@ def resolve_redirects( cert: TLSClientCertType | None = None, proxies: ProxyType | None = None, yield_requests: bool = False, + yield_requests_trail: bool = False, **adapter_kwargs: typing.Any, ) -> typing.Generator[Response | PreparedRequest, None, None]: """Receives a Response. Returns a generator of Responses or Requests.""" @@ -1255,6 +1315,8 @@ def resolve_redirects( if yield_requests: yield req else: + if yield_requests_trail: + yield req resp = self.send( req, stream=stream, diff --git a/src/niquests/structures.py b/src/niquests/structures.py index 0547d72818..896bbe8aec 100644 --- a/src/niquests/structures.py +++ b/src/niquests/structures.py @@ -7,6 +7,7 @@ from __future__ import annotations +import threading import typing from collections import OrderedDict from collections.abc import Mapping, MutableMapping @@ -101,3 +102,33 @@ def __getitem__(self, key): def get(self, key, default=None): return self.__dict__.get(key, default) + + +class SharableLimitedDict(typing.MutableMapping): + def __init__(self, max_size: int) -> None: + self._store: typing.MutableMapping[typing.Any, typing.Any] = {} + self._max_size = max_size + self._lock = threading.Lock() + + def __delitem__(self, __key) -> None: + with self._lock: + del self._store[__key] + + def __len__(self) -> int: + with self._lock: + return len(self._store) + + def __iter__(self) -> typing.Iterator: + with self._lock: + return iter(self._store) + + def __setitem__(self, key, value): + with self._lock: + if len(self._store) >= self._max_size: + self._store.popitem() + + self._store[key] = value + + def __getitem__(self, item): + with self._lock: + return self._store[item]