Skip to content

Commit 26aa364

Browse files
authored
Support setting an optional base context for functions. (#287)
* Support setting an optional base context for functions. * refactor: remove exported context related Function methods
1 parent c67fade commit 26aa364

File tree

3 files changed

+70
-5
lines changed

3 files changed

+70
-5
lines changed

lambda/entry.go

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
package lambda
44

55
import (
6+
"context"
67
"log"
78
"net"
89
"net/rpc"
@@ -37,8 +38,12 @@ import (
3738
// Where "TIn" and "TOut" are types compatible with the "encoding/json" standard library.
3839
// See https://golang.org/pkg/encoding/json/#Unmarshal for how deserialization behaves
3940
func Start(handler interface{}) {
40-
wrappedHandler := NewHandler(handler)
41-
StartHandler(wrappedHandler)
41+
StartWithContext(context.Background(), handler)
42+
}
43+
44+
// StartWithContext is the same as Start except sets the base context for the function.
45+
func StartWithContext(ctx context.Context, handler interface{}) {
46+
StartHandlerWithContext(ctx, NewHandler(handler))
4247
}
4348

4449
// StartHandler takes in a Handler wrapper interface which can be implemented either by a
@@ -48,15 +53,26 @@ func Start(handler interface{}) {
4853
//
4954
// func Invoke(context.Context, []byte) ([]byte, error)
5055
func StartHandler(handler Handler) {
56+
StartHandlerWithContext(context.Background(), handler)
57+
}
58+
59+
// StartHandlerWithContext is the same as StartHandler except sets the base context for the function.
60+
//
61+
// Handler implementation requires a single "Invoke()" function:
62+
//
63+
// func Invoke(context.Context, []byte) ([]byte, error)
64+
func StartHandlerWithContext(ctx context.Context, handler Handler) {
5165
port := os.Getenv("_LAMBDA_SERVER_PORT")
5266
lis, err := net.Listen("tcp", "localhost:"+port)
5367
if err != nil {
5468
log.Fatal(err)
5569
}
56-
err = rpc.Register(NewFunction(handler))
57-
if err != nil {
70+
71+
fn := NewFunction(handler).withContext(ctx)
72+
if err := rpc.Register(fn); err != nil {
5873
log.Fatal("failed to register handler function")
5974
}
75+
6076
rpc.Accept(lis)
6177
log.Fatal("accept should not have returned")
6278
}

lambda/function.go

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ import (
1616
// Function struct which wrap the Handler
1717
type Function struct {
1818
handler Handler
19+
ctx context.Context
1920
}
2021

2122
// NewFunction which creates a Function with a given Handler
@@ -44,7 +45,7 @@ func (fn *Function) Invoke(req *messages.InvokeRequest, response *messages.Invok
4445
}()
4546

4647
deadline := time.Unix(req.Deadline.Seconds, req.Deadline.Nanos).UTC()
47-
invokeContext, cancel := context.WithDeadline(context.Background(), deadline)
48+
invokeContext, cancel := context.WithDeadline(fn.context(), deadline)
4849
defer cancel()
4950

5051
lc := &lambdacontext.LambdaContext{
@@ -75,6 +76,30 @@ func (fn *Function) Invoke(req *messages.InvokeRequest, response *messages.Invok
7576
return nil
7677
}
7778

79+
// context returns the base context used for the fn.
80+
func (fn *Function) context() context.Context {
81+
if fn.ctx == nil {
82+
return context.Background()
83+
}
84+
85+
return fn.ctx
86+
}
87+
88+
// withContext returns a shallow copy of Function with its context changed
89+
// to the provided ctx. If the provided ctx is non-nil a Background context is set.
90+
func (fn *Function) withContext(ctx context.Context) *Function {
91+
if ctx == nil {
92+
ctx = context.Background()
93+
}
94+
95+
fn2 := new(Function)
96+
*fn2 = *fn
97+
98+
fn2.ctx = ctx
99+
100+
return fn2
101+
}
102+
78103
func getErrorType(err interface{}) string {
79104
errorType := reflect.TypeOf(err)
80105
if errorType.Kind() == reflect.Ptr {

lambda/function_test.go

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,30 @@ func TestInvoke(t *testing.T) {
5858
assert.Equal(t, deadline.UnixNano(), responseValue)
5959
}
6060

61+
func TestInvokeWithContext(t *testing.T) {
62+
key := struct{}{}
63+
srv := NewFunction(testWrapperHandler(
64+
func(ctx context.Context, input []byte) (interface{}, error) {
65+
assert.Equal(t, "dummy", ctx.Value(key))
66+
if deadline, ok := ctx.Deadline(); ok {
67+
return deadline.UnixNano(), nil
68+
}
69+
return nil, errors.New("!?!?!?!?!")
70+
}))
71+
srv = srv.withContext(context.WithValue(context.Background(), key, "dummy"))
72+
deadline := time.Now()
73+
var response messages.InvokeResponse
74+
err := srv.Invoke(&messages.InvokeRequest{
75+
Deadline: messages.InvokeRequest_Timestamp{
76+
Seconds: deadline.Unix(),
77+
Nanos: int64(deadline.Nanosecond()),
78+
}}, &response)
79+
assert.NoError(t, err)
80+
var responseValue int64
81+
assert.NoError(t, json.Unmarshal(response.Payload, &responseValue))
82+
assert.Equal(t, deadline.UnixNano(), responseValue)
83+
}
84+
6185
type CustomError struct{}
6286

6387
func (e CustomError) Error() string { return fmt.Sprintf("Something bad happened!") }

0 commit comments

Comments
 (0)