Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add SSLContext methods and properties #83

Merged
merged 3 commits into from
Feb 20, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
177 changes: 153 additions & 24 deletions src/truststore/_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import platform
import socket
import ssl
from typing import Any
import typing

from _ssl import ENCODING_DER # type: ignore[import]

Expand All @@ -14,12 +14,15 @@
from ._openssl import _configure_context, _verify_peercerts_impl


_StrOrBytesPath: typing.TypeAlias = str | bytes | os.PathLike[str] | os.PathLike[bytes]
_PasswordType: typing.TypeAlias = str | bytes | typing.Callable[[], str | bytes]


class SSLContext(ssl.SSLContext):
"""SSLContext API that uses system certificates on all platforms"""

def __init__(self, protocol: int = ssl.PROTOCOL_TLS) -> None:
self._ctx = ssl.SSLContext(protocol)
_configure_context(self._ctx)

class TruststoreSSLObject(ssl.SSLObject):
# This object exists because wrap_bio() doesn't
Expand All @@ -42,14 +45,18 @@ def wrap_socket(
server_hostname: str | None = None,
session: ssl.SSLSession | None = None,
) -> ssl.SSLSocket:
ssl_sock = self._ctx.wrap_socket(
sock,
server_side=server_side,
server_hostname=server_hostname,
do_handshake_on_connect=do_handshake_on_connect,
suppress_ragged_eofs=suppress_ragged_eofs,
session=session,
)
# Use a context manager here because the
# inner SSLContext holds on to our state
# but also does the actual handshake.
with _configure_context(self._ctx):
ssl_sock = self._ctx.wrap_socket(
sock,
server_side=server_side,
server_hostname=server_hostname,
do_handshake_on_connect=do_handshake_on_connect,
suppress_ragged_eofs=suppress_ragged_eofs,
session=session,
)
try:
_verify_peercerts(ssl_sock, server_hostname=server_hostname)
except ssl.SSLError:
Expand All @@ -65,13 +72,14 @@ def wrap_bio(
server_hostname: str | None = None,
session: ssl.SSLSession | None = None,
) -> ssl.SSLObject:
ssl_obj = self._ctx.wrap_bio(
incoming,
outgoing,
server_hostname=server_hostname,
server_side=server_side,
session=session,
)
with _configure_context(self._ctx):
ssl_obj = self._ctx.wrap_bio(
incoming,
outgoing,
server_hostname=server_hostname,
server_side=server_side,
session=session,
)
return ssl_obj

def load_verify_locations(
Expand All @@ -84,14 +92,135 @@ def load_verify_locations(
cafile=cafile, capath=capath, cadata=cadata
)

def __getattr__(self, name: str) -> Any:
return getattr(self._ctx, name)
def load_cert_chain(
self,
certfile: _StrOrBytesPath,
keyfile: _StrOrBytesPath | None = None,
password: _PasswordType | None = None,
) -> None:
return self._ctx.load_cert_chain(
certfile=certfile, keyfile=keyfile, password=password
)

def load_default_certs(
self, purpose: ssl.Purpose = ssl.Purpose.SERVER_AUTH
) -> None:
return self._ctx.load_default_certs(purpose)

def set_alpn_protocols(self, alpn_protocols: typing.Iterable[str]) -> None:
return self._ctx.set_alpn_protocols(alpn_protocols)

def set_npn_protocols(self, npn_protocols: typing.Iterable[str]) -> None:
return self._ctx.set_npn_protocols(npn_protocols)

def set_ciphers(self, __cipherlist: str) -> None:
return self._ctx.set_ciphers(__cipherlist)

def get_ciphers(self) -> typing.Any:
return self._ctx.get_ciphers()

def session_stats(self) -> dict[str, int]:
return self._ctx.session_stats()

def cert_store_stats(self) -> dict[str, int]:
raise NotImplementedError()

@typing.overload
def get_ca_certs(
self, binary_form: typing.Literal[False] = ...
) -> list[typing.Any]:
...

@typing.overload
def get_ca_certs(self, binary_form: typing.Literal[True] = ...) -> list[bytes]:
...

@typing.overload
def get_ca_certs(self, binary_form: bool = ...) -> typing.Any:
...

def get_ca_certs(self, binary_form: bool = False) -> list[typing.Any] | list[bytes]:
raise NotImplementedError()

@property
def check_hostname(self) -> bool:
return self._ctx.check_hostname

@check_hostname.setter
def check_hostname(self, value: bool) -> None:
self._ctx.check_hostname = value

@property
def hostname_checks_common_name(self) -> bool:
return self._ctx.hostname_checks_common_name

@hostname_checks_common_name.setter
def hostname_checks_common_name(self, value: bool) -> None:
self._ctx.hostname_checks_common_name = value

@property
def keylog_filename(self) -> str:
return self._ctx.keylog_filename

@keylog_filename.setter
def keylog_filename(self, value: str) -> None:
self._ctx.keylog_filename = value

@property
def maximum_version(self) -> ssl.TLSVersion:
return self._ctx.maximum_version

@maximum_version.setter
def maximum_version(self, value: ssl.TLSVersion) -> None:
self._ctx.maximum_version = value

@property
def minimum_version(self) -> ssl.TLSVersion:
return self._ctx.minimum_version

@minimum_version.setter
def minimum_version(self, value: ssl.TLSVersion) -> None:
self._ctx.minimum_version = value

@property
def options(self) -> ssl.Options:
return self._ctx.options

@options.setter
def options(self, value: ssl.Options) -> None:
self._ctx.options = value

@property
def post_handshake_auth(self) -> bool:
return self._ctx.post_handshake_auth

@post_handshake_auth.setter
def post_handshake_auth(self, value: bool) -> None:
self._ctx.post_handshake_auth = value

@property
def protocol(self) -> ssl._SSLMethod:
return self._ctx.protocol

@property
def security_level(self) -> int:
return self._ctx.security_level # type: ignore[attr-defined,no-any-return]

@property
def verify_flags(self) -> ssl.VerifyFlags:
return self._ctx.verify_flags

@verify_flags.setter
def verify_flags(self, value: ssl.VerifyFlags) -> None:
self._ctx.verify_flags = value

@property
def verify_mode(self) -> ssl.VerifyMode:
return self._ctx.verify_mode

def __setattr__(self, name: str, value: Any) -> None:
if name == "verify_flags":
self._ctx.verify_flags = value
else:
return super().__setattr__(name, value)
@verify_mode.setter
def verify_mode(self, value: ssl.VerifyMode) -> None:
self._ctx.verify_mode = value


def _verify_peercerts(
Expand Down
37 changes: 34 additions & 3 deletions src/truststore/_macos.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import contextlib
import ctypes
import platform
import ssl
import typing
from ctypes import (
CDLL,
POINTER,
Expand All @@ -13,7 +15,6 @@
c_void_p,
)
from ctypes.util import find_library
from typing import Any

_mac_version = platform.mac_ver()[0]
_mac_version_info = tuple(map(int, _mac_version.split(".")))
Expand Down Expand Up @@ -201,7 +202,7 @@ def _load_cdll(name: str, macos10_16_path: str) -> CDLL:
raise ImportError("Error initializing ctypes") from None


def _handle_osstatus(result: OSStatus, _: Any, args: Any) -> Any:
def _handle_osstatus(result: OSStatus, _: typing.Any, args: typing.Any) -> typing.Any:
"""
Raises an error if the OSStatus value is non-zero.
"""
Expand Down Expand Up @@ -264,6 +265,11 @@ class CFConst:

kCFStringEncodingUTF8 = CFStringEncoding(0x08000100)

errSecIncompleteCertRevocationCheck = -67635
errSecHostNameMismatch = -67602
errSecCertificateExpired = -67818
errSecNotTrusted = -67843


def _bytes_to_cf_data_ref(value: bytes) -> CFDataRef: # type: ignore[valid-type]
return CoreFoundation.CFDataCreate( # type: ignore[no-any-return]
Expand Down Expand Up @@ -338,9 +344,15 @@ def _der_certs_to_cf_cert_array(certs: list[bytes]) -> CFMutableArrayRef: # typ
return cf_array # type: ignore[no-any-return]


def _configure_context(ctx: ssl.SSLContext) -> None:
@contextlib.contextmanager
def _configure_context(ctx: ssl.SSLContext) -> typing.Iterator[None]:
values = ctx.check_hostname, ctx.verify_mode
ctx.check_hostname = False
ctx.verify_mode = ssl.CERT_NONE
try:
yield
finally:
ctx.check_hostname, ctx.verify_mode = values


def _verify_peercerts_impl(
Expand Down Expand Up @@ -432,8 +444,27 @@ def _verify_peercerts_impl(
f"Unknown result from Security.SecTrustEvaluateWithError: {sec_trust_eval_result!r}"
)

cf_error_code = 0
if not is_trusted:
cf_error_code = CoreFoundation.CFErrorGetCode(cf_error)

# If the error is a known failure that we're
# explicitly okay with from SSLContext configuration
# we can set is_trusted accordingly.
if ssl_context.verify_mode != ssl.CERT_REQUIRED and (
cf_error_code == CFConst.errSecNotTrusted
or cf_error_code == CFConst.errSecCertificateExpired
):
is_trusted = True
elif (
not ssl_context.check_hostname
and cf_error_code == CFConst.errSecHostNameMismatch
):
is_trusted = True

# If we're still not trusted then we start to
# construct and raise the SSLCertVerificationError.
if not is_trusted:
cf_error_string_ref = None
try:
cf_error_string_ref = CoreFoundation.CFErrorCopyDescription(cf_error)
Expand Down
8 changes: 5 additions & 3 deletions src/truststore/_openssl.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import contextlib
import os
import re
import ssl
import typing

# candidates based on https://github.com/tiran/certifi-system-store by Christian Heimes
_CA_FILE_CANDIDATES = [
Expand All @@ -17,7 +19,8 @@
_HASHED_CERT_FILENAME_RE = re.compile(r"^[0-9a-fA-F]{8}\.[0-9]$")


def _configure_context(ctx: ssl.SSLContext) -> None:
@contextlib.contextmanager
def _configure_context(ctx: ssl.SSLContext) -> typing.Iterator[None]:
# First, check whether the default locations from OpenSSL
# seem like they will give us a usable set of CA certs.
# ssl.get_default_verify_paths already takes care of:
Expand All @@ -40,8 +43,7 @@ def _configure_context(ctx: ssl.SSLContext) -> None:
ctx.load_verify_locations(cafile=cafile)
break

ctx.verify_mode = ssl.CERT_REQUIRED
ctx.check_hostname = True
yield


def _capath_contains_certs(capath: str) -> bool:
Expand Down
Loading