@@ -133,6 +133,8 @@ const (
133
133
mediaTypeContentAttachmentsPreview = "application/vnd.github.corsair-preview+json"
134
134
)
135
135
136
+ var errNonNilContext = errors .New ("context must be non-nil" )
137
+
136
138
// A Client manages communication with the GitHub API.
137
139
type Client struct {
138
140
clientMu sync.Mutex // clientMu protects the client during calls that modify the CheckRedirect func.
@@ -531,7 +533,7 @@ func parseRate(r *http.Response) Rate {
531
533
// canceled or times out, ctx.Err() will be returned.
532
534
func (c * Client ) BareDo (ctx context.Context , req * http.Request ) (* Response , error ) {
533
535
if ctx == nil {
534
- return nil , errors . New ( "context must be non-nil" )
536
+ return nil , errNonNilContext
535
537
}
536
538
req = withContext (ctx , req )
537
539
@@ -654,6 +656,20 @@ func (c *Client) checkRateLimitBeforeDo(req *http.Request, rateLimitCategory rat
654
656
return nil
655
657
}
656
658
659
+ // compareHttpResponse returns whether two http.Response objects are equal or not.
660
+ // Currently, only StatusCode is checked. This function is used when implementing the
661
+ // Is(error) bool interface for the custom error types in this package.
662
+ func compareHttpResponse (r1 , r2 * http.Response ) bool {
663
+ if r1 == nil && r2 == nil {
664
+ return true
665
+ }
666
+
667
+ if r1 != nil && r2 != nil {
668
+ return r1 .StatusCode == r2 .StatusCode
669
+ }
670
+ return false
671
+ }
672
+
657
673
/*
658
674
An ErrorResponse reports one or more errors caused by an API request.
659
675
@@ -682,6 +698,50 @@ func (r *ErrorResponse) Error() string {
682
698
r .Response .StatusCode , r .Message , r .Errors )
683
699
}
684
700
701
+ // Is returns whether the provided error equals this error.
702
+ func (r * ErrorResponse ) Is (target error ) bool {
703
+ v , ok := target .(* ErrorResponse )
704
+ if ! ok {
705
+ return false
706
+ }
707
+
708
+ if r .Message != v .Message || (r .DocumentationURL != v .DocumentationURL ) ||
709
+ ! compareHttpResponse (r .Response , v .Response ) {
710
+ return false
711
+ }
712
+
713
+ // Compare Errors.
714
+ if len (r .Errors ) != len (v .Errors ) {
715
+ return false
716
+ }
717
+ for idx := range r .Errors {
718
+ if r .Errors [idx ] != v .Errors [idx ] {
719
+ return false
720
+ }
721
+ }
722
+
723
+ // Compare Block.
724
+ if (r .Block != nil && v .Block == nil ) || (r .Block == nil && v .Block != nil ) {
725
+ return false
726
+ }
727
+ if r .Block != nil && v .Block != nil {
728
+ if r .Block .Reason != v .Block .Reason {
729
+ return false
730
+ }
731
+ if (r .Block .CreatedAt != nil && v .Block .CreatedAt == nil ) || (r .Block .CreatedAt ==
732
+ nil && v .Block .CreatedAt != nil ) {
733
+ return false
734
+ }
735
+ if r .Block .CreatedAt != nil && v .Block .CreatedAt != nil {
736
+ if * (r .Block .CreatedAt ) != * (v .Block .CreatedAt ) {
737
+ return false
738
+ }
739
+ }
740
+ }
741
+
742
+ return true
743
+ }
744
+
685
745
// TwoFactorAuthError occurs when using HTTP Basic Authentication for a user
686
746
// that has two-factor authentication enabled. The request can be reattempted
687
747
// by providing a one-time password in the request.
@@ -703,6 +763,18 @@ func (r *RateLimitError) Error() string {
703
763
r .Response .StatusCode , r .Message , formatRateReset (time .Until (r .Rate .Reset .Time )))
704
764
}
705
765
766
+ // Is returns whether the provided error equals this error.
767
+ func (r * RateLimitError ) Is (target error ) bool {
768
+ v , ok := target .(* RateLimitError )
769
+ if ! ok {
770
+ return false
771
+ }
772
+
773
+ return r .Rate == v .Rate &&
774
+ r .Message == v .Message &&
775
+ compareHttpResponse (r .Response , v .Response )
776
+ }
777
+
706
778
// AcceptedError occurs when GitHub returns 202 Accepted response with an
707
779
// empty body, which means a job was scheduled on the GitHub side to process
708
780
// the information needed and cache it.
@@ -718,6 +790,15 @@ func (*AcceptedError) Error() string {
718
790
return "job scheduled on GitHub side; try again later"
719
791
}
720
792
793
+ // Is returns whether the provided error equals this error.
794
+ func (ae * AcceptedError ) Is (target error ) bool {
795
+ v , ok := target .(* AcceptedError )
796
+ if ! ok {
797
+ return false
798
+ }
799
+ return bytes .Compare (ae .Raw , v .Raw ) == 0
800
+ }
801
+
721
802
// AbuseRateLimitError occurs when GitHub returns 403 Forbidden response with the
722
803
// "documentation_url" field value equal to "https://docs.github.com/en/free-pro-team@latest/rest/reference/#abuse-rate-limits".
723
804
type AbuseRateLimitError struct {
@@ -736,6 +817,18 @@ func (r *AbuseRateLimitError) Error() string {
736
817
r .Response .StatusCode , r .Message )
737
818
}
738
819
820
+ // Is returns whether the provided error equals this error.
821
+ func (r * AbuseRateLimitError ) Is (target error ) bool {
822
+ v , ok := target .(* AbuseRateLimitError )
823
+ if ! ok {
824
+ return false
825
+ }
826
+
827
+ return r .Message == v .Message &&
828
+ r .RetryAfter == v .RetryAfter &&
829
+ compareHttpResponse (r .Response , v .Response )
830
+ }
831
+
739
832
// sanitizeURL redacts the client_secret parameter from the URL which may be
740
833
// exposed to the user.
741
834
func sanitizeURL (uri * url.URL ) * url.URL {
0 commit comments