diff --git a/src/Middleware/Diagnostics/src/StatusCodePage/StatusCodePagesMiddleware.cs b/src/Middleware/Diagnostics/src/StatusCodePage/StatusCodePagesMiddleware.cs index 8c7ff70a8ca6..87994f6c7011 100644 --- a/src/Middleware/Diagnostics/src/StatusCodePage/StatusCodePagesMiddleware.cs +++ b/src/Middleware/Diagnostics/src/StatusCodePage/StatusCodePagesMiddleware.cs @@ -41,9 +41,9 @@ public async Task Invoke(HttpContext context) var statusCodeFeature = new StatusCodePagesFeature(); context.Features.Set(statusCodeFeature); var endpoint = context.GetEndpoint(); - var skipStatusCodePageMetadata = endpoint?.Metadata.GetMetadata(); + var shouldCheckEndpointAgain = endpoint is null; - if (skipStatusCodePageMetadata is not null) + if (HasSkipStatusCodePagesMetadata(endpoint)) { statusCodeFeature.Enabled = false; } @@ -57,6 +57,12 @@ public async Task Invoke(HttpContext context) return; } + if (shouldCheckEndpointAgain && HasSkipStatusCodePagesMetadata(context.GetEndpoint())) + { + // If the endpoint was null check the endpoint again since it could have been set by another middleware. + return; + } + // Do nothing if a response body has already been provided. if (context.Response.HasStarted || context.Response.StatusCode < 400 @@ -70,4 +76,11 @@ public async Task Invoke(HttpContext context) var statusCodeContext = new StatusCodeContext(context, _options, _next); await _options.HandleAsync(statusCodeContext); } + + private static bool HasSkipStatusCodePagesMetadata(Endpoint? endpoint) + { + var skipStatusCodePageMetadata = endpoint?.Metadata.GetMetadata(); + + return skipStatusCodePageMetadata is not null; + } } diff --git a/src/Middleware/Diagnostics/test/UnitTests/StatusCodeMiddlewareTest.cs b/src/Middleware/Diagnostics/test/UnitTests/StatusCodeMiddlewareTest.cs index 9d771df6f90e..52214d97ff90 100644 --- a/src/Middleware/Diagnostics/test/UnitTests/StatusCodeMiddlewareTest.cs +++ b/src/Middleware/Diagnostics/test/UnitTests/StatusCodeMiddlewareTest.cs @@ -313,4 +313,80 @@ public async Task SkipStatusCodePages_SupportsEndpoints() var content = await response.Content.ReadAsStringAsync(); Assert.Empty(content); } + + [Fact] + public async Task SkipStatusCodePages_SupportsSkipIfUsedBeforeRouting() + { + using var host = new HostBuilder() + .ConfigureWebHost(builder => + { + builder.UseTestServer() + .ConfigureServices(services => services.AddRouting()) + .Configure(app => + { + app.UseStatusCodePagesWithReExecute("/status"); + app.UseRouting(); + + app.UseEndpoints(endpoints => + { + endpoints.MapGet("/skip", [SkipStatusCodePages](c) => + { + c.Response.StatusCode = 400; + return Task.CompletedTask; + }); + + endpoints.MapGet("/status", (HttpResponse response) => $"Status: {response.StatusCode}"); + }); + + app.Run(_ => throw new InvalidOperationException("Invalid input provided.")); + }); + }).Build(); + + await host.StartAsync(); + + using var server = host.GetTestServer(); + var client = server.CreateClient(); + var response = await client.GetAsync("/skip"); + var content = await response.Content.ReadAsStringAsync(); + + Assert.Empty(content); + } + + [Fact] + public async Task SkipStatusCodePages_WorksIfUsedBeforeRouting() + { + using var host = new HostBuilder() + .ConfigureWebHost(builder => + { + builder.UseTestServer() + .ConfigureServices(services => services.AddRouting()) + .Configure(app => + { + app.UseStatusCodePagesWithReExecute("/status"); + app.UseRouting(); + + app.UseEndpoints(endpoints => + { + endpoints.MapGet("/", (c) => + { + c.Response.StatusCode = 400; + return Task.CompletedTask; + }); + + endpoints.MapGet("/status", (HttpResponse response) => $"Status: {response.StatusCode}"); + }); + + app.Run(_ => throw new InvalidOperationException("Invalid input provided.")); + }); + }).Build(); + + await host.StartAsync(); + + using var server = host.GetTestServer(); + var client = server.CreateClient(); + var response = await client.GetAsync("/"); + var content = await response.Content.ReadAsStringAsync(); + + Assert.Equal("Status: 400", content); + } }