diff --git a/client/transport/stdio.go b/client/transport/stdio.go index 3d9d832ac..8da8f6bc0 100644 --- a/client/transport/stdio.go +++ b/client/transport/stdio.go @@ -4,15 +4,23 @@ import ( "bufio" "context" "encoding/json" + "errors" "fmt" "io" "os" "os/exec" "sync" + "sync/atomic" + "time" "github.com/mark3labs/mcp-go/mcp" ) +const ( + readyTimeout = 5 * time.Second + readyCheckTimeout = 1 * time.Second +) + // Stdio implements the transport layer of the MCP protocol using stdio communication. // It launches a subprocess and communicates with it via standard input/output streams // using JSON-RPC messages. The client handles message routing between requests and @@ -31,6 +39,7 @@ type Stdio struct { done chan struct{} onNotification func(mcp.JSONRPCNotification) notifyMu sync.RWMutex + exitErr atomic.Value } // NewIO returns a new stdio-based transport using existing input, output, and @@ -73,14 +82,13 @@ func (c *Stdio) Start(ctx context.Context) error { return err } + // Start reading responses in a goroutine and wait for it to be ready ready := make(chan struct{}) go func() { close(ready) c.readResponses() }() - <-ready - - return nil + return waitUntilReadyOrExit(ready, c.done, readyTimeout) } // spawnCommand spawns a new process running c.command. @@ -95,7 +103,6 @@ func (c *Stdio) spawnCommand(ctx context.Context) error { mergedEnv = append(mergedEnv, c.env...) cmd.Env = mergedEnv - stdin, err := cmd.StdinPipe() if err != nil { return fmt.Errorf("failed to create stdin pipe: %w", err) @@ -120,20 +127,45 @@ func (c *Stdio) spawnCommand(ctx context.Context) error { return fmt.Errorf("failed to start command: %w", err) } + go func() { + err := cmd.Wait() + if err != nil { + c.exitErr.Store(err) + } + tryCloseDone(c.done) + }() return nil } -// Close shuts down the stdio client, closing the stdin pipe and waiting for the subprocess to exit. -// Returns an error if there are issues closing stdin or waiting for the subprocess to terminate. -func (c *Stdio) Close() error { +func tryCloseDone(done chan struct{}) { select { - case <-c.done: - return nil + case <-done: + return default: } - // cancel all in-flight request - close(c.done) + close(done) +} +func waitUntilReadyOrExit(ready <-chan struct{}, exited <-chan struct{}, timeout time.Duration) error { + select { + case <-exited: + return errors.New("process exited before signalling readiness") + case <-ready: + select { + case <-exited: + return errors.New("process exited after readiness") + case <-time.After(readyCheckTimeout): + return nil + } + case <-time.After(timeout): + return errors.New("timeout waiting for process ready") + } +} +// Close shuts down the stdio client, closing the stdin pipe and waiting for the subprocess to exit. +// Returns an error if there are issues closing stdin or waiting for the subprocess to terminate. +func (c *Stdio) Close() error { + // cancel all in-flight request + tryCloseDone(c.done) if err := c.stdin.Close(); err != nil { return fmt.Errorf("failed to close stdin: %w", err) } @@ -141,10 +173,9 @@ func (c *Stdio) Close() error { return fmt.Errorf("failed to close stderr: %w", err) } - if c.cmd != nil { - return c.cmd.Wait() + if err, ok := c.exitErr.Load().(error); ok && err != nil { + return err } - return nil } diff --git a/client/transport/stdio_test.go b/client/transport/stdio_test.go index cb25bf796..51e90b0f0 100644 --- a/client/transport/stdio_test.go +++ b/client/transport/stdio_test.go @@ -407,5 +407,34 @@ func TestStdioErrors(t *testing.T) { t.Errorf("Expected error when sending request after close, got nil") } }) + t.Run("SubprocessStartsAndExitsImmediately", func(t *testing.T) { + // Create a temporary file for the mock server + tempFile, err := os.CreateTemp("", "mockstdio_server") + if err != nil { + t.Fatalf("Failed to create temp file: %v", err) + } + tempFile.Close() + mockServerPath := tempFile.Name() + // Add .exe suffix on Windows + if runtime.GOOS == "windows" { + os.Remove(mockServerPath) // Remove the empty file first + mockServerPath += ".exe" + } + + if compileErr := compileTestServer(mockServerPath); compileErr != nil { + t.Fatalf("Failed to compile mock server: %v", compileErr) + } + //defer os.Remove(mockServerPath) + + // Create a new Stdio transport + stdio := NewStdio(mockServerPath, nil) + stdio.env = append(stdio.env, "MOCK_FAIL_IMMEDIATELY=1") + defer stdio.Close() + // Start the transport + ctx := context.Background() + if startErr := stdio.Start(ctx); startErr == nil { + t.Fatalf("Expected error when starting Stdio transport, got nil") + } + }) } diff --git a/testdata/mockstdio_server.go b/testdata/mockstdio_server.go index 9f13d5547..86bfb3071 100644 --- a/testdata/mockstdio_server.go +++ b/testdata/mockstdio_server.go @@ -26,6 +26,10 @@ type JSONRPCResponse struct { } func main() { + if os.Getenv("MOCK_FAIL_IMMEDIATELY") == "1" { + fmt.Fprintln(os.Stderr, "mock server: simulated startup failure") + os.Exit(1) + } logger := slog.New(slog.NewJSONHandler(os.Stderr, &slog.HandlerOptions{})) logger.Info("launch successful") scanner := bufio.NewScanner(os.Stdin)