Skip to content

Commit 1a6d681

Browse files
committed
Strict type checking in ToSql
cc #14
1 parent 043d231 commit 1a6d681

File tree

3 files changed

+56
-66
lines changed

3 files changed

+56
-66
lines changed

src/lib.rs

+14-10
Original file line numberDiff line numberDiff line change
@@ -263,16 +263,16 @@ impl PostgresConnection {
263263
264264
let types = [];
265265
self.write_messages([
266-
&Parse {
267-
name: stmt_name,
268-
query: query,
269-
param_types: types
270-
},
271-
&Describe {
272-
variant: 'S' as u8,
273-
name: stmt_name
274-
},
275-
&Sync]);
266+
&Parse {
267+
name: stmt_name,
268+
query: query,
269+
param_types: types
270+
},
271+
&Describe {
272+
variant: 'S' as u8,
273+
name: stmt_name
274+
},
275+
&Sync]);
276276
277277
match_read_message!(self, {
278278
ParseComplete => (),
@@ -429,6 +429,10 @@ impl<'self> Drop for PostgresStatement<'self> {
429429
}
430430
431431
impl<'self> PostgresStatement<'self> {
432+
pub fn num_params(&self) -> uint {
433+
self.param_types.len()
434+
}
435+
432436
fn bind(&self, portal_name: &str, params: &[&ToSql])
433437
-> Option<PostgresDbError> {
434438
let mut formats = ~[];

src/test.rs

+15-4
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ fn test_query() {
4444
do test_in_transaction |trans| {
4545
trans.update("CREATE TABLE foo (id BIGINT PRIMARY KEY)", []);
4646
trans.update("INSERT INTO foo (id) VALUES ($1), ($2)",
47-
[&1 as &ToSql, &2 as &ToSql]);
47+
[&1i64 as &ToSql, &2i64 as &ToSql]);
4848
let stmt = trans.prepare("SELECT * from foo ORDER BY id");
4949
let result = stmt.query([]);
5050

@@ -118,12 +118,12 @@ fn test_binary_bool_params() {
118118
119119
#[test]
120120
fn test_binary_i16_params() {
121-
test_param_type("SMALLINT", [Some(0x0011), Some(-0x0011), None]);
121+
test_param_type("SMALLINT", [Some(0x0011i16), Some(-0x0011i16), None]);
122122
}
123123
124124
#[test]
125125
fn test_binary_i32_params() {
126-
test_param_type("INT", [Some(0x00112233), Some(-0x00112233), None]);
126+
test_param_type("INT", [Some(0x00112233i32), Some(-0x00112233i32), None]);
127127
}
128128
129129
#[test]
@@ -184,8 +184,19 @@ fn test_wrong_num_params() {
184184
Err(PostgresDbError { code: ~"08P01", _ }) => (),
185185
resp => fail!("Unexpected response: %?", resp)
186186
}
187+
}
188+
}
187189
188-
trans.set_rollback();
190+
#[test]
191+
#[should_fail]
192+
fn test_wrong_param_type() {
193+
do test_in_transaction |trans| {
194+
trans.update("CREATE TABLE foo (
195+
id SERIAL PRIMARY KEY,
196+
val BOOL
197+
)", []);
198+
trans.try_update("INSERT INTO foo (val) VALUES ($1)",
199+
[&1i32 as &ToSql]);
189200
}
190201
}
191202

src/types.rs

+27-52
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ static INT2OID: Oid = 21;
1414
static INT4OID: Oid = 23;
1515
static FLOAT4OID: Oid = 700;
1616
static FLOAT8OID: Oid = 701;
17+
static VARCHAROID: Oid = 1043;
1718

1819
pub enum Format {
1920
Text = 0,
@@ -137,23 +138,24 @@ pub trait ToSql {
137138
fn to_sql(&self, ty: Oid) -> (Format, Option<~[u8]>);
138139
}
139140

140-
macro_rules! to_str_impl(
141-
($t:ty) => (
142-
impl ToSql for $t {
143-
fn to_sql(&self, _ty: Oid) -> (Format, Option<~[u8]>) {
144-
(Text, Some(self.to_str().into_bytes()))
145-
}
141+
macro_rules! check_oid(
142+
($expected:ident, $actual:ident) => (
143+
if $expected != $actual {
144+
fail!("Attempted to bind an invalid type. Expected Oid %? but got \
145+
Oid %?", $expected, $actual);
146146
}
147147
)
148148
)
149149

150150
macro_rules! to_option_impl(
151-
($t:ty) => (
151+
($oid:ident, $t:ty) => (
152152
impl ToSql for Option<$t> {
153153
fn to_sql(&self, ty: Oid) -> (Format, Option<~[u8]>) {
154+
check_oid!($oid, ty)
155+
154156
match *self {
155157
None => (Text, None),
156-
Some(val) => val.to_sql(ty)
158+
Some(ref val) => val.to_sql(ty)
157159
}
158160
}
159161
}
@@ -164,74 +166,47 @@ macro_rules! to_conversions_impl(
164166
($oid:ident, $t:ty, $f:ident) => (
165167
impl ToSql for $t {
166168
fn to_sql(&self, ty: Oid) -> (Format, Option<~[u8]>) {
167-
if ty == $oid {
168-
let mut writer = MemWriter::new();
169-
writer.$f(*self);
170-
(Binary, Some(writer.inner()))
171-
} else {
172-
(Text, Some(self.to_str().into_bytes()))
173-
}
169+
check_oid!($oid, ty)
170+
171+
let mut writer = MemWriter::new();
172+
writer.$f(*self);
173+
(Binary, Some(writer.inner()))
174174
}
175175
}
176176
)
177177
)
178178

179179
impl ToSql for bool {
180180
fn to_sql(&self, ty: Oid) -> (Format, Option<~[u8]>) {
181-
if ty == BOOLOID {
182-
(Binary, Some(~[*self as u8]))
183-
} else {
184-
(Text, Some(self.to_str().into_bytes()))
185-
}
181+
check_oid!(BOOLOID, ty)
182+
(Binary, Some(~[*self as u8]))
186183
}
187184
}
188-
to_option_impl!(bool)
185+
to_option_impl!(BOOLOID, bool)
189186

190187
to_conversions_impl!(INT2OID, i16, write_be_i16_)
191-
to_option_impl!(i16)
188+
to_option_impl!(INT2OID, i16)
192189
to_conversions_impl!(INT4OID, i32, write_be_i32_)
193-
to_option_impl!(i32)
190+
to_option_impl!(INT4OID, i32)
194191
to_conversions_impl!(INT8OID, i64, write_be_i64_)
195-
to_option_impl!(i64)
192+
to_option_impl!(INT8OID, i64)
196193
to_conversions_impl!(FLOAT4OID, f32, write_be_f32_)
197-
to_option_impl!(f32)
194+
to_option_impl!(FLOAT4OID, f32)
198195
to_conversions_impl!(FLOAT8OID, f64, write_be_f64_)
199-
to_option_impl!(f64)
200-
201-
to_str_impl!(int)
202-
to_option_impl!(int)
203-
to_str_impl!(i8)
204-
to_option_impl!(i8)
205-
to_str_impl!(uint)
206-
to_option_impl!(uint)
207-
to_str_impl!(u8)
208-
to_option_impl!(u8)
209-
to_str_impl!(u16)
210-
to_option_impl!(u16)
211-
to_str_impl!(u32)
212-
to_option_impl!(u32)
213-
to_str_impl!(u64)
214-
to_option_impl!(u64)
215-
to_str_impl!(float)
216-
to_option_impl!(float)
196+
to_option_impl!(FLOAT8OID, f64)
217197

218198
impl<'self> ToSql for &'self str {
219-
fn to_sql(&self, _ty: Oid) -> (Format, Option<~[u8]>) {
199+
fn to_sql(&self, ty: Oid) -> (Format, Option<~[u8]>) {
200+
check_oid!(VARCHAROID, ty)
220201
(Text, Some(self.as_bytes().to_owned()))
221202
}
222203
}
223204

224-
impl ToSql for Option<~str> {
225-
fn to_sql(&self, ty: Oid) -> (Format, Option<~[u8]>) {
226-
match *self {
227-
None => (Text, None),
228-
Some(ref val) => val.to_sql(ty)
229-
}
230-
}
231-
}
205+
to_option_impl!(VARCHAROID, ~str)
232206

233207
impl<'self> ToSql for Option<&'self str> {
234208
fn to_sql(&self, ty: Oid) -> (Format, Option<~[u8]>) {
209+
check_oid!(VARCHAROID, ty)
235210
match *self {
236211
None => (Text, None),
237212
Some(val) => val.to_sql(ty)

0 commit comments

Comments
 (0)