Skip to content

Commit 88d59f8

Browse files
Prevent FlightSQL server panics for do_put when stream is empty or 1st stream element is an Err (#7492)
* Fix bug and add tests to verify * Remove unnecessary let assignment. * Add error callback for `do_put` * cargo fmt --------- Co-authored-by: Andrew Lamb <[email protected]>
1 parent 1c380b8 commit 88d59f8

File tree

4 files changed

+141
-13
lines changed

4 files changed

+141
-13
lines changed

arrow-flight/src/sql/mod.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,8 @@ pub mod client;
113113
pub mod metadata;
114114
pub mod server;
115115

116+
pub use crate::streams::FallibleRequestStream;
117+
116118
/// ProstMessageExt are useful utility methods for prost::Message types
117119
pub trait ProstMessageExt: prost::Message + Default {
118120
/// type_url for this Message

arrow-flight/src/sql/server.rs

Lines changed: 47 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -17,12 +17,9 @@
1717

1818
//! Helper trait [`FlightSqlService`] for implementing a [`FlightService`] that implements FlightSQL.
1919
20+
use std::fmt::{Display, Formatter};
2021
use std::pin::Pin;
2122

22-
use futures::{stream::Peekable, Stream, StreamExt};
23-
use prost::Message;
24-
use tonic::{Request, Response, Status, Streaming};
25-
2623
use super::{
2724
ActionBeginSavepointRequest, ActionBeginSavepointResult, ActionBeginTransactionRequest,
2825
ActionBeginTransactionResult, ActionCancelQueryRequest, ActionCancelQueryResult,
@@ -41,6 +38,9 @@ use crate::{
4138
FlightData, FlightDescriptor, FlightInfo, HandshakeRequest, HandshakeResponse, PutResult,
4239
SchemaResult, Ticket,
4340
};
41+
use futures::{stream::Peekable, Stream, StreamExt};
42+
use prost::Message;
43+
use tonic::{Request, Response, Status, Streaming};
4444

4545
pub(crate) static CREATE_PREPARED_STATEMENT: &str = "CreatePreparedStatement";
4646
pub(crate) static CLOSE_PREPARED_STATEMENT: &str = "ClosePreparedStatement";
@@ -386,6 +386,15 @@ pub trait FlightSqlService: Sync + Send + Sized + 'static {
386386
)))
387387
}
388388

389+
/// Implementors may override to handle do_put errors
390+
async fn do_put_error_callback(
391+
&self,
392+
_request: Request<PeekableFlightDataStream>,
393+
error: DoPutError,
394+
) -> Result<Response<<Self as FlightService>::DoPutStream>, Status> {
395+
Err(Status::unimplemented(format!("Unhandled Error: {}", error)))
396+
}
397+
389398
/// Execute an update SQL statement.
390399
async fn do_put_statement_update(
391400
&self,
@@ -710,10 +719,21 @@ where
710719
// we wrap this stream in a `Peekable` one, which allows us to peek at
711720
// the first message without discarding it.
712721
let mut request = request.map(PeekableFlightDataStream::new);
713-
let cmd = Pin::new(request.get_mut()).peek().await.unwrap().clone()?;
722+
let mut stream = Pin::new(request.get_mut());
723+
724+
let peeked_item = stream.peek().await.cloned();
725+
let Some(cmd) = peeked_item else {
726+
return self
727+
.do_put_error_callback(request, DoPutError::MissingCommand)
728+
.await;
729+
};
714730

715-
let message =
716-
Any::decode(&*cmd.flight_descriptor.unwrap().cmd).map_err(decode_error_to_status)?;
731+
let Some(flight_descriptor) = cmd?.flight_descriptor else {
732+
return self
733+
.do_put_error_callback(request, DoPutError::MissingFlightDescriptor)
734+
.await;
735+
};
736+
let message = Any::decode(flight_descriptor.cmd).map_err(decode_error_to_status)?;
717737
match Command::try_from(message).map_err(arrow_error_to_status)? {
718738
Command::CommandStatementUpdate(command) => {
719739
let record_count = self.do_put_statement_update(command, request).await?;
@@ -968,6 +988,26 @@ where
968988
}
969989
}
970990

991+
/// Unrecoverable errors associated with `do_put` requests
992+
pub enum DoPutError {
993+
/// The first element in the request stream is missing the command
994+
MissingCommand,
995+
/// The first element in the request stream is missing the flight descriptor
996+
MissingFlightDescriptor,
997+
}
998+
impl Display for DoPutError {
999+
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
1000+
match self {
1001+
DoPutError::MissingCommand => {
1002+
write!(f, "Command is missing.")
1003+
}
1004+
DoPutError::MissingFlightDescriptor => {
1005+
write!(f, "Flight descriptor is missing.")
1006+
}
1007+
}
1008+
}
1009+
}
1010+
9711011
fn decode_error_to_status(err: prost::DecodeError) -> Status {
9721012
Status::invalid_argument(format!("{err:?}"))
9731013
}

arrow-flight/src/streams.rs

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,15 +32,16 @@ use std::task::{ready, Poll};
3232
///
3333
/// This can be used to accept a stream of `Result<_>` from a client API and send
3434
/// them to the remote server that wants only the successful results.
35-
pub(crate) struct FallibleRequestStream<T, E> {
35+
pub struct FallibleRequestStream<T, E> {
3636
/// sender to notify error
3737
sender: Option<Sender<E>>,
3838
/// fallible stream
3939
fallible_stream: Pin<Box<dyn Stream<Item = std::result::Result<T, E>> + Send + 'static>>,
4040
}
4141

4242
impl<T, E> FallibleRequestStream<T, E> {
43-
pub(crate) fn new(
43+
/// Create a FallibleRequestStream
44+
pub fn new(
4445
sender: Sender<E>,
4546
fallible_stream: Pin<Box<dyn Stream<Item = std::result::Result<T, E>> + Send + 'static>>,
4647
) -> Self {

arrow-flight/tests/flight_sql_client.rs

Lines changed: 89 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,21 +22,23 @@ use crate::common::utils::make_primitive_batch;
2222

2323
use arrow_array::RecordBatch;
2424
use arrow_flight::decode::FlightRecordBatchStream;
25+
use arrow_flight::encode::FlightDataEncoderBuilder;
2526
use arrow_flight::error::FlightError;
2627
use arrow_flight::flight_service_server::FlightServiceServer;
2728
use arrow_flight::sql::client::FlightSqlServiceClient;
2829
use arrow_flight::sql::server::{FlightSqlService, PeekableFlightDataStream};
2930
use arrow_flight::sql::{
3031
ActionBeginTransactionRequest, ActionBeginTransactionResult, ActionEndTransactionRequest,
31-
CommandStatementIngest, EndTransaction, SqlInfo, TableDefinitionOptions, TableExistsOption,
32-
TableNotExistOption,
32+
CommandStatementIngest, EndTransaction, FallibleRequestStream, ProstMessageExt, SqlInfo,
33+
TableDefinitionOptions, TableExistsOption, TableNotExistOption,
3334
};
34-
use arrow_flight::Action;
35+
use arrow_flight::{Action, FlightData, FlightDescriptor};
3536
use futures::{StreamExt, TryStreamExt};
37+
use prost::Message;
3638
use std::collections::HashMap;
3739
use std::sync::Arc;
3840
use tokio::sync::Mutex;
39-
use tonic::{Request, Status};
41+
use tonic::{IntoStreamingRequest, Request, Status};
4042
use uuid::Uuid;
4143

4244
#[tokio::test]
@@ -116,6 +118,89 @@ pub async fn test_execute_ingest_error() {
116118
);
117119
}
118120

121+
#[tokio::test]
122+
pub async fn test_do_put_empty_stream() {
123+
// Test for https://github.com/apache/arrow-rs/issues/7329
124+
125+
let test_server = FlightSqlServiceImpl::new();
126+
let fixture = TestFixture::new(test_server.service()).await;
127+
let channel = fixture.channel().await;
128+
let mut flight_sql_client = FlightSqlServiceClient::new(channel);
129+
let cmd = make_ingest_command();
130+
131+
// Create an empty request stream
132+
let input_data = futures::stream::iter(vec![]);
133+
let flight_descriptor = FlightDescriptor::new_cmd(cmd.as_any().encode_to_vec());
134+
let flight_data_encoder = FlightDataEncoderBuilder::default()
135+
.with_flight_descriptor(Some(flight_descriptor))
136+
.build(input_data);
137+
let flight_data: Vec<FlightData> = Box::pin(flight_data_encoder).try_collect().await.unwrap();
138+
let request_stream = futures::stream::iter(flight_data);
139+
140+
// Execute a `do_put` and verify that the server error contains the expected message
141+
let err = flight_sql_client.do_put(request_stream).await.unwrap_err();
142+
assert!(err
143+
.to_string()
144+
.contains("Unhandled Error: Command is missing."),);
145+
}
146+
147+
#[tokio::test]
148+
pub async fn test_do_put_first_element_err() {
149+
// Test for https://github.com/apache/arrow-rs/issues/7329
150+
151+
let test_server = FlightSqlServiceImpl::new();
152+
let fixture = TestFixture::new(test_server.service()).await;
153+
let channel = fixture.channel().await;
154+
let mut flight_sql_client = FlightSqlServiceClient::new(channel);
155+
let cmd = make_ingest_command();
156+
157+
let (sender, _receiver) = futures::channel::oneshot::channel();
158+
159+
// Create a fallible request stream such that the 1st element is a FlightError
160+
let input_data = futures::stream::iter(vec![
161+
Err(FlightError::NotYetImplemented("random error".to_string())),
162+
Ok(make_primitive_batch(5)),
163+
]);
164+
let flight_descriptor = FlightDescriptor::new_cmd(cmd.as_any().encode_to_vec());
165+
let flight_data_encoder = FlightDataEncoderBuilder::default()
166+
.with_flight_descriptor(Some(flight_descriptor))
167+
.build(input_data);
168+
let flight_data: FallibleRequestStream<FlightData, FlightError> =
169+
FallibleRequestStream::new(sender, Box::pin(flight_data_encoder));
170+
let request_stream = flight_data.into_streaming_request();
171+
172+
// Execute a `do_put` and verify that the server error contains the expected message
173+
let err = flight_sql_client.do_put(request_stream).await.unwrap_err();
174+
175+
assert!(err
176+
.to_string()
177+
.contains("Unhandled Error: Command is missing."),);
178+
}
179+
180+
#[tokio::test]
181+
pub async fn test_do_put_missing_flight_descriptor() {
182+
// Test for https://github.com/apache/arrow-rs/issues/7329
183+
184+
let test_server = FlightSqlServiceImpl::new();
185+
let fixture = TestFixture::new(test_server.service()).await;
186+
let channel = fixture.channel().await;
187+
let mut flight_sql_client = FlightSqlServiceClient::new(channel);
188+
189+
// Create a request stream such that the flight descriptor is missing
190+
let stream = futures::stream::iter(vec![Ok(make_primitive_batch(5))]);
191+
let flight_data_encoder = FlightDataEncoderBuilder::default()
192+
.with_flight_descriptor(None)
193+
.build(stream);
194+
let flight_data: Vec<FlightData> = Box::pin(flight_data_encoder).try_collect().await.unwrap();
195+
let request_stream = futures::stream::iter(flight_data);
196+
197+
// Execute a `do_put` and verify that the server error contains the expected message
198+
let err = flight_sql_client.do_put(request_stream).await.unwrap_err();
199+
assert!(err
200+
.to_string()
201+
.contains("Unhandled Error: Flight descriptor is missing."),);
202+
}
203+
119204
fn make_ingest_command() -> CommandStatementIngest {
120205
CommandStatementIngest {
121206
table_definition_options: Some(TableDefinitionOptions {

0 commit comments

Comments
 (0)