Skip to content

Commit

Permalink
Add happy eyeballs support (RFC 8305) (#789)
Browse files Browse the repository at this point in the history
  • Loading branch information
bdraco authored Dec 12, 2023
1 parent 280b9a7 commit 05ee53c
Show file tree
Hide file tree
Showing 10 changed files with 165 additions and 91 deletions.
2 changes: 1 addition & 1 deletion aioesphomeapi/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,7 +280,7 @@ def _set_log_name(self) -> None:
"""Set the log name of the device."""
resolved_address: str | None = None
if self._connection and self._connection.resolved_addr_info:
resolved_address = self._connection.resolved_addr_info.sockaddr.address
resolved_address = self._connection.resolved_addr_info[0].sockaddr.address
self.log_name = build_log_name(
self.cached_name,
self.address,
Expand Down
69 changes: 48 additions & 21 deletions aioesphomeapi/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from functools import lru_cache, partial
from typing import TYPE_CHECKING, Any, Callable

import aiohappyeyeballs
from google.protobuf import message

import aioesphomeapi.host_resolver as hr
Expand Down Expand Up @@ -250,7 +251,7 @@ def __init__(
self._handshake_complete = False
self._debug_enabled = debug_enabled
self.received_name: str = ""
self.resolved_addr_info: hr.AddrInfo | None = None
self.resolved_addr_info: list[hr.AddrInfo] = []

def set_log_name(self, name: str) -> None:
"""Set the friendly log name for this connection."""
Expand Down Expand Up @@ -319,7 +320,7 @@ def set_debug(self, enable: bool) -> None:
"""Enable or disable debug logging."""
self._debug_enabled = enable

async def _connect_resolve_host(self) -> hr.AddrInfo:
async def _connect_resolve_host(self) -> list[hr.AddrInfo]:
"""Step 1 in connect process: resolve the address."""
try:
async with asyncio_timeout(RESOLVE_TIMEOUT):
Expand All @@ -333,41 +334,67 @@ async def _connect_resolve_host(self) -> hr.AddrInfo:
f"Timeout while resolving IP address for {self.log_name}"
) from err

async def _connect_socket_connect(self, addr: hr.AddrInfo) -> None:
async def _connect_socket_connect(self, addrs: list[hr.AddrInfo]) -> None:
"""Step 2 in connect process: connect the socket."""
sock = socket.socket(family=addr.family, type=addr.type, proto=addr.proto)
self._socket = sock
sock.setblocking(False)
sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
# Try to reduce the pressure on esphome device as it measures
# ram in bytes and we measure ram in megabytes.
sock.setsockopt(socket.SOL_SOCKET, socket.SO_RCVBUF, BUFFER_SIZE)

if self._debug_enabled:
_LOGGER.debug(
"%s: Connecting to %s:%s (%s)",
self.log_name,
self._params.address,
self._params.port,
addr,
addrs,
)
sockaddr = astuple(addr.sockaddr)

try:
async with asyncio_timeout(TCP_CONNECT_TIMEOUT):
await self._loop.sock_connect(sock, sockaddr)
except asyncio_TimeoutError as err:
raise SocketAPIError(f"Timeout while connecting to {sockaddr}") from err
except OSError as err:
raise SocketAPIError(f"Error connecting to {sockaddr}: {err}") from err
addr_infos: list[aiohappyeyeballs.AddrInfoType] = [
(
addr.family,
addr.type,
addr.proto,
self._params.address,
astuple(addr.sockaddr),
)
for addr in addrs
]
last_exception: Exception | None = None
sock: socket.socket | None = None
interleave = 1
while addr_infos:
try:
async with asyncio_timeout(TCP_CONNECT_TIMEOUT):
sock = await aiohappyeyeballs.start_connection(
addr_infos,
happy_eyeballs_delay=0.25,
interleave=interleave,
loop=self._loop,
)
break
except (OSError, asyncio_TimeoutError) as err:
last_exception = err
aiohappyeyeballs.pop_addr_infos_interleave(addr_infos, interleave)

if sock is None:
if isinstance(last_exception, asyncio_TimeoutError):
raise TimeoutAPIError(
f"Timeout while connecting to {addrs}"
) from last_exception
raise SocketAPIError(
f"Error connecting to {addrs}: {last_exception}"
) from last_exception

self._socket = sock
sock.setblocking(False)
sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
# Try to reduce the pressure on esphome device as it measures
# ram in bytes and we measure ram in megabytes.
sock.setsockopt(socket.SOL_SOCKET, socket.SO_RCVBUF, BUFFER_SIZE)

if self._debug_enabled:
_LOGGER.debug(
"%s: Opened socket to %s:%s (%s)",
self.log_name,
self._params.address,
self._params.port,
addr,
addrs,
)

async def _connect_init_frame_helper(self) -> None:
Expand Down
12 changes: 6 additions & 6 deletions aioesphomeapi/host_resolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,8 +108,10 @@ async def _async_resolve_host_zeroconf(
timeout,
)
addrs: list[AddrInfo] = []
for ip in info.ip_addresses_by_version(IPVersion.All):
addrs.extend(_async_ip_address_to_addrs(ip, port)) # type: ignore[arg-type]
for ip in info.ip_addresses_by_version(IPVersion.V6Only):
addrs.extend(_async_ip_address_to_addrs(ip, port)) # type: ignore
for ip in info.ip_addresses_by_version(IPVersion.V4Only):
addrs.extend(_async_ip_address_to_addrs(ip, port)) # type: ignore
return addrs


Expand Down Expand Up @@ -182,7 +184,7 @@ async def async_resolve_host(
host: str,
port: int,
zeroconf_manager: ZeroconfManager | None = None,
) -> AddrInfo:
) -> list[AddrInfo]:
addrs: list[AddrInfo] = []

zc_error = None
Expand Down Expand Up @@ -210,6 +212,4 @@ async def async_resolve_host(
raise zc_error
raise ResolveAPIError(f"Could not resolve host {host} - got no results from OS")

# Use first matching result
# Future: return all matches and use first working one
return addrs[0]
return addrs
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
aiohappyeyeballs>=2.3.0
protobuf>=3.19.0
zeroconf>=0.128.4,<1.0
chacha20poly1305-reuseable>=0.12.0
Expand Down
49 changes: 35 additions & 14 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,12 +39,14 @@ def async_zeroconf():
@pytest.fixture
def resolve_host():
with patch("aioesphomeapi.host_resolver.async_resolve_host") as func:
func.return_value = AddrInfo(
family=socket.AF_INET,
type=socket.SOCK_STREAM,
proto=socket.IPPROTO_TCP,
sockaddr=IPv4Sockaddr("10.0.0.512", 6052),
)
func.return_value = [
AddrInfo(
family=socket.AF_INET,
type=socket.SOCK_STREAM,
proto=socket.IPPROTO_TCP,
sockaddr=IPv4Sockaddr("10.0.0.512", 6052),
)
]
yield func


Expand Down Expand Up @@ -114,6 +116,13 @@ def conn_with_expected_name(connection_params: ConnectionParams) -> APIConnectio
return PatchableAPIConnection(connection_params, mock_on_stop, True, None)


@pytest.fixture()
def aiohappyeyeballs_start_connection():
with patch("aioesphomeapi.connection.aiohappyeyeballs.start_connection") as func:
func.return_value = MagicMock(type=socket.SOCK_STREAM)
yield func


def _create_mock_transport_protocol(
transport: asyncio.Transport,
connected: asyncio.Event,
Expand All @@ -128,13 +137,17 @@ def _create_mock_transport_protocol(

@pytest_asyncio.fixture(name="plaintext_connect_task_no_login")
async def plaintext_connect_task_no_login(
conn: APIConnection, resolve_host, socket_socket, event_loop
conn: APIConnection,
resolve_host,
socket_socket,
event_loop,
aiohappyeyeballs_start_connection,
) -> tuple[APIConnection, asyncio.Transport, APIPlaintextFrameHelper, asyncio.Task]:
loop = asyncio.get_event_loop()
transport = MagicMock()
connected = asyncio.Event()

with patch.object(event_loop, "sock_connect"), patch.object(
with patch.object(
loop,
"create_connection",
side_effect=partial(_create_mock_transport_protocol, transport, connected),
Expand All @@ -146,12 +159,16 @@ async def plaintext_connect_task_no_login(

@pytest_asyncio.fixture(name="plaintext_connect_task_expected_name")
async def plaintext_connect_task_no_login_with_expected_name(
conn_with_expected_name: APIConnection, resolve_host, socket_socket, event_loop
conn_with_expected_name: APIConnection,
resolve_host,
socket_socket,
event_loop,
aiohappyeyeballs_start_connection,
) -> tuple[APIConnection, asyncio.Transport, APIPlaintextFrameHelper, asyncio.Task]:
transport = MagicMock()
connected = asyncio.Event()

with patch.object(event_loop, "sock_connect"), patch.object(
with patch.object(
event_loop,
"create_connection",
side_effect=partial(_create_mock_transport_protocol, transport, connected),
Expand All @@ -165,12 +182,16 @@ async def plaintext_connect_task_no_login_with_expected_name(

@pytest_asyncio.fixture(name="plaintext_connect_task_with_login")
async def plaintext_connect_task_with_login(
conn_with_password: APIConnection, resolve_host, socket_socket, event_loop
conn_with_password: APIConnection,
resolve_host,
socket_socket,
event_loop,
aiohappyeyeballs_start_connection,
) -> tuple[APIConnection, asyncio.Transport, APIPlaintextFrameHelper, asyncio.Task]:
transport = MagicMock()
connected = asyncio.Event()

with patch.object(event_loop, "sock_connect"), patch.object(
with patch.object(
event_loop,
"create_connection",
side_effect=partial(_create_mock_transport_protocol, transport, connected),
Expand All @@ -182,7 +203,7 @@ async def plaintext_connect_task_with_login(

@pytest_asyncio.fixture(name="api_client")
async def api_client(
resolve_host, socket_socket, event_loop
resolve_host, socket_socket, event_loop, aiohappyeyeballs_start_connection
) -> tuple[APIClient, APIConnection, asyncio.Transport, APIPlaintextFrameHelper]:
protocol: APIPlaintextFrameHelper | None = None
transport = MagicMock()
Expand All @@ -193,7 +214,7 @@ async def api_client(
password=None,
)

with patch.object(event_loop, "sock_connect"), patch.object(
with patch.object(
event_loop,
"create_connection",
side_effect=partial(_create_mock_transport_protocol, transport, connected),
Expand Down
28 changes: 16 additions & 12 deletions tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,14 +194,14 @@ class PatchableApiClient(APIClient):


@pytest.mark.asyncio
async def test_finish_connection_wraps_exceptions_as_unhandled_api_error() -> None:
async def test_finish_connection_wraps_exceptions_as_unhandled_api_error(
aiohappyeyeballs_start_connection,
) -> None:
"""Verify finish_connect re-wraps exceptions as UnhandledAPIError."""

cli = APIClient("1.2.3.4", 1234, None)
loop = asyncio.get_event_loop()
with patch(
"aioesphomeapi.client.APIConnection", PatchableAPIConnection
), patch.object(loop, "sock_connect"):
asyncio.get_event_loop()
with patch("aioesphomeapi.client.APIConnection", PatchableAPIConnection):
await cli.start_connection()

with patch.object(
Expand All @@ -217,9 +217,12 @@ async def test_finish_connection_wraps_exceptions_as_unhandled_api_error() -> No
async def test_connection_released_if_connecting_is_cancelled() -> None:
"""Verify connection is unset if connecting is cancelled."""
cli = APIClient("1.2.3.4", 1234, None)
loop = asyncio.get_event_loop()
asyncio.get_event_loop()

with patch.object(loop, "sock_connect", side_effect=partial(asyncio.sleep, 1)):
with patch(
"aioesphomeapi.connection.aiohappyeyeballs.start_connection",
side_effect=partial(asyncio.sleep, 1),
):
start_task = asyncio.create_task(cli.start_connection())
await asyncio.sleep(0)
assert cli._connection is not None
Expand All @@ -229,9 +232,9 @@ async def test_connection_released_if_connecting_is_cancelled() -> None:
await start_task
assert cli._connection is None

with patch(
"aioesphomeapi.client.APIConnection", PatchableAPIConnection
), patch.object(loop, "sock_connect"):
with patch("aioesphomeapi.client.APIConnection", PatchableAPIConnection), patch(
"aioesphomeapi.connection.aiohappyeyeballs.start_connection"
):
await cli.start_connection()
await asyncio.sleep(0)

Expand All @@ -252,8 +255,9 @@ class PatchableApiClient(APIClient):
pass

cli = PatchableApiClient("host", 1234, None)
with patch.object(
event_loop, "sock_connect", side_effect=partial(asyncio.sleep, 1)
with patch(
"aioesphomeapi.connection.aiohappyeyeballs.start_connection",
side_effect=partial(asyncio.sleep, 1),
), patch.object(cli, "finish_connection"):
connect_task = asyncio.create_task(cli.connect())

Expand Down
Loading

0 comments on commit 05ee53c

Please sign in to comment.