Skip to content

Commit

Permalink
Use std::variant and std::visit to handle request data
Browse files Browse the repository at this point in the history
To account for stronger typing, this change removes the `DelayedSubmission`
class in favor of new request-specific data types. The types are then
combined with `std::variant` and utilize `std::visit` to choose the type
to work on and gather/process request-specific data.
  • Loading branch information
pentschev committed Dec 8, 2023
1 parent 1537e1e commit e80759b
Show file tree
Hide file tree
Showing 20 changed files with 904 additions and 806 deletions.
1 change: 1 addition & 0 deletions cpp/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,7 @@ add_library(
src/log.cpp
src/request.cpp
src/request_am.cpp
src/request_data.cpp
src/request_helper.cpp
src/request_stream.cpp
src/request_tag.cpp
Expand Down
42 changes: 11 additions & 31 deletions cpp/include/ucxx/constructors.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
#include <string>
#include <vector>

#include <ucxx/request_data.h>
#include <ucxx/typedefs.h>

namespace ucxx {
Expand Down Expand Up @@ -55,45 +56,24 @@ std::shared_ptr<Worker> createWorker(std::shared_ptr<Context> context,
const bool enableFuture);

// Transfers
std::shared_ptr<RequestAm> createRequestAmSend(std::shared_ptr<Endpoint> endpoint,
void* buffer,
size_t length,
ucs_memory_type_t memoryType,
const bool enablePythonFuture,
RequestCallbackUserFunction callbackFunction,
RequestCallbackUserData callbackData);

std::shared_ptr<RequestAm> createRequestAmRecv(std::shared_ptr<Endpoint> endpoint,
const bool enablePythonFuture,
RequestCallbackUserFunction callbackFunction,
RequestCallbackUserData callbackData);
std::shared_ptr<RequestAm> createRequestAm(std::shared_ptr<Endpoint> endpoint,
const data::RequestData requestData,
const bool enablePythonFuture,
RequestCallbackUserFunction callbackFunction,
RequestCallbackUserData callbackData);

std::shared_ptr<RequestStream> createRequestStream(std::shared_ptr<Endpoint> endpoint,
TransferDirection transferDirection,
void* buffer,
size_t length,
const data::RequestData requestData,
const bool enablePythonFuture);

std::shared_ptr<RequestTag> createRequestTag(std::shared_ptr<Component> endpointOrWorker,
TransferDirection transferDirection,
void* buffer,
size_t length,
Tag tag,
TagMask tagMask,
const data::RequestData requestData,
const bool enablePythonFuture,
RequestCallbackUserFunction callbackFunction,
RequestCallbackUserData callbackData);

std::shared_ptr<RequestTagMulti> createRequestTagMultiSend(std::shared_ptr<Endpoint> endpoint,
const std::vector<void*>& buffer,
const std::vector<size_t>& size,
const std::vector<int>& isCUDA,
const Tag tag,
const bool enablePythonFuture);

std::shared_ptr<RequestTagMulti> createRequestTagMultiRecv(std::shared_ptr<Endpoint> endpoint,
const Tag tag,
const TagMask tagMask,
const bool enablePythonFuture);
std::shared_ptr<RequestTagMulti> createRequestTagMulti(std::shared_ptr<Endpoint> endpoint,
const data::RequestData requestData,
const bool enablePythonFuture);

} // namespace ucxx
123 changes: 1 addition & 122 deletions cpp/include/ucxx/delayed_submission.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,133 +18,12 @@
#include <ucs/memory/memory_type.h>

#include <ucxx/log.h>
#include <ucxx/request_data.h>

namespace ucxx {

typedef std::function<void()> DelayedSubmissionCallbackType;

enum class DelayedSubmissionOperationType { Undefined = 0, Am, Stream, Tag, TagMulti };

class DelayedSubmissionAm {
public:
const ucs_memory_type_t _memoryType; ///< Memory type used on the operation

/**
* @brief Constructor for `DelayedSubmission` Active Message-specific data.
*
* Construct an object containing Active Message-specific data for `DelayedSubmission`.
*
* @param[in] memoryType the memory type of the buffer.
*/
explicit DelayedSubmissionAm(const decltype(_memoryType) memoryType);
};

class DelayedSubmissionTag {
public:
const Tag _tag; ///< Tag to match
const std::optional<TagMask> _tagMask; ///< Tag mask to use

/**
* @brief Constructor for `DelayedSubmission` tag/multi-buffer tag-specific data.
*
* Construct an object containing tag/multi-buffer tag-specific data for
* `DelayedSubmission`.
*
* @param[in] tag the tag to match.
* @param[in] tagMask the tag mask to use (only used for receive operations).
*/
explicit DelayedSubmissionTag(const decltype(_tag) tag, const decltype(_tagMask) tagMask);
};

class DelayedSubmissionData {
public:
const DelayedSubmissionOperationType _operationType{
DelayedSubmissionOperationType::Undefined}; ///< The operation type
const TransferDirection _transferDirection{}; ///< The direction of the transfer.
const std::variant<std::monostate, DelayedSubmissionAm, DelayedSubmissionTag>
_data; ///< Data used on the operation

/**
* @brief Constructor for `DelayedSubmission` operation-specific data.
*
* Construct an object containing operation-specific data for `DelayedSubmission`, which
* may also vary depending on the direction of the transfer.
*
* @param[in] operationType the type of operation the object refers.
* @param[in] transferDirection the direction of the transfer.
* @param[in] data data for the delayed submission, required for Active
* Message, tag and multi-buffer tag, or `std::monostate`
* otherwise.
*/
explicit DelayedSubmissionData(const decltype(_operationType) operationType,
const decltype(_transferDirection) transferDirection,
const decltype(_data) data);

/**
* @brief Get the Active Message data.
*
* Get the Active Message data if the object was constructed from it, or throws
* `std::bad_variant` if constructed with different data type.
*
* @throws std::bad_variant if the object was constructed for a type other than Active
* Message.
*
* @returns the Active Message data.
*/
DelayedSubmissionAm getAm();

/**
* @brief Get the tag or multi-buffer tag data.
*
* Get the tag or multi-buffer tag data if the object was constructed from it, or throws
* `std::bad_variant` if constructed with different data type.
*
* @throws std::bad_variant if the object was constructed for a type other than tag or
* multi-buffer tag.
*
* @returns the tag or multi-buffer tag data.
*/
DelayedSubmissionTag getTag();
};

class DelayedSubmission {
public:
TransferDirection _transferDirection{}; ///< The direction of transfer.
void* _buffer{nullptr}; ///< Raw pointer to data buffer
size_t _length{0}; ///< Length of the message in bytes
DelayedSubmissionData _data; ///< Operation type and operation-specific data

DelayedSubmission() = delete;

/**
* @brief Constructor for a delayed submission operation.
*
* Construct a delayed submission operation. Delayed submission means that a transfer
* operation will not be submitted immediately, but will rather be delayed for the next
* progress iteration.
*
* This may be useful to avoid any transfer operations to be executed directly in the
* application thread, delaying all of them for the worker progress thread when enabled.
* With this approach any perceived overhead will be removed from the application thread,
* and thus provide some speedup in certain situations. It may be also useful to prevent
* a multi-threaded application for blocking while waiting for the UCX spinlock, since
* all transfer operations may be pushed to the worker progress thread.
*
* @param[in] transferDirection the direction of transfer.
* @param[in] buffer a raw pointer to the data being transferred.
* @param[in] length the size in bytes of the message being transfer.
* @param[in] tag optional tag to match for this operation (only applies
* for tag operations).
* @param[in] tagMask optional tag mask to use for this operation (only applies
* for tag operations).
* @param[in] memoryType the memory type of the buffer.
*/
DelayedSubmission(const TransferDirection transferDirection,
void* buffer,
const size_t length,
const DelayedSubmissionData);
};

template <typename T>
class BaseDelayedSubmissionCollection {
protected:
Expand Down
8 changes: 4 additions & 4 deletions cpp/include/ucxx/request.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
#include <ucxx/component.h>
#include <ucxx/endpoint.h>
#include <ucxx/future.h>
#include <ucxx/request_data.h>
#include <ucxx/typedefs.h>

#define ucxx_trace_req_f(_owner, _req, _name, _message, ...) \
Expand All @@ -34,9 +35,8 @@ class Request : public Component {
std::shared_ptr<Endpoint> _endpoint{
nullptr}; ///< Endpoint that generated request (if not from worker)
std::string _ownerString{
"undetermined owner"}; ///< String to print owner (endpoint or worker) when logging
std::shared_ptr<DelayedSubmission> _delayedSubmission{
nullptr}; ///< The submission object that will dispatch the request
"undetermined owner"}; ///< String to print owner (endpoint or worker) when logging
data::RequestData _requestData{}; ///< The operation-specific data to be used in the request
std::string _operationName{
"request_undefined"}; ///< Human-readable operation name, mostly used for log messages
std::recursive_mutex _mutex{}; ///< Mutex to prevent checking status while it's being set
Expand All @@ -62,7 +62,7 @@ class Request : public Component {
* subsequently notified.
*/
Request(std::shared_ptr<Component> endpointOrWorker,
std::shared_ptr<DelayedSubmission> delayedSubmission,
const data::RequestData requestData,
const std::string operationName,
const bool enablePythonFuture = false);

Expand Down
106 changes: 24 additions & 82 deletions cpp/include/ucxx/request_am.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
*/
#pragma once
#include <memory>
#include <string>
#include <utility>

#include <ucp/api/ucp.h>
Expand All @@ -24,128 +25,69 @@ class RequestAm : public Request {
private:
friend class internal::RecvAmMessage;

ucs_memory_type_t _sendHeader{}; ///< The header to send
std::shared_ptr<Buffer> _buffer{nullptr}; ///< The AM received message buffer

/**
* @brief Private constructor of `ucxx::RequestAm` send.
* @brief Private constructor of `ucxx::RequestAm`.
*
* This is the internal implementation of `ucxx::RequestAm` send constructor, made private
* This is the internal implementation of `ucxx::RequestAm` constructor, made private
* not to be called directly. This constructor is made private to ensure all UCXX objects
* are shared pointers and the correct lifetime management of each one.
*
* Instead the user should use one of the following:
*
* - `ucxx::Endpoint::amSend()`
* - `ucxx::createRequestAmSend()`
* - `ucxx::Endpoint::amReceive()`
* - `ucxx::createRequestAmReceive()`
*
* @throws ucxx::Error if `endpoint` is not a valid `std::shared_ptr<ucxx::Endpoint>`.
*
* @param[in] endpoint the parent endpoint.
* @param[in] buffer a raw pointer to the data to be sent.
* @param[in] length the size in bytes of the active message to be sent.
* @param[in] memoryType the memory type of the buffer.
* @param[in] enablePythonFuture whether a python future should be created and
* subsequently notified.
* @param[in] callbackFunction user-defined callback function to call upon completion.
* @param[in] callbackData user-defined data to pass to the `callbackFunction`.
*/
RequestAm(std::shared_ptr<Endpoint> endpoint,
void* buffer,
size_t length,
ucs_memory_type_t memoryType,
const bool enablePythonFuture = false,
RequestCallbackUserFunction callbackFunction = nullptr,
RequestCallbackUserData callbackData = nullptr);

/**
* @brief Private constructor of `ucxx::RequestAm` receive.
*
* This is the internal implementation of `ucxx::RequestAm` receive constructor, made
* private not to be called directly. This constructor is made private to ensure all UCXX
* objects are shared pointers and the correct lifetime management of each one.
*
* Instead the user should use one of the following:
*
* - `ucxx::Endpoint::amRecv()`
* - `ucxx::createRequestAmRecv()`
*
* @throws ucxx::Error if `endpointOrWorker` is not a valid
* `std::shared_ptr<ucxx::Endpoint>` or
* `std::shared_ptr<ucxx::Worker>`.
*
* @param[in] endpointOrWorker the parent component, which may either be a
* `std::shared_ptr<Endpoint>` or
* `std::shared_ptr<Worker>`.
* @param[in] requestData container of the specified message type, including all
* type-specific data.
* @param[in] operationName a human-readable operation name to help identifying
* requests by their types when UCXX logging is enabled.
* @param[in] enablePythonFuture whether a python future should be created and
* subsequently notified.
* @param[in] callbackFunction user-defined callback function to call upon completion.
* @param[in] callbackData user-defined data to pass to the `callbackFunction`.
*/
RequestAm(std::shared_ptr<Component> endpointOrWorker,
const data::RequestData requestData,
const std::string operationName,
const bool enablePythonFuture = false,
RequestCallbackUserFunction callbackFunction = nullptr,
RequestCallbackUserData callbackData = nullptr);

public:
/**
* @brief Constructor for `std::shared_ptr<ucxx::RequestAm>` send.
* @brief Constructor for `std::shared_ptr<ucxx::RequestAm>`.
*
* The constructor for a `std::shared_ptr<ucxx::RequestAm>` object, creating a send active
* The constructor for a `std::shared_ptr<ucxx::RequestAm>` object, creating an active
* message request, returning a pointer to a request object that can be later awaited and
* checked for errors. This is a non-blocking operation, and the status of the transfer
* must be verified from the resulting request object before the data can be
* released.
* must be verified from the resulting request object before the data can be released if
* this is a send operation, or consumed if this is a receive operation. Received data is
* available via the `getRecvBuffer()` method if the receive transfer request completed
* successfully.
*
* @throws ucxx::Error if `endpoint` is not a valid
* `std::shared_ptr<ucxx::Endpoint>`.
*
* @param[in] endpoint the parent endpoint.
* @param[in] buffer a raw pointer to the data to be transferred.
* @param[in] length the size in bytes of the tag message to be transferred.
* @param[in] memoryType the memory type of the buffer.
* @param[in] requestData container of the specified message type, including all
* type-specific data.
* @param[in] enablePythonFuture whether a python future should be created and
* subsequently notified.
* @param[in] callbackFunction user-defined callback function to call upon completion.
* @param[in] callbackData user-defined data to pass to the `callbackFunction`.
*
* @returns The `shared_ptr<ucxx::RequestAm>` object
*/
friend std::shared_ptr<RequestAm> createRequestAmSend(
std::shared_ptr<Endpoint> endpoint,
void* buffer,
size_t length,
ucs_memory_type_t memoryType,
const bool enablePythonFuture,
RequestCallbackUserFunction callbackFunction,
RequestCallbackUserData callbackData);

/**
* @brief Constructor for `std::shared_ptr<ucxx::RequestAm>` receive.
*
* The constructor for a `std::shared_ptr<ucxx::RequestAm>` object, creating a receive
* active message request, returning a pointer to a request object that can be later
* awaited and checked for errors. This is a non-blocking operation, and the status of
* the transfer must be verified from the resulting request object before the data can be
* consumed, the data is available via the `getRecvBuffer()` method if the transfer
* completed successfully.
*
* @throws ucxx::Error if `endpoint` is not a valid
* `std::shared_ptr<ucxx::Endpoint>`.
*
* @param[in] endpoint the parent endpoint.
* @param[in] enablePythonFuture whether a python future should be created and
* subsequently notified.
* @param[in] callbackFunction user-defined callback function to call upon completion.
* @param[in] callbackData user-defined data to pass to the `callbackFunction`.
*
* @returns The `shared_ptr<ucxx::RequestTag>` object
*/
friend std::shared_ptr<RequestAm> createRequestAmRecv(
std::shared_ptr<Endpoint> endpoint,
const bool enablePythonFuture,
RequestCallbackUserFunction callbackFunction,
RequestCallbackUserData callbackData);
friend std::shared_ptr<RequestAm> createRequestAm(std::shared_ptr<Endpoint> endpoint,
const data::RequestData requestData,
const bool enablePythonFuture,
RequestCallbackUserFunction callbackFunction,
RequestCallbackUserData callbackData);

virtual void populateDelayedSubmission();

Expand Down
Loading

0 comments on commit e80759b

Please sign in to comment.