Skip to content

Commit

Permalink
Merge remote-tracking branch 'upstream/branch-0.35' into update-deps
Browse files Browse the repository at this point in the history
  • Loading branch information
bdice committed Oct 23, 2023
2 parents 1da88f0 + 40e30e8 commit 859e668
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 17 deletions.
19 changes: 13 additions & 6 deletions python/ucxx/_lib_async/application_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

import ucxx._lib.libucxx as ucx_api
from ucxx._lib.arr import Array
from ucxx.exceptions import UCXMessageTruncatedError

from .continuous_ucx_progress import PollingMode, ThreadMode
from .endpoint import Endpoint
Expand Down Expand Up @@ -271,12 +272,18 @@ async def create_endpoint(self, ip_address, port, endpoint_error_handling=True):
seed = os.urandom(16)
msg_tag = hash64bits("msg_tag", seed, ucx_ep.handle)
ctrl_tag = hash64bits("ctrl_tag", seed, ucx_ep.handle)
peer_info = await exchange_peer_info(
endpoint=ucx_ep,
msg_tag=msg_tag,
ctrl_tag=ctrl_tag,
listener=False,
)
try:
peer_info = await exchange_peer_info(
endpoint=ucx_ep,
msg_tag=msg_tag,
ctrl_tag=ctrl_tag,
listener=False,
)
except UCXMessageTruncatedError:
# A truncated message occurs if the remote endpoint closed before
# exchanging peer info, in that case we should raise the endpoint
# error instead.
ucx_ep.raise_on_error()
tags = {
"msg_send": peer_info["msg_tag"],
"msg_recv": msg_tag,
Expand Down
11 changes: 6 additions & 5 deletions python/ucxx/_lib_async/exchange_peer_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# SPDX-License-Identifier: BSD-3-Clause


import asyncio
import logging
import struct

Expand All @@ -12,7 +13,7 @@
logger = logging.getLogger("ucx")


async def exchange_peer_info(endpoint, msg_tag, ctrl_tag, listener):
async def exchange_peer_info(endpoint, msg_tag, ctrl_tag, listener, stream_timeout=5.0):
"""Help function that exchange endpoint information"""

# Pack peer information incl. a checksum
Expand All @@ -26,14 +27,14 @@ async def exchange_peer_info(endpoint, msg_tag, ctrl_tag, listener):
# streaming calls (see <https://github.com/rapidsai/ucx-py/pull/509>)
if listener is True:
req = endpoint.stream_send(my_info_arr)
await req.wait()
await asyncio.wait_for(req.wait(), timeout=stream_timeout)
req = endpoint.stream_recv(peer_info_arr)
await req.wait()
await asyncio.wait_for(req.wait(), timeout=stream_timeout)
else:
req = endpoint.stream_recv(peer_info_arr)
await req.wait()
await asyncio.wait_for(req.wait(), timeout=stream_timeout)
req = endpoint.stream_send(my_info_arr)
await req.wait()
await asyncio.wait_for(req.wait(), timeout=stream_timeout)

# Unpacking and sanity check of the peer information
ret = {}
Expand Down
19 changes: 13 additions & 6 deletions python/ucxx/_lib_async/listener.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import threading

import ucxx._lib.libucxx as ucx_api
from ucxx.exceptions import UCXMessageTruncatedError

from .endpoint import Endpoint
from .exchange_peer_info import exchange_peer_info
Expand Down Expand Up @@ -132,12 +133,18 @@ async def _listener_handler_coroutine(
msg_tag = hash64bits("msg_tag", seed, endpoint.handle)
ctrl_tag = hash64bits("ctrl_tag", seed, endpoint.handle)

peer_info = await exchange_peer_info(
endpoint=endpoint,
msg_tag=msg_tag,
ctrl_tag=ctrl_tag,
listener=True,
)
try:
peer_info = await exchange_peer_info(
endpoint=endpoint,
msg_tag=msg_tag,
ctrl_tag=ctrl_tag,
listener=True,
)
except UCXMessageTruncatedError:
# A truncated message occurs if the remote endpoint closed before
# exchanging peer info, in that case we should raise the endpoint
# error instead.
endpoint.raise_on_error()
tags = {
"msg_send": peer_info["msg_tag"],
"msg_recv": msg_tag,
Expand Down

0 comments on commit 859e668

Please sign in to comment.