Skip to content

Prevent FlightSQL server panics for do_put when stream is empty or 1st stream element is an Err #7492

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
May 26, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions arrow-flight/src/sql/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,8 @@ pub mod client;
pub mod metadata;
pub mod server;

pub use crate::streams::FallibleRequestStream;

/// ProstMessageExt are useful utility methods for prost::Message types
pub trait ProstMessageExt: prost::Message + Default {
/// type_url for this Message
Expand Down
54 changes: 47 additions & 7 deletions arrow-flight/src/sql/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,9 @@

//! Helper trait [`FlightSqlService`] for implementing a [`FlightService`] that implements FlightSQL.

use std::fmt::{Display, Formatter};
use std::pin::Pin;

use futures::{stream::Peekable, Stream, StreamExt};
use prost::Message;
use tonic::{Request, Response, Status, Streaming};

use super::{
ActionBeginSavepointRequest, ActionBeginSavepointResult, ActionBeginTransactionRequest,
ActionBeginTransactionResult, ActionCancelQueryRequest, ActionCancelQueryResult,
Expand All @@ -41,6 +38,9 @@ use crate::{
FlightData, FlightDescriptor, FlightInfo, HandshakeRequest, HandshakeResponse, PutResult,
SchemaResult, Ticket,
};
use futures::{stream::Peekable, Stream, StreamExt};
use prost::Message;
use tonic::{Request, Response, Status, Streaming};

pub(crate) static CREATE_PREPARED_STATEMENT: &str = "CreatePreparedStatement";
pub(crate) static CLOSE_PREPARED_STATEMENT: &str = "ClosePreparedStatement";
Expand Down Expand Up @@ -386,6 +386,15 @@ pub trait FlightSqlService: Sync + Send + Sized + 'static {
)))
}

/// Implementors may override to handle do_put errors
async fn do_put_error_callback(
&self,
_request: Request<PeekableFlightDataStream>,
error: DoPutError,
) -> Result<Response<<Self as FlightService>::DoPutStream>, Status> {
Err(Status::unimplemented(format!("Unhandled Error: {}", error)))
}

/// Execute an update SQL statement.
async fn do_put_statement_update(
&self,
Expand Down Expand Up @@ -710,10 +719,21 @@ where
// we wrap this stream in a `Peekable` one, which allows us to peek at
// the first message without discarding it.
let mut request = request.map(PeekableFlightDataStream::new);
let cmd = Pin::new(request.get_mut()).peek().await.unwrap().clone()?;
let mut stream = Pin::new(request.get_mut());

let peeked_item = stream.peek().await.cloned();
let Some(cmd) = peeked_item else {
return self
.do_put_error_callback(request, DoPutError::MissingCommand)
.await;
};

let message =
Any::decode(&*cmd.flight_descriptor.unwrap().cmd).map_err(decode_error_to_status)?;
let Some(flight_descriptor) = cmd?.flight_descriptor else {
return self
.do_put_error_callback(request, DoPutError::MissingFlightDescriptor)
.await;
};
let message = Any::decode(flight_descriptor.cmd).map_err(decode_error_to_status)?;
match Command::try_from(message).map_err(arrow_error_to_status)? {
Command::CommandStatementUpdate(command) => {
let record_count = self.do_put_statement_update(command, request).await?;
Expand Down Expand Up @@ -968,6 +988,26 @@ where
}
}

/// Unrecoverable errors associated with `do_put` requests
pub enum DoPutError {
/// The first element in the request stream is missing the command
MissingCommand,
/// The first element in the request stream is missing the flight descriptor
MissingFlightDescriptor,
}
impl Display for DoPutError {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
match self {
DoPutError::MissingCommand => {
write!(f, "Command is missing.")
}
DoPutError::MissingFlightDescriptor => {
write!(f, "Flight descriptor is missing.")
}
}
}
}

fn decode_error_to_status(err: prost::DecodeError) -> Status {
Status::invalid_argument(format!("{err:?}"))
}
Expand Down
5 changes: 3 additions & 2 deletions arrow-flight/src/streams.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,15 +32,16 @@ use std::task::{ready, Poll};
///
/// This can be used to accept a stream of `Result<_>` from a client API and send
/// them to the remote server that wants only the successful results.
pub(crate) struct FallibleRequestStream<T, E> {
pub struct FallibleRequestStream<T, E> {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changed the visibility of this in order to create the test that validates the bug fix.

/// sender to notify error
sender: Option<Sender<E>>,
/// fallible stream
fallible_stream: Pin<Box<dyn Stream<Item = std::result::Result<T, E>> + Send + 'static>>,
}

impl<T, E> FallibleRequestStream<T, E> {
pub(crate) fn new(
/// Create a FallibleRequestStream
pub fn new(
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here: visibility change.

sender: Sender<E>,
fallible_stream: Pin<Box<dyn Stream<Item = std::result::Result<T, E>> + Send + 'static>>,
) -> Self {
Expand Down
93 changes: 89 additions & 4 deletions arrow-flight/tests/flight_sql_client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,21 +22,23 @@ use crate::common::utils::make_primitive_batch;

use arrow_array::RecordBatch;
use arrow_flight::decode::FlightRecordBatchStream;
use arrow_flight::encode::FlightDataEncoderBuilder;
use arrow_flight::error::FlightError;
use arrow_flight::flight_service_server::FlightServiceServer;
use arrow_flight::sql::client::FlightSqlServiceClient;
use arrow_flight::sql::server::{FlightSqlService, PeekableFlightDataStream};
use arrow_flight::sql::{
ActionBeginTransactionRequest, ActionBeginTransactionResult, ActionEndTransactionRequest,
CommandStatementIngest, EndTransaction, SqlInfo, TableDefinitionOptions, TableExistsOption,
TableNotExistOption,
CommandStatementIngest, EndTransaction, FallibleRequestStream, ProstMessageExt, SqlInfo,
TableDefinitionOptions, TableExistsOption, TableNotExistOption,
};
use arrow_flight::Action;
use arrow_flight::{Action, FlightData, FlightDescriptor};
use futures::{StreamExt, TryStreamExt};
use prost::Message;
use std::collections::HashMap;
use std::sync::Arc;
use tokio::sync::Mutex;
use tonic::{Request, Status};
use tonic::{IntoStreamingRequest, Request, Status};
use uuid::Uuid;

#[tokio::test]
Expand Down Expand Up @@ -116,6 +118,89 @@ pub async fn test_execute_ingest_error() {
);
}

#[tokio::test]
pub async fn test_do_put_empty_stream() {
// Test for https://github.com/apache/arrow-rs/issues/7329

let test_server = FlightSqlServiceImpl::new();
let fixture = TestFixture::new(test_server.service()).await;
let channel = fixture.channel().await;
let mut flight_sql_client = FlightSqlServiceClient::new(channel);
let cmd = make_ingest_command();

// Create an empty request stream
let input_data = futures::stream::iter(vec![]);
let flight_descriptor = FlightDescriptor::new_cmd(cmd.as_any().encode_to_vec());
let flight_data_encoder = FlightDataEncoderBuilder::default()
.with_flight_descriptor(Some(flight_descriptor))
.build(input_data);
let flight_data: Vec<FlightData> = Box::pin(flight_data_encoder).try_collect().await.unwrap();
let request_stream = futures::stream::iter(flight_data);

// Execute a `do_put` and verify that the server error contains the expected message
let err = flight_sql_client.do_put(request_stream).await.unwrap_err();
assert!(err
.to_string()
.contains("Unhandled Error: Command is missing."),);
}

#[tokio::test]
pub async fn test_do_put_first_element_err() {
// Test for https://github.com/apache/arrow-rs/issues/7329

let test_server = FlightSqlServiceImpl::new();
let fixture = TestFixture::new(test_server.service()).await;
let channel = fixture.channel().await;
let mut flight_sql_client = FlightSqlServiceClient::new(channel);
let cmd = make_ingest_command();

let (sender, _receiver) = futures::channel::oneshot::channel();

// Create a fallible request stream such that the 1st element is a FlightError
let input_data = futures::stream::iter(vec![
Err(FlightError::NotYetImplemented("random error".to_string())),
Ok(make_primitive_batch(5)),
]);
let flight_descriptor = FlightDescriptor::new_cmd(cmd.as_any().encode_to_vec());
let flight_data_encoder = FlightDataEncoderBuilder::default()
.with_flight_descriptor(Some(flight_descriptor))
.build(input_data);
let flight_data: FallibleRequestStream<FlightData, FlightError> =
FallibleRequestStream::new(sender, Box::pin(flight_data_encoder));
let request_stream = flight_data.into_streaming_request();

// Execute a `do_put` and verify that the server error contains the expected message
let err = flight_sql_client.do_put(request_stream).await.unwrap_err();

assert!(err
.to_string()
.contains("Unhandled Error: Command is missing."),);
}

#[tokio::test]
pub async fn test_do_put_missing_flight_descriptor() {
// Test for https://github.com/apache/arrow-rs/issues/7329

let test_server = FlightSqlServiceImpl::new();
let fixture = TestFixture::new(test_server.service()).await;
let channel = fixture.channel().await;
let mut flight_sql_client = FlightSqlServiceClient::new(channel);

// Create a request stream such that the flight descriptor is missing
let stream = futures::stream::iter(vec![Ok(make_primitive_batch(5))]);
let flight_data_encoder = FlightDataEncoderBuilder::default()
.with_flight_descriptor(None)
.build(stream);
let flight_data: Vec<FlightData> = Box::pin(flight_data_encoder).try_collect().await.unwrap();
let request_stream = futures::stream::iter(flight_data);

// Execute a `do_put` and verify that the server error contains the expected message
let err = flight_sql_client.do_put(request_stream).await.unwrap_err();
assert!(err
.to_string()
.contains("Unhandled Error: Flight descriptor is missing."),);
}

fn make_ingest_command() -> CommandStatementIngest {
CommandStatementIngest {
table_definition_options: Some(TableDefinitionOptions {
Expand Down
Loading