From 44aee612a48c5bd0b2f9f4a1140216dd9bc2b413 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Tue, 12 Dec 2023 10:08:01 -1000 Subject: [PATCH 01/20] Add support for passing multiple addresses to the client If we have multiple IP addresses for the ESPHome device, and we do not know which one we should connect to, they should be passed as `addresses` when creating the `APIClient` --- aioesphomeapi/client.py | 20 ++++++++++++++------ aioesphomeapi/connection.pxd | 3 ++- aioesphomeapi/connection.py | 19 ++++++++++++------- aioesphomeapi/host_resolver.py | 24 +++++++++++++----------- 4 files changed, 41 insertions(+), 25 deletions(-) diff --git a/aioesphomeapi/client.py b/aioesphomeapi/client.py index 4691ecfd..007bc117 100644 --- a/aioesphomeapi/client.py +++ b/aioesphomeapi/client.py @@ -220,6 +220,7 @@ def __init__( zeroconf_instance: ZeroconfInstanceType | None = None, noise_psk: str | None = None, expected_name: str | None = None, + addresses: list[str] | None = None, ) -> None: """Create a client, this object is shared across sessions. @@ -235,10 +236,14 @@ def __init__( :param expected_name: Require the devices name to match the given expected name. Can be used to prevent accidentally connecting to a different device if IP passed as address but DHCP reassigned IP. + :param addresses: Optional list of IP addresses to connect to which takes + precedence over the address parameter. This is most commonly used when + the device has dual stack IPv4 and IPv6 addresses and you do not know + which one to connect to. """ self._debug_enabled = _LOGGER.isEnabledFor(logging.DEBUG) self._params = ConnectionParams( - address=str(address), + addresses=addresses if addresses else [str(address)], port=port, password=password, client_info=client_info, @@ -274,17 +279,20 @@ def expected_name(self, value: str | None) -> None: @property def address(self) -> str: - return self._params.address + return self._params.addresses[0] def _set_log_name(self) -> None: """Set the log name of the device.""" - resolved_address: str | None = None - if self._connection and self._connection.resolved_addr_info: - resolved_address = self._connection.resolved_addr_info[0].sockaddr.address + ip_address: str | None = None + if self._connection: + if self._connection.connected_address: + ip_address = self._connection.connected_address + elif self._connection.resolved_addr_info: + ip_address = self._connection.resolved_addr_info[0].sockaddr.address self.log_name = build_log_name( self.cached_name, self.address, - resolved_address, + ip_address, ) if self._connection: self._connection.set_log_name(self.log_name) diff --git a/aioesphomeapi/connection.pxd b/aioesphomeapi/connection.pxd index 62aa99c3..af556aed 100644 --- a/aioesphomeapi/connection.pxd +++ b/aioesphomeapi/connection.pxd @@ -74,7 +74,7 @@ cdef object _handle_complex_message @cython.dataclasses.dataclass cdef class ConnectionParams: - cdef public str address + cdef public list addresses cdef public object port cdef public object password cdef public object client_info @@ -109,6 +109,7 @@ cdef class APIConnection: cdef bint _debug_enabled cdef public str received_name cdef public object resolved_addr_info + cdef public str connected_address cpdef void send_message(self, object msg) diff --git a/aioesphomeapi/connection.py b/aioesphomeapi/connection.py index b97565c8..efd9cf49 100644 --- a/aioesphomeapi/connection.py +++ b/aioesphomeapi/connection.py @@ -107,7 +107,7 @@ @dataclass class ConnectionParams: - address: str + addresses: list[str] port: int password: str | None client_info: str @@ -208,6 +208,7 @@ class APIConnection: "_debug_enabled", "received_name", "resolved_addr_info", + "connected_address", ) def __init__( @@ -230,7 +231,7 @@ def __init__( # Message handlers currently subscribed to incoming messages self._message_handlers: dict[Any, set[Callable[[message.Message], None]]] = {} # The friendly name to show for this connection in the logs - self.log_name = log_name or params.address + self.log_name = log_name or params.addresses # futures currently subscribed to exceptions in the read task self._read_exception_futures: set[asyncio.Future[None]] = set() @@ -252,6 +253,7 @@ def __init__( self._debug_enabled = debug_enabled self.received_name: str = "" self.resolved_addr_info: list[hr.AddrInfo] = [] + self.connected_address: str | None = None def set_log_name(self, name: str) -> None: """Set the friendly log name for this connection.""" @@ -325,7 +327,7 @@ async def _connect_resolve_host(self) -> list[hr.AddrInfo]: try: async with asyncio_timeout(RESOLVE_TIMEOUT): return await hr.async_resolve_host( - self._params.address, + self._params.addresses, self._params.port, self._params.zeroconf_manager, ) @@ -340,7 +342,7 @@ async def _connect_socket_connect(self, addrs: list[hr.AddrInfo]) -> None: _LOGGER.debug( "%s: Connecting to %s:%s (%s)", self.log_name, - self._params.address, + self._params.addresses, self._params.port, addrs, ) @@ -350,7 +352,7 @@ async def _connect_socket_connect(self, addrs: list[hr.AddrInfo]) -> None: addr.family, addr.type, addr.proto, - self._params.address, + "", astuple(addr.sockaddr), ) for addr in addrs @@ -361,9 +363,11 @@ async def _connect_socket_connect(self, addrs: list[hr.AddrInfo]) -> None: while addr_infos: try: async with asyncio_timeout(TCP_CONNECT_TIMEOUT): + # Devices are likely on the local network so we + # only use a 100ms happy eyeballs delay sock = await aiohappyeyeballs.start_connection( addr_infos, - happy_eyeballs_delay=0.25, + happy_eyeballs_delay=0.1, interleave=interleave, loop=self._loop, ) @@ -387,12 +391,13 @@ async def _connect_socket_connect(self, addrs: list[hr.AddrInfo]) -> None: # Try to reduce the pressure on esphome device as it measures # ram in bytes and we measure ram in megabytes. sock.setsockopt(socket.SOL_SOCKET, socket.SO_RCVBUF, BUFFER_SIZE) + self.connected_address = sock.getpeername()[0] if self._debug_enabled: _LOGGER.debug( "%s: Opened socket to %s:%s (%s)", self.log_name, - self._params.address, + self.connected_address, self._params.port, addrs, ) diff --git a/aioesphomeapi/host_resolver.py b/aioesphomeapi/host_resolver.py index 153f6efb..bc873df0 100644 --- a/aioesphomeapi/host_resolver.py +++ b/aioesphomeapi/host_resolver.py @@ -1,7 +1,6 @@ from __future__ import annotations import asyncio -import contextlib import logging import socket from dataclasses import dataclass @@ -181,14 +180,23 @@ def _async_ip_address_to_addrs( async def async_resolve_host( - host: str, + hosts: list[str], port: int, zeroconf_manager: ZeroconfManager | None = None, ) -> list[AddrInfo]: addrs: list[AddrInfo] = [] + zc_error: Exception | None = None + + for host in hosts: + host_is_name = host_is_name_part(host) or address_is_local(host) + + if not host_is_name: + try: + addrs.extend(_async_ip_address_to_addrs(ip_address(host), port)) + except ValueError: + # Not an IP address + continue - zc_error = None - if host_is_name_part(host) or address_is_local(host): name = host.partition(".")[0] try: addrs.extend( @@ -198,13 +206,7 @@ async def async_resolve_host( ) except ResolveAPIError as err: zc_error = err - - else: - with contextlib.suppress(ValueError): - addrs.extend(_async_ip_address_to_addrs(ip_address(host), port)) - - if not addrs: - addrs.extend(await _async_resolve_host_getaddrinfo(host, port)) + addrs.extend(await _async_resolve_host_getaddrinfo(host, port)) if not addrs: if zc_error: From e762ec3e1e72d828cd3ef7e78bdc9e940bb3a00c Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Tue, 12 Dec 2023 10:16:03 -1000 Subject: [PATCH 02/20] tweaks --- aioesphomeapi/host_resolver.py | 28 +++++++++++++++++----------- tests/conftest.py | 2 +- tests/test_host_resolver.py | 14 +++++++------- 3 files changed, 25 insertions(+), 19 deletions(-) diff --git a/aioesphomeapi/host_resolver.py b/aioesphomeapi/host_resolver.py index bc873df0..d387ca08 100644 --- a/aioesphomeapi/host_resolver.py +++ b/aioesphomeapi/host_resolver.py @@ -188,25 +188,31 @@ async def async_resolve_host( zc_error: Exception | None = None for host in hosts: + host_addrs: list[AddrInfo] = [] host_is_name = host_is_name_part(host) or address_is_local(host) + if host_is_name: + name = host.partition(".")[0] + try: + host_addrs.extend( + await _async_resolve_host_zeroconf( + name, port, zeroconf_manager=zeroconf_manager + ) + ) + except ResolveAPIError as err: + zc_error = err + if not host_is_name: try: - addrs.extend(_async_ip_address_to_addrs(ip_address(host), port)) + host_addrs.extend(_async_ip_address_to_addrs(ip_address(host), port)) except ValueError: # Not an IP address continue - name = host.partition(".")[0] - try: - addrs.extend( - await _async_resolve_host_zeroconf( - name, port, zeroconf_manager=zeroconf_manager - ) - ) - except ResolveAPIError as err: - zc_error = err - addrs.extend(await _async_resolve_host_getaddrinfo(host, port)) + if not host_addrs: + host_addrs.extend(await _async_resolve_host_getaddrinfo(host, port)) + + addrs.extend(host_addrs) if not addrs: if zc_error: diff --git a/tests/conftest.py b/tests/conftest.py index 35bb2644..cc5f489d 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -71,7 +71,7 @@ class PatchableAPIClient(APIClient): def get_mock_connection_params() -> ConnectionParams: return ConnectionParams( - address="fake.address", + addresses=["fake.address"], port=6052, password=None, client_info="Tests client", diff --git a/tests/test_host_resolver.py b/tests/test_host_resolver.py index 322f3789..d492cb5a 100644 --- a/tests/test_host_resolver.py +++ b/tests/test_host_resolver.py @@ -99,7 +99,7 @@ async def test_resolve_host_zeroconf_fails_end_to_end(async_zeroconf: AsyncZeroc "aioesphomeapi.host_resolver.AsyncServiceInfo.async_request", side_effect=Exception("no buffers"), ), pytest.raises(ResolveAPIError, match="no buffers"): - await hr.async_resolve_host("asdf.local", 6052) + await hr.async_resolve_host(["asdf.local"], 6052) @pytest.mark.asyncio @@ -140,7 +140,7 @@ async def test_resolve_host_getaddrinfo_oserror(event_loop): @patch("aioesphomeapi.host_resolver._async_resolve_host_getaddrinfo") async def test_resolve_host_mdns(resolve_addr, resolve_zc, addr_infos): resolve_zc.return_value = addr_infos - ret = await hr.async_resolve_host("example.local", 6052) + ret = await hr.async_resolve_host(["example.local"], 6052) resolve_zc.assert_called_once_with("example", 6052, zeroconf_manager=None) resolve_addr.assert_not_called() @@ -153,7 +153,7 @@ async def test_resolve_host_mdns(resolve_addr, resolve_zc, addr_infos): async def test_resolve_host_mdns_empty(resolve_addr, resolve_zc, addr_infos): resolve_zc.return_value = [] resolve_addr.return_value = addr_infos - ret = await hr.async_resolve_host("example.local", 6052) + ret = await hr.async_resolve_host(["example.local"], 6052) resolve_zc.assert_called_once_with("example", 6052, zeroconf_manager=None) resolve_addr.assert_called_once_with("example.local", 6052) @@ -166,7 +166,7 @@ async def test_resolve_host_mdns_empty(resolve_addr, resolve_zc, addr_infos): async def test_resolve_host_mdns_no_results(resolve_addr, addr_infos): resolve_addr.return_value = addr_infos with pytest.raises(ResolveAPIError): - await hr.async_resolve_host("example.local", 6052) + await hr.async_resolve_host(["example.local"], 6052) @pytest.mark.asyncio @@ -174,7 +174,7 @@ async def test_resolve_host_mdns_no_results(resolve_addr, addr_infos): @patch("aioesphomeapi.host_resolver._async_resolve_host_getaddrinfo") async def test_resolve_host_addrinfo(resolve_addr, resolve_zc, addr_infos): resolve_addr.return_value = addr_infos - ret = await hr.async_resolve_host("example.com", 6052) + ret = await hr.async_resolve_host(["example.local"], 6052) resolve_zc.assert_not_called() resolve_addr.assert_called_once_with("example.com", 6052) @@ -187,7 +187,7 @@ async def test_resolve_host_addrinfo(resolve_addr, resolve_zc, addr_infos): async def test_resolve_host_addrinfo_empty(resolve_addr, resolve_zc, addr_infos): resolve_addr.return_value = [] with pytest.raises(APIConnectionError): - await hr.async_resolve_host("example.com", 6052) + await hr.async_resolve_host(["example.local"], 6052) resolve_zc.assert_not_called() resolve_addr.assert_called_once_with("example.com", 6052) @@ -199,7 +199,7 @@ async def test_resolve_host_addrinfo_empty(resolve_addr, resolve_zc, addr_infos) async def test_resolve_host_with_address(resolve_addr, resolve_zc): resolve_zc.return_value = [] resolve_addr.return_value = addr_infos - ret = await hr.async_resolve_host("127.0.0.1", 6052) + ret = await hr.async_resolve_host(["127.0.0.1"], 6052) resolve_zc.assert_not_called() resolve_addr.assert_not_called() From ba54643355445582412d8f1470a57e6ec01b528a Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Tue, 12 Dec 2023 10:16:48 -1000 Subject: [PATCH 03/20] tweaks --- aioesphomeapi/host_resolver.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/aioesphomeapi/host_resolver.py b/aioesphomeapi/host_resolver.py index d387ca08..2acf8882 100644 --- a/aioesphomeapi/host_resolver.py +++ b/aioesphomeapi/host_resolver.py @@ -189,9 +189,9 @@ async def async_resolve_host( for host in hosts: host_addrs: list[AddrInfo] = [] - host_is_name = host_is_name_part(host) or address_is_local(host) + host_is_local_name = host_is_name_part(host) or address_is_local(host) - if host_is_name: + if host_is_local_name: name = host.partition(".")[0] try: host_addrs.extend( @@ -202,7 +202,7 @@ async def async_resolve_host( except ResolveAPIError as err: zc_error = err - if not host_is_name: + if not host_is_local_name: try: host_addrs.extend(_async_ip_address_to_addrs(ip_address(host), port)) except ValueError: From ddacc3e41249179d2f0eea4b55cb94d28b8d4413 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Tue, 12 Dec 2023 10:17:24 -1000 Subject: [PATCH 04/20] tweaks --- aioesphomeapi/host_resolver.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/aioesphomeapi/host_resolver.py b/aioesphomeapi/host_resolver.py index 2acf8882..14bf8b05 100644 --- a/aioesphomeapi/host_resolver.py +++ b/aioesphomeapi/host_resolver.py @@ -207,7 +207,7 @@ async def async_resolve_host( host_addrs.extend(_async_ip_address_to_addrs(ip_address(host), port)) except ValueError: # Not an IP address - continue + pass if not host_addrs: host_addrs.extend(await _async_resolve_host_getaddrinfo(host, port)) From ee41f045d0ea0c3f4e4699ee455e975068f205ed Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Tue, 12 Dec 2023 10:28:53 -1000 Subject: [PATCH 05/20] fix mocking --- tests/conftest.py | 8 ++++++-- tests/test_client.py | 25 +++++++++++++++++++------ tests/test_connection.py | 3 ++- tests/test_host_resolver.py | 4 ++-- 4 files changed, 29 insertions(+), 11 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index cc5f489d..92aa7ab5 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -6,7 +6,7 @@ from dataclasses import replace from functools import partial from typing import Callable -from unittest.mock import MagicMock, patch +from unittest.mock import MagicMock, Mock, patch import pytest import pytest_asyncio @@ -119,7 +119,11 @@ def conn_with_expected_name(connection_params: ConnectionParams) -> APIConnectio @pytest.fixture() def aiohappyeyeballs_start_connection(): with patch("aioesphomeapi.connection.aiohappyeyeballs.start_connection") as func: - func.return_value = MagicMock(type=socket.SOCK_STREAM) + mock_socket = Mock() + mock_socket.type = socket.SOCK_STREAM + mock_socket.fileno.return_value = 1 + mock_socket.getpeername.return_value = ("10.0.0.512", 323) + func.return_value = mock_socket yield func diff --git a/tests/test_client.py b/tests/test_client.py index 313559de..e7686a0f 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -4,9 +4,10 @@ import contextlib import itertools import logging +import socket from functools import partial from typing import Any -from unittest.mock import AsyncMock, MagicMock, call, patch +from unittest.mock import AsyncMock, MagicMock, call, create_autospec, patch import pytest from google.protobuf import message @@ -219,9 +220,15 @@ async def test_connection_released_if_connecting_is_cancelled() -> None: cli = APIClient("1.2.3.4", 1234, None) asyncio.get_event_loop() + async def _start_connection_with_delay(*args, **kwargs): + await asyncio.sleep(1) + mock_socket = create_autospec(socket.socket, spec_set=True, instance=True) + mock_socket.getpeername.return_value = ("4.3.3.3", 323) + return mock_socket + with patch( "aioesphomeapi.connection.aiohappyeyeballs.start_connection", - side_effect=partial(asyncio.sleep, 1), + _start_connection_with_delay, ): start_task = asyncio.create_task(cli.start_connection()) await asyncio.sleep(0) @@ -232,8 +239,14 @@ async def test_connection_released_if_connecting_is_cancelled() -> None: await start_task assert cli._connection is None + async def _start_connection_without_delay(*args, **kwargs): + mock_socket = create_autospec(socket.socket, spec_set=True, instance=True) + mock_socket.getpeername.return_value = ("4.3.3.3", 323) + return mock_socket + with patch("aioesphomeapi.client.APIConnection", PatchableAPIConnection), patch( - "aioesphomeapi.connection.aiohappyeyeballs.start_connection" + "aioesphomeapi.connection.aiohappyeyeballs.start_connection", + _start_connection_without_delay, ): await cli.start_connection() await asyncio.sleep(0) @@ -894,7 +907,7 @@ class PatchableAPIClient(APIClient): ) # Make sure its not a subclassed string assert type(cli._params.noise_psk) is str - assert type(cli._params.address) is str + assert type(cli._params.addresses[0]) is str assert type(cli._params.expected_name) is str rl = ReconnectLogic( @@ -930,7 +943,7 @@ async def test_no_noise_psk(): ) # Make sure its not a subclassed string assert cli._params.noise_psk is None - assert type(cli._params.address) is str + assert type(cli._params.addresses[0]) is str assert type(cli._params.expected_name) is str @@ -945,7 +958,7 @@ async def test_empty_noise_psk_or_expected_name(): expected_name="", ) assert cli._params.noise_psk is None - assert type(cli._params.address) is str + assert type(cli._params.addresses[0]) is str assert cli._params.expected_name is None diff --git a/tests/test_connection.py b/tests/test_connection.py index 73cee88a..d3ebf65e 100644 --- a/tests/test_connection.py +++ b/tests/test_connection.py @@ -571,7 +571,8 @@ async def test_connect_resolver_times_out( "create_connection", side_effect=partial(_create_mock_transport_protocol, transport, connected), ), pytest.raises( - ResolveAPIError, match="Timeout while resolving IP address for fake.address" + ResolveAPIError, + match=r"Timeout while resolving IP address for \['fake.address'\]", ): await connect(conn, login=False) diff --git a/tests/test_host_resolver.py b/tests/test_host_resolver.py index d492cb5a..82f10950 100644 --- a/tests/test_host_resolver.py +++ b/tests/test_host_resolver.py @@ -174,7 +174,7 @@ async def test_resolve_host_mdns_no_results(resolve_addr, addr_infos): @patch("aioesphomeapi.host_resolver._async_resolve_host_getaddrinfo") async def test_resolve_host_addrinfo(resolve_addr, resolve_zc, addr_infos): resolve_addr.return_value = addr_infos - ret = await hr.async_resolve_host(["example.local"], 6052) + ret = await hr.async_resolve_host(["example.com"], 6052) resolve_zc.assert_not_called() resolve_addr.assert_called_once_with("example.com", 6052) @@ -187,7 +187,7 @@ async def test_resolve_host_addrinfo(resolve_addr, resolve_zc, addr_infos): async def test_resolve_host_addrinfo_empty(resolve_addr, resolve_zc, addr_infos): resolve_addr.return_value = [] with pytest.raises(APIConnectionError): - await hr.async_resolve_host(["example.local"], 6052) + await hr.async_resolve_host(["example.com"], 6052) resolve_zc.assert_not_called() resolve_addr.assert_called_once_with("example.com", 6052) From af7e966e989f5181340dec07ab5c9a4e41d82ebf Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Tue, 12 Dec 2023 10:31:06 -1000 Subject: [PATCH 06/20] fix mocking --- aioesphomeapi/host_resolver.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/aioesphomeapi/host_resolver.py b/aioesphomeapi/host_resolver.py index 14bf8b05..14e34c92 100644 --- a/aioesphomeapi/host_resolver.py +++ b/aioesphomeapi/host_resolver.py @@ -218,6 +218,8 @@ async def async_resolve_host( if zc_error: # Only show ZC error if getaddrinfo also didn't work raise zc_error - raise ResolveAPIError(f"Could not resolve host {host} - got no results from OS") + raise ResolveAPIError( + f"Could not resolve host {hosts} - got no results from OS" + ) return addrs From 7c753c768b71a8c45fefdcfe5dfbe4e1404d4404 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Tue, 12 Dec 2023 10:32:20 -1000 Subject: [PATCH 07/20] fix mocking --- aioesphomeapi/connection.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/aioesphomeapi/connection.py b/aioesphomeapi/connection.py index efd9cf49..20b25bbe 100644 --- a/aioesphomeapi/connection.py +++ b/aioesphomeapi/connection.py @@ -231,7 +231,7 @@ def __init__( # Message handlers currently subscribed to incoming messages self._message_handlers: dict[Any, set[Callable[[message.Message], None]]] = {} # The friendly name to show for this connection in the logs - self.log_name = log_name or params.addresses + self.log_name = log_name or ",".join(params.addresses) # futures currently subscribed to exceptions in the read task self._read_exception_futures: set[asyncio.Future[None]] = set() From 3023767440e61364a2413e3cbd405aeb52b07237 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Tue, 12 Dec 2023 10:34:17 -1000 Subject: [PATCH 08/20] revert --- tests/test_connection.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_connection.py b/tests/test_connection.py index d3ebf65e..e66d7d4f 100644 --- a/tests/test_connection.py +++ b/tests/test_connection.py @@ -572,7 +572,7 @@ async def test_connect_resolver_times_out( side_effect=partial(_create_mock_transport_protocol, transport, connected), ), pytest.raises( ResolveAPIError, - match=r"Timeout while resolving IP address for \['fake.address'\]", + match="Timeout while resolving IP address for fake.address", ): await connect(conn, login=False) From 36a350ee7f6b0e8a4adb258e650114bda856f879 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Tue, 12 Dec 2023 10:37:19 -1000 Subject: [PATCH 09/20] remove unneeded code --- aioesphomeapi/client.py | 11 ++++------- aioesphomeapi/connection.pxd | 1 - aioesphomeapi/connection.py | 5 +---- 3 files changed, 5 insertions(+), 12 deletions(-) diff --git a/aioesphomeapi/client.py b/aioesphomeapi/client.py index 007bc117..2433eaa5 100644 --- a/aioesphomeapi/client.py +++ b/aioesphomeapi/client.py @@ -284,11 +284,8 @@ def address(self) -> str: def _set_log_name(self) -> None: """Set the log name of the device.""" ip_address: str | None = None - if self._connection: - if self._connection.connected_address: - ip_address = self._connection.connected_address - elif self._connection.resolved_addr_info: - ip_address = self._connection.resolved_addr_info[0].sockaddr.address + if self._connection and self._connection.connected_address: + ip_address = self._connection.connected_address self.log_name = build_log_name( self.cached_name, self.address, @@ -336,8 +333,8 @@ async def start_connection( self.log_name, ) await self._execute_connection_coro(self._connection.start_connection()) - # If we resolved the address, we should set the log name now - if self._connection.resolved_addr_info: + # If we connected, we should set the log name now + if self._connection.connected_address: self._set_log_name() async def finish_connection( diff --git a/aioesphomeapi/connection.pxd b/aioesphomeapi/connection.pxd index af556aed..4c356b59 100644 --- a/aioesphomeapi/connection.pxd +++ b/aioesphomeapi/connection.pxd @@ -108,7 +108,6 @@ cdef class APIConnection: cdef bint _handshake_complete cdef bint _debug_enabled cdef public str received_name - cdef public object resolved_addr_info cdef public str connected_address cpdef void send_message(self, object msg) diff --git a/aioesphomeapi/connection.py b/aioesphomeapi/connection.py index 20b25bbe..f134b754 100644 --- a/aioesphomeapi/connection.py +++ b/aioesphomeapi/connection.py @@ -207,7 +207,6 @@ class APIConnection: "_handshake_complete", "_debug_enabled", "received_name", - "resolved_addr_info", "connected_address", ) @@ -252,7 +251,6 @@ def __init__( self._handshake_complete = False self._debug_enabled = debug_enabled self.received_name: str = "" - self.resolved_addr_info: list[hr.AddrInfo] = [] self.connected_address: str | None = None def set_log_name(self, name: str) -> None: @@ -572,8 +570,7 @@ def _async_pong_not_received(self) -> None: async def _do_connect(self) -> None: """Do the actual connect process.""" - self.resolved_addr_info = await self._connect_resolve_host() - await self._connect_socket_connect(self.resolved_addr_info) + await self._connect_socket_connect(await self._connect_resolve_host()) async def start_connection(self) -> None: """Start the connection process. From fe22b360972c289e1371d2461a9ead81ac0e6662 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Tue, 12 Dec 2023 10:40:32 -1000 Subject: [PATCH 10/20] remove unneeded code --- aioesphomeapi/client.py | 2 +- aioesphomeapi/util.py | 12 ++++++++---- 2 files changed, 9 insertions(+), 5 deletions(-) diff --git a/aioesphomeapi/client.py b/aioesphomeapi/client.py index 2433eaa5..fd706bfb 100644 --- a/aioesphomeapi/client.py +++ b/aioesphomeapi/client.py @@ -288,7 +288,7 @@ def _set_log_name(self) -> None: ip_address = self._connection.connected_address self.log_name = build_log_name( self.cached_name, - self.address, + self._params.addresses, ip_address, ) if self._connection: diff --git a/aioesphomeapi/util.py b/aioesphomeapi/util.py index ae226bb4..3dcbd8ad 100644 --- a/aioesphomeapi/util.py +++ b/aioesphomeapi/util.py @@ -36,11 +36,15 @@ def address_is_local(address: str) -> bool: return address.removesuffix(".").endswith(".local") -def build_log_name(name: str | None, address: str, resolved_address: str | None) -> str: +def build_log_name( + name: str | None, addresses: list[str], connected_address: str | None +) -> str: """Return a log name for a connection.""" - if not name and address_is_local(address) or host_is_name_part(address): - name = address.partition(".")[0] - preferred_address = resolved_address or address + for address in addresses: + if not name and address_is_local(address) or host_is_name_part(address): + name = address.partition(".")[0] + break + preferred_address = connected_address or address if ( name and name != preferred_address From 0820ed45989e8ab6d28d358d54d9a2bc08f9eb39 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Tue, 12 Dec 2023 10:41:01 -1000 Subject: [PATCH 11/20] remove unneeded code --- aioesphomeapi/client.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/aioesphomeapi/client.py b/aioesphomeapi/client.py index fd706bfb..cc699eed 100644 --- a/aioesphomeapi/client.py +++ b/aioesphomeapi/client.py @@ -283,13 +283,13 @@ def address(self) -> str: def _set_log_name(self) -> None: """Set the log name of the device.""" - ip_address: str | None = None + connected_address: str | None = None if self._connection and self._connection.connected_address: - ip_address = self._connection.connected_address + connected_address = self._connection.connected_address self.log_name = build_log_name( self.cached_name, self._params.addresses, - ip_address, + connected_address, ) if self._connection: self._connection.set_log_name(self.log_name) From 2805bf78a328eb4a32c045a3e2d586af22de57d9 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Tue, 12 Dec 2023 10:42:08 -1000 Subject: [PATCH 12/20] remove unneeded code --- aioesphomeapi/connection.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/aioesphomeapi/connection.py b/aioesphomeapi/connection.py index f134b754..7925edf5 100644 --- a/aioesphomeapi/connection.py +++ b/aioesphomeapi/connection.py @@ -393,11 +393,10 @@ async def _connect_socket_connect(self, addrs: list[hr.AddrInfo]) -> None: if self._debug_enabled: _LOGGER.debug( - "%s: Opened socket to %s:%s (%s)", + "%s: Opened socket to %s:%s", self.log_name, self.connected_address, self._params.port, - addrs, ) async def _connect_init_frame_helper(self) -> None: From e7c0e669ca61081b817379d68d89cb2544a8c90a Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Tue, 12 Dec 2023 10:44:31 -1000 Subject: [PATCH 13/20] remove unneeded code --- aioesphomeapi/util.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/aioesphomeapi/util.py b/aioesphomeapi/util.py index 3dcbd8ad..ad3399e3 100644 --- a/aioesphomeapi/util.py +++ b/aioesphomeapi/util.py @@ -40,11 +40,12 @@ def build_log_name( name: str | None, addresses: list[str], connected_address: str | None ) -> str: """Return a log name for a connection.""" + preferred_address = connected_address for address in addresses: if not name and address_is_local(address) or host_is_name_part(address): name = address.partition(".")[0] - break - preferred_address = connected_address or address + elif not preferred_address: + preferred_address = address if ( name and name != preferred_address From 033ccc85ff311de2d95c7c32bc49eff3a9a29ef8 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Tue, 12 Dec 2023 10:46:26 -1000 Subject: [PATCH 14/20] fixes --- aioesphomeapi/util.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/aioesphomeapi/util.py b/aioesphomeapi/util.py index ad3399e3..0ca8dae8 100644 --- a/aioesphomeapi/util.py +++ b/aioesphomeapi/util.py @@ -46,6 +46,8 @@ def build_log_name( name = address.partition(".")[0] elif not preferred_address: preferred_address = address + if not preferred_address: + return name or addresses[0] if ( name and name != preferred_address From 10c76a76b2cde47c978b85701d6fa10b32c54f35 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Tue, 12 Dec 2023 10:50:55 -1000 Subject: [PATCH 15/20] logging tweaks --- aioesphomeapi/connection.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/aioesphomeapi/connection.py b/aioesphomeapi/connection.py index 7925edf5..d371b6f3 100644 --- a/aioesphomeapi/connection.py +++ b/aioesphomeapi/connection.py @@ -338,11 +338,9 @@ async def _connect_socket_connect(self, addrs: list[hr.AddrInfo]) -> None: """Step 2 in connect process: connect the socket.""" if self._debug_enabled: _LOGGER.debug( - "%s: Connecting to %s:%s (%s)", + "%s: Connecting to %s", self.log_name, - self._params.addresses, - self._params.port, - addrs, + ", ".join(str(addr.sockaddr) for addr in addrs), ) addr_infos: list[aiohappyeyeballs.AddrInfoType] = [ From df4366b4e87b68154981d353ae0f27d4dfa85f62 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Tue, 12 Dec 2023 11:00:26 -1000 Subject: [PATCH 16/20] fix test --- tests/test__frame_helper.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/test__frame_helper.py b/tests/test__frame_helper.py index f9f28183..79d2c5e1 100644 --- a/tests/test__frame_helper.py +++ b/tests/test__frame_helper.py @@ -199,7 +199,8 @@ def mock_write_frame(self, frame: bytes) -> None: ), ], ) -def test_plaintext_frame_helper( +@pytest.mark.asyncio +async def test_plaintext_frame_helper( in_bytes: bytes, pkt_data: bytes, pkt_type: int ) -> None: for _ in range(3): From 04dcd13b2731c2c0915ce627fbe917af8738e71b Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Tue, 12 Dec 2023 11:04:17 -1000 Subject: [PATCH 17/20] debug tests --- .github/workflows/ci.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 55539dac..0a3bed95 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -86,7 +86,7 @@ jobs: - run: mypy aioesphomeapi name: Check typing with mypy if: ${{ matrix.python-version == '3.11' && matrix.extension == 'skip_cython' }} - - run: pytest -vv --cov=aioesphomeapi --cov-report=xml --tb=native tests + - run: pytest -vvvs --cov=aioesphomeapi --cov-report=xml --tb=native tests name: Run tests with pytest - name: Upload coverage to Codecov uses: codecov/codecov-action@v3 From 38d4d1d1c38f9771dc2b38b1be5c148e6eb8e067 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Tue, 12 Dec 2023 11:08:37 -1000 Subject: [PATCH 18/20] fix socket mocking --- tests/conftest.py | 15 +++------------ tests/test_client.py | 3 ++- tests/test_connection.py | 20 +++++++------------- 3 files changed, 12 insertions(+), 26 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index 92aa7ab5..4e25771f 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -6,7 +6,7 @@ from dataclasses import replace from functools import partial from typing import Callable -from unittest.mock import MagicMock, Mock, patch +from unittest.mock import MagicMock, create_autospec, patch import pytest import pytest_asyncio @@ -50,12 +50,6 @@ def resolve_host(): yield func -@pytest.fixture -def socket_socket(): - with patch("socket.socket") as func: - yield func - - @pytest.fixture def patchable_api_client() -> APIClient: class PatchableAPIClient(APIClient): @@ -119,7 +113,7 @@ def conn_with_expected_name(connection_params: ConnectionParams) -> APIConnectio @pytest.fixture() def aiohappyeyeballs_start_connection(): with patch("aioesphomeapi.connection.aiohappyeyeballs.start_connection") as func: - mock_socket = Mock() + mock_socket = create_autospec(socket.socket, spec_set=True, instance=True) mock_socket.type = socket.SOCK_STREAM mock_socket.fileno.return_value = 1 mock_socket.getpeername.return_value = ("10.0.0.512", 323) @@ -143,7 +137,6 @@ def _create_mock_transport_protocol( async def plaintext_connect_task_no_login( conn: APIConnection, resolve_host, - socket_socket, event_loop, aiohappyeyeballs_start_connection, ) -> tuple[APIConnection, asyncio.Transport, APIPlaintextFrameHelper, asyncio.Task]: @@ -165,7 +158,6 @@ async def plaintext_connect_task_no_login( async def plaintext_connect_task_no_login_with_expected_name( conn_with_expected_name: APIConnection, resolve_host, - socket_socket, event_loop, aiohappyeyeballs_start_connection, ) -> tuple[APIConnection, asyncio.Transport, APIPlaintextFrameHelper, asyncio.Task]: @@ -188,7 +180,6 @@ async def plaintext_connect_task_no_login_with_expected_name( async def plaintext_connect_task_with_login( conn_with_password: APIConnection, resolve_host, - socket_socket, event_loop, aiohappyeyeballs_start_connection, ) -> tuple[APIConnection, asyncio.Transport, APIPlaintextFrameHelper, asyncio.Task]: @@ -207,7 +198,7 @@ async def plaintext_connect_task_with_login( @pytest_asyncio.fixture(name="api_client") async def api_client( - resolve_host, socket_socket, event_loop, aiohappyeyeballs_start_connection + resolve_host, event_loop, aiohappyeyeballs_start_connection ) -> tuple[APIClient, APIConnection, asyncio.Transport, APIPlaintextFrameHelper]: protocol: APIPlaintextFrameHelper | None = None transport = MagicMock() diff --git a/tests/test_client.py b/tests/test_client.py index e7686a0f..70fe312f 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -170,7 +170,8 @@ def patch_api_version(client: APIClient, version: APIVersion): client._connection.api_version = version -def test_expected_name(auth_client: APIClient) -> None: +@pytest.mark.asyncio +async def test_expected_name(auth_client: APIClient) -> None: """Ensure expected name can be set externally.""" assert auth_client.expected_name is None auth_client.expected_name = "awesome" diff --git a/tests/test_connection.py b/tests/test_connection.py index e66d7d4f..fd4d255e 100644 --- a/tests/test_connection.py +++ b/tests/test_connection.py @@ -221,7 +221,7 @@ def on_msg(msg): @pytest.mark.asyncio async def test_start_connection_socket_error( - conn: APIConnection, resolve_host, socket_socket + conn: APIConnection, resolve_host, aiohappyeyeballs_start_connection ): """Test handling of socket error during start connection.""" loop = asyncio.get_event_loop() @@ -238,7 +238,7 @@ async def test_start_connection_socket_error( @pytest.mark.asyncio async def test_start_connection_times_out( - conn: APIConnection, resolve_host, socket_socket + conn: APIConnection, resolve_host, aiohappyeyeballs_start_connection ): """Test handling of start connection timing out.""" asyncio.get_event_loop() @@ -264,9 +264,7 @@ async def _mock_socket_connect(*args, **kwargs): @pytest.mark.asyncio -async def test_start_connection_os_error( - conn: APIConnection, resolve_host, socket_socket -): +async def test_start_connection_os_error(conn: APIConnection, resolve_host): """Test handling of start connection has an OSError.""" asyncio.get_event_loop() @@ -284,9 +282,7 @@ async def test_start_connection_os_error( @pytest.mark.asyncio -async def test_start_connection_is_cancelled( - conn: APIConnection, resolve_host, socket_socket -): +async def test_start_connection_is_cancelled(conn: APIConnection, resolve_host): """Test handling of start connection is cancelled.""" asyncio.get_event_loop() @@ -305,7 +301,7 @@ async def test_start_connection_is_cancelled( @pytest.mark.asyncio async def test_finish_connection_is_cancelled( - conn: APIConnection, resolve_host, socket_socket + conn: APIConnection, resolve_host, aiohappyeyeballs_start_connection ): """Test handling of finishing connection being cancelled.""" loop = asyncio.get_event_loop() @@ -368,7 +364,7 @@ def on_msg(msg): async def test_plaintext_connection_fails_handshake( conn: APIConnection, resolve_host: AsyncMock, - socket_socket: MagicMock, + aiohappyeyeballs_start_connection: MagicMock, exception_map: tuple[Exception, Exception], ) -> None: """Test that the frame helper is closed before the underlying socket. @@ -558,7 +554,7 @@ async def test_force_disconnect_fails( @pytest.mark.asyncio async def test_connect_resolver_times_out( - conn: APIConnection, socket_socket, event_loop, aiohappyeyeballs_start_connection + conn: APIConnection, event_loop, aiohappyeyeballs_start_connection ) -> tuple[APIConnection, asyncio.Transport, APIPlaintextFrameHelper, asyncio.Task]: transport = MagicMock() connected = asyncio.Event() @@ -582,7 +578,6 @@ async def test_disconnect_fails_to_send_response( connection_params: ConnectionParams, event_loop: asyncio.AbstractEventLoop, resolve_host, - socket_socket, aiohappyeyeballs_start_connection, ) -> None: loop = asyncio.get_event_loop() @@ -633,7 +628,6 @@ async def test_disconnect_success_case( connection_params: ConnectionParams, event_loop: asyncio.AbstractEventLoop, resolve_host, - socket_socket, aiohappyeyeballs_start_connection, ) -> None: loop = asyncio.get_event_loop() From 3f743f365a17739a58959b54e0af5a501ea7d0db Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Tue, 12 Dec 2023 11:10:43 -1000 Subject: [PATCH 19/20] more missing mocking --- tests/test__frame_helper.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tests/test__frame_helper.py b/tests/test__frame_helper.py index 79d2c5e1..640f48cf 100644 --- a/tests/test__frame_helper.py +++ b/tests/test__frame_helper.py @@ -593,7 +593,9 @@ def _writer(data: bytes): @pytest.mark.asyncio -async def test_init_plaintext_with_wrong_preamble(conn: APIConnection): +async def test_init_plaintext_with_wrong_preamble( + conn: APIConnection, aiohappyeyeballs_start_connection +): loop = asyncio.get_event_loop() protocol = get_mock_protocol(conn) with patch.object(loop, "create_connection") as create_connection: From de6242f3b8cea68353615d1c320124768bd1b563 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Tue, 12 Dec 2023 11:11:28 -1000 Subject: [PATCH 20/20] Update .github/workflows/ci.yml --- .github/workflows/ci.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 0a3bed95..55539dac 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -86,7 +86,7 @@ jobs: - run: mypy aioesphomeapi name: Check typing with mypy if: ${{ matrix.python-version == '3.11' && matrix.extension == 'skip_cython' }} - - run: pytest -vvvs --cov=aioesphomeapi --cov-report=xml --tb=native tests + - run: pytest -vv --cov=aioesphomeapi --cov-report=xml --tb=native tests name: Run tests with pytest - name: Upload coverage to Codecov uses: codecov/codecov-action@v3