From 0c402d8c64bad9c1171b1c9a5c43436046c7e652 Mon Sep 17 00:00:00 2001 From: 0x676e67 Date: Tue, 28 Jan 2025 23:21:31 +0800 Subject: [PATCH] fix(client): Fix `HTTP/2` websocket request (#165) --- src/client/legacy/client.rs | 2 +- tests/legacy_client.rs | 82 +++++++++++++++++++++++++++++++++++++ 2 files changed, 83 insertions(+), 1 deletion(-) diff --git a/src/client/legacy/client.rs b/src/client/legacy/client.rs index d3d5605..fd744d3 100644 --- a/src/client/legacy/client.rs +++ b/src/client/legacy/client.rs @@ -316,7 +316,7 @@ where } else { origin_form(req.uri_mut()); } - } else if req.method() == Method::CONNECT { + } else if req.method() == Method::CONNECT && !pooled.is_http2() { authority_form(req.uri_mut()); } diff --git a/tests/legacy_client.rs b/tests/legacy_client.rs index 0f11d77..95983df 100644 --- a/tests/legacy_client.rs +++ b/tests/legacy_client.rs @@ -807,6 +807,88 @@ fn client_upgrade() { assert_eq!(vec, b"bar=foo"); } +#[cfg(not(miri))] +#[test] +fn client_http2_upgrade() { + use http::{Method, Response, Version}; + use hyper::service::service_fn; + use tokio::io::{AsyncReadExt, AsyncWriteExt}; + use tokio::net::TcpListener; + + let _ = pretty_env_logger::try_init(); + let rt = runtime(); + let server = rt + .block_on(TcpListener::bind(SocketAddr::from(([127, 0, 0, 1], 0)))) + .unwrap(); + let addr = server.local_addr().unwrap(); + let mut connector = DebugConnector::new(); + connector.alpn_h2 = true; + + let client = Client::builder(TokioExecutor::new()).build(connector); + + rt.spawn(async move { + let (stream, _) = server.accept().await.expect("accept"); + let stream = TokioIo::new(stream); + let mut builder = hyper_util::server::conn::auto::Builder::new(TokioExecutor::new()); + // IMPORTANT: This is required to advertise our support for HTTP/2 websockets to the client. + builder.http2().enable_connect_protocol(); + let _ = builder + .serve_connection_with_upgrades( + stream, + service_fn(|req| async move { + assert_eq!(req.headers().get("host"), None); + assert_eq!(req.version(), Version::HTTP_2); + assert_eq!( + req.headers().get(http::header::SEC_WEBSOCKET_VERSION), + Some(&http::header::HeaderValue::from_static("13")) + ); + assert_eq!( + req.extensions().get::(), + Some(&hyper::ext::Protocol::from_static("websocket")) + ); + + let on_upgrade = hyper::upgrade::on(req); + tokio::spawn(async move { + let upgraded = on_upgrade.await.unwrap(); + let mut io = TokioIo::new(upgraded); + + let mut vec = vec![]; + io.read_buf(&mut vec).await.unwrap(); + assert_eq!(vec, b"foo=bar"); + io.write_all(b"bar=foo").await.unwrap(); + }); + + Ok::<_, hyper::Error>(Response::new(Empty::::new())) + }), + ) + .await + .expect("server"); + }); + + let req = Request::builder() + .method(Method::CONNECT) + .uri(&*format!("http://{}/up", addr)) + .header(http::header::SEC_WEBSOCKET_VERSION, "13") + .version(Version::HTTP_2) + .extension(hyper::ext::Protocol::from_static("websocket")) + .body(Empty::::new()) + .unwrap(); + + let res = client.request(req); + let res = rt.block_on(res).unwrap(); + + assert_eq!(res.status(), http::StatusCode::OK); + assert_eq!(res.version(), Version::HTTP_2); + + let upgraded = rt.block_on(hyper::upgrade::on(res)).expect("on_upgrade"); + let mut io = TokioIo::new(upgraded); + + rt.block_on(io.write_all(b"foo=bar")).unwrap(); + let mut vec = vec![]; + rt.block_on(io.read_to_end(&mut vec)).unwrap(); + assert_eq!(vec, b"bar=foo"); +} + #[cfg(not(miri))] #[test] fn alpn_h2() {