From efa5c7b935666e3702bff71af77adcbf81acd01f Mon Sep 17 00:00:00 2001 From: Bugen Zhao Date: Mon, 3 Feb 2025 19:10:31 +0800 Subject: [PATCH 1/8] interface for poll notice --- src/frontend/src/session.rs | 18 ++++++++++ src/utils/pgwire/src/pg_protocol.rs | 6 ++-- src/utils/pgwire/src/pg_server.rs | 56 +++++++++++++++++++++++++---- 3 files changed, 70 insertions(+), 10 deletions(-) diff --git a/src/frontend/src/session.rs b/src/frontend/src/session.rs index dbc9f4f91a010..a43eb91fe44b9 100644 --- a/src/frontend/src/session.rs +++ b/src/frontend/src/session.rs @@ -17,6 +17,7 @@ use std::io::{Error, ErrorKind}; use std::net::{IpAddr, Ipv4Addr, SocketAddr}; use std::sync::atomic::{AtomicI32, Ordering}; use std::sync::{Arc, Weak}; +use std::task::{Context, Poll}; use std::time::{Duration, Instant}; use anyhow::anyhow; @@ -79,6 +80,7 @@ use risingwave_sqlparser::ast::{ObjectName, Statement}; use risingwave_sqlparser::parser::Parser; use thiserror::Error; use tokio::runtime::Builder; +use tokio::sync::mpsc::{self, UnboundedReceiver, UnboundedSender}; use tokio::sync::oneshot::Sender; use tokio::sync::watch; use tokio::task::JoinHandle; @@ -647,8 +649,12 @@ pub struct SessionImpl { /// Stores the value of configurations. config_map: Arc>, /// buffer the Notices to users, + #[deprecated] notices: RwLock>, + notice_tx: UnboundedSender, + notice_rx: Mutex>, + /// Identified by `process_id`, `secret_key`. Corresponds to `SessionManager`. id: (i32, i32), @@ -741,6 +747,8 @@ impl SessionImpl { session_config: SessionConfig, ) -> Self { let cursor_metrics = env.cursor_metrics.clone(); + let (notice_tx, notice_rx) = mpsc::unbounded_channel(); + Self { env, auth_context: Arc::new(RwLock::new(auth_context)), @@ -751,6 +759,8 @@ impl SessionImpl { txn: Default::default(), current_query_cancel_flag: Mutex::new(None), notices: Default::default(), + notice_tx, + notice_rx: Mutex::new(notice_rx), exec_context: Mutex::new(None), last_idle_instant: Default::default(), cursor_manager: Arc::new(CursorManager::new(cursor_metrics)), @@ -761,6 +771,8 @@ impl SessionImpl { #[cfg(test)] pub fn mock() -> Self { let env = FrontendEnv::mock(); + let (notice_tx, notice_rx) = mpsc::unbounded_channel(); + Self { env: FrontendEnv::mock(), auth_context: Arc::new(RwLock::new(AuthContext::new( @@ -775,6 +787,8 @@ impl SessionImpl { txn: Default::default(), current_query_cancel_flag: Mutex::new(None), notices: Default::default(), + notice_tx, + notice_rx: Mutex::new(notice_rx), exec_context: Mutex::new(None), peer_addr: Address::Tcp(SocketAddr::new( IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), @@ -1591,6 +1605,10 @@ impl Session for SessionImpl { std::mem::take(inner) } + fn poll_next_notice(self: Arc, ctx: &mut Context<'_>) -> Poll> { + self.notice_rx.lock().poll_recv(ctx) + } + fn transaction_status(&self) -> TransactionStatus { match &*self.txn.lock() { transaction::State::Initial | transaction::State::Implicit(_) => { diff --git a/src/utils/pgwire/src/pg_protocol.rs b/src/utils/pgwire/src/pg_protocol.rs index 122017d7cbe78..6bfd96ad9f596 100644 --- a/src/utils/pgwire/src/pg_protocol.rs +++ b/src/utils/pgwire/src/pg_protocol.rs @@ -74,14 +74,14 @@ where SM: SessionManager, { /// Used for write/read pg messages. - stream: PgStream, + pub stream: PgStream, /// Current states of pg connection. state: PgProtocolState, /// Whether the connection is terminated. is_terminate: bool, session_mgr: Arc, - session: Option>, + pub session: Option>, result_cache: HashMap::ValuesStream>>, unnamed_prepare_statement: Option<::PreparedStatement>, @@ -1097,7 +1097,7 @@ where Ok(()) } - async fn flush(&mut self) -> io::Result<()> { + pub async fn flush(&mut self) -> io::Result<()> { let mut stream = self.stream.lock().await; match &mut *stream { PgStreamInner::Placeholder => unreachable!(), diff --git a/src/utils/pgwire/src/pg_server.rs b/src/utils/pgwire/src/pg_server.rs index 0b0576780d026..99e4bb330e38d 100644 --- a/src/utils/pgwire/src/pg_server.rs +++ b/src/utils/pgwire/src/pg_server.rs @@ -16,9 +16,11 @@ use std::collections::HashMap; use std::future::Future; use std::str::FromStr; use std::sync::Arc; +use std::task::{Context, Poll}; use std::time::Instant; use bytes::Bytes; +use futures::StreamExt; use jsonwebtoken::{decode, decode_header, Algorithm, DecodingKey, Validation}; use parking_lot::Mutex; use risingwave_common::types::DataType; @@ -99,6 +101,8 @@ pub trait Session: Send + Sync { /// try to take the current notices from the session fn take_notices(self: Arc) -> Vec; + fn poll_next_notice(self: Arc, ctx: &mut Context<'_>) -> Poll>; + fn bind( self: Arc, prepare_statement: Self::PreparedStatement, @@ -341,16 +345,49 @@ pub async fn handle_connection( peer_addr, redact_sql_option_keywords, ); + + let mut pg_stream = pg_proto.stream.clone(); + loop { - let msg = match pg_proto.read_message().await { - Ok(msg) => msg, - Err(e) => { - tracing::error!(error = %e.as_report(), "error when reading message"); - break; + let session = pg_proto.session.clone(); + + let mut process = std::pin::pin!(async { + let msg = match pg_proto.read_message().await { + Ok(msg) => msg, + Err(e) => { + tracing::error!(error = %e.as_report(), "error when reading message"); + return false; + } + }; + tracing::trace!(?msg, "received message"); + pg_proto.process(msg).await + }); + + let ret = if let Some(session) = session { + let mut notice_stream = + futures::stream::poll_fn(move |ctx| session.clone().poll_next_notice(ctx)) + .ready_chunks(16); + + loop { + tokio::select! { + notices = notice_stream.next() => { + if let Some(notices) = notices { + for notice in notices { + pg_stream.write_no_flush(&crate::pg_message::BeMessage::NoticeResponse(¬ice)).ok(); + } + pg_stream.flush().await.ok(); + } + } + + ret = &mut process => { + break ret; + } + } } + } else { + process.await }; - tracing::trace!("Received message: {:?}", msg); - let ret = pg_proto.process(msg).await; + if ret { break; } @@ -361,6 +398,7 @@ pub async fn handle_connection( mod tests { use std::error::Error; use std::sync::Arc; + use std::task::{Context, Poll}; use std::time::Instant; use bytes::Bytes; @@ -509,6 +547,10 @@ mod tests { vec![] } + fn poll_next_notice(self: Arc, _ctx: &mut Context<'_>) -> Poll> { + Poll::Pending + } + fn transaction_status(&self) -> TransactionStatus { TransactionStatus::Idle } From 5a40dc8cb22e16ecf01252c310b7bd7411c6e22a Mon Sep 17 00:00:00 2001 From: Bugen Zhao Date: Tue, 4 Feb 2025 14:40:27 +0800 Subject: [PATCH 2/8] do not use poll interface, refactor code Signed-off-by: Bugen Zhao --- src/frontend/src/session.rs | 27 +++-------- src/utils/pgwire/src/pg_protocol.rs | 59 ++++++++++++++++++++++-- src/utils/pgwire/src/pg_server.rs | 71 +++-------------------------- 3 files changed, 68 insertions(+), 89 deletions(-) diff --git a/src/frontend/src/session.rs b/src/frontend/src/session.rs index a43eb91fe44b9..69ed7195edcb6 100644 --- a/src/frontend/src/session.rs +++ b/src/frontend/src/session.rs @@ -17,7 +17,6 @@ use std::io::{Error, ErrorKind}; use std::net::{IpAddr, Ipv4Addr, SocketAddr}; use std::sync::atomic::{AtomicI32, Ordering}; use std::sync::{Arc, Weak}; -use std::task::{Context, Poll}; use std::time::{Duration, Instant}; use anyhow::anyhow; @@ -648,9 +647,6 @@ pub struct SessionImpl { user_authenticator: UserAuthenticator, /// Stores the value of configurations. config_map: Arc>, - /// buffer the Notices to users, - #[deprecated] - notices: RwLock>, notice_tx: UnboundedSender, notice_rx: Mutex>, @@ -758,7 +754,6 @@ impl SessionImpl { peer_addr, txn: Default::default(), current_query_cancel_flag: Mutex::new(None), - notices: Default::default(), notice_tx, notice_rx: Mutex::new(notice_rx), exec_context: Mutex::new(None), @@ -786,7 +781,6 @@ impl SessionImpl { id: (0, 0), txn: Default::default(), current_query_cancel_flag: Mutex::new(None), - notices: Default::default(), notice_tx, notice_rx: Mutex::new(notice_rx), exec_context: Mutex::new(None), @@ -1157,10 +1151,6 @@ impl SessionImpl { shutdown_rx } - fn clear_notices(&self) { - *self.notices.write() = vec![]; - } - pub fn cancel_current_query(&self) { let mut flag_guard = self.current_query_cancel_flag.lock(); if let Some(sender) = flag_guard.take() { @@ -1172,12 +1162,10 @@ impl SessionImpl { info!("Trying to cancel query in distributed mode."); self.env.query_manager().cancel_queries_in_session(self.id) } - self.clear_notices() } pub fn cancel_current_creating_job(&self) { self.env.creating_streaming_job_tracker.abort_jobs(self.id); - self.clear_notices() } /// This function only used for test now. @@ -1209,7 +1197,9 @@ impl SessionImpl { pub fn notice_to_user(&self, str: impl Into) { let notice = str.into(); tracing::trace!(notice, "notice to user"); - self.notices.write().push(notice); + self.notice_tx + .send(notice) + .expect("notice channel should not be closed"); } pub fn is_barrier_read(&self) -> bool { @@ -1600,13 +1590,10 @@ impl Session for SessionImpl { Self::set_config(self, key, value).map_err(Into::into) } - fn take_notices(self: Arc) -> Vec { - let inner = &mut (*self.notices.write()); - std::mem::take(inner) - } - - fn poll_next_notice(self: Arc, ctx: &mut Context<'_>) -> Poll> { - self.notice_rx.lock().poll_recv(ctx) + async fn next_notice(self: &Arc) -> String { + std::future::poll_fn(|cx| self.clone().notice_rx.lock().poll_recv(cx)) + .await + .expect("notice channel should not be closed") } fn transaction_status(&self) -> TransactionStatus { diff --git a/src/utils/pgwire/src/pg_protocol.rs b/src/utils/pgwire/src/pg_protocol.rs index 6bfd96ad9f596..a46b88bdb9143 100644 --- a/src/utils/pgwire/src/pg_protocol.rs +++ b/src/utils/pgwire/src/pg_protocol.rs @@ -23,6 +23,7 @@ use std::{io, str}; use bytes::{Bytes, BytesMut}; use futures::stream::StreamExt; +use futures::FutureExt; use itertools::Itertools; use openssl::ssl::{SslAcceptor, SslContext, SslContextRef, SslMethod}; use risingwave_common::types::DataType; @@ -74,14 +75,14 @@ where SM: SessionManager, { /// Used for write/read pg messages. - pub stream: PgStream, + stream: PgStream, /// Current states of pg connection. state: PgProtocolState, /// Whether the connection is terminated. is_terminate: bool, session_mgr: Arc, - pub session: Option>, + session: Option>, result_cache: HashMap::ValuesStream>>, unnamed_prepare_statement: Option<::PreparedStatement>, @@ -213,6 +214,51 @@ where } } + /// Run the protocol to serve the connection. + pub async fn run(mut self) { + let mut notice_stream = self.stream.clone(); + + loop { + let session = self.session.clone(); + + let mut process = std::pin::pin!(async { + let msg = match self.read_message().await { + Ok(msg) => msg, + Err(e) => { + tracing::error!(error = %e.as_report(), "error when reading message"); + return false; + } + }; + tracing::trace!(?msg, "received message"); + self.process(msg).await + }); + + let terminated = if let Some(session) = session { + // If a session is present, subscribe and send notices asynchronously + // while processing the message. + loop { + tokio::select! { + notice = session.next_notice() => { + notice_stream.write(&BeMessage::NoticeResponse(¬ice)).await.inspect_err(|e| { + tracing::error!(error = %e.as_report(), notice, "failed to send notice"); + }).ok(); + } + terminated = &mut process => { + break terminated; + } + } + } + } else { + // Otherwise, just process the message. + process.await + }; + + if terminated { + break; + } + } + } + /// Processes one message. Returns true if the connection is terminated. pub async fn process(&mut self, msg: FeMessage) -> bool { self.do_process(msg).await.is_none() || self.is_terminate @@ -615,10 +661,13 @@ where .clone() .run_one_query(stmt.clone(), Format::Text) .await; - for notice in session.take_notices() { + + // Take all remaining notices (if any) and send them before `CommandComplete`. + while let Some(notice) = session.next_notice().now_or_never() { self.stream .write_no_flush(&BeMessage::NoticeResponse(¬ice))?; } + let mut res = res.map_err(PsqlError::SimpleQueryError)?; for notice in res.notices() { @@ -1091,13 +1140,13 @@ where BeMessage::write(&mut self.write_buf, message) } - async fn write(&mut self, message: &BeMessage<'_>) -> io::Result<()> { + pub async fn write(&mut self, message: &BeMessage<'_>) -> io::Result<()> { self.write_no_flush(message)?; self.flush().await?; Ok(()) } - pub async fn flush(&mut self) -> io::Result<()> { + async fn flush(&mut self) -> io::Result<()> { let mut stream = self.stream.lock().await; match &mut *stream { PgStreamInner::Placeholder => unreachable!(), diff --git a/src/utils/pgwire/src/pg_server.rs b/src/utils/pgwire/src/pg_server.rs index 99e4bb330e38d..9488c4f3512f2 100644 --- a/src/utils/pgwire/src/pg_server.rs +++ b/src/utils/pgwire/src/pg_server.rs @@ -16,11 +16,9 @@ use std::collections::HashMap; use std::future::Future; use std::str::FromStr; use std::sync::Arc; -use std::task::{Context, Poll}; use std::time::Instant; use bytes::Bytes; -use futures::StreamExt; use jsonwebtoken::{decode, decode_header, Algorithm, DecodingKey, Validation}; use parking_lot::Mutex; use risingwave_common::types::DataType; @@ -97,11 +95,7 @@ pub trait Session: Send + Sync { params_types: Vec>, ) -> impl Future> + Send; - // TODO: maybe this function should be async and return the notice more timely - /// try to take the current notices from the session - fn take_notices(self: Arc) -> Vec; - - fn poll_next_notice(self: Arc, ctx: &mut Context<'_>) -> Poll>; + fn next_notice(self: &Arc) -> impl Future + Send; fn bind( self: Arc, @@ -338,67 +332,20 @@ pub async fn handle_connection( S: AsyncWrite + AsyncRead + Unpin, SM: SessionManager, { - let mut pg_proto = PgProtocol::new( + PgProtocol::new( stream, session_mgr, tls_config, peer_addr, redact_sql_option_keywords, - ); - - let mut pg_stream = pg_proto.stream.clone(); - - loop { - let session = pg_proto.session.clone(); - - let mut process = std::pin::pin!(async { - let msg = match pg_proto.read_message().await { - Ok(msg) => msg, - Err(e) => { - tracing::error!(error = %e.as_report(), "error when reading message"); - return false; - } - }; - tracing::trace!(?msg, "received message"); - pg_proto.process(msg).await - }); - - let ret = if let Some(session) = session { - let mut notice_stream = - futures::stream::poll_fn(move |ctx| session.clone().poll_next_notice(ctx)) - .ready_chunks(16); - - loop { - tokio::select! { - notices = notice_stream.next() => { - if let Some(notices) = notices { - for notice in notices { - pg_stream.write_no_flush(&crate::pg_message::BeMessage::NoticeResponse(¬ice)).ok(); - } - pg_stream.flush().await.ok(); - } - } - - ret = &mut process => { - break ret; - } - } - } - } else { - process.await - }; - - if ret { - break; - } - } + ) + .run() + .await; } - #[cfg(test)] mod tests { use std::error::Error; use std::sync::Arc; - use std::task::{Context, Poll}; use std::time::Instant; use bytes::Bytes; @@ -543,12 +490,8 @@ mod tests { Ok("".to_owned()) } - fn take_notices(self: Arc) -> Vec { - vec![] - } - - fn poll_next_notice(self: Arc, _ctx: &mut Context<'_>) -> Poll> { - Poll::Pending + async fn next_notice(self: &Arc) -> String { + std::future::pending().await } fn transaction_status(&self) -> TransactionStatus { From 179e4b0ee7ed2ba2fc34b197f95d902a2a4d6fe2 Mon Sep 17 00:00:00 2001 From: Bugen Zhao Date: Tue, 4 Feb 2025 14:53:16 +0800 Subject: [PATCH 3/8] add docs Signed-off-by: Bugen Zhao --- src/frontend/src/session.rs | 2 ++ src/utils/pgwire/src/pg_protocol.rs | 2 +- src/utils/pgwire/src/pg_server.rs | 3 +++ 3 files changed, 6 insertions(+), 1 deletion(-) diff --git a/src/frontend/src/session.rs b/src/frontend/src/session.rs index 69ed7195edcb6..27c18595cec44 100644 --- a/src/frontend/src/session.rs +++ b/src/frontend/src/session.rs @@ -648,7 +648,9 @@ pub struct SessionImpl { /// Stores the value of configurations. config_map: Arc>, + /// Channel sender for frontend handler to send notices. notice_tx: UnboundedSender, + /// Channel receiver for pgwire to take notices and send to clients. notice_rx: Mutex>, /// Identified by `process_id`, `secret_key`. Corresponds to `SessionManager`. diff --git a/src/utils/pgwire/src/pg_protocol.rs b/src/utils/pgwire/src/pg_protocol.rs index a46b88bdb9143..f06dc156d0530 100644 --- a/src/utils/pgwire/src/pg_protocol.rs +++ b/src/utils/pgwire/src/pg_protocol.rs @@ -1140,7 +1140,7 @@ where BeMessage::write(&mut self.write_buf, message) } - pub async fn write(&mut self, message: &BeMessage<'_>) -> io::Result<()> { + async fn write(&mut self, message: &BeMessage<'_>) -> io::Result<()> { self.write_no_flush(message)?; self.flush().await?; Ok(()) diff --git a/src/utils/pgwire/src/pg_server.rs b/src/utils/pgwire/src/pg_server.rs index 9488c4f3512f2..1f0de89a4f3c6 100644 --- a/src/utils/pgwire/src/pg_server.rs +++ b/src/utils/pgwire/src/pg_server.rs @@ -95,6 +95,9 @@ pub trait Session: Send + Sync { params_types: Vec>, ) -> impl Future> + Send; + /// Receive the next notice message to send to the client. + /// + /// This function should be cancellation-safe. fn next_notice(self: &Arc) -> impl Future + Send; fn bind( From 7c0056ef633b58962e194fcd2c1274081a3f20ce Mon Sep 17 00:00:00 2001 From: Bugen Zhao Date: Wed, 5 Feb 2025 15:44:50 +0800 Subject: [PATCH 4/8] avoid potential deadlock Signed-off-by: Bugen Zhao --- src/utils/pgwire/src/pg_protocol.rs | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/src/utils/pgwire/src/pg_protocol.rs b/src/utils/pgwire/src/pg_protocol.rs index f06dc156d0530..59de15d317660 100644 --- a/src/utils/pgwire/src/pg_protocol.rs +++ b/src/utils/pgwire/src/pg_protocol.rs @@ -237,12 +237,16 @@ where // If a session is present, subscribe and send notices asynchronously // while processing the message. loop { + let next_notice = async { + let notice = session.next_notice().await; + notice_stream.write(&BeMessage::NoticeResponse(¬ice)).await.inspect_err(|e| { + tracing::error!(error = %e.as_report(), notice, "failed to send notice"); + }).ok(); + }; + tokio::select! { - notice = session.next_notice() => { - notice_stream.write(&BeMessage::NoticeResponse(¬ice)).await.inspect_err(|e| { - tracing::error!(error = %e.as_report(), notice, "failed to send notice"); - }).ok(); - } + _ = next_notice => {} + terminated = &mut process => { break terminated; } From 82cd94461ba4d546fda0d2a168e9e48bbbfb7c4d Mon Sep 17 00:00:00 2001 From: Bugen Zhao Date: Wed, 5 Feb 2025 15:45:03 +0800 Subject: [PATCH 5/8] fix unexpected eof Signed-off-by: Bugen Zhao --- src/utils/pgwire/src/pg_protocol.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/utils/pgwire/src/pg_protocol.rs b/src/utils/pgwire/src/pg_protocol.rs index 59de15d317660..0b9b73c434111 100644 --- a/src/utils/pgwire/src/pg_protocol.rs +++ b/src/utils/pgwire/src/pg_protocol.rs @@ -226,7 +226,7 @@ where Ok(msg) => msg, Err(e) => { tracing::error!(error = %e.as_report(), "error when reading message"); - return false; + return true; // terminate the connection } }; tracing::trace!(?msg, "received message"); From d7c62d55beb42981d4044320eb8ec2c37cad6048 Mon Sep 17 00:00:00 2001 From: Bugen Zhao Date: Tue, 11 Feb 2025 15:29:35 +0800 Subject: [PATCH 6/8] spawn a task instead to resolve cancellation safety or deadlock Signed-off-by: Bugen Zhao --- Cargo.lock | 8 ++-- Cargo.toml | 2 +- src/utils/pgwire/Cargo.toml | 1 + src/utils/pgwire/src/lib.rs | 1 + src/utils/pgwire/src/pg_extended.rs | 5 +-- src/utils/pgwire/src/pg_protocol.rs | 67 +++++++++++++---------------- src/utils/pgwire/src/pg_server.rs | 5 +-- 7 files changed, 43 insertions(+), 46 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 3c440030ce471..d9f91424394ef 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -8840,6 +8840,7 @@ dependencies = [ "thiserror-ext", "tokio-openssl", "tokio-postgres", + "tokio-util", "tracing", "workspace-hack", ] @@ -14489,17 +14490,18 @@ dependencies = [ [[package]] name = "tokio-util" -version = "0.7.9" +version = "0.7.13" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1d68074620f57a0b21594d9735eb2e98ab38b17f80d3fcb189fca266771ca60d" +checksum = "d7fcaa8d55a2bdd6b83ace262b016eca0d79ee02818c5c1bcdf0305114081078" dependencies = [ "bytes", "futures-core", "futures-io", "futures-sink", + "futures-util", + "hashbrown 0.14.5", "pin-project-lite", "tokio", - "tracing", ] [[package]] diff --git a/Cargo.toml b/Cargo.toml index bf44f574d86ac..c15c8086e1fc0 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -217,7 +217,7 @@ tokio-stream = { git = "https://github.com/madsim-rs/tokio.git", rev = "0dd1055" "net", "fs", ] } -tokio-util = "0.7" +tokio-util = "0.7.12" tracing-opentelemetry = "0.25" rand = { version = "0.8", features = ["small_rng"] } governor = { version = "0.8", default-features = false, features = ["std"] } diff --git a/src/utils/pgwire/Cargo.toml b/src/utils/pgwire/Cargo.toml index 76074e78fa442..8c26969548d4e 100644 --- a/src/utils/pgwire/Cargo.toml +++ b/src/utils/pgwire/Cargo.toml @@ -34,6 +34,7 @@ tokio = { version = "0.2", package = "madsim-tokio", features = [ "macros", ] } tokio-openssl = "0.6.3" +tokio-util = { workspace = true, features = ["rt"] } tracing = "0.1" [target.'cfg(not(madsim))'.dependencies] diff --git a/src/utils/pgwire/src/lib.rs b/src/utils/pgwire/src/lib.rs index cb1b3370a4768..770db3c1c3597 100644 --- a/src/utils/pgwire/src/lib.rs +++ b/src/utils/pgwire/src/lib.rs @@ -18,6 +18,7 @@ #![feature(buf_read_has_data_left)] #![feature(round_char_boundary)] #![feature(never_type)] +#![feature(let_chains)] #![expect(clippy::doc_markdown, reason = "FIXME: later")] pub mod error; diff --git a/src/utils/pgwire/src/pg_extended.rs b/src/utils/pgwire/src/pg_extended.rs index 50d46ef1f7b65..0925cf491bd17 100644 --- a/src/utils/pgwire/src/pg_extended.rs +++ b/src/utils/pgwire/src/pg_extended.rs @@ -17,11 +17,10 @@ use std::vec::IntoIter; use futures::stream::FusedStream; use futures::{StreamExt, TryStreamExt}; use postgres_types::FromSql; -use tokio::io::{AsyncRead, AsyncWrite}; use crate::error::{PsqlError, PsqlResult}; use crate::pg_message::{BeCommandCompleteMessage, BeMessage}; -use crate::pg_protocol::PgStream; +use crate::pg_protocol::{PgByteStream, PgStream}; use crate::pg_response::{PgResponse, ValuesStream}; use crate::types::{Format, Row}; @@ -45,7 +44,7 @@ where } /// Return indicate whether the result is consumed completely. - pub async fn consume( + pub async fn consume( &mut self, row_limit: usize, msg_stream: &mut PgStream, diff --git a/src/utils/pgwire/src/pg_protocol.rs b/src/utils/pgwire/src/pg_protocol.rs index 0b9b73c434111..e047f5a598cf0 100644 --- a/src/utils/pgwire/src/pg_protocol.rs +++ b/src/utils/pgwire/src/pg_protocol.rs @@ -36,6 +36,7 @@ use thiserror_ext::AsReport; use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt}; use tokio::sync::Mutex; use tokio_openssl::SslStream; +use tokio_util::task::AbortOnDropHandle; use tracing::Instrument; use crate::error::{PsqlError, PsqlResult}; @@ -183,7 +184,7 @@ fn redact_sql(sql: &str, keywords: RedactSqlOptionKeywordsRef) -> String { impl PgProtocol where - S: AsyncWrite + AsyncRead + Unpin, + S: PgByteStream, SM: SessionManager, { pub fn new( @@ -215,47 +216,37 @@ where } /// Run the protocol to serve the connection. - pub async fn run(mut self) { - let mut notice_stream = self.stream.clone(); + pub async fn run(&mut self) { + let mut notice_task = None; loop { - let session = self.session.clone(); - - let mut process = std::pin::pin!(async { - let msg = match self.read_message().await { - Ok(msg) => msg, - Err(e) => { - tracing::error!(error = %e.as_report(), "error when reading message"); - return true; // terminate the connection - } - }; - tracing::trace!(?msg, "received message"); - self.process(msg).await - }); - - let terminated = if let Some(session) = session { - // If a session is present, subscribe and send notices asynchronously - // while processing the message. - loop { - let next_notice = async { + // If a session is present, spawn a task to subscribe and send notices asynchronously. + if notice_task.is_none() + && let Some(session) = self.session.clone() + { + let mut stream = self.stream.clone(); + let handle = tokio::spawn(async move { + loop { let notice = session.next_notice().await; - notice_stream.write(&BeMessage::NoticeResponse(¬ice)).await.inspect_err(|e| { + if let Err(e) = stream.write(&BeMessage::NoticeResponse(¬ice)).await { tracing::error!(error = %e.as_report(), notice, "failed to send notice"); - }).ok(); - }; - - tokio::select! { - _ = next_notice => {} - - terminated = &mut process => { - break terminated; + break; } } + }); + notice_task = Some(AbortOnDropHandle::new(handle)); + } + + // Read and process messages. + let msg = match self.read_message().await { + Ok(msg) => msg, + Err(e) => { + tracing::error!(error = %e.as_report(), "error when reading message"); + break; // terminate the connection } - } else { - // Otherwise, just process the message. - process.await }; + tracing::trace!(?msg, "received message"); + let terminated = self.process(msg).await; if terminated { break; @@ -1047,6 +1038,10 @@ enum PgStreamInner { Ssl(SslStream), } +/// Trait for a byte stream that can be used for pg protocol. +pub trait PgByteStream: AsyncWrite + AsyncRead + Unpin + Send + 'static {} +impl PgByteStream for S where S: AsyncWrite + AsyncRead + Unpin + Send + 'static {} + /// Wraps a byte stream and read/write pg messages. /// /// Cloning a `PgStream` will share the same stream but a fresh & independent write buffer, @@ -1102,7 +1097,7 @@ pub struct ParameterStatus { impl PgStream where - S: AsyncWrite + AsyncRead + Unpin, + S: PgByteStream, { async fn read_startup(&mut self) -> io::Result { let mut stream = self.stream.lock().await; @@ -1170,7 +1165,7 @@ where impl PgStream where - S: AsyncWrite + AsyncRead + Unpin, + S: PgByteStream, { /// Convert the underlying stream to ssl stream based on the given context. async fn upgrade_to_ssl(&mut self, ssl_ctx: &SslContextRef) -> PsqlResult<()> { diff --git a/src/utils/pgwire/src/pg_server.rs b/src/utils/pgwire/src/pg_server.rs index 1f0de89a4f3c6..28e772859c8ed 100644 --- a/src/utils/pgwire/src/pg_server.rs +++ b/src/utils/pgwire/src/pg_server.rs @@ -27,13 +27,12 @@ use risingwave_common::util::tokio_util::sync::CancellationToken; use risingwave_sqlparser::ast::{RedactSqlOptionKeywordsRef, Statement}; use serde::Deserialize; use thiserror_ext::AsReport; -use tokio::io::{AsyncRead, AsyncWrite}; use crate::error::{PsqlError, PsqlResult}; use crate::net::{AddressRef, Listener, TcpKeepalive}; use crate::pg_field_descriptor::PgFieldDescriptor; use crate::pg_message::TransactionStatus; -use crate::pg_protocol::{PgProtocol, TlsConfig}; +use crate::pg_protocol::{PgByteStream, PgProtocol, TlsConfig}; use crate::pg_response::{PgResponse, ValuesStream}; use crate::types::Format; @@ -332,7 +331,7 @@ pub async fn handle_connection( peer_addr: AddressRef, redact_sql_option_keywords: Option, ) where - S: AsyncWrite + AsyncRead + Unpin, + S: PgByteStream, SM: SessionManager, { PgProtocol::new( From a9a395e32b5dff6e5ca196a571d1887973e8f06d Mon Sep 17 00:00:00 2001 From: Bugen Zhao Date: Tue, 11 Feb 2025 15:38:36 +0800 Subject: [PATCH 7/8] manually abort the task Signed-off-by: Bugen Zhao --- Cargo.lock | 8 +++----- Cargo.toml | 2 +- src/utils/pgwire/Cargo.toml | 1 - src/utils/pgwire/src/pg_protocol.rs | 8 ++++++-- 4 files changed, 10 insertions(+), 9 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index d9f91424394ef..3c440030ce471 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -8840,7 +8840,6 @@ dependencies = [ "thiserror-ext", "tokio-openssl", "tokio-postgres", - "tokio-util", "tracing", "workspace-hack", ] @@ -14490,18 +14489,17 @@ dependencies = [ [[package]] name = "tokio-util" -version = "0.7.13" +version = "0.7.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d7fcaa8d55a2bdd6b83ace262b016eca0d79ee02818c5c1bcdf0305114081078" +checksum = "1d68074620f57a0b21594d9735eb2e98ab38b17f80d3fcb189fca266771ca60d" dependencies = [ "bytes", "futures-core", "futures-io", "futures-sink", - "futures-util", - "hashbrown 0.14.5", "pin-project-lite", "tokio", + "tracing", ] [[package]] diff --git a/Cargo.toml b/Cargo.toml index c15c8086e1fc0..bf44f574d86ac 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -217,7 +217,7 @@ tokio-stream = { git = "https://github.com/madsim-rs/tokio.git", rev = "0dd1055" "net", "fs", ] } -tokio-util = "0.7.12" +tokio-util = "0.7" tracing-opentelemetry = "0.25" rand = { version = "0.8", features = ["small_rng"] } governor = { version = "0.8", default-features = false, features = ["std"] } diff --git a/src/utils/pgwire/Cargo.toml b/src/utils/pgwire/Cargo.toml index 8c26969548d4e..76074e78fa442 100644 --- a/src/utils/pgwire/Cargo.toml +++ b/src/utils/pgwire/Cargo.toml @@ -34,7 +34,6 @@ tokio = { version = "0.2", package = "madsim-tokio", features = [ "macros", ] } tokio-openssl = "0.6.3" -tokio-util = { workspace = true, features = ["rt"] } tracing = "0.1" [target.'cfg(not(madsim))'.dependencies] diff --git a/src/utils/pgwire/src/pg_protocol.rs b/src/utils/pgwire/src/pg_protocol.rs index e047f5a598cf0..85fcdffff7c81 100644 --- a/src/utils/pgwire/src/pg_protocol.rs +++ b/src/utils/pgwire/src/pg_protocol.rs @@ -36,7 +36,6 @@ use thiserror_ext::AsReport; use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt}; use tokio::sync::Mutex; use tokio_openssl::SslStream; -use tokio_util::task::AbortOnDropHandle; use tracing::Instrument; use crate::error::{PsqlError, PsqlResult}; @@ -234,7 +233,7 @@ where } } }); - notice_task = Some(AbortOnDropHandle::new(handle)); + notice_task = Some(handle); } // Read and process messages. @@ -252,6 +251,11 @@ where break; } } + + // Abort the notice task to release ref count of the session and the stream. + if let Some(task) = notice_task { + task.abort(); + } } /// Processes one message. Returns true if the connection is terminated. From 91f9e6f6c94578ae7499892c5a62bcabbb4bb2d3 Mon Sep 17 00:00:00 2001 From: Bugen Zhao Date: Tue, 11 Feb 2025 15:52:23 +0800 Subject: [PATCH 8/8] do not spawn, select instead Signed-off-by: Bugen Zhao --- src/utils/pgwire/src/pg_protocol.rs | 42 ++++++++++++++++------------- 1 file changed, 23 insertions(+), 19 deletions(-) diff --git a/src/utils/pgwire/src/pg_protocol.rs b/src/utils/pgwire/src/pg_protocol.rs index 85fcdffff7c81..84d1db31e2955 100644 --- a/src/utils/pgwire/src/pg_protocol.rs +++ b/src/utils/pgwire/src/pg_protocol.rs @@ -216,46 +216,50 @@ where /// Run the protocol to serve the connection. pub async fn run(&mut self) { - let mut notice_task = None; + let mut notice_fut = None; loop { - // If a session is present, spawn a task to subscribe and send notices asynchronously. - if notice_task.is_none() + // Once a session is present, create a future to subscribe and send notices asynchronously. + if notice_fut.is_none() && let Some(session) = self.session.clone() { let mut stream = self.stream.clone(); - let handle = tokio::spawn(async move { + notice_fut = Some(Box::pin(async move { loop { let notice = session.next_notice().await; if let Err(e) = stream.write(&BeMessage::NoticeResponse(¬ice)).await { tracing::error!(error = %e.as_report(), notice, "failed to send notice"); - break; } } - }); - notice_task = Some(handle); + })); } // Read and process messages. - let msg = match self.read_message().await { - Ok(msg) => msg, - Err(e) => { - tracing::error!(error = %e.as_report(), "error when reading message"); - break; // terminate the connection + let process = std::pin::pin!(async { + let msg = match self.read_message().await { + Ok(msg) => msg, + Err(e) => { + tracing::error!(error = %e.as_report(), "error when reading message"); + return true; // terminate the connection + } + }; + tracing::trace!(?msg, "received message"); + self.process(msg).await + }); + + let terminated = if let Some(notice_fut) = notice_fut.as_mut() { + tokio::select! { + _ = notice_fut => unreachable!(), + terminated = process => terminated, } + } else { + process.await }; - tracing::trace!(?msg, "received message"); - let terminated = self.process(msg).await; if terminated { break; } } - - // Abort the notice task to release ref count of the session and the stream. - if let Some(task) = notice_task { - task.abort(); - } } /// Processes one message. Returns true if the connection is terminated.