diff --git a/docs/reference.md b/docs/reference.md index 23058509..810dbfae 100644 --- a/docs/reference.md +++ b/docs/reference.md @@ -6552,11 +6552,13 @@ Buyer.select ## Schema -Additional tests to ensure schema mapping produces valid SQL + + If your table belongs to a schema other than the default schema of your database, you can specify this in your table definition with + `override def schemaName = "otherschema"` + ### Schema.schema.select -If your table belongs to a schema other than the default schema of your database, -you can specify this in your table definition with table.schemaName + ```scala Invoice.select @@ -6584,8 +6586,7 @@ Invoice.select ### Schema.schema.insert.columns -If your table belongs to a schema other than the default schema of your database, -you can specify this in your table definition with table.schemaName + ```scala Invoice.insert.columns( @@ -6611,8 +6612,7 @@ Invoice.insert.columns( ### Schema.schema.insert.values -If your table belongs to a schema other than the default schema of your database, -you can specify this in your table definition with table.schemaName + ```scala Invoice.insert @@ -6643,8 +6643,7 @@ Invoice.insert ### Schema.schema.update -If your table belongs to a schema other than the default schema of your database, -you can specify this in your table definition with table.schemaName + ```scala Invoice @@ -6677,8 +6676,7 @@ Invoice ### Schema.schema.delete -If your table belongs to a schema other than the default schema of your database, -you can specify this in your table definition with table.schemaName + ```scala Invoice.delete(_.id === 1) @@ -6701,8 +6699,7 @@ Invoice.delete(_.id === 1) ### Schema.schema.insert into -If your table belongs to a schema other than the default schema of your database, -you can specify this in your table definition with table.schemaName + ```scala Invoice.insert.select( @@ -6734,8 +6731,7 @@ Invoice.insert.select( ### Schema.schema.join -If your table belongs to a schema other than the default schema of your database, -you can specify this in your table definition with table.schemaName + ```scala Invoice.select.join(Invoice)(_.id `=` _.id).map(_._1.id) @@ -6760,6 +6756,139 @@ Invoice.select.join(Invoice)(_.id `=` _.id).map(_._1.id) +## EscapedTableName + + If your table name is a reserved sql world, e.g. `order`, you can specify this in your table definition with + `override def escape = true` + +### EscapedTableName.escape table name.select + + + +```scala +Select.select +``` + + +* + ```sql + SELECT select0.id AS id, select0.name AS name + FROM "select" select0 + ``` + + + +* + ```scala + Seq.empty[Select[Sc]] + ``` + + + +### EscapedTableName.escape table name.delete + + + +```scala +Select.delete(_ => true) +``` + + +* + ```sql + DELETE FROM "select" WHERE ? + ``` + + + +* + ```scala + 0 + ``` + + + +### EscapedTableName.escape table name.join + + + +```scala +Select.select.join(Select)(_.id `=` _.id) +``` + + +* + ```sql + SELECT + select0.id AS res_0_id, + select0.name AS res_0_name, + select1.id AS res_1_id, + select1.name AS res_1_name + FROM + "select" select0 + JOIN "select" select1 ON (select0.id = select1.id) + ``` + + + +* + ```scala + Seq.empty[(Select[Sc], Select[Sc])] + ``` + + + +### EscapedTableName.escape table name.update + + + +```scala +Select.update(_ => true).set(_.name := "hello") +``` + + +* + ```sql + UPDATE "select" SET name = ? + ``` + + + +* + ```scala + 0 + ``` + + + +### EscapedTableName.escape table name.insert + + + +```scala +Select.insert.values( + Select[Sc]( + id = 0, + name = "hello" + ) +) +``` + + +* + ```sql + INSERT INTO "select" (id, name) VALUES (?, ?) + ``` + + + +* + ```scala + 1 + ``` + + + ## SubQuery Queries that explicitly use subqueries (e.g. for `JOIN`s) or require subqueries to preserve the Scala semantics of the various operators ### SubQuery.sortTakeJoin diff --git a/scalasql/core/src/Context.scala b/scalasql/core/src/Context.scala index 5df852f5..9af889cb 100644 --- a/scalasql/core/src/Context.scala +++ b/scalasql/core/src/Context.scala @@ -24,6 +24,8 @@ trait Context { */ def config: Config + def dialectConfig: DialectConfig + def withFromNaming(fromNaming: Map[Context.From, String]): Context def withExprNaming(exprNaming: Map[Expr.Identity, SqlStr]): Context } @@ -56,7 +58,8 @@ object Context { case class Impl( fromNaming: Map[From, String], exprNaming: Map[Expr.Identity, SqlStr], - config: Config + config: Config, + dialectConfig: DialectConfig ) extends Context { def withFromNaming(fromNaming: Map[From, String]): Context = copy(fromNaming = fromNaming) @@ -93,7 +96,7 @@ object Context { .map { case (e, s) => (e, sql"${SqlStr.raw(newFromNaming(t), Array(e))}.$s") } } - Context.Impl(newFromNaming, newExprNaming, prevContext.config) + Context.Impl(newFromNaming, newExprNaming, prevContext.config, prevContext.dialectConfig) } } diff --git a/scalasql/core/src/DbApi.scala b/scalasql/core/src/DbApi.scala index be9aa105..5db3c80a 100644 --- a/scalasql/core/src/DbApi.scala +++ b/scalasql/core/src/DbApi.scala @@ -123,17 +123,22 @@ trait DbApi extends AutoCloseable { object DbApi { - def unpackQueryable[R, Q](query: Q, qr: Queryable[Q, R], config: Config) = { - val ctx = Context.Impl(Map(), Map(), config) + def unpackQueryable[R, Q]( + query: Q, + qr: Queryable[Q, R], + config: Config, + dialectConfig: DialectConfig + ) = { + val ctx = Context.Impl(Map(), Map(), config, dialectConfig) val flattened = SqlStr.flatten(qr.renderSql(query, ctx)) flattened } - def renderSql[Q, R](query: Q, config: Config, castParams: Boolean = false)( + def renderSql[Q, R](query: Q, config: Config, dialectConfig: DialectConfig)( implicit qr: Queryable[Q, R] ): String = { - val flattened = unpackQueryable(query, qr, config) - flattened.renderSql(castParams) + val flattened = unpackQueryable(query, qr, config, dialectConfig) + flattened.renderSql(dialectConfig.castParams) } /** @@ -196,7 +201,7 @@ object DbApi { lineNum: sourcecode.Line ): R = { - val flattened = unpackQueryable(query, qr, config) + val flattened = unpackQueryable(query, qr, config, dialect) if (qr.isGetGeneratedKeys(query).nonEmpty) updateGetGeneratedKeysSql(flattened)(qr.isGetGeneratedKeys(query).get, fileName, lineNum) .asInstanceOf[R] @@ -225,7 +230,7 @@ object DbApi { fileName: sourcecode.FileName, lineNum: sourcecode.Line ): Generator[R] = { - val flattened = unpackQueryable(query, qr, config) + val flattened = unpackQueryable(query, qr, config, dialect) streamFlattened0( r => { qr.asInstanceOf[Queryable[Q, R]].construct(query, r) match { @@ -276,7 +281,7 @@ object DbApi { ): Int = { val flattened = SqlStr.flatten(sql) runRawUpdate0( - flattened.renderSql(DialectConfig.castParams(dialect)), + flattened.renderSql(dialect.castParams), flattenParamPuts(flattened), fetchSize, queryTimeoutSeconds, @@ -296,7 +301,7 @@ object DbApi { ): IndexedSeq[R] = { val flattened = SqlStr.flatten(sql) runRawUpdateGetGeneratedKeys0( - flattened.renderSql(DialectConfig.castParams(dialect)), + flattened.renderSql(dialect.castParams), flattenParamPuts(flattened), fetchSize, queryTimeoutSeconds, @@ -382,7 +387,7 @@ object DbApi { lineNum: sourcecode.Line ) = streamRaw0( construct, - flattened.renderSql(DialectConfig.castParams(dialect)), + flattened.renderSql(dialect.castParams), flattenParamPuts(flattened), fetchSize, queryTimeoutSeconds, @@ -508,7 +513,7 @@ object DbApi { def renderSql[Q, R](query: Q, castParams: Boolean = false)( implicit qr: Queryable[Q, R] ): String = { - DbApi.renderSql(query, config, castParams) + DbApi.renderSql(query, config, dialect.withCastParams(castParams)) } val savepointStack = collection.mutable.ArrayDeque.empty[java.sql.Savepoint] diff --git a/scalasql/core/src/DbClient.scala b/scalasql/core/src/DbClient.scala index 2524479d..9e55694e 100644 --- a/scalasql/core/src/DbClient.scala +++ b/scalasql/core/src/DbClient.scala @@ -44,7 +44,7 @@ object DbClient { def renderSql[Q, R](query: Q, castParams: Boolean = false)( implicit qr: Queryable[Q, R] ): String = { - DbApi.renderSql(query, config, castParams) + DbApi.renderSql(query, config, dialect.withCastParams(castParams)) } def transaction[T](block: DbApi.Txn => T): T = { @@ -74,7 +74,7 @@ object DbClient { def renderSql[Q, R](query: Q, castParams: Boolean = false)( implicit qr: Queryable[Q, R] ): String = { - DbApi.renderSql(query, config, castParams) + DbApi.renderSql(query, config, dialect.withCastParams(castParams)) } private def withConnection[T](f: DbClient.Connection => T): T = { diff --git a/scalasql/core/src/DialectConfig.scala b/scalasql/core/src/DialectConfig.scala index e6964a6a..4c4f9930 100644 --- a/scalasql/core/src/DialectConfig.scala +++ b/scalasql/core/src/DialectConfig.scala @@ -1,9 +1,13 @@ package scalasql.core -trait DialectConfig { - protected def dialectCastParams: Boolean -} +trait DialectConfig { that => + def castParams: Boolean + def escape(str: String): String + + def withCastParams(params: Boolean) = new DialectConfig { + def castParams: Boolean = params + + def escape(str: String): String = that.escape(str) -object DialectConfig { - def castParams(d: DialectConfig) = d.dialectCastParams + } } diff --git a/scalasql/query/src/Delete.scala b/scalasql/query/src/Delete.scala index a118deb0..053981cf 100644 --- a/scalasql/query/src/Delete.scala +++ b/scalasql/query/src/Delete.scala @@ -24,7 +24,7 @@ object Delete { class Renderer(table: TableRef, expr: Expr[Boolean], prevContext: Context) { implicit val implicitCtx: Context = Context.compute(prevContext, Nil, Some(table)) lazy val tableNameStr = - SqlStr.raw(Table.resolve(table.value)) + SqlStr.raw(Table.fullIdentifier(table.value)) def render() = sql"DELETE FROM $tableNameStr WHERE $expr" } diff --git a/scalasql/query/src/From.scala b/scalasql/query/src/From.scala index 7ed8e82a..67f4cb17 100644 --- a/scalasql/query/src/From.scala +++ b/scalasql/query/src/From.scala @@ -15,7 +15,7 @@ class TableRef(val value: Table.Base) extends From { def fromExprAliases(prevContext: Context): Seq[(Expr.Identity, SqlStr)] = Nil def renderSql(name: SqlStr, prevContext: Context, liveExprs: LiveExprs) = { - val resolvedTable = Table.resolve(value)(prevContext) + val resolvedTable = Table.fullIdentifier(value)(prevContext) SqlStr.raw(resolvedTable + sql" " + name) } } diff --git a/scalasql/query/src/InsertColumns.scala b/scalasql/query/src/InsertColumns.scala index d166bda9..0a47f8a1 100644 --- a/scalasql/query/src/InsertColumns.scala +++ b/scalasql/query/src/InsertColumns.scala @@ -24,7 +24,7 @@ object InsertColumns { protected def expr: V[Column] = WithSqlExpr.get(insert) private[scalasql] override def renderSql(ctx: Context) = - new Renderer(columns, ctx, valuesLists, Table.resolve(table.value)(ctx)).render() + new Renderer(columns, ctx, valuesLists, Table.fullIdentifier(table.value)(ctx)).render() override protected def queryConstruct(args: Queryable.ResultSetIterator): Int = args.get(IntType) diff --git a/scalasql/query/src/InsertSelect.scala b/scalasql/query/src/InsertSelect.scala index 7c6bb1c1..e59e2a4b 100644 --- a/scalasql/query/src/InsertSelect.scala +++ b/scalasql/query/src/InsertSelect.scala @@ -20,7 +20,12 @@ object InsertSelect { def table = insert.table private[scalasql] override def renderSql(ctx: Context) = - new Renderer(select, select.qr.walkExprs(columns), ctx, Table.resolve(table.value)(ctx)) + new Renderer( + select, + select.qr.walkExprs(columns), + ctx, + Table.fullIdentifier(table.value)(ctx) + ) .render() override protected def queryConstruct(args: Queryable.ResultSetIterator): Int = diff --git a/scalasql/query/src/InsertValues.scala b/scalasql/query/src/InsertValues.scala index 686160fd..5ec85ed6 100644 --- a/scalasql/query/src/InsertValues.scala +++ b/scalasql/query/src/InsertValues.scala @@ -24,7 +24,7 @@ object InsertValues { override private[scalasql] def renderSql(ctx: Context): SqlStr = { new Renderer( - Table.resolve(insert.table.value)(ctx), + Table.fullIdentifier(insert.table.value)(ctx), Table.labels(insert.table.value), values, qr, diff --git a/scalasql/query/src/Table.scala b/scalasql/query/src/Table.scala index 6449d49e..65e1cda1 100644 --- a/scalasql/query/src/Table.scala +++ b/scalasql/query/src/Table.scala @@ -14,6 +14,8 @@ abstract class Table[V[_[_]]]()(implicit name: sourcecode.Name, metadata0: Table protected[scalasql] def schemaName = "" + protected[scalasql] def escape: Boolean = false + protected implicit def tableSelf: Table[V] = this protected def tableMetadata: Table.Metadata[V] = metadata0 @@ -50,11 +52,21 @@ object Table { def name(t: Table.Base) = t.tableName def labels(t: Table.Base) = t.tableLabels def columnNameOverride[V[_[_]]](t: Table.Base)(s: String) = t.tableColumnNameOverride(s) - def resolve(t: Table.Base)(implicit context: Context) = { - val mappedTableName = context.config.tableNameMapper(t.tableName) + def identifier(t: Table.Base)(implicit context: Context): String = { + context.config.tableNameMapper.andThen { str => + if (t.escape) { + context.dialectConfig.escape(str) + } else { + str + } + }(t.tableName) + } + def fullIdentifier( + t: Table.Base + )(implicit context: Context): String = { t.schemaName match { - case "" => mappedTableName - case str => s"$str." + mappedTableName + case "" => identifier(t) + case str => s"$str." + identifier(t) } } trait Base { @@ -66,6 +78,7 @@ object Table { protected[scalasql] def tableName: String protected[scalasql] def schemaName: String protected[scalasql] def tableLabels: Seq[String] + protected[scalasql] def escape: Boolean /** * Customizations to the column names of this table before processing, diff --git a/scalasql/query/src/Update.scala b/scalasql/query/src/Update.scala index eb088a95..9d187567 100644 --- a/scalasql/query/src/Update.scala +++ b/scalasql/query/src/Update.scala @@ -94,7 +94,7 @@ object Update { implicit lazy val implicitCtx: Context = Context.compute(prevContext, froms, Some(table)) lazy val tableName = - SqlStr.raw(Table.resolve(table.value)) + SqlStr.raw(Table.fullIdentifier(table.value)) lazy val updateList = set0.map { case assign => val kStr = SqlStr.raw(prevContext.config.columnNameMapper(assign.column.name)) diff --git a/scalasql/src/dialects/H2Dialect.scala b/scalasql/src/dialects/H2Dialect.scala index 139f3e5b..01441a9b 100644 --- a/scalasql/src/dialects/H2Dialect.scala +++ b/scalasql/src/dialects/H2Dialect.scala @@ -27,7 +27,9 @@ import java.sql.PreparedStatement trait H2Dialect extends Dialect { - protected def dialectCastParams = true + def castParams = true + + def escape(str: String) = s"\"${str.toUpperCase()}\"" override implicit def EnumType[T <: Enumeration#Value]( implicit constructor: String => T diff --git a/scalasql/src/dialects/MySqlDialect.scala b/scalasql/src/dialects/MySqlDialect.scala index 331345aa..0c0a2f37 100644 --- a/scalasql/src/dialects/MySqlDialect.scala +++ b/scalasql/src/dialects/MySqlDialect.scala @@ -43,7 +43,9 @@ import scala.reflect.ClassTag import scalasql.query.Select trait MySqlDialect extends Dialect { - protected def dialectCastParams = false + def castParams = false + + def escape(str: String) = s"`$str`" override implicit def ByteType: TypeMapper[Byte] = new MySqlByteType class MySqlByteType extends ByteType { override def castTypeString = "SIGNED" } diff --git a/scalasql/src/dialects/PostgresDialect.scala b/scalasql/src/dialects/PostgresDialect.scala index 6c98ca04..5df74c9c 100644 --- a/scalasql/src/dialects/PostgresDialect.scala +++ b/scalasql/src/dialects/PostgresDialect.scala @@ -17,7 +17,9 @@ import scalasql.operations.{ConcatOps, HyperbolicMathOps, MathOps, PadOps, TrimO trait PostgresDialect extends Dialect with ReturningDialect with OnConflictOps { - protected def dialectCastParams = false + def castParams = false + + def escape(str: String) = s"\"$str\"" override implicit def ByteType: TypeMapper[Byte] = new PostgresByteType class PostgresByteType extends ByteType { override def castTypeString = "INTEGER" } diff --git a/scalasql/src/dialects/SqliteDialect.scala b/scalasql/src/dialects/SqliteDialect.scala index aca56043..5673df1f 100644 --- a/scalasql/src/dialects/SqliteDialect.scala +++ b/scalasql/src/dialects/SqliteDialect.scala @@ -18,7 +18,9 @@ import scalasql.operations.TrimOps import java.time.{Instant, LocalDate, LocalDateTime} trait SqliteDialect extends Dialect with ReturningDialect with OnConflictOps { - protected def dialectCastParams = false + def castParams = false + + def escape(str: String) = s"\"$str\"" override implicit def LocalDateTimeType: TypeMapper[LocalDateTime] = new SqliteLocalDateTimeType class SqliteLocalDateTimeType extends LocalDateTimeType { diff --git a/scalasql/test/resources/h2-customer-schema.sql b/scalasql/test/resources/h2-customer-schema.sql index f254d3a8..9dbcb092 100644 --- a/scalasql/test/resources/h2-customer-schema.sql +++ b/scalasql/test/resources/h2-customer-schema.sql @@ -11,6 +11,7 @@ DROP TABLE IF EXISTS nested CASCADE; DROP TABLE IF EXISTS enclosing CASCADE; DROP TABLE IF EXISTS invoice CASCADE; DROP SCHEMA IF EXISTS otherschema CASCADE; +DROP TABLE IF EXISTS "SELECT" CASCADE; CREATE TABLE buyer ( id INTEGER AUTO_INCREMENT PRIMARY KEY, @@ -98,4 +99,9 @@ CREATE TABLE otherschema.invoice( id INTEGER AUTO_INCREMENT PRIMARY KEY, total DECIMAL(20, 2), vendor_name VARCHAR(256) -); \ No newline at end of file +); + +CREATE TABLE "SELECT"( + id INTEGER, + name VARCHAR(256) +) \ No newline at end of file diff --git a/scalasql/test/resources/mysql-customer-schema.sql b/scalasql/test/resources/mysql-customer-schema.sql index 61d1c149..7be200b7 100644 --- a/scalasql/test/resources/mysql-customer-schema.sql +++ b/scalasql/test/resources/mysql-customer-schema.sql @@ -10,6 +10,8 @@ DROP TABLE IF EXISTS `non_round_trip_types` CASCADE; DROP TABLE IF EXISTS `opt_cols` CASCADE; DROP TABLE IF EXISTS `nested` CASCADE; DROP TABLE IF EXISTS `enclosing` CASCADE; +DROP TABLE IF EXISTS `select` CASCADE; + SET FOREIGN_KEY_CHECKS = 1; CREATE TABLE buyer ( @@ -90,3 +92,8 @@ CREATE TABLE enclosing( foo_id INTEGER, my_boolean BOOLEAN ); + +CREATE TABLE `select`( + id INTEGER, + name VARCHAR(256) +); diff --git a/scalasql/test/resources/postgres-customer-schema.sql b/scalasql/test/resources/postgres-customer-schema.sql index 06ad4a8e..5eedac3f 100644 --- a/scalasql/test/resources/postgres-customer-schema.sql +++ b/scalasql/test/resources/postgres-customer-schema.sql @@ -12,6 +12,7 @@ DROP TABLE IF EXISTS enclosing CASCADE; DROP TABLE IF EXISTS invoice CASCADE; DROP TYPE IF EXISTS my_enum CASCADE; DROP SCHEMA IF EXISTS otherschema CASCADE; +DROP TABLE IF EXISTS "select" CASCADE; CREATE TABLE buyer ( id SERIAL PRIMARY KEY, @@ -103,3 +104,8 @@ CREATE TABLE otherschema.invoice( total DECIMAL(20, 2), vendor_name VARCHAR(256) ); + +CREATE TABLE "select"( + id INTEGER, + name VARCHAR(256) +); diff --git a/scalasql/test/resources/sqlite-customer-schema.sql b/scalasql/test/resources/sqlite-customer-schema.sql index 6cd2c297..9d11cb94 100644 --- a/scalasql/test/resources/sqlite-customer-schema.sql +++ b/scalasql/test/resources/sqlite-customer-schema.sql @@ -9,6 +9,7 @@ DROP TABLE IF EXISTS non_round_trip_types; DROP TABLE IF EXISTS nested; DROP TABLE IF EXISTS enclosing; DROP TABLE IF EXISTS opt_cols; +DROP TABLE IF EXISTS "select"; CREATE TABLE buyer ( id INTEGER PRIMARY KEY AUTOINCREMENT, @@ -91,3 +92,8 @@ CREATE TABLE enclosing( foo_id INTEGER, my_boolean BOOLEAN ); + +CREATE TABLE "select"( + id INTEGER, + name VARCHAR(256) +) diff --git a/scalasql/test/src/ConcreteTestSuites.scala b/scalasql/test/src/ConcreteTestSuites.scala index 8b906f39..81099425 100644 --- a/scalasql/test/src/ConcreteTestSuites.scala +++ b/scalasql/test/src/ConcreteTestSuites.scala @@ -29,7 +29,8 @@ import query.{ WindowFunctionTests, GetGeneratedKeysTests, WithCteTests, - SchemaTests + SchemaTests, + EscapedTableNameTests } import scalasql.dialects.{ MySqlDialectTests, @@ -61,6 +62,7 @@ package postgres { object WindowFunctionTests extends WindowFunctionTests with PostgresSuite object GetGeneratedKeysTests extends GetGeneratedKeysTests with PostgresSuite object SchemaTests extends SchemaTests with PostgresSuite + object EscapedTableNameTests extends EscapedTableNameTests with PostgresSuite object SubQueryTests extends SubQueryTests with PostgresSuite object WithCteTests extends WithCteTests with PostgresSuite @@ -106,6 +108,7 @@ package hikari { object WindowFunctionTests extends WindowFunctionTests with HikariSuite object GetGeneratedKeysTests extends GetGeneratedKeysTests with HikariSuite object SchemaTests extends SchemaTests with HikariSuite + object EscapedTableNameTests extends EscapedTableNameTests with HikariSuite object SubQueryTests extends SubQueryTests with HikariSuite object WithCteTests extends WithCteTests with HikariSuite @@ -168,6 +171,7 @@ package mysql { object ExprMathOpsTests extends ExprMathOpsTests with MySqlSuite // In MySql, schemas are databases and this requires special treatment not yet implemented here // object SchemaTests extends SchemaTests with MySqlSuite + object EscapedTableNameTests extends EscapedTableNameTests with MySqlSuite object DataTypesTests extends datatypes.DataTypesTests with MySqlSuite object OptionalTests extends datatypes.OptionalTests with MySqlSuite @@ -215,6 +219,7 @@ package sqlite { // object ExprMathOpsTests extends ExprMathOpsTests with SqliteSuite // Sqlite doesn't support schemas // object SchemaTests extends SchemaTests with SqliteSuite + object EscapedTableNameTests extends EscapedTableNameTests with SqliteSuite object DataTypesTests extends datatypes.DataTypesTests with SqliteSuite object OptionalTests extends datatypes.OptionalTests with SqliteSuite @@ -248,6 +253,7 @@ package h2 { object WindowFunctionTests extends WindowFunctionTests with H2Suite object GetGeneratedKeysTests extends GetGeneratedKeysTests with H2Suite object SchemaTests extends SchemaTests with H2Suite + object EscapedTableNameTests extends EscapedTableNameTests with H2Suite object SubQueryTests extends SubQueryTests with H2Suite object WithCteTests extends WithCteTests with H2Suite diff --git a/scalasql/test/src/UnitTestData.scala b/scalasql/test/src/UnitTestData.scala index 6e3b5ed4..aca16d6f 100644 --- a/scalasql/test/src/UnitTestData.scala +++ b/scalasql/test/src/UnitTestData.scala @@ -14,6 +14,11 @@ object Invoice extends Table[Invoice] { override def schemaName = "otherschema" } +case class Select[T[_]](id: T[Int], name: T[String]) +object Select extends Table[Select] { + override def escape = true +} + case class ShippingInfo[T[_]](id: T[Int], buyerId: T[Int], shippingDate: T[LocalDate]) object ShippingInfo extends Table[ShippingInfo] diff --git a/scalasql/test/src/query/EscapedTableNameTests.scala b/scalasql/test/src/query/EscapedTableNameTests.scala new file mode 100644 index 00000000..4f6da0c5 --- /dev/null +++ b/scalasql/test/src/query/EscapedTableNameTests.scala @@ -0,0 +1,93 @@ +package scalasql.query + +import scalasql._ +import scalasql.core.JoinNullable +import sourcecode.Text +import utest._ +import utils.ScalaSqlSuite + +import java.time.LocalDate +import scalasql.core.Config + +trait EscapedTableNameTests extends ScalaSqlSuite { + def description = """ + If your table name is a reserved sql world, e.g. `order`, you can specify this in your table definition with + `override def escape = true` + """ + + def tests = Tests { + test("escape table name") { + val tableNameEscaped = dialectSelf.escape(Config.camelToSnake(Table.name(Select))) + test("select") { + checker( + query = Text { + Select.select + }, + sql = s""" + SELECT select0.id AS id, select0.name AS name + FROM $tableNameEscaped select0 + """, + value = Seq.empty[Select[Sc]], + docs = "" + ) + } + test("delete") { + checker( + query = Text { + Select.delete(_ => true) + }, + sql = s"DELETE FROM $tableNameEscaped WHERE ?", + value = 0, + docs = "" + ) + } + test("join") { + checker( + query = Text { + Select.select.join(Select)(_.id `=` _.id) + }, + sql = s""" + SELECT + select0.id AS res_0_id, + select0.name AS res_0_name, + select1.id AS res_1_id, + select1.name AS res_1_name + FROM + $tableNameEscaped select0 + JOIN $tableNameEscaped select1 ON (select0.id = select1.id) + """, + value = Seq.empty[(Select[Sc], Select[Sc])], + docs = "" + ) + } + test("update") { + checker( + query = Text { + Select.update(_ => true).set(_.name := "hello") + }, + sqls = Seq( + s"UPDATE $tableNameEscaped SET $tableNameEscaped.name = ?", + s"UPDATE $tableNameEscaped SET name = ?" + ), + value = 0, + docs = "" + ) + } + test("insert") { + checker( + query = Text { + Select.insert.values( + Select[Sc]( + id = 0, + name = "hello" + ) + ) + }, + sql = s"INSERT INTO $tableNameEscaped (id, name) VALUES (?, ?)", + value = 1, + docs = "" + ) + } + } + } +} diff --git a/scalasql/test/src/query/SchemaTests.scala b/scalasql/test/src/query/SchemaTests.scala index 7f7f910f..60bc4623 100644 --- a/scalasql/test/src/query/SchemaTests.scala +++ b/scalasql/test/src/query/SchemaTests.scala @@ -9,7 +9,10 @@ import utils.ScalaSqlSuite import java.time.LocalDate trait SchemaTests extends ScalaSqlSuite { - def description = "Additional tests to ensure schema mapping produces valid SQL" + def description = """ + If your table belongs to a schema other than the default schema of your database, you can specify this in your table definition with + `override def schemaName = "otherschema"` + """ def tests = Tests { test("schema") { @@ -27,10 +30,7 @@ trait SchemaTests extends ScalaSqlSuite { Invoice[Sc](id = 2, total = 213.3, vendor_name = "Samsung"), Invoice[Sc](id = 3, total = 407.2, vendor_name = "Shell") ), - docs = """ - If your table belongs to a schema other than the default schema of your database, - you can specify this in your table definition with table.schemaName - """ + docs = "" ) } test("insert.columns") { @@ -41,10 +41,7 @@ trait SchemaTests extends ScalaSqlSuite { ), sql = "INSERT INTO otherschema.invoice (total, vendor_name) VALUES (?, ?)", value = 1, - docs = """ - If your table belongs to a schema other than the default schema of your database, - you can specify this in your table definition with table.schemaName - """ + docs = "" ) } test("insert.values") { @@ -60,10 +57,7 @@ trait SchemaTests extends ScalaSqlSuite { .skipColumns(_.id), sql = "INSERT INTO otherschema.invoice (total, vendor_name) VALUES (?, ?)", value = 1, - docs = """ - If your table belongs to a schema other than the default schema of your database, - you can specify this in your table definition with table.schemaName - """ + docs = "" ) } test("update") { @@ -81,10 +75,7 @@ trait SchemaTests extends ScalaSqlSuite { WHERE (invoice.id = ?)""", value = 1, - docs = """ - If your table belongs to a schema other than the default schema of your database, - you can specify this in your table definition with table.schemaName - """ + docs = "" ) } test("delete") { @@ -92,10 +83,7 @@ trait SchemaTests extends ScalaSqlSuite { query = Invoice.delete(_.id === 1), sql = "DELETE FROM otherschema.invoice WHERE (invoice.id = ?)", value = 1, - docs = """ - If your table belongs to a schema other than the default schema of your database, - you can specify this in your table definition with table.schemaName - """ + docs = "" ) } test("insert into") { @@ -112,10 +100,7 @@ trait SchemaTests extends ScalaSqlSuite { FROM otherschema.invoice invoice0""", value = 4, - docs = """ - If your table belongs to a schema other than the default schema of your database, - you can specify this in your table definition with table.schemaName - """ + docs = "" ) } test("join") { @@ -129,10 +114,7 @@ trait SchemaTests extends ScalaSqlSuite { otherschema.invoice invoice0 JOIN otherschema.invoice invoice1 ON (invoice0.id = invoice1.id)""", value = Seq(2, 3, 4, 5, 6, 7, 8, 9), - docs = """ - If your table belongs to a schema other than the default schema of your database, - you can specify this in your table definition with table.schemaName - """ + docs = "" ) } }