diff --git a/lambda/handler.go b/lambda/handler.go index e4cfaf7a..088ad5a5 100644 --- a/lambda/handler.go +++ b/lambda/handler.go @@ -22,12 +22,12 @@ type Handler interface { type handlerOptions struct { handlerFunc - baseContext context.Context - jsonResponseEscapeHTML bool - jsonResponseIndentPrefix string - jsonResponseIndentValue string - enableSIGTERM bool - sigtermCallbacks []func() + baseContext context.Context + jsonEncoderOptions []func(encoder *json.Encoder) + jsonDecoderOptions []func(decoder *json.Decoder) + setIndentUsed bool + enableSIGTERM bool + sigtermCallbacks []func() } type Option func(*handlerOptions) @@ -48,6 +48,44 @@ func WithContext(ctx context.Context) Option { }) } +// WithJSONEncoderOption allows setting arbitrary options on the underlying json encoder +// +// Usage: +// +// lambda.StartWithOptions( +// func () (string, error) { +// return "hello!>", nil +// }, +// lambda.WithJSONEncoderOption(func(encoder *json.Encoder) { +// encoder.SetEscapeHTML(true) +// encoder.SetIndent(">", " ") +// }), +// ) +func WithJSONEncoderOption(opt func(encoder *json.Encoder)) Option { + return Option(func(h *handlerOptions) { + h.jsonEncoderOptions = append(h.jsonEncoderOptions, opt) + }) +} + +// WithJSONDecoderOption allows setting arbitrary options on the underlying json decoder +// +// Usage: +// +// lambda.StartWithOptions( +// func (event any) (any, error) { +// return event, nil +// }, +// lambda.WithJSONEncoderOption(func(decoder *json.Decoder) { +// decoder.UseNumber() +// decoder.DisallowUnknownFields() +// }), +// ) +func WithJSONDecoderOption(opt func(decoder *json.Decoder)) Option { + return Option(func(h *handlerOptions) { + h.jsonDecoderOptions = append(h.jsonDecoderOptions, opt) + }) +} + // WithSetEscapeHTML sets the SetEscapeHTML argument on the underlying json encoder // // Usage: @@ -59,8 +97,8 @@ func WithContext(ctx context.Context) Option { // lambda.WithSetEscapeHTML(true), // ) func WithSetEscapeHTML(escapeHTML bool) Option { - return Option(func(h *handlerOptions) { - h.jsonResponseEscapeHTML = escapeHTML + return WithJSONEncoderOption(func(encoder *json.Encoder) { + encoder.SetEscapeHTML(escapeHTML) }) } @@ -76,8 +114,13 @@ func WithSetEscapeHTML(escapeHTML bool) Option { // ) func WithSetIndent(prefix, indent string) Option { return Option(func(h *handlerOptions) { - h.jsonResponseIndentPrefix = prefix - h.jsonResponseIndentValue = indent + // back-compat, the encoder's trailing newline is stripped unless WithSetIndent was used + if prefix != "" || indent != "" { + h.setIndentUsed = true + } + h.jsonEncoderOptions = append(h.jsonEncoderOptions, func(encoder *json.Encoder) { + encoder.SetIndent(prefix, indent) + }) }) } @@ -176,10 +219,7 @@ func newHandler(handlerFunc interface{}, options ...Option) *handlerOptions { return h } h := &handlerOptions{ - baseContext: context.Background(), - jsonResponseEscapeHTML: false, - jsonResponseIndentPrefix: "", - jsonResponseIndentValue: "", + baseContext: context.Background(), } for _, option := range options { option(h) @@ -267,9 +307,13 @@ func reflectHandler(f interface{}, h *handlerOptions) handlerFunc { out.Reset() in := bytes.NewBuffer(payload) decoder := json.NewDecoder(in) + for _, opt := range h.jsonDecoderOptions { + opt(decoder) + } encoder := json.NewEncoder(out) - encoder.SetEscapeHTML(h.jsonResponseEscapeHTML) - encoder.SetIndent(h.jsonResponseIndentPrefix, h.jsonResponseIndentValue) + for _, opt := range h.jsonEncoderOptions { + opt(encoder) + } trace := handlertrace.FromContext(ctx) @@ -325,7 +369,7 @@ func reflectHandler(f interface{}, h *handlerOptions) handlerFunc { } // back-compat, strip the encoder's trailing newline unless WithSetIndent was used - if h.jsonResponseIndentValue == "" && h.jsonResponseIndentPrefix == "" { + if !h.setIndentUsed { out.Truncate(out.Len() - 1) } return out, nil diff --git a/lambda/handler_test.go b/lambda/handler_test.go index 87942900..5330db7f 100644 --- a/lambda/handler_test.go +++ b/lambda/handler_test.go @@ -5,6 +5,7 @@ package lambda import ( "bytes" "context" + "encoding/json" "errors" "fmt" "io" @@ -285,6 +286,50 @@ func TestInvokes(t *testing.T) { return nil, messages.InvokeResponse_Error{Message: "message", Type: "type"} }, }, + { + name: "WithJSONDecoderOption(func(decoder *json.Decoder) { decoder.UseNumber() })", + input: `{ "i": 1 }`, + expected: expected{`null`, nil}, + handler: func(in struct { + I interface{} `json:"i"` + }) error { + if _, ok := in.I.(json.Number); !ok { + return fmt.Errorf("`i` was not of type json.Number: %T", in.I) + } + return nil + }, + options: []Option{WithJSONDecoderOption(func(decoder *json.Decoder) { decoder.UseNumber() })}, + }, + { + name: "WithJSONDecoderOption(func(decoder *json.Decoder) {})", + input: `{ "i": 1 }`, + expected: expected{`null`, nil}, + handler: func(in struct { + I interface{} `json:"i"` + }) error { + if _, ok := in.I.(float64); !ok { + return fmt.Errorf("`i` was not of type float64: %T", in.I) + } + return nil + }, + options: []Option{WithJSONDecoderOption(func(decoder *json.Decoder) {})}, + }, + { + name: "WithJSONEncoderOption(func(encoder *json.Encoder) { encoder.SetEscapeHTML(false) })", + expected: expected{`"html in json string!"`, nil}, + handler: func() (string, error) { + return "html in json string!", nil + }, + options: []Option{WithJSONEncoderOption(func(encoder *json.Encoder) { encoder.SetEscapeHTML(false) })}, + }, + { + name: "WithJSONEncoderOption(func(encoder *json.Encoder) { encoder.SetEscapeHTML(true) })", + expected: expected{`"\u003chtml\u003e\u003cbody\u003ehtml in json string!\u003c/body\u003e\u003c/html\u003e"`, nil}, + handler: func() (string, error) { + return "html in json string!", nil + }, + options: []Option{WithJSONEncoderOption(func(encoder *json.Encoder) { encoder.SetEscapeHTML(true) })}, + }, { name: "WithSetEscapeHTML(false)", expected: expected{`"html in json string!"`, nil},