diff --git a/pipe/function.go b/pipe/function.go index bc5d0bd..f268d7a 100644 --- a/pipe/function.go +++ b/pipe/function.go @@ -32,19 +32,44 @@ func Function(name string, f StageFunc) Stage { // goStage is a `Stage` that does its work by running an arbitrary // `stageFunc` in a goroutine. type goStage struct { - name string - f StageFunc - done chan struct{} - err error + name string + f StageFunc + done chan struct{} + err error + panicHandler StagePanicHandler } func (s *goStage) Name() string { return s.name } +func (s *goStage) SetPanicHandler(ph StagePanicHandler) { + s.panicHandler = ph +} + func (s *goStage) Start(ctx context.Context, env Env, stdin io.ReadCloser) (io.ReadCloser, error) { r, w := io.Pipe() + go func() { + defer func() { + if p := recover(); p != nil { + if s.panicHandler == nil { + // Nothing to do, just panic + panic(p) + } + + _ = w.Close() + if stdin != nil { + _ = stdin.Close() + } + close(s.done) + + err := FromPanicValue(p) + s.err = err + s.panicHandler(err) + } + }() + s.err = s.f(ctx, env, stdin, w) if err := w.Close(); err != nil && s.err == nil { s.err = fmt.Errorf("error closing output pipe for stage %q: %w", s.Name(), err) diff --git a/pipe/panic.go b/pipe/panic.go new file mode 100644 index 0000000..e18700d --- /dev/null +++ b/pipe/panic.go @@ -0,0 +1,27 @@ +package pipe + +import "fmt" + +// StagePanicHandlerAware is an interface that Stages can implement to receive +// a panic handler from the pipeline. This is particularly useful for stages +// that execute work in a separate goroutine and need to manage panics occurring +// within that goroutine. +type StagePanicHandlerAware interface { + SetPanicHandler(StagePanicHandler) +} + +// StagePanicHandler is a function that handles panics in a pipeline's stages. +type StagePanicHandler func(err error) + +// FromPanicValue converts a panic value to an error. If the panic value is +// already an error, it returns it directly. Otherwise, it wraps the value in +// a generic error. +func FromPanicValue(p any) error { + var err error + if e, ok := p.(error); ok { + err = e + } else { + err = fmt.Errorf("%v", p) + } + return err +} diff --git a/pipe/pipeline.go b/pipe/pipeline.go index e591c63..9df1b51 100644 --- a/pipe/pipeline.go +++ b/pipe/pipeline.go @@ -64,6 +64,7 @@ type Pipeline struct { started uint32 eventHandler func(e *Event) + panicHandler StagePanicHandler } var emptyEventHandler = func(e *Event) {} @@ -179,6 +180,20 @@ func WithEventHandler(handler func(e *Event)) Option { } } +// WithStagePanicHandler sets a panic handler for the stages within a pipeline. +// When a pipeline stage panics, the provided handler will be invoked, allowing +// the client to handle the panic in whatever way they see fit. +// +// Note: +// - Only the Function stage supports this functionality. +// - The client is responsible for deciding whether to recover from the panic or panicking again. +// - If a panic handler is not set, the panic will be propagated normally. +func WithStagePanicHandler(ph StagePanicHandler) Option { + return func(p *Pipeline) { + p.panicHandler = ph + } +} + func (p *Pipeline) hasStarted() bool { return atomic.LoadUint32(&p.started) != 0 } @@ -265,6 +280,10 @@ func (p *Pipeline) Start(ctx context.Context) error { } for i, s := range p.stages { + if phs, ok := s.(StagePanicHandlerAware); ok && p.panicHandler != nil { + phs.SetPanicHandler(p.panicHandler) + } + var err error stdout, err := s.Start(ctx, p.env, nextStdin) if err != nil { diff --git a/pipe/pipeline_test.go b/pipe/pipeline_test.go index d925aee..326df55 100644 --- a/pipe/pipeline_test.go +++ b/pipe/pipeline_test.go @@ -436,28 +436,54 @@ func TestFunction(t *testing.T) { dir := t.TempDir() - p := pipe.New(pipe.WithDir(dir)) - p.Add( - pipe.Print("hello world"), - pipe.Function( - "farewell", - func(_ context.Context, _ pipe.Env, stdin io.Reader, stdout io.Writer) error { - buf, err := io.ReadAll(stdin) - if err != nil { + t.Run("successful function", func(t *testing.T) { + p := pipe.New(pipe.WithDir(dir)) + p.Add( + pipe.Print("hello world"), + pipe.Function( + "farewell", + func(_ context.Context, _ pipe.Env, stdin io.Reader, stdout io.Writer) error { + buf, err := io.ReadAll(stdin) + if err != nil { + return err + } + if string(buf) != "hello world" { + return fmt.Errorf("expected \"hello world\"; got %q", string(buf)) + } + _, err = stdout.Write([]byte("goodbye, cruel world")) return err - } - if string(buf) != "hello world" { - return fmt.Errorf("expected \"hello world\"; got %q", string(buf)) - } - _, err = stdout.Write([]byte("goodbye, cruel world")) - return err - }, - ), - ) + }, + ), + ) - out, err := p.Output(ctx) - assert.NoError(t, err) - assert.EqualValues(t, "goodbye, cruel world", out) + out, err := p.Output(ctx) + assert.NoError(t, err) + assert.EqualValues(t, "goodbye, cruel world", out) + }) + + t.Run("panic with handler", func(t *testing.T) { + panickedMessage := make(chan error, 1) + p := pipe.New( + pipe.WithDir(dir), + pipe.WithStagePanicHandler(func(err error) { + panickedMessage <- err + }), + ) + p.Add( + pipe.Print("hello world"), + pipe.Function( + "farewell", + func(_ context.Context, _ pipe.Env, stdin io.Reader, stdout io.Writer) error { + panic("this is a panic") + }, + ), + ) + + out, err := p.Output(ctx) + assert.Error(t, <-panickedMessage) + assert.Error(t, err) + assert.Empty(t, out) + }) } func TestPipelineWithFunction(t *testing.T) {