diff --git a/src/Http/Http.Extensions/src/RequestDelegateFactory.cs b/src/Http/Http.Extensions/src/RequestDelegateFactory.cs index ab30b6aecfc8..80afaac46aad 100644 --- a/src/Http/Http.Extensions/src/RequestDelegateFactory.cs +++ b/src/Http/Http.Extensions/src/RequestDelegateFactory.cs @@ -558,6 +558,12 @@ private static Expression AddResponseWritingToMethodCall(Expression methodCall, var feature = httpContext.Features.Get(); if (feature?.CanHaveBody == true) { + if (!httpContext.Request.HasJsonContentType()) + { + Log.UnexpectedContentType(httpContext, httpContext.Request.ContentType); + httpContext.Response.StatusCode = StatusCodes.Status415UnsupportedMediaType; + return; + } try { bodyValue = await httpContext.Request.ReadFromJsonAsync(bodyType); @@ -590,6 +596,12 @@ private static Expression AddResponseWritingToMethodCall(Expression methodCall, var feature = httpContext.Features.Get(); if (feature?.CanHaveBody == true) { + if (!httpContext.Request.HasJsonContentType()) + { + Log.UnexpectedContentType(httpContext, httpContext.Request.ContentType); + httpContext.Response.StatusCode = StatusCodes.Status415UnsupportedMediaType; + return; + } try { bodyValue = await httpContext.Request.ReadFromJsonAsync(bodyType); @@ -603,7 +615,7 @@ private static Expression AddResponseWritingToMethodCall(Expression methodCall, { Log.RequestBodyInvalidDataException(httpContext, ex, factoryContext.ThrowOnBadRequest); - httpContext.Response.StatusCode = 400; + httpContext.Response.StatusCode = StatusCodes.Status400BadRequest; return; } } @@ -1204,6 +1216,14 @@ public static void RequiredParameterNotProvided(HttpContext httpContext, string [LoggerMessage(4, LogLevel.Debug, RequiredParameterNotProvidedLogMessage, EventName = "RequiredParameterNotProvided")] private static partial void RequiredParameterNotProvided(ILogger logger, string parameterType, string parameterName, string source); + public static void UnexpectedContentType(HttpContext httpContext, string? contentType) + => UnexpectedContentType(GetLogger(httpContext), contentType ?? "(none)"); + + [LoggerMessage(6, LogLevel.Debug, + "Expected a supported JSON media type but got \"{ContentType}\".", + EventName = "UnexpectedContentType")] + private static partial void UnexpectedContentType(ILogger logger, string contentType); + private static ILogger GetLogger(HttpContext httpContext) { var loggerFactory = httpContext.RequestServices.GetRequiredService(); diff --git a/src/Http/Http.Extensions/test/RequestDelegateFactoryTests.cs b/src/Http/Http.Extensions/test/RequestDelegateFactoryTests.cs index 9f9fd442589b..b35764bf4490 100644 --- a/src/Http/Http.Extensions/test/RequestDelegateFactoryTests.cs +++ b/src/Http/Http.Extensions/test/RequestDelegateFactoryTests.cs @@ -2360,7 +2360,65 @@ public async Task CanExecuteRequestDelegateWithResultsExtension() Assert.False(httpContext.RequestAborted.IsCancellationRequested); var decodedResponseBody = Encoding.UTF8.GetString(responseBodyStream.ToArray()); Assert.Equal(@"""Hello Tester. This is from an extension method.""", decodedResponseBody); + } + + [Fact] + public async Task RequestDelegateRejectsNonJsonContent() + { + var httpContext = new DefaultHttpContext(); + httpContext.Request.Headers["Content-Type"] = "application/xml"; + httpContext.Request.Headers["Content-Length"] = "1"; + httpContext.Features.Set(new RequestBodyDetectionFeature(true)); + + var serviceCollection = new ServiceCollection(); + serviceCollection.AddSingleton(LoggerFactory); + httpContext.RequestServices = serviceCollection.BuildServiceProvider(); + + var factoryResult = RequestDelegateFactory.Create((HttpContext context, Todo todo) => + { + }); + var requestDelegate = factoryResult.RequestDelegate; + + await requestDelegate(httpContext); + + Assert.Equal(415, httpContext.Response.StatusCode); + var logMessage = Assert.Single(TestSink.Writes); + Assert.Equal(new EventId(6, "UnexpectedContentType"), logMessage.EventId); + Assert.Equal(LogLevel.Debug, logMessage.LogLevel); + } + + [Fact] + public async Task RequestDelegateWithBindAndImplicitBodyRejectsNonJsonContent() + { + Todo originalTodo = new() + { + Name = "Write more tests!" + }; + + var httpContext = new DefaultHttpContext(); + + var requestBodyBytes = JsonSerializer.SerializeToUtf8Bytes(originalTodo); + var stream = new MemoryStream(requestBodyBytes); + httpContext.Request.Body = stream; + httpContext.Request.Headers["Content-Type"] = "application/xml"; + httpContext.Request.Headers["Content-Length"] = stream.Length.ToString(CultureInfo.InvariantCulture); + httpContext.Features.Set(new RequestBodyDetectionFeature(true)); + + var serviceCollection = new ServiceCollection(); + serviceCollection.AddSingleton(LoggerFactory); + httpContext.RequestServices = serviceCollection.BuildServiceProvider(); + + var factoryResult = RequestDelegateFactory.Create((HttpContext context, JsonTodo customTodo, Todo todo) => + { + }); + var requestDelegate = factoryResult.RequestDelegate; + + await requestDelegate(httpContext); + Assert.Equal(415, httpContext.Response.StatusCode); + var logMessage = Assert.Single(TestSink.Writes); + Assert.Equal(new EventId(6, "UnexpectedContentType"), logMessage.EventId); + Assert.Equal(LogLevel.Debug, logMessage.LogLevel); } private DefaultHttpContext CreateHttpContext() @@ -2399,6 +2457,17 @@ private class CustomTodo : Todo } } + private class JsonTodo : Todo + { + public static async ValueTask BindAsync(HttpContext context, ParameterInfo parameter) + { + // manually call deserialize so we don't check content type + var body = await JsonSerializer.DeserializeAsync(context.Request.Body); + context.Request.Body.Position = 0; + return body; + } + } + private record struct TodoStruct(int Id, string? Name, bool IsComplete) : ITodo; private interface ITodo