Skip to content

Commit 997be7a

Browse files
committed
💄
1 parent 6871fc4 commit 997be7a

File tree

2 files changed

+56
-37
lines changed

2 files changed

+56
-37
lines changed

components/gitpod-db/go/personal_access_token.go

Lines changed: 18 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@ import (
1414

1515
"github.com/google/uuid"
1616
"gorm.io/gorm"
17-
"gorm.io/gorm/clause"
1817
)
1918

2019
type PersonalAccessToken struct {
@@ -105,14 +104,24 @@ func UpdatePersonalAccessToken(ctx context.Context, conn *gorm.DB, tokenID uuid.
105104
db := conn.WithContext(ctx)
106105

107106
var token PersonalAccessToken
108-
db = db.
109-
Model(&token).
110-
Clauses(clause.Returning{}).
111-
Where("id = ?", tokenID).
112-
Where("userId = ?", userID).
113-
Where("deleted = ?", 0).
114-
Select("hash", "expirationTime").Updates(PersonalAccessToken{Hash: hash, ExpirationTime: expirationTime})
115-
if db.Error != nil {
107+
err := db.Transaction(func(tx *gorm.DB) error {
108+
txErr := tx.
109+
Where("id = ?", tokenID).
110+
Where("userId = ?", userID).
111+
Where("deleted = ?", 0).
112+
Select("hash", "expirationTime").Updates(PersonalAccessToken{Hash: hash, ExpirationTime: expirationTime}).Error
113+
if txErr != nil {
114+
return txErr
115+
}
116+
117+
txErr = tx.Where("id = ?", tokenID).Where("userId = ?", userID).Where("deleted = ?", 0).First(&token).Error
118+
if txErr != nil {
119+
return txErr
120+
}
121+
122+
return nil
123+
})
124+
if err != nil {
116125
if errors.Is(db.Error, gorm.ErrRecordNotFound) {
117126
return PersonalAccessToken{}, fmt.Errorf("Token with ID %s does not exist: %w", tokenID, ErrorNotFound)
118127
}

components/public-api-server/pkg/apiv1/tokens_test.go

Lines changed: 38 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -22,10 +22,8 @@ import (
2222
protocol "github.com/gitpod-io/gitpod/gitpod-protocol"
2323
"github.com/gitpod-io/gitpod/public-api-server/pkg/auth"
2424
"github.com/golang/mock/gomock"
25-
"github.com/google/go-cmp/cmp"
2625
"github.com/google/uuid"
2726
"github.com/stretchr/testify/require"
28-
"google.golang.org/protobuf/testing/protocmp"
2927
"google.golang.org/protobuf/types/known/timestamppb"
3028
"gorm.io/gorm"
3129
)
@@ -368,6 +366,33 @@ func TestTokensService_RegeneratePersonalAccessToken(t *testing.T) {
368366
require.Equal(t, connect.CodeInvalidArgument, connect.CodeOf(err))
369367
})
370368

369+
t.Run("responds with not found when token is not found", func(t *testing.T) {
370+
serverMock, dbConn, client := setupTokensService(t, withTokenFeatureEnabled)
371+
372+
someTokenId := uuid.New().String()
373+
374+
dbtest.CreatePersonalAccessTokenRecords(t, dbConn,
375+
dbtest.NewPersonalAccessToken(t, db.PersonalAccessToken{
376+
UserID: uuid.MustParse(user.ID),
377+
}),
378+
dbtest.NewPersonalAccessToken(t, db.PersonalAccessToken{
379+
UserID: uuid.MustParse(user2.ID),
380+
}),
381+
)
382+
383+
serverMock.EXPECT().GetLoggedInUser(gomock.Any()).Return(user, nil)
384+
385+
newTime, err := time.Parse(time.RFC3339, "2023-01-02T15:04:05Z")
386+
require.NoError(t, err)
387+
388+
newTimestamp := timestamppb.New(newTime)
389+
_, err = client.RegeneratePersonalAccessToken(context.Background(), connect.NewRequest(&v1.RegeneratePersonalAccessTokenRequest{
390+
Id: someTokenId,
391+
ExpirationTime: newTimestamp,
392+
}))
393+
require.Error(t, err, fmt.Errorf("Token with ID %s does not exist: not found", someTokenId))
394+
})
395+
371396
t.Run("regenerate correct token", func(t *testing.T) {
372397
serverMock, dbConn, client := setupTokensService(t, withTokenFeatureEnabled)
373398

@@ -380,36 +405,30 @@ func TestTokensService_RegeneratePersonalAccessToken(t *testing.T) {
380405
}),
381406
)
382407

383-
serverMock.EXPECT().GetLoggedInUser(gomock.Any()).Return(user, nil)
408+
serverMock.EXPECT().GetLoggedInUser(gomock.Any()).Return(user, nil).MaxTimes(2)
384409

385410
origResponse, err := client.GetPersonalAccessToken(context.Background(), connect.NewRequest(&v1.GetPersonalAccessTokenRequest{
386411
Id: tokens[0].ID.String(),
387412
}))
413+
require.NoError(t, err)
388414

415+
newTime, err := time.Parse(time.RFC3339, "2023-01-02T15:04:05Z")
389416
require.NoError(t, err)
390417

391-
newTimestamp := timestamppb.Now()
418+
newTimestamp := timestamppb.New(newTime)
392419
response, err := client.RegeneratePersonalAccessToken(context.Background(), connect.NewRequest(&v1.RegeneratePersonalAccessTokenRequest{
393420
Id: tokens[0].ID.String(),
394421
ExpirationTime: newTimestamp,
395422
}))
396-
397423
require.NoError(t, err)
398424

399-
// require.Equal(t, response.Msg.Token.ExpirationTime, newTimestamp)
400-
// require.NotEqual(t, response.Msg.Token.Value, origResponse.Msg.Token.Value)
401-
402-
requireNotEqualProto(t, &v1.RegeneratePersonalAccessTokenResponse{
403-
Token: personalAccessTokenToAPI(tokens[0], origResponse.Msg.Token.Value),
404-
}, response.Msg)
405-
406-
// fake code
407-
tokens[0].Hash = response.Msg.Token.Value
408-
tokens[0].ExpirationTime = response.Msg.Token.GetExpirationTime().AsTime()
409-
410-
requireEqualProto(t, &v1.RegeneratePersonalAccessTokenResponse{
411-
Token: personalAccessTokenToAPI(tokens[0], ""),
412-
}, response.Msg)
425+
require.Equal(t, origResponse.Msg.Token.Id, response.Msg.Token.Id)
426+
require.NotEqual(t, "", response.Msg.Token.Value)
427+
require.Equal(t, origResponse.Msg.Token.Name, response.Msg.Token.Name)
428+
require.Equal(t, origResponse.Msg.Token.Description, response.Msg.Token.Description)
429+
require.Equal(t, origResponse.Msg.Token.Scopes, response.Msg.Token.Scopes)
430+
require.Equal(t, newTimestamp.AsTime(), response.Msg.Token.ExpirationTime.AsTime())
431+
require.Equal(t, origResponse.Msg.Token.CreatedAt, response.Msg.Token.CreatedAt)
413432
})
414433
}
415434

@@ -536,12 +555,3 @@ func setupTokensService(t *testing.T, expClient experiments.Client) (*protocol.M
536555

537556
return serverMock, dbConn, client
538557
}
539-
540-
func requireNotEqualProto(t *testing.T, expected interface{}, actual interface{}) {
541-
t.Helper()
542-
543-
diff := cmp.Diff(expected, actual, protocmp.Transform())
544-
if diff == "" {
545-
require.Fail(t, diff, "they should not equal")
546-
}
547-
}

0 commit comments

Comments
 (0)