From 7cc24cf2b37aa49c299c549d0951afae6c6286f9 Mon Sep 17 00:00:00 2001 From: Chris Stockton Date: Mon, 14 Apr 2025 11:02:33 -0700 Subject: [PATCH 1/2] feat: last of basic refactoring and moving of existing hooks The existing hooks are now in hooks/v0, using the same tests from the api package. --- internal/api/admin.go | 64 +-- internal/api/anonymous.go | 4 +- internal/api/api.go | 10 +- internal/api/apierrors/apierrors.go | 35 +- internal/api/apierrors/apierrors_test.go | 180 ++++++++ internal/api/audit.go | 6 +- internal/api/auth.go | 22 +- internal/api/auth_test.go | 6 +- internal/api/errors.go | 36 -- internal/api/errors_test.go | 6 +- internal/api/external.go | 56 +-- internal/api/external_oauth.go | 18 +- internal/api/helpers.go | 4 +- internal/api/helpers_test.go | 8 +- internal/api/hooks_test.go | 54 +-- internal/api/identity.go | 34 +- internal/api/identity_test.go | 6 +- internal/api/invite.go | 4 +- internal/api/logout.go | 4 +- internal/api/magic_link.go | 10 +- internal/api/mail.go | 92 ++-- internal/api/mail_test.go | 4 +- internal/api/mfa.go | 160 +++---- internal/api/mfa_test.go | 2 +- internal/api/middleware.go | 12 +- internal/api/middleware_test.go | 2 +- internal/api/otp.go | 16 +- internal/api/password.go | 4 +- internal/api/phone.go | 27 +- internal/api/phone_test.go | 3 +- internal/api/pkce.go | 8 +- internal/api/reauthenticate.go | 16 +- internal/api/recover.go | 4 +- internal/api/resend.go | 16 +- internal/api/samlacs.go | 36 +- internal/api/signup.go | 32 +- internal/api/sso.go | 20 +- internal/api/ssoadmin.go | 40 +- internal/api/token.go | 66 +-- internal/api/token_oidc.go | 22 +- internal/api/token_refresh.go | 36 +- internal/api/user.go | 34 +- internal/api/verify.go | 56 +-- internal/api/verify_test.go | 8 +- internal/api/web3.go | 34 +- internal/hooks/hooks.go | 38 ++ internal/hooks/hooks_test.go | 139 ++++++ .../hooks.go => hooks/v0hooks/manager.go} | 433 ++++++++++-------- .../{auth_hooks.go => v0hooks/v0hooks.go} | 2 +- 49 files changed, 1180 insertions(+), 749 deletions(-) create mode 100644 internal/api/apierrors/apierrors_test.go create mode 100644 internal/hooks/hooks.go create mode 100644 internal/hooks/hooks_test.go rename internal/{api/hooks.go => hooks/v0hooks/manager.go} (63%) rename internal/hooks/{auth_hooks.go => v0hooks/v0hooks.go} (99%) diff --git a/internal/api/admin.go b/internal/api/admin.go index 9d3be8c77..f8188d578 100644 --- a/internal/api/admin.go +++ b/internal/api/admin.go @@ -54,7 +54,7 @@ func (a *API) loadUser(w http.ResponseWriter, r *http.Request) (context.Context, userID, err := uuid.FromString(chi.URLParam(r, "user_id")) if err != nil { - return nil, notFoundError(apierrors.ErrorCodeValidationFailed, "user_id must be an UUID") + return nil, apierrors.NewNotFoundError(apierrors.ErrorCodeValidationFailed, "user_id must be an UUID") } observability.LogEntrySetField(r, "user_id", userID) @@ -62,9 +62,9 @@ func (a *API) loadUser(w http.ResponseWriter, r *http.Request) (context.Context, u, err := models.FindUserByID(db, userID) if err != nil { if models.IsNotFoundError(err) { - return nil, notFoundError(apierrors.ErrorCodeUserNotFound, "User not found") + return nil, apierrors.NewNotFoundError(apierrors.ErrorCodeUserNotFound, "User not found") } - return nil, internalServerError("Database error loading user").WithInternalError(err) + return nil, apierrors.NewInternalServerError("Database error loading user").WithInternalError(err) } return withUser(ctx, u), nil @@ -77,7 +77,7 @@ func (a *API) loadFactor(w http.ResponseWriter, r *http.Request) (context.Contex user := getUser(ctx) factorID, err := uuid.FromString(chi.URLParam(r, "factor_id")) if err != nil { - return nil, notFoundError(apierrors.ErrorCodeValidationFailed, "factor_id must be an UUID") + return nil, apierrors.NewNotFoundError(apierrors.ErrorCodeValidationFailed, "factor_id must be an UUID") } observability.LogEntrySetField(r, "factor_id", factorID) @@ -85,9 +85,9 @@ func (a *API) loadFactor(w http.ResponseWriter, r *http.Request) (context.Contex factor, err := user.FindOwnedFactorByID(db, factorID) if err != nil { if models.IsNotFoundError(err) { - return nil, notFoundError(apierrors.ErrorCodeMFAFactorNotFound, "Factor not found") + return nil, apierrors.NewNotFoundError(apierrors.ErrorCodeMFAFactorNotFound, "Factor not found") } - return nil, internalServerError("Database error loading factor").WithInternalError(err) + return nil, apierrors.NewInternalServerError("Database error loading factor").WithInternalError(err) } return withFactor(ctx, factor), nil } @@ -109,19 +109,19 @@ func (a *API) adminUsers(w http.ResponseWriter, r *http.Request) error { pageParams, err := paginate(r) if err != nil { - return badRequestError(apierrors.ErrorCodeValidationFailed, "Bad Pagination Parameters: %v", err).WithInternalError(err) + return apierrors.NewBadRequestError(apierrors.ErrorCodeValidationFailed, "Bad Pagination Parameters: %v", err).WithInternalError(err) } sortParams, err := sort(r, map[string]bool{models.CreatedAt: true}, []models.SortField{{Name: models.CreatedAt, Dir: models.Descending}}) if err != nil { - return badRequestError(apierrors.ErrorCodeValidationFailed, "Bad Sort Parameters: %v", err) + return apierrors.NewBadRequestError(apierrors.ErrorCodeValidationFailed, "Bad Sort Parameters: %v", err) } filter := r.URL.Query().Get("filter") users, err := models.FindUsersInAudience(db, aud, pageParams, sortParams, filter) if err != nil { - return internalServerError("Database error finding users").WithInternalError(err) + return apierrors.NewInternalServerError("Database error finding users").WithInternalError(err) } addPaginationHeaders(w, r, pageParams) @@ -170,7 +170,7 @@ func (a *API) adminUserUpdate(w http.ResponseWriter, r *http.Request) error { if params.BanDuration != "none" { duration, err = time.ParseDuration(params.BanDuration) if err != nil { - return badRequestError(apierrors.ErrorCodeValidationFailed, "invalid format for ban duration: %v", err) + return apierrors.NewBadRequestError(apierrors.ErrorCodeValidationFailed, "invalid format for ban duration: %v", err) } } banDuration = &duration @@ -315,7 +315,7 @@ func (a *API) adminUserUpdate(w http.ResponseWriter, r *http.Request) error { }) if err != nil { - return internalServerError("Error updating user").WithInternalError(err) + return apierrors.NewInternalServerError("Error updating user").WithInternalError(err) } return sendJSON(w, http.StatusOK, user) @@ -339,7 +339,7 @@ func (a *API) adminUserCreate(w http.ResponseWriter, r *http.Request) error { } if params.Email == "" && params.Phone == "" { - return badRequestError(apierrors.ErrorCodeValidationFailed, "Cannot create a user without either an email or phone") + return apierrors.NewBadRequestError(apierrors.ErrorCodeValidationFailed, "Cannot create a user without either an email or phone") } var providers []string @@ -349,9 +349,9 @@ func (a *API) adminUserCreate(w http.ResponseWriter, r *http.Request) error { return err } if user, err := models.IsDuplicatedEmail(db, params.Email, aud, nil); err != nil { - return internalServerError("Database error checking email").WithInternalError(err) + return apierrors.NewInternalServerError("Database error checking email").WithInternalError(err) } else if user != nil { - return unprocessableEntityError(apierrors.ErrorCodeEmailExists, DuplicateEmailMsg) + return apierrors.NewUnprocessableEntityError(apierrors.ErrorCodeEmailExists, DuplicateEmailMsg) } providers = append(providers, "email") } @@ -362,21 +362,21 @@ func (a *API) adminUserCreate(w http.ResponseWriter, r *http.Request) error { return err } if exists, err := models.IsDuplicatedPhone(db, params.Phone, aud); err != nil { - return internalServerError("Database error checking phone").WithInternalError(err) + return apierrors.NewInternalServerError("Database error checking phone").WithInternalError(err) } else if exists { - return unprocessableEntityError(apierrors.ErrorCodePhoneExists, "Phone number already registered by another user") + return apierrors.NewUnprocessableEntityError(apierrors.ErrorCodePhoneExists, "Phone number already registered by another user") } providers = append(providers, "phone") } if params.Password != nil && params.PasswordHash != "" { - return badRequestError(apierrors.ErrorCodeValidationFailed, "Only a password or a password hash should be provided") + return apierrors.NewBadRequestError(apierrors.ErrorCodeValidationFailed, "Only a password or a password hash should be provided") } if (params.Password == nil || *params.Password == "") && params.PasswordHash == "" { password, err := password.Generate(64, 10, 0, false, true) if err != nil { - return internalServerError("Error generating password").WithInternalError(err) + return apierrors.NewInternalServerError("Error generating password").WithInternalError(err) } params.Password = &password } @@ -390,18 +390,18 @@ func (a *API) adminUserCreate(w http.ResponseWriter, r *http.Request) error { if err != nil { if errors.Is(err, bcrypt.ErrPasswordTooLong) { - return badRequestError(apierrors.ErrorCodeValidationFailed, err.Error()) + return apierrors.NewBadRequestError(apierrors.ErrorCodeValidationFailed, err.Error()) } - return internalServerError("Error creating user").WithInternalError(err) + return apierrors.NewInternalServerError("Error creating user").WithInternalError(err) } if params.Id != "" { customId, err := uuid.FromString(params.Id) if err != nil { - return badRequestError(apierrors.ErrorCodeValidationFailed, "ID must conform to the uuid v4 format") + return apierrors.NewBadRequestError(apierrors.ErrorCodeValidationFailed, "ID must conform to the uuid v4 format") } if customId == uuid.Nil { - return badRequestError(apierrors.ErrorCodeValidationFailed, "ID cannot be a nil uuid") + return apierrors.NewBadRequestError(apierrors.ErrorCodeValidationFailed, "ID cannot be a nil uuid") } user.ID = customId } @@ -419,7 +419,7 @@ func (a *API) adminUserCreate(w http.ResponseWriter, r *http.Request) error { if params.BanDuration != "none" { duration, err = time.ParseDuration(params.BanDuration) if err != nil { - return badRequestError(apierrors.ErrorCodeValidationFailed, "invalid format for ban duration: %v", err) + return apierrors.NewBadRequestError(apierrors.ErrorCodeValidationFailed, "invalid format for ban duration: %v", err) } } banDuration = &duration @@ -501,7 +501,7 @@ func (a *API) adminUserCreate(w http.ResponseWriter, r *http.Request) error { }) if err != nil { - return internalServerError("Database error creating new user").WithInternalError(err) + return apierrors.NewInternalServerError("Database error creating new user").WithInternalError(err) } return sendJSON(w, http.StatusOK, user) @@ -529,7 +529,7 @@ func (a *API) adminUserDelete(w http.ResponseWriter, r *http.Request) error { "user_email": user.Email, "user_phone": user.Phone, }); terr != nil { - return internalServerError("Error recording audit log entry").WithInternalError(terr) + return apierrors.NewInternalServerError("Error recording audit log entry").WithInternalError(terr) } if params.ShouldSoftDelete { @@ -538,24 +538,24 @@ func (a *API) adminUserDelete(w http.ResponseWriter, r *http.Request) error { return nil } if terr := user.SoftDeleteUser(tx); terr != nil { - return internalServerError("Error soft deleting user").WithInternalError(terr) + return apierrors.NewInternalServerError("Error soft deleting user").WithInternalError(terr) } if terr := user.SoftDeleteUserIdentities(tx); terr != nil { - return internalServerError("Error soft deleting user identities").WithInternalError(terr) + return apierrors.NewInternalServerError("Error soft deleting user identities").WithInternalError(terr) } // hard delete all associated factors if terr := models.DeleteFactorsByUserId(tx, user.ID); terr != nil { - return internalServerError("Error deleting user's factors").WithInternalError(terr) + return apierrors.NewInternalServerError("Error deleting user's factors").WithInternalError(terr) } // hard delete all associated sessions if terr := models.Logout(tx, user.ID); terr != nil { - return internalServerError("Error deleting user's sessions").WithInternalError(terr) + return apierrors.NewInternalServerError("Error deleting user's sessions").WithInternalError(terr) } } else { if terr := tx.Destroy(user); terr != nil { - return internalServerError("Database error deleting user").WithInternalError(terr) + return apierrors.NewInternalServerError("Database error deleting user").WithInternalError(terr) } } @@ -581,7 +581,7 @@ func (a *API) adminUserDeleteFactor(w http.ResponseWriter, r *http.Request) erro return terr } if terr := tx.Destroy(factor); terr != nil { - return internalServerError("Database error deleting factor").WithInternalError(terr) + return apierrors.NewInternalServerError("Database error deleting factor").WithInternalError(terr) } return nil }) @@ -619,7 +619,7 @@ func (a *API) adminUserUpdateFactor(w http.ResponseWriter, r *http.Request) erro if params.Phone != "" && factor.IsPhoneFactor() { phone, err := validatePhone(params.Phone) if err != nil { - return badRequestError(apierrors.ErrorCodeValidationFailed, "Invalid phone number format (E.164 required)") + return apierrors.NewBadRequestError(apierrors.ErrorCodeValidationFailed, "Invalid phone number format (E.164 required)") } if terr := factor.UpdatePhone(tx, phone); terr != nil { return terr diff --git a/internal/api/anonymous.go b/internal/api/anonymous.go index b3bc61b92..33aa04aee 100644 --- a/internal/api/anonymous.go +++ b/internal/api/anonymous.go @@ -16,7 +16,7 @@ func (a *API) SignupAnonymously(w http.ResponseWriter, r *http.Request) error { aud := a.requestAud(ctx, r) if config.DisableSignup { - return unprocessableEntityError(apierrors.ErrorCodeSignupDisabled, "Signups not allowed for this instance") + return apierrors.NewUnprocessableEntityError(apierrors.ErrorCodeSignupDisabled, "Signups not allowed for this instance") } params := &SignupParams{} @@ -48,7 +48,7 @@ func (a *API) SignupAnonymously(w http.ResponseWriter, r *http.Request) error { return nil }) if err != nil { - return internalServerError("Database error creating anonymous user").WithInternalError(err) + return apierrors.NewInternalServerError("Database error creating anonymous user").WithInternalError(err) } metering.RecordLogin("anonymous", newUser.ID) diff --git a/internal/api/api.go b/internal/api/api.go index f184ce55c..534f168b2 100644 --- a/internal/api/api.go +++ b/internal/api/api.go @@ -10,6 +10,7 @@ import ( "github.com/sirupsen/logrus" "github.com/supabase/auth/internal/api/apierrors" "github.com/supabase/auth/internal/conf" + "github.com/supabase/auth/internal/hooks" "github.com/supabase/auth/internal/mailer" "github.com/supabase/auth/internal/models" "github.com/supabase/auth/internal/observability" @@ -32,6 +33,7 @@ type API struct { config *conf.GlobalConfiguration version string + hooksMgr *hooks.Manager hibpClient *hibp.PwnedClient // overrideTime can be used to override the clock used by handlers. Should only be used in tests! @@ -40,6 +42,9 @@ type API struct { limiterOpts *LimiterOptions } +func (a *API) GetConfig() *conf.GlobalConfiguration { return a.config } +func (a *API) GetDB() *storage.Connection { return a.db } + func (a *API) Version() string { return a.version } @@ -81,6 +86,9 @@ func NewAPIWithVersion(globalConfig *conf.GlobalConfiguration, db *storage.Conne if api.limiterOpts == nil { api.limiterOpts = NewLimiterOptions(globalConfig) } + if api.hooksMgr == nil { + api.hooksMgr = hooks.NewManager(db, globalConfig) + } if api.config.Password.HIBP.Enabled { httpClient := &http.Client{ // all HIBP API requests should finish quickly to avoid @@ -157,7 +165,7 @@ func NewAPIWithVersion(globalConfig *conf.GlobalConfiguration, db *storage.Conne } if params.Email == "" && params.Phone == "" { if !api.config.External.AnonymousUsers.Enabled { - return unprocessableEntityError(apierrors.ErrorCodeAnonymousProviderDisabled, "Anonymous sign-ins are disabled") + return apierrors.NewUnprocessableEntityError(apierrors.ErrorCodeAnonymousProviderDisabled, "Anonymous sign-ins are disabled") } if _, err := api.limitHandler(limitAnonymousSignIns)(w, r); err != nil { return err diff --git a/internal/api/apierrors/apierrors.go b/internal/api/apierrors/apierrors.go index 7c6780d30..adab1d39c 100644 --- a/internal/api/apierrors/apierrors.go +++ b/internal/api/apierrors/apierrors.go @@ -2,6 +2,7 @@ package apierrors import ( "fmt" + "net/http" ) // OAuthError is the JSON handler for OAuth2 error responses @@ -30,7 +31,7 @@ func (e *OAuthError) WithInternalError(err error) *OAuthError { } // WithInternalMessage adds internal message information to the error -func (e *OAuthError) WithInternalMessage(fmtString string, args ...interface{}) *OAuthError { +func (e *OAuthError) WithInternalMessage(fmtString string, args ...any) *OAuthError { e.InternalMessage = fmt.Sprintf(fmtString, args...) return e } @@ -53,7 +54,7 @@ type HTTPError struct { ErrorID string `json:"error_id,omitempty"` } -func NewHTTPError(httpStatus int, errorCode ErrorCode, fmtString string, args ...interface{}) *HTTPError { +func NewHTTPError(httpStatus int, errorCode ErrorCode, fmtString string, args ...any) *HTTPError { return &HTTPError{ HTTPStatus: httpStatus, ErrorCode: errorCode, @@ -61,6 +62,34 @@ func NewHTTPError(httpStatus int, errorCode ErrorCode, fmtString string, args .. } } +func NewBadRequestError(errorCode ErrorCode, fmtString string, args ...any) *HTTPError { + return NewHTTPError(http.StatusBadRequest, errorCode, fmtString, args...) +} + +func NewNotFoundError(errorCode ErrorCode, fmtString string, args ...any) *HTTPError { + return NewHTTPError(http.StatusNotFound, errorCode, fmtString, args...) +} + +func NewForbiddenError(errorCode ErrorCode, fmtString string, args ...any) *HTTPError { + return NewHTTPError(http.StatusForbidden, errorCode, fmtString, args...) +} + +func NewUnprocessableEntityError(errorCode ErrorCode, fmtString string, args ...any) *HTTPError { + return NewHTTPError(http.StatusUnprocessableEntity, errorCode, fmtString, args...) +} + +func NewTooManyRequestsError(errorCode ErrorCode, fmtString string, args ...any) *HTTPError { + return NewHTTPError(http.StatusTooManyRequests, errorCode, fmtString, args...) +} + +func NewInternalServerError(fmtString string, args ...any) *HTTPError { + return NewHTTPError(http.StatusInternalServerError, ErrorCodeUnexpectedFailure, fmtString, args...) +} + +func NewConflictError(fmtString string, args ...any) *HTTPError { + return NewHTTPError(http.StatusConflict, ErrorCodeConflict, fmtString, args...) +} + func (e *HTTPError) Error() string { if e.InternalMessage != "" { return e.InternalMessage @@ -87,7 +116,7 @@ func (e *HTTPError) WithInternalError(err error) *HTTPError { } // WithInternalMessage adds internal message information to the error -func (e *HTTPError) WithInternalMessage(fmtString string, args ...interface{}) *HTTPError { +func (e *HTTPError) WithInternalMessage(fmtString string, args ...any) *HTTPError { e.InternalMessage = fmt.Sprintf(fmtString, args...) return e } diff --git a/internal/api/apierrors/apierrors_test.go b/internal/api/apierrors/apierrors_test.go new file mode 100644 index 000000000..515a79a9e --- /dev/null +++ b/internal/api/apierrors/apierrors_test.go @@ -0,0 +1,180 @@ +package apierrors + +import ( + "errors" + "net/http" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestHTTPErrors(t *testing.T) { + sentinel := errors.New("sentinel") + + tests := []struct { + from error + exp *HTTPError + }{ + + // Status, ErrorCode, fmtStr, args + { + from: NewHTTPError( + http.StatusBadRequest, + ErrorCodeBadJSON, + "Unable to parse JSON: %v", + errors.New("bad syntax"), + ), + exp: &HTTPError{ + HTTPStatus: http.StatusBadRequest, + ErrorCode: ErrorCodeBadJSON, + Message: "Unable to parse JSON: bad syntax", + }, + }, + + // ErrorCode, fmtStr, args + { + from: NewBadRequestError( + ErrorCodeBadJSON, + "Unable to parse JSON: %v", + errors.New("bad syntax"), + ), + exp: &HTTPError{ + HTTPStatus: http.StatusBadRequest, + ErrorCode: ErrorCodeBadJSON, + Message: "Unable to parse JSON: bad syntax", + }, + }, + { + from: NewNotFoundError( + ErrorCodeUnknown, + "error: %v", + sentinel, + ), + exp: &HTTPError{ + HTTPStatus: http.StatusNotFound, + ErrorCode: ErrorCodeUnknown, + Message: "error: " + sentinel.Error(), + }, + }, + { + from: NewForbiddenError( + ErrorCodeUnknown, + "error: %v", + sentinel, + ), + exp: &HTTPError{ + HTTPStatus: http.StatusForbidden, + ErrorCode: ErrorCodeUnknown, + Message: "error: " + sentinel.Error(), + }, + }, + { + from: NewUnprocessableEntityError( + ErrorCodeUnknown, + "error: %v", + sentinel, + ), + exp: &HTTPError{ + HTTPStatus: http.StatusUnprocessableEntity, + ErrorCode: ErrorCodeUnknown, + Message: "error: " + sentinel.Error(), + }, + }, + { + from: NewTooManyRequestsError( + ErrorCodeUnknown, + "error: %v", + sentinel, + ), + exp: &HTTPError{ + HTTPStatus: http.StatusTooManyRequests, + ErrorCode: ErrorCodeUnknown, + Message: "error: " + sentinel.Error(), + }, + }, + + // fmtStr, args + { + from: NewInternalServerError( + "error: %v", + sentinel, + ), + exp: &HTTPError{ + HTTPStatus: http.StatusInternalServerError, + ErrorCode: ErrorCodeUnexpectedFailure, + Message: "error: " + sentinel.Error(), + }, + }, + { + from: NewConflictError( + "error: %v", + sentinel, + ), + exp: &HTTPError{ + HTTPStatus: http.StatusConflict, + ErrorCode: ErrorCodeConflict, + Message: "error: " + sentinel.Error(), + }, + }, + } + + for idx, test := range tests { + t.Logf("tests #%v - from %v exp %#v", idx, test.from, test.exp) + + require.Error(t, test.exp) + require.Error(t, test.from) + + exp := test.exp + got, ok := test.from.(*HTTPError) + if !ok { + t.Fatalf("exp type %T, got %v", got, test.from) + } + + require.Equal(t, exp.HTTPStatus, got.HTTPStatus) + require.Equal(t, exp.ErrorCode, got.ErrorCode) + require.Equal(t, exp.Message, got.Message) + require.Equal(t, exp.Error(), got.Error()) + require.Equal(t, exp.Cause(), got.Cause()) + } + + // test Error() with internal message + { + err := NewHTTPError( + http.StatusBadRequest, + ErrorCodeBadJSON, + "Unable to parse JSON: %v", + errors.New("bad syntax"), + ).WithInternalError(sentinel).WithInternalMessage(sentinel.Error()) + + require.Equal(t, err.Error(), sentinel.Error()) + require.Equal(t, err.Cause(), sentinel) + require.Equal(t, err.Is(sentinel), true) + } +} + +func TestOAuthErrors(t *testing.T) { + sentinel := errors.New("sentinel") + + { + err := NewOAuthError( + "oauth error", + "oauth desc", + ) + + require.Error(t, err) + require.Equal(t, err.Error(), "oauth error: oauth desc") + require.Equal(t, err.Cause(), err) + } + + // test Error() with internal message + { + err := NewOAuthError( + "oauth error", + "oauth desc", + ).WithInternalError(sentinel).WithInternalMessage(sentinel.Error()) + + require.Error(t, err) + require.Equal(t, err.Error(), sentinel.Error()) + require.Equal(t, err.Cause(), sentinel) + } +} diff --git a/internal/api/audit.go b/internal/api/audit.go index e2d71bbb4..665af1dd7 100644 --- a/internal/api/audit.go +++ b/internal/api/audit.go @@ -21,7 +21,7 @@ func (a *API) adminAuditLog(w http.ResponseWriter, r *http.Request) error { // aud := a.requestAud(ctx, r) pageParams, err := paginate(r) if err != nil { - return badRequestError(apierrors.ErrorCodeValidationFailed, "Bad Pagination Parameters: %v", err) + return apierrors.NewBadRequestError(apierrors.ErrorCodeValidationFailed, "Bad Pagination Parameters: %v", err) } var col []string @@ -32,14 +32,14 @@ func (a *API) adminAuditLog(w http.ResponseWriter, r *http.Request) error { qparts := strings.SplitN(q, ":", 2) col, exists = filterColumnMap[qparts[0]] if !exists || len(qparts) < 2 { - return badRequestError(apierrors.ErrorCodeValidationFailed, "Invalid query scope: %s", q) + return apierrors.NewBadRequestError(apierrors.ErrorCodeValidationFailed, "Invalid query scope: %s", q) } qval = qparts[1] } logs, err := models.FindAuditLogEntries(db, col, qval, pageParams) if err != nil { - return internalServerError("Error searching for audit logs").WithInternalError(err) + return apierrors.NewInternalServerError("Error searching for audit logs").WithInternalError(err) } addPaginationHeaders(w, r, pageParams) diff --git a/internal/api/auth.go b/internal/api/auth.go index be2400d2f..7964b4df8 100644 --- a/internal/api/auth.go +++ b/internal/api/auth.go @@ -37,7 +37,7 @@ func (a *API) requireNotAnonymous(w http.ResponseWriter, r *http.Request) (conte ctx := r.Context() claims := getClaims(ctx) if claims.IsAnonymous { - return nil, forbiddenError(apierrors.ErrorCodeNoAuthorization, "Anonymous user not allowed to perform these actions") + return nil, apierrors.NewForbiddenError(apierrors.ErrorCodeNoAuthorization, "Anonymous user not allowed to perform these actions") } return ctx, nil } @@ -46,7 +46,7 @@ func (a *API) requireAdmin(ctx context.Context) (context.Context, error) { // Find the administrative user claims := getClaims(ctx) if claims == nil { - return nil, forbiddenError(apierrors.ErrorCodeBadJWT, "Invalid token") + return nil, apierrors.NewForbiddenError(apierrors.ErrorCodeBadJWT, "Invalid token") } adminRoles := a.config.JWT.AdminRoles @@ -56,14 +56,14 @@ func (a *API) requireAdmin(ctx context.Context) (context.Context, error) { return withAdminUser(ctx, &models.User{Role: claims.Role, Email: storage.NullString(claims.Role)}), nil } - return nil, forbiddenError(apierrors.ErrorCodeNotAdmin, "User not allowed").WithInternalMessage(fmt.Sprintf("this token needs to have one of the following roles: %v", strings.Join(adminRoles, ", "))) + return nil, apierrors.NewForbiddenError(apierrors.ErrorCodeNotAdmin, "User not allowed").WithInternalMessage(fmt.Sprintf("this token needs to have one of the following roles: %v", strings.Join(adminRoles, ", "))) } func (a *API) extractBearerToken(r *http.Request) (string, error) { authHeader := r.Header.Get("Authorization") matches := bearerRegexp.FindStringSubmatch(authHeader) if len(matches) != 2 { - return "", httpError(http.StatusUnauthorized, apierrors.ErrorCodeNoAuthorization, "This endpoint requires a Bearer token") + return "", apierrors.NewHTTPError(http.StatusUnauthorized, apierrors.ErrorCodeNoAuthorization, "This endpoint requires a Bearer token") } return matches[1], nil @@ -89,7 +89,7 @@ func (a *API) parseJWTClaims(bearer string, r *http.Request) (context.Context, e return nil, fmt.Errorf("missing kid") }) if err != nil { - return nil, forbiddenError(apierrors.ErrorCodeBadJWT, "invalid JWT: unable to parse or verify signature, %v", err).WithInternalError(err) + return nil, apierrors.NewForbiddenError(apierrors.ErrorCodeBadJWT, "invalid JWT: unable to parse or verify signature, %v", err).WithInternalError(err) } return withToken(ctx, token), nil @@ -100,23 +100,23 @@ func (a *API) maybeLoadUserOrSession(ctx context.Context) (context.Context, erro claims := getClaims(ctx) if claims == nil { - return ctx, forbiddenError(apierrors.ErrorCodeBadJWT, "invalid token: missing claims") + return ctx, apierrors.NewForbiddenError(apierrors.ErrorCodeBadJWT, "invalid token: missing claims") } if claims.Subject == "" { - return nil, forbiddenError(apierrors.ErrorCodeBadJWT, "invalid claim: missing sub claim") + return nil, apierrors.NewForbiddenError(apierrors.ErrorCodeBadJWT, "invalid claim: missing sub claim") } var user *models.User if claims.Subject != "" { userId, err := uuid.FromString(claims.Subject) if err != nil { - return ctx, badRequestError(apierrors.ErrorCodeBadJWT, "invalid claim: sub claim must be a UUID").WithInternalError(err) + return ctx, apierrors.NewBadRequestError(apierrors.ErrorCodeBadJWT, "invalid claim: sub claim must be a UUID").WithInternalError(err) } user, err = models.FindUserByID(db, userId) if err != nil { if models.IsNotFoundError(err) { - return ctx, forbiddenError(apierrors.ErrorCodeUserNotFound, "User from sub claim in JWT does not exist") + return ctx, apierrors.NewForbiddenError(apierrors.ErrorCodeUserNotFound, "User from sub claim in JWT does not exist") } return ctx, err } @@ -127,12 +127,12 @@ func (a *API) maybeLoadUserOrSession(ctx context.Context) (context.Context, erro if claims.SessionId != "" && claims.SessionId != uuid.Nil.String() { sessionId, err := uuid.FromString(claims.SessionId) if err != nil { - return ctx, forbiddenError(apierrors.ErrorCodeBadJWT, "invalid claim: session_id claim must be a UUID").WithInternalError(err) + return ctx, apierrors.NewForbiddenError(apierrors.ErrorCodeBadJWT, "invalid claim: session_id claim must be a UUID").WithInternalError(err) } session, err = models.FindSessionByID(db, sessionId, false) if err != nil { if models.IsNotFoundError(err) { - return ctx, forbiddenError(apierrors.ErrorCodeSessionNotFound, "Session from session_id claim in JWT does not exist").WithInternalError(err).WithInternalMessage(fmt.Sprintf("session id (%s) doesn't exist", sessionId)) + return ctx, apierrors.NewForbiddenError(apierrors.ErrorCodeSessionNotFound, "Session from session_id claim in JWT does not exist").WithInternalError(err).WithInternalMessage(fmt.Sprintf("session id (%s) doesn't exist", sessionId)) } return ctx, err } diff --git a/internal/api/auth_test.go b/internal/api/auth_test.go index 4d75184da..d173ff1d9 100644 --- a/internal/api/auth_test.go +++ b/internal/api/auth_test.go @@ -185,7 +185,7 @@ func (ts *AuthTestSuite) TestMaybeLoadUserOrSession() { }, Role: "authenticated", }, - ExpectedError: forbiddenError(apierrors.ErrorCodeBadJWT, "invalid claim: missing sub claim"), + ExpectedError: apierrors.NewForbiddenError(apierrors.ErrorCodeBadJWT, "invalid claim: missing sub claim"), ExpectedUser: nil, }, { @@ -207,7 +207,7 @@ func (ts *AuthTestSuite) TestMaybeLoadUserOrSession() { }, Role: "authenticated", }, - ExpectedError: badRequestError(apierrors.ErrorCodeBadJWT, "invalid claim: sub claim must be a UUID"), + ExpectedError: apierrors.NewBadRequestError(apierrors.ErrorCodeBadJWT, "invalid claim: sub claim must be a UUID"), ExpectedUser: nil, }, { @@ -256,7 +256,7 @@ func (ts *AuthTestSuite) TestMaybeLoadUserOrSession() { Role: "authenticated", SessionId: "73bf9ee0-9e8c-453b-b484-09cb93e2f341", }, - ExpectedError: forbiddenError(apierrors.ErrorCodeSessionNotFound, "Session from session_id claim in JWT does not exist").WithInternalError(models.SessionNotFoundError{}).WithInternalMessage("session id (73bf9ee0-9e8c-453b-b484-09cb93e2f341) doesn't exist"), + ExpectedError: apierrors.NewForbiddenError(apierrors.ErrorCodeSessionNotFound, "Session from session_id claim in JWT does not exist").WithInternalError(models.SessionNotFoundError{}).WithInternalMessage("session id (73bf9ee0-9e8c-453b-b484-09cb93e2f341) doesn't exist"), ExpectedUser: u, ExpectedSession: nil, }, diff --git a/internal/api/errors.go b/internal/api/errors.go index 2e1606a1a..7479f9f03 100644 --- a/internal/api/errors.go +++ b/internal/api/errors.go @@ -37,42 +37,6 @@ type ( OAuthError = apierrors.OAuthError ) -func oauthError(err string, description string) *OAuthError { - return apierrors.NewOAuthError(err, description) -} - -func httpError(httpStatus int, errorCode apierrors.ErrorCode, fmtString string, args ...any) *HTTPError { - return apierrors.NewHTTPError(httpStatus, errorCode, fmtString, args...) -} - -func badRequestError(errorCode apierrors.ErrorCode, fmtString string, args ...interface{}) *HTTPError { - return httpError(http.StatusBadRequest, errorCode, fmtString, args...) -} - -func internalServerError(fmtString string, args ...interface{}) *HTTPError { - return httpError(http.StatusInternalServerError, apierrors.ErrorCodeUnexpectedFailure, fmtString, args...) -} - -func notFoundError(errorCode apierrors.ErrorCode, fmtString string, args ...interface{}) *HTTPError { - return httpError(http.StatusNotFound, errorCode, fmtString, args...) -} - -func forbiddenError(errorCode apierrors.ErrorCode, fmtString string, args ...interface{}) *HTTPError { - return httpError(http.StatusForbidden, errorCode, fmtString, args...) -} - -func unprocessableEntityError(errorCode apierrors.ErrorCode, fmtString string, args ...interface{}) *HTTPError { - return httpError(http.StatusUnprocessableEntity, errorCode, fmtString, args...) -} - -func tooManyRequestsError(errorCode apierrors.ErrorCode, fmtString string, args ...interface{}) *HTTPError { - return httpError(http.StatusTooManyRequests, errorCode, fmtString, args...) -} - -func conflictError(fmtString string, args ...interface{}) *HTTPError { - return httpError(http.StatusConflict, apierrors.ErrorCodeConflict, fmtString, args...) -} - // Recoverer is a middleware that recovers from panics, logs the panic (and a // backtrace), and returns a HTTP 500 (Internal Server Error) status if // possible. Recoverer prints a request ID if one is provided. diff --git a/internal/api/errors_test.go b/internal/api/errors_test.go index 7afdb9ca2..dc4f34dbf 100644 --- a/internal/api/errors_test.go +++ b/internal/api/errors_test.go @@ -21,17 +21,17 @@ func TestHandleResponseErrorWithHTTPError(t *testing.T) { ExpectedBody string }{ { - HTTPError: badRequestError(apierrors.ErrorCodeBadJSON, "Unable to parse JSON"), + HTTPError: apierrors.NewBadRequestError(apierrors.ErrorCodeBadJSON, "Unable to parse JSON"), APIVersion: "", ExpectedBody: "{\"code\":400,\"error_code\":\"" + apierrors.ErrorCodeBadJSON + "\",\"msg\":\"Unable to parse JSON\"}", }, { - HTTPError: badRequestError(apierrors.ErrorCodeBadJSON, "Unable to parse JSON"), + HTTPError: apierrors.NewBadRequestError(apierrors.ErrorCodeBadJSON, "Unable to parse JSON"), APIVersion: "2023-12-31", ExpectedBody: "{\"code\":400,\"error_code\":\"" + apierrors.ErrorCodeBadJSON + "\",\"msg\":\"Unable to parse JSON\"}", }, { - HTTPError: badRequestError(apierrors.ErrorCodeBadJSON, "Unable to parse JSON"), + HTTPError: apierrors.NewBadRequestError(apierrors.ErrorCodeBadJSON, "Unable to parse JSON"), APIVersion: "2024-01-01", ExpectedBody: "{\"code\":\"" + apierrors.ErrorCodeBadJSON + "\",\"message\":\"Unable to parse JSON\"}", }, diff --git a/internal/api/external.go b/internal/api/external.go index f6dc2384d..dfaa86a03 100644 --- a/internal/api/external.go +++ b/internal/api/external.go @@ -56,7 +56,7 @@ func (a *API) GetExternalProviderRedirectURL(w http.ResponseWriter, r *http.Requ p, err := a.Provider(ctx, providerType, scopes) if err != nil { - return "", badRequestError(apierrors.ErrorCodeValidationFailed, "Unsupported provider: %+v", err).WithInternalError(err) + return "", apierrors.NewBadRequestError(apierrors.ErrorCodeValidationFailed, "Unsupported provider: %+v", err).WithInternalError(err) } inviteToken := query.Get("invite_token") @@ -64,9 +64,9 @@ func (a *API) GetExternalProviderRedirectURL(w http.ResponseWriter, r *http.Requ _, userErr := models.FindUserByConfirmationToken(db, inviteToken) if userErr != nil { if models.IsNotFoundError(userErr) { - return "", notFoundError(apierrors.ErrorCodeUserNotFound, "User identified by token not found") + return "", apierrors.NewNotFoundError(apierrors.ErrorCodeUserNotFound, "User identified by token not found") } - return "", internalServerError("Database error finding user").WithInternalError(userErr) + return "", apierrors.NewInternalServerError("Database error finding user").WithInternalError(userErr) } } @@ -108,7 +108,7 @@ func (a *API) GetExternalProviderRedirectURL(w http.ResponseWriter, r *http.Requ tokenString, err := signJwt(&config.JWT, claims) if err != nil { - return "", internalServerError("Error creating state").WithInternalError(err) + return "", apierrors.NewInternalServerError("Error creating state").WithInternalError(err) } authUrlParams := make([]oauth2.AuthCodeOption, 0) @@ -175,7 +175,7 @@ func (a *API) internalExternalProviderCallback(w http.ResponseWriter, r *http.Re userData := data.userData if len(userData.Emails) <= 0 { - return internalServerError("Error getting user email from external provider") + return apierrors.NewInternalServerError("Error getting user email from external provider") } userData.Metadata.EmailVerified = false for _, email := range userData.Emails { @@ -196,9 +196,9 @@ func (a *API) internalExternalProviderCallback(w http.ResponseWriter, r *http.Re if flowStateID := getFlowStateID(ctx); flowStateID != "" { flowState, err = models.FindFlowStateByID(a.db, flowStateID) if models.IsNotFoundError(err) { - return unprocessableEntityError(apierrors.ErrorCodeFlowStateNotFound, "Flow state not found").WithInternalError(err) + return apierrors.NewUnprocessableEntityError(apierrors.ErrorCodeFlowStateNotFound, "Flow state not found").WithInternalError(err) } else if err != nil { - return internalServerError("Failed to find flow state").WithInternalError(err) + return apierrors.NewInternalServerError("Failed to find flow state").WithInternalError(err) } } @@ -234,7 +234,7 @@ func (a *API) internalExternalProviderCallback(w http.ResponseWriter, r *http.Re } if terr != nil { - return oauthError("server_error", terr.Error()) + return apierrors.NewOAuthError("server_error", terr.Error()) } return nil }) @@ -303,7 +303,7 @@ func (a *API) createAccountFromExternalIdentity(tx *storage.Connection, r *http. case models.CreateAccount: if config.DisableSignup { - return nil, unprocessableEntityError(apierrors.ErrorCodeSignupDisabled, "Signups not allowed for this instance") + return nil, apierrors.NewUnprocessableEntityError(apierrors.ErrorCodeSignupDisabled, "Signups not allowed for this instance") } params := &SignupParams{ @@ -350,14 +350,14 @@ func (a *API) createAccountFromExternalIdentity(tx *storage.Connection, r *http. } case models.MultipleAccounts: - return nil, internalServerError("Multiple accounts with the same email address in the same linking domain detected: %v", decision.LinkingDomain) + return nil, apierrors.NewInternalServerError("Multiple accounts with the same email address in the same linking domain detected: %v", decision.LinkingDomain) default: - return nil, internalServerError("Unknown automatic linking decision: %v", decision.Decision) + return nil, apierrors.NewInternalServerError("Unknown automatic linking decision: %v", decision.Decision) } if user.IsBanned() { - return nil, forbiddenError(apierrors.ErrorCodeUserBanned, "User is banned") + return nil, apierrors.NewForbiddenError(apierrors.ErrorCodeUserBanned, "User is banned") } // TODO(hf): Expand this boolean with all providers that may not have emails (like X/Twitter, Discord). @@ -369,7 +369,7 @@ func (a *API) createAccountFromExternalIdentity(tx *storage.Connection, r *http. // need to be removed when a new oauth identity is being added // to prevent pre-account takeover attacks from happening. if terr = user.RemoveUnconfirmedIdentities(tx, identity); terr != nil { - return nil, internalServerError("Error updating user").WithInternalError(terr) + return nil, apierrors.NewInternalServerError("Error updating user").WithInternalError(terr) } if decision.CandidateEmail.Verified || config.Mailer.Autoconfirm { if terr := models.NewAuditLogEntry(r, tx, user, models.UserSignedUpAction, "", map[string]interface{}{ @@ -379,7 +379,7 @@ func (a *API) createAccountFromExternalIdentity(tx *storage.Connection, r *http. } // fall through to auto-confirm and issue token if terr = user.Confirm(tx); terr != nil { - return nil, internalServerError("Error updating user").WithInternalError(terr) + return nil, apierrors.NewInternalServerError("Error updating user").WithInternalError(terr) } } else { // Some providers, like web3 don't have email data. @@ -395,9 +395,9 @@ func (a *API) createAccountFromExternalIdentity(tx *storage.Connection, r *http. } if !config.Mailer.AllowUnverifiedEmailSignIns { if emailConfirmationSent { - return nil, storage.NewCommitWithError(unprocessableEntityError(apierrors.ErrorCodeProviderEmailNeedsVerification, fmt.Sprintf("Unverified email with %v. A confirmation email has been sent to your %v email", providerType, providerType))) + return nil, storage.NewCommitWithError(apierrors.NewUnprocessableEntityError(apierrors.ErrorCodeProviderEmailNeedsVerification, fmt.Sprintf("Unverified email with %v. A confirmation email has been sent to your %v email", providerType, providerType))) } - return nil, storage.NewCommitWithError(unprocessableEntityError(apierrors.ErrorCodeProviderEmailNeedsVerification, fmt.Sprintf("Unverified email with %v. Verify the email with %v in order to sign in", providerType, providerType))) + return nil, storage.NewCommitWithError(apierrors.NewUnprocessableEntityError(apierrors.ErrorCodeProviderEmailNeedsVerification, fmt.Sprintf("Unverified email with %v. Verify the email with %v in order to sign in", providerType, providerType))) } } } else { @@ -415,9 +415,9 @@ func (a *API) processInvite(r *http.Request, tx *storage.Connection, userData *p user, err := models.FindUserByConfirmationToken(tx, inviteToken) if err != nil { if models.IsNotFoundError(err) { - return nil, notFoundError(apierrors.ErrorCodeInviteNotFound, "Invite not found") + return nil, apierrors.NewNotFoundError(apierrors.ErrorCodeInviteNotFound, "Invite not found") } - return nil, internalServerError("Database error finding user").WithInternalError(err) + return nil, apierrors.NewInternalServerError("Database error finding user").WithInternalError(err) } var emailData *provider.Email @@ -431,7 +431,7 @@ func (a *API) processInvite(r *http.Request, tx *storage.Connection, userData *p } if emailData == nil { - return nil, badRequestError(apierrors.ErrorCodeValidationFailed, "Invited email does not match emails from external provider").WithInternalMessage("invited=%s external=%s", user.Email, strings.Join(emails, ", ")) + return nil, apierrors.NewBadRequestError(apierrors.ErrorCodeValidationFailed, "Invited email does not match emails from external provider").WithInternalMessage("invited=%s external=%s", user.Email, strings.Join(emails, ", ")) } var identityData map[string]interface{} @@ -451,7 +451,7 @@ func (a *API) processInvite(r *http.Request, tx *storage.Connection, userData *p return nil, err } if err := user.UpdateUserMetaData(tx, identityData); err != nil { - return nil, internalServerError("Database error updating user").WithInternalError(err) + return nil, apierrors.NewInternalServerError("Database error updating user").WithInternalError(err) } if err := models.NewAuditLogEntry(r, tx, user, models.InviteAcceptedAction, "", map[string]interface{}{ @@ -467,7 +467,7 @@ func (a *API) processInvite(r *http.Request, tx *storage.Connection, userData *p // potentially malicious door exists into their account; thus // the password and phone needs to be removed. if err := user.RemoveUnconfirmedIdentities(tx, identity); err != nil { - return nil, internalServerError("Error updating user").WithInternalError(err) + return nil, apierrors.NewInternalServerError("Error updating user").WithInternalError(err) } // confirm because they were able to respond to invite email @@ -486,7 +486,7 @@ func (a *API) loadExternalState(ctx context.Context, r *http.Request) (context.C state = r.URL.Query().Get("state") } if state == "" { - return ctx, badRequestError(apierrors.ErrorCodeBadOAuthCallback, "OAuth state parameter missing") + return ctx, apierrors.NewBadRequestError(apierrors.ErrorCodeBadOAuthCallback, "OAuth state parameter missing") } config := a.config claims := ExternalProviderClaims{} @@ -506,10 +506,10 @@ func (a *API) loadExternalState(ctx context.Context, r *http.Request) (context.C return nil, fmt.Errorf("missing kid") }) if err != nil { - return ctx, badRequestError(apierrors.ErrorCodeBadOAuthState, "OAuth callback with invalid state").WithInternalError(err) + return ctx, apierrors.NewBadRequestError(apierrors.ErrorCodeBadOAuthState, "OAuth callback with invalid state").WithInternalError(err) } if claims.Provider == "" { - return ctx, badRequestError(apierrors.ErrorCodeBadOAuthState, "OAuth callback with invalid state (missing provider)") + return ctx, apierrors.NewBadRequestError(apierrors.ErrorCodeBadOAuthState, "OAuth callback with invalid state (missing provider)") } if claims.InviteToken != "" { ctx = withInviteToken(ctx, claims.InviteToken) @@ -523,14 +523,14 @@ func (a *API) loadExternalState(ctx context.Context, r *http.Request) (context.C if claims.LinkingTargetID != "" { linkingTargetUserID, err := uuid.FromString(claims.LinkingTargetID) if err != nil { - return nil, badRequestError(apierrors.ErrorCodeBadOAuthState, "OAuth callback with invalid state (linking_target_id must be UUID)") + return nil, apierrors.NewBadRequestError(apierrors.ErrorCodeBadOAuthState, "OAuth callback with invalid state (linking_target_id must be UUID)") } u, err := models.FindUserByID(a.db, linkingTargetUserID) if err != nil { if models.IsNotFoundError(err) { - return nil, unprocessableEntityError(apierrors.ErrorCodeUserNotFound, "Linking target user not found") + return nil, apierrors.NewUnprocessableEntityError(apierrors.ErrorCodeUserNotFound, "Linking target user not found") } - return nil, internalServerError("Database error loading user").WithInternalError(err) + return nil, apierrors.NewInternalServerError("Database error loading user").WithInternalError(err) } ctx = withTargetUser(ctx, u) } @@ -685,7 +685,7 @@ func (a *API) createNewIdentity(tx *storage.Connection, user *models.User, provi } if terr := tx.Create(identity); terr != nil { - return nil, internalServerError("Error creating identity").WithInternalError(terr) + return nil, apierrors.NewInternalServerError("Error creating identity").WithInternalError(terr) } return identity, nil diff --git a/internal/api/external_oauth.go b/internal/api/external_oauth.go index 3b8f2f904..68befa5f8 100644 --- a/internal/api/external_oauth.go +++ b/internal/api/external_oauth.go @@ -40,7 +40,7 @@ func (a *API) loadFlowState(w http.ResponseWriter, r *http.Request) (context.Con if err != nil { u, uerr := url.ParseRequestURI(a.config.SiteURL) if uerr != nil { - return ctx, internalServerError("site url is improperly formatted").WithInternalError(uerr) + return ctx, apierrors.NewInternalServerError("site url is improperly formatted").WithInternalError(uerr) } q := getErrorQueryString(err, utilities.GetRequestID(ctx), observability.GetLogEntry(r).Entry, u.Query()) @@ -61,17 +61,17 @@ func (a *API) oAuthCallback(ctx context.Context, r *http.Request, providerType s extError := rq.Get("error") if extError != "" { - return nil, oauthError(extError, rq.Get("error_description")) + return nil, apierrors.NewOAuthError(extError, rq.Get("error_description")) } oauthCode := rq.Get("code") if oauthCode == "" { - return nil, badRequestError(apierrors.ErrorCodeBadOAuthCallback, "OAuth callback with missing authorization code missing") + return nil, apierrors.NewBadRequestError(apierrors.ErrorCodeBadOAuthCallback, "OAuth callback with missing authorization code missing") } oAuthProvider, err := a.OAuthProvider(ctx, providerType) if err != nil { - return nil, badRequestError(apierrors.ErrorCodeOAuthProviderNotSupported, "Unsupported provider: %+v", err).WithInternalError(err) + return nil, apierrors.NewBadRequestError(apierrors.ErrorCodeOAuthProviderNotSupported, "Unsupported provider: %+v", err).WithInternalError(err) } log := observability.GetLogEntry(r).Entry @@ -82,12 +82,12 @@ func (a *API) oAuthCallback(ctx context.Context, r *http.Request, providerType s token, err := oAuthProvider.GetOAuthToken(oauthCode) if err != nil { - return nil, internalServerError("Unable to exchange external code: %s", oauthCode).WithInternalError(err) + return nil, apierrors.NewInternalServerError("Unable to exchange external code: %s", oauthCode).WithInternalError(err) } userData, err := oAuthProvider.GetUserData(ctx, token) if err != nil { - return nil, internalServerError("Error getting user profile from external provider").WithInternalError(err) + return nil, apierrors.NewInternalServerError("Error getting user profile from external provider").WithInternalError(err) } switch externalProvider := oAuthProvider.(type) { @@ -113,7 +113,7 @@ func (a *API) oAuthCallback(ctx context.Context, r *http.Request, providerType s func (a *API) oAuth1Callback(ctx context.Context, providerType string) (*OAuthProviderData, error) { oAuthProvider, err := a.OAuthProvider(ctx, providerType) if err != nil { - return nil, badRequestError(apierrors.ErrorCodeOAuthProviderNotSupported, "Unsupported provider: %+v", err).WithInternalError(err) + return nil, apierrors.NewBadRequestError(apierrors.ErrorCodeOAuthProviderNotSupported, "Unsupported provider: %+v", err).WithInternalError(err) } oauthToken := getRequestToken(ctx) oauthVerifier := getOAuthVerifier(ctx) @@ -124,11 +124,11 @@ func (a *API) oAuth1Callback(ctx context.Context, providerType string) (*OAuthPr Token: oauthToken, }, oauthVerifier) if err != nil { - return nil, internalServerError("Unable to retrieve access token").WithInternalError(err) + return nil, apierrors.NewInternalServerError("Unable to retrieve access token").WithInternalError(err) } userData, err = twitterProvider.FetchUserData(ctx, accessToken) if err != nil { - return nil, internalServerError("Error getting user email from external provider").WithInternalError(err) + return nil, apierrors.NewInternalServerError("Error getting user email from external provider").WithInternalError(err) } } diff --git a/internal/api/helpers.go b/internal/api/helpers.go index 965e51f81..ae278c8c1 100644 --- a/internal/api/helpers.go +++ b/internal/api/helpers.go @@ -98,10 +98,10 @@ type RequestParams interface { func retrieveRequestParams[A RequestParams](r *http.Request, params *A) error { body, err := utilities.GetBodyBytes(r) if err != nil { - return internalServerError("Could not read body into byte slice").WithInternalError(err) + return apierrors.NewInternalServerError("Could not read body into byte slice").WithInternalError(err) } if err := json.Unmarshal(body, params); err != nil { - return badRequestError(apierrors.ErrorCodeBadJSON, "Could not parse request body as JSON: %v", err) + return apierrors.NewBadRequestError(apierrors.ErrorCodeBadJSON, "Could not parse request body as JSON: %v", err) } return nil } diff --git a/internal/api/helpers_test.go b/internal/api/helpers_test.go index b75cab833..057277c15 100644 --- a/internal/api/helpers_test.go +++ b/internal/api/helpers_test.go @@ -22,12 +22,12 @@ func TestIsValidCodeChallenge(t *testing.T) { { challenge: "invalid", isValid: false, - expectedError: badRequestError(apierrors.ErrorCodeValidationFailed, "code challenge has to be between %v and %v characters", MinCodeChallengeLength, MaxCodeChallengeLength), + expectedError: apierrors.NewBadRequestError(apierrors.ErrorCodeValidationFailed, "code challenge has to be between %v and %v characters", MinCodeChallengeLength, MaxCodeChallengeLength), }, { challenge: "codechallengecontainsinvalidcharacterslike@$^&*", isValid: false, - expectedError: badRequestError(apierrors.ErrorCodeValidationFailed, "code challenge can only contain alphanumeric characters, hyphens, periods, underscores and tildes"), + expectedError: apierrors.NewBadRequestError(apierrors.ErrorCodeValidationFailed, "code challenge can only contain alphanumeric characters, hyphens, periods, underscores and tildes"), }, { challenge: "validchallengevalidchallengevalidchallengevalidchallenge", @@ -62,12 +62,12 @@ func TestIsValidPKCEParams(t *testing.T) { { challengeMethod: "test", challenge: "", - expected: badRequestError(apierrors.ErrorCodeValidationFailed, InvalidPKCEParamsErrorMessage), + expected: apierrors.NewBadRequestError(apierrors.ErrorCodeValidationFailed, InvalidPKCEParamsErrorMessage), }, { challengeMethod: "", challenge: "test", - expected: badRequestError(apierrors.ErrorCodeValidationFailed, InvalidPKCEParamsErrorMessage), + expected: apierrors.NewBadRequestError(apierrors.ErrorCodeValidationFailed, InvalidPKCEParamsErrorMessage), }, } diff --git a/internal/api/hooks_test.go b/internal/api/hooks_test.go index c78ce5f2f..9a3097de8 100644 --- a/internal/api/hooks_test.go +++ b/internal/api/hooks_test.go @@ -12,7 +12,7 @@ import ( "github.com/stretchr/testify/require" "github.com/stretchr/testify/suite" "github.com/supabase/auth/internal/conf" - "github.com/supabase/auth/internal/hooks" + "github.com/supabase/auth/internal/hooks/v0hooks" "github.com/supabase/auth/internal/models" "github.com/supabase/auth/internal/storage" @@ -61,16 +61,16 @@ func (ts *HooksTestSuite) TestRunHTTPHook() { // setup mock requests for hooks defer gock.OffAll() - input := hooks.SendSMSInput{ + input := v0hooks.SendSMSInput{ User: ts.TestUser, - SMS: hooks.SMS{ + SMS: v0hooks.SMS{ OTP: "123456", }, } testURL := "http://localhost:54321/functions/v1/custom-sms-sender" ts.Config.Hook.SendSMS.URI = testURL - unsuccessfulResponse := hooks.AuthHookError{ + unsuccessfulResponse := v0hooks.AuthHookError{ HTTPCode: http.StatusUnprocessableEntity, Message: "test error", } @@ -78,12 +78,12 @@ func (ts *HooksTestSuite) TestRunHTTPHook() { testCases := []struct { description string expectError bool - mockResponse hooks.AuthHookError + mockResponse v0hooks.AuthHookError }{ { description: "Hook returns success", expectError: false, - mockResponse: hooks.AuthHookError{}, + mockResponse: v0hooks.AuthHookError{}, }, { description: "Hook returns error", @@ -96,25 +96,25 @@ func (ts *HooksTestSuite) TestRunHTTPHook() { Post("/"). MatchType("json"). Reply(http.StatusOK). - JSON(hooks.SendSMSOutput{}) + JSON(v0hooks.SendSMSOutput{}) gock.New(ts.Config.Hook.SendSMS.URI). Post("/"). MatchType("json"). Reply(http.StatusUnprocessableEntity). - JSON(hooks.SendSMSOutput{HookError: unsuccessfulResponse}) + JSON(v0hooks.SendSMSOutput{HookError: unsuccessfulResponse}) for _, tc := range testCases { ts.Run(tc.description, func() { req, _ := http.NewRequest("POST", ts.Config.Hook.SendSMS.URI, nil) - body, err := ts.API.runHTTPHook(req, ts.Config.Hook.SendSMS, &input) + body, err := ts.API.hooksMgr.RunHTTPHook(req, ts.Config.Hook.SendSMS, &input) if !tc.expectError { require.NoError(ts.T(), err) } else { require.Error(ts.T(), err) if body != nil { - var output hooks.SendSMSOutput + var output v0hooks.SendSMSOutput require.NoError(ts.T(), json.Unmarshal(body, &output)) require.Equal(ts.T(), unsuccessfulResponse.HTTPCode, output.HookError.HTTPCode) require.Equal(ts.T(), unsuccessfulResponse.Message, output.HookError.Message) @@ -128,9 +128,9 @@ func (ts *HooksTestSuite) TestRunHTTPHook() { func (ts *HooksTestSuite) TestShouldRetryWithRetryAfterHeader() { defer gock.OffAll() - input := hooks.SendSMSInput{ + input := v0hooks.SendSMSInput{ User: ts.TestUser, - SMS: hooks.SMS{ + SMS: v0hooks.SMS{ OTP: "123456", }, } @@ -148,16 +148,16 @@ func (ts *HooksTestSuite) TestShouldRetryWithRetryAfterHeader() { Post("/"). MatchType("json"). Reply(http.StatusOK). - JSON(hooks.SendSMSOutput{}).SetHeader("content-type", "application/json") + JSON(v0hooks.SendSMSOutput{}).SetHeader("content-type", "application/json") // Simulate the original HTTP request which triggered the hook req, err := http.NewRequest("POST", "http://localhost:9998/otp", nil) require.NoError(ts.T(), err) - body, err := ts.API.runHTTPHook(req, ts.Config.Hook.SendSMS, &input) + body, err := ts.API.hooksMgr.RunHTTPHook(req, ts.Config.Hook.SendSMS, &input) require.NoError(ts.T(), err) - var output hooks.SendSMSOutput + var output v0hooks.SendSMSOutput err = json.Unmarshal(body, &output) require.NoError(ts.T(), err, "Unmarshal should not fail") @@ -168,9 +168,9 @@ func (ts *HooksTestSuite) TestShouldRetryWithRetryAfterHeader() { func (ts *HooksTestSuite) TestShouldReturnErrorForNonJSONContentType() { defer gock.OffAll() - input := hooks.SendSMSInput{ + input := v0hooks.SendSMSInput{ User: ts.TestUser, - SMS: hooks.SMS{ + SMS: v0hooks.SMS{ OTP: "123456", }, } @@ -186,7 +186,7 @@ func (ts *HooksTestSuite) TestShouldReturnErrorForNonJSONContentType() { req, err := http.NewRequest("POST", "http://localhost:9999/otp", nil) require.NoError(ts.T(), err) - _, err = ts.API.runHTTPHook(req, ts.Config.Hook.SendSMS, &input) + _, err = ts.API.hooksMgr.RunHTTPHook(req, ts.Config.Hook.SendSMS, &input) require.Error(ts.T(), err, "Expected an error due to wrong content type") require.Contains(ts.T(), err.Error(), "Invalid JSON response.") @@ -234,32 +234,32 @@ func (ts *HooksTestSuite) TestInvokeHookIntegration() { description: "HTTP endpoint success", conn: nil, request: httptest.NewRequest("POST", authEndpoint, nil), - input: &hooks.SendEmailInput{}, - output: &hooks.SendEmailOutput{}, + input: &v0hooks.SendEmailInput{}, + output: &v0hooks.SendEmailOutput{}, uri: testHTTPUri, }, { description: "HTTPS endpoint success", conn: nil, request: httptest.NewRequest("POST", authEndpoint, nil), - input: &hooks.SendEmailInput{}, - output: &hooks.SendEmailOutput{}, + input: &v0hooks.SendEmailInput{}, + output: &v0hooks.SendEmailOutput{}, uri: testHTTPSUri, }, { description: "PostgreSQL function success", conn: ts.API.db, request: httptest.NewRequest("POST", authEndpoint, nil), - input: &hooks.SendEmailInput{}, - output: &hooks.SendEmailOutput{}, + input: &v0hooks.SendEmailInput{}, + output: &v0hooks.SendEmailOutput{}, uri: testPGUri, }, { description: "Unsupported protocol error", conn: nil, request: httptest.NewRequest("POST", authEndpoint, nil), - input: &hooks.SendEmailInput{}, - output: &hooks.SendEmailOutput{}, + input: &v0hooks.SendEmailInput{}, + output: &v0hooks.SendEmailOutput{}, uri: "ftp://example.com/path", expectedError: errors.New("unsupported protocol: \"ftp://example.com/path\" only postgres hooks and HTTPS functions are supported at the moment"), }, @@ -273,7 +273,7 @@ func (ts *HooksTestSuite) TestInvokeHookIntegration() { require.NoError(ts.T(), ts.Config.Hook.SendEmail.PopulateExtensibilityPoint()) ts.Run(tc.description, func() { - err = ts.API.invokeHook(tc.conn, tc.request, tc.input, tc.output) + err = ts.API.hooksMgr.InvokeHook(tc.conn, tc.request, tc.input, tc.output) if tc.expectedError != nil { require.EqualError(ts.T(), err, tc.expectedError.Error()) } else { diff --git a/internal/api/identity.go b/internal/api/identity.go index a2664d72f..fc2c161ee 100644 --- a/internal/api/identity.go +++ b/internal/api/identity.go @@ -18,23 +18,23 @@ func (a *API) DeleteIdentity(w http.ResponseWriter, r *http.Request) error { claims := getClaims(ctx) if claims == nil { - return internalServerError("Could not read claims") + return apierrors.NewInternalServerError("Could not read claims") } identityID, err := uuid.FromString(chi.URLParam(r, "identity_id")) if err != nil { - return notFoundError(apierrors.ErrorCodeValidationFailed, "identity_id must be an UUID") + return apierrors.NewNotFoundError(apierrors.ErrorCodeValidationFailed, "identity_id must be an UUID") } aud := a.requestAud(ctx, r) audienceFromClaims, _ := claims.GetAudience() if len(audienceFromClaims) == 0 || aud != audienceFromClaims[0] { - return forbiddenError(apierrors.ErrorCodeUnexpectedAudience, "Token audience doesn't match request audience") + return apierrors.NewForbiddenError(apierrors.ErrorCodeUnexpectedAudience, "Token audience doesn't match request audience") } user := getUser(ctx) if len(user.Identities) <= 1 { - return unprocessableEntityError(apierrors.ErrorCodeSingleIdentityNotDeletable, "User must have at least 1 identity after unlinking") + return apierrors.NewUnprocessableEntityError(apierrors.ErrorCodeSingleIdentityNotDeletable, "User must have at least 1 identity after unlinking") } var identityToBeDeleted *models.Identity for i := range user.Identities { @@ -45,7 +45,7 @@ func (a *API) DeleteIdentity(w http.ResponseWriter, r *http.Request) error { } } if identityToBeDeleted == nil { - return unprocessableEntityError(apierrors.ErrorCodeIdentityNotFound, "Identity doesn't exist") + return apierrors.NewUnprocessableEntityError(apierrors.ErrorCodeIdentityNotFound, "Identity doesn't exist") } err = a.db.Transaction(func(tx *storage.Connection) error { @@ -54,31 +54,31 @@ func (a *API) DeleteIdentity(w http.ResponseWriter, r *http.Request) error { "provider": identityToBeDeleted.Provider, "provider_id": identityToBeDeleted.ProviderID, }); terr != nil { - return internalServerError("Error recording audit log entry").WithInternalError(terr) + return apierrors.NewInternalServerError("Error recording audit log entry").WithInternalError(terr) } if terr := tx.Destroy(identityToBeDeleted); terr != nil { - return internalServerError("Database error deleting identity").WithInternalError(terr) + return apierrors.NewInternalServerError("Database error deleting identity").WithInternalError(terr) } switch identityToBeDeleted.Provider { case "phone": user.PhoneConfirmedAt = nil if terr := user.SetPhone(tx, ""); terr != nil { - return internalServerError("Database error updating user phone").WithInternalError(terr) + return apierrors.NewInternalServerError("Database error updating user phone").WithInternalError(terr) } if terr := tx.UpdateOnly(user, "phone_confirmed_at"); terr != nil { - return internalServerError("Database error updating user phone").WithInternalError(terr) + return apierrors.NewInternalServerError("Database error updating user phone").WithInternalError(terr) } default: if terr := user.UpdateUserEmailFromIdentities(tx); terr != nil { if models.IsUniqueConstraintViolatedError(terr) { - return unprocessableEntityError(apierrors.ErrorCodeEmailConflictIdentityNotDeletable, "Unable to unlink identity due to email conflict").WithInternalError(terr) + return apierrors.NewUnprocessableEntityError(apierrors.ErrorCodeEmailConflictIdentityNotDeletable, "Unable to unlink identity due to email conflict").WithInternalError(terr) } - return internalServerError("Database error updating user email").WithInternalError(terr) + return apierrors.NewInternalServerError("Database error updating user email").WithInternalError(terr) } } if terr := user.UpdateAppMetaDataProviders(tx); terr != nil { - return internalServerError("Database error updating user providers").WithInternalError(terr) + return apierrors.NewInternalServerError("Database error updating user providers").WithInternalError(terr) } return nil }) @@ -111,14 +111,14 @@ func (a *API) linkIdentityToUser(r *http.Request, ctx context.Context, tx *stora identity, terr := models.FindIdentityByIdAndProvider(tx, userData.Metadata.Subject, providerType) if terr != nil { if !models.IsNotFoundError(terr) { - return nil, internalServerError("Database error finding identity for linking").WithInternalError(terr) + return nil, apierrors.NewInternalServerError("Database error finding identity for linking").WithInternalError(terr) } } if identity != nil { if identity.UserID == targetUser.ID { - return nil, unprocessableEntityError(apierrors.ErrorCodeIdentityAlreadyExists, "Identity is already linked") + return nil, apierrors.NewUnprocessableEntityError(apierrors.ErrorCodeIdentityAlreadyExists, "Identity is already linked") } - return nil, unprocessableEntityError(apierrors.ErrorCodeIdentityAlreadyExists, "Identity is already linked to another user") + return nil, apierrors.NewUnprocessableEntityError(apierrors.ErrorCodeIdentityAlreadyExists, "Identity is already linked to another user") } if _, terr := a.createNewIdentity(tx, targetUser, providerType, structs.Map(userData.Metadata)); terr != nil { return nil, terr @@ -127,7 +127,7 @@ func (a *API) linkIdentityToUser(r *http.Request, ctx context.Context, tx *stora if targetUser.GetEmail() == "" { if terr := targetUser.UpdateUserEmailFromIdentities(tx); terr != nil { if models.IsUniqueConstraintViolatedError(terr) { - return nil, badRequestError(apierrors.ErrorCodeEmailExists, DuplicateEmailMsg) + return nil, apierrors.NewBadRequestError(apierrors.ErrorCodeEmailExists, DuplicateEmailMsg) } return nil, terr } @@ -135,7 +135,7 @@ func (a *API) linkIdentityToUser(r *http.Request, ctx context.Context, tx *stora if terr := a.sendConfirmation(r, tx, targetUser, models.ImplicitFlow); terr != nil { return nil, terr } - return nil, storage.NewCommitWithError(unprocessableEntityError(apierrors.ErrorCodeEmailNotConfirmed, "Unverified email with %v. A confirmation email has been sent to your %v email", providerType, providerType)) + return nil, storage.NewCommitWithError(apierrors.NewUnprocessableEntityError(apierrors.ErrorCodeEmailNotConfirmed, "Unverified email with %v. A confirmation email has been sent to your %v email", providerType, providerType)) } if terr := targetUser.Confirm(tx); terr != nil { return nil, terr diff --git a/internal/api/identity_test.go b/internal/api/identity_test.go index b04a5d74f..6258860cc 100644 --- a/internal/api/identity_test.go +++ b/internal/api/identity_test.go @@ -102,7 +102,7 @@ func (ts *IdentityTestSuite) TestLinkIdentityToUser() { }, } u, err = ts.API.linkIdentityToUser(r, ctx, ts.API.db, testExistingUserData, "email") - require.ErrorIs(ts.T(), err, unprocessableEntityError(apierrors.ErrorCodeIdentityAlreadyExists, "Identity is already linked")) + require.ErrorIs(ts.T(), err, apierrors.NewUnprocessableEntityError(apierrors.ErrorCodeIdentityAlreadyExists, "Identity is already linked")) require.Nil(ts.T(), u) } @@ -123,13 +123,13 @@ func (ts *IdentityTestSuite) TestUnlinkIdentityError() { desc: "User must have at least 1 identity after unlinking", user: userWithOneIdentity, identityId: userWithOneIdentity.Identities[0].ID, - expectedError: unprocessableEntityError(apierrors.ErrorCodeSingleIdentityNotDeletable, "User must have at least 1 identity after unlinking"), + expectedError: apierrors.NewUnprocessableEntityError(apierrors.ErrorCodeSingleIdentityNotDeletable, "User must have at least 1 identity after unlinking"), }, { desc: "Identity doesn't exist", user: userWithTwoIdentities, identityId: uuid.Must(uuid.NewV4()), - expectedError: unprocessableEntityError(apierrors.ErrorCodeIdentityNotFound, "Identity doesn't exist"), + expectedError: apierrors.NewUnprocessableEntityError(apierrors.ErrorCodeIdentityNotFound, "Identity doesn't exist"), }, } diff --git a/internal/api/invite.go b/internal/api/invite.go index de35942b1..2c3f8cd7f 100644 --- a/internal/api/invite.go +++ b/internal/api/invite.go @@ -35,13 +35,13 @@ func (a *API) Invite(w http.ResponseWriter, r *http.Request) error { aud := a.requestAud(ctx, r) user, err := models.FindUserByEmailAndAudience(db, params.Email, aud) if err != nil && !models.IsNotFoundError(err) { - return internalServerError("Database error finding user").WithInternalError(err) + return apierrors.NewInternalServerError("Database error finding user").WithInternalError(err) } err = db.Transaction(func(tx *storage.Connection) error { if user != nil { if user.IsConfirmed() { - return unprocessableEntityError(apierrors.ErrorCodeEmailExists, DuplicateEmailMsg) + return apierrors.NewUnprocessableEntityError(apierrors.ErrorCodeEmailExists, DuplicateEmailMsg) } } else { signupParams := SignupParams{ diff --git a/internal/api/logout.go b/internal/api/logout.go index 1ab398e22..c7b8d8878 100644 --- a/internal/api/logout.go +++ b/internal/api/logout.go @@ -36,7 +36,7 @@ func (a *API) Logout(w http.ResponseWriter, r *http.Request) error { scope = LogoutOthers default: - return badRequestError(apierrors.ErrorCodeValidationFailed, fmt.Sprintf("Unsupported logout scope %q", r.URL.Query().Get("scope"))) + return apierrors.NewBadRequestError(apierrors.ErrorCodeValidationFailed, fmt.Sprintf("Unsupported logout scope %q", r.URL.Query().Get("scope"))) } } @@ -65,7 +65,7 @@ func (a *API) Logout(w http.ResponseWriter, r *http.Request) error { return models.Logout(tx, u.ID) }) if err != nil { - return internalServerError("Error logging out user").WithInternalError(err) + return apierrors.NewInternalServerError("Error logging out user").WithInternalError(err) } w.WriteHeader(http.StatusNoContent) diff --git a/internal/api/magic_link.go b/internal/api/magic_link.go index fa0792e02..c9fde50d3 100644 --- a/internal/api/magic_link.go +++ b/internal/api/magic_link.go @@ -24,7 +24,7 @@ type MagicLinkParams struct { func (p *MagicLinkParams) Validate(a *API) error { if p.Email == "" { - return unprocessableEntityError(apierrors.ErrorCodeValidationFailed, "Password recovery requires an email") + return apierrors.NewUnprocessableEntityError(apierrors.ErrorCodeValidationFailed, "Password recovery requires an email") } var err error p.Email, err = a.validateEmail(p.Email) @@ -44,18 +44,18 @@ func (a *API) MagicLink(w http.ResponseWriter, r *http.Request) error { config := a.config if !config.External.Email.Enabled { - return unprocessableEntityError(apierrors.ErrorCodeEmailProviderDisabled, "Email logins are disabled") + return apierrors.NewUnprocessableEntityError(apierrors.ErrorCodeEmailProviderDisabled, "Email logins are disabled") } if !config.External.Email.MagicLinkEnabled { - return unprocessableEntityError(apierrors.ErrorCodeEmailProviderDisabled, "Login with magic link is disabled") + return apierrors.NewUnprocessableEntityError(apierrors.ErrorCodeEmailProviderDisabled, "Login with magic link is disabled") } params := &MagicLinkParams{} jsonDecoder := json.NewDecoder(r.Body) err := jsonDecoder.Decode(params) if err != nil { - return badRequestError(apierrors.ErrorCodeBadJSON, "Could not read verification params: %v", err).WithInternalError(err) + return apierrors.NewBadRequestError(apierrors.ErrorCodeBadJSON, "Could not read verification params: %v", err).WithInternalError(err) } if err := params.Validate(a); err != nil { @@ -75,7 +75,7 @@ func (a *API) MagicLink(w http.ResponseWriter, r *http.Request) error { if models.IsNotFoundError(err) { isNewUser = true } else { - return internalServerError("Database error finding user").WithInternalError(err) + return apierrors.NewInternalServerError("Database error finding user").WithInternalError(err) } } if user != nil { diff --git a/internal/api/mail.go b/internal/api/mail.go index 31fc309ab..f90d9a74c 100644 --- a/internal/api/mail.go +++ b/internal/api/mail.go @@ -6,7 +6,7 @@ import ( "strings" "time" - "github.com/supabase/auth/internal/hooks" + "github.com/supabase/auth/internal/hooks/v0hooks" mail "github.com/supabase/auth/internal/mailer" "go.opentelemetry.io/otel/attribute" "go.opentelemetry.io/otel/metric" @@ -79,10 +79,10 @@ func (a *API) adminGenerateLink(w http.ResponseWriter, r *http.Request) error { panic(err) } case mail.RecoveryVerification, mail.EmailChangeCurrentVerification, mail.EmailChangeNewVerification: - return notFoundError(apierrors.ErrorCodeUserNotFound, "User with this email not found") + return apierrors.NewNotFoundError(apierrors.ErrorCodeUserNotFound, "User with this email not found") } } else { - return internalServerError("Database error finding user").WithInternalError(err) + return apierrors.NewInternalServerError("Database error finding user").WithInternalError(err) } } @@ -135,7 +135,7 @@ func (a *API) adminGenerateLink(w http.ResponseWriter, r *http.Request) error { case mail.InviteVerification: if user != nil { if user.IsConfirmed() { - return unprocessableEntityError(apierrors.ErrorCodeEmailExists, DuplicateEmailMsg) + return apierrors.NewUnprocessableEntityError(apierrors.ErrorCodeEmailExists, DuplicateEmailMsg) } } else { signupParams := &SignupParams{ @@ -188,10 +188,10 @@ func (a *API) adminGenerateLink(w http.ResponseWriter, r *http.Request) error { case mail.SignupVerification: if user != nil { if user.IsConfirmed() { - return unprocessableEntityError(apierrors.ErrorCodeEmailExists, DuplicateEmailMsg) + return apierrors.NewUnprocessableEntityError(apierrors.ErrorCodeEmailExists, DuplicateEmailMsg) } if err := user.UpdateUserMetaData(tx, params.Data); err != nil { - return internalServerError("Database error updating user").WithInternalError(err) + return apierrors.NewInternalServerError("Database error updating user").WithInternalError(err) } } else { // you should never use SignupParams with @@ -225,16 +225,16 @@ func (a *API) adminGenerateLink(w http.ResponseWriter, r *http.Request) error { } case mail.EmailChangeCurrentVerification, mail.EmailChangeNewVerification: if !config.Mailer.SecureEmailChangeEnabled && params.Type == "email_change_current" { - return badRequestError(apierrors.ErrorCodeValidationFailed, "Enable secure email change to generate link for current email") + return apierrors.NewBadRequestError(apierrors.ErrorCodeValidationFailed, "Enable secure email change to generate link for current email") } params.NewEmail, terr = a.validateEmail(params.NewEmail) if terr != nil { return terr } if duplicateUser, terr := models.IsDuplicatedEmail(tx, params.NewEmail, user.Aud, user); terr != nil { - return internalServerError("Database error checking email").WithInternalError(terr) + return apierrors.NewInternalServerError("Database error checking email").WithInternalError(terr) } else if duplicateUser != nil { - return unprocessableEntityError(apierrors.ErrorCodeEmailExists, DuplicateEmailMsg) + return apierrors.NewUnprocessableEntityError(apierrors.ErrorCodeEmailExists, DuplicateEmailMsg) } now := time.Now() user.EmailChangeSentAt = &now @@ -265,7 +265,7 @@ func (a *API) adminGenerateLink(w http.ResponseWriter, r *http.Request) error { } } default: - return badRequestError(apierrors.ErrorCodeValidationFailed, "Invalid email action link type requested: %v", params.Type) + return apierrors.NewBadRequestError(apierrors.ErrorCodeValidationFailed, "Invalid email action link type requested: %v", params.Type) } if terr != nil { @@ -315,19 +315,19 @@ func (a *API) sendConfirmation(r *http.Request, tx *storage.Connection, u *model if err = a.sendEmail(r, tx, u, mail.SignupVerification, otp, "", u.ConfirmationToken); err != nil { u.ConfirmationToken = oldToken if errors.Is(err, EmailRateLimitExceeded) { - return tooManyRequestsError(apierrors.ErrorCodeOverEmailSendRateLimit, EmailRateLimitExceeded.Error()) + return apierrors.NewTooManyRequestsError(apierrors.ErrorCodeOverEmailSendRateLimit, EmailRateLimitExceeded.Error()) } else if herr, ok := err.(*HTTPError); ok { return herr } - return internalServerError("Error sending confirmation email").WithInternalError(err) + return apierrors.NewInternalServerError("Error sending confirmation email").WithInternalError(err) } u.ConfirmationSentAt = &now if err := tx.UpdateOnly(u, "confirmation_token", "confirmation_sent_at"); err != nil { - return internalServerError("Error sending confirmation email").WithInternalError(errors.Wrap(err, "Database error updating user for confirmation")) + return apierrors.NewInternalServerError("Error sending confirmation email").WithInternalError(errors.Wrap(err, "Database error updating user for confirmation")) } if err := models.CreateOneTimeToken(tx, u.ID, u.GetEmail(), u.ConfirmationToken, models.ConfirmationToken); err != nil { - return internalServerError("Error sending confirmation email").WithInternalError(errors.Wrap(err, "Database error creating confirmation token")) + return apierrors.NewInternalServerError("Error sending confirmation email").WithInternalError(errors.Wrap(err, "Database error creating confirmation token")) } return nil @@ -345,22 +345,22 @@ func (a *API) sendInvite(r *http.Request, tx *storage.Connection, u *models.User if err = a.sendEmail(r, tx, u, mail.InviteVerification, otp, "", u.ConfirmationToken); err != nil { u.ConfirmationToken = oldToken if errors.Is(err, EmailRateLimitExceeded) { - return tooManyRequestsError(apierrors.ErrorCodeOverEmailSendRateLimit, EmailRateLimitExceeded.Error()) + return apierrors.NewTooManyRequestsError(apierrors.ErrorCodeOverEmailSendRateLimit, EmailRateLimitExceeded.Error()) } else if herr, ok := err.(*HTTPError); ok { return herr } - return internalServerError("Error sending invite email").WithInternalError(err) + return apierrors.NewInternalServerError("Error sending invite email").WithInternalError(err) } u.InvitedAt = &now u.ConfirmationSentAt = &now err = tx.UpdateOnly(u, "confirmation_token", "confirmation_sent_at", "invited_at") if err != nil { - return internalServerError("Error inviting user").WithInternalError(errors.Wrap(err, "Database error updating user for invite")) + return apierrors.NewInternalServerError("Error inviting user").WithInternalError(errors.Wrap(err, "Database error updating user for invite")) } err = models.CreateOneTimeToken(tx, u.ID, u.GetEmail(), u.ConfirmationToken, models.ConfirmationToken) if err != nil { - return internalServerError("Error inviting user").WithInternalError(errors.Wrap(err, "Database error creating confirmation token for invite")) + return apierrors.NewInternalServerError("Error inviting user").WithInternalError(errors.Wrap(err, "Database error creating confirmation token for invite")) } return nil @@ -383,20 +383,20 @@ func (a *API) sendPasswordRecovery(r *http.Request, tx *storage.Connection, u *m if err := a.sendEmail(r, tx, u, mail.RecoveryVerification, otp, "", u.RecoveryToken); err != nil { u.RecoveryToken = oldToken if errors.Is(err, EmailRateLimitExceeded) { - return tooManyRequestsError(apierrors.ErrorCodeOverEmailSendRateLimit, EmailRateLimitExceeded.Error()) + return apierrors.NewTooManyRequestsError(apierrors.ErrorCodeOverEmailSendRateLimit, EmailRateLimitExceeded.Error()) } else if herr, ok := err.(*HTTPError); ok { return herr } - return internalServerError("Error sending recovery email").WithInternalError(err) + return apierrors.NewInternalServerError("Error sending recovery email").WithInternalError(err) } u.RecoverySentAt = &now if err := tx.UpdateOnly(u, "recovery_token", "recovery_sent_at"); err != nil { - return internalServerError("Error sending recovery email").WithInternalError(errors.Wrap(err, "Database error updating user for recovery")) + return apierrors.NewInternalServerError("Error sending recovery email").WithInternalError(errors.Wrap(err, "Database error updating user for recovery")) } if err := models.CreateOneTimeToken(tx, u.ID, u.GetEmail(), u.RecoveryToken, models.RecoveryToken); err != nil { - return internalServerError("Error sending recovery email").WithInternalError(errors.Wrap(err, "Database error creating recovery token")) + return apierrors.NewInternalServerError("Error sending recovery email").WithInternalError(errors.Wrap(err, "Database error creating recovery token")) } return nil @@ -420,19 +420,19 @@ func (a *API) sendReauthenticationOtp(r *http.Request, tx *storage.Connection, u if err := a.sendEmail(r, tx, u, mail.ReauthenticationVerification, otp, "", u.ReauthenticationToken); err != nil { u.ReauthenticationToken = oldToken if errors.Is(err, EmailRateLimitExceeded) { - return tooManyRequestsError(apierrors.ErrorCodeOverEmailSendRateLimit, EmailRateLimitExceeded.Error()) + return apierrors.NewTooManyRequestsError(apierrors.ErrorCodeOverEmailSendRateLimit, EmailRateLimitExceeded.Error()) } else if herr, ok := err.(*HTTPError); ok { return herr } - return internalServerError("Error sending reauthentication email").WithInternalError(err) + return apierrors.NewInternalServerError("Error sending reauthentication email").WithInternalError(err) } u.ReauthenticationSentAt = &now if err := tx.UpdateOnly(u, "reauthentication_token", "reauthentication_sent_at"); err != nil { - return internalServerError("Error sending reauthentication email").WithInternalError(errors.Wrap(err, "Database error updating user for reauthentication")) + return apierrors.NewInternalServerError("Error sending reauthentication email").WithInternalError(errors.Wrap(err, "Database error updating user for reauthentication")) } if err := models.CreateOneTimeToken(tx, u.ID, u.GetEmail(), u.ReauthenticationToken, models.ReauthenticationToken); err != nil { - return internalServerError("Error sending reauthentication email").WithInternalError(errors.Wrap(err, "Database error creating reauthentication token")) + return apierrors.NewInternalServerError("Error sending reauthentication email").WithInternalError(errors.Wrap(err, "Database error creating reauthentication token")) } return nil @@ -459,19 +459,19 @@ func (a *API) sendMagicLink(r *http.Request, tx *storage.Connection, u *models.U if err = a.sendEmail(r, tx, u, mail.MagicLinkVerification, otp, "", u.RecoveryToken); err != nil { u.RecoveryToken = oldToken if errors.Is(err, EmailRateLimitExceeded) { - return tooManyRequestsError(apierrors.ErrorCodeOverEmailSendRateLimit, EmailRateLimitExceeded.Error()) + return apierrors.NewTooManyRequestsError(apierrors.ErrorCodeOverEmailSendRateLimit, EmailRateLimitExceeded.Error()) } else if herr, ok := err.(*HTTPError); ok { return herr } - return internalServerError("Error sending magic link email").WithInternalError(err) + return apierrors.NewInternalServerError("Error sending magic link email").WithInternalError(err) } u.RecoverySentAt = &now if err := tx.UpdateOnly(u, "recovery_token", "recovery_sent_at"); err != nil { - return internalServerError("Error sending magic link email").WithInternalError(errors.Wrap(err, "Database error updating user for recovery")) + return apierrors.NewInternalServerError("Error sending magic link email").WithInternalError(errors.Wrap(err, "Database error updating user for recovery")) } if err := models.CreateOneTimeToken(tx, u.ID, u.GetEmail(), u.RecoveryToken, models.RecoveryToken); err != nil { - return internalServerError("Error sending magic link email").WithInternalError(errors.Wrap(err, "Database error creating recovery token")) + return apierrors.NewInternalServerError("Error sending magic link email").WithInternalError(errors.Wrap(err, "Database error creating recovery token")) } return nil @@ -505,11 +505,11 @@ func (a *API) sendEmailChange(r *http.Request, tx *storage.Connection, u *models if err := a.sendEmail(r, tx, u, mail.EmailChangeVerification, otpCurrent, otpNew, u.EmailChangeTokenNew); err != nil { if errors.Is(err, EmailRateLimitExceeded) { - return tooManyRequestsError(apierrors.ErrorCodeOverEmailSendRateLimit, EmailRateLimitExceeded.Error()) + return apierrors.NewTooManyRequestsError(apierrors.ErrorCodeOverEmailSendRateLimit, EmailRateLimitExceeded.Error()) } else if herr, ok := err.(*HTTPError); ok { return herr } - return internalServerError("Error sending email change email").WithInternalError(err) + return apierrors.NewInternalServerError("Error sending email change email").WithInternalError(err) } u.EmailChangeSentAt = &now @@ -521,18 +521,18 @@ func (a *API) sendEmailChange(r *http.Request, tx *storage.Connection, u *models "email_change_sent_at", "email_change_confirm_status", ); err != nil { - return internalServerError("Error sending email change email").WithInternalError(errors.Wrap(err, "Database error updating user for email change")) + return apierrors.NewInternalServerError("Error sending email change email").WithInternalError(errors.Wrap(err, "Database error updating user for email change")) } if u.EmailChangeTokenCurrent != "" { if err := models.CreateOneTimeToken(tx, u.ID, u.GetEmail(), u.EmailChangeTokenCurrent, models.EmailChangeTokenCurrent); err != nil { - return internalServerError("Error sending email change email").WithInternalError(errors.Wrap(err, "Database error creating email change token current")) + return apierrors.NewInternalServerError("Error sending email change email").WithInternalError(errors.Wrap(err, "Database error creating email change token current")) } } if u.EmailChangeTokenNew != "" { if err := models.CreateOneTimeToken(tx, u.ID, u.EmailChange, u.EmailChangeTokenNew, models.EmailChangeTokenNew); err != nil { - return internalServerError("Error sending email change email").WithInternalError(errors.Wrap(err, "Database error creating email change token new")) + return apierrors.NewInternalServerError("Error sending email change email").WithInternalError(errors.Wrap(err, "Database error creating email change token new")) } } @@ -541,13 +541,13 @@ func (a *API) sendEmailChange(r *http.Request, tx *storage.Connection, u *models func (a *API) validateEmail(email string) (string, error) { if email == "" { - return "", badRequestError(apierrors.ErrorCodeValidationFailed, "An email address is required") + return "", apierrors.NewBadRequestError(apierrors.ErrorCodeValidationFailed, "An email address is required") } if len(email) > 255 { - return "", badRequestError(apierrors.ErrorCodeValidationFailed, "An email address is too long") + return "", apierrors.NewBadRequestError(apierrors.ErrorCodeValidationFailed, "An email address is too long") } if err := checkmail.ValidateFormat(email); err != nil { - return "", badRequestError(apierrors.ErrorCodeValidationFailed, "Unable to validate email address: "+err.Error()) + return "", apierrors.NewBadRequestError(apierrors.ErrorCodeValidationFailed, "Unable to validate email address: "+err.Error()) } return strings.ToLower(email), nil @@ -555,7 +555,7 @@ func (a *API) validateEmail(email string) (string, error) { func validateSentWithinFrequencyLimit(sentAt *time.Time, frequency time.Duration) error { if sentAt != nil && sentAt.Add(frequency).After(time.Now()) { - return tooManyRequestsError(apierrors.ErrorCodeOverEmailSendRateLimit, generateFrequencyLimitErrorMessage(sentAt, frequency)) + return apierrors.NewTooManyRequestsError(apierrors.ErrorCodeOverEmailSendRateLimit, generateFrequencyLimitErrorMessage(sentAt, frequency)) } return nil } @@ -587,13 +587,13 @@ func (a *API) sendEmail(r *http.Request, tx *storage.Connection, u *models.User, if emailActionType != mail.EmailChangeVerification { if u.GetEmail() != "" && !a.checkEmailAddressAuthorization(u.GetEmail()) { - return badRequestError(apierrors.ErrorCodeEmailAddressNotAuthorized, "Email address %q cannot be used as it is not authorized", u.GetEmail()) + return apierrors.NewBadRequestError(apierrors.ErrorCodeEmailAddressNotAuthorized, "Email address %q cannot be used as it is not authorized", u.GetEmail()) } } else { // first check that the user can update their address to the // new one in u.EmailChange if u.EmailChange != "" && !a.checkEmailAddressAuthorization(u.EmailChange) { - return badRequestError(apierrors.ErrorCodeEmailAddressNotAuthorized, "Email address %q cannot be used as it is not authorized", u.EmailChange) + return apierrors.NewBadRequestError(apierrors.ErrorCodeEmailAddressNotAuthorized, "Email address %q cannot be used as it is not authorized", u.EmailChange) } // if secure email change is enabled, check that the user @@ -601,7 +601,7 @@ func (a *API) sendEmail(r *http.Request, tx *storage.Connection, u *models.User, // address authorization restriction was enabled) can even // receive the confirmation message to the existing address if config.Mailer.SecureEmailChangeEnabled && u.GetEmail() != "" && !a.checkEmailAddressAuthorization(u.GetEmail()) { - return badRequestError(apierrors.ErrorCodeEmailAddressNotAuthorized, "Email address %q cannot be used as it is not authorized", u.GetEmail()) + return apierrors.NewBadRequestError(apierrors.ErrorCodeEmailAddressNotAuthorized, "Email address %q cannot be used as it is not authorized", u.GetEmail()) } } @@ -645,12 +645,12 @@ func (a *API) sendEmail(r *http.Request, tx *storage.Connection, u *models.User, emailData.TokenNew = otpNew emailData.TokenHashNew = u.EmailChangeTokenCurrent } - input := hooks.SendEmailInput{ + input := v0hooks.SendEmailInput{ User: u, EmailData: emailData, } - output := hooks.SendEmailOutput{} - return a.invokeHook(tx, r, &input, &output) + output := v0hooks.SendEmailOutput{} + return a.hooksMgr.InvokeHook(tx, r, &input, &output) } mr := a.Mailer() @@ -676,7 +676,7 @@ func (a *API) sendEmail(r *http.Request, tx *storage.Connection, u *models.User, case errors.Is(err, mail.ErrInvalidEmailAddress), errors.Is(err, mail.ErrInvalidEmailFormat), errors.Is(err, mail.ErrInvalidEmailDNS): - return badRequestError( + return apierrors.NewBadRequestError( apierrors.ErrorCodeEmailAddressInvalid, "Email address %q is invalid", u.GetEmail()) diff --git a/internal/api/mail_test.go b/internal/api/mail_test.go index c13c18e69..97ab2df89 100644 --- a/internal/api/mail_test.go +++ b/internal/api/mail_test.go @@ -72,14 +72,14 @@ func (ts *MailTestSuite) TestValidateEmail() { desc: "empty email should return error", email: "", expectedEmail: "", - expectedError: badRequestError(apierrors.ErrorCodeValidationFailed, "An email address is required"), + expectedError: apierrors.NewBadRequestError(apierrors.ErrorCodeValidationFailed, "An email address is required"), }, { desc: "email length exceeds 255 characters", // email has 256 characters email: "testtesttesttesttesttesttesttesttesttesttesttesttesttesttesttesttesttesttesttesttesttesttesttesttesttesttesttesttesttesttesttesttesttesttesttesttesttesttesttesttesttesttesttesttesttesttesttesttesttesttesttesttesttesttesttesttesttesttesttesttest@example.com", expectedEmail: "", - expectedError: badRequestError(apierrors.ErrorCodeValidationFailed, "An email address is too long"), + expectedError: apierrors.NewBadRequestError(apierrors.ErrorCodeValidationFailed, "An email address is too long"), }, } diff --git a/internal/api/mfa.go b/internal/api/mfa.go index 2f79ccec2..8ce34e52d 100644 --- a/internal/api/mfa.go +++ b/internal/api/mfa.go @@ -22,7 +22,7 @@ import ( "github.com/supabase/auth/internal/api/sms_provider" "github.com/supabase/auth/internal/conf" "github.com/supabase/auth/internal/crypto" - "github.com/supabase/auth/internal/hooks" + "github.com/supabase/auth/internal/hooks/v0hooks" "github.com/supabase/auth/internal/metering" "github.com/supabase/auth/internal/models" "github.com/supabase/auth/internal/storage" @@ -142,7 +142,7 @@ func validateFactors(db *storage.Connection, user *models.User, newFactorName st for _, factor := range user.Factors { if factor.FriendlyName == newFactorName { - return unprocessableEntityError( + return apierrors.NewUnprocessableEntityError( apierrors.ErrorCodeMFAFactorNameConflict, fmt.Sprintf("A factor with the friendly name %q for this user already exists", newFactorName), ) @@ -153,15 +153,15 @@ func validateFactors(db *storage.Connection, user *models.User, newFactorName st } if factorCount >= int(config.MFA.MaxEnrolledFactors) { - return unprocessableEntityError(apierrors.ErrorCodeTooManyEnrolledMFAFactors, "Maximum number of verified factors reached, unenroll to continue") + return apierrors.NewUnprocessableEntityError(apierrors.ErrorCodeTooManyEnrolledMFAFactors, "Maximum number of verified factors reached, unenroll to continue") } if numVerifiedFactors >= config.MFA.MaxVerifiedFactors { - return unprocessableEntityError(apierrors.ErrorCodeTooManyEnrolledMFAFactors, "Maximum number of verified factors reached, unenroll to continue") + return apierrors.NewUnprocessableEntityError(apierrors.ErrorCodeTooManyEnrolledMFAFactors, "Maximum number of verified factors reached, unenroll to continue") } if numVerifiedFactors > 0 && session != nil && !session.IsAAL2() { - return forbiddenError(apierrors.ErrorCodeInsufficientAAL, "AAL2 required to enroll a new factor") + return apierrors.NewForbiddenError(apierrors.ErrorCodeInsufficientAAL, "AAL2 required to enroll a new factor") } return nil @@ -173,19 +173,19 @@ func (a *API) enrollPhoneFactor(w http.ResponseWriter, r *http.Request, params * session := getSession(ctx) db := a.db.WithContext(ctx) if params.Phone == "" { - return badRequestError(apierrors.ErrorCodeValidationFailed, "Phone number required to enroll Phone factor") + return apierrors.NewBadRequestError(apierrors.ErrorCodeValidationFailed, "Phone number required to enroll Phone factor") } phone, err := validatePhone(params.Phone) if err != nil { - return badRequestError(apierrors.ErrorCodeValidationFailed, "Invalid phone number format (E.164 required)") + return apierrors.NewBadRequestError(apierrors.ErrorCodeValidationFailed, "Invalid phone number format (E.164 required)") } var factorsToDelete []models.Factor for _, factor := range user.Factors { if factor.IsPhoneFactor() && factor.Phone.String() == phone { if factor.IsVerified() { - return unprocessableEntityError( + return apierrors.NewUnprocessableEntityError( apierrors.ErrorCodeMFAVerifiedFactorExists, "A verified phone factor already exists, unenroll the existing factor to continue", ) @@ -196,7 +196,7 @@ func (a *API) enrollPhoneFactor(w http.ResponseWriter, r *http.Request, params * } if err := db.Destroy(&factorsToDelete); err != nil { - return internalServerError("Database error deleting unverified phone factors").WithInternalError(err) + return apierrors.NewInternalServerError("Database error deleting unverified phone factors").WithInternalError(err) } if err := validateFactors(db, user, params.FriendlyName, a.config, session); err != nil { @@ -270,7 +270,7 @@ func (a *API) enrollTOTPFactor(w http.ResponseWriter, r *http.Request, params *E if params.Issuer == "" { u, err := url.ParseRequestURI(config.SiteURL) if err != nil { - return internalServerError("site url is improperly formatted") + return apierrors.NewInternalServerError("site url is improperly formatted") } issuer = u.Host } else { @@ -289,7 +289,7 @@ func (a *API) enrollTOTPFactor(w http.ResponseWriter, r *http.Request, params *E AccountName: user.GetEmail(), }) if err != nil { - return internalServerError(QRCodeGenerationErrorMessage).WithInternalError(err) + return apierrors.NewInternalServerError(QRCodeGenerationErrorMessage).WithInternalError(err) } svgData := svg.New(&buf) @@ -297,7 +297,7 @@ func (a *API) enrollTOTPFactor(w http.ResponseWriter, r *http.Request, params *E qs := goqrsvg.NewQrSVG(qrCode, DefaultQRSize) qs.StartQrSVG(svgData) if err = qs.WriteQrSVG(svgData); err != nil { - return internalServerError(QRCodeGenerationErrorMessage).WithInternalError(err) + return apierrors.NewInternalServerError(QRCodeGenerationErrorMessage).WithInternalError(err) } svgData.End() @@ -341,7 +341,7 @@ func (a *API) EnrollFactor(w http.ResponseWriter, r *http.Request) error { config := a.config if session == nil || user == nil { - return internalServerError("A valid session and a registered user are required to enroll a factor") + return apierrors.NewInternalServerError("A valid session and a registered user are required to enroll a factor") } params := &EnrollFactorParams{} if err := retrieveRequestParams(r, params); err != nil { @@ -351,21 +351,21 @@ func (a *API) EnrollFactor(w http.ResponseWriter, r *http.Request) error { switch params.FactorType { case models.Phone: if !config.MFA.Phone.EnrollEnabled { - return unprocessableEntityError(apierrors.ErrorCodeMFAPhoneEnrollDisabled, "MFA enroll is disabled for Phone") + return apierrors.NewUnprocessableEntityError(apierrors.ErrorCodeMFAPhoneEnrollDisabled, "MFA enroll is disabled for Phone") } return a.enrollPhoneFactor(w, r, params) case models.TOTP: if !config.MFA.TOTP.EnrollEnabled { - return unprocessableEntityError(apierrors.ErrorCodeMFATOTPEnrollDisabled, "MFA enroll is disabled for TOTP") + return apierrors.NewUnprocessableEntityError(apierrors.ErrorCodeMFATOTPEnrollDisabled, "MFA enroll is disabled for TOTP") } return a.enrollTOTPFactor(w, r, params) case models.WebAuthn: if !config.MFA.WebAuthn.EnrollEnabled { - return unprocessableEntityError(apierrors.ErrorCodeMFAWebAuthnEnrollDisabled, "MFA enroll is disabled for WebAuthn") + return apierrors.NewUnprocessableEntityError(apierrors.ErrorCodeMFAWebAuthnEnrollDisabled, "MFA enroll is disabled for WebAuthn") } return a.enrollWebAuthnFactor(w, r, params) default: - return badRequestError(apierrors.ErrorCodeValidationFailed, "factor_type needs to be totp, phone, or webauthn") + return apierrors.NewBadRequestError(apierrors.ErrorCodeValidationFailed, "factor_type needs to be totp, phone, or webauthn") } } @@ -386,12 +386,12 @@ func (a *API) challengePhoneFactor(w http.ResponseWriter, r *http.Request) error channel = sms_provider.SMSProvider } if !sms_provider.IsValidMessageChannel(channel, config) { - return badRequestError(apierrors.ErrorCodeValidationFailed, InvalidChannelError) + return apierrors.NewBadRequestError(apierrors.ErrorCodeValidationFailed, InvalidChannelError) } if factor.IsPhoneFactor() && factor.LastChallengedAt != nil { if !factor.LastChallengedAt.Add(config.MFA.Phone.MaxFrequency).Before(time.Now()) { - return tooManyRequestsError(apierrors.ErrorCodeOverSMSSendRateLimit, generateFrequencyLimitErrorMessage(factor.LastChallengedAt, config.MFA.Phone.MaxFrequency)) + return apierrors.NewTooManyRequestsError(apierrors.ErrorCodeOverSMSSendRateLimit, generateFrequencyLimitErrorMessage(factor.LastChallengedAt, config.MFA.Phone.MaxFrequency)) } } @@ -399,35 +399,35 @@ func (a *API) challengePhoneFactor(w http.ResponseWriter, r *http.Request) error challenge, err := factor.CreatePhoneChallenge(ipAddress, otp, config.Security.DBEncryption.Encrypt, config.Security.DBEncryption.EncryptionKeyID, config.Security.DBEncryption.EncryptionKey) if err != nil { - return internalServerError("error creating SMS Challenge") + return apierrors.NewInternalServerError("error creating SMS Challenge") } message, err := generateSMSFromTemplate(config.MFA.Phone.SMSTemplate, otp) if err != nil { - return internalServerError("error generating sms template").WithInternalError(err) + return apierrors.NewInternalServerError("error generating sms template").WithInternalError(err) } if config.Hook.SendSMS.Enabled { - input := hooks.SendSMSInput{ + input := v0hooks.SendSMSInput{ User: user, - SMS: hooks.SMS{ + SMS: v0hooks.SMS{ OTP: otp, SMSType: "mfa", }, } - output := hooks.SendSMSOutput{} - err := a.invokeHook(a.db, r, &input, &output) + output := v0hooks.SendSMSOutput{} + err := a.hooksMgr.InvokeHook(a.db, r, &input, &output) if err != nil { - return internalServerError("error invoking hook") + return apierrors.NewInternalServerError("error invoking hook") } } else { smsProvider, err := sms_provider.GetSmsProvider(*config) if err != nil { - return internalServerError("Failed to get SMS provider").WithInternalError(err) + return apierrors.NewInternalServerError("Failed to get SMS provider").WithInternalError(err) } // We omit messageID for now, can consider reinstating if there are requests. if _, err = smsProvider.SendMessage(factor.Phone.String(), message, channel, otp); err != nil { - return internalServerError("error sending message").WithInternalError(err) + return apierrors.NewInternalServerError("error sending message").WithInternalError(err) } } if err := db.Transaction(func(tx *storage.Connection) error { @@ -499,7 +499,7 @@ func (a *API) challengeWebAuthnFactor(w http.ResponseWriter, r *http.Request) er return err } if params.WebAuthn == nil { - return badRequestError(apierrors.ErrorCodeValidationFailed, "web_authn config required") + return apierrors.NewBadRequestError(apierrors.ErrorCodeValidationFailed, "web_authn config required") } webAuthn, err := params.WebAuthn.ToConfig() if err != nil { @@ -511,7 +511,7 @@ func (a *API) challengeWebAuthnFactor(w http.ResponseWriter, r *http.Request) er if factor.IsUnverified() { options, session, err := webAuthn.BeginRegistration(user) if err != nil { - return internalServerError("Failed to generate WebAuthn registration data").WithInternalError(err) + return apierrors.NewInternalServerError("Failed to generate WebAuthn registration data").WithInternalError(err) } ws = &models.WebAuthnSessionData{ SessionData: session, @@ -557,20 +557,20 @@ func (a *API) validateChallenge(r *http.Request, db *storage.Connection, factor challenge, err := factor.FindChallengeByID(db, challengeID) if err != nil { if models.IsNotFoundError(err) { - return nil, unprocessableEntityError(apierrors.ErrorCodeMFAFactorNotFound, "MFA factor with the provided challenge ID not found") + return nil, apierrors.NewUnprocessableEntityError(apierrors.ErrorCodeMFAFactorNotFound, "MFA factor with the provided challenge ID not found") } - return nil, internalServerError("Database error finding Challenge").WithInternalError(err) + return nil, apierrors.NewInternalServerError("Database error finding Challenge").WithInternalError(err) } if challenge.VerifiedAt != nil || challenge.IPAddress != currentIP { - return nil, unprocessableEntityError(apierrors.ErrorCodeMFAIPAddressMismatch, "Challenge and verify IP addresses mismatch.") + return nil, apierrors.NewUnprocessableEntityError(apierrors.ErrorCodeMFAIPAddressMismatch, "Challenge and verify IP addresses mismatch.") } if challenge.HasExpired(config.MFA.ChallengeExpiryDuration) { if err := db.Destroy(challenge); err != nil { - return nil, internalServerError("Database error deleting challenge").WithInternalError(err) + return nil, apierrors.NewInternalServerError("Database error deleting challenge").WithInternalError(err) } - return nil, unprocessableEntityError(apierrors.ErrorCodeMFAChallengeExpired, "MFA challenge %v has expired, verify against another challenge or create a new challenge.", challenge.ID) + return nil, apierrors.NewUnprocessableEntityError(apierrors.ErrorCodeMFAChallengeExpired, "MFA challenge %v has expired, verify against another challenge or create a new challenge.", challenge.ID) } return challenge, nil @@ -584,22 +584,22 @@ func (a *API) ChallengeFactor(w http.ResponseWriter, r *http.Request) error { switch factor.FactorType { case models.Phone: if !config.MFA.Phone.VerifyEnabled { - return unprocessableEntityError(apierrors.ErrorCodeMFAPhoneVerifyDisabled, "MFA verification is disabled for Phone") + return apierrors.NewUnprocessableEntityError(apierrors.ErrorCodeMFAPhoneVerifyDisabled, "MFA verification is disabled for Phone") } return a.challengePhoneFactor(w, r) case models.TOTP: if !config.MFA.TOTP.VerifyEnabled { - return unprocessableEntityError(apierrors.ErrorCodeMFATOTPVerifyDisabled, "MFA verification is disabled for TOTP") + return apierrors.NewUnprocessableEntityError(apierrors.ErrorCodeMFATOTPVerifyDisabled, "MFA verification is disabled for TOTP") } return a.challengeTOTPFactor(w, r) case models.WebAuthn: if !config.MFA.WebAuthn.VerifyEnabled { - return unprocessableEntityError(apierrors.ErrorCodeMFAWebAuthnVerifyDisabled, "MFA verification is disabled for WebAuthn") + return apierrors.NewUnprocessableEntityError(apierrors.ErrorCodeMFAWebAuthnVerifyDisabled, "MFA verification is disabled for WebAuthn") } return a.challengeWebAuthnFactor(w, r) default: - return badRequestError(apierrors.ErrorCodeValidationFailed, "factor_type needs to be totp, phone, or webauthn") + return apierrors.NewBadRequestError(apierrors.ErrorCodeValidationFailed, "factor_type needs to be totp, phone, or webauthn") } } @@ -619,7 +619,7 @@ func (a *API) verifyTOTPFactor(w http.ResponseWriter, r *http.Request, params *V secret, shouldReEncrypt, err := factor.GetSecret(config.Security.DBEncryption.DecryptionKeys, config.Security.DBEncryption.Encrypt, config.Security.DBEncryption.EncryptionKeyID) if err != nil { - return internalServerError("Database error verifying MFA TOTP secret").WithInternalError(err) + return apierrors.NewInternalServerError("Database error verifying MFA TOTP secret").WithInternalError(err) } valid, verr := totp.ValidateCustom(params.Code, secret, time.Now().UTC(), totp.ValidateOpts{ @@ -630,28 +630,28 @@ func (a *API) verifyTOTPFactor(w http.ResponseWriter, r *http.Request, params *V }) if config.Hook.MFAVerificationAttempt.Enabled { - input := hooks.MFAVerificationAttemptInput{ + input := v0hooks.MFAVerificationAttemptInput{ UserID: user.ID, FactorID: factor.ID, Valid: valid, } - output := hooks.MFAVerificationAttemptOutput{} - err := a.invokeHook(nil, r, &input, &output) + output := v0hooks.MFAVerificationAttemptOutput{} + err := a.hooksMgr.InvokeHook(nil, r, &input, &output) if err != nil { return err } - if output.Decision == hooks.HookRejection { + if output.Decision == v0hooks.HookRejection { if err := models.Logout(db, user.ID); err != nil { return err } if output.Message == "" { - output.Message = hooks.DefaultMFAHookRejectionMessage + output.Message = v0hooks.DefaultMFAHookRejectionMessage } - return forbiddenError(apierrors.ErrorCodeMFAVerificationRejected, output.Message) + return apierrors.NewForbiddenError(apierrors.ErrorCodeMFAVerificationRejected, output.Message) } } if !valid { @@ -664,7 +664,7 @@ func (a *API) verifyTOTPFactor(w http.ResponseWriter, r *http.Request, params *V return err } } - return unprocessableEntityError(apierrors.ErrorCodeMFAVerificationFailed, "Invalid TOTP code entered").WithInternalError(verr) + return apierrors.NewUnprocessableEntityError(apierrors.ErrorCodeMFAVerificationFailed, "Invalid TOTP code entered").WithInternalError(verr) } var token *AccessTokenResponse @@ -709,10 +709,10 @@ func (a *API) verifyTOTPFactor(w http.ResponseWriter, r *http.Request, params *V return terr } if terr = models.InvalidateSessionsWithAALLessThan(tx, user.ID, models.AAL2.String()); terr != nil { - return internalServerError("Failed to update sessions. %s", terr) + return apierrors.NewInternalServerError("Failed to update sessions. %s", terr) } if terr = models.DeleteUnverifiedFactors(tx, user, factor.FactorType); terr != nil { - return internalServerError("Error removing unverified factors. %s", terr) + return apierrors.NewInternalServerError("Error removing unverified factors. %s", terr) } return nil }) @@ -739,14 +739,14 @@ func (a *API) verifyPhoneFactor(w http.ResponseWriter, r *http.Request, params * } if challenge.VerifiedAt != nil || challenge.IPAddress != currentIP { - return unprocessableEntityError(apierrors.ErrorCodeMFAIPAddressMismatch, "Challenge and verify IP addresses mismatch") + return apierrors.NewUnprocessableEntityError(apierrors.ErrorCodeMFAIPAddressMismatch, "Challenge and verify IP addresses mismatch") } if challenge.HasExpired(config.MFA.ChallengeExpiryDuration) { if err := db.Destroy(challenge); err != nil { - return internalServerError("Database error deleting challenge").WithInternalError(err) + return apierrors.NewInternalServerError("Database error deleting challenge").WithInternalError(err) } - return unprocessableEntityError(apierrors.ErrorCodeMFAChallengeExpired, "MFA challenge %v has expired, verify against another challenge or create a new challenge.", challenge.ID) + return apierrors.NewUnprocessableEntityError(apierrors.ErrorCodeMFAChallengeExpired, "MFA challenge %v has expired, verify against another challenge or create a new challenge.", challenge.ID) } var valid bool var otpCode string @@ -754,43 +754,43 @@ func (a *API) verifyPhoneFactor(w http.ResponseWriter, r *http.Request, params * if config.Sms.IsTwilioVerifyProvider() { smsProvider, err := sms_provider.GetSmsProvider(*config) if err != nil { - return internalServerError("Failed to get SMS provider").WithInternalError(err) + return apierrors.NewInternalServerError("Failed to get SMS provider").WithInternalError(err) } if err := smsProvider.VerifyOTP(factor.Phone.String(), params.Code); err != nil { - return forbiddenError(apierrors.ErrorCodeOTPExpired, "Token has expired or is invalid").WithInternalError(err) + return apierrors.NewForbiddenError(apierrors.ErrorCodeOTPExpired, "Token has expired or is invalid").WithInternalError(err) } valid = true } else { otpCode, shouldReEncrypt, err = challenge.GetOtpCode(config.Security.DBEncryption.DecryptionKeys, config.Security.DBEncryption.Encrypt, config.Security.DBEncryption.EncryptionKeyID) if err != nil { - return internalServerError("Database error verifying MFA TOTP secret").WithInternalError(err) + return apierrors.NewInternalServerError("Database error verifying MFA TOTP secret").WithInternalError(err) } valid = subtle.ConstantTimeCompare([]byte(otpCode), []byte(params.Code)) == 1 } if config.Hook.MFAVerificationAttempt.Enabled { - input := hooks.MFAVerificationAttemptInput{ + input := v0hooks.MFAVerificationAttemptInput{ UserID: user.ID, FactorID: factor.ID, FactorType: factor.FactorType, Valid: valid, } - output := hooks.MFAVerificationAttemptOutput{} - err := a.invokeHook(nil, r, &input, &output) + output := v0hooks.MFAVerificationAttemptOutput{} + err := a.hooksMgr.InvokeHook(nil, r, &input, &output) if err != nil { return err } - if output.Decision == hooks.HookRejection { + if output.Decision == v0hooks.HookRejection { if err := models.Logout(db, user.ID); err != nil { return err } if output.Message == "" { - output.Message = hooks.DefaultMFAHookRejectionMessage + output.Message = v0hooks.DefaultMFAHookRejectionMessage } - return forbiddenError(apierrors.ErrorCodeMFAVerificationRejected, output.Message) + return apierrors.NewForbiddenError(apierrors.ErrorCodeMFAVerificationRejected, output.Message) } } if !valid { @@ -803,7 +803,7 @@ func (a *API) verifyPhoneFactor(w http.ResponseWriter, r *http.Request, params * return err } } - return unprocessableEntityError(apierrors.ErrorCodeMFAVerificationFailed, "Invalid MFA Phone code entered") + return apierrors.NewUnprocessableEntityError(apierrors.ErrorCodeMFAVerificationFailed, "Invalid MFA Phone code entered") } var token *AccessTokenResponse @@ -837,10 +837,10 @@ func (a *API) verifyPhoneFactor(w http.ResponseWriter, r *http.Request, params * return terr } if terr = models.InvalidateSessionsWithAALLessThan(tx, user.ID, models.AAL2.String()); terr != nil { - return internalServerError("Failed to update sessions. %s", terr) + return apierrors.NewInternalServerError("Failed to update sessions. %s", terr) } if terr = models.DeleteUnverifiedFactors(tx, user, factor.FactorType); terr != nil { - return internalServerError("Error removing unverified factors. %s", terr) + return apierrors.NewInternalServerError("Error removing unverified factors. %s", terr) } return nil }) @@ -864,11 +864,11 @@ func (a *API) verifyWebAuthnFactor(w http.ResponseWriter, r *http.Request, param switch { case params.WebAuthn == nil: - return badRequestError(apierrors.ErrorCodeValidationFailed, "WebAuthn config required") + return apierrors.NewBadRequestError(apierrors.ErrorCodeValidationFailed, "WebAuthn config required") case factor.IsVerified() && params.WebAuthn.AssertionResponse == nil: - return badRequestError(apierrors.ErrorCodeValidationFailed, "creation_response required to login") + return apierrors.NewBadRequestError(apierrors.ErrorCodeValidationFailed, "creation_response required to login") case factor.IsUnverified() && params.WebAuthn.CreationResponse == nil: - return badRequestError(apierrors.ErrorCodeValidationFailed, "assertion_response required to login") + return apierrors.NewBadRequestError(apierrors.ErrorCodeValidationFailed, "assertion_response required to login") default: webAuthn, err = params.WebAuthn.ToConfig() if err != nil { @@ -883,13 +883,13 @@ func (a *API) verifyWebAuthnFactor(w http.ResponseWriter, r *http.Request, param webAuthnSession := *challenge.WebAuthnSessionData.SessionData // Once the challenge is validated, we consume the challenge if err := db.Destroy(challenge); err != nil { - return internalServerError("Database error deleting challenge").WithInternalError(err) + return apierrors.NewInternalServerError("Database error deleting challenge").WithInternalError(err) } if factor.IsUnverified() { parsedResponse, err := wbnprotocol.ParseCredentialCreationResponseBody(bytes.NewReader(params.WebAuthn.CreationResponse)) if err != nil { - return badRequestError(apierrors.ErrorCodeValidationFailed, "Invalid credential_creation_response") + return apierrors.NewBadRequestError(apierrors.ErrorCodeValidationFailed, "Invalid credential_creation_response") } credential, err = webAuthn.CreateCredential(user, webAuthnSession, parsedResponse) if err != nil { @@ -899,11 +899,11 @@ func (a *API) verifyWebAuthnFactor(w http.ResponseWriter, r *http.Request, param } else if factor.IsVerified() { parsedResponse, err := wbnprotocol.ParseCredentialRequestResponseBody(bytes.NewReader(params.WebAuthn.AssertionResponse)) if err != nil { - return badRequestError(apierrors.ErrorCodeValidationFailed, "Invalid credential_request_response") + return apierrors.NewBadRequestError(apierrors.ErrorCodeValidationFailed, "Invalid credential_request_response") } credential, err = webAuthn.ValidateLogin(user, webAuthnSession, parsedResponse) if err != nil { - return internalServerError("Failed to validate WebAuthn MFA response").WithInternalError(err) + return apierrors.NewInternalServerError("Failed to validate WebAuthn MFA response").WithInternalError(err) } } var token *AccessTokenResponse @@ -936,10 +936,10 @@ func (a *API) verifyWebAuthnFactor(w http.ResponseWriter, r *http.Request, param return terr } if terr = models.InvalidateSessionsWithAALLessThan(tx, user.ID, models.AAL2.String()); terr != nil { - return internalServerError("Failed to update session").WithInternalError(terr) + return apierrors.NewInternalServerError("Failed to update session").WithInternalError(terr) } if terr = models.DeleteUnverifiedFactors(tx, user, models.WebAuthn); terr != nil { - return internalServerError("Failed to remove unverified MFA WebAuthn factors").WithInternalError(terr) + return apierrors.NewInternalServerError("Failed to remove unverified MFA WebAuthn factors").WithInternalError(terr) } return nil }) @@ -961,28 +961,28 @@ func (a *API) VerifyFactor(w http.ResponseWriter, r *http.Request) error { return err } if params.Code == "" && factor.FactorType != models.WebAuthn { - return badRequestError(apierrors.ErrorCodeValidationFailed, "Code needs to be non-empty") + return apierrors.NewBadRequestError(apierrors.ErrorCodeValidationFailed, "Code needs to be non-empty") } switch factor.FactorType { case models.Phone: if !config.MFA.Phone.VerifyEnabled { - return unprocessableEntityError(apierrors.ErrorCodeMFAPhoneVerifyDisabled, "MFA verification is disabled for Phone") + return apierrors.NewUnprocessableEntityError(apierrors.ErrorCodeMFAPhoneVerifyDisabled, "MFA verification is disabled for Phone") } return a.verifyPhoneFactor(w, r, params) case models.TOTP: if !config.MFA.TOTP.VerifyEnabled { - return unprocessableEntityError(apierrors.ErrorCodeMFATOTPVerifyDisabled, "MFA verification is disabled for TOTP") + return apierrors.NewUnprocessableEntityError(apierrors.ErrorCodeMFATOTPVerifyDisabled, "MFA verification is disabled for TOTP") } return a.verifyTOTPFactor(w, r, params) case models.WebAuthn: if !config.MFA.WebAuthn.VerifyEnabled { - return unprocessableEntityError(apierrors.ErrorCodeMFAWebAuthnEnrollDisabled, "MFA verification is disabled for WebAuthn") + return apierrors.NewUnprocessableEntityError(apierrors.ErrorCodeMFAWebAuthnEnrollDisabled, "MFA verification is disabled for WebAuthn") } return a.verifyWebAuthnFactor(w, r, params) default: - return badRequestError(apierrors.ErrorCodeValidationFailed, "factor_type needs to be totp, phone, or webauthn") + return apierrors.NewBadRequestError(apierrors.ErrorCodeValidationFailed, "factor_type needs to be totp, phone, or webauthn") } } @@ -996,11 +996,11 @@ func (a *API) UnenrollFactor(w http.ResponseWriter, r *http.Request) error { db := a.db.WithContext(ctx) if factor == nil || session == nil || user == nil { - return internalServerError("A valid session and factor are required to unenroll a factor") + return apierrors.NewInternalServerError("A valid session and factor are required to unenroll a factor") } if factor.IsVerified() && !session.IsAAL2() { - return unprocessableEntityError(apierrors.ErrorCodeInsufficientAAL, "AAL2 required to unenroll verified factor") + return apierrors.NewUnprocessableEntityError(apierrors.ErrorCodeInsufficientAAL, "AAL2 required to unenroll verified factor") } err = db.Transaction(func(tx *storage.Connection) error { diff --git a/internal/api/mfa_test.go b/internal/api/mfa_test.go index 6d9a5a3c3..3778eca92 100644 --- a/internal/api/mfa_test.go +++ b/internal/api/mfa_test.go @@ -749,7 +749,7 @@ func (ts *MFATestSuite) TestChallengeFactorNotOwnedByUser() { w := ServeAuthenticatedRequest(ts, http.MethodPost, fmt.Sprintf("http://localhost/factors/%s/challenge", otherUsersPhoneFactor.ID), signUpResp.Token, buffer) - expectedError := notFoundError(apierrors.ErrorCodeMFAFactorNotFound, "Factor not found") + expectedError := apierrors.NewNotFoundError(apierrors.ErrorCodeMFAFactorNotFound, "Factor not found") var data HTTPError require.NoError(ts.T(), json.NewDecoder(w.Body).Decode(&data)) diff --git a/internal/api/middleware.go b/internal/api/middleware.go index 43257e8d1..c974c7f91 100644 --- a/internal/api/middleware.go +++ b/internal/api/middleware.go @@ -67,7 +67,7 @@ func (a *API) performRateLimiting(lmt *limiter.Limiter, req *http.Request) error } else { err := tollbooth.LimitByKeys(lmt, []string{key}) if err != nil { - return tooManyRequestsError(apierrors.ErrorCodeOverRequestRateLimit, "Request rate limit reached") + return apierrors.NewTooManyRequestsError(apierrors.ErrorCodeOverRequestRateLimit, "Request rate limit reached") } } } @@ -100,7 +100,7 @@ func (a *API) requireEmailProvider(w http.ResponseWriter, req *http.Request) (co config := a.config if !config.External.Email.Enabled { - return nil, badRequestError(apierrors.ErrorCodeEmailProviderDisabled, "Email logins are disabled") + return nil, apierrors.NewBadRequestError(apierrors.ErrorCodeEmailProviderDisabled, "Email logins are disabled") } return ctx, nil @@ -128,11 +128,11 @@ func (a *API) verifyCaptcha(w http.ResponseWriter, req *http.Request) (context.C verificationResult, err := security.VerifyRequest(body, utilities.GetIPAddress(req), strings.TrimSpace(config.Security.Captcha.Secret), config.Security.Captcha.Provider) if err != nil { - return nil, internalServerError("captcha verification process failed").WithInternalError(err) + return nil, apierrors.NewInternalServerError("captcha verification process failed").WithInternalError(err) } if !verificationResult.Success { - return nil, badRequestError(apierrors.ErrorCodeCaptchaFailed, "captcha protection: request disallowed (%s)", strings.Join(verificationResult.ErrorCodes, ", ")) + return nil, apierrors.NewBadRequestError(apierrors.ErrorCodeCaptchaFailed, "captcha protection: request disallowed (%s)", strings.Join(verificationResult.ErrorCodes, ", ")) } return ctx, nil @@ -252,7 +252,7 @@ func (a *API) isValidExternalHost(w http.ResponseWriter, req *http.Request) (con func (a *API) requireSAMLEnabled(w http.ResponseWriter, req *http.Request) (context.Context, error) { ctx := req.Context() if !a.config.SAML.Enabled { - return nil, notFoundError(apierrors.ErrorCodeSAMLProviderDisabled, "SAML 2.0 is disabled") + return nil, apierrors.NewNotFoundError(apierrors.ErrorCodeSAMLProviderDisabled, "SAML 2.0 is disabled") } return ctx, nil } @@ -260,7 +260,7 @@ func (a *API) requireSAMLEnabled(w http.ResponseWriter, req *http.Request) (cont func (a *API) requireManualLinkingEnabled(w http.ResponseWriter, req *http.Request) (context.Context, error) { ctx := req.Context() if !a.config.Security.ManualLinkingEnabled { - return nil, notFoundError(apierrors.ErrorCodeManualLinkingDisabled, "Manual linking is disabled") + return nil, apierrors.NewNotFoundError(apierrors.ErrorCodeManualLinkingDisabled, "Manual linking is disabled") } return ctx, nil } diff --git a/internal/api/middleware_test.go b/internal/api/middleware_test.go index b0f319c18..68dbabb7c 100644 --- a/internal/api/middleware_test.go +++ b/internal/api/middleware_test.go @@ -328,7 +328,7 @@ func (ts *MiddlewareTestSuite) TestRequireSAMLEnabled() { { desc: "SAML not enabled", isEnabled: false, - expectedErr: notFoundError(apierrors.ErrorCodeSAMLProviderDisabled, "SAML 2.0 is disabled"), + expectedErr: apierrors.NewNotFoundError(apierrors.ErrorCodeSAMLProviderDisabled, "SAML 2.0 is disabled"), }, { desc: "SAML enabled", diff --git a/internal/api/otp.go b/internal/api/otp.go index 916d61deb..0e3e939f1 100644 --- a/internal/api/otp.go +++ b/internal/api/otp.go @@ -36,10 +36,10 @@ type SmsParams struct { func (p *OtpParams) Validate() error { if p.Email != "" && p.Phone != "" { - return badRequestError(apierrors.ErrorCodeValidationFailed, "Only an email address or phone number should be provided") + return apierrors.NewBadRequestError(apierrors.ErrorCodeValidationFailed, "Only an email address or phone number should be provided") } if p.Email != "" && p.Channel != "" { - return badRequestError(apierrors.ErrorCodeValidationFailed, "Channel should only be specified with Phone OTP") + return apierrors.NewBadRequestError(apierrors.ErrorCodeValidationFailed, "Channel should only be specified with Phone OTP") } if err := validatePKCEParams(p.CodeChallengeMethod, p.CodeChallenge); err != nil { return err @@ -54,7 +54,7 @@ func (p *SmsParams) Validate(config *conf.GlobalConfiguration) error { return err } if !sms_provider.IsValidMessageChannel(p.Channel, config) { - return badRequestError(apierrors.ErrorCodeValidationFailed, InvalidChannelError) + return apierrors.NewBadRequestError(apierrors.ErrorCodeValidationFailed, InvalidChannelError) } return nil } @@ -80,7 +80,7 @@ func (a *API) Otp(w http.ResponseWriter, r *http.Request) error { } if ok, err := a.shouldCreateUser(r, params); !ok { - return unprocessableEntityError(apierrors.ErrorCodeOTPDisabled, "Signups not allowed for otp") + return apierrors.NewUnprocessableEntityError(apierrors.ErrorCodeOTPDisabled, "Signups not allowed for otp") } else if err != nil { return err } @@ -91,7 +91,7 @@ func (a *API) Otp(w http.ResponseWriter, r *http.Request) error { return a.SmsOtp(w, r) } - return badRequestError(apierrors.ErrorCodeValidationFailed, "One of email or phone must be set") + return apierrors.NewBadRequestError(apierrors.ErrorCodeValidationFailed, "One of email or phone must be set") } type SmsOtpResponse struct { @@ -105,7 +105,7 @@ func (a *API) SmsOtp(w http.ResponseWriter, r *http.Request) error { config := a.config if !config.External.Phone.Enabled { - return badRequestError(apierrors.ErrorCodePhoneProviderDisabled, "Unsupported phone provider") + return apierrors.NewBadRequestError(apierrors.ErrorCodePhoneProviderDisabled, "Unsupported phone provider") } var err error @@ -130,7 +130,7 @@ func (a *API) SmsOtp(w http.ResponseWriter, r *http.Request) error { if models.IsNotFoundError(err) { isNewUser = true } else { - return internalServerError("Database error finding user").WithInternalError(err) + return apierrors.NewInternalServerError("Database error finding user").WithInternalError(err) } } if user != nil { @@ -141,7 +141,7 @@ func (a *API) SmsOtp(w http.ResponseWriter, r *http.Request) error { // Sign them up with temporary password. password, err := password.Generate(64, 10, 1, false, true) if err != nil { - return internalServerError("error creating user").WithInternalError(err) + return apierrors.NewInternalServerError("error creating user").WithInternalError(err) } signUpParams := &SignupParams{ diff --git a/internal/api/password.go b/internal/api/password.go index 78291a49f..47cc6755d 100644 --- a/internal/api/password.go +++ b/internal/api/password.go @@ -29,7 +29,7 @@ func (a *API) checkPasswordStrength(ctx context.Context, password string) error config := a.config if len(password) > MaxPasswordLength { - return badRequestError(apierrors.ErrorCodeValidationFailed, fmt.Sprintf("Password cannot be longer than %v characters", MaxPasswordLength)) + return apierrors.NewBadRequestError(apierrors.ErrorCodeValidationFailed, fmt.Sprintf("Password cannot be longer than %v characters", MaxPasswordLength)) } var messages, reasons []string @@ -53,7 +53,7 @@ func (a *API) checkPasswordStrength(ctx context.Context, password string) error pwned, err := a.hibpClient.Check(ctx, password) if err != nil { if config.Password.HIBP.FailClosed { - return internalServerError("Unable to perform password strength check with HaveIBeenPwned.org.").WithInternalError(err) + return apierrors.NewInternalServerError("Unable to perform password strength check with HaveIBeenPwned.org.").WithInternalError(err) } else { logrus.WithError(err).Warn("Unable to perform password strength check with HaveIBeenPwned.org, pwned passwords are being allowed") } diff --git a/internal/api/phone.go b/internal/api/phone.go index 5cede0ed0..9a8662dbb 100644 --- a/internal/api/phone.go +++ b/internal/api/phone.go @@ -8,12 +8,11 @@ import ( "text/template" "time" - "github.com/supabase/auth/internal/hooks" - "github.com/pkg/errors" "github.com/supabase/auth/internal/api/apierrors" "github.com/supabase/auth/internal/api/sms_provider" "github.com/supabase/auth/internal/crypto" + "github.com/supabase/auth/internal/hooks/v0hooks" "github.com/supabase/auth/internal/models" "github.com/supabase/auth/internal/storage" ) @@ -28,7 +27,7 @@ const ( func validatePhone(phone string) (string, error) { phone = formatPhoneNumber(phone) if isValid := validateE164Format(phone); !isValid { - return "", badRequestError(apierrors.ErrorCodeValidationFailed, "Invalid phone number format (E.164 required)") + return "", apierrors.NewBadRequestError(apierrors.ErrorCodeValidationFailed, "Invalid phone number format (E.164 required)") } return phone, nil } @@ -66,13 +65,13 @@ func (a *API) sendPhoneConfirmation(r *http.Request, tx *storage.Connection, use sentAt = user.ReauthenticationSentAt includeFields = append(includeFields, "reauthentication_token", "reauthentication_sent_at") default: - return "", internalServerError("invalid otp type") + return "", apierrors.NewInternalServerError("invalid otp type") } // intentionally keeping this before the test OTP, so that the behavior // of regular and test OTPs is similar if sentAt != nil && !sentAt.Add(config.Sms.MaxFrequency).Before(time.Now()) { - return "", tooManyRequestsError(apierrors.ErrorCodeOverSMSSendRateLimit, generateFrequencyLimitErrorMessage(sentAt, config.Sms.MaxFrequency)) + return "", apierrors.NewTooManyRequestsError(apierrors.ErrorCodeOverSMSSendRateLimit, generateFrequencyLimitErrorMessage(sentAt, config.Sms.MaxFrequency)) } now := time.Now() @@ -90,35 +89,35 @@ func (a *API) sendPhoneConfirmation(r *http.Request, tx *storage.Connection, use if !config.Sms.Autoconfirm { // apply rate limiting before the sms is sent out if ok := a.limiterOpts.Phone.Allow(); !ok { - return "", tooManyRequestsError(apierrors.ErrorCodeOverSMSSendRateLimit, "SMS rate limit exceeded") + return "", apierrors.NewTooManyRequestsError(apierrors.ErrorCodeOverSMSSendRateLimit, "SMS rate limit exceeded") } } otp = crypto.GenerateOtp(config.Sms.OtpLength) if config.Hook.SendSMS.Enabled { - input := hooks.SendSMSInput{ + input := v0hooks.SendSMSInput{ User: user, - SMS: hooks.SMS{ + SMS: v0hooks.SMS{ OTP: otp, }, } - output := hooks.SendSMSOutput{} - err := a.invokeHook(tx, r, &input, &output) + output := v0hooks.SendSMSOutput{} + err := a.hooksMgr.InvokeHook(tx, r, &input, &output) if err != nil { return "", err } } else { smsProvider, err := sms_provider.GetSmsProvider(*config) if err != nil { - return "", internalServerError("Unable to get SMS provider").WithInternalError(err) + return "", apierrors.NewInternalServerError("Unable to get SMS provider").WithInternalError(err) } message, err := generateSMSFromTemplate(config.Sms.SMSTemplate, otp) if err != nil { - return "", internalServerError("error generating sms template").WithInternalError(err) + return "", apierrors.NewInternalServerError("error generating sms template").WithInternalError(err) } messageID, err := smsProvider.SendMessage(phone, message, channel, otp) if err != nil { - return messageID, unprocessableEntityError(apierrors.ErrorCodeSMSSendFailed, "Error sending %s OTP to provider: %v", otpType, err) + return messageID, apierrors.NewUnprocessableEntityError(apierrors.ErrorCodeSMSSendFailed, "Error sending %s OTP to provider: %v", otpType, err) } } } @@ -154,7 +153,7 @@ func (a *API) sendPhoneConfirmation(r *http.Request, tx *storage.Connection, use } } if ottErr != nil { - return messageID, internalServerError("error creating one time token").WithInternalError(ottErr) + return messageID, apierrors.NewInternalServerError("error creating one time token").WithInternalError(ottErr) } return messageID, nil } diff --git a/internal/api/phone_test.go b/internal/api/phone_test.go index adc50f1a9..dc9180fea 100644 --- a/internal/api/phone_test.go +++ b/internal/api/phone_test.go @@ -14,6 +14,7 @@ import ( "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" "github.com/stretchr/testify/suite" + "github.com/supabase/auth/internal/api/apierrors" "github.com/supabase/auth/internal/api/sms_provider" "github.com/supabase/auth/internal/conf" "github.com/supabase/auth/internal/models" @@ -99,7 +100,7 @@ func doTestSendPhoneConfirmation(ts *PhoneTestSuite, useTestOTP bool) { { desc: "send invalid otp type ", otpType: "invalid otp type", - expected: internalServerError("invalid otp type"), + expected: apierrors.NewInternalServerError("invalid otp type"), }, } diff --git a/internal/api/pkce.go b/internal/api/pkce.go index b4feeae5b..f62c00625 100644 --- a/internal/api/pkce.go +++ b/internal/api/pkce.go @@ -22,9 +22,9 @@ func isValidCodeChallenge(codeChallenge string) (bool, error) { // See RFC 7636 Section 4.2: https://www.rfc-editor.org/rfc/rfc7636#section-4.2 switch codeChallengeLength := len(codeChallenge); { case codeChallengeLength < MinCodeChallengeLength, codeChallengeLength > MaxCodeChallengeLength: - return false, badRequestError(apierrors.ErrorCodeValidationFailed, "code challenge has to be between %v and %v characters", MinCodeChallengeLength, MaxCodeChallengeLength) + return false, apierrors.NewBadRequestError(apierrors.ErrorCodeValidationFailed, "code challenge has to be between %v and %v characters", MinCodeChallengeLength, MaxCodeChallengeLength) case !codeChallengePattern.MatchString(codeChallenge): - return false, badRequestError(apierrors.ErrorCodeValidationFailed, "code challenge can only contain alphanumeric characters, hyphens, periods, underscores and tildes") + return false, apierrors.NewBadRequestError(apierrors.ErrorCodeValidationFailed, "code challenge can only contain alphanumeric characters, hyphens, periods, underscores and tildes") default: return true, nil } @@ -42,7 +42,7 @@ func addFlowPrefixToToken(token string, flowType models.FlowType) string { func issueAuthCode(tx *storage.Connection, user *models.User, authenticationMethod models.AuthenticationMethod) (string, error) { flowState, err := models.FindFlowStateByUserID(tx, user.ID.String(), authenticationMethod) if err != nil && models.IsNotFoundError(err) { - return "", unprocessableEntityError(apierrors.ErrorCodeFlowStateNotFound, "No valid flow state found for user.") + return "", apierrors.NewUnprocessableEntityError(apierrors.ErrorCodeFlowStateNotFound, "No valid flow state found for user.") } else if err != nil { return "", err } @@ -64,7 +64,7 @@ func isImplicitFlow(flowType models.FlowType) bool { func validatePKCEParams(codeChallengeMethod, codeChallenge string) error { switch true { case (codeChallenge == "") != (codeChallengeMethod == ""): - return badRequestError(apierrors.ErrorCodeValidationFailed, InvalidPKCEParamsErrorMessage) + return apierrors.NewBadRequestError(apierrors.ErrorCodeValidationFailed, InvalidPKCEParamsErrorMessage) case codeChallenge != "": if valid, err := isValidCodeChallenge(codeChallenge); !valid { return err diff --git a/internal/api/reauthenticate.go b/internal/api/reauthenticate.go index 57eb90505..a28705caa 100644 --- a/internal/api/reauthenticate.go +++ b/internal/api/reauthenticate.go @@ -22,16 +22,16 @@ func (a *API) Reauthenticate(w http.ResponseWriter, r *http.Request) error { email, phone := user.GetEmail(), user.GetPhone() if email == "" && phone == "" { - return badRequestError(apierrors.ErrorCodeValidationFailed, "Reauthentication requires the user to have an email or a phone number") + return apierrors.NewBadRequestError(apierrors.ErrorCodeValidationFailed, "Reauthentication requires the user to have an email or a phone number") } if email != "" { if !user.IsConfirmed() { - return unprocessableEntityError(apierrors.ErrorCodeEmailNotConfirmed, "Please verify your email first.") + return apierrors.NewUnprocessableEntityError(apierrors.ErrorCodeEmailNotConfirmed, "Please verify your email first.") } } else if phone != "" { if !user.IsPhoneConfirmed() { - return unprocessableEntityError(apierrors.ErrorCodePhoneNotConfirmed, "Please verify your phone first.") + return apierrors.NewUnprocessableEntityError(apierrors.ErrorCodePhoneNotConfirmed, "Please verify your phone first.") } } @@ -68,7 +68,7 @@ func (a *API) Reauthenticate(w http.ResponseWriter, r *http.Request) error { // verifyReauthentication checks if the nonce provided is valid func (a *API) verifyReauthentication(nonce string, tx *storage.Connection, config *conf.GlobalConfiguration, user *models.User) error { if user.ReauthenticationToken == "" || user.ReauthenticationSentAt == nil { - return unprocessableEntityError(apierrors.ErrorCodeReauthenticationNotValid, InvalidNonceMessage) + return apierrors.NewUnprocessableEntityError(apierrors.ErrorCodeReauthenticationNotValid, InvalidNonceMessage) } var isValid bool if user.GetEmail() != "" { @@ -78,7 +78,7 @@ func (a *API) verifyReauthentication(nonce string, tx *storage.Connection, confi if config.Sms.IsTwilioVerifyProvider() { smsProvider, _ := sms_provider.GetSmsProvider(*config) if err := smsProvider.(*sms_provider.TwilioVerifyProvider).VerifyOTP(string(user.Phone), nonce); err != nil { - return forbiddenError(apierrors.ErrorCodeOTPExpired, "Token has expired or is invalid").WithInternalError(err) + return apierrors.NewForbiddenError(apierrors.ErrorCodeOTPExpired, "Token has expired or is invalid").WithInternalError(err) } return nil } else { @@ -86,13 +86,13 @@ func (a *API) verifyReauthentication(nonce string, tx *storage.Connection, confi isValid = isOtpValid(tokenHash, user.ReauthenticationToken, user.ReauthenticationSentAt, config.Sms.OtpExp) } } else { - return unprocessableEntityError(apierrors.ErrorCodeReauthenticationNotValid, "Reauthentication requires an email or a phone number") + return apierrors.NewUnprocessableEntityError(apierrors.ErrorCodeReauthenticationNotValid, "Reauthentication requires an email or a phone number") } if !isValid { - return unprocessableEntityError(apierrors.ErrorCodeReauthenticationNotValid, InvalidNonceMessage) + return apierrors.NewUnprocessableEntityError(apierrors.ErrorCodeReauthenticationNotValid, InvalidNonceMessage) } if err := user.ConfirmReauthentication(tx); err != nil { - return internalServerError("Error during reauthentication").WithInternalError(err) + return apierrors.NewInternalServerError("Error during reauthentication").WithInternalError(err) } return nil } diff --git a/internal/api/recover.go b/internal/api/recover.go index fed11d355..605b2fb5b 100644 --- a/internal/api/recover.go +++ b/internal/api/recover.go @@ -17,7 +17,7 @@ type RecoverParams struct { func (p *RecoverParams) Validate(a *API) error { if p.Email == "" { - return badRequestError(apierrors.ErrorCodeValidationFailed, "Password recovery requires an email") + return apierrors.NewBadRequestError(apierrors.ErrorCodeValidationFailed, "Password recovery requires an email") } var err error if p.Email, err = a.validateEmail(p.Email); err != nil { @@ -52,7 +52,7 @@ func (a *API) Recover(w http.ResponseWriter, r *http.Request) error { if models.IsNotFoundError(err) { return sendJSON(w, http.StatusOK, map[string]string{}) } - return internalServerError("Unable to process request").WithInternalError(err) + return apierrors.NewInternalServerError("Unable to process request").WithInternalError(err) } if isPKCEFlow(flowType) { if _, err := generateFlowState(db, models.Recovery.String(), models.Recovery, params.CodeChallengeMethod, params.CodeChallenge, &(user.ID)); err != nil { diff --git a/internal/api/resend.go b/internal/api/resend.go index bc3ed45cb..e0e049278 100644 --- a/internal/api/resend.go +++ b/internal/api/resend.go @@ -25,22 +25,22 @@ func (p *ResendConfirmationParams) Validate(a *API) error { break default: // type does not match one of the above - return badRequestError(apierrors.ErrorCodeValidationFailed, "Missing one of these types: signup, email_change, sms, phone_change") + return apierrors.NewBadRequestError(apierrors.ErrorCodeValidationFailed, "Missing one of these types: signup, email_change, sms, phone_change") } if p.Email == "" && p.Type == mail.SignupVerification { - return badRequestError(apierrors.ErrorCodeValidationFailed, "Type provided requires an email address") + return apierrors.NewBadRequestError(apierrors.ErrorCodeValidationFailed, "Type provided requires an email address") } if p.Phone == "" && p.Type == smsVerification { - return badRequestError(apierrors.ErrorCodeValidationFailed, "Type provided requires a phone number") + return apierrors.NewBadRequestError(apierrors.ErrorCodeValidationFailed, "Type provided requires a phone number") } var err error if p.Email != "" && p.Phone != "" { - return badRequestError(apierrors.ErrorCodeValidationFailed, "Only an email address or phone number should be provided.") + return apierrors.NewBadRequestError(apierrors.ErrorCodeValidationFailed, "Only an email address or phone number should be provided.") } else if p.Email != "" { if !config.External.Email.Enabled { - return badRequestError(apierrors.ErrorCodeEmailProviderDisabled, "Email logins are disabled") + return apierrors.NewBadRequestError(apierrors.ErrorCodeEmailProviderDisabled, "Email logins are disabled") } p.Email, err = a.validateEmail(p.Email) if err != nil { @@ -48,7 +48,7 @@ func (p *ResendConfirmationParams) Validate(a *API) error { } } else if p.Phone != "" { if !config.External.Phone.Enabled { - return badRequestError(apierrors.ErrorCodePhoneProviderDisabled, "Phone logins are disabled") + return apierrors.NewBadRequestError(apierrors.ErrorCodePhoneProviderDisabled, "Phone logins are disabled") } p.Phone, err = validatePhone(p.Phone) if err != nil { @@ -56,7 +56,7 @@ func (p *ResendConfirmationParams) Validate(a *API) error { } } else { // both email and phone are empty - return badRequestError(apierrors.ErrorCodeValidationFailed, "Missing email address or phone number") + return apierrors.NewBadRequestError(apierrors.ErrorCodeValidationFailed, "Missing email address or phone number") } return nil } @@ -87,7 +87,7 @@ func (a *API) Resend(w http.ResponseWriter, r *http.Request) error { if models.IsNotFoundError(err) { return sendJSON(w, http.StatusOK, map[string]string{}) } - return internalServerError("Unable to process request").WithInternalError(err) + return apierrors.NewInternalServerError("Unable to process request").WithInternalError(err) } switch params.Type { diff --git a/internal/api/samlacs.go b/internal/api/samlacs.go index a6dfc42aa..3f64fc235 100644 --- a/internal/api/samlacs.go +++ b/internal/api/samlacs.go @@ -48,7 +48,7 @@ func (a *API) SamlAcs(w http.ResponseWriter, r *http.Request) error { if err := a.handleSamlAcs(w, r); err != nil { u, uerr := url.Parse(a.config.SiteURL) if uerr != nil { - return internalServerError("site url is improperly formattted").WithInternalError(err) + return apierrors.NewInternalServerError("site url is improperly formattted").WithInternalError(err) } q := getErrorQueryString(err, utilities.GetRequestID(r.Context()), observability.GetLogEntry(r).Entry, u.Query()) @@ -81,17 +81,17 @@ func (a *API) handleSamlAcs(w http.ResponseWriter, r *http.Request) error { relayState, err := models.FindSAMLRelayStateByID(db, relayStateUUID) if models.IsNotFoundError(err) { - return notFoundError(apierrors.ErrorCodeSAMLRelayStateNotFound, "SAML RelayState does not exist, try logging in again?") + return apierrors.NewNotFoundError(apierrors.ErrorCodeSAMLRelayStateNotFound, "SAML RelayState does not exist, try logging in again?") } else if err != nil { return err } if time.Since(relayState.CreatedAt) >= a.config.SAML.RelayStateValidityPeriod { if err := a.samlDestroyRelayState(ctx, relayState); err != nil { - return internalServerError("SAML RelayState has expired and destroying it failed. Try logging in again?").WithInternalError(err) + return apierrors.NewInternalServerError("SAML RelayState has expired and destroying it failed. Try logging in again?").WithInternalError(err) } - return unprocessableEntityError(apierrors.ErrorCodeSAMLRelayStateExpired, "SAML RelayState has expired. Try logging in again?") + return apierrors.NewUnprocessableEntityError(apierrors.ErrorCodeSAMLRelayStateExpired, "SAML RelayState has expired. Try logging in again?") } // TODO: add abuse detection to bind the RelayState UUID with a @@ -99,7 +99,7 @@ func (a *API) handleSamlAcs(w http.ResponseWriter, r *http.Request) error { ssoProvider, err := models.FindSSOProviderByID(db, relayState.SSOProviderID) if err != nil { - return internalServerError("Unable to find SSO Provider from SAML RelayState") + return apierrors.NewInternalServerError("Unable to find SSO Provider from SAML RelayState") } initiatedBy = "sp" @@ -121,23 +121,23 @@ func (a *API) handleSamlAcs(w http.ResponseWriter, r *http.Request) error { // SAML Artifact responses are possible only when // RelayState can be used to identify the Identity // Provider. - return badRequestError(apierrors.ErrorCodeValidationFailed, "SAML Artifact response can only be used with SP initiated flow") + return apierrors.NewBadRequestError(apierrors.ErrorCodeValidationFailed, "SAML Artifact response can only be used with SP initiated flow") } samlResponse := r.FormValue("SAMLResponse") if samlResponse == "" { - return badRequestError(apierrors.ErrorCodeValidationFailed, "SAMLResponse is missing") + return apierrors.NewBadRequestError(apierrors.ErrorCodeValidationFailed, "SAMLResponse is missing") } responseXML, err := base64.StdEncoding.DecodeString(samlResponse) if err != nil { - return badRequestError(apierrors.ErrorCodeValidationFailed, "SAMLResponse is not a valid Base64 string") + return apierrors.NewBadRequestError(apierrors.ErrorCodeValidationFailed, "SAMLResponse is not a valid Base64 string") } var peekResponse saml.Response err = xml.Unmarshal(responseXML, &peekResponse) if err != nil { - return badRequestError(apierrors.ErrorCodeValidationFailed, "SAMLResponse is not a valid XML SAML assertion").WithInternalError(err) + return apierrors.NewBadRequestError(apierrors.ErrorCodeValidationFailed, "SAMLResponse is not a valid XML SAML assertion").WithInternalError(err) } initiatedBy = "idp" @@ -145,12 +145,12 @@ func (a *API) handleSamlAcs(w http.ResponseWriter, r *http.Request) error { redirectTo = relayStateValue } else { // RelayState can't be identified, so SAML flow can't continue - return badRequestError(apierrors.ErrorCodeValidationFailed, "SAML RelayState is not a valid UUID or URL") + return apierrors.NewBadRequestError(apierrors.ErrorCodeValidationFailed, "SAML RelayState is not a valid UUID or URL") } ssoProvider, err := models.FindSAMLProviderByEntityID(db, entityId) if models.IsNotFoundError(err) { - return notFoundError(apierrors.ErrorCodeSAMLIdPNotFound, "A SAML connection has not been established with this Identity Provider") + return apierrors.NewNotFoundError(apierrors.ErrorCodeSAMLIdPNotFound, "A SAML connection has not been established with this Identity Provider") } else if err != nil { return err } @@ -190,10 +190,10 @@ func (a *API) handleSamlAcs(w http.ResponseWriter, r *http.Request) error { spAssertion, err := serviceProvider.ParseResponse(r, requestIds) if err != nil { if ire, ok := err.(*saml.InvalidResponseError); ok { - return badRequestError(apierrors.ErrorCodeValidationFailed, "SAML Assertion is not valid %s", ire.Response).WithInternalError(ire.PrivateErr) + return apierrors.NewBadRequestError(apierrors.ErrorCodeValidationFailed, "SAML Assertion is not valid %s", ire.Response).WithInternalError(ire.PrivateErr) } - return badRequestError(apierrors.ErrorCodeValidationFailed, "SAML Assertion is not valid").WithInternalError(err) + return apierrors.NewBadRequestError(apierrors.ErrorCodeValidationFailed, "SAML Assertion is not valid").WithInternalError(err) } assertion := SAMLAssertion{ @@ -202,7 +202,7 @@ func (a *API) handleSamlAcs(w http.ResponseWriter, r *http.Request) error { userID := assertion.UserID() if userID == "" { - return badRequestError(apierrors.ErrorCodeSAMLAssertionNoUserID, "SAML Assertion did not contain a persistent Subject Identifier attribute or Subject NameID uniquely identifying this user") + return apierrors.NewBadRequestError(apierrors.ErrorCodeSAMLAssertionNoUserID, "SAML Assertion did not contain a persistent Subject Identifier attribute or Subject NameID uniquely identifying this user") } claims := assertion.Process(ssoProvider.SAMLProvider.AttributeMapping) @@ -214,19 +214,19 @@ func (a *API) handleSamlAcs(w http.ResponseWriter, r *http.Request) error { } if email == "" { - return badRequestError(apierrors.ErrorCodeSAMLAssertionNoEmail, "SAML Assertion does not contain an email address") + return apierrors.NewBadRequestError(apierrors.ErrorCodeSAMLAssertionNoEmail, "SAML Assertion does not contain an email address") } else { claims["email"] = email } jsonClaims, err := json.Marshal(claims) if err != nil { - return internalServerError("Mapped claims from provider could not be serialized into JSON").WithInternalError(err) + return apierrors.NewInternalServerError("Mapped claims from provider could not be serialized into JSON").WithInternalError(err) } providerClaims := &provider.Claims{} if err := json.Unmarshal(jsonClaims, providerClaims); err != nil { - return internalServerError("Mapped claims from provider could not be deserialized from JSON").WithInternalError(err) + return apierrors.NewInternalServerError("Mapped claims from provider could not be deserialized from JSON").WithInternalError(err) } providerClaims.Subject = userID @@ -300,7 +300,7 @@ func (a *API) handleSamlAcs(w http.ResponseWriter, r *http.Request) error { token, terr = a.issueRefreshToken(r, tx, user, models.SSOSAML, grantParams) if terr != nil { - return internalServerError("Unable to issue refresh token from SAML Assertion").WithInternalError(terr) + return apierrors.NewInternalServerError("Unable to issue refresh token from SAML Assertion").WithInternalError(terr) } return nil diff --git a/internal/api/signup.go b/internal/api/signup.go index f3f89a77c..09ac43524 100644 --- a/internal/api/signup.go +++ b/internal/api/signup.go @@ -33,21 +33,21 @@ func (a *API) validateSignupParams(ctx context.Context, p *SignupParams) error { config := a.config if p.Password == "" { - return badRequestError(apierrors.ErrorCodeValidationFailed, "Signup requires a valid password") + return apierrors.NewBadRequestError(apierrors.ErrorCodeValidationFailed, "Signup requires a valid password") } if err := a.checkPasswordStrength(ctx, p.Password); err != nil { return err } if p.Email != "" && p.Phone != "" { - return badRequestError(apierrors.ErrorCodeValidationFailed, "Only an email address or phone number should be provided on signup.") + return apierrors.NewBadRequestError(apierrors.ErrorCodeValidationFailed, "Only an email address or phone number should be provided on signup.") } if p.Provider == "phone" && !sms_provider.IsValidMessageChannel(p.Channel, config) { - return badRequestError(apierrors.ErrorCodeValidationFailed, InvalidChannelError) + return apierrors.NewBadRequestError(apierrors.ErrorCodeValidationFailed, InvalidChannelError) } // PKCE not needed as phone signups already return access token in body if p.Phone != "" && p.CodeChallenge != "" { - return badRequestError(apierrors.ErrorCodeValidationFailed, "PKCE not supported for phone signups") + return apierrors.NewBadRequestError(apierrors.ErrorCodeValidationFailed, "PKCE not supported for phone signups") } if err := validatePKCEParams(p.CodeChallengeMethod, p.CodeChallenge); err != nil { return err @@ -86,7 +86,7 @@ func (params *SignupParams) ToUserModel(isSSOUser bool) (user *models.User, err user, err = models.NewUser("", params.Email, params.Password, params.Aud, params.Data) } if err != nil { - err = internalServerError("Database error creating user").WithInternalError(err) + err = apierrors.NewInternalServerError("Database error creating user").WithInternalError(err) return } user.IsSSOUser = isSSOUser @@ -113,7 +113,7 @@ func (a *API) Signup(w http.ResponseWriter, r *http.Request) error { db := a.db.WithContext(ctx) if config.DisableSignup { - return unprocessableEntityError(apierrors.ErrorCodeSignupDisabled, "Signups not allowed for this instance") + return apierrors.NewUnprocessableEntityError(apierrors.ErrorCodeSignupDisabled, "Signups not allowed for this instance") } params := &SignupParams{} @@ -140,7 +140,7 @@ func (a *API) Signup(w http.ResponseWriter, r *http.Request) error { switch params.Provider { case "email": if !config.External.Email.Enabled { - return badRequestError(apierrors.ErrorCodeEmailProviderDisabled, "Email signups are disabled") + return apierrors.NewBadRequestError(apierrors.ErrorCodeEmailProviderDisabled, "Email signups are disabled") } params.Email, err = a.validateEmail(params.Email) if err != nil { @@ -149,7 +149,7 @@ func (a *API) Signup(w http.ResponseWriter, r *http.Request) error { user, err = models.IsDuplicatedEmail(db, params.Email, params.Aud, nil) case "phone": if !config.External.Phone.Enabled { - return badRequestError(apierrors.ErrorCodePhoneProviderDisabled, "Phone signups are disabled") + return apierrors.NewBadRequestError(apierrors.ErrorCodePhoneProviderDisabled, "Phone signups are disabled") } params.Phone, err = validatePhone(params.Phone) if err != nil { @@ -168,11 +168,11 @@ func (a *API) Signup(w http.ResponseWriter, r *http.Request) error { msg = "Sign up with this provider not possible" } - return badRequestError(apierrors.ErrorCodeValidationFailed, msg) + return apierrors.NewBadRequestError(apierrors.ErrorCodeValidationFailed, msg) } if err != nil && !models.IsNotFoundError(err) { - return internalServerError("Database error finding user").WithInternalError(err) + return apierrors.NewInternalServerError("Database error finding user").WithInternalError(err) } var signupUser *models.User @@ -230,7 +230,7 @@ func (a *API) Signup(w http.ResponseWriter, r *http.Request) error { return terr } if terr = user.Confirm(tx); terr != nil { - return internalServerError("Database error updating user").WithInternalError(terr) + return apierrors.NewInternalServerError("Database error updating user").WithInternalError(terr) } } else { if terr = models.NewAuditLogEntry(r, tx, user, models.UserConfirmationRequestedAction, "", map[string]interface{}{ @@ -257,7 +257,7 @@ func (a *API) Signup(w http.ResponseWriter, r *http.Request) error { return terr } if terr = user.ConfirmPhone(tx); terr != nil { - return internalServerError("Database error updating user").WithInternalError(terr) + return apierrors.NewInternalServerError("Database error updating user").WithInternalError(terr) } } else { if terr = models.NewAuditLogEntry(r, tx, user, models.UserConfirmationRequestedAction, "", map[string]interface{}{ @@ -288,7 +288,7 @@ func (a *API) Signup(w http.ResponseWriter, r *http.Request) error { return err } if config.Mailer.Autoconfirm || config.Sms.Autoconfirm { - return unprocessableEntityError(apierrors.ErrorCodeUserAlreadyExists, "User already registered") + return apierrors.NewUnprocessableEntityError(apierrors.ErrorCodeUserAlreadyExists, "User already registered") } sanitizedUser, err := sanitizeUser(user, params) if err != nil { @@ -369,10 +369,10 @@ func (a *API) signupNewUser(conn *storage.Connection, user *models.User) (*model err := conn.Transaction(func(tx *storage.Connection) error { var terr error if terr = tx.Create(user); terr != nil { - return internalServerError("Database error saving new user").WithInternalError(terr) + return apierrors.NewInternalServerError("Database error saving new user").WithInternalError(terr) } if terr = user.SetRole(tx, config.JWT.DefaultGroupName); terr != nil { - return internalServerError("Database error updating user").WithInternalError(terr) + return apierrors.NewInternalServerError("Database error updating user").WithInternalError(terr) } return nil }) @@ -384,7 +384,7 @@ func (a *API) signupNewUser(conn *storage.Connection, user *models.User) (*model // user data as it is being inserted. thus we load the user object // again to fetch those changes. if err := conn.Reload(user); err != nil { - return nil, internalServerError("Database error loading user after sign-up").WithInternalError(err) + return nil, apierrors.NewInternalServerError("Database error loading user after sign-up").WithInternalError(err) } return user, nil diff --git a/internal/api/sso.go b/internal/api/sso.go index 39304fe78..83667cfec 100644 --- a/internal/api/sso.go +++ b/internal/api/sso.go @@ -28,9 +28,9 @@ func (p *SingleSignOnParams) validate() (bool, error) { hasDomain := p.Domain != "" if hasProviderID && hasDomain { - return hasProviderID, badRequestError(apierrors.ErrorCodeValidationFailed, "Only one of provider_id or domain supported") + return hasProviderID, apierrors.NewBadRequestError(apierrors.ErrorCodeValidationFailed, "Only one of provider_id or domain supported") } else if !hasProviderID && !hasDomain { - return hasProviderID, badRequestError(apierrors.ErrorCodeValidationFailed, "A provider_id or domain needs to be provided") + return hasProviderID, apierrors.NewBadRequestError(apierrors.ErrorCodeValidationFailed, "A provider_id or domain needs to be provided") } return hasProviderID, nil @@ -74,22 +74,22 @@ func (a *API) SingleSignOn(w http.ResponseWriter, r *http.Request) error { if hasProviderID { ssoProvider, err = models.FindSSOProviderByID(db, params.ProviderID) if models.IsNotFoundError(err) { - return notFoundError(apierrors.ErrorCodeSSOProviderNotFound, "No such SSO provider") + return apierrors.NewNotFoundError(apierrors.ErrorCodeSSOProviderNotFound, "No such SSO provider") } else if err != nil { - return internalServerError("Unable to find SSO provider by ID").WithInternalError(err) + return apierrors.NewInternalServerError("Unable to find SSO provider by ID").WithInternalError(err) } } else { ssoProvider, err = models.FindSSOProviderByDomain(db, params.Domain) if models.IsNotFoundError(err) { - return notFoundError(apierrors.ErrorCodeSSOProviderNotFound, "No SSO provider assigned for this domain") + return apierrors.NewNotFoundError(apierrors.ErrorCodeSSOProviderNotFound, "No SSO provider assigned for this domain") } else if err != nil { - return internalServerError("Unable to find SSO provider by domain").WithInternalError(err) + return apierrors.NewInternalServerError("Unable to find SSO provider by domain").WithInternalError(err) } } entityDescriptor, err := ssoProvider.SAMLProvider.EntityDescriptor() if err != nil { - return internalServerError("Error parsing SAML Metadata for SAML provider").WithInternalError(err) + return apierrors.NewInternalServerError("Error parsing SAML Metadata for SAML provider").WithInternalError(err) } serviceProvider := a.getSAMLServiceProvider(entityDescriptor, false /* <- idpInitiated */) @@ -100,7 +100,7 @@ func (a *API) SingleSignOn(w http.ResponseWriter, r *http.Request) error { saml.HTTPPostBinding, ) if err != nil { - return internalServerError("Error creating SAML Authentication Request").WithInternalError(err) + return apierrors.NewInternalServerError("Error creating SAML Authentication Request").WithInternalError(err) } // Some IdPs do not support the use of the `persistent` NameID format, @@ -118,7 +118,7 @@ func (a *API) SingleSignOn(w http.ResponseWriter, r *http.Request) error { if err := db.Transaction(func(tx *storage.Connection) error { if terr := tx.Create(&relayState); terr != nil { - return internalServerError("Error creating SAML relay state from sign up").WithInternalError(err) + return apierrors.NewInternalServerError("Error creating SAML relay state from sign up").WithInternalError(err) } return nil @@ -128,7 +128,7 @@ func (a *API) SingleSignOn(w http.ResponseWriter, r *http.Request) error { ssoRedirectURL, err := authnRequest.Redirect(relayState.ID.String(), serviceProvider) if err != nil { - return internalServerError("Error creating SAML authentication request redirect URL").WithInternalError(err) + return apierrors.NewInternalServerError("Error creating SAML authentication request redirect URL").WithInternalError(err) } skipHTTPRedirect := false diff --git a/internal/api/ssoadmin.go b/internal/api/ssoadmin.go index fff2813c7..682423a56 100644 --- a/internal/api/ssoadmin.go +++ b/internal/api/ssoadmin.go @@ -30,16 +30,16 @@ func (a *API) loadSSOProvider(w http.ResponseWriter, r *http.Request) (context.C idpID, err := uuid.FromString(idpParam) if err != nil { // idpParam is not UUIDv4 - return nil, notFoundError(apierrors.ErrorCodeSSOProviderNotFound, "SSO Identity Provider not found") + return nil, apierrors.NewNotFoundError(apierrors.ErrorCodeSSOProviderNotFound, "SSO Identity Provider not found") } // idpParam is a UUIDv4 provider, err := models.FindSSOProviderByID(db, idpID) if err != nil { if models.IsNotFoundError(err) { - return nil, notFoundError(apierrors.ErrorCodeSSOProviderNotFound, "SSO Identity Provider not found") + return nil, apierrors.NewNotFoundError(apierrors.ErrorCodeSSOProviderNotFound, "SSO Identity Provider not found") } else { - return nil, internalServerError("Database error finding SSO Identity Provider").WithInternalError(err) + return nil, apierrors.NewInternalServerError("Database error finding SSO Identity Provider").WithInternalError(err) } } @@ -81,19 +81,19 @@ type CreateSSOProviderParams struct { func (p *CreateSSOProviderParams) validate(forUpdate bool) error { if !forUpdate && p.Type != "saml" { - return badRequestError(apierrors.ErrorCodeValidationFailed, "Only 'saml' supported for SSO provider type") + return apierrors.NewBadRequestError(apierrors.ErrorCodeValidationFailed, "Only 'saml' supported for SSO provider type") } else if p.MetadataURL != "" && p.MetadataXML != "" { - return badRequestError(apierrors.ErrorCodeValidationFailed, "Only one of metadata_xml or metadata_url needs to be set") + return apierrors.NewBadRequestError(apierrors.ErrorCodeValidationFailed, "Only one of metadata_xml or metadata_url needs to be set") } else if !forUpdate && p.MetadataURL == "" && p.MetadataXML == "" { - return badRequestError(apierrors.ErrorCodeValidationFailed, "Either metadata_xml or metadata_url must be set") + return apierrors.NewBadRequestError(apierrors.ErrorCodeValidationFailed, "Either metadata_xml or metadata_url must be set") } else if p.MetadataURL != "" { metadataURL, err := url.ParseRequestURI(p.MetadataURL) if err != nil { - return badRequestError(apierrors.ErrorCodeValidationFailed, "metadata_url is not a valid URL") + return apierrors.NewBadRequestError(apierrors.ErrorCodeValidationFailed, "metadata_url is not a valid URL") } if metadataURL.Scheme != "https" { - return badRequestError(apierrors.ErrorCodeValidationFailed, "metadata_url is not a HTTPS URL") + return apierrors.NewBadRequestError(apierrors.ErrorCodeValidationFailed, "metadata_url is not a HTTPS URL") } } @@ -106,7 +106,7 @@ func (p *CreateSSOProviderParams) validate(forUpdate bool) error { // it's valid default: - return badRequestError(apierrors.ErrorCodeValidationFailed, "name_id_format must be unspecified or one of %v", strings.Join([]string{ + return apierrors.NewBadRequestError(apierrors.ErrorCodeValidationFailed, "name_id_format must be unspecified or one of %v", strings.Join([]string{ string(saml.PersistentNameIDFormat), string(saml.EmailAddressNameIDFormat), string(saml.TransientNameIDFormat), @@ -143,7 +143,7 @@ func (p *CreateSSOProviderParams) metadata(ctx context.Context) ([]byte, *saml.E func parseSAMLMetadata(rawMetadata []byte) (*saml.EntityDescriptor, error) { if !utf8.Valid(rawMetadata) { - return nil, badRequestError(apierrors.ErrorCodeValidationFailed, "SAML Metadata XML contains invalid UTF-8 characters, which are not supported at this time") + return nil, apierrors.NewBadRequestError(apierrors.ErrorCodeValidationFailed, "SAML Metadata XML contains invalid UTF-8 characters, which are not supported at this time") } metadata, err := samlsp.ParseMetadata(rawMetadata) @@ -152,15 +152,15 @@ func parseSAMLMetadata(rawMetadata []byte) (*saml.EntityDescriptor, error) { } if metadata.EntityID == "" { - return nil, badRequestError(apierrors.ErrorCodeValidationFailed, "SAML Metadata does not contain an EntityID") + return nil, apierrors.NewBadRequestError(apierrors.ErrorCodeValidationFailed, "SAML Metadata does not contain an EntityID") } if len(metadata.IDPSSODescriptors) < 1 { - return nil, badRequestError(apierrors.ErrorCodeValidationFailed, "SAML Metadata does not contain any IDPSSODescriptor") + return nil, apierrors.NewBadRequestError(apierrors.ErrorCodeValidationFailed, "SAML Metadata does not contain any IDPSSODescriptor") } if len(metadata.IDPSSODescriptors) > 1 { - return nil, badRequestError(apierrors.ErrorCodeValidationFailed, "SAML Metadata contains multiple IDPSSODescriptors") + return nil, apierrors.NewBadRequestError(apierrors.ErrorCodeValidationFailed, "SAML Metadata contains multiple IDPSSODescriptors") } return metadata, nil @@ -169,7 +169,7 @@ func parseSAMLMetadata(rawMetadata []byte) (*saml.EntityDescriptor, error) { func fetchSAMLMetadata(ctx context.Context, url string) ([]byte, error) { req, err := http.NewRequest(http.MethodGet, url, nil) if err != nil { - return nil, internalServerError("Unable to create a request to metadata_url").WithInternalError(err) + return nil, apierrors.NewInternalServerError("Unable to create a request to metadata_url").WithInternalError(err) } req = req.WithContext(ctx) @@ -184,7 +184,7 @@ func fetchSAMLMetadata(ctx context.Context, url string) ([]byte, error) { defer utilities.SafeClose(resp.Body) if resp.StatusCode != http.StatusOK { - return nil, badRequestError(apierrors.ErrorCodeSAMLMetadataFetchFailed, "HTTP %v error fetching SAML Metadata from URL '%s'", resp.StatusCode, url) + return nil, apierrors.NewBadRequestError(apierrors.ErrorCodeSAMLMetadataFetchFailed, "HTTP %v error fetching SAML Metadata from URL '%s'", resp.StatusCode, url) } data, err := io.ReadAll(resp.Body) @@ -219,7 +219,7 @@ func (a *API) adminSSOProvidersCreate(w http.ResponseWriter, r *http.Request) er return err } if existingProvider != nil { - return unprocessableEntityError(apierrors.ErrorCodeSAMLIdPAlreadyExists, "SAML Identity Provider with this EntityID (%s) already exists", metadata.EntityID) + return apierrors.NewUnprocessableEntityError(apierrors.ErrorCodeSAMLIdPAlreadyExists, "SAML Identity Provider with this EntityID (%s) already exists", metadata.EntityID) } provider := &models.SSOProvider{ @@ -246,7 +246,7 @@ func (a *API) adminSSOProvidersCreate(w http.ResponseWriter, r *http.Request) er return err } if existingProvider != nil { - return badRequestError(apierrors.ErrorCodeSSODomainAlreadyExists, "SSO Domain '%s' is already assigned to an SSO identity provider (%s)", domain, existingProvider.ID.String()) + return apierrors.NewBadRequestError(apierrors.ErrorCodeSSODomainAlreadyExists, "SSO Domain '%s' is already assigned to an SSO identity provider (%s)", domain, existingProvider.ID.String()) } provider.SSODomains = append(provider.SSODomains, models.SSODomain{ @@ -301,7 +301,7 @@ func (a *API) adminSSOProvidersUpdate(w http.ResponseWriter, r *http.Request) er } if provider.SAMLProvider.EntityID != metadata.EntityID { - return badRequestError(apierrors.ErrorCodeSAMLEntityIDMismatch, "SAML Metadata can be updated only if the EntityID matches for the provider; expected '%s' but got '%s'", provider.SAMLProvider.EntityID, metadata.EntityID) + return apierrors.NewBadRequestError(apierrors.ErrorCodeSAMLEntityIDMismatch, "SAML Metadata can be updated only if the EntityID matches for the provider; expected '%s' but got '%s'", provider.SAMLProvider.EntityID, metadata.EntityID) } if params.MetadataURL != "" { @@ -330,7 +330,7 @@ func (a *API) adminSSOProvidersUpdate(w http.ResponseWriter, r *http.Request) er if existingProvider.ID == provider.ID { keepDomains[domain] = true } else { - return badRequestError(apierrors.ErrorCodeSSODomainAlreadyExists, "SSO domain '%s' already assigned to another provider (%s)", domain, existingProvider.ID.String()) + return apierrors.NewBadRequestError(apierrors.ErrorCodeSSODomainAlreadyExists, "SSO domain '%s' already assigned to another provider (%s)", domain, existingProvider.ID.String()) } } else { modified = true @@ -398,7 +398,7 @@ func (a *API) adminSSOProvidersUpdate(w http.ResponseWriter, r *http.Request) er return tx.Eager().Load(provider) }); err != nil { - return unprocessableEntityError(apierrors.ErrorCodeConflict, "Updating SSO provider failed, likely due to a conflict. Try again?").WithInternalError(err) + return apierrors.NewUnprocessableEntityError(apierrors.ErrorCodeConflict, "Updating SSO provider failed, likely due to a conflict. Try again?").WithInternalError(err) } } diff --git a/internal/api/token.go b/internal/api/token.go index 65ea94609..fd4631f55 100644 --- a/internal/api/token.go +++ b/internal/api/token.go @@ -13,7 +13,7 @@ import ( "github.com/xeipuuv/gojsonschema" "github.com/supabase/auth/internal/api/apierrors" - "github.com/supabase/auth/internal/hooks" + "github.com/supabase/auth/internal/hooks/v0hooks" "github.com/supabase/auth/internal/metering" "github.com/supabase/auth/internal/models" "github.com/supabase/auth/internal/observability" @@ -96,7 +96,7 @@ func (a *API) Token(w http.ResponseWriter, r *http.Request) error { handler = a.Web3Grant limiter = a.limiterOpts.Web3 default: - return badRequestError(apierrors.ErrorCodeInvalidCredentials, "unsupported_grant_type") + return apierrors.NewBadRequestError(apierrors.ErrorCodeInvalidCredentials, "unsupported_grant_type") } if err := a.performRateLimiting(limiter, r); err != nil { @@ -119,7 +119,7 @@ func (a *API) ResourceOwnerPasswordGrant(ctx context.Context, w http.ResponseWri config := a.config if params.Email != "" && params.Phone != "" { - return badRequestError(apierrors.ErrorCodeValidationFailed, "Only an email address or phone number should be provided on login.") + return apierrors.NewBadRequestError(apierrors.ErrorCodeValidationFailed, "Only an email address or phone number should be provided on login.") } var user *models.User var grantParams models.GrantParams @@ -131,33 +131,33 @@ func (a *API) ResourceOwnerPasswordGrant(ctx context.Context, w http.ResponseWri if params.Email != "" { provider = "email" if !config.External.Email.Enabled { - return unprocessableEntityError(apierrors.ErrorCodeEmailProviderDisabled, "Email logins are disabled") + return apierrors.NewUnprocessableEntityError(apierrors.ErrorCodeEmailProviderDisabled, "Email logins are disabled") } user, err = models.FindUserByEmailAndAudience(db, params.Email, aud) } else if params.Phone != "" { provider = "phone" if !config.External.Phone.Enabled { - return unprocessableEntityError(apierrors.ErrorCodePhoneProviderDisabled, "Phone logins are disabled") + return apierrors.NewUnprocessableEntityError(apierrors.ErrorCodePhoneProviderDisabled, "Phone logins are disabled") } params.Phone = formatPhoneNumber(params.Phone) user, err = models.FindUserByPhoneAndAudience(db, params.Phone, aud) } else { - return badRequestError(apierrors.ErrorCodeValidationFailed, "missing email or phone") + return apierrors.NewBadRequestError(apierrors.ErrorCodeValidationFailed, "missing email or phone") } if err != nil { if models.IsNotFoundError(err) { - return badRequestError(apierrors.ErrorCodeInvalidCredentials, InvalidLoginMessage) + return apierrors.NewBadRequestError(apierrors.ErrorCodeInvalidCredentials, InvalidLoginMessage) } - return internalServerError("Database error querying schema").WithInternalError(err) + return apierrors.NewInternalServerError("Database error querying schema").WithInternalError(err) } if !user.HasPassword() { - return badRequestError(apierrors.ErrorCodeInvalidCredentials, InvalidLoginMessage) + return apierrors.NewBadRequestError(apierrors.ErrorCodeInvalidCredentials, InvalidLoginMessage) } if user.IsBanned() { - return badRequestError(apierrors.ErrorCodeUserBanned, "User is banned") + return apierrors.NewBadRequestError(apierrors.ErrorCodeUserBanned, "User is banned") } isValidPassword, shouldReEncrypt, err := user.Authenticate(ctx, db, params.Password, config.Security.DBEncryption.DecryptionKeys, config.Security.DBEncryption.Encrypt, config.Security.DBEncryption.EncryptionKeyID) @@ -191,35 +191,35 @@ func (a *API) ResourceOwnerPasswordGrant(ctx context.Context, w http.ResponseWri } if config.Hook.PasswordVerificationAttempt.Enabled { - input := hooks.PasswordVerificationAttemptInput{ + input := v0hooks.PasswordVerificationAttemptInput{ UserID: user.ID, Valid: isValidPassword, } - output := hooks.PasswordVerificationAttemptOutput{} - if err := a.invokeHook(nil, r, &input, &output); err != nil { + output := v0hooks.PasswordVerificationAttemptOutput{} + if err := a.hooksMgr.InvokeHook(nil, r, &input, &output); err != nil { return err } - if output.Decision == hooks.HookRejection { + if output.Decision == v0hooks.HookRejection { if output.Message == "" { - output.Message = hooks.DefaultPasswordHookRejectionMessage + output.Message = v0hooks.DefaultPasswordHookRejectionMessage } if output.ShouldLogoutUser { if err := models.Logout(a.db, user.ID); err != nil { return err } } - return badRequestError(apierrors.ErrorCodeInvalidCredentials, output.Message) + return apierrors.NewBadRequestError(apierrors.ErrorCodeInvalidCredentials, output.Message) } } if !isValidPassword { - return badRequestError(apierrors.ErrorCodeInvalidCredentials, InvalidLoginMessage) + return apierrors.NewBadRequestError(apierrors.ErrorCodeInvalidCredentials, InvalidLoginMessage) } if params.Email != "" && !user.IsConfirmed() { - return badRequestError(apierrors.ErrorCodeEmailNotConfirmed, "Email not confirmed") + return apierrors.NewBadRequestError(apierrors.ErrorCodeEmailNotConfirmed, "Email not confirmed") } else if params.Phone != "" && !user.IsPhoneConfirmed() { - return badRequestError(apierrors.ErrorCodePhoneNotConfirmed, "Phone not confirmed") + return apierrors.NewBadRequestError(apierrors.ErrorCodePhoneNotConfirmed, "Phone not confirmed") } var token *AccessTokenResponse @@ -263,18 +263,18 @@ func (a *API) PKCE(ctx context.Context, w http.ResponseWriter, r *http.Request) } if params.AuthCode == "" || params.CodeVerifier == "" { - return badRequestError(apierrors.ErrorCodeValidationFailed, "invalid request: both auth code and code verifier should be non-empty") + return apierrors.NewBadRequestError(apierrors.ErrorCodeValidationFailed, "invalid request: both auth code and code verifier should be non-empty") } flowState, err := models.FindFlowStateByAuthCode(db, params.AuthCode) // Sanity check in case user ID was not set properly if models.IsNotFoundError(err) || flowState.UserID == nil { - return notFoundError(apierrors.ErrorCodeFlowStateNotFound, "invalid flow state, no valid flow state found") + return apierrors.NewNotFoundError(apierrors.ErrorCodeFlowStateNotFound, "invalid flow state, no valid flow state found") } else if err != nil { return err } if flowState.IsExpired(a.config.External.FlowStateExpiryDuration) { - return unprocessableEntityError(apierrors.ErrorCodeFlowStateExpired, "invalid flow state, flow state has expired") + return apierrors.NewUnprocessableEntityError(apierrors.ErrorCodeFlowStateExpired, "invalid flow state, flow state has expired") } user, err := models.FindUserByID(db, *flowState.UserID) @@ -282,7 +282,7 @@ func (a *API) PKCE(ctx context.Context, w http.ResponseWriter, r *http.Request) return err } if err := flowState.VerifyPKCE(params.CodeVerifier); err != nil { - return badRequestError(apierrors.ErrorCodeBadCodeVerifier, err.Error()) + return apierrors.NewBadRequestError(apierrors.ErrorCodeBadCodeVerifier, err.Error()) } var token *AccessTokenResponse @@ -323,7 +323,7 @@ func (a *API) PKCE(ctx context.Context, w http.ResponseWriter, r *http.Request) func (a *API) generateAccessToken(r *http.Request, tx *storage.Connection, user *models.User, sessionId *uuid.UUID, authenticationMethod models.AuthenticationMethod) (string, int64, error) { config := a.config if sessionId == nil { - return "", 0, internalServerError("Session is required to issue access token") + return "", 0, apierrors.NewInternalServerError("Session is required to issue access token") } sid := sessionId.String() session, terr := models.FindSessionByID(tx, *sessionId, false) @@ -338,7 +338,7 @@ func (a *API) generateAccessToken(r *http.Request, tx *storage.Connection, user issuedAt := time.Now().UTC() expiresAt := issuedAt.Add(time.Second * time.Duration(config.JWT.Exp)) - claims := &hooks.AccessTokenClaims{ + claims := &v0hooks.AccessTokenClaims{ RegisteredClaims: jwt.RegisteredClaims{ Subject: user.ID.String(), Audience: jwt.ClaimStrings{user.Aud}, @@ -359,15 +359,15 @@ func (a *API) generateAccessToken(r *http.Request, tx *storage.Connection, user var gotrueClaims jwt.Claims = claims if config.Hook.CustomAccessToken.Enabled { - input := hooks.CustomAccessTokenInput{ + input := v0hooks.CustomAccessTokenInput{ UserID: user.ID, Claims: claims, AuthenticationMethod: authenticationMethod.String(), } - output := hooks.CustomAccessTokenOutput{} + output := v0hooks.CustomAccessTokenOutput{} - err := a.invokeHook(tx, r, &input, &output) + err := a.hooksMgr.InvokeHook(tx, r, &input, &output) if err != nil { return "", 0, err } @@ -397,7 +397,7 @@ func (a *API) issueRefreshToken(r *http.Request, conn *storage.Connection, user refreshToken, terr = models.GrantAuthenticatedUser(tx, user, grantParams) if terr != nil { - return internalServerError("Database error granting user").WithInternalError(terr) + return apierrors.NewInternalServerError("Database error granting user").WithInternalError(terr) } terr = models.AddClaimToSession(tx, *refreshToken.SessionId, authenticationMethod) @@ -412,7 +412,7 @@ func (a *API) issueRefreshToken(r *http.Request, conn *storage.Connection, user if ok { return httpErr } - return internalServerError("error generating jwt token").WithInternalError(terr) + return apierrors.NewInternalServerError("error generating jwt token").WithInternalError(terr) } return nil }) @@ -439,7 +439,7 @@ func (a *API) updateMFASessionAndClaims(r *http.Request, tx *storage.Connection, currentClaims := getClaims(ctx) sessionId, err := uuid.FromString(currentClaims.SessionId) if err != nil { - return nil, internalServerError("Cannot read SessionId claim as UUID").WithInternalError(err) + return nil, apierrors.NewInternalServerError("Cannot read SessionId claim as UUID").WithInternalError(err) } err = tx.Transaction(func(tx *storage.Connection) error { @@ -477,7 +477,7 @@ func (a *API) updateMFASessionAndClaims(r *http.Request, tx *storage.Connection, if ok { return httpErr } - return internalServerError("error generating jwt token").WithInternalError(terr) + return apierrors.NewInternalServerError("error generating jwt token").WithInternalError(terr) } return nil }) @@ -495,7 +495,7 @@ func (a *API) updateMFASessionAndClaims(r *http.Request, tx *storage.Connection, } func validateTokenClaims(outputClaims map[string]interface{}) error { - schemaLoader := gojsonschema.NewStringLoader(hooks.MinimumViableTokenSchema) + schemaLoader := gojsonschema.NewStringLoader(v0hooks.MinimumViableTokenSchema) documentLoader := gojsonschema.NewGoLoader(outputClaims) diff --git a/internal/api/token_oidc.go b/internal/api/token_oidc.go index 650d1d96d..fbc4243ba 100644 --- a/internal/api/token_oidc.go +++ b/internal/api/token_oidc.go @@ -55,7 +55,7 @@ func (p *IdTokenGrantParams) getProvider(ctx context.Context, config *conf.Globa if issuer == "" || !provider.IsAzureIssuer(issuer) { detectedIssuer, err := provider.DetectAzureIDTokenIssuer(ctx, p.IdToken) if err != nil { - return nil, false, "", nil, badRequestError(apierrors.ErrorCodeValidationFailed, "Unable to detect issuer in ID token for Azure provider").WithInternalError(err) + return nil, false, "", nil, apierrors.NewBadRequestError(apierrors.ErrorCodeValidationFailed, "Unable to detect issuer in ID token for Azure provider").WithInternalError(err) } issuer = detectedIssuer } @@ -102,7 +102,7 @@ func (p *IdTokenGrantParams) getProvider(ctx context.Context, config *conf.Globa } if !allowed { - return nil, false, "", nil, badRequestError(apierrors.ErrorCodeValidationFailed, fmt.Sprintf("Custom OIDC provider %q not allowed", p.Provider)) + return nil, false, "", nil, apierrors.NewBadRequestError(apierrors.ErrorCodeValidationFailed, fmt.Sprintf("Custom OIDC provider %q not allowed", p.Provider)) } cfg = &conf.OAuthProviderConfiguration{ @@ -112,7 +112,7 @@ func (p *IdTokenGrantParams) getProvider(ctx context.Context, config *conf.Globa } if !cfg.Enabled { - return nil, false, "", nil, badRequestError(apierrors.ErrorCodeProviderDisabled, fmt.Sprintf("Provider (issuer %q) is not enabled", issuer)) + return nil, false, "", nil, apierrors.NewBadRequestError(apierrors.ErrorCodeProviderDisabled, fmt.Sprintf("Provider (issuer %q) is not enabled", issuer)) } oidcProvider, err := oidc.NewProvider(ctx, issuer) @@ -136,11 +136,11 @@ func (a *API) IdTokenGrant(ctx context.Context, w http.ResponseWriter, r *http.R } if params.IdToken == "" { - return oauthError("invalid request", "id_token required") + return apierrors.NewOAuthError("invalid request", "id_token required") } if params.Provider == "" && (params.ClientID == "" || params.Issuer == "") { - return oauthError("invalid request", "provider or client_id and issuer required") + return apierrors.NewOAuthError("invalid request", "provider or client_id and issuer required") } oidcProvider, skipNonceCheck, providerType, acceptableClientIDs, err := params.getProvider(ctx, config, r) @@ -153,7 +153,7 @@ func (a *API) IdTokenGrant(ctx context.Context, w http.ResponseWriter, r *http.R AccessToken: params.AccessToken, }) if err != nil { - return oauthError("invalid request", "Bad ID token").WithInternalError(err) + return apierrors.NewOAuthError("invalid request", "Bad ID token").WithInternalError(err) } userData.Metadata.EmailVerified = false @@ -169,7 +169,7 @@ func (a *API) IdTokenGrant(ctx context.Context, w http.ResponseWriter, r *http.R } if idToken.Subject == "" { - return oauthError("invalid request", "Missing sub claim in id_token") + return apierrors.NewOAuthError("invalid request", "Missing sub claim in id_token") } correctAudience := false @@ -191,7 +191,7 @@ func (a *API) IdTokenGrant(ctx context.Context, w http.ResponseWriter, r *http.R } if !correctAudience { - return oauthError("invalid request", fmt.Sprintf("Unacceptable audience in id_token: %v", idToken.Audience)) + return apierrors.NewOAuthError("invalid request", fmt.Sprintf("Unacceptable audience in id_token: %v", idToken.Audience)) } if !skipNonceCheck { @@ -199,12 +199,12 @@ func (a *API) IdTokenGrant(ctx context.Context, w http.ResponseWriter, r *http.R paramsHasNonce := params.Nonce != "" if tokenHasNonce != paramsHasNonce { - return oauthError("invalid request", "Passed nonce and nonce in id_token should either both exist or not.") + return apierrors.NewOAuthError("invalid request", "Passed nonce and nonce in id_token should either both exist or not.") } else if tokenHasNonce && paramsHasNonce { // verify nonce to mitigate replay attacks hash := fmt.Sprintf("%x", sha256.Sum256([]byte(params.Nonce))) if hash != idToken.Nonce { - return oauthError("invalid nonce", "Nonces mismatch") + return apierrors.NewOAuthError("invalid nonce", "Nonces mismatch") } } } @@ -246,7 +246,7 @@ func (a *API) IdTokenGrant(ctx context.Context, w http.ResponseWriter, r *http.R case *HTTPError: return err default: - return oauthError("server_error", "Internal Server Error").WithInternalError(err) + return apierrors.NewOAuthError("server_error", "Internal Server Error").WithInternalError(err) } } diff --git a/internal/api/token_refresh.go b/internal/api/token_refresh.go index 373b64817..a1394e4ce 100644 --- a/internal/api/token_refresh.go +++ b/internal/api/token_refresh.go @@ -31,7 +31,7 @@ func (a *API) RefreshTokenGrant(ctx context.Context, w http.ResponseWriter, r *h } if params.RefreshToken == "" { - return oauthError("invalid_request", "refresh_token required") + return apierrors.NewOAuthError("invalid_request", "refresh_token required") } // A 5 second retry loop is used to make sure that refresh token @@ -48,21 +48,21 @@ func (a *API) RefreshTokenGrant(ctx context.Context, w http.ResponseWriter, r *h user, token, session, err := models.FindUserWithRefreshToken(db, params.RefreshToken, false) if err != nil { if models.IsNotFoundError(err) { - return badRequestError(apierrors.ErrorCodeRefreshTokenNotFound, "Invalid Refresh Token: Refresh Token Not Found") + return apierrors.NewBadRequestError(apierrors.ErrorCodeRefreshTokenNotFound, "Invalid Refresh Token: Refresh Token Not Found") } - return internalServerError(err.Error()) + return apierrors.NewInternalServerError(err.Error()) } if user.IsBanned() { - return badRequestError(apierrors.ErrorCodeUserBanned, "Invalid Refresh Token: User Banned") + return apierrors.NewBadRequestError(apierrors.ErrorCodeUserBanned, "Invalid Refresh Token: User Banned") } if session == nil { // a refresh token won't have a session if it's created prior to the sessions table introduced if err := db.Destroy(token); err != nil { - return internalServerError("Error deleting refresh token with missing session").WithInternalError(err) + return apierrors.NewInternalServerError("Error deleting refresh token with missing session").WithInternalError(err) } - return badRequestError(apierrors.ErrorCodeSessionNotFound, "Invalid Refresh Token: No Valid Session Found") + return apierrors.NewBadRequestError(apierrors.ErrorCodeSessionNotFound, "Invalid Refresh Token: No Valid Session Found") } sessionValidityConfig := models.SessionValidityConfig{ @@ -78,13 +78,13 @@ func (a *API) RefreshTokenGrant(ctx context.Context, w http.ResponseWriter, r *h // do nothing case models.SessionTimedOut: - return badRequestError(apierrors.ErrorCodeSessionExpired, "Invalid Refresh Token: Session Expired (Inactivity)") + return apierrors.NewBadRequestError(apierrors.ErrorCodeSessionExpired, "Invalid Refresh Token: Session Expired (Inactivity)") case models.SessionLowAAL: - return badRequestError(apierrors.ErrorCodeSessionExpired, "Invalid Refresh Token: Session Expired (Low AAL: User Needs MFA Verification)") + return apierrors.NewBadRequestError(apierrors.ErrorCodeSessionExpired, "Invalid Refresh Token: Session Expired (Low AAL: User Needs MFA Verification)") default: - return badRequestError(apierrors.ErrorCodeSessionExpired, "Invalid Refresh Token: Session Expired") + return apierrors.NewBadRequestError(apierrors.ErrorCodeSessionExpired, "Invalid Refresh Token: Session Expired") } // Basic checks above passed, now we need to serialize access @@ -113,7 +113,7 @@ func (a *API) RefreshTokenGrant(ctx context.Context, w http.ResponseWriter, r *h retry = true return terr } - return internalServerError(terr.Error()) + return apierrors.NewInternalServerError(terr.Error()) } if a.config.Sessions.SinglePerUser { @@ -129,7 +129,7 @@ func (a *API) RefreshTokenGrant(ctx context.Context, w http.ResponseWriter, r *h retry = true return terr } else if terr != nil { - return internalServerError(terr.Error()) + return apierrors.NewInternalServerError(terr.Error()) } sessionTag := session.DetermineTag(config.Sessions.Tags) @@ -163,7 +163,7 @@ func (a *API) RefreshTokenGrant(ctx context.Context, w http.ResponseWriter, r *h if s.LastRefreshedAt(nil).After(session.LastRefreshedAt(&token.UpdatedAt)) { // session is not the most // recently active one - return badRequestError(apierrors.ErrorCodeSessionExpired, "Invalid Refresh Token: Session Expired (Revoked by Newer Login)") + return apierrors.NewBadRequestError(apierrors.ErrorCodeSessionExpired, "Invalid Refresh Token: Session Expired (Revoked by Newer Login)") } } @@ -178,7 +178,7 @@ func (a *API) RefreshTokenGrant(ctx context.Context, w http.ResponseWriter, r *h if token.Revoked { activeRefreshToken, terr := session.FindCurrentlyActiveRefreshToken(tx) if terr != nil && !models.IsNotFoundError(terr) { - return internalServerError(terr.Error()) + return apierrors.NewInternalServerError(terr.Error()) } if activeRefreshToken != nil && activeRefreshToken.Parent.String() == token.Token { @@ -202,11 +202,11 @@ func (a *API) RefreshTokenGrant(ctx context.Context, w http.ResponseWriter, r *h if config.Security.RefreshTokenRotationEnabled { // Revoke all tokens in token family if err := models.RevokeTokenFamily(tx, token); err != nil { - return internalServerError(err.Error()) + return apierrors.NewInternalServerError(err.Error()) } } - return storage.NewCommitWithError(badRequestError(apierrors.ErrorCodeRefreshTokenAlreadyUsed, "Invalid Refresh Token: Already Used").WithInternalMessage("Possible abuse attempt: %v", token.ID)) + return storage.NewCommitWithError(apierrors.NewBadRequestError(apierrors.ErrorCodeRefreshTokenAlreadyUsed, "Invalid Refresh Token: Already Used").WithInternalMessage("Possible abuse attempt: %v", token.ID)) } } } @@ -230,7 +230,7 @@ func (a *API) RefreshTokenGrant(ctx context.Context, w http.ResponseWriter, r *h if ok { return httpErr } - return internalServerError("error generating jwt token").WithInternalError(terr) + return apierrors.NewInternalServerError("error generating jwt token").WithInternalError(terr) } refreshedAt := a.Now() @@ -251,7 +251,7 @@ func (a *API) RefreshTokenGrant(ctx context.Context, w http.ResponseWriter, r *h } if terr := session.UpdateOnlyRefreshInfo(tx); terr != nil { - return internalServerError("failed to update session information").WithInternalError(terr) + return apierrors.NewInternalServerError("failed to update session information").WithInternalError(terr) } newTokenResponse = &AccessTokenResponse{ @@ -280,5 +280,5 @@ func (a *API) RefreshTokenGrant(ctx context.Context, w http.ResponseWriter, r *h return sendJSON(w, http.StatusOK, newTokenResponse) } - return conflictError("Too many concurrent token refresh requests on the same session or refresh token") + return apierrors.NewConflictError("Too many concurrent token refresh requests on the same session or refresh token") } diff --git a/internal/api/user.go b/internal/api/user.go index 815511970..8da2e82d1 100644 --- a/internal/api/user.go +++ b/internal/api/user.go @@ -45,7 +45,7 @@ func (a *API) validateUserUpdateParams(ctx context.Context, p *UserUpdateParams) p.Channel = sms_provider.SMSProvider } if !sms_provider.IsValidMessageChannel(p.Channel, config) { - return badRequestError(apierrors.ErrorCodeValidationFailed, InvalidChannelError) + return apierrors.NewBadRequestError(apierrors.ErrorCodeValidationFailed, InvalidChannelError) } } @@ -63,13 +63,13 @@ func (a *API) UserGet(w http.ResponseWriter, r *http.Request) error { ctx := r.Context() claims := getClaims(ctx) if claims == nil { - return internalServerError("Could not read claims") + return apierrors.NewInternalServerError("Could not read claims") } aud := a.requestAud(ctx, r) audienceFromClaims, _ := claims.GetAudience() if len(audienceFromClaims) == 0 || aud != audienceFromClaims[0] { - return badRequestError(apierrors.ErrorCodeValidationFailed, "Token audience doesn't match request audience") + return apierrors.NewBadRequestError(apierrors.ErrorCodeValidationFailed, "Token audience doesn't match request audience") } user := getUser(ctx) @@ -97,20 +97,20 @@ func (a *API) UserUpdate(w http.ResponseWriter, r *http.Request) error { if params.AppData != nil && !isAdmin(user, config) { if !isAdmin(user, config) { - return forbiddenError(apierrors.ErrorCodeNotAdmin, "Updating app_metadata requires admin privileges") + return apierrors.NewForbiddenError(apierrors.ErrorCodeNotAdmin, "Updating app_metadata requires admin privileges") } } if user.HasMFAEnabled() && !session.IsAAL2() { if (params.Password != nil && *params.Password != "") || (params.Email != "" && user.GetEmail() != params.Email) || (params.Phone != "" && user.GetPhone() != params.Phone) { - return httpError(http.StatusUnauthorized, apierrors.ErrorCodeInsufficientAAL, "AAL2 session is required to update email or password when MFA is enabled.") + return apierrors.NewHTTPError(http.StatusUnauthorized, apierrors.ErrorCodeInsufficientAAL, "AAL2 session is required to update email or password when MFA is enabled.") } } if user.IsAnonymous { if params.Password != nil && *params.Password != "" { if params.Email == "" && params.Phone == "" { - return unprocessableEntityError(apierrors.ErrorCodeValidationFailed, "Updating password of an anonymous user without an email or phone is not allowed") + return apierrors.NewUnprocessableEntityError(apierrors.ErrorCodeValidationFailed, "Updating password of an anonymous user without an email or phone is not allowed") } } } @@ -124,23 +124,23 @@ func (a *API) UserUpdate(w http.ResponseWriter, r *http.Request) error { updatingForbiddenFields = updatingForbiddenFields || (params.Nonce != "") if updatingForbiddenFields { - return unprocessableEntityError(apierrors.ErrorCodeUserSSOManaged, "Updating email, phone, password of a SSO account only possible via SSO") + return apierrors.NewUnprocessableEntityError(apierrors.ErrorCodeUserSSOManaged, "Updating email, phone, password of a SSO account only possible via SSO") } } if params.Email != "" && user.GetEmail() != params.Email { if duplicateUser, err := models.IsDuplicatedEmail(db, params.Email, aud, user); err != nil { - return internalServerError("Database error checking email").WithInternalError(err) + return apierrors.NewInternalServerError("Database error checking email").WithInternalError(err) } else if duplicateUser != nil { - return unprocessableEntityError(apierrors.ErrorCodeEmailExists, DuplicateEmailMsg) + return apierrors.NewUnprocessableEntityError(apierrors.ErrorCodeEmailExists, DuplicateEmailMsg) } } if params.Phone != "" && user.GetPhone() != params.Phone { if exists, err := models.IsDuplicatedPhone(db, params.Phone, aud); err != nil { - return internalServerError("Database error checking phone").WithInternalError(err) + return apierrors.NewInternalServerError("Database error checking phone").WithInternalError(err) } else if exists { - return unprocessableEntityError(apierrors.ErrorCodePhoneExists, DuplicatePhoneMsg) + return apierrors.NewUnprocessableEntityError(apierrors.ErrorCodePhoneExists, DuplicatePhoneMsg) } } @@ -150,7 +150,7 @@ func (a *API) UserUpdate(w http.ResponseWriter, r *http.Request) error { // we require reauthentication if the user hasn't signed in recently in the current session if session == nil || now.After(session.CreatedAt.Add(24*time.Hour)) { if len(params.Nonce) == 0 { - return badRequestError(apierrors.ErrorCodeReauthenticationNeeded, "Password update requires reauthentication") + return apierrors.NewBadRequestError(apierrors.ErrorCodeReauthenticationNeeded, "Password update requires reauthentication") } if err := a.verifyReauthentication(params.Nonce, db, config, user); err != nil { return err @@ -172,7 +172,7 @@ func (a *API) UserUpdate(w http.ResponseWriter, r *http.Request) error { } if isSamePassword { - return unprocessableEntityError(apierrors.ErrorCodeSamePassword, "New password should be different from the old password.") + return apierrors.NewUnprocessableEntityError(apierrors.ErrorCodeSamePassword, "New password should be different from the old password.") } } @@ -190,7 +190,7 @@ func (a *API) UserUpdate(w http.ResponseWriter, r *http.Request) error { } if terr = user.UpdatePassword(tx, sessionID); terr != nil { - return internalServerError("Error during password storage").WithInternalError(terr) + return apierrors.NewInternalServerError("Error during password storage").WithInternalError(terr) } if terr := models.NewAuditLogEntry(r, tx, user, models.UserUpdatePasswordAction, "", nil); terr != nil { @@ -200,13 +200,13 @@ func (a *API) UserUpdate(w http.ResponseWriter, r *http.Request) error { if params.Data != nil { if terr = user.UpdateUserMetaData(tx, params.Data); terr != nil { - return internalServerError("Error updating user").WithInternalError(terr) + return apierrors.NewInternalServerError("Error updating user").WithInternalError(terr) } } if params.AppData != nil { if terr = user.UpdateAppMetaData(tx, params.AppData); terr != nil { - return internalServerError("Error updating user").WithInternalError(terr) + return apierrors.NewInternalServerError("Error updating user").WithInternalError(terr) } } @@ -254,7 +254,7 @@ func (a *API) UserUpdate(w http.ResponseWriter, r *http.Request) error { } if terr = models.NewAuditLogEntry(r, tx, user, models.UserModifiedAction, "", nil); terr != nil { - return internalServerError("Error recording audit log entry").WithInternalError(terr) + return apierrors.NewInternalServerError("Error recording audit log entry").WithInternalError(terr) } return nil diff --git a/internal/api/verify.go b/internal/api/verify.go index 0d3c09b19..f86e6802e 100644 --- a/internal/api/verify.go +++ b/internal/api/verify.go @@ -49,18 +49,18 @@ type VerifyParams struct { func (p *VerifyParams) Validate(r *http.Request, a *API) error { var err error if p.Type == "" { - return badRequestError(apierrors.ErrorCodeValidationFailed, "Verify requires a verification type") + return apierrors.NewBadRequestError(apierrors.ErrorCodeValidationFailed, "Verify requires a verification type") } switch r.Method { case http.MethodGet: if p.Token == "" { - return badRequestError(apierrors.ErrorCodeValidationFailed, "Verify requires a token or a token hash") + return apierrors.NewBadRequestError(apierrors.ErrorCodeValidationFailed, "Verify requires a token or a token hash") } // TODO: deprecate the token query param from GET /verify and use token_hash instead (breaking change) p.TokenHash = p.Token case http.MethodPost: if (p.Token == "" && p.TokenHash == "") || (p.Token != "" && p.TokenHash != "") { - return badRequestError(apierrors.ErrorCodeValidationFailed, "Verify requires either a token or a token hash") + return apierrors.NewBadRequestError(apierrors.ErrorCodeValidationFailed, "Verify requires either a token or a token hash") } if p.Token != "" { if isPhoneOtpVerification(p) { @@ -72,15 +72,15 @@ func (p *VerifyParams) Validate(r *http.Request, a *API) error { } else if isEmailOtpVerification(p) { p.Email, err = a.validateEmail(p.Email) if err != nil { - return unprocessableEntityError(apierrors.ErrorCodeValidationFailed, "Invalid email format").WithInternalError(err) + return apierrors.NewUnprocessableEntityError(apierrors.ErrorCodeValidationFailed, "Invalid email format").WithInternalError(err) } p.TokenHash = crypto.GenerateTokenHash(p.Email, p.Token) } else { - return badRequestError(apierrors.ErrorCodeValidationFailed, "Only an email address or phone number should be provided on verify") + return apierrors.NewBadRequestError(apierrors.ErrorCodeValidationFailed, "Only an email address or phone number should be provided on verify") } } else if p.TokenHash != "" { if p.Email != "" || p.Phone != "" || p.RedirectTo != "" { - return badRequestError(apierrors.ErrorCodeValidationFailed, "Only the token_hash and type should be provided") + return apierrors.NewBadRequestError(apierrors.ErrorCodeValidationFailed, "Only the token_hash and type should be provided") } } default: @@ -162,7 +162,7 @@ func (a *API) verifyGet(w http.ResponseWriter, r *http.Request, params *VerifyPa return nil } default: - return badRequestError(apierrors.ErrorCodeValidationFailed, "Unsupported verification type") + return apierrors.NewBadRequestError(apierrors.ErrorCodeValidationFailed, "Unsupported verification type") } if terr != nil { @@ -187,7 +187,7 @@ func (a *API) verifyGet(w http.ResponseWriter, r *http.Request, params *VerifyPa } else if isPKCEFlow(flowType) { if authCode, terr = issueAuthCode(tx, user, authenticationMethod); terr != nil { - return badRequestError(apierrors.ErrorCodeFlowStateNotFound, "No associated flow state found. %s", terr) + return apierrors.NewBadRequestError(apierrors.ErrorCodeFlowStateNotFound, "No associated flow state found. %s", terr) } } return nil @@ -261,7 +261,7 @@ func (a *API) verifyPost(w http.ResponseWriter, r *http.Request, params *VerifyP case smsVerification, phoneChangeVerification: user, terr = a.smsVerify(r, tx, user, params) default: - return badRequestError(apierrors.ErrorCodeValidationFailed, "Unsupported verification type") + return apierrors.NewBadRequestError(apierrors.ErrorCodeValidationFailed, "Unsupported verification type") } if terr != nil { @@ -318,7 +318,7 @@ func (a *API) signupVerify(r *http.Request, ctx context.Context, conn *storage.C var terr error if shouldUpdatePassword { if terr = user.UpdatePassword(tx, nil); terr != nil { - return internalServerError("Error storing password").WithInternalError(terr) + return apierrors.NewInternalServerError("Error storing password").WithInternalError(terr) } } @@ -327,7 +327,7 @@ func (a *API) signupVerify(r *http.Request, ctx context.Context, conn *storage.C } if terr = user.Confirm(tx); terr != nil { - return internalServerError("Error confirming user").WithInternalError(terr) + return apierrors.NewInternalServerError("Error confirming user").WithInternalError(terr) } for _, identity := range user.Identities { @@ -338,7 +338,7 @@ func (a *API) signupVerify(r *http.Request, ctx context.Context, conn *storage.C if terr = identity.UpdateIdentityData(tx, map[string]interface{}{ "email_verified": true, }); terr != nil { - return internalServerError("Error setting email_verified to true on identity").WithInternalError(terr) + return apierrors.NewInternalServerError("Error setting email_verified to true on identity").WithInternalError(terr) } } @@ -373,7 +373,7 @@ func (a *API) recoverVerify(r *http.Request, conn *storage.Connection, user *mod }) if err != nil { - return nil, internalServerError("Database error updating user").WithInternalError(err) + return nil, apierrors.NewInternalServerError("Database error updating user").WithInternalError(err) } return user, nil } @@ -387,7 +387,7 @@ func (a *API) smsVerify(r *http.Request, conn *storage.Connection, user *models. return terr } if terr := user.ConfirmPhone(tx); terr != nil { - return internalServerError("Error confirming user").WithInternalError(terr) + return apierrors.NewInternalServerError("Error confirming user").WithInternalError(terr) } } else if params.Type == phoneChangeVerification { if terr := models.NewAuditLogEntry(r, tx, user, models.UserModifiedAction, "", nil); terr != nil { @@ -414,7 +414,7 @@ func (a *API) smsVerify(r *http.Request, conn *storage.Connection, user *models. } } if terr := user.ConfirmPhoneChange(tx); terr != nil { - return internalServerError("Error confirming user").WithInternalError(terr) + return apierrors.NewInternalServerError("Error confirming user").WithInternalError(terr) } } @@ -426,7 +426,7 @@ func (a *API) smsVerify(r *http.Request, conn *storage.Connection, user *models. } if terr := tx.Load(user, "Identities"); terr != nil { - return internalServerError("Error refetching identities").WithInternalError(terr) + return apierrors.NewInternalServerError("Error refetching identities").WithInternalError(terr) } return nil }) @@ -568,10 +568,10 @@ func (a *API) emailChangeVerify(r *http.Request, conn *storage.Connection, param } } if terr := tx.Load(user, "Identities"); terr != nil { - return internalServerError("Error refetching identities").WithInternalError(terr) + return apierrors.NewInternalServerError("Error refetching identities").WithInternalError(terr) } if terr := user.ConfirmEmailChange(tx, zeroConfirmation); terr != nil { - return internalServerError("Error confirm email").WithInternalError(terr) + return apierrors.NewInternalServerError("Error confirm email").WithInternalError(terr) } return nil @@ -599,18 +599,18 @@ func (a *API) verifyTokenHash(conn *storage.Connection, params *VerifyParams) (* case mail.EmailChangeVerification: user, err = models.FindUserByEmailChangeToken(conn, params.TokenHash) default: - return nil, badRequestError(apierrors.ErrorCodeValidationFailed, "Invalid email verification type") + return nil, apierrors.NewBadRequestError(apierrors.ErrorCodeValidationFailed, "Invalid email verification type") } if err != nil { if models.IsNotFoundError(err) { - return nil, forbiddenError(apierrors.ErrorCodeOTPExpired, "Email link is invalid or has expired").WithInternalError(err) + return nil, apierrors.NewForbiddenError(apierrors.ErrorCodeOTPExpired, "Email link is invalid or has expired").WithInternalError(err) } - return nil, internalServerError("Database error finding user from email link").WithInternalError(err) + return nil, apierrors.NewInternalServerError("Database error finding user from email link").WithInternalError(err) } if user.IsBanned() { - return nil, forbiddenError(apierrors.ErrorCodeUserBanned, "User is banned") + return nil, apierrors.NewForbiddenError(apierrors.ErrorCodeUserBanned, "User is banned") } var isExpired bool @@ -632,7 +632,7 @@ func (a *API) verifyTokenHash(conn *storage.Connection, params *VerifyParams) (* } if isExpired { - return nil, forbiddenError(apierrors.ErrorCodeOTPExpired, "Email link is invalid or has expired").WithInternalMessage("email link has expired") + return nil, apierrors.NewForbiddenError(apierrors.ErrorCodeOTPExpired, "Email link is invalid or has expired").WithInternalMessage("email link has expired") } return user, nil @@ -661,13 +661,13 @@ func (a *API) verifyUserAndToken(conn *storage.Connection, params *VerifyParams, if err != nil { if models.IsNotFoundError(err) { - return nil, forbiddenError(apierrors.ErrorCodeOTPExpired, "Token has expired or is invalid").WithInternalError(err) + return nil, apierrors.NewForbiddenError(apierrors.ErrorCodeOTPExpired, "Token has expired or is invalid").WithInternalError(err) } - return nil, internalServerError("Database error finding user").WithInternalError(err) + return nil, apierrors.NewInternalServerError("Database error finding user").WithInternalError(err) } if user.IsBanned() { - return nil, forbiddenError(apierrors.ErrorCodeUserBanned, "User is banned") + return nil, apierrors.NewForbiddenError(apierrors.ErrorCodeUserBanned, "User is banned") } var isValid bool @@ -710,7 +710,7 @@ func (a *API) verifyUserAndToken(conn *storage.Connection, params *VerifyParams, if !config.Hook.SendSMS.Enabled && config.Sms.IsTwilioVerifyProvider() { if err := smsProvider.(*sms_provider.TwilioVerifyProvider).VerifyOTP(phone, params.Token); err != nil { - return nil, forbiddenError(apierrors.ErrorCodeOTPExpired, "Token has expired or is invalid").WithInternalError(err) + return nil, apierrors.NewForbiddenError(apierrors.ErrorCodeOTPExpired, "Token has expired or is invalid").WithInternalError(err) } return user, nil } @@ -718,7 +718,7 @@ func (a *API) verifyUserAndToken(conn *storage.Connection, params *VerifyParams, } if !isValid { - return nil, forbiddenError(apierrors.ErrorCodeOTPExpired, "Token has expired or is invalid").WithInternalMessage("token has expired or is invalid") + return nil, apierrors.NewForbiddenError(apierrors.ErrorCodeOTPExpired, "Token has expired or is invalid").WithInternalMessage("token has expired or is invalid") } return user, nil } diff --git a/internal/api/verify_test.go b/internal/api/verify_test.go index 4e45fa21c..e0dc5ed73 100644 --- a/internal/api/verify_test.go +++ b/internal/api/verify_test.go @@ -1198,7 +1198,7 @@ func (ts *VerifyTestSuite) TestPrepErrorRedirectURL() { for _, c := range cases { ts.Run(c.desc, func() { req := httptest.NewRequest(http.MethodGet, "http://localhost", nil) - rurl, err := ts.API.prepErrorRedirectURL(badRequestError(apierrors.ErrorCodeValidationFailed, DefaultError), req, c.rurl, c.flowType) + rurl, err := ts.API.prepErrorRedirectURL(apierrors.NewBadRequestError(apierrors.ErrorCodeValidationFailed, DefaultError), req, c.rurl, c.flowType) require.NoError(ts.T(), err) require.Equal(ts.T(), c.expected, rurl) }) @@ -1248,7 +1248,7 @@ func (ts *VerifyTestSuite) TestVerifyValidateParams() { Token: "some-token", }, method: http.MethodPost, - expected: badRequestError(apierrors.ErrorCodeValidationFailed, "Only an email address or phone number should be provided on verify"), + expected: apierrors.NewBadRequestError(apierrors.ErrorCodeValidationFailed, "Only an email address or phone number should be provided on verify"), }, { desc: "Cannot send both TokenHash and Token", @@ -1258,7 +1258,7 @@ func (ts *VerifyTestSuite) TestVerifyValidateParams() { TokenHash: "some-token-hash", }, method: http.MethodPost, - expected: badRequestError(apierrors.ErrorCodeValidationFailed, "Verify requires either a token or a token hash"), + expected: apierrors.NewBadRequestError(apierrors.ErrorCodeValidationFailed, "Verify requires either a token or a token hash"), }, { desc: "No verification type specified", @@ -1267,7 +1267,7 @@ func (ts *VerifyTestSuite) TestVerifyValidateParams() { Email: "email@example.com", }, method: http.MethodPost, - expected: badRequestError(apierrors.ErrorCodeValidationFailed, "Verify requires a verification type"), + expected: apierrors.NewBadRequestError(apierrors.ErrorCodeValidationFailed, "Verify requires a verification type"), }, } diff --git a/internal/api/web3.go b/internal/api/web3.go index 6dc2c7c81..0764b3b9c 100644 --- a/internal/api/web3.go +++ b/internal/api/web3.go @@ -24,7 +24,7 @@ func (a *API) Web3Grant(ctx context.Context, w http.ResponseWriter, r *http.Requ config := a.config if !config.External.Web3Solana.Enabled { - return unprocessableEntityError(apierrors.ErrorCodeWeb3ProviderDisabled, "Web3 provider is disabled") + return apierrors.NewUnprocessableEntityError(apierrors.ErrorCodeWeb3ProviderDisabled, "Web3 provider is disabled") } params := &Web3GrantParams{} @@ -33,7 +33,7 @@ func (a *API) Web3Grant(ctx context.Context, w http.ResponseWriter, r *http.Requ } if params.Chain != "solana" { - return badRequestError(apierrors.ErrorCodeWeb3UnsupportedChain, "Unsupported chain") + return apierrors.NewBadRequestError(apierrors.ErrorCodeWeb3UnsupportedChain, "Unsupported chain") } return a.web3GrantSolana(ctx, w, r, params) @@ -44,66 +44,66 @@ func (a *API) web3GrantSolana(ctx context.Context, w http.ResponseWriter, r *htt db := a.db.WithContext(ctx) if len(params.Message) < 64 { - return badRequestError(apierrors.ErrorCodeValidationFailed, "message is too short") + return apierrors.NewBadRequestError(apierrors.ErrorCodeValidationFailed, "message is too short") } else if len(params.Message) > 20*1024 { - return badRequestError(apierrors.ErrorCodeValidationFailed, "message must not exceed 20KB") + return apierrors.NewBadRequestError(apierrors.ErrorCodeValidationFailed, "message must not exceed 20KB") } if len(params.Signature) != 86 && len(params.Signature) != 88 { - return badRequestError(apierrors.ErrorCodeValidationFailed, "signature must be 64 bytes encoded as base64 with or without padding") + return apierrors.NewBadRequestError(apierrors.ErrorCodeValidationFailed, "signature must be 64 bytes encoded as base64 with or without padding") } base64URLSignature := strings.ReplaceAll(strings.ReplaceAll(strings.TrimRight(params.Signature, "="), "+", "-"), "/", "_") signatureBytes, err := base64.RawURLEncoding.DecodeString(base64URLSignature) if err != nil { - return badRequestError(apierrors.ErrorCodeValidationFailed, "signature does not contain valid base64 characters") + return apierrors.NewBadRequestError(apierrors.ErrorCodeValidationFailed, "signature does not contain valid base64 characters") } parsedMessage, err := siws.ParseMessage(params.Message) if err != nil { - return badRequestError(apierrors.ErrorCodeValidationFailed, err.Error()) + return apierrors.NewBadRequestError(apierrors.ErrorCodeValidationFailed, err.Error()) } if !parsedMessage.VerifySignature(signatureBytes) { - return oauthError("invalid_grant", "Signature does not match address in message") + return apierrors.NewOAuthError("invalid_grant", "Signature does not match address in message") } if parsedMessage.URI.Scheme != "https" { if parsedMessage.URI.Scheme == "http" && parsedMessage.URI.Hostname() != "localhost" { - return oauthError("invalid_grant", "Signed Solana message is using URI which uses HTTP and hostname is not localhost, only HTTPS is allowed") + return apierrors.NewOAuthError("invalid_grant", "Signed Solana message is using URI which uses HTTP and hostname is not localhost, only HTTPS is allowed") } else { - return oauthError("invalid_grant", "Signed Solana message is using URI which does not use HTTPS") + return apierrors.NewOAuthError("invalid_grant", "Signed Solana message is using URI which does not use HTTPS") } } if !utilities.IsRedirectURLValid(config, parsedMessage.URI.String()) { - return oauthError("invalid_grant", "Signed Solana message is using URI which is not allowed on this server, message was signed for another app") + return apierrors.NewOAuthError("invalid_grant", "Signed Solana message is using URI which is not allowed on this server, message was signed for another app") } if parsedMessage.URI.Host != parsedMessage.Domain || !utilities.IsRedirectURLValid(config, "https://"+parsedMessage.Domain+"/") { - return oauthError("invalid_grant", "Signed Solana message is using a Domain that does not match the one in URI which is not allowed on this server") + return apierrors.NewOAuthError("invalid_grant", "Signed Solana message is using a Domain that does not match the one in URI which is not allowed on this server") } now := a.Now() if !parsedMessage.NotBefore.IsZero() && now.Before(parsedMessage.NotBefore) { - return oauthError("invalid_grant", "Signed Solana message becomes valid in the future") + return apierrors.NewOAuthError("invalid_grant", "Signed Solana message becomes valid in the future") } if !parsedMessage.ExpirationTime.IsZero() && now.After(parsedMessage.ExpirationTime) { - return oauthError("invalid_grant", "Signed Solana message is expired") + return apierrors.NewOAuthError("invalid_grant", "Signed Solana message is expired") } latestExpiryAt := parsedMessage.IssuedAt.Add(config.External.Web3Solana.MaximumValidityDuration) if now.After(latestExpiryAt) { - return oauthError("invalid_grant", "Solana message was issued too long ago") + return apierrors.NewOAuthError("invalid_grant", "Solana message was issued too long ago") } earliestIssuedAt := parsedMessage.IssuedAt.Add(-config.External.Web3Solana.MaximumValidityDuration) if now.Before(earliestIssuedAt) { - return oauthError("invalid_grant", "Solana message was issued too far in the future") + return apierrors.NewOAuthError("invalid_grant", "Solana message was issued too far in the future") } providerId := strings.Join([]string{ @@ -162,7 +162,7 @@ func (a *API) web3GrantSolana(ctx context.Context, w http.ResponseWriter, r *htt case *HTTPError: return err default: - return oauthError("server_error", "Internal Server Error").WithInternalError(err) + return apierrors.NewOAuthError("server_error", "Internal Server Error").WithInternalError(err) } } diff --git a/internal/hooks/hooks.go b/internal/hooks/hooks.go new file mode 100644 index 000000000..91b8fb162 --- /dev/null +++ b/internal/hooks/hooks.go @@ -0,0 +1,38 @@ +package hooks + +import ( + "net/http" + + "github.com/supabase/auth/internal/conf" + "github.com/supabase/auth/internal/hooks/v0hooks" + "github.com/supabase/auth/internal/storage" +) + +type Manager struct { + v0mgr *v0hooks.Manager +} + +func NewManager( + db *storage.Connection, + config *conf.GlobalConfiguration, +) *Manager { + return &Manager{ + v0mgr: v0hooks.NewManager(db, config), + } +} + +func (o *Manager) InvokeHook( + conn *storage.Connection, + r *http.Request, + input, output any, +) error { + return o.v0mgr.InvokeHook(conn, r, input, output) +} + +func (o *Manager) RunHTTPHook( + r *http.Request, + hookConfig conf.ExtensibilityPointConfiguration, + input any, +) ([]byte, error) { + return o.v0mgr.RunHTTPHook(r, hookConfig, input) +} diff --git a/internal/hooks/hooks_test.go b/internal/hooks/hooks_test.go new file mode 100644 index 000000000..656142400 --- /dev/null +++ b/internal/hooks/hooks_test.go @@ -0,0 +1,139 @@ +package hooks_test + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/gofrs/uuid" + "github.com/supabase/auth/internal/api" + "github.com/supabase/auth/internal/conf" + "github.com/supabase/auth/internal/hooks" + "github.com/supabase/auth/internal/hooks/v0hooks" + "github.com/supabase/auth/internal/models" + "github.com/supabase/auth/internal/storage" + "github.com/supabase/auth/internal/storage/test" +) + +const ( + apiTestVersion = "1" + apiTestConfig = "../../hack/test.env" +) + +func TestNewManager(t *testing.T) { + { + ctx, cancel := context.WithTimeout(context.Background(), time.Second*4) + defer cancel() + + config := helpConfig(t, apiTestConfig) + conn := helpConn(t, config) + + hr := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("content-type", "application/json") + + fmt.Fprintln(w, `{}`) + }) + + ts := httptest.NewServer(hr) + defer ts.Close() + + config.Hook.SendEmail.Enabled = true + config.Hook.SendEmail.URI = ts.URL + "/SendEmail" + + a := newAPI(config, conn) + mgr := hooks.NewManager(a.GetDB(), a.GetConfig()) + + { + in := &v0hooks.SendEmailInput{ + User: &models.User{ + ID: uuid.Must(uuid.NewV4()), + }, + } + buf := new(bytes.Buffer) + err := json.NewEncoder(buf).Encode(in) + if err != nil { + t.Fatalf("exp nil err; got %v", err) + } + + out := &v0hooks.SendEmailOutput{} + req, err := http.NewRequestWithContext( + ctx, "POST", config.Hook.SendEmail.URI, buf) + if err != nil { + t.Fatalf("exp nil err; got %v", err) + } + + err = mgr.InvokeHook(nil, req, in, out) + if err != nil { + t.Fatalf("exp nil err; got %v", err) + } + if exp, got := "", out.HookError.Message; exp != got { + t.Fatalf("exp %v; got %v", exp, got) + } + } + + { + in := &v0hooks.SendEmailInput{ + User: &models.User{ + ID: uuid.Must(uuid.NewV4()), + }, + } + buf := new(bytes.Buffer) + err := json.NewEncoder(buf).Encode(in) + if err != nil { + t.Fatalf("exp nil err; got %v", err) + } + + req, err := http.NewRequestWithContext( + ctx, "POST", config.Hook.SendEmail.URI, buf) + if err != nil { + t.Fatalf("exp nil err; got %v", err) + } + + res, err := mgr.RunHTTPHook(req, config.Hook.SendEmail, in) + if err != nil { + t.Fatalf("exp nil err; got %v", err) + } + + out := &v0hooks.SendEmailOutput{} + if err := json.Unmarshal(res, out); err != nil { + t.Fatalf("exp nil err; got %v", err) + } + if exp, got := "", out.HookError.Message; exp != got { + t.Fatalf("exp %v; got %v", exp, got) + } + } + } +} + +func newAPI( + config *conf.GlobalConfiguration, + conn *storage.Connection, +) *api.API { + limiterOpts := api.NewLimiterOptions(config) + return api.NewAPIWithVersion(config, conn, apiTestVersion, limiterOpts) +} + +func helpConfig(tb testing.TB, configPath string) *conf.GlobalConfiguration { + tb.Helper() + + config, err := conf.LoadGlobal(configPath) + if err != nil { + tb.Fatalf("error loading config %q; got %v", configPath, err) + } + return config +} + +func helpConn(tb testing.TB, config *conf.GlobalConfiguration) *storage.Connection { + tb.Helper() + + conn, err := test.SetupDBConnection(config) + if err != nil { + tb.Fatalf("error setting up db connection: %v", err) + } + return conn +} diff --git a/internal/api/hooks.go b/internal/hooks/v0hooks/manager.go similarity index 63% rename from internal/api/hooks.go rename to internal/hooks/v0hooks/manager.go index 57d21477a..3d3cba1c9 100644 --- a/internal/api/hooks.go +++ b/internal/hooks/v0hooks/manager.go @@ -1,4 +1,4 @@ -package api +package v0hooks import ( "bytes" @@ -16,10 +16,10 @@ import ( "github.com/gofrs/uuid" "github.com/sirupsen/logrus" standardwebhooks "github.com/standard-webhooks/standard-webhooks/libraries/go" + "github.com/xeipuuv/gojsonschema" "github.com/supabase/auth/internal/api/apierrors" "github.com/supabase/auth/internal/conf" - "github.com/supabase/auth/internal/hooks" "github.com/supabase/auth/internal/observability" "github.com/supabase/auth/internal/storage" ) @@ -31,178 +31,60 @@ const ( PayloadLimit = 200 * 1024 // 200KB ) -func (a *API) runPostgresHook(ctx context.Context, tx *storage.Connection, hookConfig conf.ExtensibilityPointConfiguration, input, output any) ([]byte, error) { - db := a.db.WithContext(ctx) - - request, err := json.Marshal(input) - if err != nil { - panic(err) - } - - var response []byte - invokeHookFunc := func(tx *storage.Connection) error { - // We rely on Postgres timeouts to ensure the function doesn't overrun - if terr := tx.RawQuery(fmt.Sprintf("set local statement_timeout TO '%d';", hooks.DefaultTimeout)).Exec(); terr != nil { - return terr - } - - if terr := tx.RawQuery(fmt.Sprintf("select %s(?);", hookConfig.HookName), request).First(&response); terr != nil { - return terr - } - - // reset the timeout - if terr := tx.RawQuery("set local statement_timeout TO default;").Exec(); terr != nil { - return terr - } - - return nil - } - - if tx != nil { - if err := invokeHookFunc(tx); err != nil { - return nil, err - } - } else { - if err := db.Transaction(invokeHookFunc); err != nil { - return nil, err - } - } - - if err := json.Unmarshal(response, output); err != nil { - return response, err - } - - return response, nil +type Manager struct { + db *storage.Connection + config *conf.GlobalConfiguration } -func (a *API) runHTTPHook(r *http.Request, hookConfig conf.ExtensibilityPointConfiguration, input any) ([]byte, error) { - ctx := r.Context() - client := http.Client{ - Timeout: DefaultHTTPHookTimeout, +func NewManager( + db *storage.Connection, + config *conf.GlobalConfiguration, +) *Manager { + return &Manager{ + db: db, + config: config, } - ctx, cancel := context.WithTimeout(ctx, DefaultHTTPHookTimeout) - defer cancel() - - log := observability.GetLogEntry(r).Entry - requestURL := hookConfig.URI - hookLog := log.WithFields(logrus.Fields{ - "component": "auth_hook", - "url": requestURL, - }) - - inputPayload, err := json.Marshal(input) - if err != nil { - return nil, err - } - for i := 0; i < DefaultHTTPHookRetries; i++ { - if i == 0 { - hookLog.Debugf("invocation attempt: %d", i) - } else { - hookLog.Infof("invocation attempt: %d", i) - } - msgID := uuid.Must(uuid.NewV4()) - currentTime := time.Now() - signatureList, err := generateSignatures(hookConfig.HTTPHookSecrets, msgID, currentTime, inputPayload) - if err != nil { - return nil, err - } - - req, err := http.NewRequestWithContext(ctx, http.MethodPost, requestURL, bytes.NewBuffer(inputPayload)) - if err != nil { - panic("Failed to make request object") - } - - req.Header.Set("Content-Type", "application/json") - req.Header.Set("webhook-id", msgID.String()) - req.Header.Set("webhook-timestamp", fmt.Sprintf("%d", currentTime.Unix())) - req.Header.Set("webhook-signature", strings.Join(signatureList, ", ")) - // By default, Go Client sets encoding to gzip, which does not carry a content length header. - req.Header.Set("Accept-Encoding", "identity") - - rsp, err := client.Do(req) - if err != nil && errors.Is(err, context.DeadlineExceeded) { - return nil, unprocessableEntityError(apierrors.ErrorCodeHookTimeout, fmt.Sprintf("Failed to reach hook within maximum time of %f seconds", DefaultHTTPHookTimeout.Seconds())) - - } else if err != nil { - if terr, ok := err.(net.Error); ok && terr.Timeout() || i < DefaultHTTPHookRetries-1 { - hookLog.Errorf("Request timed out for attempt %d with err %s", i, err) - time.Sleep(HTTPHookBackoffDuration) - continue - } else if i == DefaultHTTPHookRetries-1 { - return nil, unprocessableEntityError(apierrors.ErrorCodeHookTimeoutAfterRetry, "Failed to reach hook after maximum retries") - } else { - return nil, internalServerError("Failed to trigger auth hook, error making HTTP request").WithInternalError(err) - } - } +} - defer rsp.Body.Close() +func (o *Manager) InvokeHook( + conn *storage.Connection, + r *http.Request, + input, output any, +) error { + return o.invokeHook(conn, r, input, output) +} - switch rsp.StatusCode { - case http.StatusOK, http.StatusNoContent, http.StatusAccepted: - // Header.Get is case insensitive - contentType := rsp.Header.Get("Content-Type") - if contentType == "" { - return nil, badRequestError(apierrors.ErrorCodeHookPayloadInvalidContentType, "Invalid Content-Type: Missing Content-Type header") - } - mediaType, _, err := mime.ParseMediaType(contentType) - if err != nil { - return nil, badRequestError(apierrors.ErrorCodeHookPayloadInvalidContentType, fmt.Sprintf("Invalid Content-Type header: %s", err.Error())) - } - if mediaType != "application/json" { - return nil, badRequestError(apierrors.ErrorCodeHookPayloadInvalidContentType, "Invalid JSON response. Received content-type: "+contentType) - } - if rsp.Body == nil { - return nil, nil - } - limitedReader := io.LimitedReader{R: rsp.Body, N: PayloadLimit} - body, err := io.ReadAll(&limitedReader) - if err != nil { - return nil, err - } - if limitedReader.N <= 0 { - // check if the response body still has excess bytes to be read - if n, _ := rsp.Body.Read(make([]byte, 1)); n > 0 { - return nil, unprocessableEntityError(apierrors.ErrorCodeHookPayloadOverSizeLimit, fmt.Sprintf("Payload size exceeded size limit of %d bytes", PayloadLimit)) - } - } - return body, nil - case http.StatusTooManyRequests, http.StatusServiceUnavailable: - retryAfterHeader := rsp.Header.Get("retry-after") - // Check for truthy values to allow for flexibility to switch to time duration - if retryAfterHeader != "" { - continue - } - return nil, internalServerError("Service currently unavailable due to hook") - case http.StatusBadRequest: - return nil, internalServerError("Invalid payload sent to hook") - case http.StatusUnauthorized: - return nil, internalServerError("Hook requires authorization token") - default: - return nil, internalServerError("Unexpected status code returned from hook: %d", rsp.StatusCode) - } - } - return nil, nil +func (o *Manager) RunHTTPHook( + r *http.Request, + hookConfig conf.ExtensibilityPointConfiguration, + input any, +) ([]byte, error) { + return o.runHTTPHook(r, hookConfig, input) } -// invokePostgresHook invokes the hook code. conn can be nil, in which case a new +// invokeHook invokes the hook code. conn can be nil, in which case a new // transaction is opened. If calling invokeHook within a transaction, always // pass the current transaction, as pool-exhaustion deadlocks are very easy to // trigger. -func (a *API) invokeHook(conn *storage.Connection, r *http.Request, input, output any) error { +func (o *Manager) invokeHook( + conn *storage.Connection, + r *http.Request, + input, output any, +) error { var err error var response []byte switch input.(type) { - case *hooks.SendSMSInput: - hookOutput, ok := output.(*hooks.SendSMSOutput) + case *SendSMSInput: + hookOutput, ok := output.(*SendSMSOutput) if !ok { panic("output should be *hooks.SendSMSOutput") } - if response, err = a.runHook(r, conn, a.config.Hook.SendSMS, input, output); err != nil { + if response, err = o.runHook(r, conn, o.config.Hook.SendSMS, input, output); err != nil { return err } if err := json.Unmarshal(response, hookOutput); err != nil { - return internalServerError("Error unmarshaling Send SMS output.").WithInternalError(err) + return apierrors.NewInternalServerError("Error unmarshaling Send SMS output.").WithInternalError(err) } if hookOutput.IsError() { httpCode := hookOutput.HookError.HTTPCode @@ -210,23 +92,23 @@ func (a *API) invokeHook(conn *storage.Connection, r *http.Request, input, outpu if httpCode == 0 { httpCode = http.StatusInternalServerError } - httpError := &HTTPError{ + httpError := &apierrors.HTTPError{ HTTPStatus: httpCode, Message: hookOutput.HookError.Message, } return httpError.WithInternalError(&hookOutput.HookError) } return nil - case *hooks.SendEmailInput: - hookOutput, ok := output.(*hooks.SendEmailOutput) + case *SendEmailInput: + hookOutput, ok := output.(*SendEmailOutput) if !ok { panic("output should be *hooks.SendEmailOutput") } - if response, err = a.runHook(r, conn, a.config.Hook.SendEmail, input, output); err != nil { + if response, err = o.runHook(r, conn, o.config.Hook.SendEmail, input, output); err != nil { return err } if err := json.Unmarshal(response, hookOutput); err != nil { - return internalServerError("Error unmarshaling Send Email output.").WithInternalError(err) + return apierrors.NewInternalServerError("Error unmarshaling Send Email output.").WithInternalError(err) } if hookOutput.IsError() { httpCode := hookOutput.HookError.HTTPCode @@ -235,7 +117,7 @@ func (a *API) invokeHook(conn *storage.Connection, r *http.Request, input, outpu httpCode = http.StatusInternalServerError } - httpError := &HTTPError{ + httpError := &apierrors.HTTPError{ HTTPStatus: httpCode, Message: hookOutput.HookError.Message, } @@ -243,16 +125,16 @@ func (a *API) invokeHook(conn *storage.Connection, r *http.Request, input, outpu return httpError.WithInternalError(&hookOutput.HookError) } return nil - case *hooks.MFAVerificationAttemptInput: - hookOutput, ok := output.(*hooks.MFAVerificationAttemptOutput) + case *MFAVerificationAttemptInput: + hookOutput, ok := output.(*MFAVerificationAttemptOutput) if !ok { panic("output should be *hooks.MFAVerificationAttemptOutput") } - if response, err = a.runHook(r, conn, a.config.Hook.MFAVerificationAttempt, input, output); err != nil { + if response, err = o.runHook(r, conn, o.config.Hook.MFAVerificationAttempt, input, output); err != nil { return err } if err := json.Unmarshal(response, hookOutput); err != nil { - return internalServerError("Error unmarshaling MFA Verification Attempt output.").WithInternalError(err) + return apierrors.NewInternalServerError("Error unmarshaling MFA Verification Attempt output.").WithInternalError(err) } if hookOutput.IsError() { httpCode := hookOutput.HookError.HTTPCode @@ -261,7 +143,7 @@ func (a *API) invokeHook(conn *storage.Connection, r *http.Request, input, outpu httpCode = http.StatusInternalServerError } - httpError := &HTTPError{ + httpError := &apierrors.HTTPError{ HTTPStatus: httpCode, Message: hookOutput.HookError.Message, } @@ -269,17 +151,17 @@ func (a *API) invokeHook(conn *storage.Connection, r *http.Request, input, outpu return httpError.WithInternalError(&hookOutput.HookError) } return nil - case *hooks.PasswordVerificationAttemptInput: - hookOutput, ok := output.(*hooks.PasswordVerificationAttemptOutput) + case *PasswordVerificationAttemptInput: + hookOutput, ok := output.(*PasswordVerificationAttemptOutput) if !ok { panic("output should be *hooks.PasswordVerificationAttemptOutput") } - if response, err = a.runHook(r, conn, a.config.Hook.PasswordVerificationAttempt, input, output); err != nil { + if response, err = o.runHook(r, conn, o.config.Hook.PasswordVerificationAttempt, input, output); err != nil { return err } if err := json.Unmarshal(response, hookOutput); err != nil { - return internalServerError("Error unmarshaling Password Verification Attempt output.").WithInternalError(err) + return apierrors.NewInternalServerError("Error unmarshaling Password Verification Attempt output.").WithInternalError(err) } if hookOutput.IsError() { httpCode := hookOutput.HookError.HTTPCode @@ -288,7 +170,7 @@ func (a *API) invokeHook(conn *storage.Connection, r *http.Request, input, outpu httpCode = http.StatusInternalServerError } - httpError := &HTTPError{ + httpError := &apierrors.HTTPError{ HTTPStatus: httpCode, Message: hookOutput.HookError.Message, } @@ -297,16 +179,16 @@ func (a *API) invokeHook(conn *storage.Connection, r *http.Request, input, outpu } return nil - case *hooks.CustomAccessTokenInput: - hookOutput, ok := output.(*hooks.CustomAccessTokenOutput) + case *CustomAccessTokenInput: + hookOutput, ok := output.(*CustomAccessTokenOutput) if !ok { panic("output should be *hooks.CustomAccessTokenOutput") } - if response, err = a.runHook(r, conn, a.config.Hook.CustomAccessToken, input, output); err != nil { + if response, err = o.runHook(r, conn, o.config.Hook.CustomAccessToken, input, output); err != nil { return err } if err := json.Unmarshal(response, hookOutput); err != nil { - return internalServerError("Error unmarshaling Custom Access Token output.").WithInternalError(err) + return apierrors.NewInternalServerError("Error unmarshaling Custom Access Token output.").WithInternalError(err) } if hookOutput.IsError() { @@ -316,7 +198,7 @@ func (a *API) invokeHook(conn *storage.Connection, r *http.Request, input, outpu httpCode = http.StatusInternalServerError } - httpError := &HTTPError{ + httpError := &apierrors.HTTPError{ HTTPStatus: httpCode, Message: hookOutput.HookError.Message, } @@ -329,7 +211,7 @@ func (a *API) invokeHook(conn *storage.Connection, r *http.Request, input, outpu if httpCode == 0 { httpCode = http.StatusInternalServerError } - httpError := &HTTPError{ + httpError := &apierrors.HTTPError{ HTTPStatus: httpCode, Message: err.Error(), } @@ -341,7 +223,12 @@ func (a *API) invokeHook(conn *storage.Connection, r *http.Request, input, outpu return nil } -func (a *API) runHook(r *http.Request, conn *storage.Connection, hookConfig conf.ExtensibilityPointConfiguration, input, output any) ([]byte, error) { +func (o *Manager) runHook( + r *http.Request, + conn *storage.Connection, + hookConfig conf.ExtensibilityPointConfiguration, + input, output any, +) ([]byte, error) { ctx := r.Context() logEntry := observability.GetLogEntry(r) @@ -352,9 +239,9 @@ func (a *API) runHook(r *http.Request, conn *storage.Connection, hookConfig conf switch { case strings.HasPrefix(hookConfig.URI, "http:") || strings.HasPrefix(hookConfig.URI, "https:"): - response, err = a.runHTTPHook(r, hookConfig, input) + response, err = o.runHTTPHook(r, hookConfig, input) case strings.HasPrefix(hookConfig.URI, "pg-functions:"): - response, err = a.runPostgresHook(ctx, conn, hookConfig, input, output) + response, err = o.runPostgresHook(ctx, conn, hookConfig, input, output) default: return nil, fmt.Errorf("unsupported protocol: %q only postgres hooks and HTTPS functions are supported at the moment", hookConfig.URI) } @@ -369,7 +256,7 @@ func (a *API) runHook(r *http.Request, conn *storage.Connection, hookConfig conf "duration": duration.Microseconds(), }).WithError(err).Warn("Hook errored out") - return nil, internalServerError("Error running hook URI: %v", hookConfig.URI).WithInternalError(err) + return nil, apierrors.NewInternalServerError("Error running hook URI: %v", hookConfig.URI).WithInternalError(err) } logEntry.Entry.WithFields(logrus.Fields{ @@ -382,6 +269,168 @@ func (a *API) runHook(r *http.Request, conn *storage.Connection, hookConfig conf return response, nil } +func (o *Manager) runPostgresHook( + ctx context.Context, + tx *storage.Connection, + hookConfig conf.ExtensibilityPointConfiguration, + input, output any, +) ([]byte, error) { + db := o.db.WithContext(ctx) + + request, err := json.Marshal(input) + if err != nil { + panic(err) + } + + var response []byte + invokeHookFunc := func(tx *storage.Connection) error { + // We rely on Postgres timeouts to ensure the function doesn't overrun + if terr := tx.RawQuery(fmt.Sprintf("set local statement_timeout TO '%d';", DefaultTimeout)).Exec(); terr != nil { + return terr + } + + if terr := tx.RawQuery(fmt.Sprintf("select %s(?);", hookConfig.HookName), request).First(&response); terr != nil { + return terr + } + + // reset the timeout + if terr := tx.RawQuery("set local statement_timeout TO default;").Exec(); terr != nil { + return terr + } + + return nil + } + + if tx != nil { + if err := invokeHookFunc(tx); err != nil { + return nil, err + } + } else { + if err := db.Transaction(invokeHookFunc); err != nil { + return nil, err + } + } + + if err := json.Unmarshal(response, output); err != nil { + return response, err + } + + return response, nil +} + +func (o *Manager) runHTTPHook( + r *http.Request, + hookConfig conf.ExtensibilityPointConfiguration, + input any, +) ([]byte, error) { + ctx := r.Context() + client := http.Client{ + Timeout: DefaultHTTPHookTimeout, + } + ctx, cancel := context.WithTimeout(ctx, DefaultHTTPHookTimeout) + defer cancel() + + log := observability.GetLogEntry(r).Entry + requestURL := hookConfig.URI + hookLog := log.WithFields(logrus.Fields{ + "component": "auth_hook", + "url": requestURL, + }) + + inputPayload, err := json.Marshal(input) + if err != nil { + return nil, err + } + for i := 0; i < DefaultHTTPHookRetries; i++ { + if i == 0 { + hookLog.Debugf("invocation attempt: %d", i) + } else { + hookLog.Infof("invocation attempt: %d", i) + } + msgID := uuid.Must(uuid.NewV4()) + currentTime := time.Now() + signatureList, err := generateSignatures(hookConfig.HTTPHookSecrets, msgID, currentTime, inputPayload) + if err != nil { + return nil, err + } + + req, err := http.NewRequestWithContext(ctx, http.MethodPost, requestURL, bytes.NewBuffer(inputPayload)) + if err != nil { + panic("Failed to make request object") + } + + req.Header.Set("Content-Type", "application/json") + req.Header.Set("webhook-id", msgID.String()) + req.Header.Set("webhook-timestamp", fmt.Sprintf("%d", currentTime.Unix())) + req.Header.Set("webhook-signature", strings.Join(signatureList, ", ")) + // By default, Go Client sets encoding to gzip, which does not carry a content length header. + req.Header.Set("Accept-Encoding", "identity") + + rsp, err := client.Do(req) + if err != nil && errors.Is(err, context.DeadlineExceeded) { + return nil, apierrors.NewUnprocessableEntityError(apierrors.ErrorCodeHookTimeout, fmt.Sprintf("Failed to reach hook within maximum time of %f seconds", DefaultHTTPHookTimeout.Seconds())) + + } else if err != nil { + if terr, ok := err.(net.Error); ok && terr.Timeout() || i < DefaultHTTPHookRetries-1 { + hookLog.Errorf("Request timed out for attempt %d with err %s", i, err) + time.Sleep(HTTPHookBackoffDuration) + continue + } else if i == DefaultHTTPHookRetries-1 { + return nil, apierrors.NewUnprocessableEntityError(apierrors.ErrorCodeHookTimeoutAfterRetry, "Failed to reach hook after maximum retries") + } else { + return nil, apierrors.NewInternalServerError("Failed to trigger auth hook, error making HTTP request").WithInternalError(err) + } + } + + defer rsp.Body.Close() + + switch rsp.StatusCode { + case http.StatusOK, http.StatusNoContent, http.StatusAccepted: + // Header.Get is case insensitive + contentType := rsp.Header.Get("Content-Type") + if contentType == "" { + return nil, apierrors.NewBadRequestError(apierrors.ErrorCodeHookPayloadInvalidContentType, "Invalid Content-Type: Missing Content-Type header") + } + mediaType, _, err := mime.ParseMediaType(contentType) + if err != nil { + return nil, apierrors.NewBadRequestError(apierrors.ErrorCodeHookPayloadInvalidContentType, fmt.Sprintf("Invalid Content-Type header: %s", err.Error())) + } + if mediaType != "application/json" { + return nil, apierrors.NewBadRequestError(apierrors.ErrorCodeHookPayloadInvalidContentType, "Invalid JSON response. Received content-type: "+contentType) + } + if rsp.Body == nil { + return nil, nil + } + limitedReader := io.LimitedReader{R: rsp.Body, N: PayloadLimit} + body, err := io.ReadAll(&limitedReader) + if err != nil { + return nil, err + } + if limitedReader.N <= 0 { + // check if the response body still has excess bytes to be read + if n, _ := rsp.Body.Read(make([]byte, 1)); n > 0 { + return nil, apierrors.NewUnprocessableEntityError(apierrors.ErrorCodeHookPayloadOverSizeLimit, fmt.Sprintf("Payload size exceeded size limit of %d bytes", PayloadLimit)) + } + } + return body, nil + case http.StatusTooManyRequests, http.StatusServiceUnavailable: + retryAfterHeader := rsp.Header.Get("retry-after") + // Check for truthy values to allow for flexibility to switch to time duration + if retryAfterHeader != "" { + continue + } + return nil, apierrors.NewInternalServerError("Service currently unavailable due to hook") + case http.StatusBadRequest: + return nil, apierrors.NewInternalServerError("Invalid payload sent to hook") + case http.StatusUnauthorized: + return nil, apierrors.NewInternalServerError("Hook requires authorization token") + default: + return nil, apierrors.NewInternalServerError("Unexpected status code returned from hook: %d", rsp.StatusCode) + } + } + return nil, nil +} + func generateSignatures(secrets []string, msgID uuid.UUID, currentTime time.Time, inputPayload []byte) ([]string, error) { SymmetricSignaturePrefix := "v1," // TODO(joel): Handle asymmetric case once library has been upgraded @@ -404,3 +453,27 @@ func generateSignatures(secrets []string, msgID uuid.UUID, currentTime time.Time } return signatureList, nil } + +func validateTokenClaims(outputClaims map[string]interface{}) error { + schemaLoader := gojsonschema.NewStringLoader(MinimumViableTokenSchema) + + documentLoader := gojsonschema.NewGoLoader(outputClaims) + + result, err := gojsonschema.Validate(schemaLoader, documentLoader) + if err != nil { + return err + } + + if !result.Valid() { + var errorMessages string + + for _, desc := range result.Errors() { + errorMessages += fmt.Sprintf("- %s\n", desc) + fmt.Printf("- %s\n", desc) + } + return fmt.Errorf("output claims do not conform to the expected schema: \n%s", errorMessages) + + } + + return nil +} diff --git a/internal/hooks/auth_hooks.go b/internal/hooks/v0hooks/v0hooks.go similarity index 99% rename from internal/hooks/auth_hooks.go rename to internal/hooks/v0hooks/v0hooks.go index 1b881d36f..ac1d3d962 100644 --- a/internal/hooks/auth_hooks.go +++ b/internal/hooks/v0hooks/v0hooks.go @@ -1,4 +1,4 @@ -package hooks +package v0hooks import ( "github.com/gofrs/uuid" From 717ea5f22c7f27e56c2f9a77a9b0552297a48737 Mon Sep 17 00:00:00 2001 From: Chris Stockton Date: Mon, 14 Apr 2025 11:08:29 -0700 Subject: [PATCH 2/2] fix: missed an unused function when resolving conflicts --- internal/api/token.go | 26 -------------------------- 1 file changed, 26 deletions(-) diff --git a/internal/api/token.go b/internal/api/token.go index fd4631f55..c2efcae1a 100644 --- a/internal/api/token.go +++ b/internal/api/token.go @@ -2,7 +2,6 @@ package api import ( "context" - "fmt" "net/http" "net/url" "strconv" @@ -10,7 +9,6 @@ import ( "github.com/gofrs/uuid" "github.com/golang-jwt/jwt/v5" - "github.com/xeipuuv/gojsonschema" "github.com/supabase/auth/internal/api/apierrors" "github.com/supabase/auth/internal/hooks/v0hooks" @@ -493,27 +491,3 @@ func (a *API) updateMFASessionAndClaims(r *http.Request, tx *storage.Connection, User: user, }, nil } - -func validateTokenClaims(outputClaims map[string]interface{}) error { - schemaLoader := gojsonschema.NewStringLoader(v0hooks.MinimumViableTokenSchema) - - documentLoader := gojsonschema.NewGoLoader(outputClaims) - - result, err := gojsonschema.Validate(schemaLoader, documentLoader) - if err != nil { - return err - } - - if !result.Valid() { - var errorMessages string - - for _, desc := range result.Errors() { - errorMessages += fmt.Sprintf("- %s\n", desc) - fmt.Printf("- %s\n", desc) - } - return fmt.Errorf("output claims do not conform to the expected schema: \n%s", errorMessages) - - } - - return nil -}