Skip to content

Add SELECT FOR UPDATE variants for Postgres and MySql #45

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Nov 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 65 additions & 1 deletion docs/reference.md
Original file line number Diff line number Diff line change
Expand Up @@ -10243,7 +10243,7 @@ db.run(OptDataTypes.select) ==> Seq(rowSome, rowNone)
Operations specific to working with Postgres Databases
### PostgresDialect.distinctOn

ScalaSql's Postgres dialect provides teh `.distinctOn` operator, which translates
ScalaSql's Postgres dialect provides the `.distinctOn` operator, which translates
into a SQL `DISTINCT ON` clause

```scala
Expand Down Expand Up @@ -10276,6 +10276,38 @@ Purchase.select.distinctOn(_.shippingInfoId).sortBy(_.shippingInfoId).desc



### PostgresDialect.forUpdate

ScalaSql's Postgres dialect provides the `.forUpdate` operator, which translates
into a SQL `SELECT ... FOR UPDATE` clause

```scala
Invoice.select.filter(_.id === 1).forUpdate
```


*
```sql
SELECT
invoice0.id AS id,
invoice0.total AS total,
invoice0.vendor_name AS vendor_name
FROM otherschema.invoice invoice0
WHERE (invoice0.id = ?)
FOR UPDATE
```



*
```scala
Seq(
Invoice[Sc](1, 150.4, "Siemens")
)
```



### PostgresDialect.ltrim2


Expand Down Expand Up @@ -10480,6 +10512,38 @@ db.random

## MySqlDialect
Operations specific to working with MySql Databases
### MySqlDialect.forUpdate

ScalaSql's MySql dialect provides the `.forUpdate` operator, which translates
into a SQL `SELECT ... FOR UPDATE` clause

```scala
Buyer.select.filter(_.id === 1).forUpdate
```


*
```sql
SELECT
buyer0.id AS id,
buyer0.name AS name,
buyer0.date_of_birth AS date_of_birth
FROM buyer buyer0
WHERE (buyer0.id = ?)
FOR UPDATE
```



*
```scala
Seq(
Buyer[Sc](1, "James Bond", LocalDate.parse("2001-02-03"))
)
```



### MySqlDialect.reverse


Expand Down
19 changes: 16 additions & 3 deletions scalasql/query/src/Select.scala
Original file line number Diff line number Diff line change
Expand Up @@ -57,21 +57,24 @@ trait Select[Q, R]
protected def newSimpleSelect[Q, R](
expr: Q,
exprPrefix: Option[Context => SqlStr],
exprSuffix: Option[Context => SqlStr],
preserveAll: Boolean,
from: Seq[Context.From],
joins: Seq[Join],
where: Seq[Expr[?]],
groupBy0: Option[GroupBy]
)(implicit qr: Queryable.Row[Q, R], dialect: DialectTypeMappers): SimpleSelect[Q, R] =
new SimpleSelect(expr, exprPrefix, preserveAll, from, joins, where, groupBy0)
new SimpleSelect(expr, exprPrefix, exprSuffix, preserveAll, from, joins, where, groupBy0)

def qr: Queryable.Row[Q, R]

/**
* Causes this [[Select]] to ignore duplicate rows, translates into SQL `SELECT DISTINCT`
*/
def distinct: Select[Q, R] = selectWithExprPrefix(true, _ => sql"DISTINCT")

protected def selectWithExprPrefix(preserveAll: Boolean, s: Context => SqlStr): Select[Q, R]
protected def selectWithExprSuffix(preserveAll: Boolean, s: Context => SqlStr): Select[Q, R]

protected def subqueryRef(implicit qr: Queryable.Row[Q, R]) = new SubqueryRef(this)

Expand Down Expand Up @@ -227,7 +230,7 @@ trait Select[Q, R]
* in this [[Select]]
*/
def subquery: SimpleSelect[Q, R] = {
newSimpleSelect(expr, None, false, Seq(subqueryRef(qr)), Nil, Nil, None)(qr, dialect)
newSimpleSelect(expr, None, None, false, Seq(subqueryRef(qr)), Nil, Nil, None)(qr, dialect)
}

/**
Expand Down Expand Up @@ -278,19 +281,23 @@ object Select {
lhs: Select[Q, R],
expr: Q,
exprPrefix: Option[Context => SqlStr],
exprSuffix: Option[Context => SqlStr],
preserveAll: Boolean,
from: Seq[Context.From],
joins: Seq[Join],
where: Seq[Expr[?]],
groupBy0: Option[GroupBy]
)(implicit qr: Queryable.Row[Q, R], dialect: DialectTypeMappers): SimpleSelect[Q, R] =
lhs.newSimpleSelect(expr, exprPrefix, preserveAll, from, joins, where, groupBy0)
lhs.newSimpleSelect(expr, exprPrefix, exprSuffix, preserveAll, from, joins, where, groupBy0)

def toSimpleFrom[Q, R](s: Select[Q, R]) = s.selectToSimpleSelect()

def withExprPrefix[Q, R](s: Select[Q, R], preserveAll: Boolean, str: Context => SqlStr) =
s.selectWithExprPrefix(preserveAll, str)

def withExprSuffix[Q, R](s: Select[Q, R], preserveAll: Boolean, str: Context => SqlStr) =
s.selectWithExprSuffix(preserveAll, str)

implicit class ExprSelectOps[T](s: Select[Expr[T], T]) {
def sorted(implicit tm: TypeMapper[T]): Select[Expr[T], T] = s.sortBy(identity)
}
Expand All @@ -303,6 +310,12 @@ object Select {
): Select[Q, R] =
selectToSimpleSelect().selectWithExprPrefix(preserveAll, s)

override protected def selectWithExprSuffix(
preserveAll: Boolean,
s: Context => SqlStr
): Select[Q, R] =
selectToSimpleSelect().selectWithExprSuffix(preserveAll, s)

override def map[Q2, R2](f: Q => Q2)(implicit qr: Queryable.Row[Q2, R2]): Select[Q2, R2] =
selectToSimpleSelect().map(f)

Expand Down
12 changes: 10 additions & 2 deletions scalasql/query/src/SimpleSelect.scala
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import scalasql.renderer.JoinsToSql
class SimpleSelect[Q, R](
val expr: Q,
val exprPrefix: Option[Context => SqlStr],
val exprSuffix: Option[Context => SqlStr],
val preserveAll: Boolean,
val from: Seq[Context.From],
val joins: Seq[Join],
Expand All @@ -33,17 +34,21 @@ class SimpleSelect[Q, R](
protected def copy[Q, R](
expr: Q = this.expr,
exprPrefix: Option[Context => SqlStr] = this.exprPrefix,
exprSuffix: Option[Context => SqlStr] = this.exprSuffix,
preserveAll: Boolean = this.preserveAll,
from: Seq[Context.From] = this.from,
joins: Seq[Join] = this.joins,
where: Seq[Expr[?]] = this.where,
groupBy0: Option[GroupBy] = this.groupBy0
)(implicit qr: Queryable.Row[Q, R]) =
newSimpleSelect(expr, exprPrefix, preserveAll, from, joins, where, groupBy0)
newSimpleSelect(expr, exprPrefix, exprSuffix, preserveAll, from, joins, where, groupBy0)

def selectWithExprPrefix(preserveAll: Boolean, s: Context => SqlStr): Select[Q, R] =
this.copy(exprPrefix = Some(s), preserveAll = preserveAll)

def selectWithExprSuffix(preserveAll: Boolean, s: Context => SqlStr): Select[Q, R] =
this.copy(exprSuffix = Some(s), preserveAll = preserveAll)

def aggregateExpr[V: TypeMapper](
f: Q => Context => SqlStr
)(implicit qr2: Queryable.Row[Expr[V], V]): Expr[V] = {
Expand Down Expand Up @@ -111,6 +116,7 @@ class SimpleSelect[Q, R](
copy(
expr = newExpr,
exprPrefix = exprPrefix,
exprSuffix = exprSuffix,
joins = joins ++ newJoins,
where = where ++ newWheres
)
Expand Down Expand Up @@ -178,6 +184,7 @@ class SimpleSelect[Q, R](
copy(
expr = newExpr,
exprPrefix = exprPrefix,
exprSuffix = exprSuffix,
from = Seq(this.subqueryRef),
joins = Nil,
where = Nil,
Expand Down Expand Up @@ -287,11 +294,12 @@ object SimpleSelect {
)

lazy val exprPrefix = SqlStr.opt(query.exprPrefix) { p => p(context) + sql" " }
lazy val exprSuffix = SqlStr.opt(query.exprSuffix) { p => p(context) }

val tables = SqlStr
.join(query.from.map(renderedFroms(_)), SqlStr.commaSep)

sql"SELECT " + exprPrefix + exprStr + sql" FROM " + tables + joins + filtersOpt + groupByOpt
sql"SELECT " + exprPrefix + exprStr + sql" FROM " + tables + joins + filtersOpt + groupByOpt + exprSuffix
}

}
Expand Down
1 change: 1 addition & 0 deletions scalasql/query/src/WithCte.scala
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ object WithCte {
lhs,
expr = WithSqlExpr.get(lhs),
exprPrefix = None,
exprSuffix = None,
preserveAll = false,
from = Seq(lhsSubQueryRef),
joins = Nil,
Expand Down
6 changes: 5 additions & 1 deletion scalasql/src/dialects/H2Dialect.scala
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@ object H2Dialect extends H2Dialect {
new SimpleSelect(
Table.metadata(t).vExpr(ref, dialectSelf).asInstanceOf[V[Expr]],
None,
None,
false,
Seq(ref),
Nil,
Expand Down Expand Up @@ -132,6 +133,7 @@ object H2Dialect extends H2Dialect {
override def newSimpleSelect[Q, R](
expr: Q,
exprPrefix: Option[Context => SqlStr],
exprSuffix: Option[Context => SqlStr],
preserveAll: Boolean,
from: Seq[Context.From],
joins: Seq[Join],
Expand All @@ -141,13 +143,14 @@ object H2Dialect extends H2Dialect {
implicit qr: Queryable.Row[Q, R],
dialect: scalasql.core.DialectTypeMappers
): scalasql.query.SimpleSelect[Q, R] = {
new SimpleSelect(expr, exprPrefix, preserveAll, from, joins, where, groupBy0)
new SimpleSelect(expr, exprPrefix, exprSuffix, preserveAll, from, joins, where, groupBy0)
}
}

class SimpleSelect[Q, R](
expr: Q,
exprPrefix: Option[Context => SqlStr],
exprSuffix: Option[Context => SqlStr],
preserveAll: Boolean,
from: Seq[Context.From],
joins: Seq[Join],
Expand All @@ -157,6 +160,7 @@ object H2Dialect extends H2Dialect {
extends scalasql.query.SimpleSelect(
expr,
exprPrefix,
exprSuffix,
preserveAll,
from,
joins,
Expand Down
39 changes: 38 additions & 1 deletion scalasql/src/dialects/MySqlDialect.scala
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ import java.sql.PreparedStatement
import java.time.{Instant, LocalDateTime}
import java.util.UUID
import scala.reflect.ClassTag
import scalasql.query.Select

trait MySqlDialect extends Dialect {
protected def dialectCastParams = false
Expand Down Expand Up @@ -116,6 +117,38 @@ trait MySqlDialect extends Dialect {
implicit def ExprAggOpsConv[T](v: Aggregatable[Expr[T]]): operations.ExprAggOps[T] =
new MySqlDialect.ExprAggOps(v)

implicit class SelectForUpdateConv[Q, R](r: Select[Q, R]) {

/**
* SELECT .. FOR UPDATE acquires an exclusive lock, blocking other transactions from
* modifying or locking the selected rows, which is for managing concurrent transactions
* and ensuring data consistency in multi-step operations.
*/
def forUpdate: Select[Q, R] =
Select.withExprSuffix(r, true, _ => sql" FOR UPDATE")

/**
* SELECT ... FOR SHARE: Locks the selected rows for reading, allowing other transactions
* to read but not modify the locked rows
*/
def forShare: Select[Q, R] =
Select.withExprSuffix(r, true, _ => sql" FOR SHARE")

/**
* SELECT ... FOR UPDATE NOWAIT: Immediately returns an error if the selected rows are
* already locked, instead of waiting
*/
def forUpdateNoWait: Select[Q, R] =
Select.withExprSuffix(r, true, _ => sql" FOR UPDATE NOWAIT")

/**
* SELECT ... FOR UPDATE SKIP LOCKED: Skips any rows that are already locked by other
* transactions, instead of waiting
*/
def forUpdateSkipLocked: Select[Q, R] =
Select.withExprSuffix(r, true, _ => sql" FOR UPDATE SKIP LOCKED")
}

override implicit def DbApiOpsConv(db: => DbApi): MySqlDialect.DbApiOps =
new MySqlDialect.DbApiOps(this)

Expand Down Expand Up @@ -207,6 +240,7 @@ object MySqlDialect extends MySqlDialect {
new SimpleSelect(
Table.metadata(t).vExpr(ref, dialectSelf).asInstanceOf[V[Expr]],
None,
None,
false,
Seq(ref),
Nil,
Expand Down Expand Up @@ -309,6 +343,7 @@ object MySqlDialect extends MySqlDialect {
override def newSimpleSelect[Q, R](
expr: Q,
exprPrefix: Option[Context => SqlStr],
exprSuffix: Option[Context => SqlStr],
preserveAll: Boolean,
from: Seq[Context.From],
joins: Seq[Join],
Expand All @@ -318,13 +353,14 @@ object MySqlDialect extends MySqlDialect {
implicit qr: Queryable.Row[Q, R],
dialect: scalasql.core.DialectTypeMappers
): scalasql.query.SimpleSelect[Q, R] = {
new SimpleSelect(expr, exprPrefix, preserveAll, from, joins, where, groupBy0)
new SimpleSelect(expr, exprPrefix, exprSuffix, preserveAll, from, joins, where, groupBy0)
}
}

class SimpleSelect[Q, R](
expr: Q,
exprPrefix: Option[Context => SqlStr],
exprSuffix: Option[Context => SqlStr],
preserveAll: Boolean,
from: Seq[Context.From],
joins: Seq[Join],
Expand All @@ -334,6 +370,7 @@ object MySqlDialect extends MySqlDialect {
extends scalasql.query.SimpleSelect(
expr,
exprPrefix,
exprSuffix,
preserveAll,
from,
joins,
Expand Down
29 changes: 29 additions & 0 deletions scalasql/src/dialects/PostgresDialect.scala
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,35 @@ trait PostgresDialect extends Dialect with ReturningDialect with OnConflictOps {
}
}

implicit class SelectForUpdateConv[Q, R](r: Select[Q, R]) {

/**
* SELECT .. FOR UPDATE acquires an exclusive lock, blocking other transactions from
* modifying or locking the selected rows, which is for managing concurrent transactions
* and ensuring data consistency in multi-step operations.
*/
def forUpdate: Select[Q, R] =
Select.withExprSuffix(r, true, _ => sql" FOR UPDATE")

/**
* SELECT ... FOR NO KEY UPDATE: A weaker lock that doesn't block inserts into child
* tables with foreign key references
*/
def forNoKeyUpdate: Select[Q, R] =
Select.withExprSuffix(r, true, _ => sql" FOR NO KEY UPDATE")

/**
* SELECT ... FOR SHARE: Locks the selected rows for reading, allowing other transactions
* to read but not modify the locked rows.
*/
def forShare: Select[Q, R] =
Select.withExprSuffix(r, true, _ => sql" FOR SHARE")

/** SELECT ... FOR KEY SHARE: The weakest lock, only conflicts with FOR UPDATE */
def forKeyShare: Select[Q, R] =
Select.withExprSuffix(r, true, _ => sql" FOR KEY SHARE")
}

override implicit def DbApiOpsConv(db: => DbApi): PostgresDialect.DbApiOps =
new PostgresDialect.DbApiOps(this)
}
Expand Down
Loading
Loading