Skip to content

Commit

Permalink
Add specialized send/receive tag classes for DelayedSubmission
Browse files Browse the repository at this point in the history
  • Loading branch information
pentschev committed Dec 7, 2023
1 parent 1537e1e commit e0180aa
Show file tree
Hide file tree
Showing 4 changed files with 188 additions and 60 deletions.
90 changes: 82 additions & 8 deletions cpp/include/ucxx/delayed_submission.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,17 @@ namespace ucxx {

typedef std::function<void()> DelayedSubmissionCallbackType;

enum class DelayedSubmissionOperationType { Undefined = 0, Am, Stream, Tag, TagMulti };
enum class DelayedSubmissionOperationType {
Undefined = 0,
Am,
Stream,
Tag,
TagSend,
TagReceive,
TagMulti,
TagMultiSend,
TagMultiReceive
};

class DelayedSubmissionAm {
public:
Expand Down Expand Up @@ -56,12 +66,61 @@ class DelayedSubmissionTag {
explicit DelayedSubmissionTag(const decltype(_tag) tag, const decltype(_tagMask) tagMask);
};

class DelayedSubmissionTagSend {
public:
const void* _buffer; ///< The raw pointer where data to be sent is stored.
const size_t _length; ///< The length of the message.
const Tag _tag; ///< Tag to match

/**
* @brief Constructor for `DelayedSubmission` tag/multi-buffer tag-specific data.
*
* Construct an object containing tag/multi-buffer tag-specific data for
* `DelayedSubmission`.
*
* @param[in] buffer a raw pointer to the data to be transferred.
* @param[in] length the size in bytes of the tag message to be transferred.
* @param[in] tag the tag to match.
*/
explicit DelayedSubmissionTagSend(const decltype(_buffer) buffer,
const decltype(_length) length,
const decltype(_tag) tag);
};

class DelayedSubmissionTagReceive {
public:
void* _buffer; ///< The raw pointer where received data should be stored.
const size_t _length; ///< The length of the message.
const Tag _tag; ///< Tag to match
const TagMask _tagMask; ///< Tag mask to use

/**
* @brief Constructor for `DelayedSubmission` tag/multi-buffer tag-specific data.
*
* Construct an object containing tag/multi-buffer tag-specific data for
* `DelayedSubmission`.
*
* @param[in] buffer a raw pointer to the data to be transferred.
* @param[in] length the size in bytes of the tag message to be transferred.
* @param[in] tag the tag to match.
* @param[in] tagMask the tag mask to use (only used for receive operations).
*/
explicit DelayedSubmissionTagReceive(decltype(_buffer) buffer,
const decltype(_length) length,
const decltype(_tag) tag,
const decltype(_tagMask) tagMask);
};

class DelayedSubmissionData {
public:
const DelayedSubmissionOperationType _operationType{
DelayedSubmissionOperationType::Undefined}; ///< The operation type
const TransferDirection _transferDirection{}; ///< The direction of the transfer.
const std::variant<std::monostate, DelayedSubmissionAm, DelayedSubmissionTag>
const std::variant<std::monostate,
DelayedSubmissionAm,
DelayedSubmissionTagSend,
DelayedSubmissionTagReceive,
DelayedSubmissionTag>
_data; ///< Data used on the operation

/**
Expand Down Expand Up @@ -94,16 +153,31 @@ class DelayedSubmissionData {
DelayedSubmissionAm getAm();

/**
* @brief Get the tag or multi-buffer tag data.
* @brief Get the send tag or multi-buffer tag data.
*
* Get the tag or multi-buffer tag data if the object was constructed from it, or throws
* `std::bad_variant` if constructed with different data type.
* Get the tag or multi-buffer tag send data if the object was constructed from it, or
* throws * `std::bad_variant` if constructed with different data type.
*
* @throws std::bad_variant if the object was constructed for a type other than send tag
* or multi-buffer tag.
*
* @returns the send tag or multi-buffer tag data.
*/
DelayedSubmissionTagSend getTagSend();

/**
* @brief Get the receive tag or multi-buffer tag data.
*
* @throws std::bad_variant if the object was constructed for a type other than tag or
* multi-buffer tag.
* Get the tag or multi-buffer tag receive data if the object was constructed from it, or
* throws * `std::bad_variant` if constructed with different data type.
*
* @returns the tag or multi-buffer tag data.
* @throws std::bad_variant if the object was constructed for a type other than receive
* tag or multi-buffer tag.
*
* @returns the receive tag or multi-buffer tag data.
*/
DelayedSubmissionTagReceive getTagReceive();

DelayedSubmissionTag getTag();
};

Expand Down
57 changes: 42 additions & 15 deletions cpp/src/delayed_submission.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
/**
/**:
* SPDX-FileCopyrightText: Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES.
* SPDX-License-Identifier: BSD-3-Clause
*/
Expand All @@ -22,10 +22,28 @@ DelayedSubmissionTag::DelayedSubmissionTag(const Tag tag, const std::optional<Ta
{
}

DelayedSubmissionData::DelayedSubmissionData(
const DelayedSubmissionOperationType operationType,
const TransferDirection transferDirection,
const std::variant<std::monostate, DelayedSubmissionAm, DelayedSubmissionTag> data)
DelayedSubmissionTagSend::DelayedSubmissionTagSend(const void* buffer,
const size_t length,
const Tag tag)
: _buffer(buffer), _length(length), _tag(tag)
{
}

DelayedSubmissionTagReceive::DelayedSubmissionTagReceive(void* buffer,
const size_t length,
const Tag tag,
const TagMask tagMask)
: _buffer(buffer), _length(length), _tag(tag), _tagMask(tagMask)
{
}

DelayedSubmissionData::DelayedSubmissionData(const DelayedSubmissionOperationType operationType,
const TransferDirection transferDirection,
const std::variant<std::monostate,
DelayedSubmissionAm,
DelayedSubmissionTagSend,
DelayedSubmissionTagReceive,
DelayedSubmissionTag> data)
: _operationType(operationType), _transferDirection(transferDirection), _data(data)
{
if (_operationType == DelayedSubmissionOperationType::Am) {
Expand All @@ -37,17 +55,16 @@ DelayedSubmissionData::DelayedSubmissionData(
!std::holds_alternative<std::monostate>(data))
throw std::runtime_error(
"Receive Am operations do not support data value other than `std::monostate`.");
} else if (_operationType == DelayedSubmissionOperationType::Tag ||
_operationType == DelayedSubmissionOperationType::TagMulti) {
if (!std::holds_alternative<DelayedSubmissionTag>(data))
} else if (_operationType == DelayedSubmissionOperationType::TagSend ||
_operationType == DelayedSubmissionOperationType::TagMultiSend) {
if (!std::holds_alternative<DelayedSubmissionTagSend>(data))
throw std::runtime_error(
"Operations Tag and TagMulti require data to be of type `DelayedSubmissionTag`.");
if (transferDirection == TransferDirection::Send &&
std::get<DelayedSubmissionTag>(data)._tagMask)
throw std::runtime_error("Send Tag and TagMulti operations do not take a tag mask.");
else if (transferDirection == TransferDirection::Receive &&
!std::get<DelayedSubmissionTag>(data)._tagMask)
throw std::runtime_error("Receive Tag and TagMulti operations require a tag mask.");
"Operations Tag and TagMulti require data to be of type `DelayedSubmissionTagSend`.");
} else if (_operationType == DelayedSubmissionOperationType::TagReceive ||
_operationType == DelayedSubmissionOperationType::TagMultiReceive) {
if (!std::holds_alternative<DelayedSubmissionTagReceive>(data))
throw std::runtime_error(
"Operations Tag and TagMulti require data to be of type `DelayedSubmissionTagReceive`.");
} else {
if (!std::holds_alternative<std::monostate>(data))
throw std::runtime_error("Type does not support data value other than `std::monostate`.");
Expand All @@ -56,6 +73,16 @@ DelayedSubmissionData::DelayedSubmissionData(

DelayedSubmissionAm DelayedSubmissionData::getAm() { return std::get<DelayedSubmissionAm>(_data); }

DelayedSubmissionTagSend DelayedSubmissionData::getTagSend()
{
return std::get<DelayedSubmissionTagSend>(_data);
}

DelayedSubmissionTagReceive DelayedSubmissionData::getTagReceive()
{
return std::get<DelayedSubmissionTagReceive>(_data);
}

DelayedSubmissionTag DelayedSubmissionData::getTag()
{
return std::get<DelayedSubmissionTag>(_data);
Expand Down
9 changes: 8 additions & 1 deletion cpp/src/request.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,14 @@ void Request::setStatus(ucs_status_t status)
status,
ucs_status_string(status));

if (_status != UCS_INPROGRESS) ucxx_error("setStatus called but the status was already set");
if (_status != UCS_INPROGRESS)
ucxx_error(
"setStatus called on request: %p with status: %d (%s) but status: %d (%s) was already set",
this,
status,
ucs_status_string(status),
_status,
ucs_status_string(_status));
_status = status;

if (_enablePythonFuture) {
Expand Down
92 changes: 56 additions & 36 deletions cpp/src/request_tag.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -56,11 +56,13 @@ RequestTag::RequestTag(std::shared_ptr<Component> endpointOrWorker,
transferDirection,
buffer,
length,
DelayedSubmissionData(DelayedSubmissionOperationType::Tag,
transferDirection,
transferDirection == TransferDirection::Send
? DelayedSubmissionTag(tag, std::nullopt)
: DelayedSubmissionTag(tag, tagMask))),
transferDirection == TransferDirection::Send
? DelayedSubmissionData(DelayedSubmissionOperationType::Tag,
transferDirection,
DelayedSubmissionTagSend(buffer, length, tag))
: DelayedSubmissionData(DelayedSubmissionOperationType::Tag,
transferDirection,
DelayedSubmissionTagReceive(buffer, length, tag, tagMask))),
std::string(transferDirection == TransferDirection::Send ? "tagSend" : "tagRecv"),
enablePythonFuture),
_length(length)
Expand Down Expand Up @@ -113,25 +115,26 @@ void RequestTag::request()

if (_delayedSubmission->_transferDirection == TransferDirection::Send) {
param.cb.send = tagSendCallback;
request = ucp_tag_send_nbx(_endpoint->getHandle(),
_delayedSubmission->_buffer,
_delayedSubmission->_length,
_delayedSubmission->_data.getTag()._tag,
&param);
auto tagSend = _delayedSubmission->_data.getTagSend();
request = ucp_tag_send_nbx(
_endpoint->getHandle(), tagSend._buffer, tagSend._length, tagSend._tag, &param);
} else {
param.cb.recv = tagRecvCallback;
request = ucp_tag_recv_nbx(_worker->getHandle(),
_delayedSubmission->_buffer,
_delayedSubmission->_length,
_delayedSubmission->_data.getTag()._tag,
*_delayedSubmission->_data.getTag()._tagMask,
param.cb.recv = tagRecvCallback;
auto tagReceive = _delayedSubmission->_data.getTagReceive();
request = ucp_tag_recv_nbx(_worker->getHandle(),
tagReceive._buffer,
tagReceive._length,
tagReceive._tag,
tagReceive._tagMask,
&param);
}

std::lock_guard<std::recursive_mutex> lock(_mutex);
_request = request;
}

static void logPopulateDelayedSubmission() {}

void RequestTag::populateDelayedSubmission()
{
if (_delayedSubmission->_transferDirection == TransferDirection::Send &&
Expand All @@ -148,26 +151,43 @@ void RequestTag::populateDelayedSubmission()

request();

if (_enablePythonFuture)
ucxx_trace_req_f(_ownerString.c_str(),
_request,
_operationName.c_str(),
"tag 0x%lx, tagMask: 0x%lx, buffer %p, size %lu, future %p, future handle %p, "
"populateDelayedSubmission",
_delayedSubmission->_data.getTag()._tag,
_delayedSubmission->_data.getTag()._tagMask,
_delayedSubmission->_buffer,
_delayedSubmission->_length,
_future.get(),
_future->getHandle());
else
ucxx_trace_req_f(_ownerString.c_str(),
_request,
_operationName.c_str(),
"tag 0x%lx, buffer %p, size %lu, populateDelayedSubmission",
_delayedSubmission->_data.getTag()._tag,
_delayedSubmission->_buffer,
_delayedSubmission->_length);
auto log = [this](const void* buffer, const size_t length, const Tag tag, const TagMask tagMask) {
if (_enablePythonFuture)
ucxx_trace_req_f(
_ownerString.c_str(),
_request,
_operationName.c_str(),
"buffer: %p, size: %lu, tag 0x%lx, tagMask: 0x%lx, future %p, future handle %p, "
"populateDelayedSubmission",
buffer,
length,
tag,
tagMask,
_future.get(),
_future->getHandle());
else
ucxx_trace_req_f(
_ownerString.c_str(),
_request,
_operationName.c_str(),
"buffer: %p, size: %lu, tag 0x%lx, tagMask: 0x%lx, populateDelayedSubmission",
buffer,
length,
tag,
tagMask);
};

try {
auto tagSend = _delayedSubmission->_data.getTagSend();
log(tagSend._buffer, tagSend._length, tagSend._tag, TagMaskFull);
} catch (const std::bad_variant_access& e) {
try {
auto tagReceive = _delayedSubmission->_data.getTagReceive();
log(tagReceive._buffer, tagReceive._length, tagReceive._tag, tagReceive._tagMask);
} catch (const std::bad_variant_access& e) {
ucxx_error("Impossible to get transfer data.");
}
}

process();
}
Expand Down

0 comments on commit e0180aa

Please sign in to comment.