diff --git a/src/libraries/Microsoft.Extensions.DependencyInjection/src/ServiceLookup/CallSiteFactory.cs b/src/libraries/Microsoft.Extensions.DependencyInjection/src/ServiceLookup/CallSiteFactory.cs index 68c89a386e7c18..61b41eb631f022 100644 --- a/src/libraries/Microsoft.Extensions.DependencyInjection/src/ServiceLookup/CallSiteFactory.cs +++ b/src/libraries/Microsoft.Extensions.DependencyInjection/src/ServiceLookup/CallSiteFactory.cs @@ -613,14 +613,22 @@ private ConstructorCallSite CreateConstructorCallSite( { if (serviceIdentifier.ServiceKey != null && attribute is ServiceKeyAttribute) { - // Check if the parameter type matches - if (parameterType != serviceIdentifier.ServiceKey.GetType()) + // Even though the parameter may be strongly typed, support 'object' if AnyKey is used. + + if (serviceIdentifier.ServiceKey == KeyedService.AnyKey) + { + parameterType = typeof(object); + } + else if (parameterType != serviceIdentifier.ServiceKey.GetType() + && parameterType != typeof(object)) { throw new InvalidOperationException(SR.InvalidServiceKeyType); } + callSite = new ConstantCallSite(parameterType, serviceIdentifier.ServiceKey); break; } + if (attribute is FromKeyedServicesAttribute keyed) { var parameterSvcId = new ServiceIdentifier(keyed.Key, parameterType); diff --git a/src/libraries/Microsoft.Extensions.DependencyInjection/tests/DI.Tests/ServiceProviderValidationTests.cs b/src/libraries/Microsoft.Extensions.DependencyInjection/tests/DI.Tests/ServiceProviderValidationTests.cs index 312043e56d4774..3feb626c175331 100644 --- a/src/libraries/Microsoft.Extensions.DependencyInjection/tests/DI.Tests/ServiceProviderValidationTests.cs +++ b/src/libraries/Microsoft.Extensions.DependencyInjection/tests/DI.Tests/ServiceProviderValidationTests.cs @@ -146,6 +146,131 @@ public void GetService_DoesNotThrow_WhenGetServiceForPolymorphicServiceIsCalledO Assert.IsType(actual); } + [Theory] + [InlineData(true)] + [InlineData(false)] + public void BuildServiceProvider_AnyKey_ServiceKeyWithStronglyTypedArgument(bool validateOnBuild) + { + // Arrange + var serviceCollection = new ServiceCollection(); + serviceCollection.AddKeyedTransient(KeyedService.AnyKey); + using var serviceProvider = serviceCollection.BuildServiceProvider(new ServiceProviderOptions + { + ValidateScopes = true, + ValidateOnBuild = validateOnBuild + }); + + // Act + var actual = serviceProvider.GetKeyedService(42); + + // Assert + Assert.Equal(42, actual.Key); + Assert.Throws(() => serviceProvider.GetKeyedService("Hello")); + } + + private class ServiceKeyWithStronglyTypedArgument + { + public int Key { get; set; } + + public ServiceKeyWithStronglyTypedArgument([ServiceKey] int key) + { + Key = key; + } + } + + [Theory] + [InlineData(true)] + [InlineData(false)] + public void BuildServiceProvider_AnyKey_ServiceKeyWithObjectTypedArgument(bool validateOnBuild) + { + // Arrange + var serviceCollection = new ServiceCollection(); + serviceCollection.AddKeyedTransient(KeyedService.AnyKey); + using var serviceProvider = serviceCollection.BuildServiceProvider(new ServiceProviderOptions + { + ValidateScopes = true, + ValidateOnBuild = validateOnBuild + }); + + // Act + var actualInt = serviceProvider.GetKeyedService(42); + var actualString = serviceProvider.GetKeyedService("hello"); + + // Assert + Assert.Equal(42, actualInt.Key); + Assert.Equal("hello", actualString.Key); + } + + [Theory] + [InlineData(true)] + [InlineData(false)] + public void BuildServiceProvider_ServiceKeyWithObjectTypedArgument(bool validateOnBuild) + { + // Arrange + var serviceCollection = new ServiceCollection(); + serviceCollection.AddKeyedTransient(42); + serviceCollection.AddKeyedTransient("hello"); + using var serviceProvider = serviceCollection.BuildServiceProvider(new ServiceProviderOptions + { + ValidateScopes = true, + ValidateOnBuild = validateOnBuild + }); + + // Act + var actualInt = serviceProvider.GetKeyedService(42); + var actualString = serviceProvider.GetKeyedService("hello"); + var notFound = serviceProvider.GetKeyedService(false); + + // Assert + Assert.Equal(42, actualInt.Key); + Assert.Equal("hello", actualString.Key); + Assert.Null(notFound); + } + + private class ServiceKeyWithObjectTypedArgument + { + public object Key { get; set; } + + public ServiceKeyWithObjectTypedArgument([ServiceKey] object key) + { + Key = key; + } + } + + [Theory] + [InlineData(true)] + [InlineData(false)] + public void BuildServiceProvider_AnyKey_ServiceKeyWithObjectAndIntTypedArguments(bool validateOnBuild) + { + // Arrange + var serviceCollection = new ServiceCollection(); + serviceCollection.AddKeyedTransient(KeyedService.AnyKey); + using var serviceProvider = serviceCollection.BuildServiceProvider(new ServiceProviderOptions + { + ValidateScopes = true, + ValidateOnBuild = validateOnBuild + }); + + // Act + var actual = serviceProvider.GetKeyedService(42); + + // Assert + Assert.Equal(42, actual.Key); + Assert.Equal(42, actual.IntKey); + } + + private class ServiceKeyWithObjectAndIntTypedArguments + { + public object Key { get; set; } + public int IntKey { get; set; } + + public ServiceKeyWithObjectAndIntTypedArguments([ServiceKey] object key, [ServiceKey] int intKey) + { + Key = key; + IntKey = intKey; + } + } + [Fact] public void ScopeValidation_ShouldBeAbleToDistingushGenericCollections_WhenGetServiceIsCalledOnRoot() {