Skip to content

Commit

Permalink
Apply changes from review
Browse files Browse the repository at this point in the history
  • Loading branch information
ktoso committed Feb 6, 2024
1 parent c0bce0d commit 9d22af9
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 36 deletions.
24 changes: 6 additions & 18 deletions Sources/SwiftSyntaxMacroExpansion/MacroReplacement.swift
Original file line number Diff line number Diff line change
Expand Up @@ -188,8 +188,7 @@ fileprivate class ParameterReplacementVisitor: SyntaxAnyVisitor {

override func visit(_ node: GenericArgumentSyntax) -> SyntaxVisitorContinueKind {
guard let baseName = node.argument.as(IdentifierTypeSyntax.self)?.name else {
// Handle error
return .visitChildren
return .skipChildren
}

guard let genericParameterClause = macro.genericParameterClause else {
Expand Down Expand Up @@ -300,7 +299,6 @@ extension MacroDeclSyntax {
private final class MacroExpansionRewriter: SyntaxRewriter {
let parameterReplacements: [DeclReferenceExprSyntax: Int]
let arguments: [ExprSyntax]
// let genericParameterReplacements: [DeclReferenceExprSyntax: Int]
let genericParameterReplacements: [GenericArgumentSyntax: Int]
let genericArguments: [TypeSyntax]

Expand Down Expand Up @@ -336,17 +334,9 @@ private final class MacroExpansionRewriter: SyntaxRewriter {
}

// Swap in the argument for type parameter
return GenericArgumentSyntax(
leadingTrivia: node.leadingTrivia,
node.unexpectedBeforeArgument,
argument: genericArguments[parameterIndex].trimmed,
node.unexpectedBetweenArgumentAndTrailingComma,
trailingComma: node.trailingComma,
node.unexpectedAfterTrailingComma
// TODO: seems we're getting spurious trailing " " here,
// skipping trailing trivia for now
// trailingTrivia: node.trailingTrivia
)
var node = node
node.argument = genericArguments[parameterIndex].trimmed
return node
}
}

Expand Down Expand Up @@ -380,11 +370,9 @@ extension MacroDeclSyntax {
uniquingKeysWith: { l, r in l }
)
let genericArguments: [TypeSyntax] =
genericArgumentList?.arguments.map { element in
element.argument
} ?? []
genericArgumentList?.arguments.map { $0.argument } ?? []

let rewriter: MacroExpansionRewriter = MacroExpansionRewriter(
let rewriter = MacroExpansionRewriter(
parameterReplacements: parameterReplacements,
arguments: arguments,
genericReplacements: genericReplacements,
Expand Down
36 changes: 18 additions & 18 deletions Tests/SwiftSyntaxMacroExpansionTest/MacroReplacementTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ final class MacroReplacementTests: XCTestCase {
macro expand1(a: Int, b: Int) = #otherMacro(first: b, second: ["a": a], third: [3.14159, 2.71828], fourth: 4)
"""

let definition = try macro.as(MacroDeclSyntax.self)!.checkDefinition()
let definition = try macro.cast(MacroDeclSyntax.self).checkDefinition()
guard case let .expansion(_, replacements, _) = definition else {
XCTFail("not an expansion definition")
fatalError()
Expand All @@ -43,7 +43,7 @@ final class MacroReplacementTests: XCTestCase {

let diags: [Diagnostic]
do {
_ = try macro.as(MacroDeclSyntax.self)!.checkDefinition()
_ = try macro.cast(MacroDeclSyntax.self).checkDefinition()
XCTFail("should have failed with an error")
fatalError()
} catch let diagError as DiagnosticsError {
Expand All @@ -69,7 +69,7 @@ final class MacroReplacementTests: XCTestCase {

let diags: [Diagnostic]
do {
_ = try macro.as(MacroDeclSyntax.self)!.checkDefinition()
_ = try macro.cast(MacroDeclSyntax.self).checkDefinition()
XCTFail("should have failed with an error")
fatalError()
} catch let diagError as DiagnosticsError {
Expand All @@ -94,7 +94,7 @@ final class MacroReplacementTests: XCTestCase {
#expand1(a: 5, b: 17)
"""

let macroDecl = macro.as(MacroDeclSyntax.self)!
let macroDecl = macro.cast(MacroDeclSyntax.self)
let definition = try macroDecl.checkDefinition()
guard case let .expansion(expansion, replacements, genericReplacements) = definition else {
XCTFail("not a normal expansion")
Expand All @@ -115,7 +115,7 @@ final class MacroReplacementTests: XCTestCase {
)
}

func testMacroGenericArgumentExpansion_base() throws {
func testMacroGenericArgumentExpansionBase() throws {
let macro: DeclSyntax =
"""
macro gen<A, B>(a: A, b: B) = #otherMacro<A, B>(first: a, second: b)
Expand All @@ -126,7 +126,7 @@ final class MacroReplacementTests: XCTestCase {
#gen<Int, String>(a: 5, b: "Hello")
"""

let macroDecl = macro.as(MacroDeclSyntax.self)!
let macroDecl = macro.cast(MacroDeclSyntax.self)
let definition = try macroDecl.checkDefinition()
guard case let .expansion(expansion, replacements, genericReplacements) = definition else {
XCTFail("not a normal expansion")
Expand Down Expand Up @@ -158,7 +158,7 @@ final class MacroReplacementTests: XCTestCase {
)
}

func testMacroGenericArgumentExpansion_ignoreTrivia() throws {
func testMacroGenericArgumentExpansionIgnoreTrivia() throws {
let macro: DeclSyntax =
"""
macro gen<A, B /* some comment */>(a: A, b: B) = #otherMacro<A, B>(first: a, second: b)
Expand All @@ -169,7 +169,7 @@ final class MacroReplacementTests: XCTestCase {
#gen<Int, String>(a: 5, b: "Hello")
"""

let macroDecl = macro.as(MacroDeclSyntax.self)!
let macroDecl = macro.cast(MacroDeclSyntax.self)
let definition = try macroDecl.checkDefinition()
guard case let .expansion(expansion, replacements, genericReplacements) = definition else {
XCTFail("not a normal expansion")
Expand Down Expand Up @@ -200,7 +200,7 @@ final class MacroReplacementTests: XCTestCase {
)
}

func testMacroGenericArgumentExpansion_notVisitGenericParameterArguments() throws {
func testMacroGenericArgumentExpansionNotVisitGenericParameterArguments() throws {
let macro: DeclSyntax =
"""
macro gen(a: Array<Int>) = #otherMacro(first: a)
Expand All @@ -211,7 +211,7 @@ final class MacroReplacementTests: XCTestCase {
#gen(a: [1, 2, 3])
"""

let macroDecl = macro.as(MacroDeclSyntax.self)!
let macroDecl = macro.cast(MacroDeclSyntax.self)
let definition = try macroDecl.checkDefinition()
guard case let .expansion(expansion, replacements, genericReplacements) = definition else {
XCTFail("not a normal expansion")
Expand All @@ -234,7 +234,7 @@ final class MacroReplacementTests: XCTestCase {
)
}

func testMacroGenericArgumentExpansion_replaceInner() throws {
func testMacroGenericArgumentExpansionReplaceInner() throws {
let macro: DeclSyntax =
"""
macro gen<A>(a: Array<A>) = #reduce<A>(first: a)
Expand All @@ -245,7 +245,7 @@ final class MacroReplacementTests: XCTestCase {
#gen<Int>(a: [1, 2, 3])
"""

let macroDecl = macro.as(MacroDeclSyntax.self)!
let macroDecl = macro.cast(MacroDeclSyntax.self)
let definition = try macroDecl.checkDefinition()
guard case let .expansion(expansion, replacements, genericReplacements) = definition else {
XCTFail("not a normal expansion")
Expand All @@ -268,18 +268,18 @@ final class MacroReplacementTests: XCTestCase {
)
}

func testMacroGenericArgumentExpansion_array() throws {
func testMacroGenericArgumentExpansionArray() throws {
let macro: DeclSyntax =
"""
macro gen(a: Array<Int>) = #other<A>(first: a)
"""

let use: ExprSyntax =
"""
#otheren<Int>(a: [1, 2, 3])
#gen<Int>(a: [1, 2, 3])
"""

let macroDecl = macro.as(MacroDeclSyntax.self)!
let macroDecl = macro.cast(MacroDeclSyntax.self)
let definition = try macroDecl.checkDefinition()
guard case let .expansion(expansion, replacements, genericReplacements) = definition else {
XCTFail("not a normal expansion")
Expand All @@ -302,18 +302,18 @@ final class MacroReplacementTests: XCTestCase {
)
}

func testMacroExpansion_dontCrashOnDuplicates() throws {
func testMacroExpansionDontCrashOnDuplicates() throws {
let macro: DeclSyntax =
"""
macro gen(a: Array<Int>) = #other<A>(first: a)
"""

let use: ExprSyntax =
"""
#otheren<Int>(a: [1, 2, 3])
#gen<Int>(a: [1, 2, 3])
"""

let macroDecl = macro.as(MacroDeclSyntax.self)!
let macroDecl = macro.cast(MacroDeclSyntax.self)
let definition = try macroDecl.checkDefinition()
guard case let .expansion(expansion, replacements, genericReplacements) = definition else {
XCTFail("not a normal expansion")
Expand Down

0 comments on commit 9d22af9

Please sign in to comment.