Skip to content

Commit d8cd468

Browse files
committed
Generalize optional chaining detection used by #require().
This PR generalizes the optional chaining detection used in the implementation of unwrapping `#require()` such that it works with more than just member access expressions. Resolves #623.
1 parent cac6314 commit d8cd468

File tree

3 files changed

+24
-9
lines changed

3 files changed

+24
-9
lines changed

Sources/TestingMacros/Support/ConditionArgumentParsing.swift

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -242,6 +242,18 @@ private func _parseCondition(from expr: ClosureExprSyntax, for macro: some Frees
242242
return Condition(expression: expr)
243243
}
244244

245+
/// A class that walks a syntax tree looking for optional chaining expressions
246+
/// such as `a?.b.c`.
247+
private final class _OptionalChainFinder: SyntaxVisitor {
248+
/// Whether or not any optional chaining was found.
249+
var optionalChainFound = false
250+
251+
override func visit(_ node: OptionalChainingExprSyntax) -> SyntaxVisitorContinueKind {
252+
optionalChainFound = true
253+
return .skipChildren
254+
}
255+
}
256+
245257
/// Extract the underlying expression from an optional-chained expression as
246258
/// well as the number of question marks required to reach it.
247259
///
@@ -279,15 +291,9 @@ private func _exprFromOptionalChainedExpr(_ expr: some ExprSyntaxProtocol) -> (E
279291
// the member accesses in the expression use optional chaining and, if one
280292
// does, ensure we preserve optional chaining in the macro expansion.
281293
if questionMarkCount == 0 {
282-
func isOptionalChained(_ expr: some ExprSyntaxProtocol) -> Bool {
283-
if expr.is(OptionalChainingExprSyntax.self) {
284-
return true
285-
} else if let memberAccessBaseExpr = expr.as(MemberAccessExprSyntax.self)?.base {
286-
return isOptionalChained(memberAccessBaseExpr)
287-
}
288-
return false
289-
}
290-
if isOptionalChained(originalExpr) {
294+
let optionalChainFinder = _OptionalChainFinder(viewMode: .sourceAccurate)
295+
optionalChainFinder.walk(originalExpr)
296+
if optionalChainFinder.optionalChainFound {
291297
questionMarkCount = 1
292298
}
293299
}

Tests/TestingMacrosTests/ConditionMacroTests.swift

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,8 @@ struct ConditionMacroTests {
8888
##"Testing.__checkPropertyAccess(a.self, getting: { $0???.isB }, expression: .__fromPropertyAccess(.__fromSyntaxNode("a"), .__fromSyntaxNode("isB")), comments: [], isRequired: false, sourceLocation: Testing.SourceLocation.__here()).__expected()"##,
8989
##"#expect(a?.b.isB)"##:
9090
##"Testing.__checkPropertyAccess(a?.b.self, getting: { $0?.isB }, expression: .__fromPropertyAccess(.__fromSyntaxNode("a?.b"), .__fromSyntaxNode("isB")), comments: [], isRequired: false, sourceLocation: Testing.SourceLocation.__here()).__expected()"##,
91+
##"#expect(a?.b().isB)"##:
92+
##"Testing.__checkPropertyAccess(a?.b().self, getting: { $0?.isB }, expression: .__fromPropertyAccess(.__fromSyntaxNode("a?.b()"), .__fromSyntaxNode("isB")), comments: [], isRequired: false, sourceLocation: Testing.SourceLocation.__here()).__expected()"##,
9193
##"#expect(isolation: somewhere) {}"##:
9294
##"Testing.__checkClosureCall(performing: {}, expression: .__fromSyntaxNode("{}"), comments: [], isRequired: false, isolation: somewhere, sourceLocation: Testing.SourceLocation.__here()).__expected()"##,
9395
]
@@ -166,6 +168,8 @@ struct ConditionMacroTests {
166168
##"Testing.__checkPropertyAccess(a.self, getting: { $0???.isB }, expression: .__fromPropertyAccess(.__fromSyntaxNode("a"), .__fromSyntaxNode("isB")), comments: [], isRequired: true, sourceLocation: Testing.SourceLocation.__here()).__required()"##,
167169
##"#require(a?.b.isB)"##:
168170
##"Testing.__checkPropertyAccess(a?.b.self, getting: { $0?.isB }, expression: .__fromPropertyAccess(.__fromSyntaxNode("a?.b"), .__fromSyntaxNode("isB")), comments: [], isRequired: true, sourceLocation: Testing.SourceLocation.__here()).__required()"##,
171+
##"#require(a?.b().isB)"##:
172+
##"Testing.__checkPropertyAccess(a?.b().self, getting: { $0?.isB }, expression: .__fromPropertyAccess(.__fromSyntaxNode("a?.b()"), .__fromSyntaxNode("isB")), comments: [], isRequired: true, sourceLocation: Testing.SourceLocation.__here()).__required()"##,
169173
##"#require(isolation: somewhere) {}"##:
170174
##"Testing.__checkClosureCall(performing: {}, expression: .__fromSyntaxNode("{}"), comments: [], isRequired: true, isolation: somewhere, sourceLocation: Testing.SourceLocation.__here()).__required()"##,
171175
]

Tests/TestingTests/MiscellaneousTests.swift

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -222,6 +222,11 @@ struct MultiLineSuite {
222222
staticMultiLineTestDecl() async {}
223223
}
224224

225+
@Test(.hidden) func complexOptionalChainingWithRequire() throws {
226+
let x: String? = nil
227+
_ = try #require(x?[...].last)
228+
}
229+
225230
@Suite("Miscellaneous tests")
226231
struct MiscellaneousTests {
227232
@Test("Free function's name")

0 commit comments

Comments
 (0)