@@ -1294,27 +1294,19 @@ VarDecl *PatternBindingInitializer::getInitializedLazyVar() const {
1294
1294
return nullptr ;
1295
1295
}
1296
1296
1297
- static bool patternContainsVarDeclBinding (const Pattern *P, const VarDecl *VD) {
1298
- bool Result = false ;
1299
- P->forEachVariable ([&](VarDecl *FoundVD) {
1300
- Result |= FoundVD == VD;
1301
- });
1302
- return Result;
1303
- }
1304
-
1305
1297
unsigned PatternBindingDecl::getPatternEntryIndexForVarDecl (const VarDecl *VD) const {
1306
1298
assert (VD && " Cannot find a null VarDecl" );
1307
1299
1308
1300
auto List = getPatternList ();
1309
1301
if (List.size () == 1 ) {
1310
- assert (patternContainsVarDeclBinding ( List[0 ].getPattern (), VD) &&
1302
+ assert (List[0 ].getPattern ()-> containsVarDecl ( VD) &&
1311
1303
" Single entry PatternBindingDecl is set up wrong" );
1312
1304
return 0 ;
1313
1305
}
1314
1306
1315
1307
unsigned Result = 0 ;
1316
1308
for (auto entry : List) {
1317
- if (patternContainsVarDeclBinding ( entry.getPattern (), VD))
1309
+ if (entry.getPattern ()-> containsVarDecl ( VD))
1318
1310
return Result;
1319
1311
++Result;
1320
1312
}
@@ -4927,12 +4919,6 @@ SourceRange VarDecl::getTypeSourceRangeForDiagnostics() const {
4927
4919
return SourceRange ();
4928
4920
}
4929
4921
4930
- static bool isVarInPattern (const VarDecl *vd, Pattern *p) {
4931
- bool foundIt = false ;
4932
- p->forEachVariable ([&](VarDecl *foundFD) { foundIt |= foundFD == vd; });
4933
- return foundIt;
4934
- }
4935
-
4936
4922
static Optional<std::pair<CaseStmt *, Pattern *>>
4937
4923
findParentPatternCaseStmtAndPattern (const VarDecl *inputVD) {
4938
4924
auto getMatchingPattern = [&](CaseStmt *cs) -> Pattern * {
@@ -4946,7 +4932,7 @@ findParentPatternCaseStmtAndPattern(const VarDecl *inputVD) {
4946
4932
4947
4933
// Then check the rest of our case label items.
4948
4934
for (auto &item : cs->getMutableCaseLabelItems ()) {
4949
- if (isVarInPattern (inputVD, item.getPattern ())) {
4935
+ if (item.getPattern ()-> containsVarDecl (inputVD )) {
4950
4936
return item.getPattern ();
4951
4937
}
4952
4938
}
@@ -5039,15 +5025,15 @@ Pattern *VarDecl::getParentPattern() const {
5039
5025
// In a case statement, search for the pattern that contains it. This is
5040
5026
// a bit silly, because you can't have something like "case x, y:" anyway.
5041
5027
for (auto items : cs->getCaseLabelItems ()) {
5042
- if (isVarInPattern ( this , items.getPattern ()))
5028
+ if (items.getPattern ()-> containsVarDecl ( this ))
5043
5029
return items.getPattern ();
5044
5030
}
5045
5031
}
5046
5032
5047
5033
if (auto *LCS = dyn_cast<LabeledConditionalStmt>(stmt)) {
5048
5034
for (auto &elt : LCS->getCond ())
5049
5035
if (auto pat = elt.getPatternOrNull ())
5050
- if (isVarInPattern (this , pat ))
5036
+ if (pat-> containsVarDecl (this ))
5051
5037
return pat;
5052
5038
}
5053
5039
@@ -5066,6 +5052,55 @@ Pattern *VarDecl::getParentPattern() const {
5066
5052
return nullptr ;
5067
5053
}
5068
5054
5055
+ NullablePtr<VarDecl>
5056
+ VarDecl::getCorrespondingFirstCaseLabelItemVarDecl () const {
5057
+ if (!hasName ())
5058
+ return nullptr ;
5059
+
5060
+ auto *caseStmt = dyn_cast_or_null<CaseStmt>(getRecursiveParentPatternStmt ());
5061
+ if (!caseStmt)
5062
+ return nullptr ;
5063
+
5064
+ auto *pattern = caseStmt->getCaseLabelItems ().front ().getPattern ();
5065
+ SmallVector<VarDecl *, 8 > vars;
5066
+ pattern->collectVariables (vars);
5067
+ for (auto *vd : vars) {
5068
+ if (vd->hasName () && vd->getName () == getName ())
5069
+ return vd;
5070
+ }
5071
+ return nullptr ;
5072
+ }
5073
+
5074
+ bool VarDecl::isCaseBodyVariable () const {
5075
+ auto *caseStmt = dyn_cast_or_null<CaseStmt>(getRecursiveParentPatternStmt ());
5076
+ if (!caseStmt)
5077
+ return false ;
5078
+ return llvm::any_of (caseStmt->getCaseBodyVariablesOrEmptyArray (),
5079
+ [&](VarDecl *vd) { return vd == this ; });
5080
+ }
5081
+
5082
+ NullablePtr<VarDecl> VarDecl::getCorrespondingCaseBodyVariable () const {
5083
+ // Only var decls associated with case statements can have child var decls.
5084
+ auto *caseStmt = dyn_cast_or_null<CaseStmt>(getRecursiveParentPatternStmt ());
5085
+ if (!caseStmt)
5086
+ return nullptr ;
5087
+
5088
+ // If this var decl doesn't have a name, it can not have a corresponding case
5089
+ // body variable.
5090
+ if (!hasName ())
5091
+ return nullptr ;
5092
+
5093
+ auto name = getName ();
5094
+
5095
+ // A var decl associated with a case stmt implies that the case stmt has body
5096
+ // var decls. So we can access the optional value here without worry.
5097
+ auto caseBodyVars = *caseStmt->getCaseBodyVariables ();
5098
+ auto result = llvm::find_if (caseBodyVars, [&](VarDecl *caseBodyVar) {
5099
+ return caseBodyVar->getName () == name;
5100
+ });
5101
+ return (result != caseBodyVars.end ()) ? *result : nullptr ;
5102
+ }
5103
+
5069
5104
bool VarDecl::isSelfParameter () const {
5070
5105
if (isa<ParamDecl>(this )) {
5071
5106
if (auto *AFD = dyn_cast<AbstractFunctionDecl>(getDeclContext ()))
0 commit comments