Skip to content

Commit

Permalink
Add support for happy eyeballs (RFC 8305) (#349)
Browse files Browse the repository at this point in the history
  • Loading branch information
bdraco authored Dec 12, 2023
1 parent 48ad2b9 commit 1409cd6
Show file tree
Hide file tree
Showing 7 changed files with 189 additions and 99 deletions.
99 changes: 69 additions & 30 deletions aiohomekit/controller/ip/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,10 @@

import asyncio
import logging
import socket
from typing import TYPE_CHECKING, Any

import aiohappyeyeballs
from async_interrupt import interrupt

from aiohomekit.crypto.chacha20poly1305 import (
Expand Down Expand Up @@ -49,6 +51,22 @@
logger = logging.getLogger(__name__)


def _convert_hosts_to_addr_infos(
hosts: list[str], port: int
) -> list[aiohappyeyeballs.AddrInfoType]:
"""Converts the list of hosts to a list of addr_infos.
The list of hosts is the result of a DNS lookup. The list of
addr_infos is the result of a call to `socket.getaddrinfo()`.
"""
addr_infos: list[aiohappyeyeballs.AddrInfoType] = []
for host in hosts:
is_ipv6 = ":" in host
family = socket.AF_INET6 if is_ipv6 else socket.AF_INET
addr = (host, port, 0, 0) if is_ipv6 else (host, port)
addr_infos.append((family, socket.SOCK_STREAM, socket.IPPROTO_TCP, host, addr))
return addr_infos


class ConnectionReady(Exception):
"""Raised when a connection is ready to be retried."""

Expand All @@ -58,7 +76,6 @@ class InsecureHomeKitProtocol(asyncio.Protocol):

def __init__(self, connection: HomeKitConnection) -> None:
self.connection = connection
self.host = ":".join((connection.host, str(connection.port)))
self.result_cbs: list[asyncio.Future[HttpResponse]] = []
self.current_response = HttpResponse()
self.loop = asyncio.get_running_loop()
Expand Down Expand Up @@ -218,10 +235,10 @@ def data_received(self, data: bytes) -> None:

class HomeKitConnection:
def __init__(
self, owner: IpPairing, host: str, port: int, concurrency_limit: int = 1
self, owner: IpPairing, hosts: list[str], port: int, concurrency_limit: int = 1
) -> None:
self.owner = owner
self.host = host
self.hosts = hosts
self.port = port

self.closing: bool = False
Expand All @@ -241,13 +258,15 @@ def __init__(
self._concurrency_limit = asyncio.Semaphore(concurrency_limit)
self._reconnect_future: asyncio.Future[None] | None = None
self._last_connector_error: Exception | None = None
self.connected_host: str | None = None
self.host_header: str | None = None

@property
def name(self) -> str:
"""Return the name of the connection."""
if self.owner:
return self.owner.name
return f"{self.host}:{self.port}"
return f"{self.connected_host or self.hosts}:{self.port}"

@property
def is_connected(self) -> bool:
Expand Down Expand Up @@ -472,12 +491,9 @@ async def request(
"Connection lost before request could be sent"
)

buffer = []
buffer.append(f"{method.upper()} {target} HTTP/1.1")

# WARNING: It is vital that a Host: header is present or some devices
# will reject the request.
buffer.append(f"Host: {self.host}")
buffer = [f"{method.upper()} {target} HTTP/1.1", self.host_header]

if headers:
for header, value in headers:
Expand All @@ -502,7 +518,7 @@ async def request(
async with self._concurrency_limit:
if not self.protocol:
raise AccessoryDisconnectedError("Tried to send while not connected")
logger.debug("%s: raw request: %r", self.host, request_bytes)
logger.debug("%s: raw request: %r", self.connected_host, request_bytes)
resp = await self.protocol.send_bytes(request_bytes)

if resp.code >= 400 and resp.code <= 499:
Expand All @@ -512,7 +528,7 @@ async def request(
response=resp,
)

logger.debug("%s: raw response: %r", self.host, resp.body)
logger.debug("%s: raw response: %r", self.connected_host, resp.body)

return resp

Expand Down Expand Up @@ -550,20 +566,41 @@ async def _connect_once(self) -> None:
"""_connect_once must only ever be called from _reconnect to ensure its done with a lock."""
loop = asyncio.get_event_loop()

logger.debug("Attempting connection to %s:%s", self.host, self.port)

try:
async with asyncio_timeout(10):
self.transport, self.protocol = await loop.create_connection(
lambda: InsecureHomeKitProtocol(self), self.host, self.port
)

except asyncio.TimeoutError:
raise TimeoutError("Timeout")
logger.debug("Attempting connection to %s:%s", self.hosts, self.port)

except OSError as e:
raise ConnectionError(str(e))
addr_infos = _convert_hosts_to_addr_infos(self.hosts, self.port)

last_exception: Exception | None = None
sock: socket.socket | None = None
interleave = 1
while addr_infos:
try:
async with asyncio_timeout(10):
sock = await aiohappyeyeballs.start_connection(
addr_infos,
happy_eyeballs_delay=0.25,
interleave=interleave,
loop=self._loop,
)
break
except (OSError, asyncio.TimeoutError) as err:
last_exception = err
aiohappyeyeballs.pop_addr_infos_interleave(addr_infos, interleave)

if sock is None:
if isinstance(last_exception, asyncio.TimeoutError):
raise TimeoutError("Timeout") from last_exception
raise ConnectionError(str(last_exception)) from last_exception

self.transport, self.protocol = await loop.create_connection(
lambda: InsecureHomeKitProtocol(self), sock=sock
)
connected_host = sock.getpeername()[0]
self.connected_host = connected_host
if ":" in connected_host:
self.host_header = f"Host: [{connected_host}]:{self.port}"
else:
self.host_header = f"Host: {connected_host}:{self.port}"
if self.owner:
await self.owner.connection_made(False)

Expand All @@ -582,7 +619,7 @@ async def _reconnect(self) -> None:
async with self._connect_lock:
interval = 0.5

logger.debug("Starting reconnect loop to %s:%s", self.host, self.port)
logger.debug("Starting reconnect loop to %s:%s", self.hosts, self.port)

while not self.closing:
self._last_connector_error = None
Expand Down Expand Up @@ -638,7 +675,7 @@ def event_received(self, event: HttpResponse) -> None:
self.owner.event_received(parsed)

def __repr__(self) -> str:
return f"HomeKitConnection(host={self.host!r}, port={self.port!r})"
return f"HomeKitConnection(host={(self.connected_host or self.hosts)!r}, port={self.port!r})"


class SecureHomeKitConnection(HomeKitConnection):
Expand All @@ -647,7 +684,7 @@ class SecureHomeKitConnection(HomeKitConnection):
def __init__(self, owner: IpPairing, pairing_data: dict[str, Any]) -> None:
super().__init__(
owner,
pairing_data["AccessoryIP"],
pairing_data.get("AccessoryIPs", [pairing_data["AccessoryIP"]]),
pairing_data["AccessoryPort"],
)
self.pairing_data = pairing_data
Expand All @@ -663,14 +700,14 @@ async def _connect_once(self):
if self.owner and self.owner.description:
pairing = self.owner
try:
if self.host != pairing.description.address:
if set(self.hosts) != set(pairing.description.addresses):
logger.debug(
"%s: Host changed from %s to %s",
pairing.name,
self.host,
pairing.description.address,
self.hosts,
pairing.description.addresses,
)
self.host = pairing.description.address
self.hosts = pairing.description.addresses

if self.port != pairing.description.port:
logger.debug(
Expand Down Expand Up @@ -714,7 +751,9 @@ async def _connect_once(self):

self.is_secure = True

logger.debug("Secure connection to %s:%s established", self.host, self.port)
logger.debug(
"Secure connection to %s:%s established", self.connected_host, self.port
)

if self.owner:
await self.owner.connection_made(True)
5 changes: 4 additions & 1 deletion aiohomekit/controller/ip/discovery.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,9 @@ class IpDiscovery(ZeroconfDiscovery):
def __init__(self, controller, description: HomeKitService):
super().__init__(description)
self.controller = controller
self.connection = HomeKitConnection(None, description.address, description.port)
self.connection = HomeKitConnection(
None, description.addresses, description.port
)

def __repr__(self):
return f"IPDiscovery(host={self.description.address}, port={self.description.port})"
Expand Down Expand Up @@ -92,6 +94,7 @@ async def finish_pairing(pin: str) -> IpPairing:
break

pairing["AccessoryIP"] = self.description.address
pairing["AccessoryIPs"] = self.description.addresses
pairing["AccessoryPort"] = self.description.port
pairing["Connection"] = "IP"

Expand Down
6 changes: 4 additions & 2 deletions aiohomekit/controller/ip/pairing.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,9 +112,11 @@ def poll_interval(self) -> timedelta:
@property
def name(self) -> str:
"""Return the name of the pairing with the address."""
connection = self.connection
host = connection.connected_host or connection.hosts
if self.description:
return f"{self.description.name} [{self.connection.host}:{self.connection.port}] (id={self.id})"
return f"[{self.connection.host}:{self.connection.port}] (id={self.id})"
return f"{self.description.name} [{host}:{connection.port}] (id={self.id})"
return f"[{host}:{connection.port}] (id={self.id})"

def event_received(self, event):
self._callback_listeners(format_characteristic_list(event))
Expand Down
22 changes: 13 additions & 9 deletions aiohomekit/zeroconf.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,14 +88,16 @@ def from_service_info(cls, service: AsyncServiceInfo) -> HomeKitService:
# This means the first address will always be the most recently added
# address of the given IP version.
#
for ip_addr in addresses:
if not ip_addr.is_link_local and not ip_addr.is_unspecified:
address = str(ip_addr)
break
if not address:
valid_addresses = [
str(ip_addr)
for ip_addr in addresses
if not ip_addr.is_link_local and not ip_addr.is_unspecified
]
if not valid_addresses:
raise ValueError(
"Invalid HomeKit Zeroconf record: Missing non-link-local or unspecified address"
)
address = valid_addresses[0]

props: dict[str, str] = {
k.decode("utf-8").lower(): v.decode("utf-8")
Expand All @@ -118,7 +120,7 @@ def from_service_info(cls, service: AsyncServiceInfo) -> HomeKitService:
protocol_version=props.get("pv", "1.0"),
type=service.type,
address=address,
addresses=[str(ip_addr) for ip_addr in addresses],
addresses=valid_addresses,
port=service.port,
)

Expand All @@ -127,13 +129,13 @@ class ZeroconfServiceListener(ServiceListener):
"""An empty service listener."""

def add_service(self, zc: Zeroconf, type_: str, name: str) -> None:
pass
"""A service has been added."""

def remove_service(self, zc: Zeroconf, type_: str, name: str) -> None:
pass
"""A service has been removed."""

def update_service(self, zc: Zeroconf, type_: str, name: str) -> None:
pass
"""A service has been updated."""


def find_brower_for_hap_type(azc: AsyncZeroconf, hap_type: str) -> AsyncServiceBrowser:
Expand All @@ -158,6 +160,8 @@ def _update_from_discovery(self, description: HomeKitService):


class ZeroconfPairing(AbstractPairing):
description: HomeKitService

def _async_endpoint_changed(self) -> None:
"""The IP and/or port of the accessory has changed."""
pass
Expand Down
Loading

0 comments on commit 1409cd6

Please sign in to comment.