diff --git a/aioesphomeapi/host_resolver.py b/aioesphomeapi/host_resolver.py index 287303e2..455d24a3 100644 --- a/aioesphomeapi/host_resolver.py +++ b/aioesphomeapi/host_resolver.py @@ -99,7 +99,10 @@ async def _async_resolve_host_zeroconf( return [] addrs: List[AddrInfo] = [] - for raw in info.addresses_by_version(zeroconf.IPVersion.All): + rawAddrs = info.addresses_by_version( + zeroconf.IPVersion.V6Only + ) + info.addresses_by_version(zeroconf.IPVersion.V4Only) + for raw in rawAddrs: is_ipv6 = len(raw) == 16 sockaddr: Sockaddr if is_ipv6: @@ -135,7 +138,8 @@ async def _async_resolve_host_getaddrinfo(host: str, port: int) -> List[AddrInfo except OSError as err: raise APIConnectionError(f"Error resolving IP address: {err}") - addrs: List[AddrInfo] = [] + addrsV4: List[AddrInfo] = [] + addrsV6: List[AddrInfo] = [] for family, type_, proto, _, raw in res: sockaddr: Sockaddr if family == socket.AF_INET: @@ -152,10 +156,13 @@ async def _async_resolve_host_getaddrinfo(host: str, port: int) -> List[AddrInfo # Unknown family continue - addrs.append( - AddrInfo(family=family, type=type_, proto=proto, sockaddr=sockaddr) - ) - return addrs + addrInfo = AddrInfo(family=family, type=type_, proto=proto, sockaddr=sockaddr) + if family == socket.AF_INET6: + addrsV6.append(addrInfo) + else: + addrsV4.append(addrInfo) + + return addrsV6 + addrsV4 def _async_ip_address_to_addrs(host: str, port: int) -> List[AddrInfo]: diff --git a/tests/test_host_resolver.py b/tests/test_host_resolver.py index 541461e3..7ace347d 100644 --- a/tests/test_host_resolver.py +++ b/tests/test_host_resolver.py @@ -1,7 +1,9 @@ import asyncio import socket +from typing import List import pytest +import zeroconf from mock import AsyncMock, MagicMock, patch import aioesphomeapi.host_resolver as hr @@ -17,12 +19,6 @@ def async_zeroconf(): @pytest.fixture def addr_infos(): return [ - hr.AddrInfo( - family=socket.AF_INET, - type=socket.SOCK_STREAM, - proto=socket.IPPROTO_TCP, - sockaddr=hr.IPv4Sockaddr(address="10.0.0.42", port=6052), - ), hr.AddrInfo( family=socket.AF_INET6, type=socket.SOCK_STREAM, @@ -34,16 +30,28 @@ def addr_infos(): scope_id=0, ), ), + hr.AddrInfo( + family=socket.AF_INET, + type=socket.SOCK_STREAM, + proto=socket.IPPROTO_TCP, + sockaddr=hr.IPv4Sockaddr(address="10.0.0.42", port=6052), + ), ] @pytest.mark.asyncio async def test_resolve_host_zeroconf(async_zeroconf, addr_infos): + def address_for_version(version: zeroconf.IPVersion) -> List[bytes]: + if version == zeroconf.IPVersion.V6Only: + return [ + b"\x20\x01\x0d\xb8\x85\xa3\x00\x00\x00\x00\x8a\x2e\x03\x70\x73\x34", + ] + return [ + b"\x0A\x00\x00\x2A", + ] + info = MagicMock() - info.addresses_by_version.return_value = [ - b"\n\x00\x00*", - b" \x01\r\xb8\x85\xa3\x00\x00\x00\x00\x8a.\x03ps4", - ] + info.addresses_by_version.side_effect = address_for_version async_zeroconf.async_get_service_info = AsyncMock(return_value=info) async_zeroconf.async_close = AsyncMock() @@ -71,18 +79,18 @@ async def test_resolve_host_getaddrinfo(event_loop, addr_infos): with patch.object(event_loop, "getaddrinfo") as mock: mock.return_value = [ ( - socket.AF_INET, + socket.AF_INET6, socket.SOCK_STREAM, socket.IPPROTO_TCP, - "canon1", - ("10.0.0.42", 6052), + "canon2", + ("2001:db8:85a3::8a2e:370:7334", 6052, 0, 0), ), ( - socket.AF_INET6, + socket.AF_INET, socket.SOCK_STREAM, socket.IPPROTO_TCP, - "canon2", - ("2001:db8:85a3::8a2e:370:7334", 6052, 0, 0), + "canon1", + ("10.0.0.42", 6052), ), (-1, socket.SOCK_STREAM, socket.IPPROTO_TCP, "canon3", ("10.0.0.42", 6052)), ]