Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Active Message receiver callbacks #186

Merged
merged 14 commits into from
Mar 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 24 additions & 13 deletions cpp/include/ucxx/endpoint.h
Original file line number Diff line number Diff line change
Expand Up @@ -325,26 +325,37 @@ class Endpoint : public Component {
* the status of the transfer must be verified from the resulting request object before
* the data can be released.
*
* An optional `receiverCallbackInfo` may be specified, in which case the remote worker
* obligatorily needs to have registered a callback with the same `receiverCallbackInfo`
* in order to execute the callback when the active message is received. When this is
* specified, `amRecv()` will _NOT_ match this message, which is instead handled by the
* remote worker's callback.
*
* Using a Python future may be requested by specifying `enablePythonFuture`. If a
* Python future is requested, the Python application must then await on this future to
* ensure the transfer has completed. Requires UCXX Python support.
*
* @param[in] buffer a raw pointer to the data to be sent.
* @param[in] length the size in bytes of the tag message to be sent.
* @param[in] memoryType the memory type of the buffer.
* @param[in] enablePythonFuture whether a python future should be created and
* subsequently notified.
* @param[in] callbackFunction user-defined callback function to call upon completion.
* @param[in] callbackData user-defined data to pass to the `callbackFunction`.
* @param[in] buffer a raw pointer to the data to be sent.
* @param[in] length the size in bytes of the tag message to be sent.
* @param[in] memoryType the memory type of the buffer.
* @param[in] receiverCallbackInfo the owner name and unique identifier of the receiver
callback.
* @param[in] enablePythonFuture whether a python future should be created and
* subsequently notified.
* @param[in] callbackFunction user-defined callback function to call upon
completion.
* @param[in] callbackData user-defined data to pass to the `callbackFunction`.
*
* @returns Request to be subsequently checked for the completion and its state.
*/
std::shared_ptr<Request> amSend(void* buffer,
size_t length,
ucs_memory_type_t memoryType,
const bool enablePythonFuture = false,
RequestCallbackUserFunction callbackFunction = nullptr,
RequestCallbackUserData callbackData = nullptr);
std::shared_ptr<Request> amSend(
void* buffer,
const size_t length,
const ucs_memory_type_t memoryType,
const std::optional<AmReceiverCallbackInfo> receiverCallbackInfo = std::nullopt,
const bool enablePythonFuture = false,
RequestCallbackUserFunction callbackFunction = nullptr,
RequestCallbackUserData callbackData = nullptr);

/**
* @brief Enqueue an active message receive operation.
Expand Down
24 changes: 17 additions & 7 deletions cpp/include/ucxx/internal/request_am.h
Original file line number Diff line number Diff line change
Expand Up @@ -50,16 +50,18 @@ class RecvAmMessage {
*
* Construct the object, setting attributes that are later needed by the callback.
*
* @param[in] amData active messages worker data.
* @param[in] ep handle containing address of the reply endpoint (i.e., endpoint
* where user is requesting to receive).
* @param[in] request request to be later notified/delivered to user.
* @param[in] buffer buffer containing the received data
* @param[in] amData active messages worker data.
* @param[in] ep handle containing address of the reply endpoint (i.e.,
endpoint where user is requesting to receive).
* @param[in] request request to be later notified/delivered to user.
* @param[in] buffer buffer containing the received data
* @param[in] receiverCallback receiver callback to execute when request completes.
*/
RecvAmMessage(internal::AmData* amData,
ucp_ep_h ep,
std::shared_ptr<RequestAm> request,
std::shared_ptr<Buffer> buffer);
std::shared_ptr<Buffer> buffer,
AmReceiverCallbackType receiverCallback = AmReceiverCallbackType());

/**
* @brief Set the UCP request.
Expand All @@ -86,6 +88,11 @@ class RecvAmMessage {
typedef std::unordered_map<ucp_ep_h, std::queue<std::shared_ptr<RequestAm>>> AmPoolType;
typedef std::unordered_map<RequestAm*, std::shared_ptr<RecvAmMessage>> RecvAmMessageMapType;

typedef std::unordered_map<AmReceiverCallbackIdType, AmReceiverCallbackType>
AmReceiverCallbackMapType;
typedef std::unordered_map<AmReceiverCallbackOwnerType, AmReceiverCallbackMapType>
AmReceiverCallbackOwnerMapType;

/**
* @brief Active Message data owned by a `ucxx::Worker`.
*
Expand All @@ -101,7 +108,10 @@ class AmData {
AmPoolType _recvWait{}; ///< The pool of user receive requests (waiting for message arrival)
RecvAmMessageMapType
_recvAmMessageMap{}; ///< The active messages waiting to be handled by callback
std::mutex _mutex{}; ///< Mutex to provide access to pools/maps
AmReceiverCallbackOwnerMapType
_receiverCallbacks{}; ///< Receiver callbacks to handle specialized Active Messages without a
///< pool.
std::mutex _mutex{}; ///< Mutex to provide access to pools/maps
std::function<void(std::shared_ptr<Request>)>
_registerInflightRequest{}; ///< Worker function to register inflight requests with
std::unordered_map<ucs_memory_type_t, AmAllocatorType>
Expand Down
14 changes: 10 additions & 4 deletions cpp/include/ucxx/request_data.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#pragma once

#include <memory>
#include <optional>
#include <variant>
#include <vector>

Expand All @@ -29,19 +30,24 @@ class AmSend {
const void* _buffer{nullptr}; ///< The raw pointer where data to be sent is stored.
const size_t _length{0}; ///< The length of the message.
const ucs_memory_type_t _memoryType{UCS_MEMORY_TYPE_HOST}; ///< Memory type used on the operation
const std::optional<AmReceiverCallbackInfo> _receiverCallbackInfo{
std::nullopt}; ///< Owner name and unique identifier of the receiver callback.

/**
* @brief Constructor for Active Message-specific send data.
*
* Construct an object containing Active Message-specific send data.
*
* @param[in] buffer a raw pointer to the data to be sent.
* @param[in] length the size in bytes of the message to be sent.
* @param[in] memoryType the memory type of the buffer.
* @param[in] buffer a raw pointer to the data to be sent.
* @param[in] length the size in bytes of the message to be sent.
* @param[in] memoryType the memory type of the buffer.
* @param[in] receiverCallbackInfo the owner name and unique identifier of the receiver
callback.
*/
explicit AmSend(const decltype(_buffer) buffer,
const decltype(_length) length,
const decltype(_memoryType) memoryType = UCS_MEMORY_TYPE_HOST);
const decltype(_memoryType) memoryType = UCS_MEMORY_TYPE_HOST,
const decltype(_receiverCallbackInfo) receiverCallbackInfo = std::nullopt);

AmSend() = delete;
};
Expand Down
41 changes: 41 additions & 0 deletions cpp/include/ucxx/typedefs.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ namespace ucxx {

class Buffer;
class Request;
class RequestAm;

/**
* @brief Available logging levels.
Expand Down Expand Up @@ -119,6 +120,46 @@ typedef RequestCallbackUserData EndpointCloseCallbackUserData;
*/
typedef std::function<std::shared_ptr<Buffer>(size_t)> AmAllocatorType;

/**
* @brief Active Message receiver callback.
*
* Type for a custom Active Message receiver callback, executed by the remote worker upon
* Active Message request completion.
*/
typedef std::function<void(std::shared_ptr<Request>)> AmReceiverCallbackType;

/**
* @brief Active Message receiver callback owner name.
*
* A string containing the owner's name of an Active Message receiver callback. The owner
* should be a reasonably unique name, usually identifying the application, to allow other
* applications to coexist and register their own receiver callbacks.
*/
typedef std::string AmReceiverCallbackOwnerType;

/**
* @brief Active Message receiver callback identifier.
*
* A 64-bit unsigned integer unique identifier type of an Active Message receiver callback.
*/
typedef uint64_t AmReceiverCallbackIdType;

typedef const std::string AmReceiverCallbackInfoSerialized;

/**
* @brief Information of an Active Message receiver callback.
*
* Type identifying an Active Message receiver callback's owner name and unique identifier.
*/
class AmReceiverCallbackInfo {
public:
const AmReceiverCallbackOwnerType owner;
const AmReceiverCallbackIdType id;

AmReceiverCallbackInfo() = delete;
AmReceiverCallbackInfo(const AmReceiverCallbackOwnerType owner, AmReceiverCallbackIdType id);
};

typedef const std::string SerializedRemoteKey;

} // namespace ucxx
40 changes: 40 additions & 0 deletions cpp/include/ucxx/worker.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
#include <ucxx/future.h>
#include <ucxx/inflight_requests.h>
#include <ucxx/notifier.h>
#include <ucxx/typedefs.h>
#include <ucxx/worker_progress_thread.h>

namespace ucxx {
Expand Down Expand Up @@ -799,6 +800,45 @@ class Worker : public Component {
*/
void registerAmAllocator(ucs_memory_type_t memoryType, AmAllocatorType allocator);

/**
* @brief Register receiver callback for active messages.
*
* Register a new receiver callback for active messages. By default, active messages do
* not execute any callbacks on the receiving end unless one is specified when sending
* the message. If the message sender specifies a callback receiver identifier then the
* remote receiver needs to have a callback registered with the same identifier to
* execute when the request completes. To ensure multiple applications that do not know
* about each other can have coexisting callbacks where receiver identifiers may have
* the same value, an owner must be specified as well, which has the form of a string and
* should be reasonably unique to prevent accidentally calling callbacks from a separate
* application, thus names like "A" or "UCX" are discouraged in favor of more descriptive
* names such as "MyFastCommsProject", and the name "ucxx" is reserved.
wence- marked this conversation as resolved.
Show resolved Hide resolved
*
* Because it is impossible to predict which callback would be called in such an event,
* the registered callback cannot be changed, thus calling this method with the same
* given owner and identifier will throw `std::runtime_error`.
*
*
* @code{.cpp}
* // `worker` is `std::shared_ptr<ucxx::Worker>`
* auto callback = [](std::shared_ptr<ucxx::Request> req) {
* std::cout << "The UCXX request address is " << (void*)req.get() << std::endl;
* };
*
* worker->registerAmReceiverCallback({"MyFastApp", 0}, callback};
* @endcode
*
* @throws std::runtime_error if a callback with same given owner and identifier is
* already registered, or if the reserved owner name "ucxx"
* is specified.
*
* @param[in] receiverCallbackInfo the owner name and unique identifier of the receiver
callback.
* @param[in] callback the callback to execute when the active message is
* received.
*/
void registerAmReceiverCallback(AmReceiverCallbackInfo info, AmReceiverCallbackType callback);

/**
* @brief Check for uncaught active messages.
*
Expand Down
25 changes: 14 additions & 11 deletions cpp/src/endpoint.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -365,19 +365,22 @@ size_t Endpoint::cancelInflightRequestsBlocking(uint64_t period, uint64_t maxAtt

size_t Endpoint::getCancelingSize() const { return _inflightRequests->getCancelingSize(); }

std::shared_ptr<Request> Endpoint::amSend(void* buffer,
size_t length,
ucs_memory_type_t memoryType,
const bool enablePythonFuture,
RequestCallbackUserFunction callbackFunction,
RequestCallbackUserData callbackData)
std::shared_ptr<Request> Endpoint::amSend(
void* buffer,
const size_t length,
const ucs_memory_type_t memoryType,
const std::optional<AmReceiverCallbackInfo> receiverCallbackInfo,
const bool enablePythonFuture,
RequestCallbackUserFunction callbackFunction,
RequestCallbackUserData callbackData)
{
auto endpoint = std::dynamic_pointer_cast<Endpoint>(shared_from_this());
return registerInflightRequest(createRequestAm(endpoint,
data::AmSend(buffer, length, memoryType),
enablePythonFuture,
callbackFunction,
callbackData));
return registerInflightRequest(
createRequestAm(endpoint,
data::AmSend(buffer, length, memoryType, receiverCallbackInfo),
enablePythonFuture,
callbackFunction,
callbackData));
}

std::shared_ptr<Request> Endpoint::amRecv(const bool enablePythonFuture,
Expand Down
10 changes: 9 additions & 1 deletion cpp/src/internal/request_am.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
#include <ucxx/delayed_submission.h>
#include <ucxx/internal/request_am.h>
#include <ucxx/request_am.h>
#include <ucxx/typedefs.h>

namespace ucxx {

Expand All @@ -14,14 +15,21 @@ namespace internal {
RecvAmMessage::RecvAmMessage(internal::AmData* amData,
ucp_ep_h ep,
std::shared_ptr<RequestAm> request,
std::shared_ptr<Buffer> buffer)
std::shared_ptr<Buffer> buffer,
AmReceiverCallbackType receiverCallback)
: _amData(amData), _ep(ep), _request(request)
{
std::visit(data::dispatch{
[this, buffer](data::AmReceive& amReceive) { amReceive._buffer = buffer; },
[](auto) { throw std::runtime_error("Unreachable"); },
},
_request->_requestData);

if (receiverCallback) {
_request->_callback = [this, receiverCallback](ucs_status_t, std::shared_ptr<void>) {
receiverCallback(_request);
};
}
}

void RecvAmMessage::setUcpRequest(void* request) { _request->_request = request; }
Expand Down
Loading
Loading