Skip to content

Support all return types from handler in filters #41310

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
Apr 26, 2022
167 changes: 152 additions & 15 deletions src/Http/Http.Extensions/src/RequestDelegateFactory.cs
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,9 @@ public static partial class RequestDelegateFactory
{
private static readonly ParameterBindingMethodCache ParameterBindingMethodCache = new();

private static readonly MethodInfo ExecuteTaskOfTMethod = typeof(RequestDelegateFactory).GetMethod(nameof(ExecuteTask), BindingFlags.NonPublic | BindingFlags.Static)!;
private static readonly MethodInfo ExecuteTaskWithEmptyResultMethod = typeof(RequestDelegateFactory).GetMethod(nameof(ExecuteTaskWithEmptyResult), BindingFlags.NonPublic | BindingFlags.Static)!;
private static readonly MethodInfo ExecuteValueTaskWithEmptyResultMethod = typeof(RequestDelegateFactory).GetMethod(nameof(ExecuteValueTaskWithEmptyResult), BindingFlags.NonPublic | BindingFlags.Static)!;
private static readonly MethodInfo ExecuteTaskOfTMethod = typeof(RequestDelegateFactory).GetMethod(nameof(ExecuteTaskOfT), BindingFlags.NonPublic | BindingFlags.Static)!;
private static readonly MethodInfo ExecuteTaskOfStringMethod = typeof(RequestDelegateFactory).GetMethod(nameof(ExecuteTaskOfString), BindingFlags.NonPublic | BindingFlags.Static)!;
private static readonly MethodInfo ExecuteValueTaskOfTMethod = typeof(RequestDelegateFactory).GetMethod(nameof(ExecuteValueTaskOfT), BindingFlags.NonPublic | BindingFlags.Static)!;
private static readonly MethodInfo ExecuteValueTaskMethod = typeof(RequestDelegateFactory).GetMethod(nameof(ExecuteValueTask), BindingFlags.NonPublic | BindingFlags.Static)!;
Expand All @@ -47,6 +49,8 @@ public static partial class RequestDelegateFactory
private static readonly MethodInfo StringResultWriteResponseAsyncMethod = typeof(RequestDelegateFactory).GetMethod(nameof(ExecuteWriteStringResponseAsync), BindingFlags.NonPublic | BindingFlags.Static)!;
private static readonly MethodInfo StringIsNullOrEmptyMethod = typeof(string).GetMethod(nameof(string.IsNullOrEmpty), BindingFlags.Static | BindingFlags.Public)!;
private static readonly MethodInfo WrapObjectAsValueTaskMethod = typeof(RequestDelegateFactory).GetMethod(nameof(WrapObjectAsValueTask), BindingFlags.NonPublic | BindingFlags.Static)!;
private static readonly MethodInfo TaskOfTToValueTaskOfObjectMethod = typeof(RequestDelegateFactory).GetMethod(nameof(TaskOfTToValueTaskOfObject), BindingFlags.NonPublic | BindingFlags.Static)!;
private static readonly MethodInfo ValueTaskOfTToValueTaskOfObjectMethod = typeof(RequestDelegateFactory).GetMethod(nameof(ValueTaskOfTToValueTaskOfObject), BindingFlags.NonPublic | BindingFlags.Static)!;
private static readonly MethodInfo PopulateMetadataForParameterMethod = typeof(RequestDelegateFactory).GetMethod(nameof(PopulateMetadataForParameter), BindingFlags.NonPublic | BindingFlags.Static)!;
private static readonly MethodInfo PopulateMetadataForEndpointMethod = typeof(RequestDelegateFactory).GetMethod(nameof(PopulateMetadataForEndpoint), BindingFlags.NonPublic | BindingFlags.Static)!;

Expand Down Expand Up @@ -258,24 +262,40 @@ private static RouteHandlerFilterDelegate CreateFilterPipeline(MethodInfo method
// httpContext.Response.StatusCode >= 400
// ? Task.CompletedTask
// : {
// target = targetFactory(httpContext);
// handler is ((Type)target).MethodName(parameters);
// handler((string)context.Parameters[0], (int)context.Parameters[1]);
// handlerInvocation
// }
var filteredInvocation = Expression.Lambda<RouteHandlerFilterDelegate>(
Expression.Condition(
Expression.GreaterThanOrEqual(FilterContextHttpContextStatusCodeExpr, Expression.Constant(400)),
CompletedValueTaskExpr,
Expression.Block(
// To generate the handler invocation, we first create the
// target of the handler provided to the route.
// target = targetFactory(httpContext);
// This target is then used to generate the handler invocation like so;
// ((Type)target).MethodName(parameters);
// When `handler` returns an object, we generate the following wrapper
// to convert it to `ValueTask<object?>` as expected in the filter
// pipeline.
// ValueTask<object?>.FromResult(handler((string)context.Parameters[0], (int)context.Parameters[1]));
// When the `handler` is a generic Task or ValueTask we await the task and
// create a `ValueTask<object?> from the resulting value.
// new ValueTask<object?>(await handler((string)context.Parameters[0], (int)context.Parameters[1]));
// When the `handler` returns a void or a void-returning Task, then we return an EmptyHttpResult
// to as a ValueTask<object?>
// }
var handlerReturnMapping = MapHandlerReturnTypeToValueTask(
targetExpression is null
? Expression.Call(methodInfo, factoryContext.ContextArgAccess)
: Expression.Call(targetExpression, methodInfo, factoryContext.ContextArgAccess),
methodInfo.ReturnType);
var handlerInvocation = Expression.Block(
new[] { TargetExpr },
targetFactory == null
? Expression.Empty()
: Expression.Assign(TargetExpr, Expression.Invoke(targetFactory, FilterContextHttpContextExpr)),
Expression.Call(WrapObjectAsValueTaskMethod,
targetExpression is null
? Expression.Call(methodInfo, factoryContext.ContextArgAccess)
: Expression.Call(targetExpression, methodInfo, factoryContext.ContextArgAccess))
)),
handlerReturnMapping
);
var filteredInvocation = Expression.Lambda<RouteHandlerFilterDelegate>(
Expression.Condition(
Expression.GreaterThanOrEqual(FilterContextHttpContextStatusCodeExpr, Expression.Constant(400)),
CompletedValueTaskExpr,
handlerInvocation),
FilterContextExpr).Compile();
var routeHandlerContext = new RouteHandlerContext(
methodInfo,
Expand All @@ -292,6 +312,72 @@ targetExpression is null
return filteredInvocation;
}

private static Expression MapHandlerReturnTypeToValueTask(Expression methodCall, Type returnType)
{
if (returnType == typeof(void))
{
return Expression.Block(methodCall, Expression.Constant(new ValueTask<object?>(EmptyHttpResult.Instance)));
}
else if (returnType == typeof(Task))
{
return Expression.Call(ExecuteTaskWithEmptyResultMethod, methodCall);
}
else if (returnType == typeof(ValueTask))
{
return Expression.Call(ExecuteValueTaskWithEmptyResultMethod, methodCall);
}
else if (returnType == typeof(ValueTask<object?>))
{
return methodCall;
}
else if (returnType.IsGenericType &&
returnType.GetGenericTypeDefinition() == typeof(ValueTask<>))
{
var typeArg = returnType.GetGenericArguments()[0];
return Expression.Call(ValueTaskOfTToValueTaskOfObjectMethod.MakeGenericMethod(typeArg), methodCall);
}
else if (returnType.IsGenericType &&
returnType.GetGenericTypeDefinition() == typeof(Task<>))
{
var typeArg = returnType.GetGenericArguments()[0];
return Expression.Call(TaskOfTToValueTaskOfObjectMethod.MakeGenericMethod(typeArg), methodCall);
}
else
{
return Expression.Call(WrapObjectAsValueTaskMethod, methodCall);
}
}

private static ValueTask<object?> ValueTaskOfTToValueTaskOfObject<T>(ValueTask<T> valueTask)
{
static async ValueTask<object?> ExecuteAwaited(ValueTask<T> valueTask)
{
return await valueTask;
}

if (valueTask.IsCompletedSuccessfully)
{
return new ValueTask<object?>(valueTask.Result);
}

return ExecuteAwaited(valueTask);
}

private static ValueTask<object?> TaskOfTToValueTaskOfObject<T>(Task<T> task)
{
static async ValueTask<object?> ExecuteAwaited(Task<T> task)
{
return await task;
}

if (task.IsCompletedSuccessfully)
{
return new ValueTask<object?>(task.Result);
}

return ExecuteAwaited(task);
}

private static void AddTypeProvidedMetadata(MethodInfo methodInfo, List<object> metadata, IServiceProvider? services)
{
object?[]? invokeArgs = null;
Expand Down Expand Up @@ -1649,7 +1735,7 @@ private static async Task ExecuteObjectReturn(object? obj, HttpContext httpConte
}
}

private static Task ExecuteTask<T>(Task<T> task, HttpContext httpContext)
private static Task ExecuteTaskOfT<T>(Task<T> task, HttpContext httpContext)
{
EnsureRequestTaskNotNull(task);

Expand Down Expand Up @@ -1707,6 +1793,39 @@ static async Task ExecuteAwaited(ValueTask task)
return ExecuteAwaited(task);
}

private static ValueTask<object?> ExecuteTaskWithEmptyResult(Task task)
{
static async ValueTask<object?> ExecuteAwaited(Task task)
{
await task;
return EmptyHttpResult.Instance;
}

if (task.IsCompletedSuccessfully)
{
return new ValueTask<object?>(EmptyHttpResult.Instance);
}

return ExecuteAwaited(task);
}

private static ValueTask<object?> ExecuteValueTaskWithEmptyResult(ValueTask valueTask)
{
static async ValueTask<object?> ExecuteAwaited(ValueTask task)
{
await task;
return EmptyHttpResult.Instance;
}

if (valueTask.IsCompletedSuccessfully)
{
valueTask.GetAwaiter().GetResult();
return new ValueTask<object?>(EmptyHttpResult.Instance);
}

return ExecuteAwaited(valueTask);
}

private static Task ExecuteValueTaskOfT<T>(ValueTask<T> task, HttpContext httpContext)
{
static async Task ExecuteAwaited(ValueTask<T> task, HttpContext httpContext)
Expand Down Expand Up @@ -2041,4 +2160,22 @@ private static void FormatTrackedParameters(FactoryContext factoryContext, Strin
errorMessage.AppendLine(FormattableString.Invariant($"{kv.Key,-19} | {kv.Value,-15}"));
}
}

// Due to cyclic references between Http.Extensions and
// Http.Results, we define our own instance of the `EmptyHttpResult`
// type here.
private sealed class EmptyHttpResult : IResult
{
private EmptyHttpResult()
{
}

public static EmptyHttpResult Instance { get; } = new();

/// <inheritdoc/>
public Task ExecuteAsync(HttpContext httpContext)
{
return Task.CompletedTask;
}
}
}
Loading