Skip to content

Commit

Permalink
Cleanup v3/v5 client connectors
Browse files Browse the repository at this point in the history
  • Loading branch information
fafhrd91 committed Jan 10, 2022
1 parent 446bbac commit 655db48
Show file tree
Hide file tree
Showing 9 changed files with 89 additions and 170 deletions.
4 changes: 4 additions & 0 deletions CHANGES.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
# Changes

## [0.8.3] - 2022-01-10

* Cleanup v3/v5 client connectors

## [0.8.2] - 2022-01-04

* Optimize compilation times
Expand Down
10 changes: 4 additions & 6 deletions Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "ntex-mqtt"
version = "0.8.2"
version = "0.8.3"
authors = ["ntex contributors <[email protected]>"]
description = "Client and Server framework for MQTT v5 and v3.1.1 protocols"
documentation = "https://docs.rs/ntex-mqtt"
Expand All @@ -12,8 +12,7 @@ exclude = [".gitignore", ".travis.yml", ".cargo/config"]
edition = "2018"

[dependencies]
ntex = "0.5.6"
ntex-util = "0.1.7"
ntex = "0.5.8"
bitflags = "1.3"
derive_more = "0.99"
log = "0.4"
Expand All @@ -23,10 +22,9 @@ pin-project-lite = "0.2"

[dev-dependencies]
env_logger = "0.9"
futures = "0.3"
ntex-tls = "0.1.0"
ntex-tls = "0.1.1"
rustls = "0.20"
rustls-pemfile = "0.2"
openssl = "0.10"

ntex = { version = "0.5.6", features = ["tokio", "rustls", "openssl"] }
ntex = { version = "0.5", features = ["tokio", "rustls", "openssl"] }
4 changes: 2 additions & 2 deletions src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ pub enum DecodeError {
// MQTT v3 only
PacketIdRequired,
MaxSizeExceeded,
Utf8Error(std::str::Utf8Error),
Utf8Error,
}

impl error::Error for DecodeError {}
Expand Down Expand Up @@ -134,7 +134,7 @@ impl PartialEq for DecodeError {
(DecodeError::PacketIdRequired, DecodeError::PacketIdRequired) => true,
(DecodeError::MaxSizeExceeded, DecodeError::MaxSizeExceeded) => true,
(DecodeError::MalformedPacket, DecodeError::MalformedPacket) => true,
(DecodeError::Utf8Error(_), _) => false,
(DecodeError::Utf8Error, DecodeError::Utf8Error) => true,
_ => false,
}
}
Expand Down
2 changes: 1 addition & 1 deletion src/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ impl Decode for Bytes {
impl Decode for ByteString {
fn decode(src: &mut Bytes) -> Result<Self, DecodeError> {
let bytes = Bytes::decode(src)?;
Ok(ByteString::try_from(bytes)?)
Ok(ByteString::try_from(bytes).map_err(|_| DecodeError::Utf8Error)?)
}
}

Expand Down
96 changes: 24 additions & 72 deletions src/v3/client/connector.rs
Original file line number Diff line number Diff line change
@@ -1,16 +1,10 @@
use std::{future::Future, rc::Rc};

use ntex::connect::{self, Address, Connect, Connector};
use ntex::io::{utils::Boxed, Filter, Io, IoBoxed};
use ntex::io::IoBoxed;
use ntex::service::{IntoService, Service};
use ntex::time::{timeout, Seconds};
use ntex::util::{ByteString, Bytes, Either, PoolId};

#[cfg(feature = "openssl")]
use ntex::connect::openssl::{OpensslConnector, SslConnector};

#[cfg(feature = "rustls")]
use ntex::connect::rustls::{ClientConfig, RustlsConnector};
use ntex::time::{timeout_checked, Seconds};
use ntex::util::{ByteString, Bytes, PoolId};

use super::{codec, connection::Client, error::ClientError, error::ProtocolError};
use crate::v3::shared::{MqttShared, MqttSinkPool};
Expand All @@ -34,11 +28,11 @@ where
{
#[allow(clippy::new_ret_no_self)]
/// Create new mqtt connector
pub fn new(address: A) -> MqttConnector<A, Boxed<Connector<A>, Connect<A>>> {
pub fn new(address: A) -> MqttConnector<A, Connector<A>> {
MqttConnector {
address,
pkt: codec::Connect::default(),
connector: Boxed::new(Connector::default()),
connector: Connector::default(),
max_send: 16,
max_receive: 16,
max_packet_size: 64 * 1024,
Expand All @@ -52,7 +46,6 @@ where
impl<A, T> MqttConnector<A, T>
where
A: Address + Clone,
T: Service<Connect<A>, Response = IoBoxed, Error = connect::ConnectError>,
{
#[inline]
/// Create new client and provide client id
Expand Down Expand Up @@ -177,81 +170,40 @@ where
}

/// Use custom connector
pub fn connector<S, U, F>(self, connector: U) -> MqttConnector<A, Boxed<S, Connect<A>>>
pub fn connector<U, F>(self, connector: F) -> MqttConnector<A, U>
where
F: Filter,
U: IntoService<S, Connect<A>>,
S: Service<Connect<A>, Response = Io<F>, Error = connect::ConnectError>,
F: IntoService<U, Connect<A>>,
U: Service<Connect<A>, Error = connect::ConnectError>,
IoBoxed: From<U::Response>,
{
MqttConnector {
connector: Boxed::new(connector.into_service()),
pkt: self.pkt,
address: self.address,
max_send: self.max_send,
max_receive: self.max_receive,
max_packet_size: self.max_packet_size,
handshake_timeout: self.handshake_timeout,
disconnect_timeout: self.disconnect_timeout,
pool: self.pool,
}
}

#[cfg(feature = "openssl")]
/// Use openssl connector
pub fn openssl(
self,
connector: SslConnector,
) -> MqttConnector<
A,
impl Service<Connect<A>, Response = IoBoxed, Error = connect::ConnectError>,
> {
MqttConnector {
pkt: self.pkt,
address: self.address,
max_send: self.max_send,
max_receive: self.max_receive,
max_packet_size: self.max_packet_size,
connector: OpensslConnector::new(connector).map(|io| io.into_boxed()),
handshake_timeout: self.handshake_timeout,
disconnect_timeout: self.disconnect_timeout,
pool: self.pool,
}
}

#[cfg(feature = "rustls")]
/// Use rustls connector
pub fn rustls(
self,
config: ClientConfig,
) -> MqttConnector<
A,
impl Service<Connect<A>, Response = IoBoxed, Error = connect::ConnectError>,
> {
MqttConnector {
connector: connector.into_service(),
pkt: self.pkt,
address: self.address,
max_send: self.max_send,
max_receive: self.max_receive,
max_packet_size: self.max_packet_size,
connector: RustlsConnector::new(Arc::new(config)).map(|io| io.into_boxed()),
handshake_timeout: self.handshake_timeout,
disconnect_timeout: self.disconnect_timeout,
pool: self.pool,
}
}
}

impl<A, T> MqttConnector<A, T>
where
A: Address + Clone,
T: Service<Connect<A>, Error = connect::ConnectError>,
IoBoxed: From<T::Response>,
{
/// Connect to mqtt server
pub fn connect(&self) -> impl Future<Output = Result<Client, ClientError>> {
if self.handshake_timeout.non_zero() {
let fut = timeout(self.handshake_timeout, self._connect());
Either::Left(async move {
match fut.await {
Ok(res) => res.map_err(From::from),
Err(_) => Err(ClientError::HandshakeTimeout),
}
})
} else {
Either::Right(self._connect())
let fut = timeout_checked(self.handshake_timeout, self._connect());
async move {
match fut.await {
Ok(res) => res.map_err(From::from),
Err(_) => Err(ClientError::HandshakeTimeout),
}
}
}

Expand All @@ -266,7 +218,7 @@ where
let pool = self.pool.clone();

async move {
let io = fut.await?;
let io = IoBoxed::from(fut.await?);
let codec = codec::Codec::new().max_size(max_packet_size);

io.send(pkt.into(), &codec).await?;
Expand Down
93 changes: 25 additions & 68 deletions src/v5/client/connector.rs
Original file line number Diff line number Diff line change
@@ -1,16 +1,10 @@
use std::{future::Future, num::NonZeroU16, num::NonZeroU32, rc::Rc};

use ntex::connect::{self, Address, Connect, Connector};
use ntex::io::{utils::Boxed, Filter, Io, IoBoxed};
use ntex::service::Service;
use ntex::time::{timeout, Seconds};
use ntex::util::{ByteString, Bytes, Either, PoolId};

#[cfg(feature = "openssl")]
use ntex::connect::openssl::{OpensslConnector, SslConnector};

#[cfg(feature = "rustls")]
use ntex::connect::rustls::{ClientConfig, RustlsConnector};
use ntex::io::IoBoxed;
use ntex::service::{IntoService, Service};
use ntex::time::{timeout_checked, Seconds};
use ntex::util::{ByteString, Bytes, PoolId};

use super::{codec, connection::Client, error::ClientError, error::ProtocolError};
use crate::v5::shared::{MqttShared, MqttSinkPool};
Expand All @@ -31,11 +25,11 @@ where
{
#[allow(clippy::new_ret_no_self)]
/// Create new mqtt connector
pub fn new(address: A) -> MqttConnector<A, Boxed<Connector<A>, Connect<A>>> {
pub fn new(address: A) -> MqttConnector<A, Connector<A>> {
MqttConnector {
address,
pkt: codec::Connect::default(),
connector: Boxed::new(Connector::default()),
connector: Connector::default(),
handshake_timeout: Seconds::ZERO,
disconnect_timeout: Seconds(3),
pool: Rc::new(MqttSinkPool::default()),
Expand All @@ -46,7 +40,6 @@ where
impl<A, T> MqttConnector<A, T>
where
A: Address + Clone,
T: Service<Connect<A>, Response = IoBoxed, Error = connect::ConnectError>,
{
#[inline]
/// Create new client and provide client id
Expand Down Expand Up @@ -184,73 +177,37 @@ where
}

/// Use custom connector
pub fn connector<U, F>(self, connector: U) -> MqttConnector<A, Boxed<U, Connect<A>>>
pub fn connector<U, F>(self, connector: F) -> MqttConnector<A, U>
where
F: Filter,
U: Service<Connect<A>, Response = Io<F>, Error = connect::ConnectError>,
F: IntoService<U, Connect<A>>,
U: Service<Connect<A>, Error = connect::ConnectError>,
IoBoxed: From<U::Response>,
{
MqttConnector {
connector: Boxed::new(connector),
pkt: self.pkt,
address: self.address,
handshake_timeout: self.handshake_timeout,
disconnect_timeout: self.disconnect_timeout,
pool: self.pool,
}
}

#[cfg(feature = "openssl")]
/// Use openssl connector
pub fn openssl(
self,
connector: SslConnector,
) -> MqttConnector<
A,
impl Service<Request = Connect<A>, Response = IoBoxed, Error = connect::ConnectError>,
> {
MqttConnector {
pkt: self.pkt,
address: self.address,
connector: OpensslConnector::new(connector).map(|io| io.into_boxed()),
handshake_timeout: self.handshake_timeout,
disconnect_timeout: self.disconnect_timeout,
pool: self.pool,
}
}

#[cfg(feature = "rustls")]
/// Use rustls connector
pub fn rustls(
self,
config: ClientConfig,
) -> MqttConnector<
A,
impl Service<Request = Connect<A>, Response = IoBoxed, Error = connect::ConnectError>,
> {
use std::sync::Arc;

MqttConnector {
connector: connector.into_service(),
pkt: self.pkt,
address: self.address,
connector: RustlsConnector::new(Arc::new(config)).map(|io| io.into_boxed()),
handshake_timeout: self.handshake_timeout,
disconnect_timeout: self.disconnect_timeout,
pool: self.pool,
}
}
}

impl<A, T> MqttConnector<A, T>
where
A: Address + Clone,
T: Service<Connect<A>, Error = connect::ConnectError>,
IoBoxed: From<T::Response>,
{
/// Connect to mqtt server
pub fn connect(&self) -> impl Future<Output = Result<Client, ClientError>> {
if self.handshake_timeout.non_zero() {
let fut = timeout(self.handshake_timeout, self._connect());
Either::Left(async move {
match fut.await {
Ok(res) => res.map_err(From::from),
Err(_) => Err(ClientError::HandshakeTimeout),
}
})
} else {
Either::Right(self._connect())
let fut = timeout_checked(self.handshake_timeout, self._connect());
async move {
match fut.await {
Ok(res) => res.map_err(From::from),
Err(_) => Err(ClientError::HandshakeTimeout),
}
}
}

Expand All @@ -264,7 +221,7 @@ where
let pool = self.pool.clone();

async move {
let io = fut.await?;
let io = IoBoxed::from(fut.await?);
let codec = codec::Codec::new().max_inbound_size(max_packet_size);

io.send(codec::Packet::Connect(Box::new(pkt)), &codec).await?;
Expand Down
3 changes: 1 addition & 2 deletions src/v5/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,7 @@ pub struct MqttServer<St, C, Cn, P> {
impl<St, C>
MqttServer<St, C, DefaultControlService<St, C::Error>, DefaultPublishService<St, C::Error>>
where
St: 'static,
C: ServiceFactory<Handshake, Response = HandshakeAck<St>> + 'static,
C: ServiceFactory<Handshake, Response = HandshakeAck<St>>,
C::Error: fmt::Debug,
{
/// Create server factory and provide handshake service
Expand Down
Loading

0 comments on commit 655db48

Please sign in to comment.