Skip to content

Commit d1adf81

Browse files
Add support for default implementation of static virtuals with method constraints (#89061)
- The major problem was the logic which incorrectly would instantiate the methods when it wasn't necessary - As the number of flags to the implementation functions has grown very large, this change also includes logic converting them all to a single flags variable when passing them around Fixes #73658 Fixes #78865
1 parent ef4860a commit d1adf81

13 files changed

+488
-41
lines changed

src/coreclr/inc/enum_class_flags.h

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
// Licensed to the .NET Foundation under one or more agreements.
2+
// The .NET Foundation licenses this file to you under the MIT license.
3+
4+
#ifndef ENUM_CLASS_FLAGS_OPERATORS
5+
#define ENUM_CLASS_FLAGS_OPERATORS
6+
7+
template <typename T>
8+
inline auto operator& (T left, T right) -> decltype(T::support_use_as_flags)
9+
{
10+
return static_cast<T>(static_cast<int>(left) & static_cast<int>(right));
11+
}
12+
13+
template <typename T>
14+
inline auto operator| (T left, T right) -> decltype(T::support_use_as_flags)
15+
{
16+
return static_cast<T>(static_cast<int>(left) | static_cast<int>(right));
17+
}
18+
19+
template <typename T>
20+
inline auto operator^ (T left, T right) -> decltype(T::support_use_as_flags)
21+
{
22+
return static_cast<T>(static_cast<int>(left) ^ static_cast<int>(right));
23+
}
24+
25+
template <typename T>
26+
inline auto operator~ (T value) -> decltype(T::support_use_as_flags)
27+
{
28+
return static_cast<T>(~static_cast<int>(value));
29+
}
30+
31+
template <typename T>
32+
inline auto operator |= (T& left, T right) -> const decltype(T::support_use_as_flags)&
33+
{
34+
left = left | right;
35+
return left;
36+
}
37+
38+
template <typename T>
39+
inline auto operator &= (T& left, T right) -> const decltype(T::support_use_as_flags)&
40+
{
41+
left = left & right;
42+
return left;
43+
}
44+
45+
template <typename T>
46+
inline auto operator ^= (T& left, T right) -> const decltype(T::support_use_as_flags)&
47+
{
48+
left = left ^ right;
49+
return left;
50+
}
51+
52+
#endif /* ENUM_CLASS_FLAGS_OPERATORS */

src/coreclr/vm/genericdict.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1144,9 +1144,9 @@ Dictionary::PopulateEntry(
11441144
pResolvedMD = constraintType.GetMethodTable()->ResolveVirtualStaticMethod(
11451145
ownerType.GetMethodTable(),
11461146
pMethod,
1147-
/* allowNullResult */ TRUE,
1148-
/* verifyImplemented */ FALSE,
1149-
/* allowVariantMatches */ TRUE,
1147+
ResolveVirtualStaticMethodFlags::AllowNullResult |
1148+
ResolveVirtualStaticMethodFlags::AllowVariantMatches |
1149+
ResolveVirtualStaticMethodFlags::InstantiateResultOverFinalMethodDesc,
11501150
&uniqueResolution);
11511151

11521152
// If we couldn't get an exact result, fall back to using a stub to make the exact function call

src/coreclr/vm/methodtable.cpp

Lines changed: 68 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -6210,22 +6210,26 @@ MethodTable::FindDispatchImpl(
62106210

62116211
// Try exact match first
62126212
MethodDesc *pDefaultMethod = NULL;
6213+
6214+
FindDefaultInterfaceImplementationFlags flags = FindDefaultInterfaceImplementationFlags::InstantiateFoundMethodDesc;
6215+
if (throwOnConflict)
6216+
flags = flags | FindDefaultInterfaceImplementationFlags::ThrowOnConflict;
6217+
62136218
BOOL foundDefaultInterfaceImplementation = FindDefaultInterfaceImplementation(
62146219
pIfcMD, // the interface method being resolved
62156220
pIfcMT, // the interface being resolved
62166221
&pDefaultMethod,
6217-
FALSE, // allowVariance
6218-
throwOnConflict);
6222+
flags);
62196223

62206224
// If there's no exact match, try a variant match
62216225
if (!foundDefaultInterfaceImplementation && pIfcMT->HasVariance())
62226226
{
6227+
flags = flags | FindDefaultInterfaceImplementationFlags::AllowVariance;
62236228
foundDefaultInterfaceImplementation = FindDefaultInterfaceImplementation(
62246229
pIfcMD, // the interface method being resolved
62256230
pIfcMT, // the interface being resolved
62266231
&pDefaultMethod,
6227-
TRUE, // allowVariance
6228-
throwOnConflict);
6232+
flags);
62296233
}
62306234

62316235
if (foundDefaultInterfaceImplementation)
@@ -6324,10 +6328,13 @@ namespace
63246328
MethodTable *pMT,
63256329
MethodDesc *interfaceMD,
63266330
MethodTable *interfaceMT,
6327-
BOOL allowVariance,
6331+
FindDefaultInterfaceImplementationFlags findDefaultImplementationFlags,
63286332
MethodDesc **candidateMD,
63296333
ClassLoadLevel level)
63306334
{
6335+
bool allowVariance = (findDefaultImplementationFlags & FindDefaultInterfaceImplementationFlags::AllowVariance) != FindDefaultInterfaceImplementationFlags::None;
6336+
bool instantiateMethodInstantiation = (findDefaultImplementationFlags & FindDefaultInterfaceImplementationFlags::InstantiateFoundMethodDesc) != FindDefaultInterfaceImplementationFlags::None;
6337+
63316338
*candidateMD = NULL;
63326339

63336340
MethodDesc *candidateMaybe = NULL;
@@ -6418,11 +6425,20 @@ namespace
64186425
else
64196426
{
64206427
// Static virtual methods don't record MethodImpl slots so they need special handling
6428+
ResolveVirtualStaticMethodFlags resolveVirtualStaticMethodFlags = ResolveVirtualStaticMethodFlags::None;
6429+
if (allowVariance)
6430+
{
6431+
resolveVirtualStaticMethodFlags |= ResolveVirtualStaticMethodFlags::AllowVariantMatches;
6432+
}
6433+
if (instantiateMethodInstantiation)
6434+
{
6435+
resolveVirtualStaticMethodFlags |= ResolveVirtualStaticMethodFlags::InstantiateResultOverFinalMethodDesc;
6436+
}
6437+
64216438
candidateMaybe = pMT->TryResolveVirtualStaticMethodOnThisType(
64226439
interfaceMT,
64236440
interfaceMD,
6424-
/* verifyImplemented */ FALSE,
6425-
/* allowVariance */ allowVariance,
6441+
resolveVirtualStaticMethodFlags,
64266442
/* level */ level);
64276443
}
64286444
}
@@ -6461,8 +6477,7 @@ BOOL MethodTable::FindDefaultInterfaceImplementation(
64616477
MethodDesc *pInterfaceMD,
64626478
MethodTable *pInterfaceMT,
64636479
MethodDesc **ppDefaultMethod,
6464-
BOOL allowVariance,
6465-
BOOL throwOnConflict,
6480+
FindDefaultInterfaceImplementationFlags findDefaultImplementationFlags,
64666481
ClassLoadLevel level
64676482
)
64686483
{
@@ -6478,12 +6493,13 @@ BOOL MethodTable::FindDefaultInterfaceImplementation(
64786493
} CONTRACT_END;
64796494

64806495
#ifdef FEATURE_DEFAULT_INTERFACES
6496+
bool allowVariance = (findDefaultImplementationFlags & FindDefaultInterfaceImplementationFlags::AllowVariance) != FindDefaultInterfaceImplementationFlags::None;
64816497
CQuickArray<MatchCandidate> candidates;
64826498
unsigned candidatesCount = 0;
64836499

64846500
// Check the current method table itself
64856501
MethodDesc *candidateMaybe = NULL;
6486-
if (IsInterface() && TryGetCandidateImplementation(this, pInterfaceMD, pInterfaceMT, allowVariance, &candidateMaybe, level))
6502+
if (IsInterface() && TryGetCandidateImplementation(this, pInterfaceMD, pInterfaceMT, findDefaultImplementationFlags, &candidateMaybe, level))
64876503
{
64886504
_ASSERTE(candidateMaybe != NULL);
64896505

@@ -6523,7 +6539,7 @@ BOOL MethodTable::FindDefaultInterfaceImplementation(
65236539
MethodTable *pCurMT = it.GetInterface(pMT, level);
65246540

65256541
MethodDesc *pCurMD = NULL;
6526-
if (TryGetCandidateImplementation(pCurMT, pInterfaceMD, pInterfaceMT, allowVariance, &pCurMD, level))
6542+
if (TryGetCandidateImplementation(pCurMT, pInterfaceMD, pInterfaceMT, findDefaultImplementationFlags, &pCurMD, level))
65276543
{
65286544
//
65296545
// Found a match. But is it a more specific match (we want most specific interfaces)
@@ -6619,6 +6635,8 @@ BOOL MethodTable::FindDefaultInterfaceImplementation(
66196635
}
66206636
else if (pBestCandidateMT != candidates[i].pMT)
66216637
{
6638+
bool throwOnConflict = (findDefaultImplementationFlags & FindDefaultInterfaceImplementationFlags::ThrowOnConflict) != FindDefaultInterfaceImplementationFlags::None;
6639+
66226640
if (throwOnConflict)
66236641
ThrowExceptionForConflictingOverride(this, pInterfaceMT, pInterfaceMD);
66246642

@@ -8875,12 +8893,15 @@ MethodDesc *
88758893
MethodTable::ResolveVirtualStaticMethod(
88768894
MethodTable* pInterfaceType,
88778895
MethodDesc* pInterfaceMD,
8878-
BOOL allowNullResult,
8879-
BOOL verifyImplemented,
8880-
BOOL allowVariantMatches,
8896+
ResolveVirtualStaticMethodFlags resolveVirtualStaticMethodFlags,
88818897
BOOL* uniqueResolution,
88828898
ClassLoadLevel level)
88838899
{
8900+
bool verifyImplemented = (resolveVirtualStaticMethodFlags & ResolveVirtualStaticMethodFlags::VerifyImplemented) != ResolveVirtualStaticMethodFlags::None;
8901+
bool allowVariantMatches = (resolveVirtualStaticMethodFlags & ResolveVirtualStaticMethodFlags::AllowVariantMatches) != ResolveVirtualStaticMethodFlags::None;
8902+
bool instantiateMethodParameters = (resolveVirtualStaticMethodFlags & ResolveVirtualStaticMethodFlags::InstantiateResultOverFinalMethodDesc) != ResolveVirtualStaticMethodFlags::None;
8903+
bool allowNullResult = (resolveVirtualStaticMethodFlags & ResolveVirtualStaticMethodFlags::AllowNullResult) != ResolveVirtualStaticMethodFlags::None;
8904+
88848905
if (uniqueResolution != nullptr)
88858906
{
88868907
*uniqueResolution = TRUE;
@@ -8912,7 +8933,7 @@ MethodTable::ResolveVirtualStaticMethod(
89128933
// Search for match on a per-level in the type hierarchy
89138934
for (MethodTable* pMT = this; pMT != nullptr; pMT = pMT->GetParentMethodTable())
89148935
{
8915-
MethodDesc* pMD = pMT->TryResolveVirtualStaticMethodOnThisType(pInterfaceType, pInterfaceMD, verifyImplemented, /*allowVariance*/ FALSE, level);
8936+
MethodDesc* pMD = pMT->TryResolveVirtualStaticMethodOnThisType(pInterfaceType, pInterfaceMD, resolveVirtualStaticMethodFlags & ~ResolveVirtualStaticMethodFlags::AllowVariantMatches, level);
89168937
if (pMD != nullptr)
89178938
{
89188939
return pMD;
@@ -8956,7 +8977,7 @@ MethodTable::ResolveVirtualStaticMethod(
89568977
{
89578978
// Variant or equivalent matching interface found
89588979
// Attempt to resolve on variance matched interface
8959-
pMD = pMT->TryResolveVirtualStaticMethodOnThisType(pItfInMap, pInterfaceMD, verifyImplemented, /*allowVariance*/ FALSE, level);
8980+
pMD = pMT->TryResolveVirtualStaticMethodOnThisType(pItfInMap, pInterfaceMD, resolveVirtualStaticMethodFlags & ~ResolveVirtualStaticMethodFlags::AllowVariantMatches, level);
89608981
if (pMD != nullptr)
89618982
{
89628983
return pMD;
@@ -8970,12 +8991,25 @@ MethodTable::ResolveVirtualStaticMethod(
89708991
BOOL allowVariantMatchInDefaultImplementationLookup = FALSE;
89718992
do
89728993
{
8994+
FindDefaultInterfaceImplementationFlags findDefaultImplementationFlags = FindDefaultInterfaceImplementationFlags::None;
8995+
if (allowVariantMatchInDefaultImplementationLookup)
8996+
{
8997+
findDefaultImplementationFlags |= FindDefaultInterfaceImplementationFlags::AllowVariance;
8998+
}
8999+
if (uniqueResolution == nullptr)
9000+
{
9001+
findDefaultImplementationFlags |= FindDefaultInterfaceImplementationFlags::ThrowOnConflict;
9002+
}
9003+
if (instantiateMethodParameters)
9004+
{
9005+
findDefaultImplementationFlags |= FindDefaultInterfaceImplementationFlags::InstantiateFoundMethodDesc;
9006+
}
9007+
89739008
BOOL haveUniqueDefaultImplementation = FindDefaultInterfaceImplementation(
89749009
pInterfaceMD,
89759010
pInterfaceType,
89769011
&pMDDefaultImpl,
8977-
/* allowVariance */ allowVariantMatchInDefaultImplementationLookup,
8978-
/* throwOnConflict */ uniqueResolution == nullptr,
9012+
findDefaultImplementationFlags,
89799013
level);
89809014
if (haveUniqueDefaultImplementation || (pMDDefaultImpl != nullptr && (verifyImplemented || uniqueResolution != nullptr)))
89819015
{
@@ -9018,8 +9052,12 @@ MethodTable::ResolveVirtualStaticMethod(
90189052
// Try to locate the appropriate MethodImpl matching a given interface static virtual method.
90199053
// Returns nullptr on failure.
90209054
MethodDesc*
9021-
MethodTable::TryResolveVirtualStaticMethodOnThisType(MethodTable* pInterfaceType, MethodDesc* pInterfaceMD, BOOL verifyImplemented, BOOL allowVariance, ClassLoadLevel level)
9055+
MethodTable::TryResolveVirtualStaticMethodOnThisType(MethodTable* pInterfaceType, MethodDesc* pInterfaceMD, ResolveVirtualStaticMethodFlags resolveVirtualStaticMethodFlags, ClassLoadLevel level)
90229056
{
9057+
bool instantiateMethodParameters = (resolveVirtualStaticMethodFlags & ResolveVirtualStaticMethodFlags::InstantiateResultOverFinalMethodDesc) != ResolveVirtualStaticMethodFlags::None;
9058+
bool allowVariance = (resolveVirtualStaticMethodFlags & ResolveVirtualStaticMethodFlags::AllowVariantMatches) != ResolveVirtualStaticMethodFlags::None;
9059+
bool verifyImplemented = (resolveVirtualStaticMethodFlags & ResolveVirtualStaticMethodFlags::VerifyImplemented) != ResolveVirtualStaticMethodFlags::None;
9060+
90239061
HRESULT hr = S_OK;
90249062
IMDInternalImport* pMDInternalImport = GetMDImport();
90259063
HENUMInternalMethodImplHolder hEnumMethodImpl(pMDInternalImport);
@@ -9148,7 +9186,7 @@ MethodTable::TryResolveVirtualStaticMethodOnThisType(MethodTable* pInterfaceType
91489186
COMPlusThrow(kTypeLoadException, E_FAIL);
91499187
}
91509188

9151-
if (!verifyImplemented)
9189+
if (!verifyImplemented && instantiateMethodParameters)
91529190
{
91539191
pMethodImpl = pMethodImpl->FindOrCreateAssociatedMethodDesc(
91549192
pMethodImpl,
@@ -9202,9 +9240,7 @@ MethodTable::VerifyThatAllVirtualStaticMethodsAreImplemented()
92029240
!ResolveVirtualStaticMethod(
92039241
pInterfaceMT,
92049242
pMD,
9205-
/* allowNullResult */ TRUE,
9206-
/* verifyImplemented */ TRUE,
9207-
/* allowVariantMatches */ FALSE,
9243+
ResolveVirtualStaticMethodFlags::AllowNullResult | ResolveVirtualStaticMethodFlags::VerifyImplemented,
92089244
/* uniqueResolution */ &uniqueResolution,
92099245
/* level */ CLASS_LOAD_EXACTPARENTS)))
92109246
{
@@ -9240,12 +9276,18 @@ MethodTable::TryResolveConstraintMethodApprox(
92409276
_ASSERTE(!thInterfaceType.IsTypeDesc());
92419277
_ASSERTE(thInterfaceType.IsInterface());
92429278
BOOL uniqueResolution;
9279+
9280+
ResolveVirtualStaticMethodFlags flags = ResolveVirtualStaticMethodFlags::AllowVariantMatches
9281+
| ResolveVirtualStaticMethodFlags::InstantiateResultOverFinalMethodDesc;
9282+
if (pfForceUseRuntimeLookup != NULL)
9283+
{
9284+
flags |= ResolveVirtualStaticMethodFlags::AllowNullResult;
9285+
}
9286+
92439287
MethodDesc *result = ResolveVirtualStaticMethod(
92449288
thInterfaceType.GetMethodTable(),
92459289
pInterfaceMD,
9246-
/* allowNullResult */pfForceUseRuntimeLookup != NULL,
9247-
/* verifyImplemented */ FALSE,
9248-
/* allowVariantMatches */ TRUE,
9290+
flags,
92499291
&uniqueResolution);
92509292
if (result == NULL || !uniqueResolution)
92519293
{

src/coreclr/vm/methodtable.h

Lines changed: 26 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
#include "contractimpl.h"
2626
#include "generics.h"
2727
#include "gcinfotypes.h"
28+
#include "enum_class_flags.h"
2829

2930
/*
3031
* Forward Declarations
@@ -63,6 +64,28 @@ class ClassFactoryBase;
6364
class ArgDestination;
6465
enum class WellKnownAttribute : DWORD;
6566

67+
enum class ResolveVirtualStaticMethodFlags
68+
{
69+
None = 0,
70+
AllowNullResult = 1,
71+
VerifyImplemented = 2,
72+
AllowVariantMatches = 4,
73+
InstantiateResultOverFinalMethodDesc = 8,
74+
75+
support_use_as_flags // Enable the template functions in enum_class_flags.h
76+
};
77+
78+
79+
enum class FindDefaultInterfaceImplementationFlags
80+
{
81+
None,
82+
AllowVariance = 1,
83+
ThrowOnConflict = 2,
84+
InstantiateFoundMethodDesc = 4,
85+
86+
support_use_as_flags // Enable the template functions in enum_class_flags.h
87+
};
88+
6689
//============================================================================
6790
// This is the in-memory structure of a class and it will evolve.
6891
//============================================================================
@@ -2084,7 +2107,6 @@ class MethodTable
20842107
MethodDesc *GetMethodDescForComInterfaceMethod(MethodDesc *pItfMD, bool fNullOk);
20852108
#endif // FEATURE_COMINTEROP
20862109

2087-
20882110
// Resolve virtual static interface method pInterfaceMD on this type.
20892111
//
20902112
// Specify allowNullResult to return NULL instead of throwing if the there is no implementation
@@ -2096,9 +2118,7 @@ class MethodTable
20962118
MethodDesc *ResolveVirtualStaticMethod(
20972119
MethodTable* pInterfaceType,
20982120
MethodDesc* pInterfaceMD,
2099-
BOOL allowNullResult,
2100-
BOOL verifyImplemented = FALSE,
2101-
BOOL allowVariantMatches = TRUE,
2121+
ResolveVirtualStaticMethodFlags resolveVirtualStaticMethodFlags,
21022122
BOOL *uniqueResolution = NULL,
21032123
ClassLoadLevel level = CLASS_LOADED);
21042124

@@ -2178,8 +2198,7 @@ class MethodTable
21782198
MethodDesc *pInterfaceMD,
21792199
MethodTable *pObjectMT,
21802200
MethodDesc **ppDefaultMethod,
2181-
BOOL allowVariance,
2182-
BOOL throwOnConflict,
2201+
FindDefaultInterfaceImplementationFlags findDefaultImplementationFlags,
21832202
ClassLoadLevel level = CLASS_LOADED);
21842203
#endif // DACCESS_COMPILE
21852204

@@ -2219,7 +2238,7 @@ class MethodTable
22192238

22202239
// Try to resolve a given static virtual method override on this type. Return nullptr
22212240
// when not found.
2222-
MethodDesc *TryResolveVirtualStaticMethodOnThisType(MethodTable* pInterfaceType, MethodDesc* pInterfaceMD, BOOL verifyImplemented, BOOL allowVariance, ClassLoadLevel level);
2241+
MethodDesc *TryResolveVirtualStaticMethodOnThisType(MethodTable* pInterfaceType, MethodDesc* pInterfaceMD, ResolveVirtualStaticMethodFlags resolveVirtualStaticMethodFlags, ClassLoadLevel level);
22232242

22242243
public:
22252244
static MethodDesc *MapMethodDeclToMethodImpl(MethodDesc *pMDDecl);

src/coreclr/vm/runtimehandles.cpp

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1102,9 +1102,8 @@ extern "C" MethodDesc* QCALLTYPE RuntimeTypeHandle_GetInterfaceMethodImplementat
11021102
pResult = typeHandle.GetMethodTable()->ResolveVirtualStaticMethod(
11031103
thOwnerOfMD.GetMethodTable(),
11041104
pMD,
1105-
/* allowNullResult */ TRUE,
1106-
/* verifyImplemented*/ FALSE,
1107-
/* allowVariantMatches */ TRUE);
1105+
ResolveVirtualStaticMethodFlags::AllowNullResult |
1106+
ResolveVirtualStaticMethodFlags::AllowVariantMatches);
11081107
}
11091108
else
11101109
{

src/coreclr/vm/typedesc.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1611,8 +1611,9 @@ BOOL TypeVarTypeDesc::SatisfiesConstraints(SigTypeContext *pTypeContextOfConstra
16111611
if (pMD->IsVirtual() &&
16121612
pMD->IsStatic() &&
16131613
(pMD->IsAbstract() && !thElem.AsMethodTable()->ResolveVirtualStaticMethod(
1614-
pInterfaceMT, pMD, /* allowNullResult */ TRUE, /* verifyImplemented */ TRUE,
1615-
/*allowVariantMatches*/ TRUE, /*uniqueResolution*/ NULL, CLASS_DEPENDENCIES_LOADED)))
1614+
pInterfaceMT, pMD,
1615+
ResolveVirtualStaticMethodFlags::AllowNullResult | ResolveVirtualStaticMethodFlags::VerifyImplemented | ResolveVirtualStaticMethodFlags::AllowVariantMatches,
1616+
/*uniqueResolution*/ NULL, CLASS_DEPENDENCIES_LOADED)))
16161617
{
16171618
virtualStaticResolutionCheckFailed = true;
16181619
break;

0 commit comments

Comments
 (0)