Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Execute user's callback after setting request status #32

Merged
merged 10 commits into from
May 2, 2023
4 changes: 3 additions & 1 deletion build_and_run.sh
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,9 @@ run_py_benchmark() {
}

if [[ $RUN_CPP_TESTS != 0 ]]; then
${BINARY_PATH}/gtests/libucxx/UCXX_TEST
# UCX_TCP_CM_REUSEADDR=y to be able to bind immediately to the same port before
# `TIME_WAIT` timeout
UCX_TCP_CM_REUSEADDR=y ${BINARY_PATH}/gtests/libucxx/UCXX_TEST
fi
if [[ $RUN_CPP_BENCH != 0 ]]; then
# run_cpp_benchmark PROGRESS_MODE
Expand Down
14 changes: 7 additions & 7 deletions cpp/src/request.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -138,13 +138,6 @@ void Request::process()
status,
ucs_status_string(status));

ucxx_trace_req_f(_ownerString.c_str(),
_request,
_operationName.c_str(),
"callback %p",
_callback.target<void (*)(void)>());
if (_callback) _callback(_callbackData);

if (status != UCS_OK) {
ucxx_error(
"error on %s with status %d (%s)", _operationName.c_str(), status, ucs_status_string(status));
Expand All @@ -154,6 +147,13 @@ void Request::process()
}

setStatus(status);

ucxx_trace_req_f(_ownerString.c_str(),
_request,
_operationName.c_str(),
"callback %p",
_callback.target<void (*)(void)>());
if (_callback) _callback(_callbackData);
wence- marked this conversation as resolved.
Show resolved Hide resolved
}

void Request::setStatus(ucs_status_t status)
Expand Down
43 changes: 43 additions & 0 deletions cpp/tests/request.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,49 @@ TEST_P(RequestTest, ProgressTagMulti)
ASSERT_THAT(_recv[i], ContainerEq(_send[i]));
}

TEST_P(RequestTest, TagUserCallback)
{
allocate();

std::vector<std::shared_ptr<ucxx::Request>> requests(2);
std::vector<ucs_status_t> requestStatus(2, UCS_INPROGRESS);

auto checkStatus = [&requests, &requestStatus](std::shared_ptr<void> data) {
auto idx = *std::static_pointer_cast<size_t>(data);
if (requests[idx] == nullptr) {
/**
* Unfortunately, we can't check the status this way if the request completes
* immediately, as `request[idx]` will only be assigned after completion, and thus,
* after the callback has been executed.
*
* TODO: find a better way to test this.
*/
requestStatus[idx] = UCS_OK;
} else {
requestStatus[idx] = requests[idx]->getStatus();
}
};

auto sendIndex = std::make_shared<size_t>(0u);
auto recvIndex = std::make_shared<size_t>(1u);

// Submit and wait for transfers to complete
requests[0] = _ep->tagSend(_sendPtr[0], _messageSize, 0, false, checkStatus, sendIndex);
requests[1] = _ep->tagRecv(_recvPtr[0], _messageSize, 0, false, checkStatus, recvIndex);
waitRequests(_worker, requests, _progressWorker);

copyResults();

// Assert status was set before user callback is executed
for (const auto status : requestStatus)
ASSERT_THAT(status, UCS_OK);
for (const auto request : requests)
ASSERT_THAT(request->getStatus(), UCS_OK);
wence- marked this conversation as resolved.
Show resolved Hide resolved

// Assert data correctness
ASSERT_THAT(_recv[0], ContainerEq(_send[0]));
}

INSTANTIATE_TEST_SUITE_P(ProgressModes,
RequestTest,
Combine(Values(ucxx::BufferType::Host),
Expand Down