Skip to content

Commit

Permalink
remove magic_endpoint_direct_conn_cb in favor of `magic_endpoint_co…
Browse files Browse the repository at this point in the history
…nn_type_cb`
  • Loading branch information
ramfox committed Apr 26, 2024
1 parent e80ba64 commit 7b0fc6d
Show file tree
Hide file tree
Showing 2 changed files with 0 additions and 236 deletions.
19 changes: 0 additions & 19 deletions irohnet.h
Original file line number Diff line number Diff line change
Expand Up @@ -533,25 +533,6 @@ magic_endpoint_connect (
MagicEndpoint_t *
magic_endpoint_default (void);

/** \brief
* Run a callback once you have a direct connection to a peer
*
* Does not block. The provided callback will be called when we have a direct
* connection to the peer associated with the `node_id`, or the timeout has occurred.
*
* To wait indefinitely, provide -1 for the timeout parameter.
*
* `ctx` is passed along to the callback, to allow passing context, it must be thread safe as the callback is
* called from another thread.
*/
void
magic_endpoint_direct_conn_cb (
MagicEndpoint_t * ep,
void const * ctx,
PublicKey_t const * node_id,
ssize_t timeout,
void (*cb)(void const *, MagicEndpointResult_t));

/** \brief
* Frees the magic endpoint.
*/
Expand Down
217 changes: 0 additions & 217 deletions src/magic_endpoint.rs
Original file line number Diff line number Diff line change
Expand Up @@ -599,85 +599,6 @@ pub fn magic_endpoint_accept_any_cb(
});
}

/// Run a callback once you have a direct connection to a peer
///
/// Does not block. The provided callback will be called when we have a direct
/// connection to the peer associated with the `node_id`, or the timeout has occurred.
///
/// To wait indefinitely, provide -1 for the timeout parameter.
///
/// `ctx` is passed along to the callback, to allow passing context, it must be thread safe as the callback is
/// called from another thread.
#[ffi_export]
pub fn magic_endpoint_direct_conn_cb(
ep: repr_c::Box<MagicEndpoint>,
ctx: *const c_void,
node_id: &PublicKey,
timeout: isize,
cb: unsafe extern "C" fn(ctx: *const c_void, res: MagicEndpointResult),
) {
// hack around the fact that `*const c_void` is not Send
struct CtxPtr(*const c_void);
unsafe impl Send for CtxPtr {}
let ctx_ptr = CtxPtr(ctx);

let node_id: iroh_net::key::PublicKey = node_id.into();

TOKIO_EXECUTOR.spawn(async move {
// make the compiler happy
let _ = &ctx_ptr;

async fn connect(
ep: repr_c::Box<MagicEndpoint>,
node_id: iroh_net::key::PublicKey,
) -> anyhow::Result<()> {
ep.ep
.read()
.await
.as_ref()
.expect("endpoint not initalized")
.add_node_addr(iroh_net::NodeAddr::new(node_id))?;

let mut stream = ep
.ep
.read()
.await
.as_ref()
.expect("endpoint not initalized")
.conn_type_stream(&node_id)?;

while let Some(conn_type) = stream.next().await {
if matches!(conn_type, iroh_net::magicsock::ConnectionType::Direct(_)) {
return Ok(());
}
}
anyhow::bail!("stream ended before getting a direct connection");
}

let res = match timeout {
-1 => connect(ep, node_id).await,
_ => {
let timeout = Duration::from_millis(timeout as u64);
match tokio::time::timeout(timeout, connect(ep, node_id)).await {
Ok(Ok(_)) => Ok(()),
Ok(Err(err)) => Err(err),
Err(_) => Err(anyhow::anyhow!("timeout")),
}
}
};

match res {
Ok(_) => unsafe {
cb(ctx_ptr.0, MagicEndpointResult::Ok);
},
Err(err) => unsafe {
warn!("accept failed: {:?}", err);
cb(ctx_ptr.0, MagicEndpointResult::AcceptFailed);
},
}
});
}

#[derive_ReprC]
#[repr(u8)]
#[derive(Debug, Copy, Clone, PartialEq, Eq)]
Expand Down Expand Up @@ -1458,144 +1379,6 @@ mod tests {
.expect("receiver dropped or channel full");
}

#[test]
fn test_direct_conn_cb() {
let alpn: vec::Vec<u8> = b"/cool/alpn/1".to_vec().into();

// create config
let mut config_server = magic_endpoint_config_default();
magic_endpoint_config_add_alpn(&mut config_server, alpn.as_ref());

let mut config_client = magic_endpoint_config_default();
magic_endpoint_config_add_alpn(&mut config_client, alpn.as_ref());

let (s, r) = std::sync::mpsc::channel();
let (client_s, client_r) = std::sync::mpsc::channel();

// setup server
let alpn_s = alpn.clone();
let server_thread = std::thread::spawn(move || {
// create magic endpoint and bind
let ep = magic_endpoint_default();
let bind_res = magic_endpoint_bind(&config_server, 0, &ep);
assert_eq!(bind_res, MagicEndpointResult::Ok);

let mut node_addr = node_addr_default();
let res = magic_endpoint_my_addr(&ep, &mut node_addr);
assert_eq!(res, MagicEndpointResult::Ok);

s.send(node_addr).unwrap();

let ep = Arc::new(ep);
let alpn_s = alpn_s.clone();

// accept connection
println!("[s] accepting conn");
let conn = connection_default();
let mut alpn = vec::Vec::EMPTY;
let res = magic_endpoint_accept_any(&ep, &mut alpn, &conn);
assert_eq!(res, MagicEndpointResult::Ok);

if alpn.as_ref() != alpn_s.as_ref() {
panic!("unexpectd alpn: {:?}", alpn);
};

let mut send_stream = send_stream_default();
let mut recv_stream = recv_stream_default();
let accept_res = connection_accept_bi(&conn, &mut send_stream, &mut recv_stream);
assert_eq!(accept_res, MagicEndpointResult::Ok);

println!("[s] reading");

let mut recv_buffer = vec![0u8; 1024];
let read_res = recv_stream_read(&mut recv_stream, (&mut recv_buffer[..]).into());
assert!(read_res > 0);
assert_eq!(
std::str::from_utf8(&recv_buffer[..read_res as usize]).unwrap(),
"hello world",
);

println!("[s] sending");
let send_res = send_stream_write(&mut send_stream, "hello client".as_bytes().into());
assert_eq!(send_res, MagicEndpointResult::Ok);

let res = send_stream_finish(send_stream);
assert_eq!(res, MagicEndpointResult::Ok);
client_r.recv().unwrap();
});

let (direct_conn_s, mut direct_conn_r): (
tokio::sync::mpsc::Sender<MagicEndpointResult>,
tokio::sync::mpsc::Receiver<MagicEndpointResult>,
) = tokio::sync::mpsc::channel(1);

// setup client
let client_thread = std::thread::spawn(move || {
// create magic endpoint and bind
let ep = magic_endpoint_default();
let bind_res = magic_endpoint_bind(&config_client, 0, &ep);
assert_eq!(bind_res, MagicEndpointResult::Ok);

// wait for addr from server
let node_addr = r.recv().unwrap();

let alpn = alpn.clone();

// wait for a moment to make sure the server is ready
std::thread::sleep(std::time::Duration::from_millis(100));

println!("[c] dialing");
// connect to server
let conn = connection_default();
let connect_res = magic_endpoint_connect(&ep, alpn.as_ref(), node_addr.clone(), &conn);
assert_eq!(connect_res, MagicEndpointResult::Ok);

let mut send_stream = send_stream_default();
let mut recv_stream = recv_stream_default();
let open_res = connection_open_bi(&conn, &mut send_stream, &mut recv_stream);
assert_eq!(open_res, MagicEndpointResult::Ok);

let s_ptr: *const c_void = &direct_conn_s as *const _ as *const c_void;
magic_endpoint_direct_conn_cb(
ep,
s_ptr,
&node_addr.node_id,
5000,
direct_conn_callback,
);

println!("[c] sending");
let send_res = send_stream_write(&mut send_stream, "hello world".as_bytes().into());
assert_eq!(send_res, MagicEndpointResult::Ok);

println!("[c] reading");

let mut recv_buffer = vec![0u8; 1024];
let read_res = recv_stream_read(&mut recv_stream, (&mut recv_buffer[..]).into());
assert!(read_res > 0);
assert_eq!(
std::str::from_utf8(&recv_buffer[..read_res as usize]).unwrap(),
"hello client"
);

let finish_res = send_stream_finish(send_stream);
assert_eq!(finish_res, MagicEndpointResult::Ok);
client_s.send(()).unwrap();
});

server_thread.join().unwrap();
client_thread.join().unwrap();
let res = direct_conn_r.blocking_recv().unwrap();
match res {
MagicEndpointResult::Ok => {
println!("got direct connection!");
}
_ => {
panic!("did not get a direct connection: {res:?}");
}
}
}

type CallbackRes = (MagicEndpointResult, ConnectionType);

unsafe extern "C" fn conn_type_callback(
Expand Down

0 comments on commit 7b0fc6d

Please sign in to comment.