Skip to content

Type Overrides in Column Names and Bind Arguments #397

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Jun 21, 2020
61 changes: 36 additions & 25 deletions sqlx-macros/src/query/args.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
use proc_macro2::TokenStream;
use syn::spanned::Spanned;
use syn::Expr;
use syn::{Expr, Type};

use quote::{quote, quote_spanned, ToTokens};
use quote::{quote, quote_spanned};
use sqlx_core::describe::Describe;

use crate::database::{DatabaseExt, ParamChecking};
Expand Down Expand Up @@ -34,26 +34,36 @@ pub fn quote_args<DB: DatabaseExt>(
// TODO: We could remove the ParamChecking flag and just filter to only test params that are non-null
let param_ty = param_ty.as_ref().unwrap();

let param_ty = get_type_override(expr)
.or_else(|| {
Some(
DB::param_type_for_id(&param_ty)?
.parse::<proc_macro2::TokenStream>()
.unwrap(),
)
})
.ok_or_else(|| {
if let Some(feature_gate) = <DB as DatabaseExt>::get_feature_gate(&param_ty) {
format!(
"optional feature `{}` required for type {} of param #{}",
feature_gate,
param_ty,
i + 1,
)
} else {
format!("unsupported type {} for param #{}", param_ty, i + 1)
}
})?;
let param_ty = match get_type_override(expr) {
// TODO: enable this in 1.45 when we can strip `as _`
// without stripping these we get some pretty nasty type errors
Some(Type::Infer(_)) => return Err(
syn::Error::new_spanned(
expr,
"casts to `_` are not allowed in bind parameters yet"
).into()
),
// cast or type ascription will fail to compile if the type does not match
Some(_) => return Ok(quote!()),
None => {
DB::param_type_for_id(&param_ty)
.ok_or_else(|| {
if let Some(feature_gate) = <DB as DatabaseExt>::get_feature_gate(&param_ty) {
format!(
"optional feature `{}` required for type {} of param #{}",
feature_gate,
param_ty,
i + 1,
)
} else {
format!("unsupported type {} for param #{}", param_ty, i + 1)
}
})?
.parse::<proc_macro2::TokenStream>()
.map_err(|_| format!("Rust type mapping for {} not parsable", param_ty))?

}
};

Ok(quote_spanned!(expr.span() =>
// this shouldn't actually run
Expand Down Expand Up @@ -97,10 +107,11 @@ pub fn quote_args<DB: DatabaseExt>(
})
}

fn get_type_override(expr: &Expr) -> Option<TokenStream> {
fn get_type_override(expr: &Expr) -> Option<&Type> {
match expr {
Expr::Cast(cast) => Some(cast.ty.to_token_stream()),
Expr::Type(ascription) => Some(ascription.ty.to_token_stream()),
Expr::Group(group) => get_type_override(&group.expr),
Expr::Cast(cast) => Some(&cast.ty),
Expr::Type(ascription) => Some(&ascription.ty),
_ => None,
}
}
9 changes: 9 additions & 0 deletions sqlx-macros/src/query/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,15 @@ where
RecordType::Generated => {
let record_name: Type = syn::parse_str("Record").unwrap();

for rust_col in &columns {
if rust_col.type_.is_none() {
return Err(
"columns may not have wildcard overrides in `query!()` or `query_as!()"
.into(),
);
}
}

let record_fields = columns.iter().map(
|&output::RustColumn {
ref ident,
Expand Down
210 changes: 150 additions & 60 deletions sqlx-macros/src/query/output.rs
Original file line number Diff line number Diff line change
@@ -1,17 +1,19 @@
use proc_macro2::{Ident, Span, TokenStream};
use quote::quote;
use quote::{quote, ToTokens};
use syn::Type;

use sqlx_core::describe::Describe;
use sqlx_core::describe::{Column, Describe};

use crate::database::DatabaseExt;

use crate::query::QueryMacroInput;
use std::fmt::{self, Display, Formatter};
use syn::parse::{Parse, ParseStream};
use syn::Token;

pub struct RustColumn {
pub(super) ident: Ident,
pub(super) type_: TokenStream,
pub(super) type_: Option<TokenStream>,
}

struct DisplayColumn<'a> {
Expand All @@ -20,6 +22,19 @@ struct DisplayColumn<'a> {
name: &'a str,
}

struct ColumnDecl {
ident: Ident,
// TIL Rust still has OOP keywords like `abstract`, `final`, `override` and `virtual` reserved
r#override: Option<ColumnOverride>,
}

enum ColumnOverride {
NonNull,
Nullable,
Wildcard,
Exact(Type),
}

impl Display for DisplayColumn<'_> {
fn fmt(&self, f: &mut Formatter) -> fmt::Result {
write!(f, "column #{} ({:?})", self.idx + 1, self.name)
Expand All @@ -32,58 +47,34 @@ pub fn columns_to_rust<DB: DatabaseExt>(describe: &Describe<DB>) -> crate::Resul
.iter()
.enumerate()
.map(|(i, column)| -> crate::Result<_> {
let name = &*column.name;
let ident = parse_ident(name)?;

let mut type_ = if let Some(type_info) = &column.type_info {
<DB as DatabaseExt>::return_type_for_id(&type_info).map_or_else(
|| {
let message = if let Some(feature_gate) =
<DB as DatabaseExt>::get_feature_gate(&type_info)
{
format!(
"optional feature `{feat}` required for type {ty} of {col}",
ty = &type_info,
feat = feature_gate,
col = DisplayColumn {
idx: i,
name: &*column.name
}
)
} else {
format!(
"unsupported type {ty} of {col}",
ty = type_info,
col = DisplayColumn {
idx: i,
name: &*column.name
}
)
};
syn::Error::new(Span::call_site(), message).to_compile_error()
},
|t| t.parse().unwrap(),
)
} else {
syn::Error::new(
Span::call_site(),
format!(
"database couldn't tell us the type of {col}; \
this can happen for columns that are the result of an expression",
col = DisplayColumn {
idx: i,
name: &*column.name
}
),
)
.to_compile_error()
// add raw prefix to all identifiers
let decl = ColumnDecl::parse(&column.name)
.map_err(|e| format!("column name {:?} is invalid: {}", column.name, e))?;

let type_ = match decl.r#override {
Some(ColumnOverride::Exact(ty)) => Some(ty.to_token_stream()),
Some(ColumnOverride::Wildcard) => None,
// these three could be combined but I prefer the clarity here
Some(ColumnOverride::NonNull) => Some(get_column_type(i, column)),
Some(ColumnOverride::Nullable) => {
let type_ = get_column_type(i, column);
Some(quote! { Option<#type_> })
}
None => {
let type_ = get_column_type(i, column);

if column.not_null.unwrap_or(false) {
Some(type_)
} else {
Some(quote! { Option<#type_> })
}
}
};

if !column.not_null.unwrap_or(false) {
type_ = quote! { Option<#type_> };
}

Ok(RustColumn { ident, type_ })
Ok(RustColumn {
ident: decl.ident,
type_,
})
})
.collect::<crate::Result<Vec<_>>>()
}
Expand All @@ -103,13 +94,15 @@ pub fn quote_query_as<DB: DatabaseExt>(
..
},
)| {
// For "checked" queries, the macro checks these at compile time and using "try_get"
// would also perform pointless runtime checks

if input.checked {
quote!( #ident: row.try_get_unchecked::<#type_, _>(#i).try_unwrap_optional()? )
} else {
quote!( #ident: row.try_get_unchecked(#i)? )
match (input.checked, type_) {
// we guarantee the type is valid so we can skip the runtime check
(true, Some(type_)) => quote! {
#ident: row.try_get_unchecked::<#type_, _>(#i).try_unwrap_optional()?
},
// type was overridden to be a wildcard so we fallback to the runtime check
(true, None) => quote! ( #ident: row.try_get(#i)? ),
// macro is the `_unchecked!()` variant so this will die in decoding if it's wrong
(false, _) => quote!( #ident: row.try_get_unchecked(#i)? ),
}
},
);
Expand All @@ -128,6 +121,103 @@ pub fn quote_query_as<DB: DatabaseExt>(
}
}

fn get_column_type<DB: DatabaseExt>(i: usize, column: &Column<DB>) -> TokenStream {
if let Some(type_info) = &column.type_info {
<DB as DatabaseExt>::return_type_for_id(&type_info).map_or_else(
|| {
let message =
if let Some(feature_gate) = <DB as DatabaseExt>::get_feature_gate(&type_info) {
format!(
"optional feature `{feat}` required for type {ty} of {col}",
ty = &type_info,
feat = feature_gate,
col = DisplayColumn {
idx: i,
name: &*column.name
}
)
} else {
format!(
"unsupported type {ty} of {col}",
ty = type_info,
col = DisplayColumn {
idx: i,
name: &*column.name
}
)
};
syn::Error::new(Span::call_site(), message).to_compile_error()
},
|t| t.parse().unwrap(),
)
} else {
syn::Error::new(
Span::call_site(),
format!(
"database couldn't tell us the type of {col}; \
this can happen for columns that are the result of an expression",
col = DisplayColumn {
idx: i,
name: &*column.name
}
),
)
.to_compile_error()
}
}

impl ColumnDecl {
fn parse(col_name: &str) -> crate::Result<Self> {
// find the end of the identifier because we want to use our own logic to parse it
// if we tried to feed this into `syn::parse_str()` we might get an un-great error
// for some kinds of invalid identifiers
let (ident, remainder) = if let Some(i) = col_name.find(&[':', '!', '?'][..]) {
let (ident, remainder) = col_name.split_at(i);

(parse_ident(ident)?, remainder)
} else {
(parse_ident(col_name)?, "")
};

Ok(ColumnDecl {
ident,
r#override: if !remainder.is_empty() {
Some(syn::parse_str(remainder)?)
} else {
None
},
})
}
}

impl Parse for ColumnOverride {
fn parse(input: ParseStream) -> syn::Result<Self> {
let lookahead = input.lookahead1();

if lookahead.peek(Token![:]) {
input.parse::<Token![:]>()?;

let ty = Type::parse(input)?;

if let Type::Infer(_) = ty {
Ok(ColumnOverride::Wildcard)
} else {
Ok(ColumnOverride::Exact(ty))
}
} else if lookahead.peek(Token![!]) {
input.parse::<Token![!]>()?;

Ok(ColumnOverride::NonNull)
} else if lookahead.peek(Token![?]) {
input.parse::<Token![?]>()?;

Ok(ColumnOverride::Nullable)
} else {
Err(lookahead.error())
}
}
}

fn parse_ident(name: &str) -> crate::Result<Ident> {
// workaround for the following issue (it's semi-fixed but still spits out extra diagnostics)
// https://github.com/dtolnay/syn/issues/749#issuecomment-575451318
Expand Down
Loading