Skip to content

Generalize optional chaining detection used by #require(). #625

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Aug 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 15 additions & 9 deletions Sources/TestingMacros/Support/ConditionArgumentParsing.swift
Original file line number Diff line number Diff line change
Expand Up @@ -242,6 +242,18 @@ private func _parseCondition(from expr: ClosureExprSyntax, for macro: some Frees
return Condition(expression: expr)
}

/// A class that walks a syntax tree looking for optional chaining expressions
/// such as `a?.b.c`.
private final class _OptionalChainFinder: SyntaxVisitor {
/// Whether or not any optional chaining was found.
var optionalChainFound = false

override func visit(_ node: OptionalChainingExprSyntax) -> SyntaxVisitorContinueKind {
optionalChainFound = true
return .skipChildren
}
}

/// Extract the underlying expression from an optional-chained expression as
/// well as the number of question marks required to reach it.
///
Expand Down Expand Up @@ -279,15 +291,9 @@ private func _exprFromOptionalChainedExpr(_ expr: some ExprSyntaxProtocol) -> (E
// the member accesses in the expression use optional chaining and, if one
// does, ensure we preserve optional chaining in the macro expansion.
if questionMarkCount == 0 {
func isOptionalChained(_ expr: some ExprSyntaxProtocol) -> Bool {
if expr.is(OptionalChainingExprSyntax.self) {
return true
} else if let memberAccessBaseExpr = expr.as(MemberAccessExprSyntax.self)?.base {
return isOptionalChained(memberAccessBaseExpr)
}
return false
}
if isOptionalChained(originalExpr) {
let optionalChainFinder = _OptionalChainFinder(viewMode: .sourceAccurate)
optionalChainFinder.walk(originalExpr)
if optionalChainFinder.optionalChainFound {
questionMarkCount = 1
}
}
Expand Down
4 changes: 4 additions & 0 deletions Tests/TestingMacrosTests/ConditionMacroTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,8 @@ struct ConditionMacroTests {
##"Testing.__checkPropertyAccess(a.self, getting: { $0???.isB }, expression: .__fromPropertyAccess(.__fromSyntaxNode("a"), .__fromSyntaxNode("isB")), comments: [], isRequired: false, sourceLocation: Testing.SourceLocation.__here()).__expected()"##,
##"#expect(a?.b.isB)"##:
##"Testing.__checkPropertyAccess(a?.b.self, getting: { $0?.isB }, expression: .__fromPropertyAccess(.__fromSyntaxNode("a?.b"), .__fromSyntaxNode("isB")), comments: [], isRequired: false, sourceLocation: Testing.SourceLocation.__here()).__expected()"##,
##"#expect(a?.b().isB)"##:
##"Testing.__checkPropertyAccess(a?.b().self, getting: { $0?.isB }, expression: .__fromPropertyAccess(.__fromSyntaxNode("a?.b()"), .__fromSyntaxNode("isB")), comments: [], isRequired: false, sourceLocation: Testing.SourceLocation.__here()).__expected()"##,
##"#expect(isolation: somewhere) {}"##:
##"Testing.__checkClosureCall(performing: {}, expression: .__fromSyntaxNode("{}"), comments: [], isRequired: false, isolation: somewhere, sourceLocation: Testing.SourceLocation.__here()).__expected()"##,
]
Expand Down Expand Up @@ -166,6 +168,8 @@ struct ConditionMacroTests {
##"Testing.__checkPropertyAccess(a.self, getting: { $0???.isB }, expression: .__fromPropertyAccess(.__fromSyntaxNode("a"), .__fromSyntaxNode("isB")), comments: [], isRequired: true, sourceLocation: Testing.SourceLocation.__here()).__required()"##,
##"#require(a?.b.isB)"##:
##"Testing.__checkPropertyAccess(a?.b.self, getting: { $0?.isB }, expression: .__fromPropertyAccess(.__fromSyntaxNode("a?.b"), .__fromSyntaxNode("isB")), comments: [], isRequired: true, sourceLocation: Testing.SourceLocation.__here()).__required()"##,
##"#require(a?.b().isB)"##:
##"Testing.__checkPropertyAccess(a?.b().self, getting: { $0?.isB }, expression: .__fromPropertyAccess(.__fromSyntaxNode("a?.b()"), .__fromSyntaxNode("isB")), comments: [], isRequired: true, sourceLocation: Testing.SourceLocation.__here()).__required()"##,
##"#require(isolation: somewhere) {}"##:
##"Testing.__checkClosureCall(performing: {}, expression: .__fromSyntaxNode("{}"), comments: [], isRequired: true, isolation: somewhere, sourceLocation: Testing.SourceLocation.__here()).__required()"##,
]
Expand Down
5 changes: 5 additions & 0 deletions Tests/TestingTests/MiscellaneousTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,11 @@ struct MultiLineSuite {
staticMultiLineTestDecl() async {}
}

@Test(.hidden) func complexOptionalChainingWithRequire() throws {
let x: String? = nil
_ = try #require(x?[...].last)
}

@Suite("Miscellaneous tests")
struct MiscellaneousTests {
@Test("Free function's name")
Expand Down