Skip to content

added the configurable items sse_read_timeout and headers to mcp-client #55

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Mar 25, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
132 changes: 84 additions & 48 deletions client/sse.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,35 +23,59 @@ import (
// while sending requests over regular HTTP POST calls. The client handles
// automatic reconnection and message routing between requests and responses.
type SSEMCPClient struct {
baseURL *url.URL
endpoint *url.URL
httpClient *http.Client
requestID atomic.Int64
responses map[int64]chan RPCResponse
mu sync.RWMutex
done chan struct{}
initialized bool
notifications []func(mcp.JSONRPCNotification)
notifyMu sync.RWMutex
endpointChan chan struct{}
capabilities mcp.ServerCapabilities
baseURL *url.URL
endpoint *url.URL
httpClient *http.Client
requestID atomic.Int64
responses map[int64]chan RPCResponse
mu sync.RWMutex
done chan struct{}
initialized bool
notifications []func(mcp.JSONRPCNotification)
notifyMu sync.RWMutex
endpointChan chan struct{}
capabilities mcp.ServerCapabilities
headers map[string]string
sseReadTimeout time.Duration
}

type ClientOption func(*SSEMCPClient)

func WithHeaders(headers map[string]string) ClientOption {
return func(sc *SSEMCPClient) {
sc.headers = headers
}
}

func WithSSEReadTimeout(timeout time.Duration) ClientOption {
return func(sc *SSEMCPClient) {
sc.sseReadTimeout = timeout
}
}

// NewSSEMCPClient creates a new SSE-based MCP client with the given base URL.
// Returns an error if the URL is invalid.
func NewSSEMCPClient(baseURL string) (*SSEMCPClient, error) {
func NewSSEMCPClient(baseURL string, options ...ClientOption) (*SSEMCPClient, error) {
parsedURL, err := url.Parse(baseURL)
if err != nil {
return nil, fmt.Errorf("invalid URL: %w", err)
}

return &SSEMCPClient{
baseURL: parsedURL,
httpClient: &http.Client{},
responses: make(map[int64]chan RPCResponse),
done: make(chan struct{}),
endpointChan: make(chan struct{}),
}, nil
smc := &SSEMCPClient{
baseURL: parsedURL,
httpClient: &http.Client{},
responses: make(map[int64]chan RPCResponse),
done: make(chan struct{}),
endpointChan: make(chan struct{}),
sseReadTimeout: 30 * time.Second,
headers: make(map[string]string),
}

for _, opt := range options {
opt(smc)
}

return smc, nil
}

// Start initiates the SSE connection to the server and waits for the endpoint information.
Expand Down Expand Up @@ -104,41 +128,49 @@ func (c *SSEMCPClient) readSSE(reader io.ReadCloser) {
br := bufio.NewReader(reader)
var event, data string

ctx, cancel := context.WithTimeout(context.Background(), c.sseReadTimeout)
defer cancel()

for {
line, err := br.ReadString('\n')
if err != nil {
if err == io.EOF {
// Process any pending event before exit
select {
case <-ctx.Done():
return
default:
line, err := br.ReadString('\n')
if err != nil {
if err == io.EOF {
// Process any pending event before exit
if event != "" && data != "" {
c.handleSSEEvent(event, data)
}
break
}
select {
case <-c.done:
return
default:
fmt.Printf("SSE stream error: %v\n", err)
return
}
}

// Remove only newline markers
line = strings.TrimRight(line, "\r\n")
if line == "" {
// Empty line means end of event
if event != "" && data != "" {
c.handleSSEEvent(event, data)
event = ""
data = ""
}
break
}
select {
case <-c.done:
return
default:
fmt.Printf("SSE stream error: %v\n", err)
return
continue
}
}

// Remove only newline markers
line = strings.TrimRight(line, "\r\n")
if line == "" {
// Empty line means end of event
if event != "" && data != "" {
c.handleSSEEvent(event, data)
event = ""
data = ""
if strings.HasPrefix(line, "event:") {
event = strings.TrimSpace(strings.TrimPrefix(line, "event:"))
} else if strings.HasPrefix(line, "data:") {
data = strings.TrimSpace(strings.TrimPrefix(line, "data:"))
}
continue
}

if strings.HasPrefix(line, "event:") {
event = strings.TrimSpace(strings.TrimPrefix(line, "event:"))
} else if strings.HasPrefix(line, "data:") {
data = strings.TrimSpace(strings.TrimPrefix(line, "data:"))
}
}
}
Expand Down Expand Up @@ -269,6 +301,10 @@ func (c *SSEMCPClient) sendRequest(
}

req.Header.Set("Content-Type", "application/json")
// set custom http headers
for k, v := range c.headers {
req.Header.Set(k, v)
}

resp, err := c.httpClient.Do(req)
if err != nil {
Expand Down
2 changes: 1 addition & 1 deletion mcp/tools.go
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ func (t Tool) MarshalJSON() ([]byte, error) {

type ToolInputSchema struct {
Type string `json:"type"`
Properties map[string]interface{} `json:"properties,omitempty"`
Properties map[string]interface{} `json:"properties"`
Required []string `json:"required,omitempty"`
}

Expand Down