diff --git a/aioesphomeapi/connection.pxd b/aioesphomeapi/connection.pxd index 4c356b59..941c7992 100644 --- a/aioesphomeapi/connection.pxd +++ b/aioesphomeapi/connection.pxd @@ -152,3 +152,5 @@ cdef class APIConnection: cdef void _set_fatal_exception_if_unset(self, Exception err) cdef void _register_internal_message_handlers(self) + + cdef void _increase_recv_buffer_size(self) diff --git a/aioesphomeapi/connection.py b/aioesphomeapi/connection.py index d371b6f3..14403b7f 100644 --- a/aioesphomeapi/connection.py +++ b/aioesphomeapi/connection.py @@ -62,7 +62,8 @@ _LOGGER = logging.getLogger(__name__) -BUFFER_SIZE = 1024 * 1024 * 2 # Set buffer limit to 2MB +PREFERRED_BUFFER_SIZE = 2097152 # Set buffer limit to 2MB +MIN_BUFFER_SIZE = 131072 # Minimum buffer size to use DISCONNECT_REQUEST_MESSAGE = DisconnectRequest() DISCONNECT_RESPONSE_MESSAGES = (DisconnectResponse(),) @@ -384,9 +385,7 @@ async def _connect_socket_connect(self, addrs: list[hr.AddrInfo]) -> None: self._socket = sock sock.setblocking(False) sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) - # Try to reduce the pressure on esphome device as it measures - # ram in bytes and we measure ram in megabytes. - sock.setsockopt(socket.SOL_SOCKET, socket.SO_RCVBUF, BUFFER_SIZE) + self._increase_recv_buffer_size() self.connected_address = sock.getpeername()[0] if self._debug_enabled: @@ -397,6 +396,32 @@ async def _connect_socket_connect(self, addrs: list[hr.AddrInfo]) -> None: self._params.port, ) + def _increase_recv_buffer_size(self) -> None: + """Increase the recv buffer size.""" + if TYPE_CHECKING: + assert self._socket is not None + new_buffer_size = PREFERRED_BUFFER_SIZE + while True: + # Try to reduce the pressure on ESPHome device as it measures + # ram in bytes and we measure ram in megabytes. + try: + self._socket.setsockopt( + socket.SOL_SOCKET, socket.SO_RCVBUF, new_buffer_size + ) + return + except OSError as err: + if new_buffer_size <= MIN_BUFFER_SIZE: + _LOGGER.warning( + "%s: Unable to increase the socket receive buffer size to %s; " + "The connection may unstable if the ESPHome device sends " + "data at volume (ex. a Bluetooth proxy or camera): %s", + self.log_name, + new_buffer_size, + err, + ) + return + new_buffer_size //= 2 + async def _connect_init_frame_helper(self) -> None: """Step 3 in connect process: initialize the frame helper and init read loop.""" fh: APIPlaintextFrameHelper | APINoiseFrameHelper diff --git a/tests/test_connection.py b/tests/test_connection.py index fd4d255e..0e601309 100644 --- a/tests/test_connection.py +++ b/tests/test_connection.py @@ -2,10 +2,11 @@ import asyncio import logging +import socket from datetime import timedelta from functools import partial from typing import Callable, cast -from unittest.mock import AsyncMock, MagicMock, call, patch +from unittest.mock import AsyncMock, MagicMock, call, create_autospec, patch import pytest from google.protobuf import message @@ -236,6 +237,103 @@ async def test_start_connection_socket_error( await asyncio.sleep(0) +@pytest.mark.asyncio +async def test_start_connection_cannot_increase_recv_buffer( + conn: APIConnection, + resolve_host, + aiohappyeyeballs_start_connection: MagicMock, + caplog: pytest.LogCaptureFixture, +): + """Test failing to increase the recv buffer.""" + loop = asyncio.get_event_loop() + transport = MagicMock() + connected = asyncio.Event() + tried_sizes = [] + + def _setsockopt(*args, **kwargs): + if args[0] == socket.SOL_SOCKET and args[1] == socket.SO_RCVBUF: + size = args[2] + tried_sizes.append(size) + raise OSError("Socket error") + + mock_socket: socket.socket = create_autospec( + socket.socket, spec_set=True, instance=True, name="bad_buffer_socket" + ) + mock_socket.type = socket.SOCK_STREAM + mock_socket.fileno.return_value = 1 + mock_socket.getpeername.return_value = ("10.0.0.512", 323) + mock_socket.setsockopt = _setsockopt + mock_socket.sendmsg.side_effect = OSError("Socket error") + aiohappyeyeballs_start_connection.return_value = mock_socket + + with patch.object( + loop, + "create_connection", + side_effect=partial(_create_mock_transport_protocol, transport, connected), + ): + connect_task = asyncio.create_task(connect(conn, login=False)) + await asyncio.sleep(0) + await connected.wait() + protocol = conn._frame_helper + send_plaintext_hello(protocol) + await connect_task + + assert "Unable to increase the socket receive buffer size to 131072" in caplog.text + assert tried_sizes == [2097152, 1048576, 524288, 262144, 131072] + + # Failure to increase the buffer size should not cause the connection to fail + assert conn.is_connected + + +@pytest.mark.asyncio +async def test_start_connection_can_only_increase_buffer_size_to_262144( + conn: APIConnection, + resolve_host, + aiohappyeyeballs_start_connection: MagicMock, + caplog: pytest.LogCaptureFixture, +): + """Test the receive buffer can only be increased to 262144.""" + loop = asyncio.get_event_loop() + transport = MagicMock() + connected = asyncio.Event() + tried_sizes = [] + + def _setsockopt(*args, **kwargs): + if args[0] == socket.SOL_SOCKET and args[1] == socket.SO_RCVBUF: + size = args[2] + tried_sizes.append(size) + if size != 262144: + raise OSError("Socket error") + + mock_socket: socket.socket = create_autospec( + socket.socket, spec_set=True, instance=True, name="bad_buffer_socket" + ) + mock_socket.type = socket.SOCK_STREAM + mock_socket.fileno.return_value = 1 + mock_socket.getpeername.return_value = ("10.0.0.512", 323) + mock_socket.setsockopt = _setsockopt + mock_socket.sendmsg.side_effect = OSError("Socket error") + aiohappyeyeballs_start_connection.return_value = mock_socket + + with patch.object( + loop, + "create_connection", + side_effect=partial(_create_mock_transport_protocol, transport, connected), + ): + connect_task = asyncio.create_task(connect(conn, login=False)) + await asyncio.sleep(0) + await connected.wait() + protocol = conn._frame_helper + send_plaintext_hello(protocol) + await connect_task + + assert "Unable to increase the socket receive buffer size" not in caplog.text + assert tried_sizes == [2097152, 1048576, 524288, 262144] + + # Failure to increase the buffer size should not cause the connection to fail + assert conn.is_connected + + @pytest.mark.asyncio async def test_start_connection_times_out( conn: APIConnection, resolve_host, aiohappyeyeballs_start_connection