Skip to content

[Syntax] Add init(unsafeCasting: Syntax) to concrete node types #2924

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jan 2, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,15 @@ let syntaxCollectionsFile = SourceFileSyntax(leadingTrivia: copyrightHeader) {
"""
)

DeclSyntax(
"""
@_transparent
init(unsafeCasting node: Syntax) {
self._syntaxNode = node
}
"""
)

DeclSyntax("public static let syntaxKind = SyntaxKind.\(node.memberCallName)")
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,15 @@ func syntaxNode(nodesStartingWith: [Character]) -> SourceFileSyntax {
"""
)

DeclSyntax(
"""
@_transparent
init(unsafeCasting node: Syntax) {
self._syntaxNode = node
}
"""
)

let initSignature = InitSignature(node)

try! InitializerDeclSyntax(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,8 +72,7 @@ let syntaxRewriterFile = SourceFileSyntax(leadingTrivia: copyrightHeader) {
"""
/// Rewrite `node`, keeping its parent unless `detach` is `true`.
public func rewrite(_ node: some SyntaxProtocol, detach: Bool = false) -> Syntax {
var rewritten = Syntax(node)
self.dispatchVisit(&rewritten)
let rewritten = self.visitImpl(Syntax(node))
if detach {
return rewritten
}
Expand All @@ -87,11 +86,20 @@ let syntaxRewriterFile = SourceFileSyntax(leadingTrivia: copyrightHeader) {

DeclSyntax(
"""
/// Visit a ``TokenSyntax``.
/// - Parameter token: the token that is being visited
/// Visit any Syntax node.
/// - Parameter node: the node that is being visited
/// - Returns: the rewritten node
open func visit(_ token: TokenSyntax) -> TokenSyntax {
return token
@available(*, deprecated, renamed: "rewrite(_:detach:)")
public func visit(_ node: Syntax) -> Syntax {
return visitImpl(node)
}
"""
)

DeclSyntax(
"""
public func visit<T: SyntaxChildChoices>(_ node: T) -> T {
visitImpl(Syntax(node)).cast(T.self)
}
"""
)
Expand Down Expand Up @@ -133,24 +141,11 @@ let syntaxRewriterFile = SourceFileSyntax(leadingTrivia: copyrightHeader) {

DeclSyntax(
"""
/// Visit any Syntax node.
/// - Parameter node: the node that is being visited
/// Visit a ``TokenSyntax``.
/// - Parameter token: the token that is being visited
/// - Returns: the rewritten node
@available(*, deprecated, renamed: "rewrite(_:detach:)")
public func visit(_ node: Syntax) -> Syntax {
var rewritten = node
dispatchVisit(&rewritten)
return rewritten
}
"""
)

DeclSyntax(
"""
public func visit<T: SyntaxChildChoices>(_ node: T) -> T {
var rewritten = Syntax(node)
dispatchVisit(&rewritten)
return rewritten.cast(T.self)
open func visit(_ token: TokenSyntax) -> TokenSyntax {
return token
}
"""
)
Expand All @@ -164,7 +159,7 @@ let syntaxRewriterFile = SourceFileSyntax(leadingTrivia: copyrightHeader) {
/// - Returns: the rewritten node
\(node.apiAttributes())\
open func visit(_ node: \(node.kind.syntaxType)) -> \(node.kind.syntaxType) {
return visitChildren(node._syntaxNode).cast(\(node.kind.syntaxType).self)
return \(node.kind.syntaxType)(unsafeCasting: visitChildren(node._syntaxNode))
}
"""
)
Expand All @@ -176,7 +171,7 @@ let syntaxRewriterFile = SourceFileSyntax(leadingTrivia: copyrightHeader) {
/// - Returns: the rewritten node
\(node.apiAttributes())\
open func visit(_ node: \(node.kind.syntaxType)) -> \(node.baseType.syntaxBaseName) {
return \(node.baseType.syntaxBaseName)(visitChildren(node._syntaxNode).cast(\(node.kind.syntaxType).self))
return \(node.baseType.syntaxBaseName)(\(node.kind.syntaxType)(unsafeCasting: visitChildren(node._syntaxNode)))
}
"""
)
Expand All @@ -193,32 +188,35 @@ let syntaxRewriterFile = SourceFileSyntax(leadingTrivia: copyrightHeader) {
/// - Returns: the rewritten node
\(baseNode.apiAttributes())\
public func visit(_ node: \(baseKind.syntaxType)) -> \(baseKind.syntaxType) {
var node: Syntax = Syntax(node)
dispatchVisit(&node)
return node.cast(\(baseKind.syntaxType).self)
visitImpl(Syntax(node)).cast(\(baseKind.syntaxType).self)
}
"""
)
}

// NOTE: '@inline(never)' because perf tests showed the best results.
// It keeps 'dispatchVisit(_:)' function small, and make all 'case' bodies exactly the same pattern.
// Which enables some optimizations.
DeclSyntax(
"""
/// Interpret `node` as a node of type `nodeType`, visit it, calling
/// the `visit` to transform the node.
@inline(__always)
private func visitImpl<NodeType: SyntaxProtocol>(
_ node: inout Syntax,
_ nodeType: NodeType.Type,
_ visit: (NodeType) -> some SyntaxProtocol
) {
let origNode = node
visitPre(origNode)
node = visitAny(origNode) ?? Syntax(visit(origNode.cast(NodeType.self)))
visitPost(origNode)
@inline(never)
private func visitTokenSyntaxImpl(_ node: Syntax) -> Syntax {
Syntax(visit(TokenSyntax(unsafeCasting: node)))
}
"""
)

for node in NON_BASE_SYNTAX_NODES {
DeclSyntax(
"""
@inline(never)
private func visit\(node.kind.syntaxType)Impl(_ node: Syntax) -> Syntax {
Syntax(visit(\(node.kind.syntaxType)(unsafeCasting: node)))
}
"""
)
}

try IfConfigDeclSyntax(
leadingTrivia:
"""
Expand Down Expand Up @@ -255,26 +253,26 @@ let syntaxRewriterFile = SourceFileSyntax(leadingTrivia: copyrightHeader) {
/// that determines the correct visitation function will be popped of the
/// stack before the function is being called, making the switch's stack
/// space transient instead of having it linger in the call stack.
private func visitationFunc(for node: Syntax) -> ((inout Syntax) -> Void)
private func visitationFunc(for node: Syntax) -> (Syntax) -> Syntax
"""
) {
try SwitchExprSyntax("switch node.raw.kind") {
SwitchCaseSyntax("case .token:") {
StmtSyntax("return { self.visitImpl(&$0, TokenSyntax.self, self.visit) }")
StmtSyntax("return self.visitTokenSyntaxImpl(_:)")
}

for node in NON_BASE_SYNTAX_NODES {
SwitchCaseSyntax("case .\(node.enumCaseCallName):") {
StmtSyntax("return { self.visitImpl(&$0, \(node.kind.syntaxType).self, self.visit) }")
StmtSyntax("return self.visit\(node.kind.syntaxType)Impl(_:)")
}
}
}
}

DeclSyntax(
"""
private func dispatchVisit(_ node: inout Syntax) {
visitationFunc(for: node)(&node)
private func dispatchVisit(_ node: Syntax) -> Syntax {
visitationFunc(for: node)(node)
}
"""
)
Expand All @@ -285,15 +283,15 @@ let syntaxRewriterFile = SourceFileSyntax(leadingTrivia: copyrightHeader) {
poundKeyword: .poundElseToken(),
elements: .statements(
CodeBlockItemListSyntax {
try! FunctionDeclSyntax("private func dispatchVisit(_ node: inout Syntax)") {
try! FunctionDeclSyntax("private func dispatchVisit(_ node: Syntax) -> Syntax") {
try SwitchExprSyntax("switch node.raw.kind") {
SwitchCaseSyntax("case .token:") {
StmtSyntax("return visitImpl(&node, TokenSyntax.self, visit)")
StmtSyntax("return visitTokenSyntaxImpl(node)")
}

for node in NON_BASE_SYNTAX_NODES {
SwitchCaseSyntax("case .\(node.enumCaseCallName):") {
StmtSyntax("return visitImpl(&node, \(node.kind.syntaxType).self, visit)")
StmtSyntax("return visit\(node.kind.syntaxType)Impl(node)")
}
}
}
Expand All @@ -304,6 +302,16 @@ let syntaxRewriterFile = SourceFileSyntax(leadingTrivia: copyrightHeader) {
}
)

DeclSyntax(
"""
private func visitImpl(_ node: Syntax) -> Syntax {
visitPre(node)
defer { visitPost(node) }
return visitAny(node) ?? dispatchVisit(node)
}
"""
)

DeclSyntax(
"""
private func visitChildren(_ node: Syntax) -> Syntax {
Expand All @@ -325,9 +333,7 @@ let syntaxRewriterFile = SourceFileSyntax(leadingTrivia: copyrightHeader) {
for case let (child?, info) in RawSyntaxChildren(node) where viewMode.shouldTraverse(node: child) {

// Build the Syntax node to rewrite
var childNode = nodeFactory.create(parent: node, raw: child, absoluteInfo: info)

dispatchVisit(&childNode)
var childNode = visitImpl(nodeFactory.create(parent: node, raw: child, absoluteInfo: info))
if childNode.raw.id != child.id {
// The node was rewritten, let's handle it

Expand Down
Loading