Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions internal/api/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -184,8 +184,8 @@ func NewAPIWithVersion(globalConfig *conf.GlobalConfiguration, db *storage.Conne
r.With(api.limitHandler(api.limiterOpts.Otp)).
With(api.verifyCaptcha).Post("/otp", api.Otp)

r.With(api.limitHandler(api.limiterOpts.Token)).
With(api.verifyCaptcha).Post("/token", api.Token)
// rate limiting applied in handler
r.With(api.verifyCaptcha).Post("/token", api.Token)

r.With(api.limitHandler(api.limiterOpts.Verify)).Route("/verify", func(r *router) {
r.Get("/", api.Verify)
Expand Down
58 changes: 38 additions & 20 deletions internal/api/middleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,25 +57,27 @@ func (f *FunctionHooks) UnmarshalJSON(b []byte) error {

var emailRateLimitCounter = observability.ObtainMetricCounter("gotrue_email_rate_limit_counter", "Number of times an email rate limit has been triggered")

func (a *API) limitHandler(lmt *limiter.Limiter) middlewareHandler {
return func(w http.ResponseWriter, req *http.Request) (context.Context, error) {
c := req.Context()

if limitHeader := a.config.RateLimitHeader; limitHeader != "" {
key := req.Header.Get(limitHeader)

if key == "" {
log := observability.GetLogEntry(req).Entry
log.WithField("header", limitHeader).Warn("request does not have a value for the rate limiting header, rate limiting is not applied")
return c, nil
} else {
err := tollbooth.LimitByKeys(lmt, []string{key})
if err != nil {
return c, tooManyRequestsError(apierrors.ErrorCodeOverRequestRateLimit, "Request rate limit reached")
}
func (a *API) performRateLimiting(lmt *limiter.Limiter, req *http.Request) error {
if limitHeader := a.config.RateLimitHeader; limitHeader != "" {
key := req.Header.Get(limitHeader)

if key == "" {
log := observability.GetLogEntry(req).Entry
log.WithField("header", limitHeader).Warn("request does not have a value for the rate limiting header, rate limiting is not applied")
} else {
err := tollbooth.LimitByKeys(lmt, []string{key})
if err != nil {
return tooManyRequestsError(apierrors.ErrorCodeOverRequestRateLimit, "Request rate limit reached")
}
}
return c, nil
}

return nil
}

func (a *API) limitHandler(lmt *limiter.Limiter) middlewareHandler {
return func(w http.ResponseWriter, req *http.Request) (context.Context, error) {
return req.Context(), a.performRateLimiting(lmt, req)
}
}

Expand Down Expand Up @@ -137,11 +139,27 @@ func (a *API) verifyCaptcha(w http.ResponseWriter, req *http.Request) (context.C
}

func isIgnoreCaptchaRoute(req *http.Request) bool {
// captcha shouldn't be enabled on the following grant_types
// id_token, refresh_token, pkce
if req.URL.Path == "/token" && req.FormValue("grant_type") != "password" {
if req.URL.Path != "/token" {
return false
}

switch req.FormValue("grant_type") {
case "pkce":
return true

case "refresh_token":
return true

case "id_token":
return true

case "password":
return false

case "web3":
return false
}

return false
}

Expand Down
6 changes: 6 additions & 0 deletions internal/api/options.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ type LimiterOptions struct {
FactorChallenge *limiter.Limiter
SSO *limiter.Limiter
SAMLAssertion *limiter.Limiter
Web3 *limiter.Limiter
}

func (lo *LimiterOptions) apply(a *API) { a.limiterOpts = lo }
Expand Down Expand Up @@ -85,6 +86,11 @@ func NewLimiterOptions(gc *conf.GlobalConfiguration) *LimiterOptions {
DefaultExpirationTTL: time.Hour,
}).SetBurst(30)

o.Web3 = tollbooth.NewLimiter(gc.RateLimitWeb3/(60*5),
&limiter.ExpirableOptions{
DefaultExpirationTTL: time.Hour,
}).SetBurst(30)

// These all use the OTP limit per 5 min with 1hour ttl and burst of 30.
o.Recover = newLimiterPer5mOver1h(gc.RateLimitOtp)
o.Resend = newLimiterPer5mOver1h(gc.RateLimitOtp)
Expand Down
20 changes: 15 additions & 5 deletions internal/api/token.go
Original file line number Diff line number Diff line change
Expand Up @@ -80,20 +80,30 @@ func (a *API) Token(w http.ResponseWriter, r *http.Request) error {
ctx := r.Context()
grantType := r.FormValue("grant_type")

handler := a.ResourceOwnerPasswordGrant
limiter := a.limiterOpts.Token

switch grantType {
case "password":
return a.ResourceOwnerPasswordGrant(ctx, w, r)
// set above
case "refresh_token":
return a.RefreshTokenGrant(ctx, w, r)
handler = a.RefreshTokenGrant
case "id_token":
return a.IdTokenGrant(ctx, w, r)
handler = a.IdTokenGrant
case "pkce":
return a.PKCE(ctx, w, r)
handler = a.PKCE
case "web3":
return a.Web3Grant(ctx, w, r)
handler = a.Web3Grant
limiter = a.limiterOpts.Web3
default:
return badRequestError(apierrors.ErrorCodeInvalidCredentials, "unsupported_grant_type")
}

if err := a.performRateLimiting(limiter, r); err != nil {
return err
}

return handler(ctx, w, r)
}

// ResourceOwnerPasswordGrant implements the password grant type flow
Expand Down
35 changes: 33 additions & 2 deletions internal/api/token_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -225,7 +225,38 @@ func (ts *TokenTestSuite) TestSingleSessionPerUserNoTags() {

func (ts *TokenTestSuite) TestRateLimitTokenRefresh() {
var buffer bytes.Buffer
req := httptest.NewRequest(http.MethodPost, "http://localhost/token", &buffer)
req := httptest.NewRequest(http.MethodPost, "http://localhost/token?grant_type=refresh_token", &buffer)
req.Header.Set("Content-Type", "application/json")
req.Header.Set("My-Custom-Header", "1.2.3.4")

// It rate limits after 30 requests
for i := 0; i < 30; i++ {
w := httptest.NewRecorder()
ts.API.handler.ServeHTTP(w, req)
assert.Equal(ts.T(), http.StatusBadRequest, w.Code)
}
w := httptest.NewRecorder()
ts.API.handler.ServeHTTP(w, req)
assert.Equal(ts.T(), http.StatusTooManyRequests, w.Code)

// It ignores X-Forwarded-For by default
req.Header.Set("X-Forwarded-For", "1.1.1.1")
w = httptest.NewRecorder()
ts.API.handler.ServeHTTP(w, req)
assert.Equal(ts.T(), http.StatusTooManyRequests, w.Code)

// It doesn't rate limit a new value for the limited header
req = httptest.NewRequest(http.MethodPost, "http://localhost/token?grant_type=refresh_token", &buffer)
req.Header.Set("Content-Type", "application/json")
req.Header.Set("My-Custom-Header", "5.6.7.8")
w = httptest.NewRecorder()
ts.API.handler.ServeHTTP(w, req)
assert.Equal(ts.T(), http.StatusBadRequest, w.Code)
}

func (ts *TokenTestSuite) TestRateLimitWeb3() {
var buffer bytes.Buffer
req := httptest.NewRequest(http.MethodPost, "http://localhost/token?grant_type=web3", &buffer)
req.Header.Set("Content-Type", "application/json")
req.Header.Set("My-Custom-Header", "1.2.3.4")

Expand All @@ -246,7 +277,7 @@ func (ts *TokenTestSuite) TestRateLimitTokenRefresh() {
assert.Equal(ts.T(), http.StatusTooManyRequests, w.Code)

// It doesn't rate limit a new value for the limited header
req = httptest.NewRequest(http.MethodPost, "http://localhost/token", &buffer)
req = httptest.NewRequest(http.MethodPost, "http://localhost/token?grant_type=web3", &buffer)
req.Header.Set("Content-Type", "application/json")
req.Header.Set("My-Custom-Header", "5.6.7.8")
w = httptest.NewRecorder()
Expand Down
1 change: 1 addition & 0 deletions internal/conf/configuration.go
Original file line number Diff line number Diff line change
Expand Up @@ -257,6 +257,7 @@ type GlobalConfiguration struct {
RateLimitSso float64 `split_words:"true" default:"30"`
RateLimitAnonymousUsers float64 `split_words:"true" default:"30"`
RateLimitOtp float64 `split_words:"true" default:"30"`
RateLimitWeb3 float64 `split_words:"true" default:"30"`

SiteURL string `json:"site_url" split_words:"true" required:"true"`
URIAllowList []string `json:"uri_allow_list" split_words:"true"`
Expand Down
4 changes: 4 additions & 0 deletions internal/observability/request-logger.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,10 @@ func (l *structuredLogger) NewLogEntry(r *http.Request) chimiddleware.LogEntry {
"referer": referrer,
}

if r.URL.Path == "/token" {
logFields["grant_type"] = r.FormValue("grant_type")
}

if reqID := utilities.GetRequestID(r.Context()); reqID != "" {
logFields["request_id"] = reqID
}
Expand Down
Loading