Skip to content

Commit

Permalink
memory pools support
Browse files Browse the repository at this point in the history
  • Loading branch information
fafhrd91 committed Dec 2, 2021
1 parent cd7cb65 commit 454cc1f
Show file tree
Hide file tree
Showing 18 changed files with 176 additions and 145 deletions.
4 changes: 4 additions & 0 deletions CHANGES.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
# Changes

## [0.7.6] - 2021-12-02

* Add memory pools support

## [0.7.5] - 2021-11-04

* v5: Use variable length byte to encode the subscription ID #73
Expand Down
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ exclude = [".gitignore", ".travis.yml", ".cargo/config"]
edition = "2018"

[dependencies]
ntex = { version = "0.4.9", default-features = false }
ntex = { version = "0.4.11", default-features = false }
bitflags = "1.3"
derive_more = "0.99"
log = "0.4"
Expand Down
11 changes: 10 additions & 1 deletion src/io.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ pub(crate) use ntex::framed::{DispatchItem, ReadTask, State, Timer, Write, Write

use ntex::codec::{AsyncRead, AsyncWrite, Decoder, Encoder};
use ntex::service::{IntoService, Service};
use ntex::{time::Seconds, util::Either};
use ntex::{time::Seconds, util::Either, util::Pool};

type Response<U> = <U as Encoder>::Item;

Expand All @@ -26,6 +26,7 @@ pin_project_lite::pin_project! {
state: State,
inner: Rc<RefCell<DispatcherState<S, U>>>,
st: IoDispatcherState,
pool: Pool,
timer: Timer,
updated: time::Instant,
keepalive_timeout: Seconds,
Expand Down Expand Up @@ -141,6 +142,7 @@ where
service: service.into_service(),
response: None,
response_idx: 0,
pool: state.memory_pool().pool(),
inner,
state,
codec,
Expand Down Expand Up @@ -259,6 +261,12 @@ where
}
}

// handle memory pool pressure
if this.pool.poll_ready(cx).is_pending() {
read.pause(cx.waker());
return Poll::Pending;
}

match this.st {
IoDispatcherState::Processing => {
loop {
Expand Down Expand Up @@ -536,6 +544,7 @@ mod tests {
Dispatcher {
service: service.into_service(),
st: IoDispatcherState::Processing,
pool: state.memory_pool().pool(),
response: None,
response_idx: 0,
inner,
Expand Down
32 changes: 29 additions & 3 deletions src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use std::{convert::TryFrom, fmt, future::Future, io, marker, pin::Pin, rc::Rc, t
use ntex::codec::{AsyncRead, AsyncWrite};
use ntex::service::{Service, ServiceFactory};
use ntex::time::{sleep, Seconds, Sleep};
use ntex::util::{join, Ready};
use ntex::util::{join, Pool, PoolId, PoolRef, Ready};

use crate::error::{MqttError, ProtocolError};
use crate::io::State;
Expand All @@ -16,6 +16,7 @@ pub struct MqttServer<Io, V3, V5, Err, InitErr> {
v3: V3,
v5: V5,
handshake_timeout: Seconds,
pool: Pool,
_t: marker::PhantomData<(Io, Err, InitErr)>,
}

Expand All @@ -33,6 +34,7 @@ impl<Io, Err, InitErr>
MqttServer {
v3: DefaultProtocolServer::new(ProtocolVersion::MQTT3),
v5: DefaultProtocolServer::new(ProtocolVersion::MQTT5),
pool: PoolId::P5.pool(),
handshake_timeout: Seconds::ZERO,
_t: marker::PhantomData,
}
Expand Down Expand Up @@ -62,6 +64,15 @@ impl<Io, V3, V5, Err, InitErr> MqttServer<Io, V3, V5, Err, InitErr> {
self.handshake_timeout = timeout;
self
}

/// Set memory pool.
///
/// Use specified memory pool for memory allocations. By default P5
/// memory pool is used.
pub fn memory_pool(mut self, id: PoolId) -> Self {
self.pool = id.pool();
self
}
}

impl<Io, V3, V5, Err, InitErr> MqttServer<Io, V3, V5, Err, InitErr>
Expand Down Expand Up @@ -124,6 +135,7 @@ where
MqttServer {
v3: service.inner_finish(),
v5: self.v5,
pool: self.pool,
handshake_timeout: self.handshake_timeout,
_t: marker::PhantomData,
}
Expand Down Expand Up @@ -153,6 +165,7 @@ where
MqttServer {
v3: service.finish_server(),
v5: self.v5,
pool: self.pool,
handshake_timeout: self.handshake_timeout,
_t: marker::PhantomData,
}
Expand Down Expand Up @@ -205,6 +218,7 @@ where
MqttServer {
v3: self.v3,
v5: service.inner_finish(),
pool: self.pool,
handshake_timeout: self.handshake_timeout,
_t: marker::PhantomData,
}
Expand Down Expand Up @@ -234,6 +248,7 @@ where
MqttServer {
v3: self.v3,
v5: service.finish_server(),
pool: self.pool,
handshake_timeout: self.handshake_timeout,
_t: marker::PhantomData,
}
Expand Down Expand Up @@ -275,6 +290,7 @@ where
>;

fn new_service(&self, _: ()) -> Self::Future {
let pool = self.pool.clone();
let handshake_timeout = self.handshake_timeout;
let fut = join(self.v3.new_service(()), self.v5.new_service(()));
Box::pin(async move {
Expand All @@ -283,6 +299,7 @@ where
let v5 = v5?;
Ok(MqttServerImpl {
handlers: Rc::new((v3, v5)),
pool,
handshake_timeout,
_t: marker::PhantomData,
})
Expand All @@ -294,6 +311,7 @@ where
pub struct MqttServerImpl<Io, V3, V5, Err> {
handlers: Rc<(V3, V5)>,
handshake_timeout: Seconds,
pool: Pool,
_t: marker::PhantomData<(Io, Err)>,
}

Expand All @@ -311,8 +329,9 @@ where
fn poll_ready(&self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
let ready1 = self.handlers.0.poll_ready(cx)?.is_ready();
let ready2 = self.handlers.1.poll_ready(cx)?.is_ready();
let ready3 = self.pool.poll_ready(cx).is_ready();

if ready1 && ready2 {
if ready1 && ready2 && ready3 {
Poll::Ready(Ok(()))
} else {
Poll::Pending
Expand All @@ -331,11 +350,18 @@ where
}

fn call(&self, req: Io) -> Self::Future {
let pool = self.pool.pool_ref();
let delay = self.handshake_timeout.map(sleep);

MqttServerImplResponse {
state: MqttServerImplState::Version {
item: Some((req, State::new(), VersionCodec, self.handlers.clone(), delay)),
item: Some((
req,
State::with_memory_pool(pool),
VersionCodec,
self.handlers.clone(),
delay,
)),
},
}
}
Expand Down
34 changes: 29 additions & 5 deletions src/service.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use std::{fmt, future::Future, marker::PhantomData, pin::Pin, rc::Rc};
use ntex::codec::{AsyncRead, AsyncWrite, Decoder, Encoder};
use ntex::service::{IntoServiceFactory, Service, ServiceFactory};
use ntex::time::{Millis, Seconds, Sleep};
use ntex::util::{select, Either};
use ntex::util::{select, Either, Pool};

use super::io::{DispatchItem, Dispatcher, State, Timer};

Expand All @@ -15,12 +15,14 @@ pub(crate) struct FramedService<St, C, T, Io, Codec> {
handler: Rc<T>,
disconnect_timeout: Seconds,
time: Timer,
pool: Pool,
_t: PhantomData<(St, Io, Codec)>,
}

impl<St, C, T, Io, Codec> FramedService<St, C, T, Io, Codec> {
pub(crate) fn new(connect: C, service: T, disconnect_timeout: Seconds) -> Self {
pub(crate) fn new(connect: C, service: T, pool: Pool, disconnect_timeout: Seconds) -> Self {
FramedService {
pool,
connect,
disconnect_timeout,
handler: Rc::new(service),
Expand Down Expand Up @@ -62,12 +64,14 @@ where
let handler = self.handler.clone();
let disconnect_timeout = self.disconnect_timeout;
let time = self.time.clone();
let pool = self.pool.clone();

// create connect service and then create service impl
Box::pin(async move {
Ok(FramedServiceImpl {
handler,
disconnect_timeout,
pool,
time,
connect: fut.await?,
_t: PhantomData,
Expand All @@ -80,6 +84,7 @@ pub(crate) struct FramedServiceImpl<St, C, T, Io, Codec> {
connect: C,
handler: Rc<T>,
disconnect_timeout: Seconds,
pool: Pool,
time: Timer,
_t: PhantomData<(St, Io, Codec)>,
}
Expand Down Expand Up @@ -109,7 +114,14 @@ where

#[inline]
fn poll_ready(&self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.connect.poll_ready(cx)
let ready1 = self.connect.poll_ready(cx)?.is_ready();
let ready2 = self.pool.poll_ready(cx).is_ready();

if ready1 && ready2 {
Poll::Ready(Ok(()))
} else {
Poll::Pending
}
}

#[inline]
Expand Down Expand Up @@ -148,14 +160,16 @@ pub(crate) struct FramedService2<St, C, T, Io, Codec> {
connect: C,
handler: Rc<T>,
disconnect_timeout: Seconds,
pool: Pool,
time: Timer,
_t: PhantomData<(St, Io, Codec)>,
}

impl<St, C, T, Io, Codec> FramedService2<St, C, T, Io, Codec> {
pub(crate) fn new(connect: C, service: T, disconnect_timeout: Seconds) -> Self {
pub(crate) fn new(connect: C, service: T, pool: Pool, disconnect_timeout: Seconds) -> Self {
FramedService2 {
connect,
pool,
disconnect_timeout,
handler: Rc::new(service),
time: Timer::new(Millis::ONE_SEC),
Expand Down Expand Up @@ -200,13 +214,15 @@ where
let handler = self.handler.clone();
let disconnect_timeout = self.disconnect_timeout;
let time = self.time.clone();
let pool = self.pool.clone();

// create connect service and then create service impl
Box::pin(async move {
Ok(FramedServiceImpl2 {
handler,
disconnect_timeout,
time,
pool,
connect: fut.await?,
_t: PhantomData,
})
Expand All @@ -217,6 +233,7 @@ where
pub(crate) struct FramedServiceImpl2<St, C, T, Io, Codec> {
connect: C,
handler: Rc<T>,
pool: Pool,
disconnect_timeout: Seconds,
time: Timer,
_t: PhantomData<(St, Io, Codec)>,
Expand Down Expand Up @@ -247,7 +264,14 @@ where

#[inline]
fn poll_ready(&self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.connect.poll_ready(cx)
let ready1 = self.connect.poll_ready(cx)?.is_ready();
let ready2 = self.pool.poll_ready(cx).is_ready();

if ready1 && ready2 {
Poll::Ready(Ok(()))
} else {
Poll::Pending
}
}

#[inline]
Expand Down
2 changes: 1 addition & 1 deletion src/topic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -250,7 +250,7 @@ impl FromStr for Topic {
#[inline]
fn from_str(s: &str) -> Result<Self, TopicError> {
s.split('/')
.map(|level| Level::from_str(level))
.map(Level::from_str)
.collect::<Result<Vec<_>, TopicError>>()
.map(Topic)
.and_then(
Expand Down
13 changes: 11 additions & 2 deletions src/v3/client/connector.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use ntex::codec::{AsyncRead, AsyncWrite};
use ntex::connect::{self, Address, Connect, Connector};
use ntex::service::Service;
use ntex::time::{timeout, Millis, Seconds};
use ntex::util::{select, ByteString, Bytes, Either};
use ntex::util::{select, ByteString, Bytes, Either, PoolId};

#[cfg(feature = "openssl")]
use ntex::connect::openssl::{OpensslConnector, SslConnector};
Expand Down Expand Up @@ -169,6 +169,15 @@ where
self
}

/// Set memory pool.
///
/// Use specified memory pool for memory allocations. By default P5
/// memory pool is used.
pub fn memory_pool(self, id: PoolId) -> Self {
self.pool.pool.set(id.pool_ref());
self
}

/// Use custom connector
pub fn connector<U>(self, connector: U) -> MqttConnector<A, U>
where
Expand Down Expand Up @@ -249,7 +258,7 @@ where

async move {
let mut io = fut.await?;
let state = State::new();
let state = State::with_memory_pool(pool.pool.get());
let codec = codec::Codec::new().max_size(max_packet_size);

state.send(&mut io, &codec, pkt.into()).await?;
Expand Down
Loading

0 comments on commit 454cc1f

Please sign in to comment.