Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 32 additions & 32 deletions internal/api/admin.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,17 +54,17 @@ func (a *API) loadUser(w http.ResponseWriter, r *http.Request) (context.Context,

userID, err := uuid.FromString(chi.URLParam(r, "user_id"))
if err != nil {
return nil, notFoundError(apierrors.ErrorCodeValidationFailed, "user_id must be an UUID")
return nil, apierrors.NewNotFoundError(apierrors.ErrorCodeValidationFailed, "user_id must be an UUID")
}

observability.LogEntrySetField(r, "user_id", userID)

u, err := models.FindUserByID(db, userID)
if err != nil {
if models.IsNotFoundError(err) {
return nil, notFoundError(apierrors.ErrorCodeUserNotFound, "User not found")
return nil, apierrors.NewNotFoundError(apierrors.ErrorCodeUserNotFound, "User not found")
}
return nil, internalServerError("Database error loading user").WithInternalError(err)
return nil, apierrors.NewInternalServerError("Database error loading user").WithInternalError(err)
}

return withUser(ctx, u), nil
Expand All @@ -77,17 +77,17 @@ func (a *API) loadFactor(w http.ResponseWriter, r *http.Request) (context.Contex
user := getUser(ctx)
factorID, err := uuid.FromString(chi.URLParam(r, "factor_id"))
if err != nil {
return nil, notFoundError(apierrors.ErrorCodeValidationFailed, "factor_id must be an UUID")
return nil, apierrors.NewNotFoundError(apierrors.ErrorCodeValidationFailed, "factor_id must be an UUID")
}

observability.LogEntrySetField(r, "factor_id", factorID)

factor, err := user.FindOwnedFactorByID(db, factorID)
if err != nil {
if models.IsNotFoundError(err) {
return nil, notFoundError(apierrors.ErrorCodeMFAFactorNotFound, "Factor not found")
return nil, apierrors.NewNotFoundError(apierrors.ErrorCodeMFAFactorNotFound, "Factor not found")
}
return nil, internalServerError("Database error loading factor").WithInternalError(err)
return nil, apierrors.NewInternalServerError("Database error loading factor").WithInternalError(err)
}
return withFactor(ctx, factor), nil
}
Expand All @@ -109,19 +109,19 @@ func (a *API) adminUsers(w http.ResponseWriter, r *http.Request) error {

pageParams, err := paginate(r)
if err != nil {
return badRequestError(apierrors.ErrorCodeValidationFailed, "Bad Pagination Parameters: %v", err).WithInternalError(err)
return apierrors.NewBadRequestError(apierrors.ErrorCodeValidationFailed, "Bad Pagination Parameters: %v", err).WithInternalError(err)
}

sortParams, err := sort(r, map[string]bool{models.CreatedAt: true}, []models.SortField{{Name: models.CreatedAt, Dir: models.Descending}})
if err != nil {
return badRequestError(apierrors.ErrorCodeValidationFailed, "Bad Sort Parameters: %v", err)
return apierrors.NewBadRequestError(apierrors.ErrorCodeValidationFailed, "Bad Sort Parameters: %v", err)
}

filter := r.URL.Query().Get("filter")

users, err := models.FindUsersInAudience(db, aud, pageParams, sortParams, filter)
if err != nil {
return internalServerError("Database error finding users").WithInternalError(err)
return apierrors.NewInternalServerError("Database error finding users").WithInternalError(err)
}
addPaginationHeaders(w, r, pageParams)

Expand Down Expand Up @@ -170,7 +170,7 @@ func (a *API) adminUserUpdate(w http.ResponseWriter, r *http.Request) error {
if params.BanDuration != "none" {
duration, err = time.ParseDuration(params.BanDuration)
if err != nil {
return badRequestError(apierrors.ErrorCodeValidationFailed, "invalid format for ban duration: %v", err)
return apierrors.NewBadRequestError(apierrors.ErrorCodeValidationFailed, "invalid format for ban duration: %v", err)
}
}
banDuration = &duration
Expand Down Expand Up @@ -315,7 +315,7 @@ func (a *API) adminUserUpdate(w http.ResponseWriter, r *http.Request) error {
})

if err != nil {
return internalServerError("Error updating user").WithInternalError(err)
return apierrors.NewInternalServerError("Error updating user").WithInternalError(err)
}

return sendJSON(w, http.StatusOK, user)
Expand All @@ -339,7 +339,7 @@ func (a *API) adminUserCreate(w http.ResponseWriter, r *http.Request) error {
}

if params.Email == "" && params.Phone == "" {
return badRequestError(apierrors.ErrorCodeValidationFailed, "Cannot create a user without either an email or phone")
return apierrors.NewBadRequestError(apierrors.ErrorCodeValidationFailed, "Cannot create a user without either an email or phone")
}

var providers []string
Expand All @@ -349,9 +349,9 @@ func (a *API) adminUserCreate(w http.ResponseWriter, r *http.Request) error {
return err
}
if user, err := models.IsDuplicatedEmail(db, params.Email, aud, nil); err != nil {
return internalServerError("Database error checking email").WithInternalError(err)
return apierrors.NewInternalServerError("Database error checking email").WithInternalError(err)
} else if user != nil {
return unprocessableEntityError(apierrors.ErrorCodeEmailExists, DuplicateEmailMsg)
return apierrors.NewUnprocessableEntityError(apierrors.ErrorCodeEmailExists, DuplicateEmailMsg)
}
providers = append(providers, "email")
}
Expand All @@ -362,21 +362,21 @@ func (a *API) adminUserCreate(w http.ResponseWriter, r *http.Request) error {
return err
}
if exists, err := models.IsDuplicatedPhone(db, params.Phone, aud); err != nil {
return internalServerError("Database error checking phone").WithInternalError(err)
return apierrors.NewInternalServerError("Database error checking phone").WithInternalError(err)
} else if exists {
return unprocessableEntityError(apierrors.ErrorCodePhoneExists, "Phone number already registered by another user")
return apierrors.NewUnprocessableEntityError(apierrors.ErrorCodePhoneExists, "Phone number already registered by another user")
}
providers = append(providers, "phone")
}

if params.Password != nil && params.PasswordHash != "" {
return badRequestError(apierrors.ErrorCodeValidationFailed, "Only a password or a password hash should be provided")
return apierrors.NewBadRequestError(apierrors.ErrorCodeValidationFailed, "Only a password or a password hash should be provided")
}

if (params.Password == nil || *params.Password == "") && params.PasswordHash == "" {
password, err := password.Generate(64, 10, 0, false, true)
if err != nil {
return internalServerError("Error generating password").WithInternalError(err)
return apierrors.NewInternalServerError("Error generating password").WithInternalError(err)
}
params.Password = &password
}
Expand All @@ -390,18 +390,18 @@ func (a *API) adminUserCreate(w http.ResponseWriter, r *http.Request) error {

if err != nil {
if errors.Is(err, bcrypt.ErrPasswordTooLong) {
return badRequestError(apierrors.ErrorCodeValidationFailed, err.Error())
return apierrors.NewBadRequestError(apierrors.ErrorCodeValidationFailed, err.Error())
}
return internalServerError("Error creating user").WithInternalError(err)
return apierrors.NewInternalServerError("Error creating user").WithInternalError(err)
}

if params.Id != "" {
customId, err := uuid.FromString(params.Id)
if err != nil {
return badRequestError(apierrors.ErrorCodeValidationFailed, "ID must conform to the uuid v4 format")
return apierrors.NewBadRequestError(apierrors.ErrorCodeValidationFailed, "ID must conform to the uuid v4 format")
}
if customId == uuid.Nil {
return badRequestError(apierrors.ErrorCodeValidationFailed, "ID cannot be a nil uuid")
return apierrors.NewBadRequestError(apierrors.ErrorCodeValidationFailed, "ID cannot be a nil uuid")
}
user.ID = customId
}
Expand All @@ -419,7 +419,7 @@ func (a *API) adminUserCreate(w http.ResponseWriter, r *http.Request) error {
if params.BanDuration != "none" {
duration, err = time.ParseDuration(params.BanDuration)
if err != nil {
return badRequestError(apierrors.ErrorCodeValidationFailed, "invalid format for ban duration: %v", err)
return apierrors.NewBadRequestError(apierrors.ErrorCodeValidationFailed, "invalid format for ban duration: %v", err)
}
}
banDuration = &duration
Expand Down Expand Up @@ -501,7 +501,7 @@ func (a *API) adminUserCreate(w http.ResponseWriter, r *http.Request) error {
})

if err != nil {
return internalServerError("Database error creating new user").WithInternalError(err)
return apierrors.NewInternalServerError("Database error creating new user").WithInternalError(err)
}

return sendJSON(w, http.StatusOK, user)
Expand Down Expand Up @@ -529,7 +529,7 @@ func (a *API) adminUserDelete(w http.ResponseWriter, r *http.Request) error {
"user_email": user.Email,
"user_phone": user.Phone,
}); terr != nil {
return internalServerError("Error recording audit log entry").WithInternalError(terr)
return apierrors.NewInternalServerError("Error recording audit log entry").WithInternalError(terr)
}

if params.ShouldSoftDelete {
Expand All @@ -538,24 +538,24 @@ func (a *API) adminUserDelete(w http.ResponseWriter, r *http.Request) error {
return nil
}
if terr := user.SoftDeleteUser(tx); terr != nil {
return internalServerError("Error soft deleting user").WithInternalError(terr)
return apierrors.NewInternalServerError("Error soft deleting user").WithInternalError(terr)
}

if terr := user.SoftDeleteUserIdentities(tx); terr != nil {
return internalServerError("Error soft deleting user identities").WithInternalError(terr)
return apierrors.NewInternalServerError("Error soft deleting user identities").WithInternalError(terr)
}

// hard delete all associated factors
if terr := models.DeleteFactorsByUserId(tx, user.ID); terr != nil {
return internalServerError("Error deleting user's factors").WithInternalError(terr)
return apierrors.NewInternalServerError("Error deleting user's factors").WithInternalError(terr)
}
// hard delete all associated sessions
if terr := models.Logout(tx, user.ID); terr != nil {
return internalServerError("Error deleting user's sessions").WithInternalError(terr)
return apierrors.NewInternalServerError("Error deleting user's sessions").WithInternalError(terr)
}
} else {
if terr := tx.Destroy(user); terr != nil {
return internalServerError("Database error deleting user").WithInternalError(terr)
return apierrors.NewInternalServerError("Database error deleting user").WithInternalError(terr)
}
}

Expand All @@ -581,7 +581,7 @@ func (a *API) adminUserDeleteFactor(w http.ResponseWriter, r *http.Request) erro
return terr
}
if terr := tx.Destroy(factor); terr != nil {
return internalServerError("Database error deleting factor").WithInternalError(terr)
return apierrors.NewInternalServerError("Database error deleting factor").WithInternalError(terr)
}
return nil
})
Expand Down Expand Up @@ -619,7 +619,7 @@ func (a *API) adminUserUpdateFactor(w http.ResponseWriter, r *http.Request) erro
if params.Phone != "" && factor.IsPhoneFactor() {
phone, err := validatePhone(params.Phone)
if err != nil {
return badRequestError(apierrors.ErrorCodeValidationFailed, "Invalid phone number format (E.164 required)")
return apierrors.NewBadRequestError(apierrors.ErrorCodeValidationFailed, "Invalid phone number format (E.164 required)")
}
if terr := factor.UpdatePhone(tx, phone); terr != nil {
return terr
Expand Down
4 changes: 2 additions & 2 deletions internal/api/anonymous.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ func (a *API) SignupAnonymously(w http.ResponseWriter, r *http.Request) error {
aud := a.requestAud(ctx, r)

if config.DisableSignup {
return unprocessableEntityError(apierrors.ErrorCodeSignupDisabled, "Signups not allowed for this instance")
return apierrors.NewUnprocessableEntityError(apierrors.ErrorCodeSignupDisabled, "Signups not allowed for this instance")
}

params := &SignupParams{}
Expand Down Expand Up @@ -48,7 +48,7 @@ func (a *API) SignupAnonymously(w http.ResponseWriter, r *http.Request) error {
return nil
})
if err != nil {
return internalServerError("Database error creating anonymous user").WithInternalError(err)
return apierrors.NewInternalServerError("Database error creating anonymous user").WithInternalError(err)
}

metering.RecordLogin("anonymous", newUser.ID)
Expand Down
10 changes: 9 additions & 1 deletion internal/api/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (
"github.com/sirupsen/logrus"
"github.com/supabase/auth/internal/api/apierrors"
"github.com/supabase/auth/internal/conf"
"github.com/supabase/auth/internal/hooks"
"github.com/supabase/auth/internal/mailer"
"github.com/supabase/auth/internal/models"
"github.com/supabase/auth/internal/observability"
Expand All @@ -32,6 +33,7 @@ type API struct {
config *conf.GlobalConfiguration
version string

hooksMgr *hooks.Manager
hibpClient *hibp.PwnedClient

// overrideTime can be used to override the clock used by handlers. Should only be used in tests!
Expand All @@ -40,6 +42,9 @@ type API struct {
limiterOpts *LimiterOptions
}

func (a *API) GetConfig() *conf.GlobalConfiguration { return a.config }
func (a *API) GetDB() *storage.Connection { return a.db }

func (a *API) Version() string {
return a.version
}
Expand Down Expand Up @@ -81,6 +86,9 @@ func NewAPIWithVersion(globalConfig *conf.GlobalConfiguration, db *storage.Conne
if api.limiterOpts == nil {
api.limiterOpts = NewLimiterOptions(globalConfig)
}
if api.hooksMgr == nil {
api.hooksMgr = hooks.NewManager(db, globalConfig)
}
if api.config.Password.HIBP.Enabled {
httpClient := &http.Client{
// all HIBP API requests should finish quickly to avoid
Expand Down Expand Up @@ -157,7 +165,7 @@ func NewAPIWithVersion(globalConfig *conf.GlobalConfiguration, db *storage.Conne
}
if params.Email == "" && params.Phone == "" {
if !api.config.External.AnonymousUsers.Enabled {
return unprocessableEntityError(apierrors.ErrorCodeAnonymousProviderDisabled, "Anonymous sign-ins are disabled")
return apierrors.NewUnprocessableEntityError(apierrors.ErrorCodeAnonymousProviderDisabled, "Anonymous sign-ins are disabled")
}
if _, err := api.limitHandler(limitAnonymousSignIns)(w, r); err != nil {
return err
Expand Down
35 changes: 32 additions & 3 deletions internal/api/apierrors/apierrors.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package apierrors

import (
"fmt"
"net/http"
)

// OAuthError is the JSON handler for OAuth2 error responses
Expand Down Expand Up @@ -30,7 +31,7 @@ func (e *OAuthError) WithInternalError(err error) *OAuthError {
}

// WithInternalMessage adds internal message information to the error
func (e *OAuthError) WithInternalMessage(fmtString string, args ...interface{}) *OAuthError {
func (e *OAuthError) WithInternalMessage(fmtString string, args ...any) *OAuthError {
e.InternalMessage = fmt.Sprintf(fmtString, args...)
return e
}
Expand All @@ -53,14 +54,42 @@ type HTTPError struct {
ErrorID string `json:"error_id,omitempty"`
}

func NewHTTPError(httpStatus int, errorCode ErrorCode, fmtString string, args ...interface{}) *HTTPError {
func NewHTTPError(httpStatus int, errorCode ErrorCode, fmtString string, args ...any) *HTTPError {
return &HTTPError{
HTTPStatus: httpStatus,
ErrorCode: errorCode,
Message: fmt.Sprintf(fmtString, args...),
}
}

func NewBadRequestError(errorCode ErrorCode, fmtString string, args ...any) *HTTPError {
return NewHTTPError(http.StatusBadRequest, errorCode, fmtString, args...)
}

func NewNotFoundError(errorCode ErrorCode, fmtString string, args ...any) *HTTPError {
return NewHTTPError(http.StatusNotFound, errorCode, fmtString, args...)
}

func NewForbiddenError(errorCode ErrorCode, fmtString string, args ...any) *HTTPError {
return NewHTTPError(http.StatusForbidden, errorCode, fmtString, args...)
}

func NewUnprocessableEntityError(errorCode ErrorCode, fmtString string, args ...any) *HTTPError {
return NewHTTPError(http.StatusUnprocessableEntity, errorCode, fmtString, args...)
}

func NewTooManyRequestsError(errorCode ErrorCode, fmtString string, args ...any) *HTTPError {
return NewHTTPError(http.StatusTooManyRequests, errorCode, fmtString, args...)
}

func NewInternalServerError(fmtString string, args ...any) *HTTPError {
return NewHTTPError(http.StatusInternalServerError, ErrorCodeUnexpectedFailure, fmtString, args...)
}

func NewConflictError(fmtString string, args ...any) *HTTPError {
return NewHTTPError(http.StatusConflict, ErrorCodeConflict, fmtString, args...)
}

func (e *HTTPError) Error() string {
if e.InternalMessage != "" {
return e.InternalMessage
Expand All @@ -87,7 +116,7 @@ func (e *HTTPError) WithInternalError(err error) *HTTPError {
}

// WithInternalMessage adds internal message information to the error
func (e *HTTPError) WithInternalMessage(fmtString string, args ...interface{}) *HTTPError {
func (e *HTTPError) WithInternalMessage(fmtString string, args ...any) *HTTPError {
e.InternalMessage = fmt.Sprintf(fmtString, args...)
return e
}
Loading
Loading