Skip to content

Commit

Permalink
remove magic_endpoint_direct_conn_cb in favor of `magic_endpoint_co…
Browse files Browse the repository at this point in the history
…nn_type_cb`
  • Loading branch information
ramfox authored and “ramfox” committed Jan 31, 2025
1 parent c1d8a4a commit c6b79c3
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 252 deletions.
23 changes: 4 additions & 19 deletions irohnet.h
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,10 @@ enum EndpointResult {
* It is common to simply log this error and move on.
*/
ENDPOINT_RESULT_INCOMING_ERROR,
/** \brief
* Unable to find connection for the given `NodeId`
*/
ENDPOINT_RESULT_CONNECTION_TYPE_ERROR,
}
#ifndef DOXYGEN
; typedef uint8_t
Expand Down Expand Up @@ -626,25 +630,6 @@ endpoint_connect (
Endpoint_t *
endpoint_default (void);

/** \brief
* Run a callback once you have a direct connection to a peer
*
* Does not block. The provided callback will be called when we have a direct
* connection to the peer associated with the `node_id`, or the timeout has occurred.
*
* To wait indefinitely, provide -1 for the timeout parameter.
*
* `ctx` is passed along to the callback, to allow passing context, it must be thread safe as the callback is
* called from another thread.
*/
void
endpoint_direct_conn_cb (
Endpoint_t * ep,
void const * ctx,
PublicKey_t const * node_id,
ssize_t timeout,
void (*cb)(void const *, EndpointResult_t));

/** \brief
* Frees the iroh endpoint.
*/
Expand Down
240 changes: 7 additions & 233 deletions src/endpoint.rs
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,8 @@ pub enum EndpointResult {
///
/// It is common to simply log this error and move on.
IncomingError,
/// Unable to find connection for the given `NodeId`
ConnectionTypeError,
}

/// Attempts to bind the endpoint to the provided IPv4 and IPv6 address.
Expand Down Expand Up @@ -713,83 +715,6 @@ pub fn endpoint_accept_any_cb(
});
}

/// Run a callback once you have a direct connection to a peer
///
/// Does not block. The provided callback will be called when we have a direct
/// connection to the peer associated with the `node_id`, or the timeout has occurred.
///
/// To wait indefinitely, provide -1 for the timeout parameter.
///
/// `ctx` is passed along to the callback, to allow passing context, it must be thread safe as the callback is
/// called from another thread.
#[ffi_export]
pub fn endpoint_direct_conn_cb(
ep: repr_c::Box<Endpoint>,
ctx: *const c_void,
node_id: &PublicKey,
timeout: isize,
cb: unsafe extern "C" fn(ctx: *const c_void, res: EndpointResult),
) {
// hack around the fact that `*const c_void` is not Send
struct CtxPtr(*const c_void);
unsafe impl Send for CtxPtr {}
let ctx_ptr = CtxPtr(ctx);

let node_id: NodeId = node_id.into();

TOKIO_EXECUTOR.spawn(async move {
// make the compiler happy
let _ = &ctx_ptr;

async fn connect(ep: repr_c::Box<Endpoint>, node_id: NodeId) -> anyhow::Result<()> {
ep.ep
.read()
.await
.as_ref()
.expect("endpoint not initalized")
.add_node_addr(iroh::NodeAddr::new(node_id))?;

let mut stream = ep
.ep
.read()
.await
.as_ref()
.expect("endpoint not initalized")
.conn_type(node_id)?
.stream();

while let Some(conn_type) = stream.next().await {
if matches!(conn_type, iroh::endpoint::ConnectionType::Direct(_)) {
return Ok(());
}
}
anyhow::bail!("stream ended before getting a direct connection");
}

let res = match timeout {
-1 => connect(ep, node_id).await,
_ => {
let timeout = Duration::from_millis(timeout as u64);
match tokio::time::timeout(timeout, connect(ep, node_id)).await {
Ok(Ok(_)) => Ok(()),
Ok(Err(err)) => Err(err),
Err(_) => Err(anyhow::anyhow!("timeout")),
}
}
};

match res {
Ok(_) => unsafe {
cb(ctx_ptr.0, EndpointResult::Ok);
},
Err(err) => unsafe {
warn!("accept failed: {:?}", err);
cb(ctx_ptr.0, EndpointResult::AcceptFailed);
},
}
});
}

#[derive_ReprC]
#[repr(u8)]
#[derive(Debug, Copy, Clone, PartialEq, Eq)]
Expand Down Expand Up @@ -845,20 +770,6 @@ pub fn endpoint_conn_type_cb(
// make the compiler happy
let _ = &ctx_ptr;

let res = ep
.ep
.read()
.await
.as_ref()
.expect("endpoint not initalized")
.add_node_addr(iroh::NodeAddr::new(node_id));
if res.is_err() {
unsafe {
cb(ctx_ptr.0, EndpointResult::AddrError, ConnectionType::None);
}
return;
}

let mut stream = match ep
.ep
.read()
Expand All @@ -869,7 +780,11 @@ pub fn endpoint_conn_type_cb(
{
Err(_) => {
unsafe {
cb(ctx_ptr.0, EndpointResult::AddrError, ConnectionType::None);
cb(
ctx_ptr.0,
EndpointResult::ConnectionTypeError,
ConnectionType::None,
);
}
return;
}
Expand Down Expand Up @@ -1619,147 +1534,6 @@ mod tests {
client_thread.join().unwrap();
}

unsafe extern "C" fn direct_conn_callback(ctx: *const c_void, res: EndpointResult) {
// unsafe b/c dereferencing a raw pointer
let sender: &tokio::sync::mpsc::Sender<EndpointResult> =
unsafe { &(*(ctx as *const tokio::sync::mpsc::Sender<EndpointResult>)) };
sender
.try_send(res)
.expect("receiver dropped or channel full");
}

#[test]
fn test_direct_conn_cb() {
let alpn: vec::Vec<u8> = b"/cool/alpn/1".to_vec().into();

// create config
let mut config_server = endpoint_config_default();
endpoint_config_add_alpn(&mut config_server, alpn.as_ref());

let mut config_client = endpoint_config_default();
endpoint_config_add_alpn(&mut config_client, alpn.as_ref());

let (s, r) = std::sync::mpsc::channel();
let (client_s, client_r) = std::sync::mpsc::channel();

// setup server
let alpn_s = alpn.clone();
let server_thread = std::thread::spawn(move || {
// create magic endpoint and bind
let ep = endpoint_default();
let bind_res = endpoint_bind(&config_server, None, None, &ep);
assert_eq!(bind_res, EndpointResult::Ok);

let mut node_addr = node_addr_default();
let res = endpoint_node_addr(&ep, &mut node_addr);
assert_eq!(res, EndpointResult::Ok);

s.send(node_addr).unwrap();

let ep = Arc::new(ep);
let alpn_s = alpn_s.clone();

// accept connection
println!("[s] accepting conn");
let conn = connection_default();
let mut alpn = vec::Vec::EMPTY;
let res = endpoint_accept_any(&ep, &mut alpn, &conn);
assert_eq!(res, EndpointResult::Ok);

if alpn.as_ref() != alpn_s.as_ref() {
panic!("unexpectd alpn: {:?}", alpn);
};

let mut send_stream = send_stream_default();
let mut recv_stream = recv_stream_default();
let accept_res = connection_accept_bi(&conn, &mut send_stream, &mut recv_stream);
assert_eq!(accept_res, EndpointResult::Ok);

println!("[s] reading");

let mut recv_buffer = vec![0u8; 1024];
let read_res = recv_stream_read(&mut recv_stream, (&mut recv_buffer[..]).into());
assert!(read_res > 0);
assert_eq!(
std::str::from_utf8(&recv_buffer[..read_res as usize]).unwrap(),
"hello world",
);

println!("[s] sending");
let send_res = send_stream_write(&mut send_stream, "hello client".as_bytes().into());
assert_eq!(send_res, EndpointResult::Ok);

let res = send_stream_finish(send_stream);
assert_eq!(res, EndpointResult::Ok);
client_r.recv().unwrap();
});

let (direct_conn_s, mut direct_conn_r): (
tokio::sync::mpsc::Sender<EndpointResult>,
tokio::sync::mpsc::Receiver<EndpointResult>,
) = tokio::sync::mpsc::channel(1);

// setup client
let client_thread = std::thread::spawn(move || {
// create magic endpoint and bind
let ep = endpoint_default();
let bind_res = endpoint_bind(&config_client, None, None, &ep);
assert_eq!(bind_res, EndpointResult::Ok);

// wait for addr from server
let node_addr = r.recv().unwrap();

let alpn = alpn.clone();

// wait for a moment to make sure the server is ready
std::thread::sleep(std::time::Duration::from_millis(100));

println!("[c] dialing");
// connect to server
let conn = connection_default();
let connect_res = endpoint_connect(&ep, alpn.as_ref(), node_addr.clone(), &conn);
assert_eq!(connect_res, EndpointResult::Ok);

let mut send_stream = send_stream_default();
let mut recv_stream = recv_stream_default();
let open_res = connection_open_bi(&conn, &mut send_stream, &mut recv_stream);
assert_eq!(open_res, EndpointResult::Ok);

let s_ptr: *const c_void = &direct_conn_s as *const _ as *const c_void;
endpoint_direct_conn_cb(ep, s_ptr, &node_addr.node_id, 5000, direct_conn_callback);

println!("[c] sending");
let send_res = send_stream_write(&mut send_stream, "hello world".as_bytes().into());
assert_eq!(send_res, EndpointResult::Ok);

println!("[c] reading");

let mut recv_buffer = vec![0u8; 1024];
let read_res = recv_stream_read(&mut recv_stream, (&mut recv_buffer[..]).into());
assert!(read_res > 0);
assert_eq!(
std::str::from_utf8(&recv_buffer[..read_res as usize]).unwrap(),
"hello client"
);

let finish_res = send_stream_finish(send_stream);
assert_eq!(finish_res, EndpointResult::Ok);
client_s.send(()).unwrap();
});

server_thread.join().unwrap();
client_thread.join().unwrap();
let res = direct_conn_r.blocking_recv().unwrap();
match res {
EndpointResult::Ok => {
println!("got direct connection!");
}
_ => {
panic!("did not get a direct connection: {res:?}");
}
}
}

type CallbackRes = (EndpointResult, ConnectionType);

unsafe extern "C" fn conn_type_callback(
Expand Down

0 comments on commit c6b79c3

Please sign in to comment.