Skip to content

Commit c548dde

Browse files
authored
More refactoring of db.DefaultContext (#27083)
Next step of #27065
1 parent f8a1094 commit c548dde

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

83 files changed

+336
-320
lines changed

cmd/admin_user_create.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -156,7 +156,7 @@ func runCreateUser(c *cli.Context) error {
156156
UID: u.ID,
157157
}
158158

159-
if err := auth_model.NewAccessToken(t); err != nil {
159+
if err := auth_model.NewAccessToken(ctx, t); err != nil {
160160
return err
161161
}
162162

cmd/admin_user_generate_access_token.go

+2-2
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ func runGenerateAccessToken(c *cli.Context) error {
6363
UID: user.ID,
6464
}
6565

66-
exist, err := auth_model.AccessTokenByNameExists(t)
66+
exist, err := auth_model.AccessTokenByNameExists(ctx, t)
6767
if err != nil {
6868
return err
6969
}
@@ -79,7 +79,7 @@ func runGenerateAccessToken(c *cli.Context) error {
7979
t.Scope = accessTokenScope
8080

8181
// create the token
82-
if err := auth_model.NewAccessToken(t); err != nil {
82+
if err := auth_model.NewAccessToken(ctx, t); err != nil {
8383
return err
8484
}
8585

models/auth/token.go

+16-15
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
package auth
66

77
import (
8+
"context"
89
"crypto/subtle"
910
"encoding/hex"
1011
"fmt"
@@ -95,7 +96,7 @@ func init() {
9596
}
9697

9798
// NewAccessToken creates new access token.
98-
func NewAccessToken(t *AccessToken) error {
99+
func NewAccessToken(ctx context.Context, t *AccessToken) error {
99100
salt, err := util.CryptoRandomString(10)
100101
if err != nil {
101102
return err
@@ -108,7 +109,7 @@ func NewAccessToken(t *AccessToken) error {
108109
t.Token = hex.EncodeToString(token)
109110
t.TokenHash = HashToken(t.Token, t.TokenSalt)
110111
t.TokenLastEight = t.Token[len(t.Token)-8:]
111-
_, err = db.GetEngine(db.DefaultContext).Insert(t)
112+
_, err = db.GetEngine(ctx).Insert(t)
112113
return err
113114
}
114115

@@ -137,7 +138,7 @@ func getAccessTokenIDFromCache(token string) int64 {
137138
}
138139

139140
// GetAccessTokenBySHA returns access token by given token value
140-
func GetAccessTokenBySHA(token string) (*AccessToken, error) {
141+
func GetAccessTokenBySHA(ctx context.Context, token string) (*AccessToken, error) {
141142
if token == "" {
142143
return nil, ErrAccessTokenEmpty{}
143144
}
@@ -158,7 +159,7 @@ func GetAccessTokenBySHA(token string) (*AccessToken, error) {
158159
TokenLastEight: lastEight,
159160
}
160161
// Re-get the token from the db in case it has been deleted in the intervening period
161-
has, err := db.GetEngine(db.DefaultContext).ID(id).Get(accessToken)
162+
has, err := db.GetEngine(ctx).ID(id).Get(accessToken)
162163
if err != nil {
163164
return nil, err
164165
}
@@ -169,7 +170,7 @@ func GetAccessTokenBySHA(token string) (*AccessToken, error) {
169170
}
170171

171172
var tokens []AccessToken
172-
err := db.GetEngine(db.DefaultContext).Table(&AccessToken{}).Where("token_last_eight = ?", lastEight).Find(&tokens)
173+
err := db.GetEngine(ctx).Table(&AccessToken{}).Where("token_last_eight = ?", lastEight).Find(&tokens)
173174
if err != nil {
174175
return nil, err
175176
} else if len(tokens) == 0 {
@@ -189,8 +190,8 @@ func GetAccessTokenBySHA(token string) (*AccessToken, error) {
189190
}
190191

191192
// AccessTokenByNameExists checks if a token name has been used already by a user.
192-
func AccessTokenByNameExists(token *AccessToken) (bool, error) {
193-
return db.GetEngine(db.DefaultContext).Table("access_token").Where("name = ?", token.Name).And("uid = ?", token.UID).Exist()
193+
func AccessTokenByNameExists(ctx context.Context, token *AccessToken) (bool, error) {
194+
return db.GetEngine(ctx).Table("access_token").Where("name = ?", token.Name).And("uid = ?", token.UID).Exist()
194195
}
195196

196197
// ListAccessTokensOptions contain filter options
@@ -201,8 +202,8 @@ type ListAccessTokensOptions struct {
201202
}
202203

203204
// ListAccessTokens returns a list of access tokens belongs to given user.
204-
func ListAccessTokens(opts ListAccessTokensOptions) ([]*AccessToken, error) {
205-
sess := db.GetEngine(db.DefaultContext).Where("uid=?", opts.UserID)
205+
func ListAccessTokens(ctx context.Context, opts ListAccessTokensOptions) ([]*AccessToken, error) {
206+
sess := db.GetEngine(ctx).Where("uid=?", opts.UserID)
206207

207208
if len(opts.Name) != 0 {
208209
sess = sess.Where("name=?", opts.Name)
@@ -222,23 +223,23 @@ func ListAccessTokens(opts ListAccessTokensOptions) ([]*AccessToken, error) {
222223
}
223224

224225
// UpdateAccessToken updates information of access token.
225-
func UpdateAccessToken(t *AccessToken) error {
226-
_, err := db.GetEngine(db.DefaultContext).ID(t.ID).AllCols().Update(t)
226+
func UpdateAccessToken(ctx context.Context, t *AccessToken) error {
227+
_, err := db.GetEngine(ctx).ID(t.ID).AllCols().Update(t)
227228
return err
228229
}
229230

230231
// CountAccessTokens count access tokens belongs to given user by options
231-
func CountAccessTokens(opts ListAccessTokensOptions) (int64, error) {
232-
sess := db.GetEngine(db.DefaultContext).Where("uid=?", opts.UserID)
232+
func CountAccessTokens(ctx context.Context, opts ListAccessTokensOptions) (int64, error) {
233+
sess := db.GetEngine(ctx).Where("uid=?", opts.UserID)
233234
if len(opts.Name) != 0 {
234235
sess = sess.Where("name=?", opts.Name)
235236
}
236237
return sess.Count(&AccessToken{})
237238
}
238239

239240
// DeleteAccessTokenByID deletes access token by given ID.
240-
func DeleteAccessTokenByID(id, userID int64) error {
241-
cnt, err := db.GetEngine(db.DefaultContext).ID(id).Delete(&AccessToken{
241+
func DeleteAccessTokenByID(ctx context.Context, id, userID int64) error {
242+
cnt, err := db.GetEngine(ctx).ID(id).Delete(&AccessToken{
242243
UID: userID,
243244
})
244245
if err != nil {

models/auth/token_test.go

+18-17
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ import (
77
"testing"
88

99
auth_model "code.gitea.io/gitea/models/auth"
10+
"code.gitea.io/gitea/models/db"
1011
"code.gitea.io/gitea/models/unittest"
1112

1213
"github.com/stretchr/testify/assert"
@@ -18,15 +19,15 @@ func TestNewAccessToken(t *testing.T) {
1819
UID: 3,
1920
Name: "Token C",
2021
}
21-
assert.NoError(t, auth_model.NewAccessToken(token))
22+
assert.NoError(t, auth_model.NewAccessToken(db.DefaultContext, token))
2223
unittest.AssertExistsAndLoadBean(t, token)
2324

2425
invalidToken := &auth_model.AccessToken{
2526
ID: token.ID, // duplicate
2627
UID: 2,
2728
Name: "Token F",
2829
}
29-
assert.Error(t, auth_model.NewAccessToken(invalidToken))
30+
assert.Error(t, auth_model.NewAccessToken(db.DefaultContext, invalidToken))
3031
}
3132

3233
func TestAccessTokenByNameExists(t *testing.T) {
@@ -39,16 +40,16 @@ func TestAccessTokenByNameExists(t *testing.T) {
3940
}
4041

4142
// Check to make sure it doesn't exists already
42-
exist, err := auth_model.AccessTokenByNameExists(token)
43+
exist, err := auth_model.AccessTokenByNameExists(db.DefaultContext, token)
4344
assert.NoError(t, err)
4445
assert.False(t, exist)
4546

4647
// Save it to the database
47-
assert.NoError(t, auth_model.NewAccessToken(token))
48+
assert.NoError(t, auth_model.NewAccessToken(db.DefaultContext, token))
4849
unittest.AssertExistsAndLoadBean(t, token)
4950

5051
// This token must be found by name in the DB now
51-
exist, err = auth_model.AccessTokenByNameExists(token)
52+
exist, err = auth_model.AccessTokenByNameExists(db.DefaultContext, token)
5253
assert.NoError(t, err)
5354
assert.True(t, exist)
5455

@@ -59,32 +60,32 @@ func TestAccessTokenByNameExists(t *testing.T) {
5960

6061
// Name matches but different user ID, this shouldn't exists in the
6162
// database
62-
exist, err = auth_model.AccessTokenByNameExists(user4Token)
63+
exist, err = auth_model.AccessTokenByNameExists(db.DefaultContext, user4Token)
6364
assert.NoError(t, err)
6465
assert.False(t, exist)
6566
}
6667

6768
func TestGetAccessTokenBySHA(t *testing.T) {
6869
assert.NoError(t, unittest.PrepareTestDatabase())
69-
token, err := auth_model.GetAccessTokenBySHA("d2c6c1ba3890b309189a8e618c72a162e4efbf36")
70+
token, err := auth_model.GetAccessTokenBySHA(db.DefaultContext, "d2c6c1ba3890b309189a8e618c72a162e4efbf36")
7071
assert.NoError(t, err)
7172
assert.Equal(t, int64(1), token.UID)
7273
assert.Equal(t, "Token A", token.Name)
7374
assert.Equal(t, "2b3668e11cb82d3af8c6e4524fc7841297668f5008d1626f0ad3417e9fa39af84c268248b78c481daa7e5dc437784003494f", token.TokenHash)
7475
assert.Equal(t, "e4efbf36", token.TokenLastEight)
7576

76-
_, err = auth_model.GetAccessTokenBySHA("notahash")
77+
_, err = auth_model.GetAccessTokenBySHA(db.DefaultContext, "notahash")
7778
assert.Error(t, err)
7879
assert.True(t, auth_model.IsErrAccessTokenNotExist(err))
7980

80-
_, err = auth_model.GetAccessTokenBySHA("")
81+
_, err = auth_model.GetAccessTokenBySHA(db.DefaultContext, "")
8182
assert.Error(t, err)
8283
assert.True(t, auth_model.IsErrAccessTokenEmpty(err))
8384
}
8485

8586
func TestListAccessTokens(t *testing.T) {
8687
assert.NoError(t, unittest.PrepareTestDatabase())
87-
tokens, err := auth_model.ListAccessTokens(auth_model.ListAccessTokensOptions{UserID: 1})
88+
tokens, err := auth_model.ListAccessTokens(db.DefaultContext, auth_model.ListAccessTokensOptions{UserID: 1})
8889
assert.NoError(t, err)
8990
if assert.Len(t, tokens, 2) {
9091
assert.Equal(t, int64(1), tokens[0].UID)
@@ -93,39 +94,39 @@ func TestListAccessTokens(t *testing.T) {
9394
assert.Contains(t, []string{tokens[0].Name, tokens[1].Name}, "Token B")
9495
}
9596

96-
tokens, err = auth_model.ListAccessTokens(auth_model.ListAccessTokensOptions{UserID: 2})
97+
tokens, err = auth_model.ListAccessTokens(db.DefaultContext, auth_model.ListAccessTokensOptions{UserID: 2})
9798
assert.NoError(t, err)
9899
if assert.Len(t, tokens, 1) {
99100
assert.Equal(t, int64(2), tokens[0].UID)
100101
assert.Equal(t, "Token A", tokens[0].Name)
101102
}
102103

103-
tokens, err = auth_model.ListAccessTokens(auth_model.ListAccessTokensOptions{UserID: 100})
104+
tokens, err = auth_model.ListAccessTokens(db.DefaultContext, auth_model.ListAccessTokensOptions{UserID: 100})
104105
assert.NoError(t, err)
105106
assert.Empty(t, tokens)
106107
}
107108

108109
func TestUpdateAccessToken(t *testing.T) {
109110
assert.NoError(t, unittest.PrepareTestDatabase())
110-
token, err := auth_model.GetAccessTokenBySHA("4c6f36e6cf498e2a448662f915d932c09c5a146c")
111+
token, err := auth_model.GetAccessTokenBySHA(db.DefaultContext, "4c6f36e6cf498e2a448662f915d932c09c5a146c")
111112
assert.NoError(t, err)
112113
token.Name = "Token Z"
113114

114-
assert.NoError(t, auth_model.UpdateAccessToken(token))
115+
assert.NoError(t, auth_model.UpdateAccessToken(db.DefaultContext, token))
115116
unittest.AssertExistsAndLoadBean(t, token)
116117
}
117118

118119
func TestDeleteAccessTokenByID(t *testing.T) {
119120
assert.NoError(t, unittest.PrepareTestDatabase())
120121

121-
token, err := auth_model.GetAccessTokenBySHA("4c6f36e6cf498e2a448662f915d932c09c5a146c")
122+
token, err := auth_model.GetAccessTokenBySHA(db.DefaultContext, "4c6f36e6cf498e2a448662f915d932c09c5a146c")
122123
assert.NoError(t, err)
123124
assert.Equal(t, int64(1), token.UID)
124125

125-
assert.NoError(t, auth_model.DeleteAccessTokenByID(token.ID, 1))
126+
assert.NoError(t, auth_model.DeleteAccessTokenByID(db.DefaultContext, token.ID, 1))
126127
unittest.AssertNotExistsBean(t, token)
127128

128-
err = auth_model.DeleteAccessTokenByID(100, 100)
129+
err = auth_model.DeleteAccessTokenByID(db.DefaultContext, 100, 100)
129130
assert.Error(t, err)
130131
assert.True(t, auth_model.IsErrAccessTokenNotExist(err))
131132
}

models/auth/twofactor.go

+11-10
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
package auth
55

66
import (
7+
"context"
78
"crypto/md5"
89
"crypto/subtle"
910
"encoding/base32"
@@ -121,22 +122,22 @@ func (t *TwoFactor) ValidateTOTP(passcode string) (bool, error) {
121122
}
122123

123124
// NewTwoFactor creates a new two-factor authentication token.
124-
func NewTwoFactor(t *TwoFactor) error {
125-
_, err := db.GetEngine(db.DefaultContext).Insert(t)
125+
func NewTwoFactor(ctx context.Context, t *TwoFactor) error {
126+
_, err := db.GetEngine(ctx).Insert(t)
126127
return err
127128
}
128129

129130
// UpdateTwoFactor updates a two-factor authentication token.
130-
func UpdateTwoFactor(t *TwoFactor) error {
131-
_, err := db.GetEngine(db.DefaultContext).ID(t.ID).AllCols().Update(t)
131+
func UpdateTwoFactor(ctx context.Context, t *TwoFactor) error {
132+
_, err := db.GetEngine(ctx).ID(t.ID).AllCols().Update(t)
132133
return err
133134
}
134135

135136
// GetTwoFactorByUID returns the two-factor authentication token associated with
136137
// the user, if any.
137-
func GetTwoFactorByUID(uid int64) (*TwoFactor, error) {
138+
func GetTwoFactorByUID(ctx context.Context, uid int64) (*TwoFactor, error) {
138139
twofa := &TwoFactor{}
139-
has, err := db.GetEngine(db.DefaultContext).Where("uid=?", uid).Get(twofa)
140+
has, err := db.GetEngine(ctx).Where("uid=?", uid).Get(twofa)
140141
if err != nil {
141142
return nil, err
142143
} else if !has {
@@ -147,13 +148,13 @@ func GetTwoFactorByUID(uid int64) (*TwoFactor, error) {
147148

148149
// HasTwoFactorByUID returns the two-factor authentication token associated with
149150
// the user, if any.
150-
func HasTwoFactorByUID(uid int64) (bool, error) {
151-
return db.GetEngine(db.DefaultContext).Where("uid=?", uid).Exist(&TwoFactor{})
151+
func HasTwoFactorByUID(ctx context.Context, uid int64) (bool, error) {
152+
return db.GetEngine(ctx).Where("uid=?", uid).Exist(&TwoFactor{})
152153
}
153154

154155
// DeleteTwoFactorByID deletes two-factor authentication token by given ID.
155-
func DeleteTwoFactorByID(id, userID int64) error {
156-
cnt, err := db.GetEngine(db.DefaultContext).ID(id).Delete(&TwoFactor{
156+
func DeleteTwoFactorByID(ctx context.Context, id, userID int64) error {
157+
cnt, err := db.GetEngine(ctx).ID(id).Delete(&TwoFactor{
157158
UID: userID,
158159
})
159160
if err != nil {

models/issues/comment.go

+2-2
Original file line numberDiff line numberDiff line change
@@ -359,12 +359,12 @@ func (c *Comment) LoadPoster(ctx context.Context) (err error) {
359359
}
360360

361361
// AfterDelete is invoked from XORM after the object is deleted.
362-
func (c *Comment) AfterDelete() {
362+
func (c *Comment) AfterDelete(ctx context.Context) {
363363
if c.ID <= 0 {
364364
return
365365
}
366366

367-
_, err := repo_model.DeleteAttachmentsByComment(c.ID, true)
367+
_, err := repo_model.DeleteAttachmentsByComment(ctx, c.ID, true)
368368
if err != nil {
369369
log.Info("Could not delete files for comment %d on issue #%d: %s", c.ID, c.IssueID, err)
370370
}

models/issues/pull_list.go

+7-7
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,8 @@ type PullRequestsOptions struct {
2727
MilestoneID int64
2828
}
2929

30-
func listPullRequestStatement(baseRepoID int64, opts *PullRequestsOptions) (*xorm.Session, error) {
31-
sess := db.GetEngine(db.DefaultContext).Where("pull_request.base_repo_id=?", baseRepoID)
30+
func listPullRequestStatement(ctx context.Context, baseRepoID int64, opts *PullRequestsOptions) (*xorm.Session, error) {
31+
sess := db.GetEngine(ctx).Where("pull_request.base_repo_id=?", baseRepoID)
3232

3333
sess.Join("INNER", "issue", "pull_request.issue_id = issue.id")
3434
switch opts.State {
@@ -115,21 +115,21 @@ func GetUnmergedPullRequestsByBaseInfo(ctx context.Context, repoID int64, branch
115115
}
116116

117117
// GetPullRequestIDsByCheckStatus returns all pull requests according the special checking status.
118-
func GetPullRequestIDsByCheckStatus(status PullRequestStatus) ([]int64, error) {
118+
func GetPullRequestIDsByCheckStatus(ctx context.Context, status PullRequestStatus) ([]int64, error) {
119119
prs := make([]int64, 0, 10)
120-
return prs, db.GetEngine(db.DefaultContext).Table("pull_request").
120+
return prs, db.GetEngine(ctx).Table("pull_request").
121121
Where("status=?", status).
122122
Cols("pull_request.id").
123123
Find(&prs)
124124
}
125125

126126
// PullRequests returns all pull requests for a base Repo by the given conditions
127-
func PullRequests(baseRepoID int64, opts *PullRequestsOptions) ([]*PullRequest, int64, error) {
127+
func PullRequests(ctx context.Context, baseRepoID int64, opts *PullRequestsOptions) ([]*PullRequest, int64, error) {
128128
if opts.Page <= 0 {
129129
opts.Page = 1
130130
}
131131

132-
countSession, err := listPullRequestStatement(baseRepoID, opts)
132+
countSession, err := listPullRequestStatement(ctx, baseRepoID, opts)
133133
if err != nil {
134134
log.Error("listPullRequestStatement: %v", err)
135135
return nil, 0, err
@@ -140,7 +140,7 @@ func PullRequests(baseRepoID int64, opts *PullRequestsOptions) ([]*PullRequest,
140140
return nil, maxResults, err
141141
}
142142

143-
findSession, err := listPullRequestStatement(baseRepoID, opts)
143+
findSession, err := listPullRequestStatement(ctx, baseRepoID, opts)
144144
applySorts(findSession, opts.SortType, 0)
145145
if err != nil {
146146
log.Error("listPullRequestStatement: %v", err)

models/issues/pull_test.go

+2-2
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ func TestPullRequest_LoadHeadRepo(t *testing.T) {
6060

6161
func TestPullRequestsNewest(t *testing.T) {
6262
assert.NoError(t, unittest.PrepareTestDatabase())
63-
prs, count, err := issues_model.PullRequests(1, &issues_model.PullRequestsOptions{
63+
prs, count, err := issues_model.PullRequests(db.DefaultContext, 1, &issues_model.PullRequestsOptions{
6464
ListOptions: db.ListOptions{
6565
Page: 1,
6666
},
@@ -107,7 +107,7 @@ func TestLoadRequestedReviewers(t *testing.T) {
107107

108108
func TestPullRequestsOldest(t *testing.T) {
109109
assert.NoError(t, unittest.PrepareTestDatabase())
110-
prs, count, err := issues_model.PullRequests(1, &issues_model.PullRequestsOptions{
110+
prs, count, err := issues_model.PullRequests(db.DefaultContext, 1, &issues_model.PullRequestsOptions{
111111
ListOptions: db.ListOptions{
112112
Page: 1,
113113
},

0 commit comments

Comments
 (0)