Skip to content

Commit

Permalink
STAR-1275: Read the peers SRT version and the negotiated latency from…
Browse files Browse the repository at this point in the history
… the socket

Pass a reference to a ConnectionInformation struct back to user
when a new client connects and the called connects to a server.
  • Loading branch information
Per Moberg committed May 10, 2024
1 parent c7d91e9 commit 0f47a9f
Show file tree
Hide file tree
Showing 3 changed files with 65 additions and 19 deletions.
33 changes: 31 additions & 2 deletions SRTNet.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -260,7 +260,8 @@ SRTNet::ClientConnectStatus SRTNet::clientConnectToServer() {
if (result != SRT_ERROR) {
mClientConnected = true;
if (connectedToServer) {
connectedToServer(mConnectionContext, mContext);
ConnectionInformation connectionInformation = getConnectionInformation(mContext);
connectedToServer(mConnectionContext, mContext, connectionInformation);
}
// Break for-loop on first successful connect call
break;
Expand Down Expand Up @@ -315,7 +316,9 @@ bool SRTNet::waitForSRTClient(bool singleClient) {
}

SRT_LOGGER(true, LOGG_NOTIFY, "Client connected: " << newSocketCandidate);
auto ctx = clientConnected(*reinterpret_cast<sockaddr*>(&theirAddr), newSocketCandidate, mConnectionContext);

ConnectionInformation connectionInformation = getConnectionInformation(newSocketCandidate);
auto ctx = clientConnected(*reinterpret_cast<sockaddr*>(&theirAddr), newSocketCandidate, mConnectionContext, connectionInformation);

if (!ctx) {
// No ctx in return from clientConnected callback means client was rejected by user.
Expand Down Expand Up @@ -948,3 +951,29 @@ uint16_t SRTNet::getLocallyBoundPort() const {

return 0;
}

SRTNet::ConnectionInformation SRTNet::getConnectionInformation(SRTSOCKET socket) {
ConnectionInformation connectionInformation;

uint8_t clientSrtVersion[4];
int clientSrtVersionSize = sizeof(clientSrtVersion);
if (SRT_ERROR != srt_getsockflag(socket, SRTO_PEERVERSION, &clientSrtVersion, &clientSrtVersionSize)) {
// The SRT version is stored as an int (little endian), like 0x00XXYYZZ, where XX is major, YY is minor, and ZZ is patch version
connectionInformation.mPeerSrtVersion =
std::to_string((int32_t)clientSrtVersion[2]) + "." +
std::to_string((int32_t)clientSrtVersion[1]) + "." +
std::to_string((int32_t)clientSrtVersion[0]);
} else {
SRT_LOGGER(true, LOGG_ERROR, "Failed to get peer SRT version from the new connection: " << srt_getlasterror_str());
}

int32_t negotiatedLatency = 0;
int negotiatedLatencySize = sizeof(negotiatedLatency);
if (SRT_ERROR != srt_getsockflag(socket, SRTO_PEERLATENCY, &negotiatedLatency, &negotiatedLatencySize)) {
connectionInformation.mNegotiatedLatency = negotiatedLatency;
} else {
SRT_LOGGER(true, LOGG_ERROR, "Failed to get peer latency from the new connection: " << srt_getlasterror_str());
}

return connectionInformation;
}
42 changes: 28 additions & 14 deletions SRTNet.h
Original file line number Diff line number Diff line change
Expand Up @@ -47,19 +47,22 @@ enum SRTNetInstant : int { no, yes };

class SRTNet {
public:

enum class Mode {
unknown,
server,
client
};
enum class Mode { unknown, server, client };

// Fill this class with all information you need for the duration of the connection both client and server
class NetworkConnection {
public:
std::any mObject;
};

/**
* @brief Connection information that is fetched when a client connects to a server.
*/
struct ConnectionInformation {
std::string mPeerSrtVersion = "n/a"; // The SRT version of the peer
int32_t mNegotiatedLatency = -1; // The latency that was negotiated with the peer
};

/**
*
* @brief Constructor that can set a log prefix which will be added to the start of all log messages from this
Expand Down Expand Up @@ -231,10 +234,10 @@ class SRTNet {
std::pair<SRTSOCKET, std::shared_ptr<NetworkConnection>> getConnectedServer();

/**
*
*
* @brief Check if client is connected to remote end
* @returns True if client is connected to the the remote end, false otherwise or if this instance is in server mode
*/
*/
bool isConnectedToServer() const;

/**
Expand All @@ -260,7 +263,7 @@ class SRTNet {
* @brief Get the current operating mode.
* @returns The operating mode.
*
*/
*/
Mode getCurrentMode() const;

/**
Expand All @@ -274,7 +277,8 @@ class SRTNet {
* @param message the line to be logged
*
*/
static void defaultLogHandler(void* opaque, int level, const char* file, int line, const char* area, const char* message);
static void
defaultLogHandler(void* opaque, int level, const char* file, int line, const char* area, const char* message);

/**
*
Expand All @@ -288,7 +292,8 @@ class SRTNet {
/// Callback handling connecting clients (only server mode)
std::function<std::shared_ptr<NetworkConnection>(struct sockaddr& sin,
SRTSOCKET newSocket,
std::shared_ptr<NetworkConnection>& ctx)>
std::shared_ptr<NetworkConnection>& ctx,
const ConnectionInformation& connectionInformation)>
clientConnected = nullptr;

/// Callback receiving data type vector
Expand All @@ -310,7 +315,10 @@ class SRTNet {
std::function<void(std::shared_ptr<NetworkConnection>& ctx, SRTSOCKET lSocket)> clientDisconnected = nullptr;

/// Callback called whenever the client gets connected to the server (client mode only)
std::function<void(std::shared_ptr<NetworkConnection>& ctx, SRTSOCKET lSocket)> connectedToServer = nullptr;
std::function<void(std::shared_ptr<NetworkConnection>& ctx,
SRTSOCKET lSocket,
const ConnectionInformation& connectionInformation)>
connectedToServer = nullptr;

// delete copy and move constructors and assign operators
SRTNet(SRTNet const&) = delete; // Copy construct
Expand Down Expand Up @@ -380,9 +388,9 @@ class SRTNet {
* @brief Enum for the client connection status.
*/
enum ClientConnectStatus {
success, // Client was able to connect to the server
success, // Client was able to connect to the server
failToResolveAddress, // Client was not able to resolve the remote ip or port
failToConnect // Client was not able to connect to the server
failToConnect // Client was not able to connect to the server
};

/**
Expand Down Expand Up @@ -416,6 +424,12 @@ class SRTNet {
*/
bool createClientSocket();

/**
* @brief Fetch the connection information from the SRT socket.
* @return a ConnectionInformation struct with all the connection information that could be fetched.
*/
ConnectionInformation getConnectionInformation(SRTSOCKET socket);

static SRT_LOG_HANDLER_FN* gLogHandler;
static int gLogLevel;

Expand Down
9 changes: 6 additions & 3 deletions test/TestSrt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,8 @@ class TestSRTFixture : public ::testing::Test {

// notice when client connects to server
mServer.clientConnected = [&](struct sockaddr& sin, SRTSOCKET newSocket,
std::shared_ptr<SRTNet::NetworkConnection>& ctx) {
std::shared_ptr<SRTNet::NetworkConnection>& ctx,
const SRTNet::ConnectionInformation&) {
{
std::lock_guard<std::mutex> lock(mConnectedMutex);
mConnected = true;
Expand Down Expand Up @@ -150,7 +151,8 @@ TEST(TestSrt, StartStop) {

// notice when client connects to server
server.clientConnected = [&](struct sockaddr& sin, SRTSOCKET newSocket,
std::shared_ptr<SRTNet::NetworkConnection>& ctx) {
std::shared_ptr<SRTNet::NetworkConnection>& ctx,
const SRTNet::ConnectionInformation&) {
{
std::lock_guard<std::mutex> lock(connectedMutex);
connected = true;
Expand Down Expand Up @@ -239,7 +241,8 @@ TEST(TestSrt, TestPsk) {

auto ctx = std::make_shared<SRTNet::NetworkConnection>();
server.clientConnected = [&](struct sockaddr& sin, SRTSOCKET newSocket,
std::shared_ptr<SRTNet::NetworkConnection>& ctx) { return ctx; };
std::shared_ptr<SRTNet::NetworkConnection>& ctx,
const SRTNet::ConnectionInformation&) { return ctx; };
ASSERT_TRUE(server.startServer("127.0.0.1", 8009, 16, 1000, 100, SRT_LIVE_MAX_PLSIZE, 5000, kValidPsk, false, ctx));
EXPECT_FALSE(client.startClient("127.0.0.1", 8009, 16, 1000, 100, ctx, SRT_LIVE_MAX_PLSIZE, false, 5000, kInvalidPsk))
<< "Expect to fail when using incorrect PSK";
Expand Down

0 comments on commit 0f47a9f

Please sign in to comment.