Skip to content

Commit

Permalink
Avoid creating tasks for starting/finishing the connection (#826)
Browse files Browse the repository at this point in the history
  • Loading branch information
bdraco authored Feb 17, 2024
1 parent 939c529 commit e2bbbf4
Show file tree
Hide file tree
Showing 6 changed files with 210 additions and 94 deletions.
8 changes: 6 additions & 2 deletions aioesphomeapi/connection.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -98,8 +98,8 @@ cdef class APIConnection:
cdef object _pong_timer
cdef float _keep_alive_interval
cdef float _keep_alive_timeout
cdef object _start_connect_task
cdef object _finish_connect_task
cdef object _start_connect_future
cdef object _finish_connect_future
cdef public Exception _fatal_exception
cdef bint _expected_disconnect
cdef object _loop
Expand Down Expand Up @@ -154,3 +154,7 @@ cdef class APIConnection:
cdef void _register_internal_message_handlers(self)

cdef void _increase_recv_buffer_size(self)

cdef void _set_start_connect_future(self)

cdef void _set_finish_connect_future(self)
106 changes: 54 additions & 52 deletions aioesphomeapi/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from typing import TYPE_CHECKING, Any, Callable

import aiohappyeyeballs
from async_interrupt import interrupt
from google.protobuf import message

import aioesphomeapi.host_resolver as hr
Expand Down Expand Up @@ -106,6 +107,10 @@
_float = float


class ConnectionInterruptedError(Exception):
"""An error that is raised when a connection is interrupted."""


@dataclass
class ConnectionParams:
addresses: list[str]
Expand Down Expand Up @@ -198,8 +203,8 @@ class APIConnection:
"_pong_timer",
"_keep_alive_interval",
"_keep_alive_timeout",
"_start_connect_task",
"_finish_connect_task",
"_start_connect_future",
"_finish_connect_future",
"_fatal_exception",
"_expected_disconnect",
"_loop",
Expand Down Expand Up @@ -242,8 +247,8 @@ def __init__(
self._keep_alive_interval = keepalive
self._keep_alive_timeout = keepalive * KEEP_ALIVE_TIMEOUT_RATIO

self._start_connect_task: asyncio.Task[None] | None = None
self._finish_connect_task: asyncio.Task[None] | None = None
self._start_connect_future: asyncio.Future[None] | None = None
self._finish_connect_future: asyncio.Future[None] | None = None
self._fatal_exception: Exception | None = None
self._expected_disconnect = False
self._send_pending_ping = False
Expand Down Expand Up @@ -276,28 +281,13 @@ def _cleanup(self) -> None:
err = self._fatal_exception or APIConnectionError("Connection closed")
new_exc = err
if not isinstance(err, APIConnectionError):
new_exc = ReadFailedAPIError("Read failed")
new_exc = ReadFailedAPIError(str(err) or "Read failed")
new_exc.__cause__ = err
fut.set_exception(new_exc)
self._read_exception_futures.clear()
# If we are being called from do_connect we
# need to make sure we don't cancel the task
# that called us
current_task = asyncio.current_task()

if (
self._start_connect_task is not None
and self._start_connect_task is not current_task
):
self._start_connect_task.cancel("Connection cleanup")
self._start_connect_task = None

if (
self._finish_connect_task is not None
and self._finish_connect_task is not current_task
):
self._finish_connect_task.cancel("Connection cleanup")
self._finish_connect_task = None
self._set_start_connect_future()
self._set_finish_connect_future()

if self._frame_helper is not None:
self._frame_helper.close()
Expand Down Expand Up @@ -460,7 +450,9 @@ async def _connect_init_frame_helper(self) -> None:
try:
await self._frame_helper.ready_future
except asyncio_TimeoutError as err:
raise TimeoutAPIError("Handshake timed out") from err
raise TimeoutAPIError(
f"Handshake timed out after {HANDSHAKE_TIMEOUT}s"
) from err
except OSError as err:
raise HandshakeAPIError(f"Handshake failed: {err}") from err
finally:
Expand All @@ -475,19 +467,14 @@ async def _connect_hello_login(self, login: bool) -> None:
messages.append(self._make_connect_request())
msg_types.append(ConnectResponse)

try:
responses = await self.send_messages_await_response_complex(
tuple(messages),
None,
lambda resp: type(resp) # pylint: disable=unidiomatic-typecheck
is msg_types[-1],
tuple(msg_types),
CONNECT_REQUEST_TIMEOUT,
)
except TimeoutAPIError as err:
self.report_fatal_error(err)
raise TimeoutAPIError("Hello timed out") from err

responses = await self.send_messages_await_response_complex(
tuple(messages),
None,
lambda resp: type(resp) # pylint: disable=unidiomatic-typecheck
is msg_types[-1],
tuple(msg_types),
CONNECT_REQUEST_TIMEOUT,
)
resp = responses.pop(0)
self._process_hello_resp(resp)
if login:
Expand Down Expand Up @@ -605,29 +592,37 @@ async def start_connection(self) -> None:
"Connection can only be used once, connection is not in init state"
)

start_connect_task = asyncio.create_task(
self._do_connect(), name=f"{self.log_name}: aioesphomeapi do_connect"
)
self._start_connect_task = start_connect_task
self._start_connect_future = self._loop.create_future()
try:
await start_connect_task
async with interrupt(
self._start_connect_future, ConnectionInterruptedError, None
):
await self._do_connect()
except (Exception, CancelledError) as ex:
# If the task was cancelled, we need to clean up the connection
# and raise the CancelledError as APIConnectionError
self._cleanup()
raise self._wrap_fatal_connection_exception("starting", ex)
finally:
self._start_connect_task = None
self._set_start_connect_future()
self._set_connection_state(CONNECTION_STATE_SOCKET_OPENED)

def _set_start_connect_future(self) -> None:
if (
self._start_connect_future is not None
and not self._start_connect_future.done()
):
self._start_connect_future.set_result(None)
self._start_connect_future = None

def _wrap_fatal_connection_exception(
self, action: str, ex: BaseException
) -> APIConnectionError:
"""Ensure a fatal exception is wrapped as as an APIConnectionError."""
if isinstance(ex, APIConnectionError):
return ex
cause: BaseException | None = None
if isinstance(ex, CancelledError):
if isinstance(ex, (ConnectionInterruptedError, CancelledError)):
err_str = f"{action.title()} connection cancelled"
if self._fatal_exception:
err_str += f" due to fatal exception: {self._fatal_exception}"
Expand Down Expand Up @@ -664,22 +659,29 @@ async def finish_connection(self, *, login: bool) -> None:
raise RuntimeError(
"Connection must be in SOCKET_OPENED state to finish connection"
)
finish_connect_task = asyncio.create_task(
self._do_finish_connect(login),
name=f"{self.log_name}: aioesphomeapi _do_finish_connect",
)
self._finish_connect_task = finish_connect_task
self._finish_connect_future = self._loop.create_future()
try:
await self._finish_connect_task
async with interrupt(
self._finish_connect_future, ConnectionInterruptedError, None
):
await self._do_finish_connect(login)
except (Exception, CancelledError) as ex:
# If the task was cancelled, we need to clean up the connection
# and raise the CancelledError as APIConnectionError
self._cleanup()
raise self._wrap_fatal_connection_exception("finishing", ex)
finally:
self._finish_connect_task = None
self._set_finish_connect_future()
self._set_connection_state(CONNECTION_STATE_CONNECTED)

def _set_finish_connect_future(self) -> None:
if (
self._finish_connect_future is not None
and not self._finish_connect_future.done()
):
self._finish_connect_future.set_result(None)
self._finish_connect_future = None

def _set_connection_state(self, state: ConnectionState) -> None:
"""Set the connection state and log the change."""
self.connection_state = state
Expand Down Expand Up @@ -969,12 +971,12 @@ def _handle_get_time_request_internal( # pylint: disable=unused-argument

async def disconnect(self) -> None:
"""Disconnect from the API."""
if self._finish_connect_task is not None:
if self._finish_connect_future is not None:
# Try to wait for the handshake to finish so we can send
# a disconnect request. If it doesn't finish in time
# we will just close the socket.
_, pending = await asyncio.wait(
[self._finish_connect_task], timeout=DISCONNECT_CONNECT_TIMEOUT
[self._finish_connect_future], timeout=DISCONNECT_CONNECT_TIMEOUT
)
if pending:
self._set_fatal_exception_if_unset(
Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
aiohappyeyeballs>=2.3.0
async-interrupt>=1.1.1
protobuf>=3.19.0
zeroconf>=0.128.4,<1.0
chacha20poly1305-reuseable>=0.12.1
Expand Down
15 changes: 13 additions & 2 deletions tests/test__frame_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
HandshakeAPIError,
InvalidEncryptionKeyAPIError,
ProtocolAPIError,
ReadFailedAPIError,
SocketClosedAPIError,
)

Expand Down Expand Up @@ -725,18 +726,28 @@ async def test_eof_received_closes_connection(
await connect_task


@pytest.mark.parametrize(
("exception_map"),
[
(OSError("original message"), ReadFailedAPIError),
(APIConnectionError("original message"), APIConnectionError),
(SocketClosedAPIError("original message"), SocketClosedAPIError),
],
)
@pytest.mark.asyncio
async def test_connection_lost_closes_connection_and_logs(
caplog: pytest.LogCaptureFixture,
plaintext_connect_task_with_login: tuple[
APIConnection, asyncio.Transport, APIPlaintextFrameHelper, asyncio.Task
],
exception_map: tuple[Exception, Exception],
) -> None:
exception, raised_exception = exception_map
conn, transport, protocol, connect_task = plaintext_connect_task_with_login
protocol.connection_lost(OSError("original message"))
protocol.connection_lost(exception)
assert conn.is_connected is False
assert "original message" in caplog.text
with pytest.raises(APIConnectionError, match="original message"):
with pytest.raises(raised_exception, match="original message"):
await connect_task


Expand Down
Loading

0 comments on commit e2bbbf4

Please sign in to comment.