Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(pgwire): send notice asynchronously and promptly #20374

Merged
merged 8 commits into from
Feb 12, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 21 additions & 14 deletions src/frontend/src/session.rs
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ use risingwave_sqlparser::ast::{ObjectName, Statement};
use risingwave_sqlparser::parser::Parser;
use thiserror::Error;
use tokio::runtime::Builder;
use tokio::sync::mpsc::{self, UnboundedReceiver, UnboundedSender};
use tokio::sync::oneshot::Sender;
use tokio::sync::watch;
use tokio::task::JoinHandle;
Expand Down Expand Up @@ -646,8 +647,11 @@ pub struct SessionImpl {
user_authenticator: UserAuthenticator,
/// Stores the value of configurations.
config_map: Arc<RwLock<SessionConfig>>,
/// buffer the Notices to users,
notices: RwLock<Vec<String>>,

/// Channel sender for frontend handler to send notices.
notice_tx: UnboundedSender<String>,
/// Channel receiver for pgwire to take notices and send to clients.
notice_rx: Mutex<UnboundedReceiver<String>>,

/// Identified by `process_id`, `secret_key`. Corresponds to `SessionManager`.
id: (i32, i32),
Expand Down Expand Up @@ -741,6 +745,8 @@ impl SessionImpl {
session_config: SessionConfig,
) -> Self {
let cursor_metrics = env.cursor_metrics.clone();
let (notice_tx, notice_rx) = mpsc::unbounded_channel();

Self {
env,
auth_context: Arc::new(RwLock::new(auth_context)),
Expand All @@ -750,7 +756,8 @@ impl SessionImpl {
peer_addr,
txn: Default::default(),
current_query_cancel_flag: Mutex::new(None),
notices: Default::default(),
notice_tx,
notice_rx: Mutex::new(notice_rx),
exec_context: Mutex::new(None),
last_idle_instant: Default::default(),
cursor_manager: Arc::new(CursorManager::new(cursor_metrics)),
Expand All @@ -761,6 +768,8 @@ impl SessionImpl {
#[cfg(test)]
pub fn mock() -> Self {
let env = FrontendEnv::mock();
let (notice_tx, notice_rx) = mpsc::unbounded_channel();

Self {
env: FrontendEnv::mock(),
auth_context: Arc::new(RwLock::new(AuthContext::new(
Expand All @@ -774,7 +783,8 @@ impl SessionImpl {
id: (0, 0),
txn: Default::default(),
current_query_cancel_flag: Mutex::new(None),
notices: Default::default(),
notice_tx,
notice_rx: Mutex::new(notice_rx),
exec_context: Mutex::new(None),
peer_addr: Address::Tcp(SocketAddr::new(
IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)),
Expand Down Expand Up @@ -1143,10 +1153,6 @@ impl SessionImpl {
shutdown_rx
}

fn clear_notices(&self) {
*self.notices.write() = vec![];
}

pub fn cancel_current_query(&self) {
let mut flag_guard = self.current_query_cancel_flag.lock();
if let Some(sender) = flag_guard.take() {
Expand All @@ -1158,12 +1164,10 @@ impl SessionImpl {
info!("Trying to cancel query in distributed mode.");
self.env.query_manager().cancel_queries_in_session(self.id)
}
self.clear_notices()
}

pub fn cancel_current_creating_job(&self) {
self.env.creating_streaming_job_tracker.abort_jobs(self.id);
self.clear_notices()
}

/// This function only used for test now.
Expand Down Expand Up @@ -1195,7 +1199,9 @@ impl SessionImpl {
pub fn notice_to_user(&self, str: impl Into<String>) {
let notice = str.into();
tracing::trace!(notice, "notice to user");
self.notices.write().push(notice);
self.notice_tx
.send(notice)
.expect("notice channel should not be closed");
}

pub fn is_barrier_read(&self) -> bool {
Expand Down Expand Up @@ -1586,9 +1592,10 @@ impl Session for SessionImpl {
Self::set_config(self, key, value).map_err(Into::into)
}

fn take_notices(self: Arc<Self>) -> Vec<String> {
let inner = &mut (*self.notices.write());
std::mem::take(inner)
async fn next_notice(self: &Arc<Self>) -> String {
std::future::poll_fn(|cx| self.clone().notice_rx.lock().poll_recv(cx))
.await
.expect("notice channel should not be closed")
}

fn transaction_status(&self) -> TransactionStatus {
Expand Down
1 change: 1 addition & 0 deletions src/utils/pgwire/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
#![feature(buf_read_has_data_left)]
#![feature(round_char_boundary)]
#![feature(never_type)]
#![feature(let_chains)]
#![expect(clippy::doc_markdown, reason = "FIXME: later")]

pub mod error;
Expand Down
5 changes: 2 additions & 3 deletions src/utils/pgwire/src/pg_extended.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,10 @@ use std::vec::IntoIter;
use futures::stream::FusedStream;
use futures::{StreamExt, TryStreamExt};
use postgres_types::FromSql;
use tokio::io::{AsyncRead, AsyncWrite};

use crate::error::{PsqlError, PsqlResult};
use crate::pg_message::{BeCommandCompleteMessage, BeMessage};
use crate::pg_protocol::PgStream;
use crate::pg_protocol::{PgByteStream, PgStream};
use crate::pg_response::{PgResponse, ValuesStream};
use crate::types::{Format, Row};

Expand All @@ -45,7 +44,7 @@ where
}

/// Return indicate whether the result is consumed completely.
pub async fn consume<S: AsyncWrite + AsyncRead + Unpin>(
pub async fn consume<S: PgByteStream>(
&mut self,
row_limit: usize,
msg_stream: &mut PgStream<S>,
Expand Down
64 changes: 60 additions & 4 deletions src/utils/pgwire/src/pg_protocol.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ use std::{io, str};

use bytes::{Bytes, BytesMut};
use futures::stream::StreamExt;
use futures::FutureExt;
use itertools::Itertools;
use openssl::ssl::{SslAcceptor, SslContext, SslContextRef, SslMethod};
use risingwave_common::types::DataType;
Expand Down Expand Up @@ -182,7 +183,7 @@ fn redact_sql(sql: &str, keywords: RedactSqlOptionKeywordsRef) -> String {

impl<S, SM> PgProtocol<S, SM>
where
S: AsyncWrite + AsyncRead + Unpin,
S: PgByteStream,
SM: SessionManager,
{
pub fn new(
Expand Down Expand Up @@ -213,6 +214,54 @@ where
}
}

/// Run the protocol to serve the connection.
pub async fn run(&mut self) {
let mut notice_fut = None;

loop {
// Once a session is present, create a future to subscribe and send notices asynchronously.
if notice_fut.is_none()
&& let Some(session) = self.session.clone()
{
let mut stream = self.stream.clone();
notice_fut = Some(Box::pin(async move {
loop {
let notice = session.next_notice().await;
if let Err(e) = stream.write(&BeMessage::NoticeResponse(&notice)).await {
tracing::error!(error = %e.as_report(), notice, "failed to send notice");
}
}
}));
}

// Read and process messages.
let process = std::pin::pin!(async {
let msg = match self.read_message().await {
Ok(msg) => msg,
Err(e) => {
tracing::error!(error = %e.as_report(), "error when reading message");
return true; // terminate the connection
}
};
tracing::trace!(?msg, "received message");
self.process(msg).await
});

let terminated = if let Some(notice_fut) = notice_fut.as_mut() {
tokio::select! {
_ = notice_fut => unreachable!(),
terminated = process => terminated,
}
} else {
process.await
};

if terminated {
break;
}
}
}

/// Processes one message. Returns true if the connection is terminated.
pub async fn process(&mut self, msg: FeMessage) -> bool {
self.do_process(msg).await.is_none() || self.is_terminate
Expand Down Expand Up @@ -615,10 +664,13 @@ where
.clone()
.run_one_query(stmt.clone(), Format::Text)
.await;
for notice in session.take_notices() {

// Take all remaining notices (if any) and send them before `CommandComplete`.
while let Some(notice) = session.next_notice().now_or_never() {
self.stream
.write_no_flush(&BeMessage::NoticeResponse(&notice))?;
}

let mut res = res.map_err(PsqlError::SimpleQueryError)?;

for notice in res.notices() {
Expand Down Expand Up @@ -994,6 +1046,10 @@ enum PgStreamInner<S> {
Ssl(SslStream<S>),
}

/// Trait for a byte stream that can be used for pg protocol.
pub trait PgByteStream: AsyncWrite + AsyncRead + Unpin + Send + 'static {}
impl<S> PgByteStream for S where S: AsyncWrite + AsyncRead + Unpin + Send + 'static {}

/// Wraps a byte stream and read/write pg messages.
///
/// Cloning a `PgStream` will share the same stream but a fresh & independent write buffer,
Expand Down Expand Up @@ -1049,7 +1105,7 @@ pub struct ParameterStatus {

impl<S> PgStream<S>
where
S: AsyncWrite + AsyncRead + Unpin,
S: PgByteStream,
{
async fn read_startup(&mut self) -> io::Result<FeMessage> {
let mut stream = self.stream.lock().await;
Expand Down Expand Up @@ -1117,7 +1173,7 @@ where

impl<S> PgStream<S>
where
S: AsyncWrite + AsyncRead + Unpin,
S: PgByteStream,
{
/// Convert the underlying stream to ssl stream based on the given context.
async fn upgrade_to_ssl(&mut self, ssl_ctx: &SslContextRef) -> PsqlResult<()> {
Expand Down
37 changes: 12 additions & 25 deletions src/utils/pgwire/src/pg_server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,13 +27,12 @@ use risingwave_common::util::tokio_util::sync::CancellationToken;
use risingwave_sqlparser::ast::{RedactSqlOptionKeywordsRef, Statement};
use serde::Deserialize;
use thiserror_ext::AsReport;
use tokio::io::{AsyncRead, AsyncWrite};

use crate::error::{PsqlError, PsqlResult};
use crate::net::{AddressRef, Listener, TcpKeepalive};
use crate::pg_field_descriptor::PgFieldDescriptor;
use crate::pg_message::TransactionStatus;
use crate::pg_protocol::{PgProtocol, TlsConfig};
use crate::pg_protocol::{PgByteStream, PgProtocol, TlsConfig};
use crate::pg_response::{PgResponse, ValuesStream};
use crate::types::Format;

Expand Down Expand Up @@ -95,9 +94,10 @@ pub trait Session: Send + Sync {
params_types: Vec<Option<DataType>>,
) -> impl Future<Output = Result<Self::PreparedStatement, BoxedError>> + Send;

// TODO: maybe this function should be async and return the notice more timely
/// try to take the current notices from the session
fn take_notices(self: Arc<Self>) -> Vec<String>;
/// Receive the next notice message to send to the client.
///
/// This function should be cancellation-safe.
fn next_notice(self: &Arc<Self>) -> impl Future<Output = String> + Send;

fn bind(
self: Arc<Self>,
Expand Down Expand Up @@ -331,32 +331,19 @@ pub async fn handle_connection<S, SM>(
peer_addr: AddressRef,
redact_sql_option_keywords: Option<RedactSqlOptionKeywordsRef>,
) where
S: AsyncWrite + AsyncRead + Unpin,
S: PgByteStream,
SM: SessionManager,
{
let mut pg_proto = PgProtocol::new(
PgProtocol::new(
stream,
session_mgr,
tls_config,
peer_addr,
redact_sql_option_keywords,
);
loop {
let msg = match pg_proto.read_message().await {
Ok(msg) => msg,
Err(e) => {
tracing::error!(error = %e.as_report(), "error when reading message");
break;
}
};
tracing::trace!("Received message: {:?}", msg);
let ret = pg_proto.process(msg).await;
if ret {
break;
}
}
)
.run()
.await;
}

#[cfg(test)]
mod tests {
use std::error::Error;
Expand Down Expand Up @@ -505,8 +492,8 @@ mod tests {
Ok("".to_owned())
}

fn take_notices(self: Arc<Self>) -> Vec<String> {
vec![]
async fn next_notice(self: &Arc<Self>) -> String {
std::future::pending().await
}

fn transaction_status(&self) -> TransactionStatus {
Expand Down
Loading