Skip to content

Commit

Permalink
feat: add tls extension to websocket protocols
Browse files Browse the repository at this point in the history
  • Loading branch information
brussee committed Jan 10, 2023
1 parent fbc2ce6 commit 9fdcfde
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 2 deletions.
15 changes: 14 additions & 1 deletion uvicorn/protocols/websockets/websockets_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,12 @@
import http
import logging
import sys
from typing import TYPE_CHECKING, Any, List, Optional, Sequence, Tuple, Union, cast

from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple, Union, cast
from urllib.parse import unquote

import websockets

from websockets.datastructures import Headers
from websockets.exceptions import ConnectionClosed
from websockets.extensions.permessage_deflate import ServerPerMessageDeflateFactory
Expand All @@ -19,10 +21,12 @@
get_local_addr,
get_path_with_query_string,
get_remote_addr,
get_tls_info,
is_ssl,
)
from uvicorn.server import ServerState


if sys.version_info < (3, 8): # pragma: py-gte-38
from typing_extensions import Literal
else: # pragma: py-lt-38
Expand Down Expand Up @@ -80,6 +84,7 @@ def __init__(
self.server: Optional[Tuple[str, int]] = None
self.client: Optional[Tuple[str, int]] = None
self.scheme: Literal["wss", "ws"] = None # type: ignore[assignment]
self.tls: Optional[Dict[str, Any]]

# Connection events
self.scope: WebSocketScope = None # type: ignore[assignment]
Expand Down Expand Up @@ -121,6 +126,11 @@ def connection_made( # type: ignore[override]
self.client = get_remote_addr(transport)
self.scheme = "wss" if is_ssl(transport) else "ws"

if self.scheme == "wss":
self.tls = get_tls_info(transport)
if not self.tls["server_cert"]:
self.tls["server_cert"] = self.config.ssl_cert_pem

if self.logger.isEnabledFor(TRACE_LOG_LEVEL):
prefix = "%s:%d - " % self.client if self.client else ""
self.logger.log(TRACE_LOG_LEVEL, "%sWebSocket connection made", prefix)
Expand Down Expand Up @@ -190,7 +200,10 @@ async def process_request(
"query_string": query_string.encode("ascii"),
"headers": asgi_headers,
"subprotocols": subprotocols,
"extensions": {},
}
if self.scheme == "wss":
self.scope["extensions"]["tls"] = self.tls
task = self.loop.create_task(self.run_asgi())
task.add_done_callback(self.on_task_complete)
self.tasks.add(task)
Expand Down
13 changes: 12 additions & 1 deletion uvicorn/protocols/websockets/wsproto_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,11 @@
import logging
import sys
import typing

from urllib.parse import unquote

import wsproto

from wsproto import ConnectionType, events
from wsproto.connection import ConnectionState
from wsproto.extensions import Extension, PerMessageDeflate
Expand All @@ -16,10 +18,12 @@
get_local_addr,
get_path_with_query_string,
get_remote_addr,
get_tls_info,
is_ssl,
)
from uvicorn.server import ServerState


if typing.TYPE_CHECKING:
from asgiref.typing import (
ASGISendEvent,
Expand Down Expand Up @@ -97,6 +101,11 @@ def connection_made( # type: ignore[override]
self.client = get_remote_addr(transport)
self.scheme = "wss" if is_ssl(transport) else "ws"

if self.scheme == "wss":
self.tls = get_tls_info(transport)
if not self.tls["server_cert"]:
self.tls["server_cert"] = self.config.ssl_cert_pem

if self.logger.level <= TRACE_LOG_LEVEL:
prefix = "%s:%d - " % self.client if self.client else ""
self.logger.log(TRACE_LOG_LEVEL, "%sWebSocket connection made", prefix)
Expand Down Expand Up @@ -183,8 +192,10 @@ def handle_connect(self, event: events.Request) -> None:
"query_string": query_string.encode("ascii"),
"headers": headers,
"subprotocols": event.subprotocols,
"extensions": None,
"extensions": {},
}
if self.scheme == "wss":
self.scope["extensions"]["tls"] = self.tls
self.queue.put_nowait({"type": "websocket.connect"})
task = self.loop.create_task(self.run_asgi())
task.add_done_callback(self.on_task_complete)
Expand Down

0 comments on commit 9fdcfde

Please sign in to comment.