Skip to content

Fix stdio encoding issue: Enforce explicit UTF-8 for correct Unicode handling #73

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 32 additions & 1 deletion src/ModelContextProtocol/Logging/Log.cs
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ internal static partial class Log
internal static partial void TransportNotConnected(this ILogger logger, string endpointName);

[LoggerMessage(Level = LogLevel.Information, Message = "Transport sending message for {endpointName} with ID {messageId}, JSON {json}")]
internal static partial void TransportSendingMessage(this ILogger logger, string endpointName, string messageId, string json);
internal static partial void TransportSendingMessage(this ILogger logger, string endpointName, string messageId, string? json = null);

[LoggerMessage(Level = LogLevel.Information, Message = "Transport message sent for {endpointName} with ID {messageId}")]
internal static partial void TransportSentMessage(this ILogger logger, string endpointName, string messageId);
Expand Down Expand Up @@ -347,4 +347,35 @@ public static partial void SSETransportPostNotAccepted(
string endpointName,
string messageId,
string responseContent);

/// <summary>
/// Logs the byte representation of a message in UTF-8 encoding.
/// </summary>
/// <param name="logger">The logger to use.</param>
/// <param name="endpointName">The name of the endpoint.</param>
/// <param name="byteRepresentation">The byte representation as a hex string.</param>
[LoggerMessage(EventId = 39000, Level = LogLevel.Trace, Message = "Transport {EndpointName}: Message bytes (UTF-8): {ByteRepresentation}")]
private static partial void TransportMessageBytes(this ILogger logger, string endpointName, string byteRepresentation);

/// <summary>
/// Logs the byte representation of a message for diagnostic purposes.
/// This is useful for diagnosing encoding issues with non-ASCII characters.
/// </summary>
/// <param name="logger">The logger to use.</param>
/// <param name="endpointName">The name of the endpoint.</param>
/// <param name="message">The message to log bytes for.</param>
internal static void TransportMessageBytesUtf8(this ILogger logger, string endpointName, string message)
{
if (logger.IsEnabled(LogLevel.Trace))
{
var bytes = System.Text.Encoding.UTF8.GetBytes(message);
var byteRepresentation =
#if NET
Convert.ToHexString(bytes);
#else
BitConverter.ToString(bytes).Replace("-", " ");
#endif
logger.TransportMessageBytes(endpointName, byteRepresentation);
}
}
}
84 changes: 57 additions & 27 deletions src/ModelContextProtocol/Protocol/Transport/StdioClientTransport.cs
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
using System.Diagnostics;
using System.Text.Json;
using Microsoft.Extensions.Logging;
using Microsoft.Extensions.Logging.Abstractions;
using ModelContextProtocol.Configuration;
using ModelContextProtocol.Logging;
using ModelContextProtocol.Protocol.Messages;
using ModelContextProtocol.Utils;
using ModelContextProtocol.Utils.Json;
using Microsoft.Extensions.Logging;
using Microsoft.Extensions.Logging.Abstractions;
using System.Diagnostics;
using System.Text;
using System.Text.Json;

namespace ModelContextProtocol.Protocol.Transport;

Expand Down Expand Up @@ -59,6 +60,8 @@ public async Task ConnectAsync(CancellationToken cancellationToken = default)

_shutdownCts = new CancellationTokenSource();

UTF8Encoding noBomUTF8 = new(encoderShouldEmitUTF8Identifier: false);

var startInfo = new ProcessStartInfo
{
FileName = _options.Command,
Expand All @@ -68,6 +71,11 @@ public async Task ConnectAsync(CancellationToken cancellationToken = default)
UseShellExecute = false,
CreateNoWindow = true,
WorkingDirectory = _options.WorkingDirectory ?? Environment.CurrentDirectory,
StandardOutputEncoding = noBomUTF8,
StandardErrorEncoding = noBomUTF8,
#if NET
StandardInputEncoding = noBomUTF8,
#endif
};

if (!string.IsNullOrWhiteSpace(_options.Arguments))
Expand All @@ -92,13 +100,35 @@ public async Task ConnectAsync(CancellationToken cancellationToken = default)
// Set up error logging
_process.ErrorDataReceived += (sender, args) => _logger.TransportError(EndpointName, args.Data ?? "(no data)");

if (!_process.Start())
// We need both stdin and stdout to use a no-BOM UTF-8 encoding. On .NET Core,
// we can use ProcessStartInfo.StandardOutputEncoding/StandardInputEncoding, but
// StandardInputEncoding doesn't exist on .NET Framework; instead, it always picks
// up the encoding from Console.InputEncoding. As such, when not targeting .NET Core,
// we temporarily change Console.InputEncoding to no-BOM UTF-8 around the Process.Start
// call, to ensure it picks up the correct encoding.
#if NET
_processStarted = _process.Start();
#else
Encoding originalInputEncoding = Console.InputEncoding;
try
{
Console.InputEncoding = noBomUTF8;
_processStarted = _process.Start();
}
finally
{
Console.InputEncoding = originalInputEncoding;
}
#endif

if (!_processStarted)
{
_logger.TransportProcessStartFailed(EndpointName);
throw new McpTransportException("Failed to start MCP server process");
}

_logger.TransportProcessStarted(EndpointName, _process.Id);
_processStarted = true;

_process.BeginErrorReadLine();

// Start reading messages in the background
Expand Down Expand Up @@ -134,9 +164,10 @@ public override async Task SendMessageAsync(IJsonRpcMessage message, Cancellatio
{
var json = JsonSerializer.Serialize(message, _jsonOptions.GetTypeInfo<IJsonRpcMessage>());
_logger.TransportSendingMessage(EndpointName, id, json);
_logger.TransportMessageBytesUtf8(EndpointName, json);

// Write the message followed by a newline
await _process!.StandardInput.WriteLineAsync(json.AsMemory(), cancellationToken).ConfigureAwait(false);
// Write the message followed by a newline using our UTF-8 writer
await _process!.StandardInput.WriteLineAsync(json).ConfigureAwait(false);
await _process.StandardInput.FlushAsync(cancellationToken).ConfigureAwait(false);

_logger.TransportSentMessage(EndpointName, id);
Expand All @@ -161,12 +192,10 @@ private async Task ReadMessagesAsync(CancellationToken cancellationToken)
{
_logger.TransportEnteringReadMessagesLoop(EndpointName);

using var reader = _process!.StandardOutput;

while (!cancellationToken.IsCancellationRequested && !_process.HasExited)
while (!cancellationToken.IsCancellationRequested && !_process!.HasExited)
{
_logger.TransportWaitingForMessage(EndpointName);
var line = await reader.ReadLineAsync(cancellationToken).ConfigureAwait(false);
var line = await _process.StandardOutput.ReadLineAsync(cancellationToken).ConfigureAwait(false);
if (line == null)
{
_logger.TransportEndOfStream(EndpointName);
Expand All @@ -179,6 +208,7 @@ private async Task ReadMessagesAsync(CancellationToken cancellationToken)
}

_logger.TransportReceivedMessage(EndpointName, line);
_logger.TransportMessageBytesUtf8(EndpointName, line);

await ProcessMessageAsync(line, cancellationToken).ConfigureAwait(false);
}
Expand Down Expand Up @@ -230,28 +260,27 @@ private async Task ProcessMessageAsync(string line, CancellationToken cancellati
private async Task CleanupAsync(CancellationToken cancellationToken)
{
_logger.TransportCleaningUp(EndpointName);
if (_process != null && _processStarted && !_process.HasExited)

if (_process is Process process && _processStarted && !process.HasExited)
{
try
{
// Try to close stdin to signal the process to exit
_logger.TransportClosingStdin(EndpointName);
_process.StandardInput.Close();

// Wait for the process to exit
_logger.TransportWaitingForShutdown(EndpointName);

// Kill the while process tree because the process may spawn child processes
// and Node.js does not kill its children when it exits properly
_process.KillTree(_options.ShutdownTimeout);
process.KillTree(_options.ShutdownTimeout);
}
catch (Exception ex)
{
_logger.TransportShutdownFailed(EndpointName, ex);
}

_process.Dispose();
_process = null;
finally
{
process.Dispose();
_process = null;
}
}

if (_shutdownCts is { } shutdownCts)
Expand All @@ -261,29 +290,30 @@ private async Task CleanupAsync(CancellationToken cancellationToken)
_shutdownCts = null;
}

if (_readTask != null)
if (_readTask is Task readTask)
{
try
{
_logger.TransportWaitingForReadTask(EndpointName);
await _readTask.WaitAsync(TimeSpan.FromSeconds(5), cancellationToken).ConfigureAwait(false);
await readTask.WaitAsync(TimeSpan.FromSeconds(5), cancellationToken).ConfigureAwait(false);
}
catch (TimeoutException)
{
_logger.TransportCleanupReadTaskTimeout(EndpointName);
// Continue with cleanup
}
catch (OperationCanceledException)
{
_logger.TransportCleanupReadTaskCancelled(EndpointName);
// Ignore cancellation
}
catch (Exception ex)
{
_logger.TransportCleanupReadTaskFailed(EndpointName, ex);
}
_readTask = null;
_logger.TransportReadTaskCleanedUp(EndpointName);
finally
{
_logger.TransportReadTaskCleanedUp(EndpointName);
_readTask = null;
}
}

SetConnected(false);
Expand Down
81 changes: 63 additions & 18 deletions src/ModelContextProtocol/Protocol/Transport/StdioServerTransport.cs
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
using System.Text.Json;
using Microsoft.Extensions.Logging;
using Microsoft.Extensions.Logging.Abstractions;
using Microsoft.Extensions.Options;
using ModelContextProtocol.Logging;
using ModelContextProtocol.Protocol.Messages;
using ModelContextProtocol.Server;
using ModelContextProtocol.Utils;
using ModelContextProtocol.Utils.Json;
using Microsoft.Extensions.Logging;
using Microsoft.Extensions.Logging.Abstractions;
using Microsoft.Extensions.Options;
using System.Text;
using System.Text.Json;

namespace ModelContextProtocol.Protocol.Transport;

Expand All @@ -15,12 +16,14 @@ namespace ModelContextProtocol.Protocol.Transport;
/// </summary>
public sealed class StdioServerTransport : TransportBase, IServerTransport
{
private static readonly byte[] s_newlineBytes = "\n"u8.ToArray();

private readonly string _serverName;
private readonly ILogger _logger;

private readonly JsonSerializerOptions _jsonOptions = McpJsonUtilities.DefaultOptions;
private readonly TextReader _stdin = Console.In;
private readonly TextWriter _stdout = Console.Out;
private readonly TextReader _stdInReader;
private readonly Stream _stdOutStream;

private Task? _readTask;
private CancellationTokenSource? _shutdownCts;
Expand Down Expand Up @@ -83,16 +86,50 @@ public StdioServerTransport(string serverName, ILoggerFactory? loggerFactory = n

_serverName = serverName;
_logger = (ILogger?)loggerFactory?.CreateLogger<StdioClientTransport>() ?? NullLogger.Instance;

// Get raw console streams and wrap them with UTF-8 encoding
_stdInReader = new StreamReader(Console.OpenStandardInput(), Encoding.UTF8);
_stdOutStream = new BufferedStream(Console.OpenStandardOutput());
}

/// <summary>
/// Initializes a new instance of the <see cref="StdioServerTransport"/> class with explicit input/output streams.
/// </summary>
/// <param name="serverName">The name of the server.</param>
/// <param name="stdinStream">The input TextReader to use.</param>
/// <param name="stdoutStream">The output TextWriter to use.</param>
/// <param name="loggerFactory">Optional logger factory used for logging employed by the transport.</param>
/// <exception cref="ArgumentNullException"><paramref name="serverName"/> is <see langword="null"/>.</exception>
/// <remarks>
/// <para>
/// This constructor is useful for testing scenarios where you want to redirect input/output.
/// </para>
/// </remarks>
public StdioServerTransport(string serverName, Stream stdinStream, Stream stdoutStream, ILoggerFactory? loggerFactory = null)
: base(loggerFactory)
{
Throw.IfNull(serverName);
Throw.IfNull(stdinStream);
Throw.IfNull(stdoutStream);

_serverName = serverName;
_logger = (ILogger?)loggerFactory?.CreateLogger<StdioClientTransport>() ?? NullLogger.Instance;

_stdInReader = new StreamReader(stdinStream, Encoding.UTF8);
_stdOutStream = stdoutStream;
}

/// <inheritdoc/>
public Task StartListeningAsync(CancellationToken cancellationToken = default)
{
_logger.LogDebug("Starting StdioServerTransport listener for {EndpointName}", EndpointName);

_shutdownCts = new CancellationTokenSource();

_readTask = Task.Run(async () => await ReadMessagesAsync(_shutdownCts.Token).ConfigureAwait(false), CancellationToken.None);

SetConnected(true);
_logger.LogDebug("StdioServerTransport now connected for {EndpointName}", EndpointName);

return Task.CompletedTask;
}
Expand All @@ -114,11 +151,11 @@ public override async Task SendMessageAsync(IJsonRpcMessage message, Cancellatio

try
{
var json = JsonSerializer.Serialize(message, _jsonOptions.GetTypeInfo<IJsonRpcMessage>());
_logger.TransportSendingMessage(EndpointName, id, json);
_logger.TransportSendingMessage(EndpointName, id);

await _stdout.WriteLineAsync(json.AsMemory(), cancellationToken).ConfigureAwait(false);
await _stdout.FlushAsync(cancellationToken).ConfigureAwait(false);
await JsonSerializer.SerializeAsync(_stdOutStream, message, _jsonOptions.GetTypeInfo<IJsonRpcMessage>(), cancellationToken).ConfigureAwait(false);
await _stdOutStream.WriteAsync(s_newlineBytes, cancellationToken).ConfigureAwait(false);
await _stdOutStream.FlushAsync(cancellationToken).ConfigureAwait(false);;

_logger.TransportSentMessage(EndpointName, id);
}
Expand Down Expand Up @@ -146,7 +183,7 @@ private async Task ReadMessagesAsync(CancellationToken cancellationToken)
{
_logger.TransportWaitingForMessage(EndpointName);

var reader = _stdin;
var reader = _stdInReader;
var line = await reader.ReadLineAsync(cancellationToken).ConfigureAwait(false);
if (line == null)
{
Expand All @@ -160,6 +197,7 @@ private async Task ReadMessagesAsync(CancellationToken cancellationToken)
}

_logger.TransportReceivedMessage(EndpointName, line);
_logger.TransportMessageBytesUtf8(EndpointName, line);

try
{
Expand Down Expand Up @@ -207,19 +245,20 @@ private async Task CleanupAsync(CancellationToken cancellationToken)
{
_logger.TransportCleaningUp(EndpointName);

if (_shutdownCts != null)
if (_shutdownCts is { } shutdownCts)
{
await _shutdownCts.CancelAsync().ConfigureAwait(false);
_shutdownCts.Dispose();
await shutdownCts.CancelAsync().ConfigureAwait(false);
shutdownCts.Dispose();

_shutdownCts = null;
}

if (_readTask != null)
if (_readTask is { } readTask)
{
try
{
_logger.TransportWaitingForReadTask(EndpointName);
await _readTask.WaitAsync(TimeSpan.FromSeconds(5), cancellationToken).ConfigureAwait(false);
await readTask.WaitAsync(TimeSpan.FromSeconds(5), cancellationToken).ConfigureAwait(false);
}
catch (TimeoutException)
{
Expand All @@ -235,10 +274,16 @@ private async Task CleanupAsync(CancellationToken cancellationToken)
{
_logger.TransportCleanupReadTaskFailed(EndpointName, ex);
}
_readTask = null;
_logger.TransportReadTaskCleanedUp(EndpointName);
finally
{
_logger.TransportReadTaskCleanedUp(EndpointName);
_readTask = null;
}
}

_stdInReader?.Dispose();
_stdOutStream?.Dispose();

SetConnected(false);
_logger.TransportCleanedUp(EndpointName);
}
Expand Down
Loading