From 8a8a7df09a9770fc254261606d8c8d3b768cff26 Mon Sep 17 00:00:00 2001 From: Li Haoyi Date: Fri, 12 Apr 2024 11:53:28 +0800 Subject: [PATCH 1/4] first updateGetGeneratedKeysRaw and updateGetGeneratedKeysSql tests pass --- mill | 2 +- scalasql/core/src/DbApi.scala | 80 ++++++++++++++++++++++++ scalasql/core/src/Queryable.scala | 7 +++ scalasql/query/src/Query.scala | 2 + scalasql/test/src/Main.scala | 84 ++++++++++++++++---------- scalasql/test/src/api/DbApiTests.scala | 51 ++++++++++++++++ 6 files changed, 194 insertions(+), 32 deletions(-) diff --git a/mill b/mill index 0c5078a0..4723ab68 100755 --- a/mill +++ b/mill @@ -7,7 +7,7 @@ set -e if [ -z "${DEFAULT_MILL_VERSION}" ] ; then - DEFAULT_MILL_VERSION=0.11.6 + DEFAULT_MILL_VERSION=0.11.7-29-f2e220 fi if [ -z "$MILL_VERSION" ] ; then diff --git a/scalasql/core/src/DbApi.scala b/scalasql/core/src/DbApi.scala index 019de4e5..9201da1c 100644 --- a/scalasql/core/src/DbApi.scala +++ b/scalasql/core/src/DbApi.scala @@ -103,6 +103,21 @@ trait DbApi extends AutoCloseable { queryTimeoutSeconds: Int = -1 )(implicit fileName: sourcecode.FileName, lineNum: sourcecode.Line): Int + def updateGetGeneratedKeysSql[R](sql: SqlStr, fetchSize: Int = -1, queryTimeoutSeconds: Int = -1)( + implicit qr: Queryable.Row[_, R], + fileName: sourcecode.FileName, + lineNum: sourcecode.Line + ): IndexedSeq[R] + + def updateGetGeneratedKeysRaw[R]( + sql: String, + variables: Seq[Any] = Nil, + fetchSize: Int = -1, + queryTimeoutSeconds: Int = -1 + )(implicit qr: Queryable.Row[_, R], + fileName: sourcecode.FileName, + lineNum: sourcecode.Line + ): IndexedSeq[R] } object DbApi { @@ -268,6 +283,23 @@ object DbApi { ) } + def updateGetGeneratedKeysSql[R](sql: SqlStr, fetchSize: Int = -1, queryTimeoutSeconds: Int = -1)( + implicit qr: Queryable.Row[_, R], + fileName: sourcecode.FileName, + lineNum: sourcecode.Line + ): IndexedSeq[R] = { + val flattened = SqlStr.flatten(sql) + runRawUpdateGetGeneratedKeys0( + flattened.renderSql(DialectConfig.castParams(dialect)), + flattenParamPuts(flattened), + fetchSize, + queryTimeoutSeconds, + fileName, + lineNum, + qr + ) + } + def runRaw[R]( sql: String, variables: Seq[Any] = Nil, @@ -316,6 +348,24 @@ object DbApi { lineNum ) + def updateGetGeneratedKeysRaw[R]( + sql: String, + variables: Seq[Any] = Nil, + fetchSize: Int = -1, + queryTimeoutSeconds: Int = -1 + )(implicit qr: Queryable.Row[_, R], + fileName: sourcecode.FileName, + lineNum: sourcecode.Line + ): IndexedSeq[R] = runRawUpdateGetGeneratedKeys0( + sql, + anySeqPuts(variables), + fetchSize, + queryTimeoutSeconds, + fileName, + lineNum, + qr + ) + def streamFlattened0[R]( construct: Queryable.ResultSetIterator => R, flattened: SqlStr.Flattened, @@ -400,6 +450,36 @@ object DbApi { } } + def runRawUpdateGetGeneratedKeys0[R]( + sql: String, + variables: Seq[(PreparedStatement, Int) => Unit], + fetchSize: Int, + queryTimeoutSeconds: Int, + fileName: sourcecode.FileName, + lineNum: sourcecode.Line, + qr: Queryable.Row[_, R] + ): IndexedSeq[R] = { + val statement = connection.prepareStatement(sql, java.sql.Statement.RETURN_GENERATED_KEYS) + for ((v, i) <- variables.iterator.zipWithIndex) v(statement, i + 1) + configureRunCloseStatement( + statement, + fetchSize, + queryTimeoutSeconds, + sql, + fileName, + lineNum + ){ stmt => + stmt.executeUpdate() + val resultSet = stmt.getGeneratedKeys + val output = Vector.newBuilder[R] + while (resultSet.next()) { + val rowRes = qr.construct(new Queryable.ResultSetIterator(resultSet)) + output.addOne(rowRes) + } + output.result() + } + } + def configureRunCloseStatement[P <: Statement, T]( statement: P, fetchSize: Int, diff --git a/scalasql/core/src/Queryable.scala b/scalasql/core/src/Queryable.scala index c2d7417a..d4ad3e09 100644 --- a/scalasql/core/src/Queryable.scala +++ b/scalasql/core/src/Queryable.scala @@ -13,6 +13,12 @@ import java.sql.ResultSet */ trait Queryable[-Q, R] { + /** + * Whether this queryable value is executed using `java.sql.Statement.getGeneratedKeys` + * instead of `.executeQuery`. + */ + def isGetGeneratedKeys(q: Q): Boolean + /** * Whether this queryable value is executed using `java.sql.Statement.executeUpdate` * instead of `.executeQuery`. Note that this needs to be known ahead of time, and @@ -82,6 +88,7 @@ object Queryable { * available, there is no `Queryable.Row[Select[Q]]`, as `Select[Q]` returns multiple rows */ trait Row[Q, R] extends Queryable[Q, R] { + def isGetGeneratedKeys(q: Q): Boolean = false def isExecuteUpdate(q: Q): Boolean = false def isSingleRow(q: Q): Boolean = true def walkLabels(): Seq[List[String]] diff --git a/scalasql/query/src/Query.scala b/scalasql/query/src/Query.scala index 759d297b..a33e63a4 100644 --- a/scalasql/query/src/Query.scala +++ b/scalasql/query/src/Query.scala @@ -11,6 +11,7 @@ trait Query[R] extends Renderable { protected def queryWalkLabels(): Seq[List[String]] protected def queryWalkExprs(): Seq[Expr[_]] protected def queryIsSingleRow: Boolean + protected def queryGetGeneratedKeys: Boolean = false protected def queryIsExecuteUpdate: Boolean = false protected def queryConstruct(args: Queryable.ResultSetIterator): R @@ -64,6 +65,7 @@ object Query { * overrides and subclassing of [[Query]] classes */ class QueryQueryable[Q <: Query[R], R]() extends scalasql.core.Queryable[Q, R] { + override def isGetGeneratedKeys(q: Q) = q.queryGetGeneratedKeys override def isExecuteUpdate(q: Q) = q.queryIsExecuteUpdate override def walkLabels(q: Q) = q.queryWalkLabels() override def walkExprs(q: Q) = q.queryWalkExprs() diff --git a/scalasql/test/src/Main.scala b/scalasql/test/src/Main.scala index 6a6d4833..594014f4 100644 --- a/scalasql/test/src/Main.scala +++ b/scalasql/test/src/Main.scala @@ -1,38 +1,60 @@ package scalasql -import java.sql.DriverManager -import scalasql.H2Dialect._ -object Main { +import java.sql.{Connection, DriverManager, PreparedStatement, ResultSet, SQLException} - case class Example[T[_]](bytes: T[geny.Bytes]) +object Main extends App { + var connection: Connection = null + var preparedStatement: PreparedStatement = null + var resultSet: ResultSet = null - object Example extends Table[Example] + try { + // 1. Connect to the database + connection = DriverManager.getConnection("jdbc:h2:~/test", "username", "password") - // The example H2 database comes from the library `com.h2database:h2:2.2.224` - val conn = DriverManager.getConnection("jdbc:h2:mem:mydb") - - def main(args: Array[String]): Unit = { - conn - .createStatement() - .executeUpdate( - """ - CREATE TABLE data_types ( - my_var_binary VARBINARY(256) - ); + // 2. Create the table if it doesn't exist + val createTableSql = """ - ) - - val prepared = conn.prepareStatement("INSERT INTO data_types (my_var_binary) VALUES (?)") - prepared.setBytes(1, Array[Byte](1, 2, 3, 4)) - prepared.executeUpdate() - - val results = conn - .createStatement() - .executeQuery( - "SELECT data_types0.my_var_binary AS my_var_binary FROM data_types data_types0" - ) - - results.next() - pprint.log(results.getBytes(1)) + |CREATE TABLE IF NOT EXISTS MY_TABLE ( + | ID INT AUTO_INCREMENT PRIMARY KEY, + | COLUMN1 VARCHAR(255), + | COLUMN2 VARCHAR(255) + |) + |""".stripMargin + preparedStatement = connection.prepareStatement(createTableSql) + preparedStatement.execute() + + // 3. Prepare a statement with generated keys option + val insertSql = "INSERT INTO MY_TABLE (COLUMN1, COLUMN2) VALUES (?, ?)" + preparedStatement = connection.prepareStatement(insertSql, java.sql.Statement.RETURN_GENERATED_KEYS) + + // 4. Set values for placeholders + preparedStatement.setString(1, "value1") + preparedStatement.setString(2, "value2") + + // 5. Execute the insert statement + val affectedRows = preparedStatement.executeUpdate() + + if (affectedRows == 0) { + println("Insertion failed, no rows affected.") + } else { + // 6. Retrieve generated keys + resultSet = preparedStatement.getGeneratedKeys() + if (resultSet.next()) { + println("Generated Key: " + resultSet.getLong(1)) + } else { + println("No generated keys were retrieved.") + } + } + } catch { + case e: SQLException => e.printStackTrace() + } finally { + // 7. Close resources + try { + if (resultSet != null) resultSet.close() + if (preparedStatement != null) preparedStatement.close() + if (connection != null) connection.close() + } catch { + case e: SQLException => e.printStackTrace() + } } -} +} \ No newline at end of file diff --git a/scalasql/test/src/api/DbApiTests.scala b/scalasql/test/src/api/DbApiTests.scala index 4d99ad98..56380ddd 100644 --- a/scalasql/test/src/api/DbApiTests.scala +++ b/scalasql/test/src/api/DbApiTests.scala @@ -123,6 +123,32 @@ trait DbApiTests extends ScalaSqlSuite { } } ) + test("updateGetGeneratedKeysSql") - checker.recorded( + """ + ??? + """, + Text { + + dbClient.transaction { db => + val newName = "Moo Moo Cow" + val newDateOfBirth = LocalDate.parse("2000-01-01") + val generatedIds = db + .updateGetGeneratedKeysSql[Int]( + sql"INSERT INTO buyer (name, date_of_birth) VALUES ($newName, $newDateOfBirth), ($newName, $newDateOfBirth)" + ) + + assert(generatedIds == Seq(4, 5)) + + db.run(Buyer.select) ==> List( + Buyer[Sc](1, "James Bond", LocalDate.parse("2001-02-03")), + Buyer[Sc](2, "叉烧包", LocalDate.parse("1923-11-12")), + Buyer[Sc](3, "Li Haoyi", LocalDate.parse("1965-08-09")), + Buyer[Sc](4, "Moo Moo Cow", LocalDate.parse("2000-01-01")), + Buyer[Sc](5, "Moo Moo Cow", LocalDate.parse("2000-01-01")), + ) + } + } + ) test("runRaw") - checker.recorded( """ @@ -159,6 +185,31 @@ trait DbApiTests extends ScalaSqlSuite { } } ) + test("updateGetGeneratedKeysRaw") - checker.recorded( + """ + ??? + """, + Text { + dbClient.transaction { db => + val generatedKeys = db.updateGetGeneratedKeysRaw[Int]( + "INSERT INTO buyer (name, date_of_birth) VALUES (?, ?), (?, ?)", + Seq( + "Moo Moo Cow", LocalDate.parse("2000-01-01"), + "Moo Moo Cow", LocalDate.parse("2000-01-01") + ) + ) + assert(generatedKeys == Seq(4, 5)) + + db.run(Buyer.select) ==> List( + Buyer[Sc](1, "James Bond", LocalDate.parse("2001-02-03")), + Buyer[Sc](2, "叉烧包", LocalDate.parse("1923-11-12")), + Buyer[Sc](3, "Li Haoyi", LocalDate.parse("1965-08-09")), + Buyer[Sc](4, "Moo Moo Cow", LocalDate.parse("2000-01-01")), + Buyer[Sc](5, "Moo Moo Cow", LocalDate.parse("2000-01-01")) + ) + } + } + ) test("stream") - checker.recorded( """ From 0f21e2fa35b35f1ff5ff66c8f7a2d5778b24ee6b Mon Sep 17 00:00:00 2001 From: Li Haoyi Date: Fri, 12 Apr 2024 13:10:49 +0800 Subject: [PATCH 2/4] . --- scalasql/core/src/DbApi.scala | 5 +- scalasql/core/src/Queryable.scala | 4 +- scalasql/query/src/GetGeneratedKeys.scala | 32 ++++ scalasql/query/src/InsertValues.scala | 6 +- scalasql/query/src/Query.scala | 2 +- scalasql/query/src/Returning.scala | 11 +- scalasql/test/src/ConcreteTestSuites.scala | 7 + scalasql/test/src/api/DbApiTests.scala | 98 ++++++----- .../src/query/GetGeneratedKeysTests.scala | 157 ++++++++++++++++++ 9 files changed, 269 insertions(+), 53 deletions(-) create mode 100644 scalasql/query/src/GetGeneratedKeys.scala create mode 100644 scalasql/test/src/query/GetGeneratedKeysTests.scala diff --git a/scalasql/core/src/DbApi.scala b/scalasql/core/src/DbApi.scala index 9201da1c..97a8010d 100644 --- a/scalasql/core/src/DbApi.scala +++ b/scalasql/core/src/DbApi.scala @@ -164,7 +164,7 @@ object DbApi { def rollback(): Unit } - // Call hierechy the various DbApi.Impl methods, both public and private: + // Call hierarchy the various DbApi.Impl methods, both public and private: // // run // | @@ -196,7 +196,8 @@ object DbApi { ): R = { val flattened = unpackQueryable(query, qr, config) - if (qr.isExecuteUpdate(query)) updateSql(flattened).asInstanceOf[R] + if (qr.isGetGeneratedKeys(query).nonEmpty) updateGetGeneratedKeysSql(flattened)(qr.isGetGeneratedKeys(query).get, fileName, lineNum).asInstanceOf[R] + else if (qr.isExecuteUpdate(query)) updateSql(flattened).asInstanceOf[R] else { try { val res = stream(query, fetchSize, queryTimeoutSeconds)( diff --git a/scalasql/core/src/Queryable.scala b/scalasql/core/src/Queryable.scala index d4ad3e09..37690747 100644 --- a/scalasql/core/src/Queryable.scala +++ b/scalasql/core/src/Queryable.scala @@ -17,7 +17,7 @@ trait Queryable[-Q, R] { * Whether this queryable value is executed using `java.sql.Statement.getGeneratedKeys` * instead of `.executeQuery`. */ - def isGetGeneratedKeys(q: Q): Boolean + def isGetGeneratedKeys(q: Q): Option[Queryable.Row[_, _]] /** * Whether this queryable value is executed using `java.sql.Statement.executeUpdate` @@ -88,7 +88,7 @@ object Queryable { * available, there is no `Queryable.Row[Select[Q]]`, as `Select[Q]` returns multiple rows */ trait Row[Q, R] extends Queryable[Q, R] { - def isGetGeneratedKeys(q: Q): Boolean = false + def isGetGeneratedKeys(q: Q): Option[Queryable.Row[_, _]] = None def isExecuteUpdate(q: Q): Boolean = false def isSingleRow(q: Q): Boolean = true def walkLabels(): Seq[List[String]] diff --git a/scalasql/query/src/GetGeneratedKeys.scala b/scalasql/query/src/GetGeneratedKeys.scala new file mode 100644 index 00000000..119fb0d9 --- /dev/null +++ b/scalasql/query/src/GetGeneratedKeys.scala @@ -0,0 +1,32 @@ +package scalasql.query + +import scalasql.core.SqlStr.Renderable +import scalasql.core.{Context, Queryable, SqlStr, WithSqlExpr} + +/** + * Represents an [[Insert]] query that you want to call `JdbcStatement.getGeneratedKeys` + * on to retrieve any auto-generated primary key values from the results + */ +trait GetGeneratedKeys[Q, R] extends Query[Seq[R]] { + def single: Query.Single[R] = new Query.Single(this) +} + +object GetGeneratedKeys { + + class Impl[Q, R](base: Returning.InsertBase[Q])(implicit qr: Queryable.Row[_, R]) extends GetGeneratedKeys[Q, R]{ + + def expr = WithSqlExpr.get(base) + override protected def queryConstruct(args: Queryable.ResultSetIterator): Seq[R] = { + Seq(qr.construct(args)) + } + + protected def queryWalkLabels() = Nil + protected def queryWalkExprs() = Nil + protected override def queryIsSingleRow = false + protected override def queryIsExecuteUpdate = true + + override protected def renderSql(ctx: Context): SqlStr = Renderable.renderSql(base)(ctx) + + override protected def queryGetGeneratedKeys: Option[Queryable.Row[_, _]] = Some(qr) + } +} diff --git a/scalasql/query/src/InsertValues.scala b/scalasql/query/src/InsertValues.scala index d0e1a95d..8b350deb 100644 --- a/scalasql/query/src/InsertValues.scala +++ b/scalasql/query/src/InsertValues.scala @@ -1,9 +1,9 @@ package scalasql.query -import scalasql.core.{Context, DialectTypeMappers, Queryable, SqlStr, WithSqlExpr} +import scalasql.core.{Context, DialectTypeMappers, Expr, Queryable, SqlStr, WithSqlExpr} import scalasql.core.SqlStr.SqlStringSyntax -trait InsertValues[V[_[_]], R] extends Query.ExecuteUpdate[Int] { +trait InsertValues[V[_[_]], R] extends Returning.InsertBase[V[Expr]] with Query.ExecuteUpdate[Int] { def skipColumns(x: (V[Column] => Column[_])*): InsertValues[V, R] } object InsertValues { @@ -15,6 +15,8 @@ object InsertValues { skippedColumns: Seq[Column[_]] ) extends InsertValues[V, R] { + def table = insert.table + protected def expr = WithSqlExpr.get(insert).asInstanceOf[V[Expr]] override protected def queryConstruct(args: Queryable.ResultSetIterator): Int = args.get(dialect.IntType) diff --git a/scalasql/query/src/Query.scala b/scalasql/query/src/Query.scala index a33e63a4..43d2ea7d 100644 --- a/scalasql/query/src/Query.scala +++ b/scalasql/query/src/Query.scala @@ -11,7 +11,7 @@ trait Query[R] extends Renderable { protected def queryWalkLabels(): Seq[List[String]] protected def queryWalkExprs(): Seq[Expr[_]] protected def queryIsSingleRow: Boolean - protected def queryGetGeneratedKeys: Boolean = false + protected def queryGetGeneratedKeys: Option[Queryable.Row[_, _]] = None protected def queryIsExecuteUpdate: Boolean = false protected def queryConstruct(args: Queryable.ResultSetIterator): R diff --git a/scalasql/query/src/Returning.scala b/scalasql/query/src/Returning.scala index 43dfac94..c3f99c14 100644 --- a/scalasql/query/src/Returning.scala +++ b/scalasql/query/src/Returning.scala @@ -20,7 +20,16 @@ object Returning { def table: TableRef } - trait InsertBase[Q] extends Base[Q] + trait InsertBase[Q] extends Base[Q]{ + /** + * Makes this `INSERT` query call `JdbcStatement.getGeneratedKeys` when it is executed, + * returning a `Seq[R]` where `R` is a Scala type compatible with the auto-generated + * primary key type (typically something like `Int` or `Long`) + */ + def getGeneratedKeys[R](implicit qr: Queryable.Row[_, R]): GetGeneratedKeys[Q, R] = { + new GetGeneratedKeys.Impl(this) + } + } class InsertImpl[Q, R](returnable: InsertBase[_], returning: Q)( implicit qr: Queryable.Row[Q, R] diff --git a/scalasql/test/src/ConcreteTestSuites.scala b/scalasql/test/src/ConcreteTestSuites.scala index 36fe3e82..b4cff752 100644 --- a/scalasql/test/src/ConcreteTestSuites.scala +++ b/scalasql/test/src/ConcreteTestSuites.scala @@ -27,6 +27,7 @@ import query.{ ValuesTests, LateralJoinTests, WindowFunctionTests, + GetGeneratedKeysTests, WithCteTests } import scalasql.dialects.{ @@ -57,6 +58,7 @@ package postgres { object ValuesTests extends ValuesTests with PostgresSuite object LateralJoinTests extends LateralJoinTests with PostgresSuite object WindowFunctionTests extends WindowFunctionTests with PostgresSuite + object GetGeneratedKeysTests extends GetGeneratedKeysTests with PostgresSuite object SubQueryTests extends SubQueryTests with PostgresSuite object WithCteTests extends WithCteTests with PostgresSuite @@ -100,6 +102,7 @@ package hikari { object ValuesTests extends ValuesTests with HikariSuite object LateralJoinTests extends LateralJoinTests with HikariSuite object WindowFunctionTests extends WindowFunctionTests with HikariSuite + object GetGeneratedKeysTests extends GetGeneratedKeysTests with HikariSuite object SubQueryTests extends SubQueryTests with HikariSuite object WithCteTests extends WithCteTests with HikariSuite @@ -146,6 +149,7 @@ package mysql { object ValuesTests extends ValuesTests with MySqlSuite object LateralJoinTests extends LateralJoinTests with MySqlSuite object WindowFunctionTests extends WindowFunctionTests with MySqlSuite + object GetGeneratedKeysTests extends GetGeneratedKeysTests with MySqlSuite object SubQueryTests extends SubQueryTests with MySqlSuite object WithCteTests extends WithCteTests with MySqlSuite @@ -188,6 +192,8 @@ package sqlite { // Sqlite does not support lateral joins // object LateralJoinTests extends LateralJoinTests with SqliteSuite object WindowFunctionTests extends WindowFunctionTests with SqliteSuite + // Sqlite does not support getGeneratedKeys https://github.com/xerial/sqlite-jdbc/issues/980 + // object GetGeneratedKeysTests extends GetGeneratedKeysTests with SqliteSuite object SubQueryTests extends SubQueryTests with SqliteSuite object WithCteTests extends WithCteTests with SqliteSuite @@ -233,6 +239,7 @@ package h2 { // H2 does not support lateral joins // object LateralJoinTests extends LateralJoinTests with H2Suite object WindowFunctionTests extends WindowFunctionTests with H2Suite + object GetGeneratedKeysTests extends GetGeneratedKeysTests with H2Suite object SubQueryTests extends SubQueryTests with H2Suite object WithCteTests extends WithCteTests with H2Suite diff --git a/scalasql/test/src/api/DbApiTests.scala b/scalasql/test/src/api/DbApiTests.scala index 56380ddd..a631135f 100644 --- a/scalasql/test/src/api/DbApiTests.scala +++ b/scalasql/test/src/api/DbApiTests.scala @@ -3,7 +3,7 @@ package scalasql.api import geny.Generator import scalasql.core.SqlStr.SqlStringSyntax import scalasql.{Buyer, Sc} -import scalasql.utils.{MySqlSuite, ScalaSqlSuite} +import scalasql.utils.{MySqlSuite, ScalaSqlSuite, SqliteSuite} import sourcecode.Text import utest._ @@ -123,32 +123,36 @@ trait DbApiTests extends ScalaSqlSuite { } } ) - test("updateGetGeneratedKeysSql") - checker.recorded( - """ - ??? - """, - Text { + test("updateGetGeneratedKeysSql") - { + if (!this.isInstanceOf[SqliteSuite]) checker.recorded( + """ + Allows you to fetch the primary keys that were auto-generated for an INSERT + defined as a `SqlStr`. + Note: not supported by Sqlite https://github.com/xerial/sqlite-jdbc/issues/980 + """, + Text { - dbClient.transaction { db => - val newName = "Moo Moo Cow" - val newDateOfBirth = LocalDate.parse("2000-01-01") - val generatedIds = db - .updateGetGeneratedKeysSql[Int]( - sql"INSERT INTO buyer (name, date_of_birth) VALUES ($newName, $newDateOfBirth), ($newName, $newDateOfBirth)" - ) + dbClient.transaction { db => + val newName = "Moo Moo Cow" + val newDateOfBirth = LocalDate.parse("2000-01-01") + val generatedIds = db + .updateGetGeneratedKeysSql[Int]( + sql"INSERT INTO buyer (name, date_of_birth) VALUES ($newName, $newDateOfBirth), ($newName, $newDateOfBirth)" + ) - assert(generatedIds == Seq(4, 5)) + assert(generatedIds == Seq(4, 5)) - db.run(Buyer.select) ==> List( - Buyer[Sc](1, "James Bond", LocalDate.parse("2001-02-03")), - Buyer[Sc](2, "叉烧包", LocalDate.parse("1923-11-12")), - Buyer[Sc](3, "Li Haoyi", LocalDate.parse("1965-08-09")), - Buyer[Sc](4, "Moo Moo Cow", LocalDate.parse("2000-01-01")), - Buyer[Sc](5, "Moo Moo Cow", LocalDate.parse("2000-01-01")), - ) + db.run(Buyer.select) ==> List( + Buyer[Sc](1, "James Bond", LocalDate.parse("2001-02-03")), + Buyer[Sc](2, "叉烧包", LocalDate.parse("1923-11-12")), + Buyer[Sc](3, "Li Haoyi", LocalDate.parse("1965-08-09")), + Buyer[Sc](4, "Moo Moo Cow", LocalDate.parse("2000-01-01")), + Buyer[Sc](5, "Moo Moo Cow", LocalDate.parse("2000-01-01")), + ) + } } - } - ) + ) + } test("runRaw") - checker.recorded( """ @@ -185,31 +189,35 @@ trait DbApiTests extends ScalaSqlSuite { } } ) - test("updateGetGeneratedKeysRaw") - checker.recorded( - """ - ??? - """, - Text { - dbClient.transaction { db => - val generatedKeys = db.updateGetGeneratedKeysRaw[Int]( - "INSERT INTO buyer (name, date_of_birth) VALUES (?, ?), (?, ?)", - Seq( - "Moo Moo Cow", LocalDate.parse("2000-01-01"), - "Moo Moo Cow", LocalDate.parse("2000-01-01") + test("updateGetGeneratedKeysRaw") - { + if (!this.isInstanceOf[SqliteSuite]) checker.recorded( + """ + Allows you to fetch the primary keys that were auto-generated for an INSERT + defined using a raw `java.lang.String` and variables. + Note: not supported by Sqlite https://github.com/xerial/sqlite-jdbc/issues/980 + """, + Text { + dbClient.transaction { db => + val generatedKeys = db.updateGetGeneratedKeysRaw[Int]( + "INSERT INTO buyer (name, date_of_birth) VALUES (?, ?), (?, ?)", + Seq( + "Moo Moo Cow", LocalDate.parse("2000-01-01"), + "Moo Moo Cow", LocalDate.parse("2000-01-01") + ) ) - ) - assert(generatedKeys == Seq(4, 5)) + assert(generatedKeys == Seq(4, 5)) - db.run(Buyer.select) ==> List( - Buyer[Sc](1, "James Bond", LocalDate.parse("2001-02-03")), - Buyer[Sc](2, "叉烧包", LocalDate.parse("1923-11-12")), - Buyer[Sc](3, "Li Haoyi", LocalDate.parse("1965-08-09")), - Buyer[Sc](4, "Moo Moo Cow", LocalDate.parse("2000-01-01")), - Buyer[Sc](5, "Moo Moo Cow", LocalDate.parse("2000-01-01")) - ) + db.run(Buyer.select) ==> List( + Buyer[Sc](1, "James Bond", LocalDate.parse("2001-02-03")), + Buyer[Sc](2, "叉烧包", LocalDate.parse("1923-11-12")), + Buyer[Sc](3, "Li Haoyi", LocalDate.parse("1965-08-09")), + Buyer[Sc](4, "Moo Moo Cow", LocalDate.parse("2000-01-01")), + Buyer[Sc](5, "Moo Moo Cow", LocalDate.parse("2000-01-01")) + ) + } } - } - ) + ) + } test("stream") - checker.recorded( """ diff --git a/scalasql/test/src/query/GetGeneratedKeysTests.scala b/scalasql/test/src/query/GetGeneratedKeysTests.scala new file mode 100644 index 00000000..30913982 --- /dev/null +++ b/scalasql/test/src/query/GetGeneratedKeysTests.scala @@ -0,0 +1,157 @@ +package scalasql.query + +import scalasql._ +import scalasql.utils.ScalaSqlSuite +import utest._ + +import java.time.LocalDate + +trait GetGeneratedKeysTests extends ScalaSqlSuite { + def description = "`INSERT` operations with `.getGeneratedKeys`. Not supported by Sqlite (see https://github.com/xerial/sqlite-jdbc/issues/980)" + override def utestBeforeEach(path: Seq[String]): Unit = checker.reset() + def tests = Tests { + test("single") { + test("values") - { + checker( + query = Buyer.insert + .values( + Buyer[Sc](17, "test buyer", LocalDate.parse("2023-09-09")) + ) + .getGeneratedKeys[Int], + sql = "INSERT INTO buyer (id, name, date_of_birth) VALUES (?, ?, ?)", + value = Seq(17), + docs = """ + `getGeneratedKeys` on an `insert` returns the primary key, even if it was provided + explicitly. + """ + ) + + checker( + query = Buyer.select.filter(_.name `=` "test buyer"), + value = Seq(Buyer[Sc](17, "test buyer", LocalDate.parse("2023-09-09"))) + ) + } + + test("columns") - { + checker( + query = Buyer.insert + .columns( + _.name := "test buyer", + _.dateOfBirth := LocalDate.parse("2023-09-09"), + _.id := 4 + ) + .getGeneratedKeys[Long], + sql = "INSERT INTO buyer (name, date_of_birth, id) VALUES (?, ?, ?)", + value = Seq(4L), + docs = """ + All styles of `INSERT` query support `.getGeneratedKeys`, with this example + using `insert.columns` rather than `insert.values`. You can also retrieve + the generated primary keys using any compatible type, here shown using `Long` + rather than `Int` + """ + ) + + checker( + query = Buyer.select.filter(_.name `=` "test buyer"), + value = Seq(Buyer[Sc](4, "test buyer", LocalDate.parse("2023-09-09"))) + ) + } + + test("partial") - { + checker( + query = Buyer.insert + .columns(_.name := "test buyer", _.dateOfBirth := LocalDate.parse("2023-09-09")) + .getGeneratedKeys[Int], + sql = "INSERT INTO buyer (name, date_of_birth) VALUES (?, ?)", + value = Seq(4), + docs = """ + If the primary key was not provided but was auto-generated by the database, + `getGeneratedKeys` returns the generated value + """ + ) + + checker( + query = Buyer.select.filter(_.name `=` "test buyer"), + // id=4 comes from auto increment + value = Seq(Buyer[Sc](4, "test buyer", LocalDate.parse("2023-09-09"))) + ) + } + } + + test("batch") { + + test("partial") - { + checker( + query = Buyer.insert + .batched(_.name, _.dateOfBirth)( + ("test buyer A", LocalDate.parse("2001-04-07")), + ("test buyer B", LocalDate.parse("2002-05-08")), + ("test buyer C", LocalDate.parse("2003-06-09")) + ) + .getGeneratedKeys[Int], + sql = """ + INSERT INTO buyer (name, date_of_birth) + VALUES (?, ?), (?, ?), (?, ?) + """, + value = Seq(4, 5, 6), + docs = """ + `getGeneratedKeys` can return multiple generated primary key values for + a batch insert statement + """ + ) + + checker( + query = Buyer.select, + value = Seq( + Buyer[Sc](1, "James Bond", LocalDate.parse("2001-02-03")), + Buyer[Sc](2, "叉烧包", LocalDate.parse("1923-11-12")), + Buyer[Sc](3, "Li Haoyi", LocalDate.parse("1965-08-09")), + // id=4,5,6 comes from auto increment + Buyer[Sc](4, "test buyer A", LocalDate.parse("2001-04-07")), + Buyer[Sc](5, "test buyer B", LocalDate.parse("2002-05-08")), + Buyer[Sc](6, "test buyer C", LocalDate.parse("2003-06-09")) + ) + ) + } + + } + + test("select") { + + test("simple") { + checker( + query = Buyer.insert + .select( + x => (x.name, x.dateOfBirth), + Buyer.select.map(x => (x.name, x.dateOfBirth)).filter(_._1 <> "Li Haoyi") + ) + .getGeneratedKeys[Int], + sql = """ + INSERT INTO buyer (name, date_of_birth) + SELECT buyer0.name AS res_0, buyer0.date_of_birth AS res_1 + FROM buyer buyer0 + WHERE (buyer0.name <> ?) + """, + value = Seq(4, 5), + docs = """ + `getGeneratedKeys` can return multiple generated primary key values for + an `insert` based on a `select` + """ + ) + + checker( + query = Buyer.select, + value = Seq( + Buyer[Sc](1, "James Bond", LocalDate.parse("2001-02-03")), + Buyer[Sc](2, "叉烧包", LocalDate.parse("1923-11-12")), + Buyer[Sc](3, "Li Haoyi", LocalDate.parse("1965-08-09")), + // id=4,5 comes from auto increment, 6 is filtered out in the select + Buyer[Sc](4, "James Bond", LocalDate.parse("2001-02-03")), + Buyer[Sc](5, "叉烧包", LocalDate.parse("1923-11-12")) + ) + ) + } + + } + } +} From ae9e9e3c2df0292d77a1baca6bb00e690ffab97a Mon Sep 17 00:00:00 2001 From: Li Haoyi Date: Fri, 12 Apr 2024 13:11:24 +0800 Subject: [PATCH 3/4] . --- scalasql/test/src/Main.scala | 84 +++++++++++++----------------------- 1 file changed, 31 insertions(+), 53 deletions(-) diff --git a/scalasql/test/src/Main.scala b/scalasql/test/src/Main.scala index 594014f4..6a6d4833 100644 --- a/scalasql/test/src/Main.scala +++ b/scalasql/test/src/Main.scala @@ -1,60 +1,38 @@ package scalasql -import java.sql.{Connection, DriverManager, PreparedStatement, ResultSet, SQLException} +import java.sql.DriverManager +import scalasql.H2Dialect._ +object Main { -object Main extends App { - var connection: Connection = null - var preparedStatement: PreparedStatement = null - var resultSet: ResultSet = null + case class Example[T[_]](bytes: T[geny.Bytes]) - try { - // 1. Connect to the database - connection = DriverManager.getConnection("jdbc:h2:~/test", "username", "password") + object Example extends Table[Example] - // 2. Create the table if it doesn't exist - val createTableSql = + // The example H2 database comes from the library `com.h2database:h2:2.2.224` + val conn = DriverManager.getConnection("jdbc:h2:mem:mydb") + + def main(args: Array[String]): Unit = { + conn + .createStatement() + .executeUpdate( + """ + CREATE TABLE data_types ( + my_var_binary VARBINARY(256) + ); """ - |CREATE TABLE IF NOT EXISTS MY_TABLE ( - | ID INT AUTO_INCREMENT PRIMARY KEY, - | COLUMN1 VARCHAR(255), - | COLUMN2 VARCHAR(255) - |) - |""".stripMargin - preparedStatement = connection.prepareStatement(createTableSql) - preparedStatement.execute() - - // 3. Prepare a statement with generated keys option - val insertSql = "INSERT INTO MY_TABLE (COLUMN1, COLUMN2) VALUES (?, ?)" - preparedStatement = connection.prepareStatement(insertSql, java.sql.Statement.RETURN_GENERATED_KEYS) - - // 4. Set values for placeholders - preparedStatement.setString(1, "value1") - preparedStatement.setString(2, "value2") - - // 5. Execute the insert statement - val affectedRows = preparedStatement.executeUpdate() - - if (affectedRows == 0) { - println("Insertion failed, no rows affected.") - } else { - // 6. Retrieve generated keys - resultSet = preparedStatement.getGeneratedKeys() - if (resultSet.next()) { - println("Generated Key: " + resultSet.getLong(1)) - } else { - println("No generated keys were retrieved.") - } - } - } catch { - case e: SQLException => e.printStackTrace() - } finally { - // 7. Close resources - try { - if (resultSet != null) resultSet.close() - if (preparedStatement != null) preparedStatement.close() - if (connection != null) connection.close() - } catch { - case e: SQLException => e.printStackTrace() - } + ) + + val prepared = conn.prepareStatement("INSERT INTO data_types (my_var_binary) VALUES (?)") + prepared.setBytes(1, Array[Byte](1, 2, 3, 4)) + prepared.executeUpdate() + + val results = conn + .createStatement() + .executeQuery( + "SELECT data_types0.my_var_binary AS my_var_binary FROM data_types data_types0" + ) + + results.next() + pprint.log(results.getBytes(1)) } -} \ No newline at end of file +} From 1b1074b7faa80494136aa3e2084c4fc1ea35ada5 Mon Sep 17 00:00:00 2001 From: Li Haoyi Date: Fri, 12 Apr 2024 13:30:21 +0800 Subject: [PATCH 4/4] update-docs --- docs/reference.md | 322 ++++++++++++++++++ scalasql/core/src/DbApi.scala | 32 +- scalasql/query/src/GetGeneratedKeys.scala | 3 +- scalasql/query/src/Returning.scala | 3 +- scalasql/test/src/api/DbApiTests.scala | 84 ++--- .../src/query/GetGeneratedKeysTests.scala | 3 +- 6 files changed, 392 insertions(+), 55 deletions(-) diff --git a/docs/reference.md b/docs/reference.md index 7754b9dd..bc6d28b6 100644 --- a/docs/reference.md +++ b/docs/reference.md @@ -139,6 +139,38 @@ dbClient.transaction { db => +### DbApi.updateGetGeneratedKeysSql + +Allows you to fetch the primary keys that were auto-generated for an INSERT +defined as a `SqlStr`. +Note: not supported by Sqlite https://github.com/xerial/sqlite-jdbc/issues/980 + +```scala +dbClient.transaction { db => + val newName = "Moo Moo Cow" + val newDateOfBirth = LocalDate.parse("2000-01-01") + val generatedIds = db + .updateGetGeneratedKeysSql[Int]( + sql"INSERT INTO buyer (name, date_of_birth) VALUES ($newName, $newDateOfBirth), ($newName, $newDateOfBirth)" + ) + + assert(generatedIds == Seq(4, 5)) + + db.run(Buyer.select) ==> List( + Buyer[Sc](1, "James Bond", LocalDate.parse("2001-02-03")), + Buyer[Sc](2, "叉烧包", LocalDate.parse("1923-11-12")), + Buyer[Sc](3, "Li Haoyi", LocalDate.parse("1965-08-09")), + Buyer[Sc](4, "Moo Moo Cow", LocalDate.parse("2000-01-01")), + Buyer[Sc](5, "Moo Moo Cow", LocalDate.parse("2000-01-01")) + ) +} +``` + + + + + + ### DbApi.runRaw `runRawQuery` is similar to `runQuery` but allows you to pass in the SQL strings @@ -183,6 +215,40 @@ dbClient.transaction { db => +### DbApi.updateGetGeneratedKeysRaw + +Allows you to fetch the primary keys that were auto-generated for an INSERT +defined using a raw `java.lang.String` and variables. +Note: not supported by Sqlite https://github.com/xerial/sqlite-jdbc/issues/980 + +```scala +dbClient.transaction { db => + val generatedKeys = db.updateGetGeneratedKeysRaw[Int]( + "INSERT INTO buyer (name, date_of_birth) VALUES (?, ?), (?, ?)", + Seq( + "Moo Moo Cow", + LocalDate.parse("2000-01-01"), + "Moo Moo Cow", + LocalDate.parse("2000-01-01") + ) + ) + assert(generatedKeys == Seq(4, 5)) + + db.run(Buyer.select) ==> List( + Buyer[Sc](1, "James Bond", LocalDate.parse("2001-02-03")), + Buyer[Sc](2, "叉烧包", LocalDate.parse("1923-11-12")), + Buyer[Sc](3, "Li Haoyi", LocalDate.parse("1965-08-09")), + Buyer[Sc](4, "Moo Moo Cow", LocalDate.parse("2000-01-01")), + Buyer[Sc](5, "Moo Moo Cow", LocalDate.parse("2000-01-01")) + ) +} +``` + + + + + + ### DbApi.stream `db.stream` can be run on queries that return `Seq[T]`s, and makes them @@ -5877,6 +5943,262 @@ Purchase.select.mapAggregate((p, ps) => +## GetGeneratedKeys +`INSERT` operations with `.getGeneratedKeys`. Not supported by Sqlite (see https://github.com/xerial/sqlite-jdbc/issues/980) +### GetGeneratedKeys.single.values + +`getGeneratedKeys` on an `insert` returns the primary key, even if it was provided + explicitly. + +```scala +Buyer.insert + .values( + Buyer[Sc](17, "test buyer", LocalDate.parse("2023-09-09")) + ) + .getGeneratedKeys[Int] +``` + + +* + ```sql + INSERT INTO buyer (id, name, date_of_birth) VALUES (?, ?, ?) + ``` + + + +* + ```scala + Seq(17) + ``` + + + +---- + + + +```scala +Buyer.select.filter(_.name `=` "test buyer") +``` + + + + +* + ```scala + Seq(Buyer[Sc](17, "test buyer", LocalDate.parse("2023-09-09"))) + ``` + + + +### GetGeneratedKeys.single.columns + +All styles of `INSERT` query support `.getGeneratedKeys`, with this example +using `insert.columns` rather than `insert.values`. You can also retrieve +the generated primary keys using any compatible type, here shown using `Long` +rather than `Int` + +```scala +Buyer.insert + .columns( + _.name := "test buyer", + _.dateOfBirth := LocalDate.parse("2023-09-09"), + _.id := 4 + ) + .getGeneratedKeys[Long] +``` + + +* + ```sql + INSERT INTO buyer (name, date_of_birth, id) VALUES (?, ?, ?) + ``` + + + +* + ```scala + Seq(4L) + ``` + + + +---- + + + +```scala +Buyer.select.filter(_.name `=` "test buyer") +``` + + + + +* + ```scala + Seq(Buyer[Sc](4, "test buyer", LocalDate.parse("2023-09-09"))) + ``` + + + +### GetGeneratedKeys.single.partial + +If the primary key was not provided but was auto-generated by the database, +`getGeneratedKeys` returns the generated value + +```scala +Buyer.insert + .columns(_.name := "test buyer", _.dateOfBirth := LocalDate.parse("2023-09-09")) + .getGeneratedKeys[Int] +``` + + +* + ```sql + INSERT INTO buyer (name, date_of_birth) VALUES (?, ?) + ``` + + + +* + ```scala + Seq(4) + ``` + + + +---- + + + +```scala +Buyer.select.filter(_.name `=` "test buyer") +``` + + + + +* + ```scala + Seq(Buyer[Sc](4, "test buyer", LocalDate.parse("2023-09-09"))) + ``` + + + +### GetGeneratedKeys.batch.partial + +`getGeneratedKeys` can return multiple generated primary key values for +a batch insert statement + +```scala +Buyer.insert + .batched(_.name, _.dateOfBirth)( + ("test buyer A", LocalDate.parse("2001-04-07")), + ("test buyer B", LocalDate.parse("2002-05-08")), + ("test buyer C", LocalDate.parse("2003-06-09")) + ) + .getGeneratedKeys[Int] +``` + + +* + ```sql + INSERT INTO buyer (name, date_of_birth) + VALUES (?, ?), (?, ?), (?, ?) + ``` + + + +* + ```scala + Seq(4, 5, 6) + ``` + + + +---- + + + +```scala +Buyer.select +``` + + + + +* + ```scala + Seq( + Buyer[Sc](1, "James Bond", LocalDate.parse("2001-02-03")), + Buyer[Sc](2, "叉烧包", LocalDate.parse("1923-11-12")), + Buyer[Sc](3, "Li Haoyi", LocalDate.parse("1965-08-09")), + // id=4,5,6 comes from auto increment + Buyer[Sc](4, "test buyer A", LocalDate.parse("2001-04-07")), + Buyer[Sc](5, "test buyer B", LocalDate.parse("2002-05-08")), + Buyer[Sc](6, "test buyer C", LocalDate.parse("2003-06-09")) + ) + ``` + + + +### GetGeneratedKeys.select.simple + +`getGeneratedKeys` can return multiple generated primary key values for +an `insert` based on a `select` + +```scala +Buyer.insert + .select( + x => (x.name, x.dateOfBirth), + Buyer.select.map(x => (x.name, x.dateOfBirth)).filter(_._1 <> "Li Haoyi") + ) + .getGeneratedKeys[Int] +``` + + +* + ```sql + INSERT INTO buyer (name, date_of_birth) + SELECT buyer0.name AS res_0, buyer0.date_of_birth AS res_1 + FROM buyer buyer0 + WHERE (buyer0.name <> ?) + ``` + + + +* + ```scala + Seq(4, 5) + ``` + + + +---- + + + +```scala +Buyer.select +``` + + + + +* + ```scala + Seq( + Buyer[Sc](1, "James Bond", LocalDate.parse("2001-02-03")), + Buyer[Sc](2, "叉烧包", LocalDate.parse("1923-11-12")), + Buyer[Sc](3, "Li Haoyi", LocalDate.parse("1965-08-09")), + // id=4,5 comes from auto increment, 6 is filtered out in the select + Buyer[Sc](4, "James Bond", LocalDate.parse("2001-02-03")), + Buyer[Sc](5, "叉烧包", LocalDate.parse("1923-11-12")) + ) + ``` + + + ## SubQuery Queries that explicitly use subqueries (e.g. for `JOIN`s) or require subqueries to preserve the Scala semantics of the various operators ### SubQuery.sortTakeJoin diff --git a/scalasql/core/src/DbApi.scala b/scalasql/core/src/DbApi.scala index 97a8010d..c8de1f7c 100644 --- a/scalasql/core/src/DbApi.scala +++ b/scalasql/core/src/DbApi.scala @@ -104,9 +104,9 @@ trait DbApi extends AutoCloseable { )(implicit fileName: sourcecode.FileName, lineNum: sourcecode.Line): Int def updateGetGeneratedKeysSql[R](sql: SqlStr, fetchSize: Int = -1, queryTimeoutSeconds: Int = -1)( - implicit qr: Queryable.Row[_, R], - fileName: sourcecode.FileName, - lineNum: sourcecode.Line + implicit qr: Queryable.Row[_, R], + fileName: sourcecode.FileName, + lineNum: sourcecode.Line ): IndexedSeq[R] def updateGetGeneratedKeysRaw[R]( @@ -114,9 +114,10 @@ trait DbApi extends AutoCloseable { variables: Seq[Any] = Nil, fetchSize: Int = -1, queryTimeoutSeconds: Int = -1 - )(implicit qr: Queryable.Row[_, R], - fileName: sourcecode.FileName, - lineNum: sourcecode.Line + )( + implicit qr: Queryable.Row[_, R], + fileName: sourcecode.FileName, + lineNum: sourcecode.Line ): IndexedSeq[R] } @@ -196,7 +197,9 @@ object DbApi { ): R = { val flattened = unpackQueryable(query, qr, config) - if (qr.isGetGeneratedKeys(query).nonEmpty) updateGetGeneratedKeysSql(flattened)(qr.isGetGeneratedKeys(query).get, fileName, lineNum).asInstanceOf[R] + if (qr.isGetGeneratedKeys(query).nonEmpty) + updateGetGeneratedKeysSql(flattened)(qr.isGetGeneratedKeys(query).get, fileName, lineNum) + .asInstanceOf[R] else if (qr.isExecuteUpdate(query)) updateSql(flattened).asInstanceOf[R] else { try { @@ -284,7 +287,11 @@ object DbApi { ) } - def updateGetGeneratedKeysSql[R](sql: SqlStr, fetchSize: Int = -1, queryTimeoutSeconds: Int = -1)( + def updateGetGeneratedKeysSql[R]( + sql: SqlStr, + fetchSize: Int = -1, + queryTimeoutSeconds: Int = -1 + )( implicit qr: Queryable.Row[_, R], fileName: sourcecode.FileName, lineNum: sourcecode.Line @@ -354,9 +361,10 @@ object DbApi { variables: Seq[Any] = Nil, fetchSize: Int = -1, queryTimeoutSeconds: Int = -1 - )(implicit qr: Queryable.Row[_, R], - fileName: sourcecode.FileName, - lineNum: sourcecode.Line + )( + implicit qr: Queryable.Row[_, R], + fileName: sourcecode.FileName, + lineNum: sourcecode.Line ): IndexedSeq[R] = runRawUpdateGetGeneratedKeys0( sql, anySeqPuts(variables), @@ -469,7 +477,7 @@ object DbApi { sql, fileName, lineNum - ){ stmt => + ) { stmt => stmt.executeUpdate() val resultSet = stmt.getGeneratedKeys val output = Vector.newBuilder[R] diff --git a/scalasql/query/src/GetGeneratedKeys.scala b/scalasql/query/src/GetGeneratedKeys.scala index 119fb0d9..13f9ef4c 100644 --- a/scalasql/query/src/GetGeneratedKeys.scala +++ b/scalasql/query/src/GetGeneratedKeys.scala @@ -13,7 +13,8 @@ trait GetGeneratedKeys[Q, R] extends Query[Seq[R]] { object GetGeneratedKeys { - class Impl[Q, R](base: Returning.InsertBase[Q])(implicit qr: Queryable.Row[_, R]) extends GetGeneratedKeys[Q, R]{ + class Impl[Q, R](base: Returning.InsertBase[Q])(implicit qr: Queryable.Row[_, R]) + extends GetGeneratedKeys[Q, R] { def expr = WithSqlExpr.get(base) override protected def queryConstruct(args: Queryable.ResultSetIterator): Seq[R] = { diff --git a/scalasql/query/src/Returning.scala b/scalasql/query/src/Returning.scala index c3f99c14..0f0a1d44 100644 --- a/scalasql/query/src/Returning.scala +++ b/scalasql/query/src/Returning.scala @@ -20,7 +20,8 @@ object Returning { def table: TableRef } - trait InsertBase[Q] extends Base[Q]{ + trait InsertBase[Q] extends Base[Q] { + /** * Makes this `INSERT` query call `JdbcStatement.getGeneratedKeys` when it is executed, * returning a `Seq[R]` where `R` is a Scala type compatible with the auto-generated diff --git a/scalasql/test/src/api/DbApiTests.scala b/scalasql/test/src/api/DbApiTests.scala index a631135f..dd840eac 100644 --- a/scalasql/test/src/api/DbApiTests.scala +++ b/scalasql/test/src/api/DbApiTests.scala @@ -124,34 +124,35 @@ trait DbApiTests extends ScalaSqlSuite { } ) test("updateGetGeneratedKeysSql") - { - if (!this.isInstanceOf[SqliteSuite]) checker.recorded( - """ + if (!this.isInstanceOf[SqliteSuite]) + checker.recorded( + """ Allows you to fetch the primary keys that were auto-generated for an INSERT defined as a `SqlStr`. Note: not supported by Sqlite https://github.com/xerial/sqlite-jdbc/issues/980 """, - Text { + Text { - dbClient.transaction { db => - val newName = "Moo Moo Cow" - val newDateOfBirth = LocalDate.parse("2000-01-01") - val generatedIds = db - .updateGetGeneratedKeysSql[Int]( - sql"INSERT INTO buyer (name, date_of_birth) VALUES ($newName, $newDateOfBirth), ($newName, $newDateOfBirth)" - ) + dbClient.transaction { db => + val newName = "Moo Moo Cow" + val newDateOfBirth = LocalDate.parse("2000-01-01") + val generatedIds = db + .updateGetGeneratedKeysSql[Int]( + sql"INSERT INTO buyer (name, date_of_birth) VALUES ($newName, $newDateOfBirth), ($newName, $newDateOfBirth)" + ) - assert(generatedIds == Seq(4, 5)) + assert(generatedIds == Seq(4, 5)) - db.run(Buyer.select) ==> List( - Buyer[Sc](1, "James Bond", LocalDate.parse("2001-02-03")), - Buyer[Sc](2, "叉烧包", LocalDate.parse("1923-11-12")), - Buyer[Sc](3, "Li Haoyi", LocalDate.parse("1965-08-09")), - Buyer[Sc](4, "Moo Moo Cow", LocalDate.parse("2000-01-01")), - Buyer[Sc](5, "Moo Moo Cow", LocalDate.parse("2000-01-01")), - ) + db.run(Buyer.select) ==> List( + Buyer[Sc](1, "James Bond", LocalDate.parse("2001-02-03")), + Buyer[Sc](2, "叉烧包", LocalDate.parse("1923-11-12")), + Buyer[Sc](3, "Li Haoyi", LocalDate.parse("1965-08-09")), + Buyer[Sc](4, "Moo Moo Cow", LocalDate.parse("2000-01-01")), + Buyer[Sc](5, "Moo Moo Cow", LocalDate.parse("2000-01-01")) + ) + } } - } - ) + ) } test("runRaw") - checker.recorded( @@ -190,33 +191,36 @@ trait DbApiTests extends ScalaSqlSuite { } ) test("updateGetGeneratedKeysRaw") - { - if (!this.isInstanceOf[SqliteSuite]) checker.recorded( - """ + if (!this.isInstanceOf[SqliteSuite]) + checker.recorded( + """ Allows you to fetch the primary keys that were auto-generated for an INSERT defined using a raw `java.lang.String` and variables. Note: not supported by Sqlite https://github.com/xerial/sqlite-jdbc/issues/980 """, - Text { - dbClient.transaction { db => - val generatedKeys = db.updateGetGeneratedKeysRaw[Int]( - "INSERT INTO buyer (name, date_of_birth) VALUES (?, ?), (?, ?)", - Seq( - "Moo Moo Cow", LocalDate.parse("2000-01-01"), - "Moo Moo Cow", LocalDate.parse("2000-01-01") + Text { + dbClient.transaction { db => + val generatedKeys = db.updateGetGeneratedKeysRaw[Int]( + "INSERT INTO buyer (name, date_of_birth) VALUES (?, ?), (?, ?)", + Seq( + "Moo Moo Cow", + LocalDate.parse("2000-01-01"), + "Moo Moo Cow", + LocalDate.parse("2000-01-01") + ) ) - ) - assert(generatedKeys == Seq(4, 5)) + assert(generatedKeys == Seq(4, 5)) - db.run(Buyer.select) ==> List( - Buyer[Sc](1, "James Bond", LocalDate.parse("2001-02-03")), - Buyer[Sc](2, "叉烧包", LocalDate.parse("1923-11-12")), - Buyer[Sc](3, "Li Haoyi", LocalDate.parse("1965-08-09")), - Buyer[Sc](4, "Moo Moo Cow", LocalDate.parse("2000-01-01")), - Buyer[Sc](5, "Moo Moo Cow", LocalDate.parse("2000-01-01")) - ) + db.run(Buyer.select) ==> List( + Buyer[Sc](1, "James Bond", LocalDate.parse("2001-02-03")), + Buyer[Sc](2, "叉烧包", LocalDate.parse("1923-11-12")), + Buyer[Sc](3, "Li Haoyi", LocalDate.parse("1965-08-09")), + Buyer[Sc](4, "Moo Moo Cow", LocalDate.parse("2000-01-01")), + Buyer[Sc](5, "Moo Moo Cow", LocalDate.parse("2000-01-01")) + ) + } } - } - ) + ) } test("stream") - checker.recorded( diff --git a/scalasql/test/src/query/GetGeneratedKeysTests.scala b/scalasql/test/src/query/GetGeneratedKeysTests.scala index 30913982..d30801bb 100644 --- a/scalasql/test/src/query/GetGeneratedKeysTests.scala +++ b/scalasql/test/src/query/GetGeneratedKeysTests.scala @@ -7,7 +7,8 @@ import utest._ import java.time.LocalDate trait GetGeneratedKeysTests extends ScalaSqlSuite { - def description = "`INSERT` operations with `.getGeneratedKeys`. Not supported by Sqlite (see https://github.com/xerial/sqlite-jdbc/issues/980)" + def description = + "`INSERT` operations with `.getGeneratedKeys`. Not supported by Sqlite (see https://github.com/xerial/sqlite-jdbc/issues/980)" override def utestBeforeEach(path: Seq[String]): Unit = checker.reset() def tests = Tests { test("single") {