From 9c2b4636ad28906d045cd6a06c5d1a63bc0c8262 Mon Sep 17 00:00:00 2001 From: Nikolay Kim Date: Tue, 21 Dec 2021 01:52:49 +0600 Subject: [PATCH] upgrade to ntex 0.5 (#80) --- CHANGES.md | 4 + Cargo.toml | 9 +- examples/basic.rs | 19 +- examples/openssl.rs | 14 +- examples/rustls.rs | 20 +- examples/session.rs | 20 +- examples/subs.rs | 9 +- examples/subs_client.rs | 4 + src/io.rs | 427 +++++++++++++++++------------------- src/server.rs | 180 +++++++-------- src/service.rs | 134 ++++------- src/v3/client/connection.rs | 67 ++---- src/v3/client/connector.rs | 79 ++++--- src/v3/client/control.rs | 10 +- src/v3/client/dispatcher.rs | 17 +- src/v3/control.rs | 20 +- src/v3/dispatcher.rs | 16 +- src/v3/handshake.rs | 50 ++--- src/v3/mod.rs | 2 +- src/v3/selector.rs | 126 +++++------ src/v3/server.rs | 217 ++++++------------ src/v3/shared.rs | 9 +- src/v3/sink.rs | 29 ++- src/v5/client/connection.rs | 89 +++----- src/v5/client/connector.rs | 77 ++++--- src/v5/client/control.rs | 23 ++ src/v5/client/dispatcher.rs | 10 +- src/v5/control.rs | 27 ++- src/v5/dispatcher.rs | 9 +- src/v5/handshake.rs | 43 ++-- src/v5/selector.rs | 128 +++++------ src/v5/server.rs | 267 +++++++--------------- src/v5/shared.rs | 9 +- src/v5/sink.rs | 37 ++-- tests/test_server.rs | 176 +++++++-------- tests/test_server_both.rs | 15 +- tests/test_server_v5.rs | 377 ++++++++++++++++--------------- 37 files changed, 1242 insertions(+), 1527 deletions(-) diff --git a/CHANGES.md b/CHANGES.md index 277d0a3b..560a9357 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -1,5 +1,9 @@ # Changes +## [0.8.0-b.0] - 2021-12-21 + +* Upgrade to ntex 0.5 + ## [0.7.7] - 2021-12-17 * Wait for close control message and inner services on dispatcher shutdown #78 diff --git a/Cargo.toml b/Cargo.toml index da75bb7a..c7d641ea 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "ntex-mqtt" -version = "0.7.7" +version = "0.8.0-b.0" authors = ["ntex contributors "] description = "Client and Server framework for MQTT v5 and v3.1.1 protocols" documentation = "https://docs.rs/ntex-mqtt" @@ -12,7 +12,7 @@ exclude = [".gitignore", ".travis.yml", ".cargo/config"] edition = "2018" [dependencies] -ntex = { version = "0.4.11", default-features = false } +ntex = { version = "0.5.0-b.1", default-features = false } bitflags = "1.3" derive_more = "0.99" log = "0.4" @@ -23,10 +23,9 @@ pin-project-lite = "0.2" [dev-dependencies] env_logger = "0.9" futures = "0.3" +ntex-tls = "0.1.0-b.1" rustls = "0.20" rustls-pemfile = "0.2" -tokio-rustls = "0.23" openssl = "0.10" -tokio-openssl = "0.6" -ntex = { version = "0.4", features = ["rustls", "openssl"] } +ntex = { version = "0.5.0-b.1", features = ["rustls", "openssl"] } diff --git a/examples/basic.rs b/examples/basic.rs index e4239a1a..9e03c601 100644 --- a/examples/basic.rs +++ b/examples/basic.rs @@ -20,9 +20,9 @@ impl std::convert::TryFrom for v5::PublishAck { } } -async fn handshake_v3( - handshake: v3::Handshake, -) -> Result, ServerError> { +async fn handshake_v3( + handshake: v3::Handshake, +) -> Result, ServerError> { log::info!("new connection: {:?}", handshake); Ok(handshake.ack(Session, false)) } @@ -32,9 +32,9 @@ async fn publish_v3(publish: v3::Publish) -> Result<(), ServerError> { Ok(()) } -async fn handshake_v5( - handshake: v5::Handshake, -) -> Result, ServerError> { +async fn handshake_v5( + handshake: v5::Handshake, +) -> Result, ServerError> { log::info!("new connection: {:?}", handshake); Ok(handshake.ack(Session)) } @@ -46,16 +46,11 @@ async fn publish_v5(publish: v5::Publish) -> Result #[ntex::main] async fn main() -> std::io::Result<()> { - println!("{}", std::mem::size_of::()); - println!("{}", std::mem::size_of::()); - println!("{}", std::mem::size_of::()); - println!("{}", std::mem::size_of::>()); - println!("{}", std::mem::size_of::>()); std::env::set_var("RUST_LOG", "ntex=trace,ntex_mqtt=trace,basic=trace"); env_logger::init(); ntex::server::Server::build() - .bind("mqtt", "127.0.0.1:1883", || { + .bind("mqtt", "127.0.0.1:1883", |_| { MqttServer::new() .v3(v3::MqttServer::new(handshake_v3).publish(publish_v3)) .v5(v5::MqttServer::new(handshake_v5).publish(publish_v5)) diff --git a/examples/openssl.rs b/examples/openssl.rs index eed76b5d..8577f902 100644 --- a/examples/openssl.rs +++ b/examples/openssl.rs @@ -1,9 +1,7 @@ -use ntex::rt::net::TcpStream; -use ntex::server::openssl::Acceptor; use ntex::service::pipeline_factory; use ntex_mqtt::{v3, v5, MqttError, MqttServer}; +use ntex_tls::openssl::Acceptor; use openssl::ssl::{SslAcceptor, SslFiletype, SslMethod}; -use tokio_openssl::SslStream; #[derive(Clone)] struct Session; @@ -26,8 +24,8 @@ impl std::convert::TryFrom for v5::PublishAck { } async fn handshake_v3( - handshake: v3::Handshake>, -) -> Result, Session>, ServerError> { + handshake: v3::Handshake, +) -> Result, ServerError> { log::info!("new connection: {:?}", handshake); Ok(handshake.ack(Session, false)) } @@ -38,8 +36,8 @@ async fn publish_v3(publish: v3::Publish) -> Result<(), ServerError> { } async fn handshake_v5( - handshake: v5::Handshake>, -) -> Result, Session>, ServerError> { + handshake: v5::Handshake, +) -> Result, ServerError> { log::info!("new connection: {:?}", handshake); Ok(handshake.ack(Session)) } @@ -63,7 +61,7 @@ async fn main() -> std::io::Result<()> { let acceptor = builder.build(); ntex::server::Server::build() - .bind("mqtt", "127.0.0.1:8883", move || { + .bind("mqtt", "127.0.0.1:8883", move |_| { pipeline_factory(Acceptor::new(acceptor.clone())) .map_err(|_err| MqttError::Service(ServerError {})) .and_then( diff --git a/examples/rustls.rs b/examples/rustls.rs index 95d5ba17..9efa09ec 100644 --- a/examples/rustls.rs +++ b/examples/rustls.rs @@ -1,12 +1,10 @@ -use std::{fs::File, io::BufReader}; +use std::{fs::File, io::BufReader, sync::Arc}; -use ntex::rt::net::TcpStream; -use ntex::server::rustls::Acceptor; use ntex::service::pipeline_factory; use ntex_mqtt::{v3, v5, MqttError, MqttServer}; +use ntex_tls::rustls::Acceptor; use rustls::{Certificate, PrivateKey, ServerConfig}; use rustls_pemfile::{certs, rsa_private_keys}; -use tokio_rustls::server::TlsStream; #[derive(Clone)] struct Session; @@ -29,8 +27,8 @@ impl std::convert::TryFrom for v5::PublishAck { } async fn handshake_v3( - handshake: v3::Handshake>, -) -> Result, Session>, ServerError> { + handshake: v3::Handshake, +) -> Result, ServerError> { log::info!("new connection: {:?}", handshake); Ok(handshake.ack(Session, false)) } @@ -41,8 +39,8 @@ async fn publish_v3(publish: v3::Publish) -> Result<(), ServerError> { } async fn handshake_v5( - handshake: v5::Handshake>, -) -> Result, Session>, ServerError> { + handshake: v5::Handshake, +) -> Result, ServerError> { log::info!("new connection: {:?}", handshake); Ok(handshake.ack(Session)) } @@ -72,11 +70,11 @@ async fn main() -> std::io::Result<()> { .with_single_cert(cert_chain, keys) .unwrap(); - let tls_acceptor = Acceptor::new(tls_config); + let tls_acceptor = Arc::new(tls_config); ntex::server::Server::build() - .bind("mqtt", "127.0.0.1:8883", move || { - pipeline_factory(tls_acceptor.clone()) + .bind("mqtt", "127.0.0.1:8883", move |_| { + pipeline_factory(Acceptor::new(tls_acceptor.clone())) .map_err(|_err| MqttError::Service(ServerError {})) .and_then( MqttServer::new() diff --git a/examples/session.rs b/examples/session.rs index 3336f251..c468b5ad 100644 --- a/examples/session.rs +++ b/examples/session.rs @@ -1,5 +1,5 @@ -use futures::future::ok; use ntex::service::{fn_factory_with_config, fn_service}; +use ntex::util::Ready; use ntex_mqtt::v5::codec::PublishAckReason; use ntex_mqtt::{v3, v5, MqttServer}; @@ -26,9 +26,9 @@ impl std::convert::TryFrom for v5::PublishAck { } } -async fn handshake_v3( - handshake: v3::Handshake, -) -> Result, MyServerError> { +async fn handshake_v3( + handshake: v3::Handshake, +) -> Result, MyServerError> { log::info!("new connection: {:?}", handshake); let session = MySession { client_id: handshake.packet().client_id.to_string() }; @@ -56,9 +56,9 @@ async fn publish_v3( } } -async fn handshake_v5( - handshake: v5::Handshake, -) -> Result, MyServerError> { +async fn handshake_v5( + handshake: v5::Handshake, +) -> Result, MyServerError> { log::info!("new connection: {:?}", handshake); let session = MySession { client_id: handshake.packet().client_id.to_string() }; @@ -93,18 +93,18 @@ async fn main() -> std::io::Result<()> { log::info!("Hello"); ntex::server::Server::build() - .bind("mqtt", "127.0.0.1:1883", || { + .bind("mqtt", "127.0.0.1:1883", |_| { MqttServer::new() .v3(v3::MqttServer::new(handshake_v3).publish(fn_factory_with_config( |session: v3::Session| { - ok::<_, MyServerError>(fn_service(move |req| { + Ready::Ok::<_, MyServerError>(fn_service(move |req| { publish_v3(session.clone(), req) })) }, ))) .v5(v5::MqttServer::new(handshake_v5).publish(fn_factory_with_config( |session: v5::Session| { - ok::<_, MyServerError>(fn_service(move |req| { + Ready::Ok::<_, MyServerError>(fn_service(move |req| { publish_v5(session.clone(), req) })) }, diff --git a/examples/subs.rs b/examples/subs.rs index 8ff1297b..bced4590 100644 --- a/examples/subs.rs +++ b/examples/subs.rs @@ -30,9 +30,9 @@ impl std::convert::TryFrom for PublishAck { } } -async fn handshake( - handshake: v5::Handshake, -) -> Result, MyServerError> { +async fn handshake( + handshake: v5::Handshake, +) -> Result, MyServerError> { log::info!("new connection: {:?}", handshake); let session = MySession { @@ -96,6 +96,7 @@ fn control_service_factory() -> impl ServiceFactory< } v5::ControlMessage::Unsubscribe(s) => Ready::Ok(s.ack()), v5::ControlMessage::Closed(c) => Ready::Ok(c.ack()), + v5::ControlMessage::PeerGone(c) => Ready::Ok(c.ack()), })) }) } @@ -106,7 +107,7 @@ async fn main() -> std::io::Result<()> { env_logger::init(); ntex::server::Server::build() - .bind("mqtt", "127.0.0.1:1883", || { + .bind("mqtt", "127.0.0.1:1883", |_| { MqttServer::new(handshake) .control(control_service_factory()) .publish(fn_factory_with_config(|session: Session| { diff --git a/examples/subs_client.rs b/examples/subs_client.rs index 81281e87..a6f3dea2 100644 --- a/examples/subs_client.rs +++ b/examples/subs_client.rs @@ -52,6 +52,10 @@ async fn main() -> std::io::Result<()> { log::error!("Protocol error: {:?}", msg); Ready::Ok(msg.ack()) } + v5::client::ControlMessage::PeerGone(msg) => { + log::warn!("Peer closed connection: {:?}", msg.error()); + Ready::Ok(msg.ack()) + } v5::client::ControlMessage::Closed(msg) => { log::warn!("Server closed connection: {:?}", msg); Ready::Ok(msg.ack()) diff --git a/src/io.rs b/src/io.rs index d78160cc..dd97aacc 100644 --- a/src/io.rs +++ b/src/io.rs @@ -2,11 +2,11 @@ use std::task::{Context, Poll}; use std::{cell::RefCell, collections::VecDeque, future::Future, pin::Pin, rc::Rc, time}; -pub(crate) use ntex::framed::{DispatchItem, ReadTask, State, Timer, Write, WriteTask}; - -use ntex::codec::{AsyncRead, AsyncWrite, Decoder, Encoder}; +use ntex::codec::{Decoder, Encoder}; +use ntex::io::{DispatchItem, Io, IoBoxed, IoRef, Timer}; use ntex::service::{IntoService, Service}; -use ntex::{time::Seconds, util::Either, util::Pool}; +use ntex::time::{now, Seconds}; +use ntex::util::{Either, Pool}; type Response = ::Item; @@ -23,13 +23,14 @@ pin_project_lite::pin_project! { { service: S, codec: U, - state: State, + io: IoBoxed, inner: Rc>>, st: IoDispatcherState, - pool: Pool, timer: Timer, updated: time::Instant, keepalive_timeout: Seconds, + ready_err: bool, + pool: Pool, #[pin] response: Option, response_idx: usize, @@ -71,12 +72,9 @@ pub(crate) enum IoDispatcherError { Service(S), } -impl From> for IoDispatcherError { - fn from(err: Either) -> Self { - match err { - Either::Left(err) => IoDispatcherError::Service(err), - Either::Right(err) => IoDispatcherError::Encoder(err), - } +impl From for IoDispatcherError { + fn from(err: S) -> Self { + IoDispatcherError::Service(err) } } @@ -109,23 +107,18 @@ where ::Item: 'static, { /// Construct new `Dispatcher` instance with outgoing messages stream. - pub(crate) fn with>( - io: T, - state: State, + pub(crate) fn new>( + io: IoBoxed, codec: U, service: F, timer: Timer, - ) -> Self - where - T: AsyncRead + AsyncWrite + Unpin + 'static, - { - let updated = timer.now(); + ) -> Self { + let updated = now(); let keepalive_timeout = Seconds(30); - let io = Rc::new(RefCell::new(io)); // register keepalive timer let expire = updated + time::Duration::from(keepalive_timeout); - timer.register(expire, expire, &state); + timer.register(expire, expire, io.as_ref()); let inner = Rc::new(RefCell::new(DispatcherState { error: None, @@ -133,18 +126,15 @@ where queue: VecDeque::new(), })); - // start support tasks - ntex::rt::spawn(ReadTask::new(io.clone(), state.clone())); - ntex::rt::spawn(WriteTask::new(io, state.clone())); - Dispatcher { st: IoDispatcherState::Processing, service: service.into_service(), response: None, response_idx: 0, - pool: state.memory_pool().pool(), + pool: io.memory_pool().pool(), + ready_err: false, inner, - state, + io, codec, timer, updated, @@ -161,10 +151,10 @@ where // register keepalive timer let prev = self.updated + time::Duration::from(self.keepalive_timeout); if timeout.is_zero() { - self.timer.unregister(prev, &self.state); + self.timer.unregister(prev, self.io.as_ref()); } else { let expire = self.updated + time::Duration::from(timeout); - self.timer.register(expire, prev, &self.state); + self.timer.register(expire, prev, self.io.as_ref()); } self.keepalive_timeout = timeout; @@ -180,16 +170,14 @@ where /// /// By default disconnect timeout is set to 1 seconds. pub(crate) fn disconnect_timeout(self, val: Seconds) -> Self { - self.state.set_disconnect_timeout(val); + self.io.set_disconnect_timeout(val.into()); self } } impl DispatcherState where - S: Service, Response = Option>>, - S::Error: 'static, - S::Future: 'static, + S: Service, Response = Option>> + 'static, U: Encoder + Decoder, ::Item: 'static, { @@ -197,7 +185,7 @@ where &mut self, item: Result, response_idx: usize, - write: Write<'_>, + write: &IoRef, codec: &U, wake: bool, ) { @@ -207,16 +195,32 @@ where if idx == 0 { let _ = self.queue.pop_front(); self.base = self.base.wrapping_add(1); - if let Err(err) = write.encode_result(item, codec) { - self.error = Some(err.into()); + match item { + Err(err) => { + self.error = Some(err.into()); + } + Ok(Some(item)) => { + if let Err(err) = write.encode(item, codec) { + self.error = Some(IoDispatcherError::Encoder(err)); + } + } + Ok(None) => (), } // check remaining response while let Some(item) = self.queue.front_mut().and_then(|v| v.take()) { let _ = self.queue.pop_front(); self.base = self.base.wrapping_add(1); - if let Err(err) = write.encode_result(item, codec) { - self.error = Some(err.into()); + match item { + Err(err) => { + self.error = Some(err.into()); + } + Ok(Some(item)) => { + if let Err(err) = write.encode(item, codec) { + self.error = Some(IoDispatcherError::Encoder(err)); + } + } + Ok(None) => (), } } @@ -239,8 +243,7 @@ where fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { let mut this = self.as_mut().project(); - let read = this.state.read(); - let write = this.state.write(); + let io = this.io; // log::trace!("IO-DISP poll :{:?}:", this.st); @@ -252,7 +255,7 @@ where this.inner.borrow_mut().handle_result( item, *this.response_idx, - write, + io.as_ref(), this.codec, false, ); @@ -263,113 +266,104 @@ where // handle memory pool pressure if this.pool.poll_ready(cx).is_pending() { - read.pause(cx.waker()); + io.pause(cx); return Poll::Pending; } - match this.st { - IoDispatcherState::Processing => { - loop { - // log::trace!("IO-DISP state :{:?}:", this.state.flags()); + loop { + match this.st { + IoDispatcherState::Processing => { + // log::trace!("IO-DISP state :{:?}:", io.flags()); match this.service.poll_ready(cx) { Poll::Ready(Ok(_)) => { - let mut retry = false; - // service is ready, wake io read task - read.resume(); + io.resume(); // check keepalive timeout - if this.state.is_keepalive() { + if io.is_keepalive() { log::trace!("keepalive timeout"); let mut inner = this.inner.borrow_mut(); if inner.error.is_none() { inner.error = Some(IoDispatcherError::KeepAlive); } - this.state.dispatcher_stopped(); + io.stop_dispatcher(); } - let item = if this.state.is_dispatcher_stopped() { - retry = true; - let mut inner = this.inner.borrow_mut(); + // decode incoming bytes stream + let item = match io.poll_read_next(this.codec, cx) { + Poll::Pending => { + // log::trace!("not enough data to decode next frame, register dispatch task"); + if io.is_dispatcher_stopped() { + log::trace!("dispatcher is instructed to stop"); + let mut inner = this.inner.borrow_mut(); + + // unregister keep-alive timer + if this.keepalive_timeout.non_zero() { + this.timer.unregister( + *this.updated + + time::Duration::from( + *this.keepalive_timeout, + ), + io.as_ref(), + ); + } - // unregister keep-alive timer - if this.keepalive_timeout.non_zero() { - this.timer.unregister( - *this.updated - + time::Duration::from(*this.keepalive_timeout), - this.state, - ); + // check for errors + let item = inner + .error + .as_mut() + .and_then(|err| err.take()) + .or_else(|| { + io.take_error() + .map(|e| DispatchItem::Disconnect(Some(e))) + }); + *this.st = IoDispatcherState::Stop; + item + } else { + return Poll::Pending; + } } + Poll::Ready(Some(Ok(el))) => { + // update keep-alive timer + if this.keepalive_timeout.non_zero() { + let updated = now(); + if updated != *this.updated { + let ka = + time::Duration::from(*this.keepalive_timeout); + this.timer.register( + updated + ka, + *this.updated + ka, + io.as_ref(), + ); + *this.updated = updated; + } + } - // process unhandled data - if let Ok(Some(el)) = read.decode(this.codec) { Some(DispatchItem::Item(el)) - } else { - log::trace!("dispatcher is instructed to stop"); - - // check for errors - let item = inner - .error - .as_mut() - .and_then(|err| err.take()) - .or_else(|| { - this.state - .take_io_error() - .map(DispatchItem::IoError) - }); - *this.st = IoDispatcherState::Stop; - item } - } else { - // decode incoming bytes stream - if read.is_ready() { - match read.decode(this.codec) { - Ok(Some(el)) => { - // update keep-alive timer - if this.keepalive_timeout.non_zero() { - let updated = this.timer.now(); - if updated != *this.updated { - let ka = time::Duration::from( - *this.keepalive_timeout, - ); - this.timer.register( - updated + ka, - *this.updated + ka, - this.state, - ); - *this.updated = updated; - } - } + Poll::Ready(Some(Err(err))) => { + *this.st = IoDispatcherState::Stop; - Some(DispatchItem::Item(el)) - } - Ok(None) => { - // log::trace!("not enough data to decode next frame, register dispatch task"); - read.wake(cx.waker()); - return Poll::Pending; - } - Err(err) => { - retry = true; - *this.st = IoDispatcherState::Stop; - - // unregister keep-alive timer - if this.keepalive_timeout.non_zero() { - this.timer.unregister( - *this.updated - + time::Duration::from( - *this.keepalive_timeout, - ), - this.state, - ); - } + // unregister keep-alive timer + if this.keepalive_timeout.non_zero() { + this.timer.unregister( + *this.updated + + time::Duration::from(*this.keepalive_timeout), + io.as_ref(), + ); + } - Some(DispatchItem::DecoderError(err)) + match err { + Either::Left(e) => Some(DispatchItem::DecoderError(e)), + Either::Right(e) => { + Some(DispatchItem::Disconnect(Some(e))) } } - } else { - this.state.register_dispatcher(cx.waker()); - return Poll::Pending; + } + Poll::Ready(None) => { + *this.st = IoDispatcherState::Stop; + Some(DispatchItem::Disconnect(None)) } }; @@ -388,10 +382,20 @@ where if let Poll::Ready(res) = res { // check if current result is only response atm if inner.queue.is_empty() { - if let Err(err) = - write.encode_result(res, this.codec) - { - inner.error = Some(err.into()); + match res { + Err(err) => { + inner.error = Some(err.into()); + } + Ok(Some(item)) => { + if let Err(err) = + io.encode(item, this.codec) + { + inner.error = Some( + IoDispatcherError::Encoder(err), + ); + } + } + Ok(None) => (), } } else { *this.response_idx = response_idx; @@ -408,7 +412,7 @@ where inner.base.wrapping_add(inner.queue.len() as usize); inner.queue.push_back(ServiceResult::Pending); - let st = this.state.clone(); + let st = io.get_ref(); let codec = this.codec.clone(); let inner = this.inner.clone(); let fut = this.service.call(item); @@ -417,23 +421,18 @@ where inner.borrow_mut().handle_result( item, response_idx, - st.write(), + &st, &codec, true, ); }); } } - - // run again - if retry { - return self.poll(cx); - } } Poll::Pending => { // pause io read task log::trace!("service is not ready, register dispatch task"); - read.pause(cx.waker()); + io.pause(cx); return Poll::Pending; } Poll::Ready(Err(err)) => { @@ -442,57 +441,56 @@ where *this.st = IoDispatcherState::Stop; this.inner.borrow_mut().error = Some(IoDispatcherError::Service(err)); - this.state.dispatcher_ready_err(); + *this.ready_err = true; // unregister keep-alive timer if this.keepalive_timeout.non_zero() { this.timer.unregister( *this.updated + time::Duration::from(*this.keepalive_timeout), - this.state, + io.as_ref(), ); } - - return self.poll(cx); } } } - } - // drain service responses - IoDispatcherState::Stop => { - // service may relay on poll_ready for response results - if !this.state.is_dispatcher_ready_err() { - let _ = this.service.poll_ready(cx); - } + // drain service responses and shutdown io + IoDispatcherState::Stop => { + // service may relay on poll_ready for response results + if !*this.ready_err { + let _ = this.service.poll_ready(cx); + } - if this.inner.borrow().queue.is_empty() { - this.state.shutdown_io(); - *this.st = IoDispatcherState::Shutdown; - self.poll(cx) - } else { - this.state.register_dispatcher(cx.waker()); - Poll::Pending + if this.inner.borrow().queue.is_empty() { + if io.poll_shutdown(cx).is_ready() { + *this.st = IoDispatcherState::Shutdown; + continue; + } + } else { + io.register_dispatcher(cx); + } + return Poll::Pending; + } + // shutdown service + IoDispatcherState::Shutdown => { + let is_err = this.inner.borrow().error.is_some(); + + return if this.service.poll_shutdown(cx, is_err).is_ready() { + log::trace!("service shutdown is completed, stop"); + + Poll::Ready( + if let Some(IoDispatcherError::Service(err)) = + this.inner.borrow_mut().error.take() + { + Err(err) + } else { + Ok(()) + }, + ) + } else { + Poll::Pending + }; } - } - // shutdown service - IoDispatcherState::Shutdown => { - let is_err = this.inner.borrow().error.is_some(); - - return if this.service.poll_shutdown(cx, is_err).is_ready() { - log::trace!("service shutdown is completed, stop"); - - Poll::Ready( - if let Some(IoDispatcherError::Service(err)) = - this.inner.borrow_mut().error.take() - { - Err(err) - } else { - Ok(()) - }, - ) - } else { - Poll::Pending - }; } } } @@ -504,6 +502,7 @@ mod tests { use ntex::channel::condition::Condition; use ntex::codec::BytesCodec; + use ntex::io as nio; use ntex::testing::Io; use ntex::time::{sleep, Millis}; use ntex::util::{Bytes, Ready}; @@ -519,21 +518,16 @@ mod tests { ::Item: 'static, { /// Construct new `Dispatcher` instance - pub(crate) fn new>( - io: T, + pub(crate) fn new_debug>( + io: Io, codec: U, - state: State, service: F, - ) -> Self - where - T: AsyncRead + AsyncWrite + Unpin + 'static, - { + ) -> (Self, nio::IoRef) { let timer = Timer::new(Millis::ONE_SEC); let keepalive_timeout = Seconds(30); - let updated = timer.now(); - let io = Rc::new(RefCell::new(io)); - ntex::rt::spawn(ReadTask::new(io.clone(), state.clone())); - ntex::rt::spawn(WriteTask::new(io.clone(), state.clone())); + let updated = now(); + let io = nio::Io::new(io).into_boxed(); + let rio = io.get_ref(); let inner = Rc::new(RefCell::new(DispatcherState { error: None, @@ -541,19 +535,23 @@ mod tests { queue: VecDeque::new(), })); - Dispatcher { - service: service.into_service(), - st: IoDispatcherState::Processing, - pool: state.memory_pool().pool(), - response: None, - response_idx: 0, - inner, - state, - timer, - codec, - updated, - keepalive_timeout, - } + ( + Dispatcher { + service: service.into_service(), + st: IoDispatcherState::Processing, + response: None, + response_idx: 0, + pool: io.memory_pool().pool(), + ready_err: false, + inner, + io, + timer, + codec, + updated, + keepalive_timeout, + }, + rio, + ) } } @@ -563,10 +561,9 @@ mod tests { client.remote_buffer_cap(1024); client.write("GET /test HTTP/1\r\n\r\n"); - let disp = Dispatcher::new( + let (disp, _) = Dispatcher::new_debug( server, BytesCodec, - State::new(), ntex::service::fn_service(|msg: DispatchItem| async move { sleep(Millis(50)).await; if let DispatchItem::Item(msg) = msg { @@ -601,10 +598,9 @@ mod tests { let condition = Condition::new(); let waiter = condition.wait(); - let disp = Dispatcher::new( + let (disp, _) = Dispatcher::new_debug( server, BytesCodec, - State::new(), ntex::service::fn_service(move |msg: DispatchItem| { let waiter = waiter.clone(); async move { @@ -641,11 +637,9 @@ mod tests { client.remote_buffer_cap(1024); client.write("GET /test HTTP/1\r\n\r\n"); - let st = State::new(); - let disp = Dispatcher::new( + let (disp, io) = Dispatcher::new_debug( server, BytesCodec, - st.clone(), ntex::service::fn_service(|msg: DispatchItem| async move { if let DispatchItem::Item(msg) = msg { Ok::<_, ()>(Some(msg.freeze())) @@ -661,11 +655,11 @@ mod tests { let buf = client.read().await.unwrap(); assert_eq!(buf, Bytes::from_static(b"GET /test HTTP/1\r\n\r\n")); - assert!(st.write().encode(Bytes::from_static(b"test"), &BytesCodec).is_ok()); + assert!(io.encode(Bytes::from_static(b"test"), &BytesCodec).is_ok()); let buf = client.read().await.unwrap(); assert_eq!(buf, Bytes::from_static(b"test")); - st.close(); + io.close(); sleep(Millis(1200)).await; assert!(client.is_server_dropped()); } @@ -676,11 +670,9 @@ mod tests { client.remote_buffer_cap(0); client.write("GET /test HTTP/1\r\n\r\n"); - let state = State::new(); - let disp = Dispatcher::new( + let (disp, io) = Dispatcher::new_debug( server, BytesCodec, - state.clone(), ntex::service::fn_service(|_: DispatchItem| async move { Err::, _>(()) }), @@ -689,10 +681,7 @@ mod tests { let _ = disp.await; }); - state - .write() - .encode(Bytes::from_static(b"GET /test HTTP/1\r\n\r\n"), &BytesCodec) - .unwrap(); + io.encode(Bytes::from_static(b"GET /test HTTP/1\r\n\r\n"), &BytesCodec).unwrap(); // buffer should be flushed client.remote_buffer_cap(1024); @@ -700,7 +689,7 @@ mod tests { assert_eq!(buf, Bytes::from_static(b"GET /test HTTP/1\r\n\r\n")); // write side must be closed, dispatcher waiting for read side to close - assert!(client.is_closed()); + assert!(!client.is_closed()); // close read side client.close().await; @@ -733,12 +722,8 @@ mod tests { } } - let state = State::new(); - let disp = Dispatcher::new(server, BytesCodec, state.clone(), Srv(counter.clone())); - state - .write() - .encode(Bytes::from_static(b"GET /test HTTP/1\r\n\r\n"), &mut BytesCodec) - .unwrap(); + let (disp, io) = Dispatcher::new_debug(server, BytesCodec, Srv(counter.clone())); + io.encode(Bytes::from_static(b"GET /test HTTP/1\r\n\r\n"), &mut BytesCodec).unwrap(); ntex::rt::spawn(async move { let _ = disp.await; }); diff --git a/src/server.rs b/src/server.rs index 14185fd7..074cf7f7 100644 --- a/src/server.rs +++ b/src/server.rs @@ -1,30 +1,28 @@ use std::task::{Context, Poll}; use std::{convert::TryFrom, fmt, future::Future, io, marker, pin::Pin, rc::Rc, time}; -use ntex::codec::{AsyncRead, AsyncWrite}; +use ntex::io::{Filter, Io, IoBoxed}; use ntex::service::{Service, ServiceFactory}; use ntex::time::{sleep, Seconds, Sleep}; use ntex::util::{join, Pool, PoolId, PoolRef, Ready}; use crate::error::{MqttError, ProtocolError}; -use crate::io::State; use crate::version::{ProtocolVersion, VersionCodec}; use crate::{v3, v5}; /// Mqtt Server -pub struct MqttServer { +pub struct MqttServer { v3: V3, v5: V5, handshake_timeout: Seconds, - pool: Pool, - _t: marker::PhantomData<(Io, Err, InitErr)>, + _t: marker::PhantomData<(F, Err, InitErr)>, } -impl +impl MqttServer< - Io, - DefaultProtocolServer, - DefaultProtocolServer, + F, + DefaultProtocolServer, + DefaultProtocolServer, Err, InitErr, > @@ -34,18 +32,17 @@ impl MqttServer { v3: DefaultProtocolServer::new(ProtocolVersion::MQTT3), v5: DefaultProtocolServer::new(ProtocolVersion::MQTT5), - pool: PoolId::P5.pool(), handshake_timeout: Seconds::ZERO, _t: marker::PhantomData, } } } -impl Default +impl Default for MqttServer< - Io, - DefaultProtocolServer, - DefaultProtocolServer, + F, + DefaultProtocolServer, + DefaultProtocolServer, Err, InitErr, > @@ -55,7 +52,7 @@ impl Default } } -impl MqttServer { +impl MqttServer { /// Set handshake timeout. /// /// Handshake includes `connect` packet. @@ -64,30 +61,20 @@ impl MqttServer { self.handshake_timeout = timeout; self } - - /// Set memory pool. - /// - /// Use specified memory pool for memory allocations. By default P5 - /// memory pool is used. - pub fn memory_pool(mut self, id: PoolId) -> Self { - self.pool = id.pool(); - self - } } -impl MqttServer +impl MqttServer where - Io: AsyncRead + AsyncWrite + Unpin + 'static, V3: ServiceFactory< Config = (), - Request = (Io, State, Option), + Request = (IoBoxed, Option), Response = (), Error = MqttError, InitError = InitErr, >, V5: ServiceFactory< Config = (), - Request = (Io, State, Option), + Request = (IoBoxed, Option), Response = (), Error = MqttError, InitError = InitErr, @@ -96,12 +83,12 @@ where /// Service to handle v3 protocol pub fn v3( self, - service: v3::MqttServer, + service: v3::MqttServer, ) -> MqttServer< - Io, + F, impl ServiceFactory< Config = (), - Request = (Io, State, Option), + Request = (IoBoxed, Option), Response = (), Error = MqttError, InitError = InitErr, @@ -114,8 +101,8 @@ where St: 'static, C: ServiceFactory< Config = (), - Request = v3::Handshake, - Response = v3::HandshakeAck, + Request = v3::Handshake, + Response = v3::HandshakeAck, Error = Err, InitError = InitErr, > + 'static, @@ -135,7 +122,6 @@ where MqttServer { v3: service.inner_finish(), v5: self.v5, - pool: self.pool, handshake_timeout: self.handshake_timeout, _t: marker::PhantomData, } @@ -144,12 +130,12 @@ where /// Service to handle v3 protocol pub fn v3_variants( self, - service: v3::Selector, + service: v3::Selector, ) -> MqttServer< - Io, + F, impl ServiceFactory< Config = (), - Request = (Io, State, Option), + Request = (IoBoxed, Option), Response = (), Error = MqttError, InitError = InitErr, @@ -165,7 +151,6 @@ where MqttServer { v3: service.finish_server(), v5: self.v5, - pool: self.pool, handshake_timeout: self.handshake_timeout, _t: marker::PhantomData, } @@ -174,13 +159,13 @@ where /// Service to handle v5 protocol pub fn v5( self, - service: v5::MqttServer, + service: v5::MqttServer, ) -> MqttServer< - Io, + F, V3, impl ServiceFactory< Config = (), - Request = (Io, State, Option), + Request = (IoBoxed, Option), Response = (), Error = MqttError, InitError = InitErr, @@ -192,8 +177,8 @@ where St: 'static, C: ServiceFactory< Config = (), - Request = v5::Handshake, - Response = v5::HandshakeAck, + Request = v5::Handshake, + Response = v5::HandshakeAck, Error = Err, InitError = InitErr, > + 'static, @@ -218,7 +203,6 @@ where MqttServer { v3: self.v3, v5: service.inner_finish(), - pool: self.pool, handshake_timeout: self.handshake_timeout, _t: marker::PhantomData, } @@ -227,13 +211,13 @@ where /// Service to handle v5 protocol pub fn v5_variants( self, - service: v5::Selector, + service: v5::Selector, ) -> MqttServer< - Io, + F, V3, impl ServiceFactory< Config = (), - Request = (Io, State, Option), + Request = (IoBoxed, Option), Response = (), Error = MqttError, InitError = InitErr, @@ -248,26 +232,25 @@ where MqttServer { v3: self.v3, v5: service.finish_server(), - pool: self.pool, handshake_timeout: self.handshake_timeout, _t: marker::PhantomData, } } } -impl ServiceFactory for MqttServer +impl ServiceFactory for MqttServer where - Io: AsyncRead + AsyncWrite + Unpin + 'static, + F: Filter, V3: ServiceFactory< Config = (), - Request = (Io, State, Option), + Request = (IoBoxed, Option), Response = (), Error = MqttError, InitError = InitErr, >, V5: ServiceFactory< Config = (), - Request = (Io, State, Option), + Request = (IoBoxed, Option), Response = (), Error = MqttError, InitError = InitErr, @@ -276,21 +259,20 @@ where V5::Future: 'static, { type Config = (); - type Request = Io; + type Request = Io; type Response = (); type Error = MqttError; - type Service = MqttServerImpl; + type Service = MqttServerImpl; type InitError = InitErr; type Future = Pin< Box< dyn Future< - Output = Result, InitErr>, + Output = Result, InitErr>, >, >, >; fn new_service(&self, _: ()) -> Self::Future { - let pool = self.pool.clone(); let handshake_timeout = self.handshake_timeout; let fut = join(self.v3.new_service(()), self.v5.new_service(())); Box::pin(async move { @@ -299,7 +281,6 @@ where let v5 = v5?; Ok(MqttServerImpl { handlers: Rc::new((v3, v5)), - pool, handshake_timeout, _t: marker::PhantomData, }) @@ -308,30 +289,28 @@ where } /// Mqtt Server -pub struct MqttServerImpl { +pub struct MqttServerImpl { handlers: Rc<(V3, V5)>, handshake_timeout: Seconds, - pool: Pool, - _t: marker::PhantomData<(Io, Err)>, + _t: marker::PhantomData<(F, Err)>, } -impl Service for MqttServerImpl +impl Service for MqttServerImpl where - Io: AsyncRead + AsyncWrite + Unpin + 'static, - V3: Service), Response = (), Error = MqttError>, - V5: Service), Response = (), Error = MqttError>, + F: Filter, + V3: Service), Response = (), Error = MqttError>, + V5: Service), Response = (), Error = MqttError>, { - type Request = Io; + type Request = Io; type Response = (); type Error = MqttError; - type Future = MqttServerImplResponse; + type Future = MqttServerImplResponse; fn poll_ready(&self, cx: &mut Context<'_>) -> Poll> { let ready1 = self.handlers.0.poll_ready(cx)?.is_ready(); let ready2 = self.handlers.1.poll_ready(cx)?.is_ready(); - let ready3 = self.pool.poll_ready(cx).is_ready(); - if ready1 && ready2 && ready3 { + if ready1 && ready2 { Poll::Ready(Ok(())) } else { Poll::Pending @@ -349,57 +328,50 @@ where } } - fn call(&self, req: Io) -> Self::Future { - let pool = self.pool.pool_ref(); + fn call(&self, req: Io) -> Self::Future { + let req = req.into_boxed(); let delay = self.handshake_timeout.map(sleep); MqttServerImplResponse { state: MqttServerImplState::Version { - item: Some(( - req, - State::with_memory_pool(pool), - VersionCodec, - self.handlers.clone(), - delay, - )), + item: Some((req, VersionCodec, self.handlers.clone(), delay)), }, } } } pin_project_lite::pin_project! { - pub struct MqttServerImplResponse + pub struct MqttServerImplResponse where V3: Service< - Request = (Io, State, Option), + Request = (IoBoxed, Option), Response = (), Error = MqttError, >, V5: Service< - Request = (Io, State, Option), + Request = (IoBoxed, Option), Response = (), Error = MqttError, >, { #[pin] - state: MqttServerImplState, + state: MqttServerImplState, } } pin_project_lite::pin_project! { #[project = MqttServerImplStateProject] - pub(crate) enum MqttServerImplState { + pub(crate) enum MqttServerImplState { V3 { #[pin] fut: V3::Future }, V5 { #[pin] fut: V5::Future }, - Version { item: Option<(Io, State, VersionCodec, Rc<(V3, V5)>, Option)> }, + Version { item: Option<(IoBoxed, VersionCodec, Rc<(V3, V5)>, Option)> }, } } -impl Future for MqttServerImplResponse +impl Future for MqttServerImplResponse where - Io: AsyncRead + AsyncWrite + Unpin + 'static, - V3: Service), Response = (), Error = MqttError>, - V5: Service), Response = (), Error = MqttError>, + V3: Service), Response = (), Error = MqttError>, + V5: Service), Response = (), Error = MqttError>, { type Output = Result<(), MqttError>; @@ -411,7 +383,7 @@ where MqttServerImplStateProject::V3 { fut } => return fut.poll(cx), MqttServerImplStateProject::V5 { fut } => return fut.poll(cx), MqttServerImplStateProject::Version { ref mut item } => { - if let Some(ref mut delay) = item.as_mut().unwrap().4 { + if let Some(ref mut delay) = item.as_mut().unwrap().3 { match Pin::new(delay).poll(cx) { Poll::Pending => (), Poll::Ready(_) => { @@ -422,28 +394,28 @@ where let st = item.as_mut().unwrap(); - match st.1.poll_next(&mut st.0, &st.2, cx) { - Poll::Ready(Ok(Some(ver))) => { - let (io, state, _, handlers, delay) = item.take().unwrap(); + match st.0.poll_read_next(&st.1, cx) { + Poll::Ready(Some(Ok(ver))) => { + let (io, _, handlers, delay) = item.take().unwrap(); this = self.as_mut().project(); match ver { ProtocolVersion::MQTT3 => { this.state.set(MqttServerImplState::V3 { - fut: handlers.0.call((io, state, delay)), + fut: handlers.0.call((io, delay)), }) } ProtocolVersion::MQTT5 => { this.state.set(MqttServerImplState::V5 { - fut: handlers.1.call((io, state, delay)), + fut: handlers.1.call((io, delay)), }) } } continue; } - Poll::Ready(Ok(None)) => { - return Poll::Ready(Err(MqttError::Disconnected)) + Poll::Ready(None) => return Poll::Ready(Err(MqttError::Disconnected)), + Poll::Ready(Some(Err(err))) => { + return Poll::Ready(Err(MqttError::from(err))) } - Poll::Ready(Err(err)) => return Poll::Ready(Err(MqttError::from(err))), Poll::Pending => return Poll::Pending, } } @@ -452,23 +424,23 @@ where } } -pub struct DefaultProtocolServer { +pub struct DefaultProtocolServer { ver: ProtocolVersion, - _t: marker::PhantomData<(Io, Err, InitErr)>, + _t: marker::PhantomData<(Err, InitErr)>, } -impl DefaultProtocolServer { +impl DefaultProtocolServer { fn new(ver: ProtocolVersion) -> Self { Self { ver, _t: marker::PhantomData } } } -impl ServiceFactory for DefaultProtocolServer { +impl ServiceFactory for DefaultProtocolServer { type Config = (); - type Request = (Io, State, Option); + type Request = (IoBoxed, Option); type Response = (); type Error = MqttError; - type Service = DefaultProtocolServer; + type Service = DefaultProtocolServer; type InitError = InitErr; type Future = Ready; @@ -477,8 +449,8 @@ impl ServiceFactory for DefaultProtocolServer Service for DefaultProtocolServer { - type Request = (Io, State, Option); +impl Service for DefaultProtocolServer { + type Request = (IoBoxed, Option); type Response = (); type Error = MqttError; type Future = Ready; diff --git a/src/service.rs b/src/service.rs index dd61fc7a..3b692fa2 100644 --- a/src/service.rs +++ b/src/service.rs @@ -1,28 +1,27 @@ use std::task::{Context, Poll}; use std::{fmt, future::Future, marker::PhantomData, pin::Pin, rc::Rc}; -use ntex::codec::{AsyncRead, AsyncWrite, Decoder, Encoder}; +use ntex::codec::{Decoder, Encoder}; +use ntex::io::{DispatchItem, IoBoxed, Timer}; use ntex::service::{IntoServiceFactory, Service, ServiceFactory}; use ntex::time::{Millis, Seconds, Sleep}; use ntex::util::{select, Either, Pool}; -use super::io::{DispatchItem, Dispatcher, State, Timer}; +use crate::io::Dispatcher; type ResponseItem = Option<::Item>; -pub(crate) struct FramedService { +pub(crate) struct FramedService { connect: C, handler: Rc, disconnect_timeout: Seconds, time: Timer, - pool: Pool, - _t: PhantomData<(St, Io, Codec)>, + _t: PhantomData<(St, Codec)>, } -impl FramedService { - pub(crate) fn new(connect: C, service: T, pool: Pool, disconnect_timeout: Seconds) -> Self { +impl FramedService { + pub(crate) fn new(connect: C, service: T, disconnect_timeout: Seconds) -> Self { FramedService { - pool, connect, disconnect_timeout, handler: Rc::new(service), @@ -32,13 +31,11 @@ impl FramedService { } } -impl ServiceFactory for FramedService +impl ServiceFactory for FramedService where - Io: AsyncRead + AsyncWrite + Unpin + 'static, - C: ServiceFactory, + C: ServiceFactory + + 'static, C::Error: fmt::Debug, - C::Future: 'static, - ::Future: 'static, T: ServiceFactory< Config = St, Request = DispatchItem, @@ -46,17 +43,15 @@ where Error = C::Error, InitError = C::Error, > + 'static, - ::Error: 'static, - ::Future: 'static, Codec: Decoder + Encoder + Clone + 'static, ::Item: 'static, { type Config = (); - type Request = Io; + type Request = IoBoxed; type Response = (); type Error = C::Error; type InitError = C::InitError; - type Service = FramedServiceImpl; + type Service = FramedServiceImpl; type Future = Pin>>>; fn new_service(&self, _: ()) -> Self::Future { @@ -64,14 +59,12 @@ where let handler = self.handler.clone(); let disconnect_timeout = self.disconnect_timeout; let time = self.time.clone(); - let pool = self.pool.clone(); // create connect service and then create service impl Box::pin(async move { Ok(FramedServiceImpl { handler, disconnect_timeout, - pool, time, connect: fut.await?, _t: PhantomData, @@ -80,21 +73,18 @@ where } } -pub(crate) struct FramedServiceImpl { +pub(crate) struct FramedServiceImpl { connect: C, handler: Rc, disconnect_timeout: Seconds, - pool: Pool, time: Timer, - _t: PhantomData<(St, Io, Codec)>, + _t: PhantomData<(St, Codec)>, } -impl Service for FramedServiceImpl +impl Service for FramedServiceImpl where - Io: AsyncRead + AsyncWrite + Unpin + 'static, - C: Service, + C: Service + 'static, C::Error: fmt::Debug, - C::Future: 'static, T: ServiceFactory< Config = St, Request = DispatchItem, @@ -102,26 +92,16 @@ where Error = C::Error, InitError = C::Error, > + 'static, - ::Error: 'static, - ::Future: 'static, Codec: Decoder + Encoder + Clone + 'static, - ::Item: 'static, { - type Request = Io; + type Request = IoBoxed; type Response = (); type Error = C::Error; type Future = Pin>>>; #[inline] fn poll_ready(&self, cx: &mut Context<'_>) -> Poll> { - let ready1 = self.connect.poll_ready(cx)?.is_ready(); - let ready2 = self.pool.poll_ready(cx).is_ready(); - - if ready1 && ready2 { - Poll::Ready(Ok(())) - } else { - Poll::Pending - } + self.connect.poll_ready(cx) } #[inline] @@ -130,7 +110,7 @@ where } #[inline] - fn call(&self, req: Io) -> Self::Future { + fn call(&self, req: IoBoxed) -> Self::Future { log::trace!("Start connection handshake"); let handler = self.handler.clone(); @@ -139,7 +119,7 @@ where let time = self.time.clone(); Box::pin(async move { - let (io, st, codec, session, keepalive) = handshake.await.map_err(|e| { + let (io, codec, session, keepalive) = handshake.await.map_err(|e| { log::trace!("Connection handshake failed: {:?}", e); e })?; @@ -148,7 +128,7 @@ where let handler = handler.new_service(session).await?; log::trace!("Connection handler is created, starting dispatcher"); - Dispatcher::with(io, st, codec, handler, time) + Dispatcher::new(io, codec, handler, time) .keepalive_timeout(keepalive) .disconnect_timeout(timeout) .await @@ -156,20 +136,18 @@ where } } -pub(crate) struct FramedService2 { +pub(crate) struct FramedService2 { connect: C, handler: Rc, disconnect_timeout: Seconds, - pool: Pool, time: Timer, - _t: PhantomData<(St, Io, Codec)>, + _t: PhantomData<(St, Codec)>, } -impl FramedService2 { - pub(crate) fn new(connect: C, service: T, pool: Pool, disconnect_timeout: Seconds) -> Self { +impl FramedService2 { + pub(crate) fn new(connect: C, service: T, disconnect_timeout: Seconds) -> Self { FramedService2 { connect, - pool, disconnect_timeout, handler: Rc::new(service), time: Timer::new(Millis::ONE_SEC), @@ -178,17 +156,11 @@ impl FramedService2 { } } -impl ServiceFactory for FramedService2 +impl ServiceFactory for FramedService2 where - Io: AsyncRead + AsyncWrite + Unpin + 'static, - C: ServiceFactory< - Config = (), - Request = (Io, State), - Response = (Io, State, Codec, St, Seconds), - >, + C: ServiceFactory + + 'static, C::Error: fmt::Debug, - C::Future: 'static, - ::Future: 'static, T: ServiceFactory< Config = St, Request = DispatchItem, @@ -196,17 +168,14 @@ where Error = C::Error, InitError = C::Error, > + 'static, - ::Error: 'static, - ::Future: 'static, Codec: Decoder + Encoder + Clone + 'static, - ::Item: 'static, { type Config = (); - type Request = (Io, State, Option); + type Request = (IoBoxed, Option); type Response = (); type Error = C::Error; type InitError = C::InitError; - type Service = FramedServiceImpl2; + type Service = FramedServiceImpl2; type Future = Pin>>>; fn new_service(&self, _: ()) -> Self::Future { @@ -214,7 +183,6 @@ where let handler = self.handler.clone(); let disconnect_timeout = self.disconnect_timeout; let time = self.time.clone(); - let pool = self.pool.clone(); // create connect service and then create service impl Box::pin(async move { @@ -222,7 +190,6 @@ where handler, disconnect_timeout, time, - pool, connect: fut.await?, _t: PhantomData, }) @@ -230,21 +197,18 @@ where } } -pub(crate) struct FramedServiceImpl2 { +pub(crate) struct FramedServiceImpl2 { connect: C, handler: Rc, - pool: Pool, disconnect_timeout: Seconds, time: Timer, - _t: PhantomData<(St, Io, Codec)>, + _t: PhantomData<(St, Codec)>, } -impl Service for FramedServiceImpl2 +impl Service for FramedServiceImpl2 where - Io: AsyncRead + AsyncWrite + Unpin + 'static, - C: Service, + C: Service + 'static, C::Error: fmt::Debug, - C::Future: 'static, T: ServiceFactory< Config = St, Request = DispatchItem, @@ -252,26 +216,16 @@ where Error = C::Error, InitError = C::Error, > + 'static, - ::Error: 'static, - ::Future: 'static, Codec: Decoder + Encoder + Clone + 'static, - ::Item: 'static, { - type Request = (Io, State, Option); + type Request = (IoBoxed, Option); type Response = (); type Error = C::Error; type Future = Pin>>>; #[inline] fn poll_ready(&self, cx: &mut Context<'_>) -> Poll> { - let ready1 = self.connect.poll_ready(cx)?.is_ready(); - let ready2 = self.pool.poll_ready(cx).is_ready(); - - if ready1 && ready2 { - Poll::Ready(Ok(())) - } else { - Poll::Pending - } + self.connect.poll_ready(cx) } #[inline] @@ -280,20 +234,20 @@ where } #[inline] - fn call(&self, (req, state, delay): (Io, State, Option)) -> Self::Future { + fn call(&self, (req, delay): Self::Request) -> Self::Future { log::trace!("Start connection handshake"); let handler = self.handler.clone(); let timeout = self.disconnect_timeout; - let handshake = self.connect.call((req, state)); + let handshake = self.connect.call(req); let time = self.time.clone(); Box::pin(async move { - let (io, state, codec, ka, handler) = if let Some(delay) = delay { + let (io, codec, ka, handler) = if let Some(delay) = delay { let res = select( delay, Box::pin(async { - let (io, state, codec, st, ka) = handshake.await.map_err(|e| { + let (io, codec, st, ka) = handshake.await.map_err(|e| { log::trace!("Connection handshake failed: {:?}", e); e })?; @@ -302,7 +256,7 @@ where let handler = handler.new_service(st).await?; log::trace!("Connection handler is created, starting dispatcher"); - Ok::<_, C::Error>((io, state, codec, ka, handler)) + Ok::<_, C::Error>((io, codec, ka, handler)) }), ) .await; @@ -315,7 +269,7 @@ where Either::Right(item) => item?, } } else { - let (io, state, codec, st, ka) = handshake.await.map_err(|e| { + let (io, codec, st, ka) = handshake.await.map_err(|e| { log::trace!("Connection handshake failed: {:?}", e); e })?; @@ -323,10 +277,10 @@ where let handler = handler.new_service(st).await?; log::trace!("Connection handler is created, starting dispatcher"); - (io, state, codec, ka, handler) + (io, codec, ka, handler) }; - Dispatcher::with(io, state, codec, handler, time) + Dispatcher::new(io, codec, handler, time) .keepalive_timeout(ka) .disconnect_timeout(timeout) .await diff --git a/src/v3/client/connection.rs b/src/v3/client/connection.rs index 523ed34d..6ff07bd8 100644 --- a/src/v3/client/connection.rs +++ b/src/v3/client/connection.rs @@ -1,13 +1,13 @@ use std::{fmt, future::Future, marker::PhantomData, rc::Rc, time::Instant}; -use ntex::codec::{AsyncRead, AsyncWrite}; +use ntex::io::{DispatchItem, IoBoxed, Timer}; use ntex::router::{IntoPattern, Router, RouterBuilder}; use ntex::service::{apply_fn, boxed, into_service, IntoService, Service}; use ntex::time::{sleep, Millis, Seconds}; use ntex::util::{Either, Ready}; use crate::error::{MqttError, ProtocolError}; -use crate::io::{DispatchItem, Dispatcher, Timer}; +use crate::io::Dispatcher; use crate::v3::{shared::MqttShared, sink::MqttSink}; use crate::v3::{ControlResult, Publish}; @@ -15,8 +15,8 @@ use super::control::ControlMessage; use super::dispatcher::create_dispatcher; /// Mqtt client -pub struct Client { - io: Io, +pub struct Client { + io: IoBoxed, shared: Rc, keepalive: Seconds, disconnect_timeout: Seconds, @@ -24,7 +24,7 @@ pub struct Client { max_receive: usize, } -impl fmt::Debug for Client { +impl fmt::Debug for Client { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.debug_struct("v3::Client") .field("keepalive", &self.keepalive) @@ -35,13 +35,10 @@ impl fmt::Debug for Client { } } -impl Client -where - T: AsyncRead + AsyncWrite + Unpin, -{ +impl Client { /// Construct new `Dispatcher` instance with outgoing messages stream. pub(super) fn new( - io: T, + io: IoBoxed, shared: Rc, session_present: bool, keepalive_timeout: Seconds, @@ -59,10 +56,7 @@ where } } -impl Client -where - Io: AsyncRead + AsyncWrite + Unpin + 'static, -{ +impl Client { #[inline] /// Get client sink pub fn sink(&self) -> MqttSink { @@ -76,7 +70,7 @@ where } /// Configure mqtt resource for a specific topic - pub fn resource(self, address: T, service: F) -> ClientRouter + pub fn resource(self, address: T, service: F) -> ClientRouter where T: IntoPattern, F: IntoService, @@ -114,9 +108,8 @@ where into_service(|msg: ControlMessage<()>| Ready::<_, ()>::Ok(msg.disconnect())), ); - let _ = Dispatcher::with( + let _ = Dispatcher::new( self.io, - self.shared.state.clone(), self.shared.clone(), dispatcher, Timer::new(Millis::ONE_SEC), @@ -144,26 +137,20 @@ where service.into_service(), ); - Dispatcher::with( - self.io, - self.shared.state.clone(), - self.shared.clone(), - dispatcher, - Timer::new(Millis::ONE_SEC), - ) - .keepalive_timeout(Seconds::ZERO) - .disconnect_timeout(self.disconnect_timeout) - .await + Dispatcher::new(self.io, self.shared.clone(), dispatcher, Timer::new(Millis::ONE_SEC)) + .keepalive_timeout(Seconds::ZERO) + .disconnect_timeout(self.disconnect_timeout) + .await } } type Handler = boxed::BoxService; /// Mqtt client with routing capabilities -pub struct ClientRouter { +pub struct ClientRouter { builder: RouterBuilder, handlers: Vec>, - io: Io, + io: IoBoxed, shared: Rc, keepalive: Seconds, disconnect_timeout: Seconds, @@ -171,7 +158,7 @@ pub struct ClientRouter { _t: PhantomData, } -impl fmt::Debug for ClientRouter { +impl fmt::Debug for ClientRouter { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.debug_struct("v3::ClientRouter") .field("keepalive", &self.keepalive) @@ -181,9 +168,8 @@ impl fmt::Debug for ClientRouter { } } -impl ClientRouter +impl ClientRouter where - Io: AsyncRead + AsyncWrite + Unpin + 'static, Err: From + 'static, PErr: 'static, { @@ -212,9 +198,8 @@ where into_service(|msg: ControlMessage| Ready::<_, Err>::Ok(msg.disconnect())), ); - let _ = Dispatcher::with( + let _ = Dispatcher::new( self.io, - self.shared.state.clone(), self.shared.clone(), dispatcher, Timer::new(Millis::ONE_SEC), @@ -242,16 +227,10 @@ where service.into_service(), ); - Dispatcher::with( - self.io, - self.shared.state.clone(), - self.shared.clone(), - dispatcher, - Timer::new(Millis::ONE_SEC), - ) - .keepalive_timeout(Seconds::ZERO) - .disconnect_timeout(self.disconnect_timeout) - .await + Dispatcher::new(self.io, self.shared.clone(), dispatcher, Timer::new(Millis::ONE_SEC)) + .keepalive_timeout(Seconds::ZERO) + .disconnect_timeout(self.disconnect_timeout) + .await } } diff --git a/src/v3/client/connector.rs b/src/v3/client/connector.rs index 1174c665..b8b97ea7 100644 --- a/src/v3/client/connector.rs +++ b/src/v3/client/connector.rs @@ -1,7 +1,7 @@ use std::{future::Future, rc::Rc}; -use ntex::codec::{AsyncRead, AsyncWrite}; use ntex::connect::{self, Address, Connect, Connector}; +use ntex::io::{Filter, Io, IoBoxed}; use ntex::service::Service; use ntex::time::{timeout, Millis, Seconds}; use ntex::util::{select, ByteString, Bytes, Either, PoolId}; @@ -13,7 +13,6 @@ use ntex::connect::openssl::{OpensslConnector, SslConnector}; use ntex::connect::rustls::{ClientConfig, RustlsConnector}; use super::{codec, connection::Client, error::ClientError, error::ProtocolError}; -use crate::io::State; use crate::v3::shared::{MqttShared, MqttSinkPool}; /// Mqtt client connector @@ -35,11 +34,16 @@ where { #[allow(clippy::new_ret_no_self)] /// Create new mqtt connector - pub fn new(address: A) -> MqttConnector> { + pub fn new( + address: A, + ) -> MqttConnector< + A, + impl Service, Response = IoBoxed, Error = connect::ConnectError>, + > { MqttConnector { address, pkt: codec::Connect::default(), - connector: Connector::default(), + connector: Connector::default().map(|io| io.into_boxed()), max_send: 16, max_receive: 16, max_packet_size: 64 * 1024, @@ -53,8 +57,7 @@ where impl MqttConnector where A: Address + Clone, - T: Service, Error = connect::ConnectError>, - T::Response: AsyncRead + AsyncWrite + Unpin + 'static, + T: Service, Response = IoBoxed, Error = connect::ConnectError>, { #[inline] /// Create new client and provide client id @@ -179,13 +182,19 @@ where } /// Use custom connector - pub fn connector(self, connector: U) -> MqttConnector + pub fn connector( + self, + connector: U, + ) -> MqttConnector< + A, + impl Service, Response = IoBoxed, Error = connect::ConnectError>, + > where - U: Service, Error = connect::ConnectError>, - U::Response: AsyncRead + AsyncWrite + Unpin + 'static, + F: Filter, + U: Service, Response = Io, Error = connect::ConnectError>, { MqttConnector { - connector, + connector: connector.map(|io| io.into_boxed()), pkt: self.pkt, address: self.address, max_send: self.max_send, @@ -199,14 +208,20 @@ where #[cfg(feature = "openssl")] /// Use openssl connector - pub fn openssl(self, connector: SslConnector) -> MqttConnector> { + pub fn openssl( + self, + connector: SslConnector, + ) -> MqttConnector< + A, + impl Service, Response = IoBoxed, Error = connect::ConnectError>, + > { MqttConnector { pkt: self.pkt, address: self.address, max_send: self.max_send, max_receive: self.max_receive, max_packet_size: self.max_packet_size, - connector: OpensslConnector::new(connector), + connector: OpensslConnector::new(connector).map(|io| io.into_boxed()), handshake_timeout: self.handshake_timeout, disconnect_timeout: self.disconnect_timeout, pool: self.pool, @@ -215,16 +230,20 @@ where #[cfg(feature = "rustls")] /// Use rustls connector - pub fn rustls(self, config: ClientConfig) -> MqttConnector> { - use std::sync::Arc; - + pub fn rustls( + self, + config: ClientConfig, + ) -> MqttConnector< + A, + impl Service, Response = IoBoxed, Error = connect::ConnectError>, + > { MqttConnector { pkt: self.pkt, address: self.address, max_send: self.max_send, max_receive: self.max_receive, max_packet_size: self.max_packet_size, - connector: RustlsConnector::new(Arc::new(config)), + connector: RustlsConnector::new(Arc::new(config)).map(|io| io.into_boxed()), handshake_timeout: self.handshake_timeout, disconnect_timeout: self.disconnect_timeout, pool: self.pool, @@ -232,7 +251,7 @@ where } /// Connect to mqtt server - pub fn connect(&self) -> impl Future, ClientError>> { + pub fn connect(&self) -> impl Future> { if self.handshake_timeout.non_zero() { let fut = timeout(self.handshake_timeout, self._connect()); Either::Left(async move { @@ -246,7 +265,7 @@ where } } - fn _connect(&self) -> impl Future, ClientError>> { + fn _connect(&self) -> impl Future> { let fut = self.connector.call(Connect::new(self.address.clone())); let pkt = self.pkt.clone(); let max_send = self.max_send; @@ -257,23 +276,21 @@ where let pool = self.pool.clone(); async move { - let mut io = fut.await?; - let state = State::with_memory_pool(pool.pool.get()); + let io = fut.await?; let codec = codec::Codec::new().max_size(max_packet_size); - state.send(&mut io, &codec, pkt.into()).await?; + io.send(pkt.into(), &codec).await?; - let packet = state - .next(&mut io, &codec) + let packet = io + .next(&codec) .await - .map_err(|e| ClientError::from(ProtocolError::from(e))) - .and_then(|res| { - res.ok_or_else(|| { - log::trace!("Mqtt server is disconnected during handshake"); - ClientError::Disconnected - }) - })?; - let shared = Rc::new(MqttShared::new(state.clone(), codec, max_send, pool)); + .ok_or_else(|| { + log::trace!("Mqtt server is disconnected during handshake"); + ClientError::Disconnected + })? + .map_err(|e| ClientError::from(ProtocolError::from(e)))?; + + let shared = Rc::new(MqttShared::new(io.get_ref(), codec, max_send, pool)); match packet { codec::Packet::ConnectAck { session_present, return_code } => { diff --git a/src/v3/client/control.rs b/src/v3/client/control.rs index 27ff5c6a..af232302 100644 --- a/src/v3/client/control.rs +++ b/src/v3/client/control.rs @@ -1,4 +1,6 @@ -pub use crate::v3::control::{Closed, ControlResult, Disconnect, Error, ProtocolError}; +pub use crate::v3::control::{ + Closed, ControlResult, Disconnect, Error, PeerGone, ProtocolError, +}; use crate::v3::{codec, control::ControlResultKind, error}; pub enum ControlMessage { @@ -12,6 +14,8 @@ pub enum ControlMessage { Error(Error), /// Protocol level error ProtocolError(ProtocolError), + /// Peer is gone + PeerGone(PeerGone), } impl ControlMessage { @@ -35,6 +39,10 @@ impl ControlMessage { ControlMessage::ProtocolError(ProtocolError::new(err)) } + pub(super) fn peer_gone() -> Self { + ControlMessage::PeerGone(PeerGone) + } + pub fn disconnect(&self) -> ControlResult { ControlResult { result: ControlResultKind::Disconnect } } diff --git a/src/v3/client/dispatcher.rs b/src/v3/client/dispatcher.rs index 22fac0d7..e81eb286 100644 --- a/src/v3/client/dispatcher.rs +++ b/src/v3/client/dispatcher.rs @@ -2,12 +2,13 @@ use std::cell::RefCell; use std::task::{Context, Poll}; use std::{future::Future, marker::PhantomData, num::NonZeroU16, pin::Pin, rc::Rc}; +use ntex::io::DispatchItem; use ntex::service::Service; use ntex::util::{buffer::BufferService, inflight::InFlightService, Either, HashSet, Ready}; use crate::v3::shared::{Ack, MqttShared}; use crate::v3::{codec, control::ControlResultKind, publish::Publish, sink::MqttSink}; -use crate::{error::MqttError, error::ProtocolError, io::DispatchItem, types::packet_type}; +use crate::{error::MqttError, error::ProtocolError, types::packet_type}; use super::control::{ControlMessage, ControlResult}; @@ -199,10 +200,16 @@ where &self.inner, ))) } - DispatchItem::IoError(err) => Either::Right(Either::Right(ControlResponse::new( - ControlMessage::proto_error(ProtocolError::Io(err)), - &self.inner, - ))), + DispatchItem::Disconnect(err) => { + Either::Right(Either::Right(ControlResponse::new( + if let Some(err) = err { + ControlMessage::proto_error(ProtocolError::Io(err)) + } else { + ControlMessage::peer_gone() + }, + &self.inner, + ))) + } DispatchItem::KeepAliveTimeout => { Either::Right(Either::Right(ControlResponse::new( ControlMessage::proto_error(ProtocolError::KeepAliveTimeout), diff --git a/src/v3/control.rs b/src/v3/control.rs index d08b9d73..bab849d0 100644 --- a/src/v3/control.rs +++ b/src/v3/control.rs @@ -20,6 +20,8 @@ pub enum ControlMessage { Error(Error), /// Protocol level error ProtocolError(ProtocolError), + /// Peer is gone + PeerGone(PeerGone), } #[derive(Debug)] @@ -75,13 +77,18 @@ impl ControlMessage { ControlMessage::ProtocolError(ProtocolError::new(err)) } + /// Create a new `ControlMessage` from DISCONNECT packet. + pub(super) fn peer_gone() -> Self { + ControlMessage::PeerGone(PeerGone) + } + /// Disconnects the client by sending DISCONNECT packet. pub fn disconnect(&self) -> ControlResult { ControlResult { result: ControlResultKind::Disconnect } } } -#[derive(Debug)] +#[derive(Copy, Clone, Debug)] pub struct Ping; impl Ping { @@ -90,7 +97,7 @@ impl Ping { } } -#[derive(Debug)] +#[derive(Copy, Clone, Debug)] pub struct Disconnect; impl Disconnect { @@ -346,3 +353,12 @@ impl Closed { ControlResult { result: ControlResultKind::Closed } } } + +#[derive(Copy, Clone, Debug)] +pub struct PeerGone; + +impl PeerGone { + pub fn ack(self) -> ControlResult { + ControlResult { result: ControlResultKind::Nothing } + } +} diff --git a/src/v3/dispatcher.rs b/src/v3/dispatcher.rs index 4380f0f1..55611229 100644 --- a/src/v3/dispatcher.rs +++ b/src/v3/dispatcher.rs @@ -2,13 +2,13 @@ use std::cell::RefCell; use std::task::{Context, Poll}; use std::{future::Future, marker::PhantomData, num::NonZeroU16, pin::Pin, rc::Rc}; +use ntex::io::DispatchItem; use ntex::service::{fn_factory_with_config, Service, ServiceFactory}; use ntex::util::{ buffer::BufferService, inflight::InFlightService, join, Either, HashSet, Ready, }; use crate::error::{MqttError, ProtocolError}; -use crate::io::DispatchItem; use super::control::{ ControlMessage, ControlResult, ControlResultKind, Subscribe, Unsubscribe, @@ -235,10 +235,16 @@ where &self.inner, ))) } - DispatchItem::IoError(err) => Either::Right(Either::Right(ControlResponse::new( - ControlMessage::proto_error(ProtocolError::Io(err)), - &self.inner, - ))), + DispatchItem::Disconnect(err) => { + Either::Right(Either::Right(ControlResponse::new( + if let Some(err) = err { + ControlMessage::proto_error(ProtocolError::Io(err)) + } else { + ControlMessage::peer_gone() + }, + &self.inner, + ))) + } DispatchItem::WBackPressureEnabled | DispatchItem::WBackPressureDisabled => { Either::Right(Either::Left(Ready::Ok(None))) } diff --git a/src/v3/handshake.rs b/src/v3/handshake.rs index 794a998c..ef6566e2 100644 --- a/src/v3/handshake.rs +++ b/src/v3/handshake.rs @@ -1,20 +1,20 @@ use std::{fmt, rc::Rc}; -use ntex::time::Seconds; +use ntex::{io::IoBoxed, time::Seconds}; use super::codec as mqtt; use super::shared::MqttShared; use super::sink::MqttSink; /// Connect message -pub struct Handshake { - io: Io, +pub struct Handshake { + io: IoBoxed, pkt: Box, shared: Rc, } -impl Handshake { - pub(crate) fn new(pkt: Box, io: Io, shared: Rc) -> Self { +impl Handshake { + pub(crate) fn new(pkt: Box, io: IoBoxed, shared: Rc) -> Self { Self { io, pkt, shared } } @@ -27,8 +27,8 @@ impl Handshake { } #[inline] - pub fn io(&mut self) -> &mut Io { - &mut self.io + pub fn io(&self) -> &IoBoxed { + &self.io } /// Returns mqtt server sink @@ -37,7 +37,7 @@ impl Handshake { } /// Ack handshake message and set state - pub fn ack(self, st: St, session_present: bool) -> HandshakeAck { + pub fn ack(self, st: St, session_present: bool) -> HandshakeAck { let Handshake { io, shared, pkt } = self; // [MQTT-3.1.2-24]. let keepalive = if pkt.keep_alive != 0 { @@ -46,9 +46,9 @@ impl Handshake { 30 }; HandshakeAck { - session_present, io, shared, + session_present, session: Some(st), keepalive: Seconds(keepalive), return_code: mqtt::ConnectAckReason::ConnectionAccepted, @@ -56,7 +56,7 @@ impl Handshake { } /// Create connect ack object with `identifier rejected` return code - pub fn identifier_rejected(self) -> HandshakeAck { + pub fn identifier_rejected(self) -> HandshakeAck { HandshakeAck { io: self.io, shared: self.shared, @@ -68,7 +68,7 @@ impl Handshake { } /// Create connect ack object with `bad user name or password` return code - pub fn bad_username_or_pwd(self) -> HandshakeAck { + pub fn bad_username_or_pwd(self) -> HandshakeAck { HandshakeAck { io: self.io, shared: self.shared, @@ -80,7 +80,7 @@ impl Handshake { } /// Create connect ack object with `not authorized` return code - pub fn not_authorized(self) -> HandshakeAck { + pub fn not_authorized(self) -> HandshakeAck { HandshakeAck { io: self.io, shared: self.shared, @@ -92,7 +92,7 @@ impl Handshake { } /// Create connect ack object with `service unavailable` return code - pub fn service_unavailable(self) -> HandshakeAck { + pub fn service_unavailable(self) -> HandshakeAck { HandshakeAck { io: self.io, shared: self.shared, @@ -104,15 +104,15 @@ impl Handshake { } } -impl fmt::Debug for Handshake { +impl fmt::Debug for Handshake { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { self.pkt.fmt(f) } } /// Ack connect message -pub struct HandshakeAck { - pub(crate) io: Io, +pub struct HandshakeAck { + pub(crate) io: IoBoxed, pub(crate) session: Option, pub(crate) session_present: bool, pub(crate) return_code: mqtt::ConnectAckReason, @@ -120,7 +120,7 @@ pub struct HandshakeAck { pub(crate) keepalive: Seconds, } -impl HandshakeAck { +impl HandshakeAck { /// Set idle time-out for the connection in seconds /// /// By default idle time-out is set to 30 seconds. @@ -128,20 +128,4 @@ impl HandshakeAck { self.keepalive = timeout; self } - - #[doc(hidden)] - #[deprecated(since = "0.7.6", note = "Use memory pool config")] - #[inline] - /// Set read/write buffer sizes - /// - /// By default max buffer size is 4kb for both read and write buffer, - /// Min size is 256 bytes. - pub fn buffer_params( - self, - _max_read_buf: u16, - _max_write_buf: u16, - _min_buf_size: u16, - ) -> Self { - self - } } diff --git a/src/v3/mod.rs b/src/v3/mod.rs index 26acf70b..f2fec541 100644 --- a/src/v3/mod.rs +++ b/src/v3/mod.rs @@ -16,7 +16,7 @@ mod sink; pub type Session = crate::Session; -pub use self::client::Client; +pub use self::client::{Client, MqttConnector}; pub use self::control::{ControlMessage, ControlResult}; pub use self::handshake::{Handshake, HandshakeAck}; pub use self::publish::Publish; diff --git a/src/v3/selector.rs b/src/v3/selector.rs index d2a82ae7..f1d01d7a 100644 --- a/src/v3/selector.rs +++ b/src/v3/selector.rs @@ -1,12 +1,11 @@ use std::{fmt, future::Future, marker, pin::Pin, rc::Rc, task::Context, task::Poll, time}; -use ntex::codec::{AsyncRead, AsyncWrite}; +use ntex::io::{DispatchItem, Io, IoBoxed, IoRef}; use ntex::service::{apply_fn_factory, boxed, IntoServiceFactory, Service, ServiceFactory}; use ntex::time::{sleep, Seconds, Sleep}; use ntex::util::{timeout::Timeout, timeout::TimeoutError, Either, PoolId, Ready}; use crate::error::{MqttError, ProtocolError}; -use crate::io::{DispatchItem, State}; use super::control::{ControlMessage, ControlResult}; use super::default::{DefaultControlService, DefaultPublishService}; @@ -14,32 +13,26 @@ use super::handshake::{Handshake, HandshakeAck}; use super::shared::{MqttShared, MqttSinkPool}; use super::{codec as mqtt, dispatcher::factory, MqttServer, MqttSink, Publish, Session}; -pub(crate) type SelectItem = (Handshake, State, Option); +pub(crate) type SelectItem = (Handshake, Option); -type ServerFactory = boxed::BoxServiceFactory< - (), - SelectItem, - Either, ()>, - MqttError, - InitErr, ->; +type ServerFactory = + boxed::BoxServiceFactory<(), SelectItem, Either, MqttError, InitErr>; -type Server = - boxed::BoxService, Either, ()>, MqttError>; +type Server = boxed::BoxService, MqttError>; /// Mqtt server selector /// /// Selector allows to choose different mqtt server impls depends on /// connectt packet. -pub struct Selector { - servers: Vec>, +pub struct Selector { + servers: Vec>, max_size: u32, handshake_timeout: Seconds, pool: Rc, - _t: marker::PhantomData<(Io, Err, InitErr)>, + _t: marker::PhantomData<(Err, InitErr)>, } -impl Selector { +impl Selector { #[allow(clippy::new_without_default)] pub fn new() -> Self { Selector { @@ -52,9 +45,8 @@ impl Selector { } } -impl Selector +impl Selector where - Io: AsyncRead + AsyncWrite + Unpin + 'static, Err: 'static, InitErr: 'static, { @@ -76,29 +68,20 @@ where self } - /// Set memory pool. - /// - /// Use specified memory pool for memory allocations. By default P5 - /// memory pool is used. - pub fn memory_pool(self, id: PoolId) -> Self { - self.pool.pool.set(id.pool_ref()); - self - } - /// Add server variant pub fn variant( mut self, check: F, - mut server: MqttServer, + mut server: MqttServer, ) -> Self where - F: Fn(&Handshake) -> R + 'static, + F: Fn(&Handshake) -> R + 'static, R: Future> + 'static, St: 'static, C: ServiceFactory< Config = (), - Request = Handshake, - Response = HandshakeAck, + Request = Handshake, + Response = HandshakeAck, Error = Err, InitError = InitErr, > + 'static, @@ -124,7 +107,7 @@ where self, ) -> impl ServiceFactory< Config = (), - Request = (Io, State, Option), + Request = (IoBoxed, Option), Response = (), Error = MqttError, InitError = InitErr, @@ -138,18 +121,17 @@ where } } -impl ServiceFactory for Selector +impl ServiceFactory for Selector where - Io: AsyncRead + AsyncWrite + Unpin + 'static, Err: 'static, InitErr: 'static, { type Config = (); - type Request = Io; + type Request = IoBoxed; type Response = (); type Error = MqttError; type InitError = InitErr; - type Service = SelectorService; + type Service = SelectorService; type Future = Pin>>>; fn new_service(&self, _: ()) -> Self::Future { @@ -168,19 +150,18 @@ where } } -pub struct SelectorService { - servers: Rc>>, +pub struct SelectorService { + servers: Rc>>, max_size: u32, handshake_timeout: Seconds, pool: Rc, } -impl Service for SelectorService +impl Service for SelectorService where - Io: AsyncRead + AsyncWrite + Unpin + 'static, Err: 'static, { - type Request = Io; + type Request = IoBoxed; type Response = (); type Error = MqttError; type Future = Pin>>>>; @@ -212,11 +193,10 @@ where } #[inline] - fn call(&self, mut io: Io) -> Self::Future { + fn call(&self, io: IoBoxed) -> Self::Future { let servers = self.servers.clone(); - let state = State::with_memory_pool(self.pool.pool.get()); let shared = Rc::new(MqttShared::new( - state.clone(), + io.clone(), mqtt::Codec::default().max_size(self.max_size), 16, self.pool.clone(), @@ -225,18 +205,16 @@ where Box::pin(async move { // read first packet - let packet = state - .next(&mut io, &shared.codec) + let packet = io + .next(&shared.codec) .await + .ok_or_else(|| { + log::trace!("Server mqtt is disconnected during handshake"); + MqttError::Disconnected + })? .map_err(|err| { log::trace!("Error is received during mqtt handshake: {:?}", err); MqttError::from(err) - }) - .and_then(|res| { - res.ok_or_else(|| { - log::trace!("Server mqtt is disconnected during handshake"); - MqttError::Disconnected - }) })?; let connect = match packet { @@ -251,7 +229,7 @@ where }; // call servers - let mut item = (Handshake::new(connect, io, shared), state, delay); + let mut item = (Handshake::new(connect, io, shared), delay); for srv in servers.iter() { match srv.call(item).await? { Either::Left(result) => { @@ -266,25 +244,24 @@ where } } -pub(crate) struct Selector2 { - servers: Vec>, +pub(crate) struct Selector2 { + servers: Vec>, max_size: u32, pool: Rc, - _t: marker::PhantomData<(Io, Err, InitErr)>, + _t: marker::PhantomData<(Err, InitErr)>, } -impl ServiceFactory for Selector2 +impl ServiceFactory for Selector2 where - Io: AsyncRead + AsyncWrite + Unpin + 'static, Err: 'static, InitErr: 'static, { type Config = (); - type Request = (Io, State, Option); + type Request = (IoBoxed, Option); type Response = (); type Error = MqttError; type InitError = InitErr; - type Service = SelectorService2; + type Service = SelectorService2; type Future = Pin>>>; fn new_service(&self, _: ()) -> Self::Future { @@ -302,18 +279,17 @@ where } } -pub(crate) struct SelectorService2 { - servers: Rc>>, +pub(crate) struct SelectorService2 { + servers: Rc>>, max_size: u32, pool: Rc, } -impl Service for SelectorService2 +impl Service for SelectorService2 where - Io: AsyncRead + AsyncWrite + Unpin + 'static, Err: 'static, { - type Request = (Io, State, Option); + type Request = (IoBoxed, Option); type Response = (); type Error = MqttError; type Future = Pin>>>>; @@ -345,10 +321,10 @@ where } #[inline] - fn call(&self, (mut io, state, delay): Self::Request) -> Self::Future { + fn call(&self, (io, delay): Self::Request) -> Self::Future { let servers = self.servers.clone(); let shared = Rc::new(MqttShared::new( - state.clone(), + io.get_ref(), mqtt::Codec::default().max_size(self.max_size), 16, self.pool.clone(), @@ -356,18 +332,16 @@ where Box::pin(async move { // read first packet - let packet = state - .next(&mut io, &shared.codec) + let packet = io + .next(&shared.codec) .await + .ok_or_else(|| { + log::trace!("Server mqtt is disconnected during handshake"); + MqttError::Disconnected + })? .map_err(|err| { log::trace!("Error is received during mqtt handshake: {:?}", err); MqttError::from(err) - }) - .and_then(|res| { - res.ok_or_else(|| { - log::trace!("Server mqtt is disconnected during handshake"); - MqttError::Disconnected - }) })?; let connect = match packet { @@ -382,7 +356,7 @@ where }; // call servers - let mut item = (Handshake::new(connect, io, shared), state, delay); + let mut item = (Handshake::new(connect, io, shared), delay); for srv in servers.iter() { match srv.call(item).await? { Either::Left(result) => { diff --git a/src/v3/server.rs b/src/v3/server.rs index 61a464b2..939252e9 100644 --- a/src/v3/server.rs +++ b/src/v3/server.rs @@ -1,13 +1,14 @@ use std::task::{Context, Poll}; use std::{fmt, future::Future, marker::PhantomData, pin::Pin, rc::Rc}; -use ntex::codec::{AsyncRead, AsyncWrite, Decoder, Encoder}; +use ntex::codec::{Decoder, Encoder}; +use ntex::io::{into_boxed, DispatchItem, Filter, Io, IoBoxed, IoRef, Timer}; use ntex::service::{apply_fn_factory, IntoServiceFactory, Service, ServiceFactory}; use ntex::time::{Millis, Seconds, Sleep}; use ntex::util::{timeout::Timeout, timeout::TimeoutError, Either, PoolId, Ready}; use crate::error::{MqttError, ProtocolError}; -use crate::io::{DispatchItem, Dispatcher, State, Timer}; +use crate::io::Dispatcher; use crate::service::{FramedService, FramedService2}; use super::control::{ControlMessage, ControlResult}; @@ -18,7 +19,7 @@ use super::shared::{MqttShared, MqttSinkPool}; use super::{codec as mqtt, dispatcher::factory, MqttSink, Publish, Session}; /// Mqtt v3.1.1 Server -pub struct MqttServer { +pub struct MqttServer { handshake: C, control: Cn, publish: P, @@ -27,21 +28,14 @@ pub struct MqttServer, - _t: PhantomData<(Io, St)>, + _t: PhantomData, } -impl - MqttServer< - Io, - St, - C, - DefaultControlService, - DefaultPublishService, - > +impl + MqttServer, DefaultPublishService> where St: 'static, - C: ServiceFactory, Response = HandshakeAck> - + 'static, + C: ServiceFactory> + 'static, C::Error: fmt::Debug, { /// Create server factory and provide handshake service @@ -63,12 +57,10 @@ where } } -impl MqttServer +impl MqttServer where - Io: AsyncRead + AsyncWrite + Unpin + 'static, St: 'static, - C: ServiceFactory, Response = HandshakeAck> - + 'static, + C: ServiceFactory> + 'static, Cn: ServiceFactory< Config = Session, Request = ControlMessage, @@ -121,20 +113,11 @@ where self } - /// Set memory pool. - /// - /// Use specified memory pool for memory allocations. By default P5 - /// memory pool is used. - pub fn memory_pool(self, id: PoolId) -> Self { - self.pool.pool.set(id.pool_ref()); - self - } - /// Service to handle control packets /// /// All control packets are processed sequentially, max number of buffered /// control packets is 16. - pub fn control(self, service: F) -> MqttServer + pub fn control(self, service: F) -> MqttServer where F: IntoServiceFactory, Srv: ServiceFactory< @@ -158,7 +141,7 @@ where } /// Set service to handle publish packets and create mqtt server factory - pub fn publish(self, publish: F) -> MqttServer + pub fn publish(self, publish: F) -> MqttServer where F: IntoServiceFactory + 'static, Srv: ServiceFactory, Request = Publish, Response = ()> + 'static, @@ -178,11 +161,12 @@ where } /// Finish server configuration and create mqtt server factory - pub fn finish( + pub fn finish( self, - ) -> impl ServiceFactory> + ) -> impl ServiceFactory, Response = (), Error = MqttError> + where + F: Filter, { - let pool = self.pool.pool.get().pool(); let handshake = self.handshake; let publish = self .publish @@ -192,17 +176,16 @@ where let control = self.control.map_err(|e| e.into()).map_init_err(|e| MqttError::Service(e.into())); - FramedService::new( + into_boxed(FramedService::new( handshake_service_factory( handshake, self.max_size, self.handshake_timeout, - self.pool, + self.pool.clone(), ), factory(publish, control, self.inflight), - pool, self.disconnect_timeout, - ) + )) } /// Set service to handle publish packets and create mqtt server factory @@ -210,12 +193,11 @@ where self, ) -> impl ServiceFactory< Config = (), - Request = (Io, State, Option), + Request = (IoBoxed, Option), Response = (), Error = MqttError, InitError = C::InitError, > { - let pool = self.pool.pool.get().pool(); let handshake = self.handshake; let publish = self .publish @@ -226,14 +208,13 @@ where self.control.map_err(|e| e.into()).map_init_err(|e| MqttError::Service(e.into())); FramedService2::new( - handshake_service_factory2( + handshake_service_factory( handshake, self.max_size, self.handshake_timeout, - self.pool, + self.pool.clone(), ), factory(publish, control, self.inflight), - pool, self.disconnect_timeout, ) } @@ -244,13 +225,13 @@ where check: F, ) -> impl ServiceFactory< Config = (), - Request = SelectItem, - Response = Either, ()>, + Request = SelectItem, + Response = Either, Error = MqttError, InitError = C::InitError, > where - F: Fn(&Handshake) -> R + 'static, + F: Fn(&Handshake) -> R + 'static, R: Future> + 'static, { let publish = @@ -270,61 +251,20 @@ where } } -fn handshake_service_factory( - factory: C, - max_size: u32, - handshake_timeout: Seconds, - pool: Rc, -) -> impl ServiceFactory< - Config = (), - Request = Io, - Response = (Io, State, Rc, Session, Seconds), - Error = MqttError, -> -where - Io: AsyncRead + AsyncWrite + Unpin, - C: ServiceFactory, Response = HandshakeAck>, - C::Error: fmt::Debug, -{ - ntex::service::apply( - Timeout::new(Millis::from(handshake_timeout)), - ntex::service::fn_factory(move || { - let pool = pool.clone(); - let fut = factory.new_service(()); - async move { - let service = fut.await?; - let pool = pool.clone(); - let service = Rc::new(service.map_err(MqttError::Service)); - Ok::<_, C::InitError>(ntex::service::apply_fn( - service, - move |conn: Io, service| { - handshake(conn, None, service.clone(), max_size, pool.clone()) - }, - )) - } - }), - ) - .map_err(|e| match e { - TimeoutError::Service(e) => e, - TimeoutError::Timeout => MqttError::HandshakeTimeout, - }) -} - -fn handshake_service_factory2( +fn handshake_service_factory( factory: C, max_size: u32, handshake_timeout: Seconds, pool: Rc, ) -> impl ServiceFactory< Config = (), - Request = (Io, State), - Response = (Io, State, Rc, Session, Seconds), + Request = IoBoxed, + Response = (IoBoxed, Rc, Session, Seconds), Error = MqttError, InitError = C::InitError, > where - Io: AsyncRead + AsyncWrite + Unpin, - C: ServiceFactory, Response = HandshakeAck>, + C: ServiceFactory>, C::Error: fmt::Debug, { ntex::service::apply( @@ -334,10 +274,9 @@ where let fut = factory.new_service(()); async move { let service = fut.await?; - let pool = pool.clone(); let service = Rc::new(service.map_err(MqttError::Service)); - Ok(ntex::service::apply_fn(service, move |(io, state), service| { - handshake(io, Some(state), service.clone(), max_size, pool.clone()) + Ok::<_, C::InitError>(ntex::service::apply_fn(service, move |conn, service| { + handshake(conn, service.clone(), max_size, pool.clone()) })) } }), @@ -348,46 +287,41 @@ where }) } -async fn handshake( - mut io: Io, - state: Option, +async fn handshake( + io: IoBoxed, service: S, max_size: u32, pool: Rc, -) -> Result<(Io, State, Rc, Session, Seconds), S::Error> +) -> Result<(IoBoxed, Rc, Session, Seconds), S::Error> where - Io: AsyncRead + AsyncWrite + Unpin, - S: Service, Response = HandshakeAck, Error = MqttError>, + S: Service, Error = MqttError>, { log::trace!("Starting mqtt handshake"); - let state = state.unwrap_or_else(|| State::with_memory_pool(pool.pool.get())); let shared = Rc::new(MqttShared::new( - state.clone(), + io.get_ref(), mqtt::Codec::default().max_size(max_size), 16, pool, )); // read first packet - let packet = state - .next(&mut io, &shared.codec) + let packet = io + .next(&shared.codec) .await + .ok_or_else(|| { + log::trace!("Server mqtt is disconnected during handshake"); + MqttError::Disconnected + })? .map_err(|err| { log::trace!("Error is received during mqtt handshake: {:?}", err); MqttError::from(err) - }) - .and_then(|res| { - res.ok_or_else(|| { - log::trace!("Server mqtt is disconnected during handshake"); - MqttError::Disconnected - }) })?; match packet { mqtt::Packet::Connect(connect) => { // authenticate mqtt connection - let mut ack = service.call(Handshake::new(connect, io, shared)).await?; + let ack = service.call(Handshake::new(connect, io, shared)).await?; match ack.session { Some(session) => { @@ -398,10 +332,9 @@ where log::trace!("Sending success handshake ack: {:#?}", pkt); - state.send(&mut ack.io, &ack.shared.codec, pkt).await?; + ack.io.send(pkt, &ack.shared.codec).await?; Ok(( ack.io, - ack.shared.state.clone(), ack.shared.clone(), Session::new(session, MqttSink::new(ack.shared)), ack.keepalive, @@ -414,7 +347,7 @@ where }; log::trace!("Sending failed handshake ack: {:#?}", pkt); - ack.shared.state.send(&mut ack.io, &ack.shared.codec, pkt).await?; + ack.io.send(pkt, &ack.shared.codec).await?; Err(MqttError::Disconnected) } @@ -430,23 +363,21 @@ where } } -pub(crate) struct ServerSelector { +pub(crate) struct ServerSelector { connect: C, handler: Rc, disconnect_timeout: Seconds, time: Timer, check: Rc, max_size: u32, - _t: PhantomData<(St, Io, R)>, + _t: PhantomData<(St, R)>, } -impl ServiceFactory for ServerSelector +impl ServiceFactory for ServerSelector where - Io: AsyncRead + AsyncWrite + Unpin + 'static, - F: Fn(&Handshake) -> R + 'static, + F: Fn(&Handshake) -> R + 'static, R: Future>, - C: ServiceFactory, Response = HandshakeAck> - + 'static, + C: ServiceFactory> + 'static, C::Error: fmt::Debug, T: ServiceFactory< Config = Session, @@ -457,11 +388,11 @@ where > + 'static, { type Config = (); - type Request = SelectItem; - type Response = Either, ()>; + type Request = SelectItem; + type Response = Either; type Error = MqttError; type InitError = C::InitError; - type Service = ServerSelectorImpl; + type Service = ServerSelectorImpl; type Future = Pin>>>; fn new_service(&self, _: ()) -> Self::Future { @@ -487,22 +418,21 @@ where } } -pub(crate) struct ServerSelectorImpl { +pub(crate) struct ServerSelectorImpl { check: Rc, connect: Rc, handler: Rc, disconnect_timeout: Seconds, time: Timer, max_size: u32, - _t: PhantomData<(St, Io, R)>, + _t: PhantomData<(St, R)>, } -impl Service for ServerSelectorImpl +impl Service for ServerSelectorImpl where - Io: AsyncRead + AsyncWrite + Unpin + 'static, - F: Fn(&Handshake) -> R + 'static, + F: Fn(&Handshake) -> R + 'static, R: Future>, - C: Service, Response = HandshakeAck> + 'static, + C: Service> + 'static, C::Error: fmt::Debug, T: ServiceFactory< Config = Session, @@ -512,8 +442,8 @@ where InitError = MqttError, > + 'static, { - type Request = SelectItem; - type Response = Either, ()>; + type Request = SelectItem; + type Response = Either; type Error = MqttError; type Future = Pin>>>; @@ -539,7 +469,7 @@ where let max_size = self.max_size; Box::pin(async move { - let (hnd, state, mut delay) = req; + let (hnd, mut delay) = req; let result = if let Some(ref mut delay) = delay { let fut = (&*check)(&hnd); @@ -552,10 +482,10 @@ where }; if !result.map_err(MqttError::Service)? { - Ok(Either::Left((hnd, state, delay))) + Ok(Either::Left((hnd, delay))) } else { // authenticate mqtt connection - let mut ack = if let Some(ref mut delay) = delay { + let ack = if let Some(ref mut delay) = delay { let fut = connect.call(hnd); match crate::utils::select(fut, delay).await { Either::Left(res) => res.map_err(|e| { @@ -583,25 +513,16 @@ where ); ack.shared.codec.set_max_size(max_size); - state - .send(&mut ack.io, &ack.shared.codec, pkt) - .await - .map_err(MqttError::from)?; + ack.io.send(pkt, &ack.shared.codec).await.map_err(MqttError::from)?; let session = Session::new(session, MqttSink::new(ack.shared.clone())); let handler = handler.new_service(session).await?; log::trace!("Connection handler is created, starting dispatcher"); - Dispatcher::with( - ack.io, - ack.shared.state.clone(), - ack.shared, - handler, - time, - ) - .keepalive_timeout(ack.keepalive) - .disconnect_timeout(timeout) - .await?; + Dispatcher::new(ack.io, ack.shared, handler, time) + .keepalive_timeout(ack.keepalive) + .disconnect_timeout(timeout) + .await?; Ok(Either::Right(())) } None => { @@ -611,7 +532,7 @@ where }; log::trace!("Sending failed handshake ack: {:#?}", pkt); - ack.shared.state.send(&mut ack.io, &ack.shared.codec, pkt).await?; + ack.io.send(pkt, &ack.shared.codec).await?; Err(MqttError::Disconnected) } diff --git a/src/v3/shared.rs b/src/v3/shared.rs index 7b714f10..f9f129a9 100644 --- a/src/v3/shared.rs +++ b/src/v3/shared.rs @@ -2,10 +2,11 @@ use std::{cell::Cell, cell::RefCell, collections::VecDeque, num::NonZeroU16, rc: use ntex::channel::pool; use ntex::codec::{Decoder, Encoder}; +use ntex::io::IoRef; use ntex::util::{BytesMut, HashMap, PoolId, PoolRef}; use crate::error::{DecodeError, EncodeError}; -use crate::{io::State, types::packet_type, v3::codec}; +use crate::{types::packet_type, v3::codec}; pub(super) enum Ack { Publish(NonZeroU16), @@ -37,11 +38,11 @@ impl Default for MqttSinkPool { } pub(crate) struct MqttShared { + pub(super) io: IoRef, pub(super) cap: Cell, queues: RefCell, pub(super) inflight_idx: Cell, pub(super) pool: Rc, - pub(super) state: State, pub(super) codec: codec::Codec, } @@ -53,13 +54,13 @@ pub(super) struct MqttSharedQueues { impl MqttShared { pub(super) fn new( - state: State, + io: IoRef, codec: codec::Codec, cap: usize, pool: Rc, ) -> Self { Self { - state, + io, pool, codec, cap: Cell::new(cap), diff --git a/src/v3/sink.rs b/src/v3/sink.rs index 21cf0a81..a2e39c5c 100644 --- a/src/v3/sink.rs +++ b/src/v3/sink.rs @@ -28,7 +28,7 @@ impl MqttSink { /// /// Result indicates if connection is alive pub fn ready(&self) -> impl Future { - if self.0.state.is_open() { + if !self.0.io.is_closed() { self.0 .with_queues(|q| { if q.inflight.len() >= self.0.cap.get() { @@ -47,8 +47,8 @@ impl MqttSink { /// Close mqtt connection pub fn close(&self) { - if self.0.state.is_open() { - let _ = self.0.state.close(); + if !self.0.io.is_closed() { + let _ = self.0.io.close(); } self.0.with_queues(|q| { q.inflight.clear(); @@ -59,8 +59,8 @@ impl MqttSink { /// Force close mqtt connection. mqtt dispatcher does not wait for uncompleted /// responses, but it flushes buffers. pub fn force_close(&self) { - if self.0.state.is_open() { - let _ = self.0.state.force_close(); + if !self.0.io.is_closed() { + let _ = self.0.io.force_close(); } self.0.with_queues(|q| { q.inflight.clear(); @@ -70,7 +70,7 @@ impl MqttSink { /// Send ping pub(super) fn ping(&self) -> bool { - self.0.state.write().encode(codec::Packet::PingRequest, &self.0.codec).is_ok() + self.0.io.encode(codec::Packet::PingRequest, &self.0.codec).is_ok() } /// Create publish message builder @@ -187,11 +187,10 @@ impl PublishBuilder { pub fn send_at_most_once(self) -> Result<(), SendPacketError> { let packet = self.packet; - if self.shared.state.is_open() { + if !self.shared.io.is_closed() { log::trace!("Publish (QoS-0) to {:?}", packet.topic); self.shared - .state - .write() + .io .encode(codec::Packet::Publish(packet), &self.shared.codec) .map_err(SendPacketError::Encode) .map(|_| ()) @@ -208,7 +207,7 @@ impl PublishBuilder { let mut packet = self.packet; packet.qos = codec::QoS::AtLeastOnce; - if shared.state.is_open() { + if !shared.io.is_closed() { // handle client receive maximum if !shared.has_credit() { let (tx, rx) = shared.pool.waiters.channel(); @@ -256,7 +255,7 @@ impl PublishBuilder { log::trace!("Publish (QoS1) to {:#?}", packet); - match shared.state.write().encode(codec::Packet::Publish(packet), &shared.codec) { + match shared.io.encode(codec::Packet::Publish(packet), &shared.codec) { Ok(_) => Either::Right(async move { rx.await.map(|_| ()).map_err(|_| SendPacketError::Disconnected) }), @@ -296,7 +295,7 @@ impl SubscribeBuilder { let shared = self.shared; let filters = self.topic_filters; - if shared.state.is_open() { + if !shared.io.is_closed() { // handle client receive maximum if !shared.has_credit() { let (tx, rx) = shared.pool.waiters.channel(); @@ -323,7 +322,7 @@ impl SubscribeBuilder { // send subscribe to client log::trace!("Sending subscribe packet id: {} filters:{:?}", idx, filters); - match shared.state.write().encode( + match shared.io.encode( codec::Packet::Subscribe { packet_id: NonZeroU16::new(idx).unwrap(), topic_filters: filters, @@ -375,7 +374,7 @@ impl UnsubscribeBuilder { let shared = self.shared; let filters = self.topic_filters; - if shared.state.is_open() { + if !shared.io.is_closed() { // handle client receive maximum if !shared.has_credit() { let (tx, rx) = shared.pool.waiters.channel(); @@ -402,7 +401,7 @@ impl UnsubscribeBuilder { // send subscribe to client log::trace!("Sending unsubscribe packet id: {} filters:{:?}", idx, filters); - match shared.state.write().encode( + match shared.io.encode( codec::Packet::Unsubscribe { packet_id: NonZeroU16::new(idx).unwrap(), topic_filters: filters, diff --git a/src/v5/client/connection.rs b/src/v5/client/connection.rs index a8dbf9e3..ee907413 100644 --- a/src/v5/client/connection.rs +++ b/src/v5/client/connection.rs @@ -3,14 +3,14 @@ use std::{ cell::RefCell, convert::TryFrom, fmt, future::Future, marker, num::NonZeroU16, rc::Rc, }; -use ntex::codec::{AsyncRead, AsyncWrite}; +use ntex::io::{IoBoxed, Timer}; use ntex::router::{IntoPattern, Path, Router, RouterBuilder}; use ntex::service::{boxed, into_service, IntoService, Service}; use ntex::time::{sleep, Millis, Seconds}; use ntex::util::{ByteString, Either, HashMap, Ready}; use crate::error::MqttError; -use crate::io::{Dispatcher, Timer}; +use crate::io::Dispatcher; use crate::v5::publish::{Publish, PublishAck}; use crate::v5::{codec, shared::MqttShared, sink::MqttSink, ControlResult}; @@ -18,8 +18,8 @@ use super::control::ControlMessage; use super::dispatcher::create_dispatcher; /// Mqtt client -pub struct Client { - io: Io, +pub struct Client { + io: IoBoxed, shared: Rc, keepalive: Seconds, disconnect_timeout: Seconds, @@ -27,7 +27,7 @@ pub struct Client { pkt: Box, } -impl fmt::Debug for Client { +impl fmt::Debug for Client { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.debug_struct("v5::Client") .field("keepalive", &self.keepalive) @@ -38,13 +38,10 @@ impl fmt::Debug for Client { } } -impl Client -where - T: AsyncRead + AsyncWrite + Unpin, -{ +impl Client { /// Construct new `Dispatcher` instance with outgoing messages stream. pub(super) fn new( - io: T, + io: IoBoxed, shared: Rc, pkt: Box, max_receive: u16, @@ -62,10 +59,7 @@ where } } -impl Client -where - Io: AsyncRead + AsyncWrite + Unpin + 'static, -{ +impl Client { #[inline] /// Get client sink pub fn sink(&self) -> MqttSink { @@ -91,7 +85,7 @@ where } /// Configure mqtt resource for a specific topic - pub fn resource(self, address: T, service: F) -> ClientRouter + pub fn resource(self, address: T, service: F) -> ClientRouter where T: IntoPattern, F: IntoService, @@ -133,16 +127,10 @@ where }), ); - let _ = Dispatcher::with( - self.io, - self.shared.state.clone(), - self.shared, - dispatcher, - Timer::new(Millis::ONE_SEC), - ) - .keepalive_timeout(Seconds::ZERO) - .disconnect_timeout(self.disconnect_timeout) - .await; + let _ = Dispatcher::new(self.io, self.shared, dispatcher, Timer::new(Millis::ONE_SEC)) + .keepalive_timeout(Seconds::ZERO) + .disconnect_timeout(self.disconnect_timeout) + .await; } /// Run client with provided control messages handler @@ -164,26 +152,20 @@ where service.into_service(), ); - Dispatcher::with( - self.io, - self.shared.state.clone(), - self.shared, - dispatcher, - Timer::new(Millis::ONE_SEC), - ) - .keepalive_timeout(Seconds::ZERO) - .disconnect_timeout(self.disconnect_timeout) - .await + Dispatcher::new(self.io, self.shared, dispatcher, Timer::new(Millis::ONE_SEC)) + .keepalive_timeout(Seconds::ZERO) + .disconnect_timeout(self.disconnect_timeout) + .await } } type Handler = boxed::BoxService; /// Mqtt client with routing capabilities -pub struct ClientRouter { +pub struct ClientRouter { + io: IoBoxed, builder: RouterBuilder, handlers: Vec>, - io: Io, shared: Rc, keepalive: Seconds, disconnect_timeout: Seconds, @@ -191,7 +173,7 @@ pub struct ClientRouter { _t: marker::PhantomData, } -impl fmt::Debug for ClientRouter { +impl fmt::Debug for ClientRouter { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.debug_struct("v5::ClientRouter") .field("keepalive", &self.keepalive) @@ -201,9 +183,8 @@ impl fmt::Debug for ClientRouter { } } -impl ClientRouter +impl ClientRouter where - Io: AsyncRead + AsyncWrite + Unpin + 'static, Err: From + 'static, PublishAck: TryFrom, PErr: 'static, @@ -236,16 +217,10 @@ where }), ); - let _ = Dispatcher::with( - self.io, - self.shared.state.clone(), - self.shared, - dispatcher, - Timer::new(Millis::ONE_SEC), - ) - .keepalive_timeout(Seconds::ZERO) - .disconnect_timeout(self.disconnect_timeout) - .await; + let _ = Dispatcher::new(self.io, self.shared, dispatcher, Timer::new(Millis::ONE_SEC)) + .keepalive_timeout(Seconds::ZERO) + .disconnect_timeout(self.disconnect_timeout) + .await; } /// Run client and handle control messages @@ -267,16 +242,10 @@ where service.into_service(), ); - Dispatcher::with( - self.io, - self.shared.state.clone(), - self.shared, - dispatcher, - Timer::new(Millis::ONE_SEC), - ) - .keepalive_timeout(Seconds::ZERO) - .disconnect_timeout(self.disconnect_timeout) - .await + Dispatcher::new(self.io, self.shared, dispatcher, Timer::new(Millis::ONE_SEC)) + .keepalive_timeout(Seconds::ZERO) + .disconnect_timeout(self.disconnect_timeout) + .await } } diff --git a/src/v5/client/connector.rs b/src/v5/client/connector.rs index 1a9c468d..13768fc7 100644 --- a/src/v5/client/connector.rs +++ b/src/v5/client/connector.rs @@ -1,7 +1,7 @@ use std::{future::Future, num::NonZeroU16, num::NonZeroU32, rc::Rc, time::Duration}; -use ntex::codec::{AsyncRead, AsyncWrite}; use ntex::connect::{self, Address, Connect, Connector}; +use ntex::io::{Filter, Io, IoBoxed}; use ntex::service::Service; use ntex::time::{timeout, Seconds}; use ntex::util::{select, ByteString, Bytes, Either, PoolId}; @@ -13,7 +13,6 @@ use ntex::connect::openssl::{OpensslConnector, SslConnector}; use ntex::connect::rustls::{ClientConfig, RustlsConnector}; use super::{codec, connection::Client, error::ClientError, error::ProtocolError}; -use crate::io::State; use crate::v5::shared::{MqttShared, MqttSinkPool}; /// Mqtt client connector @@ -32,11 +31,16 @@ where { #[allow(clippy::new_ret_no_self)] /// Create new mqtt connector - pub fn new(address: A) -> MqttConnector> { + pub fn new( + address: A, + ) -> MqttConnector< + A, + impl Service, Response = IoBoxed, Error = connect::ConnectError>, + > { MqttConnector { address, pkt: codec::Connect::default(), - connector: Connector::default(), + connector: Connector::default().map(|io| io.into_boxed()), handshake_timeout: Seconds::ZERO, disconnect_timeout: Seconds(3), pool: Rc::new(MqttSinkPool::default()), @@ -47,8 +51,7 @@ where impl MqttConnector where A: Address + Clone, - T: Service, Error = connect::ConnectError>, - T::Response: AsyncRead + AsyncWrite + Unpin + 'static, + T: Service, Response = IoBoxed, Error = connect::ConnectError>, { #[inline] /// Create new client and provide client id @@ -186,13 +189,19 @@ where } /// Use custom connector - pub fn connector(self, connector: U) -> MqttConnector + pub fn connector( + self, + connector: U, + ) -> MqttConnector< + A, + impl Service, Response = IoBoxed, Error = connect::ConnectError>, + > where - U: Service, Error = connect::ConnectError>, - U::Response: AsyncRead + AsyncWrite + Unpin + 'static, + F: Filter, + U: Service, Response = Io, Error = connect::ConnectError>, { MqttConnector { - connector, + connector: connector.map(|io| io.into_boxed()), pkt: self.pkt, address: self.address, handshake_timeout: self.handshake_timeout, @@ -203,11 +212,17 @@ where #[cfg(feature = "openssl")] /// Use openssl connector - pub fn openssl(self, connector: SslConnector) -> MqttConnector> { + pub fn openssl( + self, + connector: SslConnector, + ) -> MqttConnector< + A, + impl Service, Response = IoBoxed, Error = connect::ConnectError>, + > { MqttConnector { pkt: self.pkt, address: self.address, - connector: OpensslConnector::new(connector), + connector: OpensslConnector::new(connector).map(|io| io.into_boxed()), handshake_timeout: self.handshake_timeout, disconnect_timeout: self.disconnect_timeout, pool: self.pool, @@ -216,13 +231,19 @@ where #[cfg(feature = "rustls")] /// Use rustls connector - pub fn rustls(self, config: ClientConfig) -> MqttConnector> { + pub fn rustls( + self, + config: ClientConfig, + ) -> MqttConnector< + A, + impl Service, Response = IoBoxed, Error = connect::ConnectError>, + > { use std::sync::Arc; MqttConnector { pkt: self.pkt, address: self.address, - connector: RustlsConnector::new(Arc::new(config)), + connector: RustlsConnector::new(Arc::new(config)).map(|io| io.into_boxed()), handshake_timeout: self.handshake_timeout, disconnect_timeout: self.disconnect_timeout, pool: self.pool, @@ -230,7 +251,7 @@ where } /// Connect to mqtt server - pub fn connect(&self) -> impl Future, ClientError>> { + pub fn connect(&self) -> impl Future> { if self.handshake_timeout.non_zero() { let fut = timeout(self.handshake_timeout, self._connect()); Either::Left(async move { @@ -244,7 +265,7 @@ where } } - fn _connect(&self) -> impl Future, ClientError>> { + fn _connect(&self) -> impl Future> { let fut = self.connector.call(Connect::new(self.address.clone())); let pkt = self.pkt.clone(); let keep_alive = pkt.keep_alive; @@ -254,23 +275,21 @@ where let pool = self.pool.clone(); async move { - let mut io = fut.await?; - let state = State::with_memory_pool(pool.pool.get()); + let io = fut.await?; let codec = codec::Codec::new().max_inbound_size(max_packet_size); - state.send(&mut io, &codec, codec::Packet::Connect(Box::new(pkt))).await?; + io.send(codec::Packet::Connect(Box::new(pkt)), &codec).await?; - let packet = state - .next(&mut io, &codec) + let packet = io + .next(&codec) .await - .map_err(|e| ClientError::from(ProtocolError::from(e))) - .and_then(|res| { - res.ok_or_else(|| { - log::trace!("Mqtt server is disconnected during handshake"); - ClientError::Disconnected - }) - })?; - let shared = Rc::new(MqttShared::new(state.clone(), codec, 0, pool)); + .ok_or_else(|| { + log::trace!("Mqtt server is disconnected during handshake"); + ClientError::Disconnected + })? + .map_err(|e| ClientError::from(ProtocolError::from(e)))?; + + let shared = Rc::new(MqttShared::new(io.get_ref(), codec, 0, pool)); match packet { codec::Packet::ConnectAck(pkt) => { diff --git a/src/v5/client/control.rs b/src/v5/client/control.rs index 6544089c..01b8dab3 100644 --- a/src/v5/client/control.rs +++ b/src/v5/client/control.rs @@ -1,3 +1,5 @@ +use std::io; + use ntex::util::ByteString; use crate::{error, v5::codec}; @@ -15,6 +17,8 @@ pub enum ControlMessage { ProtocolError(ProtocolError), /// Connection closed Closed(Closed), + /// Peer is gone + PeerGone(PeerGone), } impl ControlMessage { @@ -38,6 +42,10 @@ impl ControlMessage { ControlMessage::ProtocolError(ProtocolError::new(err)) } + pub(super) fn peer_gone(err: Option) -> Self { + ControlMessage::PeerGone(PeerGone(err)) + } + pub fn disconnect(&self, pkt: codec::Disconnect) -> ControlResult { ControlResult { packet: Some(codec::Packet::Disconnect(pkt)), disconnect: true } } @@ -93,3 +101,18 @@ impl Publish { } } } + +#[derive(Debug)] +pub struct PeerGone(Option); + +impl PeerGone { + /// Returns error reference + pub fn error(&self) -> Option<&io::Error> { + self.0.as_ref() + } + + /// Ack PeerGone message + pub fn ack(self) -> ControlResult { + ControlResult { packet: None, disconnect: true } + } +} diff --git a/src/v5/client/dispatcher.rs b/src/v5/client/dispatcher.rs index fdcb675a..f1245076 100644 --- a/src/v5/client/dispatcher.rs +++ b/src/v5/client/dispatcher.rs @@ -2,13 +2,14 @@ use std::cell::RefCell; use std::task::{Context, Poll}; use std::{future::Future, marker::PhantomData, num::NonZeroU16, pin::Pin, rc::Rc}; +use ntex::io::DispatchItem; use ntex::service::Service; use ntex::util::{buffer::BufferService, inflight::InFlightService, Either, HashSet, Ready}; use crate::error::{MqttError, ProtocolError}; +use crate::types::packet_type; use crate::v5::shared::{Ack, MqttShared}; use crate::v5::{codec, publish::Publish, publish::PublishAck, sink::MqttSink}; -use crate::{io::DispatchItem, types::packet_type}; use super::control::{ControlMessage, ControlResult}; @@ -290,10 +291,9 @@ where &self.inner, ))) } - DispatchItem::IoError(err) => Either::Right(Either::Right(ControlResponse::new( - ControlMessage::proto_error(ProtocolError::Io(err)), - &self.inner, - ))), + DispatchItem::Disconnect(err) => Either::Right(Either::Right( + ControlResponse::new(ControlMessage::peer_gone(err), &self.inner), + )), DispatchItem::KeepAliveTimeout => { Either::Right(Either::Right(ControlResponse::new( ControlMessage::proto_error(ProtocolError::KeepAliveTimeout), diff --git a/src/v5/control.rs b/src/v5/control.rs index fba22894..680da06d 100644 --- a/src/v5/control.rs +++ b/src/v5/control.rs @@ -1,4 +1,4 @@ -use std::marker::PhantomData; +use std::{io, marker::PhantomData}; use ntex::util::ByteString; @@ -24,6 +24,8 @@ pub enum ControlMessage { Error(Error), /// Protocol level error ProtocolError(ProtocolError), + /// Peer is gone + PeerGone(PeerGone), } /// Control message handling result @@ -72,6 +74,10 @@ impl ControlMessage { ControlMessage::Error(Error::new(err)) } + pub(super) fn peer_gone(err: Option) -> Self { + ControlMessage::PeerGone(PeerGone(err)) + } + pub(super) fn proto_error(err: error::ProtocolError) -> Self { ControlMessage::ProtocolError(ProtocolError::new(err)) } @@ -100,7 +106,7 @@ impl ControlMessage { pub struct Auth(codec::Auth); impl Auth { - /// Returns reference to dusconnect packet + /// Returns reference to auth packet pub fn packet(&self) -> &codec::Auth { &self.0 } @@ -123,7 +129,7 @@ impl Ping { pub struct Disconnect(pub(crate) codec::Disconnect); impl Disconnect { - /// Returns reference to dusconnect packet + /// Returns reference to disconnect packet pub fn packet(&self) -> &codec::Disconnect { &self.0 } @@ -618,3 +624,18 @@ impl ProtocolError { ) } } + +#[derive(Debug)] +pub struct PeerGone(Option); + +impl PeerGone { + /// Returns error reference + pub fn error(&self) -> Option<&io::Error> { + self.0.as_ref() + } + + /// Ack PeerGone message + pub fn ack(self) -> ControlResult { + ControlResult { packet: None, disconnect: true } + } +} diff --git a/src/v5/dispatcher.rs b/src/v5/dispatcher.rs index 3fe4baa8..a1005346 100644 --- a/src/v5/dispatcher.rs +++ b/src/v5/dispatcher.rs @@ -2,13 +2,13 @@ use std::cell::RefCell; use std::task::{Context, Poll}; use std::{convert::TryFrom, future::Future, marker, num, pin::Pin, rc::Rc}; +use ntex::io::DispatchItem; use ntex::service::{fn_factory_with_config, Service, ServiceFactory}; use ntex::util::{ buffer::BufferService, inflight::InFlightService, join, Either, HashSet, Ready, }; use crate::error::{MqttError, ProtocolError}; -use crate::io::DispatchItem; use super::control::{self, ControlMessage, ControlResult}; use super::publish::{Publish, PublishAck}; @@ -326,10 +326,9 @@ where &self.inner, ))) } - DispatchItem::IoError(err) => Either::Right(Either::Right(ControlResponse::new( - ControlMessage::proto_error(ProtocolError::Io(err)), - &self.inner, - ))), + DispatchItem::Disconnect(err) => Either::Right(Either::Right( + ControlResponse::new(ControlMessage::peer_gone(err), &self.inner), + )), DispatchItem::WBackPressureEnabled | DispatchItem::WBackPressureDisabled => { Either::Right(Either::Left(Ready::Ok(None))) } diff --git a/src/v5/handshake.rs b/src/v5/handshake.rs index 7a68da43..3d393b57 100644 --- a/src/v5/handshake.rs +++ b/src/v5/handshake.rs @@ -1,10 +1,11 @@ +use ntex::io::IoBoxed; use std::{fmt, num::NonZeroU16, rc::Rc}; use super::{codec, shared::MqttShared, sink::MqttSink}; /// Handshake message -pub struct Handshake { - io: Io, +pub struct Handshake { + io: IoBoxed, pkt: Box, pub(super) shared: Rc, pub(super) max_size: u32, @@ -12,10 +13,10 @@ pub struct Handshake { pub(super) max_topic_alias: u16, } -impl Handshake { +impl Handshake { pub(crate) fn new( pkt: Box, - io: Io, + io: IoBoxed, shared: Rc, max_size: u32, max_receive: u16, @@ -35,8 +36,8 @@ impl Handshake { } #[inline] - pub fn io(&mut self) -> &mut Io { - &mut self.io + pub fn io(&self) -> &IoBoxed { + &self.io } #[inline] @@ -47,7 +48,7 @@ impl Handshake { #[inline] /// Ack handshake message and set state - pub fn ack(self, st: St) -> HandshakeAck { + pub fn ack(self, st: St) -> HandshakeAck { let mut packet = codec::ConnectAck { reason_code: codec::ConnectAckReason::Success, topic_alias_max: self.max_topic_alias, @@ -72,7 +73,7 @@ impl Handshake { #[inline] /// Create handshake ack object with error - pub fn failed(self, reason_code: codec::ConnectAckReason) -> HandshakeAck { + pub fn failed(self, reason_code: codec::ConnectAckReason) -> HandshakeAck { HandshakeAck { io: self.io, shared: self.shared, @@ -84,7 +85,7 @@ impl Handshake { #[inline] /// Create handshake ack object with provided ConnectAck packet - pub fn fail_with(self, ack: codec::ConnectAck) -> HandshakeAck { + pub fn fail_with(self, ack: codec::ConnectAck) -> HandshakeAck { HandshakeAck { io: self.io, shared: self.shared, @@ -95,22 +96,22 @@ impl Handshake { } } -impl fmt::Debug for Handshake { +impl fmt::Debug for Handshake { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { self.pkt.fmt(f) } } /// Handshake ack message -pub struct HandshakeAck { - pub(crate) io: Io, +pub struct HandshakeAck { + pub(crate) io: IoBoxed, pub(crate) session: Option, pub(crate) shared: Rc, pub(crate) packet: codec::ConnectAck, pub(crate) keepalive: u16, } -impl HandshakeAck { +impl HandshakeAck { #[inline] /// Set idle keep-alive for the connection in seconds. /// This method sets `server_keepalive_sec` property for `ConnectAck` @@ -125,22 +126,6 @@ impl HandshakeAck { self } - #[doc(hidden)] - #[deprecated(since = "0.7.6", note = "Use memory pool config")] - #[inline] - /// Set read/write buffer sizes - /// - /// By default max buffer size is 4kb for both read and write buffer, - /// Min size is 256 bytes. - pub fn buffer_params( - self, - _max_read_buf: u16, - _max_write_buf: u16, - _min_buf_size: u16, - ) -> Self { - self - } - /// Access to ConnectAck packet #[inline] pub fn with(mut self, f: impl FnOnce(&mut codec::ConnectAck)) -> Self { diff --git a/src/v5/selector.rs b/src/v5/selector.rs index f084328b..73fe2cf9 100644 --- a/src/v5/selector.rs +++ b/src/v5/selector.rs @@ -3,13 +3,12 @@ use std::{ time, }; -use ntex::codec::{AsyncRead, AsyncWrite}; +use ntex::io::{DispatchItem, Io, IoBoxed}; use ntex::service::{apply_fn_factory, boxed, IntoServiceFactory, Service, ServiceFactory}; use ntex::time::{sleep, Seconds, Sleep}; use ntex::util::{timeout::Timeout, timeout::TimeoutError, Either, PoolId, Ready}; use crate::error::{MqttError, ProtocolError}; -use crate::io::{DispatchItem, State}; use super::control::{ControlMessage, ControlResult}; use super::default::{DefaultControlService, DefaultPublishService}; @@ -18,32 +17,26 @@ use super::publish::{Publish, PublishAck}; use super::shared::{MqttShared, MqttSinkPool}; use super::{codec as mqtt, dispatcher::factory, MqttServer, MqttSink, Session}; -pub(crate) type SelectItem = (Handshake, State, Option); +pub(crate) type SelectItem = (Handshake, Option); -type ServerFactory = boxed::BoxServiceFactory< - (), - SelectItem, - Either, ()>, - MqttError, - InitErr, ->; +type ServerFactory = + boxed::BoxServiceFactory<(), SelectItem, Either, MqttError, InitErr>; -type Server = - boxed::BoxService, Either, ()>, MqttError>; +type Server = boxed::BoxService, MqttError>; /// Mqtt server selector /// /// Selector allows to choose different mqtt server impls depends on /// connectt packet. -pub struct Selector { - servers: Vec>, +pub struct Selector { + servers: Vec>, max_size: u32, handshake_timeout: Seconds, pool: Rc, - _t: marker::PhantomData<(Io, Err, InitErr)>, + _t: marker::PhantomData<(Err, InitErr)>, } -impl Selector { +impl Selector { #[allow(clippy::new_without_default)] pub fn new() -> Self { Selector { @@ -56,9 +49,8 @@ impl Selector { } } -impl Selector +impl Selector where - Io: AsyncRead + AsyncWrite + Unpin + 'static, Err: 'static, InitErr: 'static, { @@ -80,29 +72,20 @@ where self } - /// Set memory pool. - /// - /// Use specified memory pool for memory allocations. By default P5 - /// memory pool is used. - pub fn memory_pool(self, id: PoolId) -> Self { - self.pool.pool.set(id.pool_ref()); - self - } - /// Add server variant pub fn variant( mut self, check: F, - mut server: MqttServer, + mut server: MqttServer, ) -> Self where - F: Fn(&Handshake) -> R + 'static, + F: Fn(&Handshake) -> R + 'static, R: Future> + 'static, St: 'static, C: ServiceFactory< Config = (), - Request = Handshake, - Response = HandshakeAck, + Request = Handshake, + Response = HandshakeAck, Error = Err, InitError = InitErr, > + 'static, @@ -132,7 +115,7 @@ where self, ) -> impl ServiceFactory< Config = (), - Request = (Io, State, Option), + Request = (IoBoxed, Option), Response = (), Error = MqttError, InitError = InitErr, @@ -146,18 +129,17 @@ where } } -impl ServiceFactory for Selector +impl ServiceFactory for Selector where - Io: AsyncRead + AsyncWrite + Unpin + 'static, Err: 'static, InitErr: 'static, { type Config = (); - type Request = Io; + type Request = IoBoxed; type Response = (); type Error = MqttError; type InitError = InitErr; - type Service = SelectorService; + type Service = SelectorService; type Future = Pin>>>; fn new_service(&self, _: ()) -> Self::Future { @@ -176,19 +158,18 @@ where } } -pub struct SelectorService { - servers: Rc>>, +pub struct SelectorService { + servers: Rc>>, max_size: u32, handshake_timeout: Seconds, pool: Rc, } -impl Service for SelectorService +impl Service for SelectorService where - Io: AsyncRead + AsyncWrite + Unpin + 'static, Err: 'static, { - type Request = Io; + type Request = IoBoxed; type Response = (); type Error = MqttError; type Future = Pin>>>>; @@ -220,11 +201,10 @@ where } #[inline] - fn call(&self, mut io: Io) -> Self::Future { + fn call(&self, io: IoBoxed) -> Self::Future { let servers = self.servers.clone(); - let state = State::with_memory_pool(self.pool.pool.get()); let shared = Rc::new(MqttShared::new( - state.clone(), + io.get_ref(), mqtt::Codec::default().max_inbound_size(self.max_size), 0, self.pool.clone(), @@ -233,18 +213,16 @@ where let delay = self.handshake_timeout.map(sleep); Box::pin(async move { // read first packet - let packet = state - .next(&mut io, &shared.codec) + let packet = io + .next(&shared.codec) .await + .ok_or_else(|| { + log::trace!("Server mqtt is disconnected during handshake"); + MqttError::Disconnected + })? .map_err(|err| { log::trace!("Error is received during mqtt handshake: {:?}", err); MqttError::from(err) - }) - .and_then(|res| { - res.ok_or_else(|| { - log::trace!("Server mqtt is disconnected during handshake"); - MqttError::Disconnected - }) })?; let connect = match packet { @@ -259,7 +237,7 @@ where }; // call servers - let mut item = (Handshake::new(connect, io, shared, 0, 0, 0), state, delay); + let mut item = (Handshake::new(connect, io, shared, 0, 0, 0), delay); for srv in servers.iter() { match srv.call(item).await? { Either::Left(result) => { @@ -274,25 +252,24 @@ where } } -pub(crate) struct Selector2 { - servers: Vec>, +pub(crate) struct Selector2 { + servers: Vec>, max_size: u32, pool: Rc, - _t: marker::PhantomData<(Io, Err, InitErr)>, + _t: marker::PhantomData<(Err, InitErr)>, } -impl ServiceFactory for Selector2 +impl ServiceFactory for Selector2 where - Io: AsyncRead + AsyncWrite + Unpin + 'static, Err: 'static, InitErr: 'static, { type Config = (); - type Request = (Io, State, Option); + type Request = (IoBoxed, Option); type Response = (); type Error = MqttError; type InitError = InitErr; - type Service = SelectorService2; + type Service = SelectorService2; type Future = Pin>>>; fn new_service(&self, _: ()) -> Self::Future { @@ -310,18 +287,17 @@ where } } -pub(crate) struct SelectorService2 { - servers: Rc>>, +pub(crate) struct SelectorService2 { + servers: Rc>>, max_size: u32, pool: Rc, } -impl Service for SelectorService2 +impl Service for SelectorService2 where - Io: AsyncRead + AsyncWrite + Unpin + 'static, Err: 'static, { - type Request = (Io, State, Option); + type Request = (IoBoxed, Option); type Response = (); type Error = MqttError; type Future = Pin>>>>; @@ -353,10 +329,10 @@ where } #[inline] - fn call(&self, (mut io, state, delay): Self::Request) -> Self::Future { + fn call(&self, (io, delay): Self::Request) -> Self::Future { let servers = self.servers.clone(); let shared = Rc::new(MqttShared::new( - state.clone(), + io.get_ref(), mqtt::Codec::default().max_inbound_size(self.max_size), 0, self.pool.clone(), @@ -364,18 +340,16 @@ where Box::pin(async move { // read first packet - let packet = state - .next(&mut io, &shared.codec) + let packet = io + .next(&shared.codec) .await + .ok_or_else(|| { + log::trace!("Server mqtt is disconnected during handshake"); + MqttError::Disconnected + })? .map_err(|err| { - log::trace!("Error is received during mqtt handshake: {:?}", err); + // log::trace!("Error is received during mqtt handshake: {:?}", err); MqttError::from(err) - }) - .and_then(|res| { - res.ok_or_else(|| { - log::trace!("Server mqtt is disconnected during handshake"); - MqttError::Disconnected - }) })?; let connect = match packet { @@ -390,7 +364,7 @@ where }; // call servers - let mut item = (Handshake::new(connect, io, shared, 0, 0, 0), state, delay); + let mut item = (Handshake::new(connect, io, shared, 0, 0, 0), delay); for srv in servers.iter() { match srv.call(item).await? { Either::Left(result) => { diff --git a/src/v5/server.rs b/src/v5/server.rs index 9de2405f..c22b53b6 100644 --- a/src/v5/server.rs +++ b/src/v5/server.rs @@ -1,15 +1,14 @@ use std::task::{Context, Poll}; use std::{cell::RefCell, convert::TryFrom, fmt, future::Future, marker, pin::Pin, rc::Rc}; -use ntex::codec::{AsyncRead, AsyncWrite}; -use ntex::framed::WriteTask; +use ntex::io::{into_boxed, DispatchItem, Filter, Io, IoBoxed, Timer}; use ntex::service::{IntoServiceFactory, Service, ServiceFactory}; use ntex::time::{Millis, Seconds, Sleep}; use ntex::util::timeout::{Timeout, TimeoutError}; use ntex::util::{Either, PoolId, PoolRef}; use crate::error::{MqttError, ProtocolError}; -use crate::io::{DispatchItem, Dispatcher, State, Timer}; +use crate::io::Dispatcher; use crate::service::{FramedService, FramedService2}; use crate::types::QoS; @@ -22,7 +21,7 @@ use super::shared::{MqttShared, MqttSinkPool}; use super::{codec as mqtt, dispatcher::factory, MqttSink, Session}; /// Mqtt Server -pub struct MqttServer { +pub struct MqttServer { handshake: C, srv_control: Cn, srv_publish: P, @@ -33,21 +32,14 @@ pub struct MqttServer, - _t: marker::PhantomData<(Io, St)>, + _t: marker::PhantomData, } -impl - MqttServer< - Io, - St, - C, - DefaultControlService, - DefaultPublishService, - > +impl + MqttServer, DefaultPublishService> where St: 'static, - C: ServiceFactory, Response = HandshakeAck> - + 'static, + C: ServiceFactory> + 'static, C::Error: fmt::Debug, { /// Create server factory and provide handshake service @@ -71,12 +63,10 @@ where } } -impl MqttServer +impl MqttServer where - Io: AsyncRead + AsyncWrite + Unpin + 'static, St: 'static, - C: ServiceFactory, Response = HandshakeAck> - + 'static, + C: ServiceFactory> + 'static, C::Error: fmt::Debug, Cn: ServiceFactory< Config = Session, @@ -141,20 +131,11 @@ where self } - /// Set memory pool. - /// - /// Use specified memory pool for memory allocations. By default P5 - /// memory pool is used. - pub fn memory_pool(self, id: PoolId) -> Self { - self.pool.pool.set(id.pool_ref()); - self - } - /// Service to handle control packets /// /// All control packets are processed sequentially, max number of buffered /// control packets is 16. - pub fn control(self, service: F) -> MqttServer + pub fn control(self, service: F) -> MqttServer where F: IntoServiceFactory, Srv: ServiceFactory< @@ -180,7 +161,7 @@ where } /// Set service to handle publish packets and create mqtt server factory - pub fn publish(self, publish: F) -> MqttServer + pub fn publish(self, publish: F) -> MqttServer where F: IntoServiceFactory + 'static, C::Error: From + From, @@ -205,12 +186,10 @@ where } } -impl MqttServer +impl MqttServer where - Io: AsyncRead + AsyncWrite + Unpin + 'static, St: 'static, - C: ServiceFactory, Response = HandshakeAck> - + 'static, + C: ServiceFactory> + 'static, C::Error: From + From + From @@ -226,11 +205,10 @@ where PublishAck: TryFrom, { /// Set service to handle publish packets and create mqtt server factory - pub fn finish( + pub fn finish( self, - ) -> impl ServiceFactory> + ) -> impl ServiceFactory, Response = (), Error = MqttError> { - let pool = self.pool.pool.get().pool(); let handshake = self.handshake; let publish = self.srv_publish.map_init_err(|e| MqttError::Service(e.into())); let control = self @@ -238,7 +216,7 @@ where .map_err(::from) .map_init_err(|e| MqttError::Service(e.into())); - FramedService::new( + into_boxed(FramedService::new( handshake_service_factory( handshake, self.max_size, @@ -249,9 +227,8 @@ where self.pool, ), factory(publish, control), - pool, self.disconnect_timeout, - ) + )) } /// Set service to handle publish packets and create mqtt server factory @@ -259,12 +236,11 @@ where self, ) -> impl ServiceFactory< Config = (), - Request = (Io, State, Option), + Request = (IoBoxed, Option), Response = (), Error = MqttError, InitError = C::InitError, > { - let pool = self.pool.pool.get().pool(); let handshake = self.handshake; let publish = self.srv_publish.map_init_err(|e| MqttError::Service(e.into())); let control = self @@ -273,7 +249,7 @@ where .map_init_err(|e| MqttError::Service(e.into())); FramedService2::new( - handshake_service_factory2( + handshake_service_factory( handshake, self.max_size, self.max_receive, @@ -283,7 +259,6 @@ where self.pool, ), factory(publish, control), - pool, self.disconnect_timeout, ) } @@ -294,13 +269,13 @@ where check: F, ) -> impl ServiceFactory< Config = (), - Request = SelectItem, - Response = Either, ()>, + Request = SelectItem, + Response = Either, Error = MqttError, InitError = C::InitError, > where - F: Fn(&Handshake) -> R + 'static, + F: Fn(&Handshake) -> R + 'static, R: Future> + 'static, { let publish = self.srv_publish.map_init_err(|e| MqttError::Service(e.into())); @@ -309,7 +284,7 @@ where .map_err(::from) .map_init_err(|e| MqttError::Service(e.into())); - ServerSelector:: { + ServerSelector:: { check: Rc::new(check), connect: self.handshake, handler: Rc::new(factory(publish, control)), @@ -324,60 +299,7 @@ where } } -fn handshake_service_factory( - factory: C, - max_size: u32, - max_receive: u16, - max_topic_alias: u16, - max_qos: Option, - handshake_timeout: Seconds, - pool: Rc, -) -> impl ServiceFactory< - Config = (), - Request = Io, - Response = (Io, State, Rc, Session, Seconds), - Error = MqttError, -> -where - Io: AsyncRead + AsyncWrite + Unpin + 'static, - C: ServiceFactory, Response = HandshakeAck>, - C::Error: fmt::Debug, -{ - ntex::service::apply( - Timeout::new(Millis::from(handshake_timeout)), - ntex::service::fn_factory(move || { - let pool = pool.clone(); - - let fut = factory.new_service(()); - async move { - let service = fut.await?; - let pool = pool.clone(); - let service = Rc::new(service.map_err(MqttError::Service)); - Ok::<_, C::InitError>(ntex::service::apply_fn( - service, - move |io: Io, service| { - handshake( - io, - None, - service.clone(), - max_size, - max_receive, - max_topic_alias, - max_qos, - pool.clone(), - ) - }, - )) - } - }), - ) - .map_err(|e| match e { - TimeoutError::Service(e) => e, - TimeoutError::Timeout => MqttError::HandshakeTimeout, - }) -} - -fn handshake_service_factory2( +fn handshake_service_factory( factory: C, max_size: u32, max_receive: u16, @@ -387,40 +309,36 @@ fn handshake_service_factory2( pool: Rc, ) -> impl ServiceFactory< Config = (), - Request = (Io, State), - Response = (Io, State, Rc, Session, Seconds), + Request = IoBoxed, + Response = (IoBoxed, Rc, Session, Seconds), Error = MqttError, InitError = C::InitError, > where - Io: AsyncRead + AsyncWrite + Unpin + 'static, - C: ServiceFactory, Response = HandshakeAck>, + C: ServiceFactory>, C::Error: fmt::Debug, { ntex::service::apply( Timeout::new(Millis::from(handshake_timeout)), ntex::service::fn_factory(move || { let pool = pool.clone(); + let fut = factory.new_service(()); async move { let service = fut.await?; let pool = pool.clone(); let service = Rc::new(service.map_err(MqttError::Service)); - Ok::<_, C::InitError>(ntex::service::apply_fn( - service, - move |(io, state), service| { - handshake( - io, - Some(state), - service.clone(), - max_size, - max_receive, - max_topic_alias, - max_qos, - pool.clone(), - ) - }, - )) + Ok::<_, C::InitError>(ntex::service::apply_fn(service, move |io, service| { + handshake( + io, + service.clone(), + max_size, + max_receive, + max_topic_alias, + max_qos, + pool.clone(), + ) + })) } }), ) @@ -431,41 +349,36 @@ where } #[allow(clippy::too_many_arguments)] -async fn handshake( - mut io: Io, - state: Option, +async fn handshake( + io: IoBoxed, service: S, max_size: u32, mut max_receive: u16, mut max_topic_alias: u16, max_qos: Option, pool: Rc, -) -> Result<(Io, State, Rc, Session, Seconds), S::Error> +) -> Result<(IoBoxed, Rc, Session, Seconds), S::Error> where - Io: AsyncRead + AsyncWrite + Unpin + 'static, - S: Service, Response = HandshakeAck, Error = MqttError>, + S: Service, Error = MqttError>, { log::trace!("Starting mqtt v5 handshake"); - let state = state.unwrap_or_else(|| State::with_memory_pool(pool.pool.get())); - let shared = Rc::new(MqttShared::new(state.clone(), mqtt::Codec::default(), 0, pool)); + let shared = Rc::new(MqttShared::new(io.get_ref(), mqtt::Codec::default(), 0, pool)); // set max inbound (decoder) packet size shared.codec.set_max_inbound_size(max_size); // read first packet - let packet = state - .next(&mut io, &shared.codec) + let packet = io + .next(&shared.codec) .await + .ok_or_else(|| { + log::trace!("Server mqtt is disconnected during handshake"); + MqttError::Disconnected + })? .map_err(|err| { log::trace!("Error is received during mqtt handshake: {:?}", err); MqttError::from(err) - }) - .and_then(|res| { - res.ok_or_else(|| { - log::trace!("Server mqtt is disconnected during handshake"); - MqttError::Disconnected - }) })?; match packet { @@ -515,17 +428,12 @@ where ack.packet.server_keepalive_sec = Some(ack.keepalive as u16); } - state - .send( - &mut ack.io, - &shared.codec, - mqtt::Packet::ConnectAck(Box::new(ack.packet)), - ) + ack.io + .send(mqtt::Packet::ConnectAck(Box::new(ack.packet)), &shared.codec) .await?; Ok(( ack.io, - shared.state.clone(), shared.clone(), Session::new_v5( session, @@ -539,22 +447,16 @@ where None => { log::trace!("Failed to complete handshake: {:#?}", ack.packet); - if ack.shared.state.is_open() + if !ack.io.is_closed() && ack - .shared - .state - .write() + .io .encode( mqtt::Packet::ConnectAck(Box::new(ack.packet)), &ack.shared.codec, ) .is_ok() { - WriteTask::shutdown( - Rc::new(RefCell::new(ack.io)), - ack.shared.state.clone(), - ) - .await; + let _ = ack.io.shutdown().await; } Err(MqttError::Disconnected) } @@ -570,7 +472,7 @@ where } } -pub(crate) struct ServerSelector { +pub(crate) struct ServerSelector { connect: C, handler: Rc, time: Timer, @@ -580,16 +482,14 @@ pub(crate) struct ServerSelector { max_qos: Option, disconnect_timeout: Seconds, max_topic_alias: u16, - _t: marker::PhantomData<(St, Io, R)>, + _t: marker::PhantomData<(St, R)>, } -impl ServiceFactory for ServerSelector +impl ServiceFactory for ServerSelector where - Io: AsyncRead + AsyncWrite + Unpin + 'static, - F: Fn(&Handshake) -> R + 'static, + F: Fn(&Handshake) -> R + 'static, R: Future>, - C: ServiceFactory, Response = HandshakeAck> - + 'static, + C: ServiceFactory> + 'static, C::Error: fmt::Debug, T: ServiceFactory< Config = Session, @@ -600,11 +500,11 @@ where > + 'static, { type Config = (); - type Request = SelectItem; - type Response = Either, ()>; + type Request = SelectItem; + type Response = Either; type Error = MqttError; type InitError = C::InitError; - type Service = ServerSelectorImpl; + type Service = ServerSelectorImpl; type Future = Pin>>>; fn new_service(&self, _: ()) -> Self::Future { @@ -636,7 +536,7 @@ where } } -pub(crate) struct ServerSelectorImpl { +pub(crate) struct ServerSelectorImpl { check: Rc, connect: Rc, handler: Rc, @@ -646,15 +546,14 @@ pub(crate) struct ServerSelectorImpl { disconnect_timeout: Seconds, max_topic_alias: u16, time: Timer, - _t: marker::PhantomData<(St, Io, R)>, + _t: marker::PhantomData<(St, R)>, } -impl Service for ServerSelectorImpl +impl Service for ServerSelectorImpl where - Io: AsyncRead + AsyncWrite + Unpin + 'static, - F: Fn(&Handshake) -> R + 'static, + F: Fn(&Handshake) -> R + 'static, R: Future>, - C: Service, Response = HandshakeAck> + 'static, + C: Service> + 'static, C::Error: fmt::Debug, T: ServiceFactory< Config = Session, @@ -664,8 +563,8 @@ where InitError = MqttError, > + 'static, { - type Request = SelectItem; - type Response = Either, ()>; + type Request = SelectItem; + type Response = Either; type Error = MqttError; type Future = Pin>>>; @@ -694,7 +593,7 @@ where let mut max_topic_alias = self.max_topic_alias; Box::pin(async move { - let (mut hnd, state, mut delay) = req; + let (mut hnd, mut delay) = req; let result = if let Some(ref mut delay) = delay { let fut = (&*check)(&hnd); @@ -707,7 +606,7 @@ where }; if !result.map_err(MqttError::Service)? { - Ok(Either::Left((hnd, state, delay))) + Ok(Either::Left((hnd, delay))) } else { // set max outbound (encoder) packet size if let Some(size) = hnd.packet().max_packet_size { @@ -764,12 +663,8 @@ where ack.packet.server_keepalive_sec = Some(ack.keepalive as u16); } - state - .send( - &mut ack.io, - &shared.codec, - mqtt::Packet::ConnectAck(Box::new(ack.packet)), - ) + ack.io + .send(mqtt::Packet::ConnectAck(Box::new(ack.packet)), &shared.codec) .await?; let session = Session::new_v5( @@ -781,7 +676,7 @@ where let handler = handler.new_service(session).await?; log::trace!("Connection handler is created, starting dispatcher"); - Dispatcher::with(ack.io, shared.state.clone(), shared, handler, time) + Dispatcher::new(ack.io, shared, handler, time) .keepalive_timeout(Seconds(ack.keepalive)) .disconnect_timeout(timeout) .await?; @@ -790,22 +685,16 @@ where None => { log::trace!("Failed to complete handshake: {:#?}", ack.packet); - if ack.shared.state.is_open() + if !ack.io.is_closed() && ack - .shared - .state - .write() + .io .encode( mqtt::Packet::ConnectAck(Box::new(ack.packet)), &ack.shared.codec, ) .is_ok() { - WriteTask::shutdown( - Rc::new(RefCell::new(ack.io)), - ack.shared.state.clone(), - ) - .await; + let _ = ack.io.shutdown().await; } Err(MqttError::Disconnected) } diff --git a/src/v5/shared.rs b/src/v5/shared.rs index 9039ef63..3da8745c 100644 --- a/src/v5/shared.rs +++ b/src/v5/shared.rs @@ -2,17 +2,18 @@ use std::{cell::Cell, cell::RefCell, collections::VecDeque, rc::Rc}; use ntex::channel::pool; use ntex::codec::{Decoder, Encoder}; +use ntex::io::IoRef; use ntex::util::{BytesMut, HashMap, PoolId, PoolRef}; use super::codec; -use crate::{error, io::State, types::packet_type}; +use crate::{error, types::packet_type}; pub(crate) struct MqttShared { + pub(super) io: IoRef, pub(super) cap: Cell, queues: RefCell, pub(super) inflight_idx: Cell, pub(super) pool: Rc, - pub(super) state: State, pub(super) codec: codec::Codec, } @@ -40,13 +41,13 @@ impl Default for MqttSinkPool { impl MqttShared { pub(super) fn new( - state: State, + io: IoRef, codec: codec::Codec, cap: usize, pool: Rc, ) -> Self { Self { - state, + io, pool, codec, cap: Cell::new(cap), diff --git a/src/v5/sink.rs b/src/v5/sink.rs index fc590ffa..2c291f84 100644 --- a/src/v5/sink.rs +++ b/src/v5/sink.rs @@ -23,7 +23,7 @@ impl MqttSink { /// Check connection status pub fn is_open(&self) -> bool { - self.0.state.is_open() + !self.0.io.is_closed() } /// Get client's receive credit @@ -36,7 +36,7 @@ impl MqttSink { /// /// Result indicates if connection is alive pub fn ready(&self) -> impl Future { - if self.0.state.is_open() { + if !self.0.io.is_closed() { self.0 .with_queues(|q| { if q.inflight.len() >= self.0.cap.get() { @@ -58,10 +58,9 @@ impl MqttSink { if self.is_open() { let _ = self .0 - .state - .write() + .io .encode(codec::Packet::Disconnect(codec::Disconnect::default()), &self.0.codec); - self.0.state.close(); + self.0.io.close(); } self.0.with_queues(|q| { q.inflight.clear(); @@ -72,8 +71,8 @@ impl MqttSink { /// Close mqtt connection pub fn close_with_reason(&self, pkt: codec::Disconnect) { if self.is_open() { - let _ = self.0.state.write().encode(codec::Packet::Disconnect(pkt), &self.0.codec); - self.0.state.close(); + let _ = self.0.io.encode(codec::Packet::Disconnect(pkt), &self.0.codec); + self.0.io.close(); } self.0.with_queues(|q| { q.inflight.clear(); @@ -82,12 +81,12 @@ impl MqttSink { } pub(super) fn send(&self, pkt: codec::Packet) { - let _ = self.0.state.write().encode(pkt, &self.0.codec); + let _ = self.0.io.encode(pkt, &self.0.codec); } /// Send ping pub(super) fn ping(&self) -> bool { - self.0.state.write().encode(codec::Packet::PingRequest, &self.0.codec).is_ok() + self.0.io.encode(codec::Packet::PingRequest, &self.0.codec).is_ok() } /// Close mqtt connection, dont send disconnect message @@ -96,7 +95,7 @@ impl MqttSink { q.waiters.clear(); q.inflight.clear(); }); - self.0.state.close(); + self.0.io.close(); } pub(super) fn pkt_ack(&self, pkt: Ack) -> Result<(), ProtocolError> { @@ -252,11 +251,10 @@ impl PublishBuilder { pub fn send_at_most_once(self) -> Result<(), SendPacketError> { let packet = self.packet; - if self.shared.state.is_open() { + if !self.shared.io.is_closed() { log::trace!("Publish (QoS-0) to {:?}", packet.topic); self.shared - .state - .write() + .io .encode(codec::Packet::Publish(packet), &self.shared.codec) .map_err(SendPacketError::Encode) .map(|_| ()) @@ -274,7 +272,7 @@ impl PublishBuilder { let mut packet = self.packet; packet.qos = QoS::AtLeastOnce; - if shared.state.is_open() { + if !shared.io.is_closed() { // handle client receive maximum if !shared.has_credit() { let (tx, rx) = shared.pool.waiters.channel(); @@ -324,7 +322,7 @@ impl PublishBuilder { // send publish to client log::trace!("Publish (QoS1) to {:#?}", packet); - match shared.state.write().encode(codec::Packet::Publish(packet), &shared.codec) { + match shared.io.encode(codec::Packet::Publish(packet), &shared.codec) { Ok(_) => { // wait ack from peer Either::Right(async move { @@ -383,7 +381,7 @@ impl SubscribeBuilder { let shared = self.shared; let mut packet = self.packet; - if shared.state.is_open() { + if !shared.io.is_closed() { // handle client receive maximum if !shared.has_credit() { let (tx, rx) = shared.pool.waiters.channel(); @@ -411,7 +409,7 @@ impl SubscribeBuilder { // send subscribe to client log::trace!("Sending subscribe packet {:#?}", packet); - match shared.state.write().encode(codec::Packet::Subscribe(packet), &shared.codec) { + match shared.io.encode(codec::Packet::Subscribe(packet), &shared.codec) { Ok(_) => { // wait ack from peer rx.await @@ -463,7 +461,7 @@ impl UnsubscribeBuilder { let shared = self.shared; let mut packet = self.packet; - if shared.state.is_open() { + if !shared.io.is_closed() { // handle client receive maximum if !shared.has_credit() { let (tx, rx) = shared.pool.waiters.channel(); @@ -491,8 +489,7 @@ impl UnsubscribeBuilder { // send unsubscribe to client log::trace!("Sending unsubscribe packet {:#?}", packet); - match shared.state.write().encode(codec::Packet::Unsubscribe(packet), &shared.codec) - { + match shared.io.encode(codec::Packet::Unsubscribe(packet), &shared.codec) { Ok(_) => { // wait ack from peer rx.await diff --git a/tests/test_server.rs b/tests/test_server.rs index 5e5961be..89efaf74 100644 --- a/tests/test_server.rs +++ b/tests/test_server.rs @@ -1,11 +1,10 @@ use std::sync::{atomic::AtomicBool, atomic::Ordering::Relaxed, Arc}; use std::{num::NonZeroU16, time::Duration}; -use futures::{future::ok, FutureExt, SinkExt, StreamExt}; -use ntex::codec::Framed; +use futures::FutureExt; use ntex::server; -use ntex::time::{sleep, Seconds}; -use ntex::util::{poll_fn, ByteString, Bytes}; +use ntex::time::{sleep, Millis, Seconds}; +use ntex::util::{ByteString, Bytes, Ready}; use ntex_mqtt::v3::{ client, codec, ControlMessage, Handshake, HandshakeAck, MqttServer, Publish, Session, @@ -13,7 +12,7 @@ use ntex_mqtt::v3::{ struct St; -async fn handshake(mut packet: Handshake) -> Result, ()> { +async fn handshake(mut packet: Handshake) -> Result, ()> { packet.packet(); packet.packet_mut(); packet.io(); @@ -23,7 +22,8 @@ async fn handshake(mut packet: Handshake) -> Result #[ntex::test] async fn test_simple() -> std::io::Result<()> { - let srv = server::test_server(|| MqttServer::new(handshake).publish(|_t| ok(())).finish()); + let srv = + server::test_server(|| MqttServer::new(handshake).publish(|_t| Ready::Ok(())).finish()); // connect to server let client = @@ -45,8 +45,8 @@ async fn test_simple() -> std::io::Result<()> { async fn test_connect_fail() -> std::io::Result<()> { // bad user name or password let srv = server::test_server(|| { - MqttServer::new(|conn: Handshake<_>| ok::<_, ()>(conn.bad_username_or_pwd::())) - .publish(|_t| ok(())) + MqttServer::new(|conn: Handshake| Ready::Ok::<_, ()>(conn.bad_username_or_pwd::())) + .publish(|_t| Ready::Ok(())) .finish() }); let err = @@ -58,8 +58,8 @@ async fn test_connect_fail() -> std::io::Result<()> { // identifier rejected let srv = server::test_server(|| { - MqttServer::new(|conn: Handshake<_>| ok::<_, ()>(conn.identifier_rejected::())) - .publish(|_t| ok(())) + MqttServer::new(|conn: Handshake| Ready::Ok::<_, ()>(conn.identifier_rejected::())) + .publish(|_t| Ready::Ok(())) .finish() }); let err = @@ -71,8 +71,8 @@ async fn test_connect_fail() -> std::io::Result<()> { // not authorized let srv = server::test_server(|| { - MqttServer::new(|conn: Handshake<_>| ok::<_, ()>(conn.not_authorized::())) - .publish(|_t| ok(())) + MqttServer::new(|conn: Handshake| Ready::Ok::<_, ()>(conn.not_authorized::())) + .publish(|_t| Ready::Ok(())) .finish() }); let err = @@ -84,8 +84,8 @@ async fn test_connect_fail() -> std::io::Result<()> { // service unavailable let srv = server::test_server(|| { - MqttServer::new(|conn: Handshake<_>| ok::<_, ()>(conn.service_unavailable::())) - .publish(|_t| ok(())) + MqttServer::new(|conn: Handshake| Ready::Ok::<_, ()>(conn.service_unavailable::())) + .publish(|_t| Ready::Ok(())) .finish() }); let err = @@ -106,30 +106,29 @@ async fn test_ping() -> std::io::Result<()> { let srv = server::test_server(move || { let ping = ping2.clone(); MqttServer::new(handshake) - .publish(|_| ok(())) + .publish(|_| Ready::Ok(())) .control(move |msg| { let ping = ping.clone(); match msg { ControlMessage::Ping(msg) => { ping.store(true, Relaxed); - ok(msg.ack()) + Ready::Ok(msg.ack()) } - _ => ok(msg.disconnect()), + _ => Ready::Ok(msg.disconnect()), } }) .finish() }); let io = srv.connect().await.unwrap(); - let mut framed = Framed::new(io, codec::Codec::default()); - framed - .send(codec::Packet::Connect(codec::Connect::default().client_id("user").into())) + let codec = codec::Codec::default(); + io.send(codec::Packet::Connect(codec::Connect::default().client_id("user").into()), &codec) .await .unwrap(); - framed.next().await.unwrap().unwrap(); + io.next(&codec).await.unwrap().unwrap(); - framed.send(codec::Packet::PingRequest).await.unwrap(); - let pkt = framed.next().await.unwrap().unwrap(); + io.send(codec::Packet::PingRequest, &codec).await.unwrap(); + let pkt = io.next(&codec).await.unwrap().unwrap(); assert_eq!(pkt, codec::Packet::PingResponse); assert!(ping.load(Relaxed)); @@ -148,58 +147,60 @@ async fn test_ack_order() -> std::io::Result<()> { sub.topic(); sub.subscribe(codec::QoS::AtLeastOnce); } - ok(msg.ack()) + Ready::Ok(msg.ack()) } - _ => ok(msg.disconnect()), + _ => Ready::Ok(msg.disconnect()), }) .finish() }); let io = srv.connect().await.unwrap(); - let mut framed = Framed::new(io, codec::Codec::default()); - framed.send(codec::Connect::default().client_id("user").into()).await.unwrap(); - let _ = framed.next().await.unwrap().unwrap(); - - framed - .send( - codec::Publish { - dup: false, - retain: false, - qos: codec::QoS::AtLeastOnce, - topic: ByteString::from("test"), - packet_id: Some(NonZeroU16::new(1).unwrap()), - payload: Bytes::new(), - } - .into(), - ) - .await - .unwrap(); - framed - .send(codec::Packet::Subscribe { + let codec = codec::Codec::default(); + io.send(codec::Connect::default().client_id("user").into(), &codec).await.unwrap(); + let _ = io.next(&codec).await.unwrap().unwrap(); + + io.send( + codec::Publish { + dup: false, + retain: false, + qos: codec::QoS::AtLeastOnce, + topic: ByteString::from("test"), + packet_id: Some(NonZeroU16::new(1).unwrap()), + payload: Bytes::new(), + } + .into(), + &codec, + ) + .await + .unwrap(); + io.send( + codec::Packet::Subscribe { packet_id: NonZeroU16::new(2).unwrap(), topic_filters: vec![(ByteString::from("topic1"), codec::QoS::AtLeastOnce)], - }) - .await - .unwrap(); - framed - .send( - codec::Publish { - dup: false, - retain: false, - qos: codec::QoS::AtLeastOnce, - topic: ByteString::from("test"), - packet_id: Some(NonZeroU16::new(3).unwrap()), - payload: Bytes::new(), - } - .into(), - ) - .await - .unwrap(); + }, + &codec, + ) + .await + .unwrap(); + io.send( + codec::Publish { + dup: false, + retain: false, + qos: codec::QoS::AtLeastOnce, + topic: ByteString::from("test"), + packet_id: Some(NonZeroU16::new(3).unwrap()), + payload: Bytes::new(), + } + .into(), + &codec, + ) + .await + .unwrap(); - let pkt = framed.next().await.unwrap().unwrap(); + let pkt = io.next(&codec).await.unwrap().unwrap(); assert_eq!(pkt, codec::Packet::PublishAck { packet_id: NonZeroU16::new(1).unwrap() }); - let pkt = framed.next().await.unwrap().unwrap(); + let pkt = io.next(&codec).await.unwrap().unwrap(); assert_eq!( pkt, codec::Packet::SubscribeAck { @@ -208,7 +209,7 @@ async fn test_ack_order() -> std::io::Result<()> { } ); - let pkt = framed.next().await.unwrap().unwrap(); + let pkt = io.next(&codec).await.unwrap().unwrap(); assert_eq!(pkt, codec::Packet::PublishAck { packet_id: NonZeroU16::new(3).unwrap() }); Ok(()) @@ -247,7 +248,7 @@ async fn test_disconnect() -> std::io::Result<()> { let srv = server::test_server(|| { MqttServer::new(handshake) .publish(ntex::service::fn_factory_with_config(|session: Session| { - ok(ntex::service::fn_service(move |_: Publish| { + Ready::Ok(ntex::service::fn_service(move |_: Publish| { session.sink().force_close(); async { sleep(Duration::from_millis(100)).await; @@ -294,33 +295,34 @@ async fn test_handle_incoming() -> std::io::Result<()> { .control(move |msg| match msg { ControlMessage::Disconnect(msg) => { disconnect.store(true, Relaxed); - ok(msg.ack()) + Ready::Ok(msg.ack()) } - _ => ok(msg.disconnect()), + _ => Ready::Ok(msg.disconnect()), }) .finish() }); let io = srv.connect().await.unwrap(); - let mut framed = Framed::new(io, codec::Codec::default()); - framed.write(codec::Connect::default().client_id("user").into()).unwrap(); - framed - .write( - codec::Publish { - dup: false, - retain: false, - qos: codec::QoS::AtLeastOnce, - topic: ByteString::from("test"), - packet_id: Some(NonZeroU16::new(3).unwrap()), - payload: Bytes::new(), - } - .into(), - ) - .unwrap(); - framed.write(codec::Packet::Disconnect).unwrap(); - poll_fn(|cx| framed.flush(cx)).await.unwrap(); - drop(framed); - sleep(Duration::from_millis(500)).await; + let codec = codec::Codec::default(); + io.encode(codec::Connect::default().client_id("user").into(), &codec).unwrap(); + io.encode( + codec::Publish { + dup: false, + retain: false, + qos: codec::QoS::AtLeastOnce, + topic: ByteString::from("test"), + packet_id: Some(NonZeroU16::new(3).unwrap()), + payload: Bytes::new(), + } + .into(), + &codec, + ) + .unwrap(); + io.encode(codec::Packet::Disconnect, &codec).unwrap(); + io.write_ready(true).await.unwrap(); + sleep(Millis(50)).await; + drop(io); + sleep(Millis(50)).await; assert!(publish.load(Relaxed)); assert!(disconnect.load(Relaxed)); diff --git a/tests/test_server_both.rs b/tests/test_server_both.rs index 982b9654..bca995ce 100644 --- a/tests/test_server_both.rs +++ b/tests/test_server_both.rs @@ -1,8 +1,7 @@ use std::convert::TryFrom; -use futures::future::ok; use ntex::server; -use ntex::util::{ByteString, Bytes}; +use ntex::util::{ByteString, Bytes, Ready}; use ntex_mqtt::{v3, v5, MqttServer}; @@ -29,12 +28,14 @@ impl TryFrom for v5::PublishAck { async fn test_simple() -> std::io::Result<()> { let srv = server::test_server(|| { MqttServer::new() - .v3(v3::MqttServer::new(|con: v3::Handshake<_>| { - ok::<_, TestError>(con.ack(St, false)) + .v3(v3::MqttServer::new(|con: v3::Handshake| { + Ready::Ok::<_, TestError>(con.ack(St, false)) }) - .publish(|_| ok::<_, TestError>(()))) - .v5(v5::MqttServer::new(|con: v5::Handshake<_>| ok::<_, TestError>(con.ack(St))) - .publish(|p: v5::Publish| ok::<_, TestError>(p.ack()))) + .publish(|_| Ready::Ok::<_, TestError>(()))) + .v5(v5::MqttServer::new(|con: v5::Handshake| { + Ready::Ok::<_, TestError>(con.ack(St)) + }) + .publish(|p: v5::Publish| Ready::Ok::<_, TestError>(p.ack()))) }); // connect to v5 server diff --git a/tests/test_server_v5.rs b/tests/test_server_v5.rs index 41496fcc..b8bd6454 100644 --- a/tests/test_server_v5.rs +++ b/tests/test_server_v5.rs @@ -1,11 +1,10 @@ use std::sync::{atomic::AtomicBool, atomic::Ordering::Relaxed, Arc}; use std::{convert::TryFrom, num::NonZeroU16, time::Duration}; -use futures::{future::ok, FutureExt, SinkExt, StreamExt}; -use ntex::codec::Framed; +use futures::FutureExt; use ntex::server; use ntex::time::sleep; -use ntex::util::{poll_fn, ByteString, Bytes}; +use ntex::util::{ByteString, Bytes, Ready}; use ntex_mqtt::v5::{ client, codec, error, ControlMessage, Handshake, HandshakeAck, MqttServer, Publish, @@ -43,14 +42,16 @@ fn pkt_publish() -> codec::Publish { } } -async fn handshake(packet: Handshake) -> Result, TestError> { +async fn handshake(packet: Handshake) -> Result, TestError> { Ok(packet.ack(St)) } #[ntex::test] async fn test_simple() -> std::io::Result<()> { let srv = server::test_server(|| { - MqttServer::new(handshake).publish(|p: Publish| ok::<_, TestError>(p.ack())).finish() + MqttServer::new(handshake) + .publish(|p: Publish| Ready::Ok::<_, TestError>(p.ack())) + .finish() }); // connect to server @@ -74,7 +75,7 @@ async fn test_disconnect() -> std::io::Result<()> { let srv = server::test_server(|| { MqttServer::new(handshake) .publish(ntex::service::fn_factory_with_config(|session: Session| { - ok::<_, TestError>(ntex::service::fn_service(move |p: Publish| { + Ready::Ok::<_, TestError>(ntex::service::fn_service(move |p: Publish| { session.sink().close(); async move { sleep(Duration::from_millis(100)).await; @@ -105,7 +106,7 @@ async fn test_disconnect_with_reason() -> std::io::Result<()> { let srv = server::test_server(|| { MqttServer::new(handshake) .publish(ntex::service::fn_factory_with_config(|session: Session| { - ok::<_, TestError>(ntex::service::fn_service(move |p: Publish| { + Ready::Ok::<_, TestError>(ntex::service::fn_service(move |p: Publish| { let pkt = codec::Disconnect { reason_code: codec::DisconnectReasonCode::ServerMoved, ..Default::default() @@ -143,30 +144,32 @@ async fn test_ping() -> std::io::Result<()> { let srv = server::test_server(move || { let ping = ping2.clone(); MqttServer::new(handshake) - .publish(|p: Publish| ok::<_, TestError>(p.ack())) + .publish(|p: Publish| Ready::Ok::<_, TestError>(p.ack())) .control(move |msg| { let ping = ping.clone(); match msg { ControlMessage::Ping(msg) => { ping.store(true, Relaxed); - ok::<_, TestError>(msg.ack()) + Ready::Ok::<_, TestError>(msg.ack()) } - _ => ok(msg.disconnect_with(codec::Disconnect::default())), + _ => Ready::Ok(msg.disconnect_with(codec::Disconnect::default())), } }) .finish() }); let io = srv.connect().await.unwrap(); - let mut framed = Framed::new(io, codec::Codec::new()); - framed - .send(codec::Packet::Connect(Box::new(codec::Connect::default().client_id("user")))) - .await - .unwrap(); - let _ = framed.next().await.unwrap().unwrap(); - - framed.send(codec::Packet::PingRequest).await.unwrap(); - let pkt = framed.next().await.unwrap().unwrap(); + let codec = codec::Codec::new(); + io.send( + codec::Packet::Connect(Box::new(codec::Connect::default().client_id("user"))), + &codec, + ) + .await + .unwrap(); + let _ = io.next(&codec).await.unwrap().unwrap(); + + io.send(codec::Packet::PingRequest, &codec).await.unwrap(); + let pkt = io.next(&codec).await.unwrap().unwrap(); assert_eq!(pkt, codec::Packet::PingResponse); assert!(ping.load(Relaxed)); @@ -187,50 +190,51 @@ async fn test_ack_order() -> std::io::Result<()> { sub.options(); sub.subscribe(codec::QoS::AtLeastOnce); } - ok::<_, TestError>(msg.ack()) + Ready::Ok::<_, TestError>(msg.ack()) } - _ => ok(msg.disconnect()), + _ => Ready::Ok(msg.disconnect()), }) .finish() }); let io = srv.connect().await.unwrap(); - let mut framed = Framed::new(io, codec::Codec::default()); - framed - .send(codec::Packet::Connect(Box::new(codec::Connect::default().client_id("user")))) - .await - .unwrap(); - let _ = framed.next().await.unwrap().unwrap(); - - framed - .send( - codec::Publish { packet_id: Some(NonZeroU16::new(1).unwrap()), ..pkt_publish() } - .into(), - ) - .await - .unwrap(); - framed - .send( - codec::Subscribe { - id: None, - packet_id: NonZeroU16::new(2).unwrap(), - user_properties: Default::default(), - topic_filters: vec![( - ByteString::from("topic1"), - codec::SubscriptionOptions { - qos: codec::QoS::AtLeastOnce, - no_local: false, - retain_as_published: false, - retain_handling: codec::RetainHandling::AtSubscribe, - }, - )], - } - .into(), - ) - .await - .unwrap(); + let codec = codec::Codec::default(); + io.send( + codec::Packet::Connect(Box::new(codec::Connect::default().client_id("user"))), + &codec, + ) + .await + .unwrap(); + let _ = io.next(&codec).await.unwrap().unwrap(); + + io.send( + codec::Publish { packet_id: Some(NonZeroU16::new(1).unwrap()), ..pkt_publish() }.into(), + &codec, + ) + .await + .unwrap(); + io.send( + codec::Subscribe { + id: None, + packet_id: NonZeroU16::new(2).unwrap(), + user_properties: Default::default(), + topic_filters: vec![( + ByteString::from("topic1"), + codec::SubscriptionOptions { + qos: codec::QoS::AtLeastOnce, + no_local: false, + retain_as_published: false, + retain_handling: codec::RetainHandling::AtSubscribe, + }, + )], + } + .into(), + &codec, + ) + .await + .unwrap(); - let pkt = framed.next().await.unwrap().unwrap(); + let pkt = io.next(&codec).await.unwrap().unwrap(); assert_eq!( pkt, codec::Packet::PublishAck(codec::PublishAck { @@ -241,7 +245,7 @@ async fn test_ack_order() -> std::io::Result<()> { }) ); - let pkt = framed.next().await.unwrap().unwrap(); + let pkt = io.next(&codec).await.unwrap().unwrap(); assert_eq!( pkt, codec::Packet::SubscribeAck(codec::SubscribeAck { @@ -266,69 +270,69 @@ async fn test_dups() { }); let io = srv.connect().await.unwrap(); - let mut framed = Framed::new(io, codec::Codec::default()); - framed - .send(codec::Packet::Connect(Box::new( + let codec = codec::Codec::default(); + io.send( + codec::Packet::Connect(Box::new( codec::Connect::default().client_id("user").receive_max(2), - ))) - .await - .unwrap(); - let _ = framed.next().await.unwrap().unwrap(); - - framed - .send( - codec::Publish { packet_id: Some(NonZeroU16::new(1).unwrap()), ..pkt_publish() } - .into(), - ) - .await - .unwrap(); + )), + &codec, + ) + .await + .unwrap(); + let _ = io.next(&codec).await.unwrap().unwrap(); + + io.send( + codec::Publish { packet_id: Some(NonZeroU16::new(1).unwrap()), ..pkt_publish() }.into(), + &codec, + ) + .await + .unwrap(); // send packet_id dup - framed - .send( - codec::Publish { packet_id: Some(NonZeroU16::new(1).unwrap()), ..pkt_publish() } - .into(), - ) - .await - .unwrap(); + io.send( + codec::Publish { packet_id: Some(NonZeroU16::new(1).unwrap()), ..pkt_publish() }.into(), + &codec, + ) + .await + .unwrap(); // send subscribe dup - framed - .send( - codec::Subscribe { - id: None, - packet_id: NonZeroU16::new(1).unwrap(), - user_properties: Default::default(), - topic_filters: vec![( - ByteString::from("topic1"), - codec::SubscriptionOptions { - qos: codec::QoS::AtLeastOnce, - no_local: false, - retain_as_published: false, - retain_handling: codec::RetainHandling::AtSubscribe, - }, - )], - } - .into(), - ) - .await - .unwrap(); + io.send( + codec::Subscribe { + id: None, + packet_id: NonZeroU16::new(1).unwrap(), + user_properties: Default::default(), + topic_filters: vec![( + ByteString::from("topic1"), + codec::SubscriptionOptions { + qos: codec::QoS::AtLeastOnce, + no_local: false, + retain_as_published: false, + retain_handling: codec::RetainHandling::AtSubscribe, + }, + )], + } + .into(), + &codec, + ) + .await + .unwrap(); // send unsubscribe dup - framed - .send( - codec::Unsubscribe { - packet_id: NonZeroU16::new(1).unwrap(), - user_properties: Default::default(), - topic_filters: vec![ByteString::from("topic1")], - } - .into(), - ) - .await - .unwrap(); + io.send( + codec::Unsubscribe { + packet_id: NonZeroU16::new(1).unwrap(), + user_properties: Default::default(), + topic_filters: vec![ByteString::from("topic1")], + } + .into(), + &codec, + ) + .await + .unwrap(); // PublishAck - let pkt = framed.next().await.unwrap().unwrap(); + let pkt = io.next(&codec).await.unwrap().unwrap(); assert_eq!( pkt, codec::Packet::PublishAck(codec::PublishAck { @@ -340,7 +344,7 @@ async fn test_dups() { ); // SubscribeAck - let pkt = framed.next().await.unwrap().unwrap(); + let pkt = io.next(&codec).await.unwrap().unwrap(); assert_eq!( pkt, codec::SubscribeAck { @@ -353,7 +357,7 @@ async fn test_dups() { ); // UnsubscribeAck - let pkt = framed.next().await.unwrap().unwrap(); + let pkt = io.next(&codec).await.unwrap().unwrap(); assert_eq!( pkt, codec::UnsubscribeAck { @@ -376,19 +380,21 @@ async fn test_max_receive() { sleep(Duration::from_millis(10000)).map(move |_| Ok::<_, TestError>(p.ack())) }) .control(move |msg| match msg { - ControlMessage::ProtocolError(msg) => ok::<_, TestError>(msg.ack()), - _ => ok(msg.disconnect()), + ControlMessage::ProtocolError(msg) => Ready::Ok::<_, TestError>(msg.ack()), + _ => Ready::Ok(msg.disconnect()), }) .finish() }); let io = srv.connect().await.unwrap(); - let mut framed = Framed::new(io, codec::Codec::default()); - - framed - .send(codec::Packet::Connect(Box::new(codec::Connect::default().client_id("user")))) - .await - .unwrap(); - let ack = framed.next().await.unwrap().unwrap(); + let codec = codec::Codec::default(); + + io.send( + codec::Packet::Connect(Box::new(codec::Connect::default().client_id("user"))), + &codec, + ) + .await + .unwrap(); + let ack = io.next(&codec).await.unwrap().unwrap(); assert_eq!( ack, codec::Packet::ConnectAck(Box::new(codec::ConnectAck { @@ -400,21 +406,19 @@ async fn test_max_receive() { })) ); - framed - .send( - codec::Publish { packet_id: Some(NonZeroU16::new(1).unwrap()), ..pkt_publish() } - .into(), - ) - .await - .unwrap(); - framed - .send( - codec::Publish { packet_id: Some(NonZeroU16::new(2).unwrap()), ..pkt_publish() } - .into(), - ) - .await - .unwrap(); - let pkt = framed.next().await.unwrap().unwrap(); + io.send( + codec::Publish { packet_id: Some(NonZeroU16::new(1).unwrap()), ..pkt_publish() }.into(), + &codec, + ) + .await + .unwrap(); + io.send( + codec::Publish { packet_id: Some(NonZeroU16::new(2).unwrap()), ..pkt_publish() }.into(), + &codec, + ) + .await + .unwrap(); + let pkt = io.next(&codec).await.unwrap().unwrap(); assert_eq!( pkt, codec::Packet::Disconnect(codec::Disconnect { @@ -435,16 +439,16 @@ async fn test_keepalive() { let srv = server::test_server(move || { let ka = ka2.clone(); - MqttServer::new(|con: Handshake<_>| async move { Ok(con.ack(St).keep_alive(1)) }) + MqttServer::new(|con: Handshake| async move { Ok(con.ack(St).keep_alive(1)) }) .publish(|p: Publish| async move { Ok::<_, TestError>(p.ack()) }) .control(move |msg| match msg { ControlMessage::ProtocolError(msg) => { if let &error::ProtocolError::KeepAliveTimeout = msg.get_ref() { ka.store(true, Relaxed); } - ok::<_, TestError>(msg.ack()) + Ready::Ok::<_, TestError>(msg.ack()) } - _ => ok(msg.disconnect()), + _ => Ready::Ok(msg.disconnect()), }) .finish() }); @@ -471,16 +475,16 @@ async fn test_keepalive2() { let srv = server::test_server(move || { let ka = ka2.clone(); - MqttServer::new(|con: Handshake<_>| async move { Ok(con.ack(St).keep_alive(1)) }) + MqttServer::new(|con: Handshake| async move { Ok(con.ack(St).keep_alive(1)) }) .publish(|p: Publish| async move { Ok::<_, TestError>(p.ack()) }) .control(move |msg| match msg { ControlMessage::ProtocolError(msg) => { if let &error::ProtocolError::KeepAliveTimeout = msg.get_ref() { ka.store(true, Relaxed); } - ok::<_, TestError>(msg.ack()) + Ready::Ok::<_, TestError>(msg.ack()) } - _ => ok(msg.disconnect()), + _ => Ready::Ok(msg.disconnect()), }) .finish() }); @@ -497,11 +501,11 @@ async fn test_keepalive2() { let res = sink.publish(ByteString::from_static("#"), Bytes::new()).send_at_least_once().await; assert!(res.is_ok()); - sleep(Duration::from_millis(1200)).await; + sleep(Duration::from_millis(500)).await; let res = sink.publish(ByteString::from_static("#"), Bytes::new()).send_at_least_once().await; assert!(res.is_ok()); - sleep(Duration::from_millis(2500)).await; + sleep(Duration::from_millis(2000)).await; assert!(!sink.is_open()); assert!(ka.load(Relaxed)); @@ -510,7 +514,7 @@ async fn test_keepalive2() { #[ntex::test] async fn test_sink_encoder_error_pub_qos1() { let srv = server::test_server(move || { - MqttServer::new(|con: Handshake<_>| async move { + MqttServer::new(|con: Handshake| async move { let builder = con.sink().publish("test", Bytes::new()).properties(|props| { props.user_properties.push(( "ssssssssssssssssssssssssssssssssssss".into(), @@ -530,8 +534,8 @@ async fn test_sink_encoder_error_pub_qos1() { sleep(Duration::from_millis(50)).map(move |_| Ok::<_, TestError>(p.ack())) }) .control(move |msg| match msg { - ControlMessage::ProtocolError(msg) => ok::<_, TestError>(msg.ack()), - _ => ok(msg.disconnect()), + ControlMessage::ProtocolError(msg) => Ready::Ok::<_, TestError>(msg.ack()), + _ => Ready::Ok(msg.disconnect()), }) .finish() }); @@ -556,7 +560,7 @@ async fn test_sink_encoder_error_pub_qos1() { #[ntex::test] async fn test_sink_encoder_error_pub_qos0() { let srv = server::test_server(move || { - MqttServer::new(|con: Handshake<_>| async move { + MqttServer::new(|con: Handshake| async move { let builder = con.sink().publish("test", Bytes::new()).properties(|props| { props.user_properties.push(( "ssssssssssssssssssssssssssssssssssss".into(), @@ -574,8 +578,8 @@ async fn test_sink_encoder_error_pub_qos0() { sleep(Duration::from_millis(50)).map(move |_| Ok::<_, TestError>(p.ack())) }) .control(move |msg| match msg { - ControlMessage::ProtocolError(msg) => ok::<_, TestError>(msg.ack()), - _ => ok(msg.disconnect()), + ControlMessage::ProtocolError(msg) => Ready::Ok::<_, TestError>(msg.ack()), + _ => Ready::Ok(msg.disconnect()), }) .finish() }); @@ -600,7 +604,7 @@ async fn test_sink_encoder_error_pub_qos0() { #[ntex::test] async fn test_request_problem_info() { let srv = server::test_server(move || { - MqttServer::new(|con: Handshake<_>| async move { Ok(con.ack(St)) }) + MqttServer::new(|con: Handshake| async move { Ok(con.ack(St)) }) .publish(|p: Publish| async move { Ok::<_, TestError>( p.ack() @@ -647,23 +651,25 @@ async fn test_suback_with_reason() -> std::io::Result<()> { msg.iter_mut().for_each(|mut s| { s.fail(codec::SubscribeAckReason::ImplementationSpecificError) }); - ok::<_, TestError>(msg.ack_reason("some reason".into()).ack()) + Ready::Ok::<_, TestError>(msg.ack_reason("some reason".into()).ack()) } - _ => ok(msg.disconnect()), + _ => Ready::Ok(msg.disconnect()), }) .finish() }); let io = srv.connect().await.unwrap(); - let mut framed = Framed::new(io, codec::Codec::new()); - framed - .send(codec::Packet::Connect(Box::new(codec::Connect::default().client_id("user")))) - .await - .unwrap(); - let _ = framed.next().await.unwrap().unwrap(); - - framed - .send(codec::Packet::Subscribe(codec::Subscribe { + let codec = codec::Codec::new(); + io.send( + codec::Packet::Connect(Box::new(codec::Connect::default().client_id("user"))), + &codec, + ) + .await + .unwrap(); + let _ = io.next(&codec).await.unwrap().unwrap(); + + io.send( + codec::Packet::Subscribe(codec::Subscribe { packet_id: NonZeroU16::new(1).unwrap(), topic_filters: vec![( "topic1".into(), @@ -676,10 +682,12 @@ async fn test_suback_with_reason() -> std::io::Result<()> { )], id: None, user_properties: codec::UserProperties::default(), - })) - .await - .unwrap(); - let pkt = framed.next().await.unwrap().unwrap(); + }), + &codec, + ) + .await + .unwrap(); + let pkt = io.next(&codec).await.unwrap().unwrap(); assert_eq!( pkt, codec::Packet::SubscribeAck(codec::SubscribeAck { @@ -706,36 +714,41 @@ async fn test_handle_incoming() -> std::io::Result<()> { MqttServer::new(handshake) .publish(move |p: Publish| { publish.store(true, Relaxed); - ok::<_, TestError>(p.ack()) + Ready::Ok::<_, TestError>(p.ack()) }) .control(move |msg| match msg { ControlMessage::Disconnect(msg) => { disconnect.store(true, Relaxed); - ok::<_, TestError>(msg.ack()) + Ready::Ok::<_, TestError>(msg.ack()) } - _ => ok(msg.disconnect()), + _ => Ready::Ok(msg.disconnect()), }) .finish() }); let io = srv.connect().await.unwrap(); - let mut framed = Framed::new(io, codec::Codec::default()); - framed - .write(codec::Packet::Connect(Box::new(codec::Connect::default().client_id("user")))) - .unwrap(); - framed.write(pkt_publish().into()).unwrap(); - framed - .write(codec::Packet::Disconnect(codec::Disconnect { + let codec = codec::Codec::default(); + io.encode( + codec::Packet::Connect(Box::new(codec::Connect::default().client_id("user"))), + &codec, + ) + .unwrap(); + io.encode(pkt_publish().into(), &codec).unwrap(); + io.encode( + codec::Packet::Disconnect(codec::Disconnect { reason_code: codec::DisconnectReasonCode::ReceiveMaximumExceeded, session_expiry_interval_secs: None, server_reference: None, reason_string: None, user_properties: Default::default(), - })) - .unwrap(); - poll_fn(|cx| framed.flush(cx)).await.unwrap(); - drop(framed); - sleep(Duration::from_millis(500)).await; + }), + &codec, + ) + .unwrap(); + io.write_ready(true).await.unwrap(); + sleep(Duration::from_millis(50)).await; + drop(io); + sleep(Duration::from_millis(50)).await; assert!(publish.load(Relaxed)); assert!(disconnect.load(Relaxed));