From 6c6600373459536aa45e033893df0a369375911c Mon Sep 17 00:00:00 2001 From: Bryan Moffatt Date: Sun, 6 Nov 2022 19:53:56 -0800 Subject: [PATCH 01/11] allow handler functions that return io.Reader --- lambda/handler.go | 78 ++++++++++++++++++++++--------- lambda/handler_test.go | 12 +++-- lambda/invoke_loop.go | 16 +++++-- lambda/rpc_function_test.go | 18 +++---- lambda/runtime_api_client.go | 13 +++--- lambda/runtime_api_client_test.go | 5 +- 6 files changed, 98 insertions(+), 44 deletions(-) diff --git a/lambda/handler.go b/lambda/handler.go index 0fc82d6e..764adccb 100644 --- a/lambda/handler.go +++ b/lambda/handler.go @@ -8,7 +8,9 @@ import ( "encoding/json" "errors" "fmt" + "io" "reflect" + "strings" "github.com/aws/aws-lambda-go/lambda/handlertrace" ) @@ -18,7 +20,7 @@ type Handler interface { } type handlerOptions struct { - Handler + handlerFunc baseContext context.Context jsonResponseEscapeHTML bool jsonResponseIndentPrefix string @@ -168,32 +170,57 @@ func newHandler(handlerFunc interface{}, options ...Option) *handlerOptions { if h.enableSIGTERM { enableSIGTERM(h.sigtermCallbacks) } - h.Handler = reflectHandler(handlerFunc, h) + h.handlerFunc = reflectHandler(handlerFunc, h) return h } -type bytesHandlerFunc func(context.Context, []byte) ([]byte, error) +type handlerFunc func(context.Context, []byte) (io.Reader, error) -func (h bytesHandlerFunc) Invoke(ctx context.Context, payload []byte) ([]byte, error) { - return h(ctx, payload) +// back-compat for the rpc mode +func (h handlerFunc) Invoke(ctx context.Context, payload []byte) ([]byte, error) { + response, err := h(ctx, payload) + if err != nil { + return nil, err + } + b, err := io.ReadAll(response) + if err != nil { + return nil, err + } + return b, nil } -func errorHandler(err error) Handler { - return bytesHandlerFunc(func(_ context.Context, _ []byte) ([]byte, error) { + +func errorHandler(err error) handlerFunc { + return func(_ context.Context, _ []byte) (io.Reader, error) { return nil, err - }) + } } -func reflectHandler(handlerFunc interface{}, h *handlerOptions) Handler { - if handlerFunc == nil { +type jsonOutBuffer struct { + *bytes.Buffer +} + +func (j *jsonOutBuffer) ContentType() string { + return contentTypeJSON +} + +func reflectHandler(f interface{}, h *handlerOptions) handlerFunc { + if f == nil { return errorHandler(errors.New("handler is nil")) } - if handler, ok := handlerFunc.(Handler); ok { - return handler + // back-compat: types with reciever `Invoke(context.Context, []byte) ([]byte, error)` need the return bytes wrapped + if handler, ok := f.(Handler); ok { + return func(ctx context.Context, payload []byte) (io.Reader, error) { + b, err := handler.Invoke(ctx, payload) + if err != nil { + return nil, err + } + return bytes.NewBuffer(b), nil + } } - handler := reflect.ValueOf(handlerFunc) - handlerType := reflect.TypeOf(handlerFunc) + handler := reflect.ValueOf(f) + handlerType := reflect.TypeOf(f) if handlerType.Kind() != reflect.Func { return errorHandler(fmt.Errorf("handler kind %s is not %s", handlerType.Kind(), reflect.Func)) } @@ -207,9 +234,10 @@ func reflectHandler(handlerFunc interface{}, h *handlerOptions) Handler { return errorHandler(err) } - return bytesHandlerFunc(func(ctx context.Context, payload []byte) ([]byte, error) { + out := &jsonOutBuffer{bytes.NewBuffer(nil)} + return func(ctx context.Context, payload []byte) (io.Reader, error) { + out.Reset() in := bytes.NewBuffer(payload) - out := bytes.NewBuffer(nil) decoder := json.NewDecoder(in) encoder := json.NewEncoder(out) encoder.SetEscapeHTML(h.jsonResponseEscapeHTML) @@ -250,16 +278,24 @@ func reflectHandler(handlerFunc interface{}, h *handlerOptions) Handler { trace.ResponseEvent(ctx, val) } } + + // encode to JSON if err := encoder.Encode(val); err != nil { return nil, err } - responseBytes := out.Bytes() + // if response value is an io.Reader, return it as-is + // back-compat, don't return the reader if the value serialized to a non-empty json + if reader, ok := val.(io.Reader); ok { + if strings.HasPrefix(out.String(), "{}") { + return reader, nil + } + } + // back-compat, strip the encoder's trailing newline unless WithSetIndent was used if h.jsonResponseIndentValue == "" && h.jsonResponseIndentPrefix == "" { - return responseBytes[:len(responseBytes)-1], nil + out.Truncate(out.Len() - 1) } - - return responseBytes, nil - }) + return out, nil + } } diff --git a/lambda/handler_test.go b/lambda/handler_test.go index 3c3c51d4..a2eb29bc 100644 --- a/lambda/handler_test.go +++ b/lambda/handler_test.go @@ -83,6 +83,14 @@ func TestInvalidHandlers(t *testing.T) { } } +type staticHandler struct { + body []byte +} + +func (h *staticHandler) Invoke(_ context.Context, _ []byte) ([]byte, error) { + return h.body, nil +} + type expected struct { val string err error @@ -232,9 +240,7 @@ func TestInvokes(t *testing.T) { { name: "Handler interface implementations are passthrough", expected: expected{`hello`, nil}, - handler: bytesHandlerFunc(func(_ context.Context, _ []byte) ([]byte, error) { - return []byte(`hello`), nil - }), + handler: &staticHandler{body: []byte(`hello`)}, }, } for i, testCase := range testCases { diff --git a/lambda/invoke_loop.go b/lambda/invoke_loop.go index f73689ba..829c9166 100644 --- a/lambda/invoke_loop.go +++ b/lambda/invoke_loop.go @@ -3,9 +3,11 @@ package lambda import ( + "bytes" "context" "encoding/json" "fmt" + "io" "log" "os" "strconv" @@ -70,7 +72,7 @@ func handleInvoke(invoke *invoke, handler *handlerOptions) error { ctx = context.WithValue(ctx, "x-amzn-trace-id", traceID) // call the handler, marshal any returned error - response, invokeErr := callBytesHandlerFunc(ctx, invoke.payload, handler.Handler.Invoke) + response, invokeErr := callBytesHandlerFunc(ctx, invoke.payload, handler.handlerFunc) if invokeErr != nil { if err := reportFailure(invoke, invokeErr); err != nil { return err @@ -80,7 +82,13 @@ func handleInvoke(invoke *invoke, handler *handlerOptions) error { } return nil } - if err := invoke.success(response, contentTypeJSON); err != nil { + + contentType := contentTypeBytes + type ContentType interface{ ContentType() string } + if ct, ok := response.(ContentType); ok { + contentType = ct.ContentType() + } + if err := invoke.success(response, contentType); err != nil { return fmt.Errorf("unexpected error occurred when sending the function functionResponse to the API: %v", err) } @@ -90,13 +98,13 @@ func handleInvoke(invoke *invoke, handler *handlerOptions) error { func reportFailure(invoke *invoke, invokeErr *messages.InvokeResponse_Error) error { errorPayload := safeMarshal(invokeErr) log.Printf("%s", errorPayload) - if err := invoke.failure(errorPayload, contentTypeJSON); err != nil { + if err := invoke.failure(bytes.NewReader(errorPayload), contentTypeJSON); err != nil { return fmt.Errorf("unexpected error occurred when sending the function error to the API: %v", err) } return nil } -func callBytesHandlerFunc(ctx context.Context, payload []byte, handler bytesHandlerFunc) (response []byte, invokeErr *messages.InvokeResponse_Error) { +func callBytesHandlerFunc(ctx context.Context, payload []byte, handler handlerFunc) (response io.Reader, invokeErr *messages.InvokeResponse_Error) { defer func() { if err := recover(); err != nil { invokeErr = lambdaPanicResponse(err) diff --git a/lambda/rpc_function_test.go b/lambda/rpc_function_test.go index 6935084d..bcc08332 100644 --- a/lambda/rpc_function_test.go +++ b/lambda/rpc_function_test.go @@ -9,7 +9,10 @@ import ( "context" "encoding/json" "errors" + "io" "os" + "strconv" + "strings" "testing" "time" @@ -63,14 +66,13 @@ func TestInvoke(t *testing.T) { func TestInvokeWithContext(t *testing.T) { key := struct{}{} srv := NewFunction(&handlerOptions{ - Handler: testWrapperHandler( - func(ctx context.Context, input []byte) (interface{}, error) { - assert.Equal(t, "dummy", ctx.Value(key)) - if deadline, ok := ctx.Deadline(); ok { - return deadline.UnixNano(), nil - } - return nil, errors.New("!?!?!?!?!") - }), + handlerFunc: func(ctx context.Context, _ []byte) (io.Reader, error) { + assert.Equal(t, "dummy", ctx.Value(key)) + if deadline, ok := ctx.Deadline(); ok { + return strings.NewReader(strconv.FormatInt(deadline.UnixNano(), 10)), nil + } + return nil, errors.New("!?!?!?!?!") + }, baseContext: context.WithValue(context.Background(), key, "dummy"), }) deadline := time.Now() diff --git a/lambda/runtime_api_client.go b/lambda/runtime_api_client.go index 843f7ace..a83c3ce8 100644 --- a/lambda/runtime_api_client.go +++ b/lambda/runtime_api_client.go @@ -22,6 +22,7 @@ const ( headerClientContext = "Lambda-Runtime-Client-Context" headerInvokedFunctionARN = "Lambda-Runtime-Invoked-Function-Arn" contentTypeJSON = "application/json" + contentTypeBytes = "application/octet-stream" apiVersion = "2018-06-01" ) @@ -51,9 +52,9 @@ type invoke struct { // success sends the response payload for an in-progress invocation. // Notes: // - An invoke is not complete until next() is called again! -func (i *invoke) success(payload []byte, contentType string) error { +func (i *invoke) success(body io.Reader, contentType string) error { url := i.client.baseURL + i.id + "/response" - return i.client.post(url, payload, contentType) + return i.client.post(url, body, contentType) } // failure sends the payload to the Runtime API. This marks the function's invoke as a failure. @@ -61,9 +62,9 @@ func (i *invoke) success(payload []byte, contentType string) error { // - The execution of the function process continues, and is billed, until next() is called again! // - A Lambda Function continues to be re-used for future invokes even after a failure. // If the error is fatal (panic, unrecoverable state), exit the process immediately after calling failure() -func (i *invoke) failure(payload []byte, contentType string) error { +func (i *invoke) failure(body io.Reader, contentType string) error { url := i.client.baseURL + i.id + "/error" - return i.client.post(url, payload, contentType) + return i.client.post(url, body, contentType) } // next connects to the Runtime API and waits for a new invoke Request to be available. @@ -104,8 +105,8 @@ func (c *runtimeAPIClient) next() (*invoke, error) { }, nil } -func (c *runtimeAPIClient) post(url string, payload []byte, contentType string) error { - req, err := http.NewRequest(http.MethodPost, url, bytes.NewBuffer(payload)) +func (c *runtimeAPIClient) post(url string, body io.Reader, contentType string) error { + req, err := http.NewRequest(http.MethodPost, url, body) if err != nil { return fmt.Errorf("failed to construct POST request to %s: %v", url, err) } diff --git a/lambda/runtime_api_client_test.go b/lambda/runtime_api_client_test.go index 4693048b..7ccd47fb 100644 --- a/lambda/runtime_api_client_test.go +++ b/lambda/runtime_api_client_test.go @@ -3,6 +3,7 @@ package lambda import ( + "bytes" "fmt" "io/ioutil" //nolint: staticcheck "net/http" @@ -87,11 +88,11 @@ func TestClientDoneAndError(t *testing.T) { client: client, } t.Run(fmt.Sprintf("happy Done with payload[%d]", i), func(t *testing.T) { - err := invoke.success(payload, contentTypeJSON) + err := invoke.success(bytes.NewReader(payload), contentTypeJSON) assert.NoError(t, err) }) t.Run(fmt.Sprintf("happy Error with payload[%d]", i), func(t *testing.T) { - err := invoke.failure(payload, contentTypeJSON) + err := invoke.failure(bytes.NewReader(payload), contentTypeJSON) assert.NoError(t, err) }) } From 7d5df89fc6d39d6d112cb7dd1568161b54173be7 Mon Sep 17 00:00:00 2001 From: Bryan Moffatt Date: Wed, 16 Nov 2022 14:25:49 -0800 Subject: [PATCH 02/11] add a couple test cases for io.Reader responses --- lambda/handler_test.go | 53 ++++++++++++++++++++++++++++++++++++------ 1 file changed, 46 insertions(+), 7 deletions(-) diff --git a/lambda/handler_test.go b/lambda/handler_test.go index a2eb29bc..6337df3c 100644 --- a/lambda/handler_test.go +++ b/lambda/handler_test.go @@ -6,11 +6,15 @@ import ( "context" "errors" "fmt" + "io" + "io/ioutil" + "strings" "testing" "github.com/aws/aws-lambda-go/lambda/handlertrace" "github.com/aws/aws-lambda-go/lambda/messages" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestInvalidHandlers(t *testing.T) { @@ -242,18 +246,53 @@ func TestInvokes(t *testing.T) { expected: expected{`hello`, nil}, handler: &staticHandler{body: []byte(`hello`)}, }, + { + name: "io.Reader responses are passthrough", + expected: expected{`yolo`, nil}, + handler: func() (io.Reader, error) { + return strings.NewReader(`yolo`), nil + }, + }, + { + name: "io.Reader responses that are also json serializable, using the json, ignoring the reader", + expected: expected{`{"Yolo":"yolo"}`, nil}, + handler: func() (io.Reader, error) { + return struct { + io.Reader `json:"-"` + Yolo string + }{ + Reader: strings.NewReader(`yolo`), + Yolo: "yolo", + }, nil + }, + }, } for i, testCase := range testCases { testCase := testCase t.Run(fmt.Sprintf("testCase[%d] %s", i, testCase.name), func(t *testing.T) { lambdaHandler := newHandler(testCase.handler, testCase.options...) - response, err := lambdaHandler.Invoke(context.TODO(), []byte(testCase.input)) - if testCase.expected.err != nil { - assert.Equal(t, testCase.expected.err, err) - } else { - assert.NoError(t, err) - assert.Equal(t, testCase.expected.val, string(response)) - } + t.Run("via Handler.Invoke", func(t *testing.T) { + response, err := lambdaHandler.Invoke(context.TODO(), []byte(testCase.input)) + if testCase.expected.err != nil { + assert.Equal(t, testCase.expected.err, err) + } else { + assert.NoError(t, err) + assert.Equal(t, testCase.expected.val, string(response)) + } + }) + t.Run("via handlerOptions.handlerFunc", func(t *testing.T) { + response, err := lambdaHandler.handlerFunc(context.TODO(), []byte(testCase.input)) + if testCase.expected.err != nil { + assert.Equal(t, testCase.expected.err, err) + } else { + assert.NoError(t, err) + require.NotNil(t, response) + responseBytes, err := ioutil.ReadAll(response) + assert.NoError(t, err) + assert.Equal(t, testCase.expected.val, string(responseBytes)) + } + }) + }) } } From c9eb67446ce934f81b19a443dbdcfd7831b39ffa Mon Sep 17 00:00:00 2001 From: Bryan Moffatt Date: Wed, 16 Nov 2022 16:18:13 -0800 Subject: [PATCH 03/11] add some validation of content type plumbing --- lambda/invoke_loop_test.go | 27 ++++++++++++++++++++++++--- 1 file changed, 24 insertions(+), 3 deletions(-) diff --git a/lambda/invoke_loop_test.go b/lambda/invoke_loop_test.go index 54ec96cf..fd8474fa 100644 --- a/lambda/invoke_loop_test.go +++ b/lambda/invoke_loop_test.go @@ -86,6 +86,7 @@ func TestCustomErrorMarshaling(t *testing.T) { assert.EqualError(t, startRuntimeAPILoop(endpoint, handler), expectedError) for i := range errors { assert.JSONEq(t, expected[i], string(record.responses[i])) + assert.Equal(t, contentTypeJSON, record.contentTypes[i]) } } @@ -156,7 +157,25 @@ func TestReadPayload(t *testing.T) { endpoint := strings.Split(ts.URL, "://")[1] _ = startRuntimeAPILoop(endpoint, handler) assert.Equal(t, `"socat gnivarc ma I"`, string(record.responses[0])) + assert.Equal(t, contentTypeJSON, record.contentTypes[0]) +} + +func TestBinaryResponseDefaultContentType(t *testing.T) { + ts, record := runtimeAPIServer(`{"message": "I am craving tacos"}`, 1) + defer ts.Close() + handler := NewHandler(func(event struct{ Message string }) (io.Reader, error) { + length := utf8.RuneCountInString(event.Message) + reversed := make([]rune, length) + for i, v := range event.Message { + reversed[length-i-1] = v + } + return strings.NewReader(string(reversed)), nil + }) + endpoint := strings.Split(ts.URL, "://")[1] + _ = startRuntimeAPILoop(endpoint, handler) + assert.Equal(t, `socat gnivarc ma I`, string(record.responses[0])) + assert.Equal(t, contentTypeBytes, record.contentTypes[0]) } func TestContextDeserializationErrors(t *testing.T) { @@ -209,9 +228,10 @@ func TestSafeMarshal_SerializationError(t *testing.T) { } type requestRecord struct { - nGets int - nPosts int - responses [][]byte + nGets int + nPosts int + responses [][]byte + contentTypes []string } type eventMetadata struct { @@ -276,6 +296,7 @@ func runtimeAPIServer(eventPayload string, failAfter int, overrides ...eventMeta _ = r.Body.Close() w.WriteHeader(http.StatusAccepted) record.responses = append(record.responses, response.Bytes()) + record.contentTypes = append(record.contentTypes, r.Header.Get("Content-Type")) default: w.WriteHeader(http.StatusBadRequest) } From 0d32a0e791689d823bcb6d3661778561518efac6 Mon Sep 17 00:00:00 2001 From: Bryan Moffatt Date: Thu, 17 Nov 2022 11:56:26 -0800 Subject: [PATCH 04/11] update docs --- lambda/entry.go | 3 +++ 1 file changed, 3 insertions(+) diff --git a/lambda/entry.go b/lambda/entry.go index 6c1d7194..d865e472 100644 --- a/lambda/entry.go +++ b/lambda/entry.go @@ -35,6 +35,9 @@ import ( // // Where "TIn" and "TOut" are types compatible with the "encoding/json" standard library. // See https://golang.org/pkg/encoding/json/#Unmarshal for how deserialization behaves +// +// "TOut" may also implement the io.Reader interface. +// If "TOut" is both json serializable and implements io.Reader, then the json serialization is used. func Start(handler interface{}) { StartWithOptions(handler) } From 6c1bbac35deff6f66c4d9e76c5d29582fea683fc Mon Sep 17 00:00:00 2001 From: Bryan Moffatt Date: Thu, 17 Nov 2022 12:12:48 -0800 Subject: [PATCH 05/11] fix linter error --- lambda/handler_test.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lambda/handler_test.go b/lambda/handler_test.go index 6337df3c..7a8c9382 100644 --- a/lambda/handler_test.go +++ b/lambda/handler_test.go @@ -7,7 +7,7 @@ import ( "errors" "fmt" "io" - "io/ioutil" + "io/ioutil" //nolint: staticcheck "strings" "testing" From 30909ad146a6b7e9dcbfd4fd34a7fda47f5e9ae5 Mon Sep 17 00:00:00 2001 From: Bryan Moffatt Date: Thu, 17 Nov 2022 14:07:04 -0800 Subject: [PATCH 06/11] fix build on older go versions --- lambda/handler.go | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/lambda/handler.go b/lambda/handler.go index 764adccb..88b1416c 100644 --- a/lambda/handler.go +++ b/lambda/handler.go @@ -9,6 +9,7 @@ import ( "errors" "fmt" "io" + "io/ioutil" // nolint:staticcheck "reflect" "strings" @@ -182,7 +183,7 @@ func (h handlerFunc) Invoke(ctx context.Context, payload []byte) ([]byte, error) if err != nil { return nil, err } - b, err := io.ReadAll(response) + b, err := ioutil.ReadAll(response) if err != nil { return nil, err } From 681ae3887e7de613c4be8be2c9b8c92221fa03e8 Mon Sep 17 00:00:00 2001 From: Bryan Moffatt Date: Fri, 25 Nov 2022 15:22:02 -0800 Subject: [PATCH 07/11] cover the case where the return value is a reader and not json serializable --- lambda/handler.go | 6 +++++- lambda/handler_test.go | 39 ++++++++++++++++++++++++++++++++++++--- 2 files changed, 41 insertions(+), 4 deletions(-) diff --git a/lambda/handler.go b/lambda/handler.go index 88b1416c..26c1a6ed 100644 --- a/lambda/handler.go +++ b/lambda/handler.go @@ -282,12 +282,16 @@ func reflectHandler(f interface{}, h *handlerOptions) handlerFunc { // encode to JSON if err := encoder.Encode(val); err != nil { + // if response is not JSON serializable, but the response type is a reader, return it as-is + if reader, ok := val.(io.Reader); ok { + return reader, nil + } return nil, err } // if response value is an io.Reader, return it as-is - // back-compat, don't return the reader if the value serialized to a non-empty json if reader, ok := val.(io.Reader); ok { + // back-compat, don't return the reader if the value serialized to a non-empty json if strings.HasPrefix(out.String(), "{}") { return reader, nil } diff --git a/lambda/handler_test.go b/lambda/handler_test.go index 7a8c9382..cc1fac1e 100644 --- a/lambda/handler_test.go +++ b/lambda/handler_test.go @@ -87,6 +87,15 @@ func TestInvalidHandlers(t *testing.T) { } } +type arbitraryJSON struct { + json []byte + err error +} + +func (a arbitraryJSON) MarshalJSON() ([]byte, error) { + return a.json, a.err +} + type staticHandler struct { body []byte } @@ -254,7 +263,7 @@ func TestInvokes(t *testing.T) { }, }, { - name: "io.Reader responses that are also json serializable, using the json, ignoring the reader", + name: "io.Reader responses that are also json serializable, handler returns the json, ignoring the reader", expected: expected{`{"Yolo":"yolo"}`, nil}, handler: func() (io.Reader, error) { return struct { @@ -266,6 +275,30 @@ func TestInvokes(t *testing.T) { }, nil }, }, + { + name: "types that are not json serializable result in an error", + expected: expected{``, errors.New("json: error calling MarshalJSON for type struct { lambda.arbitraryJSON }: barf")}, + handler: func() (any, error) { + return struct { + arbitraryJSON + }{ + arbitraryJSON{nil, errors.New("barf")}, + }, nil + }, + }, + { + name: "io.Reader responses that not json serializable remain passthrough", + expected: expected{`wat`, nil}, + handler: func() (io.Reader, error) { + return struct { + arbitraryJSON + io.Reader + }{ + arbitraryJSON{nil, errors.New("barf")}, + strings.NewReader("wat"), + }, nil + }, + }, } for i, testCase := range testCases { testCase := testCase @@ -274,7 +307,7 @@ func TestInvokes(t *testing.T) { t.Run("via Handler.Invoke", func(t *testing.T) { response, err := lambdaHandler.Invoke(context.TODO(), []byte(testCase.input)) if testCase.expected.err != nil { - assert.Equal(t, testCase.expected.err, err) + assert.EqualError(t, err, testCase.expected.err.Error()) } else { assert.NoError(t, err) assert.Equal(t, testCase.expected.val, string(response)) @@ -283,7 +316,7 @@ func TestInvokes(t *testing.T) { t.Run("via handlerOptions.handlerFunc", func(t *testing.T) { response, err := lambdaHandler.handlerFunc(context.TODO(), []byte(testCase.input)) if testCase.expected.err != nil { - assert.Equal(t, testCase.expected.err, err) + assert.EqualError(t, err, testCase.expected.err.Error()) } else { assert.NoError(t, err) require.NotNil(t, response) From bb4df18989d5b389431ccf1ec71bc3c2dc201b80 Mon Sep 17 00:00:00 2001 From: Bryan Moffatt Date: Fri, 25 Nov 2022 15:41:42 -0800 Subject: [PATCH 08/11] fix unit tests on older go versions --- lambda/handler_test.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lambda/handler_test.go b/lambda/handler_test.go index cc1fac1e..e67eb7e7 100644 --- a/lambda/handler_test.go +++ b/lambda/handler_test.go @@ -278,7 +278,7 @@ func TestInvokes(t *testing.T) { { name: "types that are not json serializable result in an error", expected: expected{``, errors.New("json: error calling MarshalJSON for type struct { lambda.arbitraryJSON }: barf")}, - handler: func() (any, error) { + handler: func() (interface{}, error) { return struct { arbitraryJSON }{ From 06edcdd2c0495e699967be2443386cb8d1b8758f Mon Sep 17 00:00:00 2001 From: Bryan Moffatt Date: Thu, 1 Dec 2022 11:36:16 -0800 Subject: [PATCH 09/11] handle the case when the returned reader needs to be closed by the caller --- lambda/handler.go | 4 ++++ lambda/invoke_loop.go | 10 ++++++++-- lambda/invoke_loop_test.go | 39 ++++++++++++++++++++++++++++++++++++++ 3 files changed, 51 insertions(+), 2 deletions(-) diff --git a/lambda/handler.go b/lambda/handler.go index 26c1a6ed..90f33176 100644 --- a/lambda/handler.go +++ b/lambda/handler.go @@ -183,6 +183,10 @@ func (h handlerFunc) Invoke(ctx context.Context, payload []byte) ([]byte, error) if err != nil { return nil, err } + // if the response needs to be closed (ex: net.Conn, os.File), ensure it's closed before the next invoke to prevent a resource leak + if response, ok := response.(io.Closer); ok { + defer response.Close() + } b, err := ioutil.ReadAll(response) if err != nil { return nil, err diff --git a/lambda/invoke_loop.go b/lambda/invoke_loop.go index 829c9166..9e2d6598 100644 --- a/lambda/invoke_loop.go +++ b/lambda/invoke_loop.go @@ -82,12 +82,18 @@ func handleInvoke(invoke *invoke, handler *handlerOptions) error { } return nil } + // if the response needs to be closed (ex: net.Conn, os.File), ensure it's closed before the next invoke to prevent a resource leak + if response, ok := response.(io.Closer); ok { + defer response.Close() + } + // if the response defines a content-type, plumb it through contentType := contentTypeBytes type ContentType interface{ ContentType() string } - if ct, ok := response.(ContentType); ok { - contentType = ct.ContentType() + if response, ok := response.(ContentType); ok { + contentType = response.ContentType() } + if err := invoke.success(response, contentType); err != nil { return fmt.Errorf("unexpected error occurred when sending the function functionResponse to the API: %v", err) } diff --git a/lambda/invoke_loop_test.go b/lambda/invoke_loop_test.go index fd8474fa..fab800b9 100644 --- a/lambda/invoke_loop_test.go +++ b/lambda/invoke_loop_test.go @@ -160,6 +160,20 @@ func TestReadPayload(t *testing.T) { assert.Equal(t, contentTypeJSON, record.contentTypes[0]) } +type readCloser struct { + closed bool + reader *strings.Reader +} + +func (r *readCloser) Read(p []byte) (int, error) { + return r.reader.Read(p) +} + +func (r *readCloser) Close() error { + r.closed = true + return nil +} + func TestBinaryResponseDefaultContentType(t *testing.T) { ts, record := runtimeAPIServer(`{"message": "I am craving tacos"}`, 1) defer ts.Close() @@ -178,6 +192,31 @@ func TestBinaryResponseDefaultContentType(t *testing.T) { assert.Equal(t, contentTypeBytes, record.contentTypes[0]) } +func TestBinaryResponseDoesNotLeakResources(t *testing.T) { + numResponses := 3 + responses := make([]*readCloser, numResponses) + for i := 0; i < numResponses; i++ { + responses[i] = &readCloser{closed: false, reader: strings.NewReader(fmt.Sprintf("hello %d", i))} + } + timesCalled := 0 + handler := NewHandler(func() (res io.Reader, _ error) { + res = responses[timesCalled] + timesCalled++ + return + }) + + ts, record := runtimeAPIServer(`{}`, numResponses) + defer ts.Close() + endpoint := strings.Split(ts.URL, "://")[1] + _ = startRuntimeAPILoop(endpoint, handler) + + for i := 0; i < numResponses; i++ { + assert.Equal(t, contentTypeBytes, record.contentTypes[i]) + assert.Equal(t, fmt.Sprintf("hello %d", i), string(record.responses[i])) + assert.True(t, responses[i].closed) + } +} + func TestContextDeserializationErrors(t *testing.T) { badClientContext := defaultInvokeMetadata() badClientContext.clientContext = `{ not json }` From 33d74d4dc6e089f52c89c99ab05adffcd159bf72 Mon Sep 17 00:00:00 2001 From: Bryan Moffatt Date: Thu, 22 Dec 2022 15:20:06 -0800 Subject: [PATCH 10/11] eliminate a buffer copy for handlers using rpc-mode, add extra line coverage to rpc mode logic --- lambda/handler.go | 7 +++++ lambda/handler_test.go | 20 +++++++++++--- lambda/rpc_function_test.go | 53 +++++++++++++++++++++++++++++++++++++ 3 files changed, 76 insertions(+), 4 deletions(-) diff --git a/lambda/handler.go b/lambda/handler.go index 90f33176..ee273577 100644 --- a/lambda/handler.go +++ b/lambda/handler.go @@ -187,6 +187,13 @@ func (h handlerFunc) Invoke(ctx context.Context, payload []byte) ([]byte, error) if response, ok := response.(io.Closer); ok { defer response.Close() } + // optimization: if the response is a *bytes.Buffer, a copy can be eliminated + switch response := response.(type) { + case *jsonOutBuffer: + return response.Bytes(), nil + case *bytes.Buffer: + return response.Bytes(), nil + } b, err := ioutil.ReadAll(response) if err != nil { return nil, err diff --git a/lambda/handler_test.go b/lambda/handler_test.go index e67eb7e7..17303827 100644 --- a/lambda/handler_test.go +++ b/lambda/handler_test.go @@ -3,6 +3,7 @@ package lambda import ( + "bytes" "context" "errors" "fmt" @@ -127,10 +128,8 @@ func TestInvokes(t *testing.T) { }{ { input: `"Lambda"`, - expected: expected{`"Hello Lambda!"`, nil}, - handler: func(name string) (string, error) { - return hello(name), nil - }, + expected: expected{`null`, nil}, + handler: func(_ string) {}, }, { input: `"Lambda"`, @@ -139,6 +138,12 @@ func TestInvokes(t *testing.T) { return hello(name), nil }, }, + { + expected: expected{`"Hello No Value!"`, nil}, + handler: func(ctx context.Context) (string, error) { + return hello("No Value"), nil + }, + }, { input: `"Lambda"`, expected: expected{`"Hello Lambda!"`, nil}, @@ -262,6 +267,13 @@ func TestInvokes(t *testing.T) { return strings.NewReader(`yolo`), nil }, }, + { + name: "io.Reader responses that are byte buffers are passthrough", + expected: expected{`yolo`, nil}, + handler: func() (*bytes.Buffer, error) { + return bytes.NewBuffer([]byte(`yolo`)), nil + }, + }, { name: "io.Reader responses that are also json serializable, handler returns the json, ignoring the reader", expected: expected{`{"Yolo":"yolo"}`, nil}, diff --git a/lambda/rpc_function_test.go b/lambda/rpc_function_test.go index bcc08332..b54b5306 100644 --- a/lambda/rpc_function_test.go +++ b/lambda/rpc_function_test.go @@ -233,3 +233,56 @@ func TestXAmznTraceID(t *testing.T) { } } + +type closeableResponse struct { + reader io.Reader + closed bool +} + +func (c *closeableResponse) Read(p []byte) (int, error) { + return c.reader.Read(p) +} + +func (c *closeableResponse) Close() error { + c.closed = true + return nil +} + +type readerError struct { + err error +} + +func (r *readerError) Read(_ []byte) (int, error) { + return 0, r.err +} + +func TestRPCModeInvokeClosesCloserIfResponseIsCloser(t *testing.T) { + handlerResource := &closeableResponse{ + reader: strings.NewReader(""), + closed: false, + } + srv := NewFunction(newHandler(func() (any, error) { + return handlerResource, nil + })) + var response messages.InvokeResponse + err := srv.Invoke(&messages.InvokeRequest{}, &response) + require.NoError(t, err) + assert.Equal(t, "", string(response.Payload)) + assert.True(t, handlerResource.closed) +} + +func TestRPCModeInvokeReaderErrorPropogated(t *testing.T) { + handlerResource := &closeableResponse{ + reader: &readerError{errors.New("yolo")}, + closed: false, + } + srv := NewFunction(newHandler(func() (any, error) { + return handlerResource, nil + })) + var response messages.InvokeResponse + err := srv.Invoke(&messages.InvokeRequest{}, &response) + require.NoError(t, err) + assert.Equal(t, "", string(response.Payload)) + assert.Equal(t, "yolo", response.Error.Message) + assert.True(t, handlerResource.closed) +} From a6eae083c071d632d48852c7f99f7a7fc585088d Mon Sep 17 00:00:00 2001 From: Bryan Moffatt Date: Thu, 22 Dec 2022 15:24:52 -0800 Subject: [PATCH 11/11] any -> interface{} --- lambda/rpc_function_test.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/lambda/rpc_function_test.go b/lambda/rpc_function_test.go index b54b5306..515fc62a 100644 --- a/lambda/rpc_function_test.go +++ b/lambda/rpc_function_test.go @@ -261,7 +261,7 @@ func TestRPCModeInvokeClosesCloserIfResponseIsCloser(t *testing.T) { reader: strings.NewReader(""), closed: false, } - srv := NewFunction(newHandler(func() (any, error) { + srv := NewFunction(newHandler(func() (interface{}, error) { return handlerResource, nil })) var response messages.InvokeResponse @@ -276,7 +276,7 @@ func TestRPCModeInvokeReaderErrorPropogated(t *testing.T) { reader: &readerError{errors.New("yolo")}, closed: false, } - srv := NewFunction(newHandler(func() (any, error) { + srv := NewFunction(newHandler(func() (interface{}, error) { return handlerResource, nil })) var response messages.InvokeResponse