diff --git a/server/jest.config.js b/server/jest.config.js new file mode 100644 index 000000000..f945574cc --- /dev/null +++ b/server/jest.config.js @@ -0,0 +1,18 @@ +export default { + preset: "ts-jest", + testEnvironment: "node", + roots: ["/src"], + testMatch: ["**/__tests__/**/*.test.ts", "**/*.test.ts"], + transform: { + "^.+\\.ts$": "ts-jest", + }, + moduleNameMapping: { + "^(\\.{1,2}/.*)\\.js$": "$1", + }, + extensionsToTreatAsEsm: [".ts"], + globals: { + "ts-jest": { + useESM: true, + }, + }, +}; \ No newline at end of file diff --git a/server/package.json b/server/package.json index 39cc08f13..78d207e17 100644 --- a/server/package.json +++ b/server/package.json @@ -17,12 +17,19 @@ "build": "tsc", "start": "node build/index.js", "dev": "tsx watch --clear-screen=false src/index.ts", - "dev:windows": "tsx watch --clear-screen=false src/index.ts < NUL" + "dev:windows": "tsx watch --clear-screen=false src/index.ts < NUL", + "test": "jest" }, "devDependencies": { "@types/cors": "^2.8.19", "@types/express": "^4.17.23", + "@types/jest": "^29.5.14", + "@types/shell-quote": "^1.7.5", + "@types/supertest": "^6.0.2", "@types/ws": "^8.5.12", + "jest": "^29.7.0", + "supertest": "^7.0.0", + "ts-jest": "^29.4.0", "tsx": "^4.19.0", "typescript": "^5.6.2" }, @@ -30,6 +37,8 @@ "@modelcontextprotocol/sdk": "^1.17.5", "cors": "^2.8.5", "express": "^5.1.0", + "shell-quote": "^1.8.3", + "spawn-rx": "^5.1.2", "ws": "^8.18.0", "zod": "^3.25.76" } diff --git a/server/src/__tests__/mcpProxy.test.ts b/server/src/__tests__/mcpProxy.test.ts new file mode 100644 index 000000000..5aeaaffe1 --- /dev/null +++ b/server/src/__tests__/mcpProxy.test.ts @@ -0,0 +1,331 @@ +import { jest } from "@jest/globals"; +import mcpProxy from "../mcpProxy.js"; + +// Mock transport interface +interface MockTransport { + sessionId?: string; + onmessage: ((message: any) => void) | null; + onclose: (() => void) | null; + onerror: ((error: Error) => void) | null; + send: jest.Mock; + close: jest.Mock; +} + +// Create mock transport +function createMockTransport(sessionId?: string): MockTransport { + return { + sessionId, + onmessage: null, + onclose: null, + onerror: null, + send: jest.fn().mockResolvedValue(undefined), + close: jest.fn().mockResolvedValue(undefined), + }; +} + +describe("mcpProxy", () => { + let mockClientTransport: MockTransport; + let mockServerTransport: MockTransport; + let mockCleanup: jest.Mock; + + beforeEach(() => { + mockClientTransport = createMockTransport("client-session-123"); + mockServerTransport = createMockTransport("server-session-456"); + mockCleanup = jest.fn(); + }); + + afterEach(() => { + jest.clearAllMocks(); + }); + + describe("message forwarding", () => { + it("should forward messages from client to server", async () => { + mcpProxy({ + transportToClient: mockClientTransport as any, + transportToServer: mockServerTransport as any, + onCleanup: mockCleanup, + }); + + const testMessage = { + jsonrpc: "2.0" as const, + method: "test/method", + params: { test: "data" }, + id: 1, + }; + + // Simulate client message + mockClientTransport.onmessage!(testMessage); + + expect(mockServerTransport.send).toHaveBeenCalledWith(testMessage); + }); + + it("should forward messages from server to client", async () => { + mcpProxy({ + transportToClient: mockClientTransport as any, + transportToServer: mockServerTransport as any, + onCleanup: mockCleanup, + }); + + const testMessage = { + jsonrpc: "2.0" as const, + result: { test: "response" }, + id: 1, + }; + + // Simulate server message + mockServerTransport.onmessage!(testMessage); + + expect(mockClientTransport.send).toHaveBeenCalledWith(testMessage); + }); + }); + + describe("error handling", () => { + it("should send error response when server send fails for request", async () => { + const serverError = new Error("Server send failed"); + mockServerTransport.send.mockRejectedValue(serverError); + + mcpProxy({ + transportToClient: mockClientTransport as any, + transportToServer: mockServerTransport as any, + onCleanup: mockCleanup, + }); + + const testRequest = { + jsonrpc: "2.0" as const, + method: "test/method", + params: { test: "data" }, + id: 1, + }; + + // Simulate client request that fails on server + mockClientTransport.onmessage!(testRequest); + + // Wait for the async error handling + await new Promise((resolve) => setTimeout(resolve, 0)); + + expect(mockClientTransport.send).toHaveBeenCalledWith({ + jsonrpc: "2.0", + id: 1, + error: { + code: -32001, + message: "Server send failed", + data: serverError, + }, + }); + }); + + it("should not send error response when client transport is closed", async () => { + const serverError = new Error("Server send failed"); + mockServerTransport.send.mockRejectedValue(serverError); + + mcpProxy({ + transportToClient: mockClientTransport as any, + transportToServer: mockServerTransport as any, + onCleanup: mockCleanup, + }); + + // Close client transport first + mockClientTransport.onclose!(); + + const testRequest = { + jsonrpc: "2.0" as const, + method: "test/method", + params: { test: "data" }, + id: 1, + }; + + // Now try to send message + mockClientTransport.onmessage!(testRequest); + + // Wait for the async error handling + await new Promise((resolve) => setTimeout(resolve, 0)); + + // Should not send error response since client transport is closed + expect(mockClientTransport.send).toHaveBeenCalledTimes(0); + }); + + it("should not send error response for notifications (no id)", async () => { + const serverError = new Error("Server send failed"); + mockServerTransport.send.mockRejectedValue(serverError); + + mcpProxy({ + transportToClient: mockClientTransport as any, + transportToServer: mockServerTransport as any, + onCleanup: mockCleanup, + }); + + const testNotification = { + jsonrpc: "2.0" as const, + method: "test/notification", + params: { test: "data" }, + }; + + // Simulate client notification that fails on server + mockClientTransport.onmessage!(testNotification); + + // Wait for the async error handling + await new Promise((resolve) => setTimeout(resolve, 0)); + + // Should not send error response for notifications + expect(mockClientTransport.send).toHaveBeenCalledTimes(0); + }); + }); + + describe("connection cleanup", () => { + it("should call cleanup when client transport closes", () => { + mcpProxy({ + transportToClient: mockClientTransport as any, + transportToServer: mockServerTransport as any, + onCleanup: mockCleanup, + }); + + // Simulate client transport closing + mockClientTransport.onclose!(); + + expect(mockCleanup).toHaveBeenCalledTimes(1); + expect(mockServerTransport.close).toHaveBeenCalledTimes(1); + }); + + it("should call cleanup when server transport closes", () => { + mcpProxy({ + transportToClient: mockClientTransport as any, + transportToServer: mockServerTransport as any, + onCleanup: mockCleanup, + }); + + // Simulate server transport closing + mockServerTransport.onclose!(); + + expect(mockCleanup).toHaveBeenCalledTimes(1); + expect(mockClientTransport.close).toHaveBeenCalledTimes(1); + }); + + it("should not call cleanup twice if both transports close", () => { + mcpProxy({ + transportToClient: mockClientTransport as any, + transportToServer: mockServerTransport as any, + onCleanup: mockCleanup, + }); + + // Simulate both transports closing + mockClientTransport.onclose!(); + mockServerTransport.onclose!(); + + expect(mockCleanup).toHaveBeenCalledTimes(1); + }); + + it("should work without cleanup callback", () => { + expect(() => { + mcpProxy({ + transportToClient: mockClientTransport as any, + transportToServer: mockServerTransport as any, + }); + + // Should not throw when cleanup is not provided + mockClientTransport.onclose!(); + }).not.toThrow(); + }); + + it("should handle cleanup callback errors gracefully", () => { + const errorCleanup = jest.fn().mockImplementation(() => { + throw new Error("Cleanup failed"); + }); + + expect(() => { + mcpProxy({ + transportToClient: mockClientTransport as any, + transportToServer: mockServerTransport as any, + onCleanup: errorCleanup, + }); + + // Should not throw even if cleanup fails + mockClientTransport.onclose!(); + }).not.toThrow(); + + expect(errorCleanup).toHaveBeenCalledTimes(1); + }); + }); + + describe("transport close synchronization", () => { + it("should not close server transport if already closed by server", () => { + mcpProxy({ + transportToClient: mockClientTransport as any, + transportToServer: mockServerTransport as any, + onCleanup: mockCleanup, + }); + + // First, server transport closes + mockServerTransport.onclose!(); + + // Reset mock to check if close is called again + mockServerTransport.close.mockClear(); + + // Then client transport tries to close + mockClientTransport.onclose!(); + + // Server transport should not be closed again + expect(mockServerTransport.close).toHaveBeenCalledTimes(0); + expect(mockCleanup).toHaveBeenCalledTimes(1); + }); + + it("should not close client transport if already closed by client", () => { + mcpProxy({ + transportToClient: mockClientTransport as any, + transportToServer: mockServerTransport as any, + onCleanup: mockCleanup, + }); + + // First, client transport closes + mockClientTransport.onclose!(); + + // Reset mock to check if close is called again + mockClientTransport.close.mockClear(); + + // Then server transport tries to close + mockServerTransport.onclose!(); + + // Client transport should not be closed again + expect(mockClientTransport.close).toHaveBeenCalledTimes(0); + expect(mockCleanup).toHaveBeenCalledTimes(1); + }); + }); + + describe("error handlers", () => { + it("should set error handlers on both transports", () => { + mcpProxy({ + transportToClient: mockClientTransport as any, + transportToServer: mockServerTransport as any, + onCleanup: mockCleanup, + }); + + expect(mockClientTransport.onerror).toBeTruthy(); + expect(mockServerTransport.onerror).toBeTruthy(); + }); + + it("should handle client errors without throwing", () => { + mcpProxy({ + transportToClient: mockClientTransport as any, + transportToServer: mockServerTransport as any, + onCleanup: mockCleanup, + }); + + const testError = new Error("Client error"); + expect(() => { + mockClientTransport.onerror!(testError); + }).not.toThrow(); + }); + + it("should handle server errors without throwing", () => { + mcpProxy({ + transportToClient: mockClientTransport as any, + transportToServer: mockServerTransport as any, + onCleanup: mockCleanup, + }); + + const testError = new Error("Server error"); + expect(() => { + mockServerTransport.onerror!(testError); + }).not.toThrow(); + }); + }); +}); \ No newline at end of file diff --git a/server/src/__tests__/transport-cleanup.test.ts b/server/src/__tests__/transport-cleanup.test.ts new file mode 100644 index 000000000..91968f7d4 --- /dev/null +++ b/server/src/__tests__/transport-cleanup.test.ts @@ -0,0 +1,328 @@ +import { jest } from "@jest/globals"; +import request from "supertest"; +import express from "express"; +import { SSEServerTransport } from "@modelcontextprotocol/sdk/server/sse.js"; +import { StdioClientTransport } from "@modelcontextprotocol/sdk/client/stdio.js"; + +// Mock the SDK modules +jest.mock("@modelcontextprotocol/sdk/server/sse.js"); +jest.mock("@modelcontextprotocol/sdk/client/stdio.js"); +jest.mock("../mcpProxy.js"); + +const MockSSEServerTransport = SSEServerTransport as jest.MockedClass; +const MockStdioClientTransport = StdioClientTransport as jest.MockedClass; + +describe("Transport Cleanup Integration", () => { + let app: express.Application; + let mockWebAppTransports: Map; + let mockServerTransports: Map; + let mockSSETransport: any; + let mockStdioTransport: any; + + beforeEach(() => { + jest.clearAllMocks(); + + // Setup mock transports + mockSSETransport = { + sessionId: "test-session-123", + start: jest.fn().mockResolvedValue(undefined), + send: jest.fn().mockResolvedValue(undefined), + close: jest.fn().mockResolvedValue(undefined), + onmessage: null, + onclose: null, + onerror: null, + }; + + mockStdioTransport = { + start: jest.fn().mockResolvedValue(undefined), + send: jest.fn().mockResolvedValue(undefined), + close: jest.fn().mockResolvedValue(undefined), + stderr: { + on: jest.fn(), + }, + onmessage: null, + onclose: null, + onerror: null, + }; + + MockSSEServerTransport.mockImplementation(() => mockSSETransport); + MockStdioClientTransport.mockImplementation(() => mockStdioTransport); + + // Setup Express app with transport maps (similar to actual server) + app = express(); + app.use(express.json()); + + mockWebAppTransports = new Map(); + mockServerTransports = new Map(); + + // Mock the actual cleanup logic + const setupCleanupHandlers = (webAppTransport: any, serverTransport: any) => { + const cleanup = () => { + mockWebAppTransports.delete(webAppTransport.sessionId); + mockServerTransports.delete(webAppTransport.sessionId); + }; + + // Simulate the mcpProxy cleanup behavior + webAppTransport.onclose = cleanup; + serverTransport.onclose = cleanup; + + return cleanup; + }; + + // STDIO route handler (simplified) + app.get("/stdio", async (req, res) => { + try { + const serverTransport = new MockStdioClientTransport({} as any); + await serverTransport.start(); + + const webAppTransport = new MockSSEServerTransport("/message", res as any); + await webAppTransport.start(); + + // Add to maps + mockWebAppTransports.set(webAppTransport.sessionId, webAppTransport); + mockServerTransports.set(webAppTransport.sessionId, serverTransport); + + // Setup cleanup + setupCleanupHandlers(webAppTransport, serverTransport); + + res.status(200).json({ sessionId: webAppTransport.sessionId }); + } catch (error) { + res.status(500).json({ error: error.message }); + } + }); + + // Test endpoint to check transport state + app.get("/test/transports", (req, res) => { + res.json({ + webAppTransports: Array.from(mockWebAppTransports.keys()), + serverTransports: Array.from(mockServerTransports.keys()), + }); + }); + + // Test endpoint to simulate transport close + app.post("/test/close/:sessionId", (req, res) => { + const sessionId = req.params.sessionId; + const webAppTransport = mockWebAppTransports.get(sessionId); + + if (webAppTransport && webAppTransport.onclose) { + webAppTransport.onclose(); + res.json({ closed: true }); + } else { + res.status(404).json({ error: "Transport not found" }); + } + }); + }); + + describe("STDIO Transport Cleanup", () => { + it("should create and track STDIO transports", async () => { + const response = await request(app) + .get("/stdio") + .expect(200); + + expect(response.body.sessionId).toBe("test-session-123"); + + // Check that transports are tracked + const transportsResponse = await request(app) + .get("/test/transports") + .expect(200); + + expect(transportsResponse.body.webAppTransports).toContain("test-session-123"); + expect(transportsResponse.body.serverTransports).toContain("test-session-123"); + }); + + it("should clean up transports when connection closes", async () => { + // First create the connection + await request(app) + .get("/stdio") + .expect(200); + + // Verify transports are tracked + let transportsResponse = await request(app) + .get("/test/transports") + .expect(200); + + expect(transportsResponse.body.webAppTransports).toContain("test-session-123"); + expect(transportsResponse.body.serverTransports).toContain("test-session-123"); + + // Simulate connection close + await request(app) + .post("/test/close/test-session-123") + .expect(200); + + // Verify transports are cleaned up + transportsResponse = await request(app) + .get("/test/transports") + .expect(200); + + expect(transportsResponse.body.webAppTransports).not.toContain("test-session-123"); + expect(transportsResponse.body.serverTransports).not.toContain("test-session-123"); + }); + + it("should handle multiple concurrent connections", async () => { + // Create multiple connections + const sessions = ["session-1", "session-2", "session-3"]; + + for (const sessionId of sessions) { + mockSSETransport.sessionId = sessionId; + await request(app) + .get("/stdio") + .expect(200); + } + + // Check all are tracked + const transportsResponse = await request(app) + .get("/test/transports") + .expect(200); + + for (const sessionId of sessions) { + expect(transportsResponse.body.webAppTransports).toContain(sessionId); + expect(transportsResponse.body.serverTransports).toContain(sessionId); + } + }); + + it("should handle cleanup of non-existent session gracefully", async () => { + await request(app) + .post("/test/close/non-existent-session") + .expect(404); + + // Should not affect existing transports + const transportsResponse = await request(app) + .get("/test/transports") + .expect(200); + + expect(transportsResponse.body).toEqual({ + webAppTransports: [], + serverTransports: [], + }); + }); + }); + + describe("Transport State Management", () => { + it("should maintain consistent state between webApp and server transport maps", async () => { + // Create connection + await request(app) + .get("/stdio") + .expect(200); + + let transportsResponse = await request(app) + .get("/test/transports") + .expect(200); + + expect(transportsResponse.body.webAppTransports).toEqual(["test-session-123"]); + expect(transportsResponse.body.serverTransports).toEqual(["test-session-123"]); + + // Close connection + await request(app) + .post("/test/close/test-session-123") + .expect(200); + + transportsResponse = await request(app) + .get("/test/transports") + .expect(200); + + // Both maps should be empty + expect(transportsResponse.body.webAppTransports).toEqual([]); + expect(transportsResponse.body.serverTransports).toEqual([]); + }); + + it("should handle rapid connect/disconnect cycles", async () => { + const sessionId = "rapid-test-session"; + + for (let i = 0; i < 5; i++) { + // Connect + mockSSETransport.sessionId = sessionId; + await request(app) + .get("/stdio") + .expect(200); + + // Verify connected + let transportsResponse = await request(app) + .get("/test/transports") + .expect(200); + + expect(transportsResponse.body.webAppTransports).toContain(sessionId); + expect(transportsResponse.body.serverTransports).toContain(sessionId); + + // Disconnect + await request(app) + .post(`/test/close/${sessionId}`) + .expect(200); + + // Verify disconnected + transportsResponse = await request(app) + .get("/test/transports") + .expect(200); + + expect(transportsResponse.body.webAppTransports).not.toContain(sessionId); + expect(transportsResponse.body.serverTransports).not.toContain(sessionId); + } + }); + }); +}); + +// Test for the actual issue scenario +describe("STDIO Server Restart Issue", () => { + let mockWebAppTransports: Map; + let mockServerTransports: Map; + + beforeEach(() => { + mockWebAppTransports = new Map(); + mockServerTransports = new Map(); + }); + + it("should demonstrate the problem without cleanup", () => { + const sessionId = "problem-session"; + const mockTransport = { + sessionId, + send: jest.fn().mockRejectedValue(new Error("Not connected")), + close: jest.fn(), + }; + + // Add transport to maps (simulating connection) + mockWebAppTransports.set(sessionId, mockTransport); + mockServerTransports.set(sessionId, mockTransport); + + // Simulate disconnect without cleanup (the bug) + // Transport references remain in maps + + expect(mockWebAppTransports.has(sessionId)).toBe(true); + expect(mockServerTransports.has(sessionId)).toBe(true); + + // Attempt to use stale transport (this would cause "Not connected" error) + const staleTransport = mockWebAppTransports.get(sessionId); + expect(staleTransport.send()).rejects.toThrow("Not connected"); + }); + + it("should demonstrate the fix with proper cleanup", () => { + const sessionId = "fixed-session"; + const mockTransport = { + sessionId, + send: jest.fn(), + close: jest.fn(), + onclose: null as (() => void) | null, + }; + + // Add transport to maps (simulating connection) + mockWebAppTransports.set(sessionId, mockTransport); + mockServerTransports.set(sessionId, mockTransport); + + // Setup cleanup handler (the fix) + const cleanup = () => { + mockWebAppTransports.delete(sessionId); + mockServerTransports.delete(sessionId); + }; + + mockTransport.onclose = cleanup; + + // Verify transport is tracked + expect(mockWebAppTransports.has(sessionId)).toBe(true); + expect(mockServerTransports.has(sessionId)).toBe(true); + + // Simulate disconnect with cleanup + mockTransport.onclose(); + + // Verify cleanup worked + expect(mockWebAppTransports.has(sessionId)).toBe(false); + expect(mockServerTransports.has(sessionId)).toBe(false); + }); +}); \ No newline at end of file diff --git a/server/src/index.ts b/server/src/index.ts index c0fb3797a..3e5179fde 100644 --- a/server/src/index.ts +++ b/server/src/index.ts @@ -300,9 +300,18 @@ app.post( await webAppTransport.start(); + const cleanup = () => { + if (webAppTransport.sessionId) { + webAppTransports.delete(webAppTransport.sessionId); + serverTransports.delete(webAppTransport.sessionId); + console.log(`Transports cleaned up for sessionId ${webAppTransport.sessionId}`); + } + }; + mcpProxy({ transportToClient: webAppTransport, transportToServer: serverTransport, + onCleanup: cleanup, }); await (webAppTransport as StreamableHTTPServerTransport).handleRequest( @@ -470,9 +479,16 @@ app.get( } }); + const cleanup = () => { + webAppTransports.delete(webAppTransport.sessionId); + serverTransports.delete(webAppTransport.sessionId); + console.log(`Transports cleaned up for sessionId ${webAppTransport.sessionId}`); + }; + mcpProxy({ transportToClient: webAppTransport, transportToServer: serverTransport, + onCleanup: cleanup, }); } catch (error) { console.error("Error in /stdio route:", error); @@ -528,9 +544,16 @@ app.get( await webAppTransport.start(); + const cleanup = () => { + webAppTransports.delete(webAppTransport.sessionId); + serverTransports.delete(webAppTransport.sessionId); + console.log(`Transports cleaned up for sessionId ${webAppTransport.sessionId}`); + }; + mcpProxy({ transportToClient: webAppTransport, transportToServer: serverTransport, + onCleanup: cleanup, }); } } catch (error) { diff --git a/server/src/mcpProxy.ts b/server/src/mcpProxy.ts index 664f17119..c5d3f107a 100644 --- a/server/src/mcpProxy.ts +++ b/server/src/mcpProxy.ts @@ -18,9 +18,11 @@ function onServerError(error: Error) { export default function mcpProxy({ transportToClient, transportToServer, + onCleanup, }: { transportToClient: Transport; transportToServer: Transport; + onCleanup?: () => void; }) { let transportToClientClosed = false; let transportToServerClosed = false; @@ -65,6 +67,7 @@ export default function mcpProxy({ transportToClientClosed = true; transportToServer.close().catch(onServerError); + onCleanup?.(); }; transportToServer.onclose = () => { @@ -73,6 +76,7 @@ export default function mcpProxy({ } transportToServerClosed = true; transportToClient.close().catch(onClientError); + onCleanup?.(); }; transportToClient.onerror = onClientError; diff --git a/server/tsconfig.json b/server/tsconfig.json index b5a92612a..1ad6061cf 100644 --- a/server/tsconfig.json +++ b/server/tsconfig.json @@ -12,5 +12,5 @@ "resolveJsonModule": true }, "include": ["src/**/*"], - "exclude": ["node_modules", "packages", "**/*.spec.ts"] + "exclude": ["node_modules", "packages", "**/*.spec.ts", "**/*.test.ts", "**/__tests__/**/*"] }