Skip to content

Commit 736e4fd

Browse files
committed
add test
Signed-off-by: Ben Ye <[email protected]>
1 parent 06292ab commit 736e4fd

File tree

3 files changed

+255
-37
lines changed

3 files changed

+255
-37
lines changed

pkg/frontend/transport/handler.go

Lines changed: 2 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,9 @@ import (
1616

1717
"github.com/go-kit/log"
1818
"github.com/go-kit/log/level"
19-
v1 "github.com/prometheus/client_golang/api/prometheus/v1"
2019
"github.com/prometheus/client_golang/prometheus"
2120
"github.com/prometheus/client_golang/prometheus/promauto"
2221
"github.com/weaveworks/common/httpgrpc"
23-
"google.golang.org/grpc/codes"
2422
"google.golang.org/grpc/status"
2523

2624
querier_stats "github.com/cortexproject/cortex/pkg/querier/stats"
@@ -462,30 +460,7 @@ func writeError(logger log.Logger, w http.ResponseWriter, err error, additionalH
462460
headers.Set(k, value)
463461
}
464462
}
465-
resp, ok := httpgrpc.HTTPResponseFromError(err)
466-
if ok {
467-
code := int(resp.Code)
468-
var errTyp v1.ErrorType
469-
switch resp.Code {
470-
case http.StatusBadRequest, http.StatusRequestEntityTooLarge:
471-
errTyp = v1.ErrBadData
472-
case StatusClientClosedRequest:
473-
errTyp = v1.ErrCanceled
474-
case http.StatusGatewayTimeout:
475-
errTyp = v1.ErrTimeout
476-
case http.StatusUnprocessableEntity:
477-
errTyp = v1.ErrExec
478-
case int32(codes.PermissionDenied):
479-
// Convert gRPC status code to HTTP status code.
480-
code = http.StatusUnprocessableEntity
481-
errTyp = v1.ErrBadData
482-
default:
483-
errTyp = v1.ErrServer
484-
}
485-
util_api.RespondError(logger, w, errTyp, string(resp.Body), code)
486-
} else {
487-
util_api.RespondError(logger, w, v1.ErrServer, err.Error(), http.StatusInternalServerError)
488-
}
463+
util_api.RespondFromGRPCError(logger, w, err)
489464
}
490465

491466
func writeServiceTimingHeader(queryResponseTime time.Duration, headers http.Header, stats *querier_stats.QueryStats) {
@@ -506,7 +481,7 @@ func statsValue(name string, d time.Duration) string {
506481
func getStatusCodeFromError(err error) int {
507482
switch err {
508483
case context.Canceled:
509-
return StatusClientClosedRequest
484+
return util_api.StatusClientClosedRequest
510485
case context.DeadlineExceeded:
511486
return http.StatusGatewayTimeout
512487
default:

pkg/util/api/response.go

Lines changed: 41 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,20 @@ package api
22

33
import (
44
"encoding/json"
5-
"net/http"
6-
"unsafe"
7-
85
"github.com/go-kit/log"
96
"github.com/go-kit/log/level"
107
v1 "github.com/prometheus/client_golang/api/prometheus/v1"
8+
"github.com/weaveworks/common/httpgrpc"
9+
"google.golang.org/grpc/codes"
10+
"net/http"
11+
)
12+
13+
const (
14+
// StatusClientClosedRequest is the status code for when a client request cancellation of a http request
15+
StatusClientClosedRequest = 499
1116
)
1217

18+
// Response defines the Prometheus response format.
1319
type Response struct {
1420
Status string `json:"status"`
1521
Data interface{} `json:"data"`
@@ -18,15 +24,45 @@ type Response struct {
1824
Warnings []string `json:"warnings,omitempty"`
1925
}
2026

27+
// RespondFromGRPCError writes gRPC error in Prometheus response format.
28+
// If error is not a valid gRPC error, use server_error instead.
29+
func RespondFromGRPCError(logger log.Logger, w http.ResponseWriter, err error) {
30+
resp, ok := httpgrpc.HTTPResponseFromError(err)
31+
if ok {
32+
code := int(resp.Code)
33+
var errTyp v1.ErrorType
34+
switch resp.Code {
35+
case http.StatusBadRequest, http.StatusRequestEntityTooLarge:
36+
errTyp = v1.ErrBadData
37+
case StatusClientClosedRequest:
38+
errTyp = v1.ErrCanceled
39+
case http.StatusGatewayTimeout:
40+
errTyp = v1.ErrTimeout
41+
case http.StatusUnprocessableEntity:
42+
errTyp = v1.ErrExec
43+
case int32(codes.PermissionDenied):
44+
// Convert gRPC status code to HTTP status code.
45+
code = http.StatusUnprocessableEntity
46+
errTyp = v1.ErrBadData
47+
default:
48+
errTyp = v1.ErrServer
49+
}
50+
RespondError(logger, w, errTyp, string(resp.Body), code)
51+
} else {
52+
RespondError(logger, w, v1.ErrServer, err.Error(), http.StatusInternalServerError)
53+
}
54+
}
55+
56+
// RespondError writes error in Prometheus response format using provided error type and message.
2157
func RespondError(logger log.Logger, w http.ResponseWriter, errorType v1.ErrorType, msg string, statusCode int) {
2258
var (
2359
res Response
2460
b []byte
2561
err error
2662
)
27-
b = yoloBuf(msg)
63+
b = []byte(msg)
2864
// Try to deserialize response and see if it is already in Prometheus error format.
29-
if err := json.Unmarshal(b, &res); err != nil {
65+
if err = json.Unmarshal(b, &res); err != nil {
3066
b, err = json.Marshal(&Response{
3167
Status: "error",
3268
ErrorType: errorType,
@@ -47,8 +83,3 @@ func RespondError(logger log.Logger, w http.ResponseWriter, errorType v1.ErrorTy
4783
level.Error(logger).Log("msg", "error writing response", "bytesWritten", n, "err", err)
4884
}
4985
}
50-
51-
// yoloBuf will return an unsafe pointer to a string, as the name yolo.yoloBuf implies use at your own risk.
52-
func yoloBuf(s string) []byte {
53-
return *((*[]byte)(unsafe.Pointer(&s)))
54-
}

pkg/util/api/response_test.go

Lines changed: 212 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,212 @@
1+
package api
2+
3+
import (
4+
"encoding/json"
5+
"errors"
6+
"github.com/weaveworks/common/httpgrpc"
7+
"google.golang.org/grpc/codes"
8+
"io"
9+
"net/http"
10+
"net/http/httptest"
11+
"testing"
12+
13+
"github.com/go-kit/log"
14+
v1 "github.com/prometheus/client_golang/api/prometheus/v1"
15+
"github.com/stretchr/testify/require"
16+
)
17+
18+
func TestRespondFromGRPCError(t *testing.T) {
19+
logger := log.NewNopLogger()
20+
for _, tc := range []struct {
21+
name string
22+
err error
23+
expectedResp *Response
24+
code int
25+
}{
26+
{
27+
name: "non grpc error",
28+
err: errors.New("test"),
29+
expectedResp: &Response{
30+
Status: "error",
31+
ErrorType: v1.ErrServer,
32+
Error: "test",
33+
},
34+
code: 500,
35+
},
36+
{
37+
name: "bad data",
38+
err: httpgrpc.Errorf(http.StatusBadRequest, "bad_data"),
39+
expectedResp: &Response{
40+
Status: "error",
41+
ErrorType: v1.ErrBadData,
42+
Error: "bad_data",
43+
},
44+
code: http.StatusBadRequest,
45+
},
46+
{
47+
name: "413",
48+
err: httpgrpc.Errorf(http.StatusRequestEntityTooLarge, "bad_data"),
49+
expectedResp: &Response{
50+
Status: "error",
51+
ErrorType: v1.ErrBadData,
52+
Error: "bad_data",
53+
},
54+
code: http.StatusRequestEntityTooLarge,
55+
},
56+
{
57+
name: "499",
58+
err: httpgrpc.Errorf(StatusClientClosedRequest, "bad_data"),
59+
expectedResp: &Response{
60+
Status: "error",
61+
ErrorType: v1.ErrCanceled,
62+
Error: "bad_data",
63+
},
64+
code: StatusClientClosedRequest,
65+
},
66+
{
67+
name: "504",
68+
err: httpgrpc.Errorf(http.StatusGatewayTimeout, "bad_data"),
69+
expectedResp: &Response{
70+
Status: "error",
71+
ErrorType: v1.ErrTimeout,
72+
Error: "bad_data",
73+
},
74+
code: http.StatusGatewayTimeout,
75+
},
76+
{
77+
name: "422",
78+
err: httpgrpc.Errorf(http.StatusUnprocessableEntity, "bad_data"),
79+
expectedResp: &Response{
80+
Status: "error",
81+
ErrorType: v1.ErrExec,
82+
Error: "bad_data",
83+
},
84+
code: http.StatusUnprocessableEntity,
85+
},
86+
{
87+
name: "grpc status code",
88+
err: httpgrpc.Errorf(int(codes.PermissionDenied), "bad_data"),
89+
expectedResp: &Response{
90+
Status: "error",
91+
ErrorType: v1.ErrBadData,
92+
Error: "bad_data",
93+
},
94+
code: http.StatusUnprocessableEntity,
95+
},
96+
{
97+
name: "other status code defaults to err server",
98+
err: httpgrpc.Errorf(http.StatusTooManyRequests, "bad_data"),
99+
expectedResp: &Response{
100+
Status: "error",
101+
ErrorType: v1.ErrServer,
102+
Error: "bad_data",
103+
},
104+
code: http.StatusTooManyRequests,
105+
},
106+
} {
107+
t.Run(tc.name, func(t *testing.T) {
108+
writer := httptest.NewRecorder()
109+
RespondFromGRPCError(logger, writer, tc.err)
110+
output, err := io.ReadAll(writer.Body)
111+
require.NoError(t, err)
112+
var res Response
113+
err = json.Unmarshal(output, &res)
114+
require.NoError(t, err)
115+
116+
require.Equal(t, tc.expectedResp.Status, res.Status)
117+
require.Equal(t, tc.expectedResp.Error, res.Error)
118+
require.Equal(t, tc.expectedResp.ErrorType, res.ErrorType)
119+
120+
require.Equal(t, tc.code, writer.Code)
121+
})
122+
}
123+
}
124+
125+
func TestRespondError(t *testing.T) {
126+
logger := log.NewNopLogger()
127+
for _, tc := range []struct {
128+
name string
129+
errorType v1.ErrorType
130+
msg string
131+
status string
132+
code int
133+
expectedResp *Response
134+
}{
135+
{
136+
name: "bad data",
137+
errorType: v1.ErrBadData,
138+
msg: "test_msg",
139+
status: "error",
140+
code: 400,
141+
},
142+
{
143+
name: "server error",
144+
errorType: v1.ErrServer,
145+
msg: "test_msg",
146+
status: "error",
147+
code: 500,
148+
},
149+
{
150+
name: "canceled",
151+
errorType: v1.ErrCanceled,
152+
msg: "test_msg",
153+
status: "error",
154+
code: 499,
155+
},
156+
{
157+
name: "timeout",
158+
errorType: v1.ErrTimeout,
159+
msg: "test_msg",
160+
status: "error",
161+
code: 502,
162+
},
163+
{
164+
name: "prometheus_format_error",
165+
expectedResp: &Response{
166+
Status: "error",
167+
ErrorType: v1.ErrServer,
168+
Error: "server_error",
169+
},
170+
code: 400,
171+
status: "error",
172+
errorType: v1.ErrBadData,
173+
},
174+
{
175+
// If the input Prometheus error cannot be unmarshalled,
176+
// use the error type and message provided in the function.
177+
name: "bad_prometheus_format_error",
178+
msg: `"status":"error","data":null,"errorType":"bad_data","error":"bad_data"}`,
179+
code: 500,
180+
status: "error",
181+
errorType: v1.ErrServer,
182+
},
183+
} {
184+
t.Run(tc.name, func(t *testing.T) {
185+
msg := tc.msg
186+
if tc.expectedResp != nil {
187+
output, err := json.Marshal(tc.expectedResp)
188+
require.NoError(t, err)
189+
msg = string(output)
190+
}
191+
writer := httptest.NewRecorder()
192+
RespondError(logger, writer, tc.errorType, msg, tc.code)
193+
output, err := io.ReadAll(writer.Body)
194+
require.NoError(t, err)
195+
var res Response
196+
err = json.Unmarshal(output, &res)
197+
require.NoError(t, err)
198+
199+
if tc.expectedResp == nil {
200+
require.Equal(t, tc.status, res.Status)
201+
require.Equal(t, tc.msg, res.Error)
202+
require.Equal(t, tc.errorType, res.ErrorType)
203+
} else {
204+
require.Equal(t, tc.expectedResp.Status, res.Status)
205+
require.Equal(t, tc.expectedResp.Error, res.Error)
206+
require.Equal(t, tc.expectedResp.ErrorType, res.ErrorType)
207+
}
208+
209+
require.Equal(t, tc.code, writer.Code)
210+
})
211+
}
212+
}

0 commit comments

Comments
 (0)