Skip to content

Commit 1f62c43

Browse files
Support SkipStatusCodePages when not using WebApplication builder and UseStatusCodePagesWithReExecute (#46109)
1 parent 83d6c56 commit 1f62c43

File tree

2 files changed

+91
-2
lines changed

2 files changed

+91
-2
lines changed

src/Middleware/Diagnostics/src/StatusCodePage/StatusCodePagesMiddleware.cs

+15-2
Original file line numberDiff line numberDiff line change
@@ -41,9 +41,9 @@ public async Task Invoke(HttpContext context)
4141
var statusCodeFeature = new StatusCodePagesFeature();
4242
context.Features.Set<IStatusCodePagesFeature>(statusCodeFeature);
4343
var endpoint = context.GetEndpoint();
44-
var skipStatusCodePageMetadata = endpoint?.Metadata.GetMetadata<ISkipStatusCodePagesMetadata>();
44+
var shouldCheckEndpointAgain = endpoint is null;
4545

46-
if (skipStatusCodePageMetadata is not null)
46+
if (HasSkipStatusCodePagesMetadata(endpoint))
4747
{
4848
statusCodeFeature.Enabled = false;
4949
}
@@ -57,6 +57,12 @@ public async Task Invoke(HttpContext context)
5757
return;
5858
}
5959

60+
if (shouldCheckEndpointAgain && HasSkipStatusCodePagesMetadata(context.GetEndpoint()))
61+
{
62+
// If the endpoint was null check the endpoint again since it could have been set by another middleware.
63+
return;
64+
}
65+
6066
// Do nothing if a response body has already been provided.
6167
if (context.Response.HasStarted
6268
|| context.Response.StatusCode < 400
@@ -70,4 +76,11 @@ public async Task Invoke(HttpContext context)
7076
var statusCodeContext = new StatusCodeContext(context, _options, _next);
7177
await _options.HandleAsync(statusCodeContext);
7278
}
79+
80+
private static bool HasSkipStatusCodePagesMetadata(Endpoint? endpoint)
81+
{
82+
var skipStatusCodePageMetadata = endpoint?.Metadata.GetMetadata<ISkipStatusCodePagesMetadata>();
83+
84+
return skipStatusCodePageMetadata is not null;
85+
}
7386
}

src/Middleware/Diagnostics/test/UnitTests/StatusCodeMiddlewareTest.cs

+76
Original file line numberDiff line numberDiff line change
@@ -313,4 +313,80 @@ public async Task SkipStatusCodePages_SupportsEndpoints()
313313
var content = await response.Content.ReadAsStringAsync();
314314
Assert.Empty(content);
315315
}
316+
317+
[Fact]
318+
public async Task SkipStatusCodePages_SupportsSkipIfUsedBeforeRouting()
319+
{
320+
using var host = new HostBuilder()
321+
.ConfigureWebHost(builder =>
322+
{
323+
builder.UseTestServer()
324+
.ConfigureServices(services => services.AddRouting())
325+
.Configure(app =>
326+
{
327+
app.UseStatusCodePagesWithReExecute("/status");
328+
app.UseRouting();
329+
330+
app.UseEndpoints(endpoints =>
331+
{
332+
endpoints.MapGet("/skip", [SkipStatusCodePages](c) =>
333+
{
334+
c.Response.StatusCode = 400;
335+
return Task.CompletedTask;
336+
});
337+
338+
endpoints.MapGet("/status", (HttpResponse response) => $"Status: {response.StatusCode}");
339+
});
340+
341+
app.Run(_ => throw new InvalidOperationException("Invalid input provided."));
342+
});
343+
}).Build();
344+
345+
await host.StartAsync();
346+
347+
using var server = host.GetTestServer();
348+
var client = server.CreateClient();
349+
var response = await client.GetAsync("/skip");
350+
var content = await response.Content.ReadAsStringAsync();
351+
352+
Assert.Empty(content);
353+
}
354+
355+
[Fact]
356+
public async Task SkipStatusCodePages_WorksIfUsedBeforeRouting()
357+
{
358+
using var host = new HostBuilder()
359+
.ConfigureWebHost(builder =>
360+
{
361+
builder.UseTestServer()
362+
.ConfigureServices(services => services.AddRouting())
363+
.Configure(app =>
364+
{
365+
app.UseStatusCodePagesWithReExecute("/status");
366+
app.UseRouting();
367+
368+
app.UseEndpoints(endpoints =>
369+
{
370+
endpoints.MapGet("/", (c) =>
371+
{
372+
c.Response.StatusCode = 400;
373+
return Task.CompletedTask;
374+
});
375+
376+
endpoints.MapGet("/status", (HttpResponse response) => $"Status: {response.StatusCode}");
377+
});
378+
379+
app.Run(_ => throw new InvalidOperationException("Invalid input provided."));
380+
});
381+
}).Build();
382+
383+
await host.StartAsync();
384+
385+
using var server = host.GetTestServer();
386+
var client = server.CreateClient();
387+
var response = await client.GetAsync("/");
388+
var content = await response.Content.ReadAsStringAsync();
389+
390+
Assert.Equal("Status: 400", content);
391+
}
316392
}

0 commit comments

Comments
 (0)