Skip to content

Commit 3039dc4

Browse files
authored
Support all return types from handler in filters (#41310)
* Support all return types from handler in filters * Address feedback from peer review * Fix up Task and ValueTask handling * ExecuteAwait to ExecuteAwaited * Actually await void-returning Tasks and ValueTasks * Update comment for new logic * Polish and update tests * Tweaks and test fixes * Tweak await of void-returning Task * Fix correct ExecuteTask invocation
1 parent 46ebda5 commit 3039dc4

File tree

2 files changed

+437
-15
lines changed

2 files changed

+437
-15
lines changed

src/Http/Http.Extensions/src/RequestDelegateFactory.cs

Lines changed: 152 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,9 @@ public static partial class RequestDelegateFactory
3333
{
3434
private static readonly ParameterBindingMethodCache ParameterBindingMethodCache = new();
3535

36-
private static readonly MethodInfo ExecuteTaskOfTMethod = typeof(RequestDelegateFactory).GetMethod(nameof(ExecuteTask), BindingFlags.NonPublic | BindingFlags.Static)!;
36+
private static readonly MethodInfo ExecuteTaskWithEmptyResultMethod = typeof(RequestDelegateFactory).GetMethod(nameof(ExecuteTaskWithEmptyResult), BindingFlags.NonPublic | BindingFlags.Static)!;
37+
private static readonly MethodInfo ExecuteValueTaskWithEmptyResultMethod = typeof(RequestDelegateFactory).GetMethod(nameof(ExecuteValueTaskWithEmptyResult), BindingFlags.NonPublic | BindingFlags.Static)!;
38+
private static readonly MethodInfo ExecuteTaskOfTMethod = typeof(RequestDelegateFactory).GetMethod(nameof(ExecuteTaskOfT), BindingFlags.NonPublic | BindingFlags.Static)!;
3739
private static readonly MethodInfo ExecuteTaskOfStringMethod = typeof(RequestDelegateFactory).GetMethod(nameof(ExecuteTaskOfString), BindingFlags.NonPublic | BindingFlags.Static)!;
3840
private static readonly MethodInfo ExecuteValueTaskOfTMethod = typeof(RequestDelegateFactory).GetMethod(nameof(ExecuteValueTaskOfT), BindingFlags.NonPublic | BindingFlags.Static)!;
3941
private static readonly MethodInfo ExecuteValueTaskMethod = typeof(RequestDelegateFactory).GetMethod(nameof(ExecuteValueTask), BindingFlags.NonPublic | BindingFlags.Static)!;
@@ -47,6 +49,8 @@ public static partial class RequestDelegateFactory
4749
private static readonly MethodInfo StringResultWriteResponseAsyncMethod = typeof(RequestDelegateFactory).GetMethod(nameof(ExecuteWriteStringResponseAsync), BindingFlags.NonPublic | BindingFlags.Static)!;
4850
private static readonly MethodInfo StringIsNullOrEmptyMethod = typeof(string).GetMethod(nameof(string.IsNullOrEmpty), BindingFlags.Static | BindingFlags.Public)!;
4951
private static readonly MethodInfo WrapObjectAsValueTaskMethod = typeof(RequestDelegateFactory).GetMethod(nameof(WrapObjectAsValueTask), BindingFlags.NonPublic | BindingFlags.Static)!;
52+
private static readonly MethodInfo TaskOfTToValueTaskOfObjectMethod = typeof(RequestDelegateFactory).GetMethod(nameof(TaskOfTToValueTaskOfObject), BindingFlags.NonPublic | BindingFlags.Static)!;
53+
private static readonly MethodInfo ValueTaskOfTToValueTaskOfObjectMethod = typeof(RequestDelegateFactory).GetMethod(nameof(ValueTaskOfTToValueTaskOfObject), BindingFlags.NonPublic | BindingFlags.Static)!;
5054
private static readonly MethodInfo PopulateMetadataForParameterMethod = typeof(RequestDelegateFactory).GetMethod(nameof(PopulateMetadataForParameter), BindingFlags.NonPublic | BindingFlags.Static)!;
5155
private static readonly MethodInfo PopulateMetadataForEndpointMethod = typeof(RequestDelegateFactory).GetMethod(nameof(PopulateMetadataForEndpoint), BindingFlags.NonPublic | BindingFlags.Static)!;
5256

@@ -258,24 +262,40 @@ private static RouteHandlerFilterDelegate CreateFilterPipeline(MethodInfo method
258262
// httpContext.Response.StatusCode >= 400
259263
// ? Task.CompletedTask
260264
// : {
261-
// target = targetFactory(httpContext);
262-
// handler is ((Type)target).MethodName(parameters);
263-
// handler((string)context.Parameters[0], (int)context.Parameters[1]);
265+
// handlerInvocation
264266
// }
265-
var filteredInvocation = Expression.Lambda<RouteHandlerFilterDelegate>(
266-
Expression.Condition(
267-
Expression.GreaterThanOrEqual(FilterContextHttpContextStatusCodeExpr, Expression.Constant(400)),
268-
CompletedValueTaskExpr,
269-
Expression.Block(
267+
// To generate the handler invocation, we first create the
268+
// target of the handler provided to the route.
269+
// target = targetFactory(httpContext);
270+
// This target is then used to generate the handler invocation like so;
271+
// ((Type)target).MethodName(parameters);
272+
// When `handler` returns an object, we generate the following wrapper
273+
// to convert it to `ValueTask<object?>` as expected in the filter
274+
// pipeline.
275+
// ValueTask<object?>.FromResult(handler((string)context.Parameters[0], (int)context.Parameters[1]));
276+
// When the `handler` is a generic Task or ValueTask we await the task and
277+
// create a `ValueTask<object?> from the resulting value.
278+
// new ValueTask<object?>(await handler((string)context.Parameters[0], (int)context.Parameters[1]));
279+
// When the `handler` returns a void or a void-returning Task, then we return an EmptyHttpResult
280+
// to as a ValueTask<object?>
281+
// }
282+
var handlerReturnMapping = MapHandlerReturnTypeToValueTask(
283+
targetExpression is null
284+
? Expression.Call(methodInfo, factoryContext.ContextArgAccess)
285+
: Expression.Call(targetExpression, methodInfo, factoryContext.ContextArgAccess),
286+
methodInfo.ReturnType);
287+
var handlerInvocation = Expression.Block(
270288
new[] { TargetExpr },
271289
targetFactory == null
272290
? Expression.Empty()
273291
: Expression.Assign(TargetExpr, Expression.Invoke(targetFactory, FilterContextHttpContextExpr)),
274-
Expression.Call(WrapObjectAsValueTaskMethod,
275-
targetExpression is null
276-
? Expression.Call(methodInfo, factoryContext.ContextArgAccess)
277-
: Expression.Call(targetExpression, methodInfo, factoryContext.ContextArgAccess))
278-
)),
292+
handlerReturnMapping
293+
);
294+
var filteredInvocation = Expression.Lambda<RouteHandlerFilterDelegate>(
295+
Expression.Condition(
296+
Expression.GreaterThanOrEqual(FilterContextHttpContextStatusCodeExpr, Expression.Constant(400)),
297+
CompletedValueTaskExpr,
298+
handlerInvocation),
279299
FilterContextExpr).Compile();
280300
var routeHandlerContext = new RouteHandlerContext(
281301
methodInfo,
@@ -292,6 +312,72 @@ targetExpression is null
292312
return filteredInvocation;
293313
}
294314

315+
private static Expression MapHandlerReturnTypeToValueTask(Expression methodCall, Type returnType)
316+
{
317+
if (returnType == typeof(void))
318+
{
319+
return Expression.Block(methodCall, Expression.Constant(new ValueTask<object?>(EmptyHttpResult.Instance)));
320+
}
321+
else if (returnType == typeof(Task))
322+
{
323+
return Expression.Call(ExecuteTaskWithEmptyResultMethod, methodCall);
324+
}
325+
else if (returnType == typeof(ValueTask))
326+
{
327+
return Expression.Call(ExecuteValueTaskWithEmptyResultMethod, methodCall);
328+
}
329+
else if (returnType == typeof(ValueTask<object?>))
330+
{
331+
return methodCall;
332+
}
333+
else if (returnType.IsGenericType &&
334+
returnType.GetGenericTypeDefinition() == typeof(ValueTask<>))
335+
{
336+
var typeArg = returnType.GetGenericArguments()[0];
337+
return Expression.Call(ValueTaskOfTToValueTaskOfObjectMethod.MakeGenericMethod(typeArg), methodCall);
338+
}
339+
else if (returnType.IsGenericType &&
340+
returnType.GetGenericTypeDefinition() == typeof(Task<>))
341+
{
342+
var typeArg = returnType.GetGenericArguments()[0];
343+
return Expression.Call(TaskOfTToValueTaskOfObjectMethod.MakeGenericMethod(typeArg), methodCall);
344+
}
345+
else
346+
{
347+
return Expression.Call(WrapObjectAsValueTaskMethod, methodCall);
348+
}
349+
}
350+
351+
private static ValueTask<object?> ValueTaskOfTToValueTaskOfObject<T>(ValueTask<T> valueTask)
352+
{
353+
static async ValueTask<object?> ExecuteAwaited(ValueTask<T> valueTask)
354+
{
355+
return await valueTask;
356+
}
357+
358+
if (valueTask.IsCompletedSuccessfully)
359+
{
360+
return new ValueTask<object?>(valueTask.Result);
361+
}
362+
363+
return ExecuteAwaited(valueTask);
364+
}
365+
366+
private static ValueTask<object?> TaskOfTToValueTaskOfObject<T>(Task<T> task)
367+
{
368+
static async ValueTask<object?> ExecuteAwaited(Task<T> task)
369+
{
370+
return await task;
371+
}
372+
373+
if (task.IsCompletedSuccessfully)
374+
{
375+
return new ValueTask<object?>(task.Result);
376+
}
377+
378+
return ExecuteAwaited(task);
379+
}
380+
295381
private static void AddTypeProvidedMetadata(MethodInfo methodInfo, List<object> metadata, IServiceProvider? services)
296382
{
297383
object?[]? invokeArgs = null;
@@ -1649,7 +1735,7 @@ private static async Task ExecuteObjectReturn(object? obj, HttpContext httpConte
16491735
}
16501736
}
16511737

1652-
private static Task ExecuteTask<T>(Task<T> task, HttpContext httpContext)
1738+
private static Task ExecuteTaskOfT<T>(Task<T> task, HttpContext httpContext)
16531739
{
16541740
EnsureRequestTaskNotNull(task);
16551741

@@ -1707,6 +1793,39 @@ static async Task ExecuteAwaited(ValueTask task)
17071793
return ExecuteAwaited(task);
17081794
}
17091795

1796+
private static ValueTask<object?> ExecuteTaskWithEmptyResult(Task task)
1797+
{
1798+
static async ValueTask<object?> ExecuteAwaited(Task task)
1799+
{
1800+
await task;
1801+
return EmptyHttpResult.Instance;
1802+
}
1803+
1804+
if (task.IsCompletedSuccessfully)
1805+
{
1806+
return new ValueTask<object?>(EmptyHttpResult.Instance);
1807+
}
1808+
1809+
return ExecuteAwaited(task);
1810+
}
1811+
1812+
private static ValueTask<object?> ExecuteValueTaskWithEmptyResult(ValueTask valueTask)
1813+
{
1814+
static async ValueTask<object?> ExecuteAwaited(ValueTask task)
1815+
{
1816+
await task;
1817+
return EmptyHttpResult.Instance;
1818+
}
1819+
1820+
if (valueTask.IsCompletedSuccessfully)
1821+
{
1822+
valueTask.GetAwaiter().GetResult();
1823+
return new ValueTask<object?>(EmptyHttpResult.Instance);
1824+
}
1825+
1826+
return ExecuteAwaited(valueTask);
1827+
}
1828+
17101829
private static Task ExecuteValueTaskOfT<T>(ValueTask<T> task, HttpContext httpContext)
17111830
{
17121831
static async Task ExecuteAwaited(ValueTask<T> task, HttpContext httpContext)
@@ -2041,4 +2160,22 @@ private static void FormatTrackedParameters(FactoryContext factoryContext, Strin
20412160
errorMessage.AppendLine(FormattableString.Invariant($"{kv.Key,-19} | {kv.Value,-15}"));
20422161
}
20432162
}
2163+
2164+
// Due to cyclic references between Http.Extensions and
2165+
// Http.Results, we define our own instance of the `EmptyHttpResult`
2166+
// type here.
2167+
private sealed class EmptyHttpResult : IResult
2168+
{
2169+
private EmptyHttpResult()
2170+
{
2171+
}
2172+
2173+
public static EmptyHttpResult Instance { get; } = new();
2174+
2175+
/// <inheritdoc/>
2176+
public Task ExecuteAsync(HttpContext httpContext)
2177+
{
2178+
return Task.CompletedTask;
2179+
}
2180+
}
20442181
}

0 commit comments

Comments
 (0)