@@ -22,21 +22,23 @@ use crate::common::utils::make_primitive_batch;
22
22
23
23
use arrow_array:: RecordBatch ;
24
24
use arrow_flight:: decode:: FlightRecordBatchStream ;
25
+ use arrow_flight:: encode:: FlightDataEncoderBuilder ;
25
26
use arrow_flight:: error:: FlightError ;
26
27
use arrow_flight:: flight_service_server:: FlightServiceServer ;
27
28
use arrow_flight:: sql:: client:: FlightSqlServiceClient ;
28
29
use arrow_flight:: sql:: server:: { FlightSqlService , PeekableFlightDataStream } ;
29
30
use arrow_flight:: sql:: {
30
31
ActionBeginTransactionRequest , ActionBeginTransactionResult , ActionEndTransactionRequest ,
31
- CommandStatementIngest , EndTransaction , SqlInfo , TableDefinitionOptions , TableExistsOption ,
32
- TableNotExistOption ,
32
+ CommandStatementIngest , EndTransaction , FallibleRequestStream , ProstMessageExt , SqlInfo ,
33
+ TableDefinitionOptions , TableExistsOption , TableNotExistOption ,
33
34
} ;
34
- use arrow_flight:: Action ;
35
+ use arrow_flight:: { Action , FlightData , FlightDescriptor } ;
35
36
use futures:: { StreamExt , TryStreamExt } ;
37
+ use prost:: Message ;
36
38
use std:: collections:: HashMap ;
37
39
use std:: sync:: Arc ;
38
40
use tokio:: sync:: Mutex ;
39
- use tonic:: { Request , Status } ;
41
+ use tonic:: { IntoStreamingRequest , Request , Status } ;
40
42
use uuid:: Uuid ;
41
43
42
44
#[ tokio:: test]
@@ -116,6 +118,89 @@ pub async fn test_execute_ingest_error() {
116
118
) ;
117
119
}
118
120
121
+ #[ tokio:: test]
122
+ pub async fn test_do_put_empty_stream ( ) {
123
+ // Test for https://github.com/apache/arrow-rs/issues/7329
124
+
125
+ let test_server = FlightSqlServiceImpl :: new ( ) ;
126
+ let fixture = TestFixture :: new ( test_server. service ( ) ) . await ;
127
+ let channel = fixture. channel ( ) . await ;
128
+ let mut flight_sql_client = FlightSqlServiceClient :: new ( channel) ;
129
+ let cmd = make_ingest_command ( ) ;
130
+
131
+ // Create an empty request stream
132
+ let input_data = futures:: stream:: iter ( vec ! [ ] ) ;
133
+ let flight_descriptor = FlightDescriptor :: new_cmd ( cmd. as_any ( ) . encode_to_vec ( ) ) ;
134
+ let flight_data_encoder = FlightDataEncoderBuilder :: default ( )
135
+ . with_flight_descriptor ( Some ( flight_descriptor) )
136
+ . build ( input_data) ;
137
+ let flight_data: Vec < FlightData > = Box :: pin ( flight_data_encoder) . try_collect ( ) . await . unwrap ( ) ;
138
+ let request_stream = futures:: stream:: iter ( flight_data) ;
139
+
140
+ // Execute a `do_put` and verify that the server error contains the expected message
141
+ let err = flight_sql_client. do_put ( request_stream) . await . unwrap_err ( ) ;
142
+ assert ! ( err
143
+ . to_string( )
144
+ . contains( "Unhandled Error: Command is missing." ) , ) ;
145
+ }
146
+
147
+ #[ tokio:: test]
148
+ pub async fn test_do_put_first_element_err ( ) {
149
+ // Test for https://github.com/apache/arrow-rs/issues/7329
150
+
151
+ let test_server = FlightSqlServiceImpl :: new ( ) ;
152
+ let fixture = TestFixture :: new ( test_server. service ( ) ) . await ;
153
+ let channel = fixture. channel ( ) . await ;
154
+ let mut flight_sql_client = FlightSqlServiceClient :: new ( channel) ;
155
+ let cmd = make_ingest_command ( ) ;
156
+
157
+ let ( sender, _receiver) = futures:: channel:: oneshot:: channel ( ) ;
158
+
159
+ // Create a fallible request stream such that the 1st element is a FlightError
160
+ let input_data = futures:: stream:: iter ( vec ! [
161
+ Err ( FlightError :: NotYetImplemented ( "random error" . to_string( ) ) ) ,
162
+ Ok ( make_primitive_batch( 5 ) ) ,
163
+ ] ) ;
164
+ let flight_descriptor = FlightDescriptor :: new_cmd ( cmd. as_any ( ) . encode_to_vec ( ) ) ;
165
+ let flight_data_encoder = FlightDataEncoderBuilder :: default ( )
166
+ . with_flight_descriptor ( Some ( flight_descriptor) )
167
+ . build ( input_data) ;
168
+ let flight_data: FallibleRequestStream < FlightData , FlightError > =
169
+ FallibleRequestStream :: new ( sender, Box :: pin ( flight_data_encoder) ) ;
170
+ let request_stream = flight_data. into_streaming_request ( ) ;
171
+
172
+ // Execute a `do_put` and verify that the server error contains the expected message
173
+ let err = flight_sql_client. do_put ( request_stream) . await . unwrap_err ( ) ;
174
+
175
+ assert ! ( err
176
+ . to_string( )
177
+ . contains( "Unhandled Error: Command is missing." ) , ) ;
178
+ }
179
+
180
+ #[ tokio:: test]
181
+ pub async fn test_do_put_missing_flight_descriptor ( ) {
182
+ // Test for https://github.com/apache/arrow-rs/issues/7329
183
+
184
+ let test_server = FlightSqlServiceImpl :: new ( ) ;
185
+ let fixture = TestFixture :: new ( test_server. service ( ) ) . await ;
186
+ let channel = fixture. channel ( ) . await ;
187
+ let mut flight_sql_client = FlightSqlServiceClient :: new ( channel) ;
188
+
189
+ // Create a request stream such that the flight descriptor is missing
190
+ let stream = futures:: stream:: iter ( vec ! [ Ok ( make_primitive_batch( 5 ) ) ] ) ;
191
+ let flight_data_encoder = FlightDataEncoderBuilder :: default ( )
192
+ . with_flight_descriptor ( None )
193
+ . build ( stream) ;
194
+ let flight_data: Vec < FlightData > = Box :: pin ( flight_data_encoder) . try_collect ( ) . await . unwrap ( ) ;
195
+ let request_stream = futures:: stream:: iter ( flight_data) ;
196
+
197
+ // Execute a `do_put` and verify that the server error contains the expected message
198
+ let err = flight_sql_client. do_put ( request_stream) . await . unwrap_err ( ) ;
199
+ assert ! ( err
200
+ . to_string( )
201
+ . contains( "Unhandled Error: Flight descriptor is missing." ) , ) ;
202
+ }
203
+
119
204
fn make_ingest_command ( ) -> CommandStatementIngest {
120
205
CommandStatementIngest {
121
206
table_definition_options : Some ( TableDefinitionOptions {
0 commit comments