diff --git a/github/github.go b/github/github.go index fced10769dd..e59952e3578 100644 --- a/github/github.go +++ b/github/github.go @@ -77,6 +77,13 @@ type Client struct { Repositories *RepositoriesService Search *SearchService Users *UsersService + + // If true, calls to Do will only result in an API request if there is + // deemed to be room under the rate limit that the request would + // succeed. If there is no room under the rate limit, requests will + // block until there is quota. + ObeyRateLimit bool + wait func(time.Duration) // Called when rate-limit waiting happens } // ListOptions specifies the optional parameters to various List methods that @@ -138,6 +145,7 @@ func NewClient(httpClient *http.Client) *Client { c.Repositories = &RepositoriesService{client: c} c.Search = &SearchService{client: c} c.Users = &UsersService{client: c} + c.wait = time.Sleep return c } @@ -292,6 +300,10 @@ func (r *Response) populateRate() { // interface, the raw response body will be written to v, without attempting to // first decode it. func (c *Client) Do(req *http.Request, v interface{}) (*Response, error) { + if err := c.maybeWaitForRateLimit(req); err != nil { + return nil, err + } + resp, err := c.client.Do(req) if err != nil { return nil, err @@ -320,6 +332,41 @@ func (c *Client) Do(req *http.Request, v interface{}) (*Response, error) { return response, err } +func (c *Client) maybeWaitForRateLimit(req *http.Request) error { + // Don't wait to make requests for the current rate limit. + if req.URL.Path == "/rate_limit" { + return nil + } + if c.wait == nil { + c.wait = time.Sleep + } + + if c.ObeyRateLimit { + if c.Rate.Limit == 0 { + // Don't know our rate limits yet, check now + if _, _, err := c.RateLimits(); err != nil { + return err + } + } + // TODO(imjasonh): Keep track of the last time we checked our + // rate limits and refresh now if they're possibly out-of-date. + + // While we have <5% of our remaining requests left, wait to execute. + threshold := float64(c.Rate.Limit) * .05 + tokensPerSecond := float64(c.Rate.Limit) / 60 / 60 + for float64(c.Rate.Remaining) < threshold { + tokensNeeded := threshold - float64(c.Rate.Remaining) + waitDur := time.Duration(tokensNeeded*tokensPerSecond*1000) * time.Millisecond + c.wait(waitDur) + _, _, err := c.RateLimits() + if err != nil { + return err + } + } + } + return nil +} + /* An ErrorResponse reports one or more errors caused by an API request. diff --git a/github/ratelimit_test.go b/github/ratelimit_test.go new file mode 100644 index 00000000000..036ba3fd149 --- /dev/null +++ b/github/ratelimit_test.go @@ -0,0 +1,71 @@ +package github + +import ( + "fmt" + "io" + "net/http" + "testing" + "time" +) + +func TestObeyRateLimit(t *testing.T) { + setup() + defer teardown() + + limit := 60 + s := &rateLimitServer{limit, limit, time.Now()} + mux.Handle("/", s) + + client.wait = func(d time.Duration) { + t.Logf("waiting %s", d) + s.Remaining++ + } + + client.ObeyRateLimit = true + for i := 0; i < limit*2; i++ { + t.Logf("-- request %d", i) + if _, _, err := client.Octocat("foo"); err != nil { + t.Errorf("unexpected error: %v", err) + break + } + } + + // Turn off automatic rate limit enforcement and make too many requests + // to demonstrate that the alternative is over-quota errors. + client.ObeyRateLimit = false + var err error + var resp *Response + for i := 0; i < limit*2; i++ { + _, resp, err = client.Octocat("foo") + if err != nil { + break + } + } + if resp.StatusCode != http.StatusForbidden { + t.Errorf("expected to be rate limited") + } +} + +type rateLimitServer struct { + Limit, Remaining int + Reset time.Time +} + +func (rl *rateLimitServer) ServeHTTP(w http.ResponseWriter, r *http.Request) { + w.Header().Set(headerRateLimit, fmt.Sprintf("%d", rl.Limit)) + w.Header().Set(headerRateRemaining, fmt.Sprintf("%d", rl.Remaining)) + + tokensNeeded := rl.Limit - rl.Remaining + tokensPerSecond := rl.Limit / 60 / 60 + timeUntilReset := time.Duration(tokensNeeded*tokensPerSecond) * time.Second + rl.Reset = time.Now().Add(timeUntilReset) + w.Header().Set(headerRateReset, fmt.Sprintf("%d", rl.Reset)) + + if r.URL.Path != "/rate_limit" { + if rl.Remaining <= 0 { + http.Error(w, "Over quota", http.StatusForbidden) + } + rl.Remaining-- + } + io.WriteString(w, "{}") +}