From 481a16be6f2041717db0cc9d76281cda69285bd6 Mon Sep 17 00:00:00 2001 From: TAHRI Ahmed R Date: Thu, 1 Feb 2024 19:22:27 +0100 Subject: [PATCH] :bookmark: Release 0.15.0 (#21) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit **Changed** - Highly simplified ``_crypto`` module based on upstream work PR 457 - Bump upper bound ``cryptography`` version to 42.x **Fixed** - Mitigate deprecation originating from ``cryptography`` about datetime naïve timezone --- CHANGELOG.rst | 10 + pyproject.toml | 2 +- src/qh3/__init__.py | 2 +- src/qh3/_crypto.py | 402 ++++++++-------------------------- src/qh3/h3/connection.py | 8 +- src/qh3/quic/configuration.py | 6 +- src/qh3/quic/connection.py | 14 +- src/qh3/quic/logger.py | 8 +- src/qh3/tls.py | 30 ++- tests/test_crypto.py | 50 ----- tests/test_tls.py | 18 +- tests/utils.py | 7 +- 12 files changed, 159 insertions(+), 398 deletions(-) diff --git a/CHANGELOG.rst b/CHANGELOG.rst index 47be167dd..952ac58ce 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -1,3 +1,13 @@ +0.15.0 (2023-02-01) +=================== + +**Changed** +- Highly simplified ``_crypto`` module based on upstream work https://github.com/aiortc/aioquic/pull/457 +- Bump upper bound ``cryptography`` version to 42.x + +**Fixed** +- Mitigate deprecation originating from ``cryptography`` about datetime naïve timezone. + 0.14.0 (2023-11-11) =================== diff --git a/pyproject.toml b/pyproject.toml index 895749696..c62026f0f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -33,7 +33,7 @@ classifiers = [ "Topic :: Internet :: WWW/HTTP", ] dependencies = [ - "cryptography>=41.0.0,<42.0.0", + "cryptography>=41.0.0,<43", ] dynamic = ["version"] diff --git a/src/qh3/__init__.py b/src/qh3/__init__.py index 9e78220f9..9da2f8fcc 100644 --- a/src/qh3/__init__.py +++ b/src/qh3/__init__.py @@ -1 +1 @@ -__version__ = "0.14.0" +__version__ = "0.15.0" diff --git a/src/qh3/_crypto.py b/src/qh3/_crypto.py index 3a55fa0b0..84921f598 100644 --- a/src/qh3/_crypto.py +++ b/src/qh3/_crypto.py @@ -1,355 +1,139 @@ -from typing import Tuple +import struct +from typing import Tuple, Union -from cryptography.hazmat.bindings.openssl.binding import Binding +from cryptography.exceptions import InvalidTag +from cryptography.hazmat.primitives.ciphers import ( + Cipher, + aead, + algorithms, + modes, +) -AEAD_KEY_LENGTH_MAX = 32 AEAD_NONCE_LENGTH = 12 AEAD_TAG_LENGTH = 16 -PACKET_LENGTH_MAX = 1500 -SAMPLE_LENGTH = 16 +CHACHA20_ZEROS = bytes(5) PACKET_NUMBER_LENGTH_MAX = 4 +SAMPLE_LENGTH = 16 +AEAD_KEY_LENGTH_MAX = 32 class CryptoError(ValueError): pass -def _get_cipher_by_name(binding: Binding, cipher_name: bytes): # -> EVP_CIPHER - evp_cipher = binding.lib.EVP_get_cipherbyname(cipher_name) - if evp_cipher == binding.ffi.NULL: - raise CryptoError(f"Invalid cipher name: {cipher_name.decode()}") - return evp_cipher - - -class _CryptoBase: - def __init__(self) -> None: - self._binding = Binding() - - def _handle_openssl_failure(self) -> bool: - self._binding.lib.ERR_clear_error() - raise CryptoError("OpenSSL call failed") +class AEAD: + _aead: Union[aead.AESGCM, aead.ChaCha20Poly1305] - -class AEAD(_CryptoBase): - def __init__(self, cipher_name: bytes, key: bytes, iv: bytes) -> None: - super().__init__() + def __init__(self, cipher_name: bytes, key: bytes, iv: bytes): + if cipher_name not in (b"aes-128-gcm", b"aes-256-gcm", b"chacha20-poly1305"): + raise CryptoError(f"Invalid cipher name: {cipher_name.decode()}") # check and store key and iv if len(key) > AEAD_KEY_LENGTH_MAX: raise CryptoError("Invalid key length") - self._key = key + if len(iv) != AEAD_NONCE_LENGTH: raise CryptoError("Invalid iv length") - self._iv = iv - - # create cipher contexts - evp_cipher = _get_cipher_by_name(self._binding, cipher_name) - self._decrypt_ctx = self._create_ctx(evp_cipher, operation=0) - self._encrypt_ctx = self._create_ctx(evp_cipher, operation=1) - - # allocate buffers - self._nonce = self._binding.ffi.new("unsigned char[]", AEAD_NONCE_LENGTH) - self._buffer = self._binding.ffi.new("unsigned char[]", PACKET_LENGTH_MAX) - self._buffer_view = self._binding.ffi.buffer(self._buffer) - self._outlen = self._binding.ffi.new("int *") - self._dummy_outlen = self._binding.ffi.new("int *") - - def _create_ctx(self, evp_cipher, operation: int): # -> EVP_CIPHER_CTX * - # create a cipher context with the given type and operation mode - ctx = self._binding.ffi.gc( - self._binding.lib.EVP_CIPHER_CTX_new(), - self._binding.lib.EVP_CIPHER_CTX_free, - ) - ctx != self._binding.ffi.NULL or self._handle_openssl_failure() - self._binding.lib.EVP_CipherInit_ex( - ctx, # EVP_CIPHER_CTX *ctx - evp_cipher, # const EVP_CIPHER *type - self._binding.ffi.NULL, # ENGINE *impl - self._binding.ffi.NULL, # const unsigned char *key - self._binding.ffi.NULL, # const unsigned char *iv - operation, # int enc - ) == 1 or self._handle_openssl_failure() - - # specify key and initialization vector length - self._binding.lib.EVP_CIPHER_CTX_set_key_length( - ctx, # EVP_CIPHER_CTX *ctx - len(self._key), # int keylen - ) == 1 or self._handle_openssl_failure() - self._binding.lib.EVP_CIPHER_CTX_ctrl( - ctx, # EVP_CIPHER_CTX *ctx - self._binding.lib.EVP_CTRL_AEAD_SET_IVLEN, # int cmd - AEAD_NONCE_LENGTH, # int ivlen - self._binding.ffi.NULL, # void *NULL - ) == 1 or self._handle_openssl_failure() - return ctx - def _init_nonce(self, packet_number: int) -> None: - # reference: https://datatracker.ietf.org/doc/html/rfc9001#section-5.3 + if cipher_name == b"chacha20-poly1305": + self._aead = aead.ChaCha20Poly1305(key) + else: + self._aead = aead.AESGCM(key) - # left-pad the reconstructed packet number (62 bits ~ 8 bytes) - # and XOR it with the IV - self._binding.ffi.memmove(self._nonce, self._iv, AEAD_NONCE_LENGTH) - for i in range(8): - if packet_number == 0: - break - self._nonce[AEAD_NONCE_LENGTH - 1 - i] ^= packet_number & 0xFF - packet_number >>= 8 + self._iv = iv def decrypt(self, data: bytes, associated_data: bytes, packet_number: int) -> bytes: - if len(data) < AEAD_TAG_LENGTH or len(data) > PACKET_LENGTH_MAX: - raise CryptoError("Invalid payload length") - self._init_nonce(packet_number) - - # get the appended AEAD tag (data = cipher text + tag) - cipher_text_len = len(data) - AEAD_TAG_LENGTH - self._binding.lib.EVP_CIPHER_CTX_ctrl( - self._decrypt_ctx, # EVP_CIPHER_CTX *ctx - self._binding.lib.EVP_CTRL_AEAD_SET_TAG, # int cmd - AEAD_TAG_LENGTH, # int taglen - data[cipher_text_len:], # void *tag - ) == 1 or self._handle_openssl_failure() - - # set key and nonce - self._binding.lib.EVP_CipherInit_ex( - self._decrypt_ctx, # EVP_CIPHER_CTX *ctx - self._binding.ffi.NULL, # const EVP_CIPHER *type - self._binding.ffi.NULL, # ENGINE *impl - self._key, # const unsigned char *key - self._nonce, # const unsigned char *iv - 0, # int enc - ) == 1 or self._handle_openssl_failure() - - # specify the header as additional authenticated data (AAD) - self._binding.lib.EVP_CipherUpdate( - self._decrypt_ctx, # EVP_CIPHER_CTX *ctx - self._binding.ffi.NULL, # unsigned char *out - self._dummy_outlen, # int *outl - associated_data, # const unsigned char *in - len(associated_data), # int inl - ) == 1 or self._handle_openssl_failure() - - # decrypt the cipher text (i.e. received data excluding the appended tag) - self._binding.lib.EVP_CipherUpdate( - self._decrypt_ctx, # EVP_CIPHER_CTX *ctx - self._buffer, # unsigned char *out - self._outlen, # int *outl - data, # const unsigned char *in - cipher_text_len, # int inl - ) == 1 or self._handle_openssl_failure() - - # finalize the operation - self._binding.lib.EVP_CipherFinal_ex( - self._decrypt_ctx, # EVP_CIPHER_CTX *ctx - self._binding.ffi.NULL, # unsigned char *outm - self._dummy_outlen, # int *outl - ) == 1 or self._handle_openssl_failure() - - # return the decrypted data - return self._buffer_view[: self._outlen[0]] + try: + return self._aead.decrypt( + self._nonce(packet_number), + data, + associated_data, + ) + except InvalidTag as exc: + raise CryptoError(str(exc)) def encrypt(self, data: bytes, associated_data: bytes, packet_number: int) -> bytes: - if len(data) > PACKET_LENGTH_MAX: - raise CryptoError("Invalid payload length") - self._init_nonce(packet_number) - - # set key and nonce - self._binding.lib.EVP_CipherInit_ex( - self._encrypt_ctx, # EVP_CIPHER_CTX *ctx - self._binding.ffi.NULL, # const EVP_CIPHER *type - self._binding.ffi.NULL, # ENGINE *impl - self._key, # const unsigned char *key - self._nonce, # const unsigned char *iv - 1, # int enc - ) == 1 or self._handle_openssl_failure() - - # specify the header as additional authenticated data (AAD) - self._binding.lib.EVP_CipherUpdate( - self._encrypt_ctx, # EVP_CIPHER_CTX *ctx - self._binding.ffi.NULL, # unsigned char *out - self._dummy_outlen, # int *outl - associated_data, # const unsigned char *in - len(associated_data), # int inl - ) == 1 or self._handle_openssl_failure() - - # encrypt the data - self._binding.lib.EVP_CipherUpdate( - self._encrypt_ctx, # EVP_CIPHER_CTX *ctx - self._buffer, # unsigned char *out - self._outlen, # int *outl - data, # const unsigned char *in - len(data), # int inl - ) == 1 or self._handle_openssl_failure() - - # finalize the operation - self._binding.lib.EVP_CipherFinal_ex( - self._encrypt_ctx, # EVP_CIPHER_CTX *ctx - self._binding.ffi.NULL, # unsigned char *outm - self._dummy_outlen, # int *outl - ) == 1 and self._dummy_outlen[0] == 0 or self._handle_openssl_failure() + return self._aead.encrypt( + self._nonce(packet_number), + data, + associated_data, + ) - # append the AEAD tag to the cipher text - outlen_with_tag = self._outlen[0] + AEAD_TAG_LENGTH - if outlen_with_tag > PACKET_LENGTH_MAX: - raise CryptoError("Invalid payload length") - self._binding.lib.EVP_CIPHER_CTX_ctrl( - self._encrypt_ctx, # EVP_CIPHER_CTX *ctx - self._binding.lib.EVP_CTRL_AEAD_GET_TAG, # int cmd - AEAD_TAG_LENGTH, # int taglen - self._buffer + self._outlen[0], # void *tag - ) == 1 or self._handle_openssl_failure() + def _nonce(self, packet_number: int) -> bytes: + return self._iv[0:4] + struct.pack( + ">Q", struct.unpack(">Q", self._iv[4:12])[0] ^ packet_number + ) - # return the encrypted cipher text and AEAD tag - return self._buffer_view[:outlen_with_tag] +class HeaderProtection: + def __init__(self, cipher_name: bytes, key: bytes): + if cipher_name not in (b"aes-128-ecb", b"aes-256-ecb", b"chacha20"): + raise CryptoError(f"Invalid cipher name: {cipher_name.decode()}") -class HeaderProtection(_CryptoBase): - def __init__(self, cipher_name: bytes, key: bytes) -> None: - super().__init__() - self._is_chacha20 = cipher_name == b"chacha20" if len(key) > AEAD_KEY_LENGTH_MAX: raise CryptoError("Invalid key length") - # create cipher with given type - evp_cipher = _get_cipher_by_name(self._binding, cipher_name) - self._ctx = self._binding.ffi.gc( - self._binding.lib.EVP_CIPHER_CTX_new(), - self._binding.lib.EVP_CIPHER_CTX_free, - ) - self._ctx != self._binding.ffi.NULL or self._handle_openssl_failure() - self._binding.lib.EVP_CipherInit_ex( - self._ctx, # EVP_CIPHER_CTX *ctx - evp_cipher, # const EVP_CIPHER *type - self._binding.ffi.NULL, # ENGINE *impl - self._binding.ffi.NULL, # const unsigned char *key - self._binding.ffi.NULL, # const unsigned char *iv - 1, # int enc - ) == 1 or self._handle_openssl_failure() - - # set cipher key - self._binding.lib.EVP_CIPHER_CTX_set_key_length( - self._ctx, # EVP_CIPHER_CTX *ctx - len(key), # int keylen - ) == 1 or self._handle_openssl_failure() - self._binding.lib.EVP_CipherInit_ex( - self._ctx, # EVP_CIPHER_CTX *ctx - self._binding.ffi.NULL, # const EVP_CIPHER *type - self._binding.ffi.NULL, # ENGINE *impl - key, # const unsigned char *key - self._binding.ffi.NULL, # const unsigned char *iv - 1, # int enc - ) == 1 or self._handle_openssl_failure() - - # allocate buffers - self._buffer = self._binding.ffi.new("unsigned char[]", PACKET_LENGTH_MAX) - self._buffer_view = self._binding.ffi.buffer(self._buffer) - self._dummy_outlen = self._binding.ffi.new("int *") - self._mask = self._binding.ffi.new("unsigned char[]", 31) - self._zero = self._binding.ffi.new("unsigned char[]", 5) - - def _update_mask(self, pn_offset: int, buffer_len: int) -> None: - # reference: https://datatracker.ietf.org/doc/html/rfc9001#section-5.4.2 - - # sample data starts 4 bytes after the beginning of the Packet Number field - # (regardless of its length) - sample_offset = pn_offset + 4 - assert pn_offset + SAMPLE_LENGTH <= buffer_len - - if self._is_chacha20: - # reference: https://datatracker.ietf.org/doc/html/rfc9001#section-5.4.4 - - # the first four bytes after pn_offset are block counter, - # the next 12 bytes are the nonce - self._binding.lib.EVP_CipherInit_ex( - self._ctx, # EVP_CIPHER_CTX *ctx - self._binding.ffi.NULL, # const EVP_CIPHER *type - self._binding.ffi.NULL, # ENGINE *impl - self._binding.ffi.NULL, # const unsigned char *key - self._buffer + sample_offset, # const unsigned char *iv - 1, # int enc - ) == 1 or self._handle_openssl_failure() - - # ChaCha20 is used to protect 5 zero bytes - self._binding.lib.EVP_CipherUpdate( - self._ctx, # EVP_CIPHER_CTX *ctx - self._mask, # unsigned char *out - self._dummy_outlen, # int *outl - self._zero, # const unsigned char *in - len(self._zero), # int inl - ) == 1 or self._handle_openssl_failure() - - else: - # reference: https://datatracker.ietf.org/doc/html/rfc9001#section-5.4.3 - - # AES-based header protected simply samples 16 bytes as input for AES-ECB - self._binding.lib.EVP_CipherUpdate( - self._ctx, # EVP_CIPHER_CTX *ctx - self._mask, # unsigned char *out - self._dummy_outlen, # int *outl - self._buffer + sample_offset, # const unsigned char *in - SAMPLE_LENGTH, # int inl - ) == 1 or self._handle_openssl_failure() - - def _mask_header(self) -> None: - # use one byte to mask 4 bits for long headers, and 5 bits for short ones - if self._buffer[0] & 0x80: - self._buffer[0] ^= self._mask[0] & 0x0F + if cipher_name == b"chacha20": + self._encryptor = None else: - self._buffer[0] ^= self._mask[0] & 0x1F + try: + self._encryptor = Cipher( + algorithm=algorithms.AES(key), + mode=modes.ECB(), + ).encryptor() + except ValueError as e: + raise CryptoError(str(e)) from e - def _mask_packet_number(self, pn_offset: int, pn_length: int) -> int: - # use the remaining (c.f. _mask_header) bytes to mask the packet number field - # and calculate the truncated packet number - pn_truncated = 0 - for i in range(pn_length): - value = self._buffer[pn_offset + i] ^ self._mask[1 + i] - self._buffer[pn_offset + i] = value - pn_truncated = value | (pn_truncated << 8) - return pn_truncated + self._key = key def apply(self, plain_header: bytes, protected_payload: bytes) -> bytes: - # Reference: https://datatracker.ietf.org/doc/html/rfc9001#section-5.4.1 - buffer_len = len(plain_header) + len(protected_payload) - if buffer_len > PACKET_LENGTH_MAX: - raise CryptoError("Invalid payload length") - - # read the Packet Number Length from the header pn_length = (plain_header[0] & 0x03) + 1 - - # the Packet Number is the last field of the header, calculate it's offset pn_offset = len(plain_header) - pn_length - # copy header and payload into the buffer - self._binding.ffi.memmove(self._buffer, plain_header, len(plain_header)) - self._binding.ffi.memmove( - self._buffer + len(plain_header), protected_payload, len(protected_payload) + sample_offset = PACKET_NUMBER_LENGTH_MAX - pn_length + mask = self._mask( + protected_payload[sample_offset : sample_offset + SAMPLE_LENGTH] ) - # build the mask and use it - self._update_mask(pn_offset, buffer_len) - self._mask_header() - self._mask_packet_number(pn_offset, pn_length) - - return self._buffer_view[:buffer_len] + buffer = bytearray(plain_header + protected_payload) + if buffer[0] & 0x80: + buffer[0] ^= mask[0] & 0x0F + else: + buffer[0] ^= mask[0] & 0x1F - def remove(self, packet: bytes, encrypted_offset: int) -> Tuple[bytes, int]: - # Reference: https://datatracker.ietf.org/doc/html/rfc9001#section-5.4.1 - if len(packet) > PACKET_LENGTH_MAX: - raise CryptoError("Invalid payload length") + for i in range(pn_length): + buffer[pn_offset + i] ^= mask[1 + i] - # copy the packet into the buffer - self._binding.ffi.memmove(self._buffer, packet, len(packet)) + return bytes(buffer) - # build the mask and use it to unmask the header first - self._update_mask(encrypted_offset, len(packet)) - self._mask_header() + def remove(self, packet: bytes, pn_offset: int) -> Tuple[bytes, int]: + sample_offset = pn_offset + PACKET_NUMBER_LENGTH_MAX + mask = self._mask(packet[sample_offset : sample_offset + SAMPLE_LENGTH]) - # get the packet number length and unmask it as well - pn_length = (self._buffer[0] & 0x03) + 1 - pn_truncated = self._mask_packet_number(encrypted_offset, pn_length) + buffer = bytearray(packet) + if buffer[0] & 0x80: + buffer[0] ^= mask[0] & 0x0F + else: + buffer[0] ^= mask[0] & 0x1F - # return the header and the truncated packet number - return ( - self._buffer_view[: encrypted_offset + pn_length], - pn_truncated, - ) + pn_length = (buffer[0] & 0x03) + 1 + pn_truncated = 0 + for i in range(pn_length): + buffer[pn_offset + i] ^= mask[1 + i] + pn_truncated = buffer[pn_offset + i] | (pn_truncated << 8) + + return bytes(buffer[: pn_offset + pn_length]), pn_truncated + + def _mask(self, sample: bytes) -> bytes: + if self._encryptor is None: + return ( + Cipher( + algorithm=algorithms.ChaCha20(self._key, sample), + mode=None, + ) + .encryptor() + .update(CHACHA20_ZEROS) + ) + else: + return self._encryptor.update(sample) diff --git a/src/qh3/h3/connection.py b/src/qh3/h3/connection.py index 6b4dd4466..653481568 100644 --- a/src/qh3/h3/connection.py +++ b/src/qh3/h3/connection.py @@ -652,9 +652,11 @@ def _handle_request_or_push_frame( category="http", event="frame_parsed", data=self._quic_logger.encode_http3_headers_frame( - length=stream.blocked_frame_size - if frame_data is None - else len(frame_data), + length=( + stream.blocked_frame_size + if frame_data is None + else len(frame_data) + ), headers=headers, stream_id=stream.stream_id, ), diff --git a/src/qh3/quic/configuration.py b/src/qh3/quic/configuration.py index 0ff09c62b..0e8a8ec71 100644 --- a/src/qh3/quic/configuration.py +++ b/src/qh3/quic/configuration.py @@ -137,9 +137,9 @@ def load_cert_chain( if keyfile is not None: self.private_key = load_pem_private_key( keyfile, - password=password.encode("utf8") - if isinstance(password, str) - else password, + password=( + password.encode("utf8") if isinstance(password, str) else password + ), ) def load_verify_locations( diff --git a/src/qh3/quic/connection.py b/src/qh3/quic/connection.py index b0786f236..78d0fa997 100644 --- a/src/qh3/quic/connection.py +++ b/src/qh3/quic/connection.py @@ -608,9 +608,11 @@ def datagrams_to_send(self, now: float) -> List[Tuple[bytes, NetworkAddress]]: "packet_type": self._quic_logger.packet_type( packet.packet_type ), - "scid": dump_cid(self.host_cid) - if is_long_header(packet.packet_type) - else "", + "scid": ( + dump_cid(self.host_cid) + if is_long_header(packet.packet_type) + else "" + ), "dcid": dump_cid(self._peer_cid.cid), }, "raw": {"length": packet.sent_bytes}, @@ -2560,9 +2562,9 @@ def _serialize_transport_parameters(self) -> bytes: initial_source_connection_id=self._local_initial_source_connection_id, max_ack_delay=25, max_datagram_frame_size=self._configuration.max_datagram_frame_size, - quantum_readiness=b"Q" * 1200 - if self._configuration.quantum_readiness_test - else None, + quantum_readiness=( + b"Q" * 1200 if self._configuration.quantum_readiness_test else None + ), stateless_reset_token=self._host_cids[0].stateless_reset_token, ) if not self._is_client: diff --git a/src/qh3/quic/logger.py b/src/qh3/quic/logger.py index 67b349d4c..08c3ef42d 100644 --- a/src/qh3/quic/logger.py +++ b/src/qh3/quic/logger.py @@ -84,9 +84,11 @@ def encode_connection_limit_frame(self, frame_type: int, maximum: int) -> Dict: return { "frame_type": "max_streams", "maximum": maximum, - "stream_type": "unidirectional" - if frame_type == QuicFrameType.MAX_STREAMS_UNI - else "bidirectional", + "stream_type": ( + "unidirectional" + if frame_type == QuicFrameType.MAX_STREAMS_UNI + else "bidirectional" + ), } def encode_crypto_frame(self, frame: QuicStreamFrame) -> Dict: diff --git a/src/qh3/tls.py b/src/qh3/tls.py index 50af2410e..222f0d6ff 100644 --- a/src/qh3/tls.py +++ b/src/qh3/tls.py @@ -71,8 +71,9 @@ # facilitate mocking for the test suite -def utcnow() -> datetime.datetime: - return datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) +def utcnow(remove_tz: bool = True) -> datetime.datetime: + dt = datetime.datetime.now(datetime.timezone.utc) + return dt.replace(tzinfo=None) if remove_tz else dt class AlertDescription(IntEnum): @@ -403,12 +404,23 @@ def verify_certificate( if chain is None: chain = [] + use_naive_dt = hasattr(certificate, "not_valid_before_utc") is False + # verify dates - now = utcnow() + now = utcnow(remove_tz=use_naive_dt) + + not_valid_before = ( + certificate.not_valid_before + if use_naive_dt + else certificate.not_valid_before_utc + ) + not_valid_after = ( + certificate.not_valid_after if use_naive_dt else certificate.not_valid_after_utc + ) - if now < certificate.not_valid_before: + if now < not_valid_before: raise AlertCertificateExpired("Certificate is not valid yet") - if now > certificate.not_valid_after: + if now > not_valid_after: raise AlertCertificateExpired("Certificate is no longer valid") # load CAs @@ -1604,9 +1616,11 @@ def _client_send_hello(self, output_buf: Buffer) -> None: legacy_compression_methods=self._legacy_compression_methods, alpn_protocols=self._alpn_protocols, key_share=key_share, - psk_key_exchange_modes=self._psk_key_exchange_modes - if (self.session_ticket or self.new_session_ticket_cb is not None) - else None, + psk_key_exchange_modes=( + self._psk_key_exchange_modes + if (self.session_ticket or self.new_session_ticket_cb is not None) + else None + ), server_name=self._server_name, signature_algorithms=self._signature_algorithms, supported_groups=supported_groups, diff --git a/tests/test_crypto.py b/tests/test_crypto.py index f114f2aa9..cfacbab12 100644 --- a/tests/test_crypto.py +++ b/tests/test_crypto.py @@ -1,6 +1,5 @@ import binascii from unittest import TestCase, skipIf -from unittest.mock import patch from qh3.buffer import Buffer from qh3.quic.crypto import ( @@ -382,34 +381,6 @@ def test_aead_init_args_validation(self): self.create_aead(iv=bytes(13)) self.assertEqual(str(cm.exception), "Invalid iv length") - def test_aead_data_length_validation(self): - aead = self.create_aead() - iv = aead._iv - zero_nonce = bytes(12) - - # too large for encrypt, early abort - self.assertEqual(bytes(aead._nonce), zero_nonce) - with self.assertRaises(CryptoError) as cm: - aead.encrypt(bytes(1501), associated_data=b"", packet_number=0) - self.assertEqual(str(cm.exception), "Invalid payload length") - self.assertEqual(bytes(aead._nonce), zero_nonce) - - # too large for encrypt, late tag check - with self.assertRaises(CryptoError) as cm: - aead.encrypt(bytes(1500), associated_data=b"", packet_number=0) - self.assertEqual(str(cm.exception), "Invalid payload length") - self.assertEqual(bytes(aead._nonce), iv) - - # too small for decrypt - with self.assertRaises(CryptoError) as cm: - aead.decrypt(bytes(11), associated_data=b"", packet_number=0) - self.assertEqual(str(cm.exception), "Invalid payload length") - - # too large for decrypt - with self.assertRaises(CryptoError) as cm: - aead.decrypt(bytes(1501), associated_data=b"", packet_number=0) - self.assertEqual(str(cm.exception), "Invalid payload length") - def test_hp_init_args_validation(self): # invalid cipher with self.assertRaises(CryptoError) as cm: @@ -420,24 +391,3 @@ def test_hp_init_args_validation(self): with self.assertRaises(CryptoError) as cm: self.create_hp(key=bytes(33)) self.assertEqual(str(cm.exception), "Invalid key length") - - def test_hp_data_length_validation(self): - hp = self.create_hp() - - # too large for apply - with self.assertRaises(CryptoError) as cm: - hp.apply(plain_header=bytes(501), protected_payload=bytes(1000)) - self.assertEqual(str(cm.exception), "Invalid payload length") - - # too large for remove - with self.assertRaises(CryptoError) as cm: - hp.remove(packet=bytes(1501), encrypted_offset=0) - self.assertEqual(str(cm.exception), "Invalid payload length") - - def test_handle_openssl_failure(self): - # ensure errors are cleared - aead = self.create_aead() - with patch.object(aead._binding.lib, "ERR_clear_error") as mock: - with self.assertRaises(CryptoError): - aead._handle_openssl_failure() - mock.assert_called_once() diff --git a/tests/test_tls.py b/tests/test_tls.py index 8edca51a4..63070304f 100644 --- a/tests/test_tls.py +++ b/tests/test_tls.py @@ -1294,7 +1294,7 @@ def test_verify_certificate_chain(self): certificate = load_pem_x509_certificates(fp.read())[0] with patch("qh3.tls.utcnow") as mock_utcnow: - mock_utcnow.return_value = certificate.not_valid_before + mock_utcnow.return_value = certificate.not_valid_before_utc # fail with self.assertRaises(tls.AlertBadCertificate) as cm: @@ -1315,7 +1315,7 @@ def test_verify_certificate_chain_self_signed(self): ) with patch("qh3.tls.utcnow") as mock_utcnow: - mock_utcnow.return_value = certificate.not_valid_before + mock_utcnow.return_value = certificate.not_valid_before_utc # fail with self.assertRaises(tls.AlertBadCertificate) as cm: @@ -1345,7 +1345,7 @@ def test_verify_dates(self): #  too early with patch("qh3.tls.utcnow") as mock_utcnow: mock_utcnow.return_value = ( - certificate.not_valid_before - datetime.timedelta(seconds=1) + certificate.not_valid_before_utc - datetime.timedelta(seconds=1) ) with self.assertRaises(tls.AlertCertificateExpired) as cm: verify_certificate(cadata=cadata, certificate=certificate) @@ -1353,17 +1353,17 @@ def test_verify_dates(self): # valid with patch("qh3.tls.utcnow") as mock_utcnow: - mock_utcnow.return_value = certificate.not_valid_before + mock_utcnow.return_value = certificate.not_valid_before_utc verify_certificate(cadata=cadata, certificate=certificate) with patch("qh3.tls.utcnow") as mock_utcnow: - mock_utcnow.return_value = certificate.not_valid_after + mock_utcnow.return_value = certificate.not_valid_after_utc verify_certificate(cadata=cadata, certificate=certificate) # too late with patch("qh3.tls.utcnow") as mock_utcnow: - mock_utcnow.return_value = certificate.not_valid_after + datetime.timedelta( - seconds=1 + mock_utcnow.return_value = ( + certificate.not_valid_after_utc + datetime.timedelta(seconds=1) ) with self.assertRaises(tls.AlertCertificateExpired) as cm: verify_certificate(cadata=cadata, certificate=certificate) @@ -1378,7 +1378,7 @@ def test_verify_subject(self): cadata = certificate.public_bytes(serialization.Encoding.PEM) with patch("qh3.tls.utcnow") as mock_utcnow: - mock_utcnow.return_value = certificate.not_valid_before + mock_utcnow.return_value = certificate.not_valid_before_utc # both valid match_hostname( @@ -1427,7 +1427,7 @@ def test_verify_subject_with_subjaltname(self): cadata = certificate.public_bytes(serialization.Encoding.PEM) with patch("qh3.tls.utcnow") as mock_utcnow: - mock_utcnow.return_value = certificate.not_valid_before + mock_utcnow.return_value = certificate.not_valid_before_utc # valid match_hostname( diff --git a/tests/utils.py b/tests/utils.py index a3844b012..2276644f7 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -28,12 +28,9 @@ def generate_certificate(*, alternative_names, common_name, hash_algorithm, key) .issuer_name(issuer) .public_key(key.public_key()) .serial_number(x509.random_serial_number()) - .not_valid_before( - datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) - ) + .not_valid_before(datetime.datetime.now(datetime.timezone.utc)) .not_valid_after( - datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) - + datetime.timedelta(days=10) + datetime.datetime.now(datetime.timezone.utc) + datetime.timedelta(days=10) ) ) if alternative_names: