diff --git a/internal/api/api.go b/internal/api/api.go index d852b5c65..f184ce55c 100644 --- a/internal/api/api.go +++ b/internal/api/api.go @@ -184,8 +184,8 @@ func NewAPIWithVersion(globalConfig *conf.GlobalConfiguration, db *storage.Conne r.With(api.limitHandler(api.limiterOpts.Otp)). With(api.verifyCaptcha).Post("/otp", api.Otp) - r.With(api.limitHandler(api.limiterOpts.Token)). - With(api.verifyCaptcha).Post("/token", api.Token) + // rate limiting applied in handler + r.With(api.verifyCaptcha).Post("/token", api.Token) r.With(api.limitHandler(api.limiterOpts.Verify)).Route("/verify", func(r *router) { r.Get("/", api.Verify) diff --git a/internal/api/middleware.go b/internal/api/middleware.go index 2754824fa..43257e8d1 100644 --- a/internal/api/middleware.go +++ b/internal/api/middleware.go @@ -57,25 +57,27 @@ func (f *FunctionHooks) UnmarshalJSON(b []byte) error { var emailRateLimitCounter = observability.ObtainMetricCounter("gotrue_email_rate_limit_counter", "Number of times an email rate limit has been triggered") -func (a *API) limitHandler(lmt *limiter.Limiter) middlewareHandler { - return func(w http.ResponseWriter, req *http.Request) (context.Context, error) { - c := req.Context() - - if limitHeader := a.config.RateLimitHeader; limitHeader != "" { - key := req.Header.Get(limitHeader) - - if key == "" { - log := observability.GetLogEntry(req).Entry - log.WithField("header", limitHeader).Warn("request does not have a value for the rate limiting header, rate limiting is not applied") - return c, nil - } else { - err := tollbooth.LimitByKeys(lmt, []string{key}) - if err != nil { - return c, tooManyRequestsError(apierrors.ErrorCodeOverRequestRateLimit, "Request rate limit reached") - } +func (a *API) performRateLimiting(lmt *limiter.Limiter, req *http.Request) error { + if limitHeader := a.config.RateLimitHeader; limitHeader != "" { + key := req.Header.Get(limitHeader) + + if key == "" { + log := observability.GetLogEntry(req).Entry + log.WithField("header", limitHeader).Warn("request does not have a value for the rate limiting header, rate limiting is not applied") + } else { + err := tollbooth.LimitByKeys(lmt, []string{key}) + if err != nil { + return tooManyRequestsError(apierrors.ErrorCodeOverRequestRateLimit, "Request rate limit reached") } } - return c, nil + } + + return nil +} + +func (a *API) limitHandler(lmt *limiter.Limiter) middlewareHandler { + return func(w http.ResponseWriter, req *http.Request) (context.Context, error) { + return req.Context(), a.performRateLimiting(lmt, req) } } @@ -137,11 +139,27 @@ func (a *API) verifyCaptcha(w http.ResponseWriter, req *http.Request) (context.C } func isIgnoreCaptchaRoute(req *http.Request) bool { - // captcha shouldn't be enabled on the following grant_types - // id_token, refresh_token, pkce - if req.URL.Path == "/token" && req.FormValue("grant_type") != "password" { + if req.URL.Path != "/token" { + return false + } + + switch req.FormValue("grant_type") { + case "pkce": + return true + + case "refresh_token": return true + + case "id_token": + return true + + case "password": + return false + + case "web3": + return false } + return false } diff --git a/internal/api/options.go b/internal/api/options.go index 9053c2f97..13663152f 100644 --- a/internal/api/options.go +++ b/internal/api/options.go @@ -30,6 +30,7 @@ type LimiterOptions struct { FactorChallenge *limiter.Limiter SSO *limiter.Limiter SAMLAssertion *limiter.Limiter + Web3 *limiter.Limiter } func (lo *LimiterOptions) apply(a *API) { a.limiterOpts = lo } @@ -85,6 +86,11 @@ func NewLimiterOptions(gc *conf.GlobalConfiguration) *LimiterOptions { DefaultExpirationTTL: time.Hour, }).SetBurst(30) + o.Web3 = tollbooth.NewLimiter(gc.RateLimitWeb3/(60*5), + &limiter.ExpirableOptions{ + DefaultExpirationTTL: time.Hour, + }).SetBurst(30) + // These all use the OTP limit per 5 min with 1hour ttl and burst of 30. o.Recover = newLimiterPer5mOver1h(gc.RateLimitOtp) o.Resend = newLimiterPer5mOver1h(gc.RateLimitOtp) diff --git a/internal/api/token.go b/internal/api/token.go index b67601213..65ea94609 100644 --- a/internal/api/token.go +++ b/internal/api/token.go @@ -80,20 +80,30 @@ func (a *API) Token(w http.ResponseWriter, r *http.Request) error { ctx := r.Context() grantType := r.FormValue("grant_type") + handler := a.ResourceOwnerPasswordGrant + limiter := a.limiterOpts.Token + switch grantType { case "password": - return a.ResourceOwnerPasswordGrant(ctx, w, r) + // set above case "refresh_token": - return a.RefreshTokenGrant(ctx, w, r) + handler = a.RefreshTokenGrant case "id_token": - return a.IdTokenGrant(ctx, w, r) + handler = a.IdTokenGrant case "pkce": - return a.PKCE(ctx, w, r) + handler = a.PKCE case "web3": - return a.Web3Grant(ctx, w, r) + handler = a.Web3Grant + limiter = a.limiterOpts.Web3 default: return badRequestError(apierrors.ErrorCodeInvalidCredentials, "unsupported_grant_type") } + + if err := a.performRateLimiting(limiter, r); err != nil { + return err + } + + return handler(ctx, w, r) } // ResourceOwnerPasswordGrant implements the password grant type flow diff --git a/internal/api/token_test.go b/internal/api/token_test.go index 540e28272..166bfc900 100644 --- a/internal/api/token_test.go +++ b/internal/api/token_test.go @@ -225,7 +225,38 @@ func (ts *TokenTestSuite) TestSingleSessionPerUserNoTags() { func (ts *TokenTestSuite) TestRateLimitTokenRefresh() { var buffer bytes.Buffer - req := httptest.NewRequest(http.MethodPost, "http://localhost/token", &buffer) + req := httptest.NewRequest(http.MethodPost, "http://localhost/token?grant_type=refresh_token", &buffer) + req.Header.Set("Content-Type", "application/json") + req.Header.Set("My-Custom-Header", "1.2.3.4") + + // It rate limits after 30 requests + for i := 0; i < 30; i++ { + w := httptest.NewRecorder() + ts.API.handler.ServeHTTP(w, req) + assert.Equal(ts.T(), http.StatusBadRequest, w.Code) + } + w := httptest.NewRecorder() + ts.API.handler.ServeHTTP(w, req) + assert.Equal(ts.T(), http.StatusTooManyRequests, w.Code) + + // It ignores X-Forwarded-For by default + req.Header.Set("X-Forwarded-For", "1.1.1.1") + w = httptest.NewRecorder() + ts.API.handler.ServeHTTP(w, req) + assert.Equal(ts.T(), http.StatusTooManyRequests, w.Code) + + // It doesn't rate limit a new value for the limited header + req = httptest.NewRequest(http.MethodPost, "http://localhost/token?grant_type=refresh_token", &buffer) + req.Header.Set("Content-Type", "application/json") + req.Header.Set("My-Custom-Header", "5.6.7.8") + w = httptest.NewRecorder() + ts.API.handler.ServeHTTP(w, req) + assert.Equal(ts.T(), http.StatusBadRequest, w.Code) +} + +func (ts *TokenTestSuite) TestRateLimitWeb3() { + var buffer bytes.Buffer + req := httptest.NewRequest(http.MethodPost, "http://localhost/token?grant_type=web3", &buffer) req.Header.Set("Content-Type", "application/json") req.Header.Set("My-Custom-Header", "1.2.3.4") @@ -246,7 +277,7 @@ func (ts *TokenTestSuite) TestRateLimitTokenRefresh() { assert.Equal(ts.T(), http.StatusTooManyRequests, w.Code) // It doesn't rate limit a new value for the limited header - req = httptest.NewRequest(http.MethodPost, "http://localhost/token", &buffer) + req = httptest.NewRequest(http.MethodPost, "http://localhost/token?grant_type=web3", &buffer) req.Header.Set("Content-Type", "application/json") req.Header.Set("My-Custom-Header", "5.6.7.8") w = httptest.NewRecorder() diff --git a/internal/conf/configuration.go b/internal/conf/configuration.go index 1961a96f2..783a6552a 100644 --- a/internal/conf/configuration.go +++ b/internal/conf/configuration.go @@ -257,6 +257,7 @@ type GlobalConfiguration struct { RateLimitSso float64 `split_words:"true" default:"30"` RateLimitAnonymousUsers float64 `split_words:"true" default:"30"` RateLimitOtp float64 `split_words:"true" default:"30"` + RateLimitWeb3 float64 `split_words:"true" default:"30"` SiteURL string `json:"site_url" split_words:"true" required:"true"` URIAllowList []string `json:"uri_allow_list" split_words:"true"` diff --git a/internal/observability/request-logger.go b/internal/observability/request-logger.go index 6eeffd6ea..2c9612a71 100644 --- a/internal/observability/request-logger.go +++ b/internal/observability/request-logger.go @@ -55,6 +55,10 @@ func (l *structuredLogger) NewLogEntry(r *http.Request) chimiddleware.LogEntry { "referer": referrer, } + if r.URL.Path == "/token" { + logFields["grant_type"] = r.FormValue("grant_type") + } + if reqID := utilities.GetRequestID(r.Context()); reqID != "" { logFields["request_id"] = reqID }