From 63f75622e75de88da5bf48d89c0d9ad6ff48a9a5 Mon Sep 17 00:00:00 2001 From: Matt Gilene Date: Sun, 11 Jul 2021 15:08:03 -0400 Subject: [PATCH 1/7] Implement asgiref tls extension --- tests/conftest.py | 17 ++++++ tests/test_ssl.py | 52 ++++++++++++++++ uvicorn/config.py | 3 + uvicorn/protocols/http/h11_impl.py | 24 +++++++- uvicorn/protocols/http/httptools_impl.py | 25 +++++++- uvicorn/protocols/utils.py | 76 +++++++++++++++++++++++- 6 files changed, 194 insertions(+), 3 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index d26d39432..1bc6d1d97 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -2,6 +2,7 @@ import os import socket import ssl + from copy import deepcopy from hashlib import md5 from pathlib import Path @@ -12,11 +13,13 @@ import pytest import trustme + from cryptography.hazmat.backends import default_backend from cryptography.hazmat.primitives import serialization from uvicorn.config import LOGGING_CONFIG + # Note: We explicitly turn the propagate on just for tests, because pytest # caplog not able to capture no-propagate loggers. # @@ -43,6 +46,13 @@ def tls_certificate(tls_certificate_authority: trustme.CA) -> trustme.LeafCert: ) +@pytest.fixture +def tls_client_certificate(tls_certificate_authority: trustme.CA) -> trustme.LeafCert: + return tls_certificate_authority.issue_cert( + "client@example.com", common_name="uvicorn client" + ) + + @pytest.fixture def tls_ca_certificate_pem_path(tls_certificate_authority: trustme.CA): with tls_certificate_authority.cert_pem.tempfile() as ca_cert_pem: @@ -96,6 +106,13 @@ def tls_ca_ssl_context(tls_certificate_authority: trustme.CA) -> ssl.SSLContext: return ssl_ctx +@pytest.fixture +def tls_client_certificate_pem_path(tls_client_certificate: trustme.LeafCert): + private_key_and_cert_chain = tls_client_certificate.private_key_and_cert_chain_pem + with private_key_and_cert_chain.tempfile() as client_cert_pem: + yield client_cert_pem + + @pytest.fixture(scope="package") def reload_directory_structure(tmp_path_factory: pytest.TempPathFactory): """ diff --git a/tests/test_ssl.py b/tests/test_ssl.py index d60bcf54e..b113d89b4 100644 --- a/tests/test_ssl.py +++ b/tests/test_ssl.py @@ -1,3 +1,5 @@ +import ssl + import httpx import pytest @@ -34,6 +36,56 @@ async def test_run( assert response.status_code == 204 +@pytest.mark.anyio +async def test_run_httptools_client_cert( + tls_ca_ssl_context, + tls_ca_certificate_pem_path, + tls_ca_certificate_private_key_path, + tls_client_certificate_pem_path, +): + config = Config( + app=app, + loop="asyncio", + http="httptools", + limit_max_requests=1, + ssl_keyfile=tls_ca_certificate_private_key_path, + ssl_certfile=tls_ca_certificate_pem_path, + ssl_ca_certs=tls_ca_certificate_pem_path, + ssl_cert_reqs=ssl.CERT_REQUIRED, + ) + async with run_server(config): + async with httpx.AsyncClient( + verify=tls_ca_ssl_context, cert=tls_client_certificate_pem_path + ) as client: + response = await client.get("https://127.0.0.1:8000") + assert response.status_code == 204 + + +@pytest.mark.anyio +async def test_run_h11_client_cert( + tls_ca_ssl_context, + tls_ca_certificate_pem_path, + tls_ca_certificate_private_key_path, + tls_client_certificate_pem_path, +): + config = Config( + app=app, + loop="asyncio", + http="h11", + limit_max_requests=1, + ssl_keyfile=tls_ca_certificate_private_key_path, + ssl_certfile=tls_ca_certificate_pem_path, + ssl_ca_certs=tls_ca_certificate_pem_path, + ssl_cert_reqs=ssl.CERT_REQUIRED, + ) + async with run_server(config): + async with httpx.AsyncClient( + verify=tls_ca_ssl_context, cert=tls_client_certificate_pem_path + ) as client: + response = await client.get("https://127.0.0.1:8000") + assert response.status_code == 204 + + @pytest.mark.anyio async def test_run_chain( tls_ca_ssl_context, diff --git a/uvicorn/config.py b/uvicorn/config.py index 0ebc562c1..e7993f1d9 100644 --- a/uvicorn/config.py +++ b/uvicorn/config.py @@ -285,6 +285,7 @@ def __init__( self.callback_notify = callback_notify self.ssl_keyfile = ssl_keyfile self.ssl_certfile = ssl_certfile + self.ssl_cert_pem: Optional[str] = None self.ssl_keyfile_password = ssl_keyfile_password self.ssl_version = ssl_version self.ssl_cert_reqs = ssl_cert_reqs @@ -446,6 +447,8 @@ def load(self) -> None: ca_certs=self.ssl_ca_certs, ciphers=self.ssl_ciphers, ) + with open(self.ssl_certfile) as file: + self.ssl_cert_pem = file.read() else: self.ssl = None diff --git a/uvicorn/protocols/http/h11_impl.py b/uvicorn/protocols/http/h11_impl.py index c2764b028..63e8c7653 100644 --- a/uvicorn/protocols/http/h11_impl.py +++ b/uvicorn/protocols/http/h11_impl.py @@ -2,7 +2,18 @@ import http import logging import sys -from typing import TYPE_CHECKING, Callable, List, Optional, Tuple, Union, cast + +from typing import ( + TYPE_CHECKING, + Any, + Callable, + Dict, + List, + Optional, + Tuple, + Union, + cast, +) from urllib.parse import unquote import h11 @@ -20,10 +31,12 @@ get_local_addr, get_path_with_query_string, get_remote_addr, + get_tls_info, is_ssl, ) from uvicorn.server import ServerState + if sys.version_info < (3, 8): # pragma: py-gte-38 from typing_extensions import Literal else: # pragma: py-lt-38 @@ -99,6 +112,7 @@ def __init__( self.server: Optional[Tuple[str, int]] = None self.client: Optional[Tuple[str, int]] = None self.scheme: Optional[Literal["http", "https"]] = None + self.tls: Optional[Dict[str, Any]] = None # Per-request state self.scope: HTTPScope = None # type: ignore[assignment] @@ -117,6 +131,11 @@ def connection_made( # type: ignore[override] self.client = get_remote_addr(transport) self.scheme = "https" if is_ssl(transport) else "http" + if self.config.is_ssl: + self.tls = get_tls_info(transport) + if self.tls: + self.tls["server_cert"] = self.config.ssl_cert_pem + if self.logger.level <= TRACE_LOG_LEVEL: prefix = "%s:%d - " % self.client if self.client else "" self.logger.log(TRACE_LOG_LEVEL, "%sHTTP connection made", prefix) @@ -223,6 +242,9 @@ def handle_events(self) -> None: "raw_path": raw_path, "query_string": query_string, "headers": self.headers, + "extensions": { + "tls": self.tls, + }, } upgrade = self._get_upgrade() diff --git a/uvicorn/protocols/http/httptools_impl.py b/uvicorn/protocols/http/httptools_impl.py index 734e8945d..e9c945ffb 100644 --- a/uvicorn/protocols/http/httptools_impl.py +++ b/uvicorn/protocols/http/httptools_impl.py @@ -4,9 +4,21 @@ import re import sys import urllib + from asyncio.events import TimerHandle from collections import deque -from typing import TYPE_CHECKING, Callable, Deque, List, Optional, Tuple, Union, cast +from typing import ( + TYPE_CHECKING, + Any, + Callable, + Deque, + Dict, + List, + Optional, + Tuple, + Union, + cast, +) import httptools @@ -23,10 +35,12 @@ get_local_addr, get_path_with_query_string, get_remote_addr, + get_tls_info, is_ssl, ) from uvicorn.server import ServerState + if sys.version_info < (3, 8): # pragma: py-gte-38 from typing_extensions import Literal else: # pragma: py-lt-38 @@ -98,6 +112,7 @@ def __init__( self.client: Optional[Tuple[str, int]] = None self.scheme: Optional[Literal["http", "https"]] = None self.pipeline: Deque[Tuple[RequestResponseCycle, ASGI3Application]] = deque() + self.tls: Optional[Dict[str, Any]] = None # Per-request state self.scope: HTTPScope = None # type: ignore[assignment] @@ -117,6 +132,11 @@ def connection_made( # type: ignore[override] self.client = get_remote_addr(transport) self.scheme = "https" if is_ssl(transport) else "http" + if self.config.is_ssl: + self.tls = get_tls_info(transport) + if self.tls: + self.tls["server_cert"] = self.config.ssl_cert_pem + if self.logger.level <= TRACE_LOG_LEVEL: prefix = "%s:%d - " % self.client if self.client else "" self.logger.log(TRACE_LOG_LEVEL, "%sHTTP connection made", prefix) @@ -237,6 +257,9 @@ def on_message_begin(self) -> None: "scheme": self.scheme, "root_path": self.root_path, "headers": self.headers, + "extensions": { + "tls": self.tls, + }, } # Parser callbacks diff --git a/uvicorn/protocols/utils.py b/uvicorn/protocols/utils.py index fbd4b4d5d..b6e7a662a 100644 --- a/uvicorn/protocols/utils.py +++ b/uvicorn/protocols/utils.py @@ -1,10 +1,32 @@ import asyncio +import ssl import urllib.parse -from typing import TYPE_CHECKING, Optional, Tuple + +from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple + if TYPE_CHECKING: from asgiref.typing import WWWScope +RDNS_MAPPING: Dict[str, str] = { + "commonName": "CN", + "localityName": "L", + "stateOrProvinceName": "ST", + "organizationName": "O", + "organizationalUnitName": "OU", + "countryName": "C", + "streetAddress": "STREET", + "domainComponent": "DC", + "userId": "UID", +} + +TLS_VERSION_MAP: Dict[str, int] = { + "TLSv1": 0x0301, + "TLSv1.1": 0x0302, + "TLSv1.2": 0x0303, + "TLSv1.3": 0x0304, +} + def get_remote_addr(transport: asyncio.Transport) -> Optional[Tuple[str, int]]: socket_info = transport.get_extra_info("socket") @@ -53,3 +75,55 @@ def get_path_with_query_string(scope: "WWWScope") -> str: path_with_query_string, scope["query_string"].decode("ascii") ) return path_with_query_string + + +def get_tls_info(transport: asyncio.Transport) -> Optional[Dict]: + + ### + # server_cert: Unable to set from transport information + # client_cert_chain: Just the peercert, currently no access to the full cert chain + # client_cert_name: + # client_cert_error: No access to this + # tls_version: + # cipher_suite: Too hard to convert without direct access to openssl + ### + + ssl_info: Dict[str, Any] = { + "server_cert": None, + "client_cert_chain": [], + "client_cert_name": None, + "client_cert_error": None, + "tls_version": None, + "cipher_suite": None, + } + + ssl_object = transport.get_extra_info("ssl_object", default=None) + peercert = ssl_object.getpeercert() + + if peercert: + rdn_strings = [] + for rdn in peercert["subject"]: + rdn_strings.append( + "+".join( + [ + "%s = %s" % (RDNS_MAPPING[entry[0]], entry[1]) + for entry in reversed(rdn) + if entry[0] in RDNS_MAPPING + ] + ) + ) + + ssl_info["client_cert_chain"] = [ + ssl.DER_cert_to_PEM_cert(ssl_object.getpeercert(binary_form=True)) + ] + ssl_info["client_cert_name"] = ", ".join(rdn_strings) if rdn_strings else "" + ssl_info["tls_version"] = ( + TLS_VERSION_MAP[ssl_object.version()] + if ssl_object.version() in TLS_VERSION_MAP + else None + ) + ssl_info["cipher_suite"] = list(ssl_object.cipher()) + + return ssl_info + + return None From 35bb6896271398f513ba94ea3760b613799b9701 Mon Sep 17 00:00:00 2001 From: Matt Gilen Date: Sun, 18 Jul 2021 16:10:35 +0000 Subject: [PATCH 2/7] Only add tls extension if connection is over tls --- uvicorn/protocols/http/h11_impl.py | 6 +++--- uvicorn/protocols/http/httptools_impl.py | 6 +++--- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/uvicorn/protocols/http/h11_impl.py b/uvicorn/protocols/http/h11_impl.py index 63e8c7653..dd99a7e42 100644 --- a/uvicorn/protocols/http/h11_impl.py +++ b/uvicorn/protocols/http/h11_impl.py @@ -242,10 +242,10 @@ def handle_events(self) -> None: "raw_path": raw_path, "query_string": query_string, "headers": self.headers, - "extensions": { - "tls": self.tls, - }, + "extensions": {}, } + if self.config.is_ssl: + self.scope["extensions"]["tls"] = self.tls upgrade = self._get_upgrade() if upgrade == b"websocket" and self._should_upgrade_to_ws(): diff --git a/uvicorn/protocols/http/httptools_impl.py b/uvicorn/protocols/http/httptools_impl.py index e9c945ffb..a98348afe 100644 --- a/uvicorn/protocols/http/httptools_impl.py +++ b/uvicorn/protocols/http/httptools_impl.py @@ -257,10 +257,10 @@ def on_message_begin(self) -> None: "scheme": self.scheme, "root_path": self.root_path, "headers": self.headers, - "extensions": { - "tls": self.tls, - }, + "extensions": {}, } + if self.config.is_ssl: + self.scope["extensions"]["tls"] = self.tls # Parser callbacks def on_url(self, url: bytes) -> None: From fa4a6be2eb0d9c33d9fa296d10e5cb0ecf31134c Mon Sep 17 00:00:00 2001 From: Paul Brussee Date: Tue, 10 Jan 2023 20:56:41 +0100 Subject: [PATCH 3/7] feat: use connection scheme workaround: pass server_cert from config --- uvicorn/protocols/http/h11_impl.py | 8 +++----- uvicorn/protocols/http/httptools_impl.py | 8 +++----- uvicorn/protocols/utils.py | 23 ++++++++++++----------- 3 files changed, 18 insertions(+), 21 deletions(-) diff --git a/uvicorn/protocols/http/h11_impl.py b/uvicorn/protocols/http/h11_impl.py index dd99a7e42..a2e0f6189 100644 --- a/uvicorn/protocols/http/h11_impl.py +++ b/uvicorn/protocols/http/h11_impl.py @@ -131,10 +131,8 @@ def connection_made( # type: ignore[override] self.client = get_remote_addr(transport) self.scheme = "https" if is_ssl(transport) else "http" - if self.config.is_ssl: - self.tls = get_tls_info(transport) - if self.tls: - self.tls["server_cert"] = self.config.ssl_cert_pem + if self.scheme == "https": + self.tls = get_tls_info(transport, self.config.ssl_cert_pem) if self.logger.level <= TRACE_LOG_LEVEL: prefix = "%s:%d - " % self.client if self.client else "" @@ -244,7 +242,7 @@ def handle_events(self) -> None: "headers": self.headers, "extensions": {}, } - if self.config.is_ssl: + if self.scheme == "https": self.scope["extensions"]["tls"] = self.tls upgrade = self._get_upgrade() diff --git a/uvicorn/protocols/http/httptools_impl.py b/uvicorn/protocols/http/httptools_impl.py index a98348afe..b85d81afb 100644 --- a/uvicorn/protocols/http/httptools_impl.py +++ b/uvicorn/protocols/http/httptools_impl.py @@ -132,10 +132,8 @@ def connection_made( # type: ignore[override] self.client = get_remote_addr(transport) self.scheme = "https" if is_ssl(transport) else "http" - if self.config.is_ssl: - self.tls = get_tls_info(transport) - if self.tls: - self.tls["server_cert"] = self.config.ssl_cert_pem + if self.scheme == "https": + self.tls = get_tls_info(transport, self.config.ssl_cert_pem) if self.logger.level <= TRACE_LOG_LEVEL: prefix = "%s:%d - " % self.client if self.client else "" @@ -259,7 +257,7 @@ def on_message_begin(self) -> None: "headers": self.headers, "extensions": {}, } - if self.config.is_ssl: + if self.scheme == "https": self.scope["extensions"]["tls"] = self.tls # Parser callbacks diff --git a/uvicorn/protocols/utils.py b/uvicorn/protocols/utils.py index b6e7a662a..9b22d9fa7 100644 --- a/uvicorn/protocols/utils.py +++ b/uvicorn/protocols/utils.py @@ -77,10 +77,12 @@ def get_path_with_query_string(scope: "WWWScope") -> str: return path_with_query_string -def get_tls_info(transport: asyncio.Transport) -> Optional[Dict]: +def get_tls_info( + transport: asyncio.Transport, server_pem: Optional[str] = None +) -> Dict: ### - # server_cert: Unable to set from transport information + # server_cert: Unable to set from transport information, need to read from config # client_cert_chain: Just the peercert, currently no access to the full cert chain # client_cert_name: # client_cert_error: No access to this @@ -89,7 +91,7 @@ def get_tls_info(transport: asyncio.Transport) -> Optional[Dict]: ### ssl_info: Dict[str, Any] = { - "server_cert": None, + "server_cert": server_pem, "client_cert_chain": [], "client_cert_name": None, "client_cert_error": None, @@ -117,13 +119,12 @@ def get_tls_info(transport: asyncio.Transport) -> Optional[Dict]: ssl.DER_cert_to_PEM_cert(ssl_object.getpeercert(binary_form=True)) ] ssl_info["client_cert_name"] = ", ".join(rdn_strings) if rdn_strings else "" - ssl_info["tls_version"] = ( - TLS_VERSION_MAP[ssl_object.version()] - if ssl_object.version() in TLS_VERSION_MAP - else None - ) - ssl_info["cipher_suite"] = list(ssl_object.cipher()) - return ssl_info + ssl_info["tls_version"] = ( + TLS_VERSION_MAP[ssl_object.version()] + if ssl_object.version() in TLS_VERSION_MAP + else None + ) + ssl_info["cipher_suite"] = list(ssl_object.cipher()) - return None + return ssl_info From 67646a05d32509c7f9631392df57cd878f3f7e8c Mon Sep 17 00:00:00 2001 From: Paul Brussee Date: Tue, 10 Jan 2023 20:58:35 +0100 Subject: [PATCH 4/7] feat: add tls extension to websocket protocols --- .../protocols/websockets/websockets_impl.py | 23 ++++++++++++++++++- uvicorn/protocols/websockets/wsproto_impl.py | 12 +++++++++- 2 files changed, 33 insertions(+), 2 deletions(-) diff --git a/uvicorn/protocols/websockets/websockets_impl.py b/uvicorn/protocols/websockets/websockets_impl.py index 297203ec6..1335733cd 100644 --- a/uvicorn/protocols/websockets/websockets_impl.py +++ b/uvicorn/protocols/websockets/websockets_impl.py @@ -2,10 +2,22 @@ import http import logging import sys -from typing import TYPE_CHECKING, Any, List, Optional, Sequence, Tuple, Union, cast + +from typing import ( + TYPE_CHECKING, + Any, + Dict, + List, + Optional, + Sequence, + Tuple, + Union, + cast, +) from urllib.parse import unquote import websockets + from websockets.datastructures import Headers from websockets.exceptions import ConnectionClosed from websockets.extensions.permessage_deflate import ServerPerMessageDeflateFactory @@ -19,10 +31,12 @@ get_local_addr, get_path_with_query_string, get_remote_addr, + get_tls_info, is_ssl, ) from uvicorn.server import ServerState + if sys.version_info < (3, 8): # pragma: py-gte-38 from typing_extensions import Literal else: # pragma: py-lt-38 @@ -80,6 +94,7 @@ def __init__( self.server: Optional[Tuple[str, int]] = None self.client: Optional[Tuple[str, int]] = None self.scheme: Literal["wss", "ws"] = None # type: ignore[assignment] + self.tls: Optional[Dict[str, Any]] = None # Connection events self.scope: WebSocketScope = None # type: ignore[assignment] @@ -121,6 +136,9 @@ def connection_made( # type: ignore[override] self.client = get_remote_addr(transport) self.scheme = "wss" if is_ssl(transport) else "ws" + if self.scheme == "wss": + self.tls = get_tls_info(transport, self.config.ssl_cert_pem) + if self.logger.isEnabledFor(TRACE_LOG_LEVEL): prefix = "%s:%d - " % self.client if self.client else "" self.logger.log(TRACE_LOG_LEVEL, "%sWebSocket connection made", prefix) @@ -190,7 +208,10 @@ async def process_request( "query_string": query_string.encode("ascii"), "headers": asgi_headers, "subprotocols": subprotocols, + "extensions": {}, } + if self.scheme == "wss": + self.scope["extensions"]["tls"] = self.tls task = self.loop.create_task(self.run_asgi()) task.add_done_callback(self.on_task_complete) self.tasks.add(task) diff --git a/uvicorn/protocols/websockets/wsproto_impl.py b/uvicorn/protocols/websockets/wsproto_impl.py index 1d76f3a88..ed13b0c9c 100644 --- a/uvicorn/protocols/websockets/wsproto_impl.py +++ b/uvicorn/protocols/websockets/wsproto_impl.py @@ -2,9 +2,11 @@ import logging import sys import typing + from urllib.parse import unquote import wsproto + from wsproto import ConnectionType, events from wsproto.connection import ConnectionState from wsproto.extensions import Extension, PerMessageDeflate @@ -16,10 +18,12 @@ get_local_addr, get_path_with_query_string, get_remote_addr, + get_tls_info, is_ssl, ) from uvicorn.server import ServerState + if typing.TYPE_CHECKING: from asgiref.typing import ( ASGISendEvent, @@ -70,6 +74,7 @@ def __init__( self.server: typing.Optional[typing.Tuple[str, int]] = None self.client: typing.Optional[typing.Tuple[str, int]] = None self.scheme: Literal["wss", "ws"] = None # type: ignore[assignment] + self.tls: typing.Optional[typing.Dict[str, typing.Any]] = None # WebSocket state self.queue: asyncio.Queue["WebSocketEvent"] = asyncio.Queue() @@ -97,6 +102,9 @@ def connection_made( # type: ignore[override] self.client = get_remote_addr(transport) self.scheme = "wss" if is_ssl(transport) else "ws" + if self.scheme == "wss": + self.tls = get_tls_info(transport, self.config.ssl_cert_pem) + if self.logger.level <= TRACE_LOG_LEVEL: prefix = "%s:%d - " % self.client if self.client else "" self.logger.log(TRACE_LOG_LEVEL, "%sWebSocket connection made", prefix) @@ -183,8 +191,10 @@ def handle_connect(self, event: events.Request) -> None: "query_string": query_string.encode("ascii"), "headers": headers, "subprotocols": event.subprotocols, - "extensions": None, + "extensions": {}, } + if self.scheme == "wss": + self.scope["extensions"]["tls"] = self.tls self.queue.put_nowait({"type": "websocket.connect"}) task = self.loop.create_task(self.run_asgi()) task.add_done_callback(self.on_task_complete) From aacd0bdc0ddd0cd3da9862203b47b70f2072eea3 Mon Sep 17 00:00:00 2001 From: Paul Brussee Date: Tue, 10 Jan 2023 20:20:07 +0100 Subject: [PATCH 5/7] doc: add to changelog --- CHANGELOG.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index e76f31e26..d1b5edf61 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,7 @@ # Change Log +- Add [ASGI TLS Extension version: 0.2](https://github.com/jonfoster/asgiref/blob/master/specs/tls.rst) to h11, httptools, websockets, and wsproto impl (#1119) + ## 0.20.0 - 2022-11-20 ### Added From 3b1766633d30f2d06426c02dcfccd50b12928b47 Mon Sep 17 00:00:00 2001 From: Paul Brussee Date: Tue, 10 Jan 2023 21:33:03 +0100 Subject: [PATCH 6/7] type: ignore type at assignment --- uvicorn/protocols/http/h11_impl.py | 4 ++-- uvicorn/protocols/http/httptools_impl.py | 2 +- uvicorn/protocols/websockets/websockets_impl.py | 3 ++- uvicorn/protocols/websockets/wsproto_impl.py | 3 ++- 4 files changed, 7 insertions(+), 5 deletions(-) diff --git a/uvicorn/protocols/http/h11_impl.py b/uvicorn/protocols/http/h11_impl.py index a2e0f6189..6cb1aba54 100644 --- a/uvicorn/protocols/http/h11_impl.py +++ b/uvicorn/protocols/http/h11_impl.py @@ -233,7 +233,7 @@ def handle_events(self) -> None: "http_version": event.http_version.decode("ascii"), "server": self.server, "client": self.client, - "scheme": self.scheme, + "scheme": self.scheme, # type: ignore[typeddict-item] "method": event.method.decode("ascii"), "root_path": self.root_path, "path": unquote(raw_path.decode("ascii")), @@ -243,7 +243,7 @@ def handle_events(self) -> None: "extensions": {}, } if self.scheme == "https": - self.scope["extensions"]["tls"] = self.tls + self.scope["extensions"]["tls"] = self.tls # type: ignore upgrade = self._get_upgrade() if upgrade == b"websocket" and self._should_upgrade_to_ws(): diff --git a/uvicorn/protocols/http/httptools_impl.py b/uvicorn/protocols/http/httptools_impl.py index b85d81afb..bc545ab92 100644 --- a/uvicorn/protocols/http/httptools_impl.py +++ b/uvicorn/protocols/http/httptools_impl.py @@ -258,7 +258,7 @@ def on_message_begin(self) -> None: "extensions": {}, } if self.scheme == "https": - self.scope["extensions"]["tls"] = self.tls + self.scope["extensions"]["tls"] = self.tls # type: ignore # Parser callbacks def on_url(self, url: bytes) -> None: diff --git a/uvicorn/protocols/websockets/websockets_impl.py b/uvicorn/protocols/websockets/websockets_impl.py index 1335733cd..36e1f02ee 100644 --- a/uvicorn/protocols/websockets/websockets_impl.py +++ b/uvicorn/protocols/websockets/websockets_impl.py @@ -211,7 +211,8 @@ async def process_request( "extensions": {}, } if self.scheme == "wss": - self.scope["extensions"]["tls"] = self.tls + self.scope["extensions"]["tls"] = self.tls # type: ignore + task = self.loop.create_task(self.run_asgi()) task.add_done_callback(self.on_task_complete) self.tasks.add(task) diff --git a/uvicorn/protocols/websockets/wsproto_impl.py b/uvicorn/protocols/websockets/wsproto_impl.py index ed13b0c9c..a373509af 100644 --- a/uvicorn/protocols/websockets/wsproto_impl.py +++ b/uvicorn/protocols/websockets/wsproto_impl.py @@ -194,7 +194,8 @@ def handle_connect(self, event: events.Request) -> None: "extensions": {}, } if self.scheme == "wss": - self.scope["extensions"]["tls"] = self.tls + self.scope["extensions"]["tls"] = self.tls # type: ignore + self.queue.put_nowait({"type": "websocket.connect"}) task = self.loop.create_task(self.run_asgi()) task.add_done_callback(self.on_task_complete) From 0f92ca796bdcd4fdafcaec68c4b22dc84672d189 Mon Sep 17 00:00:00 2001 From: Paul Brussee Date: Tue, 10 Jan 2023 22:09:51 +0100 Subject: [PATCH 7/7] style: remove newlines --- tests/conftest.py | 3 --- uvicorn/protocols/http/h11_impl.py | 2 -- uvicorn/protocols/http/httptools_impl.py | 2 -- uvicorn/protocols/utils.py | 2 -- uvicorn/protocols/websockets/websockets_impl.py | 3 --- uvicorn/protocols/websockets/wsproto_impl.py | 3 --- 6 files changed, 15 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index 1bc6d1d97..9485b1d82 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -2,7 +2,6 @@ import os import socket import ssl - from copy import deepcopy from hashlib import md5 from pathlib import Path @@ -13,13 +12,11 @@ import pytest import trustme - from cryptography.hazmat.backends import default_backend from cryptography.hazmat.primitives import serialization from uvicorn.config import LOGGING_CONFIG - # Note: We explicitly turn the propagate on just for tests, because pytest # caplog not able to capture no-propagate loggers. # diff --git a/uvicorn/protocols/http/h11_impl.py b/uvicorn/protocols/http/h11_impl.py index 6cb1aba54..430b9ba74 100644 --- a/uvicorn/protocols/http/h11_impl.py +++ b/uvicorn/protocols/http/h11_impl.py @@ -2,7 +2,6 @@ import http import logging import sys - from typing import ( TYPE_CHECKING, Any, @@ -36,7 +35,6 @@ ) from uvicorn.server import ServerState - if sys.version_info < (3, 8): # pragma: py-gte-38 from typing_extensions import Literal else: # pragma: py-lt-38 diff --git a/uvicorn/protocols/http/httptools_impl.py b/uvicorn/protocols/http/httptools_impl.py index bc545ab92..64bd6f899 100644 --- a/uvicorn/protocols/http/httptools_impl.py +++ b/uvicorn/protocols/http/httptools_impl.py @@ -4,7 +4,6 @@ import re import sys import urllib - from asyncio.events import TimerHandle from collections import deque from typing import ( @@ -40,7 +39,6 @@ ) from uvicorn.server import ServerState - if sys.version_info < (3, 8): # pragma: py-gte-38 from typing_extensions import Literal else: # pragma: py-lt-38 diff --git a/uvicorn/protocols/utils.py b/uvicorn/protocols/utils.py index 9b22d9fa7..96b9ac09d 100644 --- a/uvicorn/protocols/utils.py +++ b/uvicorn/protocols/utils.py @@ -1,10 +1,8 @@ import asyncio import ssl import urllib.parse - from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple - if TYPE_CHECKING: from asgiref.typing import WWWScope diff --git a/uvicorn/protocols/websockets/websockets_impl.py b/uvicorn/protocols/websockets/websockets_impl.py index 36e1f02ee..6142dca9e 100644 --- a/uvicorn/protocols/websockets/websockets_impl.py +++ b/uvicorn/protocols/websockets/websockets_impl.py @@ -2,7 +2,6 @@ import http import logging import sys - from typing import ( TYPE_CHECKING, Any, @@ -17,7 +16,6 @@ from urllib.parse import unquote import websockets - from websockets.datastructures import Headers from websockets.exceptions import ConnectionClosed from websockets.extensions.permessage_deflate import ServerPerMessageDeflateFactory @@ -36,7 +34,6 @@ ) from uvicorn.server import ServerState - if sys.version_info < (3, 8): # pragma: py-gte-38 from typing_extensions import Literal else: # pragma: py-lt-38 diff --git a/uvicorn/protocols/websockets/wsproto_impl.py b/uvicorn/protocols/websockets/wsproto_impl.py index a373509af..951f13e07 100644 --- a/uvicorn/protocols/websockets/wsproto_impl.py +++ b/uvicorn/protocols/websockets/wsproto_impl.py @@ -2,11 +2,9 @@ import logging import sys import typing - from urllib.parse import unquote import wsproto - from wsproto import ConnectionType, events from wsproto.connection import ConnectionState from wsproto.extensions import Extension, PerMessageDeflate @@ -23,7 +21,6 @@ ) from uvicorn.server import ServerState - if typing.TYPE_CHECKING: from asgiref.typing import ( ASGISendEvent,