Skip to content

Commit

Permalink
Serialize control message handling
Browse files Browse the repository at this point in the history
  • Loading branch information
fafhrd91 committed Oct 1, 2021
1 parent f0fb414 commit 1fcd7d3
Show file tree
Hide file tree
Showing 8 changed files with 119 additions and 81 deletions.
4 changes: 4 additions & 0 deletions CHANGES.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
# Changes

## [0.7.2] - 2021-10-01

* Serialize control message handling

## [0.7.1] - 2021-09-18

* Allow to extract error from control message
Expand Down
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "ntex-mqtt"
version = "0.7.1"
version = "0.7.2"
authors = ["ntex contributors <[email protected]>"]
description = "Client and Server framework for MQTT v5 and v3.1.1 protocols"
documentation = "https://docs.rs/ntex-mqtt"
Expand Down
35 changes: 21 additions & 14 deletions src/v3/client/dispatcher.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use std::task::{Context, Poll};
use std::{future::Future, marker::PhantomData, num::NonZeroU16, pin::Pin, rc::Rc};

use ntex::service::Service;
use ntex::util::{inflight::InFlightService, Either, HashSet, Ready};
use ntex::util::{buffer::BufferService, inflight::InFlightService, Either, HashSet, Ready};

use crate::v3::shared::{Ack, MqttShared};
use crate::v3::{codec, control::ControlResultKind, publish::Publish, sink::MqttSink};
Expand All @@ -24,12 +24,19 @@ pub(super) fn create_dispatcher<T, C, E>(
>
where
E: 'static,
T: Service<Request = Publish, Response = ntex::util::Either<(), Publish>, Error = E>
+ 'static,
T: Service<Request = Publish, Response = Either<(), Publish>, Error = E> + 'static,
C: Service<Request = ControlMessage<E>, Response = ControlResult, Error = E> + 'static,
{
// limit inflight control messages
let control = BufferService::new(
16,
|| MqttError::<E>::Disconnected,
// limit number of in-flight messages
InFlightService::new(1, control.map_err(MqttError::Service)),
);

// limit number of in-flight messages
InFlightService::new(inflight, Dispatcher::<T, C, E>::new(sink, publish, control))
InFlightService::new(inflight, Dispatcher::new(sink, publish, control))
}

/// Mqtt protocol dispatcher
Expand All @@ -49,8 +56,8 @@ struct Inner<C> {

impl<T, C, E> Dispatcher<T, C, E>
where
T: Service<Request = Publish, Response = ntex::util::Either<(), Publish>, Error = E>,
C: Service<Request = ControlMessage<E>, Response = ControlResult, Error = E>,
T: Service<Request = Publish, Response = Either<(), Publish>, Error = E>,
C: Service<Request = ControlMessage<E>, Response = ControlResult, Error = MqttError<E>>,
{
pub(crate) fn new(sink: MqttSink, publish: T, control: C) -> Self {
Self {
Expand All @@ -65,8 +72,8 @@ where

impl<T, C, E> Service for Dispatcher<T, C, E>
where
T: Service<Request = Publish, Response = ntex::util::Either<(), Publish>, Error = E>,
C: Service<Request = ControlMessage<E>, Response = ControlResult, Error = E>,
T: Service<Request = Publish, Response = Either<(), Publish>, Error = E>,
C: Service<Request = ControlMessage<E>, Response = ControlResult, Error = MqttError<E>>,
C::Future: 'static,
E: 'static,
{
Expand All @@ -80,7 +87,7 @@ where

fn poll_ready(&self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
let res1 = self.publish.poll_ready(cx).map_err(MqttError::Service)?;
let res2 = self.inner.control.poll_ready(cx).map_err(MqttError::Service)?;
let res2 = self.inner.control.poll_ready(cx)?;

if res1.is_pending() || res2.is_pending() {
Poll::Pending
Expand Down Expand Up @@ -218,8 +225,8 @@ pin_project_lite::pin_project! {

impl<T, C, E> Future for PublishResponse<T, C, E>
where
T: Service<Request = Publish, Response = ntex::util::Either<(), Publish>, Error = E>,
C: Service<Request = ControlMessage<E>, Response = ControlResult, Error = E>,
T: Service<Request = Publish, Response = Either<(), Publish>, Error = E>,
C: Service<Request = ControlMessage<E>, Response = ControlResult, Error = MqttError<E>>,
{
type Output = Result<Option<codec::Packet>, MqttError<E>>;

Expand Down Expand Up @@ -273,7 +280,7 @@ pin_project_lite::pin_project! {

impl<C, E> ControlResponse<C, E>
where
C: Service<Request = ControlMessage<E>, Response = ControlResult, Error = E>,
C: Service<Request = ControlMessage<E>, Response = ControlResult, Error = MqttError<E>>,
{
fn new(msg: ControlMessage<E>, inner: &Rc<Inner<C>>) -> Self {
Self { fut: inner.control.call(msg), inner: inner.clone(), _t: PhantomData }
Expand All @@ -282,14 +289,14 @@ where

impl<C, E> Future for ControlResponse<C, E>
where
C: Service<Request = ControlMessage<E>, Response = ControlResult, Error = E>,
C: Service<Request = ControlMessage<E>, Response = ControlResult, Error = MqttError<E>>,
{
type Output = Result<Option<codec::Packet>, MqttError<E>>;

fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let this = self.project();

let packet = match this.fut.poll(cx).map_err(MqttError::Service)? {
let packet = match this.fut.poll(cx)? {
Poll::Ready(item) => match item.result {
ControlResultKind::Ping => Some(codec::Packet::PingResponse),
ControlResultKind::PublishAck(id) => {
Expand Down
40 changes: 27 additions & 13 deletions src/v3/dispatcher.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@ use std::task::{Context, Poll};
use std::{future::Future, marker::PhantomData, num::NonZeroU16, pin::Pin, rc::Rc};

use ntex::service::{fn_factory_with_config, Service, ServiceFactory};
use ntex::util::{inflight::InFlightService, join, Either, HashSet, Ready};
use ntex::util::{
buffer::BufferService, inflight::InFlightService, join, Either, HashSet, Ready,
};

use crate::error::{MqttError, ProtocolError};
use crate::io::DispatchItem;
Expand Down Expand Up @@ -51,11 +53,18 @@ where
async move {
let (publish, control) = fut.await;

let control = BufferService::new(
16,
|| MqttError::<C::Error>::Disconnected,
// limit number of in-flight messages
InFlightService::new(1, control?.map_err(MqttError::Service)),
);

Ok(
// limit number of in-flight messages
InFlightService::new(
inflight,
Dispatcher::<_, _, _, E>::new(cfg, publish?, control?),
Dispatcher::<_, _, _, E>::new(cfg, publish?, control),
),
)
}
Expand All @@ -80,7 +89,7 @@ struct Inner<C> {
impl<St, T, C, E> Dispatcher<St, T, C, E>
where
T: Service<Request = Publish, Response = (), Error = E>,
C: Service<Request = ControlMessage<E>, Response = ControlResult, Error = E>,
C: Service<Request = ControlMessage<E>, Response = ControlResult, Error = MqttError<E>>,
{
pub(crate) fn new(session: Session<St>, publish: T, control: C) -> Self {
let sink = session.sink().clone();
Expand All @@ -98,7 +107,7 @@ where
impl<St, T, C, E> Service for Dispatcher<St, T, C, E>
where
T: Service<Request = Publish, Response = (), Error = E>,
C: Service<Request = ControlMessage<E>, Response = ControlResult, Error = E>,
C: Service<Request = ControlMessage<E>, Response = ControlResult, Error = MqttError<E>>,
C::Future: 'static,
E: 'static,
{
Expand All @@ -112,7 +121,7 @@ where

fn poll_ready(&self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
let res1 = self.publish.poll_ready(cx).map_err(MqttError::Service)?;
let res2 = self.inner.control.poll_ready(cx).map_err(MqttError::Service)?;
let res2 = self.inner.control.poll_ready(cx)?;

if res1.is_pending() || res2.is_pending() {
Poll::Pending
Expand Down Expand Up @@ -252,7 +261,7 @@ pin_project_lite::pin_project! {
impl<T, C, E> Future for PublishResponse<T, C, E>
where
T: Service<Request = Publish, Response = (), Error = E>,
C: Service<Request = ControlMessage<E>, Response = ControlResult, Error = E>,
C: Service<Request = ControlMessage<E>, Response = ControlResult, Error = MqttError<E>>,
{
type Output = Result<Option<codec::Packet>, MqttError<E>>;

Expand Down Expand Up @@ -300,7 +309,7 @@ pin_project_lite::pin_project! {

impl<C: Service, E> ControlResponse<C, E>
where
C: Service<Request = ControlMessage<E>, Response = ControlResult, Error = E>,
C: Service<Request = ControlMessage<E>, Response = ControlResult, Error = MqttError<E>>,
{
#[allow(clippy::match_like_matches_macro)]
fn new(pkt: ControlMessage<E>, inner: &Rc<Inner<C>>) -> Self {
Expand All @@ -315,7 +324,7 @@ where

impl<C, E> Future for ControlResponse<C, E>
where
C: Service<Request = ControlMessage<E>, Response = ControlResult, Error = E>,
C: Service<Request = ControlMessage<E>, Response = ControlResult, Error = MqttError<E>>,
{
type Output = Result<Option<codec::Packet>, MqttError<E>>;

Expand Down Expand Up @@ -350,13 +359,18 @@ where
Poll::Ready(Err(err)) => {
// do not handle nested error
if *this.error {
Poll::Ready(Err(MqttError::Service(err)))
Poll::Ready(Err(err))
} else {
// handle error from control service
*this.error = true;
let fut = this.inner.control.call(ControlMessage::error(err));
self.as_mut().project().fut.set(fut);
self.poll(cx)
match err {
MqttError::Service(err) => {
*this.error = true;
let fut = this.inner.control.call(ControlMessage::error(err));
self.as_mut().project().fut.set(fut);
self.poll(cx)
}
_ => Poll::Ready(Err(err)),
}
}
}
Poll::Pending => Poll::Pending,
Expand Down
2 changes: 1 addition & 1 deletion src/v3/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ where

/// Service to handle control packets
///
/// All control packets are processed sequentially, max buffered
/// All control packets are processed sequentially, max number of buffered
/// control packets is 16.
pub fn control<F, Srv>(self, service: F) -> MqttServer<Io, St, C, Srv, P>
where
Expand Down
68 changes: 32 additions & 36 deletions src/v5/client/dispatcher.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use std::task::{Context, Poll};
use std::{future::Future, marker::PhantomData, num::NonZeroU16, pin::Pin, rc::Rc};

use ntex::service::Service;
use ntex::util::{Either, HashSet, Ready};
use ntex::util::{buffer::BufferService, inflight::InFlightService, Either, HashSet, Ready};

use crate::error::{MqttError, ProtocolError};
use crate::v5::shared::{Ack, MqttShared};
Expand All @@ -26,13 +26,16 @@ pub(super) fn create_dispatcher<T, C, E>(
>
where
E: From<T::Error> + 'static,
T: Service<
Request = Publish,
Response = ntex::util::Either<Publish, PublishAck>,
Error = E,
> + 'static,
T: Service<Request = Publish, Response = Either<Publish, PublishAck>, Error = E> + 'static,
C: Service<Request = ControlMessage<E>, Response = ControlResult, Error = E> + 'static,
{
let control = BufferService::new(
16,
|| MqttError::<C::Error>::Disconnected,
// limit number of in-flight messages
InFlightService::new(1, control.map_err(MqttError::Service)),
);

Dispatcher::<_, _, E>::new(sink, max_receive as usize, max_topic_alias, publish, control)
}

Expand All @@ -59,12 +62,8 @@ struct PublishInfo {

impl<T, C, E> Dispatcher<T, C, E>
where
T: Service<
Request = Publish,
Response = ntex::util::Either<Publish, PublishAck>,
Error = E,
>,
C: Service<Request = ControlMessage<E>, Response = ControlResult, Error = E>,
T: Service<Request = Publish, Response = Either<Publish, PublishAck>, Error = E>,
C: Service<Request = ControlMessage<E>, Response = ControlResult, Error = MqttError<E>>,
{
fn new(
sink: MqttSink,
Expand Down Expand Up @@ -93,12 +92,8 @@ where

impl<T, C, E> Service for Dispatcher<T, C, E>
where
T: Service<
Request = Publish,
Response = ntex::util::Either<Publish, PublishAck>,
Error = E,
>,
C: Service<Request = ControlMessage<E>, Response = ControlResult, Error = E>,
T: Service<Request = Publish, Response = Either<Publish, PublishAck>, Error = E>,
C: Service<Request = ControlMessage<E>, Response = ControlResult, Error = MqttError<E>>,
C::Future: 'static,
{
type Request = DispatchItem<Rc<MqttShared>>;
Expand All @@ -111,7 +106,7 @@ where

fn poll_ready(&self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
let res1 = self.publish.poll_ready(cx).map_err(MqttError::Service)?;
let res2 = self.inner.control.poll_ready(cx).map_err(MqttError::Service)?;
let res2 = self.inner.control.poll_ready(cx)?;

if res1.is_pending() || res2.is_pending() {
Poll::Pending
Expand Down Expand Up @@ -327,12 +322,8 @@ pin_project_lite::pin_project! {

impl<T, C, E> Future for PublishResponse<T, C, E>
where
T: Service<
Request = Publish,
Response = ntex::util::Either<Publish, PublishAck>,
Error = E,
>,
C: Service<Request = ControlMessage<E>, Response = ControlResult, Error = E>,
T: Service<Request = Publish, Response = Either<Publish, PublishAck>, Error = E>,
C: Service<Request = ControlMessage<E>, Response = ControlResult, Error = MqttError<E>>,
{
type Output = Result<Option<codec::Packet>, MqttError<E>>;

Expand All @@ -343,8 +334,8 @@ where
PublishResponseStateProject::Publish { fut } => {
let ack = match fut.poll(cx) {
Poll::Ready(Ok(res)) => match res {
ntex::util::Either::Right(ack) => ack,
ntex::util::Either::Left(pkt) => {
Either::Right(ack) => ack,
Either::Left(pkt) => {
this.state.set(PublishResponseState::Control {
fut: ControlResponse::new(
ControlMessage::publish(pkt.into_inner()),
Expand Down Expand Up @@ -397,7 +388,7 @@ pin_project_lite::pin_project! {

impl<C: Service, E> ControlResponse<C, E>
where
C: Service<Request = ControlMessage<E>, Response = ControlResult, Error = E>,
C: Service<Request = ControlMessage<E>, Response = ControlResult, Error = MqttError<E>>,
{
#[allow(clippy::match_like_matches_macro)]
fn new(pkt: ControlMessage<E>, inner: &Rc<Inner<C>>) -> Self {
Expand All @@ -423,7 +414,7 @@ where

impl<C, E> Future for ControlResponse<C, E>
where
C: Service<Request = ControlMessage<E>, Response = ControlResult, Error = E>,
C: Service<Request = ControlMessage<E>, Response = ControlResult, Error = MqttError<E>>,
{
type Output = Result<Option<codec::Packet>, MqttError<E>>;

Expand All @@ -439,15 +430,20 @@ where
}
Poll::Ready(Err(err)) => {
// do not handle nested error
if *this.error {
return Poll::Ready(Err(MqttError::Service(err)));
return if *this.error {
Poll::Ready(Err(err))
} else {
// handle error from control service
*this.error = true;
let fut = this.inner.control.call(ControlMessage::error(err));
self.as_mut().project().fut.set(fut);
return self.poll(cx);
}
match err {
MqttError::Service(err) => {
*this.error = true;
let fut = this.inner.control.call(ControlMessage::error(err));
self.as_mut().project().fut.set(fut);
self.poll(cx)
}
_ => Poll::Ready(Err(err)),
}
};
}
Poll::Pending => return Poll::Pending,
};
Expand Down
Loading

0 comments on commit 1fcd7d3

Please sign in to comment.