diff --git a/Makefile b/Makefile index 6cb365d..731839d 100644 --- a/Makefile +++ b/Makefile @@ -16,14 +16,13 @@ vet: go vet ./... BIN := $(CURDIR)/bin -GO := GO $(BIN): - mkdir -p $(BIN) + mkdir -p $(BIN) # Run golang-ci lint on all source files: GOLANGCILINT := $(BIN)/golangci-lint $(BIN)/golangci-lint: - GOBIN=$(BIN) $(GO) install github.com/golangci/golangci-lint/cmd/golangci-lint@latest + GOBIN=$(BIN) go install github.com/golangci/golangci-lint/cmd/golangci-lint@latest .PHONY: fmt lint: | $(GOLANGCILINT) diff --git a/pipe/command.go b/pipe/command.go index 2c465e9..73d94c1 100644 --- a/pipe/command.go +++ b/pipe/command.go @@ -15,12 +15,17 @@ import ( "golang.org/x/sync/errgroup" ) -// commandStage is a pipeline `Stage` based on running an external -// command and piping the data through its stdin and stdout. +// commandStage is a pipeline `Stage` based on running an +// external command and piping the data through its stdin and stdout. +// It also implements `StageWithIO`. type commandStage struct { - name string - stdin io.Closer - cmd *exec.Cmd + name string + cmd *exec.Cmd + + // lateClosers is a list of things that have to be closed once the + // command has finished. + lateClosers []io.Closer + done chan struct{} wg errgroup.Group stderr bytes.Buffer @@ -30,11 +35,15 @@ type commandStage struct { ctxErr atomic.Value } -// Command returns a pipeline `Stage` based on the specified external -// `command`, run with the given command-line `args`. Its stdin and -// stdout are handled as usual, and its stderr is collected and -// included in any `*exec.ExitError` that the command might emit. -func Command(command string, args ...string) Stage { +var ( + _ StageWithIO = (*commandStage)(nil) +) + +// Command returns a pipeline `StageWithIO` based on the specified +// external `command`, run with the given command-line `args`. Its +// stdin and stdout are handled as usual, and its stderr is collected +// and included in any `*exec.ExitError` that the command might emit. +func Command(command string, args ...string) StageWithIO { if len(command) == 0 { panic("attempt to create command with empty command") } @@ -47,7 +56,7 @@ func Command(command string, args ...string) Stage { // the specified `cmd`. Its stdin and stdout are handled as usual, and // its stderr is collected and included in any `*exec.ExitError` that // the command might emit. -func CommandStage(name string, cmd *exec.Cmd) Stage { +func CommandStage(name string, cmd *exec.Cmd) StageWithIO { return &commandStage{ name: name, cmd: cmd, @@ -62,30 +71,101 @@ func (s *commandStage) Name() string { func (s *commandStage) Start( ctx context.Context, env Env, stdin io.ReadCloser, ) (io.ReadCloser, error) { + pr, pw, err := os.Pipe() + if err != nil { + return nil, err + } + + if err := s.StartWithIO(ctx, env, stdin, pw); err != nil { + _ = pr.Close() + _ = pw.Close() + return nil, err + } + + // Now close our copy of the write end of the pipe (the subprocess + // has its own copy now and will keep it open as long as it is + // running). There's not much we can do now in the case of an + // error, so just ignore them. + _ = pw.Close() + + // The caller is responsible for closing `pr`. + return pr, nil +} + +func (s *commandStage) Preferences() StagePreferences { + prefs := StagePreferences{ + StdinPreference: IOPreferenceFile, + StdoutPreference: IOPreferenceFile, + } + if s.cmd.Stdin != nil { + prefs.StdinPreference = IOPreferenceNil + } + if s.cmd.Stdout != nil { + prefs.StdoutPreference = IOPreferenceNil + } + + return prefs +} + +func (s *commandStage) StartWithIO( + ctx context.Context, env Env, stdin io.ReadCloser, stdout io.WriteCloser, +) error { if s.cmd.Dir == "" { s.cmd.Dir = env.Dir } s.setupEnv(ctx, env) + // Things that have to be closed as soon as the command has + // started: + var earlyClosers []io.Closer + + // See the type command for `Stage` and the long comment in + // `Pipeline.WithStdin()` for the explanation of this unwrapping + // and closing behavior. + if stdin != nil { - // See the long comment in `Pipeline.Start()` for the - // explanation of this special case. switch stdin := stdin.(type) { - case nopCloser: + case readerNopCloser: + // In this case, we shouldn't close it. But unwrap it for + // efficiency's sake: s.cmd.Stdin = stdin.Reader - case nopCloserWriterTo: + case readerWriterToNopCloser: + // In this case, we shouldn't close it. But unwrap it for + // efficiency's sake: s.cmd.Stdin = stdin.Reader + case *os.File: + // In this case, we can close stdin as soon as the command + // has started: + s.cmd.Stdin = stdin + earlyClosers = append(earlyClosers, stdin) default: + // In this case, we need to close `stdin`, but we should + // only do so after the command has finished: s.cmd.Stdin = stdin + s.lateClosers = append(s.lateClosers, stdin) } - // Also keep a copy so that we can close it when the command exits: - s.stdin = stdin } - stdout, err := s.cmd.StdoutPipe() - if err != nil { - return nil, err + if stdout != nil { + // See the long comment in `Pipeline.Start()` for the + // explanation of this special case. + switch stdout := stdout.(type) { + case writerNopCloser: + // In this case, we shouldn't close it. But unwrap it for + // efficiency's sake: + s.cmd.Stdout = stdout.Writer + case *os.File: + // In this case, we can close stdout as soon as the command + // has started: + s.cmd.Stdout = stdout + earlyClosers = append(earlyClosers, stdout) + default: + // In this case, we need to close `stdout`, but we should + // only do so after the command has finished: + s.cmd.Stdout = stdout + s.lateClosers = append(s.lateClosers, stdout) + } } // If the caller hasn't arranged otherwise, read the command's @@ -97,7 +177,7 @@ func (s *commandStage) Start( // can be sure. p, err := s.cmd.StderrPipe() if err != nil { - return nil, err + return err } s.wg.Go(func() error { _, err := io.Copy(&s.stderr, p) @@ -114,7 +194,11 @@ func (s *commandStage) Start( s.runInOwnProcessGroup() if err := s.cmd.Start(); err != nil { - return nil, err + return err + } + + for _, closer := range earlyClosers { + _ = closer.Close() } // Arrange for the process to be killed (gently) if the context @@ -128,7 +212,7 @@ func (s *commandStage) Start( } }() - return stdout, nil + return nil } // setupEnv sets or modifies the environment that will be passed to @@ -217,19 +301,18 @@ func (s *commandStage) Wait() error { // Make sure that any stderr is copied before `s.cmd.Wait()` // closes the read end of the pipe: - wErr := s.wg.Wait() + wgErr := s.wg.Wait() err := s.cmd.Wait() err = s.filterCmdError(err) - if err == nil && wErr != nil { - err = wErr + if err == nil && wgErr != nil { + err = wgErr } - if s.stdin != nil { - cErr := s.stdin.Close() - if cErr != nil && err == nil { - return cErr + for _, closer := range s.lateClosers { + if closeErr := closer.Close(); closeErr != nil && err == nil { + err = closeErr } } diff --git a/pipe/command_linux.go b/pipe/command_linux.go index c32ebc7..69cd82e 100644 --- a/pipe/command_linux.go +++ b/pipe/command_linux.go @@ -10,7 +10,7 @@ import ( ) // On linux, we can limit or observe memory usage in command stages. -var _ LimitableStage = (*commandStage)(nil) +var _ LimitableStageWithIO = (*commandStage)(nil) var ( errProcessInfoMissing = errors.New("cmd.Process is nil") diff --git a/pipe/command_test.go b/pipe/command_test.go index 92fd37a..67cd55e 100644 --- a/pipe/command_test.go +++ b/pipe/command_test.go @@ -79,7 +79,8 @@ func TestCopyEnvWithOverride(t *testing.T) { ex := ex t.Run(ex.label, func(t *testing.T) { assert.ElementsMatch(t, ex.expectedResult, - copyEnvWithOverrides(ex.env, ex.overrides)) + copyEnvWithOverrides(ex.env, ex.overrides), + ) }) } } diff --git a/pipe/export_test.go b/pipe/export_test.go new file mode 100644 index 0000000..2812292 --- /dev/null +++ b/pipe/export_test.go @@ -0,0 +1,4 @@ +package pipe + +// This file exports a functions to be used only for testing. +var UnwrapNopCloser = unwrapNopCloser diff --git a/pipe/filter-error.go b/pipe/filter-error.go index 654796a..fa7a3f7 100644 --- a/pipe/filter-error.go +++ b/pipe/filter-error.go @@ -14,6 +14,9 @@ import ( type ErrorFilter func(err error) error func FilterError(s Stage, filter ErrorFilter) Stage { + if s, ok := s.(StageWithIO); ok { + return efStageWithIO{StageWithIO: s, filter: filter} + } return efStage{Stage: s, filter: filter} } @@ -26,6 +29,15 @@ func (s efStage) Wait() error { return s.filter(s.Stage.Wait()) } +type efStageWithIO struct { + StageWithIO + filter ErrorFilter +} + +func (s efStageWithIO) Wait() error { + return s.filter(s.StageWithIO.Wait()) +} + // ErrorMatcher decides whether its argument matches some class of // errors (e.g., errors that we want to ignore). The function will // only be invoked for non-nil errors. diff --git a/pipe/function.go b/pipe/function.go index bc5d0bd..627c036 100644 --- a/pipe/function.go +++ b/pipe/function.go @@ -9,7 +9,7 @@ import ( // StageFunc is a function that can be used to power a `goStage`. It // should read its input from `stdin` and write its output to // `stdout`. `stdin` and `stdout` will be closed automatically (if -// necessary) once the function returns. +// non-nil) once the function returns. // // Neither `stdin` nor `stdout` are necessarily buffered. If the // `StageFunc` requires buffering, it needs to arrange that itself. @@ -38,26 +38,65 @@ type goStage struct { err error } +var ( + _ StageWithIO = (*goStage)(nil) +) + func (s *goStage) Name() string { return s.name } +func (s *goStage) Preferences() StagePreferences { + return StagePreferences{ + StdinPreference: IOPreferenceUndefined, + StdoutPreference: IOPreferenceUndefined, + } +} + func (s *goStage) Start(ctx context.Context, env Env, stdin io.ReadCloser) (io.ReadCloser, error) { - r, w := io.Pipe() + pr, pw := io.Pipe() + + if err := s.StartWithIO(ctx, env, stdin, pw); err != nil { + _ = pr.Close() + _ = pw.Close() + return nil, err + } + + return pr, nil +} + +func (s *goStage) StartWithIO( + ctx context.Context, env Env, stdin io.ReadCloser, stdout io.WriteCloser, +) error { + var r io.Reader = stdin + if stdin, ok := stdin.(readerNopCloser); ok { + r = stdin.Reader + } + + var w io.Writer = stdout + if stdout, ok := stdout.(writerNopCloser); ok { + w = stdout.Writer + } + go func() { - s.err = s.f(ctx, env, stdin, w) - if err := w.Close(); err != nil && s.err == nil { - s.err = fmt.Errorf("error closing output pipe for stage %q: %w", s.Name(), err) + s.err = s.f(ctx, env, r, w) + + if stdout != nil { + if err := stdout.Close(); err != nil && s.err == nil { + s.err = fmt.Errorf("error closing stdout for stage %q: %w", s.Name(), err) + } } + if stdin != nil { if err := stdin.Close(); err != nil && s.err == nil { s.err = fmt.Errorf("error closing stdin for stage %q: %w", s.Name(), err) } } + close(s.done) }() - return r, nil + return nil } func (s *goStage) Wait() error { diff --git a/pipe/iocopier.go b/pipe/iocopier.go index 78a9143..d48df78 100644 --- a/pipe/iocopier.go +++ b/pipe/iocopier.go @@ -7,8 +7,8 @@ import ( "os" ) -// ioCopier is a stage that copies its stdin to a specified -// `io.Writer`. It generates no stdout itself. +// ioCopier is a stage that copies its stdin to `w` then closes it. It +// generates no stdout itself. type ioCopier struct { w io.WriteCloser done chan struct{} @@ -56,6 +56,51 @@ func (s *ioCopier) Start(_ context.Context, _ Env, r io.ReadCloser) (io.ReadClos return nil, nil } +func (s *ioCopier) Preferences() StagePreferences { + return StagePreferences{ + StdinPreference: IOPreferenceUndefined, + StdoutPreference: IOPreferenceNil, + } +} + +// This method always returns `nil`. +func (s *ioCopier) StartWithIO( + _ context.Context, _ Env, stdin io.ReadCloser, stdout io.WriteCloser, +) error { + if stdout != nil { + // We won't write anything to the supplied stdout, so if for + // some reason it wasn't nil, close it immediately: + _ = stdout.Close() + } + + go func() { + _, err := io.Copy(s.w, stdin) + // We don't consider `ErrClosed` an error (FIXME: is this + // correct?): + if err != nil && !errors.Is(err, os.ErrClosed) { + s.err = err + } + if err := stdin.Close(); err != nil && s.err == nil { + s.err = err + } + if err := s.w.Close(); err != nil && s.err == nil { + s.err = err + } + close(s.done) + }() + + // FIXME: if `s.w.Write()` is blocking (e.g., because there is a + // downstream process that is not reading from the other side), + // there's no way to terminate the copy when the context expires. + // This is not too bad, because the `io.Copy()` call will exit by + // itself when its input is closed. + // + // We could, however, be smarter about exiting more quickly if the + // context expires but `s.w.Write()` is not blocking. + + return nil +} + func (s *ioCopier) Wait() error { <-s.done return s.err diff --git a/pipe/memorylimit.go b/pipe/memorylimit.go index f21ee15..7be7fe1 100644 --- a/pipe/memorylimit.go +++ b/pipe/memorylimit.go @@ -11,14 +11,14 @@ import ( const memoryPollInterval = time.Second -// ErrMemoryLimitExceeded is the error that will be used to kill a process, if -// necessary, from MemoryLimit. +// ErrMemoryLimitExceeded is the error that will be used to kill a +// process, if necessary, from MemoryLimit. var ErrMemoryLimitExceeded = errors.New("memory limit exceeded") -// LimitableStage is the superset of Stage that must be implemented by stages -// passed to MemoryLimit and MemoryObserver. -type LimitableStage interface { - Stage +// LimitableStageWithIO is the superset of StageWithIO that must be +// implemented by stages passed to MemoryLimit and MemoryObserver. +type LimitableStageWithIO interface { + StageWithIO GetRSSAnon(context.Context) (uint64, error) Kill(error) @@ -26,9 +26,9 @@ type LimitableStage interface { // MemoryLimit watches the memory usage of the stage and stops it if it // exceeds the given limit. -func MemoryLimit(stage Stage, byteLimit uint64, eventHandler func(e *Event)) Stage { +func MemoryLimit(stage StageWithIO, byteLimit uint64, eventHandler func(e *Event)) Stage { - limitableStage, ok := stage.(LimitableStage) + limitableStage, ok := stage.(LimitableStageWithIO) if !ok { eventHandler(&Event{ Command: stage.Name(), @@ -46,7 +46,7 @@ func MemoryLimit(stage Stage, byteLimit uint64, eventHandler func(e *Event)) Sta } func killAtLimit(byteLimit uint64, eventHandler func(e *Event)) memoryWatchFunc { - return func(ctx context.Context, stage LimitableStage) { + return func(ctx context.Context, stage LimitableStageWithIO) { var consecutiveErrors int t := time.NewTicker(memoryPollInterval) @@ -91,8 +91,8 @@ func killAtLimit(byteLimit uint64, eventHandler func(e *Event)) memoryWatchFunc // MemoryObserver watches memory use of the stage and logs the maximum // value when the stage exits. -func MemoryObserver(stage Stage, eventHandler func(e *Event)) Stage { - limitableStage, ok := stage.(LimitableStage) +func MemoryObserver(stage StageWithIO, eventHandler func(e *Event)) Stage { + limitableStage, ok := stage.(LimitableStageWithIO) if !ok { eventHandler(&Event{ Command: stage.Name(), @@ -110,7 +110,7 @@ func MemoryObserver(stage Stage, eventHandler func(e *Event)) Stage { func logMaxRSS(eventHandler func(e *Event)) memoryWatchFunc { - return func(ctx context.Context, stage LimitableStage) { + return func(ctx context.Context, stage LimitableStageWithIO) { var ( maxRSS uint64 samples, errors, consecutiveErrors int @@ -161,26 +161,51 @@ func logMaxRSS(eventHandler func(e *Event)) memoryWatchFunc { type memoryWatchStage struct { nameSuffix string - stage LimitableStage + stage LimitableStageWithIO watch memoryWatchFunc cancel context.CancelFunc wg sync.WaitGroup } -type memoryWatchFunc func(context.Context, LimitableStage) +type memoryWatchFunc func(context.Context, LimitableStageWithIO) -var _ LimitableStage = (*memoryWatchStage)(nil) +var _ LimitableStageWithIO = (*memoryWatchStage)(nil) func (m *memoryWatchStage) Name() string { return m.stage.Name() + m.nameSuffix } -func (m *memoryWatchStage) Start(ctx context.Context, env Env, stdin io.ReadCloser) (io.ReadCloser, error) { +func (m *memoryWatchStage) Start( + ctx context.Context, env Env, stdin io.ReadCloser, +) (io.ReadCloser, error) { io, err := m.stage.Start(ctx, env, stdin) if err != nil { return nil, err } + m.monitor(ctx) + + return io, nil +} + +func (m *memoryWatchStage) Preferences() StagePreferences { + return m.stage.Preferences() +} + +func (m *memoryWatchStage) StartWithIO( + ctx context.Context, env Env, stdin io.ReadCloser, stdout io.WriteCloser, +) error { + if err := m.stage.StartWithIO(ctx, env, stdin, stdout); err != nil { + return err + } + + m.monitor(ctx) + + return nil +} + +// monitor starts up a goroutine that monitors the memory of `m`. +func (m *memoryWatchStage) monitor(ctx context.Context) { ctx, cancel := context.WithCancel(ctx) m.cancel = cancel m.wg.Add(1) @@ -189,8 +214,6 @@ func (m *memoryWatchStage) Start(ctx context.Context, env Env, stdin io.ReadClos m.watch(ctx, m.stage) m.wg.Done() }() - - return io, nil } func (m *memoryWatchStage) Wait() error { diff --git a/pipe/memorylimit_test.go b/pipe/memorylimit_test.go index 7501c80..ffecae2 100644 --- a/pipe/memorylimit_test.go +++ b/pipe/memorylimit_test.go @@ -8,6 +8,7 @@ import ( "log" "os" "strings" + "syscall" "testing" "time" @@ -48,7 +49,7 @@ func TestMemoryObserverTreeMem(t *testing.T) { require.Greater(t, rss, 400_000_000) } -func testMemoryObserver(t *testing.T, mbs int, stage pipe.Stage) int { +func testMemoryObserver(t *testing.T, mbs int, stage pipe.StageWithIO) int { ctx := context.Background() stdinReader, stdinWriter := io.Pipe() @@ -112,54 +113,36 @@ func TestMemoryLimitTreeMem(t *testing.T) { require.ErrorContains(t, err, "memory limit exceeded") } -type closeWrapper struct { - io.Writer - close func() error -} - -func (w closeWrapper) Close() error { - return w.close() -} - -func testMemoryLimit(t *testing.T, mbs int, limit uint64, stage pipe.Stage) (string, error) { +func testMemoryLimit(t *testing.T, mbs int, limit uint64, stage pipe.StageWithIO) (string, error) { ctx := context.Background() - stdinReader, stdinWriter := io.Pipe() - devNull, err := os.OpenFile("/dev/null", os.O_WRONLY, 0) require.NoError(t, err) - // io.Pipe doesn't know if anything is listening on the other end, so once - // our process is expectedly killed then we'll end up blocked trying to - // write to it. To workaround this, make sure we close the pipe reader when - // we've detected that the process has exited (i.e. when stdout has been - // closed). This will cause our write to immediately fail with this error. - closedErr := fmt.Errorf("stdout was closed") - stdout := closeWrapper{ - Writer: devNull, - close: func() error { - require.NoError(t, stdinReader.CloseWithError(closedErr)) - return nil - }, - } - buf := &bytes.Buffer{} logger := log.New(buf, "testMemoryObserver", log.Ldate|log.Ltime) - p := pipe.New(pipe.WithDir("/"), pipe.WithStdin(stdinReader), pipe.WithStdoutCloser(stdout)) - p.Add(pipe.MemoryLimit(stage, limit, LogEventHandler(logger))) + p := pipe.New(pipe.WithDir("/"), pipe.WithStdoutCloser(devNull)) + p.Add( + pipe.Function( + "write-to-less", + func(ctx context.Context, _ pipe.Env, _ io.Reader, stdout io.Writer) error { + // Write some nonsense data to less. + var bytes [1_000_000]byte + for i := 0; i < mbs; i++ { + _, err := stdout.Write(bytes[:]) + if err != nil { + require.ErrorIs(t, err, syscall.EPIPE) + } + } + + return nil + }, + ), + pipe.MemoryLimit(stage, limit, LogEventHandler(logger)), + ) require.NoError(t, p.Start(ctx)) - // Write some nonsense data to less. - var bytes [1_000_000]byte - for i := 0; i < mbs; i++ { - _, err := stdinWriter.Write(bytes[:]) - if err != nil { - require.ErrorIs(t, err, closedErr) - } - } - - require.NoError(t, stdinWriter.Close()) err = p.Wait() return buf.String(), err diff --git a/pipe/nop_closer.go b/pipe/nop_closer.go index d435d0a..18cf7a9 100644 --- a/pipe/nop_closer.go +++ b/pipe/nop_closer.go @@ -6,29 +6,64 @@ package pipe import "io" -// newNopCloser returns a ReadCloser with a no-op Close method wrapping -// the provided io.Reader r. -// If r implements io.WriterTo, the returned io.ReadCloser will implement io.WriterTo -// by forwarding calls to r. -func newNopCloser(r io.Reader) io.ReadCloser { +// newReaderNopCloser returns a ReadCloser with a no-op Close method, +// wrapping the provided io.Reader `r`. If `r` implements +// `io.WriterTo`, the returned `io.ReadCloser` will also implement +// `io.WriterTo` by forwarding calls to `r`. +func newReaderNopCloser(r io.Reader) io.ReadCloser { if _, ok := r.(io.WriterTo); ok { - return nopCloserWriterTo{r} + return readerWriterToNopCloser{r} } - return nopCloser{r} + return readerNopCloser{r} } -type nopCloser struct { +// readerNopCloser is a ReadCloser that wraps a provided `io.Reader`, +// but whose `Close()` method does nothing. We don't need to check +// whether the wrapped reader also implements `io.WriterTo`, since +// it's always unwrapped before use. +type readerNopCloser struct { io.Reader } -func (nopCloser) Close() error { return nil } +func (readerNopCloser) Close() error { + return nil +} -type nopCloserWriterTo struct { +// readerWriterToNopCloser is like `readerNopCloser` except that it +// also implements `io.WriterTo` by delegating `WriteTo()` to the +// wrapped `io.Reader` (which must also implement `io.WriterTo`). +type readerWriterToNopCloser struct { io.Reader } -func (nopCloserWriterTo) Close() error { return nil } +func (readerWriterToNopCloser) Close() error { return nil } + +func (r readerWriterToNopCloser) WriteTo(w io.Writer) (n int64, err error) { + return r.Reader.(io.WriterTo).WriteTo(w) +} + +// writerNopCloser is a WriteCloser that wraps a provided `io.Writer`, +// but whose `Close()` method does nothing. +type writerNopCloser struct { + io.Writer +} + +func (w writerNopCloser) Close() error { + return nil +} -func (c nopCloserWriterTo) WriteTo(w io.Writer) (n int64, err error) { - return c.Reader.(io.WriterTo).WriteTo(w) +// unwrapNopCloser unwraps the object if it is some kind of nop +// closer, and returns the underlying object. This function is used +// only for testing. +func unwrapNopCloser(obj any) (any, bool) { + switch obj := obj.(type) { + case readerNopCloser: + return obj.Reader, true + case readerWriterToNopCloser: + return obj.Reader, true + case writerNopCloser: + return obj.Writer, true + default: + return nil, false + } } diff --git a/pipe/pipe_matching_test.go b/pipe/pipe_matching_test.go new file mode 100644 index 0000000..4ff1c2c --- /dev/null +++ b/pipe/pipe_matching_test.go @@ -0,0 +1,553 @@ +package pipe_test + +import ( + "context" + "fmt" + "io" + "os" + "testing" + + "github.com/github/go-pipe/pipe" + "github.com/stretchr/testify/assert" +) + +// Tests that `Pipeline.Start()` uses the correct types of pipes in +// various situations. +// +// The type of pipe to use depends on both the source and the consumer +// of the data, including the overall pipeline's stdin and stdout. So +// there are a lot of possibilities to consider. + +// Additional values used for the expected types of stdin/stdout: +const ( + IOPreferenceUndefinedNopCloser pipe.IOPreference = iota + 100 + IOPreferenceFileNopCloser +) + +func file(t *testing.T) *os.File { + f, err := os.Open(os.DevNull) + assert.NoError(t, err) + return f +} + +func readCloser() io.ReadCloser { + r, w := io.Pipe() + w.Close() + return r +} + +func writeCloser() io.WriteCloser { + r, w := io.Pipe() + r.Close() + return w +} + +func newPipeSniffingStage1( + retval io.ReadCloser, stdinExpectation pipe.IOPreference, +) *pipeSniffingStage1 { + return &pipeSniffingStage1{ + StdinExpectation: stdinExpectation, + retval: retval, + } +} + +type pipeSniffingStage1 struct { + StdinExpectation pipe.IOPreference + retval io.ReadCloser + stdin io.ReadCloser +} + +func newPipeSniffingFunc1(stdinExpectation pipe.IOPreference) *pipeSniffingStage1 { + return newPipeSniffingStage1(readCloser(), stdinExpectation) +} + +func newPipeSniffingCmd1(t *testing.T, stdinExpectation pipe.IOPreference) *pipeSniffingStage1 { + return newPipeSniffingStage1(file(t), stdinExpectation) +} + +func (*pipeSniffingStage1) Name() string { + return "pipe-sniffer" +} + +func (s *pipeSniffingStage1) Start( + _ context.Context, _ pipe.Env, stdin io.ReadCloser, +) (io.ReadCloser, error) { + s.stdin = stdin + if stdin != nil { + _ = stdin.Close() + } + + return s.retval, nil +} + +func (s *pipeSniffingStage1) Wait() error { + return nil +} + +func (s *pipeSniffingStage1) check(t *testing.T, i int) { + t.Helper() + + checkStdinExpectation(t, i, s.StdinExpectation, s.stdin) +} + +func newPipeSniffingStageWithIO( + stdinPreference, stdinExpectation pipe.IOPreference, + stdoutPreference, stdoutExpectation pipe.IOPreference, +) *pipeSniffingStageWithIO { + return &pipeSniffingStageWithIO{ + prefs: pipe.StagePreferences{ + StdinPreference: stdinPreference, + StdoutPreference: stdoutPreference, + }, + expect: pipe.StagePreferences{ + StdinPreference: stdinExpectation, + StdoutPreference: stdoutExpectation, + }, + } +} + +func newPipeSniffingFuncWithIO( + stdinExpectation, stdoutExpectation pipe.IOPreference, +) *pipeSniffingStageWithIO { + return newPipeSniffingStageWithIO( + pipe.IOPreferenceUndefined, stdinExpectation, + pipe.IOPreferenceUndefined, stdoutExpectation, + ) +} + +func newPipeSniffingCmdWithIO( + stdinExpectation, stdoutExpectation pipe.IOPreference, +) *pipeSniffingStageWithIO { + return newPipeSniffingStageWithIO( + pipe.IOPreferenceFile, stdinExpectation, + pipe.IOPreferenceFile, stdoutExpectation, + ) +} + +type pipeSniffingStageWithIO struct { + prefs pipe.StagePreferences + expect pipe.StagePreferences + stdin io.ReadCloser + stdout io.WriteCloser +} + +func (*pipeSniffingStageWithIO) Name() string { + return "pipe-sniffer" +} + +func (s *pipeSniffingStageWithIO) Start( + _ context.Context, _ pipe.Env, _ io.ReadCloser, +) (io.ReadCloser, error) { + panic("Start() called for a StageWithIO") +} + +func (s *pipeSniffingStageWithIO) Preferences() pipe.StagePreferences { + return s.prefs +} + +func (s *pipeSniffingStageWithIO) StartWithIO( + _ context.Context, _ pipe.Env, stdin io.ReadCloser, stdout io.WriteCloser, +) error { + s.stdin = stdin + if stdin != nil { + _ = stdin.Close() + } + s.stdout = stdout + if stdout != nil { + _ = stdout.Close() + } + return nil +} + +func (s *pipeSniffingStageWithIO) check(t *testing.T, i int) { + t.Helper() + + checkStdinExpectation(t, i, s.expect.StdinPreference, s.stdin) + checkStdoutExpectation(t, i, s.expect.StdoutPreference, s.stdout) +} + +func (s *pipeSniffingStageWithIO) Wait() error { + return nil +} + +var _ pipe.StageWithIO = (*pipeSniffingStageWithIO)(nil) + +func ioTypeString(f any) string { + if f == nil { + return "nil" + } + if f, ok := pipe.UnwrapNopCloser(f); ok { + return fmt.Sprintf("nopCloser(%s)", ioTypeString(f)) + } + switch f := f.(type) { + case *os.File: + return "*os.File" + case io.Reader: + return "other" + case io.Writer: + return "other" + default: + return fmt.Sprintf("%T", f) + } +} + +func prefString(pref pipe.IOPreference) string { + switch pref { + case pipe.IOPreferenceUndefined: + return "other" + case pipe.IOPreferenceFile: + return "*os.File" + case pipe.IOPreferenceNil: + return "nil" + case IOPreferenceUndefinedNopCloser: + return "nopCloser(other)" + case IOPreferenceFileNopCloser: + return "nopCloser(*os.File)" + default: + panic(fmt.Sprintf("invalid IOPreference: %d", pref)) + } +} + +type ReaderNopCloser interface { + NopCloserReader() io.Reader +} + +func checkStdinExpectation(t *testing.T, i int, pref pipe.IOPreference, stdin io.ReadCloser) { + t.Helper() + + ioType := ioTypeString(stdin) + expType := prefString(pref) + assert.Equalf( + t, expType, ioType, + "stage %d stdin: expected %s, got %s (%T)", i, expType, ioType, stdin, + ) +} + +type WriterNopCloser interface { + NopCloserWriter() io.Writer +} + +func checkStdoutExpectation(t *testing.T, i int, pref pipe.IOPreference, stdout io.WriteCloser) { + t.Helper() + + ioType := ioTypeString(stdout) + expType := prefString(pref) + assert.Equalf( + t, expType, ioType, + "stage %d stdout: expected %s, got %s (%T)", i, expType, ioType, stdout, + ) +} + +type checker interface { + check(t *testing.T, i int) +} + +func TestPipeTypes(t *testing.T) { + ctx := context.Background() + + t.Parallel() + + for _, tc := range []struct { + name string + opts []pipe.Option + stages []pipe.Stage + stdin io.Reader + stdout io.Writer + }{ + { + name: "func2", + opts: []pipe.Option{}, + stages: []pipe.Stage{ + newPipeSniffingFuncWithIO(pipe.IOPreferenceNil, pipe.IOPreferenceNil), + }, + }, + { + name: "func2-file-stdin", + opts: []pipe.Option{ + pipe.WithStdin(file(t)), + }, + stages: []pipe.Stage{ + newPipeSniffingFuncWithIO(IOPreferenceFileNopCloser, pipe.IOPreferenceNil), + }, + }, + { + name: "func2-file-stdout", + opts: []pipe.Option{ + pipe.WithStdout(file(t)), + }, + stages: []pipe.Stage{ + newPipeSniffingFuncWithIO(pipe.IOPreferenceNil, IOPreferenceFileNopCloser), + }, + }, + { + name: "func2-file-stdout-closer", + opts: []pipe.Option{ + pipe.WithStdoutCloser(file(t)), + }, + stages: []pipe.Stage{ + newPipeSniffingFuncWithIO(pipe.IOPreferenceNil, pipe.IOPreferenceFile), + }, + }, + { + name: "func2-file-stdin-other-stdout-closer-other", + opts: []pipe.Option{ + pipe.WithStdin(readCloser()), + pipe.WithStdoutCloser(writeCloser()), + }, + stages: []pipe.Stage{ + newPipeSniffingFuncWithIO(IOPreferenceUndefinedNopCloser, pipe.IOPreferenceUndefined), + }, + }, + { + name: "cmd2", + opts: []pipe.Option{}, + stages: []pipe.Stage{ + newPipeSniffingCmdWithIO(pipe.IOPreferenceNil, pipe.IOPreferenceNil), + }, + }, + { + name: "cmd2-file-stdin", + opts: []pipe.Option{ + pipe.WithStdin(file(t)), + }, + stages: []pipe.Stage{ + newPipeSniffingCmdWithIO(IOPreferenceFileNopCloser, pipe.IOPreferenceNil), + }, + }, + { + name: "cmd2-file-stdout", + opts: []pipe.Option{ + pipe.WithStdout(file(t)), + }, + stages: []pipe.Stage{ + newPipeSniffingCmdWithIO(pipe.IOPreferenceNil, IOPreferenceFileNopCloser), + }, + }, + { + name: "cmd2-file-stdout-closer", + opts: []pipe.Option{ + pipe.WithStdoutCloser(file(t)), + }, + stages: []pipe.Stage{ + newPipeSniffingCmdWithIO(pipe.IOPreferenceNil, pipe.IOPreferenceFile), + }, + }, + { + name: "cmd2-file-stdin-other-stdout-closer-other", + opts: []pipe.Option{ + pipe.WithStdin(readCloser()), + pipe.WithStdoutCloser(writeCloser()), + }, + stages: []pipe.Stage{ + newPipeSniffingCmdWithIO(IOPreferenceUndefinedNopCloser, pipe.IOPreferenceUndefined), + }, + }, + { + name: "func1", + opts: []pipe.Option{}, + stages: []pipe.Stage{ + newPipeSniffingFunc1(pipe.IOPreferenceNil), + }, + }, + { + name: "func1-file-stdin", + opts: []pipe.Option{ + pipe.WithStdin(file(t)), + }, + stages: []pipe.Stage{ + newPipeSniffingFunc1(IOPreferenceFileNopCloser), + }, + }, + { + name: "func1-file-stdout", + opts: []pipe.Option{ + pipe.WithStdout(file(t)), + }, + stages: []pipe.Stage{ + newPipeSniffingFunc1(pipe.IOPreferenceNil), + }, + }, + { + name: "func1-file-stdin-other-stdout-closer-other", + opts: []pipe.Option{ + pipe.WithStdin(readCloser()), + pipe.WithStdoutCloser(writeCloser()), + }, + stages: []pipe.Stage{ + newPipeSniffingFunc1(IOPreferenceUndefinedNopCloser), + }, + }, + { + name: "func2-func2", + opts: []pipe.Option{ + pipe.WithStdin(file(t)), + pipe.WithStdoutCloser(writeCloser()), + }, + stages: []pipe.Stage{ + newPipeSniffingFuncWithIO(IOPreferenceFileNopCloser, pipe.IOPreferenceUndefined), + newPipeSniffingFuncWithIO(pipe.IOPreferenceUndefined, pipe.IOPreferenceUndefined), + }, + }, + { + name: "func2-cmd2", + opts: []pipe.Option{ + pipe.WithStdout(file(t)), + }, + stages: []pipe.Stage{ + newPipeSniffingFuncWithIO(pipe.IOPreferenceNil, pipe.IOPreferenceFile), + newPipeSniffingCmdWithIO(pipe.IOPreferenceFile, IOPreferenceFileNopCloser), + }, + }, + { + name: "cmd2-func2", + opts: []pipe.Option{ + pipe.WithStdin(readCloser()), + }, + stages: []pipe.Stage{ + newPipeSniffingCmdWithIO(IOPreferenceUndefinedNopCloser, pipe.IOPreferenceFile), + newPipeSniffingFuncWithIO(pipe.IOPreferenceFile, pipe.IOPreferenceNil), + }, + }, + { + name: "cmd2-cmd2", + opts: []pipe.Option{}, + stages: []pipe.Stage{ + newPipeSniffingCmdWithIO(pipe.IOPreferenceNil, pipe.IOPreferenceFile), + newPipeSniffingCmdWithIO(pipe.IOPreferenceFile, pipe.IOPreferenceNil), + }, + }, + { + name: "func1-func2", + opts: []pipe.Option{}, + stages: []pipe.Stage{ + newPipeSniffingFunc1(pipe.IOPreferenceNil), + newPipeSniffingFuncWithIO(pipe.IOPreferenceUndefined, pipe.IOPreferenceNil), + }, + }, + { + name: "cmd1-func2", + opts: []pipe.Option{}, + stages: []pipe.Stage{ + newPipeSniffingCmd1(t, pipe.IOPreferenceNil), + newPipeSniffingFuncWithIO(pipe.IOPreferenceFile, pipe.IOPreferenceNil), + }, + }, + { + name: "func1-cmd2", + opts: []pipe.Option{ + pipe.WithStdin(readCloser()), + }, + stages: []pipe.Stage{ + newPipeSniffingFunc1(IOPreferenceUndefinedNopCloser), + newPipeSniffingCmdWithIO(pipe.IOPreferenceUndefined, pipe.IOPreferenceNil), + }, + }, + { + name: "cmd1-cmd2", + opts: []pipe.Option{ + pipe.WithStdin(readCloser()), + }, + stages: []pipe.Stage{ + newPipeSniffingCmd1(t, IOPreferenceUndefinedNopCloser), + newPipeSniffingCmdWithIO(pipe.IOPreferenceFile, pipe.IOPreferenceNil), + }, + }, + { + name: "func1-func1", + opts: []pipe.Option{ + pipe.WithStdin(file(t)), + }, + stages: []pipe.Stage{ + newPipeSniffingFunc1(IOPreferenceFileNopCloser), + newPipeSniffingFunc1(pipe.IOPreferenceUndefined), + }, + }, + { + name: "cmd1-func1", + opts: []pipe.Option{ + pipe.WithStdin(file(t)), + }, + stages: []pipe.Stage{ + newPipeSniffingCmd1(t, IOPreferenceFileNopCloser), + newPipeSniffingFunc1(pipe.IOPreferenceFile), + }, + }, + { + name: "func1-cmd1", + opts: []pipe.Option{ + pipe.WithStdin(file(t)), + }, + stages: []pipe.Stage{ + newPipeSniffingFunc1(IOPreferenceFileNopCloser), + newPipeSniffingCmd1(t, pipe.IOPreferenceUndefined), + }, + }, + { + name: "func2-func1", + opts: []pipe.Option{}, + stages: []pipe.Stage{ + newPipeSniffingFuncWithIO(pipe.IOPreferenceNil, pipe.IOPreferenceUndefined), + newPipeSniffingFunc1(pipe.IOPreferenceUndefined), + }, + }, + { + name: "cmd2-func1", + opts: []pipe.Option{ + pipe.WithStdin(readCloser()), + }, + stages: []pipe.Stage{ + newPipeSniffingCmdWithIO(IOPreferenceUndefinedNopCloser, pipe.IOPreferenceFile), + newPipeSniffingFunc1(pipe.IOPreferenceFile), + }, + }, + { + name: "hybrid1", + opts: []pipe.Option{}, + stages: []pipe.Stage{ + newPipeSniffingStageWithIO( + pipe.IOPreferenceUndefined, pipe.IOPreferenceNil, + pipe.IOPreferenceUndefined, pipe.IOPreferenceUndefined, + ), + newPipeSniffingStageWithIO( + pipe.IOPreferenceUndefined, pipe.IOPreferenceUndefined, + pipe.IOPreferenceFile, pipe.IOPreferenceFile, + ), + newPipeSniffingStageWithIO( + pipe.IOPreferenceUndefined, pipe.IOPreferenceFile, + pipe.IOPreferenceUndefined, pipe.IOPreferenceNil, + ), + }, + }, + { + name: "hybrid2", + opts: []pipe.Option{}, + stages: []pipe.Stage{ + newPipeSniffingStageWithIO( + pipe.IOPreferenceUndefined, pipe.IOPreferenceNil, + pipe.IOPreferenceUndefined, pipe.IOPreferenceFile, + ), + newPipeSniffingStageWithIO( + pipe.IOPreferenceFile, pipe.IOPreferenceFile, + pipe.IOPreferenceUndefined, pipe.IOPreferenceUndefined, + ), + newPipeSniffingStageWithIO( + pipe.IOPreferenceUndefined, pipe.IOPreferenceUndefined, + pipe.IOPreferenceUndefined, pipe.IOPreferenceNil, + ), + }, + }, + } { + tc := tc + + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + p := pipe.New(tc.opts...) + p.Add(tc.stages...) + assert.NoError(t, p.Run(ctx)) + for i, s := range tc.stages { + s.(checker).check(t, i) + } + }) + } +} diff --git a/pipe/pipeline.go b/pipe/pipeline.go index e591c63..af1219d 100644 --- a/pipe/pipeline.go +++ b/pipe/pipeline.go @@ -6,6 +6,7 @@ import ( "errors" "fmt" "io" + "os" "sync/atomic" ) @@ -53,7 +54,7 @@ type ContextValuesFunc func(context.Context) []EnvVar type Pipeline struct { env Env - stdin io.Reader + stdin io.ReadCloser stdout io.WriteCloser stages []Stage cancel func() @@ -68,14 +69,6 @@ type Pipeline struct { var emptyEventHandler = func(e *Event) {} -type nopWriteCloser struct { - io.Writer -} - -func (w nopWriteCloser) Close() error { - return nil -} - type NewPipeFn func(opts ...Option) *Pipeline // NewPipeline returns a Pipeline struct with all of the `options` @@ -105,14 +98,58 @@ func WithDir(dir string) Option { // WithStdin assigns stdin to the first command in the pipeline. func WithStdin(stdin io.Reader) Option { return func(p *Pipeline) { - p.stdin = stdin + // We don't want the first stage to close `stdin`, and it is + // not even necessarily an `io.ReadCloser`. So wrap it in a + // fake `io.ReadCloser` whose `Close()` method doesn't do + // anything. + // + // We could use `io.NopCloser()` for this purpose, but that + // would have a subtle problem. If the first stage is a + // `Command`, then it wants to set the `exec.Cmd`'s `Stdin` to + // an `io.Reader` corresponding to `p.stdin`. If `Cmd.Stdin` + // is an `*os.File`, then `exec.Cmd` will pass the file + // descriptor to the subcommand directly; there is no need to + // create a pipe and copy the data into the input side of the + // pipe. But if `p.stdin` is not an `*os.File`, then this + // optimization is prevented. And even worse, it also has the + // side effect that the goroutine that copies from `Cmd.Stdin` + // into the pipe doesn't terminate until that fd is closed by + // the writing side. + // + // That isn't always what we want. Consider, for example, the + // following snippet, where the subcommand's stdin is set to + // the stdin of the enclosing Go program, but wrapped with + // `io.NopCloser`: + // + // cmd := exec.Command("ls") + // cmd.Stdin = io.NopCloser(os.Stdin) + // cmd.Stdout = os.Stdout + // cmd.Stderr = os.Stderr + // cmd.Run() + // + // In this case, we don't want the Go program to wait for + // `os.Stdin` to close (because `ls` isn't even trying to read + // from its stdin). But it does: `exec.Cmd` doesn't recognize + // that `Cmd.Stdin` is an `*os.File`, so it sets up a pipe and + // copies the data itself, and this goroutine doesn't + // terminate until `cmd.Stdin` (i.e., the Go program's own + // stdin) is closed. But if, for example, the Go program is + // run from an interactive shell session, that might never + // happen, in which case the program will fail to terminate, + // even after `ls` exits. + // + // So instead, in this special case, we wrap `stdin` in our + // own `nopCloser`, which behaves like `io.NopCloser`, except + // that `pipe.CommandStage` knows how to unwrap it before + // passing it to `exec.Cmd`. + p.stdin = newReaderNopCloser(stdin) } } // WithStdout assigns stdout to the last command in the pipeline. func WithStdout(stdout io.Writer) Option { return func(p *Pipeline) { - p.stdout = nopWriteCloser{stdout} + p.stdout = writerNopCloser{stdout} } } @@ -204,6 +241,13 @@ func (p *Pipeline) AddWithIgnoredError(em ErrorMatcher, stages ...Stage) { } } +type stageStarter struct { + stageWithIO StageWithIO + prefs StagePreferences + stdin io.ReadCloser + stdout io.WriteCloser +} + // Start starts the commands in the pipeline. If `Start()` exits // without an error, `Wait()` must also be called, to allow all // resources to be freed. @@ -215,89 +259,139 @@ func (p *Pipeline) Start(ctx context.Context) error { atomic.StoreUint32(&p.started, 1) ctx, p.cancel = context.WithCancel(ctx) - var nextStdin io.ReadCloser - if p.stdin != nil { - // We don't want the first stage to actually close this, and - // `p.stdin` is not even necessarily an `io.ReadCloser`. So - // wrap it in a fake `io.ReadCloser` whose `Close()` method - // doesn't do anything. - // - // We could use `io.NopCloser()` for this purpose, but it has - // a subtle problem. If the first stage is a `Command`, then - // it wants to set the `exec.Cmd`'s `Stdin` to an `io.Reader` - // corresponding to `p.stdin`. If `Cmd.Stdin` is an - // `*os.File`, then the file descriptor can be passed to the - // subcommand directly; there is no need for this process to - // create a pipe and copy the data into the input side of the - // pipe. But if `p.stdin` is not an `*os.File`, then this - // optimization is prevented. And even worse, it also has the - // side effect that the goroutine that copies from `Cmd.Stdin` - // into the pipe doesn't terminate until that fd is closed by - // the writing side. - // - // That isn't always what we want. Consider, for example, the - // following snippet, where the subcommand's stdin is set to - // the stdin of the enclosing Go program, but wrapped with - // `io.NopCloser`: - // - // cmd := exec.Command("ls") - // cmd.Stdin = io.NopCloser(os.Stdin) - // cmd.Stdout = os.Stdout - // cmd.Stderr = os.Stderr - // cmd.Run() - // - // In this case, we don't want the Go program to wait for - // `os.Stdin` to close (because `ls` isn't even trying to read - // from its stdin). But it does: `exec.Cmd` doesn't recognize - // that `Cmd.Stdin` is an `*os.File`, so it sets up a pipe and - // copies the data itself, and this goroutine doesn't - // terminate until `cmd.Stdin` (i.e., the Go program's own - // stdin) is closed. But if, for example, the Go program is - // run from an interactive shell session, that might never - // happen, in which case the program will fail to terminate, - // even after `ls` exits. - // - // So instead, in this special case, we wrap `p.stdin` in our - // own `nopCloser`, which behaves like `io.NopCloser`, except - // that `pipe.CommandStage` knows how to unwrap it before - // passing it to `exec.Cmd`. - nextStdin = newNopCloser(p.stdin) - } + // We need to decide how to start the stages, especially whether + // to use `Stage.Start()` vs. `Stage.StartWithIO()`, and, if the + // latter, what pipes to use to connect adjacent stages + // (`os.Pipe()` vs. `io.Pipe()`) based on the two stages' + // preferences. + stageStarters := make([]stageStarter, len(p.stages), len(p.stages)+1) + // Collect information about each stage's type and preferences: for i, s := range p.stages { - var err error - stdout, err := s.Start(ctx, p.env, nextStdin) - if err != nil { - // Close the pipe that the previous stage was writing to. - // That should cause it to exit even if it's not minding - // its context. - if nextStdin != nil { - _ = nextStdin.Close() + ss := &stageStarters[i] + if s, ok := s.(StageWithIO); ok { + ss.stageWithIO = s + ss.prefs = s.Preferences() + } else { + ss.prefs = StagePreferences{ + StdinPreference: IOPreferenceUndefined, + StdoutPreference: IOPreferenceUndefined, } + } + } - // Kill and wait for any stages that have been started - // already to finish: - p.cancel() - for _, s := range p.stages[:i] { - _ = s.Wait() - } - p.eventHandler(&Event{ - Command: s.Name(), - Msg: "failed to start pipeline stage", - Err: err, + if p.stdin != nil { + // Arrange for the input of the 0th stage to come from + // `p.stdin`: + stageStarters[0].stdin = p.stdin + } + + // The handling of the last stage depends on whether it is a + // `Stage` or a `StageWithIO`. + if p.stdout != nil { + i := len(p.stages) - 1 + ss := &stageStarters[i] + + if ss.stageWithIO != nil { + ss.stdout = p.stdout + } else { + // If `p.stdout` is set but the last stage is not a + // `StageWithIO`, then we need to add an extra, synthetic stage + // to copy its output to `p.stdout`. + c := newIOCopier(p.stdout) + p.stages = append(p.stages, c) + stageStarters = append(stageStarters, stageStarter{ + stageWithIO: c, + prefs: c.Preferences(), }) - return fmt.Errorf("starting pipeline stage %q: %w", s.Name(), err) } - nextStdin = stdout } - // If the pipeline was configured with a `stdout`, add a synthetic - // stage to copy the last stage's stdout to that writer: - if p.stdout != nil { - c := newIOCopier(p.stdout) - p.stages = append(p.stages, c) - // `ioCopier.Start()` never fails: - _, _ = c.Start(ctx, p.env, nextStdin) + // Clean up any processes and pipes that have been created. `i` is + // the index of the stage that failed to start (whose output pipe + // has already been cleaned up if necessary). + abort := func(i int, err error) error { + // Close the pipe that the previous stage was writing to. + // That should cause it to exit even if it's not minding + // its context. + if stageStarters[i].stdin != nil { + _ = stageStarters[i].stdin.Close() + } + + // Kill and wait for any stages that have been started + // already to finish: + p.cancel() + for _, s := range p.stages[:i] { + _ = s.Wait() + } + p.eventHandler(&Event{ + Command: p.stages[i].Name(), + Msg: "failed to start pipeline stage", + Err: err, + }) + return fmt.Errorf( + "starting pipeline stage %q: %w", p.stages[i].Name(), err, + ) + } + + // Loop over all but the last stage, starting them. By the time we + // get to a stage, its stdin will have already been determined, + // but we still need to figure out its stdout and set the stdin + // that will be used for the subsequent stage. + for i, s := range p.stages[:len(p.stages)-1] { + ss := &stageStarters[i] + + nextSS := &stageStarters[i+1] + + if ss.stageWithIO != nil { + // We need to generate a pipe pair for this stage to use + // to communicate with its successor: + if ss.prefs.StdoutPreference == IOPreferenceFile || + nextSS.prefs.StdinPreference == IOPreferenceFile { + // Use an OS-level pipe for the communication: + var err error + nextSS.stdin, ss.stdout, err = os.Pipe() + if err != nil { + return abort(i, err) + } + } else { + nextSS.stdin, ss.stdout = io.Pipe() + } + if err := ss.stageWithIO.StartWithIO(ctx, p.env, ss.stdin, ss.stdout); err != nil { + nextSS.stdin.Close() + ss.stdout.Close() + return abort(i, err) + } + } else { + // The stage will create its own stdout when we start + // it: + var err error + nextSS.stdin, err = s.Start(ctx, p.env, ss.stdin) + if err != nil { + return abort(i, err) + } + } + } + + // The last stage needs special handling, because its stdout + // doesn't need to flow into another stage (it's already set in + // `ss.stdout` if it's needed). + { + i := len(p.stages) - 1 + s := p.stages[i] + ss := &stageStarters[i] + + if ss.stageWithIO != nil { + if err := ss.stageWithIO.StartWithIO(ctx, p.env, ss.stdin, ss.stdout); err != nil { + return abort(i, err) + } + } else { + var err error + _, err = s.Start(ctx, p.env, ss.stdin) + if err != nil { + return abort(i, err) + } + } } return nil @@ -305,7 +399,7 @@ func (p *Pipeline) Start(ctx context.Context) error { func (p *Pipeline) Output(ctx context.Context) ([]byte, error) { var buf bytes.Buffer - p.stdout = nopWriteCloser{&buf} + p.stdout = writerNopCloser{&buf} err := p.Run(ctx) return buf.Bytes(), err } diff --git a/pipe/pipeline_test.go b/pipe/pipeline_test.go index d925aee..e85d7d1 100644 --- a/pipe/pipeline_test.go +++ b/pipe/pipeline_test.go @@ -87,7 +87,7 @@ func TestPipelineSingleCommandWithStdout(t *testing.T) { } } -func TestPipelineStdinFileThatIsNeverClosed(t *testing.T) { +func TestPipelineStdinOSPipeThatIsNeverClosed(t *testing.T) { t.Parallel() // Make sure that the subprocess terminates on its own, as opposed @@ -105,7 +105,10 @@ func TestPipelineStdinFileThatIsNeverClosed(t *testing.T) { var stdout bytes.Buffer - p := pipe.New(pipe.WithStdin(r), pipe.WithStdout(&stdout)) + p := pipe.New( + pipe.WithStdin(r), + pipe.WithStdout(&stdout), + ) // Note that this command doesn't read from its stdin, so it will // terminate regardless of whether `w` gets closed: p.Add(pipe.Command("true")) @@ -115,7 +118,7 @@ func TestPipelineStdinFileThatIsNeverClosed(t *testing.T) { assert.NoError(t, p.Run(ctx)) } -func TestPipelineStdinThatIsNeverClosed(t *testing.T) { +func TestPipelineIOPipeStdinThatIsNeverClosed(t *testing.T) { t.Skip("test not run because it currently deadlocks") t.Parallel() @@ -131,8 +134,7 @@ func TestPipelineStdinThatIsNeverClosed(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond) defer cancel() - r, w, err := os.Pipe() - require.NoError(t, err) + r, w := io.Pipe() t.Cleanup(func() { _ = w.Close() _ = r.Close() @@ -140,10 +142,8 @@ func TestPipelineStdinThatIsNeverClosed(t *testing.T) { var stdout bytes.Buffer - // The point here is to wrap `r` so that `exec.Cmd` doesn't - // recognize that it's an `*os.File`: p := pipe.New( - pipe.WithStdin(io.NopCloser(r)), + pipe.WithStdin(r), pipe.WithStdout(&stdout), ) // Note that this command doesn't read from its stdin, so it will @@ -159,9 +159,7 @@ func TestNontrivialPipeline(t *testing.T) { t.Parallel() ctx := context.Background() - dir := t.TempDir() - - p := pipe.New(pipe.WithDir(dir)) + p := pipe.New() p.Add( pipe.Command("echo", "hello world"), pipe.Command("sed", "s/hello/goodbye/"), @@ -172,12 +170,13 @@ func TestNontrivialPipeline(t *testing.T) { } } -func TestPipelineReadFromSlowly(t *testing.T) { +func TestOSPipePipelineReadFromSlowly(t *testing.T) { t.Parallel() ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) defer cancel() - r, w := io.Pipe() + r, w, err := os.Pipe() + require.NoError(t, err) var buf []byte readErr := make(chan error, 1) @@ -189,14 +188,34 @@ func TestPipelineReadFromSlowly(t *testing.T) { readErr <- err }() - p := pipe.New(pipe.WithStdout(w)) + p := pipe.New(pipe.WithStdoutCloser(w)) p.Add(pipe.Command("echo", "hello world")) assert.NoError(t, p.Run(ctx)) - time.Sleep(100 * time.Millisecond) - // It's not super-intuitive, but `w` has to be closed here so that - // the `io.ReadAll()` call above knows that it's done: - _ = w.Close() + assert.NoError(t, <-readErr) + assert.Equal(t, "hello world\n", string(buf)) +} + +func TestIOPipePipelineReadFromSlowly(t *testing.T) { + t.Parallel() + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + + r, w := io.Pipe() + + var buf []byte + readErr := make(chan error, 1) + + go func() { + time.Sleep(200 * time.Millisecond) + var err error + buf, err = io.ReadAll(r) + readErr <- err + }() + + p := pipe.New(pipe.WithStdoutCloser(w)) + p.Add(pipe.Command("echo", "hello world")) + assert.NoError(t, p.Run(ctx)) assert.NoError(t, <-readErr) assert.Equal(t, "hello world\n", string(buf)) @@ -211,8 +230,6 @@ func TestPipelineReadFromSlowly2(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) defer cancel() - dir := t.TempDir() - r, w := io.Pipe() var buf []byte @@ -236,15 +253,10 @@ func TestPipelineReadFromSlowly2(t *testing.T) { } }() - p := pipe.New(pipe.WithDir(dir), pipe.WithStdout(w)) + p := pipe.New(pipe.WithStdoutCloser(w)) p.Add(pipe.Command("seq", "100")) assert.NoError(t, p.Run(ctx)) - time.Sleep(200 * time.Millisecond) - // It's not super-intuitive, but `w` has to be closed here so that - // the `io.ReadAll()` call above knows that it's done: - _ = w.Close() - assert.NoError(t, <-readErr) assert.Equal(t, 292, len(buf)) } @@ -253,9 +265,7 @@ func TestPipelineTwoCommandsPiping(t *testing.T) { t.Parallel() ctx := context.Background() - dir := t.TempDir() - - p := pipe.New(pipe.WithDir(dir)) + p := pipe.New() p.Add(pipe.Command("echo", "hello world")) assert.Panics(t, func() { p.Add(pipe.Command("")) }) out, err := p.Output(ctx) @@ -283,9 +293,7 @@ func TestPipelineExit(t *testing.T) { t.Parallel() ctx := context.Background() - dir := t.TempDir() - - p := pipe.New(pipe.WithDir(dir)) + p := pipe.New() p.Add( pipe.Command("false"), pipe.Command("true"), @@ -316,11 +324,10 @@ func TestPipelineInterrupted(t *testing.T) { } t.Parallel() - dir := t.TempDir() stdout := &bytes.Buffer{} - p := pipe.New(pipe.WithDir(dir), pipe.WithStdout(stdout)) + p := pipe.New(pipe.WithStdout(stdout)) p.Add(pipe.Command("sleep", "10")) ctx, cancel := context.WithTimeout(context.Background(), 20*time.Millisecond) @@ -339,11 +346,10 @@ func TestPipelineCanceled(t *testing.T) { } t.Parallel() - dir := t.TempDir() stdout := &bytes.Buffer{} - p := pipe.New(pipe.WithDir(dir), pipe.WithStdout(stdout)) + p := pipe.New(pipe.WithStdout(stdout)) p.Add(pipe.Command("sleep", "10")) ctx, cancel := context.WithCancel(context.Background()) @@ -367,9 +373,8 @@ func TestLittleEPIPE(t *testing.T) { } t.Parallel() - dir := t.TempDir() - p := pipe.New(pipe.WithDir(dir)) + p := pipe.New() p.Add( pipe.Command("sh", "-c", "sleep 1; echo foo"), pipe.Command("true"), @@ -391,9 +396,8 @@ func TestBigEPIPE(t *testing.T) { } t.Parallel() - dir := t.TempDir() - p := pipe.New(pipe.WithDir(dir)) + p := pipe.New() p.Add( pipe.Command("seq", "100000"), pipe.Command("true"), @@ -415,9 +419,8 @@ func TestIgnoredSIGPIPE(t *testing.T) { } t.Parallel() - dir := t.TempDir() - p := pipe.New(pipe.WithDir(dir)) + p := pipe.New() p.Add( pipe.IgnoreError(pipe.Command("seq", "100000"), pipe.IsSIGPIPE), pipe.Command("echo", "foo"), @@ -434,9 +437,7 @@ func TestFunction(t *testing.T) { t.Parallel() ctx := context.Background() - dir := t.TempDir() - - p := pipe.New(pipe.WithDir(dir)) + p := pipe.New() p.Add( pipe.Print("hello world"), pipe.Function( @@ -464,9 +465,7 @@ func TestPipelineWithFunction(t *testing.T) { t.Parallel() ctx := context.Background() - dir := t.TempDir() - - p := pipe.New(pipe.WithDir(dir)) + p := pipe.New() p.Add( pipe.Command("echo", "-n", "hello world"), pipe.Function( @@ -528,9 +527,7 @@ func TestPipelineWithLinewiseFunction(t *testing.T) { t.Parallel() ctx := context.Background() - dir := t.TempDir() - - p := pipe.New(pipe.WithDir(dir)) + p := pipe.New() // Print the numbers from 1 to 20 (generated from scratch): p.Add( seqFunction(20), @@ -581,7 +578,7 @@ func TestScannerAlwaysFlushes(t *testing.T) { var length int64 - p := pipe.New(pipe.WithDir(".")) + p := pipe.New() // Print the numbers from 1 to 20 (generated from scratch): p.Add( pipe.IgnoreError( @@ -629,7 +626,7 @@ func TestScannerFinishEarly(t *testing.T) { var length int64 - p := pipe.New(pipe.WithDir(".")) + p := pipe.New() // Print the numbers from 1 to 20 (generated from scratch): p.Add( pipe.IgnoreError( @@ -670,9 +667,7 @@ func TestPrintln(t *testing.T) { t.Parallel() ctx := context.Background() - dir := t.TempDir() - - p := pipe.New(pipe.WithDir(dir)) + p := pipe.New() p.Add(pipe.Println("Look Ma, no hands!")) out, err := p.Output(ctx) if assert.NoError(t, err) { @@ -684,9 +679,7 @@ func TestPrintf(t *testing.T) { t.Parallel() ctx := context.Background() - dir := t.TempDir() - - p := pipe.New(pipe.WithDir(dir)) + p := pipe.New() p.Add(pipe.Printf("Strangely recursive: %T", p)) out, err := p.Output(ctx) if assert.NoError(t, err) { @@ -880,10 +873,8 @@ func TestErrors(t *testing.T) { func BenchmarkSingleProgram(b *testing.B) { ctx := context.Background() - dir := b.TempDir() - for i := 0; i < b.N; i++ { - p := pipe.New(pipe.WithDir(dir)) + p := pipe.New() p.Add( pipe.Command("true"), ) @@ -894,10 +885,8 @@ func BenchmarkSingleProgram(b *testing.B) { func BenchmarkTenPrograms(b *testing.B) { ctx := context.Background() - dir := b.TempDir() - for i := 0; i < b.N; i++ { - p := pipe.New(pipe.WithDir(dir)) + p := pipe.New() p.Add( pipe.Command("echo", "hello world"), pipe.Command("cat"), @@ -920,15 +909,13 @@ func BenchmarkTenPrograms(b *testing.B) { func BenchmarkTenFunctions(b *testing.B) { ctx := context.Background() - dir := b.TempDir() - cp := func(_ context.Context, _ pipe.Env, stdin io.Reader, stdout io.Writer) error { _, err := io.Copy(stdout, stdin) return err } for i := 0; i < b.N; i++ { - p := pipe.New(pipe.WithDir(dir)) + p := pipe.New() p.Add( pipe.Println("hello world"), pipe.Function("copy1", cp), @@ -951,15 +938,13 @@ func BenchmarkTenFunctions(b *testing.B) { func BenchmarkTenMixedStages(b *testing.B) { ctx := context.Background() - dir := b.TempDir() - cp := func(_ context.Context, _ pipe.Env, stdin io.Reader, stdout io.Writer) error { _, err := io.Copy(stdout, stdin) return err } for i := 0; i < b.N; i++ { - p := pipe.New(pipe.WithDir(dir)) + p := pipe.New() p.Add( pipe.Command("echo", "hello world"), pipe.Function("copy1", cp), @@ -979,6 +964,97 @@ func BenchmarkTenMixedStages(b *testing.B) { } } +func BenchmarkMoreDataUnbuffered(b *testing.B) { + ctx := context.Background() + + cp := func(_ context.Context, _ pipe.Env, stdin io.Reader, stdout io.Writer) error { + _, err := io.Copy(stdout, stdin) + return err + } + + for i := 0; i < b.N; i++ { + count := 0 + p := pipe.New() + p.Add( + pipe.Function( + "seq", + func(ctx context.Context, _ pipe.Env, stdin io.Reader, stdout io.Writer) error { + for i := 1; i <= 100000; i++ { + fmt.Fprintln(stdout, i) + } + return nil + }, + ), + pipe.Command("cat"), + pipe.Function("copy2", cp), + pipe.Command("cat"), + pipe.Function("copy3", cp), + pipe.Command("cat"), + pipe.Function("copy4", cp), + pipe.Command("cat"), + pipe.Function("copy5", cp), + pipe.Command("cat"), + pipe.LinewiseFunction( + "count", + func(ctx context.Context, _ pipe.Env, line []byte, stdout *bufio.Writer) error { + count++ + return nil + }, + ), + ) + err := p.Run(ctx) + if assert.NoError(b, err) { + assert.EqualValues(b, 100000, count) + } + } +} + +func BenchmarkMoreDataBuffered(b *testing.B) { + ctx := context.Background() + + cp := func(_ context.Context, _ pipe.Env, stdin io.Reader, stdout io.Writer) error { + _, err := io.Copy(stdout, stdin) + return err + } + + for i := 0; i < b.N; i++ { + count := 0 + p := pipe.New() + p.Add( + pipe.Function( + "seq", + func(ctx context.Context, _ pipe.Env, stdin io.Reader, stdout io.Writer) error { + out := bufio.NewWriter(stdout) + for i := 1; i <= 1000000; i++ { + fmt.Fprintln(out, i) + } + return out.Flush() + }, + ), + pipe.Command("cat"), + pipe.Function("copy2", cp), + pipe.Command("cat"), + pipe.Function("copy3", cp), + pipe.Command("cat"), + pipe.Function("copy4", cp), + pipe.Command("cat"), + pipe.Function("copy5", cp), + pipe.Command("cat"), + pipe.LinewiseFunction( + "count", + func(ctx context.Context, _ pipe.Env, line []byte, stdout *bufio.Writer) error { + count++ + return nil + }, + ), + ) + err := p.Run(ctx) + if assert.NoError(b, err) { + assert.EqualValues(b, 1000000, count) + } + } +} + func genErr(err error) pipe.StageFunc { return func(_ context.Context, _ pipe.Env, _ io.Reader, _ io.Writer) error { return err diff --git a/pipe/scanner.go b/pipe/scanner.go index b56b58c..5ec16e8 100644 --- a/pipe/scanner.go +++ b/pipe/scanner.go @@ -56,11 +56,7 @@ func ScannerFunction( return err } } - if err := scanner.Err(); err != nil { - return err - } - - return nil + return scanner.Err() // `p.AddFunction()` arranges for `stdout` to be closed. }, ) diff --git a/pipe/stage.go b/pipe/stage.go index f3d74d9..b68713c 100644 --- a/pipe/stage.go +++ b/pipe/stage.go @@ -5,7 +5,87 @@ import ( "io" ) -// Stage is an element of a `Pipeline`. +// +// From the point of view of the pipeline as a whole, if stdin is +// provided by the user (`WithStdin()`), then we don't want to close +// it at all, whether it's an `*os.File` or not. For this reason, +// stdin has to be wrapped using a `readerNopCloser` before being +// passed into the first stage. For efficiency reasons, it's +// advantageous for the first stage should ideally unwrap its stdin +// argument before actually using it. If the wrapped value is an +// `*os.File` and the stage is a command stage, then unwrapping is +// also important to get the right semantics. +// +// For stdout, it depends on whether the user supplied it using +// `WithStdout()` or `WithStdoutCloser()`. If the former, then the +// considerations are the same as for stdin. +// +// [1] It's theoretically possible for a command to pass the open file +// descriptor to another, longer-lived process, in which case the +// file descriptor wouldn't necessarily get closed when the +// command finishes. But that's ill-behaved in a command that is +// being used in a pipeline, so we'll ignore that possibility. + +// Stage is an element of a `Pipeline`. It reads from standard input +// and writes to standard output. +// +// Who closes stdin and stdout? +// +// A `Stage` as a whole needs to be responsible for closing its end of +// stdin and stdout (assuming that `Start()` / `StartWithIO()` returns +// successfully). Its doing so tells the previous/next stage that it +// is done reading/writing data, which can affect their behavior. +// Therefore, it should close each one as soon as it is done with it. +// (If the caller wants to suppress the closing of stdin/stdout, it +// can always wrap the corresponding argument in a `nopCloser`.) +// +// Specifically, if a stage is started using `Start()`, then it is +// responsible for closing the stdin that is passed to it, and also +// for closing its end of the `io.Reader` that the method returns. If +// a stage implements `StageWithIO` and is started using +// `StartWithIO()`, then it is responsible for closing both the stdin +// and stdout that are passed in as arguments. How this should be done +// depends on the kind of stage and whether stdin/stdout are of type +// `*os.File`. +// +// If a stage is an external command, it the subprocess ultimately +// needs its own copies of `*os.File` file descriptors for its stdin +// and stdout. The external command will "always" [1] close those when +// it exits. +// +// If the stage is an external command and one of the arguments is an +// `*os.File`, then it can set the corresponding field of `exec.Cmd` +// to that argument directly. This has the result that `exec.Cmd` +// duplicates that file descriptor and passes the dup to the +// subprocess. Therefore, the stage can close its copy of that +// argument as soon as the external command has started, because the +// external command will keep its own copy open as long as necessary +// (and no longer!), in roughly the following sequence: +// +// cmd.Stdin = f // Similarly for stdout +// cmd.Start(…) +// f.Close() // close our copy +// cmd.Wait() +// +// If the stage is an external command and one of its arguments is not +// an `*os.File`, then `exec.Cmd` will take care of creating an +// `os.Pipe()`, copying from the provided argument in/out of the pipe, +// and eventually closing both ends of the pipe. The stage must close +// the argument itself, but only _after_ the external command has +// finished: +// +// cmd.Stdin = r // Similarly for stdout +// cmd.Start(…) +// cmd.Wait() +// r.Close() +// +// If the stage is a Go function, then it holds the only copy of +// stdin/stdout, so it must wait until the function is done before +// closing them (regardless of their underlying type: +// +// f(…, stdin, stdout) +// stdin.Close() +// stdout.Close() type Stage interface { // Name returns the name of the stage. Name() string @@ -16,12 +96,9 @@ type Stage interface { // might be the case for the first stage in a pipeline.) It // returns an `io.ReadCloser` from which the stage's output can be // read (or `nil` if it generates no output, which should only be - // the case for the last stage in a pipeline). It is the stages' - // responsibility to close `stdin` (if it is not nil) when it has - // read all of the input that it needs, and to close the write end - // of its output reader when it is done, as that is generally how - // the subsequent stage knows that it has received all of its - // input and can finish its work, too. + // the case for the last stage in a pipeline). See the `Stage` + // type comment for more information about responsibility for + // closing stdin and stdout. // // If `Start()` returns without an error, `Wait()` must also be // called, to allow all resources to be freed. @@ -32,3 +109,54 @@ type Stage interface { // the context passed to `Start()`. Wait() error } + +// StagePreferences is the way that a `StageWithIO` indicates its +// preferences about how it is run. This is used within +// `pipe.Pipeline` to decide when to use `os.Pipe()` vs. `io.Pipe()` +// for creating the pipes between stages. +type StagePreferences struct { + StdinPreference IOPreference + StdoutPreference IOPreference +} + +// StageWithIO is a `Stage` that can accept both stdin and stdout arguments +// when it is started. +type StageWithIO interface { + Stage + + // Preferences() returns this stage's preferences regarding how it + // should be run. + Preferences() StagePreferences + + // StartWithIO starts the stage (like `Stage.Start()`), except that it + // allows the caller to pass in both stdin and stdout. + StartWithIO(ctx context.Context, env Env, stdin io.ReadCloser, stdout io.WriteCloser) error +} + +// IOPreference describes what type of stdin / stdout a stage would +// prefer. +// +// External commands prefer `*os.File`s (such as those produced by +// `os.Pipe()`) as their stdin and stdout, because those can be passed +// directly by the external process without any extra copying and also +// simplify the semantics around process termination. Go function +// stages are typically happy with any `io.ReadCloser` (such as one +// produced by `io.Pipe()`), which can be more efficient because +// traffic through an `io.Pipe()` happens entirely in userspace. +type IOPreference int + +const ( + // IOPreferenceUndefined indicates that the stage doesn't care + // what form the specified stdin / stdout takes (i.e., any old + // `io.ReadCloser` / `io.WriteCloser` is just fine). + IOPreferenceUndefined IOPreference = iota + + // IOPreferenceFile indicates that the stage would prefer for the + // specified stdin / stdout to be an `*os.File`, to avoid copying. + IOPreferenceFile + + // IOPreferenceNil indicates that the stage does not use the + // specified stdin / stdout, so `nil` should be passed in. This + // should only happen at the beginning / end of a pipeline. + IOPreferenceNil +)