diff --git a/docs/reference.md b/docs/reference.md index 7754b9dd..bc6d28b6 100644 --- a/docs/reference.md +++ b/docs/reference.md @@ -139,6 +139,38 @@ dbClient.transaction { db => +### DbApi.updateGetGeneratedKeysSql + +Allows you to fetch the primary keys that were auto-generated for an INSERT +defined as a `SqlStr`. +Note: not supported by Sqlite https://github.com/xerial/sqlite-jdbc/issues/980 + +```scala +dbClient.transaction { db => + val newName = "Moo Moo Cow" + val newDateOfBirth = LocalDate.parse("2000-01-01") + val generatedIds = db + .updateGetGeneratedKeysSql[Int]( + sql"INSERT INTO buyer (name, date_of_birth) VALUES ($newName, $newDateOfBirth), ($newName, $newDateOfBirth)" + ) + + assert(generatedIds == Seq(4, 5)) + + db.run(Buyer.select) ==> List( + Buyer[Sc](1, "James Bond", LocalDate.parse("2001-02-03")), + Buyer[Sc](2, "叉烧包", LocalDate.parse("1923-11-12")), + Buyer[Sc](3, "Li Haoyi", LocalDate.parse("1965-08-09")), + Buyer[Sc](4, "Moo Moo Cow", LocalDate.parse("2000-01-01")), + Buyer[Sc](5, "Moo Moo Cow", LocalDate.parse("2000-01-01")) + ) +} +``` + + + + + + ### DbApi.runRaw `runRawQuery` is similar to `runQuery` but allows you to pass in the SQL strings @@ -183,6 +215,40 @@ dbClient.transaction { db => +### DbApi.updateGetGeneratedKeysRaw + +Allows you to fetch the primary keys that were auto-generated for an INSERT +defined using a raw `java.lang.String` and variables. +Note: not supported by Sqlite https://github.com/xerial/sqlite-jdbc/issues/980 + +```scala +dbClient.transaction { db => + val generatedKeys = db.updateGetGeneratedKeysRaw[Int]( + "INSERT INTO buyer (name, date_of_birth) VALUES (?, ?), (?, ?)", + Seq( + "Moo Moo Cow", + LocalDate.parse("2000-01-01"), + "Moo Moo Cow", + LocalDate.parse("2000-01-01") + ) + ) + assert(generatedKeys == Seq(4, 5)) + + db.run(Buyer.select) ==> List( + Buyer[Sc](1, "James Bond", LocalDate.parse("2001-02-03")), + Buyer[Sc](2, "叉烧包", LocalDate.parse("1923-11-12")), + Buyer[Sc](3, "Li Haoyi", LocalDate.parse("1965-08-09")), + Buyer[Sc](4, "Moo Moo Cow", LocalDate.parse("2000-01-01")), + Buyer[Sc](5, "Moo Moo Cow", LocalDate.parse("2000-01-01")) + ) +} +``` + + + + + + ### DbApi.stream `db.stream` can be run on queries that return `Seq[T]`s, and makes them @@ -5877,6 +5943,262 @@ Purchase.select.mapAggregate((p, ps) => +## GetGeneratedKeys +`INSERT` operations with `.getGeneratedKeys`. Not supported by Sqlite (see https://github.com/xerial/sqlite-jdbc/issues/980) +### GetGeneratedKeys.single.values + +`getGeneratedKeys` on an `insert` returns the primary key, even if it was provided + explicitly. + +```scala +Buyer.insert + .values( + Buyer[Sc](17, "test buyer", LocalDate.parse("2023-09-09")) + ) + .getGeneratedKeys[Int] +``` + + +* + ```sql + INSERT INTO buyer (id, name, date_of_birth) VALUES (?, ?, ?) + ``` + + + +* + ```scala + Seq(17) + ``` + + + +---- + + + +```scala +Buyer.select.filter(_.name `=` "test buyer") +``` + + + + +* + ```scala + Seq(Buyer[Sc](17, "test buyer", LocalDate.parse("2023-09-09"))) + ``` + + + +### GetGeneratedKeys.single.columns + +All styles of `INSERT` query support `.getGeneratedKeys`, with this example +using `insert.columns` rather than `insert.values`. You can also retrieve +the generated primary keys using any compatible type, here shown using `Long` +rather than `Int` + +```scala +Buyer.insert + .columns( + _.name := "test buyer", + _.dateOfBirth := LocalDate.parse("2023-09-09"), + _.id := 4 + ) + .getGeneratedKeys[Long] +``` + + +* + ```sql + INSERT INTO buyer (name, date_of_birth, id) VALUES (?, ?, ?) + ``` + + + +* + ```scala + Seq(4L) + ``` + + + +---- + + + +```scala +Buyer.select.filter(_.name `=` "test buyer") +``` + + + + +* + ```scala + Seq(Buyer[Sc](4, "test buyer", LocalDate.parse("2023-09-09"))) + ``` + + + +### GetGeneratedKeys.single.partial + +If the primary key was not provided but was auto-generated by the database, +`getGeneratedKeys` returns the generated value + +```scala +Buyer.insert + .columns(_.name := "test buyer", _.dateOfBirth := LocalDate.parse("2023-09-09")) + .getGeneratedKeys[Int] +``` + + +* + ```sql + INSERT INTO buyer (name, date_of_birth) VALUES (?, ?) + ``` + + + +* + ```scala + Seq(4) + ``` + + + +---- + + + +```scala +Buyer.select.filter(_.name `=` "test buyer") +``` + + + + +* + ```scala + Seq(Buyer[Sc](4, "test buyer", LocalDate.parse("2023-09-09"))) + ``` + + + +### GetGeneratedKeys.batch.partial + +`getGeneratedKeys` can return multiple generated primary key values for +a batch insert statement + +```scala +Buyer.insert + .batched(_.name, _.dateOfBirth)( + ("test buyer A", LocalDate.parse("2001-04-07")), + ("test buyer B", LocalDate.parse("2002-05-08")), + ("test buyer C", LocalDate.parse("2003-06-09")) + ) + .getGeneratedKeys[Int] +``` + + +* + ```sql + INSERT INTO buyer (name, date_of_birth) + VALUES (?, ?), (?, ?), (?, ?) + ``` + + + +* + ```scala + Seq(4, 5, 6) + ``` + + + +---- + + + +```scala +Buyer.select +``` + + + + +* + ```scala + Seq( + Buyer[Sc](1, "James Bond", LocalDate.parse("2001-02-03")), + Buyer[Sc](2, "叉烧包", LocalDate.parse("1923-11-12")), + Buyer[Sc](3, "Li Haoyi", LocalDate.parse("1965-08-09")), + // id=4,5,6 comes from auto increment + Buyer[Sc](4, "test buyer A", LocalDate.parse("2001-04-07")), + Buyer[Sc](5, "test buyer B", LocalDate.parse("2002-05-08")), + Buyer[Sc](6, "test buyer C", LocalDate.parse("2003-06-09")) + ) + ``` + + + +### GetGeneratedKeys.select.simple + +`getGeneratedKeys` can return multiple generated primary key values for +an `insert` based on a `select` + +```scala +Buyer.insert + .select( + x => (x.name, x.dateOfBirth), + Buyer.select.map(x => (x.name, x.dateOfBirth)).filter(_._1 <> "Li Haoyi") + ) + .getGeneratedKeys[Int] +``` + + +* + ```sql + INSERT INTO buyer (name, date_of_birth) + SELECT buyer0.name AS res_0, buyer0.date_of_birth AS res_1 + FROM buyer buyer0 + WHERE (buyer0.name <> ?) + ``` + + + +* + ```scala + Seq(4, 5) + ``` + + + +---- + + + +```scala +Buyer.select +``` + + + + +* + ```scala + Seq( + Buyer[Sc](1, "James Bond", LocalDate.parse("2001-02-03")), + Buyer[Sc](2, "叉烧包", LocalDate.parse("1923-11-12")), + Buyer[Sc](3, "Li Haoyi", LocalDate.parse("1965-08-09")), + // id=4,5 comes from auto increment, 6 is filtered out in the select + Buyer[Sc](4, "James Bond", LocalDate.parse("2001-02-03")), + Buyer[Sc](5, "叉烧包", LocalDate.parse("1923-11-12")) + ) + ``` + + + ## SubQuery Queries that explicitly use subqueries (e.g. for `JOIN`s) or require subqueries to preserve the Scala semantics of the various operators ### SubQuery.sortTakeJoin diff --git a/mill b/mill index 0c5078a0..4723ab68 100755 --- a/mill +++ b/mill @@ -7,7 +7,7 @@ set -e if [ -z "${DEFAULT_MILL_VERSION}" ] ; then - DEFAULT_MILL_VERSION=0.11.6 + DEFAULT_MILL_VERSION=0.11.7-29-f2e220 fi if [ -z "$MILL_VERSION" ] ; then diff --git a/scalasql/core/src/DbApi.scala b/scalasql/core/src/DbApi.scala index 019de4e5..c8de1f7c 100644 --- a/scalasql/core/src/DbApi.scala +++ b/scalasql/core/src/DbApi.scala @@ -103,6 +103,22 @@ trait DbApi extends AutoCloseable { queryTimeoutSeconds: Int = -1 )(implicit fileName: sourcecode.FileName, lineNum: sourcecode.Line): Int + def updateGetGeneratedKeysSql[R](sql: SqlStr, fetchSize: Int = -1, queryTimeoutSeconds: Int = -1)( + implicit qr: Queryable.Row[_, R], + fileName: sourcecode.FileName, + lineNum: sourcecode.Line + ): IndexedSeq[R] + + def updateGetGeneratedKeysRaw[R]( + sql: String, + variables: Seq[Any] = Nil, + fetchSize: Int = -1, + queryTimeoutSeconds: Int = -1 + )( + implicit qr: Queryable.Row[_, R], + fileName: sourcecode.FileName, + lineNum: sourcecode.Line + ): IndexedSeq[R] } object DbApi { @@ -149,7 +165,7 @@ object DbApi { def rollback(): Unit } - // Call hierechy the various DbApi.Impl methods, both public and private: + // Call hierarchy the various DbApi.Impl methods, both public and private: // // run // | @@ -181,7 +197,10 @@ object DbApi { ): R = { val flattened = unpackQueryable(query, qr, config) - if (qr.isExecuteUpdate(query)) updateSql(flattened).asInstanceOf[R] + if (qr.isGetGeneratedKeys(query).nonEmpty) + updateGetGeneratedKeysSql(flattened)(qr.isGetGeneratedKeys(query).get, fileName, lineNum) + .asInstanceOf[R] + else if (qr.isExecuteUpdate(query)) updateSql(flattened).asInstanceOf[R] else { try { val res = stream(query, fetchSize, queryTimeoutSeconds)( @@ -268,6 +287,27 @@ object DbApi { ) } + def updateGetGeneratedKeysSql[R]( + sql: SqlStr, + fetchSize: Int = -1, + queryTimeoutSeconds: Int = -1 + )( + implicit qr: Queryable.Row[_, R], + fileName: sourcecode.FileName, + lineNum: sourcecode.Line + ): IndexedSeq[R] = { + val flattened = SqlStr.flatten(sql) + runRawUpdateGetGeneratedKeys0( + flattened.renderSql(DialectConfig.castParams(dialect)), + flattenParamPuts(flattened), + fetchSize, + queryTimeoutSeconds, + fileName, + lineNum, + qr + ) + } + def runRaw[R]( sql: String, variables: Seq[Any] = Nil, @@ -316,6 +356,25 @@ object DbApi { lineNum ) + def updateGetGeneratedKeysRaw[R]( + sql: String, + variables: Seq[Any] = Nil, + fetchSize: Int = -1, + queryTimeoutSeconds: Int = -1 + )( + implicit qr: Queryable.Row[_, R], + fileName: sourcecode.FileName, + lineNum: sourcecode.Line + ): IndexedSeq[R] = runRawUpdateGetGeneratedKeys0( + sql, + anySeqPuts(variables), + fetchSize, + queryTimeoutSeconds, + fileName, + lineNum, + qr + ) + def streamFlattened0[R]( construct: Queryable.ResultSetIterator => R, flattened: SqlStr.Flattened, @@ -400,6 +459,36 @@ object DbApi { } } + def runRawUpdateGetGeneratedKeys0[R]( + sql: String, + variables: Seq[(PreparedStatement, Int) => Unit], + fetchSize: Int, + queryTimeoutSeconds: Int, + fileName: sourcecode.FileName, + lineNum: sourcecode.Line, + qr: Queryable.Row[_, R] + ): IndexedSeq[R] = { + val statement = connection.prepareStatement(sql, java.sql.Statement.RETURN_GENERATED_KEYS) + for ((v, i) <- variables.iterator.zipWithIndex) v(statement, i + 1) + configureRunCloseStatement( + statement, + fetchSize, + queryTimeoutSeconds, + sql, + fileName, + lineNum + ) { stmt => + stmt.executeUpdate() + val resultSet = stmt.getGeneratedKeys + val output = Vector.newBuilder[R] + while (resultSet.next()) { + val rowRes = qr.construct(new Queryable.ResultSetIterator(resultSet)) + output.addOne(rowRes) + } + output.result() + } + } + def configureRunCloseStatement[P <: Statement, T]( statement: P, fetchSize: Int, diff --git a/scalasql/core/src/Queryable.scala b/scalasql/core/src/Queryable.scala index c2d7417a..37690747 100644 --- a/scalasql/core/src/Queryable.scala +++ b/scalasql/core/src/Queryable.scala @@ -13,6 +13,12 @@ import java.sql.ResultSet */ trait Queryable[-Q, R] { + /** + * Whether this queryable value is executed using `java.sql.Statement.getGeneratedKeys` + * instead of `.executeQuery`. + */ + def isGetGeneratedKeys(q: Q): Option[Queryable.Row[_, _]] + /** * Whether this queryable value is executed using `java.sql.Statement.executeUpdate` * instead of `.executeQuery`. Note that this needs to be known ahead of time, and @@ -82,6 +88,7 @@ object Queryable { * available, there is no `Queryable.Row[Select[Q]]`, as `Select[Q]` returns multiple rows */ trait Row[Q, R] extends Queryable[Q, R] { + def isGetGeneratedKeys(q: Q): Option[Queryable.Row[_, _]] = None def isExecuteUpdate(q: Q): Boolean = false def isSingleRow(q: Q): Boolean = true def walkLabels(): Seq[List[String]] diff --git a/scalasql/query/src/GetGeneratedKeys.scala b/scalasql/query/src/GetGeneratedKeys.scala new file mode 100644 index 00000000..13f9ef4c --- /dev/null +++ b/scalasql/query/src/GetGeneratedKeys.scala @@ -0,0 +1,33 @@ +package scalasql.query + +import scalasql.core.SqlStr.Renderable +import scalasql.core.{Context, Queryable, SqlStr, WithSqlExpr} + +/** + * Represents an [[Insert]] query that you want to call `JdbcStatement.getGeneratedKeys` + * on to retrieve any auto-generated primary key values from the results + */ +trait GetGeneratedKeys[Q, R] extends Query[Seq[R]] { + def single: Query.Single[R] = new Query.Single(this) +} + +object GetGeneratedKeys { + + class Impl[Q, R](base: Returning.InsertBase[Q])(implicit qr: Queryable.Row[_, R]) + extends GetGeneratedKeys[Q, R] { + + def expr = WithSqlExpr.get(base) + override protected def queryConstruct(args: Queryable.ResultSetIterator): Seq[R] = { + Seq(qr.construct(args)) + } + + protected def queryWalkLabels() = Nil + protected def queryWalkExprs() = Nil + protected override def queryIsSingleRow = false + protected override def queryIsExecuteUpdate = true + + override protected def renderSql(ctx: Context): SqlStr = Renderable.renderSql(base)(ctx) + + override protected def queryGetGeneratedKeys: Option[Queryable.Row[_, _]] = Some(qr) + } +} diff --git a/scalasql/query/src/InsertValues.scala b/scalasql/query/src/InsertValues.scala index d0e1a95d..8b350deb 100644 --- a/scalasql/query/src/InsertValues.scala +++ b/scalasql/query/src/InsertValues.scala @@ -1,9 +1,9 @@ package scalasql.query -import scalasql.core.{Context, DialectTypeMappers, Queryable, SqlStr, WithSqlExpr} +import scalasql.core.{Context, DialectTypeMappers, Expr, Queryable, SqlStr, WithSqlExpr} import scalasql.core.SqlStr.SqlStringSyntax -trait InsertValues[V[_[_]], R] extends Query.ExecuteUpdate[Int] { +trait InsertValues[V[_[_]], R] extends Returning.InsertBase[V[Expr]] with Query.ExecuteUpdate[Int] { def skipColumns(x: (V[Column] => Column[_])*): InsertValues[V, R] } object InsertValues { @@ -15,6 +15,8 @@ object InsertValues { skippedColumns: Seq[Column[_]] ) extends InsertValues[V, R] { + def table = insert.table + protected def expr = WithSqlExpr.get(insert).asInstanceOf[V[Expr]] override protected def queryConstruct(args: Queryable.ResultSetIterator): Int = args.get(dialect.IntType) diff --git a/scalasql/query/src/Query.scala b/scalasql/query/src/Query.scala index 759d297b..43d2ea7d 100644 --- a/scalasql/query/src/Query.scala +++ b/scalasql/query/src/Query.scala @@ -11,6 +11,7 @@ trait Query[R] extends Renderable { protected def queryWalkLabels(): Seq[List[String]] protected def queryWalkExprs(): Seq[Expr[_]] protected def queryIsSingleRow: Boolean + protected def queryGetGeneratedKeys: Option[Queryable.Row[_, _]] = None protected def queryIsExecuteUpdate: Boolean = false protected def queryConstruct(args: Queryable.ResultSetIterator): R @@ -64,6 +65,7 @@ object Query { * overrides and subclassing of [[Query]] classes */ class QueryQueryable[Q <: Query[R], R]() extends scalasql.core.Queryable[Q, R] { + override def isGetGeneratedKeys(q: Q) = q.queryGetGeneratedKeys override def isExecuteUpdate(q: Q) = q.queryIsExecuteUpdate override def walkLabels(q: Q) = q.queryWalkLabels() override def walkExprs(q: Q) = q.queryWalkExprs() diff --git a/scalasql/query/src/Returning.scala b/scalasql/query/src/Returning.scala index 43dfac94..0f0a1d44 100644 --- a/scalasql/query/src/Returning.scala +++ b/scalasql/query/src/Returning.scala @@ -20,7 +20,17 @@ object Returning { def table: TableRef } - trait InsertBase[Q] extends Base[Q] + trait InsertBase[Q] extends Base[Q] { + + /** + * Makes this `INSERT` query call `JdbcStatement.getGeneratedKeys` when it is executed, + * returning a `Seq[R]` where `R` is a Scala type compatible with the auto-generated + * primary key type (typically something like `Int` or `Long`) + */ + def getGeneratedKeys[R](implicit qr: Queryable.Row[_, R]): GetGeneratedKeys[Q, R] = { + new GetGeneratedKeys.Impl(this) + } + } class InsertImpl[Q, R](returnable: InsertBase[_], returning: Q)( implicit qr: Queryable.Row[Q, R] diff --git a/scalasql/test/src/ConcreteTestSuites.scala b/scalasql/test/src/ConcreteTestSuites.scala index 36fe3e82..b4cff752 100644 --- a/scalasql/test/src/ConcreteTestSuites.scala +++ b/scalasql/test/src/ConcreteTestSuites.scala @@ -27,6 +27,7 @@ import query.{ ValuesTests, LateralJoinTests, WindowFunctionTests, + GetGeneratedKeysTests, WithCteTests } import scalasql.dialects.{ @@ -57,6 +58,7 @@ package postgres { object ValuesTests extends ValuesTests with PostgresSuite object LateralJoinTests extends LateralJoinTests with PostgresSuite object WindowFunctionTests extends WindowFunctionTests with PostgresSuite + object GetGeneratedKeysTests extends GetGeneratedKeysTests with PostgresSuite object SubQueryTests extends SubQueryTests with PostgresSuite object WithCteTests extends WithCteTests with PostgresSuite @@ -100,6 +102,7 @@ package hikari { object ValuesTests extends ValuesTests with HikariSuite object LateralJoinTests extends LateralJoinTests with HikariSuite object WindowFunctionTests extends WindowFunctionTests with HikariSuite + object GetGeneratedKeysTests extends GetGeneratedKeysTests with HikariSuite object SubQueryTests extends SubQueryTests with HikariSuite object WithCteTests extends WithCteTests with HikariSuite @@ -146,6 +149,7 @@ package mysql { object ValuesTests extends ValuesTests with MySqlSuite object LateralJoinTests extends LateralJoinTests with MySqlSuite object WindowFunctionTests extends WindowFunctionTests with MySqlSuite + object GetGeneratedKeysTests extends GetGeneratedKeysTests with MySqlSuite object SubQueryTests extends SubQueryTests with MySqlSuite object WithCteTests extends WithCteTests with MySqlSuite @@ -188,6 +192,8 @@ package sqlite { // Sqlite does not support lateral joins // object LateralJoinTests extends LateralJoinTests with SqliteSuite object WindowFunctionTests extends WindowFunctionTests with SqliteSuite + // Sqlite does not support getGeneratedKeys https://github.com/xerial/sqlite-jdbc/issues/980 + // object GetGeneratedKeysTests extends GetGeneratedKeysTests with SqliteSuite object SubQueryTests extends SubQueryTests with SqliteSuite object WithCteTests extends WithCteTests with SqliteSuite @@ -233,6 +239,7 @@ package h2 { // H2 does not support lateral joins // object LateralJoinTests extends LateralJoinTests with H2Suite object WindowFunctionTests extends WindowFunctionTests with H2Suite + object GetGeneratedKeysTests extends GetGeneratedKeysTests with H2Suite object SubQueryTests extends SubQueryTests with H2Suite object WithCteTests extends WithCteTests with H2Suite diff --git a/scalasql/test/src/api/DbApiTests.scala b/scalasql/test/src/api/DbApiTests.scala index 4d99ad98..dd840eac 100644 --- a/scalasql/test/src/api/DbApiTests.scala +++ b/scalasql/test/src/api/DbApiTests.scala @@ -3,7 +3,7 @@ package scalasql.api import geny.Generator import scalasql.core.SqlStr.SqlStringSyntax import scalasql.{Buyer, Sc} -import scalasql.utils.{MySqlSuite, ScalaSqlSuite} +import scalasql.utils.{MySqlSuite, ScalaSqlSuite, SqliteSuite} import sourcecode.Text import utest._ @@ -123,6 +123,37 @@ trait DbApiTests extends ScalaSqlSuite { } } ) + test("updateGetGeneratedKeysSql") - { + if (!this.isInstanceOf[SqliteSuite]) + checker.recorded( + """ + Allows you to fetch the primary keys that were auto-generated for an INSERT + defined as a `SqlStr`. + Note: not supported by Sqlite https://github.com/xerial/sqlite-jdbc/issues/980 + """, + Text { + + dbClient.transaction { db => + val newName = "Moo Moo Cow" + val newDateOfBirth = LocalDate.parse("2000-01-01") + val generatedIds = db + .updateGetGeneratedKeysSql[Int]( + sql"INSERT INTO buyer (name, date_of_birth) VALUES ($newName, $newDateOfBirth), ($newName, $newDateOfBirth)" + ) + + assert(generatedIds == Seq(4, 5)) + + db.run(Buyer.select) ==> List( + Buyer[Sc](1, "James Bond", LocalDate.parse("2001-02-03")), + Buyer[Sc](2, "叉烧包", LocalDate.parse("1923-11-12")), + Buyer[Sc](3, "Li Haoyi", LocalDate.parse("1965-08-09")), + Buyer[Sc](4, "Moo Moo Cow", LocalDate.parse("2000-01-01")), + Buyer[Sc](5, "Moo Moo Cow", LocalDate.parse("2000-01-01")) + ) + } + } + ) + } test("runRaw") - checker.recorded( """ @@ -159,6 +190,38 @@ trait DbApiTests extends ScalaSqlSuite { } } ) + test("updateGetGeneratedKeysRaw") - { + if (!this.isInstanceOf[SqliteSuite]) + checker.recorded( + """ + Allows you to fetch the primary keys that were auto-generated for an INSERT + defined using a raw `java.lang.String` and variables. + Note: not supported by Sqlite https://github.com/xerial/sqlite-jdbc/issues/980 + """, + Text { + dbClient.transaction { db => + val generatedKeys = db.updateGetGeneratedKeysRaw[Int]( + "INSERT INTO buyer (name, date_of_birth) VALUES (?, ?), (?, ?)", + Seq( + "Moo Moo Cow", + LocalDate.parse("2000-01-01"), + "Moo Moo Cow", + LocalDate.parse("2000-01-01") + ) + ) + assert(generatedKeys == Seq(4, 5)) + + db.run(Buyer.select) ==> List( + Buyer[Sc](1, "James Bond", LocalDate.parse("2001-02-03")), + Buyer[Sc](2, "叉烧包", LocalDate.parse("1923-11-12")), + Buyer[Sc](3, "Li Haoyi", LocalDate.parse("1965-08-09")), + Buyer[Sc](4, "Moo Moo Cow", LocalDate.parse("2000-01-01")), + Buyer[Sc](5, "Moo Moo Cow", LocalDate.parse("2000-01-01")) + ) + } + } + ) + } test("stream") - checker.recorded( """ diff --git a/scalasql/test/src/query/GetGeneratedKeysTests.scala b/scalasql/test/src/query/GetGeneratedKeysTests.scala new file mode 100644 index 00000000..d30801bb --- /dev/null +++ b/scalasql/test/src/query/GetGeneratedKeysTests.scala @@ -0,0 +1,158 @@ +package scalasql.query + +import scalasql._ +import scalasql.utils.ScalaSqlSuite +import utest._ + +import java.time.LocalDate + +trait GetGeneratedKeysTests extends ScalaSqlSuite { + def description = + "`INSERT` operations with `.getGeneratedKeys`. Not supported by Sqlite (see https://github.com/xerial/sqlite-jdbc/issues/980)" + override def utestBeforeEach(path: Seq[String]): Unit = checker.reset() + def tests = Tests { + test("single") { + test("values") - { + checker( + query = Buyer.insert + .values( + Buyer[Sc](17, "test buyer", LocalDate.parse("2023-09-09")) + ) + .getGeneratedKeys[Int], + sql = "INSERT INTO buyer (id, name, date_of_birth) VALUES (?, ?, ?)", + value = Seq(17), + docs = """ + `getGeneratedKeys` on an `insert` returns the primary key, even if it was provided + explicitly. + """ + ) + + checker( + query = Buyer.select.filter(_.name `=` "test buyer"), + value = Seq(Buyer[Sc](17, "test buyer", LocalDate.parse("2023-09-09"))) + ) + } + + test("columns") - { + checker( + query = Buyer.insert + .columns( + _.name := "test buyer", + _.dateOfBirth := LocalDate.parse("2023-09-09"), + _.id := 4 + ) + .getGeneratedKeys[Long], + sql = "INSERT INTO buyer (name, date_of_birth, id) VALUES (?, ?, ?)", + value = Seq(4L), + docs = """ + All styles of `INSERT` query support `.getGeneratedKeys`, with this example + using `insert.columns` rather than `insert.values`. You can also retrieve + the generated primary keys using any compatible type, here shown using `Long` + rather than `Int` + """ + ) + + checker( + query = Buyer.select.filter(_.name `=` "test buyer"), + value = Seq(Buyer[Sc](4, "test buyer", LocalDate.parse("2023-09-09"))) + ) + } + + test("partial") - { + checker( + query = Buyer.insert + .columns(_.name := "test buyer", _.dateOfBirth := LocalDate.parse("2023-09-09")) + .getGeneratedKeys[Int], + sql = "INSERT INTO buyer (name, date_of_birth) VALUES (?, ?)", + value = Seq(4), + docs = """ + If the primary key was not provided but was auto-generated by the database, + `getGeneratedKeys` returns the generated value + """ + ) + + checker( + query = Buyer.select.filter(_.name `=` "test buyer"), + // id=4 comes from auto increment + value = Seq(Buyer[Sc](4, "test buyer", LocalDate.parse("2023-09-09"))) + ) + } + } + + test("batch") { + + test("partial") - { + checker( + query = Buyer.insert + .batched(_.name, _.dateOfBirth)( + ("test buyer A", LocalDate.parse("2001-04-07")), + ("test buyer B", LocalDate.parse("2002-05-08")), + ("test buyer C", LocalDate.parse("2003-06-09")) + ) + .getGeneratedKeys[Int], + sql = """ + INSERT INTO buyer (name, date_of_birth) + VALUES (?, ?), (?, ?), (?, ?) + """, + value = Seq(4, 5, 6), + docs = """ + `getGeneratedKeys` can return multiple generated primary key values for + a batch insert statement + """ + ) + + checker( + query = Buyer.select, + value = Seq( + Buyer[Sc](1, "James Bond", LocalDate.parse("2001-02-03")), + Buyer[Sc](2, "叉烧包", LocalDate.parse("1923-11-12")), + Buyer[Sc](3, "Li Haoyi", LocalDate.parse("1965-08-09")), + // id=4,5,6 comes from auto increment + Buyer[Sc](4, "test buyer A", LocalDate.parse("2001-04-07")), + Buyer[Sc](5, "test buyer B", LocalDate.parse("2002-05-08")), + Buyer[Sc](6, "test buyer C", LocalDate.parse("2003-06-09")) + ) + ) + } + + } + + test("select") { + + test("simple") { + checker( + query = Buyer.insert + .select( + x => (x.name, x.dateOfBirth), + Buyer.select.map(x => (x.name, x.dateOfBirth)).filter(_._1 <> "Li Haoyi") + ) + .getGeneratedKeys[Int], + sql = """ + INSERT INTO buyer (name, date_of_birth) + SELECT buyer0.name AS res_0, buyer0.date_of_birth AS res_1 + FROM buyer buyer0 + WHERE (buyer0.name <> ?) + """, + value = Seq(4, 5), + docs = """ + `getGeneratedKeys` can return multiple generated primary key values for + an `insert` based on a `select` + """ + ) + + checker( + query = Buyer.select, + value = Seq( + Buyer[Sc](1, "James Bond", LocalDate.parse("2001-02-03")), + Buyer[Sc](2, "叉烧包", LocalDate.parse("1923-11-12")), + Buyer[Sc](3, "Li Haoyi", LocalDate.parse("1965-08-09")), + // id=4,5 comes from auto increment, 6 is filtered out in the select + Buyer[Sc](4, "James Bond", LocalDate.parse("2001-02-03")), + Buyer[Sc](5, "叉烧包", LocalDate.parse("1923-11-12")) + ) + ) + } + + } + } +}