Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make failure to increase the recv buffer size non-fatal #802

Merged
merged 3 commits into from
Dec 19, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions aioesphomeapi/connection.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -152,3 +152,5 @@ cdef class APIConnection:
cdef void _set_fatal_exception_if_unset(self, Exception err)

cdef void _register_internal_message_handlers(self)

cdef void _increase_recv_buffer_size(self)
33 changes: 29 additions & 4 deletions aioesphomeapi/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,8 @@

_LOGGER = logging.getLogger(__name__)

BUFFER_SIZE = 1024 * 1024 * 2 # Set buffer limit to 2MB
PREFERRED_BUFFER_SIZE = 2097152 # Set buffer limit to 2MB
MIN_BUFFER_SIZE = 131072 # Minimum buffer size to use

DISCONNECT_REQUEST_MESSAGE = DisconnectRequest()
DISCONNECT_RESPONSE_MESSAGES = (DisconnectResponse(),)
Expand Down Expand Up @@ -384,9 +385,7 @@ async def _connect_socket_connect(self, addrs: list[hr.AddrInfo]) -> None:
self._socket = sock
sock.setblocking(False)
sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
# Try to reduce the pressure on esphome device as it measures
# ram in bytes and we measure ram in megabytes.
sock.setsockopt(socket.SOL_SOCKET, socket.SO_RCVBUF, BUFFER_SIZE)
self._increase_recv_buffer_size()
self.connected_address = sock.getpeername()[0]

if self._debug_enabled:
Expand All @@ -397,6 +396,32 @@ async def _connect_socket_connect(self, addrs: list[hr.AddrInfo]) -> None:
self._params.port,
)

def _increase_recv_buffer_size(self) -> None:
"""Increase the recv buffer size."""
if TYPE_CHECKING:
assert self._socket is not None
new_buffer_size = PREFERRED_BUFFER_SIZE
while True:
# Try to reduce the pressure on ESPHome device as it measures
# ram in bytes and we measure ram in megabytes.
try:
self._socket.setsockopt(
socket.SOL_SOCKET, socket.SO_RCVBUF, new_buffer_size
)
return
except OSError as err:
if new_buffer_size <= MIN_BUFFER_SIZE:
_LOGGER.warning(
"%s: Unable to increase the socket receive buffer size to %s; "
"The connection may unstable if the ESPHome device sends "
"data at volume (ex. a Bluetooth proxy or camera): %s",
self.log_name,
new_buffer_size,
err,
)
return
new_buffer_size //= 2

async def _connect_init_frame_helper(self) -> None:
"""Step 3 in connect process: initialize the frame helper and init read loop."""
fh: APIPlaintextFrameHelper | APINoiseFrameHelper
Expand Down
100 changes: 99 additions & 1 deletion tests/test_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,11 @@

import asyncio
import logging
import socket
from datetime import timedelta
from functools import partial
from typing import Callable, cast
from unittest.mock import AsyncMock, MagicMock, call, patch
from unittest.mock import AsyncMock, MagicMock, call, create_autospec, patch

import pytest
from google.protobuf import message
Expand Down Expand Up @@ -236,6 +237,103 @@ async def test_start_connection_socket_error(
await asyncio.sleep(0)


@pytest.mark.asyncio
async def test_start_connection_cannot_increase_recv_buffer(
conn: APIConnection,
resolve_host,
aiohappyeyeballs_start_connection: MagicMock,
caplog: pytest.LogCaptureFixture,
):
"""Test failing to increase the recv buffer."""
loop = asyncio.get_event_loop()
transport = MagicMock()
connected = asyncio.Event()
tried_sizes = []

def _setsockopt(*args, **kwargs):
if args[0] == socket.SOL_SOCKET and args[1] == socket.SO_RCVBUF:
size = args[2]
tried_sizes.append(size)
raise OSError("Socket error")

mock_socket: socket.socket = create_autospec(
socket.socket, spec_set=True, instance=True, name="bad_buffer_socket"
)
mock_socket.type = socket.SOCK_STREAM
mock_socket.fileno.return_value = 1
mock_socket.getpeername.return_value = ("10.0.0.512", 323)
mock_socket.setsockopt = _setsockopt
mock_socket.sendmsg.side_effect = OSError("Socket error")
aiohappyeyeballs_start_connection.return_value = mock_socket

with patch.object(
loop,
"create_connection",
side_effect=partial(_create_mock_transport_protocol, transport, connected),
):
connect_task = asyncio.create_task(connect(conn, login=False))
await asyncio.sleep(0)
await connected.wait()
protocol = conn._frame_helper
send_plaintext_hello(protocol)
await connect_task

assert "Unable to increase the socket receive buffer size to 131072" in caplog.text
assert tried_sizes == [2097152, 1048576, 524288, 262144, 131072]

# Failure to increase the buffer size should not cause the connection to fail
assert conn.is_connected


@pytest.mark.asyncio
async def test_start_connection_can_only_increase_buffer_size_to_262144(
conn: APIConnection,
resolve_host,
aiohappyeyeballs_start_connection: MagicMock,
caplog: pytest.LogCaptureFixture,
):
"""Test the receive buffer can only be increased to 262144."""
loop = asyncio.get_event_loop()
transport = MagicMock()
connected = asyncio.Event()
tried_sizes = []

def _setsockopt(*args, **kwargs):
if args[0] == socket.SOL_SOCKET and args[1] == socket.SO_RCVBUF:
size = args[2]
tried_sizes.append(size)
if size != 262144:
raise OSError("Socket error")

mock_socket: socket.socket = create_autospec(
socket.socket, spec_set=True, instance=True, name="bad_buffer_socket"
)
mock_socket.type = socket.SOCK_STREAM
mock_socket.fileno.return_value = 1
mock_socket.getpeername.return_value = ("10.0.0.512", 323)
mock_socket.setsockopt = _setsockopt
mock_socket.sendmsg.side_effect = OSError("Socket error")
aiohappyeyeballs_start_connection.return_value = mock_socket

with patch.object(
loop,
"create_connection",
side_effect=partial(_create_mock_transport_protocol, transport, connected),
):
connect_task = asyncio.create_task(connect(conn, login=False))
await asyncio.sleep(0)
await connected.wait()
protocol = conn._frame_helper
send_plaintext_hello(protocol)
await connect_task

assert "Unable to increase the socket receive buffer size" not in caplog.text
assert tried_sizes == [2097152, 1048576, 524288, 262144]

# Failure to increase the buffer size should not cause the connection to fail
assert conn.is_connected


@pytest.mark.asyncio
async def test_start_connection_times_out(
conn: APIConnection, resolve_host, aiohappyeyeballs_start_connection
Expand Down