Skip to content

Add TransactionListener interface #71

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Mar 12, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 71 additions & 7 deletions scalasql/core/src/DbApi.scala
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
package scalasql.core

import DbClient.notifyListeners

import geny.Generator

import java.sql.{PreparedStatement, Statement}
Expand Down Expand Up @@ -136,6 +138,50 @@ object DbApi {
flattened.renderSql(castParams)
}

/**
* A listener that can be added to a [[DbApi.Txn]] to be notified of commit and rollback events.
*
* The default implementations of these methods do nothing, but you can override them to
* implement your own behavior.
*/
trait TransactionListener {

/**
* Called when a new transaction is started.
*/
def begin(): Unit = ()

/**
* Called before the transaction is committed.
*
* If this method throws an exception, the transaction will be rolled back and the exception
* will be propagated.
*/
def beforeCommit(): Unit = ()

/**
* Called after the transaction is committed.
*
* If this method throws an exception, it will be propagated.
*/
def afterCommit(): Unit = ()

/**
* Called before the transaction is rolled back.
*
* If this method throws an exception, the transaction will be rolled back and the exception
* will be propagated to the caller of rollback().
*/
def beforeRollback(): Unit = ()

/**
* Called after the transaction is rolled back.
*
* If this method throws an exception, it will be propagated to the caller of rollback().
*/
def afterRollback(): Unit = ()
}

/**
* An interface to a SQL database *transaction*, allowing you to run queries,
* create savepoints, or roll back the transaction.
Expand All @@ -151,9 +197,11 @@ object DbApi {
def savepoint[T](block: DbApi.Savepoint => T): T

/**
* Tolls back any active Savepoints and then rolls back this Transaction
* Rolls back any active Savepoints and then rolls back this Transaction
*/
def rollback(): Unit

def addTransactionListener(listener: TransactionListener): Unit
}

/**
Expand Down Expand Up @@ -187,9 +235,19 @@ object DbApi {
connection: java.sql.Connection,
config: Config,
dialect: DialectConfig,
autoCommit: Boolean,
rollBack0: () => Unit
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've removed the rollBack0 argument since it was only used to handle the difference between autoCommit true/false behavior.

With the new interface, client code can now hook into the rollback process if needed.

defaultListeners: Iterable[TransactionListener],
autoCommit: Boolean
) extends DbApi.Txn {

val listeners =
collection.mutable.ArrayDeque.empty[TransactionListener].addAll(defaultListeners)

override def addTransactionListener(listener: TransactionListener): Unit = {
if (autoCommit)
throw new IllegalStateException("Cannot add listener to auto-commit transaction")
listeners.append(listener)
}

def run[Q, R](query: Q, fetchSize: Int = -1, queryTimeoutSeconds: Int = -1)(
implicit qr: Queryable[Q, R],
fileName: sourcecode.FileName,
Expand Down Expand Up @@ -218,6 +276,7 @@ object DbApi {
res.toVector.asInstanceOf[R]
}
}

}

def stream[Q, R](query: Q, fetchSize: Int = -1, queryTimeoutSeconds: Int = -1)(
Expand All @@ -229,8 +288,8 @@ object DbApi {
streamFlattened0(
r => {
qr.asInstanceOf[Queryable[Q, R]].construct(query, r) match {
case s: Seq[R] => s.head
case r: R => r
case s: Seq[R] @unchecked => s.head
case r: R @unchecked => r
}
},
flattened,
Expand Down Expand Up @@ -545,8 +604,13 @@ object DbApi {
}

def rollback() = {
savepointStack.clear()
rollBack0()
try {
notifyListeners(listeners)(_.beforeRollback())
} finally {
savepointStack.clear()
connection.rollback()
notifyListeners(listeners)(_.afterRollback())
}
}

private def cast[T](t: Any): T = t.asInstanceOf[T]
Expand Down
75 changes: 64 additions & 11 deletions scalasql/core/src/DbClient.scala
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,33 @@ trait DbClient {

object DbClient {

/**
* Calls the given function for each listener, collecting any exceptions and throwing them
* as a single exception if any are thrown.
*/
private[core] def notifyListeners(listeners: Iterable[DbApi.TransactionListener])(
f: DbApi.TransactionListener => Unit
): Unit = {
if (listeners.isEmpty) return

var exception: Throwable = null
listeners.foreach { listener =>
try {
f(listener)
} catch {
case e: Throwable =>
if (exception == null) exception = e
else exception.addSuppressed(e)
}
}
if (exception != null) throw exception
}

class Connection(
connection: java.sql.Connection,
config: Config = new Config {}
config: Config = new Config {},
/** Listeners that are added to all transactions created by this connection */
listeners: Seq[DbApi.TransactionListener] = Seq.empty
)(implicit dialect: DialectConfig)
extends DbClient {

Expand All @@ -49,28 +73,57 @@ object DbClient {

def transaction[T](block: DbApi.Txn => T): T = {
connection.setAutoCommit(false)
val txn =
new DbApi.Impl(connection, config, dialect, false, () => connection.rollback())
try block(txn)
catch {
val txn = new DbApi.Impl(connection, config, dialect, listeners, autoCommit = false)
var rolledBack = false
try {
notifyListeners(txn.listeners)(_.begin())
val result = block(txn)
notifyListeners(txn.listeners)(_.beforeCommit())
result
} catch {
case e: Throwable =>
connection.rollback()
rolledBack = true
try {
notifyListeners(txn.listeners)(_.beforeRollback())
} catch {
case e2: Throwable => e.addSuppressed(e2)
} finally {
connection.rollback()
try {
notifyListeners(txn.listeners)(_.afterRollback())
} catch {
case e3: Throwable => e.addSuppressed(e3)
}
}
throw e
} finally connection.setAutoCommit(true)
} finally {
// this commits uncommitted operations, if any
connection.setAutoCommit(true)
if (!rolledBack) {
notifyListeners(txn.listeners)(_.afterCommit())
}
}
}

def getAutoCommitClientConnection: DbApi = {
connection.setAutoCommit(true)
new DbApi.Impl(connection, config, dialect, autoCommit = true, () => ())
new DbApi.Impl(connection, config, dialect, listeners, autoCommit = true)
}
}

class DataSource(
dataSource: javax.sql.DataSource,
config: Config = new Config {}
config: Config = new Config {},
/** Listeners that are added to all transactions created through the [[DataSource]] */
listeners: Seq[DbApi.TransactionListener] = Seq.empty
)(implicit dialect: DialectConfig)
extends DbClient {

/** Returns a new [[DataSource]] with the given listener added */
def withTransactionListener(listener: DbApi.TransactionListener): DbClient = {
new DataSource(dataSource, config, listeners :+ listener)
}

def renderSql[Q, R](query: Q, castParams: Boolean = false)(
implicit qr: Queryable[Q, R]
): String = {
Expand All @@ -79,7 +132,7 @@ object DbClient {

private def withConnection[T](f: DbClient.Connection => T): T = {
val connection = dataSource.getConnection
try f(new DbClient.Connection(connection, config))
try f(new DbClient.Connection(connection, config, listeners))
finally connection.close()
}

Expand All @@ -88,7 +141,7 @@ object DbClient {
def getAutoCommitClientConnection: DbApi = {
val connection = dataSource.getConnection
connection.setAutoCommit(true)
new DbApi.Impl(connection, config, dialect, autoCommit = true, () => ())
new DbApi.Impl(connection, config, dialect, defaultListeners = Seq.empty, autoCommit = true)
}
}
}
109 changes: 109 additions & 0 deletions scalasql/test/src/api/TransactionTests.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package scalasql.api

import scalasql.Purchase
import scalasql.utils.{ScalaSqlSuite, SqliteSuite}
import scalasql.DbApi
import sourcecode.Text
import utest._

Expand All @@ -12,6 +13,42 @@ trait TransactionTests extends ScalaSqlSuite {
override def utestBeforeEach(path: Seq[String]): Unit = checker.reset()
class FooException extends Exception

class ListenerException(message: String) extends Exception(message)

class StubTransactionListener(
throwOnBeforeCommit: Boolean = false,
throwOnAfterCommit: Boolean = false,
throwOnBeforeRollback: Boolean = false,
throwOnAfterRollback: Boolean = false
) extends DbApi.TransactionListener {
var beginCalled = false
var beforeCommitCalled = false
var afterCommitCalled = false
var beforeRollbackCalled = false
var afterRollbackCalled = false

override def begin(): Unit = {
beginCalled = true
}

override def beforeCommit(): Unit = {
beforeCommitCalled = true
if (throwOnBeforeCommit) throw new ListenerException("beforeCommit")
}
override def afterCommit(): Unit = {
afterCommitCalled = true
if (throwOnAfterCommit) throw new ListenerException("afterCommit")
}
override def beforeRollback(): Unit = {
beforeRollbackCalled = true
if (throwOnBeforeRollback) throw new ListenerException("beforeRollback")
}
override def afterRollback(): Unit = {
afterRollbackCalled = true
if (throwOnAfterRollback) throw new ListenerException("afterRollback")
}
}

def tests = Tests {
test("simple") {
test("commit") - checker.recorded(
Expand Down Expand Up @@ -537,5 +574,77 @@ trait TransactionTests extends ScalaSqlSuite {
}
}
}

test("listener") {
test("beforeCommit and afterCommit are called under normal circumstances") {
val listener = new StubTransactionListener()
dbClient.withTransactionListener(listener).transaction { _ =>
// do nothing
}
listener.beginCalled ==> true
listener.beforeCommitCalled ==> true
listener.afterCommitCalled ==> true
listener.beforeRollbackCalled ==> false
listener.afterRollbackCalled ==> false
}

test("if beforeCommit causes an exception, {before,after}Rollback are called") {
val listener = new StubTransactionListener(throwOnBeforeCommit = true)
val e = intercept[ListenerException] {
dbClient.transaction { implicit txn =>
txn.addTransactionListener(listener)
}
}
e.getMessage ==> "beforeCommit"
listener.beforeCommitCalled ==> true
listener.afterCommitCalled ==> false
listener.beforeRollbackCalled ==> true
listener.afterRollbackCalled ==> true
}

test("if afterCommit causes an exception, the exception is propagated") {
val listener = new StubTransactionListener(throwOnAfterCommit = true)
val e = intercept[ListenerException] {
dbClient.transaction { implicit txn =>
txn.addTransactionListener(listener)
}
}
e.getMessage ==> "afterCommit"
listener.beforeCommitCalled ==> true
listener.afterCommitCalled ==> true
listener.beforeRollbackCalled ==> false
listener.afterRollbackCalled ==> false
}

test("if beforeRollback causes an exception, afterRollback is still called") {
val listener = new StubTransactionListener(throwOnBeforeRollback = true)
val e = intercept[FooException] {
dbClient.transaction { implicit txn =>
txn.addTransactionListener(listener)
throw new FooException()
}
}
e.getSuppressed.head.getMessage ==> "beforeRollback"
listener.beforeCommitCalled ==> false
listener.afterCommitCalled ==> false
listener.beforeRollbackCalled ==> true
listener.afterRollbackCalled ==> true
}

test("if afterRollback causes an exception, the exception is propagated") {
val listener = new StubTransactionListener(throwOnAfterRollback = true)
val e = intercept[FooException] {
dbClient.transaction { implicit txn =>
txn.addTransactionListener(listener)
throw new FooException()
}
}
e.getSuppressed.head.getMessage ==> "afterRollback"
listener.beforeCommitCalled ==> false
listener.afterCommitCalled ==> false
listener.beforeRollbackCalled ==> true
listener.afterRollbackCalled ==> true
}
}
}
}
2 changes: 1 addition & 1 deletion scalasql/test/src/utils/TestChecker.scala
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ import scalasql.query.SubqueryRef
import scalasql.{DbClient, Queryable, Expr, UtestFramework}

class TestChecker(
val dbClient: DbClient,
val dbClient: DbClient.DataSource,
testSchemaFileName: String,
testDataFileName: String,
suiteName: String,
Expand Down