Skip to content

Commit ef2527f

Browse files
committed
feat(mssql): fix a few bugs and implement Connection::describe
1 parent 559169c commit ef2527f

27 files changed

+424
-61
lines changed

.gitattributes

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
* text=auto eol=lf

Cargo.lock

Lines changed: 2 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Cargo.toml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -159,3 +159,8 @@ required-features = [ "mssql" ]
159159
name = "mssql-types"
160160
path = "tests/mssql/types.rs"
161161
required-features = [ "mssql" ]
162+
163+
[[test]]
164+
name = "mssql-describe"
165+
path = "tests/mssql/describe.rs"
166+
required-features = [ "mssql" ]

sqlx-core/Cargo.toml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ default = [ "runtime-async-std" ]
1919
postgres = [ "md-5", "sha2", "base64", "sha-1", "rand", "hmac", "futures-channel/sink", "futures-util/sink" ]
2020
mysql = [ "sha-1", "sha2", "generic-array", "num-bigint", "base64", "digest", "rand" ]
2121
sqlite = [ "libsqlite3-sys" ]
22-
mssql = [ "uuid", "encoding_rs" ]
22+
mssql = [ "uuid", "encoding_rs", "regex" ]
2323

2424
# types
2525
all-types = [ "chrono", "time", "bigdecimal", "ipnetwork", "json", "uuid" ]
@@ -65,11 +65,13 @@ log = { version = "0.4.8", default-features = false }
6565
md-5 = { version = "0.8.0", default-features = false, optional = true }
6666
memchr = { version = "2.3.3", default-features = false }
6767
num-bigint = { version = "0.2.6", default-features = false, optional = true, features = [ "std" ] }
68+
once_cell = "1.4.0"
6869
percent-encoding = "2.1.0"
6970
parking_lot = "0.10.2"
7071
threadpool = "*"
7172
phf = { version = "0.8.0", features = [ "macros" ] }
7273
rand = { version = "0.7.3", default-features = false, optional = true, features = [ "std" ] }
74+
regex = { version = "1.3.9", optional = true }
7375
serde = { version = "1.0.106", features = [ "derive", "rc" ], optional = true }
7476
serde_json = { version = "1.0.51", features = [ "raw_value" ], optional = true }
7577
sha-1 = { version = "0.8.2", default-features = false, optional = true }

sqlx-core/src/encode.rs

Lines changed: 42 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -68,36 +68,49 @@ where
6868
}
6969
}
7070

71-
impl<'q, T: 'q + Encode<'q, DB>, DB: Database> Encode<'q, DB> for Option<T> {
72-
#[inline]
73-
fn produces(&self) -> DB::TypeInfo {
74-
if let Some(v) = self {
75-
v.produces()
76-
} else {
77-
T::type_info()
78-
}
79-
}
71+
#[allow(unused_macros)]
72+
macro_rules! impl_encode_for_option {
73+
($DB:ident) => {
74+
impl<'q, T: 'q + crate::encode::Encode<'q, $DB>> crate::encode::Encode<'q, $DB>
75+
for Option<T>
76+
{
77+
#[inline]
78+
fn produces(&self) -> <$DB as crate::database::Database>::TypeInfo {
79+
if let Some(v) = self {
80+
v.produces()
81+
} else {
82+
T::type_info()
83+
}
84+
}
8085

81-
#[inline]
82-
fn encode(self, buf: &mut <DB as HasArguments<'q>>::ArgumentBuffer) -> IsNull {
83-
if let Some(v) = self {
84-
v.encode(buf)
85-
} else {
86-
IsNull::Yes
87-
}
88-
}
86+
#[inline]
87+
fn encode(
88+
self,
89+
buf: &mut <$DB as crate::database::HasArguments<'q>>::ArgumentBuffer,
90+
) -> crate::encode::IsNull {
91+
if let Some(v) = self {
92+
v.encode(buf)
93+
} else {
94+
crate::encode::IsNull::Yes
95+
}
96+
}
8997

90-
#[inline]
91-
fn encode_by_ref(&self, buf: &mut <DB as HasArguments<'q>>::ArgumentBuffer) -> IsNull {
92-
if let Some(v) = self {
93-
v.encode_by_ref(buf)
94-
} else {
95-
IsNull::Yes
96-
}
97-
}
98+
#[inline]
99+
fn encode_by_ref(
100+
&self,
101+
buf: &mut <$DB as crate::database::HasArguments<'q>>::ArgumentBuffer,
102+
) -> crate::encode::IsNull {
103+
if let Some(v) = self {
104+
v.encode_by_ref(buf)
105+
} else {
106+
crate::encode::IsNull::Yes
107+
}
108+
}
98109

99-
#[inline]
100-
fn size_hint(&self) -> usize {
101-
self.as_ref().map_or(0, Encode::size_hint)
102-
}
110+
#[inline]
111+
fn size_hint(&self) -> usize {
112+
self.as_ref().map_or(0, crate::encode::Encode::size_hint)
113+
}
114+
}
115+
};
103116
}

sqlx-core/src/lib.rs

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,10 +31,12 @@ pub mod connection;
3131
#[macro_use]
3232
pub mod transaction;
3333

34+
#[macro_use]
35+
pub mod encode;
36+
3437
pub mod database;
3538
pub mod decode;
3639
pub mod describe;
37-
pub mod encode;
3840
pub mod executor;
3941
mod ext;
4042
pub mod from_row;
@@ -59,3 +61,7 @@ pub mod sqlite;
5961
#[cfg(feature = "mysql")]
6062
#[cfg_attr(docsrs, doc(cfg(feature = "mysql")))]
6163
pub mod mysql;
64+
65+
#[cfg(feature = "mssql")]
66+
#[cfg_attr(docsrs, doc(cfg(feature = "mssql")))]
67+
pub mod mssql;

sqlx-core/src/mssql/arguments.rs

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ use crate::arguments::Arguments;
22
use crate::encode::Encode;
33
use crate::mssql::database::MsSql;
44
use crate::mssql::io::MsSqlBufMutExt;
5+
use crate::mssql::protocol::rpc::StatusFlags;
56

67
#[derive(Default)]
78
pub struct MsSqlArguments {
@@ -31,6 +32,19 @@ impl MsSqlArguments {
3132
self.add_named("", value);
3233
}
3334

35+
pub(crate) fn declare<'q, T: Encode<'q, MsSql>>(&mut self, name: &str, initial_value: T) {
36+
let ty = initial_value.produces();
37+
38+
let mut ty_name = String::new();
39+
ty.0.fmt(&mut ty_name);
40+
41+
self.data.put_b_varchar(name); // [ParamName]
42+
self.data.push(StatusFlags::BY_REF_VALUE.bits()); // [StatusFlags]
43+
44+
ty.0.put(&mut self.data); // [TYPE_INFO]
45+
ty.0.put_value(&mut self.data, initial_value); // [ParamLenData]
46+
}
47+
3448
pub(crate) fn append(&mut self, arguments: &mut MsSqlArguments) {
3549
self.ordinal += arguments.ordinal;
3650
self.data.append(&mut arguments.data);

sqlx-core/src/mssql/connection/establish.rs

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -49,8 +49,7 @@ impl MsSqlConnection {
4949
server_name: "",
5050
client_interface_name: "",
5151
language: "",
52-
// FIXME: connect this to options.database
53-
database: "",
52+
database: &*options.database,
5453
client_id: [0; 6],
5554
},
5655
);

sqlx-core/src/mssql/connection/executor.rs

Lines changed: 109 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -3,16 +3,19 @@ use either::Either;
33
use futures_core::future::BoxFuture;
44
use futures_core::stream::BoxStream;
55
use futures_util::TryStreamExt;
6+
use once_cell::sync::Lazy;
7+
use regex::Regex;
68

7-
use crate::describe::Describe;
9+
use crate::describe::{Column, Describe};
810
use crate::error::Error;
911
use crate::executor::{Execute, Executor};
10-
use crate::mssql::protocol::done::Done;
12+
use crate::mssql::protocol::col_meta_data::Flags;
13+
use crate::mssql::protocol::done::{Done, Status};
1114
use crate::mssql::protocol::message::Message;
1215
use crate::mssql::protocol::packet::PacketType;
1316
use crate::mssql::protocol::rpc::{OptionFlags, Procedure, RpcRequest};
1417
use crate::mssql::protocol::sql_batch::SqlBatch;
15-
use crate::mssql::{MsSql, MsSqlArguments, MsSqlConnection, MsSqlRow};
18+
use crate::mssql::{MsSql, MsSqlArguments, MsSqlConnection, MsSqlRow, MsSqlTypeInfo};
1619

1720
impl MsSqlConnection {
1821
pub(crate) async fn wait_until_ready(&mut self) -> Result<(), Error> {
@@ -25,8 +28,10 @@ impl MsSqlConnection {
2528
let message = self.stream.recv_message().await?;
2629

2730
if let Message::DoneProc(done) | Message::Done(done) = message {
28-
// finished RPC procedure *OR* SQL batch
29-
self.handle_done(done);
31+
if !done.status.contains(Status::DONE_MORE) {
32+
// finished RPC procedure *OR* SQL batch
33+
self.handle_done(done);
34+
}
3035
}
3136
}
3237

@@ -106,20 +111,23 @@ impl<'c> Executor<'c> for &'c mut MsSqlConnection {
106111
yield v;
107112
}
108113

109-
Message::DoneProc(done) => {
110-
self.handle_done(done);
111-
break;
112-
}
114+
Message::Done(done) | Message::DoneProc(done) => {
115+
if done.status.contains(Status::DONE_COUNT) {
116+
let v = Either::Left(done.affected_rows);
117+
yield v;
118+
}
113119

114-
Message::DoneInProc(done) => {
115-
// finished SQL query *within* procedure
116-
let v = Either::Left(done.affected_rows);
117-
yield v;
120+
if !done.status.contains(Status::DONE_MORE) {
121+
self.handle_done(done);
122+
break;
123+
}
118124
}
119125

120-
Message::Done(done) => {
121-
self.handle_done(done);
122-
break;
126+
Message::DoneInProc(done) => {
127+
if done.status.contains(Status::DONE_COUNT) {
128+
let v = Either::Left(done.affected_rows);
129+
yield v;
130+
}
123131
}
124132

125133
_ => {}
@@ -157,6 +165,90 @@ impl<'c> Executor<'c> for &'c mut MsSqlConnection {
157165
'c: 'e,
158166
E: Execute<'q, Self::Database>,
159167
{
160-
unimplemented!()
168+
let s = query.query();
169+
170+
// [sp_prepare] will emit the column meta data
171+
// small issue is that we need to declare all the used placeholders with a "fallback" type
172+
// we currently use regex to collect them; false positives are *okay* but false
173+
// negatives would break the query
174+
let proc = Either::Right(Procedure::Prepare);
175+
176+
// NOTE: this does not support unicode identifiers; as we don't even support
177+
// named parameters (yet) this is probably fine, for now
178+
179+
static PARAMS_RE: Lazy<Regex> = Lazy::new(|| Regex::new(r"@p[[:alnum:]]+").unwrap());
180+
181+
let mut params = String::new();
182+
let mut num_params = 0;
183+
184+
for m in PARAMS_RE.captures_iter(s) {
185+
if !params.is_empty() {
186+
params.push_str(",");
187+
}
188+
189+
params.push_str(&m[0]);
190+
191+
// NOTE: this means that a query! of `SELECT @p1` will have the macros believe
192+
// it will return nvarchar(1); this is a greater issue with `query!` that we
193+
// we need to circle back to. This doesn't happen much in practice however.
194+
params.push_str(" nvarchar(1)");
195+
196+
num_params += 1;
197+
}
198+
199+
let params = if params.is_empty() {
200+
None
201+
} else {
202+
Some(&*params)
203+
};
204+
205+
let mut args = MsSqlArguments::default();
206+
207+
args.declare("", 0_i32);
208+
args.add_unnamed(params);
209+
args.add_unnamed(s);
210+
args.add_unnamed(0x0001_i32); // 1 = SEND_METADATA
211+
212+
self.stream.write_packet(
213+
PacketType::Rpc,
214+
RpcRequest {
215+
transaction_descriptor: self.stream.transaction_descriptor,
216+
arguments: &args,
217+
procedure: proc,
218+
options: OptionFlags::empty(),
219+
},
220+
);
221+
222+
Box::pin(async move {
223+
self.stream.flush().await?;
224+
225+
loop {
226+
match self.stream.recv_message().await? {
227+
Message::DoneProc(done) | Message::Done(done) => {
228+
if !done.status.contains(Status::DONE_MORE) {
229+
// done with prepare
230+
break;
231+
}
232+
}
233+
234+
_ => {}
235+
}
236+
}
237+
238+
let mut columns = Vec::with_capacity(self.stream.columns.len());
239+
240+
for col in &self.stream.columns {
241+
columns.push(Column {
242+
name: col.col_name.clone(),
243+
type_info: Some(MsSqlTypeInfo(col.type_info.clone())),
244+
not_null: Some(!col.flags.contains(Flags::NULLABLE)),
245+
});
246+
}
247+
248+
Ok(Describe {
249+
params: vec![None; num_params],
250+
columns,
251+
})
252+
})
161253
}
162254
}

sqlx-core/src/mssql/connection/stream.rs

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ use crate::mssql::protocol::login_ack::LoginAck;
1414
use crate::mssql::protocol::message::{Message, MessageType};
1515
use crate::mssql::protocol::packet::{PacketHeader, PacketType, Status};
1616
use crate::mssql::protocol::return_status::ReturnStatus;
17+
use crate::mssql::protocol::return_value::ReturnValue;
1718
use crate::mssql::protocol::row::Row;
1819
use crate::mssql::{MsSqlConnectOptions, MsSqlDatabaseError};
1920
use crate::net::MaybeTlsStream;
@@ -30,7 +31,7 @@ pub(crate) struct MsSqlStream {
3031

3132
// most recent column data from ColMetaData
3233
// we need to store this as its needed when decoding <Row>
33-
columns: Vec<ColumnData>,
34+
pub(crate) columns: Vec<ColumnData>,
3435
}
3536

3637
impl MsSqlStream {
@@ -112,6 +113,7 @@ impl MsSqlStream {
112113
};
113114

114115
let ty = MessageType::get(buf)?;
116+
115117
let message = match ty {
116118
MessageType::EnvChange => {
117119
match EnvChange::get(buf)? {
@@ -137,6 +139,7 @@ impl MsSqlStream {
137139
MessageType::Row => Message::Row(Row::get(buf, &self.columns)?),
138140
MessageType::LoginAck => Message::LoginAck(LoginAck::get(buf)?),
139141
MessageType::ReturnStatus => Message::ReturnStatus(ReturnStatus::get(buf)?),
142+
MessageType::ReturnValue => Message::ReturnValue(ReturnValue::get(buf)?),
140143
MessageType::Done => Message::Done(Done::get(buf)?),
141144
MessageType::DoneInProc => Message::DoneInProc(Done::get(buf)?),
142145
MessageType::DoneProc => Message::DoneProc(Done::get(buf)?),

0 commit comments

Comments
 (0)