From 01db3c2369cf1d4006286a7ab83a0ee09cbf9667 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Thu, 4 Jan 2024 14:40:41 -1000 Subject: [PATCH] Implement happy eyeballs (RFC 8305) (#7954) (cherry picked from commit c4ec3f130af3a53040b883aa2fab06fd215f9088) --- CHANGES/7954.feature | 1 + aiohttp/connector.py | 66 +++++++-- docs/client_reference.rst | 21 ++- requirements/base.txt | 2 + requirements/constraints.txt | 2 + requirements/dev.txt | 2 + requirements/runtime-deps.in | 1 + requirements/runtime-deps.txt | 2 + requirements/test.txt | 2 + setup.cfg | 1 + tests/conftest.py | 11 ++ tests/test_connector.py | 268 ++++++++++++++++++++++++++++++++-- tests/test_proxy.py | 148 ++++++++++++++++--- 13 files changed, 482 insertions(+), 45 deletions(-) create mode 100644 CHANGES/7954.feature diff --git a/CHANGES/7954.feature b/CHANGES/7954.feature new file mode 100644 index 00000000000..e536ee4b1c4 --- /dev/null +++ b/CHANGES/7954.feature @@ -0,0 +1 @@ +Implement happy eyeballs diff --git a/aiohttp/connector.py b/aiohttp/connector.py index 73f58b1a451..fd4ad6d1741 100644 --- a/aiohttp/connector.py +++ b/aiohttp/connector.py @@ -1,6 +1,7 @@ import asyncio import functools import random +import socket import sys import traceback import warnings @@ -29,6 +30,7 @@ cast, ) +import aiohappyeyeballs import attr from . import hdrs, helpers @@ -750,6 +752,10 @@ class TCPConnector(BaseConnector): limit_per_host - Number of simultaneous connections to one host. enable_cleanup_closed - Enables clean-up closed ssl transports. Disabled by default. + happy_eyeballs_delay - This is the “Connection Attempt Delay” + as defined in RFC 8305. To disable + the happy eyeballs algorithm, set to None. + interleave - “First Address Family Count” as defined in RFC 8305 loop - Optional event loop. """ @@ -772,6 +778,8 @@ def __init__( enable_cleanup_closed: bool = False, loop: Optional[asyncio.AbstractEventLoop] = None, timeout_ceil_threshold: float = 5, + happy_eyeballs_delay: Optional[float] = 0.25, + interleave: Optional[int] = None, ): super().__init__( keepalive_timeout=keepalive_timeout, @@ -792,7 +800,9 @@ def __init__( self._cached_hosts = _DNSCacheTable(ttl=ttl_dns_cache) self._throttle_dns_events: Dict[Tuple[str, int], EventResultOrError] = {} self._family = family - self._local_addr = local_addr + self._local_addr_infos = aiohappyeyeballs.addr_to_addr_infos(local_addr) + self._happy_eyeballs_delay = happy_eyeballs_delay + self._interleave = interleave def close(self) -> Awaitable[None]: """Close all ongoing DNS calls.""" @@ -980,6 +990,7 @@ def _get_fingerprint(self, req: ClientRequest) -> Optional["Fingerprint"]: async def _wrap_create_connection( self, *args: Any, + addr_infos: List[aiohappyeyeballs.AddrInfoType], req: ClientRequest, timeout: "ClientTimeout", client_error: Type[Exception] = ClientConnectorError, @@ -989,7 +1000,14 @@ async def _wrap_create_connection( async with ceil_timeout( timeout.sock_connect, ceil_threshold=timeout.ceil_threshold ): - return await self._loop.create_connection(*args, **kwargs) + sock = await aiohappyeyeballs.start_connection( + addr_infos=addr_infos, + local_addr_infos=self._local_addr_infos, + happy_eyeballs_delay=self._happy_eyeballs_delay, + interleave=self._interleave, + loop=self._loop, + ) + return await self._loop.create_connection(*args, **kwargs, sock=sock) except cert_errors as exc: raise ClientConnectorCertificateError(req.connection_key, exc) from exc except ssl_errors as exc: @@ -1143,6 +1161,27 @@ async def _start_tls_connection( return tls_transport, tls_proto + def _convert_hosts_to_addr_infos( + self, hosts: List[Dict[str, Any]] + ) -> List[aiohappyeyeballs.AddrInfoType]: + """Converts the list of hosts to a list of addr_infos. + + The list of hosts is the result of a DNS lookup. The list of + addr_infos is the result of a call to `socket.getaddrinfo()`. + """ + addr_infos: List[aiohappyeyeballs.AddrInfoType] = [] + for hinfo in hosts: + host = hinfo["host"] + is_ipv6 = ":" in host + family = socket.AF_INET6 if is_ipv6 else socket.AF_INET + if self._family and self._family != family: + continue + addr = (host, hinfo["port"], 0, 0) if is_ipv6 else (host, hinfo["port"]) + addr_infos.append( + (family, socket.SOCK_STREAM, socket.IPPROTO_TCP, "", addr) + ) + return addr_infos + async def _create_direct_connection( self, req: ClientRequest, @@ -1187,36 +1226,27 @@ def drop_exception(fut: "asyncio.Future[List[Dict[str, Any]]]") -> None: raise ClientConnectorError(req.connection_key, exc) from exc last_exc: Optional[Exception] = None - - for hinfo in hosts: - host = hinfo["host"] - port = hinfo["port"] - + addr_infos = self._convert_hosts_to_addr_infos(hosts) + while addr_infos: # Strip trailing dots, certificates contain FQDN without dots. # See https://github.com/aio-libs/aiohttp/issues/3636 server_hostname = ( - (req.server_hostname or hinfo["hostname"]).rstrip(".") - if sslcontext - else None + (req.server_hostname or host).rstrip(".") if sslcontext else None ) try: transp, proto = await self._wrap_create_connection( self._factory, - host, - port, timeout=timeout, ssl=sslcontext, - family=hinfo["family"], - proto=hinfo["proto"], - flags=hinfo["flags"], + addr_infos=addr_infos, server_hostname=server_hostname, - local_addr=self._local_addr, req=req, client_error=client_error, ) except ClientConnectorError as exc: last_exc = exc + aiohappyeyeballs.pop_addr_infos_interleave(addr_infos, self._interleave) continue if req.is_ssl() and fingerprint: @@ -1227,6 +1257,10 @@ def drop_exception(fut: "asyncio.Future[List[Dict[str, Any]]]") -> None: if not self._cleanup_closed_disabled: self._cleanup_closed_transports.append(transp) last_exc = exc + # Remove the bad peer from the list of addr_infos + sock: socket.socket = transp.get_extra_info("socket") + bad_peer = sock.getpeername() + aiohappyeyeballs.remove_addr_infos(addr_infos, bad_peer) continue return transp, proto diff --git a/docs/client_reference.rst b/docs/client_reference.rst index 57e96f2a070..93b3459ba7c 100644 --- a/docs/client_reference.rst +++ b/docs/client_reference.rst @@ -1071,7 +1071,8 @@ is controlled by *force_close* constructor's parameter). family=0, ssl_context=None, local_addr=None, \ resolver=None, keepalive_timeout=sentinel, \ force_close=False, limit=100, limit_per_host=0, \ - enable_cleanup_closed=False, loop=None) + enable_cleanup_closed=False, timeout_ceil_threshold=5, \ + happy_eyeballs_delay=0.25, interleave=None, loop=None) Connector for working with *HTTP* and *HTTPS* via *TCP* sockets. @@ -1174,6 +1175,24 @@ is controlled by *force_close* constructor's parameter). If this parameter is set to True, aiohttp additionally aborts underlining transport after 2 seconds. It is off by default. + :param float happy_eyeballs_delay: The amount of time in seconds to wait for a + connection attempt to complete, before starting the next attempt in parallel. + This is the “Connection Attempt Delay” as defined in RFC 8305. To disable + Happy Eyeballs, set this to ``None``. The default value recommended by the + RFC is 0.25 (250 milliseconds). + + .. versionadded:: 3.10 + + :param int interleave: controls address reordering when a host name resolves + to multiple IP addresses. If ``0`` or unspecified, no reordering is done, and + addresses are tried in the order returned by the resolver. If a positive + integer is specified, the addresses are interleaved by address family, and + the given integer is interpreted as “First Address Family Count” as defined + in RFC 8305. The default is ``0`` if happy_eyeballs_delay is not specified, and + ``1`` if it is. + + .. versionadded:: 3.10 + .. attribute:: family *TCP* socket family e.g. :data:`socket.AF_INET` or diff --git a/requirements/base.txt b/requirements/base.txt index 4640010dca2..99f1b5ab9d7 100644 --- a/requirements/base.txt +++ b/requirements/base.txt @@ -6,6 +6,8 @@ # aiodns==3.1.1 ; sys_platform == "linux" or sys_platform == "darwin" # via -r requirements/runtime-deps.in +aiohappyeyeballs==2.3.0 + # via -r requirements/runtime-deps.in aiosignal==1.3.1 # via -r requirements/runtime-deps.in async-timeout==4.0.3 ; python_version < "3.11" diff --git a/requirements/constraints.txt b/requirements/constraints.txt index c177e50bbd9..b6d150fa076 100644 --- a/requirements/constraints.txt +++ b/requirements/constraints.txt @@ -6,6 +6,8 @@ # aiodns==3.1.1 ; sys_platform == "linux" or sys_platform == "darwin" # via -r requirements/runtime-deps.in +aiohappyeyeballs==2.3.0 + # via -r requirements/runtime-deps.in aiohttp-theme==0.1.6 # via -r requirements/doc.in aioredis==2.0.1 diff --git a/requirements/dev.txt b/requirements/dev.txt index 5939bcd8fef..f88c11d3033 100644 --- a/requirements/dev.txt +++ b/requirements/dev.txt @@ -6,6 +6,8 @@ # aiodns==3.1.1 ; sys_platform == "linux" or sys_platform == "darwin" # via -r requirements/runtime-deps.in +aiohappyeyeballs==2.3.0 + # via -r requirements/runtime-deps.in aiohttp-theme==0.1.6 # via -r requirements/doc.in aioredis==2.0.1 diff --git a/requirements/runtime-deps.in b/requirements/runtime-deps.in index b2df16f1680..70bd75bd99d 100644 --- a/requirements/runtime-deps.in +++ b/requirements/runtime-deps.in @@ -1,6 +1,7 @@ # Extracted from `setup.cfg` via `make sync-direct-runtime-deps` aiodns; sys_platform=="linux" or sys_platform=="darwin" +aiohappyeyeballs >= 2.3.0 aiosignal >= 1.1.2 async-timeout >= 4.0, < 5.0 ; python_version < "3.11" attrs >= 17.3.0 diff --git a/requirements/runtime-deps.txt b/requirements/runtime-deps.txt index 2d4df7df38c..6c1e407eec3 100644 --- a/requirements/runtime-deps.txt +++ b/requirements/runtime-deps.txt @@ -6,6 +6,8 @@ # aiodns==3.1.1 ; sys_platform == "linux" or sys_platform == "darwin" # via -r requirements/runtime-deps.in +aiohappyeyeballs==2.3.0 + # via -r requirements/runtime-deps.in aiosignal==1.3.1 # via -r requirements/runtime-deps.in async-timeout==4.0.3 ; python_version < "3.11" diff --git a/requirements/test.txt b/requirements/test.txt index f08c7fd1788..72fb6a40e56 100644 --- a/requirements/test.txt +++ b/requirements/test.txt @@ -6,6 +6,8 @@ # aiodns==3.1.1 ; sys_platform == "linux" or sys_platform == "darwin" # via -r requirements/runtime-deps.in +aiohappyeyeballs==2.3.0 + # via -r requirements/runtime-deps.in aiosignal==1.3.1 # via -r requirements/runtime-deps.in annotated-types==0.5.0 diff --git a/setup.cfg b/setup.cfg index 331c80e154a..71dc26c9789 100644 --- a/setup.cfg +++ b/setup.cfg @@ -47,6 +47,7 @@ zip_safe = False include_package_data = True install_requires = + aiohappyeyeballs >= 2.3.0 aiosignal >= 1.1.2 async-timeout >= 4.0, < 5.0 ; python_version < "3.11" attrs >= 17.3.0 diff --git a/tests/conftest.py b/tests/conftest.py index 44e5fb7285c..fcdb482a59f 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -6,6 +6,7 @@ from hashlib import md5, sha256 from pathlib import Path from tempfile import TemporaryDirectory +from unittest import mock from uuid import uuid4 import pytest @@ -197,3 +198,13 @@ def netrc_contents( monkeypatch.setenv("NETRC", str(netrc_file_path)) return netrc_file_path + + +@pytest.fixture +def start_connection(): + with mock.patch( + "aiohttp.connector.aiohappyeyeballs.start_connection", + autospec=True, + spec_set=True, + ) as start_connection_mock: + yield start_connection_mock diff --git a/tests/test_connector.py b/tests/test_connector.py index f27d4131049..1faec002487 100644 --- a/tests/test_connector.py +++ b/tests/test_connector.py @@ -10,9 +10,11 @@ import uuid from collections import deque from contextlib import closing +from typing import Any, List, Optional from unittest import mock import pytest +from aiohappyeyeballs import AddrInfoType from yarl import URL import aiohttp @@ -539,7 +541,9 @@ async def test__drop_acquire_per_host3(loop) -> None: assert conn._acquired_per_host[123] == {789} -async def test_tcp_connector_certificate_error(loop) -> None: +async def test_tcp_connector_certificate_error( + loop: Any, start_connection: mock.AsyncMock +) -> None: req = ClientRequest("GET", URL("https://127.0.0.1:443"), loop=loop) async def certificate_error(*args, **kwargs): @@ -556,8 +560,10 @@ async def certificate_error(*args, **kwargs): assert isinstance(ctx.value, aiohttp.ClientSSLError) -async def test_tcp_connector_server_hostname_default(loop) -> None: - conn = aiohttp.TCPConnector(loop=loop) +async def test_tcp_connector_server_hostname_default( + loop: Any, start_connection: mock.AsyncMock +) -> None: + conn = aiohttp.TCPConnector() with mock.patch.object( conn._loop, "create_connection", autospec=True, spec_set=True @@ -570,8 +576,10 @@ async def test_tcp_connector_server_hostname_default(loop) -> None: assert create_connection.call_args.kwargs["server_hostname"] == "127.0.0.1" -async def test_tcp_connector_server_hostname_override(loop) -> None: - conn = aiohttp.TCPConnector(loop=loop) +async def test_tcp_connector_server_hostname_override( + loop: Any, start_connection: mock.AsyncMock +) -> None: + conn = aiohttp.TCPConnector() with mock.patch.object( conn._loop, "create_connection", autospec=True, spec_set=True @@ -595,6 +603,7 @@ async def test_tcp_connector_multiple_hosts_errors(loop) -> None: ip4 = "192.168.1.4" ip5 = "192.168.1.5" ips = [ip1, ip2, ip3, ip4, ip5] + addrs_tried = [] ips_tried = [] fingerprint = hashlib.sha256(b"foo").digest() @@ -624,11 +633,24 @@ async def _resolve_host(host, port, traces=None): os_error = certificate_error = ssl_error = fingerprint_error = False connected = False + async def start_connection(*args, **kwargs): + addr_infos: List[AddrInfoType] = kwargs["addr_infos"] + + first_addr_info = addr_infos[0] + first_addr_info_addr = first_addr_info[-1] + addrs_tried.append(first_addr_info_addr) + + mock_socket = mock.create_autospec(socket.socket, spec_set=True, instance=True) + mock_socket.getpeername.return_value = first_addr_info_addr + return mock_socket + async def create_connection(*args, **kwargs): nonlocal os_error, certificate_error, ssl_error, fingerprint_error nonlocal connected - ip = args[1] + sock = kwargs["sock"] + addr_info = sock.getpeername() + ip = addr_info[0] ips_tried.append(ip) @@ -645,6 +667,12 @@ async def create_connection(*args, **kwargs): raise ssl.SSLError if ip == ip4: + sock: socket.socket = kwargs["sock"] + + # Close the socket since we are not actually connecting + # and we don't want to leak it. + sock.close() + fingerprint_error = True tr, pr = mock.Mock(), mock.Mock() @@ -660,12 +688,21 @@ def get_extra_info(param): if param == "peername": return ("192.168.1.5", 12345) + if param == "socket": + return sock + assert False, param tr.get_extra_info = get_extra_info return tr, pr if ip == ip5: + sock: socket.socket = kwargs["sock"] + + # Close the socket since we are not actually connecting + # and we don't want to leak it. + sock.close() + connected = True tr, pr = mock.Mock(), mock.Mock() @@ -687,8 +724,13 @@ def get_extra_info(param): conn._loop.create_connection = create_connection - established_connection = await conn.connect(req, [], ClientTimeout()) - assert ips == ips_tried + with mock.patch( + "aiohttp.connector.aiohappyeyeballs.start_connection", start_connection + ): + established_connection = await conn.connect(req, [], ClientTimeout()) + + assert ips_tried == ips + assert addrs_tried == [(ip, 443) for ip in ips] assert os_error assert certificate_error @@ -699,8 +741,214 @@ def get_extra_info(param): established_connection.close() -async def test_tcp_connector_resolve_host(loop) -> None: - conn = aiohttp.TCPConnector(loop=loop, use_dns_cache=True) +@pytest.mark.parametrize( + ("happy_eyeballs_delay"), + [0.1, 0.25, None], +) +async def test_tcp_connector_happy_eyeballs( + loop: Any, happy_eyeballs_delay: Optional[float] +) -> None: + conn = aiohttp.TCPConnector(happy_eyeballs_delay=happy_eyeballs_delay) + + ip1 = "dead::beef::" + ip2 = "192.168.1.1" + ips = [ip1, ip2] + addrs_tried = [] + + req = ClientRequest( + "GET", + URL("https://mocked.host"), + loop=loop, + ) + + async def _resolve_host(host, port, traces=None): + return [ + { + "hostname": host, + "host": ip, + "port": port, + "family": socket.AF_INET6 if ":" in ip else socket.AF_INET, + "proto": 0, + "flags": socket.AI_NUMERICHOST, + } + for ip in ips + ] + + conn._resolve_host = _resolve_host + + os_error = False + connected = False + + async def sock_connect(*args, **kwargs): + addr = args[1] + nonlocal os_error + + addrs_tried.append(addr) + + if addr[0] == ip1: + os_error = True + raise OSError + + async def create_connection(*args, **kwargs): + sock: socket.socket = kwargs["sock"] + + # Close the socket since we are not actually connecting + # and we don't want to leak it. + sock.close() + + nonlocal connected + connected = True + tr = create_mocked_conn(loop) + pr = create_mocked_conn(loop) + return tr, pr + + conn._loop.sock_connect = sock_connect + conn._loop.create_connection = create_connection + + established_connection = await conn.connect(req, [], ClientTimeout()) + + assert addrs_tried == [(ip1, 443, 0, 0), (ip2, 443)] + + assert os_error + assert connected + + established_connection.close() + + +async def test_tcp_connector_interleave(loop: Any) -> None: + conn = aiohttp.TCPConnector(interleave=2) + + ip1 = "192.168.1.1" + ip2 = "192.168.1.2" + ip3 = "dead::beef::" + ip4 = "aaaa::beef::" + ip5 = "192.168.1.5" + ips = [ip1, ip2, ip3, ip4, ip5] + success_ips = [] + interleave = None + + req = ClientRequest( + "GET", + URL("https://mocked.host"), + loop=loop, + ) + + async def _resolve_host(host, port, traces=None): + return [ + { + "hostname": host, + "host": ip, + "port": port, + "family": socket.AF_INET6 if ":" in ip else socket.AF_INET, + "proto": 0, + "flags": socket.AI_NUMERICHOST, + } + for ip in ips + ] + + conn._resolve_host = _resolve_host + + async def start_connection(*args, **kwargs): + nonlocal interleave + addr_infos: List[AddrInfoType] = kwargs["addr_infos"] + interleave = kwargs["interleave"] + # Mock the 4th host connecting successfully + fourth_addr_info = addr_infos[3] + fourth_addr_info_addr = fourth_addr_info[-1] + mock_socket = mock.create_autospec(socket.socket, spec_set=True, instance=True) + mock_socket.getpeername.return_value = fourth_addr_info_addr + return mock_socket + + async def create_connection(*args, **kwargs): + sock = kwargs["sock"] + addr_info = sock.getpeername() + ip = addr_info[0] + + success_ips.append(ip) + + sock: socket.socket = kwargs["sock"] + # Close the socket since we are not actually connecting + # and we don't want to leak it. + sock.close() + tr = create_mocked_conn(loop) + pr = create_mocked_conn(loop) + return tr, pr + + conn._loop.create_connection = create_connection + + with mock.patch( + "aiohttp.connector.aiohappyeyeballs.start_connection", start_connection + ): + established_connection = await conn.connect(req, [], ClientTimeout()) + + assert success_ips == [ip4] + assert interleave == 2 + established_connection.close() + + +async def test_tcp_connector_family_is_respected(loop: Any) -> None: + conn = aiohttp.TCPConnector(family=socket.AF_INET) + + ip1 = "dead::beef::" + ip2 = "192.168.1.1" + ips = [ip1, ip2] + addrs_tried = [] + + req = ClientRequest( + "GET", + URL("https://mocked.host"), + loop=loop, + ) + + async def _resolve_host(host, port, traces=None): + return [ + { + "hostname": host, + "host": ip, + "port": port, + "family": socket.AF_INET6 if ":" in ip else socket.AF_INET, + "proto": 0, + "flags": socket.AI_NUMERICHOST, + } + for ip in ips + ] + + conn._resolve_host = _resolve_host + connected = False + + async def sock_connect(*args, **kwargs): + addr = args[1] + addrs_tried.append(addr) + + async def create_connection(*args, **kwargs): + sock: socket.socket = kwargs["sock"] + + # Close the socket since we are not actually connecting + # and we don't want to leak it. + sock.close() + + nonlocal connected + connected = True + tr = create_mocked_conn(loop) + pr = create_mocked_conn(loop) + return tr, pr + + conn._loop.sock_connect = sock_connect + conn._loop.create_connection = create_connection + + established_connection = await conn.connect(req, [], ClientTimeout()) + + # We should only try the IPv4 address since we specified + # the family to be AF_INET + assert addrs_tried == [(ip2, 443)] + + assert connected + + established_connection.close() + + +async def test_tcp_connector_resolve_host(loop: Any) -> None: + conn = aiohttp.TCPConnector(use_dns_cache=True) res = await conn._resolve_host("localhost", 8080) assert res diff --git a/tests/test_proxy.py b/tests/test_proxy.py index 1ff53e3f899..2a8643f5047 100644 --- a/tests/test_proxy.py +++ b/tests/test_proxy.py @@ -4,6 +4,7 @@ import ssl import sys import unittest +from typing import Any from unittest import mock import pytest @@ -40,7 +41,12 @@ def tearDown(self): gc.collect() @mock.patch("aiohttp.connector.ClientRequest") - def test_connect(self, ClientRequestMock) -> None: + @mock.patch( + "aiohttp.connector.aiohappyeyeballs.start_connection", + autospec=True, + spec_set=True, + ) + def test_connect(self, start_connection: Any, ClientRequestMock: Any) -> None: req = ClientRequest( "GET", URL("http://www.python.org"), @@ -54,7 +60,18 @@ async def make_conn(): return aiohttp.TCPConnector() connector = self.loop.run_until_complete(make_conn()) - connector._resolve_host = make_mocked_coro([mock.MagicMock()]) + connector._resolve_host = make_mocked_coro( + [ + { + "hostname": "hostname", + "host": "127.0.0.1", + "port": 80, + "family": socket.AF_INET, + "proto": 0, + "flags": 0, + } + ] + ) proto = mock.Mock( **{ @@ -81,7 +98,12 @@ async def make_conn(): conn.close() @mock.patch("aiohttp.connector.ClientRequest") - def test_proxy_headers(self, ClientRequestMock) -> None: + @mock.patch( + "aiohttp.connector.aiohappyeyeballs.start_connection", + autospec=True, + spec_set=True, + ) + def test_proxy_headers(self, start_connection: Any, ClientRequestMock: Any) -> None: req = ClientRequest( "GET", URL("http://www.python.org"), @@ -96,7 +118,18 @@ async def make_conn(): return aiohttp.TCPConnector() connector = self.loop.run_until_complete(make_conn()) - connector._resolve_host = make_mocked_coro([mock.MagicMock()]) + connector._resolve_host = make_mocked_coro( + [ + { + "hostname": "hostname", + "host": "127.0.0.1", + "port": 80, + "family": socket.AF_INET, + "proto": 0, + "flags": 0, + } + ] + ) proto = mock.Mock( **{ @@ -122,7 +155,12 @@ async def make_conn(): conn.close() - def test_proxy_auth(self) -> None: + @mock.patch( + "aiohttp.connector.aiohappyeyeballs.start_connection", + autospec=True, + spec_set=True, + ) + def test_proxy_auth(self, start_connection: Any) -> None: with self.assertRaises(ValueError) as ctx: ClientRequest( "GET", @@ -136,11 +174,16 @@ def test_proxy_auth(self) -> None: "proxy_auth must be None or BasicAuth() tuple", ) - def test_proxy_dns_error(self) -> None: + @mock.patch( + "aiohttp.connector.aiohappyeyeballs.start_connection", + autospec=True, + spec_set=True, + ) + def test_proxy_dns_error(self, start_connection: Any) -> None: async def make_conn(): return aiohttp.TCPConnector() - connector = self.loop.run_until_complete(make_conn()) + connector: aiohttp.TCPConnector = self.loop.run_until_complete(make_conn()) connector._resolve_host = make_mocked_coro( raise_exception=OSError("dont take it serious") ) @@ -159,7 +202,12 @@ async def make_conn(): self.assertEqual(req.url.path, "/") self.assertEqual(dict(req.headers), expected_headers) - def test_proxy_connection_error(self) -> None: + @mock.patch( + "aiohttp.connector.aiohappyeyeballs.start_connection", + autospec=True, + spec_set=True, + ) + def test_proxy_connection_error(self, start_connection: Any) -> None: async def make_conn(): return aiohttp.TCPConnector() @@ -192,7 +240,14 @@ async def make_conn(): ) @mock.patch("aiohttp.connector.ClientRequest") - def test_proxy_server_hostname_default(self, ClientRequestMock) -> None: + @mock.patch( + "aiohttp.connector.aiohappyeyeballs.start_connection", + autospec=True, + spec_set=True, + ) + def test_proxy_server_hostname_default( + self, start_connection: Any, ClientRequestMock: Any + ) -> None: proxy_req = ClientRequest( "GET", URL("http://proxy.example.com"), loop=self.loop ) @@ -252,7 +307,14 @@ async def make_conn(): self.loop.run_until_complete(req.close()) @mock.patch("aiohttp.connector.ClientRequest") - def test_proxy_server_hostname_override(self, ClientRequestMock) -> None: + @mock.patch( + "aiohttp.connector.aiohappyeyeballs.start_connection", + autospec=True, + spec_set=True, + ) + def test_proxy_server_hostname_override( + self, start_connection: Any, ClientRequestMock: Any + ) -> None: proxy_req = ClientRequest( "GET", URL("http://proxy.example.com"), @@ -316,7 +378,12 @@ async def make_conn(): self.loop.run_until_complete(req.close()) @mock.patch("aiohttp.connector.ClientRequest") - def test_https_connect(self, ClientRequestMock) -> None: + @mock.patch( + "aiohttp.connector.aiohappyeyeballs.start_connection", + autospec=True, + spec_set=True, + ) + def test_https_connect(self, start_connection: Any, ClientRequestMock: Any) -> None: proxy_req = ClientRequest( "GET", URL("http://proxy.example.com"), loop=self.loop ) @@ -376,7 +443,14 @@ async def make_conn(): self.loop.run_until_complete(req.close()) @mock.patch("aiohttp.connector.ClientRequest") - def test_https_connect_certificate_error(self, ClientRequestMock) -> None: + @mock.patch( + "aiohttp.connector.aiohappyeyeballs.start_connection", + autospec=True, + spec_set=True, + ) + def test_https_connect_certificate_error( + self, start_connection: Any, ClientRequestMock: Any + ) -> None: proxy_req = ClientRequest( "GET", URL("http://proxy.example.com"), loop=self.loop ) @@ -430,7 +504,14 @@ async def make_conn(): ) @mock.patch("aiohttp.connector.ClientRequest") - def test_https_connect_ssl_error(self, ClientRequestMock) -> None: + @mock.patch( + "aiohttp.connector.aiohappyeyeballs.start_connection", + autospec=True, + spec_set=True, + ) + def test_https_connect_ssl_error( + self, start_connection: Any, ClientRequestMock: Any + ) -> None: proxy_req = ClientRequest( "GET", URL("http://proxy.example.com"), loop=self.loop ) @@ -486,7 +567,14 @@ async def make_conn(): ) @mock.patch("aiohttp.connector.ClientRequest") - def test_https_connect_http_proxy_error(self, ClientRequestMock) -> None: + @mock.patch( + "aiohttp.connector.aiohappyeyeballs.start_connection", + autospec=True, + spec_set=True, + ) + def test_https_connect_http_proxy_error( + self, start_connection: Any, ClientRequestMock: Any + ) -> None: proxy_req = ClientRequest( "GET", URL("http://proxy.example.com"), loop=self.loop ) @@ -545,7 +633,14 @@ async def make_conn(): self.loop.run_until_complete(req.close()) @mock.patch("aiohttp.connector.ClientRequest") - def test_https_connect_resp_start_error(self, ClientRequestMock) -> None: + @mock.patch( + "aiohttp.connector.aiohappyeyeballs.start_connection", + autospec=True, + spec_set=True, + ) + def test_https_connect_resp_start_error( + self, start_connection: Any, ClientRequestMock: Any + ) -> None: proxy_req = ClientRequest( "GET", URL("http://proxy.example.com"), loop=self.loop ) @@ -598,7 +693,12 @@ async def make_conn(): ) @mock.patch("aiohttp.connector.ClientRequest") - def test_request_port(self, ClientRequestMock) -> None: + @mock.patch( + "aiohttp.connector.aiohappyeyeballs.start_connection", + autospec=True, + spec_set=True, + ) + def test_request_port(self, start_connection: Any, ClientRequestMock: Any) -> None: proxy_req = ClientRequest( "GET", URL("http://proxy.example.com"), loop=self.loop ) @@ -656,7 +756,14 @@ def test_proxy_auth_property_default(self) -> None: self.assertIsNone(req.proxy_auth) @mock.patch("aiohttp.connector.ClientRequest") - def test_https_connect_pass_ssl_context(self, ClientRequestMock) -> None: + @mock.patch( + "aiohttp.connector.aiohappyeyeballs.start_connection", + autospec=True, + spec_set=True, + ) + def test_https_connect_pass_ssl_context( + self, start_connection: Any, ClientRequestMock: Any + ) -> None: proxy_req = ClientRequest( "GET", URL("http://proxy.example.com"), loop=self.loop ) @@ -724,7 +831,12 @@ async def make_conn(): self.loop.run_until_complete(req.close()) @mock.patch("aiohttp.connector.ClientRequest") - def test_https_auth(self, ClientRequestMock) -> None: + @mock.patch( + "aiohttp.connector.aiohappyeyeballs.start_connection", + autospec=True, + spec_set=True, + ) + def test_https_auth(self, start_connection: Any, ClientRequestMock: Any) -> None: proxy_req = ClientRequest( "GET", URL("http://proxy.example.com"),