Skip to content

Allow detect whether it's in a database transaction for a context.Context #21756

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 16 commits into from
Nov 12, 2022
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion models/activities/action.go
Original file line number Diff line number Diff line change
@@ -572,7 +572,7 @@ func NotifyWatchers(actions ...*Action) error {

// NotifyWatchersActions creates batch of actions for every watcher.
func NotifyWatchersActions(acts []*Action) error {
ctx, committer, err := db.TxContext()
ctx, committer, err := db.TxContext(db.DefaultContext)
if err != nil {
return err
}
4 changes: 2 additions & 2 deletions models/activities/notification.go
Original file line number Diff line number Diff line change
@@ -142,7 +142,7 @@ func CountNotifications(opts *FindNotificationOptions) (int64, error) {

// CreateRepoTransferNotification creates notification for the user a repository was transferred to
func CreateRepoTransferNotification(doer, newOwner *user_model.User, repo *repo_model.Repository) error {
ctx, committer, err := db.TxContext()
ctx, committer, err := db.TxContext(db.DefaultContext)
if err != nil {
return err
}
@@ -185,7 +185,7 @@ func CreateRepoTransferNotification(doer, newOwner *user_model.User, repo *repo_
// for each watcher, or updates it if already exists
// receiverID > 0 just send to receiver, else send to all watcher
func CreateOrUpdateIssueNotifications(issueID, commentID, notificationAuthorID, receiverID int64) error {
ctx, committer, err := db.TxContext()
ctx, committer, err := db.TxContext(db.DefaultContext)
if err != nil {
return err
}
2 changes: 1 addition & 1 deletion models/asymkey/gpg_key.go
Original file line number Diff line number Diff line change
@@ -234,7 +234,7 @@ func DeleteGPGKey(doer *user_model.User, id int64) (err error) {
return ErrGPGKeyAccessDenied{doer.ID, key.ID}
}

ctx, committer, err := db.TxContext()
ctx, committer, err := db.TxContext(db.DefaultContext)
if err != nil {
return err
}
2 changes: 1 addition & 1 deletion models/asymkey/gpg_key_add.go
Original file line number Diff line number Diff line change
@@ -73,7 +73,7 @@ func AddGPGKey(ownerID int64, content, token, signature string) ([]*GPGKey, erro
return nil, err
}

ctx, committer, err := db.TxContext()
ctx, committer, err := db.TxContext(db.DefaultContext)
if err != nil {
return nil, err
}
2 changes: 1 addition & 1 deletion models/asymkey/gpg_key_verify.go
Original file line number Diff line number Diff line change
@@ -31,7 +31,7 @@ import (

// VerifyGPGKey marks a GPG key as verified
func VerifyGPGKey(ownerID int64, keyID, token, signature string) (string, error) {
ctx, committer, err := db.TxContext()
ctx, committer, err := db.TxContext(db.DefaultContext)
if err != nil {
return "", err
}
4 changes: 2 additions & 2 deletions models/asymkey/ssh_key.go
Original file line number Diff line number Diff line change
@@ -100,7 +100,7 @@ func AddPublicKey(ownerID int64, name, content string, authSourceID int64) (*Pub
return nil, err
}

ctx, committer, err := db.TxContext()
ctx, committer, err := db.TxContext(db.DefaultContext)
if err != nil {
return nil, err
}
@@ -321,7 +321,7 @@ func PublicKeyIsExternallyManaged(id int64) (bool, error) {
// deleteKeysMarkedForDeletion returns true if ssh keys needs update
func deleteKeysMarkedForDeletion(keys []string) (bool, error) {
// Start session
ctx, committer, err := db.TxContext()
ctx, committer, err := db.TxContext(db.DefaultContext)
if err != nil {
return false, err
}
2 changes: 1 addition & 1 deletion models/asymkey/ssh_key_deploy.go
Original file line number Diff line number Diff line change
@@ -126,7 +126,7 @@ func AddDeployKey(repoID int64, name, content string, readOnly bool) (*DeployKey
accessMode = perm.AccessModeWrite
}

ctx, committer, err := db.TxContext()
ctx, committer, err := db.TxContext(db.DefaultContext)
if err != nil {
return nil, err
}
2 changes: 1 addition & 1 deletion models/asymkey/ssh_key_principals.go
Original file line number Diff line number Diff line change
@@ -26,7 +26,7 @@ import (

// AddPrincipalKey adds new principal to database and authorized_principals file.
func AddPrincipalKey(ownerID int64, content string, authSourceID int64) (*PublicKey, error) {
ctx, committer, err := db.TxContext()
ctx, committer, err := db.TxContext(db.DefaultContext)
if err != nil {
return nil, err
}
2 changes: 1 addition & 1 deletion models/asymkey/ssh_key_verify.go
Original file line number Diff line number Diff line change
@@ -15,7 +15,7 @@ import (

// VerifySSHKey marks a SSH key as verified
func VerifySSHKey(ownerID int64, fingerprint, token, signature string) (string, error) {
ctx, committer, err := db.TxContext()
ctx, committer, err := db.TxContext(db.DefaultContext)
if err != nil {
return "", err
}
4 changes: 2 additions & 2 deletions models/auth/oauth2.go
Original file line number Diff line number Diff line change
@@ -201,7 +201,7 @@ type UpdateOAuth2ApplicationOptions struct {

// UpdateOAuth2Application updates an oauth2 application
func UpdateOAuth2Application(opts UpdateOAuth2ApplicationOptions) (*OAuth2Application, error) {
ctx, committer, err := db.TxContext()
ctx, committer, err := db.TxContext(db.DefaultContext)
if err != nil {
return nil, err
}
@@ -265,7 +265,7 @@ func deleteOAuth2Application(ctx context.Context, id, userid int64) error {

// DeleteOAuth2Application deletes the application with the given id and the grants and auth codes related to it. It checks if the userid was the creator of the app.
func DeleteOAuth2Application(id, userid int64) error {
ctx, committer, err := db.TxContext()
ctx, committer, err := db.TxContext(db.DefaultContext)
if err != nil {
return err
}
4 changes: 2 additions & 2 deletions models/auth/session.go
Original file line number Diff line number Diff line change
@@ -37,7 +37,7 @@ func ReadSession(key string) (*Session, error) {
Key: key,
}

ctx, committer, err := db.TxContext()
ctx, committer, err := db.TxContext(db.DefaultContext)
if err != nil {
return nil, err
}
@@ -73,7 +73,7 @@ func DestroySession(key string) error {

// RegenerateSession regenerates a session from the old id
func RegenerateSession(oldKey, newKey string) (*Session, error) {
ctx, committer, err := db.TxContext()
ctx, committer, err := db.TxContext(db.DefaultContext)
if err != nil {
return nil, err
}
2 changes: 1 addition & 1 deletion models/avatars/avatar.go
Original file line number Diff line number Diff line change
@@ -97,7 +97,7 @@ func saveEmailHash(email string) string {
Hash: emailHash,
}
// OK we're going to open a session just because I think that that might hide away any problems with postgres reporting errors
if err := db.WithTx(func(ctx context.Context) error {
if err := db.WithTx(db.DefaultContext, func(ctx context.Context) error {
has, err := db.GetEngine(ctx).Where("email = ? AND hash = ?", emailHash.Email, emailHash.Hash).Get(new(EmailHash))
if has || err != nil {
// Seriously we don't care about any DB problems just return the lowerEmail - we expect the transaction to fail most of the time
54 changes: 47 additions & 7 deletions models/db/context.go
Original file line number Diff line number Diff line change
@@ -8,6 +8,7 @@ import (
"context"
"database/sql"

"xorm.io/xorm"
"xorm.io/xorm/schemas"
)

@@ -86,7 +87,11 @@ type Committer interface {
}

// TxContext represents a transaction Context
func TxContext() (*Context, Committer, error) {
func TxContext(parentCtx context.Context) (*Context, Committer, error) {
if InTransaction(parentCtx) {
return nil, nil, ErrAlreadyInTransaction
}

sess := x.NewSession()
if err := sess.Begin(); err != nil {
sess.Close()
@@ -97,14 +102,24 @@ func TxContext() (*Context, Committer, error) {
}

// WithTx represents executing database operations on a transaction
// you can optionally change the context to a parent one
func WithTx(f func(ctx context.Context) error, stdCtx ...context.Context) error {
parentCtx := DefaultContext
if len(stdCtx) != 0 && stdCtx[0] != nil {
// TODO: make sure parent context has no open session
parentCtx = stdCtx[0]
// This function will always open a new transaction, if a transaction exist in parentCtx return an error.
func WithTx(parentCtx context.Context, f func(ctx context.Context) error) error {
if InTransaction(parentCtx) {
return ErrAlreadyInTransaction
}
return txWithNoCheck(parentCtx, f)
}

// AutoTx represents executing database operations on a transaction, if the transaction exist,
// this function will reuse it otherwise will create a new one and close it when finished.
func AutoTx(parentCtx context.Context, f func(ctx context.Context) error) error {
if InTransaction(parentCtx) {
return f(newContext(parentCtx, GetEngine(parentCtx), true))
}
return txWithNoCheck(parentCtx, f)
}

func txWithNoCheck(parentCtx context.Context, f func(ctx context.Context) error) error {
sess := x.NewSession()
defer sess.Close()
if err := sess.Begin(); err != nil {
@@ -180,3 +195,28 @@ func EstimateCount(ctx context.Context, bean interface{}) (int64, error) {
}
return rows, err
}

// InTransaction returns true if the engine is in a transaction otherwise return false
func InTransaction(ctx context.Context) bool {
var e Engine
if engined, ok := ctx.(Engined); ok {
e = engined.Engine()
} else {
enginedInterface := ctx.Value(enginedContextKey)
if enginedInterface != nil {
e = enginedInterface.(Engined).Engine()
}
}
if e == nil {
return false
}

switch t := e.(type) {
case *xorm.Engine:
return false
case *xorm.Session:
return t.IsInTx()
default:
return false
}
}
33 changes: 33 additions & 0 deletions models/db/context_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
// Copyright 2022 The Gitea Authors. All rights reserved.
// Use of this source code is governed by a MIT-style
// license that can be found in the LICENSE file.

package db_test

import (
"context"
"testing"

"code.gitea.io/gitea/models/db"
"code.gitea.io/gitea/models/unittest"

"github.com/stretchr/testify/assert"
)

func TestInTransaction(t *testing.T) {
assert.NoError(t, unittest.PrepareTestDatabase())
assert.False(t, db.InTransaction(db.DefaultContext))
assert.NoError(t, db.WithTx(db.DefaultContext, func(ctx context.Context) error {
assert.True(t, db.InTransaction(ctx))
return nil
}))

ctx, committer, err := db.TxContext(db.DefaultContext)
assert.NoError(t, err)
defer committer.Close()
assert.True(t, db.InTransaction(ctx))
assert.Error(t, db.WithTx(ctx, func(ctx context.Context) error {
assert.True(t, db.InTransaction(ctx))
return nil
}))
}
3 changes: 3 additions & 0 deletions models/db/error.go
Original file line number Diff line number Diff line change
@@ -5,11 +5,14 @@
package db

import (
"errors"
"fmt"

"code.gitea.io/gitea/modules/util"
)

var ErrAlreadyInTransaction = errors.New("database connection has already been in a transaction")

// ErrCancelled represents an error due to context cancellation
type ErrCancelled struct {
Message string
8 changes: 4 additions & 4 deletions models/db/index_test.go
Original file line number Diff line number Diff line change
@@ -59,7 +59,7 @@ func TestSyncMaxResourceIndex(t *testing.T) {
assert.EqualValues(t, 62, maxIndex)

// commit transaction
err = db.WithTx(func(ctx context.Context) error {
err = db.WithTx(db.DefaultContext, func(ctx context.Context) error {
err = db.SyncMaxResourceIndex(ctx, "test_index", 10, 73)
assert.NoError(t, err)
maxIndex, err = getCurrentResourceIndex(ctx, "test_index", 10)
@@ -73,7 +73,7 @@ func TestSyncMaxResourceIndex(t *testing.T) {
assert.EqualValues(t, 73, maxIndex)

// rollback transaction
err = db.WithTx(func(ctx context.Context) error {
err = db.WithTx(db.DefaultContext, func(ctx context.Context) error {
err = db.SyncMaxResourceIndex(ctx, "test_index", 10, 84)
maxIndex, err = getCurrentResourceIndex(ctx, "test_index", 10)
assert.NoError(t, err)
@@ -102,7 +102,7 @@ func TestGetNextResourceIndex(t *testing.T) {
assert.EqualValues(t, 2, maxIndex)

// commit transaction
err = db.WithTx(func(ctx context.Context) error {
err = db.WithTx(db.DefaultContext, func(ctx context.Context) error {
maxIndex, err = db.GetNextResourceIndex(ctx, "test_index", 20)
assert.NoError(t, err)
assert.EqualValues(t, 3, maxIndex)
@@ -114,7 +114,7 @@ func TestGetNextResourceIndex(t *testing.T) {
assert.EqualValues(t, 3, maxIndex)

// rollback transaction
err = db.WithTx(func(ctx context.Context) error {
err = db.WithTx(db.DefaultContext, func(ctx context.Context) error {
maxIndex, err = db.GetNextResourceIndex(ctx, "test_index", 20)
assert.NoError(t, err)
assert.EqualValues(t, 4, maxIndex)
2 changes: 1 addition & 1 deletion models/git/branches.go
Original file line number Diff line number Diff line change
@@ -544,7 +544,7 @@ func FindRenamedBranch(repoID int64, from string) (branch *RenamedBranch, exist

// RenameBranch rename a branch
func RenameBranch(repo *repo_model.Repository, from, to string, gitAction func(isDefault bool) error) (err error) {
ctx, committer, err := db.TxContext()
ctx, committer, err := db.TxContext(db.DefaultContext)
if err != nil {
return err
}
2 changes: 1 addition & 1 deletion models/git/branches_test.go
Original file line number Diff line number Diff line change
@@ -102,7 +102,7 @@ func TestRenameBranch(t *testing.T) {
repo1 := unittest.AssertExistsAndLoadBean(t, &repo_model.Repository{ID: 1})
_isDefault := false

ctx, committer, err := db.TxContext()
ctx, committer, err := db.TxContext(db.DefaultContext)
defer committer.Close()
assert.NoError(t, err)
assert.NoError(t, git_model.UpdateProtectBranch(ctx, repo1, &git_model.ProtectedBranch{
4 changes: 2 additions & 2 deletions models/git/commit_status.go
Original file line number Diff line number Diff line change
@@ -94,7 +94,7 @@ func GetNextCommitStatusIndex(repoID int64, sha string) (int64, error) {

// getNextCommitStatusIndex return the next index
func getNextCommitStatusIndex(repoID int64, sha string) (int64, error) {
ctx, commiter, err := db.TxContext()
ctx, commiter, err := db.TxContext(db.DefaultContext)
if err != nil {
return 0, err
}
@@ -297,7 +297,7 @@ func NewCommitStatus(opts NewCommitStatusOptions) error {
return fmt.Errorf("generate commit status index failed: %w", err)
}

ctx, committer, err := db.TxContext()
ctx, committer, err := db.TxContext(db.DefaultContext)
if err != nil {
return fmt.Errorf("NewCommitStatus[repo_id: %d, user_id: %d, sha: %s]: %w", opts.Repo.ID, opts.Creator.ID, opts.SHA, err)
}
6 changes: 3 additions & 3 deletions models/git/lfs.go
Original file line number Diff line number Diff line change
@@ -137,7 +137,7 @@ var ErrLFSObjectNotExist = db.ErrNotExist{Resource: "LFS Meta object"}
func NewLFSMetaObject(m *LFSMetaObject) (*LFSMetaObject, error) {
var err error

ctx, committer, err := db.TxContext()
ctx, committer, err := db.TxContext(db.DefaultContext)
if err != nil {
return nil, err
}
@@ -185,7 +185,7 @@ func RemoveLFSMetaObjectByOid(repoID int64, oid string) (int64, error) {
return 0, ErrLFSObjectNotExist
}

ctx, committer, err := db.TxContext()
ctx, committer, err := db.TxContext(db.DefaultContext)
if err != nil {
return 0, err
}
@@ -242,7 +242,7 @@ func LFSObjectIsAssociated(oid string) (bool, error) {

// LFSAutoAssociate auto associates accessible LFSMetaObjects
func LFSAutoAssociate(metas []*LFSMetaObject, user *user_model.User, repoID int64) error {
ctx, committer, err := db.TxContext()
ctx, committer, err := db.TxContext(db.DefaultContext)
if err != nil {
return err
}
4 changes: 2 additions & 2 deletions models/git/lfs_lock.go
Original file line number Diff line number Diff line change
@@ -44,7 +44,7 @@ func cleanPath(p string) string {

// CreateLFSLock creates a new lock.
func CreateLFSLock(repo *repo_model.Repository, lock *LFSLock) (*LFSLock, error) {
dbCtx, committer, err := db.TxContext()
dbCtx, committer, err := db.TxContext(db.DefaultContext)
if err != nil {
return nil, err
}
@@ -137,7 +137,7 @@ func CountLFSLockByRepoID(repoID int64) (int64, error) {

// DeleteLFSLockByID deletes a lock by given ID.
func DeleteLFSLockByID(id int64, repo *repo_model.Repository, u *user_model.User, force bool) (*LFSLock, error) {
dbCtx, committer, err := db.TxContext()
dbCtx, committer, err := db.TxContext(db.DefaultContext)
if err != nil {
return nil, err
}
2 changes: 1 addition & 1 deletion models/issues/assignees.go
Original file line number Diff line number Diff line change
@@ -64,7 +64,7 @@ func IsUserAssignedToIssue(ctx context.Context, issue *Issue, user *user_model.U

// ToggleIssueAssignee changes a user between assigned and not assigned for this issue, and make issue comment for it.
func ToggleIssueAssignee(issue *Issue, doer *user_model.User, assigneeID int64) (removed bool, comment *Comment, err error) {
ctx, committer, err := db.TxContext()
ctx, committer, err := db.TxContext(db.DefaultContext)
if err != nil {
return false, nil, err
}
Loading