diff --git a/mlir/unittests/IR/SymbolTableTest.cpp b/mlir/unittests/IR/SymbolTableTest.cpp index 5dcec749f0f42..12f582874d7f5 100644 --- a/mlir/unittests/IR/SymbolTableTest.cpp +++ b/mlir/unittests/IR/SymbolTableTest.cpp @@ -28,12 +28,14 @@ class ReplaceAllSymbolUsesTest : public ::testing::Test { void SetUp() override { ::test::registerTestDialect(registry); context = std::make_unique(registry); + builder = std::make_unique(context.get()); } void testReplaceAllSymbolUses(ReplaceFnType replaceFn) { // Set up IR and find func ops. OwningOpRef module = parseSourceString(kInput, context.get()); + ASSERT_TRUE(module); SymbolTable symbolTable(module.get()); auto opIterator = module->getBody(0)->getOperations().begin(); auto fooOp = cast(opIterator++); @@ -46,7 +48,7 @@ class ReplaceAllSymbolUsesTest : public ::testing::Test { ASSERT_TRUE(succeeded(res)); ASSERT_TRUE(succeeded(verify(module.get()))); - // Check that it got renamed. + // Check that callee of the call op got renamed. bool calleeFound = false; fooOp->walk([&](CallOpInterface callOp) { StringAttr callee = callOp.getCallableForCallee() @@ -56,13 +58,19 @@ class ReplaceAllSymbolUsesTest : public ::testing::Test { calleeFound = true; }); EXPECT_TRUE(calleeFound); + + // Check that module attribute did *not* get renamed. + auto moduleAttr = (*module)->getAttrOfType("test.attr"); + ASSERT_TRUE(moduleAttr); + EXPECT_EQ(moduleAttr.getValue(), StringRef("bar")); } std::unique_ptr context; + std::unique_ptr builder; private: constexpr static llvm::StringLiteral kInput = R"MLIR( - module { + module attributes { test.attr = @bar } { test.conversion_func_op private @foo() { "test.conversion_call_op"() { callee=@bar } : () -> () "test.return"() : () -> () @@ -81,7 +89,7 @@ TEST_F(ReplaceAllSymbolUsesTest, OperationInModuleOp) { testReplaceAllSymbolUses([&](auto symbolTable, auto module, auto fooOp, auto barOp) -> LogicalResult { return symbolTable.replaceAllSymbolUses( - barOp, StringAttr::get(context.get(), "baz"), module); + barOp, builder->getStringAttr("baz"), module); }); } @@ -90,8 +98,7 @@ TEST_F(ReplaceAllSymbolUsesTest, StringAttrInModuleOp) { testReplaceAllSymbolUses([&](auto symbolTable, auto module, auto fooOp, auto barOp) -> LogicalResult { return symbolTable.replaceAllSymbolUses( - StringAttr::get(context.get(), "bar"), - StringAttr::get(context.get(), "baz"), module); + builder->getStringAttr("bar"), builder->getStringAttr("baz"), module); }); } @@ -100,7 +107,7 @@ TEST_F(ReplaceAllSymbolUsesTest, OperationInModuleBody) { testReplaceAllSymbolUses([&](auto symbolTable, auto module, auto fooOp, auto barOp) -> LogicalResult { return symbolTable.replaceAllSymbolUses( - barOp, StringAttr::get(context.get(), "baz"), &module->getRegion(0)); + barOp, builder->getStringAttr("baz"), &module->getRegion(0)); }); } @@ -108,9 +115,9 @@ TEST_F(ReplaceAllSymbolUsesTest, StringAttrInModuleBody) { // Symbol as `StringAttr`, rename within module body. testReplaceAllSymbolUses([&](auto symbolTable, auto module, auto fooOp, auto barOp) -> LogicalResult { - return symbolTable.replaceAllSymbolUses( - StringAttr::get(context.get(), "bar"), - StringAttr::get(context.get(), "baz"), &module->getRegion(0)); + return symbolTable.replaceAllSymbolUses(builder->getStringAttr("bar"), + builder->getStringAttr("baz"), + &module->getRegion(0)); }); } @@ -119,7 +126,7 @@ TEST_F(ReplaceAllSymbolUsesTest, OperationInFuncOp) { testReplaceAllSymbolUses([&](auto symbolTable, auto module, auto fooOp, auto barOp) -> LogicalResult { return symbolTable.replaceAllSymbolUses( - barOp, StringAttr::get(context.get(), "baz"), fooOp); + barOp, builder->getStringAttr("baz"), fooOp); }); } @@ -128,8 +135,7 @@ TEST_F(ReplaceAllSymbolUsesTest, StringAttrInFuncOp) { testReplaceAllSymbolUses([&](auto symbolTable, auto module, auto fooOp, auto barOp) -> LogicalResult { return symbolTable.replaceAllSymbolUses( - StringAttr::get(context.get(), "bar"), - StringAttr::get(context.get(), "baz"), fooOp); + builder->getStringAttr("bar"), builder->getStringAttr("baz"), fooOp); }); }