Skip to content

Commit

Permalink
cleanup dispatchers generic constraints
Browse files Browse the repository at this point in the history
  • Loading branch information
fafhrd91 committed Jan 4, 2022
1 parent 382cf2a commit 446bbac
Show file tree
Hide file tree
Showing 4 changed files with 51 additions and 83 deletions.
40 changes: 19 additions & 21 deletions src/v3/dispatcher.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,37 +29,34 @@ pub(super) fn factory<St, T, C, E>(
InitError = MqttError<E>,
>
where
E: 'static,
St: 'static,
T: ServiceFactory<Publish, Session<St>, Response = (), Error = E, InitError = MqttError<E>>
+ 'static,
C: ServiceFactory<
ControlMessage<E>,
Session<St>,
Response = ControlResult,
Error = E,
InitError = MqttError<E>,
> + 'static,
T: ServiceFactory<Publish, Session<St>, Response = ()> + 'static,
C: ServiceFactory<ControlMessage<E>, Session<St>, Response = ControlResult> + 'static,
E: From<C::Error> + From<C::InitError> + From<T::Error> + From<T::InitError> + 'static,
{
fn_factory_with_config(move |cfg: Session<St>| {
// create services
let fut = join(publish.new_service(cfg.clone()), control.new_service(cfg.clone()));

async move {
let (publish, control) = fut.await;
let publish = publish.map_err(|e| MqttError::Service(e.into()))?;
let control = control
.map_err(|e| MqttError::Service(e.into()))?
.map_err(|e| MqttError::Service(E::from(e)));

let control = BufferService::new(
16,
|| MqttError::<C::Error>::Disconnected(None),
|| MqttError::<E>::Disconnected(None),
// limit number of in-flight messages
InFlightService::new(1, control?.map_err(MqttError::Service)),
InFlightService::new(1, control),
);

Ok(
// limit number of in-flight messages
InFlightService::new(
inflight,
Dispatcher::<_, _, _, E>::new(cfg, publish?, control),
Dispatcher::<_, _, _, E>::new(cfg, publish, control),
),
)
}
Expand All @@ -83,7 +80,8 @@ struct Inner<C> {

impl<St, T, C, E> Dispatcher<St, T, C, E>
where
T: Service<Publish, Response = (), Error = E>,
E: From<T::Error>,
T: Service<Publish, Response = ()>,
C: Service<ControlMessage<E>, Response = ControlResult, Error = MqttError<E>>,
{
pub(crate) fn new(session: Session<St>, publish: T, control: C) -> Self {
Expand All @@ -101,10 +99,9 @@ where

impl<St, T, C, E> Service<DispatchItem<Rc<MqttShared>>> for Dispatcher<St, T, C, E>
where
T: Service<Publish, Response = (), Error = E>,
C: Service<ControlMessage<E>, Response = ControlResult, Error = MqttError<E>>,
C::Future: 'static,
E: 'static,
E: From<T::Error> + 'static,
T: Service<Publish, Response = ()>,
C: Service<ControlMessage<E>, Response = ControlResult, Error = MqttError<E>> + 'static,
{
type Response = Option<codec::Packet>;
type Error = MqttError<E>;
Expand All @@ -114,7 +111,7 @@ where
>;

fn poll_ready(&self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
let res1 = self.publish.poll_ready(cx).map_err(MqttError::Service)?;
let res1 = self.publish.poll_ready(cx).map_err(|e| MqttError::Service(e.into()))?;
let res2 = self.inner.control.poll_ready(cx)?;

if res1.is_pending() || res2.is_pending() {
Expand Down Expand Up @@ -259,7 +256,8 @@ pin_project_lite::pin_project! {

impl<T, C, E> Future for PublishResponse<T, C, E>
where
T: Service<Publish, Response = (), Error = E>,
E: From<T::Error>,
T: Service<Publish, Response = ()>,
C: Service<ControlMessage<E>, Response = ControlResult, Error = MqttError<E>>,
{
type Output = Result<Option<codec::Packet>, MqttError<E>>;
Expand All @@ -283,7 +281,7 @@ where
}
Poll::Ready(Err(e)) => {
this.state.set(PublishResponseState::Control {
fut: ControlResponse::new(ControlMessage::error(e), this.inner),
fut: ControlResponse::new(ControlMessage::error(e.into()), this.inner),
});
self.poll(cx)
}
Expand Down
20 changes: 3 additions & 17 deletions src/v3/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -171,24 +171,15 @@ where
>,
Rc<MqttShared>,
> {
let handshake = self.handshake;
let publish = self
.publish
.into_factory()
.map_err(|e| e.into())
.map_init_err(|e| MqttError::Service(e.into()));
let control =
self.control.map_err(|e| e.into()).map_init_err(|e| MqttError::Service(e.into()));

service::MqttServer::new(
HandshakeFactory {
factory: handshake,
factory: self.handshake,
max_size: self.max_size,
handshake_timeout: self.handshake_timeout,
pool: self.pool.clone(),
_t: PhantomData,
},
factory(publish, control, self.inflight),
factory(self.publish, self.control, self.inflight),
self.disconnect_timeout,
)
}
Expand All @@ -207,15 +198,10 @@ where
F: Fn(&Handshake) -> R + 'static,
R: Future<Output = Result<bool, C::Error>> + 'static,
{
let publish =
self.publish.map_err(|e| e.into()).map_init_err(|e| MqttError::Service(e.into()));
let control =
self.control.map_err(|e| e.into()).map_init_err(|e| MqttError::Service(e.into()));

ServerSelector {
check: Rc::new(check),
connect: self.handshake,
handler: Rc::new(factory(publish, control, self.inflight)),
handler: Rc::new(factory(self.publish, self.control, self.inflight)),
max_size: self.max_size,
disconnect_timeout: self.disconnect_timeout,
_t: PhantomData,
Expand Down
58 changes: 27 additions & 31 deletions src/v5/dispatcher.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,16 +29,9 @@ pub(super) fn factory<St, T, C, E>(
>
where
St: 'static,
E: From<T::Error> + 'static,
T: ServiceFactory<Publish, Session<St>, Response = PublishAck, InitError = MqttError<E>>
+ 'static,
C: ServiceFactory<
ControlMessage<E>,
Session<St>,
Response = ControlResult,
Error = E,
InitError = MqttError<E>,
> + 'static,
E: From<T::Error> + From<T::InitError> + From<C::Error> + From<C::InitError> + 'static,
T: ServiceFactory<Publish, Session<St>, Response = PublishAck> + 'static,
C: ServiceFactory<ControlMessage<E>, Session<St>, Response = ControlResult> + 'static,
PublishAck: TryFrom<T::Error, Error = E>,
{
fn_factory_with_config(move |cfg: Session<St>| {
Expand All @@ -49,34 +42,38 @@ where

async move {
let (publish, control) = fut.await;
let publish = publish.map_err(|e| MqttError::Service(e.into()))?;
let control = control
.map_err(|e| MqttError::Service(e.into()))?
.map_err(|e| MqttError::Service(E::from(e)));

let control = BufferService::new(
16,
|| MqttError::<C::Error>::Disconnected(None),
|| MqttError::<E>::Disconnected(None),
// limit number of in-flight messages
InFlightService::new(1, control?.map_err(MqttError::Service)),
InFlightService::new(1, control),
);

Ok(Dispatcher::<_, _, E, T::Error>::new(
Ok(Dispatcher::<_, _, E>::new(
cfg.sink().clone(),
max_receive as usize,
max_topic_alias,
publish?,
publish,
control,
))
}
})
}

/// Mqtt protocol dispatcher
pub(crate) struct Dispatcher<T, C: Service<ControlMessage<E>>, E, E2> {
pub(crate) struct Dispatcher<T, C: Service<ControlMessage<E>>, E> {
sink: MqttSink,
publish: T,
shutdown: RefCell<Option<Pin<Box<C::Future>>>>,
max_receive: usize,
max_topic_alias: u16,
inner: Rc<Inner<C>>,
_t: marker::PhantomData<(E, E2)>,
_t: marker::PhantomData<E>,
}

struct Inner<C> {
Expand All @@ -90,10 +87,11 @@ struct PublishInfo {
aliases: HashSet<num::NonZeroU16>,
}

impl<T, C, E, E2> Dispatcher<T, C, E, E2>
impl<T, C, E> Dispatcher<T, C, E>
where
T: Service<Publish, Response = PublishAck, Error = E2>,
PublishAck: TryFrom<E2, Error = E>,
E: From<T::Error>,
T: Service<Publish, Response = PublishAck>,
PublishAck: TryFrom<T::Error, Error = E>,
C: Service<ControlMessage<E>, Response = ControlResult, Error = MqttError<E>>,
{
fn new(
Expand Down Expand Up @@ -122,17 +120,17 @@ where
}
}

impl<T, C, E, E2> Service<DispatchItem<Rc<MqttShared>>> for Dispatcher<T, C, E, E2>
impl<T, C, E> Service<DispatchItem<Rc<MqttShared>>> for Dispatcher<T, C, E>
where
T: Service<Publish, Response = PublishAck, Error = E2>,
PublishAck: TryFrom<E2, Error = E>,
E: From<T::Error>,
T: Service<Publish, Response = PublishAck>,
PublishAck: TryFrom<T::Error, Error = E>,
C: Service<ControlMessage<E>, Response = ControlResult, Error = MqttError<E>> + 'static,
E: From<E2>,
{
type Response = Option<codec::Packet>;
type Error = MqttError<E>;
type Future = Either<
PublishResponse<T, C, E, E2>,
PublishResponse<T, C, E>,
Either<Ready<Self::Response, MqttError<E>>, ControlResponse<C, E>>,
>;

Expand Down Expand Up @@ -235,7 +233,6 @@ where
state: PublishResponseState::Publish {
fut: self.publish.call(Publish::new(publish)),
},
_t: marker::PhantomData,
})
}
DispatchItem::Item(codec::Packet::PublishAck(packet)) => {
Expand Down Expand Up @@ -332,12 +329,11 @@ where

pin_project_lite::pin_project! {
/// Publish service response future
pub(crate) struct PublishResponse<T: Service<Publish>, C: Service<ControlMessage<E>>, E, E2> {
pub(crate) struct PublishResponse<T: Service<Publish>, C: Service<ControlMessage<E>>, E> {
#[pin]
state: PublishResponseState<T, C, E>,
packet_id: u16,
inner: Rc<Inner<C>>,
_t: marker::PhantomData<(E, E2)>,
}
}

Expand All @@ -349,11 +345,11 @@ pin_project_lite::pin_project! {
}
}

impl<T, C, E, E2> Future for PublishResponse<T, C, E, E2>
impl<T, C, E> Future for PublishResponse<T, C, E>
where
E: From<E2>,
T: Service<Publish, Response = PublishAck, Error = E2>,
PublishAck: TryFrom<E2, Error = E>,
E: From<T::Error>,
T: Service<Publish, Response = PublishAck>,
PublishAck: TryFrom<T::Error, Error = E>,
C: Service<ControlMessage<E>, Response = ControlResult, Error = MqttError<E>>,
{
type Output = Result<Option<codec::Packet>, MqttError<E>>;
Expand Down
16 changes: 2 additions & 14 deletions src/v5/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -211,12 +211,6 @@ where
>,
Rc<MqttShared>,
> {
let publish = self.srv_publish.map_init_err(|e| MqttError::Service(e.into()));
let control = self
.srv_control
.map_err(<C::Error>::from)
.map_init_err(|e| MqttError::Service(e.into()));

service::MqttServer::new(
HandshakeFactory {
factory: self.handshake,
Expand All @@ -228,7 +222,7 @@ where
pool: self.pool,
_t: PhantomData,
},
factory(publish, control),
factory(self.srv_publish, self.srv_control),
self.disconnect_timeout,
)
}
Expand All @@ -247,16 +241,10 @@ where
F: Fn(&Handshake) -> R + 'static,
R: Future<Output = Result<bool, C::Error>> + 'static,
{
let publish = self.srv_publish.map_init_err(|e| MqttError::Service(e.into()));
let control = self
.srv_control
.map_err(<C::Error>::from)
.map_init_err(|e| MqttError::Service(e.into()));

ServerSelector::<St, _, _, _, _> {
check: Rc::new(check),
connect: self.handshake,
handler: Rc::new(factory(publish, control)),
handler: Rc::new(factory(self.srv_publish, self.srv_control)),
max_size: self.max_size,
max_receive: self.max_receive,
max_topic_alias: self.max_topic_alias,
Expand Down

0 comments on commit 446bbac

Please sign in to comment.