diff --git a/src/coreclr/tools/aot/ILCompiler.Compiler/Compiler/Compilation.cs b/src/coreclr/tools/aot/ILCompiler.Compiler/Compiler/Compilation.cs index bf0ea7662fec0b..6a542d8d00a748 100644 --- a/src/coreclr/tools/aot/ILCompiler.Compiler/Compiler/Compilation.cs +++ b/src/coreclr/tools/aot/ILCompiler.Compiler/Compiler/Compilation.cs @@ -105,12 +105,12 @@ public bool CanInline(MethodDesc caller, MethodDesc callee) public bool CanReferenceConstructedMethodTable(TypeDesc type) { - return NodeFactory.DevirtualizationManager.CanReferenceConstructedMethodTable(type); + return NodeFactory.DevirtualizationManager.CanReferenceConstructedMethodTable(type.NormalizeInstantiation()); } public bool CanReferenceConstructedTypeOrCanonicalFormOfType(TypeDesc type) { - return NodeFactory.DevirtualizationManager.CanReferenceConstructedTypeOrCanonicalFormOfType(type); + return NodeFactory.DevirtualizationManager.CanReferenceConstructedTypeOrCanonicalFormOfType(type.NormalizeInstantiation()); } public DelegateCreationInfo GetDelegateCtor(TypeDesc delegateType, MethodDesc target, TypeDesc constrainedType, bool followVirtualDispatch) @@ -266,9 +266,7 @@ public bool NeedsRuntimeLookup(ReadyToRunHelperId lookupKind, object targetOfLoo public ReadyToRunHelperId GetLdTokenHelperForType(TypeDesc type) { - bool canConstructPerWholeProgramAnalysis = NodeFactory.DevirtualizationManager.CanReferenceConstructedMethodTable(type); - bool creationAllowed = ConstructedEETypeNode.CreationAllowed(type); - return (canConstructPerWholeProgramAnalysis && creationAllowed) + return (ConstructedEETypeNode.CreationAllowed(type) && NodeFactory.DevirtualizationManager.CanReferenceConstructedMethodTable(type.NormalizeInstantiation())) ? ReadyToRunHelperId.TypeHandle : ReadyToRunHelperId.NecessaryTypeHandle; } diff --git a/src/coreclr/tools/aot/ILCompiler.Compiler/Compiler/ILScanner.cs b/src/coreclr/tools/aot/ILCompiler.Compiler/Compiler/ILScanner.cs index 97820c425b33a2..5bf78fb125efc1 100644 --- a/src/coreclr/tools/aot/ILCompiler.Compiler/Compiler/ILScanner.cs +++ b/src/coreclr/tools/aot/ILCompiler.Compiler/Compiler/ILScanner.cs @@ -703,10 +703,18 @@ protected override MethodDesc ResolveVirtualMethod(MethodDesc declMethod, DefTyp } public override bool CanReferenceConstructedMethodTable(TypeDesc type) - => _constructedMethodTables.Contains(type); + { + Debug.Assert(type.NormalizeInstantiation() == type); + Debug.Assert(ConstructedEETypeNode.CreationAllowed(type)); + return _constructedMethodTables.Contains(type); + } public override bool CanReferenceConstructedTypeOrCanonicalFormOfType(TypeDesc type) - => _constructedMethodTables.Contains(type) || _canonConstructedMethodTables.Contains(type); + { + Debug.Assert(type.NormalizeInstantiation() == type); + Debug.Assert(ConstructedEETypeNode.CreationAllowed(type)); + return _constructedMethodTables.Contains(type) || _canonConstructedMethodTables.Contains(type); + } public override TypeDesc[] GetImplementingClasses(TypeDesc type) { diff --git a/src/coreclr/tools/aot/ILCompiler.Compiler/Compiler/SubstitutedILProvider.cs b/src/coreclr/tools/aot/ILCompiler.Compiler/Compiler/SubstitutedILProvider.cs index 248bdf891243cd..04db2b136d4bc1 100644 --- a/src/coreclr/tools/aot/ILCompiler.Compiler/Compiler/SubstitutedILProvider.cs +++ b/src/coreclr/tools/aot/ILCompiler.Compiler/Compiler/SubstitutedILProvider.cs @@ -22,11 +22,13 @@ public class SubstitutedILProvider : ILProvider { private readonly ILProvider _nestedILProvider; private readonly SubstitutionProvider _substitutionProvider; + private readonly DevirtualizationManager _devirtualizationManager; - public SubstitutedILProvider(ILProvider nestedILProvider, SubstitutionProvider substitutionProvider) + public SubstitutedILProvider(ILProvider nestedILProvider, SubstitutionProvider substitutionProvider, DevirtualizationManager devirtualizationManager) { _nestedILProvider = nestedILProvider; _substitutionProvider = substitutionProvider; + _devirtualizationManager = devirtualizationManager; } public override MethodIL GetMethodIL(MethodDesc method) @@ -871,7 +873,26 @@ private static bool TryExpandTypeIs(MethodIL methodIL, byte[] body, OpcodeFlags[ return true; } - private static bool TryExpandTypeEquality(MethodIL methodIL, byte[] body, OpcodeFlags[] flags, int offset, string op, out int constant) + private bool TryExpandTypeEquality(MethodIL methodIL, byte[] body, OpcodeFlags[] flags, int offset, string op, out int constant) + { + if (TryExpandTypeEquality_TokenToken(methodIL, body, flags, offset, out constant) + || TryExpandTypeEquality_TokenOther(methodIL, body, flags, offset, 1, expectGetType: false, out constant) + || TryExpandTypeEquality_TokenOther(methodIL, body, flags, offset, 2, expectGetType: false, out constant) + || TryExpandTypeEquality_TokenOther(methodIL, body, flags, offset, 3, expectGetType: false, out constant) + || TryExpandTypeEquality_TokenOther(methodIL, body, flags, offset, 1, expectGetType: true, out constant) + || TryExpandTypeEquality_TokenOther(methodIL, body, flags, offset, 2, expectGetType: true, out constant) + || TryExpandTypeEquality_TokenOther(methodIL, body, flags, offset, 3, expectGetType: true, out constant)) + { + if (op == "op_Inequality") + constant ^= 1; + + return true; + } + + return false; + } + + private static bool TryExpandTypeEquality_TokenToken(MethodIL methodIL, byte[] body, OpcodeFlags[] flags, int offset, out int constant) { // We expect to see a sequence: // ldtoken Foo @@ -919,9 +940,108 @@ private static bool TryExpandTypeEquality(MethodIL methodIL, byte[] body, Opcode constant = equality.Value ? 1 : 0; - if (op == "op_Inequality") - constant ^= 1; + return true; + } + + private bool TryExpandTypeEquality_TokenOther(MethodIL methodIL, byte[] body, OpcodeFlags[] flags, int offset, int ldInstructionSize, bool expectGetType, out int constant) + { + // We expect to see a sequence: + // ldtoken Foo + // call GetTypeFromHandle + // ldloc.X/ldloc_s X/ldarg.X/ldarg_s X + // [optional] call Object.GetType + // -> offset points here + // + // The ldtoken part can potentially be in the second argument position + + constant = 0; + int sequenceLength = 5 + 5 + ldInstructionSize + (expectGetType ? 5 : 0); + if (offset < sequenceLength) + return false; + + if ((flags[offset - sequenceLength] & OpcodeFlags.InstructionStart) == 0) + return false; + + ILReader reader = new ILReader(body, offset - sequenceLength); + + TypeDesc knownType = null; + + // Is the ldtoken in the first position? + if (reader.PeekILOpcode() == ILOpcode.ldtoken) + { + knownType = ReadLdToken(ref reader, methodIL, flags); + if (knownType == null) + return false; + + if (!ReadGetTypeFromHandle(ref reader, methodIL, flags)) + return false; + } + + ILOpcode opcode = reader.ReadILOpcode(); + if (ldInstructionSize == 1 && opcode is (>= ILOpcode.ldloc_0 and <= ILOpcode.ldloc_3) or (>= ILOpcode.ldarg_0 and <= ILOpcode.ldarg_3)) + { + // Nothing to read + } + else if (ldInstructionSize == 2 && opcode is ILOpcode.ldloc_s or ILOpcode.ldarg_s) + { + reader.ReadILByte(); + } + else if (ldInstructionSize == 3 && opcode is ILOpcode.ldloc or ILOpcode.ldarg) + { + reader.ReadILUInt16(); + } + else + { + return false; + } + + if ((flags[reader.Offset] & OpcodeFlags.BasicBlockStart) != 0) + return false; + + if (expectGetType) + { + if (reader.ReadILOpcode() is not ILOpcode.callvirt and not ILOpcode.call) + return false; + + // We don't actually mind if this is not Object.GetType + reader.ReadILToken(); + + if ((flags[reader.Offset] & OpcodeFlags.BasicBlockStart) != 0) + return false; + } + + // If the ldtoken wasn't in the first position, it must be in the other + if (knownType == null) + { + knownType = ReadLdToken(ref reader, methodIL, flags); + if (knownType == null) + return false; + + if (!ReadGetTypeFromHandle(ref reader, methodIL, flags)) + return false; + } + + // No value in making this work for definitions + if (knownType.IsGenericDefinition) + return false; + + // Dataflow runs on top of uninstantiated IL and we can't answer some questions there. + // Unfortunately this means dataflow will still see code that the rest of the system + // might have optimized away. It should not be a problem in practice. + if (knownType.ContainsSignatureVariables()) + return false; + + if (knownType.IsCanonicalDefinitionType(CanonicalFormKind.Any)) + return false; + // We don't track types without a constructed MethodTable very well. + if (!ConstructedEETypeNode.CreationAllowed(knownType)) + return false; + + if (_devirtualizationManager.CanReferenceConstructedTypeOrCanonicalFormOfType(knownType.NormalizeInstantiation())) + return false; + + constant = 0; return true; } diff --git a/src/coreclr/tools/aot/ILCompiler.RyuJit/Compiler/RyuJitCompilation.cs b/src/coreclr/tools/aot/ILCompiler.RyuJit/Compiler/RyuJitCompilation.cs index d25a14f5765312..eafbb5058e5edb 100644 --- a/src/coreclr/tools/aot/ILCompiler.RyuJit/Compiler/RyuJitCompilation.cs +++ b/src/coreclr/tools/aot/ILCompiler.RyuJit/Compiler/RyuJitCompilation.cs @@ -72,7 +72,8 @@ public override IEETypeNode NecessaryTypeSymbolIfPossible(TypeDesc type) // information proving that it isn't, give RyuJIT the constructed symbol even // though we just need the unconstructed one. // https://github.com/dotnet/runtimelab/issues/1128 - bool canPotentiallyConstruct = NodeFactory.DevirtualizationManager.CanReferenceConstructedMethodTable(type); + bool canPotentiallyConstruct = ConstructedEETypeNode.CreationAllowed(type) + && NodeFactory.DevirtualizationManager.CanReferenceConstructedMethodTable(type); if (canPotentiallyConstruct) return _nodeFactory.MaximallyConstructableType(type); @@ -81,7 +82,8 @@ public override IEETypeNode NecessaryTypeSymbolIfPossible(TypeDesc type) public FrozenRuntimeTypeNode NecessaryRuntimeTypeIfPossible(TypeDesc type) { - bool canPotentiallyConstruct = NodeFactory.DevirtualizationManager.CanReferenceConstructedMethodTable(type); + bool canPotentiallyConstruct = ConstructedEETypeNode.CreationAllowed(type) + && NodeFactory.DevirtualizationManager.CanReferenceConstructedMethodTable(type); if (canPotentiallyConstruct) return _nodeFactory.SerializedMaximallyConstructableRuntimeTypeObject(type); diff --git a/src/coreclr/tools/aot/ILCompiler.Trimming.Tests/TestCasesRunner/TrimmingDriver.cs b/src/coreclr/tools/aot/ILCompiler.Trimming.Tests/TestCasesRunner/TrimmingDriver.cs index 689c1327f538f3..83b5d7ebe12da2 100644 --- a/src/coreclr/tools/aot/ILCompiler.Trimming.Tests/TestCasesRunner/TrimmingDriver.cs +++ b/src/coreclr/tools/aot/ILCompiler.Trimming.Tests/TestCasesRunner/TrimmingDriver.cs @@ -113,7 +113,7 @@ public ILScanResults Trim (ILCompilerOptions options, TrimmingCustomizations? cu } SubstitutionProvider substitutionProvider = new SubstitutionProvider(logger, featureSwitches, substitutions); - ilProvider = new SubstitutedILProvider(ilProvider, substitutionProvider); + ilProvider = new SubstitutedILProvider(ilProvider, substitutionProvider, new DevirtualizationManager()); CompilerGeneratedState compilerGeneratedState = new CompilerGeneratedState (ilProvider, logger); diff --git a/src/coreclr/tools/aot/ILCompiler/Program.cs b/src/coreclr/tools/aot/ILCompiler/Program.cs index 91b4f8f02cf9af..39ef3af5dc4cea 100644 --- a/src/coreclr/tools/aot/ILCompiler/Program.cs +++ b/src/coreclr/tools/aot/ILCompiler/Program.cs @@ -378,7 +378,8 @@ public int Run() } SubstitutionProvider substitutionProvider = new SubstitutionProvider(logger, featureSwitches, substitutions); - ilProvider = new SubstitutedILProvider(ilProvider, substitutionProvider); + ILProvider unsubstitutedILProvider = ilProvider; + ilProvider = new SubstitutedILProvider(ilProvider, substitutionProvider, new DevirtualizationManager()); CompilerGeneratedState compilerGeneratedState = new CompilerGeneratedState(ilProvider, logger); @@ -492,10 +493,17 @@ void RunScanner() if (scanDgmlLogFileName != null) scanResults.WriteDependencyLog(scanDgmlLogFileName); + DevirtualizationManager devirtualizationManager = scanResults.GetDevirtualizationManager(); + metadataManager = ((UsageBasedMetadataManager)metadataManager).ToAnalysisBasedMetadataManager(); interopStubManager = scanResults.GetInteropStubManager(interopStateManager, pinvokePolicy); + ilProvider = new SubstitutedILProvider(unsubstitutedILProvider, substitutionProvider, devirtualizationManager); + + // Use a more precise IL provider that uses whole program analysis for dead branch elimination + builder.UseILProvider(ilProvider); + // If we have a scanner, feed the vtable analysis results to the compilation. // This could be a command line switch if we really wanted to. builder.UseVTableSliceProvider(scanResults.GetVTableLayoutInfo()); @@ -507,7 +515,7 @@ void RunScanner() // If we have a scanner, we can drive devirtualization using the information // we collected at scanning time (effectively sealing unsealed types if possible). // This could be a command line switch if we really wanted to. - builder.UseDevirtualizationManager(scanResults.GetDevirtualizationManager()); + builder.UseDevirtualizationManager(devirtualizationManager); // If we use the scanner's result, we need to consult it to drive inlining. // This prevents e.g. devirtualizing and inlining methods on types that were diff --git a/src/tests/nativeaot/SmokeTests/TrimmingBehaviors/DeadCodeElimination.cs b/src/tests/nativeaot/SmokeTests/TrimmingBehaviors/DeadCodeElimination.cs index eb03716ce3d18f..9cdeab80e8e36f 100644 --- a/src/tests/nativeaot/SmokeTests/TrimmingBehaviors/DeadCodeElimination.cs +++ b/src/tests/nativeaot/SmokeTests/TrimmingBehaviors/DeadCodeElimination.cs @@ -346,19 +346,91 @@ sealed class Gen { } sealed class Never { } - static Type s_type = null; + class Never2 { } + class Canary2 { } + class Never3 { } + class Canary3 { } + + class Maybe1 { } + + [MethodImpl(MethodImplOptions.NoInlining)] + static Type GetTheType() => null; + + [MethodImpl(MethodImplOptions.NoInlining)] + static Type GetThePointerType() => typeof(void*); + + [MethodImpl(MethodImplOptions.NoInlining)] + static object GetTheObject() => new object(); + + static volatile object s_sink; public static void Run() { // This was asserting the BCL because Never would not have reflection metadata // despite the typeof - Console.WriteLine(s_type == typeof(Never)); + Console.WriteLine(GetTheType() == typeof(Never)); // This was a compiler crash Console.WriteLine(typeof(object) == typeof(Gen<>)); #if !DEBUG ThrowIfPresent(typeof(TestTypeEquals), nameof(Never)); + + { + RunCheck(GetTheType()); + + static void RunCheck(Type t) + { + if (t == typeof(Never2)) + { + s_sink = new Canary2(); + } + } + + ThrowIfPresentWithUsableMethodTable(typeof(TestTypeEquals), nameof(Canary2)); + } + + { + + RunCheck(GetTheObject()); + + static void RunCheck(object o) + { + if (o.GetType() == typeof(Never3)) + { + s_sink = new Canary3(); + } + } + + ThrowIfPresentWithUsableMethodTable(typeof(TestTypeEquals), nameof(Canary3)); + } + + { + RunCheck(GetThePointerType()); + + static void RunCheck(Type t) + { + if (t == typeof(void*)) + { + return; + } + throw new Exception(); + } + } + + { + RunCheck(typeof(Maybe1)); + + [MethodImpl(MethodImplOptions.NoInlining)] + static void RunCheck(Type t) + { + if (t == typeof(Maybe1)) + { + return; + } + throw new Exception(); + } + } #endif } }