Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(awc): allow to set a specific sni host on the request #3522

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions awc/CHANGES.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
- Update `brotli` dependency to `7`.
- Prevent panics on connection pool drop when Tokio runtime is shutdown early.
- Minimum supported Rust version (MSRV) is now 1.75.
- Allow to set a specific SNI hostname on the request for TLS connections.

## 3.5.1

Expand Down
22 changes: 14 additions & 8 deletions awc/src/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,15 @@ use std::{fmt, net::IpAddr, rc::Rc, time::Duration};
use actix_http::{
error::HttpError,
header::{self, HeaderMap, HeaderName, TryIntoHeaderPair},
Uri,
};
use actix_rt::net::{ActixStream, TcpStream};
use actix_service::{boxed, Service};
use base64::prelude::*;

use crate::{
client::{
ClientConfig, ConnectInfo, Connector, ConnectorService, TcpConnectError, TcpConnection,
ClientConfig, ConnectInfo, Connector, ConnectorService, HostnameWithSni, TcpConnectError,
TcpConnection,
},
connect::DefaultConnector,
error::SendRequestError,
Expand Down Expand Up @@ -46,8 +46,8 @@ impl ClientBuilder {
#[allow(clippy::new_ret_no_self)]
pub fn new() -> ClientBuilder<
impl Service<
ConnectInfo<Uri>,
Response = TcpConnection<Uri, TcpStream>,
ConnectInfo<HostnameWithSni>,
Response = TcpConnection<HostnameWithSni, TcpStream>,
Error = TcpConnectError,
> + Clone,
(),
Expand All @@ -69,16 +69,22 @@ impl ClientBuilder {

impl<S, Io, M> ClientBuilder<S, M>
where
S: Service<ConnectInfo<Uri>, Response = TcpConnection<Uri, Io>, Error = TcpConnectError>
+ Clone
S: Service<
ConnectInfo<HostnameWithSni>,
Response = TcpConnection<HostnameWithSni, Io>,
Error = TcpConnectError,
> + Clone
+ 'static,
Io: ActixStream + fmt::Debug + 'static,
{
/// Use custom connector service.
pub fn connector<S1, Io1>(self, connector: Connector<S1>) -> ClientBuilder<S1, M>
where
S1: Service<ConnectInfo<Uri>, Response = TcpConnection<Uri, Io1>, Error = TcpConnectError>
+ Clone
S1: Service<
ConnectInfo<HostnameWithSni>,
Response = TcpConnection<HostnameWithSni, Io1>,
Error = TcpConnectError,
> + Clone
+ 'static,
Io1: ActixStream + fmt::Debug + 'static,
{
Expand Down
126 changes: 91 additions & 35 deletions awc/src/client/connector.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,20 +16,51 @@
use actix_service::Service;
use actix_tls::connect::{
ConnectError as TcpConnectError, ConnectInfo, Connection as TcpConnection,
Connector as TcpConnector, Resolver,
Connector as TcpConnector, Host, Resolver,
};
use futures_core::{future::LocalBoxFuture, ready};
use http::Uri;
use pin_project_lite::pin_project;

use super::{
config::ConnectorConfig,
connection::{Connection, ConnectionIo},
error::ConnectError,
pool::ConnectionPool,
Connect,
Connect, ServerName,
};

pub enum HostnameWithSni {
ForTcp(String, u16, Option<ServerName>),
ForTls(String, u16, Option<ServerName>),
}

impl Host for HostnameWithSni {
fn hostname(&self) -> &str {
match self {
HostnameWithSni::ForTcp(hostname, _, _) => hostname,
HostnameWithSni::ForTls(hostname, _, sni) => sni.as_deref().unwrap_or(hostname),
}
}

fn port(&self) -> Option<u16> {
match self {
HostnameWithSni::ForTcp(_, port, _) => Some(*port),
HostnameWithSni::ForTls(_, port, _) => Some(*port),
}
}
}

impl HostnameWithSni {
pub fn to_tls(self) -> Self {

Check failure on line 54 in awc/src/client/connector.rs

View workflow job for this annotation

GitHub Actions / clippy

[clippy] reported by reviewdog 🐶 error: methods with the following characteristics: (`to_*` and `self` type is not `Copy`) usually take `self` by reference --> awc/src/client/connector.rs:54:19 | 54 | pub fn to_tls(self) -> Self { | ^^^^ | = help: consider choosing a less ambiguous name = help: for further information visit https://rust-lang.github.io/rust-clippy/master/index.html#wrong_self_convention = note: `-D clippy::wrong-self-convention` implied by `-D warnings` = help: to override `-D warnings` add `#[allow(clippy::wrong_self_convention)]` Raw Output: awc/src/client/connector.rs:54:19:e:error: methods with the following characteristics: (`to_*` and `self` type is not `Copy`) usually take `self` by reference --> awc/src/client/connector.rs:54:19 | 54 | pub fn to_tls(self) -> Self { | ^^^^ | = help: consider choosing a less ambiguous name = help: for further information visit https://rust-lang.github.io/rust-clippy/master/index.html#wrong_self_convention = note: `-D clippy::wrong-self-convention` implied by `-D warnings` = help: to override `-D warnings` add `#[allow(clippy::wrong_self_convention)]` __END__

Check failure on line 54 in awc/src/client/connector.rs

View workflow job for this annotation

GitHub Actions / clippy

[clippy] reported by reviewdog 🐶 error: methods with the following characteristics: (`to_*` and `self` type is not `Copy`) usually take `self` by reference --> awc/src/client/connector.rs:54:19 | 54 | pub fn to_tls(self) -> Self { | ^^^^ | = help: consider choosing a less ambiguous name = help: for further information visit https://rust-lang.github.io/rust-clippy/master/index.html#wrong_self_convention = note: `-D clippy::wrong-self-convention` implied by `-D warnings` = help: to override `-D warnings` add `#[allow(clippy::wrong_self_convention)]` Raw Output: awc/src/client/connector.rs:54:19:e:error: methods with the following characteristics: (`to_*` and `self` type is not `Copy`) usually take `self` by reference --> awc/src/client/connector.rs:54:19 | 54 | pub fn to_tls(self) -> Self { | ^^^^ | = help: consider choosing a less ambiguous name = help: for further information visit https://rust-lang.github.io/rust-clippy/master/index.html#wrong_self_convention = note: `-D clippy::wrong-self-convention` implied by `-D warnings` = help: to override `-D warnings` add `#[allow(clippy::wrong_self_convention)]` __END__
match self {
HostnameWithSni::ForTcp(hostname, port, sni) => {
HostnameWithSni::ForTls(hostname, port, sni)
}
HostnameWithSni::ForTls(_, _, _) => self,
}
}
}

enum OurTlsConnector {
#[allow(dead_code)] // only dead when no TLS feature is enabled
None,
Expand Down Expand Up @@ -95,8 +126,8 @@
#[allow(clippy::new_ret_no_self, clippy::let_unit_value)]
pub fn new() -> Connector<
impl Service<
ConnectInfo<Uri>,
Response = TcpConnection<Uri, TcpStream>,
ConnectInfo<HostnameWithSni>,
Response = TcpConnection<HostnameWithSni, TcpStream>,
Error = actix_tls::connect::ConnectError,
> + Clone,
> {
Expand Down Expand Up @@ -214,8 +245,11 @@
pub fn connector<S1, Io1>(self, connector: S1) -> Connector<S1>
where
Io1: ActixStream + fmt::Debug + 'static,
S1: Service<ConnectInfo<Uri>, Response = TcpConnection<Uri, Io1>, Error = TcpConnectError>
+ Clone,
S1: Service<
ConnectInfo<HostnameWithSni>,
Response = TcpConnection<HostnameWithSni, Io1>,
Error = TcpConnectError,
> + Clone,
{
Connector {
connector,
Expand All @@ -235,8 +269,11 @@
// This remap is to hide ActixStream's trait methods. They are not meant to be called
// from user code.
IO: ActixStream + fmt::Debug + 'static,
S: Service<ConnectInfo<Uri>, Response = TcpConnection<Uri, IO>, Error = TcpConnectError>
+ Clone
S: Service<
ConnectInfo<HostnameWithSni>,
Response = TcpConnection<HostnameWithSni, IO>,
Error = TcpConnectError,
> + Clone
+ 'static,
{
/// Sets TCP connection timeout.
Expand Down Expand Up @@ -454,7 +491,7 @@
use actix_utils::future::{ready, Ready};

#[allow(non_local_definitions)]
impl IntoConnectionIo for TcpConnection<Uri, Box<dyn ConnectionIo>> {
impl IntoConnectionIo for TcpConnection<HostnameWithSni, Box<dyn ConnectionIo>> {
fn into_connection_io(self) -> (Box<dyn ConnectionIo>, Protocol) {
let io = self.into_parts().0;
(io, Protocol::Http2)
Expand Down Expand Up @@ -505,7 +542,7 @@
use actix_tls::connect::openssl::{reexports::AsyncSslStream, TlsConnector};

#[allow(non_local_definitions)]
impl<IO: ConnectionIo> IntoConnectionIo for TcpConnection<Uri, AsyncSslStream<IO>> {
impl<IO: ConnectionIo> IntoConnectionIo for TcpConnection<HostnameWithSni, AsyncSslStream<IO>> {
fn into_connection_io(self) -> (Box<dyn ConnectionIo>, Protocol) {
let sock = self.into_parts().0;
let h2 = sock
Expand Down Expand Up @@ -544,7 +581,7 @@
use actix_tls::connect::rustls_0_20::{reexports::AsyncTlsStream, TlsConnector};

#[allow(non_local_definitions)]
impl<Io: ConnectionIo> IntoConnectionIo for TcpConnection<Uri, AsyncTlsStream<Io>> {
impl<Io: ConnectionIo> IntoConnectionIo for TcpConnection<HostnameWithSni, AsyncTlsStream<Io>> {
fn into_connection_io(self) -> (Box<dyn ConnectionIo>, Protocol) {
let sock = self.into_parts().0;
let h2 = sock
Expand Down Expand Up @@ -579,7 +616,7 @@
use actix_tls::connect::rustls_0_21::{reexports::AsyncTlsStream, TlsConnector};

#[allow(non_local_definitions)]
impl<Io: ConnectionIo> IntoConnectionIo for TcpConnection<Uri, AsyncTlsStream<Io>> {
impl<Io: ConnectionIo> IntoConnectionIo for TcpConnection<HostnameWithSni, AsyncTlsStream<Io>> {
fn into_connection_io(self) -> (Box<dyn ConnectionIo>, Protocol) {
let sock = self.into_parts().0;
let h2 = sock
Expand Down Expand Up @@ -617,7 +654,7 @@
use actix_tls::connect::rustls_0_22::{reexports::AsyncTlsStream, TlsConnector};

#[allow(non_local_definitions)]
impl<Io: ConnectionIo> IntoConnectionIo for TcpConnection<Uri, AsyncTlsStream<Io>> {
impl<Io: ConnectionIo> IntoConnectionIo for TcpConnection<HostnameWithSni, AsyncTlsStream<Io>> {
fn into_connection_io(self) -> (Box<dyn ConnectionIo>, Protocol) {
let sock = self.into_parts().0;
let h2 = sock
Expand Down Expand Up @@ -652,7 +689,7 @@
use actix_tls::connect::rustls_0_23::{reexports::AsyncTlsStream, TlsConnector};

#[allow(non_local_definitions)]
impl<Io: ConnectionIo> IntoConnectionIo for TcpConnection<Uri, AsyncTlsStream<Io>> {
impl<Io: ConnectionIo> IntoConnectionIo for TcpConnection<HostnameWithSni, AsyncTlsStream<Io>> {
fn into_connection_io(self) -> (Box<dyn ConnectionIo>, Protocol) {
let sock = self.into_parts().0;
let h2 = sock
Expand Down Expand Up @@ -693,15 +730,17 @@
}
}

/// tcp service for map `TcpConnection<Uri, Io>` type to `(Io, Protocol)`
/// tcp service for map `TcpConnection<HostnameWithSni, Io>` type to `(Io, Protocol)`
#[derive(Clone)]
pub struct TcpConnectorService<S: Clone> {
service: S,
}

impl<S, Io> Service<Connect> for TcpConnectorService<S>
where
S: Service<Connect, Response = TcpConnection<Uri, Io>, Error = ConnectError> + Clone + 'static,
S: Service<Connect, Response = TcpConnection<HostnameWithSni, Io>, Error = ConnectError>
+ Clone
+ 'static,
{
type Response = (Io, Protocol);
type Error = ConnectError;
Expand All @@ -726,7 +765,7 @@

impl<Fut, Io> Future for TcpConnectorFuture<Fut>
where
Fut: Future<Output = Result<TcpConnection<Uri, Io>, ConnectError>>,
Fut: Future<Output = Result<TcpConnection<HostnameWithSni, Io>, ConnectError>>,
{
type Output = Result<(Io, Protocol), ConnectError>;

Expand Down Expand Up @@ -772,9 +811,10 @@
))]
impl<Tcp, Tls, IO> Service<Connect> for TlsConnectorService<Tcp, Tls>
where
Tcp:
Service<Connect, Response = TcpConnection<Uri, IO>, Error = ConnectError> + Clone + 'static,
Tls: Service<TcpConnection<Uri, IO>, Error = std::io::Error> + Clone + 'static,
Tcp: Service<Connect, Response = TcpConnection<HostnameWithSni, IO>, Error = ConnectError>
+ Clone
+ 'static,
Tls: Service<TcpConnection<HostnameWithSni, IO>, Error = std::io::Error> + Clone + 'static,
Tls::Response: IntoConnectionIo,
IO: ConnectionIo,
{
Expand Down Expand Up @@ -827,9 +867,14 @@

impl<S, Io, Fut1, Fut2, Res> Future for TlsConnectorFuture<S, Fut1, Fut2>
where
S: Service<TcpConnection<Uri, Io>, Response = Res, Error = std::io::Error, Future = Fut2>,
S: Service<
TcpConnection<HostnameWithSni, Io>,
Response = Res,
Error = std::io::Error,
Future = Fut2,
>,
S::Response: IntoConnectionIo,
Fut1: Future<Output = Result<TcpConnection<Uri, Io>, ConnectError>>,
Fut1: Future<Output = Result<TcpConnection<HostnameWithSni, Io>, ConnectError>>,
Fut2: Future<Output = Result<S::Response, S::Error>>,
Io: ConnectionIo,
{
Expand All @@ -843,10 +888,11 @@
timeout,
} => {
let res = ready!(fut.poll(cx))?;
let (io, hostname_with_sni) = res.into_parts();
let fut = tls_service
.take()
.expect("TlsConnectorFuture polled after complete")
.call(res);
.call(TcpConnection::new(hostname_with_sni.to_tls(), io));
let timeout = sleep(*timeout);
self.set(TlsConnectorFuture::TlsConnect { fut, timeout });
self.poll(cx)
Expand Down Expand Up @@ -880,8 +926,11 @@

impl<S, Io> Service<Connect> for TcpConnectorInnerService<S>
where
S: Service<ConnectInfo<Uri>, Response = TcpConnection<Uri, Io>, Error = TcpConnectError>
+ Clone
S: Service<
ConnectInfo<HostnameWithSni>,
Response = TcpConnection<HostnameWithSni, Io>,
Error = TcpConnectError,
> + Clone
+ 'static,
{
type Response = S::Response;
Expand All @@ -891,7 +940,13 @@
actix_service::forward_ready!(service);

fn call(&self, req: Connect) -> Self::Future {
let mut req = ConnectInfo::new(req.uri).set_addr(req.addr);
let mut req = ConnectInfo::new(HostnameWithSni::ForTcp(
req.hostname,
req.port,
req.sni_host,
))
.set_addr(req.addr)
.set_port(req.port);

if let Some(local_addr) = self.local_address {
req = req.set_local_addr(local_addr);
Expand All @@ -916,9 +971,9 @@

impl<Fut, Io> Future for TcpConnectorInnerFuture<Fut>
where
Fut: Future<Output = Result<TcpConnection<Uri, Io>, TcpConnectError>>,
Fut: Future<Output = Result<TcpConnection<HostnameWithSni, Io>, TcpConnectError>>,
{
type Output = Result<TcpConnection<Uri, Io>, ConnectError>;
type Output = Result<TcpConnection<HostnameWithSni, Io>, ConnectError>;

fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let this = self.project();
Expand Down Expand Up @@ -978,16 +1033,17 @@
}

fn call(&self, req: Connect) -> Self::Future {
match req.uri.scheme_str() {
Some("https") | Some("wss") => match self.tls_pool {
if req.tls {
match &self.tls_pool {
None => ConnectorServiceFuture::SslIsNotSupported,
Some(ref pool) => ConnectorServiceFuture::Tls {
Some(pool) => ConnectorServiceFuture::Tls {
fut: pool.call(req),
},
},
_ => ConnectorServiceFuture::Tcp {
}
} else {
ConnectorServiceFuture::Tcp {
fut: self.tcp_pool.call(req),
},
}
}
}
}
Expand Down
Loading
Loading