From 0d7f8a2d16aa95749d2e0f6eac0c343ebc6619e3 Mon Sep 17 00:00:00 2001 From: Seth Michael Larson Date: Sun, 22 Jan 2023 14:52:52 -0600 Subject: [PATCH] Change the inner SSLContext configuration to undo post-handshake --- src/truststore/_api.py | 36 +++++++++++++++++-------------- src/truststore/_macos.py | 13 +++++++++--- src/truststore/_openssl.py | 8 ++++--- src/truststore/_windows.py | 43 +++++++++++++++++++++++++++++++++++++- tests/test_sslcontext.py | 22 +++++++++++++++++-- 5 files changed, 97 insertions(+), 25 deletions(-) diff --git a/src/truststore/_api.py b/src/truststore/_api.py index 4f2030d..18f7e33 100644 --- a/src/truststore/_api.py +++ b/src/truststore/_api.py @@ -23,7 +23,6 @@ class SSLContext(ssl.SSLContext): def __init__(self, protocol: int = ssl.PROTOCOL_TLS) -> None: self._ctx = ssl.SSLContext(protocol) - _configure_context(self._ctx) class TruststoreSSLObject(ssl.SSLObject): # This object exists because wrap_bio() doesn't @@ -46,14 +45,18 @@ def wrap_socket( server_hostname: str | None = None, session: ssl.SSLSession | None = None, ) -> ssl.SSLSocket: - ssl_sock = self._ctx.wrap_socket( - sock, - server_side=server_side, - server_hostname=server_hostname, - do_handshake_on_connect=do_handshake_on_connect, - suppress_ragged_eofs=suppress_ragged_eofs, - session=session, - ) + # Use a context manager here because the + # inner SSLContext holds on to our state + # but also does the actual handshake. + with _configure_context(self._ctx): + ssl_sock = self._ctx.wrap_socket( + sock, + server_side=server_side, + server_hostname=server_hostname, + do_handshake_on_connect=do_handshake_on_connect, + suppress_ragged_eofs=suppress_ragged_eofs, + session=session, + ) try: _verify_peercerts(ssl_sock, server_hostname=server_hostname) except ssl.SSLError: @@ -69,13 +72,14 @@ def wrap_bio( server_hostname: str | None = None, session: ssl.SSLSession | None = None, ) -> ssl.SSLObject: - ssl_obj = self._ctx.wrap_bio( - incoming, - outgoing, - server_hostname=server_hostname, - server_side=server_side, - session=session, - ) + with _configure_context(self._ctx): + ssl_obj = self._ctx.wrap_bio( + incoming, + outgoing, + server_hostname=server_hostname, + server_side=server_side, + session=session, + ) return ssl_obj def load_verify_locations( diff --git a/src/truststore/_macos.py b/src/truststore/_macos.py index 5554dea..e23cfd2 100644 --- a/src/truststore/_macos.py +++ b/src/truststore/_macos.py @@ -1,6 +1,8 @@ +import contextlib import ctypes import platform import ssl +import typing from ctypes import ( CDLL, POINTER, @@ -13,7 +15,6 @@ c_void_p, ) from ctypes.util import find_library -from typing import Any _mac_version = platform.mac_ver()[0] _mac_version_info = tuple(map(int, _mac_version.split("."))) @@ -201,7 +202,7 @@ def _load_cdll(name: str, macos10_16_path: str) -> CDLL: raise ImportError("Error initializing ctypes") from None -def _handle_osstatus(result: OSStatus, _: Any, args: Any) -> Any: +def _handle_osstatus(result: OSStatus, _: typing.Any, args: typing.Any) -> typing.Any: """ Raises an error if the OSStatus value is non-zero. """ @@ -338,9 +339,15 @@ def _der_certs_to_cf_cert_array(certs: list[bytes]) -> CFMutableArrayRef: # typ return cf_array # type: ignore[no-any-return] -def _configure_context(ctx: ssl.SSLContext) -> None: +@contextlib.contextmanager +def _configure_context(ctx: ssl.SSLContext) -> typing.Iterator[None]: + values = ctx.check_hostname, ctx.verify_mode ctx.check_hostname = False ctx.verify_mode = ssl.CERT_NONE + try: + yield + finally: + ctx.check_hostname, ctx.verify_mode = values def _verify_peercerts_impl( diff --git a/src/truststore/_openssl.py b/src/truststore/_openssl.py index 86f37ee..9951cf7 100644 --- a/src/truststore/_openssl.py +++ b/src/truststore/_openssl.py @@ -1,6 +1,8 @@ +import contextlib import os import re import ssl +import typing # candidates based on https://github.com/tiran/certifi-system-store by Christian Heimes _CA_FILE_CANDIDATES = [ @@ -17,7 +19,8 @@ _HASHED_CERT_FILENAME_RE = re.compile(r"^[0-9a-fA-F]{8}\.[0-9]$") -def _configure_context(ctx: ssl.SSLContext) -> None: +@contextlib.contextmanager +def _configure_context(ctx: ssl.SSLContext) -> typing.Iterator[None]: # First, check whether the default locations from OpenSSL # seem like they will give us a usable set of CA certs. # ssl.get_default_verify_paths already takes care of: @@ -40,8 +43,7 @@ def _configure_context(ctx: ssl.SSLContext) -> None: ctx.load_verify_locations(cafile=cafile) break - ctx.verify_mode = ssl.CERT_REQUIRED - ctx.check_hostname = True + yield def _capath_contains_certs(capath: str) -> bool: diff --git a/src/truststore/_windows.py b/src/truststore/_windows.py index c92fdb2..9596570 100644 --- a/src/truststore/_windows.py +++ b/src/truststore/_windows.py @@ -1,4 +1,6 @@ +import contextlib import ssl +import typing from ctypes import WinDLL # type: ignore from ctypes import WinError # type: ignore from ctypes import ( @@ -199,11 +201,33 @@ class CERT_CHAIN_ENGINE_CONFIG(Structure): OID_PKIX_KP_SERVER_AUTH = c_char_p(b"1.3.6.1.5.5.7.3.1") CERT_CHAIN_REVOCATION_CHECK_END_CERT = 0x10000000 CERT_CHAIN_REVOCATION_CHECK_CHAIN = 0x20000000 +CERT_CHAIN_POLICY_IGNORE_ALL_NOT_TIME_VALID_FLAGS = 0x00000007 +CERT_CHAIN_POLICY_IGNORE_INVALID_BASIC_CONSTRAINTS_FLAG = 0x00000008 +CERT_CHAIN_POLICY_ALLOW_UNKNOWN_CA_FLAG = 0x00000010 +CERT_CHAIN_POLICY_IGNORE_INVALID_NAME_FLAG = 0x00000040 +CERT_CHAIN_POLICY_IGNORE_WRONG_USAGE_FLAG = 0x00000020 +CERT_CHAIN_POLICY_IGNORE_INVALID_POLICY_FLAG = 0x00000080 +CERT_CHAIN_POLICY_IGNORE_ALL_REV_UNKNOWN_FLAGS = 0x00000F00 +CERT_CHAIN_POLICY_ALLOW_TESTROOT_FLAG = 0x00008000 +CERT_CHAIN_POLICY_TRUST_TESTROOT_FLAG = 0x00004000 AUTHTYPE_SERVER = 2 CERT_CHAIN_POLICY_SSL = 4 FORMAT_MESSAGE_FROM_SYSTEM = 0x00001000 FORMAT_MESSAGE_IGNORE_INSERTS = 0x00000200 +# Flags to set for SSLContext.verify_mode=CERT_NONE +CERT_CHAIN_POLICY_VERIFY_MODE_NONE_FLAGS = ( + CERT_CHAIN_POLICY_IGNORE_ALL_NOT_TIME_VALID_FLAGS + | CERT_CHAIN_POLICY_IGNORE_INVALID_BASIC_CONSTRAINTS_FLAG + | CERT_CHAIN_POLICY_ALLOW_UNKNOWN_CA_FLAG + | CERT_CHAIN_POLICY_IGNORE_INVALID_NAME_FLAG + | CERT_CHAIN_POLICY_IGNORE_WRONG_USAGE_FLAG + | CERT_CHAIN_POLICY_IGNORE_INVALID_POLICY_FLAG + | CERT_CHAIN_POLICY_IGNORE_ALL_REV_UNKNOWN_FLAGS + | CERT_CHAIN_POLICY_ALLOW_TESTROOT_FLAG + | CERT_CHAIN_POLICY_TRUST_TESTROOT_FLAG +) + wincrypt = WinDLL("crypt32.dll") kernel32 = WinDLL("kernel32.dll") @@ -341,6 +365,7 @@ def _verify_peercerts_impl( # First attempt to verify using the default Windows system trust roots # (default chain engine). _get_and_verify_cert_chain( + ssl_context, None, hIntermediateCertStore, pCertContext, @@ -358,6 +383,7 @@ def _verify_peercerts_impl( ) if custom_ca_certs: _verify_using_custom_ca_certs( + ssl_context, custom_ca_certs, hIntermediateCertStore, pCertContext, @@ -374,6 +400,7 @@ def _verify_peercerts_impl( def _get_and_verify_cert_chain( + ssl_context: ssl.SSLContext, hChainEngine: HCERTCHAINENGINE | None, hIntermediateCertStore: HCERTSTORE, pPeerCertContext: c_void_p, @@ -406,11 +433,17 @@ def _get_and_verify_cert_chain( ssl_extra_cert_chain_policy_para.fdwChecks = 0 if server_hostname: ssl_extra_cert_chain_policy_para.pwszServerName = c_wchar_p(server_hostname) + chain_policy = CERT_CHAIN_POLICY_PARA() chain_policy.pvExtraPolicyPara = cast( pointer(ssl_extra_cert_chain_policy_para), c_void_p ) + if ssl_context.verify_mode == ssl.CERT_NONE: + chain_policy.dwFlags |= CERT_CHAIN_POLICY_VERIFY_MODE_NONE_FLAGS + if not ssl_context.check_hostname: + chain_policy.dwFlags |= CERT_CHAIN_POLICY_IGNORE_INVALID_NAME_FLAG chain_policy.cbSize = sizeof(chain_policy) + pPolicyPara = pointer(chain_policy) policy_status = CERT_CHAIN_POLICY_STATUS() policy_status.cbSize = sizeof(policy_status) @@ -456,6 +489,7 @@ def _get_and_verify_cert_chain( def _verify_using_custom_ca_certs( + ssl_context: ssl.SSLContext, custom_ca_certs: list[bytes], hIntermediateCertStore: HCERTSTORE, pPeerCertContext: c_void_p, @@ -492,6 +526,7 @@ def _verify_using_custom_ca_certs( # Get and verify a cert chain using the custom chain engine _get_and_verify_cert_chain( + ssl_context, hChainEngine, hIntermediateCertStore, pPeerCertContext, @@ -505,6 +540,12 @@ def _verify_using_custom_ca_certs( CertCloseStore(hRootCertStore, 0) -def _configure_context(ctx: ssl.SSLContext) -> None: +@contextlib.contextmanager +def _configure_context(ctx: ssl.SSLContext) -> typing.Iterator[None]: + values = ctx.check_hostname, ctx.verify_mode ctx.check_hostname = False ctx.verify_mode = ssl.CERT_NONE + try: + yield + finally: + ctx.check_hostname, ctx.verify_mode = values diff --git a/tests/test_sslcontext.py b/tests/test_sslcontext.py index 5241abc..af84294 100644 --- a/tests/test_sslcontext.py +++ b/tests/test_sslcontext.py @@ -3,7 +3,7 @@ import pytest import urllib3 -from urllib3.exceptions import InsecureRequestWarning +from urllib3.exceptions import InsecureRequestWarning, SSLError import truststore @@ -11,6 +11,7 @@ def test_minimum_maximum_version(): ctx = truststore.SSLContext(ssl.PROTOCOL_TLS_CLIENT) ctx.maximum_version = ssl.TLSVersion.TLSv1_2 + with urllib3.PoolManager(ssl_context=ctx) as http: resp = http.request("GET", "https://howsmyssl.com/a/check") @@ -24,11 +25,28 @@ def test_minimum_maximum_version(): assert ctx.maximum_version == ssl.TLSVersion.TLSv1_2 -def test_disable_verification(): +def test_check_hostname_false(): ctx = truststore.SSLContext(ssl.PROTOCOL_TLS_CLIENT) + assert ctx.check_hostname is True + assert ctx.verify_mode == ssl.CERT_REQUIRED + + with urllib3.PoolManager(ssl_context=ctx, retries=False) as http: + with pytest.raises(SSLError) as e: + http.request("GET", "https://wrong.host.badssl.com/") + assert "match" in str(e.value) + + +def test_verify_mode_cert_none(): + ctx = truststore.SSLContext(ssl.PROTOCOL_TLS_CLIENT) + assert ctx.check_hostname is True + assert ctx.verify_mode == ssl.CERT_REQUIRED + ctx.check_hostname = False ctx.verify_mode = ssl.CERT_NONE + assert ctx.check_hostname is False + assert ctx.verify_mode == ssl.CERT_NONE + with urllib3.PoolManager(ssl_context=ctx) as http, pytest.warns( InsecureRequestWarning ) as w: