Skip to content

Commit

Permalink
Improve tagProbe and AM receiver callback (#348)
Browse files Browse the repository at this point in the history
Improve `tagProbe` by accepting a tag mask for matching and return probed tag information. Expose also the sender endpoint handle to AM receive callback so that the callback is capable of knowing the origin of the message.

Additionally, fix C++ request tests that were being unintentionally skipped.

Authors:
  - Peter Andreas Entschev (https://github.com/pentschev)

Approvers:
  - Mads R. B. Kristensen (https://github.com/madsbk)

URL: #348
  • Loading branch information
pentschev authored Jan 16, 2025
1 parent b0027cf commit 93edd75
Show file tree
Hide file tree
Showing 8 changed files with 87 additions and 33 deletions.
19 changes: 17 additions & 2 deletions cpp/include/ucxx/typedefs.h
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,20 @@ enum TagMask : ucp_tag_t {};
*/
static constexpr TagMask TagMaskFull{std::numeric_limits<std::underlying_type_t<TagMask>>::max()};

/**
* @brief Information about probed tag message.
*
* Contains information returned when probing by a tag message received by the worker but
* not yet consumed.
*/
class TagRecvInfo {
public:
Tag senderTag; ///< Sender tag
size_t length; ///< The size of the received data

explicit TagRecvInfo(const ucp_tag_recv_info_t&);
};

/**
* @brief A UCP configuration map.
*
Expand Down Expand Up @@ -124,9 +138,10 @@ typedef std::function<std::shared_ptr<Buffer>(size_t)> AmAllocatorType;
* @brief Active Message receiver callback.
*
* Type for a custom Active Message receiver callback, executed by the remote worker upon
* Active Message request completion.
* Active Message request completion. The first parameter is the request that completed,
* the second is the handle of the UCX endpoint of the sender.
*/
typedef std::function<void(std::shared_ptr<Request>)> AmReceiverCallbackType;
typedef std::function<void(std::shared_ptr<Request>, ucp_ep_h)> AmReceiverCallbackType;

/**
* @brief Active Message receiver callback owner name.
Expand Down
19 changes: 14 additions & 5 deletions cpp/include/ucxx/worker.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#include <queue>
#include <string>
#include <thread>
#include <utility>

#include <ucp/api/ucp.h>

Expand Down Expand Up @@ -684,21 +685,29 @@ class Worker : public Component {
*
* Checks the worker for any uncaught tag messages. An uncaught tag message is any
* tag message that has been fully or partially received by the worker, but not matched
* by a corresponding `ucp_tag_recv_*` call.
* by a corresponding `ucp_tag_recv_*` call. Additionally, returns information about the
* tag message.
*
* @code{.cpp}
* // `worker` is `std::shared_ptr<ucxx::Worker>`
* assert(!worker->tagProbe(0));
* auto probe = worker->tagProbe(0);
* assert(!probe.first)
*
* // `ep` is a remote `std::shared_ptr<ucxx::Endpoint` to the local `worker`
* ep->tagSend(buffer, length, 0);
*
* assert(worker->tagProbe(0));
* probe = worker->tagProbe(0);
* assert(probe.first);
* assert(probe.second.tag == 0);
* assert(probe.second.length == length);
* @endcode
*
* @returns `true` if any uncaught messages were received, `false` otherwise.
* @returns pair where first elements is `true` if any uncaught messages were received,
* `false` otherwise, and second element contain the information from the tag
* receive.
*/
[[nodiscard]] bool tagProbe(const Tag tag);
[[nodiscard]] std::pair<bool, TagRecvInfo> tagProbe(const Tag tag,
const TagMask tagMask = TagMaskFull);

/**
* @brief Enqueue a tag receive operation.
Expand Down
2 changes: 1 addition & 1 deletion cpp/src/internal/request_am.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ RecvAmMessage::RecvAmMessage(internal::AmData* amData,

if (receiverCallback) {
_request->_callback = [this, receiverCallback](ucs_status_t, std::shared_ptr<void>) {
receiverCallback(_request);
receiverCallback(_request, _ep);
};
}
}
Expand Down
11 changes: 8 additions & 3 deletions cpp/src/worker.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -576,7 +576,12 @@ void Worker::removeInflightRequest(const Request* const request)
}
}

bool Worker::tagProbe(const Tag tag)
TagRecvInfo::TagRecvInfo(const ucp_tag_recv_info_t& info)
: senderTag(Tag(info.sender_tag)), length(info.length)
{
}

std::pair<bool, TagRecvInfo> Worker::tagProbe(const Tag tag, const TagMask tagMask)
{
if (!isProgressThreadRunning()) {
progress();
Expand All @@ -592,9 +597,9 @@ bool Worker::tagProbe(const Tag tag)
}

ucp_tag_recv_info_t info;
ucp_tag_message_h tag_message = ucp_tag_probe_nb(_handle, tag, TagMaskFull, 0, &info);
ucp_tag_message_h tag_message = ucp_tag_probe_nb(_handle, tag, tagMask, 0, &info);

return tag_message != NULL;
return {tag_message != NULL, TagRecvInfo(info)};
}

std::shared_ptr<Request> Worker::tagRecv(void* buffer,
Expand Down
25 changes: 13 additions & 12 deletions cpp/tests/request.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -59,17 +59,18 @@ class RequestTest : public ::testing::TestWithParam<

void SetUp()
{
std::tie(_bufferType,
_registerCustomAmAllocator,
_enableDelayedSubmission,
_progressMode,
_messageLength) = GetParam();

if (_bufferType == ucxx::BufferType::RMM) {
#if !UCXX_ENABLE_RMM
GTEST_SKIP() << "UCXX was not built with RMM support";
#endif
}

std::tie(_bufferType,
_registerCustomAmAllocator,
_enableDelayedSubmission,
_progressMode,
_messageLength) = GetParam();
_memoryType =
(_bufferType == ucxx::BufferType::RMM) ? UCS_MEMORY_TYPE_CUDA : UCS_MEMORY_TYPE_HOST;
_messageSize = _messageLength * sizeof(int);
Expand Down Expand Up @@ -168,13 +169,14 @@ TEST_P(RequestTest, ProgressAm)
GTEST_SKIP() << "Interrupting UCP worker progress operation in wait mode is not possible";
}

if (_registerCustomAmAllocator && _memoryType == UCS_MEMORY_TYPE_CUDA) {
#if !UCXX_ENABLE_RMM
GTEST_SKIP() << "UCXX was not built with RMM support";
GTEST_SKIP() << "UCXX was not built with RMM support";
#else
if (_registerCustomAmAllocator && _memoryType == UCS_MEMORY_TYPE_CUDA) {
_worker->registerAmAllocator(UCS_MEMORY_TYPE_CUDA, [](size_t length) {
return std::make_shared<ucxx::RMMBuffer>(length);
});
#endif
}

allocate(1, false);
Expand All @@ -198,7 +200,6 @@ TEST_P(RequestTest, ProgressAm)

// Assert data correctness
ASSERT_THAT(_recv[0], ContainerEq(_send[0]));
#endif
}

TEST_P(RequestTest, ProgressAmReceiverCallback)
Expand All @@ -207,13 +208,14 @@ TEST_P(RequestTest, ProgressAmReceiverCallback)
GTEST_SKIP() << "Interrupting UCP worker progress operation in wait mode is not possible";
}

if (_registerCustomAmAllocator && _memoryType == UCS_MEMORY_TYPE_CUDA) {
#if !UCXX_ENABLE_RMM
GTEST_SKIP() << "UCXX was not built with RMM support";
GTEST_SKIP() << "UCXX was not built with RMM support";
#else
if (_registerCustomAmAllocator && _memoryType == UCS_MEMORY_TYPE_CUDA) {
_worker->registerAmAllocator(UCS_MEMORY_TYPE_CUDA, [](size_t length) {
return std::make_shared<ucxx::RMMBuffer>(length);
});
#endif
}

// Define AM receiver callback's owner and id for callback
Expand All @@ -226,7 +228,7 @@ TEST_P(RequestTest, ProgressAmReceiverCallback)
// Define AM receiver callback and register with worker
std::vector<std::shared_ptr<ucxx::Request>> receivedRequests;
auto callback = ucxx::AmReceiverCallbackType(
[this, &receivedRequests, &mutex](std::shared_ptr<ucxx::Request> req) {
[this, &receivedRequests, &mutex](std::shared_ptr<ucxx::Request> req, ucp_ep_h) {
{
std::lock_guard<std::mutex> lock(mutex);
receivedRequests.push_back(req);
Expand Down Expand Up @@ -260,7 +262,6 @@ TEST_P(RequestTest, ProgressAmReceiverCallback)

// Assert data correctness
ASSERT_THAT(_recv[0], ContainerEq(_send[0]));
#endif
}

TEST_P(RequestTest, ProgressStream)
Expand Down
13 changes: 9 additions & 4 deletions cpp/tests/worker.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,8 @@ TEST_F(WorkerTest, TagProbe)
auto progressWorker = getProgressFunction(_worker, ProgressMode::Polling);
auto ep = _worker->createEndpointFromWorkerAddress(_worker->getAddress());

ASSERT_FALSE(_worker->tagProbe(ucxx::Tag{0}));
auto probed = _worker->tagProbe(ucxx::Tag{0});
ASSERT_FALSE(probed.first);

std::vector<int> buf{123};
std::vector<std::shared_ptr<ucxx::Request>> requests;
Expand All @@ -119,10 +120,14 @@ TEST_F(WorkerTest, TagProbe)

loopWithTimeout(std::chrono::milliseconds(5000), [this, progressWorker]() {
progressWorker();
return _worker->tagProbe(ucxx::Tag{0});
auto probed = _worker->tagProbe(ucxx::Tag{0});
return probed.first;
});

ASSERT_TRUE(_worker->tagProbe(ucxx::Tag{0}));
probed = _worker->tagProbe(ucxx::Tag{0});
ASSERT_TRUE(probed.first);
ASSERT_EQ(probed.second.senderTag, ucxx::Tag{0});
ASSERT_EQ(probed.second.length, buf.size() * sizeof(int));
}

TEST_F(WorkerTest, AmProbe)
Expand Down Expand Up @@ -189,7 +194,7 @@ TEST_P(WorkerProgressTest, ProgressAmReceiverCallback)
// Define AM receiver callback and register with worker
std::vector<std::shared_ptr<ucxx::Request>> receivedRequests;
auto callback = ucxx::AmReceiverCallbackType(
[this, &receivedRequests, &mutex](std::shared_ptr<ucxx::Request> req) {
[this, &receivedRequests, &mutex](std::shared_ptr<ucxx::Request> req, ucp_ep_h) {
{
std::lock_guard<std::mutex> lock(mutex);
receivedRequests.push_back(req);
Expand Down
19 changes: 16 additions & 3 deletions python/ucxx/ucxx/_lib/libucxx.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ from libcpp.memory cimport (
unique_ptr,
)
from libcpp.optional cimport nullopt
from libcpp.pair cimport pair
from libcpp.string cimport string
from libcpp.utility cimport move
from libcpp.vector cimport vector
Expand Down Expand Up @@ -684,13 +685,25 @@ cdef class UCXWorker():

return num_canceled

def tag_probe(self, UCXXTag tag) -> bool:
cdef bint tag_matched
def tag_probe(self, UCXXTag tag, UCXXTagMask tag_mask = UCXXTagMaskFull) -> bool:
cdef Tag cpp_tag = <Tag><size_t>tag.value
cdef TagMask cpp_tag_mask = <TagMask><size_t>tag_mask.value
cdef ucp_tag_recv_info_t empty_tag_recv_info
cdef pair[bint, TagRecvInfo]* probed
cdef bint tag_matched = False

with nogil:
tag_matched = self._worker.get().tagProbe(cpp_tag)
# TagRecvInfo is not default-construtible, therefore we need to use a
# pointer, allocating it using a temporary ucp_tag_recv_info_t object
probed = new pair[bint, TagRecvInfo](
False,
TagRecvInfo(empty_tag_recv_info)
)
probed[0] = self._worker.get().tagProbe(cpp_tag, cpp_tag_mask)
tag_matched = probed[0].first
del probed

# TODO: Come up with good interface to expose TagRecvInfo as well
return tag_matched

def set_progress_thread_start_callback(
Expand Down
12 changes: 9 additions & 3 deletions python/ucxx/ucxx/_lib/ucxx_api.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ from libcpp cimport bool as cpp_bool
from libcpp.functional cimport function
from libcpp.memory cimport shared_ptr, unique_ptr
from libcpp.optional cimport nullopt_t, optional
from libcpp.pair cimport pair
from libcpp.string cimport string
from libcpp.unordered_map cimport unordered_map as cpp_unordered_map
from libcpp.vector cimport vector
Expand Down Expand Up @@ -54,6 +55,9 @@ cdef extern from "ucp/api/ucp.h" nogil:

ctypedef uint64_t ucp_tag_t

ctypedef struct ucp_tag_recv_info_t:
pass

ctypedef enum ucs_status_t:
pass

Expand Down Expand Up @@ -174,10 +178,12 @@ cdef extern from "<ucxx/api.h>" namespace "ucxx" nogil:
pass
cdef enum TagMask:
pass
cdef cppclass TagRecvInfo:
TagRecvInfo(const ucp_tag_recv_info_t&)
Tag senderTag
size_t length
cdef cppclass AmReceiverCallbackInfo:
pass
# ctypedef Tag CppTag
# ctypedef TagMask CppTagMask

# Using function[Buffer] here doesn't seem possible due to Cython bugs/limitations.
# The workaround is to use a raw C function pointer and let it be parsed by the
Expand Down Expand Up @@ -241,7 +247,7 @@ cdef extern from "<ucxx/api.h>" namespace "ucxx" nogil:
size_t cancelInflightRequests(
uint64_t period, uint64_t maxAttempts
) except +raise_py_error
bint tagProbe(const Tag) const
pair[bint, TagRecvInfo] tagProbe(const Tag, const TagMask) const
void setProgressThreadStartCallback(
function[void(void*)] callback, void* callbackArg
)
Expand Down

0 comments on commit 93edd75

Please sign in to comment.