diff --git a/docs/design/datacontracts/Loader.md b/docs/design/datacontracts/Loader.md index 46fa6b716fb651..a231d53a6242fd 100644 --- a/docs/design/datacontracts/Loader.md +++ b/docs/design/datacontracts/Loader.md @@ -74,6 +74,7 @@ TargetPointer GetILBase(ModuleHandle handle); TargetPointer GetAssemblyLoadContext(ModuleHandle handle); ModuleLookupTables GetLookupTables(ModuleHandle handle); TargetPointer GetModuleLookupMapElement(TargetPointer table, uint token, out TargetNUInt flags); +IEnumerable<(TargetPointer, uint)> EnumerateModuleLookupMap(TargetPointer table); bool IsCollectible(ModuleHandle handle); bool IsAssemblyLoaded(ModuleHandle handle); TargetPointer GetGlobalLoaderAllocator(); @@ -474,6 +475,33 @@ TargetPointer GetModuleLookupMapElement(TargetPointer table, uint token, out Tar return TargetPointer.Null; } +IEnumerable<(TargetPointer, uint)> EnumerateModuleLookupMap(TargetPointer table) +{ + Data.ModuleLookupMap lookupMap = new Data.ModuleLookupMap(table); + // have to read lookupMap an extra time upfront because only the first map + // has valid supportedFlagsMask + TargetNUInt supportedFlagsMask = target.ReadNUInt(table + /* ModuleLookupMap::SupportedFlagsMask */); + uint index = 1; // zero is invalid + do + { + uint count = target.Read(table + /*ModuleLookupMap::Count*/); + if (index < count) + { + TargetPointer entryAddress = target.ReadPointer(table + /*ModuleLookupMap::TableData*/) + (ulong)(index * target.PointerSize); + TargetPointer rawValue = target.ReadPointer(entryAddress); + ulong maskedValue = rawValue & ~(supportedFlagsMask.Value); + if (maskedValue != 0) + yield return (new TargetPointer(maskedValue), index); + index++; + } + else + { + table = target.ReadPointer(table + /*ModuleLookupMap::Next*/); + index -= count; + } + } while (table != TargetPointer.Null); +} + bool IsCollectible(ModuleHandle handle) { TargetPointer assembly = target.ReadPointer(handle.Address + /*Module::Assembly*/); diff --git a/src/native/managed/cdac/Microsoft.Diagnostics.DataContractReader.Abstractions/Contracts/ILoader.cs b/src/native/managed/cdac/Microsoft.Diagnostics.DataContractReader.Abstractions/Contracts/ILoader.cs index 8822dfee7ba384..0ef7c19a2f7c61 100644 --- a/src/native/managed/cdac/Microsoft.Diagnostics.DataContractReader.Abstractions/Contracts/ILoader.cs +++ b/src/native/managed/cdac/Microsoft.Diagnostics.DataContractReader.Abstractions/Contracts/ILoader.cs @@ -97,6 +97,7 @@ public interface ILoader : IContract TargetPointer GetAssemblyLoadContext(ModuleHandle handle) => throw new NotImplementedException(); ModuleLookupTables GetLookupTables(ModuleHandle handle) => throw new NotImplementedException(); TargetPointer GetModuleLookupMapElement(TargetPointer table, uint token, out TargetNUInt flags) => throw new NotImplementedException(); + IEnumerable<(TargetPointer, uint)> EnumerateModuleLookupMap(TargetPointer table) => throw new NotImplementedException(); bool IsCollectible(ModuleHandle handle) => throw new NotImplementedException(); bool IsAssemblyLoaded(ModuleHandle handle) => throw new NotImplementedException(); diff --git a/src/native/managed/cdac/Microsoft.Diagnostics.DataContractReader.Contracts/Contracts/Loader_1.cs b/src/native/managed/cdac/Microsoft.Diagnostics.DataContractReader.Contracts/Contracts/Loader_1.cs index 2d2562625e4688..d2cfe716f6d723 100644 --- a/src/native/managed/cdac/Microsoft.Diagnostics.DataContractReader.Contracts/Contracts/Loader_1.cs +++ b/src/native/managed/cdac/Microsoft.Diagnostics.DataContractReader.Contracts/Contracts/Loader_1.cs @@ -352,27 +352,23 @@ ModuleLookupTables ILoader.GetLookupTables(ModuleHandle handle) module.MethodDefToILCodeVersioningStateMap); } - TargetPointer ILoader.GetModuleLookupMapElement(TargetPointer table, uint token, out TargetNUInt flags) + private static (bool Done, uint NextIndex) IterateLookupMap(uint index) => (false, index + 1); + private static (bool Done, uint NextIndex) SearchLookupMap(uint index) => (true, index); + private delegate (bool Done, uint NextIndex) Delegate(uint index); + private IEnumerable<(TargetPointer, uint)> IterateModuleLookupMap(TargetPointer table, uint index, Delegate iterator) { - uint rid = EcmaMetadataUtils.GetRowId(token); - ArgumentOutOfRangeException.ThrowIfZero(rid); - flags = new TargetNUInt(0); - if (table == TargetPointer.Null) - return TargetPointer.Null; - uint index = rid; - Data.ModuleLookupMap lookupMap = _target.ProcessedData.GetOrAdd(table); - // have to read lookupMap an extra time upfront because only the first map - // has valid supportedFlagsMask - TargetNUInt supportedFlagsMask = lookupMap.SupportedFlagsMask; + bool doneIterating; do { - lookupMap = _target.ProcessedData.GetOrAdd(table); + Data.ModuleLookupMap lookupMap = _target.ProcessedData.GetOrAdd(table); if (index < lookupMap.Count) { TargetPointer entryAddress = lookupMap.TableData + (ulong)(index * _target.PointerSize); TargetPointer rawValue = _target.ReadPointer(entryAddress); - flags = new TargetNUInt(rawValue & supportedFlagsMask.Value); - return rawValue & ~(supportedFlagsMask.Value); + yield return (rawValue, index); + (doneIterating, index) = iterator(index); + if (doneIterating) + yield break; } else { @@ -380,7 +376,40 @@ TargetPointer ILoader.GetModuleLookupMapElement(TargetPointer table, uint token, index -= lookupMap.Count; } } while (table != TargetPointer.Null); - return TargetPointer.Null; + } + + TargetPointer ILoader.GetModuleLookupMapElement(TargetPointer table, uint token, out TargetNUInt flags) + { + if (table == TargetPointer.Null) + { + flags = new TargetNUInt(0); + return TargetPointer.Null; + } + + Data.ModuleLookupMap lookupMap = _target.ProcessedData.GetOrAdd(table); + ulong supportedFlagsMask = lookupMap.SupportedFlagsMask.Value; + + uint rid = EcmaMetadataUtils.GetRowId(token); + ArgumentOutOfRangeException.ThrowIfZero(rid); + (TargetPointer rval, uint _) = IterateModuleLookupMap(table, rid, SearchLookupMap).FirstOrDefault(); + flags = new TargetNUInt(rval & supportedFlagsMask); + return rval & ~supportedFlagsMask; + } + + IEnumerable<(TargetPointer, uint)> ILoader.EnumerateModuleLookupMap(TargetPointer table) + { + if (table == TargetPointer.Null) + yield break; + Data.ModuleLookupMap lookupMap = _target.ProcessedData.GetOrAdd(table); + ulong supportedFlagsMask = lookupMap.SupportedFlagsMask.Value; + TargetNUInt flags = new TargetNUInt(0); + uint index = 1; // zero is invalid + foreach ((TargetPointer targetPointer, uint idx) in IterateModuleLookupMap(table, index, IterateLookupMap)) + { + TargetPointer rval = targetPointer & ~supportedFlagsMask; + if (rval != TargetPointer.Null) + yield return (rval, idx); + } } bool ILoader.IsCollectible(ModuleHandle handle) diff --git a/src/native/managed/cdac/mscordaccore_universal/Legacy/ISOSDacInterface.cs b/src/native/managed/cdac/mscordaccore_universal/Legacy/ISOSDacInterface.cs index f2663d4363cce4..6a5d3aa40157a1 100644 --- a/src/native/managed/cdac/mscordaccore_universal/Legacy/ISOSDacInterface.cs +++ b/src/native/managed/cdac/mscordaccore_universal/Legacy/ISOSDacInterface.cs @@ -123,6 +123,12 @@ internal struct DacpModuleData public ulong dwModuleIndex; // Always 0 - .NET no longer has this } +internal enum ModuleMapType +{ + TYPEDEFTOMETHODTABLE = 0x0, + TYPEREFTOMETHODTABLE = 0x1 +} + internal struct DacpMethodTableData { public int bIsFree; // everything else is NULL if this is true. @@ -280,7 +286,7 @@ internal unsafe partial interface ISOSDacInterface [PreserveSig] int GetModuleData(ClrDataAddress moduleAddr, DacpModuleData* data); [PreserveSig] - int TraverseModuleMap(/*ModuleMapType*/ int mmt, ClrDataAddress moduleAddr, /*MODULEMAPTRAVERSE*/ void* pCallback, void* token); + int TraverseModuleMap(ModuleMapType mmt, ClrDataAddress moduleAddr, delegate* unmanaged[Stdcall] pCallback, void* token); [PreserveSig] int GetAssemblyModuleList(ClrDataAddress assembly, uint count, [In, Out, MarshalUsing(CountElementName = nameof(count))] ClrDataAddress[] modules, uint* pNeeded); [PreserveSig] diff --git a/src/native/managed/cdac/mscordaccore_universal/Legacy/SOSDacImpl.cs b/src/native/managed/cdac/mscordaccore_universal/Legacy/SOSDacImpl.cs index 6d3d6b730262fd..fe0749c21d6017 100644 --- a/src/native/managed/cdac/mscordaccore_universal/Legacy/SOSDacImpl.cs +++ b/src/native/managed/cdac/mscordaccore_universal/Legacy/SOSDacImpl.cs @@ -6,6 +6,7 @@ using System.Diagnostics; using System.Linq; using System.Runtime.InteropServices; +using System.Runtime.CompilerServices; using System.Runtime.InteropServices.Marshalling; using System.Text; @@ -2142,8 +2143,78 @@ int ISOSDacInterface.TraverseEHInfo(ClrDataAddress ip, void* pCallback, void* to => _legacyImpl is not null ? _legacyImpl.TraverseEHInfo(ip, pCallback, token) : HResults.E_NOTIMPL; int ISOSDacInterface.TraverseLoaderHeap(ClrDataAddress loaderHeapAddr, void* pCallback) => _legacyImpl is not null ? _legacyImpl.TraverseLoaderHeap(loaderHeapAddr, pCallback) : HResults.E_NOTIMPL; - int ISOSDacInterface.TraverseModuleMap(int mmt, ClrDataAddress moduleAddr, void* pCallback, void* token) - => _legacyImpl is not null ? _legacyImpl.TraverseModuleMap(mmt, moduleAddr, pCallback, token) : HResults.E_NOTIMPL; + +#if DEBUG + [UnmanagedCallersOnly(CallConvs = new[] { typeof(CallConvStdcall) })] + private static void TraverseModuleMapCallback(uint index, ulong moduleAddr, void* expectedElements) + { + var expectedElementsDict = (Dictionary)GCHandle.FromIntPtr((nint)expectedElements).Target!; + if (expectedElementsDict.TryGetValue(moduleAddr, out uint expectedIndex) && expectedIndex == index) + { + expectedElementsDict[default]++; // Increment the count for verification + } + else + { + Debug.Assert(false, $"Unexpected module address {moduleAddr:x} at index {index}"); + } + } +#endif + int ISOSDacInterface.TraverseModuleMap(ModuleMapType mmt, ClrDataAddress moduleAddr, delegate* unmanaged[Stdcall] pCallback, void* token) + { + int hr = HResults.S_OK; + IEnumerable<(TargetPointer Address, uint Index)> elements = Enumerable.Empty<(TargetPointer, uint)>(); + if (moduleAddr == 0) + hr = HResults.E_INVALIDARG; + else + { + try + { + Contracts.ILoader loader = _target.Contracts.Loader; + TargetPointer moduleAddrPtr = moduleAddr.ToTargetPointer(_target); + Contracts.ModuleHandle moduleHandle = loader.GetModuleHandleFromModulePtr(moduleAddrPtr); + Contracts.ModuleLookupTables lookupTables = loader.GetLookupTables(moduleHandle); + switch (mmt) + { + case ModuleMapType.TYPEDEFTOMETHODTABLE: + elements = loader.EnumerateModuleLookupMap(lookupTables.TypeDefToMethodTable); + break; + case ModuleMapType.TYPEREFTOMETHODTABLE: + elements = loader.EnumerateModuleLookupMap(lookupTables.TypeRefToMethodTable); + break; + default: + hr = HResults.E_INVALIDARG; + break; + } + if (hr == HResults.S_OK) + { + foreach ((TargetPointer element, uint index) in elements) + { + // Call the callback with each element + pCallback(index, element.ToClrDataAddress(_target).Value, token); + } + } + } + catch (System.Exception ex) + { + hr = ex.HResult; + } + } +#if DEBUG + if (_legacyImpl is not null) + { + Dictionary expectedElements = elements.ToDictionary(tuple => tuple.Address.ToClrDataAddress(_target).Value, tuple => tuple.Index); + expectedElements.Add(default, 0); + void* tokenDebug = GCHandle.ToIntPtr(GCHandle.Alloc(expectedElements)).ToPointer(); + delegate* unmanaged[Stdcall] callbackDebugPtr = &TraverseModuleMapCallback; + + int hrLocal = _legacyImpl.TraverseModuleMap(mmt, moduleAddr, callbackDebugPtr, tokenDebug); + Debug.Assert(hrLocal == hr, $"cDAC: {hr:x}, DAC: {hrLocal:x}"); + Debug.Assert(expectedElements[default] == elements.Count(), $"cDAC: {elements.Count()} elements, DAC: {expectedElements[default]} elements"); + GCHandle.FromIntPtr((nint)tokenDebug).Free(); + } +#endif + return hr; + } int ISOSDacInterface.TraverseRCWCleanupList(ClrDataAddress cleanupListPtr, void* pCallback, void* token) => _legacyImpl is not null ? _legacyImpl.TraverseRCWCleanupList(cleanupListPtr, pCallback, token) : HResults.E_NOTIMPL; int ISOSDacInterface.TraverseVirtCallStubHeap(ClrDataAddress pAppDomain, int heaptype, void* pCallback)