Skip to content

Fix: handle subprocess failing after startup #134

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 14 commits into
base: main
Choose a base branch
from
59 changes: 45 additions & 14 deletions client/transport/stdio.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,23 @@ import (
"bufio"
"context"
"encoding/json"
"errors"
"fmt"
"io"
"os"
"os/exec"
"sync"
"sync/atomic"
"time"

"github.com/mark3labs/mcp-go/mcp"
)

const (
readyTimeout = 5 * time.Second
readyCheckTimeout = 1 * time.Second
)

// Stdio implements the transport layer of the MCP protocol using stdio communication.
// It launches a subprocess and communicates with it via standard input/output streams
// using JSON-RPC messages. The client handles message routing between requests and
Expand All @@ -31,6 +39,7 @@ type Stdio struct {
done chan struct{}
onNotification func(mcp.JSONRPCNotification)
notifyMu sync.RWMutex
exitErr atomic.Value
}

// NewIO returns a new stdio-based transport using existing input, output, and
Expand Down Expand Up @@ -73,14 +82,13 @@ func (c *Stdio) Start(ctx context.Context) error {
return err
}

// Start reading responses in a goroutine and wait for it to be ready
ready := make(chan struct{})
go func() {
close(ready)
c.readResponses()
}()
<-ready

return nil
return waitUntilReadyOrExit(ready, c.done, readyTimeout)
}

// spawnCommand spawns a new process running c.command.
Expand All @@ -95,7 +103,6 @@ func (c *Stdio) spawnCommand(ctx context.Context) error {
mergedEnv = append(mergedEnv, c.env...)

cmd.Env = mergedEnv

stdin, err := cmd.StdinPipe()
if err != nil {
return fmt.Errorf("failed to create stdin pipe: %w", err)
Expand All @@ -120,31 +127,55 @@ func (c *Stdio) spawnCommand(ctx context.Context) error {
return fmt.Errorf("failed to start command: %w", err)
}

go func() {
err := cmd.Wait()
if err != nil {
c.exitErr.Store(err)
}
tryCloseDone(c.done)
}()
return nil
}

// Close shuts down the stdio client, closing the stdin pipe and waiting for the subprocess to exit.
// Returns an error if there are issues closing stdin or waiting for the subprocess to terminate.
func (c *Stdio) Close() error {
func tryCloseDone(done chan struct{}) {
select {
case <-c.done:
return nil
case <-done:
return
default:
}
// cancel all in-flight request
close(c.done)
close(done)
}
func waitUntilReadyOrExit(ready <-chan struct{}, exited <-chan struct{}, timeout time.Duration) error {
select {
case <-exited:
return errors.New("process exited before signalling readiness")
case <-ready:
select {
case <-exited:
return errors.New("process exited after readiness")
case <-time.After(readyCheckTimeout):
return nil
}
case <-time.After(timeout):
return errors.New("timeout waiting for process ready")
}
}

// Close shuts down the stdio client, closing the stdin pipe and waiting for the subprocess to exit.
// Returns an error if there are issues closing stdin or waiting for the subprocess to terminate.
func (c *Stdio) Close() error {
// cancel all in-flight request
tryCloseDone(c.done)
if err := c.stdin.Close(); err != nil {
return fmt.Errorf("failed to close stdin: %w", err)
}
if err := c.stderr.Close(); err != nil {
return fmt.Errorf("failed to close stderr: %w", err)
}

if c.cmd != nil {
return c.cmd.Wait()
if err, ok := c.exitErr.Load().(error); ok && err != nil {
return err
}

return nil
}

Expand Down
29 changes: 29 additions & 0 deletions client/transport/stdio_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -407,5 +407,34 @@ func TestStdioErrors(t *testing.T) {
t.Errorf("Expected error when sending request after close, got nil")
}
})
t.Run("SubprocessStartsAndExitsImmediately", func(t *testing.T) {
// Create a temporary file for the mock server
tempFile, err := os.CreateTemp("", "mockstdio_server")
if err != nil {
t.Fatalf("Failed to create temp file: %v", err)
}
tempFile.Close()
mockServerPath := tempFile.Name()

// Add .exe suffix on Windows
if runtime.GOOS == "windows" {
os.Remove(mockServerPath) // Remove the empty file first
mockServerPath += ".exe"
}

if compileErr := compileTestServer(mockServerPath); compileErr != nil {
t.Fatalf("Failed to compile mock server: %v", compileErr)
}
//defer os.Remove(mockServerPath)

// Create a new Stdio transport
stdio := NewStdio(mockServerPath, nil)
stdio.env = append(stdio.env, "MOCK_FAIL_IMMEDIATELY=1")
defer stdio.Close()
// Start the transport
ctx := context.Background()
if startErr := stdio.Start(ctx); startErr == nil {
t.Fatalf("Expected error when starting Stdio transport, got nil")
}
})
}
4 changes: 4 additions & 0 deletions testdata/mockstdio_server.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,10 @@ type JSONRPCResponse struct {
}

func main() {
if os.Getenv("MOCK_FAIL_IMMEDIATELY") == "1" {
fmt.Fprintln(os.Stderr, "mock server: simulated startup failure")
os.Exit(1)
}
logger := slog.New(slog.NewJSONHandler(os.Stderr, &slog.HandlerOptions{}))
logger.Info("launch successful")
scanner := bufio.NewScanner(os.Stdin)
Expand Down