Skip to content

Commit

Permalink
Add SSLContext methods and properties
Browse files Browse the repository at this point in the history
  • Loading branch information
sethmlarson authored Feb 20, 2023
1 parent b6db4be commit 63dc9e1
Show file tree
Hide file tree
Showing 5 changed files with 288 additions and 31 deletions.
177 changes: 153 additions & 24 deletions src/truststore/_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import platform
import socket
import ssl
from typing import Any
import typing

from _ssl import ENCODING_DER # type: ignore[import]

Expand All @@ -14,12 +14,15 @@
from ._openssl import _configure_context, _verify_peercerts_impl


_StrOrBytesPath: typing.TypeAlias = str | bytes | os.PathLike[str] | os.PathLike[bytes]
_PasswordType: typing.TypeAlias = str | bytes | typing.Callable[[], str | bytes]


class SSLContext(ssl.SSLContext):
"""SSLContext API that uses system certificates on all platforms"""

def __init__(self, protocol: int = ssl.PROTOCOL_TLS) -> None:
self._ctx = ssl.SSLContext(protocol)
_configure_context(self._ctx)

class TruststoreSSLObject(ssl.SSLObject):
# This object exists because wrap_bio() doesn't
Expand All @@ -42,14 +45,18 @@ def wrap_socket(
server_hostname: str | None = None,
session: ssl.SSLSession | None = None,
) -> ssl.SSLSocket:
ssl_sock = self._ctx.wrap_socket(
sock,
server_side=server_side,
server_hostname=server_hostname,
do_handshake_on_connect=do_handshake_on_connect,
suppress_ragged_eofs=suppress_ragged_eofs,
session=session,
)
# Use a context manager here because the
# inner SSLContext holds on to our state
# but also does the actual handshake.
with _configure_context(self._ctx):
ssl_sock = self._ctx.wrap_socket(
sock,
server_side=server_side,
server_hostname=server_hostname,
do_handshake_on_connect=do_handshake_on_connect,
suppress_ragged_eofs=suppress_ragged_eofs,
session=session,
)
try:
_verify_peercerts(ssl_sock, server_hostname=server_hostname)
except ssl.SSLError:
Expand All @@ -65,13 +72,14 @@ def wrap_bio(
server_hostname: str | None = None,
session: ssl.SSLSession | None = None,
) -> ssl.SSLObject:
ssl_obj = self._ctx.wrap_bio(
incoming,
outgoing,
server_hostname=server_hostname,
server_side=server_side,
session=session,
)
with _configure_context(self._ctx):
ssl_obj = self._ctx.wrap_bio(
incoming,
outgoing,
server_hostname=server_hostname,
server_side=server_side,
session=session,
)
return ssl_obj

def load_verify_locations(
Expand All @@ -84,14 +92,135 @@ def load_verify_locations(
cafile=cafile, capath=capath, cadata=cadata
)

def __getattr__(self, name: str) -> Any:
return getattr(self._ctx, name)
def load_cert_chain(
self,
certfile: _StrOrBytesPath,
keyfile: _StrOrBytesPath | None = None,
password: _PasswordType | None = None,
) -> None:
return self._ctx.load_cert_chain(
certfile=certfile, keyfile=keyfile, password=password
)

def load_default_certs(
self, purpose: ssl.Purpose = ssl.Purpose.SERVER_AUTH
) -> None:
return self._ctx.load_default_certs(purpose)

def set_alpn_protocols(self, alpn_protocols: typing.Iterable[str]) -> None:
return self._ctx.set_alpn_protocols(alpn_protocols)

def set_npn_protocols(self, npn_protocols: typing.Iterable[str]) -> None:
return self._ctx.set_npn_protocols(npn_protocols)

def set_ciphers(self, __cipherlist: str) -> None:
return self._ctx.set_ciphers(__cipherlist)

def get_ciphers(self) -> typing.Any:
return self._ctx.get_ciphers()

def session_stats(self) -> dict[str, int]:
return self._ctx.session_stats()

def cert_store_stats(self) -> dict[str, int]:
raise NotImplementedError()

@typing.overload
def get_ca_certs(
self, binary_form: typing.Literal[False] = ...
) -> list[typing.Any]:
...

@typing.overload
def get_ca_certs(self, binary_form: typing.Literal[True] = ...) -> list[bytes]:
...

@typing.overload
def get_ca_certs(self, binary_form: bool = ...) -> typing.Any:
...

def get_ca_certs(self, binary_form: bool = False) -> list[typing.Any] | list[bytes]:
raise NotImplementedError()

@property
def check_hostname(self) -> bool:
return self._ctx.check_hostname

@check_hostname.setter
def check_hostname(self, value: bool) -> None:
self._ctx.check_hostname = value

@property
def hostname_checks_common_name(self) -> bool:
return self._ctx.hostname_checks_common_name

@hostname_checks_common_name.setter
def hostname_checks_common_name(self, value: bool) -> None:
self._ctx.hostname_checks_common_name = value

@property
def keylog_filename(self) -> str:
return self._ctx.keylog_filename

@keylog_filename.setter
def keylog_filename(self, value: str) -> None:
self._ctx.keylog_filename = value

@property
def maximum_version(self) -> ssl.TLSVersion:
return self._ctx.maximum_version

@maximum_version.setter
def maximum_version(self, value: ssl.TLSVersion) -> None:
self._ctx.maximum_version = value

@property
def minimum_version(self) -> ssl.TLSVersion:
return self._ctx.minimum_version

@minimum_version.setter
def minimum_version(self, value: ssl.TLSVersion) -> None:
self._ctx.minimum_version = value

@property
def options(self) -> ssl.Options:
return self._ctx.options

@options.setter
def options(self, value: ssl.Options) -> None:
self._ctx.options = value

@property
def post_handshake_auth(self) -> bool:
return self._ctx.post_handshake_auth

@post_handshake_auth.setter
def post_handshake_auth(self, value: bool) -> None:
self._ctx.post_handshake_auth = value

@property
def protocol(self) -> ssl._SSLMethod:
return self._ctx.protocol

@property
def security_level(self) -> int:
return self._ctx.security_level # type: ignore[attr-defined,no-any-return]

@property
def verify_flags(self) -> ssl.VerifyFlags:
return self._ctx.verify_flags

@verify_flags.setter
def verify_flags(self, value: ssl.VerifyFlags) -> None:
self._ctx.verify_flags = value

@property
def verify_mode(self) -> ssl.VerifyMode:
return self._ctx.verify_mode

def __setattr__(self, name: str, value: Any) -> None:
if name == "verify_flags":
self._ctx.verify_flags = value
else:
return super().__setattr__(name, value)
@verify_mode.setter
def verify_mode(self, value: ssl.VerifyMode) -> None:
self._ctx.verify_mode = value


def _verify_peercerts(
Expand Down
37 changes: 34 additions & 3 deletions src/truststore/_macos.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import contextlib
import ctypes
import platform
import ssl
import typing
from ctypes import (
CDLL,
POINTER,
Expand All @@ -13,7 +15,6 @@
c_void_p,
)
from ctypes.util import find_library
from typing import Any

_mac_version = platform.mac_ver()[0]
_mac_version_info = tuple(map(int, _mac_version.split(".")))
Expand Down Expand Up @@ -201,7 +202,7 @@ def _load_cdll(name: str, macos10_16_path: str) -> CDLL:
raise ImportError("Error initializing ctypes") from None


def _handle_osstatus(result: OSStatus, _: Any, args: Any) -> Any:
def _handle_osstatus(result: OSStatus, _: typing.Any, args: typing.Any) -> typing.Any:
"""
Raises an error if the OSStatus value is non-zero.
"""
Expand Down Expand Up @@ -264,6 +265,11 @@ class CFConst:

kCFStringEncodingUTF8 = CFStringEncoding(0x08000100)

errSecIncompleteCertRevocationCheck = -67635
errSecHostNameMismatch = -67602
errSecCertificateExpired = -67818
errSecNotTrusted = -67843


def _bytes_to_cf_data_ref(value: bytes) -> CFDataRef: # type: ignore[valid-type]
return CoreFoundation.CFDataCreate( # type: ignore[no-any-return]
Expand Down Expand Up @@ -338,9 +344,15 @@ def _der_certs_to_cf_cert_array(certs: list[bytes]) -> CFMutableArrayRef: # typ
return cf_array # type: ignore[no-any-return]


def _configure_context(ctx: ssl.SSLContext) -> None:
@contextlib.contextmanager
def _configure_context(ctx: ssl.SSLContext) -> typing.Iterator[None]:
values = ctx.check_hostname, ctx.verify_mode
ctx.check_hostname = False
ctx.verify_mode = ssl.CERT_NONE
try:
yield
finally:
ctx.check_hostname, ctx.verify_mode = values


def _verify_peercerts_impl(
Expand Down Expand Up @@ -432,8 +444,27 @@ def _verify_peercerts_impl(
f"Unknown result from Security.SecTrustEvaluateWithError: {sec_trust_eval_result!r}"
)

cf_error_code = 0
if not is_trusted:
cf_error_code = CoreFoundation.CFErrorGetCode(cf_error)

# If the error is a known failure that we're
# explicitly okay with from SSLContext configuration
# we can set is_trusted accordingly.
if ssl_context.verify_mode != ssl.CERT_REQUIRED and (
cf_error_code == CFConst.errSecNotTrusted
or cf_error_code == CFConst.errSecCertificateExpired
):
is_trusted = True
elif (
not ssl_context.check_hostname
and cf_error_code == CFConst.errSecHostNameMismatch
):
is_trusted = True

# If we're still not trusted then we start to
# construct and raise the SSLCertVerificationError.
if not is_trusted:
cf_error_string_ref = None
try:
cf_error_string_ref = CoreFoundation.CFErrorCopyDescription(cf_error)
Expand Down
8 changes: 5 additions & 3 deletions src/truststore/_openssl.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import contextlib
import os
import re
import ssl
import typing

# candidates based on https://github.com/tiran/certifi-system-store by Christian Heimes
_CA_FILE_CANDIDATES = [
Expand All @@ -17,7 +19,8 @@
_HASHED_CERT_FILENAME_RE = re.compile(r"^[0-9a-fA-F]{8}\.[0-9]$")


def _configure_context(ctx: ssl.SSLContext) -> None:
@contextlib.contextmanager
def _configure_context(ctx: ssl.SSLContext) -> typing.Iterator[None]:
# First, check whether the default locations from OpenSSL
# seem like they will give us a usable set of CA certs.
# ssl.get_default_verify_paths already takes care of:
Expand All @@ -40,8 +43,7 @@ def _configure_context(ctx: ssl.SSLContext) -> None:
ctx.load_verify_locations(cafile=cafile)
break

ctx.verify_mode = ssl.CERT_REQUIRED
ctx.check_hostname = True
yield


def _capath_contains_certs(capath: str) -> bool:
Expand Down
Loading

0 comments on commit 63dc9e1

Please sign in to comment.