Skip to content

Commit

Permalink
STAR-1275: Read the peers SRT version and the negotiated latency from…
Browse files Browse the repository at this point in the history
… the socket

Pass a reference to a ConnectionInformation struct back to user
when a new client connects and the called connects to a server.
  • Loading branch information
Per Moberg committed May 7, 2024
1 parent c7d91e9 commit 4b698c7
Show file tree
Hide file tree
Showing 3 changed files with 36 additions and 7 deletions.
23 changes: 21 additions & 2 deletions SRTNet.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -260,7 +260,8 @@ SRTNet::ClientConnectStatus SRTNet::clientConnectToServer() {
if (result != SRT_ERROR) {
mClientConnected = true;
if (connectedToServer) {
connectedToServer(mConnectionContext, mContext);
ConnectionInformation connectionInformation = getConnectionInformation(mContext);
connectedToServer(mConnectionContext, mContext, connectionInformation);
}
// Break for-loop on first successful connect call
break;
Expand Down Expand Up @@ -315,7 +316,9 @@ bool SRTNet::waitForSRTClient(bool singleClient) {
}

SRT_LOGGER(true, LOGG_NOTIFY, "Client connected: " << newSocketCandidate);
auto ctx = clientConnected(*reinterpret_cast<sockaddr*>(&theirAddr), newSocketCandidate, mConnectionContext);

ConnectionInformation connectionInformation = getConnectionInformation(newSocketCandidate);
auto ctx = clientConnected(*reinterpret_cast<sockaddr*>(&theirAddr), newSocketCandidate, mConnectionContext, connectionInformation);

if (!ctx) {
// No ctx in return from clientConnected callback means client was rejected by user.
Expand Down Expand Up @@ -948,3 +951,19 @@ uint16_t SRTNet::getLocallyBoundPort() const {

return 0;
}

SRTNet::ConnectionInformation SRTNet::getConnectionInformation(SRTSOCKET socket) {
uint8_t clientSrtVersion[4];
int clientSrtVersionSize = sizeof(clientSrtVersion);
srt_getsockflag(socket, SRTO_PEERVERSION, &clientSrtVersion, &clientSrtVersionSize);

int32_t negotiatedLatency = 0;
int negotiatedLatencySize = sizeof(negotiatedLatency);
srt_getsockflag(socket, SRTO_PEERLATENCY, &negotiatedLatency, &negotiatedLatencySize);

ConnectionInformation connectionInformation;
connectionInformation.mPeerSrtVersion = std::string(std::to_string((int32_t)clientSrtVersion[2]) + "." + std::to_string((int32_t)clientSrtVersion[1]) + "." + std::to_string((int32_t)clientSrtVersion[0]));
connectionInformation.mNegotiatedLatency = negotiatedLatency;

return std::move(connectionInformation);
}
11 changes: 9 additions & 2 deletions SRTNet.h
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,11 @@ class SRTNet {
std::any mObject;
};

struct ConnectionInformation {
std::string mPeerSrtVersion;
int32_t mNegotiatedLatency;
};

/**
*
* @brief Constructor that can set a log prefix which will be added to the start of all log messages from this
Expand Down Expand Up @@ -288,7 +293,8 @@ class SRTNet {
/// Callback handling connecting clients (only server mode)
std::function<std::shared_ptr<NetworkConnection>(struct sockaddr& sin,
SRTSOCKET newSocket,
std::shared_ptr<NetworkConnection>& ctx)>
std::shared_ptr<NetworkConnection>& ctx,
const ConnectionInformation& connectionInformation)>
clientConnected = nullptr;

/// Callback receiving data type vector
Expand All @@ -310,7 +316,7 @@ class SRTNet {
std::function<void(std::shared_ptr<NetworkConnection>& ctx, SRTSOCKET lSocket)> clientDisconnected = nullptr;

/// Callback called whenever the client gets connected to the server (client mode only)
std::function<void(std::shared_ptr<NetworkConnection>& ctx, SRTSOCKET lSocket)> connectedToServer = nullptr;
std::function<void(std::shared_ptr<NetworkConnection>& ctx, SRTSOCKET lSocket, const ConnectionInformation& connectionInformation)> connectedToServer = nullptr;

// delete copy and move constructors and assign operators
SRTNet(SRTNet const&) = delete; // Copy construct
Expand Down Expand Up @@ -443,4 +449,5 @@ class SRTNet {

const std::chrono::milliseconds kConnectionTimeout{1000};
const int64_t kEpollTimeoutMs{500};
static ConnectionInformation getConnectionInformation(SRTSOCKET socket);
};
9 changes: 6 additions & 3 deletions test/TestSrt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,8 @@ class TestSRTFixture : public ::testing::Test {

// notice when client connects to server
mServer.clientConnected = [&](struct sockaddr& sin, SRTSOCKET newSocket,
std::shared_ptr<SRTNet::NetworkConnection>& ctx) {
std::shared_ptr<SRTNet::NetworkConnection>& ctx,
const SRTNet::ConnectionInformation&) {
{
std::lock_guard<std::mutex> lock(mConnectedMutex);
mConnected = true;
Expand Down Expand Up @@ -150,7 +151,8 @@ TEST(TestSrt, StartStop) {

// notice when client connects to server
server.clientConnected = [&](struct sockaddr& sin, SRTSOCKET newSocket,
std::shared_ptr<SRTNet::NetworkConnection>& ctx) {
std::shared_ptr<SRTNet::NetworkConnection>& ctx,
const SRTNet::ConnectionInformation&) {
{
std::lock_guard<std::mutex> lock(connectedMutex);
connected = true;
Expand Down Expand Up @@ -239,7 +241,8 @@ TEST(TestSrt, TestPsk) {

auto ctx = std::make_shared<SRTNet::NetworkConnection>();
server.clientConnected = [&](struct sockaddr& sin, SRTSOCKET newSocket,
std::shared_ptr<SRTNet::NetworkConnection>& ctx) { return ctx; };
std::shared_ptr<SRTNet::NetworkConnection>& ctx,
const SRTNet::ConnectionInformation&) { return ctx; };
ASSERT_TRUE(server.startServer("127.0.0.1", 8009, 16, 1000, 100, SRT_LIVE_MAX_PLSIZE, 5000, kValidPsk, false, ctx));
EXPECT_FALSE(client.startClient("127.0.0.1", 8009, 16, 1000, 100, ctx, SRT_LIVE_MAX_PLSIZE, false, 5000, kInvalidPsk))
<< "Expect to fail when using incorrect PSK";
Expand Down

0 comments on commit 4b698c7

Please sign in to comment.