Skip to content

Extends the Function Stage to Handle Panics #29

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
May 20, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 35 additions & 13 deletions pipe/function.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,29 +32,41 @@ func Function(name string, f StageFunc) Stage {
// goStage is a `Stage` that does its work by running an arbitrary
// `stageFunc` in a goroutine.
type goStage struct {
name string
f StageFunc
done chan struct{}
err error
name string
f StageFunc
done chan struct{}
err error
panicHandler StagePanicHandler
}

func (s *goStage) Name() string {
return s.name
}

func (s *goStage) SetPanicHandler(ph StagePanicHandler) {
s.panicHandler = ph
}

func (s *goStage) Start(ctx context.Context, env Env, stdin io.ReadCloser) (io.ReadCloser, error) {
r, w := io.Pipe()

go func() {
s.err = s.f(ctx, env, stdin, w)
if err := w.Close(); err != nil && s.err == nil {
s.err = fmt.Errorf("error closing output pipe for stage %q: %w", s.Name(), err)
}
if stdin != nil {
if err := stdin.Close(); err != nil && s.err == nil {
s.err = fmt.Errorf("error closing stdin for stage %q: %w", s.Name(), err)
defer func() {
// Cleanup resources on exit
if err := w.Close(); err != nil && s.err == nil {
s.err = fmt.Errorf("error closing output pipe for stage %q: %w", s.Name(), err)
}
}
close(s.done)
if stdin != nil {
if err := stdin.Close(); err != nil && s.err == nil {
s.err = fmt.Errorf("error closing stdin for stage %q: %w", s.Name(), err)
}
}
close(s.done)
}()

defer s.recoverPanic()

s.err = s.f(ctx, env, stdin, w)
}()

return r, nil
Expand All @@ -64,3 +76,13 @@ func (s *goStage) Wait() error {
<-s.done
return s.err
}

func (s *goStage) recoverPanic() {
if s.panicHandler == nil {
return
}

if p := recover(); p != nil {
s.err = s.panicHandler(p)
}
}
12 changes: 12 additions & 0 deletions pipe/panic.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
package pipe

// StagePanicHandlerAware is an interface that Stages can implement to receive
// a panic handler from the pipeline. This is particularly useful for stages
// that execute work in a separate goroutine and need to manage panics occurring
// within that goroutine.
type StagePanicHandlerAware interface {
SetPanicHandler(StagePanicHandler)
}

// StagePanicHandler is a function that handles panics in a pipeline's stages.
type StagePanicHandler func(p any) error
19 changes: 19 additions & 0 deletions pipe/pipeline.go
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ type Pipeline struct {
started uint32

eventHandler func(e *Event)
panicHandler StagePanicHandler
}

var emptyEventHandler = func(e *Event) {}
Expand Down Expand Up @@ -179,6 +180,20 @@ func WithEventHandler(handler func(e *Event)) Option {
}
}

// WithStagePanicHandler sets a panic handler for the stages within a pipeline.
// When a pipeline stage panics, the provided handler will be invoked, allowing
// the client to handle the panic in whatever way they see fit.
//
// Note:
// - Only the Function stage supports this functionality.
// - The client is responsible for deciding whether to recover from the panic or panicking again.
// - If a panic handler is not set, the panic will be propagated normally.
func WithStagePanicHandler(ph StagePanicHandler) Option {
return func(p *Pipeline) {
p.panicHandler = ph
}
}

func (p *Pipeline) hasStarted() bool {
return atomic.LoadUint32(&p.started) != 0
}
Expand Down Expand Up @@ -265,6 +280,10 @@ func (p *Pipeline) Start(ctx context.Context) error {
}

for i, s := range p.stages {
if phs, ok := s.(StagePanicHandlerAware); ok && p.panicHandler != nil {
phs.SetPanicHandler(p.panicHandler)
}

var err error
stdout, err := s.Start(ctx, p.env, nextStdin)
if err != nil {
Expand Down
63 changes: 44 additions & 19 deletions pipe/pipeline_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -436,28 +436,53 @@ func TestFunction(t *testing.T) {

dir := t.TempDir()

p := pipe.New(pipe.WithDir(dir))
p.Add(
pipe.Print("hello world"),
pipe.Function(
"farewell",
func(_ context.Context, _ pipe.Env, stdin io.Reader, stdout io.Writer) error {
buf, err := io.ReadAll(stdin)
if err != nil {
t.Run("successful function", func(t *testing.T) {
p := pipe.New(pipe.WithDir(dir))
p.Add(
pipe.Print("hello world"),
pipe.Function(
"farewell",
func(_ context.Context, _ pipe.Env, stdin io.Reader, stdout io.Writer) error {
buf, err := io.ReadAll(stdin)
if err != nil {
return err
}
if string(buf) != "hello world" {
return fmt.Errorf("expected \"hello world\"; got %q", string(buf))
}
_, err = stdout.Write([]byte("goodbye, cruel world"))
return err
}
if string(buf) != "hello world" {
return fmt.Errorf("expected \"hello world\"; got %q", string(buf))
}
_, err = stdout.Write([]byte("goodbye, cruel world"))
},
),
)

out, err := p.Output(ctx)
assert.NoError(t, err)
assert.EqualValues(t, "goodbye, cruel world", out)
})

t.Run("panic with handler", func(t *testing.T) {
p := pipe.New(
pipe.WithDir(dir),
pipe.WithStagePanicHandler(func(p any) error {
err := fmt.Errorf("panic handled: %v", p)
return err
},
),
)
}),
)
p.Add(
pipe.Print("hello world"),
pipe.Function(
"farewell",
func(_ context.Context, _ pipe.Env, stdin io.Reader, stdout io.Writer) error {
panic("this is a panic")
},
),
)

out, err := p.Output(ctx)
assert.NoError(t, err)
assert.EqualValues(t, "goodbye, cruel world", out)
out, err := p.Output(ctx)
assert.ErrorContains(t, err, "panic handled")
assert.Empty(t, out)
})
}

func TestPipelineWithFunction(t *testing.T) {
Expand Down
Loading