Skip to content

Commit

Permalink
Add support in-flight messages size back-pressure
Browse files Browse the repository at this point in the history
  • Loading branch information
fafhrd91 committed Mar 14, 2022
1 parent af3ad27 commit 44391c9
Show file tree
Hide file tree
Showing 10 changed files with 301 additions and 22 deletions.
4 changes: 3 additions & 1 deletion CHANGES.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
# Changes

## [0.8.4] - 2022-0x-xx
## [0.8.4] - 2022-03-14

* Add support in-flight messages size back-pressure

* Refactor handshake timeout handling

Expand Down
215 changes: 215 additions & 0 deletions src/inflight.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,215 @@
//! Service that limits number of in-flight async requests.
use std::{cell::Cell, future::Future, marker, pin::Pin, rc::Rc, task::Context, task::Poll};

use ntex::{service::Service, task::LocalWaker};

pub(crate) trait SizedRequest {
fn size(&self) -> u32;
}

pub(crate) struct InFlightService<S> {
count: Counter,
service: S,
}

impl<S> InFlightService<S> {
pub fn new(max_cap: u16, max_size: usize, service: S) -> Self {
Self { service, count: Counter::new(max_cap, max_size) }
}
}

impl<T, R> Service<R> for InFlightService<T>
where
T: Service<R>,
R: SizedRequest,
{
type Response = T::Response;
type Error = T::Error;
type Future = InFlightServiceResponse<T, R>;

#[inline]
fn poll_ready(&self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
if self.service.poll_ready(cx)?.is_pending() {
Poll::Pending
} else if !self.count.available(cx) {
log::trace!("InFlight limit exceeded");
Poll::Pending
} else {
Poll::Ready(Ok(()))
}
}

#[inline]
fn poll_shutdown(&self, cx: &mut Context<'_>, is_error: bool) -> Poll<()> {
self.service.poll_shutdown(cx, is_error)
}

#[inline]
fn call(&self, req: R) -> Self::Future {
let size = if self.count.0.max_size > 0 { req.size() } else { 0 };
InFlightServiceResponse {
_guard: self.count.get(size),
_t: marker::PhantomData,
fut: self.service.call(req),
}
}
}

pin_project_lite::pin_project! {
#[doc(hidden)]
pub struct InFlightServiceResponse<T: Service<R>, R> {
#[pin]
fut: T::Future,
_guard: CounterGuard,
_t: marker::PhantomData<R>
}
}

impl<T: Service<R>, R> Future for InFlightServiceResponse<T, R> {
type Output = Result<T::Response, T::Error>;

fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
self.project().fut.poll(cx)
}
}

struct Counter(Rc<CounterInner>);

struct CounterInner {
max_cap: u16,
cur_cap: Cell<u16>,
max_size: usize,
cur_size: Cell<usize>,
task: LocalWaker,
}

impl Counter {
fn new(max_cap: u16, max_size: usize) -> Self {
Counter(Rc::new(CounterInner {
max_cap,
max_size,
cur_cap: Cell::new(0),
cur_size: Cell::new(0),
task: LocalWaker::new(),
}))
}

fn get(&self, size: u32) -> CounterGuard {
CounterGuard::new(size, self.0.clone())
}

fn available(&self, cx: &mut Context<'_>) -> bool {
self.0.available(cx)
}
}

struct CounterGuard(u32, Rc<CounterInner>);

impl CounterGuard {
fn new(size: u32, inner: Rc<CounterInner>) -> Self {
inner.inc(size);
CounterGuard(size, inner)
}
}

impl Unpin for CounterGuard {}

impl Drop for CounterGuard {
fn drop(&mut self) {
self.1.dec(self.0);
}
}

impl CounterInner {
fn inc(&self, size: u32) {
self.cur_cap.set(self.cur_cap.get() + 1);
self.cur_size.set(self.cur_size.get() + size as usize);
}

fn dec(&self, size: u32) {
let num = self.cur_cap.get();
self.cur_cap.set(num - 1);

let cur_size = self.cur_size.get();
let new_size = cur_size - (size as usize);
self.cur_size.set(new_size);

if num == self.max_cap || (cur_size > self.max_size && new_size <= self.max_size) {
self.task.wake();
}
}

fn available(&self, cx: &mut Context<'_>) -> bool {
if (self.max_cap == 0 || self.cur_cap.get() < self.max_cap)
&& (self.max_size == 0 || self.cur_size.get() <= self.max_size)
{
true
} else {
self.task.register(cx.waker());
false
}
}
}

#[cfg(test)]
mod tests {
use ntex::{service::Service, time::sleep, util::lazy};
use std::{task::Context, task::Poll, time::Duration};

use super::*;

struct SleepService(Duration);

impl Service<()> for SleepService {
type Response = ();
type Error = ();
type Future = Pin<Box<dyn Future<Output = Result<(), ()>>>>;

fn poll_ready(&self, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
Poll::Ready(Ok(()))
}

fn call(&self, _: ()) -> Self::Future {
let fut = sleep(self.0);
Box::pin(async move {
let _ = fut.await;
Ok::<_, ()>(())
})
}
}

impl SizedRequest for () {
fn size(&self) -> u32 {
12
}
}

#[ntex::test]
async fn test_inflight() {
let wait_time = Duration::from_millis(50);

let srv = InFlightService::new(1, 0, SleepService(wait_time));
assert_eq!(lazy(|cx| srv.poll_ready(cx)).await, Poll::Ready(Ok(())));

let res = srv.call(());
assert_eq!(lazy(|cx| srv.poll_ready(cx)).await, Poll::Pending);

let _ = res.await;
assert_eq!(lazy(|cx| srv.poll_ready(cx)).await, Poll::Ready(Ok(())));
assert!(lazy(|cx| srv.poll_shutdown(cx, false)).await.is_ready());
}

#[ntex::test]
async fn test_inflight2() {
let wait_time = Duration::from_millis(50);

let srv = InFlightService::new(0, 10, SleepService(wait_time));
assert_eq!(lazy(|cx| srv.poll_ready(cx)).await, Poll::Ready(Ok(())));

let res = srv.call(());
assert_eq!(lazy(|cx| srv.poll_ready(cx)).await, Poll::Pending);

let _ = res.await;
assert_eq!(lazy(|cx| srv.poll_ready(cx)).await, Poll::Ready(Ok(())));
}
}
1 change: 1 addition & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ pub mod error;
pub mod v3;
pub mod v5;

mod inflight;
mod io;
mod server;
mod service;
Expand Down
2 changes: 1 addition & 1 deletion src/v3/codec/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
#[allow(clippy::module_inception)]
mod codec;
mod decode;
mod encode;
pub(crate) mod encode;
mod packet;

pub use self::codec::Codec;
Expand Down
16 changes: 14 additions & 2 deletions src/v3/dispatcher.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@ use super::{codec, publish::Publish, shared::Ack, sink::MqttSink, Session};
pub(super) fn factory<St, T, C, E>(
publish: T,
control: C,
inflight: usize,
inflight: u16,
inflight_size: usize,
) -> impl ServiceFactory<
DispatchItem<Rc<MqttShared>>,
Session<St>,
Expand Down Expand Up @@ -54,15 +55,26 @@ where

Ok(
// limit number of in-flight messages
InFlightService::new(
crate::inflight::InFlightService::new(
inflight,
inflight_size,
Dispatcher::<_, _, _, E>::new(cfg, publish, control),
),
)
}
})
}

impl crate::inflight::SizedRequest for DispatchItem<Rc<MqttShared>> {
fn size(&self) -> u32 {
if let DispatchItem::Item(ref item) = self {
codec::encode::get_encoded_size(item) as u32
} else {
0
}
}
}

/// Mqtt protocol dispatcher
pub(crate) struct Dispatcher<St, T, C: Service<ControlMessage<E>>, E> {
session: Session<St>,
Expand Down
33 changes: 25 additions & 8 deletions src/v3/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,8 @@ pub struct MqttServer<St, H, C, P> {
control: C,
publish: P,
max_size: u32,
inflight: usize,
max_inflight: u16,
max_inflight_size: usize,
handshake_timeout: Seconds,
disconnect_timeout: Seconds,
pub(super) pool: Rc<MqttSinkPool>,
Expand All @@ -72,7 +73,8 @@ where
control: DefaultControlService::default(),
publish: DefaultPublishService::default(),
max_size: 0,
inflight: 16,
max_inflight: 16,
max_inflight_size: 65535,
handshake_timeout: Seconds::ZERO,
disconnect_timeout: Seconds(3),
pool: Default::default(),
Expand Down Expand Up @@ -125,8 +127,16 @@ where
/// Number of in-flight concurrent messages.
///
/// By default in-flight is set to 16 messages
pub fn inflight(mut self, val: usize) -> Self {
self.inflight = val;
pub fn inflight(mut self, val: u16) -> Self {
self.max_inflight = val;
self
}

/// Total size of in-flight messages.
///
/// By default total in-flight size is set to 64Kb
pub fn inflight_size(mut self, val: usize) -> Self {
self.max_inflight_size = val;
self
}

Expand All @@ -146,7 +156,8 @@ where
publish: self.publish,
control: service.into_factory(),
max_size: self.max_size,
inflight: self.inflight,
max_inflight: self.max_inflight,
max_inflight_size: self.max_inflight_size,
handshake_timeout: self.handshake_timeout,
disconnect_timeout: self.disconnect_timeout,
pool: self.pool,
Expand All @@ -166,7 +177,8 @@ where
publish: publish.into_factory(),
control: self.control,
max_size: self.max_size,
inflight: self.inflight,
max_inflight: self.max_inflight,
max_inflight_size: self.max_inflight_size,
handshake_timeout: self.handshake_timeout,
disconnect_timeout: self.disconnect_timeout,
pool: self.pool,
Expand Down Expand Up @@ -202,7 +214,7 @@ where
pool: self.pool.clone(),
_t: PhantomData,
},
factory(self.publish, self.control, self.inflight),
factory(self.publish, self.control, self.max_inflight, self.max_inflight_size),
self.disconnect_timeout,
)
}
Expand All @@ -224,7 +236,12 @@ where
ServerSelector {
check: Rc::new(check),
handshake: self.handshake,
handler: Rc::new(factory(self.publish, self.control, self.inflight)),
handler: Rc::new(factory(
self.publish,
self.control,
self.max_inflight,
self.max_inflight_size,
)),
max_size: self.max_size,
disconnect_timeout: self.disconnect_timeout,
_t: PhantomData,
Expand Down
2 changes: 1 addition & 1 deletion src/v5/codec/encode.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ use crate::error::EncodeError;
use crate::types::packet_type;
use crate::utils::{write_variable_length, Encode};

pub(super) trait EncodeLtd {
pub(crate) trait EncodeLtd {
fn encoded_size(&self, limit: u32) -> usize;

fn encode(&self, buf: &mut BytesMut, size: u32) -> Result<(), EncodeError>;
Expand Down
1 change: 1 addition & 0 deletions src/v5/codec/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ mod encode;
mod packet;

pub use self::codec::Codec;
pub(crate) use self::encode::EncodeLtd;
pub use self::packet::*;

pub type UserProperty = (ByteString, ByteString);
Expand Down
Loading

0 comments on commit 44391c9

Please sign in to comment.