diff --git a/scalasql/core/src/DbApi.scala b/scalasql/core/src/DbApi.scala index be9aa105..67e0dfe3 100644 --- a/scalasql/core/src/DbApi.scala +++ b/scalasql/core/src/DbApi.scala @@ -1,5 +1,7 @@ package scalasql.core +import DbClient.notifyListeners + import geny.Generator import java.sql.{PreparedStatement, Statement} @@ -136,6 +138,50 @@ object DbApi { flattened.renderSql(castParams) } + /** + * A listener that can be added to a [[DbApi.Txn]] to be notified of commit and rollback events. + * + * The default implementations of these methods do nothing, but you can override them to + * implement your own behavior. + */ + trait TransactionListener { + + /** + * Called when a new transaction is started. + */ + def begin(): Unit = () + + /** + * Called before the transaction is committed. + * + * If this method throws an exception, the transaction will be rolled back and the exception + * will be propagated. + */ + def beforeCommit(): Unit = () + + /** + * Called after the transaction is committed. + * + * If this method throws an exception, it will be propagated. + */ + def afterCommit(): Unit = () + + /** + * Called before the transaction is rolled back. + * + * If this method throws an exception, the transaction will be rolled back and the exception + * will be propagated to the caller of rollback(). + */ + def beforeRollback(): Unit = () + + /** + * Called after the transaction is rolled back. + * + * If this method throws an exception, it will be propagated to the caller of rollback(). + */ + def afterRollback(): Unit = () + } + /** * An interface to a SQL database *transaction*, allowing you to run queries, * create savepoints, or roll back the transaction. @@ -151,9 +197,11 @@ object DbApi { def savepoint[T](block: DbApi.Savepoint => T): T /** - * Tolls back any active Savepoints and then rolls back this Transaction + * Rolls back any active Savepoints and then rolls back this Transaction */ def rollback(): Unit + + def addTransactionListener(listener: TransactionListener): Unit } /** @@ -187,9 +235,19 @@ object DbApi { connection: java.sql.Connection, config: Config, dialect: DialectConfig, - autoCommit: Boolean, - rollBack0: () => Unit + defaultListeners: Iterable[TransactionListener], + autoCommit: Boolean ) extends DbApi.Txn { + + val listeners = + collection.mutable.ArrayDeque.empty[TransactionListener].addAll(defaultListeners) + + override def addTransactionListener(listener: TransactionListener): Unit = { + if (autoCommit) + throw new IllegalStateException("Cannot add listener to auto-commit transaction") + listeners.append(listener) + } + def run[Q, R](query: Q, fetchSize: Int = -1, queryTimeoutSeconds: Int = -1)( implicit qr: Queryable[Q, R], fileName: sourcecode.FileName, @@ -218,6 +276,7 @@ object DbApi { res.toVector.asInstanceOf[R] } } + } def stream[Q, R](query: Q, fetchSize: Int = -1, queryTimeoutSeconds: Int = -1)( @@ -229,8 +288,8 @@ object DbApi { streamFlattened0( r => { qr.asInstanceOf[Queryable[Q, R]].construct(query, r) match { - case s: Seq[R] => s.head - case r: R => r + case s: Seq[R] @unchecked => s.head + case r: R @unchecked => r } }, flattened, @@ -545,8 +604,13 @@ object DbApi { } def rollback() = { - savepointStack.clear() - rollBack0() + try { + notifyListeners(listeners)(_.beforeRollback()) + } finally { + savepointStack.clear() + connection.rollback() + notifyListeners(listeners)(_.afterRollback()) + } } private def cast[T](t: Any): T = t.asInstanceOf[T] diff --git a/scalasql/core/src/DbClient.scala b/scalasql/core/src/DbClient.scala index 2524479d..30f7255f 100644 --- a/scalasql/core/src/DbClient.scala +++ b/scalasql/core/src/DbClient.scala @@ -35,9 +35,33 @@ trait DbClient { object DbClient { + /** + * Calls the given function for each listener, collecting any exceptions and throwing them + * as a single exception if any are thrown. + */ + private[core] def notifyListeners(listeners: Iterable[DbApi.TransactionListener])( + f: DbApi.TransactionListener => Unit + ): Unit = { + if (listeners.isEmpty) return + + var exception: Throwable = null + listeners.foreach { listener => + try { + f(listener) + } catch { + case e: Throwable => + if (exception == null) exception = e + else exception.addSuppressed(e) + } + } + if (exception != null) throw exception + } + class Connection( connection: java.sql.Connection, - config: Config = new Config {} + config: Config = new Config {}, + /** Listeners that are added to all transactions created by this connection */ + listeners: Seq[DbApi.TransactionListener] = Seq.empty )(implicit dialect: DialectConfig) extends DbClient { @@ -49,28 +73,57 @@ object DbClient { def transaction[T](block: DbApi.Txn => T): T = { connection.setAutoCommit(false) - val txn = - new DbApi.Impl(connection, config, dialect, false, () => connection.rollback()) - try block(txn) - catch { + val txn = new DbApi.Impl(connection, config, dialect, listeners, autoCommit = false) + var rolledBack = false + try { + notifyListeners(txn.listeners)(_.begin()) + val result = block(txn) + notifyListeners(txn.listeners)(_.beforeCommit()) + result + } catch { case e: Throwable => - connection.rollback() + rolledBack = true + try { + notifyListeners(txn.listeners)(_.beforeRollback()) + } catch { + case e2: Throwable => e.addSuppressed(e2) + } finally { + connection.rollback() + try { + notifyListeners(txn.listeners)(_.afterRollback()) + } catch { + case e3: Throwable => e.addSuppressed(e3) + } + } throw e - } finally connection.setAutoCommit(true) + } finally { + // this commits uncommitted operations, if any + connection.setAutoCommit(true) + if (!rolledBack) { + notifyListeners(txn.listeners)(_.afterCommit()) + } + } } def getAutoCommitClientConnection: DbApi = { connection.setAutoCommit(true) - new DbApi.Impl(connection, config, dialect, autoCommit = true, () => ()) + new DbApi.Impl(connection, config, dialect, listeners, autoCommit = true) } } class DataSource( dataSource: javax.sql.DataSource, - config: Config = new Config {} + config: Config = new Config {}, + /** Listeners that are added to all transactions created through the [[DataSource]] */ + listeners: Seq[DbApi.TransactionListener] = Seq.empty )(implicit dialect: DialectConfig) extends DbClient { + /** Returns a new [[DataSource]] with the given listener added */ + def withTransactionListener(listener: DbApi.TransactionListener): DbClient = { + new DataSource(dataSource, config, listeners :+ listener) + } + def renderSql[Q, R](query: Q, castParams: Boolean = false)( implicit qr: Queryable[Q, R] ): String = { @@ -79,7 +132,7 @@ object DbClient { private def withConnection[T](f: DbClient.Connection => T): T = { val connection = dataSource.getConnection - try f(new DbClient.Connection(connection, config)) + try f(new DbClient.Connection(connection, config, listeners)) finally connection.close() } @@ -88,7 +141,7 @@ object DbClient { def getAutoCommitClientConnection: DbApi = { val connection = dataSource.getConnection connection.setAutoCommit(true) - new DbApi.Impl(connection, config, dialect, autoCommit = true, () => ()) + new DbApi.Impl(connection, config, dialect, defaultListeners = Seq.empty, autoCommit = true) } } } diff --git a/scalasql/test/src/api/TransactionTests.scala b/scalasql/test/src/api/TransactionTests.scala index f95c1c39..a9c9902d 100644 --- a/scalasql/test/src/api/TransactionTests.scala +++ b/scalasql/test/src/api/TransactionTests.scala @@ -2,6 +2,7 @@ package scalasql.api import scalasql.Purchase import scalasql.utils.{ScalaSqlSuite, SqliteSuite} +import scalasql.DbApi import sourcecode.Text import utest._ @@ -12,6 +13,42 @@ trait TransactionTests extends ScalaSqlSuite { override def utestBeforeEach(path: Seq[String]): Unit = checker.reset() class FooException extends Exception + class ListenerException(message: String) extends Exception(message) + + class StubTransactionListener( + throwOnBeforeCommit: Boolean = false, + throwOnAfterCommit: Boolean = false, + throwOnBeforeRollback: Boolean = false, + throwOnAfterRollback: Boolean = false + ) extends DbApi.TransactionListener { + var beginCalled = false + var beforeCommitCalled = false + var afterCommitCalled = false + var beforeRollbackCalled = false + var afterRollbackCalled = false + + override def begin(): Unit = { + beginCalled = true + } + + override def beforeCommit(): Unit = { + beforeCommitCalled = true + if (throwOnBeforeCommit) throw new ListenerException("beforeCommit") + } + override def afterCommit(): Unit = { + afterCommitCalled = true + if (throwOnAfterCommit) throw new ListenerException("afterCommit") + } + override def beforeRollback(): Unit = { + beforeRollbackCalled = true + if (throwOnBeforeRollback) throw new ListenerException("beforeRollback") + } + override def afterRollback(): Unit = { + afterRollbackCalled = true + if (throwOnAfterRollback) throw new ListenerException("afterRollback") + } + } + def tests = Tests { test("simple") { test("commit") - checker.recorded( @@ -537,5 +574,77 @@ trait TransactionTests extends ScalaSqlSuite { } } } + + test("listener") { + test("beforeCommit and afterCommit are called under normal circumstances") { + val listener = new StubTransactionListener() + dbClient.withTransactionListener(listener).transaction { _ => + // do nothing + } + listener.beginCalled ==> true + listener.beforeCommitCalled ==> true + listener.afterCommitCalled ==> true + listener.beforeRollbackCalled ==> false + listener.afterRollbackCalled ==> false + } + + test("if beforeCommit causes an exception, {before,after}Rollback are called") { + val listener = new StubTransactionListener(throwOnBeforeCommit = true) + val e = intercept[ListenerException] { + dbClient.transaction { implicit txn => + txn.addTransactionListener(listener) + } + } + e.getMessage ==> "beforeCommit" + listener.beforeCommitCalled ==> true + listener.afterCommitCalled ==> false + listener.beforeRollbackCalled ==> true + listener.afterRollbackCalled ==> true + } + + test("if afterCommit causes an exception, the exception is propagated") { + val listener = new StubTransactionListener(throwOnAfterCommit = true) + val e = intercept[ListenerException] { + dbClient.transaction { implicit txn => + txn.addTransactionListener(listener) + } + } + e.getMessage ==> "afterCommit" + listener.beforeCommitCalled ==> true + listener.afterCommitCalled ==> true + listener.beforeRollbackCalled ==> false + listener.afterRollbackCalled ==> false + } + + test("if beforeRollback causes an exception, afterRollback is still called") { + val listener = new StubTransactionListener(throwOnBeforeRollback = true) + val e = intercept[FooException] { + dbClient.transaction { implicit txn => + txn.addTransactionListener(listener) + throw new FooException() + } + } + e.getSuppressed.head.getMessage ==> "beforeRollback" + listener.beforeCommitCalled ==> false + listener.afterCommitCalled ==> false + listener.beforeRollbackCalled ==> true + listener.afterRollbackCalled ==> true + } + + test("if afterRollback causes an exception, the exception is propagated") { + val listener = new StubTransactionListener(throwOnAfterRollback = true) + val e = intercept[FooException] { + dbClient.transaction { implicit txn => + txn.addTransactionListener(listener) + throw new FooException() + } + } + e.getSuppressed.head.getMessage ==> "afterRollback" + listener.beforeCommitCalled ==> false + listener.afterCommitCalled ==> false + listener.beforeRollbackCalled ==> true + listener.afterRollbackCalled ==> true + } + } } } diff --git a/scalasql/test/src/utils/TestChecker.scala b/scalasql/test/src/utils/TestChecker.scala index 86f31ec8..665d1f3e 100644 --- a/scalasql/test/src/utils/TestChecker.scala +++ b/scalasql/test/src/utils/TestChecker.scala @@ -6,7 +6,7 @@ import scalasql.query.SubqueryRef import scalasql.{DbClient, Queryable, Expr, UtestFramework} class TestChecker( - val dbClient: DbClient, + val dbClient: DbClient.DataSource, testSchemaFileName: String, testDataFileName: String, suiteName: String,