diff --git a/src/Microsoft.Extensions.DependencyInjection.Specification.Tests/DependencyInjectionSpecificationTests.cs b/src/Microsoft.Extensions.DependencyInjection.Specification.Tests/DependencyInjectionSpecificationTests.cs index aef2cdbc..8901f69b 100644 --- a/src/Microsoft.Extensions.DependencyInjection.Specification.Tests/DependencyInjectionSpecificationTests.cs +++ b/src/Microsoft.Extensions.DependencyInjection.Specification.Tests/DependencyInjectionSpecificationTests.cs @@ -655,5 +655,30 @@ public void ServiceContainerPicksConstructorWithLongestMatches( Assert.Same(expected.MultipleService, actual.MultipleService); Assert.Same(expected.ScopedService, actual.ScopedService); } + + [Fact] + public void DisposesInReverseOrderOfCreation() + { + // Arrange + var serviceCollection = new TestServiceCollection(); + serviceCollection.AddSingleton(); + serviceCollection.AddTransient(); + serviceCollection.AddSingleton(); + serviceCollection.AddScoped(); + serviceCollection.AddTransient(); + serviceCollection.AddSingleton(); + var serviceProvider = CreateServiceProvider(serviceCollection); + + var callback = serviceProvider.GetService(); + var outer = serviceProvider.GetService(); + + // Act + ((IDisposable)serviceProvider).Dispose(); + + // Assert + Assert.Equal(outer, callback.Disposed[0]); + Assert.Equal(outer.MultipleServices.Reverse(), callback.Disposed.Skip(1).Take(3).OfType()); + Assert.Equal(outer.SingleService, callback.Disposed[4]); + } } } diff --git a/src/Microsoft.Extensions.DependencyInjection.Specification.Tests/Fakes/FakeDisposableCallbackInnerService.cs b/src/Microsoft.Extensions.DependencyInjection.Specification.Tests/Fakes/FakeDisposableCallbackInnerService.cs new file mode 100644 index 00000000..c8581330 --- /dev/null +++ b/src/Microsoft.Extensions.DependencyInjection.Specification.Tests/Fakes/FakeDisposableCallbackInnerService.cs @@ -0,0 +1,12 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +namespace Microsoft.Extensions.DependencyInjection.Specification.Fakes +{ + public class FakeDisposableCallbackInnerService : FakeDisposableCallbackService, IFakeMultipleService + { + public FakeDisposableCallbackInnerService(FakeDisposeCallback callback) : base(callback) + { + } + } +} \ No newline at end of file diff --git a/src/Microsoft.Extensions.DependencyInjection.Specification.Tests/Fakes/FakeDisposableCallbackOuterService.cs b/src/Microsoft.Extensions.DependencyInjection.Specification.Tests/Fakes/FakeDisposableCallbackOuterService.cs new file mode 100644 index 00000000..d400c122 --- /dev/null +++ b/src/Microsoft.Extensions.DependencyInjection.Specification.Tests/Fakes/FakeDisposableCallbackOuterService.cs @@ -0,0 +1,22 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System.Collections.Generic; + +namespace Microsoft.Extensions.DependencyInjection.Specification.Fakes +{ + public class FakeDisposableCallbackOuterService : FakeDisposableCallbackService, IFakeOuterService + { + public FakeDisposableCallbackOuterService( + IFakeService singleService, + IEnumerable multipleServices, + FakeDisposeCallback callback) : base(callback) + { + SingleService = singleService; + MultipleServices = multipleServices; + } + + public IFakeService SingleService { get; } + public IEnumerable MultipleServices { get; } + } +} \ No newline at end of file diff --git a/src/Microsoft.Extensions.DependencyInjection.Specification.Tests/Fakes/FakeDisposableCallbackService.cs b/src/Microsoft.Extensions.DependencyInjection.Specification.Tests/Fakes/FakeDisposableCallbackService.cs new file mode 100644 index 00000000..53e09579 --- /dev/null +++ b/src/Microsoft.Extensions.DependencyInjection.Specification.Tests/Fakes/FakeDisposableCallbackService.cs @@ -0,0 +1,30 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; + +namespace Microsoft.Extensions.DependencyInjection.Specification.Fakes +{ + public class FakeDisposableCallbackService: IDisposable + { + private static int _globalId; + private readonly int _id; + private readonly FakeDisposeCallback _callback; + + public FakeDisposableCallbackService(FakeDisposeCallback callback) + { + _id = _globalId++; + _callback = callback; + } + + public void Dispose() + { + _callback.Disposed.Add(this); + } + + public override string ToString() + { + return _id.ToString(); + } + } +} \ No newline at end of file diff --git a/src/Microsoft.Extensions.DependencyInjection.Specification.Tests/Fakes/FakeDisposeCallback.cs b/src/Microsoft.Extensions.DependencyInjection.Specification.Tests/Fakes/FakeDisposeCallback.cs new file mode 100644 index 00000000..4fab8d63 --- /dev/null +++ b/src/Microsoft.Extensions.DependencyInjection.Specification.Tests/Fakes/FakeDisposeCallback.cs @@ -0,0 +1,12 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System.Collections.Generic; + +namespace Microsoft.Extensions.DependencyInjection.Specification.Fakes +{ + public class FakeDisposeCallback + { + public List Disposed { get; } = new List(); + } +} \ No newline at end of file diff --git a/src/Microsoft.Extensions.DependencyInjection.Specification.Tests/Microsoft.Extensions.DependencyInjection.Specification.Tests.csproj b/src/Microsoft.Extensions.DependencyInjection.Specification.Tests/Microsoft.Extensions.DependencyInjection.Specification.Tests.csproj index 1b5893e7..90dea721 100644 --- a/src/Microsoft.Extensions.DependencyInjection.Specification.Tests/Microsoft.Extensions.DependencyInjection.Specification.Tests.csproj +++ b/src/Microsoft.Extensions.DependencyInjection.Specification.Tests/Microsoft.Extensions.DependencyInjection.Specification.Tests.csproj @@ -19,4 +19,8 @@ + + + + diff --git a/src/Microsoft.Extensions.DependencyInjection/ServiceLookup/CallSiteExpressionBuilder.cs b/src/Microsoft.Extensions.DependencyInjection/ServiceLookup/CallSiteExpressionBuilder.cs index e8748ef9..4632cacc 100644 --- a/src/Microsoft.Extensions.DependencyInjection/ServiceLookup/CallSiteExpressionBuilder.cs +++ b/src/Microsoft.Extensions.DependencyInjection/ServiceLookup/CallSiteExpressionBuilder.cs @@ -146,16 +146,24 @@ protected override Expression VisitClosedIEnumerable(ClosedIEnumerableCallSite c protected override Expression VisitTransient(TransientCallSite callSite, ParameterExpression provider) { var implType = callSite.Service.ImplementationType; - // Elide calls to GetCaptureDisposable if the implemenation type isn't disposable + return TryCaptureDisposible( + implType, + provider, + VisitCallSite(callSite.ServiceCallSite, provider)); + } + + private Expression TryCaptureDisposible(Type implType, ParameterExpression provider, Expression service) + { + if (implType != null && !typeof(IDisposable).GetTypeInfo().IsAssignableFrom(implType.GetTypeInfo())) { - return VisitCallSite(callSite.ServiceCallSite, provider); + return service; } return Expression.Invoke(GetCaptureDisposable(provider), - VisitCallSite(callSite.ServiceCallSite, provider)); + service); } protected override Expression VisitConstructor(ConstructorCallSite callSite, ParameterExpression provider) @@ -184,7 +192,7 @@ protected override Expression VisitScoped(ScopedCallSite callSite, ParameterExpr callSite.Key, typeof(object)); - var resolvedExpression = Expression.Variable(typeof(object), "resolved"); + var resolvedVariable = Expression.Variable(typeof(object), "resolved"); var resolvedServices = GetResolvedServices(provider); @@ -192,26 +200,32 @@ protected override Expression VisitScoped(ScopedCallSite callSite, ParameterExpr resolvedServices, TryGetValueMethodInfo, keyExpression, - resolvedExpression); + resolvedVariable); + + var service = VisitCallSite(callSite.ServiceCallSite, provider); + var captureDisposible = TryCaptureDisposible(callSite.Key.ImplementationType, provider, service); var assignExpression = Expression.Assign( - resolvedExpression, VisitCallSite(callSite.ServiceCallSite, provider)); + resolvedVariable, + captureDisposible); var addValueExpression = Expression.Call( resolvedServices, AddMethodInfo, keyExpression, - resolvedExpression); + resolvedVariable); var blockExpression = Expression.Block( typeof(object), new[] { - resolvedExpression + resolvedVariable }, Expression.IfThen( Expression.Not(tryGetValueExpression), - Expression.Block(assignExpression, addValueExpression)), - resolvedExpression); + Expression.Block( + assignExpression, + addValueExpression)), + resolvedVariable); return blockExpression; } diff --git a/src/Microsoft.Extensions.DependencyInjection/ServiceLookup/CallSiteRuntimeResolver.cs b/src/Microsoft.Extensions.DependencyInjection/ServiceLookup/CallSiteRuntimeResolver.cs index 81e9e9ff..b0f8d798 100644 --- a/src/Microsoft.Extensions.DependencyInjection/ServiceLookup/CallSiteRuntimeResolver.cs +++ b/src/Microsoft.Extensions.DependencyInjection/ServiceLookup/CallSiteRuntimeResolver.cs @@ -49,6 +49,7 @@ protected override object VisitScoped(ScopedCallSite scopedCallSite, ServiceProv if (!provider.ResolvedServices.TryGetValue(scopedCallSite.Key, out resolved)) { resolved = VisitCallSite(scopedCallSite.ServiceCallSite, provider); + provider.CaptureDisposable(resolved); provider.ResolvedServices.Add(scopedCallSite.Key, resolved); } } diff --git a/src/Microsoft.Extensions.DependencyInjection/ServiceProvider.cs b/src/Microsoft.Extensions.DependencyInjection/ServiceProvider.cs index f55224db..39c7b18e 100644 --- a/src/Microsoft.Extensions.DependencyInjection/ServiceProvider.cs +++ b/src/Microsoft.Extensions.DependencyInjection/ServiceProvider.cs @@ -19,7 +19,6 @@ internal class ServiceProvider : IServiceProvider, IDisposable { private readonly CallSiteValidator _callSiteValidator; private readonly ServiceTable _table; - private readonly ServiceProviderOptions _options; private bool _disposeCalled; private List _transientDisposables; @@ -43,7 +42,6 @@ public ServiceProvider(IEnumerable serviceDescriptors, Servic _callSiteValidator = new CallSiteValidator(); } - _options = options; _table = new ServiceTable(serviceDescriptors); _table.Add(typeof(IServiceProvider), new ServiceProviderService()); @@ -169,22 +167,15 @@ public void Dispose() _disposeCalled = true; if (_transientDisposables != null) { - foreach (var disposable in _transientDisposables) + for (int i = _transientDisposables.Count - 1; i >= 0; i--) { + var disposable = _transientDisposables[i]; disposable.Dispose(); } _transientDisposables.Clear(); } - // PERF: We've enumerating the dictionary so that we don't allocate to enumerate. - // .Values allocates a ValueCollection on the heap, enumerating the dictionary allocates - // a struct enumerator - foreach (var entry in ResolvedServices) - { - (entry.Value as IDisposable)?.Dispose(); - } - ResolvedServices.Clear(); } } diff --git a/test/Microsoft.Extensions.DependencyInjection.Tests/CallSiteTests.cs b/test/Microsoft.Extensions.DependencyInjection.Tests/CallSiteTests.cs index 03e1146b..dcd3cae4 100644 --- a/test/Microsoft.Extensions.DependencyInjection.Tests/CallSiteTests.cs +++ b/test/Microsoft.Extensions.DependencyInjection.Tests/CallSiteTests.cs @@ -4,7 +4,6 @@ using System; using System.Collections; using System.Collections.Generic; -using System.Linq.Expressions; using Microsoft.Extensions.DependencyInjection.ServiceLookup; using Microsoft.Extensions.DependencyInjection.Specification.Fakes; using Xunit; @@ -121,12 +120,19 @@ public void BuiltExpressionCanResolveNestedScopedService() Assert.Equal(serviceC, Invoke(callSite, provider)); } - [Fact] - public void BuildExpressionElidesDisposableCaptureForNonDisposableServices() + [Theory] + [InlineData(ServiceLifetime.Scoped)] + [InlineData(ServiceLifetime.Transient)] + // We are not testing singleton here because singleton resolutions always got through + // runtime resolver and there is no sense to eliminating call from there + public void BuildExpressionElidesDisposableCaptureForNonDisposableServices(ServiceLifetime lifetime) { - var descriptors = new ServiceCollection(); - descriptors.AddTransient(); - descriptors.AddTransient(); + IServiceCollection descriptors = new ServiceCollection(); + descriptors.Add(ServiceDescriptor.Describe(typeof(ServiceA), typeof(ServiceA), lifetime)); + descriptors.Add(ServiceDescriptor.Describe(typeof(ServiceB), typeof(ServiceB), lifetime)); + descriptors.Add(ServiceDescriptor.Describe(typeof(ServiceC), typeof(ServiceC), lifetime)); + + descriptors.AddScoped(); descriptors.AddTransient(); var disposables = new List(); @@ -143,12 +149,16 @@ public void BuildExpressionElidesDisposableCaptureForNonDisposableServices() Assert.Equal(0, disposables.Count); } - [Fact] - public void BuildExpressionElidesDisposableCaptureForEnumerableServices() + [Theory] + [InlineData(ServiceLifetime.Scoped)] + [InlineData(ServiceLifetime.Transient)] + // We are not testing singleton here because singleton resolutions always got through + // runtime resolver and there is no sense to eliminating call from there + public void BuildExpressionElidesDisposableCaptureForEnumerableServices(ServiceLifetime lifetime) { - var descriptors = new ServiceCollection(); - descriptors.AddTransient(); - descriptors.AddTransient(); + IServiceCollection descriptors = new ServiceCollection(); + descriptors.Add(ServiceDescriptor.Describe(typeof(ServiceA), typeof(ServiceA), lifetime)); + descriptors.Add(ServiceDescriptor.Describe(typeof(ServiceD), typeof(ServiceD), lifetime)); var disposables = new List(); var provider = new ServiceProvider(descriptors, ServiceProviderOptions.Default);