From 19fce07879f92c6f9dd8d9ca704dcc6327431387 Mon Sep 17 00:00:00 2001 From: Nikolay Kim Date: Tue, 7 Sep 2021 20:10:38 +0600 Subject: [PATCH] v3: add ControlMessage::Error and ControlMessage::ProtocolError --- CHANGES.md | 4 + Cargo.toml | 2 +- src/server.rs | 2 +- src/topic.rs | 10 +- src/v3/client/connection.rs | 97 ++----------- src/v3/client/control.rs | 20 ++- src/v3/client/dispatcher.rs | 155 +++++++++++--------- src/v3/control.rs | 66 ++++++++- src/v3/default.rs | 29 ++-- src/v3/dispatcher.rs | 280 ++++++++++++++++++++++-------------- src/v3/publish.rs | 4 + src/v3/selector.rs | 2 +- src/v3/server.rs | 103 +++---------- src/v5/control.rs | 4 +- src/v5/default.rs | 2 +- src/v5/dispatcher.rs | 2 +- 16 files changed, 407 insertions(+), 375 deletions(-) diff --git a/CHANGES.md b/CHANGES.md index 65ccc4cd..325972a1 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -1,5 +1,9 @@ # Changes +## [0.7.0-b.10] - 2021-09-07 + +* v3: add ControlMessage::Error and ControlMessage::ProtocolError + ## [0.7.0-b.9] - 2021-09-07 * v5: add helper methods to client control publish message diff --git a/Cargo.toml b/Cargo.toml index 0fd4d9b1..5162ab77 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "ntex-mqtt" -version = "0.7.0-b.9" +version = "0.7.0-b.10" authors = ["ntex contributors "] description = "Client and Server framework for MQTT v5 and v3.1.1 protocols" documentation = "https://docs.rs/ntex-mqtt" diff --git a/src/server.rs b/src/server.rs index b086b8fc..aef5a75d 100644 --- a/src/server.rs +++ b/src/server.rs @@ -110,7 +110,7 @@ where > + 'static, Cn: ServiceFactory< Config = v3::Session, - Request = v3::ControlMessage, + Request = v3::ControlMessage, Response = v3::ControlResult, > + 'static, P: ServiceFactory, Request = v3::Publish, Response = ()> diff --git a/src/topic.rs b/src/topic.rs index 505a0396..769fb42c 100644 --- a/src/topic.rs +++ b/src/topic.rs @@ -87,13 +87,13 @@ macro_rules! matches { for rhs in $levels { match lhs.next() { - Some(&Level::SingleWildcard) => { - if !rhs.match_level(&Level::SingleWildcard) { + Some(&$crate::topic::Level::SingleWildcard) => { + if !rhs.match_level(&$crate::topic::Level::SingleWildcard) { break; } } - Some(&Level::MultiWildcard) => { - return rhs.match_level(&Level::MultiWildcard); + Some(&$crate::topic::Level::MultiWildcard) => { + return rhs.match_level(&$crate::topic::Level::MultiWildcard); } Some(level) if rhs.match_level(level) => continue, _ => return false, @@ -101,7 +101,7 @@ macro_rules! matches { } match lhs.next() { - Some(&Level::MultiWildcard) => true, + Some(&$crate::topic::Level::MultiWildcard) => true, Some(_) => false, None => true, } diff --git a/src/v3/client/connection.rs b/src/v3/client/connection.rs index da4d7c14..ff516668 100644 --- a/src/v3/client/connection.rs +++ b/src/v3/client/connection.rs @@ -100,31 +100,14 @@ where MqttSink::new(self.shared.clone()), self.max_receive, into_service(|pkt| Ready::Ok(Either::Right(pkt))), - into_service(|msg: ControlMessage| Ready::<_, MqttError<()>>::Ok(msg.disconnect())), + into_service(|msg: ControlMessage<()>| Ready::<_, ()>::Ok(msg.disconnect())), ); let _ = Dispatcher::with( self.io, self.shared.state.clone(), self.shared.clone(), - apply_fn(dispatcher, |req: DispatchItem>, srv| match req { - DispatchItem::Item(req) => Either::Left(srv.call(req)), - DispatchItem::KeepAliveTimeout => Either::Right(Ready::Err( - MqttError::Protocol(ProtocolError::KeepAliveTimeout), - )), - DispatchItem::EncoderError(e) => { - Either::Right(Ready::Err(MqttError::Protocol(ProtocolError::Encode(e)))) - } - DispatchItem::DecoderError(e) => { - Either::Right(Ready::Err(MqttError::Protocol(ProtocolError::Decode(e)))) - } - DispatchItem::IoError(e) => { - Either::Right(Ready::Err(MqttError::Protocol(ProtocolError::Io(e)))) - } - DispatchItem::WBackPressureEnabled | DispatchItem::WBackPressureDisabled => { - Either::Right(Ready::Ok(None)) - } - }), + dispatcher, Timer::new(Millis::ONE_SEC), ) .keepalive_timeout(Seconds::ZERO) @@ -137,7 +120,7 @@ where where E: 'static, F: IntoService + 'static, - S: Service + 'static, + S: Service, Response = ControlResult, Error = E> + 'static, { if self.keepalive.non_zero() { ntex::rt::spawn(keepalive(MqttSink::new(self.shared.clone()), self.keepalive)); @@ -147,31 +130,14 @@ where MqttSink::new(self.shared.clone()), self.max_receive, into_service(|pkt| Ready::Ok(Either::Right(pkt))), - service.into_service().map_err(MqttError::Service), + service.into_service(), ); Dispatcher::with( self.io, self.shared.state.clone(), self.shared.clone(), - apply_fn(dispatcher, |req: DispatchItem>, srv| match req { - DispatchItem::Item(req) => Either::Left(srv.call(req)), - DispatchItem::KeepAliveTimeout => Either::Right(Ready::Err( - MqttError::Protocol(ProtocolError::KeepAliveTimeout), - )), - DispatchItem::EncoderError(e) => { - Either::Right(Ready::Err(MqttError::Protocol(ProtocolError::Encode(e)))) - } - DispatchItem::DecoderError(e) => { - Either::Right(Ready::Err(MqttError::Protocol(ProtocolError::Decode(e)))) - } - DispatchItem::IoError(e) => { - Either::Right(Ready::Err(MqttError::Protocol(ProtocolError::Io(e)))) - } - DispatchItem::WBackPressureEnabled | DispatchItem::WBackPressureDisabled => { - Either::Right(Ready::Ok(None)) - } - }), + dispatcher, Timer::new(Millis::ONE_SEC), ) .keepalive_timeout(Seconds::ZERO) @@ -222,33 +188,14 @@ where MqttSink::new(self.shared.clone()), self.max_receive, dispatch(self.builder.finish(), self.handlers), - into_service(|msg: ControlMessage| { - Ready::<_, MqttError>::Ok(msg.disconnect()) - }), + into_service(|msg: ControlMessage| Ready::<_, Err>::Ok(msg.disconnect())), ); let _ = Dispatcher::with( self.io, self.shared.state.clone(), self.shared.clone(), - apply_fn(dispatcher, |req: DispatchItem>, srv| match req { - DispatchItem::Item(req) => Either::Left(srv.call(req)), - DispatchItem::KeepAliveTimeout => Either::Right(Ready::Err( - MqttError::Protocol(ProtocolError::KeepAliveTimeout), - )), - DispatchItem::EncoderError(e) => { - Either::Right(Ready::Err(MqttError::Protocol(ProtocolError::Encode(e)))) - } - DispatchItem::DecoderError(e) => { - Either::Right(Ready::Err(MqttError::Protocol(ProtocolError::Decode(e)))) - } - DispatchItem::IoError(e) => { - Either::Right(Ready::Err(MqttError::Protocol(ProtocolError::Io(e)))) - } - DispatchItem::WBackPressureEnabled | DispatchItem::WBackPressureDisabled => { - Either::Right(Ready::Ok(None)) - } - }), + dispatcher, Timer::new(Millis::ONE_SEC), ) .keepalive_timeout(Seconds::ZERO) @@ -260,7 +207,8 @@ where pub async fn start(self, service: F) -> Result<(), MqttError> where F: IntoService + 'static, - S: Service + 'static, + S: Service, Response = ControlResult, Error = Err> + + 'static, { if self.keepalive.non_zero() { ntex::rt::spawn(keepalive(MqttSink::new(self.shared.clone()), self.keepalive)); @@ -270,31 +218,14 @@ where MqttSink::new(self.shared.clone()), self.max_receive, dispatch(self.builder.finish(), self.handlers), - service.into_service().map_err(MqttError::Service), + service.into_service(), ); Dispatcher::with( self.io, self.shared.state.clone(), self.shared.clone(), - apply_fn(dispatcher, |req: DispatchItem>, srv| match req { - DispatchItem::Item(req) => Either::Left(srv.call(req)), - DispatchItem::KeepAliveTimeout => Either::Right(Ready::Err( - MqttError::Protocol(ProtocolError::KeepAliveTimeout), - )), - DispatchItem::EncoderError(e) => { - Either::Right(Ready::Err(MqttError::Protocol(ProtocolError::Encode(e)))) - } - DispatchItem::DecoderError(e) => { - Either::Right(Ready::Err(MqttError::Protocol(ProtocolError::Decode(e)))) - } - DispatchItem::IoError(e) => { - Either::Right(Ready::Err(MqttError::Protocol(ProtocolError::Io(e)))) - } - DispatchItem::WBackPressureEnabled | DispatchItem::WBackPressureDisabled => { - Either::Right(Ready::Ok(None)) - } - }), + dispatcher, Timer::new(Millis::ONE_SEC), ) .keepalive_timeout(Seconds::ZERO) @@ -306,7 +237,7 @@ where fn dispatch( router: Router, handlers: Vec>, -) -> impl Service, Error = MqttError> +) -> impl Service, Error = Err> where PErr: 'static, Err: From, @@ -315,9 +246,9 @@ where if let Some((idx, _info)) = router.recognize(req.topic_mut()) { // exec handler let fut = call(req, &handlers[*idx]); - Either::Left(async move { fut.await.map_err(MqttError::Service) }) + Either::Left(async move { fut.await }) } else { - Either::Right(Ready::<_, MqttError>::Ok(Either::Right(req))) + Either::Right(Ready::<_, Err>::Ok(Either::Right(req))) } }) } diff --git a/src/v3/client/control.rs b/src/v3/client/control.rs index e5752967..27ff5c6a 100644 --- a/src/v3/client/control.rs +++ b/src/v3/client/control.rs @@ -1,16 +1,20 @@ -pub use crate::v3::control::{Closed, ControlResult, Disconnect}; -use crate::v3::{codec, control::ControlResultKind}; +pub use crate::v3::control::{Closed, ControlResult, Disconnect, Error, ProtocolError}; +use crate::v3::{codec, control::ControlResultKind, error}; -pub enum ControlMessage { +pub enum ControlMessage { /// Unhandled publish packet Publish(Publish), /// Disconnect packet Disconnect(Disconnect), /// Connection closed Closed(Closed), + /// Application level error from resources and control services + Error(Error), + /// Protocol level error + ProtocolError(ProtocolError), } -impl ControlMessage { +impl ControlMessage { pub(super) fn publish(pkt: codec::Publish) -> Self { ControlMessage::Publish(Publish(pkt)) } @@ -23,6 +27,14 @@ impl ControlMessage { ControlMessage::Closed(Closed::new(is_error)) } + pub(super) fn error(err: E) -> Self { + ControlMessage::Error(Error::new(err)) + } + + pub(super) fn proto_error(err: error::ProtocolError) -> Self { + ControlMessage::ProtocolError(ProtocolError::new(err)) + } + pub fn disconnect(&self) -> ControlResult { ControlResult { result: ControlResultKind::Disconnect } } diff --git a/src/v3/client/dispatcher.rs b/src/v3/client/dispatcher.rs index 1b0cea61..1367b433 100644 --- a/src/v3/client/dispatcher.rs +++ b/src/v3/client/dispatcher.rs @@ -5,9 +5,9 @@ use std::{future::Future, marker::PhantomData, num::NonZeroU16, pin::Pin, rc::Rc use ntex::service::Service; use ntex::util::{inflight::InFlightService, Either, HashSet, Ready}; -use crate::v3::shared::Ack; +use crate::v3::shared::{Ack, MqttShared}; use crate::v3::{codec, control::ControlResultKind, publish::Publish, sink::MqttSink}; -use crate::{error::MqttError, error::ProtocolError, types::packet_type}; +use crate::{error::MqttError, error::ProtocolError, io::DispatchItem, types::packet_type}; use super::control::{ControlMessage, ControlResult}; @@ -17,27 +17,28 @@ pub(super) fn create_dispatcher( inflight: usize, publish: T, control: C, -) -> impl Service, Error = MqttError> +) -> impl Service< + Request = DispatchItem>, + Response = Option, + Error = MqttError, +> where E: 'static, - T: Service< - Request = Publish, - Response = ntex::util::Either<(), Publish>, - Error = MqttError, - > + 'static, - C: Service> + T: Service, Error = E> + 'static, + C: Service, Response = ControlResult, Error = E> + 'static, { // limit number of in-flight messages - InFlightService::new(inflight, Dispatcher::<_, _, E>::new(sink, publish, control)) + InFlightService::new(inflight, Dispatcher::::new(sink, publish, control)) } /// Mqtt protocol dispatcher -pub(crate) struct Dispatcher>, C, E> { +pub(crate) struct Dispatcher { sink: MqttSink, publish: T, shutdown: Cell, inner: Rc>, + _t: PhantomData, } struct Inner { @@ -48,12 +49,8 @@ struct Inner { impl Dispatcher where - T: Service< - Request = Publish, - Response = ntex::util::Either<(), Publish>, - Error = MqttError, - >, - C: Service>, + T: Service, Error = E>, + C: Service, Response = ControlResult, Error = E>, { pub(crate) fn new(sink: MqttSink, publish: T, control: C) -> Self { Self { @@ -61,22 +58,19 @@ where sink: sink.clone(), shutdown: Cell::new(false), inner: Rc::new(Inner { sink, control, inflight: RefCell::new(HashSet::default()) }), + _t: PhantomData, } } } impl Service for Dispatcher where - T: Service< - Request = Publish, - Response = ntex::util::Either<(), Publish>, - Error = MqttError, - >, - C: Service>, + T: Service, Error = E>, + C: Service, Response = ControlResult, Error = E>, C::Future: 'static, E: 'static, { - type Request = codec::Packet; + type Request = DispatchItem>; type Response = Option; type Error = MqttError; type Future = Either< @@ -85,8 +79,8 @@ where >; fn poll_ready(&self, cx: &mut Context<'_>) -> Poll> { - let res1 = self.publish.poll_ready(cx)?; - let res2 = self.inner.control.poll_ready(cx)?; + let res1 = self.publish.poll_ready(cx).map_err(MqttError::Service)?; + let res2 = self.inner.control.poll_ready(cx).map_err(MqttError::Service)?; if res1.is_pending() || res2.is_pending() { Poll::Pending @@ -107,10 +101,10 @@ where Poll::Ready(()) } - fn call(&self, packet: codec::Packet) -> Self::Future { + fn call(&self, packet: Self::Request) -> Self::Future { log::trace!("Dispatch packet: {:#?}", packet); match packet { - codec::Packet::Publish(publish) => { + DispatchItem::Item(codec::Packet::Publish(publish)) => { let inner = self.inner.clone(); let packet_id = publish.packet_id; @@ -131,49 +125,80 @@ where _t: PhantomData, }) } - codec::Packet::PublishAck { packet_id } => { + DispatchItem::Item(codec::Packet::PublishAck { packet_id }) => { if let Err(e) = self.sink.pkt_ack(Ack::Publish(packet_id)) { Either::Right(Either::Left(Ready::Err(MqttError::Protocol(e)))) } else { Either::Right(Either::Left(Ready::Ok(None))) } } - codec::Packet::PingRequest => { + DispatchItem::Item(codec::Packet::PingRequest) => { Either::Right(Either::Left(Ready::Ok(Some(codec::Packet::PingResponse)))) } - codec::Packet::Disconnect => Either::Right(Either::Right(ControlResponse::new( - self.inner.control.call(ControlMessage::dis()), - &self.inner, - ))), - codec::Packet::SubscribeAck { packet_id, status } => { + DispatchItem::Item(codec::Packet::Disconnect) => Either::Right(Either::Right( + ControlResponse::new(ControlMessage::dis(), &self.inner), + )), + DispatchItem::Item(codec::Packet::SubscribeAck { packet_id, status }) => { if let Err(e) = self.sink.pkt_ack(Ack::Subscribe { packet_id, status }) { Either::Right(Either::Left(Ready::Err(MqttError::Protocol(e)))) } else { Either::Right(Either::Left(Ready::Ok(None))) } } - codec::Packet::UnsubscribeAck { packet_id } => { + DispatchItem::Item(codec::Packet::UnsubscribeAck { packet_id }) => { if let Err(e) = self.sink.pkt_ack(Ack::Unsubscribe(packet_id)) { Either::Right(Either::Left(Ready::Err(MqttError::Protocol(e)))) } else { Either::Right(Either::Left(Ready::Ok(None))) } } - codec::Packet::Subscribe { .. } => Either::Right(Either::Left(Ready::Err( - ProtocolError::Unexpected( - packet_type::SUBSCRIBE, - "Subscribe packet is not supported", - ) - .into(), - ))), - codec::Packet::Unsubscribe { .. } => Either::Right(Either::Left(Ready::Err( - ProtocolError::Unexpected( - packet_type::UNSUBSCRIBE, - "Unsubscribe packet is not supported", - ) - .into(), + DispatchItem::Item(codec::Packet::Subscribe { .. }) => { + Either::Right(Either::Left(Ready::Err( + ProtocolError::Unexpected( + packet_type::SUBSCRIBE, + "Subscribe packet is not supported", + ) + .into(), + ))) + } + DispatchItem::Item(codec::Packet::Unsubscribe { .. }) => { + Either::Right(Either::Left(Ready::Err( + ProtocolError::Unexpected( + packet_type::UNSUBSCRIBE, + "Unsubscribe packet is not supported", + ) + .into(), + ))) + } + DispatchItem::Item(pkt) => { + log::debug!("Unsupported packet: {:?}", pkt); + Either::Right(Either::Left(Ready::Ok(None))) + } + DispatchItem::EncoderError(err) => { + Either::Right(Either::Right(ControlResponse::new( + ControlMessage::proto_error(ProtocolError::Encode(err)), + &self.inner, + ))) + } + DispatchItem::DecoderError(err) => { + Either::Right(Either::Right(ControlResponse::new( + ControlMessage::proto_error(ProtocolError::Decode(err)), + &self.inner, + ))) + } + DispatchItem::IoError(err) => Either::Right(Either::Right(ControlResponse::new( + ControlMessage::proto_error(ProtocolError::Io(err)), + &self.inner, ))), - _ => Either::Right(Either::Left(Ready::Ok(None))), + DispatchItem::KeepAliveTimeout => { + Either::Right(Either::Right(ControlResponse::new( + ControlMessage::proto_error(ProtocolError::KeepAliveTimeout), + &self.inner, + ))) + } + DispatchItem::WBackPressureEnabled | DispatchItem::WBackPressureDisabled => { + Either::Right(Either::Left(Ready::Ok(None))) + } } } } @@ -193,12 +218,8 @@ pin_project_lite::pin_project! { impl Future for PublishResponse where - T: Service< - Request = Publish, - Response = ntex::util::Either<(), Publish>, - Error = MqttError, - >, - C: Service>, + T: Service, Error = E>, + C: Service, Response = ControlResult, Error = E>, { type Output = Result, MqttError>; @@ -208,10 +229,16 @@ where } let mut this = self.as_mut().project(); - let res = match this.fut.poll(cx)? { - Poll::Ready(item) => item, + let res = match this.fut.poll(cx) { + Poll::Ready(Ok(item)) => item, + Poll::Ready(Err(e)) => { + this.fut_c + .set(Some(ControlResponse::new(ControlMessage::error(e), &*this.inner))); + return self.poll(cx); + } Poll::Pending => return Poll::Pending, }; + match res { Either::Left(_) => { log::trace!("Publish result for packet {:?} is ready", this.packet_id); @@ -225,7 +252,7 @@ where } Either::Right(pkt) => { this.fut_c.set(Some(ControlResponse::new( - this.inner.control.call(ControlMessage::publish(pkt.into_inner())), + ControlMessage::publish(pkt.into_inner()), &*this.inner, ))); self.poll(cx) @@ -246,23 +273,23 @@ pin_project_lite::pin_project! { impl ControlResponse where - C: Service>, + C: Service, Response = ControlResult, Error = E>, { - fn new(fut: C::Future, inner: &Rc>) -> Self { - Self { fut, inner: inner.clone(), _t: PhantomData } + fn new(msg: ControlMessage, inner: &Rc>) -> Self { + Self { fut: inner.control.call(msg), inner: inner.clone(), _t: PhantomData } } } impl Future for ControlResponse where - C: Service>, + C: Service, Response = ControlResult, Error = E>, { type Output = Result, MqttError>; fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { let this = self.project(); - let packet = match this.fut.poll(cx)? { + let packet = match this.fut.poll(cx).map_err(MqttError::Service)? { Poll::Ready(item) => match item.result { ControlResultKind::Ping => Some(codec::Packet::PingResponse), ControlResultKind::PublishAck(id) => { diff --git a/src/v3/control.rs b/src/v3/control.rs index 0f3cce1c..8131a6b6 100644 --- a/src/v3/control.rs +++ b/src/v3/control.rs @@ -2,10 +2,10 @@ use ntex::util::ByteString; use std::{marker::PhantomData, num::NonZeroU16}; use super::codec; -use crate::types::QoS; +use crate::{error, types::QoS}; #[derive(Debug)] -pub enum ControlMessage { +pub enum ControlMessage { /// Ping packet Ping(Ping), /// Disconnect packet @@ -16,6 +16,10 @@ pub enum ControlMessage { Unsubscribe(Unsubscribe), /// Connection dropped Closed(Closed), + /// Service level error + Error(Error), + /// Protocol level error + ProtocolError(ProtocolError), } #[derive(Debug)] @@ -34,7 +38,7 @@ pub(crate) enum ControlResultKind { Closed, } -impl ControlMessage { +impl ControlMessage { pub(crate) fn ping() -> Self { ControlMessage::Ping(Ping) } @@ -47,6 +51,14 @@ impl ControlMessage { ControlMessage::Closed(Closed::new(is_error)) } + pub(super) fn error(err: E) -> Self { + ControlMessage::Error(Error::new(err)) + } + + pub(super) fn proto_error(err: error::ProtocolError) -> Self { + ControlMessage::ProtocolError(ProtocolError::new(err)) + } + pub fn disconnect(&self) -> ControlResult { ControlResult { result: ControlResultKind::Disconnect } } @@ -70,6 +82,54 @@ impl Disconnect { } } +/// Service level error +#[derive(Debug)] +pub struct Error { + err: E, +} + +impl Error { + pub fn new(err: E) -> Self { + Self { err } + } + + #[inline] + /// Returns reference to mqtt error + pub fn get_ref(&self) -> &E { + &self.err + } + + #[inline] + /// Ack service error, return disconnect packet and close connection. + pub fn ack(self) -> ControlResult { + ControlResult { result: ControlResultKind::Disconnect } + } +} + +/// Protocol level error +#[derive(Debug)] +pub struct ProtocolError { + err: error::ProtocolError, +} + +impl ProtocolError { + pub fn new(err: error::ProtocolError) -> Self { + Self { err } + } + + #[inline] + /// Returns reference to a protocol error + pub fn get_ref(&self) -> &error::ProtocolError { + &self.err + } + + #[inline] + /// Ack protocol error, return disconnect packet and close connection. + pub fn ack(self) -> ControlResult { + ControlResult { result: ControlResultKind::Disconnect } + } +} + /// Subscribe message #[derive(Debug)] pub struct Subscribe { diff --git a/src/v3/default.rs b/src/v3/default.rs index 96193898..5e1baacd 100644 --- a/src/v3/default.rs +++ b/src/v3/default.rs @@ -1,10 +1,9 @@ -use std::marker::PhantomData; -use std::task::{Context, Poll}; +use std::{fmt, marker::PhantomData, task::Context, task::Poll}; use ntex::service::{Service, ServiceFactory}; use ntex::util::Ready; -use super::control::{ControlMessage, ControlResult}; +use super::control::{ControlMessage, ControlResult, ControlResultKind}; use super::publish::Publish; use super::Session; @@ -58,9 +57,9 @@ impl Default for DefaultControlService { } } -impl ServiceFactory for DefaultControlService { +impl ServiceFactory for DefaultControlService { type Config = Session; - type Request = ControlMessage; + type Request = ControlMessage; type Response = ControlResult; type Error = E; type InitError = E; @@ -72,8 +71,8 @@ impl ServiceFactory for DefaultControlService { } } -impl Service for DefaultControlService { - type Request = ControlMessage; +impl Service for DefaultControlService { + type Request = ControlMessage; type Response = ControlResult; type Error = E; type Future = Ready; @@ -84,21 +83,17 @@ impl Service for DefaultControlService { } #[inline] - fn call(&self, subs: ControlMessage) -> Self::Future { + fn call(&self, pkt: Self::Request) -> Self::Future { log::warn!("MQTT Subscribe is not supported"); - Ready::Ok(match subs { + Ready::Ok(match pkt { ControlMessage::Ping(ping) => ping.ack(), ControlMessage::Disconnect(disc) => disc.ack(), - ControlMessage::Subscribe(subs) => { - log::warn!("MQTT Subscribe is not supported"); - subs.ack() - } - ControlMessage::Unsubscribe(unsubs) => { - log::warn!("MQTT Unsubscribe is not supported"); - unsubs.ack() - } ControlMessage::Closed(msg) => msg.ack(), + _ => { + log::warn!("MQTT3 Control service is not configured, pkt: {:?}", pkt); + ControlResult { result: ControlResultKind::Disconnect } + } }) } } diff --git a/src/v3/dispatcher.rs b/src/v3/dispatcher.rs index 07eb22b2..fe712a07 100644 --- a/src/v3/dispatcher.rs +++ b/src/v3/dispatcher.rs @@ -5,11 +5,13 @@ use std::{future::Future, marker::PhantomData, num::NonZeroU16, pin::Pin, rc::Rc use ntex::service::{fn_factory_with_config, Service, ServiceFactory}; use ntex::util::{inflight::InFlightService, join, Either, HashSet, Ready}; -use crate::error::MqttError; +use crate::error::{MqttError, ProtocolError}; +use crate::io::DispatchItem; use super::control::{ ControlMessage, ControlResult, ControlResultKind, Subscribe, Unsubscribe, }; +use super::shared::MqttShared; use super::{codec, publish::Publish, shared::Ack, sink::MqttSink, Session}; /// mqtt3 protocol dispatcher @@ -19,7 +21,7 @@ pub(super) fn factory( inflight: usize, ) -> impl ServiceFactory< Config = Session, - Request = codec::Packet, + Request = DispatchItem>, Response = Option, Error = MqttError, InitError = MqttError, @@ -31,14 +33,14 @@ where Config = Session, Request = Publish, Response = (), - Error = MqttError, + Error = E, InitError = MqttError, > + 'static, C: ServiceFactory< Config = Session, - Request = ControlMessage, + Request = ControlMessage, Response = ControlResult, - Error = MqttError, + Error = E, InitError = MqttError, > + 'static, { @@ -61,23 +63,24 @@ where } /// Mqtt protocol dispatcher -pub(crate) struct Dispatcher>, C, E> { +pub(crate) struct Dispatcher { session: Session, publish: T, - control: C, shutdown: Cell, - inner: Rc, + inner: Rc>, + _t: PhantomData<(E,)>, } -struct Inner { +struct Inner { + control: C, sink: MqttSink, inflight: RefCell>, } impl Dispatcher where - T: Service>, - C: Service>, + T: Service, + C: Service, Response = ControlResult, Error = E>, { pub(crate) fn new(session: Session, publish: T, control: C) -> Self { let sink = session.sink().clone(); @@ -85,31 +88,31 @@ where Self { session, publish, - control, shutdown: Cell::new(false), - inner: Rc::new(Inner { sink, inflight: RefCell::new(HashSet::default()) }), + inner: Rc::new(Inner { sink, control, inflight: RefCell::new(HashSet::default()) }), + _t: PhantomData, } } } impl Service for Dispatcher where - T: Service>, - C: Service>, + T: Service, + C: Service, Response = ControlResult, Error = E>, C::Future: 'static, E: 'static, { - type Request = codec::Packet; + type Request = DispatchItem>; type Response = Option; type Error = MqttError; type Future = Either< - PublishResponse>, - Either>, ControlResponse>, + PublishResponse, + Either>, ControlResponse>, >; fn poll_ready(&self, cx: &mut Context<'_>) -> Poll> { - let res1 = self.publish.poll_ready(cx)?; - let res2 = self.control.poll_ready(cx)?; + let res1 = self.publish.poll_ready(cx).map_err(MqttError::Service)?; + let res2 = self.inner.control.poll_ready(cx).map_err(MqttError::Service)?; if res1.is_pending() || res2.is_pending() { Poll::Pending @@ -122,7 +125,7 @@ where if !self.shutdown.get() { self.inner.sink.close(); self.shutdown.set(true); - let fut = self.control.call(ControlMessage::closed(is_error)); + let fut = self.inner.control.call(ControlMessage::closed(is_error)); ntex::rt::spawn(async move { let _ = fut.await; }); @@ -130,10 +133,11 @@ where Poll::Ready(()) } - fn call(&self, packet: codec::Packet) -> Self::Future { - log::trace!("Dispatch packet: {:#?}", packet); - match packet { - codec::Packet::Publish(publish) => { + fn call(&self, req: DispatchItem>) -> Self::Future { + log::trace!("Dispatch v3 packet: {:#?}", req); + + match req { + DispatchItem::Item(codec::Packet::Publish(publish)) => { let inner = self.inner.clone(); let packet_id = publish.packet_id; @@ -141,34 +145,34 @@ where if let Some(pid) = packet_id { if !inner.inflight.borrow_mut().insert(pid) { log::trace!("Duplicated packet id for publish packet: {:?}", pid); - return Either::Right(Either::Left(Ready::Err( - MqttError::ServerError("Duplicated packet id for publish packet"), + return Either::Right(Either::Right(ControlResponse::new( + ControlMessage::proto_error(ProtocolError::ReceiveMaximumExceeded), + &self.inner, ))); } } Either::Left(PublishResponse { packet_id, inner, - fut: self.publish.call(Publish::new(publish)), - _t: PhantomData, + state: PublishResponseState::Publish { + fut: self.publish.call(Publish::new(publish)), + }, }) } - codec::Packet::PublishAck { packet_id } => { + DispatchItem::Item(codec::Packet::PublishAck { packet_id }) => { if let Err(e) = self.session.sink().pkt_ack(Ack::Publish(packet_id)) { - Either::Right(Either::Left(Ready::Err(MqttError::Protocol(e)))) + Either::Right(Either::Right(ControlResponse::new( + ControlMessage::proto_error(e), + &self.inner, + ))) } else { Either::Right(Either::Left(Ready::Ok(None))) } } - codec::Packet::PingRequest => Either::Right(Either::Right(ControlResponse::new( - self.control.call(ControlMessage::ping()), - &self.inner, - ))), - codec::Packet::Disconnect => Either::Right(Either::Right(ControlResponse::new( - self.control.call(ControlMessage::pkt_disconnect()), - &self.inner, - ))), - codec::Packet::Subscribe { packet_id, topic_filters } => { + DispatchItem::Item(codec::Packet::PingRequest) => Either::Right(Either::Right( + ControlResponse::new(ControlMessage::ping(), &self.inner), + )), + DispatchItem::Item(codec::Packet::Subscribe { packet_id, topic_filters }) => { if !self.inner.inflight.borrow_mut().insert(packet_id) { log::trace!("Duplicated packet id for unsubscribe packet: {:?}", packet_id); return Either::Right(Either::Left(Ready::Err(MqttError::ServerError( @@ -177,14 +181,11 @@ where } Either::Right(Either::Right(ControlResponse::new( - self.control.call(ControlMessage::Subscribe(Subscribe::new( - packet_id, - topic_filters, - ))), + ControlMessage::Subscribe(Subscribe::new(packet_id, topic_filters)), &self.inner, ))) } - codec::Packet::Unsubscribe { packet_id, topic_filters } => { + DispatchItem::Item(codec::Packet::Unsubscribe { packet_id, topic_filters }) => { if !self.inner.inflight.borrow_mut().insert(packet_id) { log::trace!("Duplicated packet id for unsubscribe packet: {:?}", packet_id); return Either::Right(Either::Left(Ready::Err(MqttError::ServerError( @@ -193,109 +194,172 @@ where } Either::Right(Either::Right(ControlResponse::new( - self.control.call(ControlMessage::Unsubscribe(Unsubscribe::new( - packet_id, - topic_filters, - ))), + ControlMessage::Unsubscribe(Unsubscribe::new(packet_id, topic_filters)), + &self.inner, + ))) + } + DispatchItem::Item(codec::Packet::Disconnect) => Either::Right(Either::Right( + ControlResponse::new(ControlMessage::pkt_disconnect(), &self.inner), + )), + DispatchItem::Item(_) => Either::Right(Either::Left(Ready::Ok(None))), + DispatchItem::EncoderError(err) => { + Either::Right(Either::Right(ControlResponse::new( + ControlMessage::proto_error(ProtocolError::Encode(err)), + &self.inner, + ))) + } + DispatchItem::KeepAliveTimeout => { + Either::Right(Either::Right(ControlResponse::new( + ControlMessage::proto_error(ProtocolError::KeepAliveTimeout), &self.inner, ))) } - _ => Either::Right(Either::Left(Ready::Ok(None))), + DispatchItem::DecoderError(err) => { + Either::Right(Either::Right(ControlResponse::new( + ControlMessage::proto_error(ProtocolError::Decode(err)), + &self.inner, + ))) + } + DispatchItem::IoError(err) => Either::Right(Either::Right(ControlResponse::new( + ControlMessage::proto_error(ProtocolError::Io(err)), + &self.inner, + ))), + DispatchItem::WBackPressureEnabled | DispatchItem::WBackPressureDisabled => { + Either::Right(Either::Left(Ready::Ok(None))) + } } } } pin_project_lite::pin_project! { /// Publish service response future - pub(crate) struct PublishResponse { + pub(crate) struct PublishResponse { #[pin] - fut: T, + state: PublishResponseState, packet_id: Option, - inner: Rc, - _t: PhantomData, + inner: Rc>, + } +} + +pin_project_lite::pin_project! { + #[project = PublishResponseStateProject] + enum PublishResponseState { + Publish { #[pin] fut: T::Future }, + Control { #[pin] fut: ControlResponse }, } } -impl Future for PublishResponse +impl Future for PublishResponse where - T: Future>, + T: Service, + C: Service, Response = ControlResult, Error = E>, { - type Output = Result, E>; + type Output = Result, MqttError>; - fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - let this = self.project(); + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let mut this = self.as_mut().project(); - match this.fut.poll(cx) { - Poll::Ready(result) => result?, - Poll::Pending => return Poll::Pending, - }; + match this.state.as_mut().project() { + PublishResponseStateProject::Publish { fut } => match fut.poll(cx) { + Poll::Ready(Ok(_)) => { + log::trace!("Publish result for packet {:?} is ready", this.packet_id); - log::trace!("Publish result for packet {:?} is ready", this.packet_id); - - if let Some(packet_id) = this.packet_id { - this.inner.inflight.borrow_mut().remove(packet_id); - Poll::Ready(Ok(Some(codec::Packet::PublishAck { packet_id: *packet_id }))) - } else { - Poll::Ready(Ok(None)) + if let Some(packet_id) = this.packet_id { + this.inner.inflight.borrow_mut().remove(packet_id); + Poll::Ready(Ok(Some(codec::Packet::PublishAck { + packet_id: *packet_id, + }))) + } else { + Poll::Ready(Ok(None)) + } + } + Poll::Ready(Err(e)) => { + this.state.set(PublishResponseState::Control { + fut: ControlResponse::new(ControlMessage::error(e), this.inner), + }); + self.poll(cx) + } + Poll::Pending => Poll::Pending, + }, + PublishResponseStateProject::Control { fut } => fut.poll(cx), } } } pin_project_lite::pin_project! { /// Control service response future - pub(crate) struct ControlResponse - where - T: Future>>, + pub(crate) struct ControlResponse { #[pin] - fut: T, - inner: Rc, + fut: C::Future, + inner: Rc>, + error: bool, + _t: PhantomData, } } -impl ControlResponse +impl ControlResponse where - T: Future>>, + C: Service, Response = ControlResult, Error = E>, { - fn new(fut: T, inner: &Rc) -> Self { - Self { fut, inner: inner.clone() } + #[allow(clippy::match_like_matches_macro)] + fn new(pkt: ControlMessage, inner: &Rc>) -> Self { + let error = match pkt { + ControlMessage::Error(_) | ControlMessage::ProtocolError(_) => true, + _ => false, + }; + + Self { error, fut: inner.control.call(pkt), inner: inner.clone(), _t: PhantomData } } } -impl Future for ControlResponse +impl Future for ControlResponse where - T: Future>>, + C: Service, Response = ControlResult, Error = E>, { type Output = Result, MqttError>; - fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - let this = self.project(); + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let this = self.as_mut().project(); - let packet = match this.fut.poll(cx)? { - Poll::Ready(item) => match item.result { - ControlResultKind::Ping => Some(codec::Packet::PingResponse), - ControlResultKind::Subscribe(res) => { - this.inner.inflight.borrow_mut().remove(&res.packet_id); - Some(codec::Packet::SubscribeAck { - status: res.codes, - packet_id: res.packet_id, - }) - } - ControlResultKind::Unsubscribe(res) => { - this.inner.inflight.borrow_mut().remove(&res.packet_id); - Some(codec::Packet::UnsubscribeAck { packet_id: res.packet_id }) - } - ControlResultKind::Disconnect - | ControlResultKind::Closed - | ControlResultKind::Nothing => { - this.inner.sink.close(); - None + match this.fut.poll(cx) { + Poll::Ready(Ok(item)) => { + let packet = match item.result { + ControlResultKind::Ping => Some(codec::Packet::PingResponse), + ControlResultKind::Subscribe(res) => { + this.inner.inflight.borrow_mut().remove(&res.packet_id); + Some(codec::Packet::SubscribeAck { + status: res.codes, + packet_id: res.packet_id, + }) + } + ControlResultKind::Unsubscribe(res) => { + this.inner.inflight.borrow_mut().remove(&res.packet_id); + Some(codec::Packet::UnsubscribeAck { packet_id: res.packet_id }) + } + ControlResultKind::Disconnect + | ControlResultKind::Closed + | ControlResultKind::Nothing => { + this.inner.sink.close(); + None + } + ControlResultKind::PublishAck(_) => unreachable!(), + }; + Poll::Ready(Ok(packet)) + } + Poll::Ready(Err(err)) => { + // do not handle nested error + if *this.error { + Poll::Ready(Err(MqttError::Service(err))) + } else { + // handle error from control service + *this.error = true; + let fut = this.inner.control.call(ControlMessage::error(err)); + self.as_mut().project().fut.set(fut); + self.poll(cx) } - ControlResultKind::PublishAck(_) => unreachable!(), - }, - Poll::Pending => return Poll::Pending, - }; - - Poll::Ready(Ok(packet)) + } + Poll::Pending => Poll::Pending, + } } } diff --git a/src/v3/publish.rs b/src/v3/publish.rs index ac50daa4..f10983cf 100644 --- a/src/v3/publish.rs +++ b/src/v3/publish.rs @@ -13,6 +13,10 @@ pub struct Publish { topic: Path, } +#[derive(Debug)] +/// Publish ack +pub struct PublishAck; + impl Publish { pub(crate) fn new(publish: codec::Publish) -> Self { Self { topic: Path::new(publish.topic.clone()), publish } diff --git a/src/v3/selector.rs b/src/v3/selector.rs index 502aed17..cba0258f 100644 --- a/src/v3/selector.rs +++ b/src/v3/selector.rs @@ -95,7 +95,7 @@ where > + 'static, Cn: ServiceFactory< Config = Session, - Request = ControlMessage, + Request = ControlMessage, Response = ControlResult, > + 'static, P: ServiceFactory, Request = Publish, Response = ()> + 'static, diff --git a/src/v3/server.rs b/src/v3/server.rs index 7be6b009..c7cb257f 100644 --- a/src/v3/server.rs +++ b/src/v3/server.rs @@ -69,8 +69,11 @@ where St: 'static, C: ServiceFactory, Response = HandshakeAck> + 'static, - Cn: ServiceFactory, Request = ControlMessage, Response = ControlResult> - + 'static, + Cn: ServiceFactory< + Config = Session, + Request = ControlMessage, + Response = ControlResult, + > + 'static, P: ServiceFactory, Request = Publish, Response = ()> + 'static, C::Error: From @@ -127,7 +130,7 @@ where F: IntoServiceFactory, Srv: ServiceFactory< Config = Session, - Request = ControlMessage, + Request = ControlMessage, Response = ControlResult, > + 'static, C::Error: From + From, @@ -174,12 +177,10 @@ where let publish = self .publish .into_factory() - .map_err(|e| MqttError::Service(e.into())) - .map_init_err(|e| MqttError::Service(e.into())); - let control = self - .control - .map_err(|e| MqttError::Service(e.into())) + .map_err(|e| e.into()) .map_init_err(|e| MqttError::Service(e.into())); + let control = + self.control.map_err(|e| e.into()).map_init_err(|e| MqttError::Service(e.into())); FramedService::new( handshake_service_factory( @@ -188,26 +189,7 @@ where self.handshake_timeout, self.pool, ), - apply_fn_factory( - factory(publish, control, self.inflight), - |req: DispatchItem>, srv| match req { - DispatchItem::Item(req) => Either::Left(srv.call(req)), - DispatchItem::KeepAliveTimeout => Either::Right(Ready::Err( - MqttError::Protocol(ProtocolError::KeepAliveTimeout), - )), - DispatchItem::EncoderError(e) => { - Either::Right(Ready::Err(MqttError::Protocol(ProtocolError::Encode(e)))) - } - DispatchItem::DecoderError(e) => { - Either::Right(Ready::Err(MqttError::Protocol(ProtocolError::Decode(e)))) - } - DispatchItem::IoError(e) => { - Either::Right(Ready::Err(MqttError::Protocol(ProtocolError::Io(e)))) - } - DispatchItem::WBackPressureEnabled - | DispatchItem::WBackPressureDisabled => Either::Right(Ready::Ok(None)), - }, - ), + factory(publish, control, self.inflight), self.disconnect_timeout, ) } @@ -226,12 +208,10 @@ where let publish = self .publish .into_factory() - .map_err(|e| MqttError::Service(e.into())) - .map_init_err(|e| MqttError::Service(e.into())); - let control = self - .control - .map_err(|e| MqttError::Service(e.into())) + .map_err(|e| e.into()) .map_init_err(|e| MqttError::Service(e.into())); + let control = + self.control.map_err(|e| e.into()).map_init_err(|e| MqttError::Service(e.into())); FramedService2::new( handshake_service_factory2( @@ -240,26 +220,7 @@ where self.handshake_timeout, self.pool, ), - apply_fn_factory( - factory(publish, control, self.inflight), - |req: DispatchItem>, srv| match req { - DispatchItem::Item(req) => Either::Left(srv.call(req)), - DispatchItem::KeepAliveTimeout => Either::Right(Ready::Err( - MqttError::Protocol(ProtocolError::KeepAliveTimeout), - )), - DispatchItem::EncoderError(e) => { - Either::Right(Ready::Err(MqttError::Protocol(ProtocolError::Encode(e)))) - } - DispatchItem::DecoderError(e) => { - Either::Right(Ready::Err(MqttError::Protocol(ProtocolError::Decode(e)))) - } - DispatchItem::IoError(e) => { - Either::Right(Ready::Err(MqttError::Protocol(ProtocolError::Io(e)))) - } - DispatchItem::WBackPressureEnabled - | DispatchItem::WBackPressureDisabled => Either::Right(Ready::Ok(None)), - }, - ), + factory(publish, control, self.inflight), self.disconnect_timeout, ) } @@ -279,41 +240,15 @@ where F: Fn(&Handshake) -> R + 'static, R: Future> + 'static, { - let publish = self - .publish - .map_err(|e| MqttError::Service(e.into())) - .map_init_err(|e| MqttError::Service(e.into())); - let control = self - .control - .map_err(|e| MqttError::Service(e.into())) - .map_init_err(|e| MqttError::Service(e.into())); - - let handler = apply_fn_factory( - factory(publish, control, self.inflight), - |req: DispatchItem>, srv| match req { - DispatchItem::Item(req) => Either::Left(srv.call(req)), - DispatchItem::KeepAliveTimeout => Either::Right(Ready::Err( - MqttError::Protocol(ProtocolError::KeepAliveTimeout), - )), - DispatchItem::EncoderError(e) => { - Either::Right(Ready::Err(MqttError::Protocol(ProtocolError::Encode(e)))) - } - DispatchItem::DecoderError(e) => { - Either::Right(Ready::Err(MqttError::Protocol(ProtocolError::Decode(e)))) - } - DispatchItem::IoError(e) => { - Either::Right(Ready::Err(MqttError::Protocol(ProtocolError::Io(e)))) - } - DispatchItem::WBackPressureEnabled | DispatchItem::WBackPressureDisabled => { - Either::Right(Ready::Ok(None)) - } - }, - ); + let publish = + self.publish.map_err(|e| e.into()).map_init_err(|e| MqttError::Service(e.into())); + let control = + self.control.map_err(|e| e.into()).map_init_err(|e| MqttError::Service(e.into())); ServerSelector { check: Rc::new(check), connect: self.handshake, - handler: Rc::new(handler), + handler: Rc::new(factory(publish, control, self.inflight)), max_size: self.max_size, disconnect_timeout: self.disconnect_timeout, time: Timer::new(Millis::ONE_SEC), diff --git a/src/v5/control.rs b/src/v5/control.rs index 2171ec59..8d3aeec2 100644 --- a/src/v5/control.rs +++ b/src/v5/control.rs @@ -439,7 +439,7 @@ impl Error { #[inline] /// Returns reference to mqtt error - pub fn get_err(&self) -> &E { + pub fn get_ref(&self) -> &E { &self.err } @@ -485,7 +485,7 @@ impl Error { } } -/// Connection failed message +/// Protocol level error #[derive(Debug)] pub struct ProtocolError { err: error::ProtocolError, diff --git a/src/v5/default.rs b/src/v5/default.rs index c2223f36..9977f566 100644 --- a/src/v5/default.rs +++ b/src/v5/default.rs @@ -88,7 +88,7 @@ impl Service for DefaultControlService { ControlMessage::Ping(pkt) => Ready::Ok(pkt.ack()), ControlMessage::Disconnect(pkt) => Ready::Ok(pkt.ack()), _ => { - log::warn!("MQTT Control service is not configured, pkt: {:?}", pkt); + log::warn!("MQTT5 Control service is not configured, pkt: {:?}", pkt); Ready::Ok(pkt.disconnect_with(super::codec::Disconnect::new( super::codec::DisconnectReasonCode::UnspecifiedError, ))) diff --git a/src/v5/dispatcher.rs b/src/v5/dispatcher.rs index 1c21affa..6eb544fd 100644 --- a/src/v5/dispatcher.rs +++ b/src/v5/dispatcher.rs @@ -157,7 +157,7 @@ where } fn call(&self, request: Self::Request) -> Self::Future { - log::trace!("Dispatch packet: {:#?}", request); + log::trace!("Dispatch v5 packet: {:#?}", request); match request { DispatchItem::Item(codec::Packet::Publish(publish)) => {