From 1df489515dbf0235cbd94946405ec76feaf65559 Mon Sep 17 00:00:00 2001 From: Seth Michael Larson Date: Sun, 22 Jan 2023 13:59:01 -0600 Subject: [PATCH 1/3] Add SSLContext methods and properties --- src/truststore/_api.py | 141 ++++++++++++++++++++++++++++++++++++--- tests/test_sslcontext.py | 36 ++++++++++ 2 files changed, 169 insertions(+), 8 deletions(-) create mode 100644 tests/test_sslcontext.py diff --git a/src/truststore/_api.py b/src/truststore/_api.py index 463bdbf..4f2030d 100644 --- a/src/truststore/_api.py +++ b/src/truststore/_api.py @@ -2,7 +2,7 @@ import platform import socket import ssl -from typing import Any +import typing from _ssl import ENCODING_DER # type: ignore[import] @@ -14,6 +14,10 @@ from ._openssl import _configure_context, _verify_peercerts_impl +_StrOrBytesPath: typing.TypeAlias = str | bytes | os.PathLike[str] | os.PathLike[bytes] +_PasswordType: typing.TypeAlias = str | bytes | typing.Callable[[], str | bytes] + + class SSLContext(ssl.SSLContext): """SSLContext API that uses system certificates on all platforms""" @@ -84,14 +88,135 @@ def load_verify_locations( cafile=cafile, capath=capath, cadata=cadata ) - def __getattr__(self, name: str) -> Any: - return getattr(self._ctx, name) + def load_cert_chain( + self, + certfile: _StrOrBytesPath, + keyfile: _StrOrBytesPath | None = None, + password: _PasswordType | None = None, + ) -> None: + return self._ctx.load_cert_chain( + certfile=certfile, keyfile=keyfile, password=password + ) + + def load_default_certs( + self, purpose: ssl.Purpose = ssl.Purpose.SERVER_AUTH + ) -> None: + return self._ctx.load_default_certs(purpose) + + def set_alpn_protocols(self, alpn_protocols: typing.Iterable[str]) -> None: + return self._ctx.set_alpn_protocols(alpn_protocols) + + def set_npn_protocols(self, npn_protocols: typing.Iterable[str]) -> None: + return self._ctx.set_npn_protocols(npn_protocols) + + def set_ciphers(self, __cipherlist: str) -> None: + return self._ctx.set_ciphers(__cipherlist) + + def get_ciphers(self) -> typing.Any: + return self._ctx.get_ciphers() + + def session_stats(self) -> dict[str, int]: + return self._ctx.session_stats() + + def cert_store_stats(self) -> dict[str, int]: + raise NotImplementedError() + + @typing.overload + def get_ca_certs( + self, binary_form: typing.Literal[False] = ... + ) -> list[typing.Any]: + ... + + @typing.overload + def get_ca_certs(self, binary_form: typing.Literal[True] = ...) -> list[bytes]: + ... + + @typing.overload + def get_ca_certs(self, binary_form: bool = ...) -> typing.Any: + ... + + def get_ca_certs(self, binary_form: bool = False) -> list[typing.Any] | list[bytes]: + raise NotImplementedError() + + @property + def check_hostname(self) -> bool: + return self._ctx.check_hostname + + @check_hostname.setter + def check_hostname(self, value: bool) -> None: + self._ctx.check_hostname = value + + @property + def hostname_checks_common_name(self) -> bool: + return self._ctx.hostname_checks_common_name + + @hostname_checks_common_name.setter + def hostname_checks_common_name(self, value: bool) -> None: + self._ctx.hostname_checks_common_name = value + + @property + def keylog_filename(self) -> str: + return self._ctx.keylog_filename + + @keylog_filename.setter + def keylog_filename(self, value: str) -> None: + self._ctx.keylog_filename = value + + @property + def maximum_version(self) -> ssl.TLSVersion: + return self._ctx.maximum_version + + @maximum_version.setter + def maximum_version(self, value: ssl.TLSVersion) -> None: + self._ctx.maximum_version = value + + @property + def minimum_version(self) -> ssl.TLSVersion: + return self._ctx.minimum_version + + @minimum_version.setter + def minimum_version(self, value: ssl.TLSVersion) -> None: + self._ctx.minimum_version = value + + @property + def options(self) -> ssl.Options: + return self._ctx.options + + @options.setter + def options(self, value: ssl.Options) -> None: + self._ctx.options = value + + @property + def post_handshake_auth(self) -> bool: + return self._ctx.post_handshake_auth + + @post_handshake_auth.setter + def post_handshake_auth(self, value: bool) -> None: + self._ctx.post_handshake_auth = value + + @property + def protocol(self) -> ssl._SSLMethod: + return self._ctx.protocol + + @property + def security_level(self) -> int: + return self._ctx.security_level # type: ignore[attr-defined,no-any-return] + + @property + def verify_flags(self) -> ssl.VerifyFlags: + return self._ctx.verify_flags + + @verify_flags.setter + def verify_flags(self, value: ssl.VerifyFlags) -> None: + self._ctx.verify_flags = value + + @property + def verify_mode(self) -> ssl.VerifyMode: + return self._ctx.verify_mode - def __setattr__(self, name: str, value: Any) -> None: - if name == "verify_flags": - self._ctx.verify_flags = value - else: - return super().__setattr__(name, value) + @verify_mode.setter + def verify_mode(self, value: ssl.VerifyMode) -> None: + self._ctx.verify_mode = value def _verify_peercerts( diff --git a/tests/test_sslcontext.py b/tests/test_sslcontext.py new file mode 100644 index 0000000..5241abc --- /dev/null +++ b/tests/test_sslcontext.py @@ -0,0 +1,36 @@ +import json +import ssl + +import pytest +import urllib3 +from urllib3.exceptions import InsecureRequestWarning + +import truststore + + +def test_minimum_maximum_version(): + ctx = truststore.SSLContext(ssl.PROTOCOL_TLS_CLIENT) + ctx.maximum_version = ssl.TLSVersion.TLSv1_2 + with urllib3.PoolManager(ssl_context=ctx) as http: + + resp = http.request("GET", "https://howsmyssl.com/a/check") + data = json.loads(resp.data) + assert data["tls_version"] == "TLS 1.2" + + assert ctx.minimum_version in ( + ssl.TLSVersion.TLSv1_2, + ssl.TLSVersion.MINIMUM_SUPPORTED, + ) + assert ctx.maximum_version == ssl.TLSVersion.TLSv1_2 + + +def test_disable_verification(): + ctx = truststore.SSLContext(ssl.PROTOCOL_TLS_CLIENT) + ctx.check_hostname = False + ctx.verify_mode = ssl.CERT_NONE + + with urllib3.PoolManager(ssl_context=ctx) as http, pytest.warns( + InsecureRequestWarning + ) as w: + http.request("GET", "https://expired.badssl.com/") + assert len(w) == 1 From 0d7f8a2d16aa95749d2e0f6eac0c343ebc6619e3 Mon Sep 17 00:00:00 2001 From: Seth Michael Larson Date: Sun, 22 Jan 2023 14:52:52 -0600 Subject: [PATCH 2/3] Change the inner SSLContext configuration to undo post-handshake --- src/truststore/_api.py | 36 +++++++++++++++++-------------- src/truststore/_macos.py | 13 +++++++++--- src/truststore/_openssl.py | 8 ++++--- src/truststore/_windows.py | 43 +++++++++++++++++++++++++++++++++++++- tests/test_sslcontext.py | 22 +++++++++++++++++-- 5 files changed, 97 insertions(+), 25 deletions(-) diff --git a/src/truststore/_api.py b/src/truststore/_api.py index 4f2030d..18f7e33 100644 --- a/src/truststore/_api.py +++ b/src/truststore/_api.py @@ -23,7 +23,6 @@ class SSLContext(ssl.SSLContext): def __init__(self, protocol: int = ssl.PROTOCOL_TLS) -> None: self._ctx = ssl.SSLContext(protocol) - _configure_context(self._ctx) class TruststoreSSLObject(ssl.SSLObject): # This object exists because wrap_bio() doesn't @@ -46,14 +45,18 @@ def wrap_socket( server_hostname: str | None = None, session: ssl.SSLSession | None = None, ) -> ssl.SSLSocket: - ssl_sock = self._ctx.wrap_socket( - sock, - server_side=server_side, - server_hostname=server_hostname, - do_handshake_on_connect=do_handshake_on_connect, - suppress_ragged_eofs=suppress_ragged_eofs, - session=session, - ) + # Use a context manager here because the + # inner SSLContext holds on to our state + # but also does the actual handshake. + with _configure_context(self._ctx): + ssl_sock = self._ctx.wrap_socket( + sock, + server_side=server_side, + server_hostname=server_hostname, + do_handshake_on_connect=do_handshake_on_connect, + suppress_ragged_eofs=suppress_ragged_eofs, + session=session, + ) try: _verify_peercerts(ssl_sock, server_hostname=server_hostname) except ssl.SSLError: @@ -69,13 +72,14 @@ def wrap_bio( server_hostname: str | None = None, session: ssl.SSLSession | None = None, ) -> ssl.SSLObject: - ssl_obj = self._ctx.wrap_bio( - incoming, - outgoing, - server_hostname=server_hostname, - server_side=server_side, - session=session, - ) + with _configure_context(self._ctx): + ssl_obj = self._ctx.wrap_bio( + incoming, + outgoing, + server_hostname=server_hostname, + server_side=server_side, + session=session, + ) return ssl_obj def load_verify_locations( diff --git a/src/truststore/_macos.py b/src/truststore/_macos.py index 5554dea..e23cfd2 100644 --- a/src/truststore/_macos.py +++ b/src/truststore/_macos.py @@ -1,6 +1,8 @@ +import contextlib import ctypes import platform import ssl +import typing from ctypes import ( CDLL, POINTER, @@ -13,7 +15,6 @@ c_void_p, ) from ctypes.util import find_library -from typing import Any _mac_version = platform.mac_ver()[0] _mac_version_info = tuple(map(int, _mac_version.split("."))) @@ -201,7 +202,7 @@ def _load_cdll(name: str, macos10_16_path: str) -> CDLL: raise ImportError("Error initializing ctypes") from None -def _handle_osstatus(result: OSStatus, _: Any, args: Any) -> Any: +def _handle_osstatus(result: OSStatus, _: typing.Any, args: typing.Any) -> typing.Any: """ Raises an error if the OSStatus value is non-zero. """ @@ -338,9 +339,15 @@ def _der_certs_to_cf_cert_array(certs: list[bytes]) -> CFMutableArrayRef: # typ return cf_array # type: ignore[no-any-return] -def _configure_context(ctx: ssl.SSLContext) -> None: +@contextlib.contextmanager +def _configure_context(ctx: ssl.SSLContext) -> typing.Iterator[None]: + values = ctx.check_hostname, ctx.verify_mode ctx.check_hostname = False ctx.verify_mode = ssl.CERT_NONE + try: + yield + finally: + ctx.check_hostname, ctx.verify_mode = values def _verify_peercerts_impl( diff --git a/src/truststore/_openssl.py b/src/truststore/_openssl.py index 86f37ee..9951cf7 100644 --- a/src/truststore/_openssl.py +++ b/src/truststore/_openssl.py @@ -1,6 +1,8 @@ +import contextlib import os import re import ssl +import typing # candidates based on https://github.com/tiran/certifi-system-store by Christian Heimes _CA_FILE_CANDIDATES = [ @@ -17,7 +19,8 @@ _HASHED_CERT_FILENAME_RE = re.compile(r"^[0-9a-fA-F]{8}\.[0-9]$") -def _configure_context(ctx: ssl.SSLContext) -> None: +@contextlib.contextmanager +def _configure_context(ctx: ssl.SSLContext) -> typing.Iterator[None]: # First, check whether the default locations from OpenSSL # seem like they will give us a usable set of CA certs. # ssl.get_default_verify_paths already takes care of: @@ -40,8 +43,7 @@ def _configure_context(ctx: ssl.SSLContext) -> None: ctx.load_verify_locations(cafile=cafile) break - ctx.verify_mode = ssl.CERT_REQUIRED - ctx.check_hostname = True + yield def _capath_contains_certs(capath: str) -> bool: diff --git a/src/truststore/_windows.py b/src/truststore/_windows.py index c92fdb2..9596570 100644 --- a/src/truststore/_windows.py +++ b/src/truststore/_windows.py @@ -1,4 +1,6 @@ +import contextlib import ssl +import typing from ctypes import WinDLL # type: ignore from ctypes import WinError # type: ignore from ctypes import ( @@ -199,11 +201,33 @@ class CERT_CHAIN_ENGINE_CONFIG(Structure): OID_PKIX_KP_SERVER_AUTH = c_char_p(b"1.3.6.1.5.5.7.3.1") CERT_CHAIN_REVOCATION_CHECK_END_CERT = 0x10000000 CERT_CHAIN_REVOCATION_CHECK_CHAIN = 0x20000000 +CERT_CHAIN_POLICY_IGNORE_ALL_NOT_TIME_VALID_FLAGS = 0x00000007 +CERT_CHAIN_POLICY_IGNORE_INVALID_BASIC_CONSTRAINTS_FLAG = 0x00000008 +CERT_CHAIN_POLICY_ALLOW_UNKNOWN_CA_FLAG = 0x00000010 +CERT_CHAIN_POLICY_IGNORE_INVALID_NAME_FLAG = 0x00000040 +CERT_CHAIN_POLICY_IGNORE_WRONG_USAGE_FLAG = 0x00000020 +CERT_CHAIN_POLICY_IGNORE_INVALID_POLICY_FLAG = 0x00000080 +CERT_CHAIN_POLICY_IGNORE_ALL_REV_UNKNOWN_FLAGS = 0x00000F00 +CERT_CHAIN_POLICY_ALLOW_TESTROOT_FLAG = 0x00008000 +CERT_CHAIN_POLICY_TRUST_TESTROOT_FLAG = 0x00004000 AUTHTYPE_SERVER = 2 CERT_CHAIN_POLICY_SSL = 4 FORMAT_MESSAGE_FROM_SYSTEM = 0x00001000 FORMAT_MESSAGE_IGNORE_INSERTS = 0x00000200 +# Flags to set for SSLContext.verify_mode=CERT_NONE +CERT_CHAIN_POLICY_VERIFY_MODE_NONE_FLAGS = ( + CERT_CHAIN_POLICY_IGNORE_ALL_NOT_TIME_VALID_FLAGS + | CERT_CHAIN_POLICY_IGNORE_INVALID_BASIC_CONSTRAINTS_FLAG + | CERT_CHAIN_POLICY_ALLOW_UNKNOWN_CA_FLAG + | CERT_CHAIN_POLICY_IGNORE_INVALID_NAME_FLAG + | CERT_CHAIN_POLICY_IGNORE_WRONG_USAGE_FLAG + | CERT_CHAIN_POLICY_IGNORE_INVALID_POLICY_FLAG + | CERT_CHAIN_POLICY_IGNORE_ALL_REV_UNKNOWN_FLAGS + | CERT_CHAIN_POLICY_ALLOW_TESTROOT_FLAG + | CERT_CHAIN_POLICY_TRUST_TESTROOT_FLAG +) + wincrypt = WinDLL("crypt32.dll") kernel32 = WinDLL("kernel32.dll") @@ -341,6 +365,7 @@ def _verify_peercerts_impl( # First attempt to verify using the default Windows system trust roots # (default chain engine). _get_and_verify_cert_chain( + ssl_context, None, hIntermediateCertStore, pCertContext, @@ -358,6 +383,7 @@ def _verify_peercerts_impl( ) if custom_ca_certs: _verify_using_custom_ca_certs( + ssl_context, custom_ca_certs, hIntermediateCertStore, pCertContext, @@ -374,6 +400,7 @@ def _verify_peercerts_impl( def _get_and_verify_cert_chain( + ssl_context: ssl.SSLContext, hChainEngine: HCERTCHAINENGINE | None, hIntermediateCertStore: HCERTSTORE, pPeerCertContext: c_void_p, @@ -406,11 +433,17 @@ def _get_and_verify_cert_chain( ssl_extra_cert_chain_policy_para.fdwChecks = 0 if server_hostname: ssl_extra_cert_chain_policy_para.pwszServerName = c_wchar_p(server_hostname) + chain_policy = CERT_CHAIN_POLICY_PARA() chain_policy.pvExtraPolicyPara = cast( pointer(ssl_extra_cert_chain_policy_para), c_void_p ) + if ssl_context.verify_mode == ssl.CERT_NONE: + chain_policy.dwFlags |= CERT_CHAIN_POLICY_VERIFY_MODE_NONE_FLAGS + if not ssl_context.check_hostname: + chain_policy.dwFlags |= CERT_CHAIN_POLICY_IGNORE_INVALID_NAME_FLAG chain_policy.cbSize = sizeof(chain_policy) + pPolicyPara = pointer(chain_policy) policy_status = CERT_CHAIN_POLICY_STATUS() policy_status.cbSize = sizeof(policy_status) @@ -456,6 +489,7 @@ def _get_and_verify_cert_chain( def _verify_using_custom_ca_certs( + ssl_context: ssl.SSLContext, custom_ca_certs: list[bytes], hIntermediateCertStore: HCERTSTORE, pPeerCertContext: c_void_p, @@ -492,6 +526,7 @@ def _verify_using_custom_ca_certs( # Get and verify a cert chain using the custom chain engine _get_and_verify_cert_chain( + ssl_context, hChainEngine, hIntermediateCertStore, pPeerCertContext, @@ -505,6 +540,12 @@ def _verify_using_custom_ca_certs( CertCloseStore(hRootCertStore, 0) -def _configure_context(ctx: ssl.SSLContext) -> None: +@contextlib.contextmanager +def _configure_context(ctx: ssl.SSLContext) -> typing.Iterator[None]: + values = ctx.check_hostname, ctx.verify_mode ctx.check_hostname = False ctx.verify_mode = ssl.CERT_NONE + try: + yield + finally: + ctx.check_hostname, ctx.verify_mode = values diff --git a/tests/test_sslcontext.py b/tests/test_sslcontext.py index 5241abc..af84294 100644 --- a/tests/test_sslcontext.py +++ b/tests/test_sslcontext.py @@ -3,7 +3,7 @@ import pytest import urllib3 -from urllib3.exceptions import InsecureRequestWarning +from urllib3.exceptions import InsecureRequestWarning, SSLError import truststore @@ -11,6 +11,7 @@ def test_minimum_maximum_version(): ctx = truststore.SSLContext(ssl.PROTOCOL_TLS_CLIENT) ctx.maximum_version = ssl.TLSVersion.TLSv1_2 + with urllib3.PoolManager(ssl_context=ctx) as http: resp = http.request("GET", "https://howsmyssl.com/a/check") @@ -24,11 +25,28 @@ def test_minimum_maximum_version(): assert ctx.maximum_version == ssl.TLSVersion.TLSv1_2 -def test_disable_verification(): +def test_check_hostname_false(): ctx = truststore.SSLContext(ssl.PROTOCOL_TLS_CLIENT) + assert ctx.check_hostname is True + assert ctx.verify_mode == ssl.CERT_REQUIRED + + with urllib3.PoolManager(ssl_context=ctx, retries=False) as http: + with pytest.raises(SSLError) as e: + http.request("GET", "https://wrong.host.badssl.com/") + assert "match" in str(e.value) + + +def test_verify_mode_cert_none(): + ctx = truststore.SSLContext(ssl.PROTOCOL_TLS_CLIENT) + assert ctx.check_hostname is True + assert ctx.verify_mode == ssl.CERT_REQUIRED + ctx.check_hostname = False ctx.verify_mode = ssl.CERT_NONE + assert ctx.check_hostname is False + assert ctx.verify_mode == ssl.CERT_NONE + with urllib3.PoolManager(ssl_context=ctx) as http, pytest.warns( InsecureRequestWarning ) as w: From c26b914f81a007bd826a3f5cd5f4f94633d89765 Mon Sep 17 00:00:00 2001 From: Seth Michael Larson Date: Tue, 24 Jan 2023 21:49:23 -0600 Subject: [PATCH 3/3] Add support for macOS verification disabling too --- src/truststore/_macos.py | 24 ++++++++++++++++++++++++ 1 file changed, 24 insertions(+) diff --git a/src/truststore/_macos.py b/src/truststore/_macos.py index e23cfd2..b8fd57e 100644 --- a/src/truststore/_macos.py +++ b/src/truststore/_macos.py @@ -265,6 +265,11 @@ class CFConst: kCFStringEncodingUTF8 = CFStringEncoding(0x08000100) + errSecIncompleteCertRevocationCheck = -67635 + errSecHostNameMismatch = -67602 + errSecCertificateExpired = -67818 + errSecNotTrusted = -67843 + def _bytes_to_cf_data_ref(value: bytes) -> CFDataRef: # type: ignore[valid-type] return CoreFoundation.CFDataCreate( # type: ignore[no-any-return] @@ -439,8 +444,27 @@ def _verify_peercerts_impl( f"Unknown result from Security.SecTrustEvaluateWithError: {sec_trust_eval_result!r}" ) + cf_error_code = 0 if not is_trusted: cf_error_code = CoreFoundation.CFErrorGetCode(cf_error) + + # If the error is a known failure that we're + # explicitly okay with from SSLContext configuration + # we can set is_trusted accordingly. + if ssl_context.verify_mode != ssl.CERT_REQUIRED and ( + cf_error_code == CFConst.errSecNotTrusted + or cf_error_code == CFConst.errSecCertificateExpired + ): + is_trusted = True + elif ( + not ssl_context.check_hostname + and cf_error_code == CFConst.errSecHostNameMismatch + ): + is_trusted = True + + # If we're still not trusted then we start to + # construct and raise the SSLCertVerificationError. + if not is_trusted: cf_error_string_ref = None try: cf_error_string_ref = CoreFoundation.CFErrorCopyDescription(cf_error)